Spaces:
Running
Running
| """ | |
| app.py β IFPRI Discussion Papers RAG Search Application | |
| Loads a pre-built FAISS index (produced by ingest.py) and provides a | |
| Gradio interface for semantic search and AI-assisted Q&A over 1,681 | |
| IFPRI Discussion Papers. | |
| Environment variables: | |
| HF_TOKEN β Hugging Face token (required for LLM inference) | |
| LLM_MODEL β HF model ID (default: mistralai/Mistral-7B-Instruct-v0.3) | |
| EMBED_MODEL β Embedding model (default: BAAI/bge-small-en-v1.5) | |
| TOP_K β Max papers to retrieve (default: 5) | |
| """ | |
| import csv | |
| import json | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "BAAI/bge-small-en-v1.5") | |
| # Models confirmed working via HF auto-routing (tested March 2026): | |
| # Qwen/Qwen2.5-7B-Instruct β default, confirmed working | |
| LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| TOP_K = int(os.getenv("TOP_K", "5")) | |
| INDEX_DIR = Path(os.getenv("INDEX_DIR", "faiss_index")) | |
| META_FILE = Path(os.getenv("META_FILE", "metadata.json")) | |
| ITEMS_CSV = Path(os.getenv("ITEMS_CSV", "output-items.csv")) | |
| PDFS_CSV = Path(os.getenv("PDFS_CSV", "output-pdfs.csv")) | |
| SYSTEM_PROMPT = ( | |
| "You are a knowledgeable research assistant for the International Food Policy " | |
| "Research Institute (IFPRI). You help food policy researchers find relevant " | |
| "Discussion Papers and synthesize key findings. When answering, cite the " | |
| "specific paper(s) (by number) you draw on. Be concise and precise." | |
| ) | |
| # ββ Load resources ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Loading embedding model: {EMBED_MODEL}") | |
| _embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBED_MODEL, | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| print(f"Loading FAISS index from: {INDEX_DIR}") | |
| _vectorstore = FAISS.load_local( | |
| str(INDEX_DIR), _embeddings, allow_dangerous_deserialization=True | |
| ) | |
| print(f"Loading paper metadata from: {META_FILE}") | |
| with open(META_FILE, encoding="utf-8") as f: | |
| _papers_by_id: dict[str, dict] = {p["paper_id"]: p for p in json.load(f)} | |
| # Build paper_id β {url, title, doc_type} lookup from output-pdfs.csv + output-items.csv | |
| _paper_urls: dict[str, str] = {} | |
| _paper_titles: dict[str, str] = {} | |
| _paper_types: dict[str, str] = {} | |
| if ITEMS_CSV.exists() and PDFS_CSV.exists(): | |
| print(f"Loading paper metadata from {ITEMS_CSV.name} + {PDFS_CSV.name}") | |
| with open(ITEMS_CSV, encoding="utf-8") as f: | |
| items_by_uuid = {r["item-uuid"]: r for r in csv.DictReader(f)} | |
| with open(PDFS_CSV, encoding="utf-8") as f: | |
| for row in csv.DictReader(f): | |
| paper_id = row["bitstream-name"].removesuffix(".pdf") | |
| item = items_by_uuid.get(row["item-uuid"], {}) | |
| if item.get("md_dc_identifier_uri"): | |
| _paper_urls[paper_id] = item["md_dc_identifier_uri"] | |
| if item.get("md_dc_title"): | |
| _paper_titles[paper_id] = item["md_dc_title"] | |
| if item.get("md_dcterms_type"): | |
| _paper_types[paper_id] = item["md_dcterms_type"] | |
| print(f" Loaded metadata for {len(_paper_urls)} papers") | |
| else: | |
| print(" [WARN] output-items.csv / output-pdfs.csv not found β metadata will be omitted") | |
| _token_display = f"...{HF_TOKEN[-5:]}" if HF_TOKEN and len(HF_TOKEN) >= 5 else ("(not set)" if not HF_TOKEN else HF_TOKEN) | |
| print(f"HF_TOKEN: {_token_display}") | |
| print(f"LLM model: {LLM_MODEL}") | |
| _llm = InferenceClient(model=LLM_MODEL, token=HF_TOKEN) | |
| print("Ready.\n") | |
| # ββ RAG helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def retrieve(query: str, n_papers: int) -> list[tuple]: | |
| """ | |
| Return up to n_papers unique (doc, score) pairs, deduplicated by paper_id. | |
| Retrieves 3Γ more chunks than needed to ensure enough unique papers. | |
| """ | |
| raw = _vectorstore.similarity_search_with_score(query, k=n_papers * 4) | |
| seen: dict[str, tuple] = {} | |
| for doc, score in raw: | |
| pid = doc.metadata.get("paper_id", "") | |
| if pid not in seen: | |
| seen[pid] = (doc, score) | |
| if len(seen) >= n_papers: | |
| break | |
| # Sort ascending by L2 distance (lower = more similar = higher relevance) | |
| return sorted(seen.values(), key=lambda x: x[1]) | |
| def build_context(hits: list[tuple]) -> str: | |
| parts = [] | |
| for i, (doc, _) in enumerate(hits, 1): | |
| pid = doc.metadata.get("paper_id", "") | |
| title = _paper_titles.get(pid) or doc.metadata.get("title", "Untitled") | |
| url = _paper_urls.get(pid, "") | |
| excerpt = doc.page_content.replace("\n", " ").strip()[:600] | |
| url_part = f" | {url}" if url else "" | |
| parts.append(f"[{i}] {title} (ID: {pid}{url_part})\n{excerpt}") | |
| return "\n\n---\n\n".join(parts) | |
| def ask_llm(query: str, context: str) -> str: | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Using the IFPRI Discussion Paper excerpts below, answer the " | |
| f"following research question. Cite papers by their bracketed " | |
| f"number.\n\nQuestion: {query}\n\nContext:\n{context}" | |
| ), | |
| }, | |
| ] | |
| try: | |
| resp = _llm.chat_completion( | |
| messages=messages, | |
| max_tokens=600, | |
| temperature=0.3, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as exc: | |
| return ( | |
| f"β οΈ Could not reach the LLM ({exc}).\n\n" | |
| "Check that **HF_TOKEN** is set and valid, or try again in a moment." | |
| ) | |
| # ββ Citation linker βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def linkify_citations(text: str, hits: list[tuple]) -> str: | |
| """Replace [N] citation markers in LLM answer with markdown hyperlinks.""" | |
| import re | |
| url_by_index = {} | |
| for i, (doc, _) in enumerate(hits, 1): | |
| pid = doc.metadata.get("paper_id", "") | |
| url = _paper_urls.get(pid, "") | |
| if url: | |
| url_by_index[i] = url | |
| def replace(match): | |
| n = int(match.group(1)) | |
| url = url_by_index.get(n) | |
| return f"[[{n}]]({url})" if url else match.group(0) | |
| return re.sub(r"\[(\d+)\]", replace, text) | |
| # ββ Relevance score helper ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cosine_to_pct(score: float) -> str: | |
| """ | |
| FAISS with normalised embeddings returns L2 distance. | |
| Convert to an intuitive 0β100% relevance: pct = (1 - score/2) * 100. | |
| L2=0 β 100% (identical), L2=2 β 0% (opposite). | |
| """ | |
| pct = (1.0 - min(max(score, 0.0), 2.0) / 2.0) * 100 | |
| return f"{pct:.1f}%" | |
| # ββ Main search function ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rag_search(query: str, n_papers: int) -> tuple[str, str]: | |
| query = query.strip() | |
| if not query: | |
| return "Please enter a research question or keyword.", "" | |
| hits = retrieve(query, n_papers) | |
| if not hits: | |
| return "No relevant papers found. Try different keywords.", "" | |
| context = build_context(hits) | |
| answer = linkify_citations(ask_llm(query, context), hits) | |
| # Format paper cards | |
| cards = [] | |
| for i, (doc, score) in enumerate(hits, 1): | |
| pid = doc.metadata.get("paper_id", "") | |
| title = _paper_titles.get(pid) or doc.metadata.get("title", "Untitled") | |
| doc_type = _paper_types.get(pid, "") | |
| author = _papers_by_id.get(pid, {}).get("author", "") | |
| pages = _papers_by_id.get(pid, {}).get("num_pages", "") | |
| url = _paper_urls.get(pid, "") | |
| rel = cosine_to_pct(score) | |
| snippet = doc.page_content.replace("\n", " ").strip()[:350] | |
| title_line = f"### [{i}] [{title}]({url})" if url else f"### [{i}] {title}" | |
| type_line = f"**Type:** {doc_type} \n" if doc_type else "" | |
| author_line = f"**Author(s):** {author} \n" if author else "" | |
| pages_line = f"**Pages:** {pages} \n" if pages else "" | |
| url_line = f"**URL:** {url} \n" if url else "" | |
| cards.append( | |
| f"{title_line}\n" | |
| f"{type_line}" | |
| f"**Relevance:** {rel} \n" | |
| f"{author_line}" | |
| f"{pages_line}" | |
| f"{url_line}" | |
| f"> {snippet}β¦" | |
| ) | |
| return answer, "\n\n---\n\n".join(cards) | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EXAMPLES = [ | |
| ["What are the impacts of climate change on food security in Sub-Saharan Africa?"], | |
| ["How does irrigation investment affect smallholder farmers' income?"], | |
| ["What is the role of social protection programs in reducing hunger?"], | |
| ["How do food price shocks affect vulnerable households in developing countries?"], | |
| ["What policies support women's empowerment in agriculture?"], | |
| ["How can index-based insurance help farmers manage drought risk?"], | |
| ["What are the effects of agricultural R&D investment on crop productivity?"], | |
| ["How does urbanization affect food demand and dietary patterns?"], | |
| ] | |
| with gr.Blocks(title="IFPRI Discussion Papers Search") as demo: | |
| gr.Markdown( | |
| """ | |
| # π IFPRI Discussion Papers β AI Search | |
| Search and explore **1,681 IFPRI Discussion Papers** using semantic AI search. | |
| Ask a research question or enter keywords; the app retrieves the most relevant | |
| papers and generates a synthesised answer with citations. | |
| > **Powered by** `BAAI/bge-small-en-v1.5` embeddings Β· `Qwen/Qwen2.5-7B-Instruct` LLM | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| query_box = gr.Textbox( | |
| label="Research question or keywords", | |
| placeholder=( | |
| "e.g. 'What are effective strategies to reduce child " | |
| "malnutrition in South Asia?'" | |
| ), | |
| lines=2, | |
| elem_id="query-box", | |
| ) | |
| with gr.Column(scale=1, min_width=160): | |
| n_slider = gr.Slider( | |
| minimum=3, maximum=10, value=TOP_K, step=1, | |
| label="Papers to return", | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π€ AI Answer") | |
| answer_md = gr.Markdown(value="*Results will appear here after searching.*") | |
| with gr.Column(scale=3): | |
| gr.Markdown("### π Relevant Papers") | |
| papers_md = gr.Markdown(value="") | |
| gr.Markdown("---") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=query_box, | |
| label="Example queries β click to try", | |
| examples_per_page=8, | |
| ) | |
| # Wire up events | |
| search_btn.click(rag_search, [query_box, n_slider], [answer_md, papers_md]) | |
| query_box.submit(rag_search, [query_box, n_slider], [answer_md, papers_md]) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| theme=gr.themes.Soft(primary_hue="green", font=gr.themes.GoogleFont("Inter")), | |
| css=""" | |
| .gradio-container { max-width: 1100px; margin: auto; } | |
| #query-box textarea { font-size: 16px; } | |
| """, | |
| ) | |