Kamal Nayan Kumar
feat: revert backend to use sentence-transformers (all-MiniLM-L6-v2) for better embeddings
548ff8f
import json
import os
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
import spacy
from spacy import displacy
from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
from rank_bm25 import BM25Okapi
try:
from sentence_transformers import SentenceTransformer
except Exception:
SentenceTransformer = None
BASE_DIR = Path(__file__).resolve().parents[1]
CORPUS_PATH = BASE_DIR / "data" / "corpus.json"
EMBEDDINGS_PATH = BASE_DIR / "data" / "embeddings.npy"
# Hugging Face Spaces provides 16GB of free RAM, so we don't need strict truncation.
# However, we still load pre-computed embeddings if available to speed up boot times.
MAX_DOCS = int(os.environ.get("MAX_DOCS", 50000))
def _norm(values: np.ndarray) -> np.ndarray:
if values.size == 0:
return values
v_max = float(np.max(values))
if v_max <= 0:
return np.zeros_like(values, dtype=float)
return values / v_max
def _tokenize(text: str) -> List[str]:
return [tok.lower() for tok in text.split() if tok.strip()]
def _extract_query_roles(query: str, nlp_model) -> Dict[str, str]:
doc = nlp_model(query)
predicate = ""
arg0: List[str] = []
arg1: List[str] = []
argm_tmp: List[str] = []
for tok in doc:
if not predicate and tok.pos_ == "VERB":
predicate = tok.lemma_.lower()
for tok in doc:
if tok.dep_ in {"nsubj", "nsubjpass"}:
arg0.append(" ".join(t.text for t in tok.subtree))
elif tok.dep_ in {"dobj", "obj", "iobj"}:
arg1.append(" ".join(t.text for t in tok.subtree))
elif tok.dep_ in {"obl", "npadvmod"} and tok.ent_type_ in {"DATE", "TIME"}:
argm_tmp.append(" ".join(t.text for t in tok.subtree))
out: Dict[str, str] = {"predicate": predicate}
if arg0:
out["ARG0"] = " ".join(dict.fromkeys(arg0))
if arg1:
out["ARG1"] = " ".join(dict.fromkeys(arg1))
if argm_tmp:
out["ARGM-TMP"] = " ".join(dict.fromkeys(argm_tmp))
return out
def _role_overlap_score(
query_roles: Dict[str, str], doc_roles: Dict[str, str]
) -> float:
keys = [k for k, v in query_roles.items() if v]
if not keys:
return 0.0
scores: List[float] = []
for key in keys:
if key == "predicate":
q_pred = query_roles.get("predicate", "").split(".")[0]
d_pred = doc_roles.get("predicate", "").split(".")[0]
scores.append(1.0 if q_pred and d_pred and q_pred == d_pred else 0.0)
continue
q_tokens = set(_tokenize(query_roles.get(key, "")))
d_tokens = set(_tokenize(doc_roles.get(key, "")))
if not q_tokens or not d_tokens:
scores.append(0.0)
continue
inter = len(q_tokens.intersection(d_tokens))
union = len(q_tokens.union(d_tokens))
scores.append(inter / union if union else 0.0)
return float(np.mean(scores)) if scores else 0.0
def _maybe_load_embedder() -> Any:
if SentenceTransformer is None:
return None
try:
return SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
print(f"Failed to load embedder: {e}")
return None
def _dense_scores(query: str) -> np.ndarray:
if EMBEDDER is None or DOC_EMBEDDINGS.size == 0:
return np.zeros(len(CORPUS), dtype=float)
query_embedding = EMBEDDER.encode(query)
return np.array(np.dot(DOC_EMBEDDINGS, query_embedding), dtype=float)
ROLE_KEYS = {
"predicate",
"ARG0",
"ARG1",
"ARG2",
"ARG3",
"ARG4",
"ARGM-LOC",
"ARGM-TMP",
"ARGM-MNR",
"ARGM-CAU",
"ARGM-DIS",
"ARGM-ADV",
"ARGM-MOD",
"ARGM-NEG",
}
def _build_result(
idx: int,
bm25_scores: np.ndarray,
dense_scores: np.ndarray,
srl_scores: np.ndarray,
hybrid_scores: np.ndarray,
) -> Dict[str, Any]:
doc = CORPUS[int(idx)]
text = doc.get("text", "")
roles = {k: v for k, v in doc.items() if k in ROLE_KEYS and v}
try:
spacy_doc = NLP(text)
html = displacy.render(
spacy_doc,
style="dep",
page=False,
options={
"compact": True,
"color": "#e4e4e7",
"font": "sans-serif",
"bg": "transparent",
},
)
except Exception:
html = ""
return {
"id": doc.get("id"),
"text": text,
"scores": {
"bm25": float(bm25_scores[idx]),
"dense": float(dense_scores[idx]),
"srl": float(srl_scores[idx]),
"hybrid": float(hybrid_scores[idx]),
},
"roles": roles,
"displacy_html": html,
}
with CORPUS_PATH.open("r", encoding="utf-8") as f:
CORPUS: List[Dict[str, Any]] = json.load(f)
if len(CORPUS) > MAX_DOCS:
print(f"Truncating corpus from {len(CORPUS)} to {MAX_DOCS} docs to prevent OOM.")
CORPUS = CORPUS[:MAX_DOCS]
TEXTS = [doc.get("text", "") for doc in CORPUS]
BM25 = BM25Okapi([_tokenize(text) for text in TEXTS])
EMBEDDER = _maybe_load_embedder()
if EMBEDDINGS_PATH.exists():
print(f"Loading pre-computed embeddings from {EMBEDDINGS_PATH}")
DOC_EMBEDDINGS = np.load(str(EMBEDDINGS_PATH))
if len(DOC_EMBEDDINGS) > MAX_DOCS:
DOC_EMBEDDINGS = DOC_EMBEDDINGS[:MAX_DOCS]
elif EMBEDDER is not None and TEXTS:
print("Computing embeddings on the fly... (This may cause OOM on small instances)")
DOC_EMBEDDINGS = np.array(EMBEDDER.encode(TEXTS))
else:
DOC_EMBEDDINGS = np.array([])
try:
NLP = spacy.load("en_core_web_sm")
except Exception:
NLP = spacy.blank("en")
app = FastAPI(title="Semantic IR Backend")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/search")
def search(
q: str = Query(..., min_length=1), top_k: int = Query(10, ge=1, le=50)
) -> Dict[str, Any]:
bm25_scores = np.array(BM25.get_scores(_tokenize(q)), dtype=float)
dense_scores = _dense_scores(q)
query_roles = _extract_query_roles(q, NLP)
srl_scores = np.array(
[_role_overlap_score(query_roles, doc) for doc in CORPUS], dtype=float
)
bm25_n = _norm(bm25_scores)
dense_n = _norm(dense_scores)
srl_n = _norm(srl_scores)
hybrid_scores = 0.5 * bm25_n + 0.3 * dense_n + 0.2 * srl_n
def _build_results(
idx_array: np.ndarray, score_arr: np.ndarray
) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
for idx in idx_array:
if score_arr[idx] <= 0.0:
continue
results.append(
_build_result(
int(idx), bm25_scores, dense_scores, srl_scores, hybrid_scores
)
)
if len(results) >= top_k:
break
return results
bm25_results = _build_results(np.argsort(-bm25_scores), bm25_scores)
dense_results = _build_results(np.argsort(-dense_scores), dense_scores)
srl_results = _build_results(np.argsort(-srl_scores), srl_scores)
hybrid_results = _build_results(np.argsort(-hybrid_scores), hybrid_scores)
return {
"bm25_results": bm25_results,
"dense_results": dense_results,
"srl_results": srl_results,
"hybrid_results": hybrid_results,
}
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"status": "ok",
"documents": len(CORPUS),
"dense_enabled": EMBEDDER is not None,
}