ddi / src /training /advanced_ddi_model.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Graph-aware multimodal DDI foundation model.
This module preserves the historical AdvancedDDINet import path while
replacing the old tabular fusion net with a graph-based, clinically safer
representation learner. The model supports:
- molecular graph encoders for each drug
- pharmacology and interaction graph encoders
- residual MLP branches for legacy features
- gated multi-head fusion with uncertainty and confidence heads
The implementation remains CPU-compatible and can run without PyTorch
Geometric by using the native graph encoder defined in
training.graph_representations.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional, Sequence
import torch
import torch.nn as nn
from .graph_representations import GraphEncoder, GraphSample
class ResidualBlock(nn.Module):
def __init__(self, dim: int, dropout: float = 0.2):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.block = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 2, dim),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.dropout(self.block(self.norm(x)))
class BranchEncoder(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, depth: int = 2, dropout: float = 0.2):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
)
self.blocks = nn.ModuleList([ResidualBlock(hidden_dim, dropout=dropout) for _ in range(depth)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
for block in self.blocks:
x = block(x)
return x
class GatedAttentionFusion(nn.Module):
def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.2):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, batch_first=True, dropout=dropout)
self.gate = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim),
nn.Sigmoid(),
)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 2, dim),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x_norm = self.norm(x)
attn_out, attn_weights = self.attn(x_norm, x_norm, x_norm, need_weights=True)
gated = self.gate(x_norm) * attn_out
x = x + self.dropout(gated)
x = x + self.dropout(self.ffn(self.norm(x)))
return x, attn_weights
@dataclass(frozen=True)
class AdvancedModelConfig:
fingerprint_dim: int = 8192
semantic_dim: int = 3072
pharmacology_dim: int = 512
pair_dim: int = 64
molecular_dim: int = 41
hidden_dim: int = 256
fusion_dim: int = 256
num_classes: int = 4
dropout: float = 0.25
num_heads: int = 4
graph_input_dim: int = 0
graph_hidden_dim: int = 128
graph_output_dim: int = 128
graph_layers: int = 3
use_graph_branches: bool = True
confidence_dim: int = 64
class _GraphBranch(nn.Module):
def __init__(self, output_dim: int, hidden_dim: int, layers: int, dropout: float):
super().__init__()
self.encoder = GraphEncoder(input_dim=32, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=layers, dropout=dropout)
def _to_device(self, graph: GraphSample, device: torch.device) -> GraphSample:
if graph.node_features.device == device:
return graph
return GraphSample(
node_features=graph.node_features.to(device),
edge_index=graph.edge_index.to(device),
edge_features=graph.edge_features.to(device),
graph_features=graph.graph_features.to(device),
node_types=None if graph.node_types is None else graph.node_types.to(device),
edge_types=None if graph.edge_types is None else graph.edge_types.to(device),
valid=graph.valid,
)
def forward(self, graph: GraphSample | Sequence[GraphSample] | None, device: torch.device) -> torch.Tensor:
if graph is None:
return torch.zeros((1, self.encoder.readout[-1].normalized_shape[0]), device=device)
if isinstance(graph, Sequence):
embeddings = [self.encoder(self._to_device(item, device)) for item in graph]
return torch.cat(embeddings, dim=0)
return self.encoder(self._to_device(graph, device))
class AdvancedDDINet(nn.Module):
"""Graph-aware multimodal DDI model with selective-prediction outputs."""
def __init__(self, config: AdvancedModelConfig):
super().__init__()
self.config = config
self.fingerprint_branch = BranchEncoder(config.fingerprint_dim, config.hidden_dim, depth=3, dropout=config.dropout)
self.semantic_branch = BranchEncoder(config.semantic_dim, config.hidden_dim, depth=2, dropout=config.dropout)
self.pharmacology_branch = BranchEncoder(config.pharmacology_dim, config.hidden_dim, depth=2, dropout=config.dropout)
self.pairwise_branch = BranchEncoder(config.pair_dim, config.hidden_dim // 2, depth=1, dropout=config.dropout)
self.molecular_branch = BranchEncoder(config.molecular_dim, config.hidden_dim // 2, depth=1, dropout=config.dropout)
self.graph_dense_proj = nn.Sequential(
nn.LazyLinear(config.graph_output_dim),
nn.GELU(),
nn.Dropout(config.dropout),
nn.LayerNorm(config.graph_output_dim),
)
self.use_graph_branches = config.use_graph_branches
self.molecular_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout)
self.pharmacology_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout)
self.interaction_graph_branch = _GraphBranch(config.graph_output_dim, config.graph_hidden_dim, config.graph_layers, config.dropout)
self.graph_bundle_proj = nn.Sequential(
nn.LazyLinear(config.graph_output_dim),
nn.GELU(),
nn.Dropout(config.dropout),
nn.LayerNorm(config.graph_output_dim),
)
fusion_input_dim = config.hidden_dim * 4 + config.graph_output_dim * 2
self.fusion_proj = nn.Linear(fusion_input_dim, config.fusion_dim)
self.fusion_attention = GatedAttentionFusion(config.fusion_dim, num_heads=config.num_heads, dropout=config.dropout)
self.residual_fusion = nn.Sequential(
nn.LayerNorm(config.fusion_dim),
nn.Linear(config.fusion_dim, config.fusion_dim),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.fusion_dim, config.fusion_dim),
)
self.confidence_head = nn.Sequential(
nn.LayerNorm(config.fusion_dim),
nn.Linear(config.fusion_dim, config.confidence_dim),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.confidence_dim, 1),
)
self.uncertainty_head = nn.Sequential(
nn.LayerNorm(config.fusion_dim),
nn.Linear(config.fusion_dim, config.confidence_dim),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.confidence_dim, 1),
)
self.head = nn.Sequential(
nn.LayerNorm(config.fusion_dim),
nn.Linear(config.fusion_dim, config.fusion_dim),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.fusion_dim, config.num_classes),
)
def _ensure_batch(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.unsqueeze(0) if tensor.dim() == 1 else tensor
def _graph_bundle_embedding(self, graph_bundle: Optional[dict[str, Any] | Sequence[dict[str, Any]]], device: torch.device) -> torch.Tensor:
if not self.use_graph_branches or graph_bundle is None:
return torch.zeros((1, self.config.graph_output_dim * 4), device=device)
if isinstance(graph_bundle, Sequence):
return torch.cat([self._graph_bundle_embedding(item, device) for item in graph_bundle], dim=0)
drug_a = graph_bundle.get("drug_a_graph")
drug_b = graph_bundle.get("drug_b_graph")
pharmacology_graph = graph_bundle.get("pharmacology_graph")
interaction_graph = graph_bundle.get("interaction_graph")
interaction_summary = graph_bundle.get("interaction_summary")
emb_a = self.molecular_graph_branch(drug_a, device)
emb_b = self.molecular_graph_branch(drug_b, device)
emb_pharma = self.pharmacology_graph_branch(pharmacology_graph, device)
emb_interaction = self.interaction_graph_branch(interaction_graph, device)
if interaction_summary is None:
summary = torch.zeros((1, self.config.graph_output_dim), device=device)
else:
summary = self.graph_dense_proj(self._ensure_batch(interaction_summary.to(device)))
pair_graph = torch.cat([emb_a, emb_b, torch.abs(emb_a - emb_b), emb_a * emb_b], dim=-1)
bundle = torch.cat([pair_graph, emb_pharma, emb_interaction, summary], dim=-1)
return self.graph_bundle_proj(bundle)
def forward(
self,
fingerprint: torch.Tensor,
semantic: torch.Tensor,
pharmacology: torch.Tensor,
pairwise: torch.Tensor,
molecular: torch.Tensor,
graph_features: Optional[torch.Tensor] = None,
graph_bundle: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
device = fingerprint.device
fp = self.fingerprint_branch(self._ensure_batch(fingerprint))
sem = self.semantic_branch(self._ensure_batch(semantic))
ph = self.pharmacology_branch(self._ensure_batch(pharmacology))
pair = self.pairwise_branch(self._ensure_batch(pairwise))
mol = self.molecular_branch(self._ensure_batch(molecular))
if graph_features is not None:
graph_dense = self.graph_dense_proj(self._ensure_batch(graph_features.to(device)))
else:
graph_dense = torch.zeros((fp.shape[0], self.config.graph_output_dim), device=device)
graph_bundle_embedding = self._graph_bundle_embedding(graph_bundle, device)
if graph_bundle_embedding.shape[0] != fp.shape[0]:
graph_bundle_embedding = graph_bundle_embedding.expand(fp.shape[0], -1)
if graph_dense.shape[0] != fp.shape[0]:
graph_dense = graph_dense.expand(fp.shape[0], -1)
branches = [fp, sem, ph, pair, mol, graph_dense, graph_bundle_embedding]
fused = torch.cat(branches, dim=-1)
fused = self.fusion_proj(fused).unsqueeze(1)
fused, attn_weights = self.fusion_attention(fused)
fused = fused.squeeze(1)
fused = fused + self.residual_fusion(fused)
logits = self.head(fused)
confidence = torch.sigmoid(self.confidence_head(fused)).squeeze(-1)
uncertainty = torch.sigmoid(self.uncertainty_head(fused)).squeeze(-1)
severe_probability = torch.softmax(logits, dim=-1)[..., -1]
return logits, {
"attention": attn_weights,
"fused_embedding": fused,
"confidence_embedding": confidence.unsqueeze(-1),
"uncertainty_embedding": uncertainty.unsqueeze(-1),
"severe_probability": severe_probability.unsqueeze(-1),
}