Spaces:
Running
Running
| """Graph-based biomedical representation utilities for MEDCARE-DDI. | |
| This module provides a CPU-friendly graph stack that can operate without | |
| PyTorch Geometric while preserving the same conceptual architecture: | |
| - molecular graphs from RDKit SMILES | |
| - relational pharmacology graphs from DrugBank metadata | |
| - interaction neighborhood graphs for DDI topology learning | |
| - graph encoders with residual message passing, layer normalization, | |
| dropout, and Jumping Knowledge aggregation | |
| The implementation is intentionally deterministic and cache-friendly so it can | |
| be used in both offline training and online inference paths. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import hashlib | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Iterable, Mapping | |
| import joblib | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| try: # pragma: no cover - optional dependency | |
| from rdkit import Chem | |
| except Exception: # pragma: no cover | |
| Chem = None # type: ignore | |
| from .canonical_drug_mapper import CanonicalDrugMapper, _iter_drugbank_blocks, _normalize_text, _parse_drugbank_block, _compact_key | |
| from .molecular_sanitization import SanitizedMolecule, sanitize_smiles | |
| logger = logging.getLogger("medcare_ddi.graph") | |
| BASE_DIR = Path(__file__).resolve().parents[2] | |
| GRAPH_CACHE_DIR = BASE_DIR / "models" / "feature_cache" / "graph" | |
| GRAPH_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| def _stable_hash(value: str, modulo: int) -> int: | |
| digest = hashlib.sha1(str(value).encode("utf-8")).hexdigest() | |
| return int(digest[:12], 16) % modulo | |
| def _one_hot(index: int, size: int) -> np.ndarray: | |
| vector = np.zeros(size, dtype=np.float32) | |
| if 0 <= index < size: | |
| vector[index] = 1.0 | |
| return vector | |
| ATOM_NUMBERS = {1: 0, 5: 1, 6: 2, 7: 3, 8: 4, 9: 5, 15: 6, 16: 7, 17: 8, 35: 9, 53: 10} | |
| HYBRIDIZATION_NAMES = ["sp", "sp2", "sp3", "sp3d", "sp3d2"] | |
| BOND_TYPES = ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC"] | |
| NODE_FEATURE_DIM = 32 | |
| EDGE_FEATURE_DIM = 9 | |
| class GraphSample: | |
| """Compact graph container that is easy to serialize and batch.""" | |
| node_features: torch.Tensor | |
| edge_index: torch.Tensor | |
| edge_features: torch.Tensor | |
| graph_features: torch.Tensor | |
| node_types: torch.Tensor | None = None | |
| edge_types: torch.Tensor | None = None | |
| valid: bool = True | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "node_features": self.node_features.detach().cpu().numpy(), | |
| "edge_index": self.edge_index.detach().cpu().numpy(), | |
| "edge_features": self.edge_features.detach().cpu().numpy(), | |
| "graph_features": self.graph_features.detach().cpu().numpy(), | |
| "node_types": None if self.node_types is None else self.node_types.detach().cpu().numpy(), | |
| "edge_types": None if self.edge_types is None else self.edge_types.detach().cpu().numpy(), | |
| "valid": self.valid, | |
| } | |
| def from_dict(cls, payload: Mapping[str, Any]) -> "GraphSample": | |
| def _tensor(value: Any) -> torch.Tensor: | |
| return torch.as_tensor(np.asarray(value), dtype=torch.float32) | |
| node_types = payload.get("node_types") | |
| edge_types = payload.get("edge_types") | |
| return cls( | |
| node_features=_tensor(payload["node_features"]), | |
| edge_index=torch.as_tensor(np.asarray(payload["edge_index"]), dtype=torch.long), | |
| edge_features=_tensor(payload["edge_features"]), | |
| graph_features=_tensor(payload["graph_features"]), | |
| node_types=None if node_types is None else torch.as_tensor(np.asarray(node_types), dtype=torch.long), | |
| edge_types=None if edge_types is None else torch.as_tensor(np.asarray(edge_types), dtype=torch.long), | |
| valid=bool(payload.get("valid", True)), | |
| ) | |
| def _safe_mol(smiles: str): | |
| if Chem is None: | |
| return None | |
| try: | |
| return Chem.MolFromSmiles(smiles or "") | |
| except Exception: | |
| return None | |
| def _atom_features(atom) -> np.ndarray: | |
| hybridization = str(atom.GetHybridization()) if atom is not None else "" | |
| implicit_valence = 0.0 | |
| total_valence = 0.0 | |
| if atom is not None: | |
| try: | |
| implicit_valence = float(atom.GetValence(Chem.ValenceType.IMPLICIT)) | |
| total_valence = float(atom.GetValence(Chem.ValenceType.EXPLICIT) + atom.GetValence(Chem.ValenceType.IMPLICIT)) | |
| except Exception: | |
| implicit_valence = float(atom.GetImplicitValence()) | |
| total_valence = float(atom.GetTotalValence()) | |
| feature_blocks = [ | |
| _one_hot(ATOM_NUMBERS.get(int(atom.GetAtomicNum()), 0), len(ATOM_NUMBERS)), | |
| np.array([ | |
| float(atom.GetDegree()), | |
| total_valence, | |
| float(atom.GetFormalCharge()), | |
| float(atom.GetTotalNumHs()), | |
| implicit_valence, | |
| float(atom.GetIsAromatic()), | |
| float(atom.IsInRing()), | |
| float(atom.GetChiralTag()), | |
| ], dtype=np.float32), | |
| _one_hot(HYBRIDIZATION_NAMES.index(hybridization.lower()) if hybridization.lower() in HYBRIDIZATION_NAMES else -1, len(HYBRIDIZATION_NAMES)), | |
| ] | |
| return np.concatenate(feature_blocks, axis=0).astype(np.float32) | |
| def _empty_invalid_graph(reason: str = "invalid_smiles") -> GraphSample: | |
| node_features = np.zeros((1, NODE_FEATURE_DIM), dtype=np.float32) | |
| edge_index = np.zeros((2, 0), dtype=np.int64) | |
| edge_features = np.zeros((0, EDGE_FEATURE_DIM), dtype=np.float32) | |
| # Keep deterministic fallback summary while marking invalid validity flag. | |
| graph_features = np.array([0.0] * 11 + [1.0], dtype=np.float32) | |
| graph = GraphSample( | |
| node_features=torch.from_numpy(node_features), | |
| edge_index=torch.from_numpy(edge_index), | |
| edge_features=torch.from_numpy(edge_features), | |
| graph_features=torch.from_numpy(graph_features), | |
| valid=False, | |
| ) | |
| return graph | |
| def _pad_vector(vector: np.ndarray, dim: int) -> np.ndarray: | |
| if vector.shape[-1] >= dim: | |
| return vector[..., :dim].astype(np.float32) | |
| padding = np.zeros((dim - vector.shape[-1],), dtype=np.float32) | |
| return np.concatenate([vector.astype(np.float32), padding], axis=0) | |
| def _bond_features(bond) -> np.ndarray: | |
| bond_type = str(bond.GetBondType()) if bond is not None else "" | |
| features = np.array( | |
| [ | |
| float(bond_type in BOND_TYPES), | |
| float(bond.GetIsConjugated()), | |
| float(bond.GetIsAromatic()), | |
| float(bond.GetBondDir()), | |
| float(bond.GetStereo()), | |
| ], | |
| dtype=np.float32, | |
| ) | |
| return np.concatenate([features, _one_hot(BOND_TYPES.index(bond_type) if bond_type in BOND_TYPES else -1, len(BOND_TYPES))], axis=0).astype(np.float32) | |
| def _graph_statistics(node_features: np.ndarray, edge_index: np.ndarray, edge_features: np.ndarray, valid: bool) -> np.ndarray: | |
| num_nodes = float(node_features.shape[0]) | |
| num_edges = float(edge_index.shape[1]) if edge_index.size else 0.0 | |
| if node_features.size == 0: | |
| return np.zeros(12, dtype=np.float32) | |
| degrees = np.zeros(node_features.shape[0], dtype=np.float32) | |
| if edge_index.size: | |
| np.add.at(degrees, edge_index[0], 1.0) | |
| return np.array( | |
| [ | |
| num_nodes, | |
| num_edges, | |
| float(degrees.mean()) if degrees.size else 0.0, | |
| float(degrees.std()) if degrees.size else 0.0, | |
| float(degrees.max()) if degrees.size else 0.0, | |
| float(node_features[:, 0].sum()), | |
| float(node_features[:, 1].sum()), | |
| float(node_features[:, 5].sum()), | |
| float(edge_features[:, 0].sum()) if edge_features.size else 0.0, | |
| float(edge_features[:, 2].sum()) if edge_features.size else 0.0, | |
| float(valid), | |
| float(num_nodes > 0 and num_edges > 0), | |
| ], | |
| dtype=np.float32, | |
| ) | |
| def smiles_to_graph(smiles: str) -> GraphSample: | |
| """Convert SMILES into a compact graph representation. | |
| Invalid molecules return a single-node fallback graph with validity flags. | |
| """ | |
| sanitized = sanitize_smiles(smiles) | |
| if not sanitized.valid or sanitized.mol is None: | |
| # Do not create fake tensors for invalid molecules — return None | |
| return None | |
| mol = sanitized.mol | |
| atom_features = [ | |
| _atom_features(atom) | |
| for atom in mol.GetAtoms() | |
| ] | |
| node_features = np.vstack([_pad_vector(features, NODE_FEATURE_DIM) for features in atom_features]).astype(np.float32) if atom_features else np.zeros((1, NODE_FEATURE_DIM), dtype=np.float32) | |
| edges: list[tuple[int, int]] = [] | |
| bond_features: list[np.ndarray] = [] | |
| for bond in mol.GetBonds(): | |
| src = int(bond.GetBeginAtomIdx()) | |
| dst = int(bond.GetEndAtomIdx()) | |
| features = _bond_features(bond) | |
| edges.extend([(src, dst), (dst, src)]) | |
| bond_features.extend([features, features]) | |
| if edges: | |
| edge_index = np.asarray(edges, dtype=np.int64).T | |
| edge_features = np.vstack([_pad_vector(features, EDGE_FEATURE_DIM) for features in bond_features]).astype(np.float32) | |
| else: | |
| edge_index = np.zeros((2, 0), dtype=np.int64) | |
| edge_features = np.zeros((0, EDGE_FEATURE_DIM), dtype=np.float32) | |
| graph_features = _graph_statistics(node_features, edge_index, edge_features, valid=True) | |
| return GraphSample( | |
| node_features=torch.from_numpy(node_features), | |
| edge_index=torch.from_numpy(edge_index), | |
| edge_features=torch.from_numpy(edge_features), | |
| graph_features=torch.from_numpy(graph_features), | |
| valid=True, | |
| ) | |
| def sanitize_smiles_for_graph(smiles: str) -> SanitizedMolecule: | |
| """Expose sanitization for training/inference quality gates.""" | |
| return sanitize_smiles(smiles) | |
| def _parse_smiles_from_block(block: str) -> str: | |
| import re | |
| found = re.search(r"<kind>SMILES</kind>\s*<value>(.*?)</value>", block, flags=re.S | re.I) | |
| if not found: | |
| return "" | |
| raw = found.group(1).replace("\n", "").strip() | |
| # remove stray spaces inside SMILES and collapse internal whitespace | |
| raw = re.sub(r"\s+", "", raw) | |
| # treat obvious placeholders as empty | |
| if "..." in raw or raw.endswith("..."): | |
| return "" | |
| return raw | |
| def load_drugbank_metadata() -> dict[str, dict[str, Any]]: | |
| """Load a lightweight DrugBank metadata map from the structured artifacts.""" | |
| metadata: dict[str, dict[str, Any]] = {} | |
| mapper = CanonicalDrugMapper.from_structured_artifacts() | |
| for entity in mapper.entities: | |
| record = { | |
| "drugbank_id": entity.drugbank_id, | |
| "primary_name": entity.primary_name, | |
| "aliases": list(entity.aliases), | |
| "atc_codes": list(entity.atc_codes), | |
| "categories": list(entity.categories), | |
| "semantic_tokens": list(entity.semantic_tokens), | |
| "counts": dict(entity.counts), | |
| "targets": [], | |
| "enzymes": [], | |
| "transporters": [], | |
| "carriers": [], | |
| "smiles": "", | |
| } | |
| metadata[_normalize_text(entity.primary_name)] = record | |
| # index by compact key and canonical id for robust lookup from dataset names | |
| metadata[_compact_key(entity.primary_name)] = record | |
| metadata[entity.canonical_id] = record | |
| if entity.drugbank_id: | |
| metadata[_normalize_text(entity.drugbank_id)] = record | |
| for alias in entity.aliases: | |
| alias_key = _normalize_text(alias) | |
| if alias_key and alias_key not in metadata: | |
| metadata[alias_key] = record | |
| # also index compact alias form | |
| compact_alias = _compact_key(alias) | |
| if compact_alias and compact_alias not in metadata: | |
| metadata[compact_alias] = record | |
| try: | |
| from preprocessing.artifact_store import DRUGBANK_XML | |
| path = Path(DRUGBANK_XML) | |
| except ImportError: | |
| path = None | |
| if path is None or not Path(path).exists(): | |
| return metadata | |
| for block in _iter_drugbank_blocks(path): | |
| parsed = _parse_drugbank_block(block) | |
| if parsed is None: | |
| continue | |
| key = _normalize_text(parsed.primary_name) | |
| if key not in metadata: | |
| metadata[key] = {} | |
| metadata[key].update( | |
| { | |
| "drugbank_id": parsed.drugbank_id, | |
| "primary_name": parsed.primary_name, | |
| "aliases": list(parsed.aliases), | |
| "atc_codes": list(parsed.atc_codes), | |
| "categories": list(parsed.categories), | |
| "semantic_tokens": list(parsed.semantic_tokens), | |
| "counts": dict(parsed.counts), | |
| } | |
| ) | |
| raw_smiles = _parse_smiles_from_block(block) | |
| # canonicalize and store only RDKit-validated canonical SMILES | |
| try: | |
| from .molecular_sanitization import sanitize_smiles | |
| sanitized = sanitize_smiles(raw_smiles) | |
| canonical = sanitized.canonical_smiles if sanitized.valid else "" | |
| except Exception: | |
| canonical = "" | |
| metadata[key]["smiles"] = canonical | |
| metadata[key]["smiles_raw"] = raw_smiles | |
| # ensure compact keys and canonical id map to the same record (populate smiles) | |
| compact = _compact_key(parsed.primary_name) | |
| if compact: | |
| metadata[compact] = metadata[key] | |
| if parsed.drugbank_id: | |
| metadata[_normalize_text(parsed.drugbank_id)] = metadata[key] | |
| metadata[parsed.drugbank_id] = metadata[key] | |
| if parsed.drugbank_id: | |
| metadata[_normalize_text(parsed.drugbank_id)] = metadata[key] | |
| for alias in parsed.aliases: | |
| alias_key = _normalize_text(alias) | |
| if alias_key and alias_key not in metadata: | |
| metadata[alias_key] = metadata[key] | |
| compact_alias = _compact_key(alias) | |
| if compact_alias and compact_alias not in metadata: | |
| metadata[compact_alias] = metadata[key] | |
| return metadata | |
| def _find_smiles_in_metadata(metadata: Mapping[str, Mapping[str, Any]], name: str) -> dict[str, Any] | None: | |
| """Heuristic fallback: look for an exact alias match inside records' alias lists. | |
| Only return a match when it is unambiguous and the record contains a SMILES string. | |
| This avoids noisy substring matches that often arise from product names. | |
| """ | |
| if not name: | |
| return None | |
| key = _normalize_text(name) | |
| # exact key present? | |
| rec = metadata.get(key) | |
| if rec and rec.get('smiles'): | |
| return rec | |
| # Search aliases for exact normalized match | |
| candidates = [] | |
| for rec_key, rec in metadata.items(): | |
| aliases = rec.get('aliases', []) or [] | |
| for alias in aliases: | |
| if _normalize_text(alias) == key and rec.get('smiles'): | |
| candidates.append(rec) | |
| break | |
| if len(candidates) == 1: | |
| return candidates[0] | |
| return None | |
| def build_drug_graph_bundle( | |
| drug_a: str, | |
| drug_b: str, | |
| metadata: Mapping[str, Mapping[str, Any]] | None = None, | |
| ) -> dict[str, GraphSample | torch.Tensor | dict[str, Any]]: | |
| """Build graph inputs for a DDI pair. | |
| The returned bundle contains: | |
| - molecular graphs for each drug | |
| - a small relational pharmacology graph | |
| - a compact interaction summary vector | |
| """ | |
| metadata = metadata or {} | |
| meta_a = metadata.get(_normalize_text(drug_a), {}) | |
| if not meta_a: | |
| found = _find_smiles_in_metadata(metadata, drug_a) | |
| if found: | |
| meta_a = found | |
| meta_b = metadata.get(_normalize_text(drug_b), {}) | |
| if not meta_b: | |
| found = _find_smiles_in_metadata(metadata, drug_b) | |
| if found: | |
| meta_b = found | |
| smiles_a_raw = str(meta_a.get("smiles", "")) | |
| smiles_b_raw = str(meta_b.get("smiles", "")) | |
| smiles_a_validation = sanitize_smiles(smiles_a_raw) | |
| smiles_b_validation = sanitize_smiles(smiles_b_raw) | |
| smiles_a = smiles_a_validation.canonical_smiles if smiles_a_validation.valid else "" | |
| smiles_b = smiles_b_validation.canonical_smiles if smiles_b_validation.valid else "" | |
| graph_a = smiles_to_graph(smiles_a) if smiles_a else None | |
| graph_b = smiles_to_graph(smiles_b) if smiles_b else None | |
| concepts_a = _collect_pharmacology_concepts(meta_a) | |
| concepts_b = _collect_pharmacology_concepts(meta_b) | |
| pharmacology_graph = _build_concept_graph(drug_a, drug_b, concepts_a, concepts_b) | |
| interaction_graph = _build_interaction_graph(drug_a, drug_b, meta_a, meta_b) | |
| def _gf_val(g, idx: int) -> float: | |
| try: | |
| return float(g.graph_features[idx].item()) if g is not None else 0.0 | |
| except Exception: | |
| return 0.0 | |
| interaction_summary = torch.tensor( | |
| [ | |
| float(bool(graph_a)), | |
| float(bool(graph_b)), | |
| float(len(concepts_a & concepts_b) > 0), | |
| float(len(concepts_a)), | |
| float(len(concepts_b)), | |
| float(len(concepts_a & concepts_b)), | |
| float(len(concepts_a | concepts_b)), | |
| _gf_val(graph_a, 0), | |
| _gf_val(graph_b, 0), | |
| _gf_val(graph_a, 1), | |
| _gf_val(graph_b, 1), | |
| _gf_val(graph_a, 2), | |
| _gf_val(graph_b, 2), | |
| _gf_val(graph_a, 5), | |
| _gf_val(graph_b, 5), | |
| _gf_val(graph_a, 8), | |
| _gf_val(graph_b, 8), | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| return { | |
| "drug_a_graph": graph_a, | |
| "drug_b_graph": graph_b, | |
| "pharmacology_graph": pharmacology_graph, | |
| "interaction_graph": interaction_graph, | |
| "interaction_summary": interaction_summary, | |
| "smiles_a": smiles_a, | |
| "smiles_b": smiles_b, | |
| "smiles_a_raw": smiles_a_raw, | |
| "smiles_b_raw": smiles_b_raw, | |
| "smiles_a_validation": smiles_a_validation.to_report_dict(), | |
| "smiles_b_validation": smiles_b_validation.to_report_dict(), | |
| "quarantined": bool((not smiles_a_validation.valid) or (not smiles_b_validation.valid)), | |
| "quarantine_reasons": [ | |
| reason | |
| for reason in [ | |
| smiles_a_validation.reason if not smiles_a_validation.valid else "", | |
| smiles_b_validation.reason if not smiles_b_validation.valid else "", | |
| ] | |
| if reason | |
| ], | |
| } | |
| def _collect_pharmacology_concepts(meta: Mapping[str, Any]) -> set[str]: | |
| concepts: set[str] = set() | |
| for key in ("atc_codes", "categories", "targets", "enzymes", "transporters", "carriers", "semantic_tokens"): | |
| value = meta.get(key, []) | |
| if isinstance(value, str): | |
| token = _normalize_text(value) | |
| if token: | |
| concepts.add(token) | |
| continue | |
| if isinstance(value, Iterable): | |
| for item in value: | |
| token = _normalize_text(item) | |
| if token: | |
| concepts.add(token) | |
| return {concept for concept in concepts if concept} | |
| def _build_concept_graph(drug_a: str, drug_b: str, concepts_a: set[str], concepts_b: set[str]) -> GraphSample: | |
| concepts = sorted((concepts_a | concepts_b)) | |
| nodes = [f"drug::{_normalize_text(drug_a)}", f"drug::{_normalize_text(drug_b)}", *[f"concept::{concept}" for concept in concepts]] | |
| node_types = [0, 0] + [1] * len(concepts) | |
| node_features = np.vstack([_hashed_node_feature(node) for node in nodes]).astype(np.float32) | |
| edges: list[tuple[int, int]] = [] | |
| edge_features: list[np.ndarray] = [] | |
| edge_types: list[int] = [] | |
| for idx, concept in enumerate(concepts, start=2): | |
| if concept in concepts_a: | |
| edges.extend([(0, idx), (idx, 0)]) | |
| edge_features.extend([_edge_feature(0), _edge_feature(0)]) | |
| edge_types.extend([0, 0]) | |
| if concept in concepts_b: | |
| edges.extend([(1, idx), (idx, 1)]) | |
| edge_features.extend([_edge_feature(1), _edge_feature(1)]) | |
| edge_types.extend([1, 1]) | |
| if concept in concepts_a and concept in concepts_b: | |
| edges.extend([(0, 1), (1, 0)]) | |
| edge_features.extend([_edge_feature(2), _edge_feature(2)]) | |
| edge_types.extend([2, 2]) | |
| if edges: | |
| edge_index = np.asarray(edges, dtype=np.int64).T | |
| edge_feat = np.vstack(edge_features).astype(np.float32) | |
| edge_type_tensor = torch.tensor(edge_types, dtype=torch.long) | |
| else: | |
| edge_index = np.zeros((2, 0), dtype=np.int64) | |
| edge_feat = np.zeros((0, EDGE_FEATURE_DIM), dtype=np.float32) | |
| edge_type_tensor = torch.zeros((0,), dtype=torch.long) | |
| graph_features = np.array( | |
| [ | |
| float(len(concepts)), | |
| float(len(concepts_a)), | |
| float(len(concepts_b)), | |
| float(len(concepts_a & concepts_b)), | |
| float(len(edges) > 0), | |
| float(len(nodes)), | |
| float(sum(node_types)), | |
| float(len(concepts_a | concepts_b)), | |
| 0.0, | |
| 0.0, | |
| 0.0, | |
| 0.0, | |
| ], | |
| dtype=np.float32, | |
| ) | |
| return GraphSample( | |
| node_features=torch.from_numpy(node_features), | |
| edge_index=torch.from_numpy(edge_index), | |
| edge_features=torch.from_numpy(edge_feat), | |
| graph_features=torch.from_numpy(graph_features), | |
| node_types=torch.tensor(node_types, dtype=torch.long), | |
| edge_types=edge_type_tensor, | |
| valid=bool(concepts), | |
| ) | |
| def _build_interaction_graph(drug_a: str, drug_b: str, meta_a: Mapping[str, Any], meta_b: Mapping[str, Any]) -> GraphSample: | |
| context_nodes = _interaction_context_nodes(meta_a, meta_b) | |
| nodes = [ | |
| f"interaction::{_normalize_text(drug_a)}", | |
| f"interaction::{_normalize_text(drug_b)}", | |
| *context_nodes, | |
| ] | |
| node_features = np.vstack([_hashed_node_feature(node) for node in nodes]).astype(np.float32) | |
| edges = [(0, 1), (1, 0)] | |
| edge_features = [_edge_feature(3), _edge_feature(3)] | |
| if len(nodes) > 2: | |
| for idx in range(2, len(nodes)): | |
| edges.extend([(0, idx), (idx, 0), (1, idx), (idx, 1)]) | |
| edge_features.extend([_edge_feature(4), _edge_feature(4), _edge_feature(5), _edge_feature(5)]) | |
| edge_index = np.asarray(edges, dtype=np.int64).T if edges else np.zeros((2, 0), dtype=np.int64) | |
| edge_feat = np.vstack([_pad_vector(features, EDGE_FEATURE_DIM) for features in edge_features]).astype(np.float32) if edge_features else np.zeros((0, EDGE_FEATURE_DIM), dtype=np.float32) | |
| graph_features = np.array([ | |
| float(len(nodes)), | |
| float(len(edges)), | |
| float(_stable_hash(drug_a + drug_b, 997) / 997.0), | |
| float(len(context_nodes)), | |
| float(bool(nodes)), | |
| float(len(nodes) > 2), | |
| float(len(edges) > 0), | |
| 1.0, | |
| float(len(context_nodes) > 0), | |
| float(len(nodes) - 2), | |
| float(len(edges) // max(1, len(nodes))), | |
| 0.0, | |
| ], dtype=np.float32) | |
| return GraphSample( | |
| node_features=torch.from_numpy(node_features), | |
| edge_index=torch.from_numpy(edge_index), | |
| edge_features=torch.from_numpy(edge_feat), | |
| graph_features=torch.from_numpy(graph_features), | |
| valid=True, | |
| ) | |
| def _interaction_context_nodes(meta_a: Mapping[str, Any], meta_b: Mapping[str, Any]) -> list[str]: | |
| nodes: list[str] = [] | |
| for key in ("atc_codes", "targets", "enzymes", "transporters", "carriers", "categories"): | |
| values_a = {_normalize_text(value) for value in meta_a.get(key, []) if _normalize_text(value)} | |
| values_b = {_normalize_text(value) for value in meta_b.get(key, []) if _normalize_text(value)} | |
| shared = sorted(values_a & values_b) | |
| nodes.extend([f"{key}::{value}" for value in shared[:8]]) | |
| return nodes | |
| def _hashed_node_feature(node: str, dim: int = 32) -> np.ndarray: | |
| vector = np.zeros(dim, dtype=np.float32) | |
| vector[_stable_hash(node, dim)] = 1.0 | |
| return vector | |
| def _edge_feature(edge_type: int, dim: int = EDGE_FEATURE_DIM) -> np.ndarray: | |
| return _one_hot(edge_type, dim) | |
| class GraphMessagePassingBlock(nn.Module): | |
| def __init__(self, hidden_dim: int, edge_dim: int, dropout: float = 0.2): | |
| super().__init__() | |
| self.self_proj = nn.Linear(hidden_dim, hidden_dim) | |
| self.neigh_proj = nn.Linear(hidden_dim, hidden_dim) | |
| self.edge_gate = nn.Sequential( | |
| nn.Linear(edge_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.Sigmoid(), | |
| ) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim * 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim * 2, hidden_dim), | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: | |
| if edge_index.numel() == 0: | |
| return x + self.dropout(self.ffn(self.norm(self.self_proj(x)))) | |
| src, dst = edge_index | |
| messages = self.neigh_proj(x[src]) * self.edge_gate(edge_attr) | |
| aggregated = torch.zeros_like(x) | |
| aggregated.index_add_(0, dst, messages) | |
| degree = torch.zeros(x.size(0), device=x.device, dtype=x.dtype) | |
| degree.index_add_(0, dst, torch.ones(dst.size(0), device=x.device, dtype=x.dtype)) | |
| degree = degree.clamp_min(1.0).unsqueeze(-1) | |
| out = self.self_proj(x) + aggregated / degree | |
| out = x + self.dropout(self.ffn(self.norm(out))) | |
| return out | |
| class JumpingKnowledgePooling(nn.Module): | |
| def __init__(self, hidden_dim: int): | |
| super().__init__() | |
| self.attn = nn.Linear(hidden_dim, 1) | |
| def forward(self, layer_outputs: list[torch.Tensor]) -> torch.Tensor: | |
| stacked = torch.stack(layer_outputs, dim=0) | |
| weights = torch.softmax(self.attn(stacked).squeeze(-1), dim=0).unsqueeze(-1) | |
| return (stacked * weights).sum(dim=0) | |
| class GraphEncoder(nn.Module): | |
| def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3, dropout: float = 0.2, edge_dim: int = EDGE_FEATURE_DIM): | |
| super().__init__() | |
| self.input_proj = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| ) | |
| self.blocks = nn.ModuleList([ | |
| GraphMessagePassingBlock(hidden_dim, edge_dim=edge_dim, dropout=dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| self.jk = JumpingKnowledgePooling(hidden_dim) | |
| self.readout = nn.Sequential( | |
| nn.Linear(hidden_dim * 2 + 12, output_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.LayerNorm(output_dim), | |
| ) | |
| def forward(self, graph: GraphSample) -> torch.Tensor: | |
| x = graph.node_features | |
| if x.dim() == 1: | |
| x = x.unsqueeze(0) | |
| edge_index = graph.edge_index.long().to(x.device) | |
| edge_attr = graph.edge_features.to(x.device) | |
| if edge_attr.numel() == 0: | |
| edge_attr = torch.zeros((0, EDGE_FEATURE_DIM), device=x.device, dtype=x.dtype) | |
| h = self.input_proj(x) | |
| layer_outputs = [h] | |
| for block in self.blocks: | |
| h = block(h, edge_index, edge_attr) | |
| layer_outputs.append(h) | |
| jk = self.jk(layer_outputs) | |
| pooled_mean = jk.mean(dim=0, keepdim=True) | |
| pooled_max = jk.max(dim=0, keepdim=True).values | |
| graph_features = graph.graph_features.to(x.device).reshape(1, -1) | |
| return self.readout(torch.cat([pooled_mean, pooled_max, graph_features], dim=-1)) | |
| def cache_graph_bundle(bundle: Mapping[str, Any], path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| joblib.dump(bundle, path) | |
| def load_graph_bundle(path: Path) -> dict[str, Any]: | |
| return joblib.load(path) | |