Nutri / app.py
farfar101's picture
Update app.py
75228b2 verified
raw
history blame
25 kB
"""
Pet Nutrition Knowledge Assistant β€” Hugging Face Space
=======================================================
ZeroGPU-compatible Gradio app that wraps the DeBERTa + FLAN-T5 hybrid RAG
pipeline over two veterinary nutrition PDFs.
Expected Space file layout
--------------------------
app.py ← this file
requirements.txt
docs/
FEDIAF-Nutritional-Guidelines_2025-ONLINE.pdf
Essential cat and dog nutrition booklet V2 - electronic version.pdf
"""
# ── ZeroGPU / Spaces compatibility ──────────────────────────────────────────
import spaces # must be imported before torch / transformers on ZeroGPU
import io
import os
import re
import time
import unicodedata
from collections import Counter
import faiss
import gradio as gr
import numpy as np
import pandas as pd
import torch
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pypdf import PdfReader
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder, SentenceTransformer
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
)
# ── Constants ────────────────────────────────────────────────────────────────
MODEL_ID = "deepset/deberta-v3-base-squad2"
GENERATOR_MODEL_ID = "google/flan-t5-base"
EMBEDDING_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
RERANKER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
DOCS_DIR = os.path.dirname(__file__)
PDF_FILENAMES = [
"FEDIAF-Nutritional-Guidelines_2025-ONLINE.pdf",
"Essential cat and dog nutrition booklet V2 - electronic version.pdf",
]
PROSE_SPLITTER = RecursiveCharacterTextSplitter(
chunk_size=700,
chunk_overlap=120,
separators=["\n\n", "\n", ". ", "; ", " ", ""],
)
DEFAULT_CONFIDENCE_THRESHOLD = 0.30
DENSE_INITIAL_K = 28
LEXICAL_INITIAL_K = 28
RRF_K = 60
RETRIEVER_CANDIDATE_K = 20
SOURCE_MATCH_BONUS = 2.5
NEIGHBOR_SCORE_DECAY = 0.12
MAX_CHUNKS_PER_PAGE = 2
MMR_LAMBDA = 0.7
SOURCE_HINTS = {
"FEDIAF-Nutritional-Guidelines_2025-ONLINE.pdf": [
"fediaf", "publication month", "publication year",
"recommendation tables", "adult maintenance",
"growth and reproduction", "per 100 g dm",
"nutritional maximum", "legal maximum",
"canned pet food", "dry pet food",
"body condition score", "optimal body fat",
"scale of 1 to 9",
],
"Essential cat and dog nutrition booklet V2 - electronic version.pdf": [
"understanding dogs", "understanding cats", "teeth",
"obligate carnivore", "bile salts", "amino acids",
"complete and balanced nutrition", "essential nutrients",
"most important nutrient", "water", "drinking water",
"macronutrients", "gross energy", "fat", "protein",
"carbohydrate", "phosphorus", "light bearing", "ash",
"dietary minerals",
],
}
# ── PDF loading & chunking (runs once at startup, CPU-only) ──────────────────
TABLE_COLUMN_SPLIT_RE = re.compile(r"\s{2,}")
def is_table_like_line(line: str) -> bool:
stripped = line.strip()
if not stripped:
return False
has_columns = bool(TABLE_COLUMN_SPLIT_RE.search(stripped))
has_numeric = bool(re.search(r"\d", stripped))
has_unit = bool(re.search(r"\b(g|mg|kg|iu|kcal|mj|dm|Β΅g|ug|%)\b", stripped.lower()))
return has_columns and len(stripped) <= 180 and (has_numeric or has_unit)
def normalize_page_text(raw_text: str) -> str:
normalized_lines = []
blank_pending = False
for raw_line in raw_text.splitlines():
line = raw_line.replace("\x00", "").replace("\t", " ").rstrip()
if not line.strip():
if normalized_lines and not blank_pending:
normalized_lines.append("")
blank_pending = True
continue
blank_pending = False
if is_table_like_line(line):
columns = [
col.strip()
for col in TABLE_COLUMN_SPLIT_RE.split(line.strip())
if col.strip()
]
cleaned_line = " | ".join(columns)
else:
cleaned_line = re.sub(r"\s+", " ", line).strip()
normalized_lines.append(cleaned_line)
return "\n".join(normalized_lines).strip()
def load_docs(pdf_path: str, source_name: str) -> list[dict]:
with open(pdf_path, "rb") as fh:
reader = PdfReader(fh)
pages = []
for i, page in enumerate(reader.pages):
try:
raw_text = page.extract_text(extraction_mode="layout") or ""
except TypeError:
raw_text = page.extract_text() or ""
cleaned = normalize_page_text(raw_text)
if cleaned:
pages.append({"source": source_name, "page": i + 1, "text": cleaned})
print(f" '{source_name}': extracted {len(pages)} pages.")
return pages
def is_heading_line(line: str) -> bool:
stripped = line.strip()
if not stripped or len(stripped) > 120 or "|" in stripped:
return False
if stripped.endswith((".", ",", ";")):
return False
word_count = len(stripped.split())
alpha_chars = [c for c in stripped if c.isalpha()]
upper_ratio = (
sum(c.isupper() for c in alpha_chars) / len(alpha_chars)
if alpha_chars else 0.0
)
return upper_ratio >= 0.55 or word_count <= 10
def split_page_into_blocks(page_text: str) -> list[dict]:
raw_blocks = []
current_lines, current_type = [], "prose"
def flush():
nonlocal current_lines, current_type
if current_lines:
raw_blocks.append({"type": current_type, "text": "\n".join(current_lines).strip()})
current_lines, current_type = [], "prose"
for line in page_text.splitlines():
stripped = line.strip()
if not stripped:
flush(); continue
line_type = "table" if "|" in stripped else ("heading" if is_heading_line(stripped) else "prose")
if current_lines and line_type != current_type:
flush()
if not current_lines:
current_type = line_type
current_lines.append(stripped)
flush()
merged, pending_heading = [], None
for block in raw_blocks:
if block["type"] == "heading":
pending_heading = block["text"] if pending_heading is None else f"{pending_heading}\n{block['text']}"
continue
text = f"{pending_heading}\n{block['text']}" if pending_heading else block["text"]
pending_heading = None
merged.append({"type": block["type"], "text": text})
if pending_heading:
merged.append({"type": "heading", "text": pending_heading})
return merged
def chunk_block_text(block_text: str, block_type: str) -> list[str]:
if block_type == "table":
table_lines = [l.strip() for l in block_text.splitlines() if l.strip()]
chunks, current_lines, current_len = [], [], 0
for line in table_lines:
ll = len(line) + 1
if current_lines and current_len + ll > 700:
chunks.append("\n".join(current_lines))
overlap = current_lines[-2:] if len(current_lines) > 2 else current_lines[-1:]
current_lines = overlap.copy()
current_len = sum(len(x) + 1 for x in current_lines)
current_lines.append(line)
current_len += ll
if current_lines:
chunks.append("\n".join(current_lines))
return chunks
return PROSE_SPLITTER.split_text(block_text)
def build_retrieval_text(chunk: dict) -> str:
source_stub = re.sub(r"\.pdf$", "", chunk["source"], flags=re.IGNORECASE)
parts = [
f"source: {source_stub}",
f"page: {chunk['page']}",
f"block type: {chunk['block_type']}",
"content style: table values, nutrient units, and label-value pairs"
if chunk["block_type"] == "table"
else "content style: prose explanation and definitions",
chunk["text"],
]
return "\n".join(parts)
def build_chunks(all_pages: list[dict]) -> tuple[list[dict], dict]:
chunks = []
for page in all_pages:
for b_idx, block in enumerate(split_page_into_blocks(page["text"])):
for c_idx, ct in enumerate(chunk_block_text(block["text"], block["type"])):
rec = {
"source": page["source"], "page": page["page"],
"chunk_id": f"{page['source']}_p{page['page']}_b{b_idx}_c{c_idx}",
"block_type": block["type"], "block_index": b_idx, "chunk_index": c_idx,
"text": ct,
}
rec["retrieval_text"] = build_retrieval_text(rec)
chunks.append(rec)
page_chunk_lookup: dict = {}
for chunk in chunks:
page_chunk_lookup.setdefault((chunk["source"], chunk["page"]), []).append(chunk)
for page_chunks in page_chunk_lookup.values():
for pos, chunk in enumerate(page_chunks):
chunk["page_chunk_position"] = pos
return chunks, page_chunk_lookup
# ── Build indices (CPU, runs once at startup) ─────────────────────────────────
print("Loading PDFs …")
all_pages = []
for fname in PDF_FILENAMES:
path = os.path.join(DOCS_DIR, fname)
if not os.path.exists(path):
raise FileNotFoundError(
f"PDF not found: {path}\n"
"Place both PDFs inside a 'docs/' folder in your Space repository."
)
all_pages.extend(load_docs(path, fname))
print("Building chunks …")
chunks, page_chunk_lookup = build_chunks(all_pages)
print(f"Total chunks: {len(chunks)}")
print(f"Loading embedding model: {EMBEDDING_MODEL_ID} …")
embedder = SentenceTransformer(EMBEDDING_MODEL_ID)
retrieval_corpus_texts = [c.get("retrieval_text", c["text"]) for c in chunks]
print("Encoding chunks …")
corpus_embeddings = embedder.encode(
retrieval_corpus_texts,
batch_size=64,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=True,
)
def bm25_tokenize(text: str) -> list[str]:
normalized = re.sub(r"[^a-z0-9%+.-]+", " ", text.lower())
return [t for t in normalized.split() if t]
bm25_corpus_tokens = [bm25_tokenize(t) for t in retrieval_corpus_texts]
bm25_index = BM25Okapi(bm25_corpus_tokens)
embedding_dim = corpus_embeddings.shape[1]
faiss_index = faiss.IndexFlatIP(embedding_dim)
faiss_index.add(corpus_embeddings)
print(f"FAISS index built ({faiss_index.ntotal} vectors).")
print(f"Loading reranker: {RERANKER_MODEL_ID} …")
reranker = CrossEncoder(RERANKER_MODEL_ID)
print("Reranker ready.")
# ── Models loaded with @spaces.GPU ───────────────────────────────────────────
# QA pipeline and generator are initialised inside the inference function
# so ZeroGPU can attach a GPU just-in-time. We pre-load tokenizer/model
# weights to CPU here so the first request isn't slow for _loading_.
print(f"Pre-loading QA model weights: {MODEL_ID} …")
from transformers import AutoModelForQuestionAnswering
_qa_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
_qa_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_ID)
print(f"Pre-loading generator weights: {GENERATOR_MODEL_ID} …")
_gen_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_ID)
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GENERATOR_MODEL_ID)
print("All weights pre-loaded to CPU. ZeroGPU will move them on first call.")
# ── Retriever helpers (CPU, no GPU needed) ────────────────────────────────────
def canonicalize_source_name(name):
return name.lower().strip() if name else None
def infer_preferred_source(query: str):
q = query.lower()
best, best_score = None, 0
for src, kws in SOURCE_HINTS.items():
score = sum(1 for kw in kws if kw in q)
if score > best_score:
best, best_score = src, score
return best
def reciprocal_rank_fusion(rankings, k=RRF_K):
scores = {}
for ranking in rankings:
for rank, idx in enumerate(ranking):
scores[idx] = scores.get(idx, 0.0) + 1.0 / (k + rank + 1)
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
def chunk_distance(a, b):
if a["source"] != b["source"] or a["page"] != b["page"]:
return 10_000
return abs(int(a.get("page_chunk_position", 0)) - int(b.get("page_chunk_position", 0)))
def expand_with_neighbors(ranked, preferred_source=None, seed_count=8):
expanded = {c["chunk_id"]: c.copy() for c in ranked}
pref_norm = canonicalize_source_name(preferred_source)
for seed in ranked[:seed_count]:
page_key = (seed["source"], seed["page"])
page_ch = page_chunk_lookup.get(page_key, [])
pos = int(seed.get("page_chunk_position", -1))
if pos < 0:
continue
for offset in (-1, 1):
np_ = pos + offset
if np_ < 0 or np_ >= len(page_ch):
continue
nb = page_ch[np_].copy()
dist = chunk_distance(seed, nb)
if dist >= 10_000:
continue
nb.update({
"faiss_score": nb.get("faiss_score", 0.0),
"bm25_score": nb.get("bm25_score", 0.0),
"rrf_score": nb.get("rrf_score", 0.0),
"source_match": int(pref_norm is not None and
canonicalize_source_name(nb["source"]) == pref_norm),
"neighbor_seed": seed["chunk_id"],
"neighbor_distance": dist,
})
nb_score = round(seed["score"] - NEIGHBOR_SCORE_DECAY * dist, 4)
nb["rerank_score"] = nb.get("rerank_score", nb_score)
nb["final_score"] = nb_score
nb["score"] = nb_score
existing = expanded.get(nb["chunk_id"])
if existing is None or nb["score"] > existing.get("score", float("-inf")):
expanded[nb["chunk_id"]] = nb
return sorted(
expanded.values(),
key=lambda c: (c.get("score",0), c.get("rerank_score",0), c.get("rrf_score",0)),
reverse=True,
)
def mmr_select(query_embedding, candidates, top_k, lambda_mult=MMR_LAMBDA):
if len(candidates) <= top_k:
return candidates
cand_embs = embedder.encode(
[c.get("retrieval_text", c["text"]) for c in candidates],
convert_to_numpy=True, normalize_embeddings=True,
)
selected, remaining = [], list(range(len(candidates)))
base_scores = np.array([float(c.get("score", 0)) for c in candidates])
while remaining and len(selected) < top_k:
if not selected:
best = max(remaining, key=lambda i: base_scores[i])
else:
sel_embs = cand_embs[selected]
best, best_val = None, float("-inf")
for i in remaining:
penalty = float(np.max(sel_embs @ cand_embs[i]))
val = lambda_mult * base_scores[i] - (1 - lambda_mult) * penalty
if val > best_val:
best_val, best = val, i
selected.append(best)
remaining.remove(best)
return [candidates[i] for i in selected]
def cap_page_duplicates(candidates, top_k, max_per_page=MAX_CHUNKS_PER_PAGE):
capped, skipped, counts = [], [], {}
for c in candidates:
key = (c["source"], c["page"])
n = counts.get(key, 0)
if n < max_per_page:
capped.append(c)
counts[key] = n + 1
else:
skipped.append(c)
if len(capped) >= top_k:
return capped[:top_k]
for c in skipped:
if c not in capped:
capped.append(c)
if len(capped) >= top_k:
break
return capped[:top_k]
def retrieve(query: str, top_k: int = 5, candidate_k: int = RETRIEVER_CANDIDATE_K,
preferred_source=None) -> list[dict]:
candidate_k = max(candidate_k, top_k)
preferred_source = preferred_source or infer_preferred_source(query)
dense_k = max(candidate_k, DENSE_INITIAL_K)
lexical_k = max(candidate_k, LEXICAL_INITIAL_K)
q_vec = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)
q_emb = q_vec[0]
d_scores, d_idxs = faiss_index.search(q_vec, dense_k)
dense_ranking = [i for i in d_idxs[0] if i != -1]
dense_score_map = {i: round(float(s), 4) for s, i in zip(d_scores[0], d_idxs[0]) if i != -1}
lex_scores = bm25_index.get_scores(bm25_tokenize(query))
lex_ranking = list(np.argsort(lex_scores)[::-1][:lexical_k])
lex_score_map = {i: round(float(lex_scores[i]), 4) for i in lex_ranking}
fused = reciprocal_rank_fusion([dense_ranking, lex_ranking])
fused_map = dict(fused)
fused_idxs = [i for i, _ in fused[:candidate_k]]
candidates = []
for idx in fused_idxs:
c = chunks[idx].copy()
c["faiss_score"] = dense_score_map.get(idx, 0.0)
c["bm25_score"] = lex_score_map.get(idx, 0.0)
c["rrf_score"] = round(fused_map.get(idx, 0.0), 6)
c["source_match"] = int(
preferred_source is not None and
canonicalize_source_name(c["source"]) == canonicalize_source_name(preferred_source)
)
candidates.append(c)
if not candidates:
return []
pairs = [(query, c.get("retrieval_text", c["text"])) for c in candidates]
rr_scores = reranker.predict(pairs)
for c, rs in zip(candidates, rr_scores):
c["rerank_score"] = round(float(rs), 4)
c["final_score"] = round(c["rerank_score"] + SOURCE_MATCH_BONUS * c["source_match"], 4)
c["score"] = c["final_score"]
candidates.sort(key=lambda c: (c["final_score"], c["rerank_score"], c["rrf_score"]), reverse=True)
expanded = expand_with_neighbors(candidates, preferred_source, seed_count=min(len(candidates), max(top_k, 8)))
diversified = mmr_select(q_emb, expanded, top_k=max(top_k * 2, top_k))
return cap_page_duplicates(diversified, top_k)
def assemble_context(retrieved: list[dict], max_chars: int = 2000) -> str:
parts, total = [], 0
for c in retrieved:
text = c["text"]
if total + len(text) > max_chars:
remaining = max_chars - total
if remaining > 50:
parts.append(text[:remaining])
break
parts.append(text)
total += len(text)
return " ".join(parts)
# ── Inference: decorated with @spaces.GPU so ZeroGPU attaches a GPU ──────────
@spaces.GPU
def run_rag(query: str, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD) -> dict:
"""Full RAG pipeline. Runs on GPU when available via ZeroGPU."""
device = 0 if torch.cuda.is_available() else -1
# Move models to the active device for this GPU slot
if torch.cuda.is_available():
_qa_model.to("cuda")
_gen_model.to("cuda")
qa_pipe = pipeline(
"question-answering",
model=_qa_model,
tokenizer=_qa_tokenizer,
device=device,
)
t0 = time.time()
preferred_source = infer_preferred_source(query)
retrieved = retrieve(query, top_k=5, preferred_source=preferred_source)
context = assemble_context(retrieved, max_chars=2000)
fallback_msg = (
"The system could not extract a sufficiently supported answer "
"from the retrieved documents. Please rephrase your query or "
"consult the source manuals directly."
)
if not context.strip():
return {
"answer": fallback_msg, "confidence": 0.0, "answer_mode": "fallback",
"sources": [], "latency_ms": round((time.time() - t0) * 1000, 1),
}
qa_out = qa_pipe(question=query, context=context)
confidence = round(float(qa_out["score"]), 4)
extractive = qa_out["answer"].strip() or "No answer extracted."
# Grounded generative answer via FLAN-T5
generative = ""
try:
evidence_blocks = []
for rank, c in enumerate(retrieved[:5], 1):
evidence_blocks.append(f"[{rank}] Source: {c['source']} | Page: {c['page']}\n{c['text']}")
evidence_text = "\n\n".join(evidence_blocks) or "No evidence retrieved."
prompt = (
"You are a grounded pet nutrition assistant.\n"
"Use only the evidence excerpts below. Do not use outside knowledge.\n"
"Your job is to turn the extractive evidence into a natural chatbot answer "
"while staying faithful to the sources.\n"
f"If the evidence does not clearly support an answer, reply exactly with:\n{fallback_msg}\n\n"
f"Question: {query}\n"
f"Extractive hint: {extractive}\n"
f"Extractive confidence: {confidence:.4f}\n\n"
f"Evidence:\n{evidence_text}\n\n"
"Write a concise answer in 1-3 sentences.\n"
"Prefer natural conversational wording, but keep every claim grounded in the evidence.\n"
"Do not invent unsupported facts or recommendations."
)
inputs = _gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
out_ids = _gen_model.generate(**inputs, max_new_tokens=96, do_sample=False)
generative = _gen_tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
generative = re.sub(r"\s+", " ", generative).strip()
except Exception as exc:
print(f"Generator failed: {exc}")
generative = ""
if generative and generative != fallback_msg:
answer, mode = generative, "generative"
elif confidence >= confidence_threshold:
answer, mode = extractive, "extractive"
else:
answer, mode = fallback_msg, "fallback"
sources = list({(c["source"], c["page"]) for c in retrieved})
latency_ms = round((time.time() - t0) * 1000, 1)
return {
"answer": answer, "confidence": confidence, "answer_mode": mode,
"sources": sources, "latency_ms": latency_ms,
"extractive_answer": extractive, "generative_answer": generative,
}
# ── Gradio UI ────────────────────────────────────────────────────────────────
def chat_fn(message: str, history: list) -> str:
try:
result = run_rag(message)
answer = result["answer"]
confidence = result["confidence"]
mode = result["answer_mode"]
sources = result["sources"]
latency = result["latency_ms"]
src_lines = "\n".join(f" β€’ {src} β€” p.{pg}" for src, pg in sorted(sources))
return (
f"{answer}\n\n"
f"---\n"
f"**Mode:** {mode} &nbsp;|&nbsp; **Confidence:** {confidence:.4f} &nbsp;|&nbsp; "
f"**Latency:** {latency} ms\n\n"
f"**Sources:**\n{src_lines}"
)
except Exception as exc:
return (
"The system could not extract a sufficiently supported answer from the "
"retrieved documents. Please rephrase your question or consult the source manuals.\n\n"
f"*(Error: {exc})*"
)
with gr.Blocks(theme=gr.themes.Ocean(), title="Pet Nutrition Knowledge Assistant 🐾") as demo:
gr.Markdown(
"""
# 🐾 Pet Nutrition Knowledge Assistant
Ask questions about dog and cat nutrition.
Answers are grounded in two authoritative veterinary nutrition sources:
- **FEDIAF Nutritional Guidelines (2025)**
- **WALTHAM Essential Cat and Dog Nutrition Booklet (v2)**
*Powered by DeBERTa-v3 extractive QA + FLAN-T5 generative synthesis over a hybrid FAISS + BM25 retriever.*
"""
)
chatbot = gr.ChatInterface(
fn=chat_fn,
examples=[
"Proteins include a total of how many different amino acids?",
"Dogs have how many teeth?",
"What does MER stand for?",
"How much water is in dry pet food?",
"what is wet pet food?",
"What are the effects of phosphorus deficiency?",
"What is the common name for the remaining material made up of dietary minerals?",
"What does L stand for",
"How many essential nutrients are required by cats and dogs?",
],
cache_examples=False,
)
demo.launch()