| |
| """ |
| Training Script for Multimodal Glycan BERT v3 |
| |
| Trains a multimodal transformer on glycan sequences, MS spectra, and 3D structures. |
| Supports automatic checkpointing and resuming from interruptions. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.amp import autocast, GradScaler |
| import yaml |
| import json |
| import sys |
| import argparse |
| from pathlib import Path |
| from tqdm import tqdm |
| from datetime import datetime |
| import math |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent.absolute())) |
|
|
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
| from training.multimodal_dataset import MultimodalGlycanDataset, create_multimodal_dataloaders |
| from training.multimodal_masking import MultimodalMaskingStrategy |
|
|
|
|
| class MultimodalTrainer: |
| """ |
| Trainer for Multimodal Glycan BERT v3. |
| |
| Features: |
| - Automatic checkpointing every N steps and epochs |
| - Resume from any checkpoint |
| - Detailed progress tracking per modality |
| - Early stopping |
| - Mixed precision training |
| """ |
| |
| def __init__(self, config_path: Path, resume_from: str = None, restart: bool = False): |
| """ |
| Initialize trainer. |
| |
| Args: |
| config_path: Path to multimodal_config.yaml |
| resume_from: Path to checkpoint to resume from (optional, auto-detects if None) |
| restart: If True, ignore any existing checkpoints and start fresh |
| """ |
| |
| with open(config_path, 'r') as f: |
| self.config = yaml.safe_load(f) |
| |
| self.config_path = config_path |
| |
| |
| self.checkpoint_dir = Path(config_path).parent.parent / self.config['output']['checkpoint_dir'] |
| self.log_dir = Path(config_path).parent.parent / self.config['output']['log_dir'] |
| self.checkpoint_dir.mkdir(exist_ok=True) |
| self.log_dir.mkdir(exist_ok=True) |
| |
| |
| if not restart and resume_from is None: |
| resume_from = self._find_latest_checkpoint() |
| if resume_from: |
| print(f"✓ Found existing checkpoint: {resume_from}") |
| print(" Will resume from this checkpoint (use --restart to start fresh)") |
| |
| self.resume_from = resume_from |
| |
| |
| self.device = self._setup_device() |
| print(f"\nUsing device: {self.device}") |
| |
| |
| print("\nInitializing model...") |
| self.model = self._create_model() |
| |
| |
| print("Loading data...") |
| self.train_loader, self.val_loader = self._create_dataloaders() |
| |
| |
| self.optimizer = self._create_optimizer() |
| self.scheduler = self._create_scheduler() |
| |
| |
| self.scaler = GradScaler() if self.config['training']['use_amp'] else None |
| |
| |
| self.current_epoch = 0 |
| self.global_step = 0 |
| self.best_val_loss = float('inf') |
| self.epochs_without_improvement = 0 |
| |
| |
| self.log_file = self.log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" |
| |
| |
| if self.resume_from: |
| self.load_checkpoint(self.resume_from) |
| |
| def _find_latest_checkpoint(self) -> str: |
| """ |
| Find the latest checkpoint in the checkpoint directory. |
| |
| Returns: |
| Path to latest checkpoint or None if no checkpoints found |
| """ |
| if not self.checkpoint_dir.exists(): |
| return None |
| |
| |
| checkpoints = list(self.checkpoint_dir.glob("checkpoint_*.pt")) |
| |
| if not checkpoints: |
| return None |
| |
| |
| checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True) |
| |
| return str(checkpoints[0]) |
| |
| def _setup_device(self) -> torch.device: |
| """Setup compute device.""" |
| if torch.cuda.is_available(): |
| return torch.device('cuda') |
| else: |
| return torch.device('cpu') |
| |
| def _create_model(self) -> MultimodalGlycanBERT: |
| """Create multimodal model from config.""" |
| |
| model_cfg = self.config['model'] |
| |
| model_config = MultimodalGlycanBERTConfig( |
| |
| seq_vocab_size=model_cfg['sequence']['vocab_size'], |
| seq_hidden_size=model_cfg['sequence']['hidden_size'], |
| seq_num_layers=model_cfg['sequence']['num_hidden_layers'], |
| seq_num_heads=model_cfg['sequence']['num_attention_heads'], |
| seq_max_length=model_cfg['sequence']['max_length'], |
| use_cnn_frontend=model_cfg['sequence'].get('use_cnn_frontend', True), |
| cnn_kernel_size=model_cfg['sequence'].get('cnn_kernel_size', 3), |
| |
| |
| ms_vocab_size=model_cfg['mass_spectrometry']['vocab_size'], |
| ms_hidden_size=model_cfg['mass_spectrometry']['hidden_size'], |
| ms_num_layers=model_cfg['mass_spectrometry']['num_hidden_layers'], |
| ms_num_heads=model_cfg['mass_spectrometry']['num_attention_heads'], |
| ms_max_length=model_cfg['mass_spectrometry']['max_length'], |
| |
| |
| struct_vocab_size=model_cfg['structure_3d']['vocab_size'], |
| struct_hidden_size=model_cfg['structure_3d']['hidden_size'], |
| struct_num_layers=model_cfg['structure_3d']['num_hidden_layers'], |
| struct_num_heads=model_cfg['structure_3d']['num_attention_heads'], |
| struct_max_length=model_cfg['structure_3d']['max_length'], |
| use_cross_attention=model_cfg['structure_3d']['use_cross_attention'], |
| |
| |
| fusion_hidden_size=model_cfg['fusion']['fusion_hidden_size'], |
| fusion_num_layers=model_cfg['fusion']['fusion_num_layers'], |
| |
| |
| seq_loss_weight=self.config['training']['loss_weights']['sequence'], |
| dist_loss_weight=self.config['training']['loss_weights'].get('dist_loss_weight', 0.25), |
| ms_loss_weight=self.config['training']['loss_weights']['ms'], |
| struct_loss_weight=self.config['training']['loss_weights']['structure_3d'], |
| |
| |
| hidden_dropout_prob=model_cfg['sequence']['hidden_dropout_prob'], |
| attention_probs_dropout_prob=model_cfg['sequence']['attention_probs_dropout_prob'], |
| layer_norm_eps=model_cfg['sequence']['layer_norm_eps'], |
| pad_token_id=model_cfg['sequence']['pad_token_id'], |
| mask_token_id=model_cfg['sequence']['mask_token_id'] |
| ) |
| |
| model = MultimodalGlycanBERT(model_config) |
| model.to(self.device) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Model parameters: {total_params:,} total, {trainable_params:,} trainable") |
| |
| |
| |
| if self.config['training'].get('use_dynamic_loss', False): |
| self.log_vars = nn.ParameterList([ |
| nn.Parameter(torch.zeros(1, device=self.device)), |
| nn.Parameter(torch.zeros(1, device=self.device)), |
| nn.Parameter(torch.zeros(1, device=self.device)), |
| ]) |
| print("Using dynamic loss weighting (uncertainty-based)") |
| else: |
| self.log_vars = None |
| |
| return model |
| |
| def _create_dataloaders(self): |
| """Create train and validation dataloaders.""" |
| base_path = Path(self.config_path).parent.parent |
| |
| train_loader, val_loader = create_multimodal_dataloaders( |
| sequences_path=str(base_path / self.config['data']['sequences']), |
| ms_tokens_path=str(base_path / self.config['data']['ms_tokens']), |
| structure_data_path=str(base_path / self.config['data']['structure_data']), |
| batch_size=self.config['training']['batch_size'], |
| num_workers=self.config['hardware']['num_workers'], |
| max_seq_length=self.config['model']['sequence']['max_length'], |
| max_ms_length=self.config['model']['mass_spectrometry']['max_length'], |
| max_struct_length=self.config['model']['structure_3d']['max_length'] |
| ) |
| |
| return train_loader, val_loader |
| |
| def _create_optimizer(self) -> optim.Optimizer: |
| """Create optimizer.""" |
| return optim.AdamW( |
| self.model.parameters(), |
| lr=self.config['training']['learning_rate'], |
| weight_decay=self.config['training']['weight_decay'], |
| betas=(0.9, 0.999), |
| eps=1e-8 |
| ) |
| |
| def _create_scheduler(self): |
| """Create learning rate scheduler with warmup.""" |
| warmup_steps = self.config['training']['warmup_steps'] |
| total_steps = len(self.train_loader) * self.config['training']['max_epochs'] |
| |
| def lr_lambda(current_step): |
| if current_step < warmup_steps: |
| return float(current_step) / float(max(1, warmup_steps)) |
| progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) |
| |
| return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) |
| |
| def train_epoch(self, epoch: int): |
| """Train for one epoch.""" |
| self.model.train() |
| total_loss = 0 |
| total_seq_loss = 0 |
| total_ms_loss = 0 |
| total_struct_loss = 0 |
| total_dist_loss = 0 |
| num_batches = 0 |
| |
| |
| model_cfg = self.config['model'] |
| train_cfg = self.config['training'] |
| |
| |
| base_path = Path(self.config_path).parent.parent |
| vocab_path = base_path / "data" / "vocabulary.json" |
| with open(vocab_path, 'r') as f: |
| vocab = json.load(f) |
| |
| |
| special_tokens_to_skip = train_cfg.get('special_tokens_to_skip', []) |
| seq_special_token_ids = [] |
| for token_name in special_tokens_to_skip: |
| token_id = vocab.get('special_tokens', {}).get(token_name) |
| if token_id is not None: |
| seq_special_token_ids.append(token_id) |
| |
| |
| ambig_path = base_path / "data" / "ambiguity_tokens.json" |
| seq_ambiguous_token_ids = [] |
| if ambig_path.exists(): |
| with open(ambig_path, 'r') as f: |
| ambig_data = json.load(f) |
| for token_name, token_id in ambig_data.get('ambiguous_tokens', {}).items(): |
| seq_ambiguous_token_ids.append(token_id) |
| |
| masking_strategy = MultimodalMaskingStrategy( |
| |
| seq_vocab_size=model_cfg['sequence']['vocab_size'], |
| seq_mask_token_id=model_cfg['sequence']['mask_token_id'], |
| seq_pad_token_id=model_cfg['sequence']['pad_token_id'], |
| seq_special_token_ids=seq_special_token_ids, |
| seq_ambiguous_token_ids=seq_ambiguous_token_ids, |
| seq_mask_prob=train_cfg['mask_prob'], |
| |
| |
| ms_vocab_size=model_cfg['mass_spectrometry']['vocab_size'], |
| ms_vocab_offset=model_cfg['mass_spectrometry']['vocab_offset'], |
| ms_mask_token_id=model_cfg['sequence']['mask_token_id'], |
| ms_pad_token_id=model_cfg['sequence']['pad_token_id'], |
| ms_special_token_ids=[], |
| ms_mask_prob=train_cfg['mask_prob'], |
| |
| |
| struct_vocab_size=model_cfg['structure_3d']['vocab_size'], |
| struct_mask_token_id=1, |
| struct_pad_token_id=0, |
| struct_special_token_ids=[], |
| struct_mask_prob=train_cfg['mask_prob'], |
| |
| |
| mask_token_prob=train_cfg.get('mask_token_prob', 0.8), |
| random_token_prob=train_cfg.get('random_token_prob', 0.1), |
| unchanged_prob=train_cfg.get('unchanged_prob', 0.1), |
| ) |
| |
| total_loss = 0 |
| total_seq_loss = 0 |
| total_ms_loss = 0 |
| total_struct_loss = 0 |
| total_dist_loss = 0 |
| num_batches = 0 |
| |
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config['training']['max_epochs']}") |
| |
| for batch in pbar: |
| |
| batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| |
| masked_batch = masking_strategy.mask_multimodal_batch( |
| seq_token_ids=batch['seq_token_ids'], |
| ms_token_ids=batch['ms_token_ids'], |
| has_ms=batch['has_ms'], |
| struct_token_ids=batch['struct_token_ids'], |
| has_3d=batch['has_3d'] |
| ) |
| |
| |
| batch['seq_token_ids'] = masked_batch['seq_masked_ids'] |
| batch['seq_labels'] = masked_batch['seq_labels'] |
| batch['ms_token_ids'] = masked_batch['ms_masked_ids'] |
| batch['ms_labels'] = masked_batch['ms_labels'] |
| batch['struct_token_ids'] = masked_batch['struct_masked_ids'] |
| batch['struct_labels'] = masked_batch['struct_labels'] |
| |
| |
| if not hasattr(self, '_dist_batch_debug'): |
| dl = batch.get('dist_labels') |
| if dl is not None: |
| valid = (dl != -1).sum().item() |
| print(f"[TRAIN DEBUG] dist_labels in batch: shape={dl.shape}, valid_count={valid}") |
| else: |
| print("[TRAIN DEBUG] dist_labels is NOT in batch!") |
| self._dist_batch_debug = True |
| |
| |
| if self.scaler: |
| with autocast(device_type='cuda'): |
| outputs = self.model( |
| seq_token_ids=batch['seq_token_ids'], |
| seq_attention_mask=batch['seq_attention_mask'], |
| seq_residue_ids=batch['seq_residue_ids'], |
| seq_branch_depths=batch.get('seq_branch_depths'), |
| seq_linkage_types=batch.get('seq_linkage_types'), |
| ms_token_ids=batch.get('ms_token_ids'), |
| ms_attention_mask=batch.get('ms_attention_mask'), |
| struct_token_ids=batch.get('struct_token_ids'), |
| struct_attention_mask=batch.get('struct_attention_mask'), |
| struct_residue_ids=batch.get('struct_residue_ids'), |
| has_ms=batch['has_ms'], |
| has_3d=batch['has_3d'], |
| seq_labels=batch['seq_labels'], |
| ms_labels=batch.get('ms_labels'), |
| struct_labels=batch.get('struct_labels'), |
| dist_labels=batch.get('dist_labels') |
| ) |
| loss = outputs['loss'] |
| |
| |
| self.optimizer.zero_grad() |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['training']['max_grad_norm']) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| outputs = self.model( |
| seq_token_ids=batch['seq_token_ids'], |
| seq_attention_mask=batch['seq_attention_mask'], |
| seq_residue_ids=batch['seq_residue_ids'], |
| seq_branch_depths=batch.get('seq_branch_depths'), |
| seq_linkage_types=batch.get('seq_linkage_types'), |
| ms_token_ids=batch.get('ms_token_ids'), |
| ms_attention_mask=batch.get('ms_attention_mask'), |
| struct_token_ids=batch.get('struct_token_ids'), |
| struct_attention_mask=batch.get('struct_attention_mask'), |
| struct_residue_ids=batch.get('struct_residue_ids'), |
| has_ms=batch['has_ms'], |
| has_3d=batch['has_3d'], |
| seq_labels=batch['seq_labels'], |
| ms_labels=batch.get('ms_labels'), |
| struct_labels=batch.get('struct_labels'), |
| dist_labels=batch.get('dist_labels') |
| ) |
| loss = outputs['loss'] |
| |
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['training']['max_grad_norm']) |
| self.optimizer.step() |
| |
| self.scheduler.step() |
| self.global_step += 1 |
| |
| |
| total_loss += loss.item() |
| seq_loss_val = outputs.get('seq_loss', 0) |
| ms_loss_val = outputs.get('ms_loss', 0) |
| struct_loss_val = outputs.get('struct_loss', 0) |
| dist_loss_val = outputs.get('dist_loss') or 0 |
| |
| |
| if isinstance(seq_loss_val, torch.Tensor): |
| seq_loss_val = seq_loss_val.item() |
| if isinstance(ms_loss_val, torch.Tensor): |
| ms_loss_val = ms_loss_val.item() |
| if isinstance(struct_loss_val, torch.Tensor): |
| struct_loss_val = struct_loss_val.item() |
| if isinstance(dist_loss_val, torch.Tensor): |
| dist_loss_val = dist_loss_val.item() |
| |
| total_seq_loss += seq_loss_val |
| total_ms_loss += ms_loss_val |
| total_struct_loss += struct_loss_val |
| total_dist_loss += dist_loss_val |
| num_batches += 1 |
| |
| |
| pbar.set_postfix({ |
| 'loss': f"{loss.item():.4f}", |
| 'seq': f"{seq_loss_val:.4f}", |
| 'dist': f"{dist_loss_val:.4f}" if dist_loss_val > 0 else "-", |
| 'ms': f"{ms_loss_val:.4f}" if ms_loss_val > 0 else "-", |
| 'struct': f"{struct_loss_val:.4f}" if struct_loss_val > 0 else "-", |
| 'lr': f"{self.scheduler.get_last_lr()[0]:.2e}" |
| }) |
| |
| |
| if self.global_step % self.config['training']['validate_every_n_steps'] == 0: |
| val_metrics = self.validate() |
| self._log(f"Step {self.global_step} validation: {val_metrics}") |
| self.model.train() |
| |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
| avg_seq_loss = total_seq_loss / num_batches if num_batches > 0 else 0 |
| avg_ms_loss = total_ms_loss / num_batches if num_batches > 0 else 0 |
| avg_struct_loss = total_struct_loss / num_batches if num_batches > 0 else 0 |
| avg_dist_loss = total_dist_loss / num_batches if num_batches > 0 else 0 |
| |
| return { |
| 'loss': avg_loss, |
| 'seq_loss': avg_seq_loss, |
| 'ms_loss': avg_ms_loss, |
| 'struct_loss': avg_struct_loss, |
| 'dist_loss': avg_dist_loss |
| } |
| |
| @torch.no_grad() |
| def validate(self): |
| """Validate on validation set.""" |
| self.model.eval() |
| total_loss = 0 |
| total_seq_loss = 0 |
| total_ms_loss = 0 |
| total_struct_loss = 0 |
| total_dist_loss = 0 |
| num_batches = 0 |
| |
| |
| model_cfg = self.config['model'] |
| train_cfg = self.config['training'] |
| |
| |
| base_path = Path(self.config_path).parent.parent |
| vocab_path = base_path / "data" / "vocabulary.json" |
| with open(vocab_path, 'r') as f: |
| vocab = json.load(f) |
| |
| |
| special_tokens_to_skip = train_cfg.get('special_tokens_to_skip', []) |
| seq_special_token_ids = [] |
| for token_name in special_tokens_to_skip: |
| token_id = vocab.get('special_tokens', {}).get(token_name) |
| if token_id is not None: |
| seq_special_token_ids.append(token_id) |
| |
| |
| ambig_path = base_path / "data" / "ambiguity_tokens.json" |
| seq_ambiguous_token_ids = [] |
| if ambig_path.exists(): |
| with open(ambig_path, 'r') as f: |
| ambig_data = json.load(f) |
| for token_name, token_id in ambig_data.get('ambiguous_tokens', {}).items(): |
| seq_ambiguous_token_ids.append(token_id) |
| |
| masking_strategy = MultimodalMaskingStrategy( |
| |
| seq_vocab_size=model_cfg['sequence']['vocab_size'], |
| seq_mask_token_id=model_cfg['sequence']['mask_token_id'], |
| seq_pad_token_id=model_cfg['sequence']['pad_token_id'], |
| seq_special_token_ids=seq_special_token_ids, |
| seq_ambiguous_token_ids=seq_ambiguous_token_ids, |
| seq_mask_prob=train_cfg['mask_prob'], |
| |
| |
| ms_vocab_size=model_cfg['mass_spectrometry']['vocab_size'], |
| ms_vocab_offset=model_cfg['mass_spectrometry']['vocab_offset'], |
| ms_mask_token_id=model_cfg['sequence']['mask_token_id'], |
| ms_pad_token_id=model_cfg['sequence']['pad_token_id'], |
| ms_special_token_ids=[], |
| ms_mask_prob=train_cfg['mask_prob'], |
| |
| |
| struct_vocab_size=model_cfg['structure_3d']['vocab_size'], |
| struct_mask_token_id=1, |
| struct_pad_token_id=0, |
| struct_special_token_ids=[], |
| struct_mask_prob=train_cfg['mask_prob'], |
| |
| |
| mask_token_prob=train_cfg.get('mask_token_prob', 0.8), |
| random_token_prob=train_cfg.get('random_token_prob', 0.1), |
| unchanged_prob=train_cfg.get('unchanged_prob', 0.1), |
| ) |
| |
| for batch in tqdm(self.val_loader, desc="Validating", leave=False): |
| |
| batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| |
| masked_batch = masking_strategy.mask_multimodal_batch( |
| seq_token_ids=batch['seq_token_ids'], |
| ms_token_ids=batch['ms_token_ids'], |
| has_ms=batch['has_ms'], |
| struct_token_ids=batch['struct_token_ids'], |
| has_3d=batch['has_3d'] |
| ) |
| |
| |
| batch['seq_token_ids'] = masked_batch['seq_masked_ids'] |
| batch['seq_labels'] = masked_batch['seq_labels'] |
| batch['ms_token_ids'] = masked_batch['ms_masked_ids'] |
| batch['ms_labels'] = masked_batch['ms_labels'] |
| batch['struct_token_ids'] = masked_batch['struct_masked_ids'] |
| batch['struct_labels'] = masked_batch['struct_labels'] |
| |
| |
| outputs = self.model( |
| seq_token_ids=batch['seq_token_ids'], |
| seq_attention_mask=batch['seq_attention_mask'], |
| seq_residue_ids=batch['seq_residue_ids'], |
| seq_branch_depths=batch.get('seq_branch_depths'), |
| seq_linkage_types=batch.get('seq_linkage_types'), |
| ms_token_ids=batch.get('ms_token_ids'), |
| ms_attention_mask=batch.get('ms_attention_mask'), |
| struct_token_ids=batch.get('struct_token_ids'), |
| struct_attention_mask=batch.get('struct_attention_mask'), |
| struct_residue_ids=batch.get('struct_residue_ids'), |
| has_ms=batch['has_ms'], |
| has_3d=batch['has_3d'], |
| seq_labels=batch['seq_labels'], |
| ms_labels=batch.get('ms_labels'), |
| struct_labels=batch.get('struct_labels'), |
| dist_labels=batch.get('dist_labels') |
| ) |
| |
| total_loss += outputs['loss'].item() |
| |
| seq_loss_val = outputs.get('seq_loss', 0) |
| ms_loss_val = outputs.get('ms_loss', 0) |
| struct_loss_val = outputs.get('struct_loss', 0) |
| dist_loss_val = outputs.get('dist_loss') or 0 |
| |
| |
| if isinstance(seq_loss_val, torch.Tensor): |
| seq_loss_val = seq_loss_val.item() |
| if isinstance(ms_loss_val, torch.Tensor): |
| ms_loss_val = ms_loss_val.item() |
| if isinstance(struct_loss_val, torch.Tensor): |
| struct_loss_val = struct_loss_val.item() |
| if isinstance(dist_loss_val, torch.Tensor): |
| dist_loss_val = dist_loss_val.item() |
| |
| total_seq_loss += seq_loss_val |
| total_ms_loss += ms_loss_val |
| total_struct_loss += struct_loss_val |
| total_dist_loss += dist_loss_val |
| num_batches += 1 |
| |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0 |
| avg_seq_loss = total_seq_loss / num_batches if num_batches > 0 else 0 |
| avg_ms_loss = total_ms_loss / num_batches if num_batches > 0 else 0 |
| avg_struct_loss = total_struct_loss / num_batches if num_batches > 0 else 0 |
| avg_dist_loss = total_dist_loss / num_batches if num_batches > 0 else 0 |
| |
| return { |
| 'loss': avg_loss, |
| 'seq_loss': avg_seq_loss, |
| 'ms_loss': avg_ms_loss, |
| 'struct_loss': avg_struct_loss, |
| 'dist_loss': avg_dist_loss |
| } |
|
|
| |
| def train(self): |
| """Main training loop.""" |
| print("\n" + "="*80) |
| print("STARTING TRAINING") |
| print("="*80) |
| print(f"Epochs: {self.config['training']['max_epochs']}") |
| print(f"Batch size: {self.config['training']['batch_size']}") |
| print(f"Learning rate: {self.config['training']['learning_rate']}") |
| print(f"Device: {self.device}") |
| print(f"Mixed precision: {self.config['training']['use_amp']}") |
| print(f"Checkpoints: {self.checkpoint_dir}") |
| print(f"Logs: {self.log_dir}") |
| print("="*80 + "\n") |
| |
| for epoch in range(self.current_epoch, self.config['training']['max_epochs']): |
| self.current_epoch = epoch |
| |
| |
| train_metrics = self.train_epoch(epoch) |
| |
| |
| val_metrics = self.validate() |
| |
| |
| print(f"\nEpoch {epoch+1} Summary:") |
| print(f" Train Loss: {train_metrics['loss']:.4f} (seq: {train_metrics['seq_loss']:.4f}, ms: {train_metrics['ms_loss']:.4f}, struct: {train_metrics['struct_loss']:.4f})") |
| print(f" Val Loss: {val_metrics['loss']:.4f} (seq: {val_metrics['seq_loss']:.4f}, ms: {val_metrics['ms_loss']:.4f}, struct: {val_metrics['struct_loss']:.4f})") |
| print(f" Best Val Loss: {self.best_val_loss:.4f}") |
| print(f" LR: {self.scheduler.get_last_lr()[0]:.2e}") |
| |
| self._log(f"Epoch {epoch+1}: Train={train_metrics}, Val={val_metrics}") |
| |
| |
| val_loss = val_metrics['loss'] |
| if val_loss < self.best_val_loss: |
| self.best_val_loss = val_loss |
| self.epochs_without_improvement = 0 |
| self._best_epoch = epoch + 1 |
| print(f"✓ New best! Val loss: {val_loss:.4f}") |
| else: |
| self.epochs_without_improvement += 1 |
| print(f" No improvement for {self.epochs_without_improvement} epochs") |
| |
| |
| if self.epochs_without_improvement >= self.config['training']['early_stopping_patience']: |
| print(f"\nEarly stopping after {epoch+1} epochs (no improvement for {self.epochs_without_improvement} epochs)") |
| |
| self.save_checkpoint(self.config['output']['best_model_path'], is_best=True) |
| self.save_checkpoint(f"checkpoint_epoch_{epoch+1}.pt") |
| break |
| |
| |
| if (epoch + 1) % 5 == 0: |
| |
| self.save_checkpoint(self.config['output']['best_model_path'], is_best=True) |
| |
| self.save_checkpoint(f"checkpoint_epoch_{epoch+1}.pt") |
| print(f"✓ Saved checkpoints at epoch {epoch+1}") |
| |
| print("\n" + "="*80) |
| print("TRAINING COMPLETE") |
| print(f"Best validation loss: {self.best_val_loss:.4f}") |
| print(f"Total epochs: {self.current_epoch + 1}") |
| print(f"Total steps: {self.global_step}") |
| print("="*80 + "\n") |
| |
| def save_checkpoint(self, filename: str, is_best: bool = False): |
| """Save model checkpoint.""" |
| checkpoint = { |
| 'epoch': self.current_epoch, |
| 'global_step': self.global_step, |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'scheduler_state_dict': self.scheduler.state_dict(), |
| 'best_val_loss': self.best_val_loss, |
| 'epochs_without_improvement': self.epochs_without_improvement, |
| 'config': self.config |
| } |
| |
| if self.scaler: |
| checkpoint['scaler_state_dict'] = self.scaler.state_dict() |
| |
| save_path = self.checkpoint_dir / filename |
| torch.save(checkpoint, save_path) |
| |
| if is_best: |
| print(f"✓ Saved best model to {save_path}") |
| else: |
| print(f"✓ Saved checkpoint to {save_path}") |
| |
| def load_checkpoint(self, checkpoint_path: str): |
| """ |
| Load checkpoint and resume training. |
| |
| Args: |
| checkpoint_path: Path to checkpoint file |
| """ |
| checkpoint_file = Path(checkpoint_path) |
| |
| if not checkpoint_file.exists(): |
| print(f"✗ Checkpoint not found: {checkpoint_path}") |
| print(" Starting training from scratch...") |
| return |
| |
| print(f"Loading checkpoint from {checkpoint_path}...") |
| checkpoint = torch.load(checkpoint_file, map_location=self.device) |
| |
| |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| |
| |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| |
| |
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| |
| |
| self.current_epoch = checkpoint['epoch'] |
| self.global_step = checkpoint['global_step'] |
| self.best_val_loss = checkpoint['best_val_loss'] |
| self.epochs_without_improvement = checkpoint.get('epochs_without_improvement', 0) |
| |
| |
| if self.scaler and 'scaler_state_dict' in checkpoint: |
| self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
| |
| print(f"✓ Resumed from epoch {self.current_epoch + 1}, step {self.global_step}") |
| print(f" Best validation loss: {self.best_val_loss:.4f}") |
| print(f" Epochs without improvement: {self.epochs_without_improvement}") |
| |
| def _log(self, message: str): |
| """Log message to file and console.""" |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| log_message = f"[{timestamp}] {message}" |
| |
| with open(self.log_file, 'a') as f: |
| f.write(log_message + '\n') |
|
|
|
|
| def main(): |
| """Main entry point.""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Train Multimodal Glycan BERT v3') |
| parser.add_argument('--config', type=str, default='model/multimodal_config.yaml', |
| help='Path to config file') |
| parser.add_argument('--restart', action='store_true', |
| help='Start training from scratch, ignoring any existing checkpoints') |
| parser.add_argument('--resume', type=str, default=None, |
| help='Path to specific checkpoint to resume from (overrides auto-detection)') |
| args = parser.parse_args() |
| |
| config_path = Path(__file__).parent.parent / args.config |
| |
| if not config_path.exists(): |
| print(f"Error: Config file not found: {config_path}") |
| sys.exit(1) |
| |
| trainer = MultimodalTrainer(config_path, resume_from=args.resume, restart=args.restart) |
| trainer.train() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|