| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from tqdm import tqdm |
| | import os |
| | from datetime import datetime |
| | |
| | try: |
| | from torchdiffeq import odeint |
| | TORCHDIFFEQ_AVAILABLE = True |
| | print("✓ torchdiffeq available for proper ODE solving") |
| | except ImportError: |
| | TORCHDIFFEQ_AVAILABLE = False |
| | print("⚠️ torchdiffeq not available, using manual Euler integration") |
| |
|
| | |
| | from compressor_with_embeddings import Compressor, Decompressor |
| | from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG |
| |
|
| | class AMPGenerator: |
| | """ |
| | Generate AMP samples using trained ProtFlow model. |
| | """ |
| | |
| | def __init__(self, model_path, device='cuda'): |
| | self.device = device |
| | |
| | |
| | self._load_models(model_path) |
| | |
| | |
| | self.stats = torch.load('normalization_stats.pt', map_location=device) |
| | |
| | def _load_models(self, model_path): |
| | """Load trained models.""" |
| | print("Loading trained models...") |
| | |
| | |
| | self.compressor = Compressor().to(self.device) |
| | self.decompressor = Decompressor().to(self.device) |
| | |
| | self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device)) |
| | self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device)) |
| | |
| | |
| | self.flow_model = AMPFlowMatcherCFGConcat( |
| | hidden_dim=480, |
| | compressed_dim=80, |
| | n_layers=12, |
| | n_heads=16, |
| | dim_ff=3072, |
| | max_seq_len=25, |
| | use_cfg=True |
| | ).to(self.device) |
| | |
| | checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| | |
| | |
| | state_dict = checkpoint['flow_model_state_dict'] |
| | new_state_dict = {} |
| | |
| | for key, value in state_dict.items(): |
| | |
| | if key.startswith('_orig_mod.'): |
| | new_key = key[10:] |
| | else: |
| | new_key = key |
| | new_state_dict[new_key] = value |
| | |
| | self.flow_model.load_state_dict(new_state_dict) |
| | |
| | print(f"✓ All models loaded successfully from step {checkpoint['step']}!") |
| | print(f" Loss at checkpoint: {checkpoint['loss']:.6f}") |
| | |
| | |
| | if TORCHDIFFEQ_AVAILABLE: |
| | print("✓ Enhanced with proper ODE solving (torchdiffeq)") |
| | else: |
| | print("⚠️ Using fallback Euler integration") |
| | |
| | def _create_ode_func(self, cfg_scale=7.5): |
| | """Create ODE function for torchdiffeq integration.""" |
| | |
| | def ode_func(t, x): |
| | """ |
| | ODE function: dx/dt = v_theta(x, t) |
| | |
| | Args: |
| | t: scalar time (single float) |
| | x: state tensor [B*L*D] (flattened) |
| | Returns: |
| | dx/dt: derivative [B*L*D] (flattened) |
| | """ |
| | |
| | batch_size, seq_len, dim = self.current_shape |
| | x = x.view(batch_size, seq_len, dim) |
| | |
| | |
| | t_tensor = torch.full((batch_size,), t, device=self.device, dtype=x.dtype) |
| | |
| | |
| | if cfg_scale > 0: |
| | |
| | amp_labels = torch.full((batch_size,), 0, device=self.device) |
| | vt_cond = self.flow_model(x, t_tensor, labels=amp_labels) |
| | |
| | |
| | mask_labels = torch.full((batch_size,), 2, device=self.device) |
| | vt_uncond = self.flow_model(x, t_tensor, labels=mask_labels) |
| | |
| | |
| | vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
| | else: |
| | |
| | mask_labels = torch.full((batch_size,), 2, device=self.device) |
| | vt = self.flow_model(x, t_tensor, labels=mask_labels) |
| | |
| | |
| | return vt.view(-1) |
| | |
| | return ode_func |
| | |
| | def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5, |
| | ode_method='dopri5', rtol=1e-5, atol=1e-6): |
| | """ |
| | Generate AMP samples using flow matching with CFG and improved ODE solving. |
| | |
| | Args: |
| | num_samples: Number of AMP samples to generate |
| | num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow) |
| | batch_size: Batch size for generation |
| | cfg_scale: CFG guidance scale (higher = stronger conditioning) |
| | ode_method: ODE solver method ('dopri5', 'rk4', 'euler', 'adaptive_heun') |
| | rtol: Relative tolerance for adaptive solvers |
| | atol: Absolute tolerance for adaptive solvers |
| | """ |
| | method_str = f"{ode_method} ODE solver" if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler' else "manual Euler integration" |
| | print(f"Generating {num_samples} AMP samples with {method_str} (CFG scale: {cfg_scale})...") |
| | if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': |
| | print(f" Method: {ode_method}, rtol={rtol}, atol={atol}") |
| | |
| | self.flow_model.eval() |
| | self.compressor.eval() |
| | self.decompressor.eval() |
| | |
| | all_generated = [] |
| | |
| | with torch.no_grad(): |
| | for i in tqdm(range(0, num_samples, batch_size), desc="Generating with improved ODE"): |
| | current_batch = min(batch_size, num_samples - i) |
| | |
| | |
| | eps = torch.randn(current_batch, 25, 80, device=self.device) |
| | |
| | |
| | if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': |
| | |
| | try: |
| | |
| | self.current_shape = eps.shape |
| | |
| | |
| | ode_func = self._create_ode_func(cfg_scale=cfg_scale) |
| | |
| | |
| | t_span = torch.tensor([1.0, 0.0], device=self.device, dtype=eps.dtype) |
| | |
| | |
| | y0 = eps.view(-1) |
| | |
| | |
| | if ode_method in ['dopri5', 'adaptive_heun']: |
| | |
| | solution = odeint( |
| | ode_func, y0, t_span, |
| | method=ode_method, |
| | rtol=rtol, |
| | atol=atol, |
| | options={'max_num_steps': 1000} |
| | ) |
| | else: |
| | |
| | solution = odeint( |
| | ode_func, y0, t_span, |
| | method=ode_method, |
| | options={'step_size': 0.04} |
| | ) |
| | |
| | |
| | xt = solution[-1].view(self.current_shape) |
| | |
| | except Exception as e: |
| | print(f"⚠️ ODE solving failed for batch {i//batch_size + 1}: {e}") |
| | print("Falling back to Euler method...") |
| | |
| | xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) |
| | else: |
| | |
| | xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) |
| | |
| | |
| | decompressed = self.decompressor(xt) |
| | |
| | |
| | m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] |
| | decompressed = decompressed * (mx - mn + 1e-8) + mn |
| | decompressed = decompressed * s + m |
| | |
| | all_generated.append(decompressed.cpu()) |
| | |
| | |
| | generated_embeddings = torch.cat(all_generated, dim=0) |
| | |
| | print(f"✓ Generated {generated_embeddings.shape[0]} AMP embeddings") |
| | print(f" Shape: {generated_embeddings.shape}") |
| | print(f" Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}") |
| | |
| | return generated_embeddings |
| | |
| | def _generate_with_euler(self, eps, current_batch, cfg_scale, num_steps): |
| | """Fallback Euler integration method (original implementation).""" |
| | xt = eps.clone() |
| | amp_labels = torch.full((current_batch,), 0, device=self.device) |
| | mask_labels = torch.full((current_batch,), 2, device=self.device) |
| | |
| | for step in range(num_steps): |
| | t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) |
| | |
| | |
| | if cfg_scale > 0: |
| | |
| | vt_cond = self.flow_model(xt, t, labels=amp_labels) |
| | |
| | |
| | vt_uncond = self.flow_model(xt, t, labels=mask_labels) |
| | |
| | |
| | vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
| | else: |
| | |
| | vt = self.flow_model(xt, t, labels=mask_labels) |
| | |
| | |
| | dt = -1.0 / num_steps |
| | xt = xt + vt * dt |
| | |
| | return xt |
| | |
| | def compare_ode_methods(self, num_samples=20, cfg_scale=7.5): |
| | """ |
| | Compare different ODE solving methods for quality assessment. |
| | """ |
| | if not TORCHDIFFEQ_AVAILABLE: |
| | print("⚠️ torchdiffeq not available, cannot compare ODE methods") |
| | return self.generate_amps(num_samples=num_samples, cfg_scale=cfg_scale) |
| | |
| | methods = ['euler', 'rk4', 'dopri5', 'adaptive_heun'] |
| | results = {} |
| | |
| | print("🔬 Comparing ODE solving methods...") |
| | |
| | for method in methods: |
| | print(f"\n--- Testing {method} ---") |
| | try: |
| | start_time = torch.cuda.Event(enable_timing=True) |
| | end_time = torch.cuda.Event(enable_timing=True) |
| | |
| | start_time.record() |
| | embeddings = self.generate_amps( |
| | num_samples=num_samples, |
| | batch_size=10, |
| | cfg_scale=cfg_scale, |
| | ode_method=method |
| | ) |
| | end_time.record() |
| | |
| | torch.cuda.synchronize() |
| | elapsed_time = start_time.elapsed_time(end_time) / 1000.0 |
| | |
| | results[method] = { |
| | 'embeddings': embeddings, |
| | 'time': elapsed_time, |
| | 'mean': embeddings.mean().item(), |
| | 'std': embeddings.std().item(), |
| | 'success': True |
| | } |
| | |
| | print(f"✓ {method}: {elapsed_time:.2f}s, mean={embeddings.mean():.4f}, std={embeddings.std():.4f}") |
| | |
| | except Exception as e: |
| | print(f"❌ {method} failed: {e}") |
| | results[method] = {'success': False, 'error': str(e)} |
| | |
| | return results |
| | |
| | def generate_with_reflow(self, num_samples=100): |
| | """ |
| | Generate AMP samples using 1-step reflow (if you have reflow model). |
| | """ |
| | print(f"Generating {num_samples} AMP samples with 1-step reflow...") |
| | |
| | |
| | |
| | return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32) |
| |
|
| | def main(): |
| | """Main generation function.""" |
| | print("=== AMP Generation Pipeline with CFG ===") |
| | |
| | |
| | model_path = '/data2/edwardsun/flow_checkpoints/amp_flow_model_best_optimized.pth' |
| | |
| | |
| | try: |
| | checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) |
| | print(f"✓ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}") |
| | print(f" Global step: {checkpoint['global_step']}") |
| | print(f" Total samples: {checkpoint['total_samples']:,}") |
| | except: |
| | print(f"❌ Best model not found: {model_path}") |
| | print("Please train the flow matching model first using amp_flow_training.py") |
| | return |
| | |
| | |
| | generator = AMPGenerator(model_path, device='cuda') |
| | |
| | |
| | if TORCHDIFFEQ_AVAILABLE: |
| | print("\n🔬 Comparing ODE solving methods...") |
| | comparison_results = generator.compare_ode_methods(num_samples=10, cfg_scale=7.5) |
| | |
| | |
| | best_method = 'dopri5' |
| | print(f"\n🚀 Using {best_method} for main generation...") |
| | else: |
| | best_method = 'euler' |
| | print("\n⚠️ Using fallback Euler integration...") |
| | |
| | |
| | print("\n1. Generating with CFG scale 0.0 (no conditioning)...") |
| | samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0, ode_method=best_method) |
| | |
| | print("\n2. Generating with CFG scale 3.0 (weak conditioning)...") |
| | samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0, ode_method=best_method) |
| | |
| | print("\n3. Generating with CFG scale 7.5 (strong conditioning)...") |
| | samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5, ode_method=best_method) |
| | |
| | print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...") |
| | samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0, ode_method=best_method) |
| | |
| | |
| | output_dir = '/data2/edwardsun/generated_samples' |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | today = datetime.now().strftime('%Y%m%d') |
| | |
| | |
| | torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt')) |
| | torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt')) |
| | torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt')) |
| | torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt')) |
| | |
| | print("\n✓ Generation complete!") |
| | print(f"Generated samples saved (Date: {today}):") |
| | print(f" - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)") |
| | print(f" - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)") |
| | print(f" - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)") |
| | print(f" - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)") |
| | |
| | print("\nCFG Analysis:") |
| | print(" - CFG scale 0.0: No conditioning, generates diverse sequences") |
| | print(" - CFG scale 3.0: Weak AMP conditioning") |
| | print(" - CFG scale 7.5: Strong AMP conditioning (recommended)") |
| | print(" - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)") |
| | |
| | print("\nNext steps:") |
| | print("1. Decode embeddings back to sequences using ESM-2 decoder") |
| | print("2. Evaluate with ProtFlow metrics (FPD, MMD, ESM-2 perplexity)") |
| | print("3. Compare sequences generated with different CFG scales") |
| | print("4. Evaluate AMP properties (antimicrobial activity, toxicity)") |
| | if TORCHDIFFEQ_AVAILABLE: |
| | print(f"5. ✓ Enhanced generation with {best_method} ODE solver") |
| | else: |
| | print("5. Install torchdiffeq for improved ODE solving: pip install torchdiffeq") |
| |
|
| | if __name__ == "__main__": |
| | main() |