|
|
""" |
|
|
This module provides transformer-based models for processing hierarchical VCF data |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import logging |
|
|
from typing import Dict, List, Tuple, Optional, Union, Any |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
from transformers.utils import ModelOutput |
|
|
|
|
|
from config import ModelConfig, ConfigManager |
|
|
from tokenizer import HierarchicalVCFTokenizer |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HierarchicalVCFOutput(ModelOutput): |
|
|
""" |
|
|
Args: |
|
|
loss: Classification loss (if labels provided) |
|
|
logits: Classification logits |
|
|
hidden_states: Last hidden states |
|
|
attentions: Attention weights from all layers |
|
|
hierarchical_embeddings: Embeddings at each hierarchical level |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: torch.FloatTensor = None |
|
|
hidden_states: Optional[torch.FloatTensor] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
hierarchical_embeddings: Optional[Dict[str, torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
class HierarchicalVCFConfig(PretrainedConfig): |
|
|
|
|
|
model_type = "hierarchical-vcf" |
|
|
|
|
|
def __init__(self, |
|
|
vocab_sizes: Optional[Dict[str, int]] = None, |
|
|
embed_dim: int = 64, |
|
|
transformer_dim: int = 256, |
|
|
nhead: int = 8, |
|
|
num_layers: int = 3, |
|
|
num_classes: int = 2, |
|
|
hidden_dims: List[int] = None, |
|
|
dropout: float = 0.1, |
|
|
activation: str = "gelu", |
|
|
layer_norm_eps: float = 1e-12, |
|
|
max_position_embeddings: int = 1024, |
|
|
use_hierarchical_attention: bool = True, |
|
|
use_positional_encoding: bool = True, |
|
|
attention_probs_dropout_prob: float = 0.1, |
|
|
hidden_dropout_prob: float = 0.1, |
|
|
classifier_dropout: Optional[float] = None, |
|
|
**kwargs): |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.vocab_sizes = vocab_sizes or { |
|
|
'impact': 10, 'ref': 10, 'alt': 10, |
|
|
'chromosome': 30, 'pathway': 100, 'gene': 1000 |
|
|
} |
|
|
self.embed_dim = embed_dim |
|
|
self.transformer_dim = transformer_dim |
|
|
self.nhead = nhead |
|
|
self.num_layers = num_layers |
|
|
self.num_classes = num_classes |
|
|
self.hidden_dims = hidden_dims or [512, 256] |
|
|
self.dropout = dropout |
|
|
self.activation = activation |
|
|
self.layer_norm_eps = layer_norm_eps |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.use_hierarchical_attention = use_hierarchical_attention |
|
|
self.use_positional_encoding = use_positional_encoding |
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
|
self.classifier_dropout = classifier_dropout |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
|
|
|
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * |
|
|
(-math.log(10000.0) / d_model)) |
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0).transpose(0, 1) |
|
|
|
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: Tensor of shape [seq_len, batch_size, d_model] |
|
|
""" |
|
|
x = x + self.pe[:x.size(0), :] |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
class MutationEmbedder(nn.Module): |
|
|
|
|
|
def __init__(self, vocab_sizes: Dict[str, int], embed_dim: int = 64, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embed_dim |
|
|
self.mutation_fields = ['impact', 'ref', 'alt'] |
|
|
|
|
|
|
|
|
self.embed_layers = nn.ModuleDict({ |
|
|
field: nn.Embedding(vocab_sizes.get(field, 100), embed_dim, padding_idx=0) |
|
|
for field in self.mutation_fields |
|
|
}) |
|
|
|
|
|
|
|
|
self.mutation_dim = embed_dim * len(self.mutation_fields) |
|
|
self.projection = nn.Linear(self.mutation_dim, embed_dim) |
|
|
self.layer_norm = nn.LayerNorm(embed_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, mutation_batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
mutation_batch: Dict with tensors for each field |
|
|
|
|
|
Returns: |
|
|
Embedded mutations tensor [batch_size, seq_len, embed_dim] |
|
|
""" |
|
|
embeddings = [] |
|
|
|
|
|
for field in self.mutation_fields: |
|
|
if field in mutation_batch: |
|
|
field_emb = self.embed_layers[field](mutation_batch[field]) |
|
|
embeddings.append(field_emb) |
|
|
|
|
|
if not embeddings: |
|
|
raise ValueError("No valid mutation fields found in input") |
|
|
|
|
|
|
|
|
concat_emb = torch.cat(embeddings, dim=-1) |
|
|
projected_emb = self.projection(concat_emb) |
|
|
|
|
|
|
|
|
output = self.layer_norm(projected_emb) |
|
|
output = self.dropout(output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class HierarchicalAttention(nn.Module): |
|
|
|
|
|
def __init__(self, d_model: int, nhead: int = 8, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.nhead = nhead |
|
|
|
|
|
|
|
|
self.multihead_attn = nn.MultiheadAttention( |
|
|
d_model, nhead, dropout=dropout, batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.attention_weights = nn.Parameter(torch.randn(d_model)) |
|
|
self.layer_norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor [batch_size, seq_len, d_model] |
|
|
mask: Attention mask [batch_size, seq_len] |
|
|
Returns: |
|
|
Tuple of (pooled_output, attention_weights) |
|
|
""" |
|
|
|
|
|
attn_output, attn_weights = self.multihead_attn(x, x, x, key_padding_mask=mask) |
|
|
attn_output = self.layer_norm(attn_output + x) |
|
|
|
|
|
|
|
|
scores = torch.matmul(attn_output, self.attention_weights) |
|
|
|
|
|
if mask is not None: |
|
|
scores = scores.masked_fill(mask, float('-inf')) |
|
|
|
|
|
attention_probs = F.softmax(scores, dim=-1) |
|
|
pooled_output = torch.sum(attention_probs.unsqueeze(-1) * attn_output, dim=1) |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
return pooled_output, attention_probs |
|
|
|
|
|
|
|
|
class HierarchicalTransformerLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model: int, nhead: int = 8, dim_feedforward: int = 2048, |
|
|
dropout: float = 0.1, activation: str = "gelu"): |
|
|
super().__init__() |
|
|
|
|
|
self.hierarchical_attention = HierarchicalAttention(d_model, nhead, dropout) |
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
|
|
if activation == "gelu": |
|
|
self.activation = F.gelu |
|
|
elif activation == "relu": |
|
|
self.activation = F.relu |
|
|
else: |
|
|
raise ValueError(f"Unsupported activation: {activation}") |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor [batch_size, seq_len, d_model] |
|
|
mask: Attention mask |
|
|
Returns: |
|
|
Tuple of (output, attention_weights) |
|
|
""" |
|
|
|
|
|
attn_output, attn_weights = self.hierarchical_attention(x, mask) |
|
|
x = self.norm1(x.mean(dim=1) + self.dropout1(attn_output)) |
|
|
|
|
|
|
|
|
ff_output = self.linear2(self.dropout2(self.activation(self.linear1(x)))) |
|
|
x = self.norm2(x + ff_output) |
|
|
|
|
|
return x, attn_weights |
|
|
|
|
|
|
|
|
class HierarchicalVCFModel(PreTrainedModel): |
|
|
""" |
|
|
This model processes VCF data in a hierarchical manner: |
|
|
Mutations -> Genes -> Chromosomes -> Pathways -> Sample |
|
|
""" |
|
|
|
|
|
config_class = HierarchicalVCFConfig |
|
|
|
|
|
def __init__(self, config: HierarchicalVCFConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.config = config |
|
|
self.num_classes = config.num_classes |
|
|
|
|
|
|
|
|
self.mutation_embedder = MutationEmbedder( |
|
|
vocab_sizes=config.vocab_sizes, |
|
|
embed_dim=config.embed_dim, |
|
|
dropout=config.hidden_dropout_prob |
|
|
) |
|
|
|
|
|
|
|
|
if config.use_positional_encoding: |
|
|
self.pos_encoder = PositionalEncoding( |
|
|
config.embed_dim, |
|
|
max_len=config.max_position_embeddings, |
|
|
dropout=config.hidden_dropout_prob |
|
|
) |
|
|
|
|
|
|
|
|
self.transformer_layers = nn.ModuleList([ |
|
|
HierarchicalTransformerLayer( |
|
|
d_model=config.embed_dim, |
|
|
nhead=config.nhead, |
|
|
dim_feedforward=config.transformer_dim, |
|
|
dropout=config.attention_probs_dropout_prob, |
|
|
activation=config.activation |
|
|
) |
|
|
for _ in range(config.num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.gene_aggregator = HierarchicalAttention(config.embed_dim, config.nhead) |
|
|
self.chromosome_aggregator = HierarchicalAttention(config.embed_dim, config.nhead) |
|
|
self.pathway_aggregator = HierarchicalAttention(config.embed_dim, config.nhead) |
|
|
|
|
|
|
|
|
classifier_layers = [] |
|
|
input_dim = config.embed_dim |
|
|
|
|
|
for hidden_dim in config.hidden_dims: |
|
|
classifier_layers.extend([ |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(config.classifier_dropout or config.hidden_dropout_prob) |
|
|
]) |
|
|
input_dim = hidden_dim |
|
|
|
|
|
classifier_layers.append(nn.Linear(input_dim, config.num_classes)) |
|
|
|
|
|
self.classifier = nn.Sequential(*classifier_layers) |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
torch.nn.init.ones_(module.weight) |
|
|
|
|
|
def forward(self, |
|
|
input_data: Dict[str, Any], |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True) -> Union[Tuple, HierarchicalVCFOutput]: |
|
|
""" |
|
|
Args: |
|
|
input_data: Hierarchical input data from data collator |
|
|
labels: Labels for supervised learning |
|
|
output_attentions: Whether to output attention weights |
|
|
output_hidden_states: Whether to output hidden states |
|
|
return_dict: Whether to return ModelOutput object |
|
|
Returns: |
|
|
HierarchicalVCFOutput or tuple of outputs |
|
|
""" |
|
|
|
|
|
batch_samples = input_data['samples'] |
|
|
batch_size = len(batch_samples) |
|
|
|
|
|
sample_embeddings = [] |
|
|
all_attentions = [] if output_attentions else None |
|
|
hierarchical_embeddings = {} if output_hidden_states else None |
|
|
|
|
|
for sample_idx, sample in enumerate(batch_samples): |
|
|
sample_embedding = self._process_sample( |
|
|
sample, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states |
|
|
) |
|
|
|
|
|
if output_attentions: |
|
|
sample_embedding, sample_attentions = sample_embedding |
|
|
all_attentions.append(sample_attentions) |
|
|
|
|
|
if output_hidden_states: |
|
|
sample_embedding, sample_hierarchical = sample_embedding |
|
|
for level, emb in sample_hierarchical.items(): |
|
|
if level not in hierarchical_embeddings: |
|
|
hierarchical_embeddings[level] = [] |
|
|
hierarchical_embeddings[level].append(emb) |
|
|
|
|
|
sample_embeddings.append(sample_embedding) |
|
|
|
|
|
|
|
|
if sample_embeddings: |
|
|
hidden_states = torch.stack(sample_embeddings) |
|
|
else: |
|
|
hidden_states = torch.zeros(batch_size, self.config.embed_dim, device=self.device) |
|
|
|
|
|
|
|
|
logits = self.classifier(hidden_states) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.num_classes == 1: |
|
|
|
|
|
loss_fct = nn.MSELoss() |
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
|
else: |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
if output_hidden_states: |
|
|
output = output + (hidden_states,) |
|
|
if output_attentions: |
|
|
output = output + (all_attentions,) |
|
|
if loss is not None: |
|
|
output = (loss,) + output |
|
|
return output |
|
|
|
|
|
return HierarchicalVCFOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=hidden_states, |
|
|
attentions=all_attentions, |
|
|
hierarchical_embeddings=hierarchical_embeddings |
|
|
) |
|
|
|
|
|
def _process_sample(self, |
|
|
sample: Dict[str, Any], |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Process a single hierarchical sample. |
|
|
Args: |
|
|
sample: Single sample from batch |
|
|
output_attentions: Whether to return attention weights |
|
|
output_hidden_states: Whether to return hierarchical embeddings |
|
|
Returns: |
|
|
Sample embedding tensor or tuple with additional outputs |
|
|
""" |
|
|
|
|
|
pathway_embeddings = [] |
|
|
sample_attentions = {} if output_attentions else None |
|
|
sample_hierarchical = {} if output_hidden_states else None |
|
|
|
|
|
for pathway_token, chromosomes in sample.items(): |
|
|
chromosome_embeddings = [] |
|
|
|
|
|
for chrom_token, genes in chromosomes.items(): |
|
|
gene_embeddings = [] |
|
|
|
|
|
for gene_token, mutations in genes.items(): |
|
|
|
|
|
gene_embedding = self._process_gene_mutations( |
|
|
mutations, |
|
|
output_attentions=output_attentions |
|
|
) |
|
|
|
|
|
if output_attentions: |
|
|
gene_embedding, gene_attentions = gene_embedding |
|
|
if 'gene_level' not in sample_attentions: |
|
|
sample_attentions['gene_level'] = [] |
|
|
sample_attentions['gene_level'].append(gene_attentions) |
|
|
|
|
|
gene_embeddings.append(gene_embedding) |
|
|
|
|
|
if gene_embeddings: |
|
|
|
|
|
gene_tensor = torch.stack(gene_embeddings).unsqueeze(0) |
|
|
chrom_embedding, chrom_attention = self.chromosome_aggregator(gene_tensor) |
|
|
chrom_embedding = chrom_embedding.squeeze(0) |
|
|
|
|
|
chromosome_embeddings.append(chrom_embedding) |
|
|
|
|
|
if output_attentions: |
|
|
if 'chromosome_level' not in sample_attentions: |
|
|
sample_attentions['chromosome_level'] = [] |
|
|
sample_attentions['chromosome_level'].append(chrom_attention) |
|
|
|
|
|
if chromosome_embeddings: |
|
|
|
|
|
chrom_tensor = torch.stack(chromosome_embeddings).unsqueeze(0) |
|
|
pathway_embedding, pathway_attention = self.pathway_aggregator(chrom_tensor) |
|
|
pathway_embedding = pathway_embedding.squeeze(0) |
|
|
|
|
|
pathway_embeddings.append(pathway_embedding) |
|
|
|
|
|
if output_attentions: |
|
|
if 'pathway_level' not in sample_attentions: |
|
|
sample_attentions['pathway_level'] = [] |
|
|
sample_attentions['pathway_level'].append(pathway_attention) |
|
|
|
|
|
if output_hidden_states: |
|
|
sample_hierarchical['pathway_embeddings'] = pathway_embeddings |
|
|
|
|
|
if pathway_embeddings: |
|
|
|
|
|
pathway_tensor = torch.stack(pathway_embeddings).unsqueeze(0) |
|
|
sample_embedding, sample_attention = self.gene_aggregator(pathway_tensor) |
|
|
sample_embedding = sample_embedding.squeeze(0) |
|
|
|
|
|
if output_attentions: |
|
|
sample_attentions['sample_level'] = sample_attention |
|
|
else: |
|
|
|
|
|
sample_embedding = torch.zeros(self.config.embed_dim, device=self.device) |
|
|
|
|
|
|
|
|
result = sample_embedding |
|
|
|
|
|
if output_attentions and output_hidden_states: |
|
|
result = (result, sample_attentions, sample_hierarchical) |
|
|
elif output_attentions: |
|
|
result = (result, sample_attentions) |
|
|
elif output_hidden_states: |
|
|
result = (result, sample_hierarchical) |
|
|
|
|
|
return result |
|
|
|
|
|
def _process_gene_mutations(self, |
|
|
mutations: Dict[str, Any], |
|
|
output_attentions: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Process mutations for a single gene. |
|
|
Args: |
|
|
mutations: Mutation data for gene |
|
|
output_attentions: Whether to return attention weights |
|
|
Returns: |
|
|
Gene embedding tensor |
|
|
""" |
|
|
|
|
|
|
|
|
mutation_tensors = {} |
|
|
attention_mask = None |
|
|
|
|
|
for field in ['impact', 'ref', 'alt']: |
|
|
if field in mutations: |
|
|
if isinstance(mutations[field], dict) and 'tokens' in mutations[field]: |
|
|
|
|
|
mutation_tensors[field] = torch.tensor(mutations[field]['tokens'], device=self.device) |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.tensor(mutations[field]['mask'], device=self.device).bool() |
|
|
else: |
|
|
|
|
|
mutation_tensors[field] = torch.tensor(mutations[field], device=self.device) |
|
|
|
|
|
if not mutation_tensors: |
|
|
return torch.zeros(self.config.embed_dim, device=self.device) |
|
|
|
|
|
|
|
|
mutation_embeddings = self.mutation_embedder(mutation_tensors) |
|
|
|
|
|
|
|
|
if self.config.use_positional_encoding: |
|
|
mutation_embeddings = mutation_embeddings.unsqueeze(1) |
|
|
mutation_embeddings = self.pos_encoder(mutation_embeddings) |
|
|
mutation_embeddings = mutation_embeddings.squeeze(1) |
|
|
|
|
|
|
|
|
mutation_embeddings = mutation_embeddings.unsqueeze(0) |
|
|
|
|
|
layer_attentions = [] if output_attentions else None |
|
|
|
|
|
for layer in self.transformer_layers: |
|
|
mutation_embeddings, layer_attention = layer(mutation_embeddings, attention_mask) |
|
|
mutation_embeddings = mutation_embeddings.unsqueeze(1) |
|
|
|
|
|
if output_attentions: |
|
|
layer_attentions.append(layer_attention) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).expand_as(mutation_embeddings.squeeze(0)) |
|
|
masked_embeddings = mutation_embeddings.squeeze(0) * mask_expanded.float() |
|
|
gene_embedding = masked_embeddings.sum(dim=0) / mask_expanded.sum(dim=0).clamp(min=1) |
|
|
else: |
|
|
|
|
|
gene_embedding = mutation_embeddings.mean(dim=1).squeeze(0) |
|
|
|
|
|
if output_attentions: |
|
|
return gene_embedding, layer_attentions |
|
|
|
|
|
return gene_embedding |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
"""Get model device.""" |
|
|
return next(self.parameters()).device |
|
|
|
|
|
def create_model_from_config(config_manager: ConfigManager, |
|
|
tokenizer: HierarchicalVCFTokenizer) -> HierarchicalVCFModel: |
|
|
""" |
|
|
Args: |
|
|
config_manager: Configuration manager |
|
|
tokenizer: Tokenizer instance |
|
|
task_type: Type of task ('classification', 'regression') |
|
|
Returns: |
|
|
Configured model |
|
|
""" |
|
|
|
|
|
model_config = config_manager.model_config |
|
|
|
|
|
|
|
|
hf_config = HierarchicalVCFConfig( |
|
|
vocab_sizes=tokenizer.get_all_vocab_sizes(), |
|
|
embed_dim=model_config.embed_dim, |
|
|
transformer_dim=model_config.transformer_dim, |
|
|
nhead=model_config.nhead, |
|
|
num_layers=model_config.num_layers, |
|
|
num_classes=model_config.num_classes, |
|
|
hidden_dims=model_config.hidden_dims, |
|
|
dropout=model_config.dropout |
|
|
) |
|
|
|
|
|
|
|
|
model = HierarchicalVCFModel(hf_config) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class ModelTrainer: |
|
|
""" |
|
|
Training utilities for Hierarchical VCF Model. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model: HierarchicalVCFModel, |
|
|
train_dataloader, |
|
|
val_dataloader, |
|
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
|
|
device: Optional[torch.device] = None): |
|
|
|
|
|
self.model = model |
|
|
self.train_dataloader = train_dataloader |
|
|
self.val_dataloader = val_dataloader |
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
if optimizer is None: |
|
|
self.optimizer = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=1e-4, |
|
|
weight_decay=0.01 |
|
|
) |
|
|
else: |
|
|
self.optimizer = optimizer |
|
|
|
|
|
self.scheduler = scheduler |
|
|
|
|
|
|
|
|
self.train_losses = [] |
|
|
self.val_losses = [] |
|
|
self.val_accuracies = [] |
|
|
|
|
|
def train_epoch(self) -> float: |
|
|
"""Train for one epoch.""" |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch in self.train_dataloader: |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if 'labels' in batch: |
|
|
labels = batch['labels'].to(self.device) |
|
|
else: |
|
|
labels = None |
|
|
|
|
|
|
|
|
outputs = self.model(batch, labels=labels) |
|
|
loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0] |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
if self.scheduler: |
|
|
self.scheduler.step() |
|
|
|
|
|
avg_loss = total_loss / max(num_batches, 1) |
|
|
self.train_losses.append(avg_loss) |
|
|
|
|
|
return avg_loss |
|
|
|
|
|
def validate(self) -> Tuple[float, float]: |
|
|
"""Validate model.""" |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
correct_predictions = 0 |
|
|
total_predictions = 0 |
|
|
num_batches = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in self.val_dataloader: |
|
|
|
|
|
if 'labels' in batch: |
|
|
labels = batch['labels'].to(self.device) |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
outputs = self.model(batch, labels=labels) |
|
|
loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0] |
|
|
logits = outputs.logits if hasattr(outputs, 'logits') else outputs[1] |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
predictions = torch.argmax(logits, dim=-1) |
|
|
correct_predictions += (predictions == labels).sum().item() |
|
|
total_predictions += labels.size(0) |
|
|
num_batches += 1 |
|
|
|
|
|
avg_loss = total_loss / max(num_batches, 1) |
|
|
accuracy = correct_predictions / max(total_predictions, 1) |
|
|
|
|
|
self.val_losses.append(avg_loss) |
|
|
self.val_accuracies.append(accuracy) |
|
|
|
|
|
return avg_loss, accuracy |
|
|
|
|
|
def train(self, num_epochs: int, save_path: Optional[str] = None) -> Dict[str, List[float]]: |
|
|
""" |
|
|
Train model for specified number of epochs. |
|
|
|
|
|
Args: |
|
|
num_epochs: Number of training epochs |
|
|
save_path: Path to save best model |
|
|
|
|
|
Returns: |
|
|
Training history |
|
|
""" |
|
|
|
|
|
best_val_loss = float('inf') |
|
|
|
|
|
logger.info(f"Starting training for {num_epochs} epochs...") |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
train_loss = self.train_epoch() |
|
|
|
|
|
|
|
|
val_loss, val_accuracy = self.validate() |
|
|
|
|
|
logger.info( |
|
|
f"Epoch {epoch+1}/{num_epochs}: " |
|
|
f"Train Loss: {train_loss:.4f}, " |
|
|
f"Val Loss: {val_loss:.4f}, " |
|
|
f"Val Accuracy: {val_accuracy:.4f}" |
|
|
) |
|
|
|
|
|
|
|
|
if save_path and val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
self.model.save_pretrained(save_path) |
|
|
logger.info(f"Saved best model to {save_path}") |
|
|
|
|
|
return { |
|
|
'train_losses': self.train_losses, |
|
|
'val_losses': self.val_losses, |
|
|
'val_accuracies': self.val_accuracies |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from tokenizer import create_tokenizer_from_config |
|
|
from dataset import create_data_module_from_config |
|
|
|
|
|
|
|
|
config_manager = ConfigManager() |
|
|
config_manager.model_config.embed_dim = 32 |
|
|
config_manager.model_config.num_classes = 2 |
|
|
|
|
|
|
|
|
tokenizer = create_tokenizer_from_config(config_manager) |
|
|
|
|
|
|
|
|
example_data = { |
|
|
'sample1': { |
|
|
'pathway1': { |
|
|
'chr1': { |
|
|
'gene1': [ |
|
|
{'impact': 'HIGH', 'reference': 'A', 'alternate': 'T'} |
|
|
] |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
tokenizer.build_vocabulary(example_data) |
|
|
|
|
|
|
|
|
model = create_model_from_config(config_manager, tokenizer) |
|
|
|
|
|
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters") |
|
|
print(f"Model config: {model.config}") |
|
|
|
|
|
|
|
|
dummy_batch = { |
|
|
'samples': [example_data['sample1']], |
|
|
'batch_size': 1 |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(dummy_batch) |
|
|
print(f"Output logits shape: {outputs.logits.shape}") |
|
|
print(f"Output logits: {outputs.logits}") |