| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | AA_str = 'ACDEFGHIKLMNPQRSTVWY*-'.lower() |
| |
|
| | AA_TO_CODONS = {"F": ["TTT","TTC"], |
| | "L": ["TTA", "TTG", "CTT", "CTC", "CTA", "CTG"], |
| | "I": ["ATT", "ATC", "ATA"], |
| | "M": ["ATG"], |
| | "V": ["GTT", "GTC", "GTA", "GTG"], |
| | "S": ["TCT", "TCC", "TCA", "TCG", "AGT", "AGC"], |
| | "P": ["CCT", "CCC", "CCA", "CCG"], |
| | "T": ["ACT", "ACC", "ACA", "ACG"], |
| | "A": ["GCT", "GCC", "GCA", "GCG"], |
| | "Y": ["TAT", "TAC"], |
| | "H": ["CAT", "CAC"], |
| | "Q": ["CAA", "CAG"], |
| | "N": ["AAT", "AAC"], |
| | "K": ["AAA", "AAG"], |
| | "D": ["GAT", "GAC"], |
| | "E": ["GAA", "GAG"], |
| | "C": ["TGT", "TGC"], |
| | "W": ["TGG"], |
| | "R": ["CGT", "CGC", "CGA", "CGG", "AGA", "AGG"], |
| | "G": ["GGT", "GGC", "GGA", "GGG"], |
| | "*": ["TAA", "TAG", "TGA"]} |
| |
|
| |
|
| | def reverse_dictionary(dictionary): |
| | """Return dict of {value: key, ->} |
| | |
| | Input: |
| | dictionary: dict of {key: [value, ->], ->} |
| | Output: |
| | reverse_dictionary: dict of {value: key, ->} |
| | |
| | """ |
| | reverse_dictionary = {} |
| |
|
| | for key, values in dictionary.items(): |
| | for value in values: |
| | reverse_dictionary[value] = key |
| |
|
| | return reverse_dictionary |
| |
|
| | CODON_TO_AA = reverse_dictionary(AA_TO_CODONS) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | def create_codon_mask(logits, target_protein,backbone_cds, amino_acid_to_codons,base_map={'A': 0, 'T': 1, 'C': 2, 'G': 3}): |
| | batch_size, seq_length, vocab_size = logits.shape |
| | mask = torch.full_like(logits, float("-inf")) |
| |
|
| | for i, amino_acid in enumerate(target_protein): |
| | codon_start = i * 3 |
| | codon_end = codon_start + 3 |
| |
|
| | if codon_end > seq_length: |
| | continue |
| |
|
| | possible_codons = amino_acid_to_codons.get(amino_acid, []) |
| | |
| | for pos in range(codon_start, codon_end): |
| | base_pos = pos % 3 |
| | for codon in possible_codons: |
| | flag = True |
| | for j,nt in enumerate(backbone_cds[codon_start:codon_end]): |
| | if '_'==nt:continue |
| | if codon[j]!=nt: |
| | flag = False |
| | |
| | if flag: |
| | base = codon[base_pos] |
| | base_idx = base_map[base] |
| | mask[:, pos, base_idx] = 0 |
| |
|
| | a = mask.numpy() |
| | return mask |
| |
|
| |
|
| | if __name__ == '__main__': |
| |
|
| |
|
| | |
| | target_protein = ['M', 'A', 'L'] |
| | |
| | |
| | |
| | logits = torch.randn(1, len(target_protein)*3, 4) |
| |
|
| | |
| | backbone_cds = 'AT_G_C_TC' |
| | base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'} |
| | mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS,reverse_dictionary(base_map)) |
| | |
| |
|
| | |
| | masked_logits = mask + logits |
| | a = masked_logits.numpy() |
| | |
| | predictions = torch.argmax(masked_logits, dim=-1) |
| |
|
| | |
| | predicted_sequence = ''.join([base_map[p.item()] for p in predictions[0]]) |
| |
|
| | print("Predicted mRNA sequence:", predicted_sequence) |
| |
|
| |
|