import json import os from pathlib import Path from typing import Any, Dict, List import numpy as np import spacy from spacy import displacy from fastapi import FastAPI, Query from fastapi.middleware.cors import CORSMiddleware from rank_bm25 import BM25Okapi try: from sentence_transformers import SentenceTransformer except Exception: SentenceTransformer = None BASE_DIR = Path(__file__).resolve().parents[1] CORPUS_PATH = BASE_DIR / "data" / "corpus.json" EMBEDDINGS_PATH = BASE_DIR / "data" / "embeddings.npy" # Hugging Face Spaces provides 16GB of free RAM, so we don't need strict truncation. # However, we still load pre-computed embeddings if available to speed up boot times. MAX_DOCS = int(os.environ.get("MAX_DOCS", 50000)) def _norm(values: np.ndarray) -> np.ndarray: if values.size == 0: return values v_max = float(np.max(values)) if v_max <= 0: return np.zeros_like(values, dtype=float) return values / v_max def _tokenize(text: str) -> List[str]: return [tok.lower() for tok in text.split() if tok.strip()] def _extract_query_roles(query: str, nlp_model) -> Dict[str, str]: doc = nlp_model(query) predicate = "" arg0: List[str] = [] arg1: List[str] = [] argm_tmp: List[str] = [] for tok in doc: if not predicate and tok.pos_ == "VERB": predicate = tok.lemma_.lower() for tok in doc: if tok.dep_ in {"nsubj", "nsubjpass"}: arg0.append(" ".join(t.text for t in tok.subtree)) elif tok.dep_ in {"dobj", "obj", "iobj"}: arg1.append(" ".join(t.text for t in tok.subtree)) elif tok.dep_ in {"obl", "npadvmod"} and tok.ent_type_ in {"DATE", "TIME"}: argm_tmp.append(" ".join(t.text for t in tok.subtree)) out: Dict[str, str] = {"predicate": predicate} if arg0: out["ARG0"] = " ".join(dict.fromkeys(arg0)) if arg1: out["ARG1"] = " ".join(dict.fromkeys(arg1)) if argm_tmp: out["ARGM-TMP"] = " ".join(dict.fromkeys(argm_tmp)) return out def _role_overlap_score( query_roles: Dict[str, str], doc_roles: Dict[str, str] ) -> float: keys = [k for k, v in query_roles.items() if v] if not keys: return 0.0 scores: List[float] = [] for key in keys: if key == "predicate": q_pred = query_roles.get("predicate", "").split(".")[0] d_pred = doc_roles.get("predicate", "").split(".")[0] scores.append(1.0 if q_pred and d_pred and q_pred == d_pred else 0.0) continue q_tokens = set(_tokenize(query_roles.get(key, ""))) d_tokens = set(_tokenize(doc_roles.get(key, ""))) if not q_tokens or not d_tokens: scores.append(0.0) continue inter = len(q_tokens.intersection(d_tokens)) union = len(q_tokens.union(d_tokens)) scores.append(inter / union if union else 0.0) return float(np.mean(scores)) if scores else 0.0 def _maybe_load_embedder() -> Any: if SentenceTransformer is None: return None try: return SentenceTransformer("all-MiniLM-L6-v2") except Exception as e: print(f"Failed to load embedder: {e}") return None def _dense_scores(query: str) -> np.ndarray: if EMBEDDER is None or DOC_EMBEDDINGS.size == 0: return np.zeros(len(CORPUS), dtype=float) query_embedding = EMBEDDER.encode(query) return np.array(np.dot(DOC_EMBEDDINGS, query_embedding), dtype=float) ROLE_KEYS = { "predicate", "ARG0", "ARG1", "ARG2", "ARG3", "ARG4", "ARGM-LOC", "ARGM-TMP", "ARGM-MNR", "ARGM-CAU", "ARGM-DIS", "ARGM-ADV", "ARGM-MOD", "ARGM-NEG", } def _build_result( idx: int, bm25_scores: np.ndarray, dense_scores: np.ndarray, srl_scores: np.ndarray, hybrid_scores: np.ndarray, ) -> Dict[str, Any]: doc = CORPUS[int(idx)] text = doc.get("text", "") roles = {k: v for k, v in doc.items() if k in ROLE_KEYS and v} try: spacy_doc = NLP(text) html = displacy.render( spacy_doc, style="dep", page=False, options={ "compact": True, "color": "#e4e4e7", "font": "sans-serif", "bg": "transparent", }, ) except Exception: html = "" return { "id": doc.get("id"), "text": text, "scores": { "bm25": float(bm25_scores[idx]), "dense": float(dense_scores[idx]), "srl": float(srl_scores[idx]), "hybrid": float(hybrid_scores[idx]), }, "roles": roles, "displacy_html": html, } with CORPUS_PATH.open("r", encoding="utf-8") as f: CORPUS: List[Dict[str, Any]] = json.load(f) if len(CORPUS) > MAX_DOCS: print(f"Truncating corpus from {len(CORPUS)} to {MAX_DOCS} docs to prevent OOM.") CORPUS = CORPUS[:MAX_DOCS] TEXTS = [doc.get("text", "") for doc in CORPUS] BM25 = BM25Okapi([_tokenize(text) for text in TEXTS]) EMBEDDER = _maybe_load_embedder() if EMBEDDINGS_PATH.exists(): print(f"Loading pre-computed embeddings from {EMBEDDINGS_PATH}") DOC_EMBEDDINGS = np.load(str(EMBEDDINGS_PATH)) if len(DOC_EMBEDDINGS) > MAX_DOCS: DOC_EMBEDDINGS = DOC_EMBEDDINGS[:MAX_DOCS] elif EMBEDDER is not None and TEXTS: print("Computing embeddings on the fly... (This may cause OOM on small instances)") DOC_EMBEDDINGS = np.array(EMBEDDER.encode(TEXTS)) else: DOC_EMBEDDINGS = np.array([]) try: NLP = spacy.load("en_core_web_sm") except Exception: NLP = spacy.blank("en") app = FastAPI(title="Semantic IR Backend") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/search") def search( q: str = Query(..., min_length=1), top_k: int = Query(10, ge=1, le=50) ) -> Dict[str, Any]: bm25_scores = np.array(BM25.get_scores(_tokenize(q)), dtype=float) dense_scores = _dense_scores(q) query_roles = _extract_query_roles(q, NLP) srl_scores = np.array( [_role_overlap_score(query_roles, doc) for doc in CORPUS], dtype=float ) bm25_n = _norm(bm25_scores) dense_n = _norm(dense_scores) srl_n = _norm(srl_scores) hybrid_scores = 0.5 * bm25_n + 0.3 * dense_n + 0.2 * srl_n def _build_results( idx_array: np.ndarray, score_arr: np.ndarray ) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] for idx in idx_array: if score_arr[idx] <= 0.0: continue results.append( _build_result( int(idx), bm25_scores, dense_scores, srl_scores, hybrid_scores ) ) if len(results) >= top_k: break return results bm25_results = _build_results(np.argsort(-bm25_scores), bm25_scores) dense_results = _build_results(np.argsort(-dense_scores), dense_scores) srl_results = _build_results(np.argsort(-srl_scores), srl_scores) hybrid_results = _build_results(np.argsort(-hybrid_scores), hybrid_scores) return { "bm25_results": bm25_results, "dense_results": dense_results, "srl_results": srl_results, "hybrid_results": hybrid_results, } @app.get("/health") def health() -> Dict[str, Any]: return { "status": "ok", "documents": len(CORPUS), "dense_enabled": EMBEDDER is not None, }