|
|
import torch |
|
|
from transformers import LogitsProcessor |
|
|
from Bio.Seq import Seq |
|
|
|
|
|
|
|
|
aa_to_codon_human = { |
|
|
'F': ['TTT', 'TTC'], 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'], |
|
|
'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], 'Y': ['TAT', 'TAC'], |
|
|
'C': ['TGT', 'TGC'], 'W': ['TGG'], 'P': ['CCT', 'CCC', 'CCA', 'CCG'], |
|
|
'H': ['CAT', 'CAC'], 'Q': ['CAA', 'CAG'], 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'], |
|
|
'I': ['ATT', 'ATC', 'ATA'], 'M': ['ATG'], 'T': ['ACT', 'ACC', 'ACA', 'ACG'], |
|
|
'N': ['AAT', 'AAC'], 'K': ['AAA', 'AAG'], 'V': ['GTT', 'GTC', 'GTA', 'GTG'], |
|
|
'A': ['GCT', 'GCC', 'GCA', 'GCG'], 'D': ['GAT', 'GAC'], 'E': ['GAA', 'GAG'], |
|
|
'G': ['GGT', 'GGC', 'GGA', 'GGG'], '*': ['TAA', 'TAG', 'TGA'] |
|
|
} |
|
|
|
|
|
class SynonymMaskingLogitsProcessor(LogitsProcessor): |
|
|
def __init__(self, current_aa, tokenizer, aa_to_codon=None): |
|
|
self.current_aa = current_aa |
|
|
self.tokenizer = tokenizer |
|
|
self.aa_to_codon = aa_to_codon if aa_to_codon is not None else aa_to_codon_human |
|
|
|
|
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: |
|
|
synonymous_codons = self.aa_to_codon.get(self.current_aa, []) |
|
|
synonym_token_ids = self.tokenizer.convert_tokens_to_ids(synonymous_codons) |
|
|
mask = torch.ones_like(scores) * -float('inf') |
|
|
mask[:, synonym_token_ids] = 0 |
|
|
return scores + mask |
|
|
|
|
|
def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top_k=None, top_p=None, aa_to_codon=None, model=None, tokenizer=None): |
|
|
""" |
|
|
Generate synonymous codon alternatives for a given set of codons. |
|
|
|
|
|
Args: |
|
|
initial_codons: List of codons to optimize |
|
|
temperature: Sampling temperature |
|
|
top_k: Top-k sampling parameter |
|
|
top_p: Top-p (nucleus) sampling parameter |
|
|
aa_to_codon: Amino acid to codon mapping (defaults to human genetic code) |
|
|
model: The CodonGPT model (if None, uses global 'model' variable) |
|
|
tokenizer: The codon tokenizer (if None, uses global 'tokenizer' variable) |
|
|
|
|
|
Returns: |
|
|
List of optimized codons |
|
|
""" |
|
|
|
|
|
if model is None: |
|
|
import inspect |
|
|
|
|
|
frame = inspect.currentframe().f_back |
|
|
model = frame.f_locals.get('model') or frame.f_globals.get('model') |
|
|
if model is None: |
|
|
raise ValueError("Model not provided and no global 'model' variable found") |
|
|
|
|
|
if tokenizer is None: |
|
|
import inspect |
|
|
|
|
|
frame = inspect.currentframe().f_back |
|
|
tokenizer = frame.f_locals.get('tokenizer') or frame.f_globals.get('tokenizer') |
|
|
if tokenizer is None: |
|
|
raise ValueError("Tokenizer not provided and no global 'tokenizer' variable found") |
|
|
|
|
|
if aa_to_codon is None: |
|
|
aa_to_codon = aa_to_codon_human |
|
|
|
|
|
optimized_codons = [] |
|
|
current_sequence_tokens = [tokenizer.bos_token_id] |
|
|
|
|
|
for codon in initial_codons: |
|
|
aa = str(Seq(codon).translate()) |
|
|
logits_processor = [SynonymMaskingLogitsProcessor(aa, tokenizer, aa_to_codon)] |
|
|
|
|
|
input_ids = torch.tensor([current_sequence_tokens]) |
|
|
|
|
|
output = model.generate( |
|
|
input_ids, |
|
|
max_length=len(current_sequence_tokens) + 1, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
num_return_sequences=1, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
logits_processor=logits_processor, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
next_token_id = output[0][-1].item() |
|
|
predicted_codon = tokenizer.decode([next_token_id]) |
|
|
|
|
|
optimized_codons.append(predicted_codon.upper()) |
|
|
current_sequence_tokens.append(next_token_id) |
|
|
|
|
|
return optimized_codons |