File size: 9,517 Bytes
370f342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import os
from datetime import datetime

# Import your components
from compressor_with_embeddings import Compressor, Decompressor
from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG

class AMPGenerator:
    """
    Generate AMP samples using trained ProtFlow model.
    """
    
    def __init__(self, model_path, device='cuda'):
        self.device = device
        
        # Load models
        self._load_models(model_path)
        
        # Load preprocessing statistics
        self.stats = torch.load('normalization_stats.pt', map_location=device)
        
    def _load_models(self, model_path):
        """Load trained models."""
        print("Loading trained models...")
        
        # Load compressor and decompressor
        self.compressor = Compressor().to(self.device)
        self.decompressor = Decompressor().to(self.device)
        
        self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device))
        self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device))
        
        # Load flow matching model with CFG
        self.flow_model = AMPFlowMatcherCFGConcat(
            hidden_dim=480,
            compressed_dim=80,  # 1280 // 16
            n_layers=12,
            n_heads=16,
            dim_ff=3072,
            max_seq_len=25,
            use_cfg=True
        ).to(self.device)
        
        checkpoint = torch.load(model_path, map_location=self.device)
        
        # Handle PyTorch compilation wrapper
        state_dict = checkpoint['flow_model_state_dict']
        new_state_dict = {}
        
        for key, value in state_dict.items():
            # Remove _orig_mod prefix if present
            if key.startswith('_orig_mod.'):
                new_key = key[10:]  # Remove '_orig_mod.' prefix
            else:
                new_key = key
            new_state_dict[new_key] = value
        
        self.flow_model.load_state_dict(new_state_dict)
        
        print(f"✓ All models loaded successfully from step {checkpoint['step']}!")
        print(f"  Loss at checkpoint: {checkpoint['loss']:.6f}")
        
    def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5):
        """
        Generate AMP samples using flow matching with CFG.
        
        Args:
            num_samples: Number of AMP samples to generate
            num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow)
            batch_size: Batch size for generation
            cfg_scale: CFG guidance scale (higher = stronger conditioning)
        """
        print(f"Generating {num_samples} AMP samples with {num_steps} steps (CFG scale: {cfg_scale})...")
        
        self.flow_model.eval()
        self.compressor.eval()
        self.decompressor.eval()
        
        all_generated = []
        
        with torch.no_grad():
            for i in tqdm(range(0, num_samples, batch_size), desc="Generating"):
                current_batch = min(batch_size, num_samples - i)
                
                # Sample random noise
                eps = torch.randn(current_batch, 25, 80, device=self.device)  # [B, L', COMP_DIM]
                
                # ODE solving steps with CFG
                xt = eps.clone()
                amp_labels = torch.full((current_batch,), 0, device=self.device)  # 0 = AMP
                mask_labels = torch.full((current_batch,), 2, device=self.device)  # 2 = Mask
                
                for step in range(num_steps):
                    t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps)
                    
                    # CFG: Generate with condition and without condition
                    if cfg_scale > 0:
                        # With AMP condition
                        vt_cond = self.flow_model(xt, t, labels=amp_labels)
                        
                        # Without condition (mask)
                        vt_uncond = self.flow_model(xt, t, labels=mask_labels)
                        
                        # CFG interpolation
                        vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
                    else:
                        # No CFG, use mask label
                        vt = self.flow_model(xt, t, labels=mask_labels)
                    
                    # Euler step for backward integration (t: 1 -> 0)
                    # Use negative dt to integrate backward from noise to data
                    dt = -1.0 / num_steps
                    xt = xt + vt * dt
                
                # Decompress to get embeddings
                decompressed = self.decompressor(xt)  # [B, L, ESM_DIM]
                
                # Apply reverse preprocessing
                m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max']
                decompressed = decompressed * (mx - mn + 1e-8) + mn
                decompressed = decompressed * s + m
                
                all_generated.append(decompressed.cpu())
        
        # Concatenate all batches
        generated_embeddings = torch.cat(all_generated, dim=0)
        
        print(f"✓ Generated {generated_embeddings.shape[0]} AMP embeddings")
        print(f"  Shape: {generated_embeddings.shape}")
        print(f"  Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}")
        
        return generated_embeddings
    
    def generate_with_reflow(self, num_samples=100):
        """
        Generate AMP samples using 1-step reflow (if you have reflow model).
        """
        print(f"Generating {num_samples} AMP samples with 1-step reflow...")
        
        # This would use the reflow implementation
        # For now, just use 1-step generation
        return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32)

def main():
    """Main generation function."""
    print("=== AMP Generation Pipeline with CFG ===")
    
    # Use the best model from training
    model_path = '/data2/edwardsun/flow_amp/checkpoints/amp_flow_model_best_optimized.pth'
    
    # Check if checkpoint exists
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        print(f"✓ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}")
        print(f"  Global step: {checkpoint['global_step']}")
        print(f"  Total samples: {checkpoint['total_samples']:,}")
    except:
        print(f"❌ Best model not found: {model_path}")
        print("Please train the flow matching model first using amp_flow_training.py")
        return
    
    # Initialize generator
    generator = AMPGenerator(model_path, device='cuda')
    
    # Generate samples with different CFG scales
    print("\n1. Generating with CFG scale 0.0 (no conditioning)...")
    samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0)
    
    print("\n2. Generating with CFG scale 3.0 (weak conditioning)...")
    samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0)
    
    print("\n3. Generating with CFG scale 7.5 (strong conditioning)...")
    samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5)
    
    print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...")
    samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0)
    
    # Create output directory if it doesn't exist
    output_dir = '/data2/edwardsun/generated_samples'
    os.makedirs(output_dir, exist_ok=True)
    
    # Get today's date for filename
    today = datetime.now().strftime('%Y%m%d')
    
    # Save generated samples with date
    torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt'))
    torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt'))
    torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt'))
    torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt'))
    
    print("\n✓ Generation complete!")
    print(f"Generated samples saved (Date: {today}):")
    print(f"  - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)")
    print(f"  - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)")
    print(f"  - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)")
    print(f"  - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)")
    
    print("\nCFG Analysis:")
    print("  - CFG scale 0.0: No conditioning, generates diverse sequences")
    print("  - CFG scale 3.0: Weak AMP conditioning")
    print("  - CFG scale 7.5: Strong AMP conditioning (recommended)")
    print("  - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)")
    
    print("\nNext steps:")
    print("1. Decode embeddings back to sequences using ESM-2 decoder")
    print("2. Evaluate AMP properties (antimicrobial activity, toxicity)")
    print("3. Compare sequences generated with different CFG scales")
    print("4. Implement conditioning for specific properties")

if __name__ == "__main__":
    main()