| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import math |
| |
|
| | class SinusoidalTimeEmbedding(nn.Module): |
| | """Sinusoidal time embedding as used in ProtFlow paper.""" |
| | |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| | |
| | def forward(self, time): |
| | device = time.device |
| | half_dim = self.dim // 2 |
| | embeddings = math.log(10000) / (half_dim - 1) |
| | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
| | |
| | if time.dim() > 2: |
| | time = time.squeeze() |
| | embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) |
| | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
| | |
| | if embeddings.dim() > 2: |
| | embeddings = embeddings.squeeze() |
| | return embeddings |
| |
|
| | class LabelMLP(nn.Module): |
| | """ |
| | MLP for processing class labels into embeddings. |
| | This approach processes labels separately from time embeddings. |
| | """ |
| | def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256): |
| | super().__init__() |
| | self.num_classes = num_classes |
| | |
| | |
| | self.label_mlp = nn.Sequential( |
| | nn.Embedding(num_classes, mlp_dim), |
| | nn.Linear(mlp_dim, mlp_dim), |
| | nn.GELU(), |
| | nn.Linear(mlp_dim, hidden_dim), |
| | nn.GELU(), |
| | nn.Linear(hidden_dim, hidden_dim) |
| | ) |
| | |
| | |
| | nn.init.normal_(self.label_mlp[0].weight, std=0.02) |
| | |
| | def forward(self, labels): |
| | """ |
| | Args: |
| | labels: (B,) tensor of class labels |
| | - 0: AMP (MIC < 100) |
| | - 1: Non-AMP (MIC >= 100) |
| | - 2: Mask (Unknown MIC) |
| | Returns: |
| | embeddings: (B, hidden_dim) tensor of processed label embeddings |
| | """ |
| | return self.label_mlp(labels) |
| |
|
| | class AMPFlowMatcherCFGConcat(nn.Module): |
| | """ |
| | Flow Matching model with Classifier-Free Guidance using concatenation approach. |
| | - 12-layer transformer with long skip connections |
| | - Time embedding + MLP-processed label embedding (concatenated then projected) |
| | - Optimized for peptide sequences (max length 50) |
| | """ |
| | |
| | def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16, |
| | dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True): |
| | super().__init__() |
| | self.hidden_dim = hidden_dim |
| | self.compressed_dim = compressed_dim |
| | self.n_layers = n_layers |
| | self.max_seq_len = max_seq_len |
| | self.use_cfg = use_cfg |
| | |
| | |
| | self.time_embed = nn.Sequential( |
| | SinusoidalTimeEmbedding(hidden_dim), |
| | nn.Linear(hidden_dim, hidden_dim), |
| | nn.GELU(), |
| | nn.Linear(hidden_dim, hidden_dim) |
| | ) |
| | |
| | |
| | if use_cfg: |
| | self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim) |
| | |
| | |
| | self.condition_proj = nn.Sequential( |
| | nn.Linear(hidden_dim * 2, hidden_dim), |
| | nn.GELU(), |
| | nn.Linear(hidden_dim, hidden_dim) |
| | ) |
| | |
| | |
| | self.compress_proj = nn.Linear(compressed_dim, hidden_dim) |
| | self.decompress_proj = nn.Linear(hidden_dim, compressed_dim) |
| | |
| | |
| | self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim)) |
| | |
| | |
| | self.layers = nn.ModuleList([ |
| | nn.TransformerEncoderLayer( |
| | d_model=hidden_dim, |
| | nhead=n_heads, |
| | dim_feedforward=dim_ff, |
| | dropout=dropout, |
| | activation='gelu', |
| | batch_first=True |
| | ) for _ in range(n_layers) |
| | ]) |
| | |
| | |
| | self.skip_projections = nn.ModuleList([ |
| | nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1) |
| | ]) |
| | |
| | |
| | self.output_proj = nn.Linear(hidden_dim, compressed_dim) |
| | |
| | def forward(self, x, t, labels=None, mask=None): |
| | """ |
| | Args: |
| | x: compressed latent (B, L, compressed_dim) - AMP embeddings |
| | t: time scalar (B,) or (B, 1) |
| | labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask |
| | mask: attention mask (B, L) if needed |
| | """ |
| | B, L, D = x.shape |
| | |
| | |
| | x = self.compress_proj(x) |
| | |
| | |
| | if L <= self.max_seq_len: |
| | x = x + self.pos_embed[:, :L, :] |
| | |
| | |
| | if t.dim() == 1: |
| | t = t.unsqueeze(-1) |
| | elif t.dim() > 2: |
| | t = t.squeeze() |
| | if t.dim() == 1: |
| | t = t.unsqueeze(-1) |
| | |
| | t_emb = self.time_embed(t) |
| | |
| | if t_emb.dim() > 2: |
| | t_emb = t_emb.squeeze() |
| | t_emb = t_emb.unsqueeze(1).expand(-1, L, -1) |
| | |
| | |
| | if self.use_cfg and labels is not None: |
| | |
| | label_emb = self.label_mlp(labels) |
| | label_emb = label_emb.unsqueeze(1).expand(-1, L, -1) |
| | |
| | |
| | combined_emb = torch.cat([t_emb, label_emb], dim=-1) |
| | projected_emb = self.condition_proj(combined_emb) |
| | else: |
| | projected_emb = t_emb |
| | |
| | |
| | skip_features = [] |
| | |
| | |
| | for i, layer in enumerate(self.layers): |
| | |
| | if i > 0 and i < len(self.layers) - 1: |
| | skip_feat = skip_features[i-1] |
| | skip_feat = self.skip_projections[i-1](skip_feat) |
| | x = x + skip_feat |
| | |
| | |
| | if i < len(self.layers) - 1: |
| | skip_features.append(x.clone()) |
| | |
| | |
| | x = x + projected_emb |
| | |
| | |
| | x = layer(x, src_key_padding_mask=mask) |
| | |
| | |
| | x = self.output_proj(x) |
| | |
| | return x |
| |
|
| | class AMPProtFlowPipelineCFG: |
| | """ |
| | Complete ProtFlow pipeline for AMP generation with CFG. |
| | """ |
| | |
| | def __init__(self, compressor, decompressor, flow_model, device='cuda'): |
| | self.compressor = compressor |
| | self.decompressor = decompressor |
| | self.flow_model = flow_model |
| | self.device = device |
| | |
| | |
| | self.stats = torch.load('normalization_stats.pt', map_location=device) |
| | |
| | def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5, |
| | condition_label=0): |
| | """ |
| | Generate AMP samples using CFG. |
| | |
| | Args: |
| | num_samples: Number of samples to generate |
| | num_steps: Number of ODE solving steps |
| | cfg_scale: CFG guidance scale (higher = stronger conditioning) |
| | condition_label: 0=AMP, 1=Non-AMP, 2=Mask |
| | """ |
| | print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...") |
| | |
| | |
| | batch_size = min(num_samples, 32) |
| | all_samples = [] |
| | |
| | for i in range(0, num_samples, batch_size): |
| | current_batch = min(batch_size, num_samples - i) |
| | |
| | |
| | eps = torch.randn(current_batch, self.flow_model.max_seq_len, |
| | self.flow_model.compressed_dim, device=self.device) |
| | |
| | |
| | xt = eps.clone() |
| | for step in range(num_steps): |
| | t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) |
| | |
| | |
| | if cfg_scale > 0: |
| | |
| | vt_cond = self.flow_model(xt, t, |
| | labels=torch.full((current_batch,), condition_label, |
| | device=self.device)) |
| | |
| | |
| | vt_uncond = self.flow_model(xt, t, |
| | labels=torch.full((current_batch,), 2, |
| | device=self.device)) |
| | |
| | |
| | vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
| | else: |
| | |
| | vt = self.flow_model(xt, t, |
| | labels=torch.full((current_batch,), 2, |
| | device=self.device)) |
| | |
| | |
| | |
| | dt = -1.0 / num_steps |
| | xt = xt + vt * dt |
| | |
| | all_samples.append(xt) |
| | |
| | |
| | generated = torch.cat(all_samples, dim=0) |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | decompressed = self.decompressor(generated) |
| | |
| | |
| | m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] |
| | decompressed = decompressed * (mx - mn + 1e-8) + mn |
| | decompressed = decompressed * s + m |
| | |
| | return generated, decompressed |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | flow_model = AMPFlowMatcherCFGConcat( |
| | hidden_dim=480, |
| | compressed_dim=30, |
| | n_layers=12, |
| | n_heads=16, |
| | dim_ff=3072, |
| | max_seq_len=25, |
| | use_cfg=True |
| | ) |
| | |
| | print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}") |
| | |
| | |
| | batch_size = 4 |
| | seq_len = 20 |
| | compressed_dim = 30 |
| | |
| | x = torch.randn(batch_size, seq_len, compressed_dim) |
| | t = torch.rand(batch_size) |
| | labels = torch.randint(0, 3, (batch_size,)) |
| | |
| | with torch.no_grad(): |
| | output = flow_model(x, t, labels=labels) |
| | print(f"Input shape: {x.shape}") |
| | print(f"Output shape: {output.shape}") |
| | print(f"Time embedding shape: {t.shape}") |
| | print(f"Labels: {labels}") |
| | |
| | print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!") |