NewNutriapp / app.py
eaglelandsonce's picture
Upload 6 files
f53107a verified
"""NutriWise AI Assistant (Advanced RAG) - Gradio app
What this app does:
- Loads your OpenAI key + base URL from config.json (or environment variables).
- Builds (or reuses) a text vector store from Vitamin_and_minerals.pdf
- Builds (or reuses) an image vector store from the 'sources/' images via OpenCLIP embeddings.
- Expands the query, retrieves relevant text chunks, reranks them with a cross-encoder,
and generates an answer with a chat model.
- Shows top related images.
Folder expectations (recommended):
.
β”œβ”€β”€ app.py
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ config.json # (NOT committed) contains API_KEY and OPENAI_API_BASE
└── data/
└── MLS14 - Adv RAG.zip # the provided zip (or already extracted folder)
Run:
pip install -r requirements.txt
python app.py
"""
from __future__ import annotations
import ast
import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, List, Tuple
import gradio as gr
# --- Paths ---
BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
ZIP_NAME = "MLS14 - Adv RAG.zip"
EXTRACT_DIR = DATA_DIR / "MLS14 - Adv RAG"
PDF_NAME = "Vitamin_and_minerals.pdf"
IMAGES_DIR_NAME = "sources"
TEXT_DB_DIR = BASE_DIR / "vectordb_text"
IMAGE_DB_DIR = BASE_DIR / "my_vectordb"
# --- Config ---
def load_config() -> dict:
"""Load config.json if present; otherwise rely on environment variables."""
cfg_path = BASE_DIR / "config.json"
cfg: dict = {}
if cfg_path.exists():
with cfg_path.open("r", encoding="utf-8") as f:
cfg = json.load(f)
# match your notebook
if cfg.get("API_KEY"):
os.environ["OPENAI_API_KEY"] = cfg["API_KEY"]
if cfg.get("OPENAI_API_BASE"):
os.environ["OPENAI_BASE_URL"] = cfg["OPENAI_API_BASE"]
return cfg
def ensure_data_ready() -> Tuple[Path, Path]:
"""Ensure the zip is extracted (or already present) and return (pdf_path, images_dir)."""
DATA_DIR.mkdir(exist_ok=True)
pdf_path = EXTRACT_DIR / PDF_NAME
images_dir = EXTRACT_DIR / IMAGES_DIR_NAME
if pdf_path.exists() and images_dir.exists():
return pdf_path, images_dir
# Try to extract zip from ./data or project root
import zipfile
zip_candidates = [
DATA_DIR / ZIP_NAME,
BASE_DIR / ZIP_NAME,
Path.cwd() / ZIP_NAME,
]
zip_path = next((p for p in zip_candidates if p.exists()), None)
if zip_path is None:
raise FileNotFoundError(
f"Could not find '{ZIP_NAME}'. Put it in '{DATA_DIR}/' or alongside app.py."
)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(DATA_DIR)
if not pdf_path.exists():
raise FileNotFoundError(f"Expected PDF not found after extraction: {pdf_path}")
if not images_dir.exists():
raise FileNotFoundError(f"Expected images folder not found after extraction: {images_dir}")
return pdf_path, images_dir
# --- LangChain / Chroma initialization ---
@dataclass
class RagAssets:
llm: Any
embeddings: Any
retriever_k: int
vectorstore: Any
crossencoder: Any
image_collection: Any
def _safe_parse_list(text: str) -> List[str]:
"""Best-effort parse for a list of queries produced by the LLM."""
text = text.strip()
# Try literal eval (Python list)
try:
val = ast.literal_eval(text)
if isinstance(val, list) and all(isinstance(x, str) for x in val):
return [x.strip() for x in val if x.strip()]
except Exception:
pass
# Try JSON list
try:
import json as _json
val = _json.loads(text)
if isinstance(val, list) and all(isinstance(x, str) for x in val):
return [x.strip() for x in val if x.strip()]
except Exception:
pass
# Fallback: one per line (strip bullets)
lines = []
for line in text.splitlines():
line = line.strip().lstrip("-β€’*").strip()
if line:
lines.append(line)
return lines
@lru_cache(maxsize=1)
def init_assets() -> RagAssets:
load_config()
pdf_path, images_dir = ensure_data_ready()
# --- LangChain ---
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# --- Text vector store (persisted) ---
vectorstore = Chroma(
collection_name="vitamin_and_minerals",
embedding_function=embeddings,
persist_directory=str(TEXT_DB_DIR),
)
# Populate if empty
try:
count = vectorstore._collection.count() # type: ignore[attr-defined]
except Exception:
count = 0
if count == 0:
loader = PyPDFLoader(str(pdf_path))
docs = loader.load()
vectorstore.add_documents(docs)
try:
vectorstore.persist()
except Exception:
pass
retriever_k = 8
# --- Reranker (cross-encoder) ---
crossencoder = HuggingFaceCrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
# --- Image vector store (ChromaDB client) ---
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
client = chromadb.PersistentClient(path=str(IMAGE_DB_DIR))
image_loader = ImageLoader()
image_embed_fn = OpenCLIPEmbeddingFunction()
image_collection = client.get_or_create_collection(
name="nutrition_images",
embedding_function=image_embed_fn,
data_loader=image_loader,
)
if image_collection.count() == 0:
image_paths = sorted(
[p for p in images_dir.glob("*") if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".webp"}]
)
metadatas = []
for p in image_paths:
vitamin = "?"
s = p.stem.lower()
if s.startswith("vitamina"):
vitamin = "A"
elif s.startswith("vitaminc"):
vitamin = "C"
elif s.startswith("vitamind"):
vitamin = "D"
metadatas.append({"vitamin": vitamin, "info": f"Food sources of Vitamin {vitamin}"})
image_collection.add(
ids=[str(i) for i in range(len(image_paths))],
uris=[str(p) for p in image_paths],
metadatas=metadatas,
)
return RagAssets(
llm=llm,
embeddings=embeddings,
retriever_k=retriever_k,
vectorstore=vectorstore,
crossencoder=crossencoder,
image_collection=image_collection,
)
# --- RAG logic ---
def expand_query(user_query: str, llm: Any) -> List[str]:
prompt = f"""You are helping query expansion for document retrieval.
Given the question, produce 3 to 5 alternative rephrasings that keep the SAME meaning.
- Use synonyms and alternate phrasing.
- Keep acronyms and unknown terms unchanged.
- Return ONLY a JSON array of strings.
Question: {user_query}
"""
resp = llm.invoke(prompt).content
queries = _safe_parse_list(resp)
queries = [q for q in queries if q.strip()]
if user_query not in queries:
queries.insert(0, user_query)
return queries[:6]
def retrieve_and_rerank(user_query: str, queries: List[str], assets: RagAssets) -> List[Any]:
retriever = assets.vectorstore.as_retriever(search_kwargs={"k": assets.retriever_k})
candidates = []
seen = set()
for q in queries:
docs = retriever.get_relevant_documents(q)
for d in docs:
key = (d.metadata.get("source"), d.metadata.get("page"), d.page_content[:200])
if key not in seen:
seen.add(key)
candidates.append(d)
if not candidates:
return []
pairs = [[user_query, d.page_content] for d in candidates]
scores = assets.crossencoder.score(pairs)
ranked = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return [d for d, _ in ranked[:5]]
def build_prompt(user_query: str, docs: List[Any]) -> str:
if not docs:
context = "(no context retrieved)"
else:
parts = []
for i, d in enumerate(docs, 1):
src = d.metadata.get("source", "document")
page = d.metadata.get("page", "?")
parts.append(f"[{i}] (source={Path(src).name}, page={page})\n{d.page_content}")
context = "\n\n".join(parts)
return f"""You are an expert assistant for nutrition Q&A.
Answer the user's question using ONLY the context below.
If the context is insufficient, say what is missing and what you'd look up next.
Question: {user_query}
Context:
{context}
"""
def retrieve_images(user_query: str, assets: RagAssets, n: int = 3) -> List[Tuple[str, str]]:
res = assets.image_collection.query(
query_texts=[user_query],
n_results=n,
include=["uris", "metadatas"],
)
uris = res.get("uris", [[]])[0] if res else []
metas = res.get("metadatas", [[]])[0] if res else []
return [(uri, (meta or {}).get("info", "")) for uri, meta in zip(uris, metas)]
def chat(user_message: str, history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
assets = init_assets()
queries = expand_query(user_message, assets.llm)
top_docs = retrieve_and_rerank(user_message, queries, assets)
prompt = build_prompt(user_message, top_docs)
answer = assets.llm.invoke(prompt).content
images = retrieve_images(user_message, assets, n=3)
history = history + [(user_message, answer)]
return history, images
# --- Gradio UI ---
def build_ui() -> gr.Blocks:
error_banner = ""
try:
load_config()
ensure_data_ready()
except Exception as e:
error_banner = (
"⚠️ **Setup issue:**\n\n"
f"- {e}\n\n"
"**Fix:** Put `MLS14 - Adv RAG.zip` into `./data/` (recommended) or next to `app.py`, "
"and ensure `config.json` exists next to `app.py` (or set `OPENAI_API_KEY` + `OPENAI_BASE_URL`)."
)
with gr.Blocks(title="NutriWise AI Assistant") as demo:
gr.Markdown("# NutriWise AI Assistant\nAdvanced RAG over a nutrition PDF + image retrieval")
if error_banner:
gr.Markdown(error_banner)
chatbot = gr.Chatbot(label="Chat", height=360)
gallery = gr.Gallery(label="Relevant Images", columns=3, height=240, preview=True)
state = gr.State([])
with gr.Row():
msg = gr.Textbox(label="Ask a question", placeholder="e.g., What are the benefits of vitamin C?", scale=5)
send = gr.Button("Send", scale=1)
def _submit(message, history):
return chat(message, history)
send.click(_submit, inputs=[msg, state], outputs=[chatbot, gallery]).then(lambda: "", None, msg)
msg.submit(_submit, inputs=[msg, state], outputs=[chatbot, gallery]).then(lambda: "", None, msg)
gr.Markdown("Tip: first run may take a bit while it builds vector stores (saved locally).")
return demo
if __name__ == "__main__":
ui = build_ui()
ui.launch()