| import logging |
| import os |
| import threading |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| log = logging.getLogger(__name__) |
|
|
| |
| |
| |
| EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "intfloat/multilingual-e5-base") |
| CHROMA_PATH = "models/knowledge_base" |
| COLLECTION_NAME = "immuniwatch_kb" |
| TOP_K = 5 |
| SIMILARITY_THRESHOLD = 0.72 |
|
|
|
|
| |
| |
| |
| class EvidenceRecord: |
| """ |
| Structured evidence record returned for a misinformation post. |
| Matches system design Section 5.3.1 EvidenceRecord schema. |
| """ |
|
|
| def __init__( |
| self, |
| source_title: str, |
| source_url: str, |
| snippet: str, |
| similarity: float, |
| language: str = "en", |
| ): |
| self.source_title = source_title |
| self.source_url = source_url |
| self.snippet = snippet |
| self.similarity = round(similarity, 4) |
| self.language = language |
|
|
| def to_dict(self) -> dict: |
| return { |
| "source_title": self.source_title, |
| "source_url": self.source_url, |
| "snippet": self.snippet[:300], |
| "similarity": self.similarity, |
| "language": self.language, |
| } |
|
|
|
|
| |
| |
| |
| class RAGRetriever: |
| """ |
| Retrieves verified health facts from ChromaDB knowledge base. |
| Uses multilingual-e5-large for cross-language matching — |
| a Hausa claim will match an English WHO fact correctly. |
| |
| Must call ingest_all() from ingestion.py first to populate |
| the knowledge base before retrieval works. |
| """ |
|
|
| def __init__(self): |
| self._collection = None |
| self._ready = False |
| self._init() |
|
|
| def _init(self) -> None: |
| """Connect to ChromaDB. Fails gracefully if not yet ingested.""" |
| kb_path = Path(CHROMA_PATH) |
| if not kb_path.exists(): |
| log.warning( |
| "Knowledge base not found at %s. " |
| "Run: python -m src.intelligence.ingestion", |
| CHROMA_PATH, |
| ) |
| return |
|
|
| try: |
| import chromadb |
| from chromadb.utils import embedding_functions |
|
|
| client = chromadb.PersistentClient(path=CHROMA_PATH) |
| ef = embedding_functions.SentenceTransformerEmbeddingFunction( |
| model_name=EMBEDDING_MODEL, |
| device="cpu", |
| ) |
| self._collection = client.get_collection( |
| name=COLLECTION_NAME, |
| embedding_function=ef, |
| ) |
| count = self._collection.count() |
| self._ready = count > 0 |
| log.info( |
| "RAG ready — %d chunks in knowledge base", count |
| ) |
| except Exception as e: |
| log.error("RAG initialisation failed: %s", e) |
|
|
| def is_ready(self) -> bool: |
| return self._ready |
|
|
| def retrieve( |
| self, |
| post_text: str, |
| language: Optional[str] = None, |
| ) -> List[EvidenceRecord]: |
| """ |
| Retrieve top-5 verified facts most relevant to the post. |
| |
| Args: |
| post_text: The misinformation post content. |
| language: Post language code (en/pcm/ha/yo/ig). |
| multilingual-e5-large handles cross-language |
| matching automatically. |
| |
| Returns: |
| List of EvidenceRecord objects above similarity threshold. |
| Empty list if knowledge base is not ready or no match found. |
| """ |
| if not self._ready: |
| log.warning("RAG not ready — returning empty evidence") |
| return [] |
|
|
| if not post_text or len(post_text.strip()) < 5: |
| return [] |
|
|
| try: |
| |
| query = f"query: {post_text.strip()}" |
|
|
| results = self._collection.query( |
| query_texts=[query], |
| n_results=TOP_K, |
| include=["documents", "metadatas", "distances"], |
| ) |
|
|
| evidence = [] |
| documents = results.get("documents", [[]])[0] |
| metadatas = results.get("metadatas", [[]])[0] |
| distances = results.get("distances", [[]])[0] |
|
|
| for doc, meta, distance in zip(documents, metadatas, distances): |
| |
| |
| similarity = 1.0 - distance |
|
|
| if similarity < SIMILARITY_THRESHOLD: |
| continue |
|
|
| evidence.append(EvidenceRecord( |
| source_title= meta.get("source", "Unknown"), |
| source_url= meta.get("url", ""), |
| snippet= doc, |
| similarity= similarity, |
| language= meta.get("language", "en"), |
| )) |
|
|
| log.debug( |
| "RAG retrieved %d evidence records for post (threshold=%.2f)", |
| len(evidence), SIMILARITY_THRESHOLD, |
| ) |
| return evidence |
|
|
| except Exception as e: |
| log.error("RAG retrieval failed: %s", e) |
| return [] |
|
|
| def retrieve_as_dicts( |
| self, |
| post_text: str, |
| language: Optional[str] = None, |
| ) -> List[dict]: |
| """ |
| Same as retrieve() but returns plain dicts. |
| Used by FastAPI routes to serialise response. |
| """ |
| return [e.to_dict() for e in self.retrieve(post_text, language)] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| _st_model = None |
| _st_lock = threading.Lock() |
|
|
|
|
| def preload_embedder() -> None: |
| """Called at startup to warm the SentenceTransformer before first request.""" |
| try: |
| _get_st_model() |
| log.info("Embedding model loaded and ready.") |
| except Exception as exc: |
| log.warning("Embedding model failed to preload: %s", exc) |
|
|
|
|
| def _get_st_model(): |
| """Lazy singleton — loads multilingual-e5-large once and reuses it.""" |
| global _st_model |
| if _st_model is not None: |
| return _st_model |
| with _st_lock: |
| if _st_model is None: |
| from sentence_transformers import SentenceTransformer |
| log.info("Loading SentenceTransformer: %s", EMBEDDING_MODEL) |
| _st_model = SentenceTransformer(EMBEDDING_MODEL) |
| return _st_model |
|
|
|
|
| def is_embedder_ready() -> bool: |
| return _st_model is not None |
|
|
|
|
| def embed_text(text: str) -> List[float]: |
| """ |
| Embed a single text string. |
| Prefix 'query: ' is required by multilingual-e5 models for query inputs. |
| Returns list of 768 floats (multilingual-e5-base). |
| """ |
| model = _get_st_model() |
| vec = model.encode(f"query: {text}", normalize_embeddings=True) |
| return vec.tolist() |
|
|
|
|
| def embed_batch(items: List[dict]) -> List[dict]: |
| """ |
| Embed multiple items. |
| items: list of {"doc_id": str, "text": str} |
| Returns: list of {"doc_id": str, "embedding": list[float]} |
| """ |
| model = _get_st_model() |
| texts = [f"query: {it['text']}" for it in items] |
| vecs = model.encode(texts, normalize_embeddings=True, batch_size=32) |
| return [ |
| {"doc_id": item["doc_id"], "embedding": vec.tolist()} |
| for item, vec in zip(items, vecs) |
| ] |