FlowAMP / test_generated_peptides.py
esunAI's picture
Initial FlowAMP upload: Complete project with all essential files
370f342
import torch
import numpy as np
import json
import os
from tqdm import tqdm
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
# Import our components
from generate_amps import AMPGenerator
from compressor_with_embeddings import Compressor, Decompressor
from final_sequence_decoder import EmbeddingToSequenceConverter
# Import local APEX wrapper
try:
from local_apex_wrapper import LocalAPEXWrapper
APEX_AVAILABLE = True
except ImportError as e:
print(f"Warning: Local APEX not available: {e}")
APEX_AVAILABLE = False
class PeptideTester:
"""
Generate peptides and test them using APEX for antimicrobial activity.
"""
def __init__(self, model_path='amp_flow_model_final.pth', device='cuda'):
self.device = device
self.model_path = model_path
# Initialize generator
print("Initializing peptide generator...")
self.generator = AMPGenerator(model_path, device)
# Initialize embedding to sequence converter
print("Initializing embedding to sequence converter...")
self.converter = EmbeddingToSequenceConverter(device)
# Initialize APEX if available
if APEX_AVAILABLE:
print("Initializing local APEX predictor...")
self.apex = LocalAPEXWrapper()
print("✓ Local APEX loaded successfully!")
else:
self.apex = None
print("⚠ Local APEX not available - will only generate sequences")
def generate_peptides(self, num_samples=100, num_steps=25, batch_size=32):
"""
Generate peptide sequences using the trained flow model.
"""
print(f"\n=== Generating {num_samples} Peptide Sequences ===")
# Generate embeddings
generated_embeddings = self.generator.generate_amps(
num_samples=num_samples,
num_steps=num_steps,
batch_size=batch_size
)
print(f"Generated embeddings shape: {generated_embeddings.shape}")
# Convert embeddings to sequences using the converter
sequences = self.converter.batch_embedding_to_sequences(generated_embeddings)
# Filter valid sequences
sequences = self.converter.filter_valid_sequences(sequences)
return sequences
def test_with_apex(self, sequences):
"""
Test generated sequences using APEX for antimicrobial activity.
"""
if not APEX_AVAILABLE:
print("⚠ APEX not available - skipping activity prediction")
return None
print(f"\n=== Testing {len(sequences)} Sequences with APEX ===")
results = []
for i, seq in tqdm(enumerate(sequences), desc="Testing with APEX"):
try:
# Predict antimicrobial activity using local APEX
avg_mic = self.apex.predict_single(seq)
is_amp = self.apex.is_amp(seq, threshold=32.0) # MIC threshold
result = {
'sequence': seq,
'sequence_id': f'generated_{i:04d}',
'apex_score': avg_mic, # Lower MIC = better activity
'is_amp': is_amp,
'length': len(seq)
}
results.append(result)
except Exception as e:
print(f"Error testing sequence {i}: {e}")
continue
return results
def analyze_results(self, results):
"""
Analyze the results of APEX testing.
"""
if not results:
print("No results to analyze")
return
print(f"\n=== Analysis of {len(results)} Generated Peptides ===")
# Extract scores
scores = [r['apex_score'] for r in results]
amp_count = sum(1 for r in results if r['is_amp'])
print(f"Total sequences tested: {len(results)}")
print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)")
print(f"Average MIC: {np.mean(scores):.2f} μg/mL")
print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL")
print(f"MIC std: {np.std(scores):.2f} μg/mL")
# Show top candidates
top_candidates = sorted(results, key=lambda x: x['apex_score'], reverse=True)[:10]
print(f"\n=== Top 10 Candidates ===")
for i, candidate in enumerate(top_candidates):
print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | "
f"Length: {candidate['length']:2d} | "
f"Sequence: {candidate['sequence']}")
return results
def save_results(self, results, filename='generated_peptides_results.json'):
"""
Save results to JSON file.
"""
if not results:
print("No results to save")
return
output = {
'metadata': {
'model_path': self.model_path,
'num_sequences': len(results),
'generation_timestamp': str(torch.cuda.Event() if torch.cuda.is_available() else 'cpu'),
'apex_available': APEX_AVAILABLE
},
'results': results
}
with open(filename, 'w') as f:
json.dump(output, f, indent=2)
print(f"✓ Results saved to {filename}")
def run_full_pipeline(self, num_samples=100, save_results=True):
"""
Run the complete pipeline: generate peptides and test with APEX.
"""
print("🚀 Starting Full Peptide Generation and Testing Pipeline")
print("=" * 60)
# Step 1: Generate peptides
sequences = self.generate_peptides(num_samples=num_samples)
# Step 2: Test with APEX
results = self.test_with_apex(sequences)
# Step 3: Analyze results
if results:
self.analyze_results(results)
# Step 4: Save results
if save_results:
self.save_results(results)
return results
def main():
"""
Main function to test existing decoded sequence files with APEX.
"""
print("🧬 AMP Flow Model - Testing Decoded Sequences with APEX")
print("=" * 60)
# Check if APEX is available
if not APEX_AVAILABLE:
print("❌ Local APEX not available - cannot test sequences")
print("Please ensure local_apex_wrapper.py is properly set up")
return
# Initialize tester (we only need APEX, not the generator)
print("Initializing APEX predictor...")
apex = LocalAPEXWrapper()
print("✓ Local APEX loaded successfully!")
# Get today's date for filename
today = datetime.now().strftime('%Y%m%d')
# Define the decoded sequence files to test (using today's generated sequences)
cfg_files = {
'No CFG (0.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_no_cfg_00_{today}.txt',
'Weak CFG (3.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_weak_cfg_30_{today}.txt',
'Strong CFG (7.5)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_strong_cfg_75_{today}.txt',
'Very Strong CFG (15.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_very_strong_cfg_150_{today}.txt'
}
all_results = {}
for cfg_name, file_path in cfg_files.items():
print(f"\n{'='*60}")
print(f"Testing {cfg_name} sequences...")
print(f"Loading: {file_path}")
if not os.path.exists(file_path):
print(f"❌ File not found: {file_path}")
continue
# Read sequences from file
sequences = []
with open(file_path, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '\t' in line:
# Parse sequence from tab-separated format
parts = line.split('\t')
if len(parts) >= 2:
seq = parts[1].strip()
if seq and len(seq) > 0:
sequences.append(seq)
print(f"✓ Loaded {len(sequences)} sequences from {file_path}")
# Test sequences with APEX
results = []
print(f"Testing {len(sequences)} sequences with APEX...")
for i, seq in tqdm(enumerate(sequences), desc=f"Testing {cfg_name}"):
try:
# Predict antimicrobial activity using local APEX
avg_mic = apex.predict_single(seq)
is_amp = apex.is_amp(seq, threshold=32.0) # MIC threshold
result = {
'sequence': seq,
'sequence_id': f'{cfg_name.lower().replace(" ", "_").replace("(", "").replace(")", "").replace(".", "")}_{i:03d}',
'cfg_setting': cfg_name,
'apex_score': avg_mic, # Lower MIC = better activity
'is_amp': is_amp,
'length': len(seq)
}
results.append(result)
except Exception as e:
print(f"Warning: Error testing sequence {i}: {e}")
continue
# Analyze results for this CFG setting
if results:
print(f"\n=== Analysis of {cfg_name} ===")
scores = [r['apex_score'] for r in results]
amp_count = sum(1 for r in results if r['is_amp'])
print(f"Total sequences tested: {len(results)}")
print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)")
print(f"Average MIC: {np.mean(scores):.2f} μg/mL")
print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL")
print(f"MIC std: {np.std(scores):.2f} μg/mL")
# Show top 5 candidates for this CFG setting
top_candidates = sorted(results, key=lambda x: x['apex_score'])[:5] # Lower MIC is better
print(f"\n=== Top 5 Candidates ({cfg_name}) ===")
for i, candidate in enumerate(top_candidates):
print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | "
f"Length: {candidate['length']:2d} | "
f"Sequence: {candidate['sequence']}")
all_results[cfg_name] = results
# Create output directory if it doesn't exist
output_dir = '/data2/edwardsun/apex_results'
os.makedirs(output_dir, exist_ok=True)
# Save individual results with date
output_file = os.path.join(output_dir, f"apex_results_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.json")
with open(output_file, 'w') as f:
json.dump({
'metadata': {
'cfg_setting': cfg_name,
'num_sequences': len(results),
'apex_available': APEX_AVAILABLE
},
'results': results
}, f, indent=2)
print(f"✓ Results saved to {output_file}")
# Overall comparison
print(f"\n{'='*60}")
print("OVERALL COMPARISON ACROSS CFG SETTINGS")
print(f"{'='*60}")
for cfg_name, results in all_results.items():
if results:
scores = [r['apex_score'] for r in results]
amp_count = sum(1 for r in results if r['is_amp'])
print(f"\n{cfg_name}:")
print(f" Total: {len(results)} | AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)")
print(f" Avg MIC: {np.mean(scores):.2f} μg/mL | Best MIC: {np.min(scores):.2f} μg/mL")
# Find best overall candidates
all_candidates = []
for cfg_name, results in all_results.items():
all_candidates.extend(results)
if all_candidates:
print(f"\n{'='*60}")
print("TOP 10 OVERALL CANDIDATES (All CFG Settings)")
print(f"{'='*60}")
top_overall = sorted(all_candidates, key=lambda x: x['apex_score'])[:10]
for i, candidate in enumerate(top_overall):
print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | "
f"CFG: {candidate['cfg_setting']} | "
f"Sequence: {candidate['sequence']}")
# Create output directory if it doesn't exist
output_dir = '/data2/edwardsun/apex_results'
os.makedirs(output_dir, exist_ok=True)
# Save overall results with date
overall_results_file = os.path.join(output_dir, f'apex_results_all_cfg_comparison_{today}.json')
with open(overall_results_file, 'w') as f:
json.dump({
'metadata': {
'date': today,
'total_sequences': len(all_candidates),
'apex_available': APEX_AVAILABLE,
'cfg_settings_tested': list(all_results.keys())
},
'results': all_candidates
}, f, indent=2)
print(f"\n✓ Overall results saved to {overall_results_file}")
# Save comprehensive MIC summary
mic_summary_file = os.path.join(output_dir, f'mic_summary_{today}.json')
mic_summary = {
'date': today,
'summary_by_cfg': {},
'all_mics': [r['apex_score'] for r in all_candidates],
'amp_count': sum(1 for r in all_candidates if r['is_amp']),
'total_sequences': len(all_candidates)
}
for cfg_name, results in all_results.items():
if results:
scores = [r['apex_score'] for r in results]
amp_count = sum(1 for r in results if r['is_amp'])
mic_summary['summary_by_cfg'][cfg_name] = {
'num_sequences': len(results),
'amp_count': amp_count,
'amp_percentage': amp_count/len(results)*100,
'avg_mic': np.mean(scores),
'min_mic': np.min(scores),
'max_mic': np.max(scores),
'std_mic': np.std(scores),
'all_mics': scores
}
with open(mic_summary_file, 'w') as f:
json.dump(mic_summary, f, indent=2)
print(f"✓ MIC summary saved to {mic_summary_file}")
print(f"\n✅ APEX testing completed successfully!")
print(f"Tested {len(all_candidates)} total sequences across all CFG settings")
if __name__ == "__main__":
main()