Spaces:
Running
Running
| """Graph-aware multimodal DDI foundation model. | |
| This module preserves the historical AdvancedDDINet import path while | |
| replacing the old tabular fusion net with a graph-based, clinically safer | |
| representation learner. The model supports: | |
| - molecular graph encoders for each drug | |
| - pharmacology and interaction graph encoders | |
| - residual MLP branches for legacy features | |
| - gated multi-head fusion with uncertainty and confidence heads | |
| The implementation remains CPU-compatible and can run without PyTorch | |
| Geometric by using the native graph encoder defined in | |
| training.graph_representations. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Optional, Sequence | |
| import torch | |
| import torch.nn as nn | |
| from .graph_representations import GraphEncoder, GraphSample | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, dim: int, dropout: float = 0.2): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.block = nn.Sequential( | |
| nn.Linear(dim, dim * 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * 2, dim), | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.dropout(self.block(self.norm(x))) | |
| class BranchEncoder(nn.Module): | |
| def __init__(self, input_dim: int, hidden_dim: int, depth: int = 2, dropout: float = 0.2): | |
| super().__init__() | |
| self.proj = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| ) | |
| self.blocks = nn.ModuleList([ResidualBlock(hidden_dim, dropout=dropout) for _ in range(depth)]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.proj(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| return x | |
| class GatedAttentionFusion(nn.Module): | |
| def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.2): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, batch_first=True, dropout=dropout) | |
| self.gate = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.GELU(), | |
| nn.Linear(dim, dim), | |
| nn.Sigmoid(), | |
| ) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(dim, dim * 2), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * 2, dim), | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| x_norm = self.norm(x) | |
| attn_out, attn_weights = self.attn(x_norm, x_norm, x_norm, need_weights=True) | |
| gated = self.gate(x_norm) * attn_out | |
| x = x + self.dropout(gated) | |
| x = x + self.dropout(self.ffn(self.norm(x))) | |
| return x, attn_weights | |
| class AdvancedModelConfig: | |
| fingerprint_dim: int = 8192 | |
| semantic_dim: int = 3072 | |
| pharmacology_dim: int = 512 | |
| pair_dim: int = 64 | |
| molecular_dim: int = 41 | |
| hidden_dim: int = 256 | |
| fusion_dim: int = 256 | |
| num_classes: int = 4 | |
| dropout: float = 0.25 | |
| num_heads: int = 4 | |
| graph_input_dim: int = 0 | |
| graph_hidden_dim: int = 128 | |
| graph_output_dim: int = 128 | |
| graph_layers: int = 3 | |
| use_graph_branches: bool = True | |
| confidence_dim: int = 64 | |
| class _GraphBranch(nn.Module): | |
| def __init__(self, output_dim: int, hidden_dim: int, layers: int, dropout: float): | |
| super().__init__() | |
| self.encoder = GraphEncoder(input_dim=32, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=layers, dropout=dropout) | |
| def _to_device(self, graph: GraphSample, device: torch.device) -> GraphSample: | |
| if graph.node_features.device == device: | |
| return graph | |
| return GraphSample( | |
| node_features=graph.node_features.to(device), | |
| edge_index=graph.edge_index.to(device), | |
| edge_features=graph.edge_features.to(device), | |
| graph_features=graph.graph_features.to(device), | |
| node_types=None if graph.node_types is None else graph.node_types.to(device), | |
| edge_types=None if graph.edge_types is None else graph.edge_types.to(device), | |
| valid=graph.valid, | |
| ) | |
| def forward(self, graph: GraphSample | Sequence[GraphSample] | None, device: torch.device) -> torch.Tensor: | |
| if graph is None: | |
| return torch.zeros((1, self.encoder.readout[-1].normalized_shape[0]), device=device) | |
| if isinstance(graph, Sequence): | |
| embeddings = [self.encoder(self._to_device(item, device)) for item in graph] | |
| return torch.cat(embeddings, dim=0) | |
| return self.encoder(self._to_device(graph, device)) | |
| class AdvancedDDINet(nn.Module): | |
| """Graph-aware multimodal DDI model with selective-prediction outputs.""" | |
| def __init__(self, config: AdvancedModelConfig): | |
| super().__init__() | |
| self.config = config | |
| self.fingerprint_branch = BranchEncoder(config.fingerprint_dim, config.hidden_dim, depth=3, dropout=config.dropout) | |
| self.semantic_branch = BranchEncoder(config.semantic_dim, config.hidden_dim, depth=2, dropout=config.dropout) | |
| self.pharmacology_branch = BranchEncoder(config.pharmacology_dim, config.hidden_dim, depth=2, dropout=config.dropout) | |
| self.pairwise_branch = BranchEncoder(config.pair_dim, config.hidden_dim // 2, depth=1, dropout=config.dropout) | |
| self.molecular_branch = BranchEncoder(config.molecular_dim, config.hidden_dim // 2, depth=1, dropout=config.dropout) | |
| self.graph_dense_proj = nn.Sequential( | |
| nn.LazyLinear(config.graph_output_dim), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.LayerNorm(config.graph_output_dim), | |
| ) | |
| self.use_graph_branches = config.use_graph_branches | |
| self.molecular_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout) | |
| self.pharmacology_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout) | |
| self.interaction_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout) | |
| self.graph_bundle_proj = nn.Sequential( | |
| nn.LazyLinear(config.graph_output_dim), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.LayerNorm(config.graph_output_dim), | |
| ) | |
| fusion_input_dim = config.hidden_dim * 4 + config.graph_output_dim * 2 | |
| self.fusion_proj = nn.Linear(fusion_input_dim, config.fusion_dim) | |
| self.fusion_attention = GatedAttentionFusion(config.fusion_dim, num_heads=config.num_heads, dropout=config.dropout) | |
| self.residual_fusion = nn.Sequential( | |
| nn.LayerNorm(config.fusion_dim), | |
| nn.Linear(config.fusion_dim, config.fusion_dim), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.fusion_dim, config.fusion_dim), | |
| ) | |
| self.confidence_head = nn.Sequential( | |
| nn.LayerNorm(config.fusion_dim), | |
| nn.Linear(config.fusion_dim, config.confidence_dim), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.confidence_dim, 1), | |
| ) | |
| self.uncertainty_head = nn.Sequential( | |
| nn.LayerNorm(config.fusion_dim), | |
| nn.Linear(config.fusion_dim, config.confidence_dim), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.confidence_dim, 1), | |
| ) | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(config.fusion_dim), | |
| nn.Linear(config.fusion_dim, config.fusion_dim), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.fusion_dim, config.num_classes), | |
| ) | |
| def _ensure_batch(self, tensor: torch.Tensor) -> torch.Tensor: | |
| return tensor.unsqueeze(0) if tensor.dim() == 1 else tensor | |
| def _graph_bundle_embedding(self, graph_bundle: Optional[dict[str, Any] | Sequence[dict[str, Any]]], device: torch.device) -> torch.Tensor: | |
| if not self.use_graph_branches or graph_bundle is None: | |
| return torch.zeros((1, self.config.graph_output_dim * 4), device=device) | |
| if isinstance(graph_bundle, Sequence): | |
| return torch.cat([self._graph_bundle_embedding(item, device) for item in graph_bundle], dim=0) | |
| drug_a = graph_bundle.get("drug_a_graph") | |
| drug_b = graph_bundle.get("drug_b_graph") | |
| pharmacology_graph = graph_bundle.get("pharmacology_graph") | |
| interaction_graph = graph_bundle.get("interaction_graph") | |
| interaction_summary = graph_bundle.get("interaction_summary") | |
| emb_a = self.molecular_graph_branch(drug_a, device) | |
| emb_b = self.molecular_graph_branch(drug_b, device) | |
| emb_pharma = self.pharmacology_graph_branch(pharmacology_graph, device) | |
| emb_interaction = self.interaction_graph_branch(interaction_graph, device) | |
| if interaction_summary is None: | |
| summary = torch.zeros((1, self.config.graph_output_dim), device=device) | |
| else: | |
| summary = self.graph_dense_proj(self._ensure_batch(interaction_summary.to(device))) | |
| pair_graph = torch.cat([emb_a, emb_b, torch.abs(emb_a - emb_b), emb_a * emb_b], dim=-1) | |
| bundle = torch.cat([pair_graph, emb_pharma, emb_interaction, summary], dim=-1) | |
| return self.graph_bundle_proj(bundle) | |
| def forward( | |
| self, | |
| fingerprint: torch.Tensor, | |
| semantic: torch.Tensor, | |
| pharmacology: torch.Tensor, | |
| pairwise: torch.Tensor, | |
| molecular: torch.Tensor, | |
| graph_features: Optional[torch.Tensor] = None, | |
| graph_bundle: Optional[dict[str, Any]] = None, | |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| device = fingerprint.device | |
| fp = self.fingerprint_branch(self._ensure_batch(fingerprint)) | |
| sem = self.semantic_branch(self._ensure_batch(semantic)) | |
| ph = self.pharmacology_branch(self._ensure_batch(pharmacology)) | |
| pair = self.pairwise_branch(self._ensure_batch(pairwise)) | |
| mol = self.molecular_branch(self._ensure_batch(molecular)) | |
| if graph_features is not None: | |
| graph_dense = self.graph_dense_proj(self._ensure_batch(graph_features.to(device))) | |
| else: | |
| graph_dense = torch.zeros((fp.shape[0], self.config.graph_output_dim), device=device) | |
| graph_bundle_embedding = self._graph_bundle_embedding(graph_bundle, device) | |
| if graph_bundle_embedding.shape[0] != fp.shape[0]: | |
| graph_bundle_embedding = graph_bundle_embedding.expand(fp.shape[0], -1) | |
| if graph_dense.shape[0] != fp.shape[0]: | |
| graph_dense = graph_dense.expand(fp.shape[0], -1) | |
| branches = [fp, sem, ph, pair, mol, graph_dense, graph_bundle_embedding] | |
| fused = torch.cat(branches, dim=-1) | |
| fused = self.fusion_proj(fused).unsqueeze(1) | |
| fused, attn_weights = self.fusion_attention(fused) | |
| fused = fused.squeeze(1) | |
| fused = fused + self.residual_fusion(fused) | |
| logits = self.head(fused) | |
| confidence = torch.sigmoid(self.confidence_head(fused)).squeeze(-1) | |
| uncertainty = torch.sigmoid(self.uncertainty_head(fused)).squeeze(-1) | |
| severe_probability = torch.softmax(logits, dim=-1)[..., -1] | |
| return logits, { | |
| "attention": attn_weights, | |
| "fused_embedding": fused, | |
| "confidence_embedding": confidence.unsqueeze(-1), | |
| "uncertainty_embedding": uncertainty.unsqueeze(-1), | |
| "severe_probability": severe_probability.unsqueeze(-1), | |
| } | |