ifpri-dp-rag / app.py
feedcomposer's picture
Upload app.py with huggingface_hub
5d71852 verified
"""
app.py β€” IFPRI Discussion Papers RAG Search Application
Loads a pre-built FAISS index (produced by ingest.py) and provides a
Gradio interface for semantic search and AI-assisted Q&A over 1,681
IFPRI Discussion Papers.
Environment variables:
HF_TOKEN β€” Hugging Face token (required for LLM inference)
LLM_MODEL β€” HF model ID (default: mistralai/Mistral-7B-Instruct-v0.3)
EMBED_MODEL β€” Embedding model (default: BAAI/bge-small-en-v1.5)
TOP_K β€” Max papers to retrieve (default: 5)
"""
import csv
import json
import os
from pathlib import Path
import gradio as gr
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
# ── Config ────────────────────────────────────────────────────────────────────
EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-small-en-v1.5")
# Models confirmed working via HF auto-routing (tested March 2026):
# Qwen/Qwen2.5-7B-Instruct ← default, confirmed working
LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
TOP_K = int(os.getenv("TOP_K", "5"))
INDEX_DIR = Path(os.getenv("INDEX_DIR", "faiss_index"))
META_FILE = Path(os.getenv("META_FILE", "metadata.json"))
ITEMS_CSV = Path(os.getenv("ITEMS_CSV", "output-items.csv"))
PDFS_CSV = Path(os.getenv("PDFS_CSV", "output-pdfs.csv"))
SYSTEM_PROMPT = (
"You are a knowledgeable research assistant for the International Food Policy "
"Research Institute (IFPRI). You help food policy researchers find relevant "
"Discussion Papers and synthesize key findings. When answering, cite the "
"specific paper(s) (by number) you draw on. Be concise and precise."
)
# ── Load resources ────────────────────────────────────────────────────────────
print(f"Loading embedding model: {EMBED_MODEL}")
_embeddings = HuggingFaceEmbeddings(
model_name=EMBED_MODEL,
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
print(f"Loading FAISS index from: {INDEX_DIR}")
_vectorstore = FAISS.load_local(
str(INDEX_DIR), _embeddings, allow_dangerous_deserialization=True
)
print(f"Loading paper metadata from: {META_FILE}")
with open(META_FILE, encoding="utf-8") as f:
_papers_by_id: dict[str, dict] = {p["paper_id"]: p for p in json.load(f)}
# Build paper_id β†’ {url, title, doc_type} lookup from output-pdfs.csv + output-items.csv
_paper_urls: dict[str, str] = {}
_paper_titles: dict[str, str] = {}
_paper_types: dict[str, str] = {}
if ITEMS_CSV.exists() and PDFS_CSV.exists():
print(f"Loading paper metadata from {ITEMS_CSV.name} + {PDFS_CSV.name}")
with open(ITEMS_CSV, encoding="utf-8") as f:
items_by_uuid = {r["item-uuid"]: r for r in csv.DictReader(f)}
with open(PDFS_CSV, encoding="utf-8") as f:
for row in csv.DictReader(f):
paper_id = row["bitstream-name"].removesuffix(".pdf")
item = items_by_uuid.get(row["item-uuid"], {})
if item.get("md_dc_identifier_uri"):
_paper_urls[paper_id] = item["md_dc_identifier_uri"]
if item.get("md_dc_title"):
_paper_titles[paper_id] = item["md_dc_title"]
if item.get("md_dcterms_type"):
_paper_types[paper_id] = item["md_dcterms_type"]
print(f" Loaded metadata for {len(_paper_urls)} papers")
else:
print(" [WARN] output-items.csv / output-pdfs.csv not found β€” metadata will be omitted")
_token_display = f"...{HF_TOKEN[-5:]}" if HF_TOKEN and len(HF_TOKEN) >= 5 else ("(not set)" if not HF_TOKEN else HF_TOKEN)
print(f"HF_TOKEN: {_token_display}")
print(f"LLM model: {LLM_MODEL}")
_llm = InferenceClient(model=LLM_MODEL, token=HF_TOKEN)
print("Ready.\n")
# ── RAG helpers ───────────────────────────────────────────────────────────────
def retrieve(query: str, n_papers: int) -> list[tuple]:
"""
Return up to n_papers unique (doc, score) pairs, deduplicated by paper_id.
Retrieves 3Γ— more chunks than needed to ensure enough unique papers.
"""
raw = _vectorstore.similarity_search_with_score(query, k=n_papers * 4)
seen: dict[str, tuple] = {}
for doc, score in raw:
pid = doc.metadata.get("paper_id", "")
if pid not in seen:
seen[pid] = (doc, score)
if len(seen) >= n_papers:
break
# Sort ascending by L2 distance (lower = more similar = higher relevance)
return sorted(seen.values(), key=lambda x: x[1])
def build_context(hits: list[tuple]) -> str:
parts = []
for i, (doc, _) in enumerate(hits, 1):
pid = doc.metadata.get("paper_id", "")
title = _paper_titles.get(pid) or doc.metadata.get("title", "Untitled")
url = _paper_urls.get(pid, "")
excerpt = doc.page_content.replace("\n", " ").strip()[:600]
url_part = f" | {url}" if url else ""
parts.append(f"[{i}] {title} (ID: {pid}{url_part})\n{excerpt}")
return "\n\n---\n\n".join(parts)
def ask_llm(query: str, context: str) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"Using the IFPRI Discussion Paper excerpts below, answer the "
f"following research question. Cite papers by their bracketed "
f"number.\n\nQuestion: {query}\n\nContext:\n{context}"
),
},
]
try:
resp = _llm.chat_completion(
messages=messages,
max_tokens=600,
temperature=0.3,
)
return resp.choices[0].message.content.strip()
except Exception as exc:
return (
f"⚠️ Could not reach the LLM ({exc}).\n\n"
"Check that **HF_TOKEN** is set and valid, or try again in a moment."
)
# ── Citation linker ───────────────────────────────────────────────────────────
def linkify_citations(text: str, hits: list[tuple]) -> str:
"""Replace [N] citation markers in LLM answer with markdown hyperlinks."""
import re
url_by_index = {}
for i, (doc, _) in enumerate(hits, 1):
pid = doc.metadata.get("paper_id", "")
url = _paper_urls.get(pid, "")
if url:
url_by_index[i] = url
def replace(match):
n = int(match.group(1))
url = url_by_index.get(n)
return f"[[{n}]]({url})" if url else match.group(0)
return re.sub(r"\[(\d+)\]", replace, text)
# ── Relevance score helper ────────────────────────────────────────────────────
def cosine_to_pct(score: float) -> str:
"""
FAISS with normalised embeddings returns L2 distance.
Convert to an intuitive 0–100% relevance: pct = (1 - score/2) * 100.
L2=0 β†’ 100% (identical), L2=2 β†’ 0% (opposite).
"""
pct = (1.0 - min(max(score, 0.0), 2.0) / 2.0) * 100
return f"{pct:.1f}%"
# ── Main search function ──────────────────────────────────────────────────────
def rag_search(query: str, n_papers: int) -> tuple[str, str]:
query = query.strip()
if not query:
return "Please enter a research question or keyword.", ""
hits = retrieve(query, n_papers)
if not hits:
return "No relevant papers found. Try different keywords.", ""
context = build_context(hits)
answer = linkify_citations(ask_llm(query, context), hits)
# Format paper cards
cards = []
for i, (doc, score) in enumerate(hits, 1):
pid = doc.metadata.get("paper_id", "")
title = _paper_titles.get(pid) or doc.metadata.get("title", "Untitled")
doc_type = _paper_types.get(pid, "")
author = _papers_by_id.get(pid, {}).get("author", "")
pages = _papers_by_id.get(pid, {}).get("num_pages", "")
url = _paper_urls.get(pid, "")
rel = cosine_to_pct(score)
snippet = doc.page_content.replace("\n", " ").strip()[:350]
title_line = f"### [{i}] [{title}]({url})" if url else f"### [{i}] {title}"
type_line = f"**Type:** {doc_type} \n" if doc_type else ""
author_line = f"**Author(s):** {author} \n" if author else ""
pages_line = f"**Pages:** {pages} \n" if pages else ""
url_line = f"**URL:** {url} \n" if url else ""
cards.append(
f"{title_line}\n"
f"{type_line}"
f"**Relevance:** {rel} \n"
f"{author_line}"
f"{pages_line}"
f"{url_line}"
f"> {snippet}…"
)
return answer, "\n\n---\n\n".join(cards)
# ── Gradio UI ─────────────────────────────────────────────────────────────────
EXAMPLES = [
["What are the impacts of climate change on food security in Sub-Saharan Africa?"],
["How does irrigation investment affect smallholder farmers' income?"],
["What is the role of social protection programs in reducing hunger?"],
["How do food price shocks affect vulnerable households in developing countries?"],
["What policies support women's empowerment in agriculture?"],
["How can index-based insurance help farmers manage drought risk?"],
["What are the effects of agricultural R&D investment on crop productivity?"],
["How does urbanization affect food demand and dietary patterns?"],
]
with gr.Blocks(title="IFPRI Discussion Papers Search") as demo:
gr.Markdown(
"""
# πŸ“š IFPRI Discussion Papers β€” AI Search
Search and explore **1,681 IFPRI Discussion Papers** using semantic AI search.
Ask a research question or enter keywords; the app retrieves the most relevant
papers and generates a synthesised answer with citations.
> **Powered by** `BAAI/bge-small-en-v1.5` embeddings Β· `Qwen/Qwen2.5-7B-Instruct` LLM
"""
)
with gr.Row():
with gr.Column(scale=5):
query_box = gr.Textbox(
label="Research question or keywords",
placeholder=(
"e.g. 'What are effective strategies to reduce child "
"malnutrition in South Asia?'"
),
lines=2,
elem_id="query-box",
)
with gr.Column(scale=1, min_width=160):
n_slider = gr.Slider(
minimum=3, maximum=10, value=TOP_K, step=1,
label="Papers to return",
)
search_btn = gr.Button("πŸ” Search", variant="primary", size="lg")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### πŸ€– AI Answer")
answer_md = gr.Markdown(value="*Results will appear here after searching.*")
with gr.Column(scale=3):
gr.Markdown("### πŸ“„ Relevant Papers")
papers_md = gr.Markdown(value="")
gr.Markdown("---")
gr.Examples(
examples=EXAMPLES,
inputs=query_box,
label="Example queries β€” click to try",
examples_per_page=8,
)
# Wire up events
search_btn.click(rag_search, [query_box, n_slider], [answer_md, papers_md])
query_box.submit(rag_search, [query_box, n_slider], [answer_md, papers_md])
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
theme=gr.themes.Soft(primary_hue="green", font=gr.themes.GoogleFont("Inter")),
css="""
.gradio-container { max-width: 1100px; margin: auto; }
#query-box textarea { font-size: 16px; }
""",
)