maotao / model /model_exp.py
julse's picture
upload AA2CDS
4707555 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Title : model_exp.py
project : minimind_RiboUTR
Created by: julse
Created on: 2025/3/9 00:45
des: TODO
"""
import os
import torch
from torch import nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Any, Optional, Tuple, List
from .model_downstream import ConvNetCodon
from .model_ribo import MiniMindLM, MOEFeedForward, NonLinearHead
from .LMConfig import LMConfig,LMaoTaoConfig
class VocabEmbedding(nn.Module):
def __init__(self, vocab_size, embedding_dim,padding_idx=None):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim,padding_idx=padding_idx)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
x = self.embedding(x)
x = self.dropout(x)
return x
class ExpAdapter(nn.Module):
def __init__(self, input_dim, output_dim,padding_idx=0):
super().__init__()
self.padding_idx = padding_idx
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
mask = x == self.padding_idx
x = self.linear(x)
x = x.masked_fill(mask[:,:,0].unsqueeze(-1).repeat(1,1,x.shape[-1]), 0)
return x
class UncertaintyWeighting(nn.Module):
def __init__(self, num_losses):
super().__init__()
self.log_vars = nn.Parameter(torch.zeros(num_losses))
def forward(self, losses):
total_loss = 0
for i, loss in enumerate(losses):
# print(i,'loss',loss)
precision = torch.exp(-self.log_vars[i])
total_loss += precision * loss + self.log_vars[i]
# print(total_loss)
return total_loss
class MyModelOutput(CausalLMOutputWithPast):
def __init__(self,logits: Optional[torch.FloatTensor] = None,
aux_loss: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
exp: Optional[torch.FloatTensor] = None,
te: Optional[torch.FloatTensor] = None,
zero_shot: Optional[torch.FloatTensor] = None,
embeddings: Optional[torch.FloatTensor] = None,
input_embedding: Optional[torch.FloatTensor] = None,
input_twod_tokens: Optional[torch.FloatTensor] = None,
hidden_states: Optional[torch.FloatTensor] = None,
attentions: Optional[Tuple[torch.FloatTensor]] = None,
**kwargs):
# 初始化父类字段
super().__init__(
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
attentions=attentions,
**kwargs
)
# 定义自定义字段(PyTorch 1.8+ 的类型注解优化序列化)
self.aux_loss = aux_loss
self.exp = exp
self.te = te
self.zero_shot = zero_shot
self.embeddings = embeddings
self.input_embedding = input_embedding
self.input_twod_tokens = input_twod_tokens
# 自动同步所有张量到同一设备
self._sync_device()
def _sync_device(self) -> None:
"""同步所有张量到 logits 所在的设备"""
if not hasattr(self, 'logits') or self.logits is None:
return # 无基准设备时跳过
base_device = self.logits.device
for field in ['aux_loss', 'te', 'embeddings', 'input_embedding', 'input_twod_tokens']:
tensor = getattr(self, field)
if isinstance(tensor, torch.Tensor) and tensor.device != base_device:
setattr(self, field, tensor.to(base_device))
def __setattr__(self, name, value):
"""重写属性设置,确保新张量自动同步设备"""
super().__setattr__(name, value)
if isinstance(value, torch.Tensor) and hasattr(self, 'logits') and self.logits is not None:
if value.device != self.logits.device:
super().__setattr__(name, value.to(self.logits.device))
class MiniMindLMForExp(MiniMindLM):
def __init__(self, params: LMConfig = None, env_counts=2471):
super().__init__(params)
# 禁用或忽略原有的分类头
# 添加新的回归头
self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0)
self.exp_dropout = nn.Dropout(params.dropout)
# self.env_adapter = nn.Embedding(env_counts, 1) # 处理 实验指示符
self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1)
# self.feature_adapter = nn.Linear(3,1)
self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2)
self.te_head = ConvNetCodon(params.dim,params.dim//2,1)
# self.regression_head = nn.Linear(params.vocab_size, output_dim)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
src_exp_data: Optional[torch.Tensor] = None,
env_ids: Optional[torch.Tensor] = None,
twod_tokens: Optional[torch.Tensor] = None,
src_feature: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
**args):
# args
past_key_values = past_key_values or [None] * len(self.layers)
# 输入处理
input_ids = input_ids.to(torch.long)
start_pos = args.get('start_pos', 0)
twod_tokens = twod_tokens.to(torch.float32)
h = self.dropout(self.tok_embeddings(input_ids)) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24}
seq_mask = input_ids == 1# padding note
seq_mask.unsqueeze_(-1)
h = h.masked_fill_(seq_mask, 0)
if src_exp_data is not None:
src_exp_data = src_exp_data.to(torch.float32)
src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) # [5, 30, 256]
src_exp_data = src_exp_data.masked_fill_(seq_mask, 0)
h+=src_exp_data
h = h.masked_fill_(seq_mask, 0)
if env_ids is not None:
env_ids = env_ids.to(torch.long)
env_ids = self.env_adapter(env_ids.unsqueeze(-1))
h+=env_ids
h = h.masked_fill_(seq_mask, 0)
# print('h',h.shape)
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
for l, layer in enumerate(self.layers):
h, past_kv = layer(
h, pos_cis,
twod_tokens=twod_tokens,
past_key_value=past_key_values[l],
use_cache=use_cache
)
h = h.masked_fill_(seq_mask, 0)
past_kvs.append(past_kv)
h = self.norm(h)
h = h.masked_fill_(seq_mask, 0)
logits = self.output(h)
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
exp = self.exp_head(h) if hasattr(self,'exp_head') else None # delete in downstream task
if exp is not None:exp = exp.masked_fill_(seq_mask, 0)
h = h.masked_fill_(seq_mask, 0)
te = self.te_head(h)
if not h.requires_grad:
# 计算非 padding 元素的总和
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2))
# 计算非 padding 元素的数量
count_h = torch.sum(~seq_mask, dim=(1, 2))
# 计算均值
mean_h = sum_h / count_h
# 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0
mean_h[count_h == 0] = 0
# 将均值 reshape 为 (-1, 1)
zero_shot = mean_h.reshape(-1, 1)
# print(zero_shot.shape,zero_shot)
else:
zero_shot = None
# if src_feature is not None:
# src_feature = src_feature.to(torch.float32)
# te += self.feature_adapter(src_feature)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
self.OUT.__setitem__('exp', exp)
self.OUT.__setitem__('te', te)
self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果
self.OUT.__setitem__('embeddings', h)
# print('embeddings',h.shape)
return self.OUT
class MiniMindLMForExp44(MiniMindLM):
def __init__(self, params: LMConfig = None, env_counts=2471):
super().__init__(params)
# 禁用或忽略原有的分类头
# 添加新的回归头
self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0)
self.exp_dropout = nn.Dropout(params.dropout)
# self.env_adapter = nn.Embedding(env_counts, 1) # 处理 实验指示符
self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1)
# self.feature_adapter = nn.Linear(3,1)
self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2)
self.te_head = ConvNetCodon(params.dim,params.dim//2,1)
# self.regression_head = nn.Linear(params.vocab_size, output_dim)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
src_exp_data: Optional[torch.Tensor] = None,
env_ids: Optional[torch.Tensor] = None,
twod_tokens: Optional[torch.Tensor] = None,
src_feature: Optional[torch.Tensor] = None,
input_embedding=None,
input_twod_tokens = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
**args):
# args
past_key_values = past_key_values or [None] * len(self.layers)
# 输入处理
input_ids = input_ids.to(torch.long)
start_pos = args.get('start_pos', 0)
twod_tokens = twod_tokens.to(torch.float32)
embedding = self.tok_embeddings(input_ids).to(torch.float32)
if input_embedding is not None:
embedding = input_embedding #+ embedding
if input_twod_tokens is not None:
twod_tokens = input_twod_tokens + 1e-7
h = self.dropout(embedding) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24}
seq_mask = input_ids == 1# padding note
seq_mask.unsqueeze_(-1)
h = h.masked_fill(seq_mask, 0)
# gap =3000
# region = (input_ids.size(1)-5)/4
# start_indices = start_pos + region * 2+4
# end_indices = start_pos + region * 4 + 5 +gap
# pos_cis = torch.concat([self.pos_cis[:int(start_indices)],self.pos_cis[int(end_indices-region*2-1):int(end_indices)]],dim=0)
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
for l, layer in enumerate(self.layers[:4]):
h, past_kv = layer(
h, pos_cis,
twod_tokens=twod_tokens,
past_key_value=past_key_values[l],
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
if src_exp_data is not None:
src_exp_data = src_exp_data.to(torch.float32)
src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) # [5, 30, 256]
src_exp_data = src_exp_data.masked_fill(seq_mask, 0)
h+=src_exp_data
h = h.masked_fill(seq_mask, 0)
if env_ids is not None:
env_ids = env_ids.to(torch.long)
env_ids = self.env_adapter(env_ids.unsqueeze(-1))
h+=env_ids
h = h.masked_fill(seq_mask, 0)
# 后四层处理
for l, layer in enumerate(self.layers[4:]):
h, past_kv = layer(
h, pos_cis,
twod_tokens=twod_tokens,
past_key_value=past_key_values[l + 4], # 调整索引以访问后四层的past_key_value
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
h = self.norm(h)
# if input_embedding is not None:
# h = input_embedding + embedding
h = h.masked_fill(seq_mask, 0)
logits = self.output(h)
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
exp = self.exp_head(h) if hasattr(self,'exp_head') else None # delete in downstream task
if exp is not None:exp = exp.masked_fill_(seq_mask, 0)
h = h.masked_fill(seq_mask, 0)
te = self.te_head(h)
if not h.requires_grad:
# 计算非 padding 元素的总和
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2))
# 计算非 padding 元素的数量
count_h = torch.sum(~seq_mask, dim=(1, 2))
# 计算均值
mean_h = sum_h / count_h
# 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0
mean_h[count_h == 0] = 0
# 将均值 reshape 为 (-1, 1)
zero_shot = mean_h.reshape(-1, 1)
# print(zero_shot.shape,zero_shot)
else:
zero_shot = None
# if src_feature is not None:
# src_feature = src_feature.to(torch.float32)
# te += self.feature_adapter(src_feature)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
self.OUT.__setitem__('exp', exp)
self.OUT.__setitem__('te', te)
self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果
self.OUT.__setitem__('embeddings', h)
self.OUT.__setitem__('input_embedding', embedding) # for generation
self.OUT.__setitem__('input_twod_tokens', twod_tokens) # for generation
# print('embeddings',h.shape)
return self.OUT
class MiniMindLMForExp44_4region(MiniMindLM):
def __init__(self, params: LMConfig = None, env_counts=2471,break_position=604):
super().__init__(params)
# 禁用或忽略原有的分类头
# 添加新的回归头
self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0)
self.exp_dropout = nn.Dropout(params.dropout)
# self.env_adapter = nn.Embedding(env_counts, 1) # 处理 实验指示符
self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1)
# self.feature_adapter = nn.Linear(3,1)
self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2)
self.te_head = ConvNetCodon(params.dim,params.dim//2,1)
self.break_position = break_position
# self.regression_head = nn.Linear(params.vocab_size, output_dim)
self.OUT = MyModelOutput()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
src_exp_data: Optional[torch.Tensor] = None,
env_ids: Optional[torch.Tensor] = None,
twod_tokens: Optional[torch.Tensor] = None,
src_feature: Optional[torch.Tensor] = None,
input_embedding=None,
input_twod_tokens = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
**args):
# args
past_key_values = past_key_values or [None] * len(self.layers)
# 输入处理
input_ids = input_ids.to(torch.long)
start_pos = args.get('start_pos', 0)
twod_tokens = twod_tokens.to(torch.float32) # [1, 1, 1205, 1205]
embedding = self.tok_embeddings(input_ids).to(torch.float32)
if input_embedding is not None:
embedding = input_embedding #+ embedding
if input_twod_tokens is not None:
twod_tokens = input_twod_tokens + 1e-7
h = self.dropout(embedding) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24}
seq_mask = input_ids == 1# padding note
seq_mask.unsqueeze_(-1)
h = h.masked_fill(seq_mask, 0)
# gap =0
gap = 3000
if input_ids.size(1)<self.break_position:
start_pos = self.break_position-input_ids.size(1)
# region = (input_ids.size(1)-5)/4
# start_indices = start_pos + region * 2+4
# end_indices = start_pos + region * 4 + 5 +gap
# pos_cis = torch.concat([self.pos_cis[:int(start_indices)],self.pos_cis[int(end_indices-region*2-1):int(end_indices)]],dim=0)
# pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
pos_cis = torch.concat([self.pos_cis[start_pos:self.break_position],self.pos_cis[self.break_position+gap:gap+input_ids.size(1)]],dim=0)
# [1205, 16] # self.pos_cis shape: params.max_seq_len,params.dim // params.n_heads
assert input_ids.size(1) <= pos_cis.size(0)+gap,(
f"Sequence length mismatch: input_ids has length {input_ids.size(1)},gap is {gap} "
f"but pos_cis has length {self.pos_cis.size(0)}. "
f"Please ensure that max_seq_len is set to at least {input_ids.size(1)+gap}."
)
past_kvs = []
l = 0
layer = self.layers[l]
h, past_kv = layer(
h, pos_cis,
twod_tokens=twod_tokens,
past_key_value=past_key_values[l],
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
for l, layer in enumerate(self.layers[1:4]):
h, past_kv = layer(
h, pos_cis,
twod_tokens=None,
past_key_value=past_key_values[l+1],
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
if src_exp_data is not None:
src_exp_data = src_exp_data.to(torch.float32)
src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) # [5, 30, 256]
src_exp_data = src_exp_data.masked_fill(seq_mask, 0)
h+=src_exp_data
h = h.masked_fill(seq_mask, 0)
if env_ids is not None:
env_ids = env_ids.to(torch.long)
env_ids = self.env_adapter(env_ids.unsqueeze(-1))
h+=env_ids
h = h.masked_fill(seq_mask, 0)
# 后四层处理
for l, layer in enumerate(self.layers[4:]):
h, past_kv = layer(
h, pos_cis,
twod_tokens=None,
past_key_value=past_key_values[l + 4], # 调整索引以访问后四层的past_key_value
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
h = self.norm(h)
# if input_embedding is not None:
# h = input_embedding + embedding
h = h.masked_fill(seq_mask, 0)
logits = self.output(h)
# moe_aux_loss = [l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)]
# aux_loss = sum(moe_aux_loss) if len(moe_aux_loss)>0 else 0 # 非moe这里就是0
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
exp = self.exp_head(h) if hasattr(self,'exp_head') else None # delete in downstream task
if exp is not None:exp = exp.masked_fill_(seq_mask, 0)
h = h.masked_fill(seq_mask, 0)
te = self.te_head(h)
if not h.requires_grad:
# 计算非 padding 元素的总和
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2))
# 计算非 padding 元素的数量
count_h = torch.sum(~seq_mask, dim=(1, 2))
# 计算均值
mean_h = sum_h / count_h
# 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0
mean_h[count_h == 0] = 0
# 将均值 reshape 为 (-1, 1)
zero_shot = mean_h.reshape(-1, 1)
# print(zero_shot.shape,zero_shot)
else:
zero_shot = None
# # if src_feature is not None:
# # src_feature = src_feature.to(torch.float32)
# # te += self.feature_adapter(src_feature)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
self.OUT.__setitem__('exp', exp)
self.OUT.__setitem__('te', te)
self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果
self.OUT.__setitem__('embeddings', h)
self.OUT.__setitem__('input_embedding', embedding) # for generation
self.OUT.__setitem__('input_twod_tokens', twod_tokens) # for generation
# print('embeddings',h.shape)
return self.OUT
class ConditionalLoss(nn.Module):
def __init__(self, use_te_loss=True, te_loss_weight=1.0):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
self.loss_mse = nn.MSELoss()
self.te_loss_weight = te_loss_weight # TE loss的权重
def forward(self, pred_nn:torch.Tensor=None, pred_te:torch.Tensor=None,
targets_nn:torch.Tensor=None, targets_te:torch.Tensor=None,
species_idx:torch.Tensor=None, truncated_idx:torch.Tensor=None,seq_mask:torch.Tensor=None):
# 根据特征对loss进行分组计算
total_loss = 0
batch_size,length,vocab = pred_nn.size()
# 按物种类别分组计算
unique_species = torch.unique(species_idx)
species_losses = {}
species_nn_losses = {}
species_te_losses = {}
for species in unique_species:
species_mask = species_idx == species
if species_mask.sum() > 0: # 确保有样本
# 计算NN loss todo:shape is wrong
species_nn_loss = self.loss_fn(
torch.masked_select(pred_nn[species_mask].view(-1, vocab), seq_mask[species_mask].view(-1)),
torch.masked_select(targets_nn[species_mask], seq_mask[species_mask].view(-1))
)
# species_nn_loss = self.loss_fn(pred_nn[species_mask].view(-1,vocab), targets_nn[species_mask].view(-1))
species_nn_losses[f'species_{species.item()}'] = species_nn_loss
# 计算TE loss(可选)
species_te_loss = 0
if targets_te is not None:
species_te_loss = self.loss_mse(pred_te[species_mask].view(-1), targets_te[species_mask])
species_te_losses[f'species_{species.item()}'] = species_te_loss
# 组合loss
species_loss = species_nn_loss + self.te_loss_weight * species_te_loss
species_losses[f'species_{species.item()}'] = species_loss
total_loss += species_loss
# 按截断位置分组计算
unique_trunc = torch.unique(truncated_idx)
trunc_losses = {}
trunc_nn_losses = {}
trunc_te_losses = {}
for trunc_pos in unique_trunc:
trunc_mask = (truncated_idx == trunc_pos)
if trunc_mask.sum() > 0:
# 计算NN loss
trunc_nn_loss = self.loss_fn(pred_nn[trunc_mask].view(-1,vocab), targets_nn[trunc_mask].view(-1))
trunc_nn_losses[f'trunc_{trunc_pos.item()}'] = trunc_nn_loss
# 计算TE loss(可选)
trunc_te_loss = 0
if targets_te:
trunc_te_loss = self.loss_mse(pred_te[trunc_mask].view(-1), targets_te[trunc_mask])
trunc_te_losses[f'trunc_{trunc_pos.item()}'] = trunc_te_loss
# 组合loss
trunc_loss = trunc_nn_loss + self.te_loss_weight * trunc_te_loss
trunc_losses[f'trunc_{trunc_pos.item()}'] = trunc_loss
total_loss += trunc_loss
# 计算平均loss
num_groups = len(unique_species) + len(unique_trunc)
avg_loss = total_loss / num_groups if num_groups > 0 else torch.tensor(0.0)
# 返回结果
result = {
'total_loss': avg_loss,
'species_losses': species_losses,
'trunc_losses': trunc_losses,
'species_nn_losses': species_nn_losses,
'trunc_nn_losses': trunc_nn_losses,
}
# 如果使用了TE loss,添加相关信息
if self.use_te_loss:
result.update({
'species_te_losses': species_te_losses,
'trunc_te_losses': trunc_te_losses,
'te_loss_weight': self.te_loss_weight
})
return result
class MiniMindLM_Maotao(MiniMindLM):
def __init__(self, params: LMConfig = None):
super().__init__(params)
# 禁用或忽略原有的分类头
head_dim = params.dim // params.n_heads
# self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0)
# self.exp_dropout = nn.Dropout(params.dropout)
# self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2)
# self.te_head = ConvNetCodon(params.dim,params.dim//2,1)
# Adapters for new features
# Adapters for continuous, position, and categorical features
# Assuming continuous_features and position_feature are single values per sequence position
# self.aa_embedding_adapter = VocabEmbedding(params.vocab_size, params.dim)
self.aa_embedding_adapter = VocabEmbedding(params.aa_vocab_size, params.dim,padding_idx=0) # for
self.species_feature_adapter = VocabEmbedding(params.species_size, params.dim)
self.truncated_feature_adapter = VocabEmbedding(params.truncated_size, params.dim) # full,head,tail,boundary,middle
self.continuous_features = nn.Linear(params.continuous_features_dim, params.dim) # data['off_start'],data['off_end'],data['full_len'] log(x+1)
# 特征权重学习
self.aa_attention = nn.MultiheadAttention(params.dim, num_heads= head_dim)
self.feature_fusion = nn.Linear(params.dim*3, params.dim)
# Output heads for new outputs
# Assuming target_nn, target, maotao_id are derived from the hidden state 'h'
self.te_head = ConvNetCodon(params.dim,params.dim//2,1)
# 打破权重共享,重新初始化tok_embeddings的权重
self.tok_embeddings.weight = nn.Parameter(torch.randn_like(self.tok_embeddings.weight))
# self.regression_head = nn.Linear(params.vocab_size, output_dim)
self.OUT = MyModelOutput()
# loss for training and evaluation
# self.loss_model = ConditionalLoss()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
twod_tokens: Optional[torch.Tensor] = None,
aa_idx: Optional[torch.Tensor] = None,
continuous_features: Optional[torch.Tensor] = None,
species_features: Optional[torch.Tensor] = None,
truncated_features: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
input_embedding=None,
input_twod_tokens=None,
# for training and evaluation
# targets_nn=None, targets_te=None,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
# Initial embedding from aa_idx (amino acid index) as the primary sequence input
# Assuming aa_idx replaces input_ids as the main sequence input
if aa_idx is None:
raise ValueError("aa_idx must be provided as the primary sequence input.")
'''input
aa
nn
mer3
'''
# 输入处理
'''input rna'''
input_ids = input_ids.to(torch.long)
start_pos = args.get('start_pos', 0)
twod_tokens = twod_tokens.to(torch.float32) # [1, 1, 1205, 1205]
nn_embedding = self.tok_embeddings(input_ids).to(torch.float32) # nn_embedding
if input_embedding is not None:
nn_embedding = input_embedding #+ embedding
if input_twod_tokens is not None:
twod_tokens = input_twod_tokens + 1e-7
h = self.dropout(nn_embedding)
seq_mask = input_ids == 1 # for nn
seq_mask.unsqueeze_(-1)
# Mask for padding (assuming 1 is padding index for aa_idx)
h = h.masked_fill(seq_mask, 1)
'''input aa'''
aa_idx = torch.clamp(aa_idx-9,min=0) # min(aa_idx,10), padding_idx=0 for protein, 1 is the mini aa
# from rna aa all {'<s>': 0, '<pad>': 1, '</s>': 2, '<unk>': 3, 'G': 4, 'A': 5, 'U': 6, 'C': 7, 'N': 8, '<mask>': 9, 'T': 6, '_': 1, 'a': 10, 'c': 11, 'd': 12, 'e': 13, 'f': 14, 'g': 15, 'h': 16, 'i': 17, 'k': 18, 'l': 19, 'm': 20, 'n': 21, 'p': 22, 'q': 23, 'r': 24, 's': 25, 't': 26, 'v': 27, 'w': 28, 'y': 29, '*': 30, '-': 31}
# to {'_': 0, 'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9, 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17, 'V': 18, 'W': 19, 'Y': 20, '*': 21}
aa_embedding = self.aa_embedding_adapter(aa_idx.to(torch.long)).to(torch.float32) # 1200/3
species_features = self.species_feature_adapter(species_features).to(torch.float32) # 1
truncated_features = self.truncated_feature_adapter(truncated_features).to(torch.float32) # 1
continuous_features = self.continuous_features(continuous_features).to(torch.float32) # 3
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
l = 0
layer = self.layers[l]
h, past_kv = layer(
h, pos_cis,
twod_tokens=twod_tokens,
past_key_value=past_key_values[l],
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
'''
模态融合
'''
batch_size, seq_len, hidden_dim = h.shape
frame_1 = h[:, 0::3, :]
frame_2 = h[:, 1::3, :]
frame_3 = h[:, 2::3, :]
# 第一层:氨基酸级别的调整
# 特征作为query,序列作为key和value
attended_from_aa, _ = self.aa_attention(
query=frame_3, # 调整目标:第三位密码子
key=aa_embedding, # 调整依据:特征
value=aa_embedding # 调整内容:特征
)
# 第二层:全局特征的调整
global_features = torch.cat([species_features, truncated_features, continuous_features], dim=1)
global_features = self.feature_fusion(global_features)
global_features = global_features.unsqueeze(1).expand(-1, frame_3.size(1), -1)
# 如果需要恢复到原始序列顺序,使用:
new_h_reshaped = torch.stack([frame_1, frame_2, frame_3+ attended_from_aa + global_features], dim=2)
h = new_h_reshaped.reshape(batch_size, -1, hidden_dim)
h = h.masked_fill(seq_mask, 0)
for l, layer in enumerate(self.layers[1:4]):
h, past_kv = layer(
h, pos_cis,
twod_tokens=None,
past_key_value=past_key_values[l+1],
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
# 后四层处理
for l, layer in enumerate(self.layers[4:]):
h, past_kv = layer(
h, pos_cis,
twod_tokens=None,
past_key_value=past_key_values[l + 4], # 调整索引以访问后四层的past_key_value
use_cache=use_cache
)
h = h.masked_fill(seq_mask, 0)
past_kvs.append(past_kv)
h = self.norm(h)
# if input_embedding is not None:
# h = input_embedding + embedding
h = h.masked_fill(seq_mask, 0)
logits = self.output(h)
# moe_aux_loss = [l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)]
# aux_loss = sum(moe_aux_loss) if len(moe_aux_loss)>0 else 0 # 非moe这里就是0
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
h = h.masked_fill(seq_mask, 0)
cai = self.te_head(h)
zero_shot = self.get_zero_shot(seq_mask,h) if not h.requires_grad else None
# if not h.requires_grad:
# # 计算非 padding 元素的总和
# sum_h = torch.sum(h * ~seq_mask, dim=(1, 2))
#
# # 计算非 padding 元素的数量
# count_h = torch.sum(~seq_mask, dim=(1, 2))
# # 计算均值
# mean_h = sum_h / count_h
#
# # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0
# mean_h[count_h == 0] = 0
#
# # 将均值 reshape 为 (-1, 1)
# zero_shot = mean_h.reshape(-1, 1)
# # print(zero_shot.shape,zero_shot)
# else:
# zero_shot = None
# if targets_nn is not None:
# # pred_nn, pred_te, targets_nn,targets_te, species_features, truncated_features
# # self.loss_model(logits, cai, targets_nn, targets_te, species_features, truncated_features)
# loss = self.loss_model(pred_nn=logits, pred_te=cai,
# targets_nn=targets_nn, targets_te=targets_te,
# species_idx=species_idx, truncated_idx=truncated_idx,
# seq_mask=seq_mask)
# self.OUT.__setitem__('loss', loss)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
self.OUT.__setitem__('te', cai) # cai here
self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果
self.OUT.__setitem__('embeddings', h)
self.OUT.__setitem__('input_embedding', nn_embedding) # for generation
self.OUT.__setitem__('input_twod_tokens', twod_tokens) # for generation
# print('embeddings',h.shape)
return self.OUT
def get_zero_shot(self,seq_mask,h):
# h = self.OUT.__getitem__('embeddings')
# 计算非 padding 元素的总和
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2))
# 计算非 padding 元素的数量
count_h = torch.sum(~seq_mask, dim=(1, 2))
# 计算均值
mean_h = sum_h / count_h
# 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0
mean_h[count_h == 0] = 0
# 将均值 reshape 为 (-1, 1)
zero_shot = mean_h.reshape(-1, 1)
# print(zero_shot.shape,zero_shot)
return zero_shot