Devisri515's picture
fix CI
2b63102
Raw
History Blame Contribute Delete
11 kB
import os
import uuid
import gradio as gr
import requests
os.environ["TOKENIZERS_PARALLELISM"] = "false"
FASTAPI_URL = os.getenv("FASTAPI_URL", "http://127.0.0.1:8000")
CHAT_ENDPOINT = f"{FASTAPI_URL}/chat"
UPLOAD_ENDPOINT = f"{FASTAPI_URL}/upload"
RESET_ENDPOINT = f"{FASTAPI_URL}/reset"
CLEAR_MEMORY_ENDPOINT = f"{FASTAPI_URL}/clear_memory"
SUPPORTED_TYPES = [".pdf", ".docx", ".txt", ".md", ".csv"]
CSS = """
#submit-btn { margin-top: auto; margin-bottom: auto; height: 80px; }
#input-row { align-items: center; }
"""
SOURCE_LABELS = {
"rag": "Uploaded Documents (RAG)",
"web": "Web Search",
"rag+web": "Documents + Web Search",
"unknown": "Unknown",
}
def upload_files(files) -> str:
if not files:
return "No files selected."
try:
multipart = []
for f in files:
path = f if isinstance(f, str) else f.name
with open(path, "rb") as fp:
multipart.append(("files", (os.path.basename(path), fp.read(), "application/octet-stream")))
# Large PDFs can take minutes to embed on CPU, so allow a generous timeout.
resp = requests.post(UPLOAD_ENDPOINT, files=multipart, timeout=900)
resp.raise_for_status()
return resp.json().get("status", "Files processed.")
except requests.exceptions.ReadTimeout:
return ("Still indexing. This file is large and CPU embedding is slow; "
"wait a moment and try your question, it may already be indexed.")
except requests.exceptions.RequestException as e:
return f"Upload failed: {e}"
def reset_documents() -> str:
try:
resp = requests.post(RESET_ENDPOINT, timeout=10)
resp.raise_for_status()
return resp.json().get("status", "Documents cleared.")
except requests.exceptions.RequestException as e:
return f"Reset failed: {e}"
def _bar(value: float, width: int = 12) -> str:
filled = round(value * width)
return "█" * filled + "░" * (width - filled)
def _fmt(label: str, value: float, hint: str) -> str:
return f"**{label}** \n`{_bar(value)} {value:.2f} / 1.0` \n_{hint}_\n"
def _faith_hint(score: float, source: str) -> str:
what = "web results" if source == "web" else "retrieved documents"
if score >= 0.75:
return f"Answer is well-grounded in the {what}"
if score >= 0.50:
return f"Answer is mostly grounded in the {what}, minor unsupported details possible"
return f"Low grounding: answer may contain content not present in the {what}"
def _relevance_hint(score: float) -> str:
if score >= 0.70:
return "Answer directly addresses the question"
if score >= 0.45:
return "Answer is related but may not fully address the question"
return "Answer may be off-topic or incomplete"
def _accuracy_hint(score: float) -> str:
if score >= 0.75:
return "Strong match with your reference answer"
if score >= 0.40:
return "Partial match: some key points differ from your reference"
return "Low overlap with your reference answer"
def _format_metrics(source: str, faithfulness, answer_relevance, accuracy) -> str:
if faithfulness is None and answer_relevance is None:
return ""
src_label = SOURCE_LABELS.get(source, source)
lines = [f"**Answer source: {src_label}**", "---"]
if faithfulness is not None:
faith_label = "Faithfulness (grounded in documents)" if source != "web" else "Faithfulness (grounded in web results)"
lines.append(_fmt(faith_label, faithfulness, _faith_hint(faithfulness, source)))
if answer_relevance is not None:
lines.append(_fmt("Answer Relevance (addresses the question)", answer_relevance, _relevance_hint(answer_relevance)))
if accuracy is not None:
lines.append(_fmt("Accuracy (matches your reference)", accuracy, _accuracy_hint(accuracy)))
else:
lines.append("_Accuracy: provide a reference answer above to see this score._")
return "\n".join(lines)
def process_query(message: str, api_key: str, reference: str, session_id: str, chat_history: list) -> tuple[list, str, str]:
if not api_key.strip():
return chat_history, "Enter your Google Gemini API key above to start.", ""
if not message.strip():
return chat_history, "Please enter a question.", ""
try:
payload = {"query": message, "api_key": api_key.strip(), "session_id": session_id}
if reference.strip():
payload["reference"] = reference.strip()
resp = requests.post(CHAT_ENDPOINT, json=payload, timeout=180)
if resp.status_code in (400, 401, 429):
detail = resp.json().get("detail", "Request could not be completed.")
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": f"Sorry, I can't respond right now:\n\n{detail}"})
return chat_history, detail, ""
resp.raise_for_status()
data = resp.json()
response_text = data.get("response", "")
source = data.get("source", "unknown")
citations = data.get("citations", [])
src_label = SOURCE_LABELS.get(source, source)
footer = f"\n\n--- Source: {src_label} ---"
if citations:
footer += "\n📎 Sources used: " + "; ".join(citations)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": f"{response_text}{footer}"})
metrics_md = _format_metrics(
source,
data.get("faithfulness"),
data.get("answer_relevance"),
data.get("accuracy"),
)
return chat_history, f"Done. Answered via {src_label}", metrics_md
except requests.exceptions.RequestException as e:
error = f"Request failed: {e}"
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": f"ERROR: {error}"})
return chat_history, error, ""
def clear_chat(session_id: str) -> tuple[list, str, str]:
"""Clear the UI and the server-side conversation memory for this session."""
try:
requests.post(CLEAR_MEMORY_ENDPOINT, params={"session_id": session_id}, timeout=10)
except requests.exceptions.RequestException:
pass
return [], "", ""
with gr.Blocks(title="Agentic RAG Knowledge Search") as demo:
session_id = gr.State() # per-browser id for server-side memory
gr.Markdown("# Agentic RAG Knowledge Search")
gr.Markdown(
"Upload your own documents and ask questions. "
"The agent searches your files via RAG first, then falls back to web search if needed. "
"It remembers the conversation, so follow-up questions work. "
"**No files uploaded?** The agent uses a built-in legal/policy document as the default knowledge base."
)
with gr.Group():
gr.Markdown(
"### Your Gemini API Key (required)\n"
"This app uses **your own** Google Gemini key; it is sent only with your requests and never stored. "
"Get a free key at [Google AI Studio](https://aistudio.google.com/apikey)."
)
api_key_input = gr.Textbox(
label="Google Gemini API Key",
placeholder="Paste your API key here (starts with AIza...)",
type="password",
lines=1,
)
with gr.Group():
gr.Markdown(
f"### Upload Documents\n"
f"Supported: **{', '.join(SUPPORTED_TYPES)}**. Multiple files allowed; new uploads add to the index. \n"
"_Large PDFs (100s of pages) can take a few minutes to index on the free CPU, so please be patient._"
)
file_input = gr.File(label="Select Files", file_count="multiple", file_types=SUPPORTED_TYPES)
with gr.Row():
upload_btn = gr.Button("Process & Index Files", variant="primary")
reset_btn = gr.Button("Clear Uploaded Documents", variant="secondary")
upload_status = gr.Textbox(label="Upload Status", interactive=False, lines=2,
placeholder="Upload status will appear here...")
gr.Markdown("---")
with gr.Group():
gr.Markdown(
"### Ask a Question\n"
"> Press **Enter** or click **Submit**. The agent decides whether to search your documents, the web, or both."
)
chatbot = gr.Chatbot(label="Conversation", height=420)
with gr.Row(elem_id="input-row"):
user_input = gr.Textbox(
placeholder="Ask anything about your documents or the web...",
lines=2, scale=5, show_label=False, container=False,
)
submit_btn = gr.Button("Submit", variant="primary", scale=1, elem_id="submit-btn")
reference_input = gr.Textbox(
label="Reference Answer (optional)",
placeholder="Paste an expected correct answer here to also see an Accuracy score...",
lines=2,
)
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary", scale=1)
status_output = gr.Textbox(label="Status", interactive=False, lines=1,
placeholder="Status will appear here...")
gr.Markdown("---")
with gr.Group():
gr.Markdown(
"### Evaluation Metrics\n"
"Computed automatically after every response, with **no extra API calls and no reference needed** for the first two.\n\n"
"| Metric | Always shown? | What it measures |\n"
"|---|---|---|\n"
"| **Faithfulness** | Yes | Is the answer grounded in the source it used? (docs or web results) |\n"
"| **Answer Relevance** | Yes | Does the answer actually address your question? |\n"
"| **Accuracy** | Only with reference | Does the answer match an expected correct answer? |"
)
metrics_output = gr.Markdown()
demo.load(fn=lambda: str(uuid.uuid4()), inputs=[], outputs=[session_id])
upload_btn.click(fn=upload_files, inputs=[file_input], outputs=[upload_status])
reset_btn.click(fn=reset_documents, inputs=[], outputs=[upload_status])
submit_btn.click(
fn=process_query,
inputs=[user_input, api_key_input, reference_input, session_id, chatbot],
outputs=[chatbot, status_output, metrics_output],
).then(fn=lambda: "", inputs=[], outputs=[user_input])
user_input.submit(
fn=process_query,
inputs=[user_input, api_key_input, reference_input, session_id, chatbot],
outputs=[chatbot, status_output, metrics_output],
).then(fn=lambda: "", inputs=[], outputs=[user_input])
clear_btn.click(fn=clear_chat, inputs=[session_id], outputs=[chatbot, status_output, metrics_output])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Soft(), css=CSS)