|
|
import torch |
|
|
import numpy as np |
|
|
import json |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
from generate_amps import AMPGenerator |
|
|
from compressor_with_embeddings import Compressor, Decompressor |
|
|
from final_sequence_decoder import EmbeddingToSequenceConverter |
|
|
|
|
|
|
|
|
try: |
|
|
from local_apex_wrapper import LocalAPEXWrapper |
|
|
APEX_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"Warning: Local APEX not available: {e}") |
|
|
APEX_AVAILABLE = False |
|
|
|
|
|
class PeptideTester: |
|
|
""" |
|
|
Generate peptides and test them using APEX for antimicrobial activity. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path='amp_flow_model_final.pth', device='cuda'): |
|
|
self.device = device |
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
print("Initializing peptide generator...") |
|
|
self.generator = AMPGenerator(model_path, device) |
|
|
|
|
|
|
|
|
print("Initializing embedding to sequence converter...") |
|
|
self.converter = EmbeddingToSequenceConverter(device) |
|
|
|
|
|
|
|
|
if APEX_AVAILABLE: |
|
|
print("Initializing local APEX predictor...") |
|
|
self.apex = LocalAPEXWrapper() |
|
|
print("✓ Local APEX loaded successfully!") |
|
|
else: |
|
|
self.apex = None |
|
|
print("⚠ Local APEX not available - will only generate sequences") |
|
|
|
|
|
def generate_peptides(self, num_samples=100, num_steps=25, batch_size=32): |
|
|
""" |
|
|
Generate peptide sequences using the trained flow model. |
|
|
""" |
|
|
print(f"\n=== Generating {num_samples} Peptide Sequences ===") |
|
|
|
|
|
|
|
|
generated_embeddings = self.generator.generate_amps( |
|
|
num_samples=num_samples, |
|
|
num_steps=num_steps, |
|
|
batch_size=batch_size |
|
|
) |
|
|
|
|
|
print(f"Generated embeddings shape: {generated_embeddings.shape}") |
|
|
|
|
|
|
|
|
sequences = self.converter.batch_embedding_to_sequences(generated_embeddings) |
|
|
|
|
|
|
|
|
sequences = self.converter.filter_valid_sequences(sequences) |
|
|
|
|
|
return sequences |
|
|
|
|
|
|
|
|
|
|
|
def test_with_apex(self, sequences): |
|
|
""" |
|
|
Test generated sequences using APEX for antimicrobial activity. |
|
|
""" |
|
|
if not APEX_AVAILABLE: |
|
|
print("⚠ APEX not available - skipping activity prediction") |
|
|
return None |
|
|
|
|
|
print(f"\n=== Testing {len(sequences)} Sequences with APEX ===") |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, seq in tqdm(enumerate(sequences), desc="Testing with APEX"): |
|
|
try: |
|
|
|
|
|
avg_mic = self.apex.predict_single(seq) |
|
|
is_amp = self.apex.is_amp(seq, threshold=32.0) |
|
|
|
|
|
result = { |
|
|
'sequence': seq, |
|
|
'sequence_id': f'generated_{i:04d}', |
|
|
'apex_score': avg_mic, |
|
|
'is_amp': is_amp, |
|
|
'length': len(seq) |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error testing sequence {i}: {e}") |
|
|
continue |
|
|
|
|
|
return results |
|
|
|
|
|
def analyze_results(self, results): |
|
|
""" |
|
|
Analyze the results of APEX testing. |
|
|
""" |
|
|
if not results: |
|
|
print("No results to analyze") |
|
|
return |
|
|
|
|
|
print(f"\n=== Analysis of {len(results)} Generated Peptides ===") |
|
|
|
|
|
|
|
|
scores = [r['apex_score'] for r in results] |
|
|
amp_count = sum(1 for r in results if r['is_amp']) |
|
|
|
|
|
print(f"Total sequences tested: {len(results)}") |
|
|
print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") |
|
|
print(f"Average MIC: {np.mean(scores):.2f} μg/mL") |
|
|
print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL") |
|
|
print(f"MIC std: {np.std(scores):.2f} μg/mL") |
|
|
|
|
|
|
|
|
top_candidates = sorted(results, key=lambda x: x['apex_score'], reverse=True)[:10] |
|
|
|
|
|
print(f"\n=== Top 10 Candidates ===") |
|
|
for i, candidate in enumerate(top_candidates): |
|
|
print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " |
|
|
f"Length: {candidate['length']:2d} | " |
|
|
f"Sequence: {candidate['sequence']}") |
|
|
|
|
|
return results |
|
|
|
|
|
def save_results(self, results, filename='generated_peptides_results.json'): |
|
|
""" |
|
|
Save results to JSON file. |
|
|
""" |
|
|
if not results: |
|
|
print("No results to save") |
|
|
return |
|
|
|
|
|
output = { |
|
|
'metadata': { |
|
|
'model_path': self.model_path, |
|
|
'num_sequences': len(results), |
|
|
'generation_timestamp': str(torch.cuda.Event() if torch.cuda.is_available() else 'cpu'), |
|
|
'apex_available': APEX_AVAILABLE |
|
|
}, |
|
|
'results': results |
|
|
} |
|
|
|
|
|
with open(filename, 'w') as f: |
|
|
json.dump(output, f, indent=2) |
|
|
|
|
|
print(f"✓ Results saved to {filename}") |
|
|
|
|
|
def run_full_pipeline(self, num_samples=100, save_results=True): |
|
|
""" |
|
|
Run the complete pipeline: generate peptides and test with APEX. |
|
|
""" |
|
|
print("🚀 Starting Full Peptide Generation and Testing Pipeline") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
sequences = self.generate_peptides(num_samples=num_samples) |
|
|
|
|
|
|
|
|
results = self.test_with_apex(sequences) |
|
|
|
|
|
|
|
|
if results: |
|
|
self.analyze_results(results) |
|
|
|
|
|
|
|
|
if save_results: |
|
|
self.save_results(results) |
|
|
|
|
|
return results |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function to test existing decoded sequence files with APEX. |
|
|
""" |
|
|
print("🧬 AMP Flow Model - Testing Decoded Sequences with APEX") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if not APEX_AVAILABLE: |
|
|
print("❌ Local APEX not available - cannot test sequences") |
|
|
print("Please ensure local_apex_wrapper.py is properly set up") |
|
|
return |
|
|
|
|
|
|
|
|
print("Initializing APEX predictor...") |
|
|
apex = LocalAPEXWrapper() |
|
|
print("✓ Local APEX loaded successfully!") |
|
|
|
|
|
|
|
|
today = datetime.now().strftime('%Y%m%d') |
|
|
|
|
|
|
|
|
cfg_files = { |
|
|
'No CFG (0.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_no_cfg_00_{today}.txt', |
|
|
'Weak CFG (3.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_weak_cfg_30_{today}.txt', |
|
|
'Strong CFG (7.5)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_strong_cfg_75_{today}.txt', |
|
|
'Very Strong CFG (15.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_very_strong_cfg_150_{today}.txt' |
|
|
} |
|
|
|
|
|
all_results = {} |
|
|
|
|
|
for cfg_name, file_path in cfg_files.items(): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Testing {cfg_name} sequences...") |
|
|
print(f"Loading: {file_path}") |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
print(f"❌ File not found: {file_path}") |
|
|
continue |
|
|
|
|
|
|
|
|
sequences = [] |
|
|
with open(file_path, 'r') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line and not line.startswith('#') and '\t' in line: |
|
|
|
|
|
parts = line.split('\t') |
|
|
if len(parts) >= 2: |
|
|
seq = parts[1].strip() |
|
|
if seq and len(seq) > 0: |
|
|
sequences.append(seq) |
|
|
|
|
|
print(f"✓ Loaded {len(sequences)} sequences from {file_path}") |
|
|
|
|
|
|
|
|
results = [] |
|
|
print(f"Testing {len(sequences)} sequences with APEX...") |
|
|
|
|
|
for i, seq in tqdm(enumerate(sequences), desc=f"Testing {cfg_name}"): |
|
|
try: |
|
|
|
|
|
avg_mic = apex.predict_single(seq) |
|
|
is_amp = apex.is_amp(seq, threshold=32.0) |
|
|
|
|
|
result = { |
|
|
'sequence': seq, |
|
|
'sequence_id': f'{cfg_name.lower().replace(" ", "_").replace("(", "").replace(")", "").replace(".", "")}_{i:03d}', |
|
|
'cfg_setting': cfg_name, |
|
|
'apex_score': avg_mic, |
|
|
'is_amp': is_amp, |
|
|
'length': len(seq) |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Error testing sequence {i}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
if results: |
|
|
print(f"\n=== Analysis of {cfg_name} ===") |
|
|
scores = [r['apex_score'] for r in results] |
|
|
amp_count = sum(1 for r in results if r['is_amp']) |
|
|
|
|
|
print(f"Total sequences tested: {len(results)}") |
|
|
print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") |
|
|
print(f"Average MIC: {np.mean(scores):.2f} μg/mL") |
|
|
print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL") |
|
|
print(f"MIC std: {np.std(scores):.2f} μg/mL") |
|
|
|
|
|
|
|
|
top_candidates = sorted(results, key=lambda x: x['apex_score'])[:5] |
|
|
|
|
|
print(f"\n=== Top 5 Candidates ({cfg_name}) ===") |
|
|
for i, candidate in enumerate(top_candidates): |
|
|
print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " |
|
|
f"Length: {candidate['length']:2d} | " |
|
|
f"Sequence: {candidate['sequence']}") |
|
|
|
|
|
all_results[cfg_name] = results |
|
|
|
|
|
|
|
|
output_dir = '/data2/edwardsun/apex_results' |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
output_file = os.path.join(output_dir, f"apex_results_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.json") |
|
|
with open(output_file, 'w') as f: |
|
|
json.dump({ |
|
|
'metadata': { |
|
|
'cfg_setting': cfg_name, |
|
|
'num_sequences': len(results), |
|
|
'apex_available': APEX_AVAILABLE |
|
|
}, |
|
|
'results': results |
|
|
}, f, indent=2) |
|
|
print(f"✓ Results saved to {output_file}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("OVERALL COMPARISON ACROSS CFG SETTINGS") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for cfg_name, results in all_results.items(): |
|
|
if results: |
|
|
scores = [r['apex_score'] for r in results] |
|
|
amp_count = sum(1 for r in results if r['is_amp']) |
|
|
print(f"\n{cfg_name}:") |
|
|
print(f" Total: {len(results)} | AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") |
|
|
print(f" Avg MIC: {np.mean(scores):.2f} μg/mL | Best MIC: {np.min(scores):.2f} μg/mL") |
|
|
|
|
|
|
|
|
all_candidates = [] |
|
|
for cfg_name, results in all_results.items(): |
|
|
all_candidates.extend(results) |
|
|
|
|
|
if all_candidates: |
|
|
print(f"\n{'='*60}") |
|
|
print("TOP 10 OVERALL CANDIDATES (All CFG Settings)") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
top_overall = sorted(all_candidates, key=lambda x: x['apex_score'])[:10] |
|
|
for i, candidate in enumerate(top_overall): |
|
|
print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " |
|
|
f"CFG: {candidate['cfg_setting']} | " |
|
|
f"Sequence: {candidate['sequence']}") |
|
|
|
|
|
|
|
|
output_dir = '/data2/edwardsun/apex_results' |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
overall_results_file = os.path.join(output_dir, f'apex_results_all_cfg_comparison_{today}.json') |
|
|
with open(overall_results_file, 'w') as f: |
|
|
json.dump({ |
|
|
'metadata': { |
|
|
'date': today, |
|
|
'total_sequences': len(all_candidates), |
|
|
'apex_available': APEX_AVAILABLE, |
|
|
'cfg_settings_tested': list(all_results.keys()) |
|
|
}, |
|
|
'results': all_candidates |
|
|
}, f, indent=2) |
|
|
print(f"\n✓ Overall results saved to {overall_results_file}") |
|
|
|
|
|
|
|
|
mic_summary_file = os.path.join(output_dir, f'mic_summary_{today}.json') |
|
|
mic_summary = { |
|
|
'date': today, |
|
|
'summary_by_cfg': {}, |
|
|
'all_mics': [r['apex_score'] for r in all_candidates], |
|
|
'amp_count': sum(1 for r in all_candidates if r['is_amp']), |
|
|
'total_sequences': len(all_candidates) |
|
|
} |
|
|
|
|
|
for cfg_name, results in all_results.items(): |
|
|
if results: |
|
|
scores = [r['apex_score'] for r in results] |
|
|
amp_count = sum(1 for r in results if r['is_amp']) |
|
|
mic_summary['summary_by_cfg'][cfg_name] = { |
|
|
'num_sequences': len(results), |
|
|
'amp_count': amp_count, |
|
|
'amp_percentage': amp_count/len(results)*100, |
|
|
'avg_mic': np.mean(scores), |
|
|
'min_mic': np.min(scores), |
|
|
'max_mic': np.max(scores), |
|
|
'std_mic': np.std(scores), |
|
|
'all_mics': scores |
|
|
} |
|
|
|
|
|
with open(mic_summary_file, 'w') as f: |
|
|
json.dump(mic_summary, f, indent=2) |
|
|
print(f"✓ MIC summary saved to {mic_summary_file}") |
|
|
|
|
|
print(f"\n✅ APEX testing completed successfully!") |
|
|
print(f"Tested {len(all_candidates)} total sequences across all CFG settings") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |