ZhangNy commited on
Commit
75db650
·
1 Parent(s): 269a91a

Add Space app files

Browse files
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.so
5
+ *.egg-info/
6
+ dist/
7
+ build/
8
+ .pytest_cache/
9
+
10
+ # Virtual environments
11
+ venv/
12
+ .venv/
13
+ ENV/
14
+ env/
15
+
16
+ # IDEs
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+ *~
22
+
23
+ # Local runtime storage (vector DB + sqlite doc store)
24
+ storage/
25
+ *.db
26
+ *.sqlite
27
+ *.sqlite3
28
+
29
+ # Dataset caches / artifacts
30
+ hf_dataset_prepared/
31
+ .cache/
32
+ .huggingface/
33
+
34
+ # Raw data (private; never publish to Spaces)
35
+ old_related_files/
36
+ private_scripts/
37
+
38
+ # Logs
39
+ *.log
40
+ logs/
41
+
42
+ # Env files
43
+ .env
44
+ .env.local
45
+
46
+
README.md CHANGED
@@ -11,4 +11,86 @@ license: mit
11
  short_description: 'Ask questions about thoracic radiology and get answers with '
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: 'Ask questions about thoracic radiology and get answers with '
12
  ---
13
 
14
+
15
+ ## Overview
16
+
17
+ This repository contains a **Hugging Face Spaces-ready** RAG (Retrieval-Augmented Generation) demo for thoracic radiology Q&A.
18
+
19
+ - **Default index (prebuilt)**: `ZhangNy/radiology-index-qwen3-embedding-0.6b`
20
+ - **Raw public dataset**: `ZhangNy/radiology-dataset`
21
+ - **No image rendering in UI**: references link to original pages where images can be viewed.
22
+
23
+ The Space uses **external APIs** for Embeddings / Reranker / LLM via **Secrets**.
24
+
25
+ ## Run (local)
26
+
27
+ ```bash
28
+ cd LangGraphAgent/rebuild_1219
29
+ pip install -r requirements.txt
30
+
31
+ export EMBED_API_KEY="..."
32
+ export LLM_API_KEY="..."
33
+ # optional:
34
+ export RERANK_API_KEY="..."
35
+
36
+ python app.py --config config/default_config.yaml --host 0.0.0.0 --port 7860
37
+ ```
38
+
39
+ Open `http://localhost:7860`.
40
+
41
+ ## Required Hugging Face Space Secrets
42
+
43
+ ### Required
44
+
45
+ - **`EMBED_API_KEY`**: embedding API key (OpenAI-compatible)
46
+ - **`LLM_API_KEY`**: LLM API key (OpenAI-compatible)
47
+
48
+ ### Recommended
49
+
50
+ - **`RERANK_API_KEY`**: reranker API key (OpenAI-compatible `/rerank` endpoint)
51
+
52
+ ### Optional (override defaults)
53
+
54
+ - **`EMBED_API_BASE_URL`**, **`EMBED_MODEL_NAME`**
55
+ - **`RERANK_API_BASE_URL`**, **`RERANK_MODEL_NAME`**
56
+ - **`LLM_BASE_URL`**, **`LLM_MODEL_NAME`**
57
+ - **`RAG_INDEX_REPO_ID`** (default: `ZhangNy/radiology-index-qwen3-embedding-0.6b`)
58
+ - **`RAG_STORAGE_DIR`** (default: `/data/radiology_rag` if `/data` exists, else `./storage`)
59
+
60
+ ## Advanced: rebuild your own index (offline)
61
+
62
+ Install dev deps:
63
+
64
+ ```bash
65
+ pip install -r requirements-dev.txt
66
+ ```
67
+
68
+ The `scripts/` folder (to be used locally) will support:
69
+ - Downloading `ZhangNy/radiology-dataset` to `./hf_dataset_prepared`
70
+ - Building a new index with a different embedding model
71
+ - Publishing that index as a Hugging Face dataset repo
72
+
73
+ ### Fast path (no rebuild): publish your existing local index
74
+
75
+ If you already have a built index locally (e.g. `rebuild_1217/storage` contains `chroma_db/` + `doc_store.db`),
76
+ you can **package it without images** and upload it:
77
+
78
+ ```bash
79
+ python scripts/package_existing_storage.py \
80
+ --storage /home/zny/codes/radioagent_prepare/LangGraphAgent/rebuild_1217/storage \
81
+ --output-dir ./index_out \
82
+ --overwrite
83
+
84
+ python scripts/publish_index_to_hf.py \
85
+ --repo ZhangNy/radiology-index-qwen3-embedding-0.6b \
86
+ --folder ./index_out \
87
+ --token $HF_TOKEN
88
+ ```
89
+
90
+ ## Notes
91
+
92
+ - **Do not commit API keys**. This repo is configured to read them from environment variables / Space Secrets.
93
+ - **Index compatibility**: query-time embedding model should match the index embedding model for best retrieval quality.
94
+
95
+
96
+
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces entry point (Gradio).
3
+
4
+ Run locally:
5
+ python app.py --config config/default_config.yaml --host 0.0.0.0 --port 7860
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+
15
+ from radiology_rag.gradio_compat import patch_gradio_predict_body_for_pydantic_v2
16
+ from radiology_rag.ui import RadiologyRAGApp
17
+
18
+
19
+ def _configure_logging() -> None:
20
+ level = os.getenv("LOG_LEVEL", "INFO").upper()
21
+ logging.basicConfig(
22
+ level=getattr(logging, level, logging.INFO),
23
+ format="%(asctime)s - %(levelname)s - %(message)s",
24
+ )
25
+
26
+
27
+ def _default_storage_dir() -> str:
28
+ # Prefer /data on Spaces if persistent storage is enabled.
29
+ if Path("/data").exists():
30
+ return "/data/radiology_rag"
31
+ return "./storage"
32
+
33
+
34
+ def main() -> int:
35
+ _configure_logging()
36
+
37
+ parser = argparse.ArgumentParser(description="Radiology RAG (Spaces-ready)")
38
+ parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Path to config YAML")
39
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
40
+ parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "7860")), help="Server port")
41
+ args = parser.parse_args()
42
+
43
+ # Ensure storage dir env is set early so config interpolation uses it.
44
+ if not os.getenv("RAG_STORAGE_DIR"):
45
+ os.environ["RAG_STORAGE_DIR"] = _default_storage_dir()
46
+
47
+ # Optional compatibility patch for Gradio 4.16 + Pydantic v2.
48
+ if patch_gradio_predict_body_for_pydantic_v2():
49
+ logging.getLogger(__name__).info("Applied Gradio/Pydantic v2 compatibility patch")
50
+
51
+ app = RadiologyRAGApp(config_path=args.config)
52
+ demo = app.create_interface()
53
+
54
+ demo.launch(
55
+ server_name=args.host,
56
+ server_port=args.port,
57
+ share=False,
58
+ show_error=True,
59
+ )
60
+ return 0
61
+
62
+
63
+ if __name__ == "__main__":
64
+ raise SystemExit(main())
65
+
66
+
config/default_config.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default configuration for the Hugging Face Spaces deployment (API-first).
2
+ # IMPORTANT: do NOT put any real API keys in this file. Use Spaces Secrets instead.
3
+
4
+ # Prebuilt index (vector DB + doc store) stored as a Hugging Face dataset repo.
5
+ index:
6
+ repo_id: "${RAG_INDEX_REPO_ID:ZhangNy/radiology-index-qwen3-embedding-0.6b}"
7
+ revision: "${RAG_INDEX_REVISION:main}"
8
+
9
+ # Embedding configuration (query-time embeddings must match the index embedding model).
10
+ embedding:
11
+ type: "api" # Options: "api" or "local" (local requires requirements-dev.txt)
12
+ api_base_url: "${EMBED_API_BASE_URL:https://api.siliconflow.cn/v1}"
13
+ api_key: "${EMBED_API_KEY:}"
14
+ model_name: "${EMBED_MODEL_NAME:Qwen/Qwen3-Embedding-0.6B}"
15
+ batch_size: 32
16
+
17
+ # Reranker configuration (recommended; can be disabled if no key is provided).
18
+ reranker:
19
+ enabled: true
20
+ type: "api" # Options: "api" or "local" (local requires requirements-dev.txt)
21
+ api_base_url: "${RERANK_API_BASE_URL:https://api.siliconflow.cn/v1}"
22
+ api_key: "${RERANK_API_KEY:}"
23
+ model_name: "${RERANK_MODEL_NAME:BAAI/bge-reranker-v2-m3}"
24
+ top_k: 10
25
+
26
+ # LLM configuration (OpenAI-compatible API).
27
+ llm:
28
+ base_url: "${LLM_BASE_URL:https://poloai.top/v1}"
29
+ api_key: "${LLM_API_KEY:}"
30
+ model_name: "${LLM_MODEL_NAME:gemini-3-flash-preview}"
31
+ temperature: 0.7
32
+ max_tokens: 2000
33
+
34
+ # Storage paths (prefer /data on Spaces; app.py will default RAG_STORAGE_DIR to /data/radiology_rag when available).
35
+ storage:
36
+ vector_db_path: "${RAG_STORAGE_DIR:./storage}/chroma_db"
37
+ doc_store_path: "${RAG_STORAGE_DIR:./storage}/doc_store.db"
38
+
39
+ # Text splitting parameters (used for index build scripts; kept here for transparency).
40
+ processing:
41
+ chunk_size: 1024
42
+ chunk_overlap: 200
43
+ separators:
44
+ - "\n\n#### "
45
+ - "\n\n### "
46
+ - "\n\n## "
47
+ - "\n\n"
48
+ - "\n"
49
+ - " "
50
+ - ""
51
+ keep_separator: true
52
+
53
+ # Retrieval configuration
54
+ retrieval:
55
+ # Default strategy for this Space:
56
+ # - balanced_multi_source: includes Wikipedia (encyclopedia) by default
57
+ strategy: "balanced_multi_source"
58
+ top_k: 20
59
+ source_filters:
60
+ - "article"
61
+ - "case"
62
+ - "tutorial"
63
+ - "encyclopedia"
64
+
65
+ search_type: "similarity" # "similarity" or "mmr"
66
+ chunk_fetch_multiplier: 3
67
+
68
+ # MMR parameters (only if search_type == "mmr")
69
+ mmr_lambda: 0.5
70
+ mmr_fetch_k: 50
71
+
72
+ # Balanced multi-source retrieval policy
73
+ multi_source:
74
+ total_top_k: 8
75
+ sources_priority: ["article", "case", "encyclopedia", "tutorial"]
76
+ article:
77
+ candidate_k: 80
78
+ max_k: 3
79
+ min_score: 0.15
80
+ required: true
81
+ case:
82
+ candidate_k: 80
83
+ max_k: 3
84
+ min_score: 0.15
85
+ required: true
86
+ encyclopedia:
87
+ candidate_k: 8
88
+ max_k: 2
89
+ min_score: 0.15
90
+ required: true
91
+ tutorial:
92
+ candidate_k: 20
93
+ max_k: 2
94
+ min_score: 0.50
95
+ required: false
96
+
97
+ # Encyclopedia configuration (Wikipedia)
98
+ encyclopedia:
99
+ wikipedia:
100
+ language: "en"
101
+ user_agent: "RadiologyRAG-Space/1.0"
102
+ timeout_s: 10
103
+ max_chars_per_doc: 2000
104
+
105
+ # Citation configuration (no images in this Space; references link to original pages)
106
+ citation:
107
+ format: "numbered"
108
+ max_content_length: 900
109
+
110
+ # Gradio UI configuration
111
+ ui:
112
+ title: "Thoracic Radiology RAG System"
113
+ description: "Ask questions about thoracic radiology and get answers with citations (articles, cases, tutorials + Wikipedia)."
114
+ theme: "soft"
115
+ show_retrieved_docs: true
116
+
117
+ # Logging configuration
118
+ logging:
119
+ level: "INFO"
120
+ format: "%(asctime)s - %(levelname)s - %(message)s"
121
+
122
+
radiology_rag/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Radiology RAG (Spaces-ready) - minimal, API-first implementation.
3
+
4
+ This package is designed to be deployed on Hugging Face Spaces and load a
5
+ prebuilt vector index from a public Hugging Face dataset repo.
6
+ """
7
+
8
+ __all__ = ["__version__"]
9
+
10
+ __version__ = "0.1.0"
11
+
12
+
radiology_rag/citations.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Citation helpers (numbered citations like [1], [2], ...)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+
9
+ class CitationManager:
10
+ def __init__(self, *, max_content_length: int = 900):
11
+ self.max_content_length = int(max_content_length)
12
+ self.documents: List[Dict[str, Any]] = []
13
+ self.doc_id_to_index: Dict[str, int] = {}
14
+
15
+ def clear(self) -> None:
16
+ self.documents = []
17
+ self.doc_id_to_index = {}
18
+
19
+ def add_document(self, document: Dict[str, Any]) -> int:
20
+ doc_id = document.get("doc_id") or ""
21
+ if doc_id in self.doc_id_to_index:
22
+ return self.doc_id_to_index[doc_id]
23
+ self.documents.append(document)
24
+ idx = len(self.documents)
25
+ self.doc_id_to_index[doc_id] = idx
26
+ return idx
27
+
28
+ def add_documents(self, documents: List[Dict[str, Any]]) -> List[int]:
29
+ return [self.add_document(d) for d in documents]
30
+
31
+ @staticmethod
32
+ def parse_citations_in_text(text: str) -> List[int]:
33
+ matches = re.findall(r"\[(\d+)\]", text or "")
34
+ out = []
35
+ for m in matches:
36
+ try:
37
+ out.append(int(m))
38
+ except Exception:
39
+ continue
40
+ return out
41
+
42
+ def validate_citations(self, text: str) -> Tuple[bool, List[int]]:
43
+ cited = self.parse_citations_in_text(text or "")
44
+ invalid = [i for i in cited if i < 1 or i > len(self.documents)]
45
+ return (len(invalid) == 0), invalid
46
+
47
+ def get_statistics(self) -> Dict[str, Any]:
48
+ counts: Dict[str, int] = {}
49
+ for d in self.documents:
50
+ st = d.get("source_type", "unknown") or "unknown"
51
+ counts[st] = counts.get(st, 0) + 1
52
+ return {"total": len(self.documents), "source_type_counts": counts}
53
+
54
+
radiology_rag/config.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Config loader with `${ENV_VAR}` and `${ENV_VAR:default}` interpolation.
3
+
4
+ We intentionally keep config logic lightweight so it works well in Spaces.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Optional
13
+
14
+ import yaml
15
+
16
+
17
+ class Config:
18
+ """Load and access YAML config with env var interpolation."""
19
+
20
+ def __init__(self, config_path: str):
21
+ self.config_path = Path(config_path)
22
+ self._config = self._load_config()
23
+ self._config = self._recursive_resolve(self._config)
24
+
25
+ def _load_config(self) -> Dict[str, Any]:
26
+ if not self.config_path.exists():
27
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
28
+ with open(self.config_path, "r", encoding="utf-8") as f:
29
+ data = yaml.safe_load(f) or {}
30
+ if not isinstance(data, dict):
31
+ raise ValueError("Config root must be a mapping/dict")
32
+ return data
33
+
34
+ @staticmethod
35
+ def _resolve_string(value: str) -> str:
36
+ # Pattern: ${VAR_NAME} or ${VAR_NAME:default_value}
37
+ # NOTE: default_value may be empty, e.g. `${API_KEY:}`. Use `*` (not `+`) to allow empty.
38
+ pattern = r"\$\{([^:}]+)(?::([^}]*))?\}"
39
+
40
+ def replace(match: re.Match) -> str:
41
+ var_name = match.group(1)
42
+ default_value = match.group(2) if match.group(2) is not None else ""
43
+ return os.getenv(var_name, default_value)
44
+
45
+ return re.sub(pattern, replace, value)
46
+
47
+ def _recursive_resolve(self, obj: Any) -> Any:
48
+ if isinstance(obj, dict):
49
+ return {k: self._recursive_resolve(v) for k, v in obj.items()}
50
+ if isinstance(obj, list):
51
+ return [self._recursive_resolve(v) for v in obj]
52
+ if isinstance(obj, str):
53
+ return self._resolve_string(obj)
54
+ return obj
55
+
56
+ def get(self, key: str, default: Any = None) -> Any:
57
+ keys = key.split(".")
58
+ value: Any = self._config
59
+ for k in keys:
60
+ if isinstance(value, dict) and k in value:
61
+ value = value[k]
62
+ else:
63
+ return default
64
+ return value
65
+
66
+ def get_str(self, key: str, default: str = "") -> str:
67
+ v = self.get(key, default)
68
+ return default if v is None else str(v)
69
+
70
+ def get_int(self, key: str, default: int = 0) -> int:
71
+ v = self.get(key, default)
72
+ if v is None:
73
+ return default
74
+ if isinstance(v, int):
75
+ return v
76
+ try:
77
+ return int(str(v).strip())
78
+ except Exception:
79
+ return default
80
+
81
+ def get_float(self, key: str, default: float = 0.0) -> float:
82
+ v = self.get(key, default)
83
+ if v is None:
84
+ return default
85
+ if isinstance(v, (int, float)):
86
+ return float(v)
87
+ try:
88
+ return float(str(v).strip())
89
+ except Exception:
90
+ return default
91
+
92
+ def get_bool(self, key: str, default: bool = False) -> bool:
93
+ v = self.get(key, default)
94
+ if isinstance(v, bool):
95
+ return v
96
+ if v is None:
97
+ return default
98
+ s = str(v).strip().lower()
99
+ if s in {"1", "true", "yes", "y", "on"}:
100
+ return True
101
+ if s in {"0", "false", "no", "n", "off"}:
102
+ return False
103
+ return default
104
+
105
+ def as_dict(self) -> Dict[str, Any]:
106
+ return dict(self._config)
107
+
108
+
radiology_rag/doc_store.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite-backed document store.
3
+
4
+ We keep the schema compatible with the previous rebuild_1217 implementation:
5
+ table `documents(doc_id, complete_document, main_content, images, source_type)`.
6
+
7
+ In this Space we do NOT use images, but we keep the column for compatibility with
8
+ existing indexes and to allow advanced users to extend the system.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ import sqlite3
17
+ from typing import Any, Iterator, List, Optional, Sequence, Tuple
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class PersistentDocStore:
23
+ def __init__(self, db_path: str, *, read_only: bool = False):
24
+ self.db_path = db_path
25
+ self.read_only = bool(read_only)
26
+
27
+ if not self.read_only:
28
+ self.init_db()
29
+
30
+ def _connect(self) -> sqlite3.Connection:
31
+ if self.read_only:
32
+ # Open in read-only mode to avoid accidental writes in Spaces runtime.
33
+ return sqlite3.connect(f"file:{self.db_path}?mode=ro", uri=True)
34
+ return sqlite3.connect(self.db_path)
35
+
36
+ def init_db(self) -> None:
37
+ db_dir = os.path.dirname(self.db_path) or "."
38
+ os.makedirs(db_dir, exist_ok=True)
39
+ conn = self._connect()
40
+ try:
41
+ cursor = conn.cursor()
42
+ cursor.execute(
43
+ """
44
+ CREATE TABLE IF NOT EXISTS documents (
45
+ doc_id TEXT PRIMARY KEY,
46
+ complete_document TEXT,
47
+ main_content TEXT,
48
+ images TEXT,
49
+ source_type TEXT
50
+ )
51
+ """
52
+ )
53
+ conn.commit()
54
+ finally:
55
+ conn.close()
56
+
57
+ def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None:
58
+ if self.read_only:
59
+ raise RuntimeError("DocStore is read-only")
60
+ conn = self._connect()
61
+ try:
62
+ cursor = conn.cursor()
63
+ for doc_id, content in key_value_pairs:
64
+ cursor.execute(
65
+ """
66
+ INSERT OR REPLACE INTO documents
67
+ (doc_id, complete_document, main_content, images, source_type)
68
+ VALUES (?, ?, ?, ?, ?)
69
+ """,
70
+ (
71
+ doc_id,
72
+ json.dumps(content.get("complete_document", {}), ensure_ascii=False),
73
+ content.get("main_content", "") or "",
74
+ json.dumps(content.get("images", []), ensure_ascii=False),
75
+ content.get("source_type", "") or "",
76
+ ),
77
+ )
78
+ conn.commit()
79
+ finally:
80
+ conn.close()
81
+
82
+ def mget(self, keys: Sequence[str]) -> List[Optional[Any]]:
83
+ conn = self._connect()
84
+ try:
85
+ cursor = conn.cursor()
86
+ out: List[Optional[Any]] = []
87
+ for doc_id in keys:
88
+ cursor.execute(
89
+ "SELECT complete_document, main_content, images, source_type FROM documents WHERE doc_id = ?",
90
+ (doc_id,),
91
+ )
92
+ row = cursor.fetchone()
93
+ if not row:
94
+ out.append(None)
95
+ continue
96
+ complete_document, main_content, images, source_type = row
97
+ out.append(
98
+ {
99
+ "complete_document": json.loads(complete_document or "{}"),
100
+ "main_content": main_content or "",
101
+ "images": json.loads(images or "[]"),
102
+ "source_type": source_type or "",
103
+ }
104
+ )
105
+ return out
106
+ finally:
107
+ conn.close()
108
+
109
+ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
110
+ conn = self._connect()
111
+ try:
112
+ cursor = conn.cursor()
113
+ if prefix:
114
+ cursor.execute("SELECT doc_id FROM documents WHERE doc_id LIKE ?", (f"{prefix}%",))
115
+ else:
116
+ cursor.execute("SELECT doc_id FROM documents")
117
+ for (doc_id,) in cursor.fetchall():
118
+ yield str(doc_id)
119
+ finally:
120
+ conn.close()
121
+
122
+ def count(self) -> int:
123
+ conn = self._connect()
124
+ try:
125
+ cursor = conn.cursor()
126
+ cursor.execute("SELECT COUNT(*) FROM documents")
127
+ return int(cursor.fetchone()[0])
128
+ finally:
129
+ conn.close()
130
+
131
+ def count_by_source_type(self) -> dict:
132
+ conn = self._connect()
133
+ try:
134
+ cursor = conn.cursor()
135
+ cursor.execute("SELECT source_type, COUNT(*) FROM documents GROUP BY source_type")
136
+ counts = {}
137
+ for source_type, count in cursor.fetchall():
138
+ counts[str(source_type)] = int(count)
139
+ return counts
140
+ finally:
141
+ conn.close()
142
+
143
+
radiology_rag/embedding.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding utilities (OpenAI-compatible, API-first)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from langchain_openai import OpenAIEmbeddings
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class EmbeddingConfig:
12
+ base_url: str
13
+ api_key: str
14
+ model_name: str
15
+ batch_size: int = 32
16
+
17
+
18
+ class EmbeddingClient:
19
+ """Thin wrapper over LangChain OpenAIEmbeddings."""
20
+
21
+ def __init__(self, cfg: EmbeddingConfig):
22
+ self.cfg = cfg
23
+ self._emb = OpenAIEmbeddings(
24
+ base_url=cfg.base_url,
25
+ api_key=cfg.api_key,
26
+ model=cfg.model_name,
27
+ chunk_size=int(cfg.batch_size or 32),
28
+ )
29
+
30
+ def embed_query(self, text: str) -> list[float]:
31
+ return self._emb.embed_query(text)
32
+
33
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
34
+ return self._emb.embed_documents(texts)
35
+
36
+ @property
37
+ def langchain_embeddings(self) -> OpenAIEmbeddings:
38
+ """Expose the underlying LangChain embeddings for Chroma."""
39
+ return self._emb
40
+
41
+
radiology_rag/encyclopedia.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wikipedia encyclopedia retrieval (MediaWiki API)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import re
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import requests
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class WikipediaConfig:
17
+ language: str = "en"
18
+ user_agent: str = "RadiologyRAG-Space/1.0"
19
+ timeout_s: int = 15
20
+ max_chars_per_doc: int = 2000
21
+
22
+
23
+ class WikipediaEncyclopediaService:
24
+ def __init__(self, config: Optional[WikipediaConfig] = None):
25
+ self.config = config or WikipediaConfig()
26
+ self._session = requests.Session()
27
+ self._session.headers.update({"User-Agent": self.config.user_agent})
28
+
29
+ @property
30
+ def api_base(self) -> str:
31
+ return f"https://{self.config.language}.wikipedia.org/w/api.php"
32
+
33
+ @staticmethod
34
+ def _derive_search_query(user_query: str) -> str:
35
+ q = (user_query or "").strip()
36
+ if not q:
37
+ return ""
38
+
39
+ tokens = re.findall(r"[A-Za-z][A-Za-z'\\-]*", q.lower())
40
+ if not tokens:
41
+ return q
42
+
43
+ stop = {
44
+ "what",
45
+ "which",
46
+ "who",
47
+ "whom",
48
+ "whose",
49
+ "when",
50
+ "where",
51
+ "why",
52
+ "how",
53
+ "is",
54
+ "are",
55
+ "was",
56
+ "were",
57
+ "be",
58
+ "been",
59
+ "being",
60
+ "do",
61
+ "does",
62
+ "did",
63
+ "can",
64
+ "could",
65
+ "should",
66
+ "would",
67
+ "may",
68
+ "might",
69
+ "will",
70
+ "shall",
71
+ "a",
72
+ "an",
73
+ "the",
74
+ "and",
75
+ "or",
76
+ "but",
77
+ "to",
78
+ "of",
79
+ "for",
80
+ "with",
81
+ "without",
82
+ "in",
83
+ "on",
84
+ "at",
85
+ "by",
86
+ "from",
87
+ "as",
88
+ "it",
89
+ "its",
90
+ "this",
91
+ "that",
92
+ "these",
93
+ "those",
94
+ "your",
95
+ "my",
96
+ "their",
97
+ "our",
98
+ "about",
99
+ }
100
+
101
+ keep_short = {"ct", "mr", "mri", "pet", "us", "cxr"}
102
+ keywords: List[str] = []
103
+ seen = set()
104
+ for t in tokens:
105
+ if t in stop:
106
+ continue
107
+ if len(t) < 3 and t not in keep_short:
108
+ continue
109
+ if t in seen:
110
+ continue
111
+ seen.add(t)
112
+ keywords.append(t)
113
+
114
+ return " ".join(keywords[:8]) if keywords else q
115
+
116
+ def retrieve(self, query: str, top_k: int = 5, max_chars_per_doc: Optional[int] = None) -> List[Dict[str, Any]]:
117
+ q = (query or "").strip()
118
+ if not q:
119
+ return []
120
+
121
+ search_q = self._derive_search_query(q)
122
+ if not search_q:
123
+ return []
124
+
125
+ max_chars = int(max_chars_per_doc or self.config.max_chars_per_doc)
126
+ try:
127
+ search_params = {
128
+ "action": "query",
129
+ "list": "search",
130
+ "srsearch": search_q,
131
+ "srlimit": max(1, min(int(top_k), 20)),
132
+ "format": "json",
133
+ }
134
+ resp = self._session.get(self.api_base, params=search_params, timeout=self.config.timeout_s)
135
+ resp.raise_for_status()
136
+ data = resp.json() or {}
137
+ hits = (data.get("query", {}) or {}).get("search", []) or []
138
+
139
+ # Fallback to raw query if rewrite yields no hits
140
+ if not hits and search_q != q:
141
+ search_params["srsearch"] = q
142
+ resp = self._session.get(self.api_base, params=search_params, timeout=self.config.timeout_s)
143
+ resp.raise_for_status()
144
+ data = resp.json() or {}
145
+ hits = (data.get("query", {}) or {}).get("search", []) or []
146
+
147
+ if not hits:
148
+ return []
149
+
150
+ pageids = [str(h.get("pageid")) for h in hits if h.get("pageid") is not None]
151
+ if not pageids:
152
+ return []
153
+
154
+ pages_params = {
155
+ "action": "query",
156
+ "pageids": "|".join(pageids),
157
+ "prop": "extracts|info",
158
+ "explaintext": 1,
159
+ "exintro": 1,
160
+ "exchars": max_chars,
161
+ "inprop": "url",
162
+ "format": "json",
163
+ }
164
+ resp2 = self._session.get(self.api_base, params=pages_params, timeout=self.config.timeout_s)
165
+ resp2.raise_for_status()
166
+ pages_data = resp2.json() or {}
167
+ pages = (pages_data.get("query", {}) or {}).get("pages", {}) or {}
168
+
169
+ docs: List[Dict[str, Any]] = []
170
+ for pid in pageids:
171
+ page = pages.get(pid) or {}
172
+ title = page.get("title") or ""
173
+ extract = (page.get("extract") or "").strip()
174
+ url = page.get("fullurl") or ""
175
+ if not title or not extract:
176
+ continue
177
+ docs.append(
178
+ {
179
+ "doc_id": f"encyclopedia_{pid}",
180
+ "source_type": "encyclopedia",
181
+ "title": title,
182
+ "content": extract,
183
+ "url": url,
184
+ "metadata": {"provider": "wikipedia", "pageid": pid},
185
+ "score": 0.0,
186
+ }
187
+ )
188
+
189
+ return docs[: int(top_k)]
190
+ except Exception as e:
191
+ logger.warning(f"Wikipedia retrieval failed: {e}")
192
+ return []
193
+
194
+
radiology_rag/gradio_compat.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compatibility patches for Gradio when running with newer FastAPI/Pydantic versions.
3
+
4
+ Background:
5
+ - Gradio 4.16 defines `gradio.data_classes.PredictBody.request: Optional[fastapi.Request]`.
6
+ - Under Pydantic v2, `fastapi/starlette Request` cannot be converted into a JSON schema,
7
+ which can crash FastAPI request parsing for Gradio's `/run/{api_name}` endpoint.
8
+
9
+ This module applies a targeted runtime patch that replaces that field with `Any`.
10
+ It is intentionally narrow and only runs when we detect the problematic combination.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Any
16
+
17
+
18
+ def patch_gradio_predict_body_for_pydantic_v2() -> bool:
19
+ """Return True if a patch was applied."""
20
+ try:
21
+ import pydantic
22
+
23
+ major = int(str(pydantic.__version__).split(".", 1)[0])
24
+ if major < 2:
25
+ return False
26
+
27
+ import gradio.data_classes as gr_data_classes
28
+ import gradio.routes as gr_routes
29
+ from pydantic import BaseModel, ConfigDict, create_model
30
+ from typing import List, Optional
31
+
32
+ PredictBody = getattr(gr_data_classes, "PredictBody", None)
33
+ if PredictBody is None:
34
+ return False
35
+
36
+ ann = getattr(PredictBody, "__annotations__", {}) or {}
37
+ if "request" not in ann:
38
+ return False
39
+
40
+ PatchedPredictBody = create_model( # type: ignore[call-arg]
41
+ "PredictBody",
42
+ __base__=BaseModel,
43
+ __config__=ConfigDict(arbitrary_types_allowed=True),
44
+ session_hash=(Optional[str], None),
45
+ event_id=(Optional[str], None),
46
+ data=(List[Any], ...),
47
+ event_data=(Optional[Any], None),
48
+ fn_index=(Optional[int], None),
49
+ trigger_id=(Optional[int], None),
50
+ batched=(Optional[bool], False),
51
+ request=(Optional[Any], None),
52
+ )
53
+
54
+ gr_data_classes.PredictBody = PatchedPredictBody # type: ignore[attr-defined]
55
+ gr_routes.PredictBody = PatchedPredictBody # type: ignore[attr-defined]
56
+ return True
57
+ except Exception:
58
+ return False
59
+
60
+
radiology_rag/index_bootstrap.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Index bootstrap utilities for Hugging Face Spaces.
3
+
4
+ This Space relies on a prebuilt index stored on Hugging Face Datasets:
5
+ - ChromaDB persist directory (vector store)
6
+ - SQLite doc store (parent documents)
7
+
8
+ At startup we download (once) and place the index into a writable storage dir
9
+ (prefer /data on Spaces when persistent storage is enabled).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ import os
17
+ import shutil
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Any, Dict, Optional, Tuple
21
+
22
+ from huggingface_hub import snapshot_download
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ DEFAULT_INDEX_REPO_ID = "ZhangNy/radiology-index-qwen3-embedding-0.6b"
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class IndexPaths:
32
+ vector_db_path: Path
33
+ doc_store_path: Path
34
+ manifest_path: Optional[Path]
35
+ snapshot_dir: Optional[Path]
36
+
37
+
38
+ def resolve_default_storage_dir() -> Path:
39
+ """
40
+ Determine a good default storage directory for Spaces.
41
+
42
+ Priority:
43
+ - $RAG_STORAGE_DIR (user override)
44
+ - /data/radiology_rag (Spaces persistent storage)
45
+ - ./storage (local)
46
+ """
47
+ env = (os.getenv("RAG_STORAGE_DIR") or "").strip()
48
+ if env:
49
+ return Path(env)
50
+ if Path("/data").exists():
51
+ return Path("/data") / "radiology_rag"
52
+ return Path("./storage")
53
+
54
+
55
+ def _find_index_artifacts(snapshot_dir: Path) -> Tuple[Path, Path, Optional[Path]]:
56
+ """
57
+ Find (chroma_db_dir, doc_store_db, manifest_json) inside a HF snapshot.
58
+
59
+ We support either:
60
+ - chroma_db/, doc_store.db, manifest.json
61
+ - storage/chroma_db/, storage/doc_store.db, storage/manifest.json
62
+ """
63
+ candidates = [
64
+ (snapshot_dir / "chroma_db", snapshot_dir / "doc_store.db", snapshot_dir / "manifest.json"),
65
+ (snapshot_dir / "storage" / "chroma_db", snapshot_dir / "storage" / "doc_store.db", snapshot_dir / "storage" / "manifest.json"),
66
+ ]
67
+ for chroma_dir, doc_db, manifest in candidates:
68
+ if chroma_dir.exists() and chroma_dir.is_dir() and doc_db.exists() and doc_db.is_file():
69
+ return chroma_dir, doc_db, (manifest if manifest.exists() else None)
70
+
71
+ raise FileNotFoundError(
72
+ "Could not locate index artifacts inside snapshot. "
73
+ "Expected either {chroma_db/, doc_store.db} or {storage/chroma_db/, storage/doc_store.db}."
74
+ )
75
+
76
+
77
+ def read_manifest(manifest_path: Optional[Path]) -> Optional[Dict[str, Any]]:
78
+ if not manifest_path or not manifest_path.exists():
79
+ return None
80
+ try:
81
+ with open(manifest_path, "r", encoding="utf-8") as f:
82
+ return json.load(f) or {}
83
+ except Exception as e:
84
+ logger.warning(f"Failed to read manifest.json: {e}")
85
+ return None
86
+
87
+
88
+ def ensure_index(
89
+ *,
90
+ repo_id: str = DEFAULT_INDEX_REPO_ID,
91
+ revision: Optional[str] = None,
92
+ target_vector_db_path: Optional[str] = None,
93
+ target_doc_store_path: Optional[str] = None,
94
+ storage_dir: Optional[str] = None,
95
+ force_download: bool = False,
96
+ ) -> IndexPaths:
97
+ """
98
+ Ensure the index exists locally at the configured storage paths.
99
+
100
+ Returns resolved IndexPaths; raises on unrecoverable errors.
101
+ """
102
+ # Resolve target paths
103
+ if storage_dir:
104
+ base_dir = Path(storage_dir)
105
+ else:
106
+ base_dir = resolve_default_storage_dir()
107
+ base_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ vector_db_path = Path(target_vector_db_path) if target_vector_db_path else (base_dir / "chroma_db")
110
+ doc_store_path = Path(target_doc_store_path) if target_doc_store_path else (base_dir / "doc_store.db")
111
+
112
+ # Fast path: already present
113
+ if (
114
+ not force_download
115
+ and vector_db_path.exists()
116
+ and vector_db_path.is_dir()
117
+ and doc_store_path.exists()
118
+ and doc_store_path.is_file()
119
+ ):
120
+ logger.info(f"Index already present: vector_db={vector_db_path} doc_store={doc_store_path}")
121
+ manifest_path = (base_dir / "manifest.json") if (base_dir / "manifest.json").exists() else None
122
+ return IndexPaths(vector_db_path=vector_db_path, doc_store_path=doc_store_path, manifest_path=manifest_path, snapshot_dir=None)
123
+
124
+ # Download snapshot
125
+ repo_id = (repo_id or "").strip() or DEFAULT_INDEX_REPO_ID
126
+ logger.info(f"Downloading index snapshot from HF dataset repo: {repo_id} (revision={revision or 'main'})")
127
+ snapshot_dir = Path(
128
+ snapshot_download(
129
+ repo_id=repo_id,
130
+ repo_type="dataset",
131
+ revision=revision or None,
132
+ local_files_only=False,
133
+ )
134
+ )
135
+
136
+ src_chroma_dir, src_doc_db, src_manifest = _find_index_artifacts(snapshot_dir)
137
+ logger.info(f"Found index artifacts in snapshot: chroma={src_chroma_dir} doc_store={src_doc_db}")
138
+
139
+ # Copy to writable target locations
140
+ if vector_db_path.exists():
141
+ shutil.rmtree(vector_db_path, ignore_errors=True)
142
+ vector_db_path.parent.mkdir(parents=True, exist_ok=True)
143
+ shutil.copytree(src_chroma_dir, vector_db_path, dirs_exist_ok=False)
144
+
145
+ doc_store_path.parent.mkdir(parents=True, exist_ok=True)
146
+ shutil.copy2(src_doc_db, doc_store_path)
147
+
148
+ manifest_path: Optional[Path] = None
149
+ if src_manifest and src_manifest.exists():
150
+ manifest_path = doc_store_path.parent / "manifest.json"
151
+ try:
152
+ shutil.copy2(src_manifest, manifest_path)
153
+ except Exception as e:
154
+ logger.warning(f"Failed to copy manifest.json: {e}")
155
+ manifest_path = None
156
+
157
+ logger.info(f"Index ready: vector_db={vector_db_path} doc_store={doc_store_path}")
158
+ return IndexPaths(vector_db_path=vector_db_path, doc_store_path=doc_store_path, manifest_path=manifest_path, snapshot_dir=snapshot_dir)
159
+
160
+
radiology_rag/rag.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG engine (Spaces-ready, API-first).
3
+
4
+ Pipeline:
5
+ 1) Retrieve documents from prebuilt Chroma+SQLite index (local) + Wikipedia (optional)
6
+ 2) Rerank (API; auto-disables if missing key)
7
+ 3) Build a prompt with numbered citations [1], [2], ...
8
+ 4) Call LLM (OpenAI-compatible) and stream answer
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ import time
15
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
16
+
17
+ from langchain_core.prompts import PromptTemplate
18
+ from langchain_openai import ChatOpenAI
19
+
20
+ from radiology_rag.config import Config
21
+ from radiology_rag.citations import CitationManager
22
+ from radiology_rag.retrieval import MultiSourceRetrievalService, RetrievalService
23
+ from radiology_rag.reranker import RerankerConfig, RerankerService
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class RAGEngine:
29
+ """Retrieval-Augmented Generation engine for radiology queries."""
30
+
31
+ RAG_PROMPT_TEMPLATE = """You are a helpful radiology assistant with access to medical literature.
32
+
33
+ Answer the user's question based on the provided context.
34
+
35
+ **Rules**
36
+ 1. Use the context as primary evidence.
37
+ 2. Add numbered citations like [1], [2], ... immediately after the relevant sentences.
38
+ 3. Do NOT invent citations. Only cite sources that appear in the context.
39
+ 4. If the context does not contain enough information, say so and provide the best general explanation you can.
40
+
41
+ **Context**
42
+ {context}
43
+
44
+ **User Question**
45
+ {question}
46
+
47
+ **Answer (with citations)**
48
+ """
49
+
50
+ def __init__(self, config: Config):
51
+ self.config = config
52
+
53
+ # Retrieval + rerank
54
+ self.retrieval_service = RetrievalService(config)
55
+
56
+ rr_cfg = RerankerConfig(
57
+ enabled=config.get_bool("reranker.enabled", True),
58
+ base_url=config.get_str("reranker.api_base_url"),
59
+ api_key=config.get_str("reranker.api_key"),
60
+ model_name=config.get_str("reranker.model_name"),
61
+ top_k=config.get_int("reranker.top_k", 10),
62
+ )
63
+ self.reranker_service = RerankerService(rr_cfg)
64
+
65
+ self.multi_source_retrieval_service = MultiSourceRetrievalService(
66
+ config=config,
67
+ retrieval_service=self.retrieval_service,
68
+ reranker_service=self.reranker_service,
69
+ )
70
+
71
+ self.citation_manager = CitationManager(max_content_length=config.get_int("citation.max_content_length", 900))
72
+
73
+ # LLM (OpenAI-compatible)
74
+ self.llm = ChatOpenAI(
75
+ base_url=config.get_str("llm.base_url"),
76
+ api_key=config.get_str("llm.api_key"),
77
+ model=config.get_str("llm.model_name"),
78
+ temperature=config.get_float("llm.temperature", 0.7),
79
+ max_tokens=config.get_int("llm.max_tokens", 2000),
80
+ )
81
+
82
+ self.prompt = PromptTemplate(template=self.RAG_PROMPT_TEMPLATE, input_variables=["context", "question"])
83
+
84
+ @staticmethod
85
+ def _normalize_retrieval_strategy(strategy: Optional[str]) -> str:
86
+ s = (strategy or "").strip() or "default"
87
+ if s not in {"default", "balanced_multi_source"}:
88
+ logger.warning(f"Unknown retrieval strategy '{s}', falling back to 'default'")
89
+ return "default"
90
+ return s
91
+
92
+ @staticmethod
93
+ def _format_context(documents: List[Dict[str, Any]], citation_indices: List[int]) -> str:
94
+ parts: List[str] = []
95
+ for doc, idx in zip(documents, citation_indices):
96
+ source_type = (doc.get("source_type") or "").upper()
97
+ title = doc.get("title") or "Untitled"
98
+ content = doc.get("content") or ""
99
+ url = doc.get("url") or ""
100
+
101
+ block = f"[{idx}] **{source_type}: {title}**\n{content}"
102
+ if url:
103
+ block += f"\nURL: {url}"
104
+ parts.append(block)
105
+
106
+ return "\n\n---\n\n".join(parts)
107
+
108
+ def _retrieve(self, *, question: str, top_k: Optional[int], source_filters: Optional[List[str]], retrieval_strategy: str):
109
+ if retrieval_strategy == "balanced_multi_source":
110
+ final_k = int(top_k or self.config.get_int("retrieval.multi_source.total_top_k", 8))
111
+ return self.multi_source_retrieval_service.retrieve(
112
+ query=question,
113
+ total_top_k=final_k,
114
+ source_filters=source_filters,
115
+ return_debug=True,
116
+ )
117
+
118
+ k = int(top_k or self.config.get_int("retrieval.top_k", 20))
119
+ docs = self.retrieval_service.retrieve(query=question, top_k=k, source_filters=source_filters)
120
+ return docs, None
121
+
122
+ def query_stream(
123
+ self,
124
+ *,
125
+ question: str,
126
+ top_k: Optional[int] = None,
127
+ source_filters: Optional[List[str]] = None,
128
+ retrieval_strategy: Optional[str] = None,
129
+ stream_yield_interval_s: float = 0.15,
130
+ stream_min_chars: int = 80,
131
+ ) -> Iterator[Dict[str, Any]]:
132
+ q = (question or "").strip()
133
+ if not q:
134
+ yield {
135
+ "type": "final",
136
+ "answer": "Please enter a question.",
137
+ "references": [],
138
+ "metadata": {"num_retrieved": 0, "num_reranked": 0, "retrieval_strategy": "default"},
139
+ }
140
+ return
141
+
142
+ self.citation_manager.clear()
143
+ retrieval_strategy_n = self._normalize_retrieval_strategy(
144
+ retrieval_strategy or self.config.get_str("retrieval.strategy", "default")
145
+ )
146
+
147
+ start_time = time.time()
148
+
149
+ # Step 1: retrieve
150
+ docs, retrieval_debug = self._retrieve(
151
+ question=q, top_k=top_k, source_filters=source_filters, retrieval_strategy=retrieval_strategy_n
152
+ )
153
+ if not docs:
154
+ yield {
155
+ "type": "final",
156
+ "answer": "I couldn't find any relevant information to answer your question.",
157
+ "references": [],
158
+ "metadata": {"num_retrieved": 0, "num_reranked": 0, "retrieval_strategy": retrieval_strategy_n},
159
+ }
160
+ return
161
+
162
+ # Step 2: rerank (default strategy only; balanced already reranks per-source)
163
+ if retrieval_strategy_n == "balanced_multi_source":
164
+ reranked_docs = sorted(docs, key=lambda d: float(d.get("score", 0.0)), reverse=True)
165
+ num_retrieved = int(sum((retrieval_debug.get("candidate_counts") or {}).values())) if retrieval_debug else len(docs)
166
+ retrieved_label = "Recalled (candidates)"
167
+ reranked_label = "Selected (after rerank)"
168
+ else:
169
+ reranked_docs = self.reranker_service.rerank(
170
+ query=q, documents=docs, top_k=self.config.get_int("reranker.top_k", 10)
171
+ )
172
+ num_retrieved = len(docs)
173
+ retrieved_label = "Retrieved"
174
+ reranked_label = "After Reranking"
175
+
176
+ # Step 3: citations + context + prompt
177
+ citation_indices = self.citation_manager.add_documents(reranked_docs)
178
+ context = self._format_context(reranked_docs, citation_indices)
179
+ prompt_text = self.prompt.format(context=context, question=q)
180
+
181
+ # Step 4: stream LLM
182
+ answer_parts: List[str] = []
183
+ buffered = ""
184
+ last_yield = time.monotonic()
185
+
186
+ try:
187
+ for chunk in self.llm.stream(prompt_text):
188
+ delta = getattr(chunk, "content", "") or ""
189
+ if not delta:
190
+ continue
191
+ answer_parts.append(delta)
192
+ buffered += delta
193
+ now = time.monotonic()
194
+ if (now - last_yield) >= float(stream_yield_interval_s) or len(buffered) >= int(stream_min_chars):
195
+ yield {"type": "answer", "answer": "".join(answer_parts)}
196
+ buffered = ""
197
+ last_yield = now
198
+
199
+ answer = "".join(answer_parts).strip()
200
+ if not answer:
201
+ response = self.llm.invoke(prompt_text)
202
+ answer = (response.content or "").strip()
203
+ except Exception as e:
204
+ logger.error(f"LLM streaming failed: {e}", exc_info=True)
205
+ try:
206
+ response = self.llm.invoke(prompt_text)
207
+ answer = (response.content or "").strip()
208
+ except Exception:
209
+ answer = "An error occurred while generating the answer. Please try again."
210
+
211
+ is_valid, invalid = self.citation_manager.validate_citations(answer)
212
+ elapsed = time.time() - start_time
213
+ source_dist = self.citation_manager.get_statistics().get("source_type_counts", {})
214
+
215
+ yield {
216
+ "type": "final",
217
+ "answer": answer,
218
+ "references": reranked_docs,
219
+ "metadata": {
220
+ "retrieval_strategy": retrieval_strategy_n,
221
+ "num_retrieved": num_retrieved,
222
+ "num_reranked": len(reranked_docs),
223
+ "retrieved_label": retrieved_label,
224
+ "reranked_label": reranked_label,
225
+ "citations_valid": is_valid,
226
+ "invalid_citations": invalid,
227
+ "source_type_distribution": source_dist,
228
+ "elapsed_time": elapsed,
229
+ "candidate_counts": (retrieval_debug.get("candidate_counts") if retrieval_debug else None),
230
+ "gated_counts": (retrieval_debug.get("gated_counts") if retrieval_debug else None),
231
+ "selected_counts": (retrieval_debug.get("selected_counts") if retrieval_debug else None),
232
+ },
233
+ }
234
+
235
+ def query(
236
+ self,
237
+ *,
238
+ question: str,
239
+ top_k: Optional[int] = None,
240
+ source_filters: Optional[List[str]] = None,
241
+ retrieval_strategy: Optional[str] = None,
242
+ ) -> Dict[str, Any]:
243
+ """Non-streaming convenience wrapper."""
244
+ final: Dict[str, Any] = {}
245
+ for event in self.query_stream(
246
+ question=question,
247
+ top_k=top_k,
248
+ source_filters=source_filters,
249
+ retrieval_strategy=retrieval_strategy,
250
+ stream_yield_interval_s=999.0,
251
+ stream_min_chars=10**9,
252
+ ):
253
+ if event.get("type") == "final":
254
+ final = event
255
+ return final
256
+
257
+
radiology_rag/reranker.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reranker client (OpenAI-compatible API).
3
+
4
+ Expected endpoint:
5
+ POST {base_url}/rerank
6
+ {
7
+ "model": "...",
8
+ "query": "...",
9
+ "documents": ["...", "..."],
10
+ "top_n": 10
11
+ }
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ import os
18
+ import time
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional
21
+
22
+ import requests
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class RerankerConfig:
29
+ enabled: bool
30
+ base_url: str
31
+ api_key: str
32
+ model_name: str
33
+ top_k: int = 10
34
+
35
+
36
+ class NoOpReranker:
37
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
38
+ k = int(top_k or len(documents))
39
+ out = []
40
+ for d in documents[:k]:
41
+ dc = dict(d)
42
+ dc.setdefault("score", 0.0)
43
+ out.append(dc)
44
+ return out
45
+
46
+
47
+ class APIReranker:
48
+ def __init__(self, cfg: RerankerConfig):
49
+ self.cfg = cfg
50
+
51
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
52
+ if not documents:
53
+ return []
54
+
55
+ texts = [doc.get("content", "") for doc in documents]
56
+ k = int(top_k or self.cfg.top_k or len(documents))
57
+ k = max(1, min(k, len(documents)))
58
+
59
+ max_docs_per_request = int(os.getenv("RERANK_MAX_DOCS_PER_REQUEST", "64"))
60
+ max_docs_per_request = max(1, max_docs_per_request)
61
+
62
+ def _call_once(doc_texts: List[str], top_n: int) -> dict:
63
+ last_err: Optional[Exception] = None
64
+ for attempt in range(3):
65
+ try:
66
+ resp = requests.post(
67
+ f"{self.cfg.base_url.rstrip('/')}/rerank",
68
+ json={
69
+ "model": self.cfg.model_name,
70
+ "query": query,
71
+ "documents": doc_texts,
72
+ "top_n": int(top_n),
73
+ },
74
+ headers={"Authorization": f"Bearer {self.cfg.api_key}"},
75
+ timeout=30,
76
+ )
77
+ resp.raise_for_status()
78
+ return resp.json() or {}
79
+ except Exception as e:
80
+ last_err = e
81
+ if attempt < 2:
82
+ time.sleep(0.5 * (attempt + 1))
83
+ raise last_err or RuntimeError("Unknown reranker API error")
84
+
85
+ try:
86
+ if len(texts) <= max_docs_per_request:
87
+ result = _call_once(texts, top_n=k)
88
+ reranked_docs: List[Dict[str, Any]] = []
89
+ for item in result.get("results", []) or []:
90
+ idx = item.get("index")
91
+ score = float(item.get("relevance_score", 0.0) or 0.0)
92
+ if idx is None:
93
+ continue
94
+ idx = int(idx)
95
+ if idx < 0 or idx >= len(documents):
96
+ continue
97
+ dc = dict(documents[idx])
98
+ dc["score"] = score
99
+ reranked_docs.append(dc)
100
+ return reranked_docs
101
+
102
+ # Chunked: score all docs per chunk then globally sort.
103
+ scored: List[Dict[str, Any]] = []
104
+ for offset in range(0, len(texts), max_docs_per_request):
105
+ chunk_texts = texts[offset : offset + max_docs_per_request]
106
+ result = _call_once(chunk_texts, top_n=len(chunk_texts))
107
+ for item in result.get("results", []) or []:
108
+ idx = item.get("index")
109
+ if idx is None:
110
+ continue
111
+ global_idx = offset + int(idx)
112
+ if global_idx < 0 or global_idx >= len(documents):
113
+ continue
114
+ score = float(item.get("relevance_score", 0.0) or 0.0)
115
+ dc = dict(documents[global_idx])
116
+ dc["score"] = score
117
+ scored.append(dc)
118
+
119
+ scored.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
120
+ return scored[:k]
121
+ except Exception as e:
122
+ logger.warning(f"Reranker API failed; falling back to no-op ordering. Error: {e}")
123
+ return NoOpReranker().rerank(query, documents, top_k=k)
124
+
125
+
126
+ class RerankerService:
127
+ """High-level reranker wrapper (auto-disables if misconfigured)."""
128
+
129
+ def __init__(self, cfg: RerankerConfig):
130
+ self.cfg = cfg
131
+ if not cfg.enabled:
132
+ self._impl = NoOpReranker()
133
+ return
134
+ if not (cfg.api_key or "").strip():
135
+ logger.warning("Reranker enabled but RERANK_API_KEY is empty; disabling reranker.")
136
+ self._impl = NoOpReranker()
137
+ return
138
+ self._impl = APIReranker(cfg)
139
+
140
+ def rerank(self, query: str, documents: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
141
+ return self._impl.rerank(query, documents, top_k=top_k)
142
+
143
+
radiology_rag/retrieval.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector retrieval + balanced multi-source retrieval (with Wikipedia)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
8
+
9
+ from langchain_chroma import Chroma
10
+
11
+ from radiology_rag.config import Config
12
+ from radiology_rag.doc_store import PersistentDocStore
13
+ from radiology_rag.embedding import EmbeddingClient, EmbeddingConfig
14
+ from radiology_rag.encyclopedia import WikipediaConfig, WikipediaEncyclopediaService
15
+ from radiology_rag.reranker import RerankerService, RerankerConfig
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _dedupe_by_doc_id(docs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
21
+ seen = set()
22
+ out = []
23
+ for d in docs:
24
+ did = d.get("doc_id")
25
+ if not did or did in seen:
26
+ continue
27
+ seen.add(did)
28
+ out.append(d)
29
+ return out
30
+
31
+
32
+ class RetrievalService:
33
+ """Retrieve parent documents from Chroma + SQLite doc store."""
34
+
35
+ def __init__(self, config: Config):
36
+ self.config = config
37
+
38
+ vector_db_path = config.get_str("storage.vector_db_path")
39
+ doc_store_path = config.get_str("storage.doc_store_path")
40
+
41
+ self.embedding_client = EmbeddingClient(
42
+ EmbeddingConfig(
43
+ base_url=config.get_str("embedding.api_base_url"),
44
+ api_key=config.get_str("embedding.api_key"),
45
+ model_name=config.get_str("embedding.model_name"),
46
+ batch_size=config.get_int("embedding.batch_size", 32),
47
+ )
48
+ )
49
+
50
+ self.doc_store = PersistentDocStore(doc_store_path, read_only=True)
51
+ self.vectorstore = Chroma(
52
+ collection_name="radiology_docs",
53
+ embedding_function=self.embedding_client.langchain_embeddings,
54
+ persist_directory=vector_db_path,
55
+ )
56
+
57
+ self.search_type = config.get_str("retrieval.search_type", "similarity")
58
+ self.chunk_fetch_multiplier = config.get_int("retrieval.chunk_fetch_multiplier", 6)
59
+
60
+ def retrieve_candidates_by_vector(
61
+ self,
62
+ *,
63
+ query_embedding: List[float],
64
+ candidate_k: int,
65
+ source_type_filter: Optional[str] = None,
66
+ ) -> List[Dict[str, Any]]:
67
+ k = max(1, int(candidate_k))
68
+ chunk_k = max(k * int(self.chunk_fetch_multiplier), k)
69
+
70
+ chunk_filter = {"source_type": source_type_filter} if source_type_filter else None
71
+
72
+ # Fetch chunks from vector store using the embedding vector
73
+ if self.search_type == "mmr":
74
+ mmr_lambda = self.config.get_float("retrieval.mmr_lambda", 0.5)
75
+ fetch_k = max(self.config.get_int("retrieval.mmr_fetch_k", 50), chunk_k)
76
+ try:
77
+ chunk_docs = self.vectorstore.max_marginal_relevance_search_by_vector(
78
+ query_embedding,
79
+ k=chunk_k,
80
+ fetch_k=fetch_k,
81
+ lambda_mult=mmr_lambda,
82
+ filter=chunk_filter,
83
+ )
84
+ except TypeError:
85
+ chunk_docs = self.vectorstore.max_marginal_relevance_search_by_vector(
86
+ query_embedding,
87
+ k=chunk_k,
88
+ fetch_k=fetch_k,
89
+ lambda_mult=mmr_lambda,
90
+ )
91
+ else:
92
+ try:
93
+ chunk_docs = self.vectorstore.similarity_search_by_vector(
94
+ query_embedding,
95
+ k=chunk_k,
96
+ filter=chunk_filter,
97
+ )
98
+ except TypeError:
99
+ chunk_docs = self.vectorstore.similarity_search_by_vector(
100
+ query_embedding,
101
+ k=chunk_k,
102
+ )
103
+
104
+ # Unique parent IDs
105
+ parent_ids: List[str] = []
106
+ seen = set()
107
+ for doc in chunk_docs:
108
+ parent_id = (doc.metadata or {}).get("parent_id")
109
+ if not parent_id or parent_id in seen:
110
+ continue
111
+ seen.add(parent_id)
112
+ parent_ids.append(parent_id)
113
+ if len(parent_ids) >= k:
114
+ break
115
+
116
+ # Hydrate parents from doc store
117
+ parent_docs = self.doc_store.mget(parent_ids)
118
+ results: List[Dict[str, Any]] = []
119
+ for doc_id, doc_content in zip(parent_ids, parent_docs):
120
+ if doc_content is None:
121
+ continue
122
+ complete = doc_content.get("complete_document", {}) or {}
123
+ results.append(
124
+ {
125
+ "doc_id": doc_id,
126
+ "source_type": doc_content.get("source_type", "") or "",
127
+ "title": complete.get("title", "") or "",
128
+ "content": doc_content.get("main_content", "") or "",
129
+ "url": complete.get("url", "") or "",
130
+ "metadata": complete.get("metadata", {}) or {},
131
+ "score": 0.0,
132
+ }
133
+ )
134
+
135
+ logger.info(
136
+ f"Retrieved {len(results)} candidate parents (candidate_k={k}, source_type_filter={source_type_filter or 'ALL'})"
137
+ )
138
+ return results
139
+
140
+ def retrieve(
141
+ self,
142
+ *,
143
+ query: str,
144
+ top_k: int,
145
+ source_filters: Optional[List[str]] = None,
146
+ ) -> List[Dict[str, Any]]:
147
+ q = (query or "").strip()
148
+ if not q:
149
+ return []
150
+
151
+ query_embedding = self.embedding_client.embed_query(q)
152
+ k = max(1, int(top_k))
153
+
154
+ if not source_filters:
155
+ return self.retrieve_candidates_by_vector(query_embedding=query_embedding, candidate_k=k)
156
+
157
+ allowed = [s for s in source_filters if s in {"article", "case", "tutorial"}]
158
+ if not allowed:
159
+ # If only encyclopedia is selected, local retrieval yields none.
160
+ return []
161
+
162
+ if len(allowed) == 1:
163
+ return self.retrieve_candidates_by_vector(query_embedding=query_embedding, candidate_k=k, source_type_filter=allowed[0])
164
+
165
+ merged: List[Dict[str, Any]] = []
166
+ for st in allowed:
167
+ merged.extend(
168
+ self.retrieve_candidates_by_vector(query_embedding=query_embedding, candidate_k=k, source_type_filter=st)
169
+ )
170
+ merged = _dedupe_by_doc_id(merged)
171
+ return merged[:k]
172
+
173
+ def get_document_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]:
174
+ doc_id = (doc_id or "").strip()
175
+ if not doc_id:
176
+ return None
177
+ docs = self.doc_store.mget([doc_id])
178
+ if not docs or not docs[0]:
179
+ return None
180
+ doc_content = docs[0]
181
+ complete = doc_content.get("complete_document", {}) or {}
182
+ return {
183
+ "doc_id": doc_id,
184
+ "source_type": doc_content.get("source_type", "") or "",
185
+ "title": complete.get("title", "") or "",
186
+ "content": doc_content.get("main_content", "") or "",
187
+ "url": complete.get("url", "") or "",
188
+ "metadata": complete.get("metadata", {}) or {},
189
+ }
190
+
191
+
192
+ @dataclass
193
+ class PerSourcePolicy:
194
+ candidate_k: int
195
+ max_k: int
196
+ min_score: float
197
+ required: bool = False
198
+
199
+
200
+ @dataclass
201
+ class BalancedMultiSourcePolicy:
202
+ total_top_k: int = 8
203
+ sources_priority: Sequence[str] = ("article", "case", "encyclopedia", "tutorial")
204
+ article: PerSourcePolicy = field(
205
+ default_factory=lambda: PerSourcePolicy(candidate_k=200, max_k=3, min_score=0.15, required=True)
206
+ )
207
+ case: PerSourcePolicy = field(
208
+ default_factory=lambda: PerSourcePolicy(candidate_k=200, max_k=3, min_score=0.15, required=True)
209
+ )
210
+ encyclopedia: PerSourcePolicy = field(
211
+ default_factory=lambda: PerSourcePolicy(candidate_k=8, max_k=2, min_score=0.15, required=True)
212
+ )
213
+ tutorial: PerSourcePolicy = field(
214
+ default_factory=lambda: PerSourcePolicy(candidate_k=20, max_k=2, min_score=0.50, required=False)
215
+ )
216
+
217
+
218
+ class MultiSourceRetrievalService:
219
+ """Per-source recall + per-source rerank + gating + merge (includes Wikipedia)."""
220
+
221
+ def __init__(
222
+ self,
223
+ config: Config,
224
+ retrieval_service: Optional[RetrievalService] = None,
225
+ reranker_service: Optional[RerankerService] = None,
226
+ encyclopedia_service: Optional[WikipediaEncyclopediaService] = None,
227
+ ):
228
+ self.config = config
229
+ self.retrieval_service = retrieval_service or RetrievalService(config)
230
+
231
+ rr_cfg = RerankerConfig(
232
+ enabled=config.get_bool("reranker.enabled", True),
233
+ base_url=config.get_str("reranker.api_base_url"),
234
+ api_key=config.get_str("reranker.api_key"),
235
+ model_name=config.get_str("reranker.model_name"),
236
+ top_k=config.get_int("reranker.top_k", 10),
237
+ )
238
+ self.reranker_service = reranker_service or RerankerService(rr_cfg)
239
+
240
+ wiki_cfg = WikipediaConfig(
241
+ language=config.get_str("encyclopedia.wikipedia.language", "en"),
242
+ user_agent=config.get_str("encyclopedia.wikipedia.user_agent", "RadiologyRAG-Space/1.0"),
243
+ timeout_s=config.get_int("encyclopedia.wikipedia.timeout_s", 15),
244
+ max_chars_per_doc=config.get_int("encyclopedia.wikipedia.max_chars_per_doc", 2000),
245
+ )
246
+ self.encyclopedia_service = encyclopedia_service or WikipediaEncyclopediaService(wiki_cfg)
247
+
248
+ def _load_policy(self, total_top_k: Optional[int] = None) -> BalancedMultiSourcePolicy:
249
+ total = int(total_top_k or self.config.get_int("retrieval.multi_source.total_top_k", 8))
250
+ total = max(1, total)
251
+
252
+ def pol(name: str, default: PerSourcePolicy) -> PerSourcePolicy:
253
+ base = f"retrieval.multi_source.{name}"
254
+ return PerSourcePolicy(
255
+ candidate_k=self.config.get_int(f"{base}.candidate_k", default.candidate_k),
256
+ max_k=self.config.get_int(f"{base}.max_k", default.max_k),
257
+ min_score=self.config.get_float(f"{base}.min_score", default.min_score),
258
+ required=self.config.get_bool(f"{base}.required", default.required),
259
+ )
260
+
261
+ defaults = BalancedMultiSourcePolicy(total_top_k=total)
262
+ sources_priority = self.config.get("retrieval.multi_source.sources_priority", defaults.sources_priority)
263
+ sources_priority = tuple(sources_priority) if isinstance(sources_priority, (list, tuple)) else defaults.sources_priority
264
+
265
+ return BalancedMultiSourcePolicy(
266
+ total_top_k=total,
267
+ sources_priority=sources_priority,
268
+ article=pol("article", defaults.article),
269
+ case=pol("case", defaults.case),
270
+ encyclopedia=pol("encyclopedia", defaults.encyclopedia),
271
+ tutorial=pol("tutorial", defaults.tutorial),
272
+ )
273
+
274
+ @staticmethod
275
+ def _gate_and_trim(docs: List[Dict[str, Any]], policy: PerSourcePolicy) -> List[Dict[str, Any]]:
276
+ filtered = [d for d in docs if float(d.get("score", 0.0)) >= float(policy.min_score)]
277
+ return filtered[: max(0, int(policy.max_k))]
278
+
279
+ def retrieve(
280
+ self,
281
+ *,
282
+ query: str,
283
+ total_top_k: Optional[int] = None,
284
+ source_filters: Optional[List[str]] = None,
285
+ return_debug: bool = False,
286
+ ) -> Any:
287
+ q = (query or "").strip()
288
+ if not q:
289
+ return ([], {}) if return_debug else []
290
+
291
+ policy = self._load_policy(total_top_k=total_top_k)
292
+ allowed = set(source_filters) if source_filters else set(policy.sources_priority)
293
+ allowed = {s for s in allowed if s in {"article", "case", "tutorial", "encyclopedia"}}
294
+ if not allowed:
295
+ allowed = set(policy.sources_priority)
296
+
297
+ # Compute query embedding once for local sources
298
+ needs_local = any(s in allowed for s in {"article", "case", "tutorial"})
299
+ query_embedding = self.retrieval_service.embedding_client.embed_query(q) if needs_local else None
300
+
301
+ # 1) Recall candidates
302
+ candidates: Dict[str, List[Dict[str, Any]]] = {}
303
+ if "article" in allowed and query_embedding is not None:
304
+ candidates["article"] = self.retrieval_service.retrieve_candidates_by_vector(
305
+ query_embedding=query_embedding, candidate_k=policy.article.candidate_k, source_type_filter="article"
306
+ )
307
+ if "case" in allowed and query_embedding is not None:
308
+ candidates["case"] = self.retrieval_service.retrieve_candidates_by_vector(
309
+ query_embedding=query_embedding, candidate_k=policy.case.candidate_k, source_type_filter="case"
310
+ )
311
+ if "tutorial" in allowed and query_embedding is not None:
312
+ candidates["tutorial"] = self.retrieval_service.retrieve_candidates_by_vector(
313
+ query_embedding=query_embedding, candidate_k=policy.tutorial.candidate_k, source_type_filter="tutorial"
314
+ )
315
+ if "encyclopedia" in allowed:
316
+ candidates["encyclopedia"] = self.encyclopedia_service.retrieve(
317
+ q,
318
+ top_k=policy.encyclopedia.candidate_k,
319
+ max_chars_per_doc=self.config.get_int("encyclopedia.wikipedia.max_chars_per_doc", 2000),
320
+ )
321
+
322
+ # 2) Rerank per-source (full ordering; trim later)
323
+ reranked: Dict[str, List[Dict[str, Any]]] = {}
324
+ for src, docs in candidates.items():
325
+ if not docs:
326
+ reranked[src] = []
327
+ continue
328
+ rr = self.reranker_service.rerank(query=q, documents=docs, top_k=len(docs))
329
+ reranked[src] = rr
330
+
331
+ # 3) Gating
332
+ gated: Dict[str, List[Dict[str, Any]]] = {}
333
+ if "article" in reranked:
334
+ gated["article"] = self._gate_and_trim(reranked["article"], policy.article)
335
+ if "case" in reranked:
336
+ gated["case"] = self._gate_and_trim(reranked["case"], policy.case)
337
+ if "encyclopedia" in reranked:
338
+ gated["encyclopedia"] = self._gate_and_trim(reranked["encyclopedia"], policy.encyclopedia)
339
+
340
+ if "tutorial" in reranked:
341
+ tdocs = reranked["tutorial"]
342
+ best = float(tdocs[0].get("score", 0.0)) if tdocs else 0.0
343
+ if best >= policy.tutorial.min_score:
344
+ gated["tutorial"] = tdocs[: max(0, int(policy.tutorial.max_k))]
345
+ else:
346
+ gated["tutorial"] = []
347
+
348
+ # 4) Merge with global budget
349
+ selected: List[Dict[str, Any]] = []
350
+ selected_ids = set()
351
+
352
+ def _add(doc: Dict[str, Any]) -> None:
353
+ did = doc.get("doc_id")
354
+ if not did or did in selected_ids:
355
+ return
356
+ selected.append(doc)
357
+ selected_ids.add(did)
358
+
359
+ src_to_pol = {
360
+ "article": policy.article,
361
+ "case": policy.case,
362
+ "encyclopedia": policy.encyclopedia,
363
+ "tutorial": policy.tutorial,
364
+ }
365
+
366
+ # required pass
367
+ for src in policy.sources_priority:
368
+ if src not in allowed:
369
+ continue
370
+ p = src_to_pol[src]
371
+ if not p.required:
372
+ continue
373
+ if gated.get(src):
374
+ _add(gated[src][0])
375
+
376
+ # fill by global score
377
+ remaining_pool: List[Dict[str, Any]] = []
378
+ for docs in gated.values():
379
+ for d in docs:
380
+ if d.get("doc_id") not in selected_ids:
381
+ remaining_pool.append(d)
382
+ remaining_pool.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
383
+
384
+ for d in remaining_pool:
385
+ if len(selected) >= policy.total_top_k:
386
+ break
387
+ _add(d)
388
+
389
+ debug = {
390
+ "allowed_sources": sorted(list(allowed)),
391
+ "candidate_counts": {k: len(v) for k, v in candidates.items()},
392
+ "gated_counts": {k: len(v) for k, v in gated.items()},
393
+ "selected_counts": {},
394
+ }
395
+ for d in selected:
396
+ st = d.get("source_type", "unknown")
397
+ debug["selected_counts"][st] = debug["selected_counts"].get(st, 0) + 1
398
+
399
+ if return_debug:
400
+ return selected, debug
401
+ return selected
402
+
403
+
radiology_rag/ui.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio UI for the Radiology RAG Space."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
11
+
12
+ import gradio as gr
13
+
14
+ from radiology_rag.config import Config
15
+ from radiology_rag.index_bootstrap import ensure_index, read_manifest
16
+ from radiology_rag.rag import RAGEngine
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _truncate(text: str, max_len: int) -> str:
22
+ s = (text or "").strip()
23
+ if len(s) <= max_len:
24
+ return s
25
+ return s[: max(0, max_len - 3)] + "..."
26
+
27
+
28
+ def format_error_message(error: str) -> str:
29
+ return f"**⚠️ Error**\n\n{error}"
30
+
31
+
32
+ def format_loading_message() -> str:
33
+ return "**🔄 Processing your query...**\n\nRetrieving relevant sources and generating an answer with citations."
34
+
35
+
36
+ def format_reference_card(doc: Dict[str, Any], index: int) -> str:
37
+ title = doc.get("title", "Untitled") or "Untitled"
38
+ source_type = (doc.get("source_type") or "").upper()
39
+ url = doc.get("url", "") or ""
40
+ content = doc.get("content", "") or ""
41
+ score = float(doc.get("score", 0.0) or 0.0)
42
+
43
+ max_preview_length = 350
44
+ preview = _truncate(content, max_preview_length).replace("\n", " ")
45
+
46
+ type_colors = {
47
+ "ARTICLE": "#3b82f6",
48
+ "CASE": "#10b981",
49
+ "TUTORIAL": "#f59e0b",
50
+ "ENCYCLOPEDIA": "#8b5cf6",
51
+ }
52
+ color = type_colors.get(source_type, "#6b7280")
53
+
54
+ score_html = f"<span style='color:#6b7280;font-size:12px;'>Score: {score:.3f}</span>" if score > 0 else ""
55
+ url_html = (
56
+ f"<p style='margin:0 0 8px 0;font-size:12px;'><a href='{url}' target='_blank' "
57
+ f"style='color:#3b82f6;text-decoration:none;'>🔗 View Source</a></p>"
58
+ if url
59
+ else ""
60
+ )
61
+
62
+ return f"""
63
+ <div id="ref-{index}" style="border:1px solid #e5e7eb;border-radius:8px;padding:16px;margin-bottom:16px;background:white;scroll-margin-top:90px;">
64
+ <div style="display:flex;align-items:center;gap:8px;margin-bottom:12px;flex-wrap:wrap;">
65
+ <span style="background:{color};color:white;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;">
66
+ {source_type or "SOURCE"}
67
+ </span>
68
+ <span style="background:#f3f4f6;color:#374151;padding:4px 12px;border-radius:12px;font-size:12px;font-weight:600;">
69
+ [{index}]
70
+ </span>
71
+ {score_html}
72
+ </div>
73
+ <h3 style="margin:0 0 8px 0;color:#111827;font-size:18px;">{title}</h3>
74
+ {url_html}
75
+ <p style="margin:0;color:#4b5563;font-size:14px;line-height:1.5;">{preview}</p>
76
+ </div>
77
+ """
78
+
79
+
80
+ def format_reference_panel(references: List[Dict[str, Any]]) -> str:
81
+ if not references:
82
+ return "<p style='color:#6b7280;text-align:center;padding:20px;'>No references available</p>"
83
+ html_parts = ['<div style="max-height: 600px; overflow-y: auto;">']
84
+ for i, doc in enumerate(references, 1):
85
+ html_parts.append(format_reference_card(doc, i))
86
+ html_parts.append("</div>")
87
+ return "".join(html_parts)
88
+
89
+
90
+ def format_statistics(metadata: Dict[str, Any]) -> str:
91
+ num_retrieved = int(metadata.get("num_retrieved", 0) or 0)
92
+ num_reranked = int(metadata.get("num_reranked", 0) or 0)
93
+ source_dist = metadata.get("source_type_distribution", {}) or {}
94
+ retrieved_label = metadata.get("retrieved_label", "Retrieved")
95
+ reranked_label = metadata.get("reranked_label", "After Reranking")
96
+ elapsed = float(metadata.get("elapsed_time", 0.0) or 0.0)
97
+ strategy = metadata.get("retrieval_strategy", "")
98
+
99
+ chips = "".join(
100
+ [
101
+ f"<span style='display:inline-block;background:#e5e7eb;color:#111827;padding:4px 8px;border-radius:4px;margin-right:8px;font-size:12px;line-height:1.2;'>{k}: {v}</span>"
102
+ for k, v in source_dist.items()
103
+ ]
104
+ )
105
+
106
+ return f"""
107
+ <div style="background:#f9fafb;padding:16px;border-radius:8px;margin-top:16px;">
108
+ <h4 style="margin:0 0 12px 0;color:#374151;font-size:14px;">📊 Query Statistics</h4>
109
+ <div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(150px,1fr));gap:12px;">
110
+ <div>
111
+ <p style="margin:0;color:#6b7280;font-size:12px;">{retrieved_label}</p>
112
+ <p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_retrieved}</p>
113
+ </div>
114
+ <div>
115
+ <p style="margin:0;color:#6b7280;font-size:12px;">{reranked_label}</p>
116
+ <p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{num_reranked}</p>
117
+ </div>
118
+ <div>
119
+ <p style="margin:0;color:#6b7280;font-size:12px;">Elapsed</p>
120
+ <p style="margin:0;color:#111827;font-size:20px;font-weight:600;">{elapsed:.2f}s</p>
121
+ </div>
122
+ </div>
123
+ <div style="margin-top:12px;">
124
+ <p style="margin:0 0 6px 0;color:#6b7280;font-size:12px;">Retrieval Strategy: <code>{strategy}</code></p>
125
+ <p style="margin:0 0 4px 0;color:#6b7280;font-size:12px;">Source Distribution:</p>
126
+ {chips if chips else "<span style='color:#6b7280;font-size:12px;'>N/A</span>"}
127
+ </div>
128
+ </div>
129
+ """
130
+
131
+
132
+ def create_settings_accordion(
133
+ *,
134
+ default_strategy: str,
135
+ default_temperature: float,
136
+ default_sources: List[str],
137
+ ) -> Tuple[gr.Radio, gr.Slider, gr.CheckboxGroup]:
138
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
139
+ gr.Markdown(
140
+ "#### Retrieval Strategy\n"
141
+ "- **default**: one mixed retrieval + single rerank (fast)\n"
142
+ "- **balanced_multi_source**: per-source recall + per-source rerank + Wikipedia (more diverse)\n"
143
+ )
144
+
145
+ retrieval_strategy = gr.Radio(
146
+ choices=["default", "balanced_multi_source"],
147
+ value=default_strategy,
148
+ label="Retrieval Strategy",
149
+ )
150
+
151
+ temperature_slider = gr.Slider(
152
+ minimum=0.0,
153
+ maximum=1.0,
154
+ value=float(default_temperature),
155
+ step=0.1,
156
+ label="LLM Temperature",
157
+ )
158
+
159
+ source_filter = gr.CheckboxGroup(
160
+ choices=["article", "case", "tutorial", "encyclopedia"],
161
+ value=default_sources,
162
+ label="Filter by Source Type",
163
+ )
164
+
165
+ return retrieval_strategy, temperature_slider, source_filter
166
+
167
+
168
+ class RadiologyRAGApp:
169
+ def __init__(self, config_path: str):
170
+ self.config = Config(config_path)
171
+ self.startup_error: Optional[str] = None
172
+ self.startup_warnings: List[str] = []
173
+ self.index_manifest: Optional[Dict[str, Any]] = None
174
+ self.rag_engine: Optional[RAGEngine] = None
175
+
176
+ # Validate required secrets
177
+ missing: List[str] = []
178
+ if not self.config.get_str("embedding.api_key"):
179
+ missing.append("EMBED_API_KEY")
180
+ if not self.config.get_str("llm.api_key"):
181
+ missing.append("LLM_API_KEY")
182
+ if missing:
183
+ self.startup_error = (
184
+ "Missing required Hugging Face Space Secrets: "
185
+ + ", ".join([f"`{m}`" for m in missing])
186
+ + ".\n\nPlease set them in the Space **Settings → Secrets** and restart the Space."
187
+ )
188
+ return
189
+
190
+ # Reranker is optional; warn if enabled but missing key
191
+ if self.config.get_bool("reranker.enabled", True) and not self.config.get_str("reranker.api_key"):
192
+ self.startup_warnings.append(
193
+ "Reranker is enabled but `RERANK_API_KEY` is missing. Reranking will be disabled (fallback to no-op)."
194
+ )
195
+
196
+ # Ensure index exists locally (download if needed)
197
+ try:
198
+ idx = ensure_index(
199
+ repo_id=self.config.get_str("index.repo_id"),
200
+ revision=self.config.get_str("index.revision", "main") or None,
201
+ target_vector_db_path=self.config.get_str("storage.vector_db_path"),
202
+ target_doc_store_path=self.config.get_str("storage.doc_store_path"),
203
+ storage_dir=str(Path(self.config.get_str("storage.doc_store_path")).parent),
204
+ )
205
+ self.index_manifest = read_manifest(idx.manifest_path)
206
+
207
+ # Optional: warn if embedding model differs
208
+ if self.index_manifest:
209
+ idx_model = (
210
+ (self.index_manifest.get("embedding") or {}).get("model_name")
211
+ or self.index_manifest.get("embedding_model")
212
+ or ""
213
+ )
214
+ cfg_model = self.config.get_str("embedding.model_name")
215
+ if idx_model and cfg_model and idx_model != cfg_model:
216
+ self.startup_warnings.append(
217
+ f"Index embedding model mismatch: index='{idx_model}' vs config='{cfg_model}'. "
218
+ "For best results, rebuild the index with the same embedding model."
219
+ )
220
+
221
+ except Exception as e:
222
+ # Try to provide actionable guidance for common HF Hub errors.
223
+ repo_id = self.config.get_str("index.repo_id")
224
+ try:
225
+ from huggingface_hub.utils import ( # type: ignore
226
+ GatedRepoError,
227
+ HfHubHTTPError,
228
+ RepositoryNotFoundError,
229
+ )
230
+
231
+ if isinstance(e, RepositoryNotFoundError):
232
+ self.startup_error = (
233
+ f"Index dataset repo not found: `{repo_id}`.\n\n"
234
+ "If you haven't uploaded the prebuilt index yet, build and publish it locally:\n"
235
+ "1) `pip install -r requirements-dev.txt`\n"
236
+ "2) `python scripts/build_vector_db.py --config config/default_config.yaml --source huggingface --dataset ZhangNy/radiology-dataset --output-dir ./index_out`\n"
237
+ f"3) `python scripts/publish_index_to_hf.py --repo {repo_id} --folder ./index_out --token $HF_TOKEN`\n\n"
238
+ "Or set `RAG_INDEX_REPO_ID` to an existing index repo."
239
+ )
240
+ return
241
+ if isinstance(e, GatedRepoError):
242
+ self.startup_error = (
243
+ f"Index dataset repo is gated/private: `{repo_id}`.\n\n"
244
+ "Make sure the repo is public, or provide authentication (HF token) in the environment."
245
+ )
246
+ return
247
+ if isinstance(e, HfHubHTTPError):
248
+ self.startup_error = (
249
+ f"Failed to download index from `{repo_id}`.\n\n"
250
+ f"HF Hub error: {e}"
251
+ )
252
+ return
253
+ except Exception:
254
+ # If importing HF-specific exceptions fails, fall back to generic message.
255
+ pass
256
+
257
+ self.startup_error = (
258
+ f"Failed to prepare index from `{repo_id}`.\n\n"
259
+ f"Error: {e}"
260
+ )
261
+ return
262
+
263
+ # Build RAG engine
264
+ try:
265
+ self.rag_engine = RAGEngine(self.config)
266
+ except Exception as e:
267
+ self.startup_error = f"Failed to initialize RAG engine: {e}"
268
+ return
269
+
270
+ def process_query(
271
+ self,
272
+ question: str,
273
+ temperature: float,
274
+ source_filters: List[str],
275
+ retrieval_strategy: str,
276
+ ) -> Iterator[Tuple[str, str, str]]:
277
+ if self.startup_error:
278
+ yield format_error_message(self.startup_error), "", ""
279
+ return
280
+ if self.rag_engine is None:
281
+ yield format_error_message("RAG engine not initialized."), "", ""
282
+ return
283
+
284
+ q = (question or "").strip()
285
+ if not q:
286
+ yield format_error_message("Please enter a question."), "", ""
287
+ return
288
+
289
+ # Update LLM temperature on the fly
290
+ try:
291
+ self.rag_engine.llm.temperature = float(temperature)
292
+ except Exception:
293
+ pass
294
+
295
+ sources = source_filters or []
296
+ loading_md = (
297
+ f"{format_loading_message()}\n\n"
298
+ f"**Retrieval Strategy**: `{retrieval_strategy}`\n\n"
299
+ f"**Sources**: `{', '.join(sources) if sources else 'ALL'}`"
300
+ )
301
+ loading_refs = "<p style='color:#6b7280;text-align:center;padding:20px;'>Retrieving & reranking...</p>"
302
+ loading_stats = "<p style='color:#6b7280;padding:10px;'>Working...</p>"
303
+ yield loading_md, loading_refs, loading_stats
304
+
305
+ start_time = time.time()
306
+ last_partial = ""
307
+
308
+ try:
309
+ for event in self.rag_engine.query_stream(
310
+ question=q,
311
+ source_filters=sources if sources else None,
312
+ retrieval_strategy=retrieval_strategy,
313
+ ):
314
+ etype = (event or {}).get("type")
315
+ if etype == "answer":
316
+ partial = (event.get("answer") or "")
317
+ if partial and partial != last_partial:
318
+ # Make citations clickable: [1] -> [1](#ref-1)
319
+ answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", partial)
320
+ last_partial = partial
321
+ yield answer_md, loading_refs, loading_stats
322
+ elif etype == "final":
323
+ meta = event.get("metadata") or {}
324
+ # If engine didn't populate elapsed_time (it does), we fill it.
325
+ meta.setdefault("elapsed_time", time.time() - start_time)
326
+
327
+ final_answer = (event.get("answer") or "")
328
+ answer_md = re.sub(r"\[(\d+)\](?!\()", r"[\1](#ref-\1)", final_answer)
329
+ references_html = format_reference_panel(event.get("references") or [])
330
+ stats_html = format_statistics(meta)
331
+ yield answer_md, references_html, stats_html
332
+ return
333
+
334
+ yield format_error_message("No response was generated. Please try again."), "", ""
335
+ except Exception as e:
336
+ logger.error(f"Error processing query: {e}", exc_info=True)
337
+ yield format_error_message(f"An error occurred: {e}"), "", ""
338
+
339
+ def create_interface(self) -> gr.Blocks:
340
+ title = self.config.get_str("ui.title", "Radiology RAG")
341
+ description = self.config.get_str("ui.description", "")
342
+ theme = self.config.get_str("ui.theme", "soft")
343
+ default_strategy = self.config.get_str("retrieval.strategy", "balanced_multi_source")
344
+ default_sources = self.config.get("retrieval.source_filters", ["article", "case", "tutorial", "encyclopedia"])
345
+ if not isinstance(default_sources, list):
346
+ default_sources = ["article", "case", "tutorial", "encyclopedia"]
347
+ default_temp = self.config.get_float("llm.temperature", 0.7)
348
+
349
+ with gr.Blocks(title=title, theme=theme) as interface:
350
+ gr.Markdown(f"# {title}")
351
+ if description:
352
+ gr.Markdown(description)
353
+
354
+ if self.startup_error:
355
+ gr.Markdown(format_error_message(self.startup_error))
356
+ gr.Markdown(
357
+ "### Required Secrets\n"
358
+ "- `EMBED_API_KEY`\n"
359
+ "- `LLM_API_KEY`\n\n"
360
+ "Optional (recommended):\n"
361
+ "- `RERANK_API_KEY`\n"
362
+ )
363
+ return interface
364
+
365
+ if self.startup_warnings:
366
+ gr.Markdown("### ⚠️ Startup Warnings")
367
+ gr.Markdown("\n".join([f"- {w}" for w in self.startup_warnings]))
368
+
369
+ with gr.Row():
370
+ with gr.Column(scale=1):
371
+ gr.Markdown("### Ask a Question")
372
+ question_input = gr.Textbox(
373
+ label="Your Question",
374
+ placeholder="e.g., What is achalasia and how is it diagnosed?",
375
+ lines=3,
376
+ )
377
+
378
+ retrieval_strategy, temperature_slider, source_filter = create_settings_accordion(
379
+ default_strategy=default_strategy,
380
+ default_temperature=default_temp,
381
+ default_sources=default_sources,
382
+ )
383
+
384
+ submit_btn = gr.Button("Search & Answer", variant="primary", size="lg")
385
+
386
+ gr.Markdown("### Example Questions")
387
+ gr.Examples(
388
+ examples=[
389
+ ["What is achalasia and how is it diagnosed on imaging?"],
390
+ ["Explain the imaging findings in Barrett's esophagus"],
391
+ ["What are the characteristics of a Zenker's diverticulum?"],
392
+ ["Describe the CT findings of esophageal cancer"],
393
+ ],
394
+ inputs=[question_input],
395
+ label="Click an example to try it",
396
+ )
397
+
398
+ with gr.Column(scale=2):
399
+ gr.Markdown("### Answer (with citations)")
400
+ answer_output = gr.Markdown(value="*Your answer will appear here...*")
401
+ stats_output = gr.HTML(label="Statistics")
402
+ gr.Markdown("### Retrieved References")
403
+ references_output = gr.HTML(
404
+ value="<p style='color:#6b7280;text-align:center;padding:20px;'>References will appear here...</p>"
405
+ )
406
+
407
+ submit_btn.click(
408
+ fn=self.process_query,
409
+ inputs=[question_input, temperature_slider, source_filter, retrieval_strategy],
410
+ outputs=[answer_output, references_output, stats_output],
411
+ )
412
+
413
+ gr.Markdown("---")
414
+ with gr.Accordion("About", open=False):
415
+ gr.Markdown(
416
+ "This Space demonstrates a radiology RAG system using a prebuilt vector index "
417
+ f"(`{self.config.get_str('index.repo_id')}`) and external APIs for embeddings/LLM.\n\n"
418
+ "**Disclaimer**: Educational use only. Always consult qualified professionals for clinical decisions."
419
+ )
420
+
421
+ return interface
422
+
423
+
requirements-dev.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dev / offline build dependencies (NOT used by Spaces runtime).
2
+ # Includes dataset loading + optional local embedding/rerank tooling.
3
+
4
+ -r requirements.txt
5
+
6
+ datasets>=2.16.0
7
+ tqdm>=4.66.0
8
+ pillow>=10.0.0
9
+
10
+ # Needed for index build (text splitting + Document objects)
11
+ langchain>=0.1.0
12
+ langchain-text-splitters>=0.0.1
13
+
14
+ # Optional: local models for advanced users (large)
15
+ sentence-transformers>=2.3.0
16
+ torch>=2.0.0
17
+ transformers>=4.36.0
18
+
19
+
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Runtime dependencies for Hugging Face Spaces (CPU-friendly, API-first).
2
+ # Keep this minimal to improve build stability.
3
+
4
+ # Web UI
5
+ gradio==4.16.0
6
+ gradio_client==0.8.1
7
+
8
+ # RAG core
9
+ langchain-core>=0.1.0
10
+ langchain-openai>=0.0.5
11
+ langchain-chroma>=0.1.0
12
+ chromadb>=0.4.22
13
+
14
+ # Index download (from HF Hub)
15
+ huggingface-hub>=0.20.0
16
+
17
+ # Utilities
18
+ pyyaml>=6.0
19
+ requests>=2.31.0
20
+
21
+
scripts/build_vector_db.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build a Chroma + SQLite index for this RAG system (offline / advanced users).
3
+
4
+ The index output folder is compatible with the Space runtime bootstrap:
5
+ <output_dir>/
6
+ chroma_db/
7
+ doc_store.db
8
+ manifest.json
9
+
10
+ Examples:
11
+
12
+ 1) Build from HF dataset directly (streaming is not supported for save_to_disk-based build):
13
+ python scripts/build_vector_db.py \
14
+ --config config/default_config.yaml \
15
+ --source huggingface \
16
+ --dataset ZhangNy/radiology-dataset \
17
+ --output-dir ./index_out
18
+
19
+ 2) Build from local saved dataset:
20
+ python scripts/build_vector_db.py \
21
+ --config config/default_config.yaml \
22
+ --source local \
23
+ --local-path ./hf_dataset_prepared \
24
+ --output-dir ./index_out
25
+
26
+ Notes:
27
+ - Embedding model used at build time must match query-time embeddings used in the Space,
28
+ otherwise retrieval quality will degrade.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import json
35
+ import os
36
+ import sys
37
+ import shutil
38
+ import time
39
+ from collections import Counter
40
+ from pathlib import Path
41
+ from typing import Any, Dict, List, Optional, Tuple
42
+
43
+ # Allow running as `python scripts/*.py` without installing the package.
44
+ sys.path.append(str(Path(__file__).resolve().parents[1]))
45
+
46
+
47
+ def _clean_text(text: str) -> str:
48
+ # Remove markdown hyperlinks [text](url) -> text
49
+ import re
50
+
51
+ t = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text or "")
52
+ return t.replace("\xa0", " ")
53
+
54
+
55
+ def main() -> int:
56
+ parser = argparse.ArgumentParser(description="Build vector index (Chroma + SQLite doc store)")
57
+ parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Config YAML path")
58
+ parser.add_argument("--source", choices=["local", "huggingface"], default="huggingface")
59
+ parser.add_argument("--local-path", type=str, default=None, help="Path to dataset saved via save_to_disk()")
60
+ parser.add_argument("--dataset", type=str, default="ZhangNy/radiology-dataset", help="HF dataset repo id")
61
+ parser.add_argument("--split", type=str, default="train")
62
+ parser.add_argument("--limit", type=int, default=None, help="Limit number of documents (debug)")
63
+ parser.add_argument("--output-dir", type=str, default="./index_out", help="Output directory for index artifacts")
64
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite output dir if exists")
65
+ args = parser.parse_args()
66
+
67
+ from datasets import load_dataset, load_from_disk
68
+ from langchain_chroma import Chroma
69
+ from langchain_core.documents import Document
70
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
71
+
72
+ from radiology_rag.config import Config
73
+ from radiology_rag.doc_store import PersistentDocStore
74
+ from radiology_rag.embedding import EmbeddingClient, EmbeddingConfig
75
+
76
+ cfg = Config(args.config)
77
+
78
+ out_dir = Path(args.output_dir)
79
+ chroma_dir = out_dir / "chroma_db"
80
+ doc_db = out_dir / "doc_store.db"
81
+ manifest_path = out_dir / "manifest.json"
82
+
83
+ if out_dir.exists() and args.overwrite:
84
+ shutil.rmtree(out_dir)
85
+ out_dir.mkdir(parents=True, exist_ok=True)
86
+
87
+ if chroma_dir.exists() or doc_db.exists():
88
+ if not args.overwrite:
89
+ raise SystemExit(f"Output dir already has index artifacts. Use --overwrite. ({out_dir})")
90
+
91
+ # Load dataset
92
+ if args.source == "local":
93
+ if not args.local_path:
94
+ raise SystemExit("--local-path is required when --source local")
95
+ dataset = load_from_disk(args.local_path)
96
+ else:
97
+ dataset = load_dataset(args.dataset, split=args.split)
98
+
99
+ if args.limit:
100
+ dataset = dataset.select(range(min(int(args.limit), len(dataset))))
101
+
102
+ # Splitter
103
+ splitter = RecursiveCharacterTextSplitter(
104
+ chunk_size=cfg.get_int("processing.chunk_size", 1024),
105
+ chunk_overlap=cfg.get_int("processing.chunk_overlap", 200),
106
+ separators=cfg.get("processing.separators", ["\n\n", "\n", " "]),
107
+ keep_separator=cfg.get_bool("processing.keep_separator", True),
108
+ )
109
+
110
+ # Embeddings
111
+ emb = EmbeddingClient(
112
+ EmbeddingConfig(
113
+ base_url=cfg.get_str("embedding.api_base_url"),
114
+ api_key=cfg.get_str("embedding.api_key"),
115
+ model_name=cfg.get_str("embedding.model_name"),
116
+ batch_size=cfg.get_int("embedding.batch_size", 32),
117
+ )
118
+ )
119
+
120
+ # Storage
121
+ doc_store = PersistentDocStore(str(doc_db), read_only=False)
122
+ vectorstore = Chroma(
123
+ collection_name="radiology_docs",
124
+ embedding_function=emb.langchain_embeddings,
125
+ persist_directory=str(chroma_dir),
126
+ )
127
+
128
+ # Build
129
+ start = time.time()
130
+ parent_pairs: List[Tuple[str, Dict[str, Any]]] = []
131
+ child_docs: List[Document] = []
132
+ counts = Counter()
133
+
134
+ for item in dataset:
135
+ doc_id = (item.get("doc_id") or "").strip()
136
+ if not doc_id:
137
+ continue
138
+ source_type = (item.get("source_type") or "").strip()
139
+ title = (item.get("title") or "").strip()
140
+ content = _clean_text(item.get("content") or "")
141
+ url = (item.get("url") or "").strip()
142
+ metadata = item.get("metadata") or {}
143
+
144
+ counts[source_type or "unknown"] += 1
145
+
146
+ # Parent document record
147
+ parent_pairs.append(
148
+ (
149
+ doc_id,
150
+ {
151
+ "complete_document": {
152
+ "doc_id": doc_id,
153
+ "title": title,
154
+ "content": content,
155
+ "url": url,
156
+ "metadata": metadata,
157
+ },
158
+ "main_content": content,
159
+ "images": [], # not used in this Space
160
+ "source_type": source_type,
161
+ },
162
+ )
163
+ )
164
+
165
+ # Child chunks for vector store
166
+ chunks = splitter.split_text(content)
167
+ total = len(chunks)
168
+ for i, chunk in enumerate(chunks):
169
+ child_docs.append(
170
+ Document(
171
+ page_content=chunk,
172
+ metadata={
173
+ "doc_id": f"{doc_id}_chunk_{i}",
174
+ "parent_id": doc_id,
175
+ "source_type": source_type,
176
+ "title": title,
177
+ "chunk_index": i,
178
+ "total_chunks": total,
179
+ },
180
+ )
181
+ )
182
+
183
+ # Persist parent docs
184
+ doc_store.mset(parent_pairs)
185
+
186
+ # Add chunks in batches
187
+ batch_size = int(cfg.get_int("processing.batch_size", 32))
188
+ for i in range(0, len(child_docs), batch_size):
189
+ vectorstore.add_documents(child_docs[i : i + batch_size])
190
+
191
+ elapsed = time.time() - start
192
+
193
+ # Manifest
194
+ manifest = {
195
+ "built_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
196
+ "seconds": elapsed,
197
+ "dataset": {"source": args.source, "dataset": args.dataset, "split": args.split, "limit": args.limit},
198
+ "embedding": {"type": "api", "model_name": cfg.get_str("embedding.model_name"), "base_url": cfg.get_str("embedding.api_base_url")},
199
+ "processing": {
200
+ "chunk_size": cfg.get_int("processing.chunk_size", 1024),
201
+ "chunk_overlap": cfg.get_int("processing.chunk_overlap", 200),
202
+ },
203
+ "counts_by_source_type": dict(counts),
204
+ "artifacts": {"chroma_dir": "chroma_db", "doc_store": "doc_store.db"},
205
+ }
206
+ with open(manifest_path, "w", encoding="utf-8") as f:
207
+ json.dump(manifest, f, ensure_ascii=False, indent=2)
208
+
209
+ print(f"✓ Index built at: {out_dir}")
210
+ print(f" - documents: {sum(counts.values())} (by type: {dict(counts)})")
211
+ print(f" - chunks: {len(child_docs)}")
212
+ print(f" - elapsed: {elapsed:.1f}s")
213
+ return 0
214
+
215
+
216
+ if __name__ == "__main__":
217
+ raise SystemExit(main())
218
+
219
+
scripts/download_hf_dataset.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download the public dataset from Hugging Face and save it to disk.
3
+
4
+ Example:
5
+ python scripts/download_hf_dataset.py \
6
+ --dataset ZhangNy/radiology-dataset \
7
+ --split train \
8
+ --output ./hf_dataset_prepared
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ # Allow running as `python scripts/*.py` without installing the package.
18
+ sys.path.append(str(Path(__file__).resolve().parents[1]))
19
+
20
+
21
+ def main() -> int:
22
+ parser = argparse.ArgumentParser(description="Download HF dataset to local disk")
23
+ parser.add_argument("--dataset", type=str, default="ZhangNy/radiology-dataset", help="HF dataset repo id")
24
+ parser.add_argument("--split", type=str, default="train", help="Dataset split")
25
+ parser.add_argument("--output", type=str, default="./hf_dataset_prepared", help="Output directory (save_to_disk)")
26
+ parser.add_argument("--cache-dir", type=str, default=None, help="Optional datasets cache dir")
27
+ args = parser.parse_args()
28
+
29
+ from datasets import load_dataset
30
+
31
+ out_dir = Path(args.output)
32
+ out_dir.parent.mkdir(parents=True, exist_ok=True)
33
+
34
+ ds = load_dataset(args.dataset, split=args.split, cache_dir=args.cache_dir)
35
+ ds.save_to_disk(str(out_dir))
36
+ print(f"✓ Saved dataset to: {out_dir} (rows={len(ds)})")
37
+ return 0
38
+
39
+
40
+ if __name__ == "__main__":
41
+ raise SystemExit(main())
42
+
43
+
scripts/package_existing_storage.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Package an existing local index folder (e.g. rebuild_1217/storage) into a clean index folder.
3
+
4
+ This is the fastest path if you already built the index locally and want to publish it
5
+ to Hugging Face without rebuilding embeddings.
6
+
7
+ Input (example):
8
+ /path/to/storage/
9
+ chroma_db/
10
+ doc_store.db
11
+ images/ # optional (ignored)
12
+
13
+ Output:
14
+ ./index_out/
15
+ chroma_db/
16
+ doc_store.db
17
+ manifest.json
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ import sys
25
+ import shutil
26
+ import sqlite3
27
+ import time
28
+ from pathlib import Path
29
+ from typing import Dict
30
+
31
+ # Allow running as `python scripts/*.py` without installing the package.
32
+ sys.path.append(str(Path(__file__).resolve().parents[1]))
33
+
34
+
35
+ def _count_by_source_type(doc_store_db: Path) -> Dict[str, int]:
36
+ counts: Dict[str, int] = {}
37
+ conn = sqlite3.connect(str(doc_store_db))
38
+ try:
39
+ cur = conn.cursor()
40
+ cur.execute("SELECT source_type, COUNT(*) FROM documents GROUP BY source_type")
41
+ for source_type, count in cur.fetchall():
42
+ counts[str(source_type)] = int(count)
43
+ finally:
44
+ conn.close()
45
+ return counts
46
+
47
+
48
+ def main() -> int:
49
+ parser = argparse.ArgumentParser(description="Package existing index storage into index_out (no images)")
50
+ parser.add_argument("--storage", type=str, required=True, help="Existing storage dir containing chroma_db/ + doc_store.db")
51
+ parser.add_argument("--output-dir", type=str, default="./index_out", help="Output folder")
52
+ parser.add_argument("--config", type=str, default="config/default_config.yaml", help="Config YAML (for embedding metadata)")
53
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite output dir if exists")
54
+ args = parser.parse_args()
55
+
56
+ from radiology_rag.config import Config
57
+
58
+ storage = Path(args.storage)
59
+ src_chroma = storage / "chroma_db"
60
+ src_doc = storage / "doc_store.db"
61
+ if not src_chroma.exists() or not src_doc.exists():
62
+ raise SystemExit(f"Storage missing required files: {src_chroma} / {src_doc}")
63
+
64
+ out_dir = Path(args.output_dir)
65
+ out_chroma = out_dir / "chroma_db"
66
+ out_doc = out_dir / "doc_store.db"
67
+ out_manifest = out_dir / "manifest.json"
68
+
69
+ if out_dir.exists() and args.overwrite:
70
+ shutil.rmtree(out_dir)
71
+ out_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ if out_chroma.exists() or out_doc.exists():
74
+ if not args.overwrite:
75
+ raise SystemExit(f"Output already exists. Use --overwrite. ({out_dir})")
76
+
77
+ # Copy artifacts (exclude images/)
78
+ if out_chroma.exists():
79
+ shutil.rmtree(out_chroma, ignore_errors=True)
80
+ shutil.copytree(src_chroma, out_chroma, dirs_exist_ok=False)
81
+ shutil.copy2(src_doc, out_doc)
82
+
83
+ cfg = Config(args.config)
84
+ counts = _count_by_source_type(out_doc)
85
+ manifest = {
86
+ "packaged_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
87
+ "source_storage": str(storage),
88
+ "embedding": {"model_name": cfg.get_str("embedding.model_name"), "type": cfg.get_str("embedding.type", "api")},
89
+ "processing": {
90
+ "chunk_size": cfg.get_int("processing.chunk_size", 1024),
91
+ "chunk_overlap": cfg.get_int("processing.chunk_overlap", 200),
92
+ },
93
+ "counts_by_source_type": counts,
94
+ "artifacts": {"chroma_dir": "chroma_db", "doc_store": "doc_store.db"},
95
+ "images_included": False,
96
+ }
97
+ with open(out_manifest, "w", encoding="utf-8") as f:
98
+ json.dump(manifest, f, ensure_ascii=False, indent=2)
99
+
100
+ print(f"✓ Packaged index to: {out_dir}")
101
+ print(f" - chroma_db: {out_chroma}")
102
+ print(f" - doc_store: {out_doc}")
103
+ print(f" - manifest: {out_manifest}")
104
+ return 0
105
+
106
+
107
+ if __name__ == "__main__":
108
+ raise SystemExit(main())
109
+
110
+
scripts/publish_index_to_hf.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Publish a built index folder to Hugging Face Datasets.
3
+
4
+ Example:
5
+ python scripts/publish_index_to_hf.py \
6
+ --repo ZhangNy/radiology-index-qwen3-embedding-0.6b \
7
+ --folder ./index_out \
8
+ --token $HF_TOKEN
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ # Allow running as `python scripts/*.py` without installing the package.
18
+ sys.path.append(str(Path(__file__).resolve().parents[1]))
19
+
20
+
21
+ def main() -> int:
22
+ parser = argparse.ArgumentParser(description="Upload index artifacts to HF datasets repo")
23
+ parser.add_argument("--repo", type=str, required=True, help="HF dataset repo id, e.g. user/my-index")
24
+ parser.add_argument("--folder", type=str, required=True, help="Local folder containing chroma_db/ + doc_store.db")
25
+ parser.add_argument("--token", type=str, default=None, help="HF token (or set HF_TOKEN env)")
26
+ parser.add_argument("--private", action="store_true", help="Create repo as private")
27
+ parser.add_argument("--revision", type=str, default="main", help="Target revision/branch")
28
+ parser.add_argument(
29
+ "--ignore",
30
+ type=str,
31
+ default="",
32
+ help="Comma-separated ignore patterns for upload_folder (e.g. 'images/**,**/images/**')",
33
+ )
34
+ args = parser.parse_args()
35
+
36
+ from huggingface_hub import HfApi
37
+
38
+ token = args.token or None
39
+ if token is None:
40
+ import os
41
+
42
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
43
+
44
+ if not token:
45
+ raise SystemExit("Missing token. Provide --token or set HF_TOKEN.")
46
+
47
+ folder = Path(args.folder)
48
+ if not folder.exists():
49
+ raise SystemExit(f"Folder not found: {folder}")
50
+
51
+ api = HfApi()
52
+ api.create_repo(
53
+ repo_id=args.repo,
54
+ repo_type="dataset",
55
+ private=bool(args.private),
56
+ exist_ok=True,
57
+ token=token,
58
+ )
59
+
60
+ api.upload_folder(
61
+ repo_id=args.repo,
62
+ repo_type="dataset",
63
+ folder_path=str(folder),
64
+ path_in_repo="",
65
+ token=token,
66
+ revision=args.revision,
67
+ commit_message="Upload prebuilt radiology RAG index",
68
+ ignore_patterns=[p.strip() for p in (args.ignore or "").split(",") if p.strip()] or None,
69
+ )
70
+
71
+ print(f"✓ Uploaded index folder to HF dataset repo: {args.repo}")
72
+ return 0
73
+
74
+
75
+ if __name__ == "__main__":
76
+ raise SystemExit(main())
77
+
78
+