Spaces:
Running on Zero
Running on Zero
| """Gradio UI for the Radiology RAG Space.""" | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterator, List, Optional, Tuple | |
| import gradio as gr | |
| from radiology_rag.config import Config | |
| from radiology_rag.index_bootstrap import ensure_index, read_manifest | |
| from radiology_rag.rag import RAGEngine | |
| logger = logging.getLogger(__name__) | |
| def _truncate(text: str, max_len: int) -> str: | |
| s = (text or "").strip() | |
| if len(s) <= max_len: | |
| return s | |
| return s[: max(0, max_len - 3)] + "..." | |
| def format_error_message(error: str) -> str: | |
| return f"**⚠️ Error**\n\n{error}" | |
| def format_loading_message() -> str: | |
| return "**🔄 Processing your query...**\n\nRetrieving relevant sources and generating an answer with citations." | |
| def format_reference_card(doc: Dict[str, Any], index: int) -> str: | |
| title = doc.get("title", "Untitled") or "Untitled" | |
| source_type = (doc.get("source_type") or "").upper() | |
| url = doc.get("url", "") or "" | |
| content = doc.get("content", "") or "" | |
| score = float(doc.get("score", 0.0) or 0.0) | |
| max_preview_length = 350 | |
| preview = _truncate(content, max_preview_length).replace("\n", " ") | |
| type_colors = { | |
| "ARTICLE": "#3b82f6", | |
| "CASE": "#10b981", | |
| "TUTORIAL": "#f59e0b", | |
| "ENCYCLOPEDIA": "#8b5cf6", | |
| } | |
| color = type_colors.get(source_type, "#6b7280") | |
| score_html = f"<span style='color:#6b7280;font-size:12px;'>Score: {score:.3f}</span>" if score > 0 else "" | |
| url_html = ( | |
| f"<p style='margin:0 0 8px 0;font-size:12px;'><a href='{url}' target='_blank' " | |
| f"style='color:#3b82f6;text-decoration:none;'>🔗 View Source</a></p>" | |
| if url | |
| else "" | |
| ) | |
| return f""" | |
| <div id="ref-{index}" style="border:1px solid #e5e7eb;border-radius:8px;padding:16px;margin-bottom:16px;background:white;scroll-margin-top:90px;"> | |
| <div style="display:flex;align-items:center;gap:8px;margin-bottom:12px;flex-wrap:wrap;"> | |
| <span style="background:{color};color:white;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;"> | |
| {source_type or "SOURCE"} | |
| </span> | |
| <span style="background:#f3f4f6;color:#374151;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;"> | |
| [{index}] | |
| </span> | |
| {score_html} | |
| </div> | |
| <h3 style="margin:0 0 8px 0;color:#111827;font-size:18px;">{title}</h3> | |
| {url_html} | |
| <p style="margin:0;color:#4b5563;font-size:14px;line-height:1.5;">{preview}</p> | |
| </div> | |
| """ | |
| def format_reference_panel(references: List[Dict[str, Any]]) -> str: | |
| if not references: | |
| return "<p style='color:#6b7280;text-align:center;padding:20px;'>No references available</p>" | |
| html_parts = ['<div style="max-height: 600px; overflow-y: auto;">'] | |
| for i, doc in enumerate(references, 1): | |
| html_parts.append(format_reference_card(doc, i)) | |
| html_parts.append("</div>") | |
| return "".join(html_parts) | |
| def format_statistics(metadata: Dict[str, Any]) -> str: | |
| num_retrieved = int(metadata.get("num_retrieved", 0) or 0) | |
| num_reranked = int(metadata.get("num_reranked", 0) or 0) | |
| source_dist = metadata.get("source_type_distribution", {}) or {} | |
| retrieved_label = metadata.get("retrieved_label", "Retrieved") | |
| reranked_label = metadata.get("reranked_label", "After Reranking") | |
| elapsed = float(metadata.get("elapsed_time", 0.0) or 0.0) | |
| strategy = metadata.get("retrieval_strategy", "") | |
| chips = "".join( | |
| [ | |
| f"<span style='display:inline-block;background:#e5e7eb;color:#111827;padding:4px 8px;border-radius:4px;margin-right:8px;font-size:12px;line-height:1.2;'>{k}: {v}</span>" | |
| for k, v in source_dist.items() | |
| ] | |
| ) | |
| return f""" | |
| <div style="background:#f9fafb;padding:16px;border-radius:8px;margin-top:16px;"> | |
| <h4 style="margin:0 0 12px 0;color:#374151;font-size:14px;">📊 Query Statistics</h4> | |
| <div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(150px,1fr));gap:12px;"> | |
| <div> | |
| <p style="margin:0;color:#6b7280;font-size:12px;">{retrieved_label}</p> | |
| <p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_retrieved}</p> | |
| </div> | |
| <div> | |
| <p style="margin:0;color:#6b7280;font-size:12px;">{reranked_label}</p> | |
| <p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_reranked}</p> | |
| </div> | |
| <div> | |
| <p style="margin:0;color:#6b7280;font-size:12px;">Elapsed</p> | |
| <p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{elapsed:.2f}s</p> | |
| </div> | |
| </div> | |
| <div style="margin-top:12px;"> | |
| <p style="margin:0 0 6px 0;color:#6b7280;font-size:12px;">Retrieval Strategy: <code>{strategy}</code></p> | |
| <p style="margin:0 0 4px 0;color:#6b7280;font-size:12px;">Source Distribution:</p> | |
| {chips if chips else "<span style='color:#6b7280;font-size:12px;'>N/A</span>"} | |
| </div> | |
| </div> | |
| """ | |
| def create_settings_accordion( | |
| *, | |
| default_strategy: str, | |
| default_temperature: float, | |
| default_sources: List[str], | |
| ) -> Tuple[gr.Radio, gr.Slider, gr.CheckboxGroup]: | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| gr.Markdown( | |
| "#### Retrieval Strategy\n" | |
| "- **default**: one mixed retrieval + single rerank (fast)\n" | |
| "- **balanced_multi_source**: per-source recall + per-source rerank + Wikipedia (more diverse)\n" | |
| ) | |
| retrieval_strategy = gr.Radio( | |
| choices=["default", "balanced_multi_source"], | |
| value=default_strategy, | |
| label="Retrieval Strategy", | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=float(default_temperature), | |
| step=0.1, | |
| label="LLM Temperature", | |
| ) | |
| source_filter = gr.CheckboxGroup( | |
| choices=["article", "case", "tutorial", "encyclopedia"], | |
| value=default_sources, | |
| label="Filter by Source Type", | |
| ) | |
| return retrieval_strategy, temperature_slider, source_filter | |
| class RadiologyRAGApp: | |
| def __init__(self, config_path: str): | |
| self.config = Config(config_path) | |
| self.startup_error: Optional[str] = None | |
| self.startup_warnings: List[str] = [] | |
| self.index_manifest: Optional[Dict[str, Any]] = None | |
| self.rag_engine: Optional[RAGEngine] = None | |
| # Validate required secrets | |
| missing: List[str] = [] | |
| if not self.config.get_str("embedding.api_key"): | |
| missing.append("EMBED_API_KEY") | |
| if not self.config.get_str("llm.api_key"): | |
| missing.append("LLM_API_KEY") | |
| if missing: | |
| self.startup_error = ( | |
| "Missing required Hugging Face Space Secrets: " | |
| + ", ".join([f"`{m}`" for m in missing]) | |
| + ".\n\nPlease set them in the Space **Settings → Secrets** and restart the Space." | |
| ) | |
| return | |
| # Reranker is optional; warn if enabled but missing key | |
| if self.config.get_bool("reranker.enabled", True) and not self.config.get_str("reranker.api_key"): | |
| self.startup_warnings.append( | |
| "Reranker is enabled but `RERANK_API_KEY` is missing. Reranking will be disabled (fallback to no-op)." | |
| ) | |
| # Ensure index exists locally (download if needed) | |
| try: | |
| idx = ensure_index( | |
| repo_id=self.config.get_str("index.repo_id"), | |
| revision=self.config.get_str("index.revision", "main") or None, | |
| target_vector_db_path=self.config.get_str("storage.vector_db_path"), | |
| target_doc_store_path=self.config.get_str("storage.doc_store_path"), | |
| storage_dir=str(Path(self.config.get_str("storage.doc_store_path")).parent), | |
| ) | |
| self.index_manifest = read_manifest(idx.manifest_path) | |
| # Optional: warn if embedding model differs | |
| if self.index_manifest: | |
| idx_model = ( | |
| (self.index_manifest.get("embedding") or {}).get("model_name") | |
| or self.index_manifest.get("embedding_model") | |
| or "" | |
| ) | |
| cfg_model = self.config.get_str("embedding.model_name") | |
| if idx_model and cfg_model and idx_model != cfg_model: | |
| self.startup_warnings.append( | |
| f"Index embedding model mismatch: index='{idx_model}' vs config='{cfg_model}'. " | |
| "For best results, rebuild the index with the same embedding model." | |
| ) | |
| except Exception as e: | |
| # Try to provide actionable guidance for common HF Hub errors. | |
| repo_id = self.config.get_str("index.repo_id") | |
| try: | |
| from huggingface_hub.utils import ( # type: ignore | |
| GatedRepoError, | |
| HfHubHTTPError, | |
| RepositoryNotFoundError, | |
| ) | |
| if isinstance(e, RepositoryNotFoundError): | |
| self.startup_error = ( | |
| f"Index dataset repo not found: `{repo_id}`.\n\n" | |
| "If you haven't uploaded the prebuilt index yet, build and publish it locally:\n" | |
| "1) `pip install -r requirements-dev.txt`\n" | |
| "2) `python scripts/build_vector_db.py --config config/default_config.yaml --source huggingface --dataset ZhangNy/radiology-dataset --output-dir ./index_out`\n" | |
| f"3) `python scripts/publish_index_to_hf.py --repo {repo_id} --folder ./index_out --token $HF_TOKEN`\n\n" | |
| "Or set `RAG_INDEX_REPO_ID` to an existing index repo." | |
| ) | |
| return | |
| if isinstance(e, GatedRepoError): | |
| self.startup_error = ( | |
| f"Index dataset repo is gated/private: `{repo_id}`.\n\n" | |
| "Make sure the repo is public, or provide authentication (HF token) in the environment." | |
| ) | |
| return | |
| if isinstance(e, HfHubHTTPError): | |
| self.startup_error = ( | |
| f"Failed to download index from `{repo_id}`.\n\n" | |
| f"HF Hub error: {e}" | |
| ) | |
| return | |
| except Exception: | |
| # If importing HF-specific exceptions fails, fall back to generic message. | |
| pass | |
| self.startup_error = ( | |
| f"Failed to prepare index from `{repo_id}`.\n\n" | |
| f"Error: {e}" | |
| ) | |
| return | |
| # Build RAG engine | |
| try: | |
| self.rag_engine = RAGEngine(self.config) | |
| except Exception as e: | |
| self.startup_error = f"Failed to initialize RAG engine: {e}" | |
| return | |
| def process_query( | |
| self, | |
| question: str, | |
| temperature: float, | |
| source_filters: List[str], | |
| retrieval_strategy: str, | |
| ) -> Iterator[Tuple[str, str, str]]: | |
| if self.startup_error: | |
| yield format_error_message(self.startup_error), "", "" | |
| return | |
| if self.rag_engine is None: | |
| yield format_error_message("RAG engine not initialized."), "", "" | |
| return | |
| q = (question or "").strip() | |
| if not q: | |
| yield format_error_message("Please enter a question."), "", "" | |
| return | |
| # Update LLM temperature on the fly | |
| try: | |
| self.rag_engine.llm.temperature = float(temperature) | |
| except Exception: | |
| pass | |
| sources = source_filters or [] | |
| loading_md = ( | |
| f"{format_loading_message()}\n\n" | |
| f"**Retrieval Strategy**: `{retrieval_strategy}`\n\n" | |
| f"**Sources**: `{', '.join(sources) if sources else 'ALL'}`" | |
| ) | |
| loading_refs = "<p style='color:#6b7280;text-align:center;padding:20px;'>Retrieving & reranking...</p>" | |
| loading_stats = "<p style='color:#6b7280;padding:10px;'>Working...</p>" | |
| yield loading_md, loading_refs, loading_stats | |
| start_time = time.time() | |
| last_partial = "" | |
| try: | |
| for event in self.rag_engine.query_stream( | |
| question=q, | |
| source_filters=sources if sources else None, | |
| retrieval_strategy=retrieval_strategy, | |
| ): | |
| etype = (event or {}).get("type") | |
| if etype == "answer": | |
| partial = (event.get("answer") or "") | |
| if partial and partial != last_partial: | |
| # Make citations clickable: [1] -> [1](#ref-1) | |
| answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", partial) | |
| last_partial = partial | |
| yield answer_md, loading_refs, loading_stats | |
| elif etype == "final": | |
| meta = event.get("metadata") or {} | |
| # If engine didn't populate elapsed_time (it does), we fill it. | |
| meta.setdefault("elapsed_time", time.time() - start_time) | |
| final_answer = (event.get("answer") or "") | |
| answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", final_answer) | |
| references_html = format_reference_panel(event.get("references") or []) | |
| stats_html = format_statistics(meta) | |
| yield answer_md, references_html, stats_html | |
| return | |
| yield format_error_message("No response was generated. Please try again."), "", "" | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}", exc_info=True) | |
| yield format_error_message(f"An error occurred: {e}"), "", "" | |
| def create_interface(self) -> gr.Blocks: | |
| title = self.config.get_str("ui.title", "Radiology RAG") | |
| description = self.config.get_str("ui.description", "") | |
| theme = self.config.get_str("ui.theme", "soft") | |
| default_strategy = self.config.get_str("retrieval.strategy", "balanced_multi_source") | |
| default_sources = self.config.get("retrieval.source_filters", ["article", "case", "tutorial", "encyclopedia"]) | |
| if not isinstance(default_sources, list): | |
| default_sources = ["article", "case", "tutorial", "encyclopedia"] | |
| default_temp = self.config.get_float("llm.temperature", 0.7) | |
| with gr.Blocks(title=title, theme=theme) as interface: | |
| gr.Markdown(f"# {title}") | |
| if description: | |
| gr.Markdown(description) | |
| if self.startup_error: | |
| gr.Markdown(format_error_message(self.startup_error)) | |
| gr.Markdown( | |
| "### Required Secrets\n" | |
| "- `EMBED_API_KEY`\n" | |
| "- `LLM_API_KEY`\n\n" | |
| "Optional (recommended):\n" | |
| "- `RERANK_API_KEY`\n" | |
| ) | |
| return interface | |
| if self.startup_warnings: | |
| gr.Markdown("### ⚠️ Startup Warnings") | |
| gr.Markdown("\n".join([f"- {w}" for w in self.startup_warnings])) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Ask a Question") | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="e.g., What is achalasia and how is it diagnosed?", | |
| lines=3, | |
| ) | |
| retrieval_strategy, temperature_slider, source_filter = create_settings_accordion( | |
| default_strategy=default_strategy, | |
| default_temperature=default_temp, | |
| default_sources=default_sources, | |
| ) | |
| submit_btn = gr.Button("Search & Answer", variant="primary", size="lg") | |
| gr.Markdown("### Example Questions") | |
| gr.Examples( | |
| examples=[ | |
| ["What is achalasia and how is it diagnosed on imaging?"], | |
| ["Explain the imaging findings in Barrett's esophagus"], | |
| ["What are the characteristics of a Zenker's diverticulum?"], | |
| ["Describe the CT findings of esophageal cancer"], | |
| ], | |
| inputs=[question_input], | |
| label="Click an example to try it", | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Answer (with citations)") | |
| answer_output = gr.Markdown(value="*Your answer will appear here...*") | |
| stats_output = gr.HTML(label="Statistics") | |
| gr.Markdown("### Retrieved References") | |
| references_output = gr.HTML( | |
| value="<p style='color:#6b7280;text-align:center;padding:20px;'>References will appear here...</p>" | |
| ) | |
| submit_btn.click( | |
| fn=self.process_query, | |
| inputs=[question_input, temperature_slider, source_filter, retrieval_strategy], | |
| outputs=[answer_output, references_output, stats_output], | |
| ) | |
| gr.Markdown("---") | |
| with gr.Accordion("About", open=False): | |
| gr.Markdown( | |
| "This Space demonstrates a radiology RAG system using a prebuilt vector index " | |
| f"(`{self.config.get_str('index.repo_id')}`) and external APIs for embeddings/LLM.\n\n" | |
| "**Disclaimer**: Educational use only. Always consult qualified professionals for clinical decisions." | |
| ) | |
| return interface | |