Spaces:
Running on Zero
Running on Zero
| """Wikipedia encyclopedia retrieval (MediaWiki API).""" | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| import requests | |
| logger = logging.getLogger(__name__) | |
| class WikipediaConfig: | |
| language: str = "en" | |
| user_agent: str = "RadiologyRAG-Space/1.0" | |
| timeout_s: int = 15 | |
| max_chars_per_doc: int = 2000 | |
| class WikipediaEncyclopediaService: | |
| def __init__(self, config: Optional[WikipediaConfig] = None): | |
| self.config = config or WikipediaConfig() | |
| self._session = requests.Session() | |
| self._session.headers.update({"User-Agent": self.config.user_agent}) | |
| def api_base(self) -> str: | |
| return f"https://{self.config.language}.wikipedia.org/w/api.php" | |
| def _derive_search_query(user_query: str) -> str: | |
| q = (user_query or "").strip() | |
| if not q: | |
| return "" | |
| tokens = re.findall(r"[A-Za-z][A-Za-z'\\-]*", q.lower()) | |
| if not tokens: | |
| return q | |
| stop = { | |
| "what", | |
| "which", | |
| "who", | |
| "whom", | |
| "whose", | |
| "when", | |
| "where", | |
| "why", | |
| "how", | |
| "is", | |
| "are", | |
| "was", | |
| "were", | |
| "be", | |
| "been", | |
| "being", | |
| "do", | |
| "does", | |
| "did", | |
| "can", | |
| "could", | |
| "should", | |
| "would", | |
| "may", | |
| "might", | |
| "will", | |
| "shall", | |
| "a", | |
| "an", | |
| "the", | |
| "and", | |
| "or", | |
| "but", | |
| "to", | |
| "of", | |
| "for", | |
| "with", | |
| "without", | |
| "in", | |
| "on", | |
| "at", | |
| "by", | |
| "from", | |
| "as", | |
| "it", | |
| "its", | |
| "this", | |
| "that", | |
| "these", | |
| "those", | |
| "your", | |
| "my", | |
| "their", | |
| "our", | |
| "about", | |
| } | |
| keep_short = {"ct", "mr", "mri", "pet", "us", "cxr"} | |
| keywords: List[str] = [] | |
| seen = set() | |
| for t in tokens: | |
| if t in stop: | |
| continue | |
| if len(t) < 3 and t not in keep_short: | |
| continue | |
| if t in seen: | |
| continue | |
| seen.add(t) | |
| keywords.append(t) | |
| return " ".join(keywords[:8]) if keywords else q | |
| def retrieve(self, query: str, top_k: int = 5, max_chars_per_doc: Optional[int] = None) -> List[Dict[str, Any]]: | |
| q = (query or "").strip() | |
| if not q: | |
| return [] | |
| search_q = self._derive_search_query(q) | |
| if not search_q: | |
| return [] | |
| max_chars = int(max_chars_per_doc or self.config.max_chars_per_doc) | |
| try: | |
| search_params = { | |
| "action": "query", | |
| "list": "search", | |
| "srsearch": search_q, | |
| "srlimit": max(1, min(int(top_k), 20)), | |
| "format": "json", | |
| } | |
| resp = self._session.get(self.api_base, params=search_params, timeout=self.config.timeout_s) | |
| resp.raise_for_status() | |
| data = resp.json() or {} | |
| hits = (data.get("query", {}) or {}).get("search", []) or [] | |
| # Fallback to raw query if rewrite yields no hits | |
| if not hits and search_q != q: | |
| search_params["srsearch"] = q | |
| resp = self._session.get(self.api_base, params=search_params, timeout=self.config.timeout_s) | |
| resp.raise_for_status() | |
| data = resp.json() or {} | |
| hits = (data.get("query", {}) or {}).get("search", []) or [] | |
| if not hits: | |
| return [] | |
| pageids = [str(h.get("pageid")) for h in hits if h.get("pageid") is not None] | |
| if not pageids: | |
| return [] | |
| pages_params = { | |
| "action": "query", | |
| "pageids": "|".join(pageids), | |
| "prop": "extracts|info", | |
| "explaintext": 1, | |
| "exintro": 1, | |
| "exchars": max_chars, | |
| "inprop": "url", | |
| "format": "json", | |
| } | |
| resp2 = self._session.get(self.api_base, params=pages_params, timeout=self.config.timeout_s) | |
| resp2.raise_for_status() | |
| pages_data = resp2.json() or {} | |
| pages = (pages_data.get("query", {}) or {}).get("pages", {}) or {} | |
| docs: List[Dict[str, Any]] = [] | |
| for pid in pageids: | |
| page = pages.get(pid) or {} | |
| title = page.get("title") or "" | |
| extract = (page.get("extract") or "").strip() | |
| url = page.get("fullurl") or "" | |
| if not title or not extract: | |
| continue | |
| docs.append( | |
| { | |
| "doc_id": f"encyclopedia_{pid}", | |
| "source_type": "encyclopedia", | |
| "title": title, | |
| "content": extract, | |
| "url": url, | |
| "metadata": {"provider": "wikipedia", "pageid": pid}, | |
| "score": 0.0, | |
| } | |
| ) | |
| return docs[: int(top_k)] | |
| except Exception as e: | |
| logger.warning(f"Wikipedia retrieval failed: {e}") | |
| return [] | |