import torch from transformers import LogitsProcessor from Bio.Seq import Seq # Complete amino acid to codon mapping (human genetic code) aa_to_codon_human = { 'F': ['TTT', 'TTC'], 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'], 'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], 'Y': ['TAT', 'TAC'], 'C': ['TGT', 'TGC'], 'W': ['TGG'], 'P': ['CCT', 'CCC', 'CCA', 'CCG'], 'H': ['CAT', 'CAC'], 'Q': ['CAA', 'CAG'], 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'], 'I': ['ATT', 'ATC', 'ATA'], 'M': ['ATG'], 'T': ['ACT', 'ACC', 'ACA', 'ACG'], 'N': ['AAT', 'AAC'], 'K': ['AAA', 'AAG'], 'V': ['GTT', 'GTC', 'GTA', 'GTG'], 'A': ['GCT', 'GCC', 'GCA', 'GCG'], 'D': ['GAT', 'GAC'], 'E': ['GAA', 'GAG'], 'G': ['GGT', 'GGC', 'GGA', 'GGG'], '*': ['TAA', 'TAG', 'TGA'] } class SynonymMaskingLogitsProcessor(LogitsProcessor): def __init__(self, current_aa, tokenizer, aa_to_codon=None): self.current_aa = current_aa self.tokenizer = tokenizer self.aa_to_codon = aa_to_codon if aa_to_codon is not None else aa_to_codon_human def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: synonymous_codons = self.aa_to_codon.get(self.current_aa, []) synonym_token_ids = self.tokenizer.convert_tokens_to_ids(synonymous_codons) mask = torch.ones_like(scores) * -float('inf') mask[:, synonym_token_ids] = 0 return scores + mask def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top_k=None, top_p=None, aa_to_codon=None, model=None, tokenizer=None): """ Generate synonymous codon alternatives for a given set of codons. Args: initial_codons: List of codons to optimize temperature: Sampling temperature top_k: Top-k sampling parameter top_p: Top-p (nucleus) sampling parameter aa_to_codon: Amino acid to codon mapping (defaults to human genetic code) model: The CodonGPT model (if None, uses global 'model' variable) tokenizer: The codon tokenizer (if None, uses global 'tokenizer' variable) Returns: List of optimized codons """ # Use global variables if not provided as parameters if model is None: import builtins model = getattr(builtins, 'model', globals().get('model')) if model is None: raise ValueError("Model not provided and no global 'model' variable found") if tokenizer is None: import builtins tokenizer = getattr(builtins, 'tokenizer', globals().get('tokenizer')) if tokenizer is None: raise ValueError("Tokenizer not provided and no global 'tokenizer' variable found") if aa_to_codon is None: aa_to_codon = aa_to_codon_human optimized_codons = [] current_sequence_tokens = [tokenizer.bos_token_id] for codon in initial_codons: aa = str(Seq(codon).translate()) logits_processor = [SynonymMaskingLogitsProcessor(aa, tokenizer, aa_to_codon)] input_ids = torch.tensor([current_sequence_tokens]) output = model.generate( input_ids, max_length=len(current_sequence_tokens) + 1, temperature=temperature, top_k=top_k, top_p=top_p, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, logits_processor=logits_processor, do_sample=True ) next_token_id = output[0][-1].item() predicted_codon = tokenizer.decode([next_token_id]) optimized_codons.append(predicted_codon.upper()) current_sequence_tokens.append(next_token_id) return optimized_codons