| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.utils.data.distributed import DistributedSampler |
| | import torch.distributed as dist |
| | import numpy as np |
| | from tqdm import tqdm |
| | import json |
| | import os |
| | import argparse |
| |
|
| | |
| | from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset |
| | from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding |
| | from cfg_dataset import CFGFlowDataset, create_cfg_dataloader |
| |
|
| | |
| | ESM_DIM = 1280 |
| | COMP_RATIO = 16 |
| | COMP_DIM = ESM_DIM // COMP_RATIO |
| | MAX_SEQ_LEN = 50 |
| | BATCH_SIZE = 64 |
| | EPOCHS = 5000 |
| | BASE_LR = 1e-4 |
| | LR_MIN = 2e-5 |
| | WARMUP_STEPS = 100 |
| |
|
| | def setup_distributed(): |
| | """Setup distributed training.""" |
| | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| | rank = int(os.environ["RANK"]) |
| | world_size = int(os.environ['WORLD_SIZE']) |
| | local_rank = int(os.environ['LOCAL_RANK']) |
| | else: |
| | print('Not using distributed mode') |
| | return None, None, None |
| |
|
| | torch.cuda.set_device(local_rank) |
| | dist.init_process_group(backend='nccl', init_method='env://') |
| | dist.barrier() |
| | |
| | return rank, world_size, local_rank |
| |
|
| | class AMPFlowTrainerMultiGPU: |
| | """ |
| | Multi-GPU training pipeline for AMP generation using ProtFlow methodology. |
| | """ |
| | |
| | def __init__(self, embeddings_path, cfg_data_path, rank, world_size, local_rank): |
| | self.rank = rank |
| | self.world_size = world_size |
| | self.local_rank = local_rank |
| | self.device = torch.device(f'cuda:{local_rank}') |
| | self.embeddings_path = embeddings_path |
| | self.cfg_data_path = cfg_data_path |
| | |
| | |
| | if self.rank == 0: |
| | print(f"Loading ALL AMP embeddings from {embeddings_path}...") |
| | |
| | |
| | combined_path = os.path.join(embeddings_path, "all_peptide_embeddings.pt") |
| | |
| | if os.path.exists(combined_path): |
| | print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") |
| | self.embeddings = torch.load(combined_path, map_location=self.device) |
| | print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") |
| | else: |
| | print("Combined embeddings file not found, loading individual files...") |
| | |
| | import glob |
| | |
| | embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt")) |
| | embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] |
| | |
| | print(f"Found {len(embedding_files)} individual embedding files") |
| | |
| | |
| | embeddings_list = [] |
| | for file_path in embedding_files: |
| | try: |
| | embedding = torch.load(file_path) |
| | if embedding.dim() == 2: |
| | embeddings_list.append(embedding) |
| | else: |
| | print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
| | except Exception as e: |
| | print(f"Warning: Could not load {file_path}: {e}") |
| | |
| | if not embeddings_list: |
| | raise ValueError("No valid embeddings found!") |
| | |
| | self.embeddings = torch.stack(embeddings_list) |
| | print(f"Loaded {len(self.embeddings)} embeddings from individual files") |
| | |
| | |
| | print("Computing preprocessing statistics...") |
| | self._compute_preprocessing_stats() |
| | |
| | |
| | if self.rank == 0: |
| | stats_tensor = torch.stack([ |
| | self.stats['mean'], self.stats['std'], |
| | self.stats['min'], self.stats['max'] |
| | ]).to(self.device) |
| | else: |
| | stats_tensor = torch.zeros(4, ESM_DIM, device=self.device) |
| | |
| | dist.broadcast(stats_tensor, src=0) |
| | |
| | if self.rank != 0: |
| | self.stats = { |
| | 'mean': stats_tensor[0], |
| | 'std': stats_tensor[1], |
| | 'min': stats_tensor[2], |
| | 'max': stats_tensor[3] |
| | } |
| | |
| | |
| | self._initialize_models() |
| | |
| | def _compute_preprocessing_stats(self): |
| | """Compute preprocessing statistics (only on main process).""" |
| | |
| | flat = self.embeddings.view(-1, ESM_DIM) |
| | |
| | |
| | feat_mean = flat.mean(0) |
| | feat_std = flat.std(0) + 1e-8 |
| | |
| | |
| | z_score_normalized = (flat - feat_mean) / feat_std |
| | z_score_clamped = torch.clamp(z_score_normalized, -4, 4) |
| | |
| | |
| | feat_min = z_score_clamped.min(0)[0] |
| | feat_max = z_score_clamped.max(0)[0] |
| | |
| | |
| | self.stats = { |
| | 'mean': feat_mean, |
| | 'std': feat_std, |
| | 'min': feat_min, |
| | 'max': feat_max |
| | } |
| | |
| | |
| | torch.save(self.stats, 'normalization_stats.pt') |
| | if self.rank == 0: |
| | print("✓ Preprocessing statistics computed and saved to normalization_stats.pt") |
| | |
| | def _initialize_models(self): |
| | """Initialize models for distributed training.""" |
| | |
| | self.compressor = Compressor().to(self.device) |
| | self.decompressor = Decompressor().to(self.device) |
| | |
| | |
| | self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device)) |
| | self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device)) |
| | |
| | |
| | self.flow_model = AMPFlowMatcherCFGConcat( |
| | hidden_dim=480, |
| | compressed_dim=COMP_DIM, |
| | n_layers=12, |
| | n_heads=16, |
| | dim_ff=3072, |
| | max_seq_len=25, |
| | use_cfg=True |
| | ).to(self.device) |
| | |
| | |
| | self.flow_model = DDP(self.flow_model, device_ids=[self.local_rank], find_unused_parameters=True) |
| | |
| | if self.rank == 0: |
| | print("✓ Initialized models for distributed training") |
| | print(f" - Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}") |
| | print(f" - Using {self.world_size} GPUs") |
| | |
| | def _preprocess_batch(self, batch): |
| | """Apply preprocessing to a batch of embeddings.""" |
| | |
| | h_norm = (batch - self.stats['mean'].to(batch.device)) / self.stats['std'].to(batch.device) |
| | |
| | |
| | h_trunc = torch.clamp(h_norm, min=-4.0, max=4.0) |
| | |
| | |
| | h_min = self.stats['min'].to(batch.device) |
| | h_max = self.stats['max'].to(batch.device) |
| | h_scaled = (h_trunc - h_min) / (h_max - h_min + 1e-8) |
| | h_scaled = torch.clamp(h_scaled, 0.0, 1.0) |
| | |
| | return h_scaled |
| | |
| | def train_flow_matching(self): |
| | """Train the flow matching model using distributed training.""" |
| | if self.rank == 0: |
| | print("Step 3: Training Flow Matching model (Multi-GPU)...") |
| | |
| | |
| | try: |
| | |
| | dataset = CFGFlowDataset( |
| | embeddings_path=self.embeddings_path, |
| | cfg_data_path=self.cfg_data_path, |
| | use_masked_labels=True, |
| | max_seq_len=MAX_SEQ_LEN, |
| | device=self.device |
| | ) |
| | print("✓ Using CFG dataset with real labels") |
| | except Exception as e: |
| | print(f"Warning: Could not load CFG dataset: {e}") |
| | print("Falling back to random labels (not recommended for CFG)") |
| | |
| | dataset = PrecomputedEmbeddingDataset(self.embeddings_path) |
| | |
| | sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank) |
| | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4) |
| | |
| | |
| | optimizer = optim.AdamW( |
| | self.flow_model.parameters(), |
| | lr=BASE_LR, |
| | betas=(0.9, 0.98), |
| | weight_decay=0.01, |
| | eps=1e-6 |
| | ) |
| | |
| | |
| | warmup_sched = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS) |
| | cosine_sched = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR_MIN) |
| | scheduler = SequentialLR(optimizer, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS]) |
| | |
| | |
| | self.flow_model.train() |
| | total_steps = 0 |
| | |
| | if self.rank == 0: |
| | print(f"Starting training for {EPOCHS} iterations with FULL DATA...") |
| | print(f"Total batch size: {BATCH_SIZE * self.world_size}") |
| | print(f"Steps per epoch: {len(dataloader)}") |
| | print(f"Total samples: {len(dataset):,}") |
| | print(f"Estimated time: ~30-45 minutes (using ALL data)") |
| | |
| | for epoch in range(EPOCHS): |
| | sampler.set_epoch(epoch) |
| | |
| | for batch_idx, batch_data in enumerate(dataloader): |
| | |
| | if isinstance(batch_data, dict) and 'embeddings' in batch_data: |
| | |
| | x = batch_data['embeddings'].to(self.device) |
| | labels = batch_data['labels'].to(self.device) |
| | else: |
| | |
| | x = batch_data.to(self.device) |
| | labels = torch.randint(0, 3, (x.shape[0],), device=self.device) |
| | |
| | batch_size = x.shape[0] |
| | |
| | |
| | x_processed = self._preprocess_batch(x) |
| | |
| | |
| | with torch.no_grad(): |
| | z = self.compressor(x_processed, self.stats) |
| | |
| | |
| | eps = torch.randn_like(z) |
| | |
| | |
| | t = torch.rand(batch_size, device=self.device) |
| | |
| | |
| | xt = t.view(batch_size, 1, 1) * eps + (1 - t.view(batch_size, 1, 1)) * z |
| | |
| | |
| | ut = eps - z |
| | |
| | |
| | |
| | |
| | |
| | vt_pred = self.flow_model(xt, t, labels=labels) |
| | |
| | |
| | loss = ((vt_pred - ut) ** 2).mean() |
| | |
| | |
| | optimizer.zero_grad() |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), 1.0) |
| | |
| | optimizer.step() |
| | scheduler.step() |
| | |
| | total_steps += 1 |
| | |
| | |
| | if self.rank == 0 and total_steps % 10 == 0: |
| | progress = (total_steps / EPOCHS) * 100 |
| | label_dist = torch.bincount(labels, minlength=3) |
| | print(f"Step {total_steps}/{EPOCHS} ({progress:.1f}%): Loss = {loss.item():.6f}, LR = {scheduler.get_last_lr()[0]:.2e}, Labels: AMP={label_dist[0]}, Non-AMP={label_dist[1]}, Mask={label_dist[2]}") |
| | |
| | |
| | if self.rank == 0 and total_steps % 100 == 0: |
| | self._save_checkpoint(total_steps, loss.item()) |
| | |
| | |
| | if self.rank == 0 and total_steps % 200 == 0: |
| | self._validate() |
| | |
| | |
| | if self.rank == 0: |
| | self._save_checkpoint(total_steps, loss.item(), is_final=True) |
| | print("✓ Flow matching training completed!") |
| | |
| | def _save_checkpoint(self, step, loss, is_final=False): |
| | """Save training checkpoint (only on main process).""" |
| | |
| | model_state_dict = self.flow_model.module.state_dict() |
| | |
| | checkpoint = { |
| | 'step': step, |
| | 'flow_model_state_dict': model_state_dict, |
| | 'loss': loss, |
| | } |
| | |
| | if is_final: |
| | torch.save(checkpoint, 'amp_flow_model_final_full_data.pth') |
| | print(f"✓ Final model saved: amp_flow_model_final_full_data.pth") |
| | else: |
| | torch.save(checkpoint, f'amp_flow_checkpoint_full_data_step_{step}.pth') |
| | print(f"✓ Checkpoint saved: amp_flow_checkpoint_full_data_step_{step}.pth") |
| | |
| | def _validate(self): |
| | """Validate the model by generating a few samples.""" |
| | print("Generating validation samples...") |
| | self.flow_model.eval() |
| | |
| | with torch.no_grad(): |
| | |
| | eps = torch.randn(4, 25, COMP_DIM, device=self.device) |
| | xt = eps.clone() |
| | |
| | |
| | labels = torch.full((4,), 0, device=self.device) |
| | for step in range(25): |
| | t = torch.ones(4, device=self.device) * (1.0 - step/25) |
| | vt = self.flow_model(xt, t, labels=labels) |
| | dt = 1.0 / 25 |
| | xt = xt + vt * dt |
| | |
| | |
| | decompressed = self.decompressor(xt) |
| | |
| | |
| | m, s, mn, mx = self.stats['mean'].to(self.device), self.stats['std'].to(self.device), self.stats['min'].to(self.device), self.stats['max'].to(self.device) |
| | decompressed = decompressed * (mx - mn + 1e-8) + mn |
| | decompressed = decompressed * s + m |
| | |
| | print(f" Generated samples shape: {decompressed.shape}") |
| | print(f" Sample stats - Mean: {decompressed.mean():.4f}, Std: {decompressed.std():.4f}") |
| | |
| | self.flow_model.train() |
| |
|
| | def main(): |
| | """Main training function with distributed setup.""" |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--local_rank', type=int, default=0) |
| | parser.add_argument('--cfg_data_path', type=str, default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json', |
| | help='Path to FULL CFG training data with real labels') |
| | args = parser.parse_args() |
| | |
| | |
| | rank, world_size, local_rank = setup_distributed() |
| | |
| | if rank == 0: |
| | print("=== Multi-GPU AMP Flow Matching Training Pipeline with FULL DATA ===") |
| | print("This implements the complete ProtFlow methodology for AMP generation.") |
| | print("Training for 5,000 iterations (~30-45 minutes) using ALL available data.") |
| | print() |
| | |
| | |
| | required_files = [ |
| | 'final_compressor_model.pth', |
| | 'final_decompressor_model.pth', |
| | '/data2/edwardsun/flow_project/peptide_embeddings/' |
| | ] |
| | |
| | for file in required_files: |
| | if not os.path.exists(file): |
| | print(f"❌ Missing required file: {file}") |
| | print("Please ensure you have:") |
| | print("1. Run final_sequence_encoder.py to generate embeddings") |
| | print("2. Run compressor_with_embeddings.py to train compressor/decompressor") |
| | return |
| | |
| | |
| | if not os.path.exists(args.cfg_data_path): |
| | print(f"⚠️ CFG data not found: {args.cfg_data_path}") |
| | print("Training will use random labels (not recommended for CFG)") |
| | print("To use real labels, run uniprot_data_processor.py first") |
| | else: |
| | print(f"✓ CFG data found: {args.cfg_data_path}") |
| | |
| | print("✓ All required files found!") |
| | print() |
| | |
| | |
| | trainer = AMPFlowTrainerMultiGPU( |
| | embeddings_path='/data2/edwardsun/flow_project/peptide_embeddings/', |
| | cfg_data_path=args.cfg_data_path, |
| | rank=rank, |
| | world_size=world_size, |
| | local_rank=local_rank |
| | ) |
| | |
| | |
| | trainer.train_flow_matching() |
| | |
| | if rank == 0: |
| | print("\n=== Multi-GPU Training Complete with FULL DATA ===") |
| | print("Your AMP flow matching model trained on ALL available data!") |
| | print("Next steps:") |
| | print("1. Test the model: python generate_amps.py") |
| | print("2. Compare performance with previous model") |
| | print("3. Implement reflow for 1-step generation") |
| | print("4. Add conditioning for toxicity (future project)") |
| |
|
| | if __name__ == "__main__": |
| | main() |