Spaces:
Running
on
Zero
Running
on
Zero
Add Space app files
Browse files- .gitignore +46 -0
- README.md +83 -1
- app.py +66 -0
- config/default_config.yaml +122 -0
- radiology_rag/__init__.py +12 -0
- radiology_rag/citations.py +54 -0
- radiology_rag/config.py +108 -0
- radiology_rag/doc_store.py +143 -0
- radiology_rag/embedding.py +41 -0
- radiology_rag/encyclopedia.py +194 -0
- radiology_rag/gradio_compat.py +60 -0
- radiology_rag/index_bootstrap.py +160 -0
- radiology_rag/rag.py +257 -0
- radiology_rag/reranker.py +143 -0
- radiology_rag/retrieval.py +403 -0
- radiology_rag/ui.py +423 -0
- requirements-dev.txt +19 -0
- requirements.txt +21 -0
- scripts/build_vector_db.py +219 -0
- scripts/download_hf_dataset.py +43 -0
- scripts/package_existing_storage.py +110 -0
- scripts/publish_index_to_hf.py +78 -0
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.so
|
| 5 |
+
*.egg-info/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
.pytest_cache/
|
| 9 |
+
|
| 10 |
+
# Virtual environments
|
| 11 |
+
venv/
|
| 12 |
+
.venv/
|
| 13 |
+
ENV/
|
| 14 |
+
env/
|
| 15 |
+
|
| 16 |
+
# IDEs
|
| 17 |
+
.vscode/
|
| 18 |
+
.idea/
|
| 19 |
+
*.swp
|
| 20 |
+
*.swo
|
| 21 |
+
*~
|
| 22 |
+
|
| 23 |
+
# Local runtime storage (vector DB + sqlite doc store)
|
| 24 |
+
storage/
|
| 25 |
+
*.db
|
| 26 |
+
*.sqlite
|
| 27 |
+
*.sqlite3
|
| 28 |
+
|
| 29 |
+
# Dataset caches / artifacts
|
| 30 |
+
hf_dataset_prepared/
|
| 31 |
+
.cache/
|
| 32 |
+
.huggingface/
|
| 33 |
+
|
| 34 |
+
# Raw data (private; never publish to Spaces)
|
| 35 |
+
old_related_files/
|
| 36 |
+
private_scripts/
|
| 37 |
+
|
| 38 |
+
# Logs
|
| 39 |
+
*.log
|
| 40 |
+
logs/
|
| 41 |
+
|
| 42 |
+
# Env files
|
| 43 |
+
.env
|
| 44 |
+
.env.local
|
| 45 |
+
|
| 46 |
+
|
README.md
CHANGED
|
@@ -11,4 +11,86 @@ license: mit
|
|
| 11 |
short_description: 'Ask questions about thoracic radiology and get answers with '
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
short_description: 'Ask questions about thoracic radiology and get answers with '
|
| 12 |
---
|
| 13 |
|
| 14 |
+
|
| 15 |
+
## Overview
|
| 16 |
+
|
| 17 |
+
This repository contains a **Hugging Face Spaces-ready** RAG (Retrieval-Augmented Generation) demo for thoracic radiology Q&A.
|
| 18 |
+
|
| 19 |
+
- **Default index (prebuilt)**: `ZhangNy/radiology-index-qwen3-embedding-0.6b`
|
| 20 |
+
- **Raw public dataset**: `ZhangNy/radiology-dataset`
|
| 21 |
+
- **No image rendering in UI**: references link to original pages where images can be viewed.
|
| 22 |
+
|
| 23 |
+
The Space uses **external APIs** for Embeddings / Reranker / LLM via **Secrets**.
|
| 24 |
+
|
| 25 |
+
## Run (local)
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
cd LangGraphAgent/rebuild_1219
|
| 29 |
+
pip install -r requirements.txt
|
| 30 |
+
|
| 31 |
+
export EMBED_API_KEY="..."
|
| 32 |
+
export LLM_API_KEY="..."
|
| 33 |
+
# optional:
|
| 34 |
+
export RERANK_API_KEY="..."
|
| 35 |
+
|
| 36 |
+
python app.py --config config/default_config.yaml --host 0.0.0.0 --port 7860
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Open `http://localhost:7860`.
|
| 40 |
+
|
| 41 |
+
## Required Hugging Face Space Secrets
|
| 42 |
+
|
| 43 |
+
### Required
|
| 44 |
+
|
| 45 |
+
- **`EMBED_API_KEY`**: embedding API key (OpenAI-compatible)
|
| 46 |
+
- **`LLM_API_KEY`**: LLM API key (OpenAI-compatible)
|
| 47 |
+
|
| 48 |
+
### Recommended
|
| 49 |
+
|
| 50 |
+
- **`RERANK_API_KEY`**: reranker API key (OpenAI-compatible `/rerank` endpoint)
|
| 51 |
+
|
| 52 |
+
### Optional (override defaults)
|
| 53 |
+
|
| 54 |
+
- **`EMBED_API_BASE_URL`**, **`EMBED_MODEL_NAME`**
|
| 55 |
+
- **`RERANK_API_BASE_URL`**, **`RERANK_MODEL_NAME`**
|
| 56 |
+
- **`LLM_BASE_URL`**, **`LLM_MODEL_NAME`**
|
| 57 |
+
- **`RAG_INDEX_REPO_ID`** (default: `ZhangNy/radiology-index-qwen3-embedding-0.6b`)
|
| 58 |
+
- **`RAG_STORAGE_DIR`** (default: `/data/radiology_rag` if `/data` exists, else `./storage`)
|
| 59 |
+
|
| 60 |
+
## Advanced: rebuild your own index (offline)
|
| 61 |
+
|
| 62 |
+
Install dev deps:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
pip install -r requirements-dev.txt
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
The `scripts/` folder (to be used locally) will support:
|
| 69 |
+
- Downloading `ZhangNy/radiology-dataset` to `./hf_dataset_prepared`
|
| 70 |
+
- Building a new index with a different embedding model
|
| 71 |
+
- Publishing that index as a Hugging Face dataset repo
|
| 72 |
+
|
| 73 |
+
### Fast path (no rebuild): publish your existing local index
|
| 74 |
+
|
| 75 |
+
If you already have a built index locally (e.g. `rebuild_1217/storage` contains `chroma_db/` + `doc_store.db`),
|
| 76 |
+
you can **package it without images** and upload it:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python scripts/package_existing_storage.py \
|
| 80 |
+
--storage /home/zny/codes/radioagent_prepare/LangGraphAgent/rebuild_1217/storage \
|
| 81 |
+
--output-dir ./index_out \
|
| 82 |
+
--overwrite
|
| 83 |
+
|
| 84 |
+
python scripts/publish_index_to_hf.py \
|
| 85 |
+
--repo ZhangNy/radiology-index-qwen3-embedding-0.6b \
|
| 86 |
+
--folder ./index_out \
|
| 87 |
+
--token $HF_TOKEN
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## Notes
|
| 91 |
+
|
| 92 |
+
- **Do not commit API keys**. This repo is configured to read them from environment variables / Space Secrets.
|
| 93 |
+
- **Index compatibility**: query-time embedding model should match the index embedding model for best retrieval quality.
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
app.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Spaces entry point (Gradio).
|
| 3 |
+
|
| 4 |
+
Run locally:
|
| 5 |
+
python app.py --config config/default_config.yaml --host 0.0.0.0 --port 7860
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from radiology_rag.gradio_compat import patch_gradio_predict_body_for_pydantic_v2
|
| 16 |
+
from radiology_rag.ui import RadiologyRAGApp
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _configure_logging() -> None:
|
| 20 |
+
level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 21 |
+
logging.basicConfig(
|
| 22 |
+
level=getattr(logging, level, logging.INFO),
|
| 23 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _default_storage_dir() -> str:
|
| 28 |
+
# Prefer /data on Spaces if persistent storage is enabled.
|
| 29 |
+
if Path("/data").exists():
|
| 30 |
+
return "/data/radiology_rag"
|
| 31 |
+
return "./storage"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main() -> int:
|
| 35 |
+
_configure_logging()
|
| 36 |
+
|
| 37 |
+
parser = argparse.ArgumentParser(description="Radiology RAG (Spaces-ready)")
|
| 38 |
+
parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Path to config YAML")
|
| 39 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
|
| 40 |
+
parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "7860")), help="Server port")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
# Ensure storage dir env is set early so config interpolation uses it.
|
| 44 |
+
if not os.getenv("RAG_STORAGE_DIR"):
|
| 45 |
+
os.environ["RAG_STORAGE_DIR"] = _default_storage_dir()
|
| 46 |
+
|
| 47 |
+
# Optional compatibility patch for Gradio 4.16 + Pydantic v2.
|
| 48 |
+
if patch_gradio_predict_body_for_pydantic_v2():
|
| 49 |
+
logging.getLogger(__name__).info("Applied Gradio/Pydantic v2 compatibility patch")
|
| 50 |
+
|
| 51 |
+
app = RadiologyRAGApp(config_path=args.config)
|
| 52 |
+
demo = app.create_interface()
|
| 53 |
+
|
| 54 |
+
demo.launch(
|
| 55 |
+
server_name=args.host,
|
| 56 |
+
server_port=args.port,
|
| 57 |
+
share=False,
|
| 58 |
+
show_error=True,
|
| 59 |
+
)
|
| 60 |
+
return 0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
raise SystemExit(main())
|
| 65 |
+
|
| 66 |
+
|
config/default_config.yaml
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default configuration for the Hugging Face Spaces deployment (API-first).
|
| 2 |
+
# IMPORTANT: do NOT put any real API keys in this file. Use Spaces Secrets instead.
|
| 3 |
+
|
| 4 |
+
# Prebuilt index (vector DB + doc store) stored as a Hugging Face dataset repo.
|
| 5 |
+
index:
|
| 6 |
+
repo_id: "${RAG_INDEX_REPO_ID:ZhangNy/radiology-index-qwen3-embedding-0.6b}"
|
| 7 |
+
revision: "${RAG_INDEX_REVISION:main}"
|
| 8 |
+
|
| 9 |
+
# Embedding configuration (query-time embeddings must match the index embedding model).
|
| 10 |
+
embedding:
|
| 11 |
+
type: "api" # Options: "api" or "local" (local requires requirements-dev.txt)
|
| 12 |
+
api_base_url: "${EMBED_API_BASE_URL:https://api.siliconflow.cn/v1}"
|
| 13 |
+
api_key: "${EMBED_API_KEY:}"
|
| 14 |
+
model_name: "${EMBED_MODEL_NAME:Qwen/Qwen3-Embedding-0.6B}"
|
| 15 |
+
batch_size: 32
|
| 16 |
+
|
| 17 |
+
# Reranker configuration (recommended; can be disabled if no key is provided).
|
| 18 |
+
reranker:
|
| 19 |
+
enabled: true
|
| 20 |
+
type: "api" # Options: "api" or "local" (local requires requirements-dev.txt)
|
| 21 |
+
api_base_url: "${RERANK_API_BASE_URL:https://api.siliconflow.cn/v1}"
|
| 22 |
+
api_key: "${RERANK_API_KEY:}"
|
| 23 |
+
model_name: "${RERANK_MODEL_NAME:BAAI/bge-reranker-v2-m3}"
|
| 24 |
+
top_k: 10
|
| 25 |
+
|
| 26 |
+
# LLM configuration (OpenAI-compatible API).
|
| 27 |
+
llm:
|
| 28 |
+
base_url: "${LLM_BASE_URL:https://poloai.top/v1}"
|
| 29 |
+
api_key: "${LLM_API_KEY:}"
|
| 30 |
+
model_name: "${LLM_MODEL_NAME:gemini-3-flash-preview}"
|
| 31 |
+
temperature: 0.7
|
| 32 |
+
max_tokens: 2000
|
| 33 |
+
|
| 34 |
+
# Storage paths (prefer /data on Spaces; app.py will default RAG_STORAGE_DIR to /data/radiology_rag when available).
|
| 35 |
+
storage:
|
| 36 |
+
vector_db_path: "${RAG_STORAGE_DIR:./storage}/chroma_db"
|
| 37 |
+
doc_store_path: "${RAG_STORAGE_DIR:./storage}/doc_store.db"
|
| 38 |
+
|
| 39 |
+
# Text splitting parameters (used for index build scripts; kept here for transparency).
|
| 40 |
+
processing:
|
| 41 |
+
chunk_size: 1024
|
| 42 |
+
chunk_overlap: 200
|
| 43 |
+
separators:
|
| 44 |
+
- "\n\n#### "
|
| 45 |
+
- "\n\n### "
|
| 46 |
+
- "\n\n## "
|
| 47 |
+
- "\n\n"
|
| 48 |
+
- "\n"
|
| 49 |
+
- " "
|
| 50 |
+
- ""
|
| 51 |
+
keep_separator: true
|
| 52 |
+
|
| 53 |
+
# Retrieval configuration
|
| 54 |
+
retrieval:
|
| 55 |
+
# Default strategy for this Space:
|
| 56 |
+
# - balanced_multi_source: includes Wikipedia (encyclopedia) by default
|
| 57 |
+
strategy: "balanced_multi_source"
|
| 58 |
+
top_k: 20
|
| 59 |
+
source_filters:
|
| 60 |
+
- "article"
|
| 61 |
+
- "case"
|
| 62 |
+
- "tutorial"
|
| 63 |
+
- "encyclopedia"
|
| 64 |
+
|
| 65 |
+
search_type: "similarity" # "similarity" or "mmr"
|
| 66 |
+
chunk_fetch_multiplier: 3
|
| 67 |
+
|
| 68 |
+
# MMR parameters (only if search_type == "mmr")
|
| 69 |
+
mmr_lambda: 0.5
|
| 70 |
+
mmr_fetch_k: 50
|
| 71 |
+
|
| 72 |
+
# Balanced multi-source retrieval policy
|
| 73 |
+
multi_source:
|
| 74 |
+
total_top_k: 8
|
| 75 |
+
sources_priority: ["article", "case", "encyclopedia", "tutorial"]
|
| 76 |
+
article:
|
| 77 |
+
candidate_k: 80
|
| 78 |
+
max_k: 3
|
| 79 |
+
min_score: 0.15
|
| 80 |
+
required: true
|
| 81 |
+
case:
|
| 82 |
+
candidate_k: 80
|
| 83 |
+
max_k: 3
|
| 84 |
+
min_score: 0.15
|
| 85 |
+
required: true
|
| 86 |
+
encyclopedia:
|
| 87 |
+
candidate_k: 8
|
| 88 |
+
max_k: 2
|
| 89 |
+
min_score: 0.15
|
| 90 |
+
required: true
|
| 91 |
+
tutorial:
|
| 92 |
+
candidate_k: 20
|
| 93 |
+
max_k: 2
|
| 94 |
+
min_score: 0.50
|
| 95 |
+
required: false
|
| 96 |
+
|
| 97 |
+
# Encyclopedia configuration (Wikipedia)
|
| 98 |
+
encyclopedia:
|
| 99 |
+
wikipedia:
|
| 100 |
+
language: "en"
|
| 101 |
+
user_agent: "RadiologyRAG-Space/1.0"
|
| 102 |
+
timeout_s: 10
|
| 103 |
+
max_chars_per_doc: 2000
|
| 104 |
+
|
| 105 |
+
# Citation configuration (no images in this Space; references link to original pages)
|
| 106 |
+
citation:
|
| 107 |
+
format: "numbered"
|
| 108 |
+
max_content_length: 900
|
| 109 |
+
|
| 110 |
+
# Gradio UI configuration
|
| 111 |
+
ui:
|
| 112 |
+
title: "Thoracic Radiology RAG System"
|
| 113 |
+
description: "Ask questions about thoracic radiology and get answers with citations (articles, cases, tutorials + Wikipedia)."
|
| 114 |
+
theme: "soft"
|
| 115 |
+
show_retrieved_docs: true
|
| 116 |
+
|
| 117 |
+
# Logging configuration
|
| 118 |
+
logging:
|
| 119 |
+
level: "INFO"
|
| 120 |
+
format: "%(asctime)s - %(levelname)s - %(message)s"
|
| 121 |
+
|
| 122 |
+
|
radiology_rag/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Radiology RAG (Spaces-ready) - minimal, API-first implementation.
|
| 3 |
+
|
| 4 |
+
This package is designed to be deployed on Hugging Face Spaces and load a
|
| 5 |
+
prebuilt vector index from a public Hugging Face dataset repo.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
__all__ = ["__version__"]
|
| 9 |
+
|
| 10 |
+
__version__ = "0.1.0"
|
| 11 |
+
|
| 12 |
+
|
radiology_rag/citations.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Citation helpers (numbered citations like [1], [2], ...)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CitationManager:
|
| 10 |
+
def __init__(self, *, max_content_length: int = 900):
|
| 11 |
+
self.max_content_length = int(max_content_length)
|
| 12 |
+
self.documents: List[Dict[str, Any]] = []
|
| 13 |
+
self.doc_id_to_index: Dict[str, int] = {}
|
| 14 |
+
|
| 15 |
+
def clear(self) -> None:
|
| 16 |
+
self.documents = []
|
| 17 |
+
self.doc_id_to_index = {}
|
| 18 |
+
|
| 19 |
+
def add_document(self, document: Dict[str, Any]) -> int:
|
| 20 |
+
doc_id = document.get("doc_id") or ""
|
| 21 |
+
if doc_id in self.doc_id_to_index:
|
| 22 |
+
return self.doc_id_to_index[doc_id]
|
| 23 |
+
self.documents.append(document)
|
| 24 |
+
idx = len(self.documents)
|
| 25 |
+
self.doc_id_to_index[doc_id] = idx
|
| 26 |
+
return idx
|
| 27 |
+
|
| 28 |
+
def add_documents(self, documents: List[Dict[str, Any]]) -> List[int]:
|
| 29 |
+
return [self.add_document(d) for d in documents]
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def parse_citations_in_text(text: str) -> List[int]:
|
| 33 |
+
matches = re.findall(r"\[(\d+)\]", text or "")
|
| 34 |
+
out = []
|
| 35 |
+
for m in matches:
|
| 36 |
+
try:
|
| 37 |
+
out.append(int(m))
|
| 38 |
+
except Exception:
|
| 39 |
+
continue
|
| 40 |
+
return out
|
| 41 |
+
|
| 42 |
+
def validate_citations(self, text: str) -> Tuple[bool, List[int]]:
|
| 43 |
+
cited = self.parse_citations_in_text(text or "")
|
| 44 |
+
invalid = [i for i in cited if i < 1 or i > len(self.documents)]
|
| 45 |
+
return (len(invalid) == 0), invalid
|
| 46 |
+
|
| 47 |
+
def get_statistics(self) -> Dict[str, Any]:
|
| 48 |
+
counts: Dict[str, int] = {}
|
| 49 |
+
for d in self.documents:
|
| 50 |
+
st = d.get("source_type", "unknown") or "unknown"
|
| 51 |
+
counts[st] = counts.get(st, 0) + 1
|
| 52 |
+
return {"total": len(self.documents), "source_type_counts": counts}
|
| 53 |
+
|
| 54 |
+
|
radiology_rag/config.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Config loader with `${ENV_VAR}` and `${ENV_VAR:default}` interpolation.
|
| 3 |
+
|
| 4 |
+
We intentionally keep config logic lightweight so it works well in Spaces.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Dict, Optional
|
| 13 |
+
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Config:
|
| 18 |
+
"""Load and access YAML config with env var interpolation."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config_path: str):
|
| 21 |
+
self.config_path = Path(config_path)
|
| 22 |
+
self._config = self._load_config()
|
| 23 |
+
self._config = self._recursive_resolve(self._config)
|
| 24 |
+
|
| 25 |
+
def _load_config(self) -> Dict[str, Any]:
|
| 26 |
+
if not self.config_path.exists():
|
| 27 |
+
raise FileNotFoundError(f"Config file not found: {self.config_path}")
|
| 28 |
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
| 29 |
+
data = yaml.safe_load(f) or {}
|
| 30 |
+
if not isinstance(data, dict):
|
| 31 |
+
raise ValueError("Config root must be a mapping/dict")
|
| 32 |
+
return data
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def _resolve_string(value: str) -> str:
|
| 36 |
+
# Pattern: ${VAR_NAME} or ${VAR_NAME:default_value}
|
| 37 |
+
# NOTE: default_value may be empty, e.g. `${API_KEY:}`. Use `*` (not `+`) to allow empty.
|
| 38 |
+
pattern = r"\$\{([^:}]+)(?::([^}]*))?\}"
|
| 39 |
+
|
| 40 |
+
def replace(match: re.Match) -> str:
|
| 41 |
+
var_name = match.group(1)
|
| 42 |
+
default_value = match.group(2) if match.group(2) is not None else ""
|
| 43 |
+
return os.getenv(var_name, default_value)
|
| 44 |
+
|
| 45 |
+
return re.sub(pattern, replace, value)
|
| 46 |
+
|
| 47 |
+
def _recursive_resolve(self, obj: Any) -> Any:
|
| 48 |
+
if isinstance(obj, dict):
|
| 49 |
+
return {k: self._recursive_resolve(v) for k, v in obj.items()}
|
| 50 |
+
if isinstance(obj, list):
|
| 51 |
+
return [self._recursive_resolve(v) for v in obj]
|
| 52 |
+
if isinstance(obj, str):
|
| 53 |
+
return self._resolve_string(obj)
|
| 54 |
+
return obj
|
| 55 |
+
|
| 56 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 57 |
+
keys = key.split(".")
|
| 58 |
+
value: Any = self._config
|
| 59 |
+
for k in keys:
|
| 60 |
+
if isinstance(value, dict) and k in value:
|
| 61 |
+
value = value[k]
|
| 62 |
+
else:
|
| 63 |
+
return default
|
| 64 |
+
return value
|
| 65 |
+
|
| 66 |
+
def get_str(self, key: str, default: str = "") -> str:
|
| 67 |
+
v = self.get(key, default)
|
| 68 |
+
return default if v is None else str(v)
|
| 69 |
+
|
| 70 |
+
def get_int(self, key: str, default: int = 0) -> int:
|
| 71 |
+
v = self.get(key, default)
|
| 72 |
+
if v is None:
|
| 73 |
+
return default
|
| 74 |
+
if isinstance(v, int):
|
| 75 |
+
return v
|
| 76 |
+
try:
|
| 77 |
+
return int(str(v).strip())
|
| 78 |
+
except Exception:
|
| 79 |
+
return default
|
| 80 |
+
|
| 81 |
+
def get_float(self, key: str, default: float = 0.0) -> float:
|
| 82 |
+
v = self.get(key, default)
|
| 83 |
+
if v is None:
|
| 84 |
+
return default
|
| 85 |
+
if isinstance(v, (int, float)):
|
| 86 |
+
return float(v)
|
| 87 |
+
try:
|
| 88 |
+
return float(str(v).strip())
|
| 89 |
+
except Exception:
|
| 90 |
+
return default
|
| 91 |
+
|
| 92 |
+
def get_bool(self, key: str, default: bool = False) -> bool:
|
| 93 |
+
v = self.get(key, default)
|
| 94 |
+
if isinstance(v, bool):
|
| 95 |
+
return v
|
| 96 |
+
if v is None:
|
| 97 |
+
return default
|
| 98 |
+
s = str(v).strip().lower()
|
| 99 |
+
if s in {"1", "true", "yes", "y", "on"}:
|
| 100 |
+
return True
|
| 101 |
+
if s in {"0", "false", "no", "n", "off"}:
|
| 102 |
+
return False
|
| 103 |
+
return default
|
| 104 |
+
|
| 105 |
+
def as_dict(self) -> Dict[str, Any]:
|
| 106 |
+
return dict(self._config)
|
| 107 |
+
|
| 108 |
+
|
radiology_rag/doc_store.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLite-backed document store.
|
| 3 |
+
|
| 4 |
+
We keep the schema compatible with the previous rebuild_1217 implementation:
|
| 5 |
+
table `documents(doc_id, complete_document, main_content, images, source_type)`.
|
| 6 |
+
|
| 7 |
+
In this Space we do NOT use images, but we keep the column for compatibility with
|
| 8 |
+
existing indexes and to allow advanced users to extend the system.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import sqlite3
|
| 17 |
+
from typing import Any, Iterator, List, Optional, Sequence, Tuple
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PersistentDocStore:
|
| 23 |
+
def __init__(self, db_path: str, *, read_only: bool = False):
|
| 24 |
+
self.db_path = db_path
|
| 25 |
+
self.read_only = bool(read_only)
|
| 26 |
+
|
| 27 |
+
if not self.read_only:
|
| 28 |
+
self.init_db()
|
| 29 |
+
|
| 30 |
+
def _connect(self) -> sqlite3.Connection:
|
| 31 |
+
if self.read_only:
|
| 32 |
+
# Open in read-only mode to avoid accidental writes in Spaces runtime.
|
| 33 |
+
return sqlite3.connect(f"file:{self.db_path}?mode=ro", uri=True)
|
| 34 |
+
return sqlite3.connect(self.db_path)
|
| 35 |
+
|
| 36 |
+
def init_db(self) -> None:
|
| 37 |
+
db_dir = os.path.dirname(self.db_path) or "."
|
| 38 |
+
os.makedirs(db_dir, exist_ok=True)
|
| 39 |
+
conn = self._connect()
|
| 40 |
+
try:
|
| 41 |
+
cursor = conn.cursor()
|
| 42 |
+
cursor.execute(
|
| 43 |
+
"""
|
| 44 |
+
CREATE TABLE IF NOT EXISTS documents (
|
| 45 |
+
doc_id TEXT PRIMARY KEY,
|
| 46 |
+
complete_document TEXT,
|
| 47 |
+
main_content TEXT,
|
| 48 |
+
images TEXT,
|
| 49 |
+
source_type TEXT
|
| 50 |
+
)
|
| 51 |
+
"""
|
| 52 |
+
)
|
| 53 |
+
conn.commit()
|
| 54 |
+
finally:
|
| 55 |
+
conn.close()
|
| 56 |
+
|
| 57 |
+
def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
|
| 58 |
+
if self.read_only:
|
| 59 |
+
raise RuntimeError("DocStore is read-only")
|
| 60 |
+
conn = self._connect()
|
| 61 |
+
try:
|
| 62 |
+
cursor = conn.cursor()
|
| 63 |
+
for doc_id, content in key_value_pairs:
|
| 64 |
+
cursor.execute(
|
| 65 |
+
"""
|
| 66 |
+
INSERT OR REPLACE INTO documents
|
| 67 |
+
(doc_id, complete_document, main_content, images, source_type)
|
| 68 |
+
VALUES (?, ?, ?, ?, ?)
|
| 69 |
+
""",
|
| 70 |
+
(
|
| 71 |
+
doc_id,
|
| 72 |
+
json.dumps(content.get("complete_document", {}), ensure_ascii=False),
|
| 73 |
+
content.get("main_content", "") or "",
|
| 74 |
+
json.dumps(content.get("images", []), ensure_ascii=False),
|
| 75 |
+
content.get("source_type", "") or "",
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
conn.commit()
|
| 79 |
+
finally:
|
| 80 |
+
conn.close()
|
| 81 |
+
|
| 82 |
+
def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
|
| 83 |
+
conn = self._connect()
|
| 84 |
+
try:
|
| 85 |
+
cursor = conn.cursor()
|
| 86 |
+
out: List[Optional[Any]] = []
|
| 87 |
+
for doc_id in keys:
|
| 88 |
+
cursor.execute(
|
| 89 |
+
"SELECT complete_document, main_content, images, source_type FROM documents WHERE doc_id = ?",
|
| 90 |
+
(doc_id,),
|
| 91 |
+
)
|
| 92 |
+
row = cursor.fetchone()
|
| 93 |
+
if not row:
|
| 94 |
+
out.append(None)
|
| 95 |
+
continue
|
| 96 |
+
complete_document, main_content, images, source_type = row
|
| 97 |
+
out.append(
|
| 98 |
+
{
|
| 99 |
+
"complete_document": json.loads(complete_document or "{}"),
|
| 100 |
+
"main_content": main_content or "",
|
| 101 |
+
"images": json.loads(images or "[]"),
|
| 102 |
+
"source_type": source_type or "",
|
| 103 |
+
}
|
| 104 |
+
)
|
| 105 |
+
return out
|
| 106 |
+
finally:
|
| 107 |
+
conn.close()
|
| 108 |
+
|
| 109 |
+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
| 110 |
+
conn = self._connect()
|
| 111 |
+
try:
|
| 112 |
+
cursor = conn.cursor()
|
| 113 |
+
if prefix:
|
| 114 |
+
cursor.execute("SELECT doc_id FROM documents WHERE doc_id LIKE ?", (f"{prefix}%",))
|
| 115 |
+
else:
|
| 116 |
+
cursor.execute("SELECT doc_id FROM documents")
|
| 117 |
+
for (doc_id,) in cursor.fetchall():
|
| 118 |
+
yield str(doc_id)
|
| 119 |
+
finally:
|
| 120 |
+
conn.close()
|
| 121 |
+
|
| 122 |
+
def count(self) -> int:
|
| 123 |
+
conn = self._connect()
|
| 124 |
+
try:
|
| 125 |
+
cursor = conn.cursor()
|
| 126 |
+
cursor.execute("SELECT COUNT(*) FROM documents")
|
| 127 |
+
return int(cursor.fetchone()[0])
|
| 128 |
+
finally:
|
| 129 |
+
conn.close()
|
| 130 |
+
|
| 131 |
+
def count_by_source_type(self) -> dict:
|
| 132 |
+
conn = self._connect()
|
| 133 |
+
try:
|
| 134 |
+
cursor = conn.cursor()
|
| 135 |
+
cursor.execute("SELECT source_type, COUNT(*) FROM documents GROUP BY source_type")
|
| 136 |
+
counts = {}
|
| 137 |
+
for source_type, count in cursor.fetchall():
|
| 138 |
+
counts[str(source_type)] = int(count)
|
| 139 |
+
return counts
|
| 140 |
+
finally:
|
| 141 |
+
conn.close()
|
| 142 |
+
|
| 143 |
+
|
radiology_rag/embedding.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding utilities (OpenAI-compatible, API-first)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from langchain_openai import OpenAIEmbeddings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class EmbeddingConfig:
|
| 12 |
+
base_url: str
|
| 13 |
+
api_key: str
|
| 14 |
+
model_name: str
|
| 15 |
+
batch_size: int = 32
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EmbeddingClient:
|
| 19 |
+
"""Thin wrapper over LangChain OpenAIEmbeddings."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, cfg: EmbeddingConfig):
|
| 22 |
+
self.cfg = cfg
|
| 23 |
+
self._emb = OpenAIEmbeddings(
|
| 24 |
+
base_url=cfg.base_url,
|
| 25 |
+
api_key=cfg.api_key,
|
| 26 |
+
model=cfg.model_name,
|
| 27 |
+
chunk_size=int(cfg.batch_size or 32),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def embed_query(self, text: str) -> list[float]:
|
| 31 |
+
return self._emb.embed_query(text)
|
| 32 |
+
|
| 33 |
+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
| 34 |
+
return self._emb.embed_documents(texts)
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def langchain_embeddings(self) -> OpenAIEmbeddings:
|
| 38 |
+
"""Expose the underlying LangChain embeddings for Chroma."""
|
| 39 |
+
return self._emb
|
| 40 |
+
|
| 41 |
+
|
radiology_rag/encyclopedia.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wikipedia encyclopedia retrieval (MediaWiki API)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class WikipediaConfig:
|
| 17 |
+
language: str = "en"
|
| 18 |
+
user_agent: str = "RadiologyRAG-Space/1.0"
|
| 19 |
+
timeout_s: int = 15
|
| 20 |
+
max_chars_per_doc: int = 2000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WikipediaEncyclopediaService:
|
| 24 |
+
def __init__(self, config: Optional[WikipediaConfig] = None):
|
| 25 |
+
self.config = config or WikipediaConfig()
|
| 26 |
+
self._session = requests.Session()
|
| 27 |
+
self._session.headers.update({"User-Agent": self.config.user_agent})
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def api_base(self) -> str:
|
| 31 |
+
return f"https://{self.config.language}.wikipedia.org/w/api.php"
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def _derive_search_query(user_query: str) -> str:
|
| 35 |
+
q = (user_query or "").strip()
|
| 36 |
+
if not q:
|
| 37 |
+
return ""
|
| 38 |
+
|
| 39 |
+
tokens = re.findall(r"[A-Za-z][A-Za-z'\\-]*", q.lower())
|
| 40 |
+
if not tokens:
|
| 41 |
+
return q
|
| 42 |
+
|
| 43 |
+
stop = {
|
| 44 |
+
"what",
|
| 45 |
+
"which",
|
| 46 |
+
"who",
|
| 47 |
+
"whom",
|
| 48 |
+
"whose",
|
| 49 |
+
"when",
|
| 50 |
+
"where",
|
| 51 |
+
"why",
|
| 52 |
+
"how",
|
| 53 |
+
"is",
|
| 54 |
+
"are",
|
| 55 |
+
"was",
|
| 56 |
+
"were",
|
| 57 |
+
"be",
|
| 58 |
+
"been",
|
| 59 |
+
"being",
|
| 60 |
+
"do",
|
| 61 |
+
"does",
|
| 62 |
+
"did",
|
| 63 |
+
"can",
|
| 64 |
+
"could",
|
| 65 |
+
"should",
|
| 66 |
+
"would",
|
| 67 |
+
"may",
|
| 68 |
+
"might",
|
| 69 |
+
"will",
|
| 70 |
+
"shall",
|
| 71 |
+
"a",
|
| 72 |
+
"an",
|
| 73 |
+
"the",
|
| 74 |
+
"and",
|
| 75 |
+
"or",
|
| 76 |
+
"but",
|
| 77 |
+
"to",
|
| 78 |
+
"of",
|
| 79 |
+
"for",
|
| 80 |
+
"with",
|
| 81 |
+
"without",
|
| 82 |
+
"in",
|
| 83 |
+
"on",
|
| 84 |
+
"at",
|
| 85 |
+
"by",
|
| 86 |
+
"from",
|
| 87 |
+
"as",
|
| 88 |
+
"it",
|
| 89 |
+
"its",
|
| 90 |
+
"this",
|
| 91 |
+
"that",
|
| 92 |
+
"these",
|
| 93 |
+
"those",
|
| 94 |
+
"your",
|
| 95 |
+
"my",
|
| 96 |
+
"their",
|
| 97 |
+
"our",
|
| 98 |
+
"about",
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
keep_short = {"ct", "mr", "mri", "pet", "us", "cxr"}
|
| 102 |
+
keywords: List[str] = []
|
| 103 |
+
seen = set()
|
| 104 |
+
for t in tokens:
|
| 105 |
+
if t in stop:
|
| 106 |
+
continue
|
| 107 |
+
if len(t) < 3 and t not in keep_short:
|
| 108 |
+
continue
|
| 109 |
+
if t in seen:
|
| 110 |
+
continue
|
| 111 |
+
seen.add(t)
|
| 112 |
+
keywords.append(t)
|
| 113 |
+
|
| 114 |
+
return " ".join(keywords[:8]) if keywords else q
|
| 115 |
+
|
| 116 |
+
def retrieve(self, query: str, top_k: int = 5, max_chars_per_doc: Optional[int] = None) -> List[Dict[str, Any]]:
|
| 117 |
+
q = (query or "").strip()
|
| 118 |
+
if not q:
|
| 119 |
+
return []
|
| 120 |
+
|
| 121 |
+
search_q = self._derive_search_query(q)
|
| 122 |
+
if not search_q:
|
| 123 |
+
return []
|
| 124 |
+
|
| 125 |
+
max_chars = int(max_chars_per_doc or self.config.max_chars_per_doc)
|
| 126 |
+
try:
|
| 127 |
+
search_params = {
|
| 128 |
+
"action": "query",
|
| 129 |
+
"list": "search",
|
| 130 |
+
"srsearch": search_q,
|
| 131 |
+
"srlimit": max(1, min(int(top_k), 20)),
|
| 132 |
+
"format": "json",
|
| 133 |
+
}
|
| 134 |
+
resp = self._session.get(self.api_base, params=search_params, timeout=self.config.timeout_s)
|
| 135 |
+
resp.raise_for_status()
|
| 136 |
+
data = resp.json() or {}
|
| 137 |
+
hits = (data.get("query", {}) or {}).get("search", []) or []
|
| 138 |
+
|
| 139 |
+
# Fallback to raw query if rewrite yields no hits
|
| 140 |
+
if not hits and search_q != q:
|
| 141 |
+
search_params["srsearch"] = q
|
| 142 |
+
resp = self._session.get(self.api_base, params=search_params, timeout=self.config.timeout_s)
|
| 143 |
+
resp.raise_for_status()
|
| 144 |
+
data = resp.json() or {}
|
| 145 |
+
hits = (data.get("query", {}) or {}).get("search", []) or []
|
| 146 |
+
|
| 147 |
+
if not hits:
|
| 148 |
+
return []
|
| 149 |
+
|
| 150 |
+
pageids = [str(h.get("pageid")) for h in hits if h.get("pageid") is not None]
|
| 151 |
+
if not pageids:
|
| 152 |
+
return []
|
| 153 |
+
|
| 154 |
+
pages_params = {
|
| 155 |
+
"action": "query",
|
| 156 |
+
"pageids": "|".join(pageids),
|
| 157 |
+
"prop": "extracts|info",
|
| 158 |
+
"explaintext": 1,
|
| 159 |
+
"exintro": 1,
|
| 160 |
+
"exchars": max_chars,
|
| 161 |
+
"inprop": "url",
|
| 162 |
+
"format": "json",
|
| 163 |
+
}
|
| 164 |
+
resp2 = self._session.get(self.api_base, params=pages_params, timeout=self.config.timeout_s)
|
| 165 |
+
resp2.raise_for_status()
|
| 166 |
+
pages_data = resp2.json() or {}
|
| 167 |
+
pages = (pages_data.get("query", {}) or {}).get("pages", {}) or {}
|
| 168 |
+
|
| 169 |
+
docs: List[Dict[str, Any]] = []
|
| 170 |
+
for pid in pageids:
|
| 171 |
+
page = pages.get(pid) or {}
|
| 172 |
+
title = page.get("title") or ""
|
| 173 |
+
extract = (page.get("extract") or "").strip()
|
| 174 |
+
url = page.get("fullurl") or ""
|
| 175 |
+
if not title or not extract:
|
| 176 |
+
continue
|
| 177 |
+
docs.append(
|
| 178 |
+
{
|
| 179 |
+
"doc_id": f"encyclopedia_{pid}",
|
| 180 |
+
"source_type": "encyclopedia",
|
| 181 |
+
"title": title,
|
| 182 |
+
"content": extract,
|
| 183 |
+
"url": url,
|
| 184 |
+
"metadata": {"provider": "wikipedia", "pageid": pid},
|
| 185 |
+
"score": 0.0,
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return docs[: int(top_k)]
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.warning(f"Wikipedia retrieval failed: {e}")
|
| 192 |
+
return []
|
| 193 |
+
|
| 194 |
+
|
radiology_rag/gradio_compat.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compatibility patches for Gradio when running with newer FastAPI/Pydantic versions.
|
| 3 |
+
|
| 4 |
+
Background:
|
| 5 |
+
- Gradio 4.16 defines `gradio.data_classes.PredictBody.request: Optional[fastapi.Request]`.
|
| 6 |
+
- Under Pydantic v2, `fastapi/starlette Request` cannot be converted into a JSON schema,
|
| 7 |
+
which can crash FastAPI request parsing for Gradio's `/run/{api_name}` endpoint.
|
| 8 |
+
|
| 9 |
+
This module applies a targeted runtime patch that replaces that field with `Any`.
|
| 10 |
+
It is intentionally narrow and only runs when we detect the problematic combination.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def patch_gradio_predict_body_for_pydantic_v2() -> bool:
|
| 19 |
+
"""Return True if a patch was applied."""
|
| 20 |
+
try:
|
| 21 |
+
import pydantic
|
| 22 |
+
|
| 23 |
+
major = int(str(pydantic.__version__).split(".", 1)[0])
|
| 24 |
+
if major < 2:
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
import gradio.data_classes as gr_data_classes
|
| 28 |
+
import gradio.routes as gr_routes
|
| 29 |
+
from pydantic import BaseModel, ConfigDict, create_model
|
| 30 |
+
from typing import List, Optional
|
| 31 |
+
|
| 32 |
+
PredictBody = getattr(gr_data_classes, "PredictBody", None)
|
| 33 |
+
if PredictBody is None:
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
ann = getattr(PredictBody, "__annotations__", {}) or {}
|
| 37 |
+
if "request" not in ann:
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
PatchedPredictBody = create_model( # type: ignore[call-arg]
|
| 41 |
+
"PredictBody",
|
| 42 |
+
__base__=BaseModel,
|
| 43 |
+
__config__=ConfigDict(arbitrary_types_allowed=True),
|
| 44 |
+
session_hash=(Optional[str], None),
|
| 45 |
+
event_id=(Optional[str], None),
|
| 46 |
+
data=(List[Any], ...),
|
| 47 |
+
event_data=(Optional[Any], None),
|
| 48 |
+
fn_index=(Optional[int], None),
|
| 49 |
+
trigger_id=(Optional[int], None),
|
| 50 |
+
batched=(Optional[bool], False),
|
| 51 |
+
request=(Optional[Any], None),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
gr_data_classes.PredictBody = PatchedPredictBody # type: ignore[attr-defined]
|
| 55 |
+
gr_routes.PredictBody = PatchedPredictBody # type: ignore[attr-defined]
|
| 56 |
+
return True
|
| 57 |
+
except Exception:
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
|
radiology_rag/index_bootstrap.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Index bootstrap utilities for Hugging Face Spaces.
|
| 3 |
+
|
| 4 |
+
This Space relies on a prebuilt index stored on Hugging Face Datasets:
|
| 5 |
+
- ChromaDB persist directory (vector store)
|
| 6 |
+
- SQLite doc store (parent documents)
|
| 7 |
+
|
| 8 |
+
At startup we download (once) and place the index into a writable storage dir
|
| 9 |
+
(prefer /data on Spaces when persistent storage is enabled).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
from huggingface_hub import snapshot_download
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DEFAULT_INDEX_REPO_ID = "ZhangNy/radiology-index-qwen3-embedding-0.6b"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True)
|
| 31 |
+
class IndexPaths:
|
| 32 |
+
vector_db_path: Path
|
| 33 |
+
doc_store_path: Path
|
| 34 |
+
manifest_path: Optional[Path]
|
| 35 |
+
snapshot_dir: Optional[Path]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def resolve_default_storage_dir() -> Path:
|
| 39 |
+
"""
|
| 40 |
+
Determine a good default storage directory for Spaces.
|
| 41 |
+
|
| 42 |
+
Priority:
|
| 43 |
+
- $RAG_STORAGE_DIR (user override)
|
| 44 |
+
- /data/radiology_rag (Spaces persistent storage)
|
| 45 |
+
- ./storage (local)
|
| 46 |
+
"""
|
| 47 |
+
env = (os.getenv("RAG_STORAGE_DIR") or "").strip()
|
| 48 |
+
if env:
|
| 49 |
+
return Path(env)
|
| 50 |
+
if Path("/data").exists():
|
| 51 |
+
return Path("/data") / "radiology_rag"
|
| 52 |
+
return Path("./storage")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _find_index_artifacts(snapshot_dir: Path) -> Tuple[Path, Path, Optional[Path]]:
|
| 56 |
+
"""
|
| 57 |
+
Find (chroma_db_dir, doc_store_db, manifest_json) inside a HF snapshot.
|
| 58 |
+
|
| 59 |
+
We support either:
|
| 60 |
+
- chroma_db/, doc_store.db, manifest.json
|
| 61 |
+
- storage/chroma_db/, storage/doc_store.db, storage/manifest.json
|
| 62 |
+
"""
|
| 63 |
+
candidates = [
|
| 64 |
+
(snapshot_dir / "chroma_db", snapshot_dir / "doc_store.db", snapshot_dir / "manifest.json"),
|
| 65 |
+
(snapshot_dir / "storage" / "chroma_db", snapshot_dir / "storage" / "doc_store.db", snapshot_dir / "storage" / "manifest.json"),
|
| 66 |
+
]
|
| 67 |
+
for chroma_dir, doc_db, manifest in candidates:
|
| 68 |
+
if chroma_dir.exists() and chroma_dir.is_dir() and doc_db.exists() and doc_db.is_file():
|
| 69 |
+
return chroma_dir, doc_db, (manifest if manifest.exists() else None)
|
| 70 |
+
|
| 71 |
+
raise FileNotFoundError(
|
| 72 |
+
"Could not locate index artifacts inside snapshot. "
|
| 73 |
+
"Expected either {chroma_db/, doc_store.db} or {storage/chroma_db/, storage/doc_store.db}."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def read_manifest(manifest_path: Optional[Path]) -> Optional[Dict[str, Any]]:
|
| 78 |
+
if not manifest_path or not manifest_path.exists():
|
| 79 |
+
return None
|
| 80 |
+
try:
|
| 81 |
+
with open(manifest_path, "r", encoding="utf-8") as f:
|
| 82 |
+
return json.load(f) or {}
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.warning(f"Failed to read manifest.json: {e}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def ensure_index(
|
| 89 |
+
*,
|
| 90 |
+
repo_id: str = DEFAULT_INDEX_REPO_ID,
|
| 91 |
+
revision: Optional[str] = None,
|
| 92 |
+
target_vector_db_path: Optional[str] = None,
|
| 93 |
+
target_doc_store_path: Optional[str] = None,
|
| 94 |
+
storage_dir: Optional[str] = None,
|
| 95 |
+
force_download: bool = False,
|
| 96 |
+
) -> IndexPaths:
|
| 97 |
+
"""
|
| 98 |
+
Ensure the index exists locally at the configured storage paths.
|
| 99 |
+
|
| 100 |
+
Returns resolved IndexPaths; raises on unrecoverable errors.
|
| 101 |
+
"""
|
| 102 |
+
# Resolve target paths
|
| 103 |
+
if storage_dir:
|
| 104 |
+
base_dir = Path(storage_dir)
|
| 105 |
+
else:
|
| 106 |
+
base_dir = resolve_default_storage_dir()
|
| 107 |
+
base_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
vector_db_path = Path(target_vector_db_path) if target_vector_db_path else (base_dir / "chroma_db")
|
| 110 |
+
doc_store_path = Path(target_doc_store_path) if target_doc_store_path else (base_dir / "doc_store.db")
|
| 111 |
+
|
| 112 |
+
# Fast path: already present
|
| 113 |
+
if (
|
| 114 |
+
not force_download
|
| 115 |
+
and vector_db_path.exists()
|
| 116 |
+
and vector_db_path.is_dir()
|
| 117 |
+
and doc_store_path.exists()
|
| 118 |
+
and doc_store_path.is_file()
|
| 119 |
+
):
|
| 120 |
+
logger.info(f"Index already present: vector_db={vector_db_path} doc_store={doc_store_path}")
|
| 121 |
+
manifest_path = (base_dir / "manifest.json") if (base_dir / "manifest.json").exists() else None
|
| 122 |
+
return IndexPaths(vector_db_path=vector_db_path, doc_store_path=doc_store_path, manifest_path=manifest_path, snapshot_dir=None)
|
| 123 |
+
|
| 124 |
+
# Download snapshot
|
| 125 |
+
repo_id = (repo_id or "").strip() or DEFAULT_INDEX_REPO_ID
|
| 126 |
+
logger.info(f"Downloading index snapshot from HF dataset repo: {repo_id} (revision={revision or 'main'})")
|
| 127 |
+
snapshot_dir = Path(
|
| 128 |
+
snapshot_download(
|
| 129 |
+
repo_id=repo_id,
|
| 130 |
+
repo_type="dataset",
|
| 131 |
+
revision=revision or None,
|
| 132 |
+
local_files_only=False,
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
src_chroma_dir, src_doc_db, src_manifest = _find_index_artifacts(snapshot_dir)
|
| 137 |
+
logger.info(f"Found index artifacts in snapshot: chroma={src_chroma_dir} doc_store={src_doc_db}")
|
| 138 |
+
|
| 139 |
+
# Copy to writable target locations
|
| 140 |
+
if vector_db_path.exists():
|
| 141 |
+
shutil.rmtree(vector_db_path, ignore_errors=True)
|
| 142 |
+
vector_db_path.parent.mkdir(parents=True, exist_ok=True)
|
| 143 |
+
shutil.copytree(src_chroma_dir, vector_db_path, dirs_exist_ok=False)
|
| 144 |
+
|
| 145 |
+
doc_store_path.parent.mkdir(parents=True, exist_ok=True)
|
| 146 |
+
shutil.copy2(src_doc_db, doc_store_path)
|
| 147 |
+
|
| 148 |
+
manifest_path: Optional[Path] = None
|
| 149 |
+
if src_manifest and src_manifest.exists():
|
| 150 |
+
manifest_path = doc_store_path.parent / "manifest.json"
|
| 151 |
+
try:
|
| 152 |
+
shutil.copy2(src_manifest, manifest_path)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.warning(f"Failed to copy manifest.json: {e}")
|
| 155 |
+
manifest_path = None
|
| 156 |
+
|
| 157 |
+
logger.info(f"Index ready: vector_db={vector_db_path} doc_store={doc_store_path}")
|
| 158 |
+
return IndexPaths(vector_db_path=vector_db_path, doc_store_path=doc_store_path, manifest_path=manifest_path, snapshot_dir=snapshot_dir)
|
| 159 |
+
|
| 160 |
+
|
radiology_rag/rag.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG engine (Spaces-ready, API-first).
|
| 3 |
+
|
| 4 |
+
Pipeline:
|
| 5 |
+
1) Retrieve documents from prebuilt Chroma+SQLite index (local) + Wikipedia (optional)
|
| 6 |
+
2) Rerank (API; auto-disables if missing key)
|
| 7 |
+
3) Build a prompt with numbered citations [1], [2], ...
|
| 8 |
+
4) Call LLM (OpenAI-compatible) and stream answer
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import time
|
| 15 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
from langchain_core.prompts import PromptTemplate
|
| 18 |
+
from langchain_openai import ChatOpenAI
|
| 19 |
+
|
| 20 |
+
from radiology_rag.config import Config
|
| 21 |
+
from radiology_rag.citations import CitationManager
|
| 22 |
+
from radiology_rag.retrieval import MultiSourceRetrievalService, RetrievalService
|
| 23 |
+
from radiology_rag.reranker import RerankerConfig, RerankerService
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RAGEngine:
|
| 29 |
+
"""Retrieval-Augmented Generation engine for radiology queries."""
|
| 30 |
+
|
| 31 |
+
RAG_PROMPT_TEMPLATE = """You are a helpful radiology assistant with access to medical literature.
|
| 32 |
+
|
| 33 |
+
Answer the user's question based on the provided context.
|
| 34 |
+
|
| 35 |
+
**Rules**
|
| 36 |
+
1. Use the context as primary evidence.
|
| 37 |
+
2. Add numbered citations like [1], [2], ... immediately after the relevant sentences.
|
| 38 |
+
3. Do NOT invent citations. Only cite sources that appear in the context.
|
| 39 |
+
4. If the context does not contain enough information, say so and provide the best general explanation you can.
|
| 40 |
+
|
| 41 |
+
**Context**
|
| 42 |
+
{context}
|
| 43 |
+
|
| 44 |
+
**User Question**
|
| 45 |
+
{question}
|
| 46 |
+
|
| 47 |
+
**Answer (with citations)**
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, config: Config):
|
| 51 |
+
self.config = config
|
| 52 |
+
|
| 53 |
+
# Retrieval + rerank
|
| 54 |
+
self.retrieval_service = RetrievalService(config)
|
| 55 |
+
|
| 56 |
+
rr_cfg = RerankerConfig(
|
| 57 |
+
enabled=config.get_bool("reranker.enabled", True),
|
| 58 |
+
base_url=config.get_str("reranker.api_base_url"),
|
| 59 |
+
api_key=config.get_str("reranker.api_key"),
|
| 60 |
+
model_name=config.get_str("reranker.model_name"),
|
| 61 |
+
top_k=config.get_int("reranker.top_k", 10),
|
| 62 |
+
)
|
| 63 |
+
self.reranker_service = RerankerService(rr_cfg)
|
| 64 |
+
|
| 65 |
+
self.multi_source_retrieval_service = MultiSourceRetrievalService(
|
| 66 |
+
config=config,
|
| 67 |
+
retrieval_service=self.retrieval_service,
|
| 68 |
+
reranker_service=self.reranker_service,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.citation_manager = CitationManager(max_content_length=config.get_int("citation.max_content_length", 900))
|
| 72 |
+
|
| 73 |
+
# LLM (OpenAI-compatible)
|
| 74 |
+
self.llm = ChatOpenAI(
|
| 75 |
+
base_url=config.get_str("llm.base_url"),
|
| 76 |
+
api_key=config.get_str("llm.api_key"),
|
| 77 |
+
model=config.get_str("llm.model_name"),
|
| 78 |
+
temperature=config.get_float("llm.temperature", 0.7),
|
| 79 |
+
max_tokens=config.get_int("llm.max_tokens", 2000),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.prompt = PromptTemplate(template=self.RAG_PROMPT_TEMPLATE, input_variables=["context", "question"])
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def _normalize_retrieval_strategy(strategy: Optional[str]) -> str:
|
| 86 |
+
s = (strategy or "").strip() or "default"
|
| 87 |
+
if s not in {"default", "balanced_multi_source"}:
|
| 88 |
+
logger.warning(f"Unknown retrieval strategy '{s}', falling back to 'default'")
|
| 89 |
+
return "default"
|
| 90 |
+
return s
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def _format_context(documents: List[Dict[str, Any]], citation_indices: List[int]) -> str:
|
| 94 |
+
parts: List[str] = []
|
| 95 |
+
for doc, idx in zip(documents, citation_indices):
|
| 96 |
+
source_type = (doc.get("source_type") or "").upper()
|
| 97 |
+
title = doc.get("title") or "Untitled"
|
| 98 |
+
content = doc.get("content") or ""
|
| 99 |
+
url = doc.get("url") or ""
|
| 100 |
+
|
| 101 |
+
block = f"[{idx}] **{source_type}: {title}**\n{content}"
|
| 102 |
+
if url:
|
| 103 |
+
block += f"\nURL: {url}"
|
| 104 |
+
parts.append(block)
|
| 105 |
+
|
| 106 |
+
return "\n\n---\n\n".join(parts)
|
| 107 |
+
|
| 108 |
+
def _retrieve(self, *, question: str, top_k: Optional[int], source_filters: Optional[List[str]], retrieval_strategy: str):
|
| 109 |
+
if retrieval_strategy == "balanced_multi_source":
|
| 110 |
+
final_k = int(top_k or self.config.get_int("retrieval.multi_source.total_top_k", 8))
|
| 111 |
+
return self.multi_source_retrieval_service.retrieve(
|
| 112 |
+
query=question,
|
| 113 |
+
total_top_k=final_k,
|
| 114 |
+
source_filters=source_filters,
|
| 115 |
+
return_debug=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
k = int(top_k or self.config.get_int("retrieval.top_k", 20))
|
| 119 |
+
docs = self.retrieval_service.retrieve(query=question, top_k=k, source_filters=source_filters)
|
| 120 |
+
return docs, None
|
| 121 |
+
|
| 122 |
+
def query_stream(
|
| 123 |
+
self,
|
| 124 |
+
*,
|
| 125 |
+
question: str,
|
| 126 |
+
top_k: Optional[int] = None,
|
| 127 |
+
source_filters: Optional[List[str]] = None,
|
| 128 |
+
retrieval_strategy: Optional[str] = None,
|
| 129 |
+
stream_yield_interval_s: float = 0.15,
|
| 130 |
+
stream_min_chars: int = 80,
|
| 131 |
+
) -> Iterator[Dict[str, Any]]:
|
| 132 |
+
q = (question or "").strip()
|
| 133 |
+
if not q:
|
| 134 |
+
yield {
|
| 135 |
+
"type": "final",
|
| 136 |
+
"answer": "Please enter a question.",
|
| 137 |
+
"references": [],
|
| 138 |
+
"metadata": {"num_retrieved": 0, "num_reranked": 0, "retrieval_strategy": "default"},
|
| 139 |
+
}
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
self.citation_manager.clear()
|
| 143 |
+
retrieval_strategy_n = self._normalize_retrieval_strategy(
|
| 144 |
+
retrieval_strategy or self.config.get_str("retrieval.strategy", "default")
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
start_time = time.time()
|
| 148 |
+
|
| 149 |
+
# Step 1: retrieve
|
| 150 |
+
docs, retrieval_debug = self._retrieve(
|
| 151 |
+
question=q, top_k=top_k, source_filters=source_filters, retrieval_strategy=retrieval_strategy_n
|
| 152 |
+
)
|
| 153 |
+
if not docs:
|
| 154 |
+
yield {
|
| 155 |
+
"type": "final",
|
| 156 |
+
"answer": "I couldn't find any relevant information to answer your question.",
|
| 157 |
+
"references": [],
|
| 158 |
+
"metadata": {"num_retrieved": 0, "num_reranked": 0, "retrieval_strategy": retrieval_strategy_n},
|
| 159 |
+
}
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
# Step 2: rerank (default strategy only; balanced already reranks per-source)
|
| 163 |
+
if retrieval_strategy_n == "balanced_multi_source":
|
| 164 |
+
reranked_docs = sorted(docs, key=lambda d: float(d.get("score", 0.0)), reverse=True)
|
| 165 |
+
num_retrieved = int(sum((retrieval_debug.get("candidate_counts") or {}).values())) if retrieval_debug else len(docs)
|
| 166 |
+
retrieved_label = "Recalled (candidates)"
|
| 167 |
+
reranked_label = "Selected (after rerank)"
|
| 168 |
+
else:
|
| 169 |
+
reranked_docs = self.reranker_service.rerank(
|
| 170 |
+
query=q, documents=docs, top_k=self.config.get_int("reranker.top_k", 10)
|
| 171 |
+
)
|
| 172 |
+
num_retrieved = len(docs)
|
| 173 |
+
retrieved_label = "Retrieved"
|
| 174 |
+
reranked_label = "After Reranking"
|
| 175 |
+
|
| 176 |
+
# Step 3: citations + context + prompt
|
| 177 |
+
citation_indices = self.citation_manager.add_documents(reranked_docs)
|
| 178 |
+
context = self._format_context(reranked_docs, citation_indices)
|
| 179 |
+
prompt_text = self.prompt.format(context=context, question=q)
|
| 180 |
+
|
| 181 |
+
# Step 4: stream LLM
|
| 182 |
+
answer_parts: List[str] = []
|
| 183 |
+
buffered = ""
|
| 184 |
+
last_yield = time.monotonic()
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
for chunk in self.llm.stream(prompt_text):
|
| 188 |
+
delta = getattr(chunk, "content", "") or ""
|
| 189 |
+
if not delta:
|
| 190 |
+
continue
|
| 191 |
+
answer_parts.append(delta)
|
| 192 |
+
buffered += delta
|
| 193 |
+
now = time.monotonic()
|
| 194 |
+
if (now - last_yield) >= float(stream_yield_interval_s) or len(buffered) >= int(stream_min_chars):
|
| 195 |
+
yield {"type": "answer", "answer": "".join(answer_parts)}
|
| 196 |
+
buffered = ""
|
| 197 |
+
last_yield = now
|
| 198 |
+
|
| 199 |
+
answer = "".join(answer_parts).strip()
|
| 200 |
+
if not answer:
|
| 201 |
+
response = self.llm.invoke(prompt_text)
|
| 202 |
+
answer = (response.content or "").strip()
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(f"LLM streaming failed: {e}", exc_info=True)
|
| 205 |
+
try:
|
| 206 |
+
response = self.llm.invoke(prompt_text)
|
| 207 |
+
answer = (response.content or "").strip()
|
| 208 |
+
except Exception:
|
| 209 |
+
answer = "An error occurred while generating the answer. Please try again."
|
| 210 |
+
|
| 211 |
+
is_valid, invalid = self.citation_manager.validate_citations(answer)
|
| 212 |
+
elapsed = time.time() - start_time
|
| 213 |
+
source_dist = self.citation_manager.get_statistics().get("source_type_counts", {})
|
| 214 |
+
|
| 215 |
+
yield {
|
| 216 |
+
"type": "final",
|
| 217 |
+
"answer": answer,
|
| 218 |
+
"references": reranked_docs,
|
| 219 |
+
"metadata": {
|
| 220 |
+
"retrieval_strategy": retrieval_strategy_n,
|
| 221 |
+
"num_retrieved": num_retrieved,
|
| 222 |
+
"num_reranked": len(reranked_docs),
|
| 223 |
+
"retrieved_label": retrieved_label,
|
| 224 |
+
"reranked_label": reranked_label,
|
| 225 |
+
"citations_valid": is_valid,
|
| 226 |
+
"invalid_citations": invalid,
|
| 227 |
+
"source_type_distribution": source_dist,
|
| 228 |
+
"elapsed_time": elapsed,
|
| 229 |
+
"candidate_counts": (retrieval_debug.get("candidate_counts") if retrieval_debug else None),
|
| 230 |
+
"gated_counts": (retrieval_debug.get("gated_counts") if retrieval_debug else None),
|
| 231 |
+
"selected_counts": (retrieval_debug.get("selected_counts") if retrieval_debug else None),
|
| 232 |
+
},
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
def query(
|
| 236 |
+
self,
|
| 237 |
+
*,
|
| 238 |
+
question: str,
|
| 239 |
+
top_k: Optional[int] = None,
|
| 240 |
+
source_filters: Optional[List[str]] = None,
|
| 241 |
+
retrieval_strategy: Optional[str] = None,
|
| 242 |
+
) -> Dict[str, Any]:
|
| 243 |
+
"""Non-streaming convenience wrapper."""
|
| 244 |
+
final: Dict[str, Any] = {}
|
| 245 |
+
for event in self.query_stream(
|
| 246 |
+
question=question,
|
| 247 |
+
top_k=top_k,
|
| 248 |
+
source_filters=source_filters,
|
| 249 |
+
retrieval_strategy=retrieval_strategy,
|
| 250 |
+
stream_yield_interval_s=999.0,
|
| 251 |
+
stream_min_chars=10**9,
|
| 252 |
+
):
|
| 253 |
+
if event.get("type") == "final":
|
| 254 |
+
final = event
|
| 255 |
+
return final
|
| 256 |
+
|
| 257 |
+
|
radiology_rag/reranker.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranker client (OpenAI-compatible API).
|
| 3 |
+
|
| 4 |
+
Expected endpoint:
|
| 5 |
+
POST {base_url}/rerank
|
| 6 |
+
{
|
| 7 |
+
"model": "...",
|
| 8 |
+
"query": "...",
|
| 9 |
+
"documents": ["...", "..."],
|
| 10 |
+
"top_n": 10
|
| 11 |
+
}
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import time
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Dict, List, Optional
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class RerankerConfig:
|
| 29 |
+
enabled: bool
|
| 30 |
+
base_url: str
|
| 31 |
+
api_key: str
|
| 32 |
+
model_name: str
|
| 33 |
+
top_k: int = 10
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class NoOpReranker:
|
| 37 |
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
|
| 38 |
+
k = int(top_k or len(documents))
|
| 39 |
+
out = []
|
| 40 |
+
for d in documents[:k]:
|
| 41 |
+
dc = dict(d)
|
| 42 |
+
dc.setdefault("score", 0.0)
|
| 43 |
+
out.append(dc)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class APIReranker:
|
| 48 |
+
def __init__(self, cfg: RerankerConfig):
|
| 49 |
+
self.cfg = cfg
|
| 50 |
+
|
| 51 |
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
|
| 52 |
+
if not documents:
|
| 53 |
+
return []
|
| 54 |
+
|
| 55 |
+
texts = [doc.get("content", "") for doc in documents]
|
| 56 |
+
k = int(top_k or self.cfg.top_k or len(documents))
|
| 57 |
+
k = max(1, min(k, len(documents)))
|
| 58 |
+
|
| 59 |
+
max_docs_per_request = int(os.getenv("RERANK_MAX_DOCS_PER_REQUEST", "64"))
|
| 60 |
+
max_docs_per_request = max(1, max_docs_per_request)
|
| 61 |
+
|
| 62 |
+
def _call_once(doc_texts: List[str], top_n: int) -> dict:
|
| 63 |
+
last_err: Optional[Exception] = None
|
| 64 |
+
for attempt in range(3):
|
| 65 |
+
try:
|
| 66 |
+
resp = requests.post(
|
| 67 |
+
f"{self.cfg.base_url.rstrip('/')}/rerank",
|
| 68 |
+
json={
|
| 69 |
+
"model": self.cfg.model_name,
|
| 70 |
+
"query": query,
|
| 71 |
+
"documents": doc_texts,
|
| 72 |
+
"top_n": int(top_n),
|
| 73 |
+
},
|
| 74 |
+
headers={"Authorization": f"Bearer {self.cfg.api_key}"},
|
| 75 |
+
timeout=30,
|
| 76 |
+
)
|
| 77 |
+
resp.raise_for_status()
|
| 78 |
+
return resp.json() or {}
|
| 79 |
+
except Exception as e:
|
| 80 |
+
last_err = e
|
| 81 |
+
if attempt < 2:
|
| 82 |
+
time.sleep(0.5 * (attempt + 1))
|
| 83 |
+
raise last_err or RuntimeError("Unknown reranker API error")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
if len(texts) <= max_docs_per_request:
|
| 87 |
+
result = _call_once(texts, top_n=k)
|
| 88 |
+
reranked_docs: List[Dict[str, Any]] = []
|
| 89 |
+
for item in result.get("results", []) or []:
|
| 90 |
+
idx = item.get("index")
|
| 91 |
+
score = float(item.get("relevance_score", 0.0) or 0.0)
|
| 92 |
+
if idx is None:
|
| 93 |
+
continue
|
| 94 |
+
idx = int(idx)
|
| 95 |
+
if idx < 0 or idx >= len(documents):
|
| 96 |
+
continue
|
| 97 |
+
dc = dict(documents[idx])
|
| 98 |
+
dc["score"] = score
|
| 99 |
+
reranked_docs.append(dc)
|
| 100 |
+
return reranked_docs
|
| 101 |
+
|
| 102 |
+
# Chunked: score all docs per chunk then globally sort.
|
| 103 |
+
scored: List[Dict[str, Any]] = []
|
| 104 |
+
for offset in range(0, len(texts), max_docs_per_request):
|
| 105 |
+
chunk_texts = texts[offset : offset + max_docs_per_request]
|
| 106 |
+
result = _call_once(chunk_texts, top_n=len(chunk_texts))
|
| 107 |
+
for item in result.get("results", []) or []:
|
| 108 |
+
idx = item.get("index")
|
| 109 |
+
if idx is None:
|
| 110 |
+
continue
|
| 111 |
+
global_idx = offset + int(idx)
|
| 112 |
+
if global_idx < 0 or global_idx >= len(documents):
|
| 113 |
+
continue
|
| 114 |
+
score = float(item.get("relevance_score", 0.0) or 0.0)
|
| 115 |
+
dc = dict(documents[global_idx])
|
| 116 |
+
dc["score"] = score
|
| 117 |
+
scored.append(dc)
|
| 118 |
+
|
| 119 |
+
scored.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
|
| 120 |
+
return scored[:k]
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.warning(f"Reranker API failed; falling back to no-op ordering. Error: {e}")
|
| 123 |
+
return NoOpReranker().rerank(query, documents, top_k=k)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class RerankerService:
|
| 127 |
+
"""High-level reranker wrapper (auto-disables if misconfigured)."""
|
| 128 |
+
|
| 129 |
+
def __init__(self, cfg: RerankerConfig):
|
| 130 |
+
self.cfg = cfg
|
| 131 |
+
if not cfg.enabled:
|
| 132 |
+
self._impl = NoOpReranker()
|
| 133 |
+
return
|
| 134 |
+
if not (cfg.api_key or "").strip():
|
| 135 |
+
logger.warning("Reranker enabled but RERANK_API_KEY is empty; disabling reranker.")
|
| 136 |
+
self._impl = NoOpReranker()
|
| 137 |
+
return
|
| 138 |
+
self._impl = APIReranker(cfg)
|
| 139 |
+
|
| 140 |
+
def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
|
| 141 |
+
return self._impl.rerank(query, documents, top_k=top_k)
|
| 142 |
+
|
| 143 |
+
|
radiology_rag/retrieval.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector retrieval + balanced multi-source retrieval (with Wikipedia)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
from langchain_chroma import Chroma
|
| 10 |
+
|
| 11 |
+
from radiology_rag.config import Config
|
| 12 |
+
from radiology_rag.doc_store import PersistentDocStore
|
| 13 |
+
from radiology_rag.embedding import EmbeddingClient, EmbeddingConfig
|
| 14 |
+
from radiology_rag.encyclopedia import WikipediaConfig, WikipediaEncyclopediaService
|
| 15 |
+
from radiology_rag.reranker import RerankerService, RerankerConfig
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _dedupe_by_doc_id(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 21 |
+
seen = set()
|
| 22 |
+
out = []
|
| 23 |
+
for d in docs:
|
| 24 |
+
did = d.get("doc_id")
|
| 25 |
+
if not did or did in seen:
|
| 26 |
+
continue
|
| 27 |
+
seen.add(did)
|
| 28 |
+
out.append(d)
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RetrievalService:
|
| 33 |
+
"""Retrieve parent documents from Chroma + SQLite doc store."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, config: Config):
|
| 36 |
+
self.config = config
|
| 37 |
+
|
| 38 |
+
vector_db_path = config.get_str("storage.vector_db_path")
|
| 39 |
+
doc_store_path = config.get_str("storage.doc_store_path")
|
| 40 |
+
|
| 41 |
+
self.embedding_client = EmbeddingClient(
|
| 42 |
+
EmbeddingConfig(
|
| 43 |
+
base_url=config.get_str("embedding.api_base_url"),
|
| 44 |
+
api_key=config.get_str("embedding.api_key"),
|
| 45 |
+
model_name=config.get_str("embedding.model_name"),
|
| 46 |
+
batch_size=config.get_int("embedding.batch_size", 32),
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.doc_store = PersistentDocStore(doc_store_path, read_only=True)
|
| 51 |
+
self.vectorstore = Chroma(
|
| 52 |
+
collection_name="radiology_docs",
|
| 53 |
+
embedding_function=self.embedding_client.langchain_embeddings,
|
| 54 |
+
persist_directory=vector_db_path,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.search_type = config.get_str("retrieval.search_type", "similarity")
|
| 58 |
+
self.chunk_fetch_multiplier = config.get_int("retrieval.chunk_fetch_multiplier", 6)
|
| 59 |
+
|
| 60 |
+
def retrieve_candidates_by_vector(
|
| 61 |
+
self,
|
| 62 |
+
*,
|
| 63 |
+
query_embedding: List[float],
|
| 64 |
+
candidate_k: int,
|
| 65 |
+
source_type_filter: Optional[str] = None,
|
| 66 |
+
) -> List[Dict[str, Any]]:
|
| 67 |
+
k = max(1, int(candidate_k))
|
| 68 |
+
chunk_k = max(k * int(self.chunk_fetch_multiplier), k)
|
| 69 |
+
|
| 70 |
+
chunk_filter = {"source_type": source_type_filter} if source_type_filter else None
|
| 71 |
+
|
| 72 |
+
# Fetch chunks from vector store using the embedding vector
|
| 73 |
+
if self.search_type == "mmr":
|
| 74 |
+
mmr_lambda = self.config.get_float("retrieval.mmr_lambda", 0.5)
|
| 75 |
+
fetch_k = max(self.config.get_int("retrieval.mmr_fetch_k", 50), chunk_k)
|
| 76 |
+
try:
|
| 77 |
+
chunk_docs = self.vectorstore.max_marginal_relevance_search_by_vector(
|
| 78 |
+
query_embedding,
|
| 79 |
+
k=chunk_k,
|
| 80 |
+
fetch_k=fetch_k,
|
| 81 |
+
lambda_mult=mmr_lambda,
|
| 82 |
+
filter=chunk_filter,
|
| 83 |
+
)
|
| 84 |
+
except TypeError:
|
| 85 |
+
chunk_docs = self.vectorstore.max_marginal_relevance_search_by_vector(
|
| 86 |
+
query_embedding,
|
| 87 |
+
k=chunk_k,
|
| 88 |
+
fetch_k=fetch_k,
|
| 89 |
+
lambda_mult=mmr_lambda,
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
try:
|
| 93 |
+
chunk_docs = self.vectorstore.similarity_search_by_vector(
|
| 94 |
+
query_embedding,
|
| 95 |
+
k=chunk_k,
|
| 96 |
+
filter=chunk_filter,
|
| 97 |
+
)
|
| 98 |
+
except TypeError:
|
| 99 |
+
chunk_docs = self.vectorstore.similarity_search_by_vector(
|
| 100 |
+
query_embedding,
|
| 101 |
+
k=chunk_k,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Unique parent IDs
|
| 105 |
+
parent_ids: List[str] = []
|
| 106 |
+
seen = set()
|
| 107 |
+
for doc in chunk_docs:
|
| 108 |
+
parent_id = (doc.metadata or {}).get("parent_id")
|
| 109 |
+
if not parent_id or parent_id in seen:
|
| 110 |
+
continue
|
| 111 |
+
seen.add(parent_id)
|
| 112 |
+
parent_ids.append(parent_id)
|
| 113 |
+
if len(parent_ids) >= k:
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
# Hydrate parents from doc store
|
| 117 |
+
parent_docs = self.doc_store.mget(parent_ids)
|
| 118 |
+
results: List[Dict[str, Any]] = []
|
| 119 |
+
for doc_id, doc_content in zip(parent_ids, parent_docs):
|
| 120 |
+
if doc_content is None:
|
| 121 |
+
continue
|
| 122 |
+
complete = doc_content.get("complete_document", {}) or {}
|
| 123 |
+
results.append(
|
| 124 |
+
{
|
| 125 |
+
"doc_id": doc_id,
|
| 126 |
+
"source_type": doc_content.get("source_type", "") or "",
|
| 127 |
+
"title": complete.get("title", "") or "",
|
| 128 |
+
"content": doc_content.get("main_content", "") or "",
|
| 129 |
+
"url": complete.get("url", "") or "",
|
| 130 |
+
"metadata": complete.get("metadata", {}) or {},
|
| 131 |
+
"score": 0.0,
|
| 132 |
+
}
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
logger.info(
|
| 136 |
+
f"Retrieved {len(results)} candidate parents (candidate_k={k}, source_type_filter={source_type_filter or 'ALL'})"
|
| 137 |
+
)
|
| 138 |
+
return results
|
| 139 |
+
|
| 140 |
+
def retrieve(
|
| 141 |
+
self,
|
| 142 |
+
*,
|
| 143 |
+
query: str,
|
| 144 |
+
top_k: int,
|
| 145 |
+
source_filters: Optional[List[str]] = None,
|
| 146 |
+
) -> List[Dict[str, Any]]:
|
| 147 |
+
q = (query or "").strip()
|
| 148 |
+
if not q:
|
| 149 |
+
return []
|
| 150 |
+
|
| 151 |
+
query_embedding = self.embedding_client.embed_query(q)
|
| 152 |
+
k = max(1, int(top_k))
|
| 153 |
+
|
| 154 |
+
if not source_filters:
|
| 155 |
+
return self.retrieve_candidates_by_vector(query_embedding=query_embedding, candidate_k=k)
|
| 156 |
+
|
| 157 |
+
allowed = [s for s in source_filters if s in {"article", "case", "tutorial"}]
|
| 158 |
+
if not allowed:
|
| 159 |
+
# If only encyclopedia is selected, local retrieval yields none.
|
| 160 |
+
return []
|
| 161 |
+
|
| 162 |
+
if len(allowed) == 1:
|
| 163 |
+
return self.retrieve_candidates_by_vector(query_embedding=query_embedding, candidate_k=k, source_type_filter=allowed[0])
|
| 164 |
+
|
| 165 |
+
merged: List[Dict[str, Any]] = []
|
| 166 |
+
for st in allowed:
|
| 167 |
+
merged.extend(
|
| 168 |
+
self.retrieve_candidates_by_vector(query_embedding=query_embedding, candidate_k=k, source_type_filter=st)
|
| 169 |
+
)
|
| 170 |
+
merged = _dedupe_by_doc_id(merged)
|
| 171 |
+
return merged[:k]
|
| 172 |
+
|
| 173 |
+
def get_document_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
| 174 |
+
doc_id = (doc_id or "").strip()
|
| 175 |
+
if not doc_id:
|
| 176 |
+
return None
|
| 177 |
+
docs = self.doc_store.mget([doc_id])
|
| 178 |
+
if not docs or not docs[0]:
|
| 179 |
+
return None
|
| 180 |
+
doc_content = docs[0]
|
| 181 |
+
complete = doc_content.get("complete_document", {}) or {}
|
| 182 |
+
return {
|
| 183 |
+
"doc_id": doc_id,
|
| 184 |
+
"source_type": doc_content.get("source_type", "") or "",
|
| 185 |
+
"title": complete.get("title", "") or "",
|
| 186 |
+
"content": doc_content.get("main_content", "") or "",
|
| 187 |
+
"url": complete.get("url", "") or "",
|
| 188 |
+
"metadata": complete.get("metadata", {}) or {},
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@dataclass
|
| 193 |
+
class PerSourcePolicy:
|
| 194 |
+
candidate_k: int
|
| 195 |
+
max_k: int
|
| 196 |
+
min_score: float
|
| 197 |
+
required: bool = False
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclass
|
| 201 |
+
class BalancedMultiSourcePolicy:
|
| 202 |
+
total_top_k: int = 8
|
| 203 |
+
sources_priority: Sequence[str] = ("article", "case", "encyclopedia", "tutorial")
|
| 204 |
+
article: PerSourcePolicy = field(
|
| 205 |
+
default_factory=lambda: PerSourcePolicy(candidate_k=200, max_k=3, min_score=0.15, required=True)
|
| 206 |
+
)
|
| 207 |
+
case: PerSourcePolicy = field(
|
| 208 |
+
default_factory=lambda: PerSourcePolicy(candidate_k=200, max_k=3, min_score=0.15, required=True)
|
| 209 |
+
)
|
| 210 |
+
encyclopedia: PerSourcePolicy = field(
|
| 211 |
+
default_factory=lambda: PerSourcePolicy(candidate_k=8, max_k=2, min_score=0.15, required=True)
|
| 212 |
+
)
|
| 213 |
+
tutorial: PerSourcePolicy = field(
|
| 214 |
+
default_factory=lambda: PerSourcePolicy(candidate_k=20, max_k=2, min_score=0.50, required=False)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class MultiSourceRetrievalService:
|
| 219 |
+
"""Per-source recall + per-source rerank + gating + merge (includes Wikipedia)."""
|
| 220 |
+
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
config: Config,
|
| 224 |
+
retrieval_service: Optional[RetrievalService] = None,
|
| 225 |
+
reranker_service: Optional[RerankerService] = None,
|
| 226 |
+
encyclopedia_service: Optional[WikipediaEncyclopediaService] = None,
|
| 227 |
+
):
|
| 228 |
+
self.config = config
|
| 229 |
+
self.retrieval_service = retrieval_service or RetrievalService(config)
|
| 230 |
+
|
| 231 |
+
rr_cfg = RerankerConfig(
|
| 232 |
+
enabled=config.get_bool("reranker.enabled", True),
|
| 233 |
+
base_url=config.get_str("reranker.api_base_url"),
|
| 234 |
+
api_key=config.get_str("reranker.api_key"),
|
| 235 |
+
model_name=config.get_str("reranker.model_name"),
|
| 236 |
+
top_k=config.get_int("reranker.top_k", 10),
|
| 237 |
+
)
|
| 238 |
+
self.reranker_service = reranker_service or RerankerService(rr_cfg)
|
| 239 |
+
|
| 240 |
+
wiki_cfg = WikipediaConfig(
|
| 241 |
+
language=config.get_str("encyclopedia.wikipedia.language", "en"),
|
| 242 |
+
user_agent=config.get_str("encyclopedia.wikipedia.user_agent", "RadiologyRAG-Space/1.0"),
|
| 243 |
+
timeout_s=config.get_int("encyclopedia.wikipedia.timeout_s", 15),
|
| 244 |
+
max_chars_per_doc=config.get_int("encyclopedia.wikipedia.max_chars_per_doc", 2000),
|
| 245 |
+
)
|
| 246 |
+
self.encyclopedia_service = encyclopedia_service or WikipediaEncyclopediaService(wiki_cfg)
|
| 247 |
+
|
| 248 |
+
def _load_policy(self, total_top_k: Optional[int] = None) -> BalancedMultiSourcePolicy:
|
| 249 |
+
total = int(total_top_k or self.config.get_int("retrieval.multi_source.total_top_k", 8))
|
| 250 |
+
total = max(1, total)
|
| 251 |
+
|
| 252 |
+
def pol(name: str, default: PerSourcePolicy) -> PerSourcePolicy:
|
| 253 |
+
base = f"retrieval.multi_source.{name}"
|
| 254 |
+
return PerSourcePolicy(
|
| 255 |
+
candidate_k=self.config.get_int(f"{base}.candidate_k", default.candidate_k),
|
| 256 |
+
max_k=self.config.get_int(f"{base}.max_k", default.max_k),
|
| 257 |
+
min_score=self.config.get_float(f"{base}.min_score", default.min_score),
|
| 258 |
+
required=self.config.get_bool(f"{base}.required", default.required),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
defaults = BalancedMultiSourcePolicy(total_top_k=total)
|
| 262 |
+
sources_priority = self.config.get("retrieval.multi_source.sources_priority", defaults.sources_priority)
|
| 263 |
+
sources_priority = tuple(sources_priority) if isinstance(sources_priority, (list, tuple)) else defaults.sources_priority
|
| 264 |
+
|
| 265 |
+
return BalancedMultiSourcePolicy(
|
| 266 |
+
total_top_k=total,
|
| 267 |
+
sources_priority=sources_priority,
|
| 268 |
+
article=pol("article", defaults.article),
|
| 269 |
+
case=pol("case", defaults.case),
|
| 270 |
+
encyclopedia=pol("encyclopedia", defaults.encyclopedia),
|
| 271 |
+
tutorial=pol("tutorial", defaults.tutorial),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def _gate_and_trim(docs: List[Dict[str, Any]], policy: PerSourcePolicy) -> List[Dict[str, Any]]:
|
| 276 |
+
filtered = [d for d in docs if float(d.get("score", 0.0)) >= float(policy.min_score)]
|
| 277 |
+
return filtered[: max(0, int(policy.max_k))]
|
| 278 |
+
|
| 279 |
+
def retrieve(
|
| 280 |
+
self,
|
| 281 |
+
*,
|
| 282 |
+
query: str,
|
| 283 |
+
total_top_k: Optional[int] = None,
|
| 284 |
+
source_filters: Optional[List[str]] = None,
|
| 285 |
+
return_debug: bool = False,
|
| 286 |
+
) -> Any:
|
| 287 |
+
q = (query or "").strip()
|
| 288 |
+
if not q:
|
| 289 |
+
return ([], {}) if return_debug else []
|
| 290 |
+
|
| 291 |
+
policy = self._load_policy(total_top_k=total_top_k)
|
| 292 |
+
allowed = set(source_filters) if source_filters else set(policy.sources_priority)
|
| 293 |
+
allowed = {s for s in allowed if s in {"article", "case", "tutorial", "encyclopedia"}}
|
| 294 |
+
if not allowed:
|
| 295 |
+
allowed = set(policy.sources_priority)
|
| 296 |
+
|
| 297 |
+
# Compute query embedding once for local sources
|
| 298 |
+
needs_local = any(s in allowed for s in {"article", "case", "tutorial"})
|
| 299 |
+
query_embedding = self.retrieval_service.embedding_client.embed_query(q) if needs_local else None
|
| 300 |
+
|
| 301 |
+
# 1) Recall candidates
|
| 302 |
+
candidates: Dict[str, List[Dict[str, Any]]] = {}
|
| 303 |
+
if "article" in allowed and query_embedding is not None:
|
| 304 |
+
candidates["article"] = self.retrieval_service.retrieve_candidates_by_vector(
|
| 305 |
+
query_embedding=query_embedding, candidate_k=policy.article.candidate_k, source_type_filter="article"
|
| 306 |
+
)
|
| 307 |
+
if "case" in allowed and query_embedding is not None:
|
| 308 |
+
candidates["case"] = self.retrieval_service.retrieve_candidates_by_vector(
|
| 309 |
+
query_embedding=query_embedding, candidate_k=policy.case.candidate_k, source_type_filter="case"
|
| 310 |
+
)
|
| 311 |
+
if "tutorial" in allowed and query_embedding is not None:
|
| 312 |
+
candidates["tutorial"] = self.retrieval_service.retrieve_candidates_by_vector(
|
| 313 |
+
query_embedding=query_embedding, candidate_k=policy.tutorial.candidate_k, source_type_filter="tutorial"
|
| 314 |
+
)
|
| 315 |
+
if "encyclopedia" in allowed:
|
| 316 |
+
candidates["encyclopedia"] = self.encyclopedia_service.retrieve(
|
| 317 |
+
q,
|
| 318 |
+
top_k=policy.encyclopedia.candidate_k,
|
| 319 |
+
max_chars_per_doc=self.config.get_int("encyclopedia.wikipedia.max_chars_per_doc", 2000),
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# 2) Rerank per-source (full ordering; trim later)
|
| 323 |
+
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
| 324 |
+
for src, docs in candidates.items():
|
| 325 |
+
if not docs:
|
| 326 |
+
reranked[src] = []
|
| 327 |
+
continue
|
| 328 |
+
rr = self.reranker_service.rerank(query=q, documents=docs, top_k=len(docs))
|
| 329 |
+
reranked[src] = rr
|
| 330 |
+
|
| 331 |
+
# 3) Gating
|
| 332 |
+
gated: Dict[str, List[Dict[str, Any]]] = {}
|
| 333 |
+
if "article" in reranked:
|
| 334 |
+
gated["article"] = self._gate_and_trim(reranked["article"], policy.article)
|
| 335 |
+
if "case" in reranked:
|
| 336 |
+
gated["case"] = self._gate_and_trim(reranked["case"], policy.case)
|
| 337 |
+
if "encyclopedia" in reranked:
|
| 338 |
+
gated["encyclopedia"] = self._gate_and_trim(reranked["encyclopedia"], policy.encyclopedia)
|
| 339 |
+
|
| 340 |
+
if "tutorial" in reranked:
|
| 341 |
+
tdocs = reranked["tutorial"]
|
| 342 |
+
best = float(tdocs[0].get("score", 0.0)) if tdocs else 0.0
|
| 343 |
+
if best >= policy.tutorial.min_score:
|
| 344 |
+
gated["tutorial"] = tdocs[: max(0, int(policy.tutorial.max_k))]
|
| 345 |
+
else:
|
| 346 |
+
gated["tutorial"] = []
|
| 347 |
+
|
| 348 |
+
# 4) Merge with global budget
|
| 349 |
+
selected: List[Dict[str, Any]] = []
|
| 350 |
+
selected_ids = set()
|
| 351 |
+
|
| 352 |
+
def _add(doc: Dict[str, Any]) -> None:
|
| 353 |
+
did = doc.get("doc_id")
|
| 354 |
+
if not did or did in selected_ids:
|
| 355 |
+
return
|
| 356 |
+
selected.append(doc)
|
| 357 |
+
selected_ids.add(did)
|
| 358 |
+
|
| 359 |
+
src_to_pol = {
|
| 360 |
+
"article": policy.article,
|
| 361 |
+
"case": policy.case,
|
| 362 |
+
"encyclopedia": policy.encyclopedia,
|
| 363 |
+
"tutorial": policy.tutorial,
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
# required pass
|
| 367 |
+
for src in policy.sources_priority:
|
| 368 |
+
if src not in allowed:
|
| 369 |
+
continue
|
| 370 |
+
p = src_to_pol[src]
|
| 371 |
+
if not p.required:
|
| 372 |
+
continue
|
| 373 |
+
if gated.get(src):
|
| 374 |
+
_add(gated[src][0])
|
| 375 |
+
|
| 376 |
+
# fill by global score
|
| 377 |
+
remaining_pool: List[Dict[str, Any]] = []
|
| 378 |
+
for docs in gated.values():
|
| 379 |
+
for d in docs:
|
| 380 |
+
if d.get("doc_id") not in selected_ids:
|
| 381 |
+
remaining_pool.append(d)
|
| 382 |
+
remaining_pool.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
|
| 383 |
+
|
| 384 |
+
for d in remaining_pool:
|
| 385 |
+
if len(selected) >= policy.total_top_k:
|
| 386 |
+
break
|
| 387 |
+
_add(d)
|
| 388 |
+
|
| 389 |
+
debug = {
|
| 390 |
+
"allowed_sources": sorted(list(allowed)),
|
| 391 |
+
"candidate_counts": {k: len(v) for k, v in candidates.items()},
|
| 392 |
+
"gated_counts": {k: len(v) for k, v in gated.items()},
|
| 393 |
+
"selected_counts": {},
|
| 394 |
+
}
|
| 395 |
+
for d in selected:
|
| 396 |
+
st = d.get("source_type", "unknown")
|
| 397 |
+
debug["selected_counts"][st] = debug["selected_counts"].get(st, 0) + 1
|
| 398 |
+
|
| 399 |
+
if return_debug:
|
| 400 |
+
return selected, debug
|
| 401 |
+
return selected
|
| 402 |
+
|
| 403 |
+
|
radiology_rag/ui.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio UI for the Radiology RAG Space."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
from radiology_rag.config import Config
|
| 15 |
+
from radiology_rag.index_bootstrap import ensure_index, read_manifest
|
| 16 |
+
from radiology_rag.rag import RAGEngine
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _truncate(text: str, max_len: int) -> str:
|
| 22 |
+
s = (text or "").strip()
|
| 23 |
+
if len(s) <= max_len:
|
| 24 |
+
return s
|
| 25 |
+
return s[: max(0, max_len - 3)] + "..."
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def format_error_message(error: str) -> str:
|
| 29 |
+
return f"**⚠️ Error**\n\n{error}"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def format_loading_message() -> str:
|
| 33 |
+
return "**🔄 Processing your query...**\n\nRetrieving relevant sources and generating an answer with citations."
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def format_reference_card(doc: Dict[str, Any], index: int) -> str:
|
| 37 |
+
title = doc.get("title", "Untitled") or "Untitled"
|
| 38 |
+
source_type = (doc.get("source_type") or "").upper()
|
| 39 |
+
url = doc.get("url", "") or ""
|
| 40 |
+
content = doc.get("content", "") or ""
|
| 41 |
+
score = float(doc.get("score", 0.0) or 0.0)
|
| 42 |
+
|
| 43 |
+
max_preview_length = 350
|
| 44 |
+
preview = _truncate(content, max_preview_length).replace("\n", " ")
|
| 45 |
+
|
| 46 |
+
type_colors = {
|
| 47 |
+
"ARTICLE": "#3b82f6",
|
| 48 |
+
"CASE": "#10b981",
|
| 49 |
+
"TUTORIAL": "#f59e0b",
|
| 50 |
+
"ENCYCLOPEDIA": "#8b5cf6",
|
| 51 |
+
}
|
| 52 |
+
color = type_colors.get(source_type, "#6b7280")
|
| 53 |
+
|
| 54 |
+
score_html = f"<span style='color:#6b7280;font-size:12px;'>Score: {score:.3f}</span>" if score > 0 else ""
|
| 55 |
+
url_html = (
|
| 56 |
+
f"<p style='margin:0 0 8px 0;font-size:12px;'><a href='{url}' target='_blank' "
|
| 57 |
+
f"style='color:#3b82f6;text-decoration:none;'>🔗 View Source</a></p>"
|
| 58 |
+
if url
|
| 59 |
+
else ""
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return f"""
|
| 63 |
+
<div id="ref-{index}" style="border:1px solid #e5e7eb;border-radius:8px;padding:16px;margin-bottom:16px;background:white;scroll-margin-top:90px;">
|
| 64 |
+
<div style="display:flex;align-items:center;gap:8px;margin-bottom:12px;flex-wrap:wrap;">
|
| 65 |
+
<span style="background:{color};color:white;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;">
|
| 66 |
+
{source_type or "SOURCE"}
|
| 67 |
+
</span>
|
| 68 |
+
<span style="background:#f3f4f6;color:#374151;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;">
|
| 69 |
+
[{index}]
|
| 70 |
+
</span>
|
| 71 |
+
{score_html}
|
| 72 |
+
</div>
|
| 73 |
+
<h3 style="margin:0 0 8px 0;color:#111827;font-size:18px;">{title}</h3>
|
| 74 |
+
{url_html}
|
| 75 |
+
<p style="margin:0;color:#4b5563;font-size:14px;line-height:1.5;">{preview}</p>
|
| 76 |
+
</div>
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def format_reference_panel(references: List[Dict[str, Any]]) -> str:
|
| 81 |
+
if not references:
|
| 82 |
+
return "<p style='color:#6b7280;text-align:center;padding:20px;'>No references available</p>"
|
| 83 |
+
html_parts = ['<div style="max-height: 600px; overflow-y: auto;">']
|
| 84 |
+
for i, doc in enumerate(references, 1):
|
| 85 |
+
html_parts.append(format_reference_card(doc, i))
|
| 86 |
+
html_parts.append("</div>")
|
| 87 |
+
return "".join(html_parts)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def format_statistics(metadata: Dict[str, Any]) -> str:
|
| 91 |
+
num_retrieved = int(metadata.get("num_retrieved", 0) or 0)
|
| 92 |
+
num_reranked = int(metadata.get("num_reranked", 0) or 0)
|
| 93 |
+
source_dist = metadata.get("source_type_distribution", {}) or {}
|
| 94 |
+
retrieved_label = metadata.get("retrieved_label", "Retrieved")
|
| 95 |
+
reranked_label = metadata.get("reranked_label", "After Reranking")
|
| 96 |
+
elapsed = float(metadata.get("elapsed_time", 0.0) or 0.0)
|
| 97 |
+
strategy = metadata.get("retrieval_strategy", "")
|
| 98 |
+
|
| 99 |
+
chips = "".join(
|
| 100 |
+
[
|
| 101 |
+
f"<span style='display:inline-block;background:#e5e7eb;color:#111827;padding:4px 8px;border-radius:4px;margin-right:8px;font-size:12px;line-height:1.2;'>{k}: {v}</span>"
|
| 102 |
+
for k, v in source_dist.items()
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return f"""
|
| 107 |
+
<div style="background:#f9fafb;padding:16px;border-radius:8px;margin-top:16px;">
|
| 108 |
+
<h4 style="margin:0 0 12px 0;color:#374151;font-size:14px;">📊 Query Statistics</h4>
|
| 109 |
+
<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(150px,1fr));gap:12px;">
|
| 110 |
+
<div>
|
| 111 |
+
<p style="margin:0;color:#6b7280;font-size:12px;">{retrieved_label}</p>
|
| 112 |
+
<p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_retrieved}</p>
|
| 113 |
+
</div>
|
| 114 |
+
<div>
|
| 115 |
+
<p style="margin:0;color:#6b7280;font-size:12px;">{reranked_label}</p>
|
| 116 |
+
<p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_reranked}</p>
|
| 117 |
+
</div>
|
| 118 |
+
<div>
|
| 119 |
+
<p style="margin:0;color:#6b7280;font-size:12px;">Elapsed</p>
|
| 120 |
+
<p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{elapsed:.2f}s</p>
|
| 121 |
+
</div>
|
| 122 |
+
</div>
|
| 123 |
+
<div style="margin-top:12px;">
|
| 124 |
+
<p style="margin:0 0 6px 0;color:#6b7280;font-size:12px;">Retrieval Strategy: <code>{strategy}</code></p>
|
| 125 |
+
<p style="margin:0 0 4px 0;color:#6b7280;font-size:12px;">Source Distribution:</p>
|
| 126 |
+
{chips if chips else "<span style='color:#6b7280;font-size:12px;'>N/A</span>"}
|
| 127 |
+
</div>
|
| 128 |
+
</div>
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_settings_accordion(
|
| 133 |
+
*,
|
| 134 |
+
default_strategy: str,
|
| 135 |
+
default_temperature: float,
|
| 136 |
+
default_sources: List[str],
|
| 137 |
+
) -> Tuple[gr.Radio, gr.Slider, gr.CheckboxGroup]:
|
| 138 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 139 |
+
gr.Markdown(
|
| 140 |
+
"#### Retrieval Strategy\n"
|
| 141 |
+
"- **default**: one mixed retrieval + single rerank (fast)\n"
|
| 142 |
+
"- **balanced_multi_source**: per-source recall + per-source rerank + Wikipedia (more diverse)\n"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
retrieval_strategy = gr.Radio(
|
| 146 |
+
choices=["default", "balanced_multi_source"],
|
| 147 |
+
value=default_strategy,
|
| 148 |
+
label="Retrieval Strategy",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
temperature_slider = gr.Slider(
|
| 152 |
+
minimum=0.0,
|
| 153 |
+
maximum=1.0,
|
| 154 |
+
value=float(default_temperature),
|
| 155 |
+
step=0.1,
|
| 156 |
+
label="LLM Temperature",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
source_filter = gr.CheckboxGroup(
|
| 160 |
+
choices=["article", "case", "tutorial", "encyclopedia"],
|
| 161 |
+
value=default_sources,
|
| 162 |
+
label="Filter by Source Type",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return retrieval_strategy, temperature_slider, source_filter
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class RadiologyRAGApp:
|
| 169 |
+
def __init__(self, config_path: str):
|
| 170 |
+
self.config = Config(config_path)
|
| 171 |
+
self.startup_error: Optional[str] = None
|
| 172 |
+
self.startup_warnings: List[str] = []
|
| 173 |
+
self.index_manifest: Optional[Dict[str, Any]] = None
|
| 174 |
+
self.rag_engine: Optional[RAGEngine] = None
|
| 175 |
+
|
| 176 |
+
# Validate required secrets
|
| 177 |
+
missing: List[str] = []
|
| 178 |
+
if not self.config.get_str("embedding.api_key"):
|
| 179 |
+
missing.append("EMBED_API_KEY")
|
| 180 |
+
if not self.config.get_str("llm.api_key"):
|
| 181 |
+
missing.append("LLM_API_KEY")
|
| 182 |
+
if missing:
|
| 183 |
+
self.startup_error = (
|
| 184 |
+
"Missing required Hugging Face Space Secrets: "
|
| 185 |
+
+ ", ".join([f"`{m}`" for m in missing])
|
| 186 |
+
+ ".\n\nPlease set them in the Space **Settings → Secrets** and restart the Space."
|
| 187 |
+
)
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
# Reranker is optional; warn if enabled but missing key
|
| 191 |
+
if self.config.get_bool("reranker.enabled", True) and not self.config.get_str("reranker.api_key"):
|
| 192 |
+
self.startup_warnings.append(
|
| 193 |
+
"Reranker is enabled but `RERANK_API_KEY` is missing. Reranking will be disabled (fallback to no-op)."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Ensure index exists locally (download if needed)
|
| 197 |
+
try:
|
| 198 |
+
idx = ensure_index(
|
| 199 |
+
repo_id=self.config.get_str("index.repo_id"),
|
| 200 |
+
revision=self.config.get_str("index.revision", "main") or None,
|
| 201 |
+
target_vector_db_path=self.config.get_str("storage.vector_db_path"),
|
| 202 |
+
target_doc_store_path=self.config.get_str("storage.doc_store_path"),
|
| 203 |
+
storage_dir=str(Path(self.config.get_str("storage.doc_store_path")).parent),
|
| 204 |
+
)
|
| 205 |
+
self.index_manifest = read_manifest(idx.manifest_path)
|
| 206 |
+
|
| 207 |
+
# Optional: warn if embedding model differs
|
| 208 |
+
if self.index_manifest:
|
| 209 |
+
idx_model = (
|
| 210 |
+
(self.index_manifest.get("embedding") or {}).get("model_name")
|
| 211 |
+
or self.index_manifest.get("embedding_model")
|
| 212 |
+
or ""
|
| 213 |
+
)
|
| 214 |
+
cfg_model = self.config.get_str("embedding.model_name")
|
| 215 |
+
if idx_model and cfg_model and idx_model != cfg_model:
|
| 216 |
+
self.startup_warnings.append(
|
| 217 |
+
f"Index embedding model mismatch: index='{idx_model}' vs config='{cfg_model}'. "
|
| 218 |
+
"For best results, rebuild the index with the same embedding model."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
# Try to provide actionable guidance for common HF Hub errors.
|
| 223 |
+
repo_id = self.config.get_str("index.repo_id")
|
| 224 |
+
try:
|
| 225 |
+
from huggingface_hub.utils import ( # type: ignore
|
| 226 |
+
GatedRepoError,
|
| 227 |
+
HfHubHTTPError,
|
| 228 |
+
RepositoryNotFoundError,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if isinstance(e, RepositoryNotFoundError):
|
| 232 |
+
self.startup_error = (
|
| 233 |
+
f"Index dataset repo not found: `{repo_id}`.\n\n"
|
| 234 |
+
"If you haven't uploaded the prebuilt index yet, build and publish it locally:\n"
|
| 235 |
+
"1) `pip install -r requirements-dev.txt`\n"
|
| 236 |
+
"2) `python scripts/build_vector_db.py --config config/default_config.yaml --source huggingface --dataset ZhangNy/radiology-dataset --output-dir ./index_out`\n"
|
| 237 |
+
f"3) `python scripts/publish_index_to_hf.py --repo {repo_id} --folder ./index_out --token $HF_TOKEN`\n\n"
|
| 238 |
+
"Or set `RAG_INDEX_REPO_ID` to an existing index repo."
|
| 239 |
+
)
|
| 240 |
+
return
|
| 241 |
+
if isinstance(e, GatedRepoError):
|
| 242 |
+
self.startup_error = (
|
| 243 |
+
f"Index dataset repo is gated/private: `{repo_id}`.\n\n"
|
| 244 |
+
"Make sure the repo is public, or provide authentication (HF token) in the environment."
|
| 245 |
+
)
|
| 246 |
+
return
|
| 247 |
+
if isinstance(e, HfHubHTTPError):
|
| 248 |
+
self.startup_error = (
|
| 249 |
+
f"Failed to download index from `{repo_id}`.\n\n"
|
| 250 |
+
f"HF Hub error: {e}"
|
| 251 |
+
)
|
| 252 |
+
return
|
| 253 |
+
except Exception:
|
| 254 |
+
# If importing HF-specific exceptions fails, fall back to generic message.
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
self.startup_error = (
|
| 258 |
+
f"Failed to prepare index from `{repo_id}`.\n\n"
|
| 259 |
+
f"Error: {e}"
|
| 260 |
+
)
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
+
# Build RAG engine
|
| 264 |
+
try:
|
| 265 |
+
self.rag_engine = RAGEngine(self.config)
|
| 266 |
+
except Exception as e:
|
| 267 |
+
self.startup_error = f"Failed to initialize RAG engine: {e}"
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
def process_query(
|
| 271 |
+
self,
|
| 272 |
+
question: str,
|
| 273 |
+
temperature: float,
|
| 274 |
+
source_filters: List[str],
|
| 275 |
+
retrieval_strategy: str,
|
| 276 |
+
) -> Iterator[Tuple[str, str, str]]:
|
| 277 |
+
if self.startup_error:
|
| 278 |
+
yield format_error_message(self.startup_error), "", ""
|
| 279 |
+
return
|
| 280 |
+
if self.rag_engine is None:
|
| 281 |
+
yield format_error_message("RAG engine not initialized."), "", ""
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
q = (question or "").strip()
|
| 285 |
+
if not q:
|
| 286 |
+
yield format_error_message("Please enter a question."), "", ""
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
# Update LLM temperature on the fly
|
| 290 |
+
try:
|
| 291 |
+
self.rag_engine.llm.temperature = float(temperature)
|
| 292 |
+
except Exception:
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
sources = source_filters or []
|
| 296 |
+
loading_md = (
|
| 297 |
+
f"{format_loading_message()}\n\n"
|
| 298 |
+
f"**Retrieval Strategy**: `{retrieval_strategy}`\n\n"
|
| 299 |
+
f"**Sources**: `{', '.join(sources) if sources else 'ALL'}`"
|
| 300 |
+
)
|
| 301 |
+
loading_refs = "<p style='color:#6b7280;text-align:center;padding:20px;'>Retrieving & reranking...</p>"
|
| 302 |
+
loading_stats = "<p style='color:#6b7280;padding:10px;'>Working...</p>"
|
| 303 |
+
yield loading_md, loading_refs, loading_stats
|
| 304 |
+
|
| 305 |
+
start_time = time.time()
|
| 306 |
+
last_partial = ""
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
for event in self.rag_engine.query_stream(
|
| 310 |
+
question=q,
|
| 311 |
+
source_filters=sources if sources else None,
|
| 312 |
+
retrieval_strategy=retrieval_strategy,
|
| 313 |
+
):
|
| 314 |
+
etype = (event or {}).get("type")
|
| 315 |
+
if etype == "answer":
|
| 316 |
+
partial = (event.get("answer") or "")
|
| 317 |
+
if partial and partial != last_partial:
|
| 318 |
+
# Make citations clickable: [1] -> [1](#ref-1)
|
| 319 |
+
answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", partial)
|
| 320 |
+
last_partial = partial
|
| 321 |
+
yield answer_md, loading_refs, loading_stats
|
| 322 |
+
elif etype == "final":
|
| 323 |
+
meta = event.get("metadata") or {}
|
| 324 |
+
# If engine didn't populate elapsed_time (it does), we fill it.
|
| 325 |
+
meta.setdefault("elapsed_time", time.time() - start_time)
|
| 326 |
+
|
| 327 |
+
final_answer = (event.get("answer") or "")
|
| 328 |
+
answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", final_answer)
|
| 329 |
+
references_html = format_reference_panel(event.get("references") or [])
|
| 330 |
+
stats_html = format_statistics(meta)
|
| 331 |
+
yield answer_md, references_html, stats_html
|
| 332 |
+
return
|
| 333 |
+
|
| 334 |
+
yield format_error_message("No response was generated. Please try again."), "", ""
|
| 335 |
+
except Exception as e:
|
| 336 |
+
logger.error(f"Error processing query: {e}", exc_info=True)
|
| 337 |
+
yield format_error_message(f"An error occurred: {e}"), "", ""
|
| 338 |
+
|
| 339 |
+
def create_interface(self) -> gr.Blocks:
|
| 340 |
+
title = self.config.get_str("ui.title", "Radiology RAG")
|
| 341 |
+
description = self.config.get_str("ui.description", "")
|
| 342 |
+
theme = self.config.get_str("ui.theme", "soft")
|
| 343 |
+
default_strategy = self.config.get_str("retrieval.strategy", "balanced_multi_source")
|
| 344 |
+
default_sources = self.config.get("retrieval.source_filters", ["article", "case", "tutorial", "encyclopedia"])
|
| 345 |
+
if not isinstance(default_sources, list):
|
| 346 |
+
default_sources = ["article", "case", "tutorial", "encyclopedia"]
|
| 347 |
+
default_temp = self.config.get_float("llm.temperature", 0.7)
|
| 348 |
+
|
| 349 |
+
with gr.Blocks(title=title, theme=theme) as interface:
|
| 350 |
+
gr.Markdown(f"# {title}")
|
| 351 |
+
if description:
|
| 352 |
+
gr.Markdown(description)
|
| 353 |
+
|
| 354 |
+
if self.startup_error:
|
| 355 |
+
gr.Markdown(format_error_message(self.startup_error))
|
| 356 |
+
gr.Markdown(
|
| 357 |
+
"### Required Secrets\n"
|
| 358 |
+
"- `EMBED_API_KEY`\n"
|
| 359 |
+
"- `LLM_API_KEY`\n\n"
|
| 360 |
+
"Optional (recommended):\n"
|
| 361 |
+
"- `RERANK_API_KEY`\n"
|
| 362 |
+
)
|
| 363 |
+
return interface
|
| 364 |
+
|
| 365 |
+
if self.startup_warnings:
|
| 366 |
+
gr.Markdown("### ⚠️ Startup Warnings")
|
| 367 |
+
gr.Markdown("\n".join([f"- {w}" for w in self.startup_warnings]))
|
| 368 |
+
|
| 369 |
+
with gr.Row():
|
| 370 |
+
with gr.Column(scale=1):
|
| 371 |
+
gr.Markdown("### Ask a Question")
|
| 372 |
+
question_input = gr.Textbox(
|
| 373 |
+
label="Your Question",
|
| 374 |
+
placeholder="e.g., What is achalasia and how is it diagnosed?",
|
| 375 |
+
lines=3,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
retrieval_strategy, temperature_slider, source_filter = create_settings_accordion(
|
| 379 |
+
default_strategy=default_strategy,
|
| 380 |
+
default_temperature=default_temp,
|
| 381 |
+
default_sources=default_sources,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
submit_btn = gr.Button("Search & Answer", variant="primary", size="lg")
|
| 385 |
+
|
| 386 |
+
gr.Markdown("### Example Questions")
|
| 387 |
+
gr.Examples(
|
| 388 |
+
examples=[
|
| 389 |
+
["What is achalasia and how is it diagnosed on imaging?"],
|
| 390 |
+
["Explain the imaging findings in Barrett's esophagus"],
|
| 391 |
+
["What are the characteristics of a Zenker's diverticulum?"],
|
| 392 |
+
["Describe the CT findings of esophageal cancer"],
|
| 393 |
+
],
|
| 394 |
+
inputs=[question_input],
|
| 395 |
+
label="Click an example to try it",
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
with gr.Column(scale=2):
|
| 399 |
+
gr.Markdown("### Answer (with citations)")
|
| 400 |
+
answer_output = gr.Markdown(value="*Your answer will appear here...*")
|
| 401 |
+
stats_output = gr.HTML(label="Statistics")
|
| 402 |
+
gr.Markdown("### Retrieved References")
|
| 403 |
+
references_output = gr.HTML(
|
| 404 |
+
value="<p style='color:#6b7280;text-align:center;padding:20px;'>References will appear here...</p>"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
submit_btn.click(
|
| 408 |
+
fn=self.process_query,
|
| 409 |
+
inputs=[question_input, temperature_slider, source_filter, retrieval_strategy],
|
| 410 |
+
outputs=[answer_output, references_output, stats_output],
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
gr.Markdown("---")
|
| 414 |
+
with gr.Accordion("About", open=False):
|
| 415 |
+
gr.Markdown(
|
| 416 |
+
"This Space demonstrates a radiology RAG system using a prebuilt vector index "
|
| 417 |
+
f"(`{self.config.get_str('index.repo_id')}`) and external APIs for embeddings/LLM.\n\n"
|
| 418 |
+
"**Disclaimer**: Educational use only. Always consult qualified professionals for clinical decisions."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
return interface
|
| 422 |
+
|
| 423 |
+
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dev / offline build dependencies (NOT used by Spaces runtime).
|
| 2 |
+
# Includes dataset loading + optional local embedding/rerank tooling.
|
| 3 |
+
|
| 4 |
+
-r requirements.txt
|
| 5 |
+
|
| 6 |
+
datasets>=2.16.0
|
| 7 |
+
tqdm>=4.66.0
|
| 8 |
+
pillow>=10.0.0
|
| 9 |
+
|
| 10 |
+
# Needed for index build (text splitting + Document objects)
|
| 11 |
+
langchain>=0.1.0
|
| 12 |
+
langchain-text-splitters>=0.0.1
|
| 13 |
+
|
| 14 |
+
# Optional: local models for advanced users (large)
|
| 15 |
+
sentence-transformers>=2.3.0
|
| 16 |
+
torch>=2.0.0
|
| 17 |
+
transformers>=4.36.0
|
| 18 |
+
|
| 19 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Runtime dependencies for Hugging Face Spaces (CPU-friendly, API-first).
|
| 2 |
+
# Keep this minimal to improve build stability.
|
| 3 |
+
|
| 4 |
+
# Web UI
|
| 5 |
+
gradio==4.16.0
|
| 6 |
+
gradio_client==0.8.1
|
| 7 |
+
|
| 8 |
+
# RAG core
|
| 9 |
+
langchain-core>=0.1.0
|
| 10 |
+
langchain-openai>=0.0.5
|
| 11 |
+
langchain-chroma>=0.1.0
|
| 12 |
+
chromadb>=0.4.22
|
| 13 |
+
|
| 14 |
+
# Index download (from HF Hub)
|
| 15 |
+
huggingface-hub>=0.20.0
|
| 16 |
+
|
| 17 |
+
# Utilities
|
| 18 |
+
pyyaml>=6.0
|
| 19 |
+
requests>=2.31.0
|
| 20 |
+
|
| 21 |
+
|
scripts/build_vector_db.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build a Chroma + SQLite index for this RAG system (offline / advanced users).
|
| 3 |
+
|
| 4 |
+
The index output folder is compatible with the Space runtime bootstrap:
|
| 5 |
+
<output_dir>/
|
| 6 |
+
chroma_db/
|
| 7 |
+
doc_store.db
|
| 8 |
+
manifest.json
|
| 9 |
+
|
| 10 |
+
Examples:
|
| 11 |
+
|
| 12 |
+
1) Build from HF dataset directly (streaming is not supported for save_to_disk-based build):
|
| 13 |
+
python scripts/build_vector_db.py \
|
| 14 |
+
--config config/default_config.yaml \
|
| 15 |
+
--source huggingface \
|
| 16 |
+
--dataset ZhangNy/radiology-dataset \
|
| 17 |
+
--output-dir ./index_out
|
| 18 |
+
|
| 19 |
+
2) Build from local saved dataset:
|
| 20 |
+
python scripts/build_vector_db.py \
|
| 21 |
+
--config config/default_config.yaml \
|
| 22 |
+
--source local \
|
| 23 |
+
--local-path ./hf_dataset_prepared \
|
| 24 |
+
--output-dir ./index_out
|
| 25 |
+
|
| 26 |
+
Notes:
|
| 27 |
+
- Embedding model used at build time must match query-time embeddings used in the Space,
|
| 28 |
+
otherwise retrieval quality will degrade.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import json
|
| 35 |
+
import os
|
| 36 |
+
import sys
|
| 37 |
+
import shutil
|
| 38 |
+
import time
|
| 39 |
+
from collections import Counter
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 42 |
+
|
| 43 |
+
# Allow running as `python scripts/*.py` without installing the package.
|
| 44 |
+
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _clean_text(text: str) -> str:
|
| 48 |
+
# Remove markdown hyperlinks [text](url) -> text
|
| 49 |
+
import re
|
| 50 |
+
|
| 51 |
+
t = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text or "")
|
| 52 |
+
return t.replace("\xa0", " ")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main() -> int:
|
| 56 |
+
parser = argparse.ArgumentParser(description="Build vector index (Chroma + SQLite doc store)")
|
| 57 |
+
parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Config YAML path")
|
| 58 |
+
parser.add_argument("--source", choices=["local", "huggingface"], default="huggingface")
|
| 59 |
+
parser.add_argument("--local-path", type=str, default=None, help="Path to dataset saved via save_to_disk()")
|
| 60 |
+
parser.add_argument("--dataset", type=str, default="ZhangNy/radiology-dataset", help="HF dataset repo id")
|
| 61 |
+
parser.add_argument("--split", type=str, default="train")
|
| 62 |
+
parser.add_argument("--limit", type=int, default=None, help="Limit number of documents (debug)")
|
| 63 |
+
parser.add_argument("--output-dir", type=str, default="./index_out", help="Output directory for index artifacts")
|
| 64 |
+
parser.add_argument("--overwrite", action="store_true", help="Overwrite output dir if exists")
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
from datasets import load_dataset, load_from_disk
|
| 68 |
+
from langchain_chroma import Chroma
|
| 69 |
+
from langchain_core.documents import Document
|
| 70 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 71 |
+
|
| 72 |
+
from radiology_rag.config import Config
|
| 73 |
+
from radiology_rag.doc_store import PersistentDocStore
|
| 74 |
+
from radiology_rag.embedding import EmbeddingClient, EmbeddingConfig
|
| 75 |
+
|
| 76 |
+
cfg = Config(args.config)
|
| 77 |
+
|
| 78 |
+
out_dir = Path(args.output_dir)
|
| 79 |
+
chroma_dir = out_dir / "chroma_db"
|
| 80 |
+
doc_db = out_dir / "doc_store.db"
|
| 81 |
+
manifest_path = out_dir / "manifest.json"
|
| 82 |
+
|
| 83 |
+
if out_dir.exists() and args.overwrite:
|
| 84 |
+
shutil.rmtree(out_dir)
|
| 85 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
if chroma_dir.exists() or doc_db.exists():
|
| 88 |
+
if not args.overwrite:
|
| 89 |
+
raise SystemExit(f"Output dir already has index artifacts. Use --overwrite. ({out_dir})")
|
| 90 |
+
|
| 91 |
+
# Load dataset
|
| 92 |
+
if args.source == "local":
|
| 93 |
+
if not args.local_path:
|
| 94 |
+
raise SystemExit("--local-path is required when --source local")
|
| 95 |
+
dataset = load_from_disk(args.local_path)
|
| 96 |
+
else:
|
| 97 |
+
dataset = load_dataset(args.dataset, split=args.split)
|
| 98 |
+
|
| 99 |
+
if args.limit:
|
| 100 |
+
dataset = dataset.select(range(min(int(args.limit), len(dataset))))
|
| 101 |
+
|
| 102 |
+
# Splitter
|
| 103 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 104 |
+
chunk_size=cfg.get_int("processing.chunk_size", 1024),
|
| 105 |
+
chunk_overlap=cfg.get_int("processing.chunk_overlap", 200),
|
| 106 |
+
separators=cfg.get("processing.separators", ["\n\n", "\n", " "]),
|
| 107 |
+
keep_separator=cfg.get_bool("processing.keep_separator", True),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Embeddings
|
| 111 |
+
emb = EmbeddingClient(
|
| 112 |
+
EmbeddingConfig(
|
| 113 |
+
base_url=cfg.get_str("embedding.api_base_url"),
|
| 114 |
+
api_key=cfg.get_str("embedding.api_key"),
|
| 115 |
+
model_name=cfg.get_str("embedding.model_name"),
|
| 116 |
+
batch_size=cfg.get_int("embedding.batch_size", 32),
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Storage
|
| 121 |
+
doc_store = PersistentDocStore(str(doc_db), read_only=False)
|
| 122 |
+
vectorstore = Chroma(
|
| 123 |
+
collection_name="radiology_docs",
|
| 124 |
+
embedding_function=emb.langchain_embeddings,
|
| 125 |
+
persist_directory=str(chroma_dir),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Build
|
| 129 |
+
start = time.time()
|
| 130 |
+
parent_pairs: List[Tuple[str, Dict[str, Any]]] = []
|
| 131 |
+
child_docs: List[Document] = []
|
| 132 |
+
counts = Counter()
|
| 133 |
+
|
| 134 |
+
for item in dataset:
|
| 135 |
+
doc_id = (item.get("doc_id") or "").strip()
|
| 136 |
+
if not doc_id:
|
| 137 |
+
continue
|
| 138 |
+
source_type = (item.get("source_type") or "").strip()
|
| 139 |
+
title = (item.get("title") or "").strip()
|
| 140 |
+
content = _clean_text(item.get("content") or "")
|
| 141 |
+
url = (item.get("url") or "").strip()
|
| 142 |
+
metadata = item.get("metadata") or {}
|
| 143 |
+
|
| 144 |
+
counts[source_type or "unknown"] += 1
|
| 145 |
+
|
| 146 |
+
# Parent document record
|
| 147 |
+
parent_pairs.append(
|
| 148 |
+
(
|
| 149 |
+
doc_id,
|
| 150 |
+
{
|
| 151 |
+
"complete_document": {
|
| 152 |
+
"doc_id": doc_id,
|
| 153 |
+
"title": title,
|
| 154 |
+
"content": content,
|
| 155 |
+
"url": url,
|
| 156 |
+
"metadata": metadata,
|
| 157 |
+
},
|
| 158 |
+
"main_content": content,
|
| 159 |
+
"images": [], # not used in this Space
|
| 160 |
+
"source_type": source_type,
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Child chunks for vector store
|
| 166 |
+
chunks = splitter.split_text(content)
|
| 167 |
+
total = len(chunks)
|
| 168 |
+
for i, chunk in enumerate(chunks):
|
| 169 |
+
child_docs.append(
|
| 170 |
+
Document(
|
| 171 |
+
page_content=chunk,
|
| 172 |
+
metadata={
|
| 173 |
+
"doc_id": f"{doc_id}_chunk_{i}",
|
| 174 |
+
"parent_id": doc_id,
|
| 175 |
+
"source_type": source_type,
|
| 176 |
+
"title": title,
|
| 177 |
+
"chunk_index": i,
|
| 178 |
+
"total_chunks": total,
|
| 179 |
+
},
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Persist parent docs
|
| 184 |
+
doc_store.mset(parent_pairs)
|
| 185 |
+
|
| 186 |
+
# Add chunks in batches
|
| 187 |
+
batch_size = int(cfg.get_int("processing.batch_size", 32))
|
| 188 |
+
for i in range(0, len(child_docs), batch_size):
|
| 189 |
+
vectorstore.add_documents(child_docs[i : i + batch_size])
|
| 190 |
+
|
| 191 |
+
elapsed = time.time() - start
|
| 192 |
+
|
| 193 |
+
# Manifest
|
| 194 |
+
manifest = {
|
| 195 |
+
"built_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
|
| 196 |
+
"seconds": elapsed,
|
| 197 |
+
"dataset": {"source": args.source, "dataset": args.dataset, "split": args.split, "limit": args.limit},
|
| 198 |
+
"embedding": {"type": "api", "model_name": cfg.get_str("embedding.model_name"), "base_url": cfg.get_str("embedding.api_base_url")},
|
| 199 |
+
"processing": {
|
| 200 |
+
"chunk_size": cfg.get_int("processing.chunk_size", 1024),
|
| 201 |
+
"chunk_overlap": cfg.get_int("processing.chunk_overlap", 200),
|
| 202 |
+
},
|
| 203 |
+
"counts_by_source_type": dict(counts),
|
| 204 |
+
"artifacts": {"chroma_dir": "chroma_db", "doc_store": "doc_store.db"},
|
| 205 |
+
}
|
| 206 |
+
with open(manifest_path, "w", encoding="utf-8") as f:
|
| 207 |
+
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
| 208 |
+
|
| 209 |
+
print(f"✓ Index built at: {out_dir}")
|
| 210 |
+
print(f" - documents: {sum(counts.values())} (by type: {dict(counts)})")
|
| 211 |
+
print(f" - chunks: {len(child_docs)}")
|
| 212 |
+
print(f" - elapsed: {elapsed:.1f}s")
|
| 213 |
+
return 0
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
raise SystemExit(main())
|
| 218 |
+
|
| 219 |
+
|
scripts/download_hf_dataset.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Download the public dataset from Hugging Face and save it to disk.
|
| 3 |
+
|
| 4 |
+
Example:
|
| 5 |
+
python scripts/download_hf_dataset.py \
|
| 6 |
+
--dataset ZhangNy/radiology-dataset \
|
| 7 |
+
--split train \
|
| 8 |
+
--output ./hf_dataset_prepared
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Allow running as `python scripts/*.py` without installing the package.
|
| 18 |
+
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> int:
|
| 22 |
+
parser = argparse.ArgumentParser(description="Download HF dataset to local disk")
|
| 23 |
+
parser.add_argument("--dataset", type=str, default="ZhangNy/radiology-dataset", help="HF dataset repo id")
|
| 24 |
+
parser.add_argument("--split", type=str, default="train", help="Dataset split")
|
| 25 |
+
parser.add_argument("--output", type=str, default="./hf_dataset_prepared", help="Output directory (save_to_disk)")
|
| 26 |
+
parser.add_argument("--cache-dir", type=str, default=None, help="Optional datasets cache dir")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
from datasets import load_dataset
|
| 30 |
+
|
| 31 |
+
out_dir = Path(args.output)
|
| 32 |
+
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
ds = load_dataset(args.dataset, split=args.split, cache_dir=args.cache_dir)
|
| 35 |
+
ds.save_to_disk(str(out_dir))
|
| 36 |
+
print(f"✓ Saved dataset to: {out_dir} (rows={len(ds)})")
|
| 37 |
+
return 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
raise SystemExit(main())
|
| 42 |
+
|
| 43 |
+
|
scripts/package_existing_storage.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Package an existing local index folder (e.g. rebuild_1217/storage) into a clean index folder.
|
| 3 |
+
|
| 4 |
+
This is the fastest path if you already built the index locally and want to publish it
|
| 5 |
+
to Hugging Face without rebuilding embeddings.
|
| 6 |
+
|
| 7 |
+
Input (example):
|
| 8 |
+
/path/to/storage/
|
| 9 |
+
chroma_db/
|
| 10 |
+
doc_store.db
|
| 11 |
+
images/ # optional (ignored)
|
| 12 |
+
|
| 13 |
+
Output:
|
| 14 |
+
./index_out/
|
| 15 |
+
chroma_db/
|
| 16 |
+
doc_store.db
|
| 17 |
+
manifest.json
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import sys
|
| 25 |
+
import shutil
|
| 26 |
+
import sqlite3
|
| 27 |
+
import time
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Dict
|
| 30 |
+
|
| 31 |
+
# Allow running as `python scripts/*.py` without installing the package.
|
| 32 |
+
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _count_by_source_type(doc_store_db: Path) -> Dict[str, int]:
|
| 36 |
+
counts: Dict[str, int] = {}
|
| 37 |
+
conn = sqlite3.connect(str(doc_store_db))
|
| 38 |
+
try:
|
| 39 |
+
cur = conn.cursor()
|
| 40 |
+
cur.execute("SELECT source_type, COUNT(*) FROM documents GROUP BY source_type")
|
| 41 |
+
for source_type, count in cur.fetchall():
|
| 42 |
+
counts[str(source_type)] = int(count)
|
| 43 |
+
finally:
|
| 44 |
+
conn.close()
|
| 45 |
+
return counts
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main() -> int:
|
| 49 |
+
parser = argparse.ArgumentParser(description="Package existing index storage into index_out (no images)")
|
| 50 |
+
parser.add_argument("--storage", type=str, required=True, help="Existing storage dir containing chroma_db/ + doc_store.db")
|
| 51 |
+
parser.add_argument("--output-dir", type=str, default="./index_out", help="Output folder")
|
| 52 |
+
parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Config YAML (for embedding metadata)")
|
| 53 |
+
parser.add_argument("--overwrite", action="store_true", help="Overwrite output dir if exists")
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
from radiology_rag.config import Config
|
| 57 |
+
|
| 58 |
+
storage = Path(args.storage)
|
| 59 |
+
src_chroma = storage / "chroma_db"
|
| 60 |
+
src_doc = storage / "doc_store.db"
|
| 61 |
+
if not src_chroma.exists() or not src_doc.exists():
|
| 62 |
+
raise SystemExit(f"Storage missing required files: {src_chroma} / {src_doc}")
|
| 63 |
+
|
| 64 |
+
out_dir = Path(args.output_dir)
|
| 65 |
+
out_chroma = out_dir / "chroma_db"
|
| 66 |
+
out_doc = out_dir / "doc_store.db"
|
| 67 |
+
out_manifest = out_dir / "manifest.json"
|
| 68 |
+
|
| 69 |
+
if out_dir.exists() and args.overwrite:
|
| 70 |
+
shutil.rmtree(out_dir)
|
| 71 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
if out_chroma.exists() or out_doc.exists():
|
| 74 |
+
if not args.overwrite:
|
| 75 |
+
raise SystemExit(f"Output already exists. Use --overwrite. ({out_dir})")
|
| 76 |
+
|
| 77 |
+
# Copy artifacts (exclude images/)
|
| 78 |
+
if out_chroma.exists():
|
| 79 |
+
shutil.rmtree(out_chroma, ignore_errors=True)
|
| 80 |
+
shutil.copytree(src_chroma, out_chroma, dirs_exist_ok=False)
|
| 81 |
+
shutil.copy2(src_doc, out_doc)
|
| 82 |
+
|
| 83 |
+
cfg = Config(args.config)
|
| 84 |
+
counts = _count_by_source_type(out_doc)
|
| 85 |
+
manifest = {
|
| 86 |
+
"packaged_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
|
| 87 |
+
"source_storage": str(storage),
|
| 88 |
+
"embedding": {"model_name": cfg.get_str("embedding.model_name"), "type": cfg.get_str("embedding.type", "api")},
|
| 89 |
+
"processing": {
|
| 90 |
+
"chunk_size": cfg.get_int("processing.chunk_size", 1024),
|
| 91 |
+
"chunk_overlap": cfg.get_int("processing.chunk_overlap", 200),
|
| 92 |
+
},
|
| 93 |
+
"counts_by_source_type": counts,
|
| 94 |
+
"artifacts": {"chroma_dir": "chroma_db", "doc_store": "doc_store.db"},
|
| 95 |
+
"images_included": False,
|
| 96 |
+
}
|
| 97 |
+
with open(out_manifest, "w", encoding="utf-8") as f:
|
| 98 |
+
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
| 99 |
+
|
| 100 |
+
print(f"✓ Packaged index to: {out_dir}")
|
| 101 |
+
print(f" - chroma_db: {out_chroma}")
|
| 102 |
+
print(f" - doc_store: {out_doc}")
|
| 103 |
+
print(f" - manifest: {out_manifest}")
|
| 104 |
+
return 0
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
raise SystemExit(main())
|
| 109 |
+
|
| 110 |
+
|
scripts/publish_index_to_hf.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Publish a built index folder to Hugging Face Datasets.
|
| 3 |
+
|
| 4 |
+
Example:
|
| 5 |
+
python scripts/publish_index_to_hf.py \
|
| 6 |
+
--repo ZhangNy/radiology-index-qwen3-embedding-0.6b \
|
| 7 |
+
--folder ./index_out \
|
| 8 |
+
--token $HF_TOKEN
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Allow running as `python scripts/*.py` without installing the package.
|
| 18 |
+
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> int:
|
| 22 |
+
parser = argparse.ArgumentParser(description="Upload index artifacts to HF datasets repo")
|
| 23 |
+
parser.add_argument("--repo", type=str, required=True, help="HF dataset repo id, e.g. user/my-index")
|
| 24 |
+
parser.add_argument("--folder", type=str, required=True, help="Local folder containing chroma_db/ + doc_store.db")
|
| 25 |
+
parser.add_argument("--token", type=str, default=None, help="HF token (or set HF_TOKEN env)")
|
| 26 |
+
parser.add_argument("--private", action="store_true", help="Create repo as private")
|
| 27 |
+
parser.add_argument("--revision", type=str, default="main", help="Target revision/branch")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--ignore",
|
| 30 |
+
type=str,
|
| 31 |
+
default="",
|
| 32 |
+
help="Comma-separated ignore patterns for upload_folder (e.g. 'images/**,**/images/**')",
|
| 33 |
+
)
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
from huggingface_hub import HfApi
|
| 37 |
+
|
| 38 |
+
token = args.token or None
|
| 39 |
+
if token is None:
|
| 40 |
+
import os
|
| 41 |
+
|
| 42 |
+
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
|
| 43 |
+
|
| 44 |
+
if not token:
|
| 45 |
+
raise SystemExit("Missing token. Provide --token or set HF_TOKEN.")
|
| 46 |
+
|
| 47 |
+
folder = Path(args.folder)
|
| 48 |
+
if not folder.exists():
|
| 49 |
+
raise SystemExit(f"Folder not found: {folder}")
|
| 50 |
+
|
| 51 |
+
api = HfApi()
|
| 52 |
+
api.create_repo(
|
| 53 |
+
repo_id=args.repo,
|
| 54 |
+
repo_type="dataset",
|
| 55 |
+
private=bool(args.private),
|
| 56 |
+
exist_ok=True,
|
| 57 |
+
token=token,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
api.upload_folder(
|
| 61 |
+
repo_id=args.repo,
|
| 62 |
+
repo_type="dataset",
|
| 63 |
+
folder_path=str(folder),
|
| 64 |
+
path_in_repo="",
|
| 65 |
+
token=token,
|
| 66 |
+
revision=args.revision,
|
| 67 |
+
commit_message="Upload prebuilt radiology RAG index",
|
| 68 |
+
ignore_patterns=[p.strip() for p in (args.ignore or "").split(",") if p.strip()] or None,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
print(f"✓ Uploaded index folder to HF dataset repo: {args.repo}")
|
| 72 |
+
return 0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
raise SystemExit(main())
|
| 77 |
+
|
| 78 |
+
|