Spaces:
Configuration error
Configuration error
| """NutriWise AI Assistant (Advanced RAG) - Gradio app | |
| What this app does: | |
| - Loads your OpenAI key + base URL from config.json (or environment variables). | |
| - Builds (or reuses) a text vector store from Vitamin_and_minerals.pdf | |
| - Builds (or reuses) an image vector store from the 'sources/' images via OpenCLIP embeddings. | |
| - Expands the query, retrieves relevant text chunks, reranks them with a cross-encoder, | |
| and generates an answer with a chat model. | |
| - Shows top related images. | |
| Folder expectations (recommended): | |
| . | |
| βββ app.py | |
| βββ requirements.txt | |
| βββ config.json # (NOT committed) contains API_KEY and OPENAI_API_BASE | |
| βββ data/ | |
| βββ MLS14 - Adv RAG.zip # the provided zip (or already extracted folder) | |
| Run: | |
| pip install -r requirements.txt | |
| python app.py | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any, List, Tuple | |
| import gradio as gr | |
| # --- Paths --- | |
| BASE_DIR = Path(__file__).resolve().parent | |
| DATA_DIR = BASE_DIR / "data" | |
| ZIP_NAME = "MLS14 - Adv RAG.zip" | |
| EXTRACT_DIR = DATA_DIR / "MLS14 - Adv RAG" | |
| PDF_NAME = "Vitamin_and_minerals.pdf" | |
| IMAGES_DIR_NAME = "sources" | |
| TEXT_DB_DIR = BASE_DIR / "vectordb_text" | |
| IMAGE_DB_DIR = BASE_DIR / "my_vectordb" | |
| # --- Config --- | |
| def load_config() -> dict: | |
| """Load config.json if present; otherwise rely on environment variables.""" | |
| cfg_path = BASE_DIR / "config.json" | |
| cfg: dict = {} | |
| if cfg_path.exists(): | |
| with cfg_path.open("r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| # match your notebook | |
| if cfg.get("API_KEY"): | |
| os.environ["OPENAI_API_KEY"] = cfg["API_KEY"] | |
| if cfg.get("OPENAI_API_BASE"): | |
| os.environ["OPENAI_BASE_URL"] = cfg["OPENAI_API_BASE"] | |
| return cfg | |
| def ensure_data_ready() -> Tuple[Path, Path]: | |
| """Ensure the zip is extracted (or already present) and return (pdf_path, images_dir).""" | |
| DATA_DIR.mkdir(exist_ok=True) | |
| pdf_path = EXTRACT_DIR / PDF_NAME | |
| images_dir = EXTRACT_DIR / IMAGES_DIR_NAME | |
| if pdf_path.exists() and images_dir.exists(): | |
| return pdf_path, images_dir | |
| # Try to extract zip from ./data or project root | |
| import zipfile | |
| zip_candidates = [ | |
| DATA_DIR / ZIP_NAME, | |
| BASE_DIR / ZIP_NAME, | |
| Path.cwd() / ZIP_NAME, | |
| ] | |
| zip_path = next((p for p in zip_candidates if p.exists()), None) | |
| if zip_path is None: | |
| raise FileNotFoundError( | |
| f"Could not find '{ZIP_NAME}'. Put it in '{DATA_DIR}/' or alongside app.py." | |
| ) | |
| with zipfile.ZipFile(zip_path, "r") as z: | |
| z.extractall(DATA_DIR) | |
| if not pdf_path.exists(): | |
| raise FileNotFoundError(f"Expected PDF not found after extraction: {pdf_path}") | |
| if not images_dir.exists(): | |
| raise FileNotFoundError(f"Expected images folder not found after extraction: {images_dir}") | |
| return pdf_path, images_dir | |
| # --- LangChain / Chroma initialization --- | |
| class RagAssets: | |
| llm: Any | |
| embeddings: Any | |
| retriever_k: int | |
| vectorstore: Any | |
| crossencoder: Any | |
| image_collection: Any | |
| def _safe_parse_list(text: str) -> List[str]: | |
| """Best-effort parse for a list of queries produced by the LLM.""" | |
| text = text.strip() | |
| # Try literal eval (Python list) | |
| try: | |
| val = ast.literal_eval(text) | |
| if isinstance(val, list) and all(isinstance(x, str) for x in val): | |
| return [x.strip() for x in val if x.strip()] | |
| except Exception: | |
| pass | |
| # Try JSON list | |
| try: | |
| import json as _json | |
| val = _json.loads(text) | |
| if isinstance(val, list) and all(isinstance(x, str) for x in val): | |
| return [x.strip() for x in val if x.strip()] | |
| except Exception: | |
| pass | |
| # Fallback: one per line (strip bullets) | |
| lines = [] | |
| for line in text.splitlines(): | |
| line = line.strip().lstrip("-β’*").strip() | |
| if line: | |
| lines.append(line) | |
| return lines | |
| def init_assets() -> RagAssets: | |
| load_config() | |
| pdf_path, images_dir = ensure_data_ready() | |
| # --- LangChain --- | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") | |
| llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| # --- Text vector store (persisted) --- | |
| vectorstore = Chroma( | |
| collection_name="vitamin_and_minerals", | |
| embedding_function=embeddings, | |
| persist_directory=str(TEXT_DB_DIR), | |
| ) | |
| # Populate if empty | |
| try: | |
| count = vectorstore._collection.count() # type: ignore[attr-defined] | |
| except Exception: | |
| count = 0 | |
| if count == 0: | |
| loader = PyPDFLoader(str(pdf_path)) | |
| docs = loader.load() | |
| vectorstore.add_documents(docs) | |
| try: | |
| vectorstore.persist() | |
| except Exception: | |
| pass | |
| retriever_k = 8 | |
| # --- Reranker (cross-encoder) --- | |
| crossencoder = HuggingFaceCrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| # --- Image vector store (ChromaDB client) --- | |
| import chromadb | |
| from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction | |
| from chromadb.utils.data_loaders import ImageLoader | |
| client = chromadb.PersistentClient(path=str(IMAGE_DB_DIR)) | |
| image_loader = ImageLoader() | |
| image_embed_fn = OpenCLIPEmbeddingFunction() | |
| image_collection = client.get_or_create_collection( | |
| name="nutrition_images", | |
| embedding_function=image_embed_fn, | |
| data_loader=image_loader, | |
| ) | |
| if image_collection.count() == 0: | |
| image_paths = sorted( | |
| [p for p in images_dir.glob("*") if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}] | |
| ) | |
| metadatas = [] | |
| for p in image_paths: | |
| vitamin = "?" | |
| s = p.stem.lower() | |
| if s.startswith("vitamina"): | |
| vitamin = "A" | |
| elif s.startswith("vitaminc"): | |
| vitamin = "C" | |
| elif s.startswith("vitamind"): | |
| vitamin = "D" | |
| metadatas.append({"vitamin": vitamin, "info": f"Food sources of Vitamin {vitamin}"}) | |
| image_collection.add( | |
| ids=[str(i) for i in range(len(image_paths))], | |
| uris=[str(p) for p in image_paths], | |
| metadatas=metadatas, | |
| ) | |
| return RagAssets( | |
| llm=llm, | |
| embeddings=embeddings, | |
| retriever_k=retriever_k, | |
| vectorstore=vectorstore, | |
| crossencoder=crossencoder, | |
| image_collection=image_collection, | |
| ) | |
| # --- RAG logic --- | |
| def expand_query(user_query: str, llm: Any) -> List[str]: | |
| prompt = f"""You are helping query expansion for document retrieval. | |
| Given the question, produce 3 to 5 alternative rephrasings that keep the SAME meaning. | |
| - Use synonyms and alternate phrasing. | |
| - Keep acronyms and unknown terms unchanged. | |
| - Return ONLY a JSON array of strings. | |
| Question: {user_query} | |
| """ | |
| resp = llm.invoke(prompt).content | |
| queries = _safe_parse_list(resp) | |
| queries = [q for q in queries if q.strip()] | |
| if user_query not in queries: | |
| queries.insert(0, user_query) | |
| return queries[:6] | |
| def retrieve_and_rerank(user_query: str, queries: List[str], assets: RagAssets) -> List[Any]: | |
| retriever = assets.vectorstore.as_retriever(search_kwargs={"k": assets.retriever_k}) | |
| candidates = [] | |
| seen = set() | |
| for q in queries: | |
| docs = retriever.get_relevant_documents(q) | |
| for d in docs: | |
| key = (d.metadata.get("source"), d.metadata.get("page"), d.page_content[:200]) | |
| if key not in seen: | |
| seen.add(key) | |
| candidates.append(d) | |
| if not candidates: | |
| return [] | |
| pairs = [[user_query, d.page_content] for d in candidates] | |
| scores = assets.crossencoder.score(pairs) | |
| ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True) | |
| return [d for d, _ in ranked[:5]] | |
| def build_prompt(user_query: str, docs: List[Any]) -> str: | |
| if not docs: | |
| context = "(no context retrieved)" | |
| else: | |
| parts = [] | |
| for i, d in enumerate(docs, 1): | |
| src = d.metadata.get("source", "document") | |
| page = d.metadata.get("page", "?") | |
| parts.append(f"[{i}] (source={Path(src).name}, page={page})\n{d.page_content}") | |
| context = "\n\n".join(parts) | |
| return f"""You are an expert assistant for nutrition Q&A. | |
| Answer the user's question using ONLY the context below. | |
| If the context is insufficient, say what is missing and what you'd look up next. | |
| Question: {user_query} | |
| Context: | |
| {context} | |
| """ | |
| def retrieve_images(user_query: str, assets: RagAssets, n: int = 3) -> List[Tuple[str, str]]: | |
| res = assets.image_collection.query( | |
| query_texts=[user_query], | |
| n_results=n, | |
| include=["uris", "metadatas"], | |
| ) | |
| uris = res.get("uris", [[]])[0] if res else [] | |
| metas = res.get("metadatas", [[]])[0] if res else [] | |
| return [(uri, (meta or {}).get("info", "")) for uri, meta in zip(uris, metas)] | |
| def chat(user_message: str, history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: | |
| assets = init_assets() | |
| queries = expand_query(user_message, assets.llm) | |
| top_docs = retrieve_and_rerank(user_message, queries, assets) | |
| prompt = build_prompt(user_message, top_docs) | |
| answer = assets.llm.invoke(prompt).content | |
| images = retrieve_images(user_message, assets, n=3) | |
| history = history + [(user_message, answer)] | |
| return history, images | |
| # --- Gradio UI --- | |
| def build_ui() -> gr.Blocks: | |
| error_banner = "" | |
| try: | |
| load_config() | |
| ensure_data_ready() | |
| except Exception as e: | |
| error_banner = ( | |
| "β οΈ **Setup issue:**\n\n" | |
| f"- {e}\n\n" | |
| "**Fix:** Put `MLS14 - Adv RAG.zip` into `./data/` (recommended) or next to `app.py`, " | |
| "and ensure `config.json` exists next to `app.py` (or set `OPENAI_API_KEY` + `OPENAI_BASE_URL`)." | |
| ) | |
| with gr.Blocks(title="NutriWise AI Assistant") as demo: | |
| gr.Markdown("# NutriWise AI Assistant\nAdvanced RAG over a nutrition PDF + image retrieval") | |
| if error_banner: | |
| gr.Markdown(error_banner) | |
| chatbot = gr.Chatbot(label="Chat", height=360) | |
| gallery = gr.Gallery(label="Relevant Images", columns=3, height=240, preview=True) | |
| state = gr.State([]) | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Ask a question", placeholder="e.g., What are the benefits of vitamin C?", scale=5) | |
| send = gr.Button("Send", scale=1) | |
| def _submit(message, history): | |
| return chat(message, history) | |
| send.click(_submit, inputs=[msg, state], outputs=[chatbot, gallery]).then(lambda: "", None, msg) | |
| msg.submit(_submit, inputs=[msg, state], outputs=[chatbot, gallery]).then(lambda: "", None, msg) | |
| gr.Markdown("Tip: first run may take a bit while it builds vector stores (saved locally).") | |
| return demo | |
| if __name__ == "__main__": | |
| ui = build_ui() | |
| ui.launch() | |