Spaces:
Sleeping
Sleeping
Kamal Nayan Kumar
feat: revert backend to use sentence-transformers (all-MiniLM-L6-v2) for better embeddings
548ff8f | import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| import spacy | |
| from spacy import displacy | |
| from fastapi import FastAPI, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from rank_bm25 import BM25Okapi | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| except Exception: | |
| SentenceTransformer = None | |
| BASE_DIR = Path(__file__).resolve().parents[1] | |
| CORPUS_PATH = BASE_DIR / "data" / "corpus.json" | |
| EMBEDDINGS_PATH = BASE_DIR / "data" / "embeddings.npy" | |
| # Hugging Face Spaces provides 16GB of free RAM, so we don't need strict truncation. | |
| # However, we still load pre-computed embeddings if available to speed up boot times. | |
| MAX_DOCS = int(os.environ.get("MAX_DOCS", 50000)) | |
| def _norm(values: np.ndarray) -> np.ndarray: | |
| if values.size == 0: | |
| return values | |
| v_max = float(np.max(values)) | |
| if v_max <= 0: | |
| return np.zeros_like(values, dtype=float) | |
| return values / v_max | |
| def _tokenize(text: str) -> List[str]: | |
| return [tok.lower() for tok in text.split() if tok.strip()] | |
| def _extract_query_roles(query: str, nlp_model) -> Dict[str, str]: | |
| doc = nlp_model(query) | |
| predicate = "" | |
| arg0: List[str] = [] | |
| arg1: List[str] = [] | |
| argm_tmp: List[str] = [] | |
| for tok in doc: | |
| if not predicate and tok.pos_ == "VERB": | |
| predicate = tok.lemma_.lower() | |
| for tok in doc: | |
| if tok.dep_ in {"nsubj", "nsubjpass"}: | |
| arg0.append(" ".join(t.text for t in tok.subtree)) | |
| elif tok.dep_ in {"dobj", "obj", "iobj"}: | |
| arg1.append(" ".join(t.text for t in tok.subtree)) | |
| elif tok.dep_ in {"obl", "npadvmod"} and tok.ent_type_ in {"DATE", "TIME"}: | |
| argm_tmp.append(" ".join(t.text for t in tok.subtree)) | |
| out: Dict[str, str] = {"predicate": predicate} | |
| if arg0: | |
| out["ARG0"] = " ".join(dict.fromkeys(arg0)) | |
| if arg1: | |
| out["ARG1"] = " ".join(dict.fromkeys(arg1)) | |
| if argm_tmp: | |
| out["ARGM-TMP"] = " ".join(dict.fromkeys(argm_tmp)) | |
| return out | |
| def _role_overlap_score( | |
| query_roles: Dict[str, str], doc_roles: Dict[str, str] | |
| ) -> float: | |
| keys = [k for k, v in query_roles.items() if v] | |
| if not keys: | |
| return 0.0 | |
| scores: List[float] = [] | |
| for key in keys: | |
| if key == "predicate": | |
| q_pred = query_roles.get("predicate", "").split(".")[0] | |
| d_pred = doc_roles.get("predicate", "").split(".")[0] | |
| scores.append(1.0 if q_pred and d_pred and q_pred == d_pred else 0.0) | |
| continue | |
| q_tokens = set(_tokenize(query_roles.get(key, ""))) | |
| d_tokens = set(_tokenize(doc_roles.get(key, ""))) | |
| if not q_tokens or not d_tokens: | |
| scores.append(0.0) | |
| continue | |
| inter = len(q_tokens.intersection(d_tokens)) | |
| union = len(q_tokens.union(d_tokens)) | |
| scores.append(inter / union if union else 0.0) | |
| return float(np.mean(scores)) if scores else 0.0 | |
| def _maybe_load_embedder() -> Any: | |
| if SentenceTransformer is None: | |
| return None | |
| try: | |
| return SentenceTransformer("all-MiniLM-L6-v2") | |
| except Exception as e: | |
| print(f"Failed to load embedder: {e}") | |
| return None | |
| def _dense_scores(query: str) -> np.ndarray: | |
| if EMBEDDER is None or DOC_EMBEDDINGS.size == 0: | |
| return np.zeros(len(CORPUS), dtype=float) | |
| query_embedding = EMBEDDER.encode(query) | |
| return np.array(np.dot(DOC_EMBEDDINGS, query_embedding), dtype=float) | |
| ROLE_KEYS = { | |
| "predicate", | |
| "ARG0", | |
| "ARG1", | |
| "ARG2", | |
| "ARG3", | |
| "ARG4", | |
| "ARGM-LOC", | |
| "ARGM-TMP", | |
| "ARGM-MNR", | |
| "ARGM-CAU", | |
| "ARGM-DIS", | |
| "ARGM-ADV", | |
| "ARGM-MOD", | |
| "ARGM-NEG", | |
| } | |
| def _build_result( | |
| idx: int, | |
| bm25_scores: np.ndarray, | |
| dense_scores: np.ndarray, | |
| srl_scores: np.ndarray, | |
| hybrid_scores: np.ndarray, | |
| ) -> Dict[str, Any]: | |
| doc = CORPUS[int(idx)] | |
| text = doc.get("text", "") | |
| roles = {k: v for k, v in doc.items() if k in ROLE_KEYS and v} | |
| try: | |
| spacy_doc = NLP(text) | |
| html = displacy.render( | |
| spacy_doc, | |
| style="dep", | |
| page=False, | |
| options={ | |
| "compact": True, | |
| "color": "#e4e4e7", | |
| "font": "sans-serif", | |
| "bg": "transparent", | |
| }, | |
| ) | |
| except Exception: | |
| html = "" | |
| return { | |
| "id": doc.get("id"), | |
| "text": text, | |
| "scores": { | |
| "bm25": float(bm25_scores[idx]), | |
| "dense": float(dense_scores[idx]), | |
| "srl": float(srl_scores[idx]), | |
| "hybrid": float(hybrid_scores[idx]), | |
| }, | |
| "roles": roles, | |
| "displacy_html": html, | |
| } | |
| with CORPUS_PATH.open("r", encoding="utf-8") as f: | |
| CORPUS: List[Dict[str, Any]] = json.load(f) | |
| if len(CORPUS) > MAX_DOCS: | |
| print(f"Truncating corpus from {len(CORPUS)} to {MAX_DOCS} docs to prevent OOM.") | |
| CORPUS = CORPUS[:MAX_DOCS] | |
| TEXTS = [doc.get("text", "") for doc in CORPUS] | |
| BM25 = BM25Okapi([_tokenize(text) for text in TEXTS]) | |
| EMBEDDER = _maybe_load_embedder() | |
| if EMBEDDINGS_PATH.exists(): | |
| print(f"Loading pre-computed embeddings from {EMBEDDINGS_PATH}") | |
| DOC_EMBEDDINGS = np.load(str(EMBEDDINGS_PATH)) | |
| if len(DOC_EMBEDDINGS) > MAX_DOCS: | |
| DOC_EMBEDDINGS = DOC_EMBEDDINGS[:MAX_DOCS] | |
| elif EMBEDDER is not None and TEXTS: | |
| print("Computing embeddings on the fly... (This may cause OOM on small instances)") | |
| DOC_EMBEDDINGS = np.array(EMBEDDER.encode(TEXTS)) | |
| else: | |
| DOC_EMBEDDINGS = np.array([]) | |
| try: | |
| NLP = spacy.load("en_core_web_sm") | |
| except Exception: | |
| NLP = spacy.blank("en") | |
| app = FastAPI(title="Semantic IR Backend") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def search( | |
| q: str = Query(..., min_length=1), top_k: int = Query(10, ge=1, le=50) | |
| ) -> Dict[str, Any]: | |
| bm25_scores = np.array(BM25.get_scores(_tokenize(q)), dtype=float) | |
| dense_scores = _dense_scores(q) | |
| query_roles = _extract_query_roles(q, NLP) | |
| srl_scores = np.array( | |
| [_role_overlap_score(query_roles, doc) for doc in CORPUS], dtype=float | |
| ) | |
| bm25_n = _norm(bm25_scores) | |
| dense_n = _norm(dense_scores) | |
| srl_n = _norm(srl_scores) | |
| hybrid_scores = 0.5 * bm25_n + 0.3 * dense_n + 0.2 * srl_n | |
| def _build_results( | |
| idx_array: np.ndarray, score_arr: np.ndarray | |
| ) -> List[Dict[str, Any]]: | |
| results: List[Dict[str, Any]] = [] | |
| for idx in idx_array: | |
| if score_arr[idx] <= 0.0: | |
| continue | |
| results.append( | |
| _build_result( | |
| int(idx), bm25_scores, dense_scores, srl_scores, hybrid_scores | |
| ) | |
| ) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| bm25_results = _build_results(np.argsort(-bm25_scores), bm25_scores) | |
| dense_results = _build_results(np.argsort(-dense_scores), dense_scores) | |
| srl_results = _build_results(np.argsort(-srl_scores), srl_scores) | |
| hybrid_results = _build_results(np.argsort(-hybrid_scores), hybrid_scores) | |
| return { | |
| "bm25_results": bm25_results, | |
| "dense_results": dense_results, | |
| "srl_results": srl_results, | |
| "hybrid_results": hybrid_results, | |
| } | |
| def health() -> Dict[str, Any]: | |
| return { | |
| "status": "ok", | |
| "documents": len(CORPUS), | |
| "dense_enabled": EMBEDDER is not None, | |
| } | |