codonGPT / synonymous_logit_processor.py
anuj2054's picture
Update synonymous_logit_processor.py
fdbfd89 verified
import torch
from transformers import LogitsProcessor
from Bio.Seq import Seq
# Complete amino acid to codon mapping (human genetic code)
aa_to_codon_human = {
'F': ['TTT', 'TTC'], 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'],
'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], 'Y': ['TAT', 'TAC'],
'C': ['TGT', 'TGC'], 'W': ['TGG'], 'P': ['CCT', 'CCC', 'CCA', 'CCG'],
'H': ['CAT', 'CAC'], 'Q': ['CAA', 'CAG'], 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
'I': ['ATT', 'ATC', 'ATA'], 'M': ['ATG'], 'T': ['ACT', 'ACC', 'ACA', 'ACG'],
'N': ['AAT', 'AAC'], 'K': ['AAA', 'AAG'], 'V': ['GTT', 'GTC', 'GTA', 'GTG'],
'A': ['GCT', 'GCC', 'GCA', 'GCG'], 'D': ['GAT', 'GAC'], 'E': ['GAA', 'GAG'],
'G': ['GGT', 'GGC', 'GGA', 'GGG'], '*': ['TAA', 'TAG', 'TGA']
}
class SynonymMaskingLogitsProcessor(LogitsProcessor):
def __init__(self, current_aa, tokenizer, aa_to_codon=None):
self.current_aa = current_aa
self.tokenizer = tokenizer
self.aa_to_codon = aa_to_codon if aa_to_codon is not None else aa_to_codon_human
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
synonymous_codons = self.aa_to_codon.get(self.current_aa, [])
synonym_token_ids = self.tokenizer.convert_tokens_to_ids(synonymous_codons)
mask = torch.ones_like(scores) * -float('inf')
mask[:, synonym_token_ids] = 0
return scores + mask
def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top_k=None, top_p=None, aa_to_codon=None, model=None, tokenizer=None):
"""
Generate synonymous codon alternatives for a given set of codons.
Args:
initial_codons: List of codons to optimize
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Top-p (nucleus) sampling parameter
aa_to_codon: Amino acid to codon mapping (defaults to human genetic code)
model: The CodonGPT model (if None, uses global 'model' variable)
tokenizer: The codon tokenizer (if None, uses global 'tokenizer' variable)
Returns:
List of optimized codons
"""
# Use global variables if not provided as parameters
if model is None:
import inspect
# Check calling frame's globals and locals
frame = inspect.currentframe().f_back
model = frame.f_locals.get('model') or frame.f_globals.get('model')
if model is None:
raise ValueError("Model not provided and no global 'model' variable found")
if tokenizer is None:
import inspect
# Check calling frame's globals and locals
frame = inspect.currentframe().f_back
tokenizer = frame.f_locals.get('tokenizer') or frame.f_globals.get('tokenizer')
if tokenizer is None:
raise ValueError("Tokenizer not provided and no global 'tokenizer' variable found")
if aa_to_codon is None:
aa_to_codon = aa_to_codon_human
optimized_codons = []
current_sequence_tokens = [tokenizer.bos_token_id]
for codon in initial_codons:
aa = str(Seq(codon).translate())
logits_processor = [SynonymMaskingLogitsProcessor(aa, tokenizer, aa_to_codon)]
input_ids = torch.tensor([current_sequence_tokens])
output = model.generate(
input_ids,
max_length=len(current_sequence_tokens) + 1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
logits_processor=logits_processor,
do_sample=True
)
next_token_id = output[0][-1].item()
predicted_codon = tokenizer.decode([next_token_id])
optimized_codons.append(predicted_codon.upper())
current_sequence_tokens.append(next_token_id)
return optimized_codons