Spaces:
Running
Running
| """Advanced biomedical feature engineering for DDI prediction. | |
| This module extends the lightweight 20-feature baseline with optional: | |
| - RDKit Morgan fingerprints and pairwise similarity metrics | |
| - pharmacology feature vectors (CYP450, ATC, targets, transporters, MOA) | |
| - semantic biomedical embeddings with cache-backed fallbacks | |
| - pairwise interaction features for shared pathways and metabolism conflicts | |
| All expensive feature sources are optional. When a source is unavailable, | |
| the module falls back to deterministic hashed vectors so the pipeline remains | |
| deployable on CPU with p99 latency constraints. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, asdict | |
| import hashlib | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple | |
| import joblib | |
| import numpy as np | |
| try: | |
| from rdkit import Chem, DataStructs | |
| from rdkit.Chem import Descriptors | |
| from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator | |
| except Exception: # pragma: no cover - optional dependency | |
| Chem = None # type: ignore | |
| DataStructs = None # type: ignore | |
| GetMorganGenerator = None # type: ignore | |
| Descriptors = None # type: ignore | |
| from chemistry.smiles_recovery import recover_invalid_smiles, validate_smiles, write_smiles_recovery_report | |
| from training.embeddings import EmbeddingService, init_embedding_service | |
| from training.graph_representations import build_drug_graph_bundle, load_drugbank_metadata | |
| from training.molecular_sanitization import ( | |
| InvalidMoleculeTracker, | |
| build_graph_health_metrics, | |
| validate_graph_object, | |
| write_graph_quality_report, | |
| ) | |
| logger = logging.getLogger("medcare_ddi.advanced_features") | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| MODELS_DIR = BASE_DIR / "models" | |
| MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| MAX_INVALID_RATE_DEFAULT = 0.15 | |
| def _stable_hash(value: str, modulo: int = 2**31 - 1) -> int: | |
| digest = hashlib.sha256(str(value).encode("utf-8")).hexdigest() | |
| return int(digest[:16], 16) % modulo | |
| def _normalize(value: Any) -> str: | |
| return " ".join(str(value or "").strip().lower().split()) | |
| def _safe_mol(smiles: str): | |
| validated = validate_smiles(smiles) | |
| return validated.get("mol") if validated.get("valid") else None | |
| def _hashed_vector(tokens: Sequence[str], dim: int) -> np.ndarray: | |
| vec = np.zeros(dim, dtype=np.float32) | |
| for token in tokens: | |
| if not token: | |
| continue | |
| vec[_stable_hash(token, dim)] += 1.0 | |
| if vec.sum() > 0: | |
| vec /= max(1.0, float(np.linalg.norm(vec))) | |
| return vec | |
| def _pair_similarity_features(fp_a: np.ndarray, fp_b: np.ndarray) -> np.ndarray: | |
| both_valid = float(fp_a.sum() > 0 and fp_b.sum() > 0) | |
| any_invalid = 1.0 - both_valid | |
| if not both_valid: | |
| return np.array([0.0, 0.0, 0.0, both_valid, any_invalid], dtype=np.float32) | |
| intersection = float(np.minimum(fp_a, fp_b).sum()) | |
| union = float(np.maximum(fp_a, fp_b).sum()) + 1e-8 | |
| tanimoto = intersection / union | |
| dice = 2.0 * intersection / (fp_a.sum() + fp_b.sum() + 1e-8) | |
| cosine = float(np.dot(fp_a, fp_b) / (np.linalg.norm(fp_a) * np.linalg.norm(fp_b) + 1e-8)) | |
| return np.array([tanimoto, dice, cosine, both_valid, any_invalid], dtype=np.float32) | |
| def _unknown_vector(dim: int, namespace: str) -> np.ndarray: | |
| vec = _hashed_vector([f"UNKNOWN_{namespace}"], dim) | |
| if not np.any(vec): | |
| vec[0] = 1.0 | |
| return vec.astype(np.float32) | |
| def _morgan_features(smiles: str, radius: int = 2, n_bits: int = 2048) -> np.ndarray: | |
| if GetMorganGenerator is None or DataStructs is None: | |
| return _hashed_vector([smiles], n_bits) | |
| validated = validate_smiles(smiles) | |
| if not validated.get("valid") or validated.get("mol") is None: | |
| return _unknown_vector(n_bits, "DRUG") | |
| generator = GetMorganGenerator(radius=radius, fpSize=n_bits) | |
| fp = generator.GetFingerprint(validated["mol"]) | |
| arr = np.zeros((n_bits,), dtype=np.int8) | |
| DataStructs.ConvertToNumpyArray(fp, arr) | |
| return arr.astype(np.float32) | |
| def _descriptor_features(smiles: str) -> np.ndarray: | |
| if Chem is None or Descriptors is None: | |
| return np.zeros(12, dtype=np.float32) | |
| validated = validate_smiles(smiles) | |
| if not validated.get("valid") or validated.get("mol") is None: | |
| return np.zeros(12, dtype=np.float32) | |
| mol = validated["mol"] | |
| return np.array( | |
| [ | |
| float(Descriptors.MolWt(mol)), | |
| float(Descriptors.MolLogP(mol)), | |
| float(Descriptors.TPSA(mol)), | |
| float(Descriptors.NumHDonors(mol)), | |
| float(Descriptors.NumHAcceptors(mol)), | |
| float(Descriptors.NumRotatableBonds(mol)), | |
| float(Descriptors.RingCount(mol)), | |
| float(mol.GetNumAtoms()), | |
| float(mol.GetNumHeavyAtoms()), | |
| float(mol.GetNumBonds()), | |
| float(Descriptors.FractionCSP3(mol)), | |
| float(Descriptors.HeavyAtomMolWt(mol)), | |
| ], | |
| dtype=np.float32, | |
| ) | |
| def _pairwise_molecular_features(smiles_a: str, smiles_b: str) -> np.ndarray: | |
| fp_a = _morgan_features(smiles_a) | |
| fp_b = _morgan_features(smiles_b) | |
| sim = _pair_similarity_features(fp_a, fp_b) | |
| desc_a = _descriptor_features(smiles_a) | |
| desc_b = _descriptor_features(smiles_b) | |
| delta = np.abs(desc_a - desc_b) | |
| return np.concatenate([sim, desc_a, desc_b, delta], axis=0).astype(np.float32) | |
| class AdvancedFeatureConfig: | |
| fingerprint_dim: int = 2048 | |
| semantic_dim: int = 768 | |
| pharmacology_dim: int = 128 | |
| pair_dim: int = 64 | |
| cache_dir: str = "models/feature_cache" | |
| use_transformer_embeddings: bool = False | |
| bio_text_model: str = "pubmedbert" | |
| bio_semantic_model: str = "pubmedbert" | |
| smiles_model: str = "seyonec/ChemBERTa-zinc-base-v1" | |
| invalid_rate_threshold: float = MAX_INVALID_RATE_DEFAULT | |
| class BiomedicalFeatureCache: | |
| def __init__(self, cache_dir: Path): | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| def _path(self, key: str) -> Path: | |
| digest = hashlib.sha256(key.encode("utf-8")).hexdigest()[:40] | |
| return self.cache_dir / f"{digest}.joblib" | |
| def get(self, key: str) -> Any | None: | |
| path = self._path(key) | |
| return joblib.load(path) if path.exists() else None | |
| def put(self, key: str, value: Any) -> None: | |
| joblib.dump(value, self._path(key)) | |
| def path_for_key(self, key: str) -> Path: | |
| return self._path(key) | |
| class AdvancedBiomedicalFeatureEngineer: | |
| """Build advanced biomedical features for DDI pairs. | |
| Optional metadata maps can include: | |
| - smiles_map: drug -> SMILES | |
| - atc_map: drug -> ATC code or ATC string | |
| - target_map: drug -> iterable of target proteins | |
| - cyp_map: drug -> iterable of CYP enzymes / inhibitors | |
| - moa_map: drug -> mechanism-of-action string | |
| - transporter_map: drug -> iterable of transporters | |
| - description_map: drug -> free-text description | |
| - active_ingredient_map: drug -> active ingredient text | |
| """ | |
| def __init__(self, config: AdvancedFeatureConfig | None = None, metadata: Optional[Mapping[str, Mapping[str, Any]]] = None): | |
| self.config = config or AdvancedFeatureConfig() | |
| self.metadata = metadata or {} | |
| cache_root = Path(self.config.cache_dir) | |
| if not cache_root.is_absolute(): | |
| cache_root = BASE_DIR / cache_root | |
| self.cache = BiomedicalFeatureCache(cache_root) | |
| self.embedding_service: EmbeddingService = init_embedding_service(cache_dir=str(cache_root / "embeddings")) | |
| self._graph_metadata_cache: Optional[dict[str, dict[str, Any]]] = None | |
| self.invalid_tracker = InvalidMoleculeTracker() | |
| self._smiles_recovery_cache: dict[str, dict[str, Any]] = {} | |
| def _drug_meta(self, drug: str) -> Mapping[str, Any]: | |
| key = _normalize(drug) | |
| meta = self.metadata.get(key) | |
| if meta: | |
| return meta | |
| return self._graph_metadata().get(key, {}) | |
| def _smiles(self, drug: str) -> str: | |
| meta = self._drug_meta(drug) | |
| # Never lowercase SMILES; preserve exact chemistry notation. | |
| return str(meta.get("smiles") or meta.get("SMILES") or "").strip() | |
| def _recover_smiles(self, drug: str) -> dict[str, Any]: | |
| key = _normalize(drug) | |
| cached = self._smiles_recovery_cache.get(key) | |
| if cached is not None: | |
| return cached | |
| meta = self._drug_meta(drug) | |
| raw_smiles = str(meta.get("smiles") or meta.get("SMILES") or meta.get("smiles_raw") or "").strip() | |
| recovered = recover_invalid_smiles(drug, raw_smiles) | |
| self._smiles_recovery_cache[key] = recovered | |
| return recovered | |
| def _pair_metadata(self, drug_a: str, drug_b: str) -> dict[str, dict[str, Any]]: | |
| metadata = {key: dict(value) for key, value in self._graph_metadata().items()} | |
| for drug in (drug_a, drug_b): | |
| key = _normalize(drug) | |
| meta = dict(self._drug_meta(drug)) | |
| recovered = self._recover_smiles(drug) | |
| if recovered.get("valid") and recovered.get("canonical_smiles"): | |
| meta["smiles"] = str(recovered["canonical_smiles"]) | |
| metadata[key] = meta | |
| return metadata | |
| def _graph_metadata(self) -> dict[str, dict[str, Any]]: | |
| if self._graph_metadata_cache is not None: | |
| return self._graph_metadata_cache | |
| if self.metadata: | |
| self._graph_metadata_cache = { _normalize(key): dict(value) for key, value in self.metadata.items() } | |
| return self._graph_metadata_cache | |
| try: | |
| self._graph_metadata_cache = load_drugbank_metadata() | |
| except Exception as exc: | |
| logger.warning("Graph metadata fallback to empty map: %s", exc) | |
| self._graph_metadata_cache = {} | |
| return self._graph_metadata_cache | |
| def _text(self, drug: str, keys: Sequence[str]) -> str: | |
| meta = self._drug_meta(drug) | |
| parts: List[str] = [] | |
| for key in keys: | |
| value = meta.get(key, "") | |
| if isinstance(value, (list, tuple, set)): | |
| parts.extend([_normalize(item) for item in value]) | |
| else: | |
| parts.append(_normalize(value)) | |
| return " ".join(part for part in parts if part) | |
| def _semantic_vector(self, texts: List[str], model_name: str) -> np.ndarray: | |
| cleaned = [_normalize(text) for text in texts] | |
| key = json.dumps({"model": model_name, "texts": cleaned}, sort_keys=True) | |
| cached = self.cache.get(key) | |
| if cached is not None: | |
| return cached | |
| if not self.config.use_transformer_embeddings: | |
| emb = np.vstack([_hashed_vector([text], self.config.semantic_dim) for text in cleaned]).astype(np.float32) | |
| elif any(cleaned): | |
| try: | |
| emb = self.embedding_service.get_text_embeddings(cleaned, model_name=model_name, batch_size=8) | |
| except Exception as exc: | |
| logger.warning("Embedding fallback used for %s: %s", model_name, exc) | |
| emb = np.vstack([_hashed_vector([text], self.config.semantic_dim) for text in cleaned]) | |
| else: | |
| emb = np.zeros((len(cleaned), self.config.semantic_dim), dtype=np.float32) | |
| self.cache.put(key, emb) | |
| return emb | |
| def _pharmacology_vector(self, drug: str) -> np.ndarray: | |
| meta = self._drug_meta(drug) | |
| tokens: List[str] = [] | |
| for key in ("atc", "atc_code", "targets", "cyp", "enzymes", "transporters", "mechanism", "moa"): | |
| value = meta.get(key, []) | |
| if isinstance(value, str): | |
| tokens.extend(_tokenize_string(value)) | |
| else: | |
| for item in value if isinstance(value, (list, tuple, set)) else [value]: | |
| tokens.extend(_tokenize_string(str(item))) | |
| if meta.get("active_ingredient"): | |
| tokens.extend(_tokenize_string(str(meta["active_ingredient"]))) | |
| if meta.get("description"): | |
| tokens.extend(_tokenize_string(str(meta["description"]))) | |
| return _hashed_vector(tokens, self.config.pharmacology_dim) | |
| def _pair_pharmacology_vector(self, drug_a: str, drug_b: str) -> np.ndarray: | |
| vec_a = self._pharmacology_vector(drug_a) | |
| vec_b = self._pharmacology_vector(drug_b) | |
| shared = np.minimum(vec_a, vec_b) | |
| delta = np.abs(vec_a - vec_b) | |
| return np.concatenate([vec_a, vec_b, shared, delta], axis=0).astype(np.float32) | |
| def _pair_semantic_vector(self, drug_a: str, drug_b: str) -> np.ndarray: | |
| text_a = " ".join( | |
| part for part in [ | |
| _normalize(drug_a), | |
| self._text(drug_a, ["description", "active_ingredient", "moa"]), | |
| ] if part | |
| ) | |
| text_b = " ".join( | |
| part for part in [ | |
| _normalize(drug_b), | |
| self._text(drug_b, ["description", "active_ingredient", "moa"]), | |
| ] if part | |
| ) | |
| emb = self._semantic_vector([text_a, text_b], self.config.bio_semantic_model) | |
| return np.concatenate([emb[0], emb[1], np.abs(emb[0] - emb[1]), emb[0] * emb[1]], axis=0).astype(np.float32) | |
| def pair_features(self, drug_a: str, drug_b: str) -> Dict[str, np.ndarray]: | |
| recovered_a = self._recover_smiles(drug_a) | |
| recovered_b = self._recover_smiles(drug_b) | |
| smiles_a = str(recovered_a.get("canonical_smiles") or "") | |
| smiles_b = str(recovered_b.get("canonical_smiles") or "") | |
| pair_mol = _pairwise_molecular_features(smiles_a, smiles_b) | |
| fp_a = _morgan_features(smiles_a, radius=2, n_bits=self.config.fingerprint_dim) | |
| fp_b = _morgan_features(smiles_b, radius=2, n_bits=self.config.fingerprint_dim) | |
| pair_fp = np.concatenate([fp_a, fp_b, np.abs(fp_a - fp_b), fp_a * fp_b], axis=0).astype(np.float32) | |
| pharma = self._pair_pharmacology_vector(drug_a, drug_b) | |
| semantic = self._pair_semantic_vector(drug_a, drug_b) | |
| pair_tokens = [ | |
| _normalize(drug_a), _normalize(drug_b), | |
| self._text(drug_a, ["targets", "cyp", "transporters", "moa"]), | |
| self._text(drug_b, ["targets", "cyp", "transporters", "moa"]), | |
| ] | |
| pairwise = _hashed_vector([token for token in pair_tokens if token], self.config.pair_dim) | |
| return { | |
| "fingerprint": pair_fp, | |
| "semantic": semantic, | |
| "pharmacology": pharma, | |
| "pairwise": pairwise, | |
| "molecular_pair": pair_mol, | |
| "fused": np.concatenate([pair_fp, semantic, pharma, pairwise, pair_mol], axis=0).astype(np.float32), | |
| } | |
| def pair_graph_bundle(self, drug_a: str, drug_b: str) -> Dict[str, Any]: | |
| """Return graph inputs for the pair, using DrugBank-backed metadata when available.""" | |
| metadata = self._pair_metadata(drug_a, drug_b) | |
| bundle = build_drug_graph_bundle(drug_a, drug_b, metadata=metadata) | |
| for graph_key in ("drug_a_graph", "drug_b_graph", "pharmacology_graph", "interaction_graph"): | |
| graph = bundle.get(graph_key) | |
| if graph is None: | |
| continue | |
| errors = validate_graph_object(graph) | |
| if errors: | |
| bundle.setdefault("graph_validation_errors", {})[graph_key] = errors | |
| return bundle | |
| def graph_summary(self, drug_a: str, drug_b: str) -> np.ndarray: | |
| """Return a compact dense summary of the graph bundle for backward-compatible models.""" | |
| bundle = self.pair_graph_bundle(drug_a, drug_b) | |
| summary = bundle["interaction_summary"] | |
| if hasattr(summary, "detach"): | |
| return summary.detach().cpu().numpy().astype(np.float32) | |
| return np.asarray(summary, dtype=np.float32) | |
| def batch_features(self, drug_pairs: Iterable[Tuple[str, str]]) -> Dict[str, np.ndarray]: | |
| rows = [self.pair_features(a, b) for a, b in drug_pairs] | |
| keys = rows[0].keys() if rows else [] | |
| return {key: np.vstack([row[key] for row in rows]).astype(np.float32) for key in keys} | |
| def preprocess_pairs_with_quality_gates( | |
| self, | |
| df, | |
| *, | |
| drug_a_col: str = "Drug_A", | |
| drug_b_col: str = "Drug_B", | |
| label_col: str = "Level", | |
| output_dir: Optional[Path] = None, | |
| invalid_rate_threshold: Optional[float] = None, | |
| ): | |
| """Preprocess and filter invalid chemistry with deterministic caching. | |
| Returns filtered dataframe with feature columns and a metrics dictionary. | |
| Raises ValueError if quality gates fail. | |
| """ | |
| output_dir = output_dir or (MODELS_DIR / "reports" / "chemistry") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| invalid_rate_threshold = float(invalid_rate_threshold if invalid_rate_threshold is not None else self.config.invalid_rate_threshold) | |
| pair_cache_dir = self.cache.cache_dir / "preprocessed_pairs" | |
| pair_cache_dir.mkdir(parents=True, exist_ok=True) | |
| total_rows = int(len(df)) | |
| kept_records: list[dict[str, Any]] = [] | |
| graph_bundles: list[dict[str, Any]] = [] | |
| recovery_audit_records: list[dict[str, Any]] = [] | |
| expected_dims: dict[str, int] | None = None | |
| def _log_dropped_row(row_index: int, drug_name: str, raw_smiles: str, reason: str) -> None: | |
| if not reason: | |
| reason = "unknown" | |
| self.invalid_tracker.add(row_index, raw_smiles, reason, drug_name=drug_name) | |
| def _append_recovery_audit(row_index: int, drug_name: str, recovery: dict[str, Any]) -> None: | |
| validation_status = "valid" if recovery.get("valid") else "invalid" | |
| recovery_audit_records.append( | |
| { | |
| "row_index": int(row_index), | |
| "drug_name": str(drug_name), | |
| "original_smiles": str(recovery.get("original_smiles") or ""), | |
| "repaired_smiles": recovery.get("canonical_smiles"), | |
| "recovery_method": str(recovery.get("recovery_method") or "failed_recovery"), | |
| "validation_status": validation_status, | |
| "failure_reason": None if validation_status == "valid" else str(recovery.get("error") or "recovery_failed"), | |
| } | |
| ) | |
| for row_idx, row in df.reset_index(drop=False).iterrows(): | |
| raw_index = int(row.get("index", row_idx)) | |
| drug_a = str(row[drug_a_col]) | |
| drug_b = str(row[drug_b_col]) | |
| recovery_a = self._recover_smiles(drug_a) | |
| recovery_b = self._recover_smiles(drug_b) | |
| _append_recovery_audit(raw_index, drug_a, recovery_a) | |
| _append_recovery_audit(raw_index, drug_b, recovery_b) | |
| cache_key = json.dumps( | |
| { | |
| "v": 1, | |
| "a": _normalize(drug_a), | |
| "b": _normalize(drug_b), | |
| "idx": raw_index, | |
| }, | |
| sort_keys=True, | |
| ) | |
| cache_path = self.cache.path_for_key(cache_key) | |
| cached_row = None | |
| if cache_path.exists(): | |
| try: | |
| cached_row = joblib.load(cache_path) | |
| except Exception: | |
| cached_row = None | |
| if cached_row is not None: | |
| bundle = cached_row["graph_bundle"] | |
| graph_bundles.append(bundle) | |
| if not cached_row.get("is_quarantined", False): | |
| record = cached_row["record"] | |
| if expected_dims is None: | |
| expected_dims = { | |
| "fingerprint": int(record["fingerprint"].shape[0]), | |
| "semantic": int(record["semantic"].shape[0]), | |
| "pharmacology": int(record["pharmacology"].shape[0]), | |
| "pairwise": int(record["pairwise"].shape[0]), | |
| "molecular_pair": int(record["molecular_pair"].shape[0]), | |
| "graph_summary": int(record["graph_summary"].shape[0]), | |
| } | |
| kept_records.append(record) | |
| elif bundle.get("quarantine_reasons"): | |
| for reason in bundle.get("quarantine_reasons", []): | |
| _log_dropped_row(raw_index, drug_a if recovery_a.get("valid") is False else drug_b, "", str(reason)) | |
| continue | |
| bundle = self.pair_graph_bundle(drug_a, drug_b) | |
| graph_bundles.append(bundle) | |
| smiles_a_val = bundle.get("smiles_a_validation", {}) | |
| smiles_b_val = bundle.get("smiles_b_validation", {}) | |
| if not bool(smiles_a_val.get("valid", False)): | |
| _log_dropped_row(raw_index, drug_a, str(bundle.get("smiles_a_raw", "")), str(smiles_a_val.get("error", smiles_a_val.get("reason", "unknown")))) | |
| if not bool(smiles_b_val.get("valid", False)): | |
| _log_dropped_row(raw_index, drug_b, str(bundle.get("smiles_b_raw", "")), str(smiles_b_val.get("error", smiles_b_val.get("reason", "unknown")))) | |
| has_graph_errors = bool(bundle.get("graph_validation_errors")) | |
| is_quarantined = bool(bundle.get("quarantined", False)) or has_graph_errors | |
| if is_quarantined: | |
| for reason in bundle.get("quarantine_reasons", []) or []: | |
| _log_dropped_row(raw_index, drug_a if not recovery_a.get("valid") else drug_b, str(bundle.get("smiles_a_raw", "") if not recovery_a.get("valid") else bundle.get("smiles_b_raw", "")), str(reason)) | |
| joblib.dump( | |
| { | |
| "is_quarantined": True, | |
| "graph_bundle": bundle, | |
| }, | |
| cache_path, | |
| ) | |
| continue | |
| features = self.pair_features(drug_a, drug_b) | |
| graph_summary = self.graph_summary(drug_a, drug_b) | |
| record = { | |
| "row_index": raw_index, | |
| drug_a_col: drug_a, | |
| drug_b_col: drug_b, | |
| label_col: row[label_col], | |
| "fingerprint": features["fingerprint"], | |
| "semantic": features["semantic"], | |
| "pharmacology": features["pharmacology"], | |
| "pairwise": features["pairwise"], | |
| "molecular_pair": features["molecular_pair"], | |
| "fused": features["fused"], | |
| "graph_bundle": bundle, | |
| "graph_summary": graph_summary, | |
| } | |
| current_dims = { | |
| "fingerprint": int(record["fingerprint"].shape[0]), | |
| "semantic": int(record["semantic"].shape[0]), | |
| "pharmacology": int(record["pharmacology"].shape[0]), | |
| "pairwise": int(record["pairwise"].shape[0]), | |
| "molecular_pair": int(record["molecular_pair"].shape[0]), | |
| "graph_summary": int(record["graph_summary"].shape[0]), | |
| } | |
| if expected_dims is None: | |
| expected_dims = current_dims | |
| elif current_dims != expected_dims: | |
| _log_dropped_row(raw_index, drug_a, str(bundle.get("smiles_a_raw", "")), "feature_dimension_mismatch") | |
| raise ValueError( | |
| f"Quality gate failed: feature dimensions inconsistent for row {raw_index}. expected={expected_dims} got={current_dims}" | |
| ) | |
| joblib.dump( | |
| { | |
| "is_quarantined": False, | |
| "graph_bundle": bundle, | |
| "record": record, | |
| }, | |
| cache_path, | |
| ) | |
| kept_records.append(record) | |
| filtered_df = __import__("pandas").DataFrame(kept_records) | |
| invalid_report = output_dir / "invalid_smiles_report.json" | |
| summary_report = output_dir / "filtered_dataset_summary.json" | |
| self.invalid_tracker.write_reports(invalid_report, summary_report, total_rows=total_rows, kept_rows=int(len(filtered_df))) | |
| recovery_report = write_smiles_recovery_report(recovery_audit_records, output_dir.parent / "smiles_recovery_report.json") | |
| health_metrics = build_graph_health_metrics(graph_bundles) | |
| graph_report = output_dir / "graph_quality_report.md" | |
| write_graph_quality_report(graph_report, health_metrics) | |
| sanitized_dataset_path = output_dir / "sanitized_graph_dataset.joblib" | |
| joblib.dump(filtered_df, sanitized_dataset_path) | |
| final_invalid_rate = float(recovery_report["summary"]["failed_recovery"] / max(1, total_rows)) | |
| if final_invalid_rate > invalid_rate_threshold: | |
| raise ValueError( | |
| f"Quality gate failed: unrecoverable molecule rate {final_invalid_rate:.4f} exceeds threshold {invalid_rate_threshold:.4f}" | |
| ) | |
| if health_metrics.get("validation_error_counts"): | |
| raise ValueError( | |
| "Quality gate failed: graph validation errors detected; see graph_quality_report.md" | |
| ) | |
| metrics = { | |
| "total_rows": total_rows, | |
| "kept_rows": int(len(filtered_df)), | |
| "removed_rows": int(total_rows - len(filtered_df)), | |
| "invalid_rate": final_invalid_rate, | |
| "initial_invalid_rate": float(sum(1 for row in recovery_audit_records if row["validation_status"] != "valid") / max(1, len(recovery_audit_records))), | |
| "preprocessing_statistics": recovery_report["summary"], | |
| "reports": { | |
| "invalid_smiles_report": str(invalid_report), | |
| "filtered_dataset_summary": str(summary_report), | |
| "graph_quality_report": str(graph_report), | |
| "sanitized_graph_dataset": str(sanitized_dataset_path), | |
| "smiles_recovery_report": str(output_dir.parent / "smiles_recovery_report.json"), | |
| }, | |
| "graph_health": health_metrics, | |
| } | |
| return filtered_df, metrics | |
| def _tokenize_string(value: str) -> List[str]: | |
| cleaned = _normalize(value) | |
| if not cleaned: | |
| return [] | |
| return [token for token in cleaned.replace("/", " ").replace(";", " ").replace(",", " ").split() if token] | |
| def load_metadata_map(json_path: str | Path | None = None) -> Dict[str, Dict[str, Any]]: | |
| if not json_path: | |
| return {} | |
| path = Path(json_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Metadata map not found at {path}") | |
| data = json.loads(path.read_text(encoding="utf-8")) | |
| return {_normalize(key): value for key, value in data.items()} | |
| def build_feature_cache_path(name: str) -> Path: | |
| return MODELS_DIR / "feature_cache" / name | |