File size: 12,522 Bytes
404d784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
"""

Optimize protein sequences using ColiFormer.



This script provides a user-friendly interface for codon optimization,

supporting both single sequences and batch processing via FASTA files.



Usage:

    # Single sequence

    python scripts/optimize_sequence.py --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" --output optimized.fasta



    # Batch processing from FASTA file

    python scripts/optimize_sequence.py --input sequences.fasta --output optimized.fasta --batch



    # With GC content constraints

    python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --gc-min 0.45 --gc-max 0.55

"""

import argparse
import os
import sys
from pathlib import Path
from typing import Any, List, Tuple

# Add parent directory to path to import CodonTransformer
sys.path.insert(0, str(Path(__file__).parent.parent))


def parse_fasta(fasta_path: str) -> List[Tuple[str, str]]:
    """

    Parse FASTA file into list of (name, sequence) tuples.



    Args:

        fasta_path: Path to FASTA file



    Returns:

        List of (name, sequence) tuples

    """
    sequences = []
    current_name = None
    current_seq = []

    with open(fasta_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if current_name is not None:
                    sequences.append((current_name, ''.join(current_seq)))
                current_name = line[1:] if len(line) > 1 else f"sequence_{len(sequences)+1}"
                current_seq = []
            else:
                current_seq.append(line.upper())

        if current_name is not None:
            sequences.append((current_name, ''.join(current_seq)))

    return sequences


def write_fasta(output_path: str, sequences: List[Tuple[str, str]]):
    """

    Write sequences to FASTA file.



    Args:

        output_path: Output FASTA file path

        sequences: List of (name, sequence) tuples

    """
    with open(output_path, 'w') as f:
        for name, seq in sequences:
            f.write(f">{name}\n")
            # Write sequence in 60-character lines
            for i in range(0, len(seq), 60):
                f.write(seq[i:i+60] + "\n")


def optimize_single_sequence(

    protein: str,

    model: Any,

    tokenizer: Any,

    device: Any,

    organism: str = "Escherichia coli general",

    gc_min: float = None,

    gc_max: float = None,

    cai_weights: dict = None,

    tai_weights: dict = None

) -> dict:
    """

    Optimize a single protein sequence.



    Args:

        protein: Protein sequence string

        model: Loaded ColiFormer model

        tokenizer: Tokenizer

        device: PyTorch device

        organism: Target organism name

        gc_min: Minimum GC content (0-1)

        gc_max: Maximum GC content (0-1)

        cai_weights: CAI weights dictionary

        tai_weights: tAI weights dictionary



    Returns:

        Dictionary with optimization results

    """
    # Lazy imports so `python scripts/optimize_sequence.py --help` works without ML deps installed.
    from CodonTransformer.CodonPrediction import predict_dna_sequence
    from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
    from CAI import CAI

    # Determine GC bounds if specified
    gc_bounds = None
    use_constrained = False
    if gc_min is not None and gc_max is not None:
        gc_bounds = (gc_min, gc_max)
        use_constrained = True

    # Run optimization
    output = predict_dna_sequence(
        protein=protein,
        organism=organism,
        device=device,
        model=model,
        tokenizer=tokenizer,
        deterministic=True,
        match_protein=True,
        use_constrained_search=use_constrained,
        gc_bounds=gc_bounds,
        beam_size=20 if use_constrained else 5,
    )

    if isinstance(output, list):
        output = output[0]

    optimized_dna = output.predicted_dna

    # Calculate metrics
    gc_content = get_GC_content(optimized_dna) / 100.0  # Convert to fraction

    metrics = {
        'protein': protein,
        'optimized_dna': optimized_dna,
        'gc_content': gc_content,
        'length': len(optimized_dna),
    }

    if cai_weights:
        try:
            metrics['cai'] = CAI(optimized_dna, weights=cai_weights)
        except:
            metrics['cai'] = None
    else:
        metrics['cai'] = None

    if tai_weights:
        try:
            metrics['tai'] = calculate_tAI(optimized_dna, tai_weights)
        except:
            metrics['tai'] = None
    else:
        metrics['tai'] = None

    return metrics


def load_reference_data(ref_sequences_path: str = None):
    """

    Load reference sequences and calculate CAI weights.



    Args:

        ref_sequences_path: Path to CSV with reference sequences



    Returns:

        Tuple of (cai_weights, tai_weights)

    """
    # Lazy imports so `--help` works without ML deps installed.
    import pandas as pd
    from CAI import relative_adaptiveness
    from CodonTransformer.CodonEvaluation import get_ecoli_tai_weights

    cai_weights = None
    tai_weights = None

    # Try to load reference sequences for CAI
    if ref_sequences_path and os.path.exists(ref_sequences_path):
        try:
            df = pd.read_csv(ref_sequences_path)
            if 'dna_sequence' in df.columns:
                ref_sequences = df['dna_sequence'].tolist()
                cai_weights = relative_adaptiveness(sequences=ref_sequences)
                print(f"Loaded CAI weights from {len(ref_sequences)} reference sequences")
        except Exception as e:
            print(f"Warning: Could not load CAI weights: {e}")

    # Load tAI weights
    try:
        tai_weights = get_ecoli_tai_weights()
        print("Loaded E. coli tAI weights")
    except Exception as e:
        print(f"Warning: Could not load tAI weights: {e}")

    return cai_weights, tai_weights


def main():
    """Main entry point for sequence optimization."""
    parser = argparse.ArgumentParser(
        description="Optimize protein sequences using ENCOT",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""

Examples:

    # Single sequence

    python scripts/optimize_sequence.py --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" --output optimized.fasta



    # Batch processing from FASTA file

    python scripts/optimize_sequence.py --input sequences.fasta --output optimized.fasta --batch



    # With GC content constraints

    python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --gc-min 0.45 --gc-max 0.55



    # Use custom checkpoint

    python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --checkpoint models/my_model.ckpt

        """
    )
    parser.add_argument(
        "--input",
        type=str,
        required=True,
        help="Input protein sequence (string) or FASTA file path"
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output FASTA file path"
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="Path to model checkpoint (default: auto-download from Hugging Face)"
    )
    parser.add_argument(
        "--organism",
        type=str,
        default="Escherichia coli general",
        help="Target organism (default: Escherichia coli general)"
    )
    parser.add_argument(
        "--gc-min",
        type=float,
        default=None,
        help="Minimum GC content (0-1, e.g., 0.45 for 45%%)"
    )
    parser.add_argument(
        "--gc-max",
        type=float,
        default=None,
        help="Maximum GC content (0-1, e.g., 0.55 for 55%%)"
    )
    parser.add_argument(
        "--batch",
        action="store_true",
        help="Process input as FASTA file with multiple sequences"
    )
    parser.add_argument(
        "--ref-sequences",
        type=str,
        default="data/ecoli_processed_genes.csv",
        help="Path to reference sequences CSV for CAI calculation"
    )
    parser.add_argument(
        "--use-gpu",
        action="store_true",
        help="Use GPU if available"
    )

    args = parser.parse_args()

    try:
        # Lazy imports so `--help` works without ML deps installed.
        import torch
        from transformers import AutoTokenizer
        from CodonTransformer.CodonPrediction import load_model
        import pandas as pd

        # Setup device
        device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
        print(f"Using device: {device}")

        # Load model
        print("Loading ColiFormer model...")
        if args.checkpoint:
            model = load_model(model_path=args.checkpoint, device=device)
            print(f"Loaded model from {args.checkpoint}")
        else:
            # Try to load from Hugging Face
            try:
                from huggingface_hub import hf_hub_download
                checkpoint_path = hf_hub_download(
                    repo_id="saketh11/ColiFormer",
                    filename="balanced_alm_finetune.ckpt",
                    cache_dir="./hf_cache"
                )
                model = load_model(model_path=checkpoint_path, device=device)
                print("Loaded model from Hugging Face (saketh11/ColiFormer)")
            except Exception as e:
                print(f"Warning: Could not load from Hugging Face: {e}")
                print("Falling back to base CodonTransformer model...")
                from transformers import BigBirdForMaskedLM
                model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device)

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")

        # Load reference data for metrics
        cai_weights, tai_weights = load_reference_data(args.ref_sequences)

        # Parse input
        if args.batch or os.path.exists(args.input):
            # FASTA file
            print(f"Reading sequences from {args.input}...")
            sequences = parse_fasta(args.input)
            print(f"Found {len(sequences)} sequences")
        else:
            # Single sequence string
            sequences = [("sequence_1", args.input.upper())]

        # Optimize sequences
        optimized_sequences = []
        results = []

        for i, (name, protein_seq) in enumerate(sequences, 1):
            print(f"\nOptimizing sequence {i}/{len(sequences)}: {name}")

            metrics = optimize_single_sequence(
                protein=protein_seq,
                model=model,
                tokenizer=tokenizer,
                device=device,
                organism=args.organism,
                gc_min=args.gc_min,
                gc_max=args.gc_max,
                cai_weights=cai_weights,
                tai_weights=tai_weights
            )

            optimized_sequences.append((name, metrics['optimized_dna']))
            results.append({
                'name': name,
                'protein_length': len(protein_seq),
                'dna_length': metrics['length'],
                'gc_content': f"{metrics['gc_content']*100:.2f}%",
                'cai': metrics['cai'],
                'tai': metrics['tai'],
            })

            print(f"  GC content: {metrics['gc_content']*100:.2f}%")
            if metrics['cai']:
                print(f"  CAI: {metrics['cai']:.3f}")
            if metrics['tai']:
                print(f"  tAI: {metrics['tai']:.3f}")

        # Write output
        write_fasta(args.output, optimized_sequences)
        print(f"\nOptimized sequences saved to {args.output}")

        # Print summary
        if len(results) > 1:
            print("\n" + "="*60)
            print("Summary Statistics")
            print("="*60)
            df = pd.DataFrame(results)
            print(df.to_string(index=False))
            print("="*60)

    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()