ZhangNy's picture
Add Space app files
75db650
"""Gradio UI for the Radiology RAG Space."""
from __future__ import annotations
import logging
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple
import gradio as gr
from radiology_rag.config import Config
from radiology_rag.index_bootstrap import ensure_index, read_manifest
from radiology_rag.rag import RAGEngine
logger = logging.getLogger(__name__)
def _truncate(text: str, max_len: int) -> str:
s = (text or "").strip()
if len(s) <= max_len:
return s
return s[: max(0, max_len - 3)] + "..."
def format_error_message(error: str) -> str:
return f"**⚠️ Error**\n\n{error}"
def format_loading_message() -> str:
return "**🔄 Processing your query...**\n\nRetrieving relevant sources and generating an answer with citations."
def format_reference_card(doc: Dict[str, Any], index: int) -> str:
title = doc.get("title", "Untitled") or "Untitled"
source_type = (doc.get("source_type") or "").upper()
url = doc.get("url", "") or ""
content = doc.get("content", "") or ""
score = float(doc.get("score", 0.0) or 0.0)
max_preview_length = 350
preview = _truncate(content, max_preview_length).replace("\n", " ")
type_colors = {
"ARTICLE": "#3b82f6",
"CASE": "#10b981",
"TUTORIAL": "#f59e0b",
"ENCYCLOPEDIA": "#8b5cf6",
}
color = type_colors.get(source_type, "#6b7280")
score_html = f"<span style='color:#6b7280;font-size:12px;'>Score: {score:.3f}</span>" if score > 0 else ""
url_html = (
f"<p style='margin:0 0 8px 0;font-size:12px;'><a href='{url}' target='_blank' "
f"style='color:#3b82f6;text-decoration:none;'>🔗 View Source</a></p>"
if url
else ""
)
return f"""
<div id="ref-{index}" style="border:1px solid #e5e7eb;border-radius:8px;padding:16px;margin-bottom:16px;background:white;scroll-margin-top:90px;">
<div style="display:flex;align-items:center;gap:8px;margin-bottom:12px;flex-wrap:wrap;">
<span style="background:{color};color:white;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;">
{source_type or "SOURCE"}
</span>
<span style="background:#f3f4f6;color:#374151;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;">
[{index}]
</span>
{score_html}
</div>
<h3 style="margin:0 0 8px 0;color:#111827;font-size:18px;">{title}</h3>
{url_html}
<p style="margin:0;color:#4b5563;font-size:14px;line-height:1.5;">{preview}</p>
</div>
"""
def format_reference_panel(references: List[Dict[str, Any]]) -> str:
if not references:
return "<p style='color:#6b7280;text-align:center;padding:20px;'>No references available</p>"
html_parts = ['<div style="max-height: 600px; overflow-y: auto;">']
for i, doc in enumerate(references, 1):
html_parts.append(format_reference_card(doc, i))
html_parts.append("</div>")
return "".join(html_parts)
def format_statistics(metadata: Dict[str, Any]) -> str:
num_retrieved = int(metadata.get("num_retrieved", 0) or 0)
num_reranked = int(metadata.get("num_reranked", 0) or 0)
source_dist = metadata.get("source_type_distribution", {}) or {}
retrieved_label = metadata.get("retrieved_label", "Retrieved")
reranked_label = metadata.get("reranked_label", "After Reranking")
elapsed = float(metadata.get("elapsed_time", 0.0) or 0.0)
strategy = metadata.get("retrieval_strategy", "")
chips = "".join(
[
f"<span style='display:inline-block;background:#e5e7eb;color:#111827;padding:4px 8px;border-radius:4px;margin-right:8px;font-size:12px;line-height:1.2;'>{k}: {v}</span>"
for k, v in source_dist.items()
]
)
return f"""
<div style="background:#f9fafb;padding:16px;border-radius:8px;margin-top:16px;">
<h4 style="margin:0 0 12px 0;color:#374151;font-size:14px;">📊 Query Statistics</h4>
<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(150px,1fr));gap:12px;">
<div>
<p style="margin:0;color:#6b7280;font-size:12px;">{retrieved_label}</p>
<p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_retrieved}</p>
</div>
<div>
<p style="margin:0;color:#6b7280;font-size:12px;">{reranked_label}</p>
<p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_reranked}</p>
</div>
<div>
<p style="margin:0;color:#6b7280;font-size:12px;">Elapsed</p>
<p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{elapsed:.2f}s</p>
</div>
</div>
<div style="margin-top:12px;">
<p style="margin:0 0 6px 0;color:#6b7280;font-size:12px;">Retrieval Strategy: <code>{strategy}</code></p>
<p style="margin:0 0 4px 0;color:#6b7280;font-size:12px;">Source Distribution:</p>
{chips if chips else "<span style='color:#6b7280;font-size:12px;'>N/A</span>"}
</div>
</div>
"""
def create_settings_accordion(
*,
default_strategy: str,
default_temperature: float,
default_sources: List[str],
) -> Tuple[gr.Radio, gr.Slider, gr.CheckboxGroup]:
with gr.Accordion("⚙️ Advanced Settings", open=False):
gr.Markdown(
"#### Retrieval Strategy\n"
"- **default**: one mixed retrieval + single rerank (fast)\n"
"- **balanced_multi_source**: per-source recall + per-source rerank + Wikipedia (more diverse)\n"
)
retrieval_strategy = gr.Radio(
choices=["default", "balanced_multi_source"],
value=default_strategy,
label="Retrieval Strategy",
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=float(default_temperature),
step=0.1,
label="LLM Temperature",
)
source_filter = gr.CheckboxGroup(
choices=["article", "case", "tutorial", "encyclopedia"],
value=default_sources,
label="Filter by Source Type",
)
return retrieval_strategy, temperature_slider, source_filter
class RadiologyRAGApp:
def __init__(self, config_path: str):
self.config = Config(config_path)
self.startup_error: Optional[str] = None
self.startup_warnings: List[str] = []
self.index_manifest: Optional[Dict[str, Any]] = None
self.rag_engine: Optional[RAGEngine] = None
# Validate required secrets
missing: List[str] = []
if not self.config.get_str("embedding.api_key"):
missing.append("EMBED_API_KEY")
if not self.config.get_str("llm.api_key"):
missing.append("LLM_API_KEY")
if missing:
self.startup_error = (
"Missing required Hugging Face Space Secrets: "
+ ", ".join([f"`{m}`" for m in missing])
+ ".\n\nPlease set them in the Space **Settings → Secrets** and restart the Space."
)
return
# Reranker is optional; warn if enabled but missing key
if self.config.get_bool("reranker.enabled", True) and not self.config.get_str("reranker.api_key"):
self.startup_warnings.append(
"Reranker is enabled but `RERANK_API_KEY` is missing. Reranking will be disabled (fallback to no-op)."
)
# Ensure index exists locally (download if needed)
try:
idx = ensure_index(
repo_id=self.config.get_str("index.repo_id"),
revision=self.config.get_str("index.revision", "main") or None,
target_vector_db_path=self.config.get_str("storage.vector_db_path"),
target_doc_store_path=self.config.get_str("storage.doc_store_path"),
storage_dir=str(Path(self.config.get_str("storage.doc_store_path")).parent),
)
self.index_manifest = read_manifest(idx.manifest_path)
# Optional: warn if embedding model differs
if self.index_manifest:
idx_model = (
(self.index_manifest.get("embedding") or {}).get("model_name")
or self.index_manifest.get("embedding_model")
or ""
)
cfg_model = self.config.get_str("embedding.model_name")
if idx_model and cfg_model and idx_model != cfg_model:
self.startup_warnings.append(
f"Index embedding model mismatch: index='{idx_model}' vs config='{cfg_model}'. "
"For best results, rebuild the index with the same embedding model."
)
except Exception as e:
# Try to provide actionable guidance for common HF Hub errors.
repo_id = self.config.get_str("index.repo_id")
try:
from huggingface_hub.utils import ( # type: ignore
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
)
if isinstance(e, RepositoryNotFoundError):
self.startup_error = (
f"Index dataset repo not found: `{repo_id}`.\n\n"
"If you haven't uploaded the prebuilt index yet, build and publish it locally:\n"
"1) `pip install -r requirements-dev.txt`\n"
"2) `python scripts/build_vector_db.py --config config/default_config.yaml --source huggingface --dataset ZhangNy/radiology-dataset --output-dir ./index_out`\n"
f"3) `python scripts/publish_index_to_hf.py --repo {repo_id} --folder ./index_out --token $HF_TOKEN`\n\n"
"Or set `RAG_INDEX_REPO_ID` to an existing index repo."
)
return
if isinstance(e, GatedRepoError):
self.startup_error = (
f"Index dataset repo is gated/private: `{repo_id}`.\n\n"
"Make sure the repo is public, or provide authentication (HF token) in the environment."
)
return
if isinstance(e, HfHubHTTPError):
self.startup_error = (
f"Failed to download index from `{repo_id}`.\n\n"
f"HF Hub error: {e}"
)
return
except Exception:
# If importing HF-specific exceptions fails, fall back to generic message.
pass
self.startup_error = (
f"Failed to prepare index from `{repo_id}`.\n\n"
f"Error: {e}"
)
return
# Build RAG engine
try:
self.rag_engine = RAGEngine(self.config)
except Exception as e:
self.startup_error = f"Failed to initialize RAG engine: {e}"
return
def process_query(
self,
question: str,
temperature: float,
source_filters: List[str],
retrieval_strategy: str,
) -> Iterator[Tuple[str, str, str]]:
if self.startup_error:
yield format_error_message(self.startup_error), "", ""
return
if self.rag_engine is None:
yield format_error_message("RAG engine not initialized."), "", ""
return
q = (question or "").strip()
if not q:
yield format_error_message("Please enter a question."), "", ""
return
# Update LLM temperature on the fly
try:
self.rag_engine.llm.temperature = float(temperature)
except Exception:
pass
sources = source_filters or []
loading_md = (
f"{format_loading_message()}\n\n"
f"**Retrieval Strategy**: `{retrieval_strategy}`\n\n"
f"**Sources**: `{', '.join(sources) if sources else 'ALL'}`"
)
loading_refs = "<p style='color:#6b7280;text-align:center;padding:20px;'>Retrieving & reranking...</p>"
loading_stats = "<p style='color:#6b7280;padding:10px;'>Working...</p>"
yield loading_md, loading_refs, loading_stats
start_time = time.time()
last_partial = ""
try:
for event in self.rag_engine.query_stream(
question=q,
source_filters=sources if sources else None,
retrieval_strategy=retrieval_strategy,
):
etype = (event or {}).get("type")
if etype == "answer":
partial = (event.get("answer") or "")
if partial and partial != last_partial:
# Make citations clickable: [1] -> [1](#ref-1)
answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", partial)
last_partial = partial
yield answer_md, loading_refs, loading_stats
elif etype == "final":
meta = event.get("metadata") or {}
# If engine didn't populate elapsed_time (it does), we fill it.
meta.setdefault("elapsed_time", time.time() - start_time)
final_answer = (event.get("answer") or "")
answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", final_answer)
references_html = format_reference_panel(event.get("references") or [])
stats_html = format_statistics(meta)
yield answer_md, references_html, stats_html
return
yield format_error_message("No response was generated. Please try again."), "", ""
except Exception as e:
logger.error(f"Error processing query: {e}", exc_info=True)
yield format_error_message(f"An error occurred: {e}"), "", ""
def create_interface(self) -> gr.Blocks:
title = self.config.get_str("ui.title", "Radiology RAG")
description = self.config.get_str("ui.description", "")
theme = self.config.get_str("ui.theme", "soft")
default_strategy = self.config.get_str("retrieval.strategy", "balanced_multi_source")
default_sources = self.config.get("retrieval.source_filters", ["article", "case", "tutorial", "encyclopedia"])
if not isinstance(default_sources, list):
default_sources = ["article", "case", "tutorial", "encyclopedia"]
default_temp = self.config.get_float("llm.temperature", 0.7)
with gr.Blocks(title=title, theme=theme) as interface:
gr.Markdown(f"# {title}")
if description:
gr.Markdown(description)
if self.startup_error:
gr.Markdown(format_error_message(self.startup_error))
gr.Markdown(
"### Required Secrets\n"
"- `EMBED_API_KEY`\n"
"- `LLM_API_KEY`\n\n"
"Optional (recommended):\n"
"- `RERANK_API_KEY`\n"
)
return interface
if self.startup_warnings:
gr.Markdown("### ⚠️ Startup Warnings")
gr.Markdown("\n".join([f"- {w}" for w in self.startup_warnings]))
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Ask a Question")
question_input = gr.Textbox(
label="Your Question",
placeholder="e.g., What is achalasia and how is it diagnosed?",
lines=3,
)
retrieval_strategy, temperature_slider, source_filter = create_settings_accordion(
default_strategy=default_strategy,
default_temperature=default_temp,
default_sources=default_sources,
)
submit_btn = gr.Button("Search & Answer", variant="primary", size="lg")
gr.Markdown("### Example Questions")
gr.Examples(
examples=[
["What is achalasia and how is it diagnosed on imaging?"],
["Explain the imaging findings in Barrett's esophagus"],
["What are the characteristics of a Zenker's diverticulum?"],
["Describe the CT findings of esophageal cancer"],
],
inputs=[question_input],
label="Click an example to try it",
)
with gr.Column(scale=2):
gr.Markdown("### Answer (with citations)")
answer_output = gr.Markdown(value="*Your answer will appear here...*")
stats_output = gr.HTML(label="Statistics")
gr.Markdown("### Retrieved References")
references_output = gr.HTML(
value="<p style='color:#6b7280;text-align:center;padding:20px;'>References will appear here...</p>"
)
submit_btn.click(
fn=self.process_query,
inputs=[question_input, temperature_slider, source_filter, retrieval_strategy],
outputs=[answer_output, references_output, stats_output],
)
gr.Markdown("---")
with gr.Accordion("About", open=False):
gr.Markdown(
"This Space demonstrates a radiology RAG system using a prebuilt vector index "
f"(`{self.config.get_str('index.repo_id')}`) and external APIs for embeddings/LLM.\n\n"
"**Disclaimer**: Educational use only. Always consult qualified professionals for clinical decisions."
)
return interface