"""Graph-aware multimodal DDI foundation model. This module preserves the historical AdvancedDDINet import path while replacing the old tabular fusion net with a graph-based, clinically safer representation learner. The model supports: - molecular graph encoders for each drug - pharmacology and interaction graph encoders - residual MLP branches for legacy features - gated multi-head fusion with uncertainty and confidence heads The implementation remains CPU-compatible and can run without PyTorch Geometric by using the native graph encoder defined in training.graph_representations. """ from __future__ import annotations from dataclasses import dataclass from typing import Any, Optional, Sequence import torch import torch.nn as nn from .graph_representations import GraphEncoder, GraphSample class ResidualBlock(nn.Module): def __init__(self, dim: int, dropout: float = 0.2): super().__init__() self.norm = nn.LayerNorm(dim) self.block = nn.Sequential( nn.Linear(dim, dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 2, dim), ) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.dropout(self.block(self.norm(x))) class BranchEncoder(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, depth: int = 2, dropout: float = 0.2): super().__init__() self.proj = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), ) self.blocks = nn.ModuleList([ResidualBlock(hidden_dim, dropout=dropout) for _ in range(depth)]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) for block in self.blocks: x = block(x) return x class GatedAttentionFusion(nn.Module): def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.2): super().__init__() self.norm = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, batch_first=True, dropout=dropout) self.gate = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim), nn.Sigmoid(), ) self.ffn = nn.Sequential( nn.Linear(dim, dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 2, dim), ) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x_norm = self.norm(x) attn_out, attn_weights = self.attn(x_norm, x_norm, x_norm, need_weights=True) gated = self.gate(x_norm) * attn_out x = x + self.dropout(gated) x = x + self.dropout(self.ffn(self.norm(x))) return x, attn_weights @dataclass(frozen=True) class AdvancedModelConfig: fingerprint_dim: int = 8192 semantic_dim: int = 3072 pharmacology_dim: int = 512 pair_dim: int = 64 molecular_dim: int = 41 hidden_dim: int = 256 fusion_dim: int = 256 num_classes: int = 4 dropout: float = 0.25 num_heads: int = 4 graph_input_dim: int = 0 graph_hidden_dim: int = 128 graph_output_dim: int = 128 graph_layers: int = 3 use_graph_branches: bool = True confidence_dim: int = 64 class _GraphBranch(nn.Module): def __init__(self, output_dim: int, hidden_dim: int, layers: int, dropout: float): super().__init__() self.encoder = GraphEncoder(input_dim=32, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=layers, dropout=dropout) def _to_device(self, graph: GraphSample, device: torch.device) -> GraphSample: if graph.node_features.device == device: return graph return GraphSample( node_features=graph.node_features.to(device), edge_index=graph.edge_index.to(device), edge_features=graph.edge_features.to(device), graph_features=graph.graph_features.to(device), node_types=None if graph.node_types is None else graph.node_types.to(device), edge_types=None if graph.edge_types is None else graph.edge_types.to(device), valid=graph.valid, ) def forward(self, graph: GraphSample | Sequence[GraphSample] | None, device: torch.device) -> torch.Tensor: if graph is None: return torch.zeros((1, self.encoder.readout[-1].normalized_shape[0]), device=device) if isinstance(graph, Sequence): embeddings = [self.encoder(self._to_device(item, device)) for item in graph] return torch.cat(embeddings, dim=0) return self.encoder(self._to_device(graph, device)) class AdvancedDDINet(nn.Module): """Graph-aware multimodal DDI model with selective-prediction outputs.""" def __init__(self, config: AdvancedModelConfig): super().__init__() self.config = config self.fingerprint_branch = BranchEncoder(config.fingerprint_dim, config.hidden_dim, depth=3, dropout=config.dropout) self.semantic_branch = BranchEncoder(config.semantic_dim, config.hidden_dim, depth=2, dropout=config.dropout) self.pharmacology_branch = BranchEncoder(config.pharmacology_dim, config.hidden_dim, depth=2, dropout=config.dropout) self.pairwise_branch = BranchEncoder(config.pair_dim, config.hidden_dim // 2, depth=1, dropout=config.dropout) self.molecular_branch = BranchEncoder(config.molecular_dim, config.hidden_dim // 2, depth=1, dropout=config.dropout) self.graph_dense_proj = nn.Sequential( nn.LazyLinear(config.graph_output_dim), nn.GELU(), nn.Dropout(config.dropout), nn.LayerNorm(config.graph_output_dim), ) self.use_graph_branches = config.use_graph_branches self.molecular_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout) self.pharmacology_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout) self.interaction_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout) self.graph_bundle_proj = nn.Sequential( nn.LazyLinear(config.graph_output_dim), nn.GELU(), nn.Dropout(config.dropout), nn.LayerNorm(config.graph_output_dim), ) fusion_input_dim = config.hidden_dim * 4 + config.graph_output_dim * 2 self.fusion_proj = nn.Linear(fusion_input_dim, config.fusion_dim) self.fusion_attention = GatedAttentionFusion(config.fusion_dim, num_heads=config.num_heads, dropout=config.dropout) self.residual_fusion = nn.Sequential( nn.LayerNorm(config.fusion_dim), nn.Linear(config.fusion_dim, config.fusion_dim), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.fusion_dim, config.fusion_dim), ) self.confidence_head = nn.Sequential( nn.LayerNorm(config.fusion_dim), nn.Linear(config.fusion_dim, config.confidence_dim), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.confidence_dim, 1), ) self.uncertainty_head = nn.Sequential( nn.LayerNorm(config.fusion_dim), nn.Linear(config.fusion_dim, config.confidence_dim), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.confidence_dim, 1), ) self.head = nn.Sequential( nn.LayerNorm(config.fusion_dim), nn.Linear(config.fusion_dim, config.fusion_dim), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.fusion_dim, config.num_classes), ) def _ensure_batch(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.unsqueeze(0) if tensor.dim() == 1 else tensor def _graph_bundle_embedding(self, graph_bundle: Optional[dict[str, Any] | Sequence[dict[str, Any]]], device: torch.device) -> torch.Tensor: if not self.use_graph_branches or graph_bundle is None: return torch.zeros((1, self.config.graph_output_dim * 4), device=device) if isinstance(graph_bundle, Sequence): return torch.cat([self._graph_bundle_embedding(item, device) for item in graph_bundle], dim=0) drug_a = graph_bundle.get("drug_a_graph") drug_b = graph_bundle.get("drug_b_graph") pharmacology_graph = graph_bundle.get("pharmacology_graph") interaction_graph = graph_bundle.get("interaction_graph") interaction_summary = graph_bundle.get("interaction_summary") emb_a = self.molecular_graph_branch(drug_a, device) emb_b = self.molecular_graph_branch(drug_b, device) emb_pharma = self.pharmacology_graph_branch(pharmacology_graph, device) emb_interaction = self.interaction_graph_branch(interaction_graph, device) if interaction_summary is None: summary = torch.zeros((1, self.config.graph_output_dim), device=device) else: summary = self.graph_dense_proj(self._ensure_batch(interaction_summary.to(device))) pair_graph = torch.cat([emb_a, emb_b, torch.abs(emb_a - emb_b), emb_a * emb_b], dim=-1) bundle = torch.cat([pair_graph, emb_pharma, emb_interaction, summary], dim=-1) return self.graph_bundle_proj(bundle) def forward( self, fingerprint: torch.Tensor, semantic: torch.Tensor, pharmacology: torch.Tensor, pairwise: torch.Tensor, molecular: torch.Tensor, graph_features: Optional[torch.Tensor] = None, graph_bundle: Optional[dict[str, Any]] = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: device = fingerprint.device fp = self.fingerprint_branch(self._ensure_batch(fingerprint)) sem = self.semantic_branch(self._ensure_batch(semantic)) ph = self.pharmacology_branch(self._ensure_batch(pharmacology)) pair = self.pairwise_branch(self._ensure_batch(pairwise)) mol = self.molecular_branch(self._ensure_batch(molecular)) if graph_features is not None: graph_dense = self.graph_dense_proj(self._ensure_batch(graph_features.to(device))) else: graph_dense = torch.zeros((fp.shape[0], self.config.graph_output_dim), device=device) graph_bundle_embedding = self._graph_bundle_embedding(graph_bundle, device) if graph_bundle_embedding.shape[0] != fp.shape[0]: graph_bundle_embedding = graph_bundle_embedding.expand(fp.shape[0], -1) if graph_dense.shape[0] != fp.shape[0]: graph_dense = graph_dense.expand(fp.shape[0], -1) branches = [fp, sem, ph, pair, mol, graph_dense, graph_bundle_embedding] fused = torch.cat(branches, dim=-1) fused = self.fusion_proj(fused).unsqueeze(1) fused, attn_weights = self.fusion_attention(fused) fused = fused.squeeze(1) fused = fused + self.residual_fusion(fused) logits = self.head(fused) confidence = torch.sigmoid(self.confidence_head(fused)).squeeze(-1) uncertainty = torch.sigmoid(self.uncertainty_head(fused)).squeeze(-1) severe_probability = torch.softmax(logits, dim=-1)[..., -1] return logits, { "attention": attn_weights, "fused_embedding": fused, "confidence_embedding": confidence.unsqueeze(-1), "uncertainty_embedding": uncertainty.unsqueeze(-1), "severe_probability": severe_probability.unsqueeze(-1), }