maotao / model /codon_tables.py
julse's picture
upload AA2CDS
4707555 verified
#!/usr/bin/env python
import torch
import torch.nn.functional as F
AA_str = 'ACDEFGHIKLMNPQRSTVWY*-'.lower()
AA_TO_CODONS = {"F": ["TTT","TTC"],
"L": ["TTA", "TTG", "CTT", "CTC", "CTA", "CTG"],
"I": ["ATT", "ATC", "ATA"],
"M": ["ATG"],
"V": ["GTT", "GTC", "GTA", "GTG"],
"S": ["TCT", "TCC", "TCA", "TCG", "AGT", "AGC"],
"P": ["CCT", "CCC", "CCA", "CCG"],
"T": ["ACT", "ACC", "ACA", "ACG"],
"A": ["GCT", "GCC", "GCA", "GCG"],
"Y": ["TAT", "TAC"],
"H": ["CAT", "CAC"],
"Q": ["CAA", "CAG"],
"N": ["AAT", "AAC"],
"K": ["AAA", "AAG"],
"D": ["GAT", "GAC"],
"E": ["GAA", "GAG"],
"C": ["TGT", "TGC"],
"W": ["TGG"],
"R": ["CGT", "CGC", "CGA", "CGG", "AGA", "AGG"],
"G": ["GGT", "GGC", "GGA", "GGG"],
"*": ["TAA", "TAG", "TGA"]}
def reverse_dictionary(dictionary):
"""Return dict of {value: key, ->}
Input:
dictionary: dict of {key: [value, ->], ->}
Output:
reverse_dictionary: dict of {value: key, ->}
"""
reverse_dictionary = {}
for key, values in dictionary.items():
for value in values:
reverse_dictionary[value] = key
return reverse_dictionary
CODON_TO_AA = reverse_dictionary(AA_TO_CODONS)
# 将氨基酸序列转换为密码子掩码
def create_codon_mask(logits, target_protein,backbone_cds, amino_acid_to_codons,base_map={'A': 0, 'T': 1, 'C': 2, 'G': 3}):
batch_size, seq_length, vocab_size = logits.shape
mask = torch.full_like(logits, float("-inf"))
for i, amino_acid in enumerate(target_protein):
codon_start = i * 3 # 每个氨基酸对应 3 个碱基
codon_end = codon_start + 3
if codon_end > seq_length:
continue # 超出序列长度,跳过
possible_codons = amino_acid_to_codons.get(amino_acid, [])
# filter_codons = []
for pos in range(codon_start, codon_end):
base_pos = pos % 3 # 当前碱基在密码子中的位置(0, 1, 2)
for codon in possible_codons:
flag = True
for j,nt in enumerate(backbone_cds[codon_start:codon_end]):
if '_'==nt:continue
if codon[j]!=nt:
flag = False
# filter_codons.append(codon)
if flag:
base = codon[base_pos]
base_idx = base_map[base]
mask[:, pos, base_idx] = 0
a = mask.numpy()
return mask
if __name__ == '__main__':
# 目标氨基酸序列
target_protein = ['M', 'A', 'L']
# target_protein = AA_str.upper()
# 假设的 logits 输出,形状为 (batch_size, seq_length, vocab_size)
# 这里假设 batch_size=1,seq_length=9(即 3 个密码子),vocab_size=4(A, U, C, G)
logits = torch.randn(1, len(target_protein)*3, 4)
# 创建掩码
backbone_cds = 'AT_G_C_TC'
base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'}
mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS,reverse_dictionary(base_map))
# joint_mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS)
# 应用掩码
masked_logits = mask + logits
a = masked_logits.numpy()
# 生成预测(例如,取最大值)
predictions = torch.argmax(masked_logits, dim=-1)
# 将预测结果转换为碱基序列
predicted_sequence = ''.join([base_map[p.item()] for p in predictions[0]])
print("Predicted mRNA sequence:", predicted_sequence)