#!/usr/bin/env python # -*- coding: utf-8 -*- """ Title : model_exp.py project : minimind_RiboUTR Created by: julse Created on: 2025/3/9 00:45 des: TODO """ import os import torch from torch import nn from transformers.modeling_outputs import CausalLMOutputWithPast from typing import Any, Optional, Tuple, List from .model_downstream import ConvNetCodon from .model_ribo import MiniMindLM, MOEFeedForward, NonLinearHead from .LMConfig import LMConfig,LMaoTaoConfig class VocabEmbedding(nn.Module): def __init__(self, vocab_size, embedding_dim,padding_idx=None): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim,padding_idx=padding_idx) self.dropout = nn.Dropout(0.1) def forward(self, x): x = self.embedding(x) x = self.dropout(x) return x class ExpAdapter(nn.Module): def __init__(self, input_dim, output_dim,padding_idx=0): super().__init__() self.padding_idx = padding_idx self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): mask = x == self.padding_idx x = self.linear(x) x = x.masked_fill(mask[:,:,0].unsqueeze(-1).repeat(1,1,x.shape[-1]), 0) return x class UncertaintyWeighting(nn.Module): def __init__(self, num_losses): super().__init__() self.log_vars = nn.Parameter(torch.zeros(num_losses)) def forward(self, losses): total_loss = 0 for i, loss in enumerate(losses): # print(i,'loss',loss) precision = torch.exp(-self.log_vars[i]) total_loss += precision * loss + self.log_vars[i] # print(total_loss) return total_loss class MyModelOutput(CausalLMOutputWithPast): def __init__(self,logits: Optional[torch.FloatTensor] = None, aux_loss: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, exp: Optional[torch.FloatTensor] = None, te: Optional[torch.FloatTensor] = None, zero_shot: Optional[torch.FloatTensor] = None, embeddings: Optional[torch.FloatTensor] = None, input_embedding: Optional[torch.FloatTensor] = None, input_twod_tokens: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None, attentions: Optional[Tuple[torch.FloatTensor]] = None, **kwargs): # 初始化父类字段 super().__init__( logits=logits, past_key_values=past_key_values, hidden_states=hidden_states, attentions=attentions, **kwargs ) # 定义自定义字段(PyTorch 1.8+ 的类型注解优化序列化) self.aux_loss = aux_loss self.exp = exp self.te = te self.zero_shot = zero_shot self.embeddings = embeddings self.input_embedding = input_embedding self.input_twod_tokens = input_twod_tokens # 自动同步所有张量到同一设备 self._sync_device() def _sync_device(self) -> None: """同步所有张量到 logits 所在的设备""" if not hasattr(self, 'logits') or self.logits is None: return # 无基准设备时跳过 base_device = self.logits.device for field in ['aux_loss', 'te', 'embeddings', 'input_embedding', 'input_twod_tokens']: tensor = getattr(self, field) if isinstance(tensor, torch.Tensor) and tensor.device != base_device: setattr(self, field, tensor.to(base_device)) def __setattr__(self, name, value): """重写属性设置,确保新张量自动同步设备""" super().__setattr__(name, value) if isinstance(value, torch.Tensor) and hasattr(self, 'logits') and self.logits is not None: if value.device != self.logits.device: super().__setattr__(name, value.to(self.logits.device)) class MiniMindLMForExp(MiniMindLM): def __init__(self, params: LMConfig = None, env_counts=2471): super().__init__(params) # 禁用或忽略原有的分类头 # 添加新的回归头 self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) self.exp_dropout = nn.Dropout(params.dropout) # self.env_adapter = nn.Embedding(env_counts, 1) # 处理 实验指示符 self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1) # self.feature_adapter = nn.Linear(3,1) self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) self.te_head = ConvNetCodon(params.dim,params.dim//2,1) # self.regression_head = nn.Linear(params.vocab_size, output_dim) def forward(self, input_ids: Optional[torch.Tensor] = None, src_exp_data: Optional[torch.Tensor] = None, env_ids: Optional[torch.Tensor] = None, twod_tokens: Optional[torch.Tensor] = None, src_feature: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, **args): # args past_key_values = past_key_values or [None] * len(self.layers) # 输入处理 input_ids = input_ids.to(torch.long) start_pos = args.get('start_pos', 0) twod_tokens = twod_tokens.to(torch.float32) h = self.dropout(self.tok_embeddings(input_ids)) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24} seq_mask = input_ids == 1# padding note seq_mask.unsqueeze_(-1) h = h.masked_fill_(seq_mask, 0) if src_exp_data is not None: src_exp_data = src_exp_data.to(torch.float32) src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) # [5, 30, 256] src_exp_data = src_exp_data.masked_fill_(seq_mask, 0) h+=src_exp_data h = h.masked_fill_(seq_mask, 0) if env_ids is not None: env_ids = env_ids.to(torch.long) env_ids = self.env_adapter(env_ids.unsqueeze(-1)) h+=env_ids h = h.masked_fill_(seq_mask, 0) # print('h',h.shape) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] past_kvs = [] for l, layer in enumerate(self.layers): h, past_kv = layer( h, pos_cis, twod_tokens=twod_tokens, past_key_value=past_key_values[l], use_cache=use_cache ) h = h.masked_fill_(seq_mask, 0) past_kvs.append(past_kv) h = self.norm(h) h = h.masked_fill_(seq_mask, 0) logits = self.output(h) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) exp = self.exp_head(h) if hasattr(self,'exp_head') else None # delete in downstream task if exp is not None:exp = exp.masked_fill_(seq_mask, 0) h = h.masked_fill_(seq_mask, 0) te = self.te_head(h) if not h.requires_grad: # 计算非 padding 元素的总和 sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) # 计算非 padding 元素的数量 count_h = torch.sum(~seq_mask, dim=(1, 2)) # 计算均值 mean_h = sum_h / count_h # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0 mean_h[count_h == 0] = 0 # 将均值 reshape 为 (-1, 1) zero_shot = mean_h.reshape(-1, 1) # print(zero_shot.shape,zero_shot) else: zero_shot = None # if src_feature is not None: # src_feature = src_feature.to(torch.float32) # te += self.feature_adapter(src_feature) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('past_key_values', past_kvs) self.OUT.__setitem__('exp', exp) self.OUT.__setitem__('te', te) self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果 self.OUT.__setitem__('embeddings', h) # print('embeddings',h.shape) return self.OUT class MiniMindLMForExp44(MiniMindLM): def __init__(self, params: LMConfig = None, env_counts=2471): super().__init__(params) # 禁用或忽略原有的分类头 # 添加新的回归头 self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) self.exp_dropout = nn.Dropout(params.dropout) # self.env_adapter = nn.Embedding(env_counts, 1) # 处理 实验指示符 self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1) # self.feature_adapter = nn.Linear(3,1) self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) self.te_head = ConvNetCodon(params.dim,params.dim//2,1) # self.regression_head = nn.Linear(params.vocab_size, output_dim) def forward(self, input_ids: Optional[torch.Tensor] = None, src_exp_data: Optional[torch.Tensor] = None, env_ids: Optional[torch.Tensor] = None, twod_tokens: Optional[torch.Tensor] = None, src_feature: Optional[torch.Tensor] = None, input_embedding=None, input_twod_tokens = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, **args): # args past_key_values = past_key_values or [None] * len(self.layers) # 输入处理 input_ids = input_ids.to(torch.long) start_pos = args.get('start_pos', 0) twod_tokens = twod_tokens.to(torch.float32) embedding = self.tok_embeddings(input_ids).to(torch.float32) if input_embedding is not None: embedding = input_embedding #+ embedding if input_twod_tokens is not None: twod_tokens = input_twod_tokens + 1e-7 h = self.dropout(embedding) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24} seq_mask = input_ids == 1# padding note seq_mask.unsqueeze_(-1) h = h.masked_fill(seq_mask, 0) # gap =3000 # region = (input_ids.size(1)-5)/4 # start_indices = start_pos + region * 2+4 # end_indices = start_pos + region * 4 + 5 +gap # pos_cis = torch.concat([self.pos_cis[:int(start_indices)],self.pos_cis[int(end_indices-region*2-1):int(end_indices)]],dim=0) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] past_kvs = [] for l, layer in enumerate(self.layers[:4]): h, past_kv = layer( h, pos_cis, twod_tokens=twod_tokens, past_key_value=past_key_values[l], use_cache=use_cache ) h = h.masked_fill(seq_mask, 0) past_kvs.append(past_kv) if src_exp_data is not None: src_exp_data = src_exp_data.to(torch.float32) src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) # [5, 30, 256] src_exp_data = src_exp_data.masked_fill(seq_mask, 0) h+=src_exp_data h = h.masked_fill(seq_mask, 0) if env_ids is not None: env_ids = env_ids.to(torch.long) env_ids = self.env_adapter(env_ids.unsqueeze(-1)) h+=env_ids h = h.masked_fill(seq_mask, 0) # 后四层处理 for l, layer in enumerate(self.layers[4:]): h, past_kv = layer( h, pos_cis, twod_tokens=twod_tokens, past_key_value=past_key_values[l + 4], # 调整索引以访问后四层的past_key_value use_cache=use_cache ) h = h.masked_fill(seq_mask, 0) past_kvs.append(past_kv) h = self.norm(h) # if input_embedding is not None: # h = input_embedding + embedding h = h.masked_fill(seq_mask, 0) logits = self.output(h) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) exp = self.exp_head(h) if hasattr(self,'exp_head') else None # delete in downstream task if exp is not None:exp = exp.masked_fill_(seq_mask, 0) h = h.masked_fill(seq_mask, 0) te = self.te_head(h) if not h.requires_grad: # 计算非 padding 元素的总和 sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) # 计算非 padding 元素的数量 count_h = torch.sum(~seq_mask, dim=(1, 2)) # 计算均值 mean_h = sum_h / count_h # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0 mean_h[count_h == 0] = 0 # 将均值 reshape 为 (-1, 1) zero_shot = mean_h.reshape(-1, 1) # print(zero_shot.shape,zero_shot) else: zero_shot = None # if src_feature is not None: # src_feature = src_feature.to(torch.float32) # te += self.feature_adapter(src_feature) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('past_key_values', past_kvs) self.OUT.__setitem__('exp', exp) self.OUT.__setitem__('te', te) self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果 self.OUT.__setitem__('embeddings', h) self.OUT.__setitem__('input_embedding', embedding) # for generation self.OUT.__setitem__('input_twod_tokens', twod_tokens) # for generation # print('embeddings',h.shape) return self.OUT class MiniMindLMForExp44_4region(MiniMindLM): def __init__(self, params: LMConfig = None, env_counts=2471,break_position=604): super().__init__(params) # 禁用或忽略原有的分类头 # 添加新的回归头 self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) self.exp_dropout = nn.Dropout(params.dropout) # self.env_adapter = nn.Embedding(env_counts, 1) # 处理 实验指示符 self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1) # self.feature_adapter = nn.Linear(3,1) self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) self.te_head = ConvNetCodon(params.dim,params.dim//2,1) self.break_position = break_position # self.regression_head = nn.Linear(params.vocab_size, output_dim) self.OUT = MyModelOutput() def forward(self, input_ids: Optional[torch.Tensor] = None, src_exp_data: Optional[torch.Tensor] = None, env_ids: Optional[torch.Tensor] = None, twod_tokens: Optional[torch.Tensor] = None, src_feature: Optional[torch.Tensor] = None, input_embedding=None, input_twod_tokens = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, **args): # args past_key_values = past_key_values or [None] * len(self.layers) # 输入处理 input_ids = input_ids.to(torch.long) start_pos = args.get('start_pos', 0) twod_tokens = twod_tokens.to(torch.float32) # [1, 1, 1205, 1205] embedding = self.tok_embeddings(input_ids).to(torch.float32) if input_embedding is not None: embedding = input_embedding #+ embedding if input_twod_tokens is not None: twod_tokens = input_twod_tokens + 1e-7 h = self.dropout(embedding) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24} seq_mask = input_ids == 1# padding note seq_mask.unsqueeze_(-1) h = h.masked_fill(seq_mask, 0) # gap =0 gap = 3000 if input_ids.size(1)0 else 0 # 非moe这里就是0 aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) exp = self.exp_head(h) if hasattr(self,'exp_head') else None # delete in downstream task if exp is not None:exp = exp.masked_fill_(seq_mask, 0) h = h.masked_fill(seq_mask, 0) te = self.te_head(h) if not h.requires_grad: # 计算非 padding 元素的总和 sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) # 计算非 padding 元素的数量 count_h = torch.sum(~seq_mask, dim=(1, 2)) # 计算均值 mean_h = sum_h / count_h # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0 mean_h[count_h == 0] = 0 # 将均值 reshape 为 (-1, 1) zero_shot = mean_h.reshape(-1, 1) # print(zero_shot.shape,zero_shot) else: zero_shot = None # # if src_feature is not None: # # src_feature = src_feature.to(torch.float32) # # te += self.feature_adapter(src_feature) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('past_key_values', past_kvs) self.OUT.__setitem__('exp', exp) self.OUT.__setitem__('te', te) self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果 self.OUT.__setitem__('embeddings', h) self.OUT.__setitem__('input_embedding', embedding) # for generation self.OUT.__setitem__('input_twod_tokens', twod_tokens) # for generation # print('embeddings',h.shape) return self.OUT class ConditionalLoss(nn.Module): def __init__(self, use_te_loss=True, te_loss_weight=1.0): super().__init__() self.loss_fn = nn.CrossEntropyLoss() self.loss_mse = nn.MSELoss() self.te_loss_weight = te_loss_weight # TE loss的权重 def forward(self, pred_nn:torch.Tensor=None, pred_te:torch.Tensor=None, targets_nn:torch.Tensor=None, targets_te:torch.Tensor=None, species_idx:torch.Tensor=None, truncated_idx:torch.Tensor=None,seq_mask:torch.Tensor=None): # 根据特征对loss进行分组计算 total_loss = 0 batch_size,length,vocab = pred_nn.size() # 按物种类别分组计算 unique_species = torch.unique(species_idx) species_losses = {} species_nn_losses = {} species_te_losses = {} for species in unique_species: species_mask = species_idx == species if species_mask.sum() > 0: # 确保有样本 # 计算NN loss todo:shape is wrong species_nn_loss = self.loss_fn( torch.masked_select(pred_nn[species_mask].view(-1, vocab), seq_mask[species_mask].view(-1)), torch.masked_select(targets_nn[species_mask], seq_mask[species_mask].view(-1)) ) # species_nn_loss = self.loss_fn(pred_nn[species_mask].view(-1,vocab), targets_nn[species_mask].view(-1)) species_nn_losses[f'species_{species.item()}'] = species_nn_loss # 计算TE loss(可选) species_te_loss = 0 if targets_te is not None: species_te_loss = self.loss_mse(pred_te[species_mask].view(-1), targets_te[species_mask]) species_te_losses[f'species_{species.item()}'] = species_te_loss # 组合loss species_loss = species_nn_loss + self.te_loss_weight * species_te_loss species_losses[f'species_{species.item()}'] = species_loss total_loss += species_loss # 按截断位置分组计算 unique_trunc = torch.unique(truncated_idx) trunc_losses = {} trunc_nn_losses = {} trunc_te_losses = {} for trunc_pos in unique_trunc: trunc_mask = (truncated_idx == trunc_pos) if trunc_mask.sum() > 0: # 计算NN loss trunc_nn_loss = self.loss_fn(pred_nn[trunc_mask].view(-1,vocab), targets_nn[trunc_mask].view(-1)) trunc_nn_losses[f'trunc_{trunc_pos.item()}'] = trunc_nn_loss # 计算TE loss(可选) trunc_te_loss = 0 if targets_te: trunc_te_loss = self.loss_mse(pred_te[trunc_mask].view(-1), targets_te[trunc_mask]) trunc_te_losses[f'trunc_{trunc_pos.item()}'] = trunc_te_loss # 组合loss trunc_loss = trunc_nn_loss + self.te_loss_weight * trunc_te_loss trunc_losses[f'trunc_{trunc_pos.item()}'] = trunc_loss total_loss += trunc_loss # 计算平均loss num_groups = len(unique_species) + len(unique_trunc) avg_loss = total_loss / num_groups if num_groups > 0 else torch.tensor(0.0) # 返回结果 result = { 'total_loss': avg_loss, 'species_losses': species_losses, 'trunc_losses': trunc_losses, 'species_nn_losses': species_nn_losses, 'trunc_nn_losses': trunc_nn_losses, } # 如果使用了TE loss,添加相关信息 if self.use_te_loss: result.update({ 'species_te_losses': species_te_losses, 'trunc_te_losses': trunc_te_losses, 'te_loss_weight': self.te_loss_weight }) return result class MiniMindLM_Maotao(MiniMindLM): def __init__(self, params: LMConfig = None): super().__init__(params) # 禁用或忽略原有的分类头 head_dim = params.dim // params.n_heads # self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) # self.exp_dropout = nn.Dropout(params.dropout) # self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) # self.te_head = ConvNetCodon(params.dim,params.dim//2,1) # Adapters for new features # Adapters for continuous, position, and categorical features # Assuming continuous_features and position_feature are single values per sequence position # self.aa_embedding_adapter = VocabEmbedding(params.vocab_size, params.dim) self.aa_embedding_adapter = VocabEmbedding(params.aa_vocab_size, params.dim,padding_idx=0) # for self.species_feature_adapter = VocabEmbedding(params.species_size, params.dim) self.truncated_feature_adapter = VocabEmbedding(params.truncated_size, params.dim) # full,head,tail,boundary,middle self.continuous_features = nn.Linear(params.continuous_features_dim, params.dim) # data['off_start'],data['off_end'],data['full_len'] log(x+1) # 特征权重学习 self.aa_attention = nn.MultiheadAttention(params.dim, num_heads= head_dim) self.feature_fusion = nn.Linear(params.dim*3, params.dim) # Output heads for new outputs # Assuming target_nn, target, maotao_id are derived from the hidden state 'h' self.te_head = ConvNetCodon(params.dim,params.dim//2,1) # 打破权重共享,重新初始化tok_embeddings的权重 self.tok_embeddings.weight = nn.Parameter(torch.randn_like(self.tok_embeddings.weight)) # self.regression_head = nn.Linear(params.vocab_size, output_dim) self.OUT = MyModelOutput() # loss for training and evaluation # self.loss_model = ConditionalLoss() def forward(self, input_ids: Optional[torch.Tensor] = None, twod_tokens: Optional[torch.Tensor] = None, aa_idx: Optional[torch.Tensor] = None, continuous_features: Optional[torch.Tensor] = None, species_features: Optional[torch.Tensor] = None, truncated_features: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, input_embedding=None, input_twod_tokens=None, # for training and evaluation # targets_nn=None, targets_te=None, **args): past_key_values = past_key_values or [None] * len(self.layers) # Initial embedding from aa_idx (amino acid index) as the primary sequence input # Assuming aa_idx replaces input_ids as the main sequence input if aa_idx is None: raise ValueError("aa_idx must be provided as the primary sequence input.") '''input aa nn mer3 ''' # 输入处理 '''input rna''' input_ids = input_ids.to(torch.long) start_pos = args.get('start_pos', 0) twod_tokens = twod_tokens.to(torch.float32) # [1, 1, 1205, 1205] nn_embedding = self.tok_embeddings(input_ids).to(torch.float32) # nn_embedding if input_embedding is not None: nn_embedding = input_embedding #+ embedding if input_twod_tokens is not None: twod_tokens = input_twod_tokens + 1e-7 h = self.dropout(nn_embedding) seq_mask = input_ids == 1 # for nn seq_mask.unsqueeze_(-1) # Mask for padding (assuming 1 is padding index for aa_idx) h = h.masked_fill(seq_mask, 1) '''input aa''' aa_idx = torch.clamp(aa_idx-9,min=0) # min(aa_idx,10), padding_idx=0 for protein, 1 is the mini aa # from rna aa all {'': 0, '': 1, '': 2, '': 3, 'G': 4, 'A': 5, 'U': 6, 'C': 7, 'N': 8, '': 9, 'T': 6, '_': 1, 'a': 10, 'c': 11, 'd': 12, 'e': 13, 'f': 14, 'g': 15, 'h': 16, 'i': 17, 'k': 18, 'l': 19, 'm': 20, 'n': 21, 'p': 22, 'q': 23, 'r': 24, 's': 25, 't': 26, 'v': 27, 'w': 28, 'y': 29, '*': 30, '-': 31} # to {'_': 0, 'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 'V': 18, 'W': 19, 'Y': 20, '*': 21} aa_embedding = self.aa_embedding_adapter(aa_idx.to(torch.long)).to(torch.float32) # 1200/3 species_features = self.species_feature_adapter(species_features).to(torch.float32) # 1 truncated_features = self.truncated_feature_adapter(truncated_features).to(torch.float32) # 1 continuous_features = self.continuous_features(continuous_features).to(torch.float32) # 3 pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] past_kvs = [] l = 0 layer = self.layers[l] h, past_kv = layer( h, pos_cis, twod_tokens=twod_tokens, past_key_value=past_key_values[l], use_cache=use_cache ) h = h.masked_fill(seq_mask, 0) past_kvs.append(past_kv) ''' 模态融合 ''' batch_size, seq_len, hidden_dim = h.shape frame_1 = h[:, 0::3, :] frame_2 = h[:, 1::3, :] frame_3 = h[:, 2::3, :] # 第一层:氨基酸级别的调整 # 特征作为query,序列作为key和value attended_from_aa, _ = self.aa_attention( query=frame_3, # 调整目标:第三位密码子 key=aa_embedding, # 调整依据:特征 value=aa_embedding # 调整内容:特征 ) # 第二层:全局特征的调整 global_features = torch.cat([species_features, truncated_features, continuous_features], dim=1) global_features = self.feature_fusion(global_features) global_features = global_features.unsqueeze(1).expand(-1, frame_3.size(1), -1) # 如果需要恢复到原始序列顺序,使用: new_h_reshaped = torch.stack([frame_1, frame_2, frame_3+ attended_from_aa + global_features], dim=2) h = new_h_reshaped.reshape(batch_size, -1, hidden_dim) h = h.masked_fill(seq_mask, 0) for l, layer in enumerate(self.layers[1:4]): h, past_kv = layer( h, pos_cis, twod_tokens=None, past_key_value=past_key_values[l+1], use_cache=use_cache ) h = h.masked_fill(seq_mask, 0) past_kvs.append(past_kv) # 后四层处理 for l, layer in enumerate(self.layers[4:]): h, past_kv = layer( h, pos_cis, twod_tokens=None, past_key_value=past_key_values[l + 4], # 调整索引以访问后四层的past_key_value use_cache=use_cache ) h = h.masked_fill(seq_mask, 0) past_kvs.append(past_kv) h = self.norm(h) # if input_embedding is not None: # h = input_embedding + embedding h = h.masked_fill(seq_mask, 0) logits = self.output(h) # moe_aux_loss = [l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)] # aux_loss = sum(moe_aux_loss) if len(moe_aux_loss)>0 else 0 # 非moe这里就是0 aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) h = h.masked_fill(seq_mask, 0) cai = self.te_head(h) zero_shot = self.get_zero_shot(seq_mask,h) if not h.requires_grad else None # if not h.requires_grad: # # 计算非 padding 元素的总和 # sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) # # # 计算非 padding 元素的数量 # count_h = torch.sum(~seq_mask, dim=(1, 2)) # # 计算均值 # mean_h = sum_h / count_h # # # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0 # mean_h[count_h == 0] = 0 # # # 将均值 reshape 为 (-1, 1) # zero_shot = mean_h.reshape(-1, 1) # # print(zero_shot.shape,zero_shot) # else: # zero_shot = None # if targets_nn is not None: # # pred_nn, pred_te, targets_nn,targets_te, species_features, truncated_features # # self.loss_model(logits, cai, targets_nn, targets_te, species_features, truncated_features) # loss = self.loss_model(pred_nn=logits, pred_te=cai, # targets_nn=targets_nn, targets_te=targets_te, # species_idx=species_idx, truncated_idx=truncated_idx, # seq_mask=seq_mask) # self.OUT.__setitem__('loss', loss) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('past_key_values', past_kvs) self.OUT.__setitem__('te', cai) # cai here self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果 self.OUT.__setitem__('embeddings', h) self.OUT.__setitem__('input_embedding', nn_embedding) # for generation self.OUT.__setitem__('input_twod_tokens', twod_tokens) # for generation # print('embeddings',h.shape) return self.OUT def get_zero_shot(self,seq_mask,h): # h = self.OUT.__getitem__('embeddings') # 计算非 padding 元素的总和 sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) # 计算非 padding 元素的数量 count_h = torch.sum(~seq_mask, dim=(1, 2)) # 计算均值 mean_h = sum_h / count_h # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0 mean_h[count_h == 0] = 0 # 将均值 reshape 为 (-1, 1) zero_shot = mean_h.reshape(-1, 1) # print(zero_shot.shape,zero_shot) return zero_shot