Spaces:
Sleeping
Sleeping
| """ | |
| File: benchmark_evaluation.py | |
| ------------------------------ | |
| Benchmark E. coli protein sequences with ENCOT, generate optimized DNA, | |
| compute metrics (CAI, tAI, GC, CFD, cis-elements), and produce summary tables | |
| and figures. | |
| """ | |
| import sys | |
| import os | |
| import argparse | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import json | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from datetime import datetime | |
| import time | |
| from tqdm import tqdm | |
| from typing import Dict, List, Tuple, Any | |
| from CAI import CAI, relative_adaptiveness | |
| from CodonTransformer.CodonData import ( | |
| download_codon_frequencies_from_kazusa, | |
| get_codon_frequencies, | |
| ) | |
| from CodonTransformer.CodonPrediction import ( | |
| load_model, | |
| predict_dna_sequence, | |
| ) | |
| from CodonTransformer.CodonEvaluation import ( | |
| get_GC_content, | |
| get_ecoli_tai_weights, | |
| get_min_max_profile, | |
| calculate_tAI, | |
| count_negative_cis_elements, | |
| ) | |
| from transformers import AutoTokenizer | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from evaluate_optimizer import translate_dna_to_protein | |
| def find_longest_orf(dna_sequence: str) -> str: | |
| """ | |
| Find the longest open reading frame (ORF) in a DNA sequence. | |
| Args: | |
| dna_sequence (str): Input DNA sequence (ATCGN characters). | |
| Returns: | |
| str: Longest ORF (from start to stop codon), or empty string if none. | |
| """ | |
| dna_sequence = dna_sequence.upper() | |
| start_codons = ['ATG'] | |
| stop_codons = ['TAA', 'TAG', 'TGA'] | |
| longest_orf = "" | |
| for frame in range(3): | |
| current_orf = "" | |
| in_orf = False | |
| for i in range(frame, len(dna_sequence) - 2, 3): | |
| codon = dna_sequence[i:i+3] | |
| if len(codon) != 3: | |
| break | |
| if codon in start_codons and not in_orf: | |
| in_orf = True | |
| current_orf = codon | |
| elif in_orf: | |
| current_orf += codon | |
| if codon in stop_codons: | |
| if len(current_orf) > len(longest_orf): | |
| longest_orf = current_orf | |
| in_orf = False | |
| current_orf = "" | |
| if in_orf and len(current_orf) > len(longest_orf): | |
| longest_orf = current_orf | |
| return longest_orf | |
| def _detect_columns(df: pd.DataFrame, name_hint: str | None = None, seq_hint: str | None = None) -> tuple[str | None, str]: | |
| """ | |
| Detect name and sequence columns in a case-insensitive, robust way. | |
| Args: | |
| df (pd.DataFrame): Input DataFrame read from Excel. | |
| name_hint (str | None): Optional override for name/label column (case-insensitive). | |
| seq_hint (str | None): Optional override for sequence column (case-insensitive). | |
| Returns: | |
| tuple[str | None, str]: Detected (name_column or None, sequence_column). | |
| Raises: | |
| ValueError: If a sequence-like column cannot be found. | |
| """ | |
| cols = list(df.columns) | |
| low_map = {c.lower().strip(): c for c in cols} | |
| # If hints are provided and exist (case-insensitive), honor them | |
| if name_hint: | |
| nh = name_hint.lower().strip() | |
| if nh in low_map: | |
| name_col = low_map[nh] | |
| else: | |
| name_col = None | |
| else: | |
| name_col = None | |
| if seq_hint: | |
| sh = seq_hint.lower().strip() | |
| if sh in low_map: | |
| seq_col = low_map[sh] | |
| else: | |
| seq_col = None | |
| else: | |
| seq_col = None | |
| # If not found, try candidates | |
| if name_col is None: | |
| name_candidates = [ | |
| 'name','id','title','gene','protein','description','label','accession','locus','entry','uniprot','ncbi','protein name' | |
| ] | |
| for k in name_candidates: | |
| if k in low_map: | |
| name_col = low_map[k] | |
| break | |
| if seq_col is None: | |
| seq_candidates = [ | |
| # protein-first | |
| 'protein sequence','protein_sequence','protein','aa sequence','aa_sequence','aa','amino acid sequence','amino_acid_sequence', | |
| # generic | |
| 'sequence','seq', | |
| # dna/cds | |
| 'cds','dna','coding sequence','coding_sequence','cds sequence','cds_sequence' | |
| ] | |
| for k in seq_candidates: | |
| if k in low_map: | |
| seq_col = low_map[k] | |
| break | |
| if not seq_col: | |
| raise ValueError(f"Could not detect sequence column. Available columns: {cols}") | |
| return name_col, seq_col | |
| def parse_excel_sequences(excel_path: str, name_col: str | None = None, seq_col: str | None = None, sheet_name: str | int | None = None) -> List[Dict[str, str]]: | |
| """ | |
| Parse sequences from the benchmark Excel file and auto-detect relevant columns. | |
| Args: | |
| excel_path (str): Path to the Excel file. | |
| name_col (str | None): Optional override for sequence name column. | |
| seq_col (str | None): Optional override for sequence column. | |
| sheet_name (str | int | None): Sheet name or index (default: first sheet). | |
| Returns: | |
| List[Dict[str, str]]: List of standardized sequence records with fields: | |
| id, name, protein_sequence, original_sequence (DNA or None), is_dna. | |
| Raises: | |
| ValueError: If a sequence column cannot be detected. | |
| """ | |
| sn = sheet_name | |
| if isinstance(sn, str) and sn.isdigit(): | |
| sn = int(sn) | |
| if sn is None: | |
| sn = 0 | |
| df_or_dict = pd.read_excel(excel_path, sheet_name=sn) | |
| if isinstance(df_or_dict, dict): | |
| first_title, df = next(iter(df_or_dict.items())) | |
| print(f"Using sheet: {first_title}") | |
| else: | |
| df = df_or_dict | |
| sequences = [] | |
| detected_name_col, detected_seq_col = _detect_columns(df, name_col, seq_col) | |
| print(f"Detected columns -> name: {detected_name_col or '[generated]'}, sequence: {detected_seq_col}") | |
| for idx, row in df.iterrows(): | |
| sequence = str(row[detected_seq_col]).strip() | |
| if detected_name_col: | |
| name = str(row[detected_name_col]).strip() | |
| else: | |
| name = f"seq_{idx}" | |
| if name.startswith('>'): | |
| name = name[1:].strip() | |
| sequence = ''.join(filter(str.isalpha, sequence)) | |
| dna_chars = sum(1 for c in sequence.upper() if c in 'ATCGN') | |
| is_dna = (dna_chars / len(sequence)) > 0.95 if len(sequence) > 0 else False | |
| if is_dna: | |
| longest_orf = find_longest_orf(sequence) | |
| if longest_orf and len(longest_orf) >= 30: | |
| original_dna = longest_orf | |
| protein_seq = translate_dna_to_protein(longest_orf) | |
| else: | |
| truncated_len = (len(sequence) // 3) * 3 | |
| if truncated_len >= 30: | |
| original_dna = sequence[:truncated_len] | |
| protein_seq = translate_dna_to_protein(original_dna) | |
| else: | |
| continue | |
| if '*' in protein_seq: | |
| stop_pos = protein_seq.find('*') | |
| if stop_pos >= 10: | |
| protein_seq = protein_seq[:stop_pos] | |
| original_dna = original_dna[:stop_pos*3] | |
| else: | |
| continue | |
| else: | |
| protein_seq = sequence.upper() | |
| protein_seq = protein_seq.replace('*', '') | |
| original_dna = None | |
| if len(protein_seq) < 10: | |
| continue | |
| sequences.append({ | |
| 'id': idx, | |
| 'name': name, | |
| 'protein_sequence': protein_seq, | |
| 'original_sequence': original_dna, | |
| 'is_dna': is_dna | |
| }) | |
| return sequences | |
| def calculate_cfd(dna_sequence: str, codon_frequencies: Dict) -> float: | |
| """ | |
| Calculate Codon Frequency Distribution (CFD) similarity to a reference. | |
| Args: | |
| dna_sequence (str): Input DNA sequence. | |
| codon_frequencies (Dict): Reference frequencies; accepts flattened mapping | |
| or an amino2codon structure (will be flattened). | |
| Returns: | |
| float: Similarity score in [0, 1] where higher is more similar. | |
| """ | |
| if not dna_sequence: | |
| return 0.0 | |
| codon_count = {} | |
| total_codons = 0 | |
| for i in range(0, len(dna_sequence) - 2, 3): | |
| codon = dna_sequence[i:i+3].upper() | |
| if len(codon) == 3: | |
| codon_count[codon] = codon_count.get(codon, 0) + 1 | |
| total_codons += 1 | |
| seq_freq = {} | |
| if total_codons > 0: | |
| for codon, count in codon_count.items(): | |
| seq_freq[codon] = count / total_codons | |
| # Flatten amino2codon frequencies if needed | |
| flat_codon_freq = {} | |
| if isinstance(codon_frequencies, dict): | |
| first_key = next(iter(codon_frequencies.keys())) | |
| if isinstance(codon_frequencies[first_key], tuple) and len(codon_frequencies[first_key]) == 2: | |
| for amino, (codons, freqs) in codon_frequencies.items(): | |
| for codon, freq in zip(codons, freqs): | |
| flat_codon_freq[codon] = freq | |
| else: | |
| flat_codon_freq = codon_frequencies | |
| similarity = 0.0 | |
| count = 0 | |
| for codon in set(list(seq_freq.keys()) + list(flat_codon_freq.keys())): | |
| seq_f = seq_freq.get(codon, 0.0) | |
| ref_f = flat_codon_freq.get(codon, 0.0) | |
| similarity += 1 - abs(seq_f - ref_f) | |
| count += 1 | |
| return similarity / count if count > 0 else 0.0 | |
| def run_model_on_sequences( | |
| sequences: List[Dict], | |
| model, | |
| tokenizer, | |
| device, | |
| cai_weights: Dict, | |
| tai_weights: Dict, | |
| codon_frequencies: Dict, | |
| reference_profile: List[float], | |
| output_dir: str | |
| ) -> pd.DataFrame: | |
| """ | |
| Run ColiFormer on protein sequences and compute metrics for optimized DNA. | |
| Args: | |
| sequences (List[Dict]): Parsed sequence records. | |
| model: Loaded ColiFormer model. | |
| tokenizer: Tokenizer used by the model. | |
| device: Torch device. | |
| cai_weights (Dict): CAI weights. | |
| tai_weights (Dict): tAI weights. | |
| codon_frequencies (Dict): Reference codon frequencies. | |
| reference_profile (List[float]): Reserved for DTW profile (unused here). | |
| output_dir (str): Directory for outputs (not written here). | |
| Returns: | |
| pd.DataFrame: Per-sequence metrics and optimized DNA. | |
| """ | |
| results = [] | |
| print(f"Processing {len(sequences)} sequences...") | |
| for seq_data in tqdm(sequences, desc="Optimizing sequences"): | |
| protein_seq = seq_data['protein_sequence'] | |
| if len(protein_seq) < 10: | |
| continue | |
| try: | |
| start_time = time.time() | |
| output = predict_dna_sequence( | |
| protein=protein_seq, | |
| organism="Escherichia coli general", | |
| device=device, | |
| model=model, | |
| deterministic=True, | |
| match_protein=True, | |
| ) | |
| runtime = time.time() - start_time | |
| if isinstance(output, list): | |
| optimized_dna = output[0].predicted_dna | |
| else: | |
| optimized_dna = output.predicted_dna | |
| original_metrics = {} | |
| if seq_data['is_dna'] and seq_data['original_sequence']: | |
| original_dna = seq_data['original_sequence'].upper() | |
| original_metrics = { | |
| 'original_cai': CAI(original_dna, weights=cai_weights), | |
| 'original_gc': get_GC_content(original_dna), | |
| 'original_tai': calculate_tAI(original_dna, tai_weights), | |
| 'original_cfd': calculate_cfd(original_dna, codon_frequencies), | |
| 'original_neg_cis': count_negative_cis_elements(original_dna), | |
| } | |
| optimized_metrics = { | |
| 'optimized_cai': CAI(optimized_dna, weights=cai_weights), | |
| 'optimized_gc': get_GC_content(optimized_dna), | |
| 'optimized_tai': calculate_tAI(optimized_dna, tai_weights), | |
| 'optimized_cfd': calculate_cfd(optimized_dna, codon_frequencies), | |
| 'optimized_neg_cis': count_negative_cis_elements(optimized_dna), | |
| 'runtime': runtime, | |
| } | |
| result = { | |
| 'id': seq_data['id'], | |
| 'name': seq_data['name'], | |
| 'protein_sequence': protein_seq, | |
| 'protein_length': len(protein_seq), | |
| 'optimized_dna': optimized_dna, | |
| **original_metrics, | |
| **optimized_metrics, | |
| } | |
| results.append(result) | |
| except Exception as e: | |
| print(f"Error processing sequence {seq_data['id']}: {str(e)}") | |
| continue | |
| return pd.DataFrame(results) | |
| def generate_visualizations(results_df: pd.DataFrame, output_dir: str): | |
| """ | |
| Generate visualizations and a metrics summary table. | |
| Saves: | |
| - CAI before/after bar plot | |
| - Median CAI comparison | |
| - Metrics distribution panel | |
| - CSV summary table | |
| Args: | |
| results_df (pd.DataFrame): Results from optimization. | |
| output_dir (str): Output directory root. | |
| Returns: | |
| pd.DataFrame: Summary table of aggregate metrics. | |
| """ | |
| plt.style.use('seaborn-v0_8-darkgrid') | |
| sns.set_palette("husl") | |
| fig_dir = os.path.join(output_dir, 'figures') | |
| os.makedirs(fig_dir, exist_ok=True) | |
| # 1. Before/After CAI Graph | |
| if 'original_cai' in results_df.columns: | |
| plt.figure(figsize=(12, 8)) | |
| before_cai = results_df['original_cai'].dropna() | |
| after_cai = results_df.loc[before_cai.index, 'optimized_cai'] | |
| x = np.arange(len(before_cai)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(14, 8)) | |
| bars1 = ax.bar(x - width/2, before_cai, width, label='Before Optimization', alpha=0.8) | |
| bars2 = ax.bar(x + width/2, after_cai, width, label='After Optimization', alpha=0.8) | |
| ax.set_xlabel('Sequence Index', fontsize=12) | |
| ax.set_ylabel('CAI Score', fontsize=12) | |
| ax.set_title('ENCOT: CAI Before and After Optimization', fontsize=14, fontweight='bold') | |
| ax.set_xticks(x[::5]) # Show every 5th label | |
| ax.set_xticklabels(x[::5]) | |
| ax.legend() | |
| ax.grid(axis='y', alpha=0.3) | |
| avg_before = before_cai.mean() | |
| avg_after = after_cai.mean() | |
| improvement = ((avg_after - avg_before) / avg_before) * 100 | |
| ax.text(0.02, 0.98, f'Average CAI Before: {avg_before:.3f}\nAverage CAI After: {avg_after:.3f}\nImprovement: {improvement:.1f}%', | |
| transform=ax.transAxes, fontsize=10, verticalalignment='top', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(fig_dir, 'cai_before_after.png'), dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"CAI Before/After graph saved to {os.path.join(fig_dir, 'cai_before_after.png')}") | |
| # 1b. Median CAI Before/After Graph | |
| plt.figure(figsize=(8, 6)) | |
| median_before = before_cai.median() | |
| median_after = after_cai.median() | |
| categories = ['Before Optimization', 'After Optimization'] | |
| medians = [median_before, median_after] | |
| colors = ['#ff7f0e', '#2ca02c'] | |
| bars = plt.bar(categories, medians, color=colors, alpha=0.8, width=0.6) | |
| plt.ylabel('Median CAI Score', fontsize=12) | |
| plt.title('ENCOT: Median CAI Before and After Optimization', fontsize=14, fontweight='bold') | |
| plt.ylim(0, max(medians) * 1.2) | |
| for bar, median in zip(bars, medians): | |
| plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, | |
| f'{median:.3f}', ha='center', va='bottom', fontweight='bold') | |
| improvement_pct = ((median_after - median_before) / median_before) * 100 | |
| plt.text(0.5, max(medians) * 0.95, f'Improvement: {improvement_pct:.1f}%', | |
| ha='center', transform=plt.gca().transData, fontsize=12, | |
| bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7)) | |
| plt.grid(axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(fig_dir, 'median_cai_comparison.png'), dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Median CAI comparison graph saved to {os.path.join(fig_dir, 'median_cai_comparison.png')}") | |
| # 2. Summary metrics table | |
| metrics_summary = {} | |
| if 'original_cai' in results_df.columns: | |
| metrics_summary['CAI'] = { | |
| 'Before': results_df['original_cai'].mean(), | |
| 'After': results_df['optimized_cai'].mean(), | |
| 'Improvement': ((results_df['optimized_cai'].mean() - results_df['original_cai'].mean()) / results_df['original_cai'].mean()) * 100 | |
| } | |
| metrics_summary['GC Content (%)'] = { | |
| 'Before': results_df['original_gc'].mean(), | |
| 'After': results_df['optimized_gc'].mean(), | |
| 'Difference': results_df['optimized_gc'].mean() - results_df['original_gc'].mean() | |
| } | |
| metrics_summary['tAI'] = { | |
| 'Before': results_df['original_tai'].mean(), | |
| 'After': results_df['optimized_tai'].mean(), | |
| 'Improvement': ((results_df['optimized_tai'].mean() - results_df['original_tai'].mean()) / results_df['original_tai'].mean()) * 100 | |
| } | |
| metrics_summary['CFD'] = { | |
| 'Before': results_df['original_cfd'].mean(), | |
| 'After': results_df['optimized_cfd'].mean(), | |
| 'Improvement': ((results_df['optimized_cfd'].mean() - results_df['original_cfd'].mean()) / results_df['original_cfd'].mean()) * 100 | |
| } | |
| metrics_summary['Negative Cis Elements'] = { | |
| 'Before': results_df['original_neg_cis'].mean(), | |
| 'After': results_df['optimized_neg_cis'].mean(), | |
| 'Reduction': results_df['original_neg_cis'].mean() - results_df['optimized_neg_cis'].mean() | |
| } | |
| else: | |
| metrics_summary['CAI'] = { | |
| 'Optimized': results_df['optimized_cai'].mean(), | |
| 'Std Dev': results_df['optimized_cai'].std() | |
| } | |
| metrics_summary['GC Content (%)'] = { | |
| 'Optimized': results_df['optimized_gc'].mean(), | |
| 'Std Dev': results_df['optimized_gc'].std() | |
| } | |
| metrics_summary['tAI'] = { | |
| 'Optimized': results_df['optimized_tai'].mean(), | |
| 'Std Dev': results_df['optimized_tai'].std() | |
| } | |
| metrics_summary['CFD'] = { | |
| 'Optimized': results_df['optimized_cfd'].mean(), | |
| 'Std Dev': results_df['optimized_cfd'].std() | |
| } | |
| metrics_summary['Negative Cis Elements'] = { | |
| 'Optimized': results_df['optimized_neg_cis'].mean(), | |
| 'Std Dev': results_df['optimized_neg_cis'].std() | |
| } | |
| metrics_summary['Runtime (seconds)'] = { | |
| 'Mean': results_df['runtime'].mean(), | |
| 'Median': results_df['runtime'].median(), | |
| 'Total': results_df['runtime'].sum() | |
| } | |
| summary_df = pd.DataFrame(metrics_summary).T | |
| summary_df = summary_df.round(4) | |
| summary_df.to_csv(os.path.join(output_dir, 'metrics_summary.csv')) | |
| print(f"\nMetrics Summary saved to {os.path.join(output_dir, 'metrics_summary.csv')}") | |
| print("\n" + "="*60) | |
| print("METRICS SUMMARY:") | |
| print("="*60) | |
| print(summary_df.to_string()) | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 10)) | |
| axes = axes.flatten() | |
| metrics_to_plot = [ | |
| ('optimized_cai', 'CAI Distribution'), | |
| ('optimized_gc', 'GC Content Distribution (%)'), | |
| ('optimized_tai', 'tAI Distribution'), | |
| ('optimized_cfd', 'CFD Distribution'), | |
| ('optimized_neg_cis', 'Negative Cis Elements'), | |
| ('runtime', 'Runtime Distribution (seconds)') | |
| ] | |
| for idx, (col, title) in enumerate(metrics_to_plot): | |
| if col in results_df.columns: | |
| axes[idx].hist(results_df[col].dropna(), bins=20, edgecolor='black', alpha=0.7) | |
| axes[idx].set_title(title, fontsize=10, fontweight='bold') | |
| axes[idx].set_xlabel(col.replace('optimized_', '').replace('_', ' ').title()) | |
| axes[idx].set_ylabel('Frequency') | |
| axes[idx].grid(axis='y', alpha=0.3) | |
| mean_val = results_df[col].mean() | |
| axes[idx].axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.3f}') | |
| axes[idx].legend() | |
| plt.suptitle('ENCOT: Optimization Metrics Distribution', fontsize=14, fontweight='bold', y=1.02) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(fig_dir, 'metrics_distribution.png'), dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Metrics distribution plot saved to {os.path.join(fig_dir, 'metrics_distribution.png')}") | |
| return summary_df | |
| def main(): | |
| """CLI entrypoint to run the ENCOT benchmark workflow.""" | |
| parser = argparse.ArgumentParser(description="Benchmark ENCOT on E. coli sequences") | |
| parser.add_argument("--excel_path", type=str, default="Benchmark 80 sequences.xlsx", | |
| help="Path to benchmark Excel file") | |
| parser.add_argument("--checkpoint_path", type=str, default="models/ecoli-codon-optimizer/finetune_best.ckpt", | |
| help="Path to fine-tuned model checkpoint") | |
| parser.add_argument("--natural_sequences_path", type=str, default="data/ecoli_processed_genes.csv", | |
| help="Path to natural E. coli sequences for CAI calculation") | |
| parser.add_argument("--output_dir", type=str, default="benchmark_results", | |
| help="Directory to save results") | |
| parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available") | |
| parser.add_argument("--name_col", type=str, default=None, help="Optional: column name for sequence label (case-insensitive)") | |
| parser.add_argument("--seq_col", type=str, default=None, help="Optional: column name for sequence (case-insensitive)") | |
| parser.add_argument("--sheet_name", type=str, default=None, help="Optional: Excel sheet name or index") | |
| args = parser.parse_args() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_dir = os.path.join(args.output_dir, f"run_{timestamp}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| print("="*60) | |
| print("ENCOT BENCHMARK EVALUATION") | |
| print("="*60) | |
| device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu") | |
| print(f"Using device: {device}") | |
| print(f"\nLoading sequences from {args.excel_path}...") | |
| sequences = parse_excel_sequences( | |
| args.excel_path, | |
| name_col=args.name_col, | |
| seq_col=args.seq_col, | |
| sheet_name=args.sheet_name, | |
| ) | |
| print(f"Loaded {len(sequences)} sequences") | |
| print("\nLoading ENCOT model...") | |
| model = load_model(model_path=args.checkpoint_path, device=device) | |
| tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") | |
| print("Model loaded successfully") | |
| print("\nPreparing evaluation utilities...") | |
| natural_df = pd.read_csv(args.natural_sequences_path) | |
| ref_sequences = natural_df['dna_sequence'].tolist() | |
| cai_weights = relative_adaptiveness(sequences=ref_sequences) | |
| print("CAI weights generated") | |
| tai_weights = get_ecoli_tai_weights() | |
| print("tAI weights loaded") | |
| try: | |
| codon_frequencies = download_codon_frequencies_from_kazusa(taxonomy_id=83333) | |
| print("Codon frequencies loaded from Kazusa") | |
| except Exception as e: | |
| print(f"Warning: Kazusa download failed ({e}). Using local frequencies.") | |
| codon_frequencies = get_codon_frequencies( | |
| ref_sequences, organism="Escherichia coli general" | |
| ) | |
| reference_profile = [] | |
| print("\n" + "="*60) | |
| print("RUNNING OPTIMIZATION...") | |
| print("="*60) | |
| results_df = run_model_on_sequences( | |
| sequences=sequences, | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| cai_weights=cai_weights, | |
| tai_weights=tai_weights, | |
| codon_frequencies=codon_frequencies, | |
| reference_profile=reference_profile, | |
| output_dir=output_dir | |
| ) | |
| results_path = os.path.join(output_dir, 'optimization_results.csv') | |
| results_df.to_csv(results_path, index=False) | |
| print(f"\nRaw results saved to {results_path}") | |
| optimized_sequences = results_df[['id', 'name', 'protein_sequence', 'optimized_dna']].copy() | |
| optimized_sequences['protein_length'] = results_df['protein_length'] | |
| optimized_sequences['dna_length'] = optimized_sequences['optimized_dna'].apply(len) | |
| optimized_sequences['optimized_cai'] = results_df['optimized_cai'] | |
| optimized_sequences['optimized_gc'] = results_df['optimized_gc'] | |
| optimized_sequences['optimized_tai'] = results_df['optimized_tai'] | |
| if 'original_cai' in results_df.columns: | |
| optimized_sequences['original_cai'] = results_df['original_cai'] | |
| optimized_sequences['cai_improvement'] = ((results_df['optimized_cai'] - results_df['original_cai']) / results_df['original_cai'] * 100).round(2) | |
| optimized_sequences_path = os.path.join(output_dir, 'optimized_dna_sequences.csv') | |
| optimized_sequences.to_csv(optimized_sequences_path, index=False) | |
| print(f"Optimized DNA sequences saved to {optimized_sequences_path}") | |
| print("\n" + "="*60) | |
| print("GENERATING VISUALIZATIONS...") | |
| print("="*60) | |
| summary_df = generate_visualizations(results_df, output_dir) | |
| print("\n" + "="*60) | |
| print("BENCHMARK EVALUATION COMPLETE") | |
| print("="*60) | |
| print(f"Results saved to: {output_dir}") | |
| print(f"Total sequences processed: {len(results_df)}") | |
| print(f"Average runtime per sequence: {results_df['runtime'].mean():.2f} seconds") | |
| print(f"Total runtime: {results_df['runtime'].sum():.2f} seconds") | |
| if __name__ == "__main__": | |
| main() | |