| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR |
| | import json |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| | |
| | ESM_DIM = 1280 |
| | COMP_RATIO = 16 |
| | COMP_DIM = ESM_DIM // COMP_RATIO |
| | MAX_SEQ_LEN = 50 |
| | BATCH_SIZE = 32 |
| | EPOCHS = 30 |
| | BASE_LR = 1e-3 |
| | LR_MIN = 8e-5 |
| | WARMUP_STEPS = 10_000 |
| | DEPTH = 4 |
| | HEADS = 8 |
| | DIM_FF = ESM_DIM * 4 |
| | POOLING = True |
| |
|
| | |
| | class PrecomputedEmbeddingDataset(Dataset): |
| | def __init__(self, embeddings_path): |
| | """ |
| | Load pre-computed embeddings from the final_sequence_encoder.py output. |
| | Args: |
| | embeddings_path: Path to the directory containing individual .pt embedding files |
| | """ |
| | print(f"Loading pre-computed embeddings from {embeddings_path}...") |
| | |
| | |
| | import glob |
| | import os |
| | |
| | embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt")) |
| | embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json')] |
| | |
| | print(f"Found {len(embedding_files)} embedding files") |
| | |
| | |
| | embeddings_list = [] |
| | for file_path in embedding_files: |
| | try: |
| | embedding = torch.load(file_path) |
| | if embedding.dim() == 2: |
| | embeddings_list.append(embedding) |
| | else: |
| | print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
| | except Exception as e: |
| | print(f"Warning: Could not load {file_path}: {e}") |
| | |
| | if not embeddings_list: |
| | raise ValueError("No valid embeddings found!") |
| | |
| | self.embeddings = torch.stack(embeddings_list) |
| | print(f"Loaded {len(self.embeddings)} embeddings with shape {self.embeddings.shape}") |
| | |
| | |
| | if len(self.embeddings.shape) != 3: |
| | raise ValueError(f"Expected 3D tensor, got shape {self.embeddings.shape}") |
| | |
| | if self.embeddings.shape[1] != MAX_SEQ_LEN: |
| | print(f"Warning: Expected sequence length {MAX_SEQ_LEN}, got {self.embeddings.shape[1]}") |
| | |
| | if self.embeddings.shape[2] != ESM_DIM: |
| | print(f"Warning: Expected embedding dim {ESM_DIM}, got {self.embeddings.shape[2]}") |
| |
|
| | def __len__(self): |
| | return len(self.embeddings) |
| | |
| | def __getitem__(self, idx): |
| | return self.embeddings[idx] |
| |
|
| | |
| | class Compressor(nn.Module): |
| | def __init__(self, in_dim=ESM_DIM, out_dim=COMP_DIM): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(in_dim) |
| | layer = lambda: nn.TransformerEncoderLayer( |
| | d_model=in_dim, nhead=HEADS, dim_feedforward=DIM_FF, |
| | batch_first=True) |
| | |
| | self.pre_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
| | self.post_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
| | self.proj = nn.Sequential( |
| | nn.LayerNorm(in_dim), |
| | nn.Linear(in_dim, out_dim), |
| | nn.Tanh() |
| | ) |
| | self.pooling = POOLING |
| |
|
| | def forward(self, x, stats=None): |
| | if stats: |
| | m, s, mn, mx = stats['mean'], stats['std'], stats['min'], stats['max'] |
| | |
| | m = m.to(x.device) |
| | s = s.to(x.device) |
| | mn = mn.to(x.device) |
| | mx = mx.to(x.device) |
| | x = torch.clamp((x - m) / s, -4, 4) |
| | x = torch.clamp((x - mn) / (mx - mn + 1e-8), 0, 1) |
| | x = self.norm(x) |
| | x = self.pre_tr(x) |
| | if self.pooling: |
| | B, L, D = x.shape |
| | if L % 2: x = x[:, :-1, :] |
| | x = x.view(B, L//2, 2, D).mean(2) |
| | x = self.post_tr(x) |
| | return self.proj(x) |
| |
|
| | |
| | class Decompressor(nn.Module): |
| | def __init__(self, in_dim=COMP_DIM, out_dim=ESM_DIM): |
| | super().__init__() |
| | self.proj = nn.Sequential( |
| | nn.LayerNorm(in_dim), |
| | nn.Linear(in_dim, out_dim) |
| | ) |
| | layer = lambda: nn.TransformerEncoderLayer( |
| | d_model=out_dim, nhead=HEADS, dim_feedforward=DIM_FF, |
| | batch_first=True) |
| | self.decoder = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
| | self.pooling = POOLING |
| |
|
| | def forward(self, z): |
| | x = self.proj(z) |
| | if self.pooling: |
| | x = x.repeat_interleave(2, dim=1) |
| | return self.decoder(x) |
| |
|
| | |
| | def train_with_precomputed_embeddings(embeddings_path, device='cuda'): |
| | """ |
| | Train compressor using pre-computed embeddings from final_sequence_encoder.py |
| | """ |
| | |
| | ds = PrecomputedEmbeddingDataset(embeddings_path) |
| | |
| | |
| | print("Computing normalization statistics...") |
| | flat = ds.embeddings.view(-1, ESM_DIM) |
| | stats = { |
| | 'mean': flat.mean(0), |
| | 'std': flat.std(0) + 1e-8, |
| | 'min': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).min(0)[0], |
| | 'max': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).max(0)[0] |
| | } |
| | |
| | |
| | torch.save(stats, 'normalization_stats.pt') |
| | print("Saved normalization statistics to normalization_stats.pt") |
| | |
| | |
| | dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) |
| | |
| | |
| | comp = Compressor().to(device) |
| | decomp = Decompressor().to(device) |
| | |
| | |
| | opt = optim.AdamW(list(comp.parameters()) + list(decomp.parameters()), lr=BASE_LR) |
| | |
| | |
| | warmup_sched = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS) |
| | cosine_sched = CosineAnnealingLR(opt, T_max=EPOCHS*len(dl), eta_min=LR_MIN) |
| | sched = SequentialLR(opt, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS]) |
| |
|
| | print(f"Starting training for {EPOCHS} epochs...") |
| | print(f"Device: {device}") |
| | print(f"Batch size: {BATCH_SIZE}") |
| | print(f"Total batches per epoch: {len(dl)}") |
| |
|
| | |
| | for epoch in range(1, EPOCHS+1): |
| | total_loss = 0 |
| | comp.train() |
| | decomp.train() |
| | |
| | for batch_idx, x in enumerate(tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")): |
| | x = x.to(device) |
| | z = comp(x, stats) |
| | xr = decomp(z) |
| | loss = (x - xr).pow(2).mean() |
| | |
| | opt.zero_grad() |
| | loss.backward() |
| | opt.step() |
| | sched.step() |
| | |
| | total_loss += loss.item() |
| | |
| | |
| | if batch_idx % 100 == 0: |
| | print(f" Batch {batch_idx}/{len(dl)} - Loss: {loss.item():.6f}") |
| | |
| | avg_loss = total_loss / len(dl) |
| | print(f"Epoch {epoch}/{EPOCHS} — Average MSE: {avg_loss:.6f}") |
| | |
| | |
| | if epoch % 5 == 0: |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'compressor_state_dict': comp.state_dict(), |
| | 'decompressor_state_dict': decomp.state_dict(), |
| | 'optimizer_state_dict': opt.state_dict(), |
| | 'loss': avg_loss, |
| | }, f'checkpoint_epoch_{epoch}.pth') |
| |
|
| | |
| | torch.save(comp.state_dict(), 'compressor_final.pth') |
| | torch.save(decomp.state_dict(), 'decompressor_final.pth') |
| | print("Training completed! Models saved as compressor_final.pth and decompressor_final.pth") |
| |
|
| | |
| | def load_and_test_models(compressor_path, decompressor_path, embeddings_path, device='cuda'): |
| | """ |
| | Load trained models and test reconstruction quality |
| | """ |
| | print("Loading trained models...") |
| | comp = Compressor().to(device) |
| | decomp = Decompressor().to(device) |
| | |
| | comp.load_state_dict(torch.load(compressor_path)) |
| | decomp.load_state_dict(torch.load(decompressor_path)) |
| | |
| | comp.eval() |
| | decomp.eval() |
| | |
| | |
| | ds = PrecomputedEmbeddingDataset(embeddings_path) |
| | test_loader = DataLoader(ds, batch_size=16, shuffle=False) |
| | |
| | |
| | stats = torch.load('normalization_stats.pt') |
| | |
| | print("Testing reconstruction quality...") |
| | total_mse = 0 |
| | total_samples = 0 |
| | |
| | with torch.no_grad(): |
| | for batch in tqdm(test_loader, desc="Testing"): |
| | x = batch.to(device) |
| | z = comp(x, stats) |
| | xr = decomp(z) |
| | mse = (x - xr).pow(2).mean() |
| | total_mse += mse.item() * len(x) |
| | total_samples += len(x) |
| | |
| | avg_mse = total_mse / total_samples |
| | print(f"Average reconstruction MSE: {avg_mse:.6f}") |
| | |
| | return avg_mse |
| |
|
| | |
| | if __name__ == '__main__': |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description='Train protein compressor with pre-computed embeddings') |
| | parser.add_argument('--embeddings', type=str, default='/data2/edwardsun/flow_project/compressor_dataset/peptide_embeddings.pt', |
| | help='Path to pre-computed embeddings from final_sequence_encoder.py') |
| | parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)') |
| | parser.add_argument('--test', action='store_true', help='Test existing models instead of training') |
| | |
| | args = parser.parse_args() |
| | |
| | device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| | |
| | if args.test: |
| | |
| | load_and_test_models('compressor_final.pth', 'decompressor_final.pth', args.embeddings, device) |
| | else: |
| | |
| | train_with_precomputed_embeddings(args.embeddings, device) |