anuj2054 commited on
Commit
cbb23a7
·
verified ·
1 Parent(s): e6d6d14

Create synonymous_logit_processor.py

Browse files
Files changed (1) hide show
  1. synonymous_logit_processor.py +42 -0
synonymous_logit_processor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class SynonymMaskingLogitsProcessor(LogitsProcessor):
2
+ def __init__(self, current_aa, tokenizer, aa_to_codon):
3
+ self.current_aa = current_aa
4
+ self.tokenizer = tokenizer
5
+ self.aa_to_codon = aa_to_codon
6
+
7
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
8
+ synonymous_codons = self.aa_to_codon.get(self.current_aa, [])
9
+ synonym_token_ids = self.tokenizer.convert_tokens_to_ids(synonymous_codons)
10
+ mask = torch.ones_like(scores) * -float('inf')
11
+ mask[:, synonym_token_ids] = 0
12
+ return scores + mask
13
+
14
+ def generate_candidate_codons_with_generate(initial_codons, temperature=1.0, top_k=None, top_p=None):
15
+ optimized_codons = []
16
+ current_sequence_tokens = [tokenizer.bos_token_id]
17
+
18
+ for codon in initial_codons:
19
+ aa = str(Seq(codon).translate())
20
+ logits_processor = [SynonymMaskingLogitsProcessor(aa, tokenizer, aa_to_codon_human)]
21
+
22
+ input_ids = torch.tensor([current_sequence_tokens])#.to(device)
23
+
24
+ output = model.generate(
25
+ input_ids,
26
+ max_length=len(current_sequence_tokens) + 1,
27
+ temperature=temperature,
28
+ top_k=top_k,
29
+ top_p=top_p,
30
+ num_return_sequences=1,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ logits_processor=logits_processor,
33
+ do_sample=True # Ensure sampling is used for temperature, top_k, top_p
34
+ )
35
+
36
+ next_token_id = output[0][-1].item()
37
+ predicted_codon = tokenizer.decode([next_token_id])
38
+
39
+ optimized_codons.append(predicted_codon.upper())
40
+ current_sequence_tokens.append(next_token_id)
41
+
42
+ return optimized_codons