import os import uuid import gradio as gr import requests os.environ["TOKENIZERS_PARALLELISM"] = "false" FASTAPI_URL = os.getenv("FASTAPI_URL", "http://127.0.0.1:8000") CHAT_ENDPOINT = f"{FASTAPI_URL}/chat" UPLOAD_ENDPOINT = f"{FASTAPI_URL}/upload" RESET_ENDPOINT = f"{FASTAPI_URL}/reset" CLEAR_MEMORY_ENDPOINT = f"{FASTAPI_URL}/clear_memory" SUPPORTED_TYPES = [".pdf", ".docx", ".txt", ".md", ".csv"] CSS = """ #submit-btn { margin-top: auto; margin-bottom: auto; height: 80px; } #input-row { align-items: center; } """ SOURCE_LABELS = { "rag": "Uploaded Documents (RAG)", "web": "Web Search", "rag+web": "Documents + Web Search", "unknown": "Unknown", } def upload_files(files) -> str: if not files: return "No files selected." try: multipart = [] for f in files: path = f if isinstance(f, str) else f.name with open(path, "rb") as fp: multipart.append(("files", (os.path.basename(path), fp.read(), "application/octet-stream"))) # Large PDFs can take minutes to embed on CPU, so allow a generous timeout. resp = requests.post(UPLOAD_ENDPOINT, files=multipart, timeout=900) resp.raise_for_status() return resp.json().get("status", "Files processed.") except requests.exceptions.ReadTimeout: return ("Still indexing. This file is large and CPU embedding is slow; " "wait a moment and try your question, it may already be indexed.") except requests.exceptions.RequestException as e: return f"Upload failed: {e}" def reset_documents() -> str: try: resp = requests.post(RESET_ENDPOINT, timeout=10) resp.raise_for_status() return resp.json().get("status", "Documents cleared.") except requests.exceptions.RequestException as e: return f"Reset failed: {e}" def _bar(value: float, width: int = 12) -> str: filled = round(value * width) return "ā–ˆ" * filled + "ā–‘" * (width - filled) def _fmt(label: str, value: float, hint: str) -> str: return f"**{label}** \n`{_bar(value)} {value:.2f} / 1.0` \n_{hint}_\n" def _faith_hint(score: float, source: str) -> str: what = "web results" if source == "web" else "retrieved documents" if score >= 0.75: return f"Answer is well-grounded in the {what}" if score >= 0.50: return f"Answer is mostly grounded in the {what}, minor unsupported details possible" return f"Low grounding: answer may contain content not present in the {what}" def _relevance_hint(score: float) -> str: if score >= 0.70: return "Answer directly addresses the question" if score >= 0.45: return "Answer is related but may not fully address the question" return "Answer may be off-topic or incomplete" def _accuracy_hint(score: float) -> str: if score >= 0.75: return "Strong match with your reference answer" if score >= 0.40: return "Partial match: some key points differ from your reference" return "Low overlap with your reference answer" def _format_metrics(source: str, faithfulness, answer_relevance, accuracy) -> str: if faithfulness is None and answer_relevance is None: return "" src_label = SOURCE_LABELS.get(source, source) lines = [f"**Answer source: {src_label}**", "---"] if faithfulness is not None: faith_label = "Faithfulness (grounded in documents)" if source != "web" else "Faithfulness (grounded in web results)" lines.append(_fmt(faith_label, faithfulness, _faith_hint(faithfulness, source))) if answer_relevance is not None: lines.append(_fmt("Answer Relevance (addresses the question)", answer_relevance, _relevance_hint(answer_relevance))) if accuracy is not None: lines.append(_fmt("Accuracy (matches your reference)", accuracy, _accuracy_hint(accuracy))) else: lines.append("_Accuracy: provide a reference answer above to see this score._") return "\n".join(lines) def process_query(message: str, api_key: str, reference: str, session_id: str, chat_history: list) -> tuple[list, str, str]: if not api_key.strip(): return chat_history, "Enter your Google Gemini API key above to start.", "" if not message.strip(): return chat_history, "Please enter a question.", "" try: payload = {"query": message, "api_key": api_key.strip(), "session_id": session_id} if reference.strip(): payload["reference"] = reference.strip() resp = requests.post(CHAT_ENDPOINT, json=payload, timeout=180) if resp.status_code in (400, 401, 429): detail = resp.json().get("detail", "Request could not be completed.") chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": f"Sorry, I can't respond right now:\n\n{detail}"}) return chat_history, detail, "" resp.raise_for_status() data = resp.json() response_text = data.get("response", "") source = data.get("source", "unknown") citations = data.get("citations", []) src_label = SOURCE_LABELS.get(source, source) footer = f"\n\n--- Source: {src_label} ---" if citations: footer += "\nšŸ“Ž Sources used: " + "; ".join(citations) chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": f"{response_text}{footer}"}) metrics_md = _format_metrics( source, data.get("faithfulness"), data.get("answer_relevance"), data.get("accuracy"), ) return chat_history, f"Done. Answered via {src_label}", metrics_md except requests.exceptions.RequestException as e: error = f"Request failed: {e}" chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": f"ERROR: {error}"}) return chat_history, error, "" def clear_chat(session_id: str) -> tuple[list, str, str]: """Clear the UI and the server-side conversation memory for this session.""" try: requests.post(CLEAR_MEMORY_ENDPOINT, params={"session_id": session_id}, timeout=10) except requests.exceptions.RequestException: pass return [], "", "" with gr.Blocks(title="Agentic RAG Knowledge Search") as demo: session_id = gr.State() # per-browser id for server-side memory gr.Markdown("# Agentic RAG Knowledge Search") gr.Markdown( "Upload your own documents and ask questions. " "The agent searches your files via RAG first, then falls back to web search if needed. " "It remembers the conversation, so follow-up questions work. " "**No files uploaded?** The agent uses a built-in legal/policy document as the default knowledge base." ) with gr.Group(): gr.Markdown( "### Your Gemini API Key (required)\n" "This app uses **your own** Google Gemini key; it is sent only with your requests and never stored. " "Get a free key at [Google AI Studio](https://aistudio.google.com/apikey)." ) api_key_input = gr.Textbox( label="Google Gemini API Key", placeholder="Paste your API key here (starts with AIza...)", type="password", lines=1, ) with gr.Group(): gr.Markdown( f"### Upload Documents\n" f"Supported: **{', '.join(SUPPORTED_TYPES)}**. Multiple files allowed; new uploads add to the index. \n" "_Large PDFs (100s of pages) can take a few minutes to index on the free CPU, so please be patient._" ) file_input = gr.File(label="Select Files", file_count="multiple", file_types=SUPPORTED_TYPES) with gr.Row(): upload_btn = gr.Button("Process & Index Files", variant="primary") reset_btn = gr.Button("Clear Uploaded Documents", variant="secondary") upload_status = gr.Textbox(label="Upload Status", interactive=False, lines=2, placeholder="Upload status will appear here...") gr.Markdown("---") with gr.Group(): gr.Markdown( "### Ask a Question\n" "> Press **Enter** or click **Submit**. The agent decides whether to search your documents, the web, or both." ) chatbot = gr.Chatbot(label="Conversation", height=420) with gr.Row(elem_id="input-row"): user_input = gr.Textbox( placeholder="Ask anything about your documents or the web...", lines=2, scale=5, show_label=False, container=False, ) submit_btn = gr.Button("Submit", variant="primary", scale=1, elem_id="submit-btn") reference_input = gr.Textbox( label="Reference Answer (optional)", placeholder="Paste an expected correct answer here to also see an Accuracy score...", lines=2, ) with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary", scale=1) status_output = gr.Textbox(label="Status", interactive=False, lines=1, placeholder="Status will appear here...") gr.Markdown("---") with gr.Group(): gr.Markdown( "### Evaluation Metrics\n" "Computed automatically after every response, with **no extra API calls and no reference needed** for the first two.\n\n" "| Metric | Always shown? | What it measures |\n" "|---|---|---|\n" "| **Faithfulness** | Yes | Is the answer grounded in the source it used? (docs or web results) |\n" "| **Answer Relevance** | Yes | Does the answer actually address your question? |\n" "| **Accuracy** | Only with reference | Does the answer match an expected correct answer? |" ) metrics_output = gr.Markdown() demo.load(fn=lambda: str(uuid.uuid4()), inputs=[], outputs=[session_id]) upload_btn.click(fn=upload_files, inputs=[file_input], outputs=[upload_status]) reset_btn.click(fn=reset_documents, inputs=[], outputs=[upload_status]) submit_btn.click( fn=process_query, inputs=[user_input, api_key_input, reference_input, session_id, chatbot], outputs=[chatbot, status_output, metrics_output], ).then(fn=lambda: "", inputs=[], outputs=[user_input]) user_input.submit( fn=process_query, inputs=[user_input, api_key_input, reference_input, session_id, chatbot], outputs=[chatbot, status_output, metrics_output], ).then(fn=lambda: "", inputs=[], outputs=[user_input]) clear_btn.click(fn=clear_chat, inputs=[session_id], outputs=[chatbot, status_output, metrics_output]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Soft(), css=CSS)