Spaces:
Runtime error
Runtime error
| import re | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| try: | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| except Exception: # pragma: no cover - fallback path for minimal runtime | |
| TfidfVectorizer = None | |
| ENDPOINT_QUERY_HINTS: Dict[str, List[str]] = { | |
| "Genotoxicity (OECD TG)": [ | |
| "genotoxicity", | |
| "mutagenicity", | |
| "AMES", | |
| "micronucleus", | |
| "comet assay", | |
| "chromosomal aberration", | |
| "OECD TG 471 473 476 487 490 474 489", | |
| ], | |
| "NAMs / In Silico": [ | |
| "in silico", | |
| "QSAR", | |
| "read-across", | |
| "AOP", | |
| "PBPK", | |
| "high-throughput", | |
| "omics", | |
| "organ-on-chip", | |
| "microphysiological", | |
| ], | |
| "Acute toxicity": ["acute toxicity", "LD50", "LC50", "single dose", "lethality", "mortality"], | |
| "Repeated dose toxicity": [ | |
| "repeated dose", | |
| "subchronic", | |
| "chronic", | |
| "NOAEL", | |
| "LOAEL", | |
| "target organ", | |
| "90-day", | |
| "28-day", | |
| ], | |
| "Irritation / Sensitization": ["skin irritation", "eye irritation", "sensitization", "LLNA", "Draize"], | |
| "Repro / Developmental": ["reproductive toxicity", "fertility", "developmental toxicity", "teratogenic", "prenatal", "postnatal"], | |
| "Carcinogenicity": ["carcinogenicity", "tumor", "neoplasm", "cancer", "two-year bioassay"], | |
| } | |
| FRAMEWORK_QUERY_HINTS: Dict[str, List[str]] = { | |
| "FDA CTP": [ | |
| "genotoxicity hazard identification", | |
| "carcinogenicity tiering", | |
| "excess lifetime cancer risk", | |
| "constituent comparison", | |
| "weight of evidence", | |
| ], | |
| "EPA": [ | |
| "cancer slope factor", | |
| "inhalation unit risk", | |
| "lifetime excess cancer risk", | |
| "mode of action", | |
| "weight of evidence descriptors", | |
| ], | |
| } | |
| EQUATION_INPUT_HINTS: List[str] = [ | |
| "exposure concentration", | |
| "daily intake", | |
| "mg/kg-day", | |
| "ug/m3", | |
| "cancer slope factor", | |
| "inhalation unit risk", | |
| "body weight", | |
| ] | |
| def clean_text(t: str) -> str: | |
| t = (t or "").replace("\x00", " ") | |
| return re.sub(r"\s+", " ", t).strip() | |
| def split_sentences(text: str) -> List[str]: | |
| t = clean_text(text) | |
| if not t: | |
| return [] | |
| return [x.strip() for x in re.split(r"(?<=[\.!\?])\s+", t) if x.strip()] | |
| def _tokenize(s: str) -> List[str]: | |
| return [w for w in re.findall(r"[a-zA-Z0-9\-]+", (s or "").lower()) if len(w) >= 3] | |
| def extract_evidence_span(page_text: str, query: str, page: Optional[int] = None, n_sentences: int = 5) -> Dict[str, Any]: | |
| sents = split_sentences(page_text) | |
| if not sents: | |
| return {"text": "", "page": page, "start_sentence": 0, "mode": "empty"} | |
| qwords = _tokenize(query) | |
| hit_i = None | |
| for i, s in enumerate(sents): | |
| sl = s.lower() | |
| if any(w in sl for w in qwords): | |
| hit_i = i | |
| break | |
| if hit_i is None: | |
| snippet = " ".join(sents[:n_sentences]) | |
| return {"text": snippet, "page": page, "start_sentence": 0, "mode": "fallback"} | |
| start = max(0, hit_i - 2) | |
| end = min(len(sents), hit_i + 3) | |
| snippet = " ".join(sents[start:end]) | |
| return {"text": snippet, "page": page, "start_sentence": start, "mode": "hit"} | |
| def build_query_families( | |
| base_queries: List[str], | |
| endpoint_modules: Optional[List[str]] = None, | |
| frameworks: Optional[List[str]] = None, | |
| ) -> Dict[str, List[str]]: | |
| endpoint_modules = endpoint_modules or [] | |
| frameworks = frameworks or [] | |
| endpoint_terms: List[str] = [] | |
| for ep in endpoint_modules: | |
| endpoint_terms.extend(ENDPOINT_QUERY_HINTS.get(ep, [])) | |
| framework_terms: List[str] = [] | |
| for fw in frameworks: | |
| framework_terms.extend(FRAMEWORK_QUERY_HINTS.get(fw, [])) | |
| return { | |
| "base": [q for q in base_queries if (q or "").strip()], | |
| "endpoint": endpoint_terms, | |
| "framework": framework_terms, | |
| "equation_inputs": EQUATION_INPUT_HINTS, | |
| } | |
| def expand_regulatory_queries( | |
| base_queries: List[str], | |
| endpoint_modules: Optional[List[str]] = None, | |
| frameworks: Optional[List[str]] = None, | |
| extra_terms: Optional[List[str]] = None, | |
| ) -> Tuple[List[str], Dict[str, List[str]]]: | |
| families = build_query_families(base_queries, endpoint_modules, frameworks) | |
| queries: List[str] = [] | |
| for vals in families.values(): | |
| queries.extend(vals) | |
| queries.extend(extra_terms or []) | |
| deduped: List[str] = [] | |
| seen = set() | |
| for q in queries: | |
| x = (q or "").strip() | |
| if not x: | |
| continue | |
| k = x.lower() | |
| if k in seen: | |
| continue | |
| seen.add(k) | |
| deduped.append(x) | |
| return deduped, families | |
| def _lexical_ranks(texts: List[str], query: str) -> Tuple[List[int], np.ndarray]: | |
| if not texts: | |
| return [], np.array([], dtype=np.float32) | |
| if TfidfVectorizer is None: | |
| q_tokens = set(_tokenize(query)) | |
| sims = [] | |
| for t in texts: | |
| tl = t.lower() | |
| sims.append(float(sum(1 for tok in q_tokens if tok in tl))) | |
| arr = np.array(sims, dtype=np.float32) | |
| order = list(np.argsort(arr)[::-1]) | |
| return order, arr | |
| vec = TfidfVectorizer(stop_words="english", ngram_range=(1, 2), max_features=25000) | |
| x = vec.fit_transform(texts) | |
| qv = vec.transform([query]) | |
| sims = (x @ qv.T).toarray().ravel().astype(np.float32) | |
| order = list(np.argsort(sims)[::-1]) | |
| return order, sims | |
| def _embedding_ranks(item_embeddings: np.ndarray, query_embedding: np.ndarray) -> Tuple[List[int], np.ndarray]: | |
| if item_embeddings.size == 0: | |
| return [], np.array([], dtype=np.float32) | |
| q = np.asarray(query_embedding, dtype=np.float32) | |
| qn = np.linalg.norm(q) + 1e-12 | |
| q = q / qn | |
| mat = np.asarray(item_embeddings, dtype=np.float32) | |
| norms = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12 | |
| mat = mat / norms | |
| sims = (mat @ q).astype(np.float32) | |
| order = list(np.argsort(sims)[::-1]) | |
| return order, sims | |
| def _rrf_score(rank_lists: List[List[int]], k: int = 60) -> Dict[int, float]: | |
| out: Dict[int, float] = {} | |
| for rank_list in rank_lists: | |
| for rank_pos, idx in enumerate(rank_list): | |
| out[idx] = out.get(idx, 0.0) + (1.0 / (k + rank_pos + 1.0)) | |
| return out | |
| def _family_coverage_score(texts: List[str], families: Dict[str, List[str]]) -> Dict[str, float]: | |
| merged = " ".join([clean_text(t).lower() for t in texts]) | |
| out: Dict[str, float] = {} | |
| for family, queries in families.items(): | |
| if not queries: | |
| out[family] = 0.0 | |
| continue | |
| hits = 0 | |
| for q in queries: | |
| tokens = _tokenize(q) | |
| if any(t in merged for t in tokens): | |
| hits += 1 | |
| out[family] = round(hits / max(1, len(queries)), 4) | |
| return out | |
| def hybrid_rank_text_items( | |
| items: List[Dict[str, Any]], | |
| query: str, | |
| families: Optional[Dict[str, List[str]]] = None, | |
| top_k: int = 12, | |
| item_embeddings: Optional[np.ndarray] = None, | |
| query_embedding: Optional[np.ndarray] = None, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| if not items: | |
| return [], { | |
| "ranking_method": "empty", | |
| "selected_indices": [], | |
| "coverage_by_query_family": families or {}, | |
| "coverage_score": 0.0, | |
| "component_scores": {}, | |
| } | |
| texts = [clean_text(i.get("text", "")) for i in items] | |
| lex_order, lex_scores = _lexical_ranks(texts, query) | |
| rank_lists = [lex_order] | |
| method = "lexical_only" | |
| emb_scores = None | |
| if item_embeddings is not None and query_embedding is not None: | |
| try: | |
| emb_order, emb_scores = _embedding_ranks(item_embeddings, query_embedding) | |
| rank_lists.append(emb_order) | |
| method = "hybrid_rrf" | |
| except Exception: | |
| emb_scores = None | |
| rrf = _rrf_score(rank_lists) | |
| final_order = sorted(rrf.keys(), key=lambda idx: rrf[idx], reverse=True) | |
| selected_indices = [int(x) for x in final_order[: max(1, int(top_k))]] | |
| selected: List[Dict[str, Any]] = [] | |
| for idx in selected_indices: | |
| row = dict(items[idx]) | |
| row["_nlp_rrf_score"] = float(rrf.get(idx, 0.0)) | |
| row["_nlp_lex_score"] = float(lex_scores[idx]) if len(lex_scores) > idx else 0.0 | |
| if emb_scores is not None and len(emb_scores) > idx: | |
| row["_nlp_emb_score"] = float(emb_scores[idx]) | |
| selected.append(row) | |
| fam = families or {"base": [query]} | |
| cov = _family_coverage_score([x.get("text", "") for x in selected], fam) | |
| cov_score = round(float(np.mean(list(cov.values()))) if cov else 0.0, 4) | |
| diagnostics = { | |
| "ranking_method": method, | |
| "selected_indices": [int(x) for x in selected_indices], | |
| "coverage_by_query_family": cov, | |
| "coverage_score": cov_score, | |
| "component_scores": { | |
| "lexical": [float(lex_scores[i]) for i in selected_indices if len(lex_scores) > i], | |
| "embedding": [float(emb_scores[i]) for i in selected_indices if emb_scores is not None and len(emb_scores) > i], | |
| }, | |
| } | |
| return selected, diagnostics | |