#!/usr/bin/env python3 """ Training Script for Multimodal Glycan BERT v3 Trains a multimodal transformer on glycan sequences, MS spectra, and 3D structures. Supports automatic checkpointing and resuming from interruptions. """ import torch import torch.nn as nn import torch.optim as optim from torch.amp import autocast, GradScaler import yaml import json import sys import argparse from pathlib import Path from tqdm import tqdm from datetime import datetime import math # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent.absolute())) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig from training.multimodal_dataset import MultimodalGlycanDataset, create_multimodal_dataloaders from training.multimodal_masking import MultimodalMaskingStrategy class MultimodalTrainer: """ Trainer for Multimodal Glycan BERT v3. Features: - Automatic checkpointing every N steps and epochs - Resume from any checkpoint - Detailed progress tracking per modality - Early stopping - Mixed precision training """ def __init__(self, config_path: Path, resume_from: str = None, restart: bool = False): """ Initialize trainer. Args: config_path: Path to multimodal_config.yaml resume_from: Path to checkpoint to resume from (optional, auto-detects if None) restart: If True, ignore any existing checkpoints and start fresh """ # Load config with open(config_path, 'r') as f: self.config = yaml.safe_load(f) self.config_path = config_path # Setup directories first self.checkpoint_dir = Path(config_path).parent.parent / self.config['output']['checkpoint_dir'] self.log_dir = Path(config_path).parent.parent / self.config['output']['log_dir'] self.checkpoint_dir.mkdir(exist_ok=True) self.log_dir.mkdir(exist_ok=True) # Auto-detect latest checkpoint if not restarting and no explicit checkpoint given if not restart and resume_from is None: resume_from = self._find_latest_checkpoint() if resume_from: print(f"✓ Found existing checkpoint: {resume_from}") print(" Will resume from this checkpoint (use --restart to start fresh)") self.resume_from = resume_from # Setup device self.device = self._setup_device() print(f"\nUsing device: {self.device}") # Create model print("\nInitializing model...") self.model = self._create_model() # Create dataloaders print("Loading data...") self.train_loader, self.val_loader = self._create_dataloaders() # Create optimizer and scheduler self.optimizer = self._create_optimizer() self.scheduler = self._create_scheduler() # Mixed precision scaler self.scaler = GradScaler() if self.config['training']['use_amp'] else None # Training state self.current_epoch = 0 self.global_step = 0 self.best_val_loss = float('inf') self.epochs_without_improvement = 0 # Logging self.log_file = self.log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" # Resume from checkpoint if specified if self.resume_from: self.load_checkpoint(self.resume_from) def _find_latest_checkpoint(self) -> str: """ Find the latest checkpoint in the checkpoint directory. Returns: Path to latest checkpoint or None if no checkpoints found """ if not self.checkpoint_dir.exists(): return None # Look for checkpoint files checkpoints = list(self.checkpoint_dir.glob("checkpoint_*.pt")) if not checkpoints: return None # Sort by modification time (most recent first) checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True) return str(checkpoints[0]) def _setup_device(self) -> torch.device: """Setup compute device.""" if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') def _create_model(self) -> MultimodalGlycanBERT: """Create multimodal model from config.""" # Extract model config model_cfg = self.config['model'] model_config = MultimodalGlycanBERTConfig( # Sequence config seq_vocab_size=model_cfg['sequence']['vocab_size'], seq_hidden_size=model_cfg['sequence']['hidden_size'], seq_num_layers=model_cfg['sequence']['num_hidden_layers'], seq_num_heads=model_cfg['sequence']['num_attention_heads'], seq_max_length=model_cfg['sequence']['max_length'], use_cnn_frontend=model_cfg['sequence'].get('use_cnn_frontend', True), cnn_kernel_size=model_cfg['sequence'].get('cnn_kernel_size', 3), # MS config ms_vocab_size=model_cfg['mass_spectrometry']['vocab_size'], ms_hidden_size=model_cfg['mass_spectrometry']['hidden_size'], ms_num_layers=model_cfg['mass_spectrometry']['num_hidden_layers'], ms_num_heads=model_cfg['mass_spectrometry']['num_attention_heads'], ms_max_length=model_cfg['mass_spectrometry']['max_length'], # Structure config struct_vocab_size=model_cfg['structure_3d']['vocab_size'], struct_hidden_size=model_cfg['structure_3d']['hidden_size'], struct_num_layers=model_cfg['structure_3d']['num_hidden_layers'], struct_num_heads=model_cfg['structure_3d']['num_attention_heads'], struct_max_length=model_cfg['structure_3d']['max_length'], use_cross_attention=model_cfg['structure_3d']['use_cross_attention'], # Fusion config fusion_hidden_size=model_cfg['fusion']['fusion_hidden_size'], fusion_num_layers=model_cfg['fusion']['fusion_num_layers'], # Loss weights seq_loss_weight=self.config['training']['loss_weights']['sequence'], dist_loss_weight=self.config['training']['loss_weights'].get('dist_loss_weight', 0.25), ms_loss_weight=self.config['training']['loss_weights']['ms'], struct_loss_weight=self.config['training']['loss_weights']['structure_3d'], # Common config hidden_dropout_prob=model_cfg['sequence']['hidden_dropout_prob'], attention_probs_dropout_prob=model_cfg['sequence']['attention_probs_dropout_prob'], layer_norm_eps=model_cfg['sequence']['layer_norm_eps'], pad_token_id=model_cfg['sequence']['pad_token_id'], mask_token_id=model_cfg['sequence']['mask_token_id'] ) model = MultimodalGlycanBERT(model_config) model.to(self.device) # Print model size total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model parameters: {total_params:,} total, {trainable_params:,} trainable") # Initialize dynamic loss weights (uncertainty-based) # Learn log(sigma^2) for each modality - weights = 1/(2*sigma^2) if self.config['training'].get('use_dynamic_loss', False): self.log_vars = nn.ParameterList([ nn.Parameter(torch.zeros(1, device=self.device)), # seq nn.Parameter(torch.zeros(1, device=self.device)), # ms nn.Parameter(torch.zeros(1, device=self.device)), # struct ]) print("Using dynamic loss weighting (uncertainty-based)") else: self.log_vars = None return model def _create_dataloaders(self): """Create train and validation dataloaders.""" base_path = Path(self.config_path).parent.parent train_loader, val_loader = create_multimodal_dataloaders( sequences_path=str(base_path / self.config['data']['sequences']), ms_tokens_path=str(base_path / self.config['data']['ms_tokens']), structure_data_path=str(base_path / self.config['data']['structure_data']), batch_size=self.config['training']['batch_size'], num_workers=self.config['hardware']['num_workers'], max_seq_length=self.config['model']['sequence']['max_length'], max_ms_length=self.config['model']['mass_spectrometry']['max_length'], max_struct_length=self.config['model']['structure_3d']['max_length'] ) return train_loader, val_loader def _create_optimizer(self) -> optim.Optimizer: """Create optimizer.""" return optim.AdamW( self.model.parameters(), lr=self.config['training']['learning_rate'], weight_decay=self.config['training']['weight_decay'], betas=(0.9, 0.999), eps=1e-8 ) def _create_scheduler(self): """Create learning rate scheduler with warmup.""" warmup_steps = self.config['training']['warmup_steps'] total_steps = len(self.train_loader) * self.config['training']['max_epochs'] def lr_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) def train_epoch(self, epoch: int): """Train for one epoch.""" self.model.train() total_loss = 0 total_seq_loss = 0 total_ms_loss = 0 total_struct_loss = 0 total_dist_loss = 0 num_batches = 0 # Create masking strategy model_cfg = self.config['model'] train_cfg = self.config['training'] # Load vocabulary to get special token IDs base_path = Path(self.config_path).parent.parent vocab_path = base_path / "data" / "vocabulary.json" with open(vocab_path, 'r') as f: vocab = json.load(f) # Get special token IDs from config special_tokens_to_skip = train_cfg.get('special_tokens_to_skip', []) seq_special_token_ids = [] for token_name in special_tokens_to_skip: token_id = vocab.get('special_tokens', {}).get(token_name) if token_id is not None: seq_special_token_ids.append(token_id) # Get ambiguous token IDs (x, X, ?, u, d, o) ambig_path = base_path / "data" / "ambiguity_tokens.json" seq_ambiguous_token_ids = [] if ambig_path.exists(): with open(ambig_path, 'r') as f: ambig_data = json.load(f) for token_name, token_id in ambig_data.get('ambiguous_tokens', {}).items(): seq_ambiguous_token_ids.append(token_id) masking_strategy = MultimodalMaskingStrategy( # Sequence masking seq_vocab_size=model_cfg['sequence']['vocab_size'], seq_mask_token_id=model_cfg['sequence']['mask_token_id'], seq_pad_token_id=model_cfg['sequence']['pad_token_id'], seq_special_token_ids=seq_special_token_ids, seq_ambiguous_token_ids=seq_ambiguous_token_ids, seq_mask_prob=train_cfg['mask_prob'], # MS masking ms_vocab_size=model_cfg['mass_spectrometry']['vocab_size'], ms_vocab_offset=model_cfg['mass_spectrometry']['vocab_offset'], ms_mask_token_id=model_cfg['sequence']['mask_token_id'], # Use same mask token ms_pad_token_id=model_cfg['sequence']['pad_token_id'], # Use same pad token ms_special_token_ids=[], ms_mask_prob=train_cfg['mask_prob'], # Structure masking struct_vocab_size=model_cfg['structure_3d']['vocab_size'], struct_mask_token_id=1, # VQ-VAE mask token struct_pad_token_id=0, # VQ-VAE pad token struct_special_token_ids=[], struct_mask_prob=train_cfg['mask_prob'], # Common parameters mask_token_prob=train_cfg.get('mask_token_prob', 0.8), random_token_prob=train_cfg.get('random_token_prob', 0.1), unchanged_prob=train_cfg.get('unchanged_prob', 0.1), ) total_loss = 0 total_seq_loss = 0 total_ms_loss = 0 total_struct_loss = 0 total_dist_loss = 0 num_batches = 0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config['training']['max_epochs']}") for batch in pbar: # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Apply masking masked_batch = masking_strategy.mask_multimodal_batch( seq_token_ids=batch['seq_token_ids'], ms_token_ids=batch['ms_token_ids'], has_ms=batch['has_ms'], struct_token_ids=batch['struct_token_ids'], has_3d=batch['has_3d'] ) # Merge masked results back into batch batch['seq_token_ids'] = masked_batch['seq_masked_ids'] batch['seq_labels'] = masked_batch['seq_labels'] batch['ms_token_ids'] = masked_batch['ms_masked_ids'] batch['ms_labels'] = masked_batch['ms_labels'] batch['struct_token_ids'] = masked_batch['struct_masked_ids'] batch['struct_labels'] = masked_batch['struct_labels'] # DEBUG: Print dist_labels info once if not hasattr(self, '_dist_batch_debug'): dl = batch.get('dist_labels') if dl is not None: valid = (dl != -1).sum().item() print(f"[TRAIN DEBUG] dist_labels in batch: shape={dl.shape}, valid_count={valid}") else: print("[TRAIN DEBUG] dist_labels is NOT in batch!") self._dist_batch_debug = True # Forward pass with mixed precision if self.scaler: with autocast(device_type='cuda'): outputs = self.model( seq_token_ids=batch['seq_token_ids'], seq_attention_mask=batch['seq_attention_mask'], seq_residue_ids=batch['seq_residue_ids'], seq_branch_depths=batch.get('seq_branch_depths'), # NEW seq_linkage_types=batch.get('seq_linkage_types'), # NEW ms_token_ids=batch.get('ms_token_ids'), ms_attention_mask=batch.get('ms_attention_mask'), struct_token_ids=batch.get('struct_token_ids'), struct_attention_mask=batch.get('struct_attention_mask'), struct_residue_ids=batch.get('struct_residue_ids'), has_ms=batch['has_ms'], has_3d=batch['has_3d'], seq_labels=batch['seq_labels'], ms_labels=batch.get('ms_labels'), struct_labels=batch.get('struct_labels'), dist_labels=batch.get('dist_labels') # Topology labels ) loss = outputs['loss'] # Backward pass self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['training']['max_grad_norm']) self.scaler.step(self.optimizer) self.scaler.update() else: outputs = self.model( seq_token_ids=batch['seq_token_ids'], seq_attention_mask=batch['seq_attention_mask'], seq_residue_ids=batch['seq_residue_ids'], seq_branch_depths=batch.get('seq_branch_depths'), # NEW seq_linkage_types=batch.get('seq_linkage_types'), # NEW ms_token_ids=batch.get('ms_token_ids'), ms_attention_mask=batch.get('ms_attention_mask'), struct_token_ids=batch.get('struct_token_ids'), struct_attention_mask=batch.get('struct_attention_mask'), struct_residue_ids=batch.get('struct_residue_ids'), has_ms=batch['has_ms'], has_3d=batch['has_3d'], seq_labels=batch['seq_labels'], ms_labels=batch.get('ms_labels'), struct_labels=batch.get('struct_labels'), dist_labels=batch.get('dist_labels') # NEW: Pass Topology Labels ) loss = outputs['loss'] # Backward pass self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['training']['max_grad_norm']) self.optimizer.step() self.scheduler.step() self.global_step += 1 # Accumulate losses total_loss += loss.item() seq_loss_val = outputs.get('seq_loss', 0) ms_loss_val = outputs.get('ms_loss', 0) struct_loss_val = outputs.get('struct_loss', 0) dist_loss_val = outputs.get('dist_loss') or 0 # Convert tensor losses to float if isinstance(seq_loss_val, torch.Tensor): seq_loss_val = seq_loss_val.item() if isinstance(ms_loss_val, torch.Tensor): ms_loss_val = ms_loss_val.item() if isinstance(struct_loss_val, torch.Tensor): struct_loss_val = struct_loss_val.item() if isinstance(dist_loss_val, torch.Tensor): dist_loss_val = dist_loss_val.item() total_seq_loss += seq_loss_val total_ms_loss += ms_loss_val total_struct_loss += struct_loss_val total_dist_loss += dist_loss_val num_batches += 1 # Update progress bar pbar.set_postfix({ 'loss': f"{loss.item():.4f}", 'seq': f"{seq_loss_val:.4f}", 'dist': f"{dist_loss_val:.4f}" if dist_loss_val > 0 else "-", 'ms': f"{ms_loss_val:.4f}" if ms_loss_val > 0 else "-", 'struct': f"{struct_loss_val:.4f}" if struct_loss_val > 0 else "-", 'lr': f"{self.scheduler.get_last_lr()[0]:.2e}" }) # Validate periodically if self.global_step % self.config['training']['validate_every_n_steps'] == 0: val_metrics = self.validate() self._log(f"Step {self.global_step} validation: {val_metrics}") self.model.train() avg_loss = total_loss / num_batches if num_batches > 0 else 0 avg_seq_loss = total_seq_loss / num_batches if num_batches > 0 else 0 avg_ms_loss = total_ms_loss / num_batches if num_batches > 0 else 0 avg_struct_loss = total_struct_loss / num_batches if num_batches > 0 else 0 avg_dist_loss = total_dist_loss / num_batches if num_batches > 0 else 0 return { 'loss': avg_loss, 'seq_loss': avg_seq_loss, 'ms_loss': avg_ms_loss, 'struct_loss': avg_struct_loss, 'dist_loss': avg_dist_loss } @torch.no_grad() def validate(self): """Validate on validation set.""" self.model.eval() total_loss = 0 total_seq_loss = 0 total_ms_loss = 0 total_struct_loss = 0 total_dist_loss = 0 num_batches = 0 # Create masking strategy model_cfg = self.config['model'] train_cfg = self.config['training'] # Load vocabulary to get special token IDs base_path = Path(self.config_path).parent.parent vocab_path = base_path / "data" / "vocabulary.json" with open(vocab_path, 'r') as f: vocab = json.load(f) # Get special token IDs from config special_tokens_to_skip = train_cfg.get('special_tokens_to_skip', []) seq_special_token_ids = [] for token_name in special_tokens_to_skip: token_id = vocab.get('special_tokens', {}).get(token_name) if token_id is not None: seq_special_token_ids.append(token_id) # Get ambiguous token IDs (x, X, ?, u, d, o) ambig_path = base_path / "data" / "ambiguity_tokens.json" seq_ambiguous_token_ids = [] if ambig_path.exists(): with open(ambig_path, 'r') as f: ambig_data = json.load(f) for token_name, token_id in ambig_data.get('ambiguous_tokens', {}).items(): seq_ambiguous_token_ids.append(token_id) masking_strategy = MultimodalMaskingStrategy( # Sequence masking seq_vocab_size=model_cfg['sequence']['vocab_size'], seq_mask_token_id=model_cfg['sequence']['mask_token_id'], seq_pad_token_id=model_cfg['sequence']['pad_token_id'], seq_special_token_ids=seq_special_token_ids, seq_ambiguous_token_ids=seq_ambiguous_token_ids, seq_mask_prob=train_cfg['mask_prob'], # MS masking ms_vocab_size=model_cfg['mass_spectrometry']['vocab_size'], ms_vocab_offset=model_cfg['mass_spectrometry']['vocab_offset'], ms_mask_token_id=model_cfg['sequence']['mask_token_id'], ms_pad_token_id=model_cfg['sequence']['pad_token_id'], ms_special_token_ids=[], ms_mask_prob=train_cfg['mask_prob'], # Structure masking struct_vocab_size=model_cfg['structure_3d']['vocab_size'], struct_mask_token_id=1, struct_pad_token_id=0, struct_special_token_ids=[], struct_mask_prob=train_cfg['mask_prob'], # Common parameters mask_token_prob=train_cfg.get('mask_token_prob', 0.8), random_token_prob=train_cfg.get('random_token_prob', 0.1), unchanged_prob=train_cfg.get('unchanged_prob', 0.1), ) for batch in tqdm(self.val_loader, desc="Validating", leave=False): # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Apply masking masked_batch = masking_strategy.mask_multimodal_batch( seq_token_ids=batch['seq_token_ids'], ms_token_ids=batch['ms_token_ids'], has_ms=batch['has_ms'], struct_token_ids=batch['struct_token_ids'], has_3d=batch['has_3d'] ) # Merge masked results back into batch batch['seq_token_ids'] = masked_batch['seq_masked_ids'] batch['seq_labels'] = masked_batch['seq_labels'] batch['ms_token_ids'] = masked_batch['ms_masked_ids'] batch['ms_labels'] = masked_batch['ms_labels'] batch['struct_token_ids'] = masked_batch['struct_masked_ids'] batch['struct_labels'] = masked_batch['struct_labels'] # Forward pass outputs = self.model( seq_token_ids=batch['seq_token_ids'], seq_attention_mask=batch['seq_attention_mask'], seq_residue_ids=batch['seq_residue_ids'], seq_branch_depths=batch.get('seq_branch_depths'), seq_linkage_types=batch.get('seq_linkage_types'), ms_token_ids=batch.get('ms_token_ids'), ms_attention_mask=batch.get('ms_attention_mask'), struct_token_ids=batch.get('struct_token_ids'), struct_attention_mask=batch.get('struct_attention_mask'), struct_residue_ids=batch.get('struct_residue_ids'), has_ms=batch['has_ms'], has_3d=batch['has_3d'], seq_labels=batch['seq_labels'], ms_labels=batch.get('ms_labels'), struct_labels=batch.get('struct_labels'), dist_labels=batch.get('dist_labels') ) total_loss += outputs['loss'].item() seq_loss_val = outputs.get('seq_loss', 0) ms_loss_val = outputs.get('ms_loss', 0) struct_loss_val = outputs.get('struct_loss', 0) dist_loss_val = outputs.get('dist_loss') or 0 # Convert tensor losses to float if isinstance(seq_loss_val, torch.Tensor): seq_loss_val = seq_loss_val.item() if isinstance(ms_loss_val, torch.Tensor): ms_loss_val = ms_loss_val.item() if isinstance(struct_loss_val, torch.Tensor): struct_loss_val = struct_loss_val.item() if isinstance(dist_loss_val, torch.Tensor): dist_loss_val = dist_loss_val.item() total_seq_loss += seq_loss_val total_ms_loss += ms_loss_val total_struct_loss += struct_loss_val total_dist_loss += dist_loss_val num_batches += 1 avg_loss = total_loss / num_batches if num_batches > 0 else 0 avg_seq_loss = total_seq_loss / num_batches if num_batches > 0 else 0 avg_ms_loss = total_ms_loss / num_batches if num_batches > 0 else 0 avg_struct_loss = total_struct_loss / num_batches if num_batches > 0 else 0 avg_dist_loss = total_dist_loss / num_batches if num_batches > 0 else 0 return { 'loss': avg_loss, 'seq_loss': avg_seq_loss, 'ms_loss': avg_ms_loss, 'struct_loss': avg_struct_loss, 'dist_loss': avg_dist_loss } def train(self): """Main training loop.""" print("\n" + "="*80) print("STARTING TRAINING") print("="*80) print(f"Epochs: {self.config['training']['max_epochs']}") print(f"Batch size: {self.config['training']['batch_size']}") print(f"Learning rate: {self.config['training']['learning_rate']}") print(f"Device: {self.device}") print(f"Mixed precision: {self.config['training']['use_amp']}") print(f"Checkpoints: {self.checkpoint_dir}") print(f"Logs: {self.log_dir}") print("="*80 + "\n") for epoch in range(self.current_epoch, self.config['training']['max_epochs']): self.current_epoch = epoch # Train epoch train_metrics = self.train_epoch(epoch) # Validate val_metrics = self.validate() # Log metrics print(f"\nEpoch {epoch+1} Summary:") print(f" Train Loss: {train_metrics['loss']:.4f} (seq: {train_metrics['seq_loss']:.4f}, ms: {train_metrics['ms_loss']:.4f}, struct: {train_metrics['struct_loss']:.4f})") print(f" Val Loss: {val_metrics['loss']:.4f} (seq: {val_metrics['seq_loss']:.4f}, ms: {val_metrics['ms_loss']:.4f}, struct: {val_metrics['struct_loss']:.4f})") print(f" Best Val Loss: {self.best_val_loss:.4f}") print(f" LR: {self.scheduler.get_last_lr()[0]:.2e}") self._log(f"Epoch {epoch+1}: Train={train_metrics}, Val={val_metrics}") # Check for improvement (track but don't save yet) val_loss = val_metrics['loss'] if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.epochs_without_improvement = 0 self._best_epoch = epoch + 1 # Track which epoch was best print(f"✓ New best! Val loss: {val_loss:.4f}") else: self.epochs_without_improvement += 1 print(f" No improvement for {self.epochs_without_improvement} epochs") # Early stopping if self.epochs_without_improvement >= self.config['training']['early_stopping_patience']: print(f"\nEarly stopping after {epoch+1} epochs (no improvement for {self.epochs_without_improvement} epochs)") # Save final checkpoint before stopping self.save_checkpoint(self.config['output']['best_model_path'], is_best=True) self.save_checkpoint(f"checkpoint_epoch_{epoch+1}.pt") break # Save checkpoints every 5 epochs if (epoch + 1) % 5 == 0: # Save best model if we've seen improvement in last 5 epochs self.save_checkpoint(self.config['output']['best_model_path'], is_best=True) # Save numbered checkpoint self.save_checkpoint(f"checkpoint_epoch_{epoch+1}.pt") print(f"✓ Saved checkpoints at epoch {epoch+1}") print("\n" + "="*80) print("TRAINING COMPLETE") print(f"Best validation loss: {self.best_val_loss:.4f}") print(f"Total epochs: {self.current_epoch + 1}") print(f"Total steps: {self.global_step}") print("="*80 + "\n") def save_checkpoint(self, filename: str, is_best: bool = False): """Save model checkpoint.""" checkpoint = { 'epoch': self.current_epoch, 'global_step': self.global_step, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_val_loss': self.best_val_loss, 'epochs_without_improvement': self.epochs_without_improvement, 'config': self.config } if self.scaler: checkpoint['scaler_state_dict'] = self.scaler.state_dict() save_path = self.checkpoint_dir / filename torch.save(checkpoint, save_path) if is_best: print(f"✓ Saved best model to {save_path}") else: print(f"✓ Saved checkpoint to {save_path}") def load_checkpoint(self, checkpoint_path: str): """ Load checkpoint and resume training. Args: checkpoint_path: Path to checkpoint file """ checkpoint_file = Path(checkpoint_path) if not checkpoint_file.exists(): print(f"✗ Checkpoint not found: {checkpoint_path}") print(" Starting training from scratch...") return print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_file, map_location=self.device) # Load model state self.model.load_state_dict(checkpoint['model_state_dict']) # Load optimizer state self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load scheduler state self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Load training state self.current_epoch = checkpoint['epoch'] self.global_step = checkpoint['global_step'] self.best_val_loss = checkpoint['best_val_loss'] self.epochs_without_improvement = checkpoint.get('epochs_without_improvement', 0) # Load scaler state if it exists if self.scaler and 'scaler_state_dict' in checkpoint: self.scaler.load_state_dict(checkpoint['scaler_state_dict']) print(f"✓ Resumed from epoch {self.current_epoch + 1}, step {self.global_step}") print(f" Best validation loss: {self.best_val_loss:.4f}") print(f" Epochs without improvement: {self.epochs_without_improvement}") def _log(self, message: str): """Log message to file and console.""" timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') log_message = f"[{timestamp}] {message}" with open(self.log_file, 'a') as f: f.write(log_message + '\n') def main(): """Main entry point.""" import argparse parser = argparse.ArgumentParser(description='Train Multimodal Glycan BERT v3') parser.add_argument('--config', type=str, default='model/multimodal_config.yaml', help='Path to config file') parser.add_argument('--restart', action='store_true', help='Start training from scratch, ignoring any existing checkpoints') parser.add_argument('--resume', type=str, default=None, help='Path to specific checkpoint to resume from (overrides auto-detection)') args = parser.parse_args() config_path = Path(__file__).parent.parent / args.config if not config_path.exists(): print(f"Error: Config file not found: {config_path}") sys.exit(1) trainer = MultimodalTrainer(config_path, resume_from=args.resume, restart=args.restart) trainer.train() if __name__ == '__main__': main()