NLP_Project / toxra_core /nlp_pipeline.py
hchevva's picture
Upload nlp_pipeline.py
ddfd78b verified
import re
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
from sklearn.feature_extraction.text import TfidfVectorizer
except Exception: # pragma: no cover - fallback path for minimal runtime
TfidfVectorizer = None
ENDPOINT_QUERY_HINTS: Dict[str, List[str]] = {
"Genotoxicity (OECD TG)": [
"genotoxicity",
"mutagenicity",
"AMES",
"micronucleus",
"comet assay",
"chromosomal aberration",
"OECD TG 471 473 476 487 490 474 489",
],
"NAMs / In Silico": [
"in silico",
"QSAR",
"read-across",
"AOP",
"PBPK",
"high-throughput",
"omics",
"organ-on-chip",
"microphysiological",
],
"Acute toxicity": ["acute toxicity", "LD50", "LC50", "single dose", "lethality", "mortality"],
"Repeated dose toxicity": [
"repeated dose",
"subchronic",
"chronic",
"NOAEL",
"LOAEL",
"target organ",
"90-day",
"28-day",
],
"Irritation / Sensitization": ["skin irritation", "eye irritation", "sensitization", "LLNA", "Draize"],
"Repro / Developmental": ["reproductive toxicity", "fertility", "developmental toxicity", "teratogenic", "prenatal", "postnatal"],
"Carcinogenicity": ["carcinogenicity", "tumor", "neoplasm", "cancer", "two-year bioassay"],
}
FRAMEWORK_QUERY_HINTS: Dict[str, List[str]] = {
"FDA CTP": [
"genotoxicity hazard identification",
"carcinogenicity tiering",
"excess lifetime cancer risk",
"constituent comparison",
"weight of evidence",
],
"EPA": [
"cancer slope factor",
"inhalation unit risk",
"lifetime excess cancer risk",
"mode of action",
"weight of evidence descriptors",
],
}
EQUATION_INPUT_HINTS: List[str] = [
"exposure concentration",
"daily intake",
"mg/kg-day",
"ug/m3",
"cancer slope factor",
"inhalation unit risk",
"body weight",
]
def clean_text(t: str) -> str:
t = (t or "").replace("\x00", " ")
return re.sub(r"\s+", " ", t).strip()
def split_sentences(text: str) -> List[str]:
t = clean_text(text)
if not t:
return []
return [x.strip() for x in re.split(r"(?<=[\.!\?])\s+", t) if x.strip()]
def _tokenize(s: str) -> List[str]:
return [w for w in re.findall(r"[a-zA-Z0-9\-]+", (s or "").lower()) if len(w) >= 3]
def extract_evidence_span(page_text: str, query: str, page: Optional[int] = None, n_sentences: int = 5) -> Dict[str, Any]:
sents = split_sentences(page_text)
if not sents:
return {"text": "", "page": page, "start_sentence": 0, "mode": "empty"}
qwords = _tokenize(query)
hit_i = None
for i, s in enumerate(sents):
sl = s.lower()
if any(w in sl for w in qwords):
hit_i = i
break
if hit_i is None:
snippet = " ".join(sents[:n_sentences])
return {"text": snippet, "page": page, "start_sentence": 0, "mode": "fallback"}
start = max(0, hit_i - 2)
end = min(len(sents), hit_i + 3)
snippet = " ".join(sents[start:end])
return {"text": snippet, "page": page, "start_sentence": start, "mode": "hit"}
def build_query_families(
base_queries: List[str],
endpoint_modules: Optional[List[str]] = None,
frameworks: Optional[List[str]] = None,
) -> Dict[str, List[str]]:
endpoint_modules = endpoint_modules or []
frameworks = frameworks or []
endpoint_terms: List[str] = []
for ep in endpoint_modules:
endpoint_terms.extend(ENDPOINT_QUERY_HINTS.get(ep, []))
framework_terms: List[str] = []
for fw in frameworks:
framework_terms.extend(FRAMEWORK_QUERY_HINTS.get(fw, []))
return {
"base": [q for q in base_queries if (q or "").strip()],
"endpoint": endpoint_terms,
"framework": framework_terms,
"equation_inputs": EQUATION_INPUT_HINTS,
}
def expand_regulatory_queries(
base_queries: List[str],
endpoint_modules: Optional[List[str]] = None,
frameworks: Optional[List[str]] = None,
extra_terms: Optional[List[str]] = None,
) -> Tuple[List[str], Dict[str, List[str]]]:
families = build_query_families(base_queries, endpoint_modules, frameworks)
queries: List[str] = []
for vals in families.values():
queries.extend(vals)
queries.extend(extra_terms or [])
deduped: List[str] = []
seen = set()
for q in queries:
x = (q or "").strip()
if not x:
continue
k = x.lower()
if k in seen:
continue
seen.add(k)
deduped.append(x)
return deduped, families
def _lexical_ranks(texts: List[str], query: str) -> Tuple[List[int], np.ndarray]:
if not texts:
return [], np.array([], dtype=np.float32)
if TfidfVectorizer is None:
q_tokens = set(_tokenize(query))
sims = []
for t in texts:
tl = t.lower()
sims.append(float(sum(1 for tok in q_tokens if tok in tl)))
arr = np.array(sims, dtype=np.float32)
order = list(np.argsort(arr)[::-1])
return order, arr
vec = TfidfVectorizer(stop_words="english", ngram_range=(1, 2), max_features=25000)
x = vec.fit_transform(texts)
qv = vec.transform([query])
sims = (x @ qv.T).toarray().ravel().astype(np.float32)
order = list(np.argsort(sims)[::-1])
return order, sims
def _embedding_ranks(item_embeddings: np.ndarray, query_embedding: np.ndarray) -> Tuple[List[int], np.ndarray]:
if item_embeddings.size == 0:
return [], np.array([], dtype=np.float32)
q = np.asarray(query_embedding, dtype=np.float32)
qn = np.linalg.norm(q) + 1e-12
q = q / qn
mat = np.asarray(item_embeddings, dtype=np.float32)
norms = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12
mat = mat / norms
sims = (mat @ q).astype(np.float32)
order = list(np.argsort(sims)[::-1])
return order, sims
def _rrf_score(rank_lists: List[List[int]], k: int = 60) -> Dict[int, float]:
out: Dict[int, float] = {}
for rank_list in rank_lists:
for rank_pos, idx in enumerate(rank_list):
out[idx] = out.get(idx, 0.0) + (1.0 / (k + rank_pos + 1.0))
return out
def _family_coverage_score(texts: List[str], families: Dict[str, List[str]]) -> Dict[str, float]:
merged = " ".join([clean_text(t).lower() for t in texts])
out: Dict[str, float] = {}
for family, queries in families.items():
if not queries:
out[family] = 0.0
continue
hits = 0
for q in queries:
tokens = _tokenize(q)
if any(t in merged for t in tokens):
hits += 1
out[family] = round(hits / max(1, len(queries)), 4)
return out
def hybrid_rank_text_items(
items: List[Dict[str, Any]],
query: str,
families: Optional[Dict[str, List[str]]] = None,
top_k: int = 12,
item_embeddings: Optional[np.ndarray] = None,
query_embedding: Optional[np.ndarray] = None,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
if not items:
return [], {
"ranking_method": "empty",
"selected_indices": [],
"coverage_by_query_family": families or {},
"coverage_score": 0.0,
"component_scores": {},
}
texts = [clean_text(i.get("text", "")) for i in items]
lex_order, lex_scores = _lexical_ranks(texts, query)
rank_lists = [lex_order]
method = "lexical_only"
emb_scores = None
if item_embeddings is not None and query_embedding is not None:
try:
emb_order, emb_scores = _embedding_ranks(item_embeddings, query_embedding)
rank_lists.append(emb_order)
method = "hybrid_rrf"
except Exception:
emb_scores = None
rrf = _rrf_score(rank_lists)
final_order = sorted(rrf.keys(), key=lambda idx: rrf[idx], reverse=True)
selected_indices = [int(x) for x in final_order[: max(1, int(top_k))]]
selected: List[Dict[str, Any]] = []
for idx in selected_indices:
row = dict(items[idx])
row["_nlp_rrf_score"] = float(rrf.get(idx, 0.0))
row["_nlp_lex_score"] = float(lex_scores[idx]) if len(lex_scores) > idx else 0.0
if emb_scores is not None and len(emb_scores) > idx:
row["_nlp_emb_score"] = float(emb_scores[idx])
selected.append(row)
fam = families or {"base": [query]}
cov = _family_coverage_score([x.get("text", "") for x in selected], fam)
cov_score = round(float(np.mean(list(cov.values()))) if cov else 0.0, 4)
diagnostics = {
"ranking_method": method,
"selected_indices": [int(x) for x in selected_indices],
"coverage_by_query_family": cov,
"coverage_score": cov_score,
"component_scores": {
"lexical": [float(lex_scores[i]) for i in selected_indices if len(lex_scores) > i],
"embedding": [float(emb_scores[i]) for i in selected_indices if emb_scores is not None and len(emb_scores) > i],
},
}
return selected, diagnostics