#!/usr/bin/env python import torch import torch.nn.functional as F AA_str = 'ACDEFGHIKLMNPQRSTVWY*-'.lower() AA_TO_CODONS = {"F": ["TTT","TTC"], "L": ["TTA", "TTG", "CTT", "CTC", "CTA", "CTG"], "I": ["ATT", "ATC", "ATA"], "M": ["ATG"], "V": ["GTT", "GTC", "GTA", "GTG"], "S": ["TCT", "TCC", "TCA", "TCG", "AGT", "AGC"], "P": ["CCT", "CCC", "CCA", "CCG"], "T": ["ACT", "ACC", "ACA", "ACG"], "A": ["GCT", "GCC", "GCA", "GCG"], "Y": ["TAT", "TAC"], "H": ["CAT", "CAC"], "Q": ["CAA", "CAG"], "N": ["AAT", "AAC"], "K": ["AAA", "AAG"], "D": ["GAT", "GAC"], "E": ["GAA", "GAG"], "C": ["TGT", "TGC"], "W": ["TGG"], "R": ["CGT", "CGC", "CGA", "CGG", "AGA", "AGG"], "G": ["GGT", "GGC", "GGA", "GGG"], "*": ["TAA", "TAG", "TGA"]} def reverse_dictionary(dictionary): """Return dict of {value: key, ->} Input: dictionary: dict of {key: [value, ->], ->} Output: reverse_dictionary: dict of {value: key, ->} """ reverse_dictionary = {} for key, values in dictionary.items(): for value in values: reverse_dictionary[value] = key return reverse_dictionary CODON_TO_AA = reverse_dictionary(AA_TO_CODONS) # 将氨基酸序列转换为密码子掩码 def create_codon_mask(logits, target_protein,backbone_cds, amino_acid_to_codons,base_map={'A': 0, 'T': 1, 'C': 2, 'G': 3}): batch_size, seq_length, vocab_size = logits.shape mask = torch.full_like(logits, float("-inf")) for i, amino_acid in enumerate(target_protein): codon_start = i * 3 # 每个氨基酸对应 3 个碱基 codon_end = codon_start + 3 if codon_end > seq_length: continue # 超出序列长度,跳过 possible_codons = amino_acid_to_codons.get(amino_acid, []) # filter_codons = [] for pos in range(codon_start, codon_end): base_pos = pos % 3 # 当前碱基在密码子中的位置(0, 1, 2) for codon in possible_codons: flag = True for j,nt in enumerate(backbone_cds[codon_start:codon_end]): if '_'==nt:continue if codon[j]!=nt: flag = False # filter_codons.append(codon) if flag: base = codon[base_pos] base_idx = base_map[base] mask[:, pos, base_idx] = 0 a = mask.numpy() return mask if __name__ == '__main__': # 目标氨基酸序列 target_protein = ['M', 'A', 'L'] # target_protein = AA_str.upper() # 假设的 logits 输出,形状为 (batch_size, seq_length, vocab_size) # 这里假设 batch_size=1,seq_length=9(即 3 个密码子),vocab_size=4(A, U, C, G) logits = torch.randn(1, len(target_protein)*3, 4) # 创建掩码 backbone_cds = 'AT_G_C_TC' base_map = {0: 'A', 1: 'T', 2: 'C', 3: 'G'} mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS,reverse_dictionary(base_map)) # joint_mask = create_codon_mask(logits, target_protein,backbone_cds, AA_TO_CODONS) # 应用掩码 masked_logits = mask + logits a = masked_logits.numpy() # 生成预测(例如,取最大值) predictions = torch.argmax(masked_logits, dim=-1) # 将预测结果转换为碱基序列 predicted_sequence = ''.join([base_map[p.item()] for p in predictions[0]]) print("Predicted mRNA sequence:", predicted_sequence)