""" This module provides transformer-based models for processing hierarchical VCF data """ import torch import torch.nn as nn import torch.nn.functional as F import math import logging from typing import Dict, List, Tuple, Optional, Union, Any from dataclasses import dataclass from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput from transformers.utils import ModelOutput from config import ModelConfig, ConfigManager from tokenizer import HierarchicalVCFTokenizer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class HierarchicalVCFOutput(ModelOutput): """ Args: loss: Classification loss (if labels provided) logits: Classification logits hidden_states: Last hidden states attentions: Attention weights from all layers hierarchical_embeddings: Embeddings at each hierarchical level """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[torch.FloatTensor] = None attentions: Optional[Tuple[torch.FloatTensor]] = None hierarchical_embeddings: Optional[Dict[str, torch.FloatTensor]] = None class HierarchicalVCFConfig(PretrainedConfig): model_type = "hierarchical-vcf" def __init__(self, vocab_sizes: Optional[Dict[str, int]] = None, embed_dim: int = 64, transformer_dim: int = 256, nhead: int = 8, num_layers: int = 3, num_classes: int = 2, hidden_dims: List[int] = None, dropout: float = 0.1, activation: str = "gelu", layer_norm_eps: float = 1e-12, max_position_embeddings: int = 1024, use_hierarchical_attention: bool = True, use_positional_encoding: bool = True, attention_probs_dropout_prob: float = 0.1, hidden_dropout_prob: float = 0.1, classifier_dropout: Optional[float] = None, **kwargs): super().__init__(**kwargs) self.vocab_sizes = vocab_sizes or { 'impact': 10, 'ref': 10, 'alt': 10, 'chromosome': 30, 'pathway': 100, 'gene': 1000 } self.embed_dim = embed_dim self.transformer_dim = transformer_dim self.nhead = nhead self.num_layers = num_layers self.num_classes = num_classes self.hidden_dims = hidden_dims or [512, 256] self.dropout = dropout self.activation = activation self.layer_norm_eps = layer_norm_eps self.max_position_embeddings = max_position_embeddings self.use_hierarchical_attention = use_hierarchical_attention self.use_positional_encoding = use_positional_encoding self.attention_probs_dropout_prob = attention_probs_dropout_prob self.hidden_dropout_prob = hidden_dropout_prob self.classifier_dropout = classifier_dropout class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor of shape [seq_len, batch_size, d_model] """ x = x + self.pe[:x.size(0), :] return self.dropout(x) class MutationEmbedder(nn.Module): def __init__(self, vocab_sizes: Dict[str, int], embed_dim: int = 64, dropout: float = 0.1): super().__init__() self.embed_dim = embed_dim self.mutation_fields = ['impact', 'ref', 'alt'] # Create embedding layers for each field self.embed_layers = nn.ModuleDict({ field: nn.Embedding(vocab_sizes.get(field, 100), embed_dim, padding_idx=0) for field in self.mutation_fields }) # Projection layer to combine embeddings self.mutation_dim = embed_dim * len(self.mutation_fields) self.projection = nn.Linear(self.mutation_dim, embed_dim) self.layer_norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, mutation_batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Args: mutation_batch: Dict with tensors for each field Returns: Embedded mutations tensor [batch_size, seq_len, embed_dim] """ embeddings = [] for field in self.mutation_fields: if field in mutation_batch: field_emb = self.embed_layers[field](mutation_batch[field]) embeddings.append(field_emb) if not embeddings: raise ValueError("No valid mutation fields found in input") # Concatenate and project concat_emb = torch.cat(embeddings, dim=-1) projected_emb = self.projection(concat_emb) # Apply layer norm and dropout output = self.layer_norm(projected_emb) output = self.dropout(output) return output class HierarchicalAttention(nn.Module): def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1): super().__init__() self.d_model = d_model self.nhead = nhead # Multi-head attention self.multihead_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=True ) # Attention pooling self.attention_weights = nn.Parameter(torch.randn(d_model)) self.layer_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Input tensor [batch_size, seq_len, d_model] mask: Attention mask [batch_size, seq_len] Returns: Tuple of (pooled_output, attention_weights) """ # Self-attention attn_output, attn_weights = self.multihead_attn(x, x, x, key_padding_mask=mask) attn_output = self.layer_norm(attn_output + x) # Residual connection # Attention pooling scores = torch.matmul(attn_output, self.attention_weights) # [batch_size, seq_len] if mask is not None: scores = scores.masked_fill(mask, float('-inf')) attention_probs = F.softmax(scores, dim=-1) # [batch_size, seq_len] pooled_output = torch.sum(attention_probs.unsqueeze(-1) * attn_output, dim=1) # [batch_size, d_model] pooled_output = self.dropout(pooled_output) return pooled_output, attention_probs class HierarchicalTransformerLayer(nn.Module): def __init__(self, d_model: int, nhead: int = 8, dim_feedforward: int = 2048, dropout: float = 0.1, activation: str = "gelu"): super().__init__() self.hierarchical_attention = HierarchicalAttention(d_model, nhead, dropout) # Feed-forward network self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) if activation == "gelu": self.activation = F.gelu elif activation == "relu": self.activation = F.relu else: raise ValueError(f"Unsupported activation: {activation}") def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Input tensor [batch_size, seq_len, d_model] mask: Attention mask Returns: Tuple of (output, attention_weights) """ # Hierarchical attention attn_output, attn_weights = self.hierarchical_attention(x, mask) x = self.norm1(x.mean(dim=1) + self.dropout1(attn_output)) # Pool input for residual # Feed-forward ff_output = self.linear2(self.dropout2(self.activation(self.linear1(x)))) x = self.norm2(x + ff_output) return x, attn_weights class HierarchicalVCFModel(PreTrainedModel): """ This model processes VCF data in a hierarchical manner: Mutations -> Genes -> Chromosomes -> Pathways -> Sample """ config_class = HierarchicalVCFConfig def __init__(self, config: HierarchicalVCFConfig): super().__init__(config) self.config = config self.num_classes = config.num_classes # Embedding layers self.mutation_embedder = MutationEmbedder( vocab_sizes=config.vocab_sizes, embed_dim=config.embed_dim, dropout=config.hidden_dropout_prob ) # Positional encoding if config.use_positional_encoding: self.pos_encoder = PositionalEncoding( config.embed_dim, max_len=config.max_position_embeddings, dropout=config.hidden_dropout_prob ) # Hierarchical transformer layers self.transformer_layers = nn.ModuleList([ HierarchicalTransformerLayer( d_model=config.embed_dim, nhead=config.nhead, dim_feedforward=config.transformer_dim, dropout=config.attention_probs_dropout_prob, activation=config.activation ) for _ in range(config.num_layers) ]) # Hierarchical aggregation layers self.gene_aggregator = HierarchicalAttention(config.embed_dim, config.nhead) self.chromosome_aggregator = HierarchicalAttention(config.embed_dim, config.nhead) self.pathway_aggregator = HierarchicalAttention(config.embed_dim, config.nhead) # Classification head classifier_layers = [] input_dim = config.embed_dim for hidden_dim in config.hidden_dims: classifier_layers.extend([ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(config.classifier_dropout or config.hidden_dropout_prob) ]) input_dim = hidden_dim classifier_layers.append(nn.Linear(input_dim, config.num_classes)) self.classifier = nn.Sequential(*classifier_layers) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) def forward(self, input_data: Dict[str, Any], labels: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True) -> Union[Tuple, HierarchicalVCFOutput]: """ Args: input_data: Hierarchical input data from data collator labels: Labels for supervised learning output_attentions: Whether to output attention weights output_hidden_states: Whether to output hidden states return_dict: Whether to return ModelOutput object Returns: HierarchicalVCFOutput or tuple of outputs """ batch_samples = input_data['samples'] batch_size = len(batch_samples) sample_embeddings = [] all_attentions = [] if output_attentions else None hierarchical_embeddings = {} if output_hidden_states else None for sample_idx, sample in enumerate(batch_samples): sample_embedding = self._process_sample( sample, output_attentions=output_attentions, output_hidden_states=output_hidden_states ) if output_attentions: sample_embedding, sample_attentions = sample_embedding all_attentions.append(sample_attentions) if output_hidden_states: sample_embedding, sample_hierarchical = sample_embedding for level, emb in sample_hierarchical.items(): if level not in hierarchical_embeddings: hierarchical_embeddings[level] = [] hierarchical_embeddings[level].append(emb) sample_embeddings.append(sample_embedding) # Stack sample embeddings if sample_embeddings: hidden_states = torch.stack(sample_embeddings) # [batch_size, embed_dim] else: hidden_states = torch.zeros(batch_size, self.config.embed_dim, device=self.device) # Classification logits = self.classifier(hidden_states) # Compute loss if labels provided loss = None if labels is not None: if self.config.num_classes == 1: # Regression loss_fct = nn.MSELoss() loss = loss_fct(logits.squeeze(), labels.squeeze()) else: # Classification loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) if not return_dict: output = (logits,) if output_hidden_states: output = output + (hidden_states,) if output_attentions: output = output + (all_attentions,) if loss is not None: output = (loss,) + output return output return HierarchicalVCFOutput( loss=loss, logits=logits, hidden_states=hidden_states, attentions=all_attentions, hierarchical_embeddings=hierarchical_embeddings ) def _process_sample(self, sample: Dict[str, Any], output_attentions: bool = False, output_hidden_states: bool = False) -> torch.Tensor: """ Process a single hierarchical sample. Args: sample: Single sample from batch output_attentions: Whether to return attention weights output_hidden_states: Whether to return hierarchical embeddings Returns: Sample embedding tensor or tuple with additional outputs """ pathway_embeddings = [] sample_attentions = {} if output_attentions else None sample_hierarchical = {} if output_hidden_states else None for pathway_token, chromosomes in sample.items(): chromosome_embeddings = [] for chrom_token, genes in chromosomes.items(): gene_embeddings = [] for gene_token, mutations in genes.items(): # Process mutations for this gene gene_embedding = self._process_gene_mutations( mutations, output_attentions=output_attentions ) if output_attentions: gene_embedding, gene_attentions = gene_embedding if 'gene_level' not in sample_attentions: sample_attentions['gene_level'] = [] sample_attentions['gene_level'].append(gene_attentions) gene_embeddings.append(gene_embedding) if gene_embeddings: # Aggregate genes to chromosome level gene_tensor = torch.stack(gene_embeddings).unsqueeze(0) # [1, num_genes, embed_dim] chrom_embedding, chrom_attention = self.chromosome_aggregator(gene_tensor) chrom_embedding = chrom_embedding.squeeze(0) # [embed_dim] chromosome_embeddings.append(chrom_embedding) if output_attentions: if 'chromosome_level' not in sample_attentions: sample_attentions['chromosome_level'] = [] sample_attentions['chromosome_level'].append(chrom_attention) if chromosome_embeddings: # Aggregate chromosomes to pathway level chrom_tensor = torch.stack(chromosome_embeddings).unsqueeze(0) # [1, num_chroms, embed_dim] pathway_embedding, pathway_attention = self.pathway_aggregator(chrom_tensor) pathway_embedding = pathway_embedding.squeeze(0) # [embed_dim] pathway_embeddings.append(pathway_embedding) if output_attentions: if 'pathway_level' not in sample_attentions: sample_attentions['pathway_level'] = [] sample_attentions['pathway_level'].append(pathway_attention) if output_hidden_states: sample_hierarchical['pathway_embeddings'] = pathway_embeddings if pathway_embeddings: # Aggregate pathways to sample level pathway_tensor = torch.stack(pathway_embeddings).unsqueeze(0) # [1, num_pathways, embed_dim] sample_embedding, sample_attention = self.gene_aggregator(pathway_tensor) # Reuse gene aggregator sample_embedding = sample_embedding.squeeze(0) # [embed_dim] if output_attentions: sample_attentions['sample_level'] = sample_attention else: # Handle empty sample sample_embedding = torch.zeros(self.config.embed_dim, device=self.device) # Prepare return value result = sample_embedding if output_attentions and output_hidden_states: result = (result, sample_attentions, sample_hierarchical) elif output_attentions: result = (result, sample_attentions) elif output_hidden_states: result = (result, sample_hierarchical) return result def _process_gene_mutations(self, mutations: Dict[str, Any], output_attentions: bool = False) -> torch.Tensor: """ Process mutations for a single gene. Args: mutations: Mutation data for gene output_attentions: Whether to return attention weights Returns: Gene embedding tensor """ # Handle masked format from data collator mutation_tensors = {} attention_mask = None for field in ['impact', 'ref', 'alt']: if field in mutations: if isinstance(mutations[field], dict) and 'tokens' in mutations[field]: # Masked format mutation_tensors[field] = torch.tensor(mutations[field]['tokens'], device=self.device) if attention_mask is None: attention_mask = torch.tensor(mutations[field]['mask'], device=self.device).bool() else: # Direct format mutation_tensors[field] = torch.tensor(mutations[field], device=self.device) if not mutation_tensors: return torch.zeros(self.config.embed_dim, device=self.device) # Embed mutations mutation_embeddings = self.mutation_embedder(mutation_tensors) # [seq_len, embed_dim] # Add positional encoding if enabled if self.config.use_positional_encoding: mutation_embeddings = mutation_embeddings.unsqueeze(1) # [seq_len, 1, embed_dim] mutation_embeddings = self.pos_encoder(mutation_embeddings) mutation_embeddings = mutation_embeddings.squeeze(1) # [seq_len, embed_dim] # Apply transformer layers mutation_embeddings = mutation_embeddings.unsqueeze(0) # [1, seq_len, embed_dim] layer_attentions = [] if output_attentions else None for layer in self.transformer_layers: mutation_embeddings, layer_attention = layer(mutation_embeddings, attention_mask) mutation_embeddings = mutation_embeddings.unsqueeze(1) # Add seq dim back if output_attentions: layer_attentions.append(layer_attention) # Pool to get gene representation if attention_mask is not None: # Masked pooling mask_expanded = attention_mask.unsqueeze(-1).expand_as(mutation_embeddings.squeeze(0)) masked_embeddings = mutation_embeddings.squeeze(0) * mask_expanded.float() gene_embedding = masked_embeddings.sum(dim=0) / mask_expanded.sum(dim=0).clamp(min=1) else: # Simple mean pooling gene_embedding = mutation_embeddings.mean(dim=1).squeeze(0) if output_attentions: return gene_embedding, layer_attentions return gene_embedding @property def device(self) -> torch.device: """Get model device.""" return next(self.parameters()).device def create_model_from_config(config_manager: ConfigManager, tokenizer: HierarchicalVCFTokenizer) -> HierarchicalVCFModel: """ Args: config_manager: Configuration manager tokenizer: Tokenizer instance task_type: Type of task ('classification', 'regression') Returns: Configured model """ model_config = config_manager.model_config # Create Hugging Face config hf_config = HierarchicalVCFConfig( vocab_sizes=tokenizer.get_all_vocab_sizes(), embed_dim=model_config.embed_dim, transformer_dim=model_config.transformer_dim, nhead=model_config.nhead, num_layers=model_config.num_layers, num_classes=model_config.num_classes, hidden_dims=model_config.hidden_dims, dropout=model_config.dropout ) # Create model based on task type model = HierarchicalVCFModel(hf_config) return model # Model utilities class ModelTrainer: """ Training utilities for Hierarchical VCF Model. """ def __init__(self, model: HierarchicalVCFModel, train_dataloader, val_dataloader, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, device: Optional[torch.device] = None): self.model = model self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Move model to device self.model.to(self.device) # Default optimizer if optimizer is None: self.optimizer = torch.optim.AdamW( model.parameters(), lr=1e-4, weight_decay=0.01 ) else: self.optimizer = optimizer self.scheduler = scheduler # Training metrics self.train_losses = [] self.val_losses = [] self.val_accuracies = [] def train_epoch(self) -> float: """Train for one epoch.""" self.model.train() total_loss = 0.0 num_batches = 0 for batch in self.train_dataloader: self.optimizer.zero_grad() # Move data to device if 'labels' in batch: labels = batch['labels'].to(self.device) else: labels = None # Forward pass outputs = self.model(batch, labels=labels) loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0] # Backward pass loss.backward() self.optimizer.step() total_loss += loss.item() num_batches += 1 if self.scheduler: self.scheduler.step() avg_loss = total_loss / max(num_batches, 1) self.train_losses.append(avg_loss) return avg_loss def validate(self) -> Tuple[float, float]: """Validate model.""" self.model.eval() total_loss = 0.0 correct_predictions = 0 total_predictions = 0 num_batches = 0 with torch.no_grad(): for batch in self.val_dataloader: # Move data to device if 'labels' in batch: labels = batch['labels'].to(self.device) else: continue # Skip if no labels # Forward pass outputs = self.model(batch, labels=labels) loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0] logits = outputs.logits if hasattr(outputs, 'logits') else outputs[1] total_loss += loss.item() # Calculate accuracy predictions = torch.argmax(logits, dim=-1) correct_predictions += (predictions == labels).sum().item() total_predictions += labels.size(0) num_batches += 1 avg_loss = total_loss / max(num_batches, 1) accuracy = correct_predictions / max(total_predictions, 1) self.val_losses.append(avg_loss) self.val_accuracies.append(accuracy) return avg_loss, accuracy def train(self, num_epochs: int, save_path: Optional[str] = None) -> Dict[str, List[float]]: """ Train model for specified number of epochs. Args: num_epochs: Number of training epochs save_path: Path to save best model Returns: Training history """ best_val_loss = float('inf') logger.info(f"Starting training for {num_epochs} epochs...") for epoch in range(num_epochs): # Train train_loss = self.train_epoch() # Validate val_loss, val_accuracy = self.validate() logger.info( f"Epoch {epoch+1}/{num_epochs}: " f"Train Loss: {train_loss:.4f}, " f"Val Loss: {val_loss:.4f}, " f"Val Accuracy: {val_accuracy:.4f}" ) # Save best model if save_path and val_loss < best_val_loss: best_val_loss = val_loss self.model.save_pretrained(save_path) logger.info(f"Saved best model to {save_path}") return { 'train_losses': self.train_losses, 'val_losses': self.val_losses, 'val_accuracies': self.val_accuracies } # Example usage and testing if __name__ == "__main__": from tokenizer import create_tokenizer_from_config from dataset import create_data_module_from_config # Create configuration config_manager = ConfigManager() config_manager.model_config.embed_dim = 32 config_manager.model_config.num_classes = 2 # Create tokenizer and model tokenizer = create_tokenizer_from_config(config_manager) # Build vocabulary with example data example_data = { 'sample1': { 'pathway1': { 'chr1': { 'gene1': [ {'impact': 'HIGH', 'reference': 'A', 'alternate': 'T'} ] } } } } tokenizer.build_vocabulary(example_data) # Create model model = create_model_from_config(config_manager, tokenizer) print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters") print(f"Model config: {model.config}") # Test forward pass with dummy data dummy_batch = { 'samples': [example_data['sample1']], 'batch_size': 1 } with torch.no_grad(): outputs = model(dummy_batch) print(f"Output logits shape: {outputs.logits.shape}") print(f"Output logits: {outputs.logits}")