| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| | import numpy as np |
| | from tqdm import tqdm |
| | import json |
| | import os |
| | import argparse |
| | import time |
| | from torch.cuda.amp import autocast, GradScaler |
| | import wandb |
| |
|
| | |
| | from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset |
| | from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding |
| | from cfg_dataset import CFGFlowDataset, create_cfg_dataloader |
| |
|
| | |
| | ESM_DIM = 1280 |
| | COMP_RATIO = 16 |
| | COMP_DIM = ESM_DIM // COMP_RATIO |
| | MAX_SEQ_LEN = 50 |
| |
|
| | |
| | BATCH_SIZE = 512 |
| | EPOCHS = 2000 |
| | BASE_LR = 8e-4 |
| | LR_MIN = 4e-4 |
| | WARMUP_STEPS = 4000 |
| | GPU_ID = 0 |
| |
|
| | |
| | USE_MIXED_PRECISION = True |
| | GRADIENT_CLIP_NORM = 0.5 |
| | WEIGHT_DECAY = 0.01 |
| | VALIDATION_INTERVAL = 5000 |
| | CHECKPOINT_INTERVAL = 300 |
| | NUM_WORKERS = 32 |
| |
|
| | |
| | CFG_DROPOUT_RATE = 0.15 |
| |
|
| | class AMPFlowTrainerSingleGPUFullData: |
| | """ |
| | Optimized Single GPU training pipeline for AMP generation using ProtFlow methodology. |
| | Uses ALL available data with H100-optimized settings for overnight training. |
| | """ |
| | |
| | def __init__(self, embeddings_path, cfg_data_path, use_wandb=False): |
| | self.device = torch.device(f'cuda:{GPU_ID}') |
| | self.embeddings_path = embeddings_path |
| | self.cfg_data_path = cfg_data_path |
| | self.use_wandb = use_wandb |
| | |
| | |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | |
| | print(f"Using GPU {GPU_ID} for optimized H100 training") |
| | print(f"Mixed precision: {USE_MIXED_PRECISION}") |
| | print(f"Batch size: {BATCH_SIZE}") |
| | print(f"Target epochs: {EPOCHS}") |
| | print(f"Learning rate: {BASE_LR} -> {LR_MIN}") |
| | |
| | |
| | if USE_MIXED_PRECISION: |
| | self.scaler = GradScaler() |
| | print("✓ Mixed precision training enabled (BF16)") |
| | |
| | |
| | if self.use_wandb: |
| | wandb.init( |
| | project="amp-flow-training", |
| | config={ |
| | "batch_size": BATCH_SIZE, |
| | "epochs": EPOCHS, |
| | "base_lr": BASE_LR, |
| | "lr_min": LR_MIN, |
| | "warmup_steps": WARMUP_STEPS, |
| | "mixed_precision": USE_MIXED_PRECISION, |
| | "gradient_clip": GRADIENT_CLIP_NORM, |
| | "weight_decay": WEIGHT_DECAY |
| | } |
| | ) |
| | |
| | print(f"Loading ALL AMP embeddings from {embeddings_path}...") |
| | |
| | |
| | self._load_all_embeddings() |
| | |
| | |
| | print("Computing preprocessing statistics...") |
| | self._compute_preprocessing_stats() |
| | |
| | |
| | self._initialize_models() |
| | |
| | |
| | self._initialize_data() |
| | |
| | |
| | self._initialize_optimizer() |
| | |
| | print("✓ Optimized Single GPU training setup complete with FULL DATA!") |
| | |
| | def _load_all_embeddings(self): |
| | """Load ALL peptide embeddings from the combined file.""" |
| | |
| | combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") |
| | |
| | if os.path.exists(combined_path): |
| | print(f"Loading combined embeddings from {combined_path}...") |
| | self.embeddings = torch.load(combined_path, map_location=self.device) |
| | print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") |
| | else: |
| | print("Combined embeddings file not found, loading individual files...") |
| | |
| | import glob |
| | |
| | embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) |
| | embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] |
| | |
| | print(f"Found {len(embedding_files)} individual embedding files") |
| | |
| | |
| | embeddings_list = [] |
| | for file_path in embedding_files: |
| | try: |
| | embedding = torch.load(file_path) |
| | if embedding.dim() == 2: |
| | embeddings_list.append(embedding) |
| | else: |
| | print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
| | except Exception as e: |
| | print(f"Warning: Could not load {file_path}: {e}") |
| | |
| | if not embeddings_list: |
| | raise ValueError("No valid embeddings found!") |
| | |
| | self.embeddings = torch.stack(embeddings_list) |
| | print(f"Loaded {len(self.embeddings)} embeddings from individual files") |
| | |
| | def _compute_preprocessing_stats(self): |
| | """Compute normalization statistics for embeddings.""" |
| | |
| | flat_embeddings = self.embeddings.reshape(-1, ESM_DIM) |
| | |
| | |
| | mean = flat_embeddings.mean(dim=0) |
| | std = flat_embeddings.std(dim=0) |
| | min_val = flat_embeddings.min() |
| | max_val = flat_embeddings.max() |
| | |
| | self.stats = { |
| | 'mean': mean, |
| | 'std': std, |
| | 'min': min_val, |
| | 'max': max_val |
| | } |
| | |
| | |
| | torch.save(self.stats, 'normalization_stats.pt') |
| | print(f"✓ Statistics computed and saved:") |
| | print(f" Total embeddings: {len(self.embeddings):,}") |
| | print(f" Mean: {mean.mean():.4f} ± {mean.std():.4f}") |
| | print(f" Std: {std.mean():.4f} ± {std.std():.4f}") |
| | print(f" Range: [{min_val:.4f}, {max_val:.4f}]") |
| | |
| | def _initialize_models(self): |
| | """Initialize compressor, decompressor, and flow model.""" |
| | print("Initializing models...") |
| | |
| | |
| | self.compressor = Compressor().to(self.device) |
| | self.decompressor = Decompressor().to(self.device) |
| | |
| | self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device)) |
| | self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device)) |
| | |
| | |
| | self.flow_model = AMPFlowMatcherCFGConcat( |
| | hidden_dim=480, |
| | compressed_dim=COMP_DIM, |
| | n_layers=12, |
| | n_heads=16, |
| | dim_ff=3072, |
| | max_seq_len=25, |
| | use_cfg=True |
| | ).to(self.device) |
| | |
| | |
| | try: |
| | self.flow_model = torch.compile(self.flow_model, mode="reduce-overhead") |
| | print("✓ Model compiled with torch.compile for speedup") |
| | except Exception as e: |
| | print(f"⚠️ Model compilation failed: {e}") |
| | |
| | |
| | self.compressor.train() |
| | self.decompressor.train() |
| | self.flow_model.train() |
| | |
| | print(f"✓ Models initialized:") |
| | print(f" Compressor parameters: {sum(p.numel() for p in self.compressor.parameters()):,}") |
| | print(f" Decompressor parameters: {sum(p.numel() for p in self.decompressor.parameters()):,}") |
| | print(f" Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}") |
| | |
| | def _initialize_data(self): |
| | """Initialize datasets and dataloaders with FULL data.""" |
| | print("Initializing datasets with FULL data...") |
| | |
| | |
| | self.cfg_dataset = CFGFlowDataset( |
| | embeddings_path=self.embeddings_path, |
| | cfg_data_path=self.cfg_data_path, |
| | use_masked_labels=True, |
| | max_seq_len=MAX_SEQ_LEN, |
| | device=self.device |
| | ) |
| | |
| | |
| | self.dataloader = create_cfg_dataloader( |
| | self.cfg_dataset, |
| | batch_size=BATCH_SIZE, |
| | shuffle=True, |
| | num_workers=NUM_WORKERS |
| | ) |
| | |
| | |
| | self.total_steps = len(self.dataloader) * EPOCHS |
| | self.validation_steps = VALIDATION_INTERVAL |
| | |
| | print(f"✓ Dataset initialized with FULL data:") |
| | print(f" Total samples: {len(self.cfg_dataset):,}") |
| | print(f" Batch size: {BATCH_SIZE}") |
| | print(f" Batches per epoch: {len(self.dataloader):,}") |
| | print(f" Total training steps: {self.total_steps:,}") |
| | print(f" Validation every: {self.validation_steps:,} steps") |
| | |
| | def _initialize_optimizer(self): |
| | """Initialize optimizer and learning rate scheduler.""" |
| | print("Initializing optimizer and scheduler...") |
| | |
| | |
| | self.optimizer = optim.AdamW( |
| | self.flow_model.parameters(), |
| | lr=BASE_LR, |
| | weight_decay=WEIGHT_DECAY, |
| | betas=(0.9, 0.98), |
| | eps=1e-6 |
| | ) |
| | |
| | |
| | warmup_scheduler = LinearLR( |
| | self.optimizer, |
| | start_factor=0.1, |
| | end_factor=1.0, |
| | total_iters=WARMUP_STEPS |
| | ) |
| | |
| | main_scheduler = CosineAnnealingLR( |
| | self.optimizer, |
| | T_max=self.total_steps - WARMUP_STEPS, |
| | eta_min=LR_MIN |
| | ) |
| | |
| | self.scheduler = SequentialLR( |
| | self.optimizer, |
| | schedulers=[warmup_scheduler, main_scheduler], |
| | milestones=[WARMUP_STEPS] |
| | ) |
| | |
| | print(f"✓ Optimizer initialized:") |
| | print(f" Base LR: {BASE_LR}") |
| | print(f" Min LR: {LR_MIN}") |
| | print(f" Warmup steps: {WARMUP_STEPS}") |
| | print(f" Weight decay: {WEIGHT_DECAY}") |
| | print(f" Gradient clip norm: {GRADIENT_CLIP_NORM}") |
| | |
| | def _preprocess_batch(self, batch): |
| | """Preprocess a batch of data for training.""" |
| | |
| | embeddings = batch['embeddings'].to(self.device) |
| | labels = batch['labels'].to(self.device) |
| | |
| | |
| | m, s = self.stats['mean'].to(self.device), self.stats['std'].to(self.device) |
| | mn, mx = self.stats['min'].to(self.device), self.stats['max'].to(self.device) |
| | |
| | embeddings = (embeddings - m) / (s + 1e-8) |
| | embeddings = (embeddings - mn) / (mx - mn + 1e-8) |
| | |
| | |
| | with torch.no_grad(): |
| | compressed = self.compressor(embeddings) |
| | |
| | return compressed, labels |
| | |
| | def _compute_validation_metrics(self): |
| | """Compute validation metrics on a subset of data.""" |
| | self.flow_model.eval() |
| | val_losses = [] |
| | |
| | |
| | val_samples = min(1000, len(self.cfg_dataset)) |
| | val_indices = torch.randperm(len(self.cfg_dataset))[:val_samples] |
| | |
| | with torch.no_grad(): |
| | for i in range(0, val_samples, BATCH_SIZE): |
| | batch_indices = val_indices[i:i+BATCH_SIZE] |
| | batch_data = [self.cfg_dataset[idx] for idx in batch_indices] |
| | |
| | |
| | embeddings = torch.stack([item['embedding'] for item in batch_data]) |
| | labels = torch.stack([item['label'] for item in batch_data]) |
| | |
| | |
| | compressed, labels = self._preprocess_batch({ |
| | 'embeddings': embeddings, |
| | 'labels': labels |
| | }) |
| | |
| | B, L, D = compressed.shape |
| | |
| | |
| | t = torch.rand(B, device=self.device) |
| | |
| | |
| | eps = torch.randn_like(compressed) |
| | |
| | |
| | xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps |
| | |
| | |
| | vt_pred = self.flow_model(xt, t, labels=labels) |
| | |
| | |
| | vt_target = eps - compressed |
| | |
| | |
| | loss = F.mse_loss(vt_pred, vt_target) |
| | val_losses.append(loss.item()) |
| | |
| | self.flow_model.train() |
| | return np.mean(val_losses) |
| | |
| | def train_flow_matching(self): |
| | """Train the flow matching model with FULL data and optimizations.""" |
| | print(f"🚀 Starting Optimized Single GPU Flow Matching Training with FULL DATA") |
| | print(f"GPU: {GPU_ID}") |
| | print(f"Total iterations: {EPOCHS}") |
| | print(f"Batch size: {BATCH_SIZE}") |
| | print(f"Total samples: {len(self.cfg_dataset):,}") |
| | print(f"Mixed precision: {USE_MIXED_PRECISION}") |
| | print(f"Estimated time: ~8-10 hours (overnight training with ALL data)") |
| | print("=" * 60) |
| | |
| | |
| | best_loss = float('inf') |
| | losses = [] |
| | val_losses = [] |
| | global_step = 0 |
| | start_time = time.time() |
| | |
| | for epoch in tqdm(range(EPOCHS), desc="Training Flow Model"): |
| | epoch_losses = [] |
| | epoch_start_time = time.time() |
| | |
| | for batch_idx, batch in enumerate(self.dataloader): |
| | |
| | compressed, labels = self._preprocess_batch(batch) |
| | B, L, D = compressed.shape |
| | |
| | |
| | if torch.rand(1).item() < CFG_DROPOUT_RATE: |
| | labels = torch.full_like(labels, fill_value=-1) |
| | |
| | |
| | t = torch.rand(B, device=self.device) |
| | |
| | |
| | eps = torch.randn_like(compressed) |
| | |
| | |
| | xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps |
| | |
| | |
| | if USE_MIXED_PRECISION: |
| | with autocast(dtype=torch.bfloat16): |
| | vt_pred = self.flow_model(xt, t, labels=labels) |
| | vt_target = eps - compressed |
| | loss = F.mse_loss(vt_pred, vt_target) |
| | |
| | |
| | self.optimizer.zero_grad() |
| | self.scaler.scale(loss).backward() |
| | |
| | |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM) |
| | |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | else: |
| | |
| | vt_pred = self.flow_model(xt, t, labels=labels) |
| | vt_target = eps - compressed |
| | loss = F.mse_loss(vt_pred, vt_target) |
| | |
| | |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM) |
| | |
| | self.optimizer.step() |
| | |
| | |
| | self.scheduler.step() |
| | |
| | epoch_losses.append(loss.item()) |
| | global_step += 1 |
| | |
| | |
| | if batch_idx % 100 == 0: |
| | current_lr = self.scheduler.get_last_lr()[0] |
| | elapsed_time = time.time() - start_time |
| | steps_per_sec = global_step / elapsed_time |
| | eta_hours = (self.total_steps - global_step) / steps_per_sec / 3600 |
| | |
| | print(f"Epoch {epoch:4d} | Step {global_step:6d}/{self.total_steps:6d} | " |
| | f"Loss: {loss.item():.6f} | LR: {current_lr:.2e} | " |
| | f"Speed: {steps_per_sec:.1f} steps/s | ETA: {eta_hours:.1f}h") |
| | |
| | |
| | if self.use_wandb: |
| | wandb.log({ |
| | 'train/loss': loss.item(), |
| | 'train/learning_rate': current_lr, |
| | 'train/steps_per_sec': steps_per_sec, |
| | 'train/global_step': global_step |
| | }) |
| | |
| | |
| | if global_step % self.validation_steps == 0: |
| | val_loss = self._compute_validation_metrics() |
| | val_losses.append(val_loss) |
| | |
| | print(f"Validation at step {global_step}: Loss = {val_loss:.6f}") |
| | |
| | if self.use_wandb: |
| | wandb.log({ |
| | 'val/loss': val_loss, |
| | 'val/global_step': global_step |
| | }) |
| | |
| | |
| | if val_loss < best_loss: |
| | best_loss = val_loss |
| | self._save_checkpoint(epoch, val_loss, global_step, is_final=False, is_best=True) |
| | |
| | |
| | avg_loss = np.mean(epoch_losses) |
| | losses.append(avg_loss) |
| | epoch_time = time.time() - epoch_start_time |
| | |
| | print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.6f} | " |
| | f"LR: {self.scheduler.get_last_lr()[0]:.2e} | " |
| | f"Time: {epoch_time:.1f}s | Samples: {len(self.cfg_dataset):,}") |
| | |
| | |
| | if (epoch + 1) % CHECKPOINT_INTERVAL == 0: |
| | self._save_checkpoint(epoch, avg_loss, global_step, is_final=True) |
| | |
| | |
| | self._save_checkpoint(EPOCHS - 1, losses[-1], global_step, is_final=True) |
| | |
| | total_time = time.time() - start_time |
| | print("=" * 60) |
| | print("🎉 Optimized Training Complete with FULL DATA!") |
| | print(f"Best validation loss: {best_loss:.6f}") |
| | print(f"Total training time: {total_time/3600:.1f} hours") |
| | print(f"Total samples used: {len(self.cfg_dataset):,}") |
| | print(f"Final model saved as: amp_flow_model_final_optimized.pth") |
| | |
| | return losses, val_losses |
| | |
| | def _save_checkpoint(self, step, loss, global_step, is_final=False, is_best=False): |
| | """Save model checkpoint.""" |
| | |
| | output_dir = '/data2/edwardsun/flow_checkpoints' |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | if is_best: |
| | filename = os.path.join(output_dir, 'amp_flow_model_best_optimized.pth') |
| | elif is_final: |
| | filename = os.path.join(output_dir, 'amp_flow_model_final_optimized.pth') |
| | else: |
| | filename = os.path.join(output_dir, f'amp_flow_checkpoint_optimized_step_{step:04d}.pth') |
| | |
| | checkpoint = { |
| | 'step': step, |
| | 'global_step': global_step, |
| | 'loss': loss, |
| | 'flow_model_state_dict': self.flow_model.state_dict(), |
| | 'optimizer_state_dict': self.optimizer.state_dict(), |
| | 'scheduler_state_dict': self.scheduler.state_dict(), |
| | 'stats': self.stats, |
| | 'total_samples': len(self.cfg_dataset), |
| | 'config': { |
| | 'batch_size': BATCH_SIZE, |
| | 'epochs': EPOCHS, |
| | 'base_lr': BASE_LR, |
| | 'lr_min': LR_MIN, |
| | 'warmup_steps': WARMUP_STEPS, |
| | 'mixed_precision': USE_MIXED_PRECISION, |
| | 'gradient_clip': GRADIENT_CLIP_NORM, |
| | 'weight_decay': WEIGHT_DECAY |
| | } |
| | } |
| | |
| | torch.save(checkpoint, filename) |
| | print(f"✓ Checkpoint saved: {filename} (loss: {loss:.6f}, step: {global_step})") |
| |
|
| | def main(): |
| | """Main training function.""" |
| | global BATCH_SIZE, EPOCHS |
| | |
| | parser = argparse.ArgumentParser(description='Optimized Single GPU AMP Flow Training with FULL DATA') |
| | parser.add_argument('--embeddings', default='/data2/edwardsun/flow_project/peptide_embeddings/', |
| | help='Path to peptide embeddings directory') |
| | parser.add_argument('--cfg_data', default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json', |
| | help='Path to FULL CFG data file') |
| | parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging') |
| | parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training') |
| | parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of training epochs') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if args.batch_size != BATCH_SIZE: |
| | BATCH_SIZE = args.batch_size |
| | if args.epochs != EPOCHS: |
| | EPOCHS = args.epochs |
| | |
| | print(f"Starting optimized training with batch_size={BATCH_SIZE}, epochs={EPOCHS}") |
| | |
| | |
| | trainer = AMPFlowTrainerSingleGPUFullData(args.embeddings, args.cfg_data, args.use_wandb) |
| | |
| | |
| | losses, val_losses = trainer.train_flow_matching() |
| | |
| | print("Optimized training completed successfully with FULL DATA!") |
| |
|
| | if __name__ == "__main__": |
| | main() |