anuj2054 commited on
Commit
00c75b5
·
verified ·
1 Parent(s): 0ffaa8a

Update synonymous_logit_processor.py

Browse files
Files changed (1) hide show
  1. synonymous_logit_processor.py +51 -15
synonymous_logit_processor.py CHANGED
@@ -1,19 +1,24 @@
 
 
 
 
 
1
  aa_to_codon_human = {
2
- 'A': ['GCT', 'GCC', 'GCA', 'GCG'], 'C': ['TGT', 'TGC'], 'D': ['GAT', 'GAC'],
3
- 'E': ['GAA', 'GAG'], 'F': ['TTT', 'TTC'], 'G': ['GGT', 'GGC', 'GGA', 'GGG'],
4
- 'H': ['CAT', 'CAC'], 'I': ['ATT', 'ATC', 'ATA'], 'K': ['AAA', 'AAG'],
5
- 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'], 'M': ['ATG'],
6
- 'N': ['AAT', 'AAC'], 'P': ['CCT', 'CCC', 'CCA', 'CCG'], 'Q': ['CAA', 'CAG'],
7
- 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'], 'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'],
8
- 'T': ['ACT', 'ACC', 'ACA', 'ACG'], 'V': ['GTT', 'GTC', 'GTA', 'GTG'],
9
- 'W': ['TGG'], 'Y': ['TAT', 'TAC'], '*': ['TAA', 'TAG', 'TGA']
10
  }
11
 
12
  class SynonymMaskingLogitsProcessor(LogitsProcessor):
13
- def __init__(self, current_aa, tokenizer, aa_to_codon):
14
  self.current_aa = current_aa
15
  self.tokenizer = tokenizer
16
- self.aa_to_codon = aa_to_codon
17
 
18
  def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
19
  synonymous_codons = self.aa_to_codon.get(self.current_aa, [])
@@ -22,15 +27,46 @@ class SynonymMaskingLogitsProcessor(LogitsProcessor):
22
  mask[:, synonym_token_ids] = 0
23
  return scores + mask
24
 
25
- def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top_k=None, top_p=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  optimized_codons = []
27
  current_sequence_tokens = [tokenizer.bos_token_id]
28
 
29
  for codon in initial_codons:
30
  aa = str(Seq(codon).translate())
31
- logits_processor = [SynonymMaskingLogitsProcessor(aa, tokenizer, aa_to_codon_human)]
32
 
33
- input_ids = torch.tensor([current_sequence_tokens])#.to(device)
34
 
35
  output = model.generate(
36
  input_ids,
@@ -39,9 +75,9 @@ def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top
39
  top_k=top_k,
40
  top_p=top_p,
41
  num_return_sequences=1,
42
- pad_token_id=tokenizer.eos_token_id,
43
  logits_processor=logits_processor,
44
- do_sample=True # Ensure sampling is used for temperature, top_k, top_p
45
  )
46
 
47
  next_token_id = output[0][-1].item()
 
1
+ import torch
2
+ from transformers import LogitsProcessor
3
+ from Bio.Seq import Seq
4
+
5
+ # Complete amino acid to codon mapping (human genetic code)
6
  aa_to_codon_human = {
7
+ 'F': ['TTT', 'TTC'], 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'],
8
+ 'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], 'Y': ['TAT', 'TAC'],
9
+ 'C': ['TGT', 'TGC'], 'W': ['TGG'], 'P': ['CCT', 'CCC', 'CCA', 'CCG'],
10
+ 'H': ['CAT', 'CAC'], 'Q': ['CAA', 'CAG'], 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
11
+ 'I': ['ATT', 'ATC', 'ATA'], 'M': ['ATG'], 'T': ['ACT', 'ACC', 'ACA', 'ACG'],
12
+ 'N': ['AAT', 'AAC'], 'K': ['AAA', 'AAG'], 'V': ['GTT', 'GTC', 'GTA', 'GTG'],
13
+ 'A': ['GCT', 'GCC', 'GCA', 'GCG'], 'D': ['GAT', 'GAC'], 'E': ['GAA', 'GAG'],
14
+ 'G': ['GGT', 'GGC', 'GGA', 'GGG'], '*': ['TAA', 'TAG', 'TGA']
15
  }
16
 
17
  class SynonymMaskingLogitsProcessor(LogitsProcessor):
18
+ def __init__(self, current_aa, tokenizer, aa_to_codon=None):
19
  self.current_aa = current_aa
20
  self.tokenizer = tokenizer
21
+ self.aa_to_codon = aa_to_codon if aa_to_codon is not None else aa_to_codon_human
22
 
23
  def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
24
  synonymous_codons = self.aa_to_codon.get(self.current_aa, [])
 
27
  mask[:, synonym_token_ids] = 0
28
  return scores + mask
29
 
30
+ def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top_k=None, top_p=None, aa_to_codon=None, model=None, tokenizer=None):
31
+ """
32
+ Generate synonymous codon alternatives for a given set of codons.
33
+
34
+ Args:
35
+ initial_codons: List of codons to optimize
36
+ temperature: Sampling temperature
37
+ top_k: Top-k sampling parameter
38
+ top_p: Top-p (nucleus) sampling parameter
39
+ aa_to_codon: Amino acid to codon mapping (defaults to human genetic code)
40
+ model: The CodonGPT model (if None, uses global 'model' variable)
41
+ tokenizer: The codon tokenizer (if None, uses global 'tokenizer' variable)
42
+
43
+ Returns:
44
+ List of optimized codons
45
+ """
46
+ # Use global variables if not provided as parameters
47
+ if model is None:
48
+ import builtins
49
+ model = getattr(builtins, 'model', globals().get('model'))
50
+ if model is None:
51
+ raise ValueError("Model not provided and no global 'model' variable found")
52
+
53
+ if tokenizer is None:
54
+ import builtins
55
+ tokenizer = getattr(builtins, 'tokenizer', globals().get('tokenizer'))
56
+ if tokenizer is None:
57
+ raise ValueError("Tokenizer not provided and no global 'tokenizer' variable found")
58
+
59
+ if aa_to_codon is None:
60
+ aa_to_codon = aa_to_codon_human
61
+
62
  optimized_codons = []
63
  current_sequence_tokens = [tokenizer.bos_token_id]
64
 
65
  for codon in initial_codons:
66
  aa = str(Seq(codon).translate())
67
+ logits_processor = [SynonymMaskingLogitsProcessor(aa, tokenizer, aa_to_codon)]
68
 
69
+ input_ids = torch.tensor([current_sequence_tokens])
70
 
71
  output = model.generate(
72
  input_ids,
 
75
  top_k=top_k,
76
  top_p=top_p,
77
  num_return_sequences=1,
78
+ pad_token_id=tokenizer.pad_token_id,
79
  logits_processor=logits_processor,
80
+ do_sample=True
81
  )
82
 
83
  next_token_id = output[0][-1].item()