import torch import numpy as np import json import os from tqdm import tqdm import warnings from datetime import datetime warnings.filterwarnings('ignore') # Import our components from generate_amps import AMPGenerator from compressor_with_embeddings import Compressor, Decompressor from final_sequence_decoder import EmbeddingToSequenceConverter # Import local APEX wrapper try: from local_apex_wrapper import LocalAPEXWrapper APEX_AVAILABLE = True except ImportError as e: print(f"Warning: Local APEX not available: {e}") APEX_AVAILABLE = False class PeptideTester: """ Generate peptides and test them using APEX for antimicrobial activity. """ def __init__(self, model_path='amp_flow_model_final.pth', device='cuda'): self.device = device self.model_path = model_path # Initialize generator print("Initializing peptide generator...") self.generator = AMPGenerator(model_path, device) # Initialize embedding to sequence converter print("Initializing embedding to sequence converter...") self.converter = EmbeddingToSequenceConverter(device) # Initialize APEX if available if APEX_AVAILABLE: print("Initializing local APEX predictor...") self.apex = LocalAPEXWrapper() print("✓ Local APEX loaded successfully!") else: self.apex = None print("⚠ Local APEX not available - will only generate sequences") def generate_peptides(self, num_samples=100, num_steps=25, batch_size=32): """ Generate peptide sequences using the trained flow model. """ print(f"\n=== Generating {num_samples} Peptide Sequences ===") # Generate embeddings generated_embeddings = self.generator.generate_amps( num_samples=num_samples, num_steps=num_steps, batch_size=batch_size ) print(f"Generated embeddings shape: {generated_embeddings.shape}") # Convert embeddings to sequences using the converter sequences = self.converter.batch_embedding_to_sequences(generated_embeddings) # Filter valid sequences sequences = self.converter.filter_valid_sequences(sequences) return sequences def test_with_apex(self, sequences): """ Test generated sequences using APEX for antimicrobial activity. """ if not APEX_AVAILABLE: print("⚠ APEX not available - skipping activity prediction") return None print(f"\n=== Testing {len(sequences)} Sequences with APEX ===") results = [] for i, seq in tqdm(enumerate(sequences), desc="Testing with APEX"): try: # Predict antimicrobial activity using local APEX avg_mic = self.apex.predict_single(seq) is_amp = self.apex.is_amp(seq, threshold=32.0) # MIC threshold result = { 'sequence': seq, 'sequence_id': f'generated_{i:04d}', 'apex_score': avg_mic, # Lower MIC = better activity 'is_amp': is_amp, 'length': len(seq) } results.append(result) except Exception as e: print(f"Error testing sequence {i}: {e}") continue return results def analyze_results(self, results): """ Analyze the results of APEX testing. """ if not results: print("No results to analyze") return print(f"\n=== Analysis of {len(results)} Generated Peptides ===") # Extract scores scores = [r['apex_score'] for r in results] amp_count = sum(1 for r in results if r['is_amp']) print(f"Total sequences tested: {len(results)}") print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") print(f"Average MIC: {np.mean(scores):.2f} μg/mL") print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL") print(f"MIC std: {np.std(scores):.2f} μg/mL") # Show top candidates top_candidates = sorted(results, key=lambda x: x['apex_score'], reverse=True)[:10] print(f"\n=== Top 10 Candidates ===") for i, candidate in enumerate(top_candidates): print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " f"Length: {candidate['length']:2d} | " f"Sequence: {candidate['sequence']}") return results def save_results(self, results, filename='generated_peptides_results.json'): """ Save results to JSON file. """ if not results: print("No results to save") return output = { 'metadata': { 'model_path': self.model_path, 'num_sequences': len(results), 'generation_timestamp': str(torch.cuda.Event() if torch.cuda.is_available() else 'cpu'), 'apex_available': APEX_AVAILABLE }, 'results': results } with open(filename, 'w') as f: json.dump(output, f, indent=2) print(f"✓ Results saved to {filename}") def run_full_pipeline(self, num_samples=100, save_results=True): """ Run the complete pipeline: generate peptides and test with APEX. """ print("🚀 Starting Full Peptide Generation and Testing Pipeline") print("=" * 60) # Step 1: Generate peptides sequences = self.generate_peptides(num_samples=num_samples) # Step 2: Test with APEX results = self.test_with_apex(sequences) # Step 3: Analyze results if results: self.analyze_results(results) # Step 4: Save results if save_results: self.save_results(results) return results def main(): """ Main function to test existing decoded sequence files with APEX. """ print("🧬 AMP Flow Model - Testing Decoded Sequences with APEX") print("=" * 60) # Check if APEX is available if not APEX_AVAILABLE: print("❌ Local APEX not available - cannot test sequences") print("Please ensure local_apex_wrapper.py is properly set up") return # Initialize tester (we only need APEX, not the generator) print("Initializing APEX predictor...") apex = LocalAPEXWrapper() print("✓ Local APEX loaded successfully!") # Get today's date for filename today = datetime.now().strftime('%Y%m%d') # Define the decoded sequence files to test (using today's generated sequences) cfg_files = { 'No CFG (0.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_no_cfg_00_{today}.txt', 'Weak CFG (3.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_weak_cfg_30_{today}.txt', 'Strong CFG (7.5)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_strong_cfg_75_{today}.txt', 'Very Strong CFG (15.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_very_strong_cfg_150_{today}.txt' } all_results = {} for cfg_name, file_path in cfg_files.items(): print(f"\n{'='*60}") print(f"Testing {cfg_name} sequences...") print(f"Loading: {file_path}") if not os.path.exists(file_path): print(f"❌ File not found: {file_path}") continue # Read sequences from file sequences = [] with open(file_path, 'r') as f: for line in f: line = line.strip() if line and not line.startswith('#') and '\t' in line: # Parse sequence from tab-separated format parts = line.split('\t') if len(parts) >= 2: seq = parts[1].strip() if seq and len(seq) > 0: sequences.append(seq) print(f"✓ Loaded {len(sequences)} sequences from {file_path}") # Test sequences with APEX results = [] print(f"Testing {len(sequences)} sequences with APEX...") for i, seq in tqdm(enumerate(sequences), desc=f"Testing {cfg_name}"): try: # Predict antimicrobial activity using local APEX avg_mic = apex.predict_single(seq) is_amp = apex.is_amp(seq, threshold=32.0) # MIC threshold result = { 'sequence': seq, 'sequence_id': f'{cfg_name.lower().replace(" ", "_").replace("(", "").replace(")", "").replace(".", "")}_{i:03d}', 'cfg_setting': cfg_name, 'apex_score': avg_mic, # Lower MIC = better activity 'is_amp': is_amp, 'length': len(seq) } results.append(result) except Exception as e: print(f"Warning: Error testing sequence {i}: {e}") continue # Analyze results for this CFG setting if results: print(f"\n=== Analysis of {cfg_name} ===") scores = [r['apex_score'] for r in results] amp_count = sum(1 for r in results if r['is_amp']) print(f"Total sequences tested: {len(results)}") print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") print(f"Average MIC: {np.mean(scores):.2f} μg/mL") print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL") print(f"MIC std: {np.std(scores):.2f} μg/mL") # Show top 5 candidates for this CFG setting top_candidates = sorted(results, key=lambda x: x['apex_score'])[:5] # Lower MIC is better print(f"\n=== Top 5 Candidates ({cfg_name}) ===") for i, candidate in enumerate(top_candidates): print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " f"Length: {candidate['length']:2d} | " f"Sequence: {candidate['sequence']}") all_results[cfg_name] = results # Create output directory if it doesn't exist output_dir = '/data2/edwardsun/apex_results' os.makedirs(output_dir, exist_ok=True) # Save individual results with date output_file = os.path.join(output_dir, f"apex_results_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.json") with open(output_file, 'w') as f: json.dump({ 'metadata': { 'cfg_setting': cfg_name, 'num_sequences': len(results), 'apex_available': APEX_AVAILABLE }, 'results': results }, f, indent=2) print(f"✓ Results saved to {output_file}") # Overall comparison print(f"\n{'='*60}") print("OVERALL COMPARISON ACROSS CFG SETTINGS") print(f"{'='*60}") for cfg_name, results in all_results.items(): if results: scores = [r['apex_score'] for r in results] amp_count = sum(1 for r in results if r['is_amp']) print(f"\n{cfg_name}:") print(f" Total: {len(results)} | AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") print(f" Avg MIC: {np.mean(scores):.2f} μg/mL | Best MIC: {np.min(scores):.2f} μg/mL") # Find best overall candidates all_candidates = [] for cfg_name, results in all_results.items(): all_candidates.extend(results) if all_candidates: print(f"\n{'='*60}") print("TOP 10 OVERALL CANDIDATES (All CFG Settings)") print(f"{'='*60}") top_overall = sorted(all_candidates, key=lambda x: x['apex_score'])[:10] for i, candidate in enumerate(top_overall): print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " f"CFG: {candidate['cfg_setting']} | " f"Sequence: {candidate['sequence']}") # Create output directory if it doesn't exist output_dir = '/data2/edwardsun/apex_results' os.makedirs(output_dir, exist_ok=True) # Save overall results with date overall_results_file = os.path.join(output_dir, f'apex_results_all_cfg_comparison_{today}.json') with open(overall_results_file, 'w') as f: json.dump({ 'metadata': { 'date': today, 'total_sequences': len(all_candidates), 'apex_available': APEX_AVAILABLE, 'cfg_settings_tested': list(all_results.keys()) }, 'results': all_candidates }, f, indent=2) print(f"\n✓ Overall results saved to {overall_results_file}") # Save comprehensive MIC summary mic_summary_file = os.path.join(output_dir, f'mic_summary_{today}.json') mic_summary = { 'date': today, 'summary_by_cfg': {}, 'all_mics': [r['apex_score'] for r in all_candidates], 'amp_count': sum(1 for r in all_candidates if r['is_amp']), 'total_sequences': len(all_candidates) } for cfg_name, results in all_results.items(): if results: scores = [r['apex_score'] for r in results] amp_count = sum(1 for r in results if r['is_amp']) mic_summary['summary_by_cfg'][cfg_name] = { 'num_sequences': len(results), 'amp_count': amp_count, 'amp_percentage': amp_count/len(results)*100, 'avg_mic': np.mean(scores), 'min_mic': np.min(scores), 'max_mic': np.max(scores), 'std_mic': np.std(scores), 'all_mics': scores } with open(mic_summary_file, 'w') as f: json.dump(mic_summary, f, indent=2) print(f"✓ MIC summary saved to {mic_summary_file}") print(f"\n✅ APEX testing completed successfully!") print(f"Tested {len(all_candidates)} total sequences across all CFG settings") if __name__ == "__main__": main()