|
|
|
|
|
|
|
|
""" |
|
|
Title : model_exp.py |
|
|
project : minimind_RiboUTR |
|
|
Created by: julse |
|
|
Created on: 2025/3/9 00:45 |
|
|
des: TODO |
|
|
""" |
|
|
import os |
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from typing import Any, Optional, Tuple, List |
|
|
|
|
|
from .model_downstream import ConvNetCodon |
|
|
from .model_ribo import MiniMindLM, MOEFeedForward, NonLinearHead |
|
|
from .LMConfig import LMConfig,LMaoTaoConfig |
|
|
|
|
|
|
|
|
class VocabEmbedding(nn.Module): |
|
|
def __init__(self, vocab_size, embedding_dim,padding_idx=None): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim,padding_idx=padding_idx) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.embedding(x) |
|
|
x = self.dropout(x) |
|
|
return x |
|
|
class ExpAdapter(nn.Module): |
|
|
def __init__(self, input_dim, output_dim,padding_idx=0): |
|
|
super().__init__() |
|
|
self.padding_idx = padding_idx |
|
|
self.linear = nn.Linear(input_dim, output_dim) |
|
|
def forward(self, x): |
|
|
mask = x == self.padding_idx |
|
|
x = self.linear(x) |
|
|
x = x.masked_fill(mask[:,:,0].unsqueeze(-1).repeat(1,1,x.shape[-1]), 0) |
|
|
return x |
|
|
class UncertaintyWeighting(nn.Module): |
|
|
def __init__(self, num_losses): |
|
|
super().__init__() |
|
|
self.log_vars = nn.Parameter(torch.zeros(num_losses)) |
|
|
|
|
|
def forward(self, losses): |
|
|
total_loss = 0 |
|
|
for i, loss in enumerate(losses): |
|
|
|
|
|
precision = torch.exp(-self.log_vars[i]) |
|
|
total_loss += precision * loss + self.log_vars[i] |
|
|
|
|
|
return total_loss |
|
|
|
|
|
class MyModelOutput(CausalLMOutputWithPast): |
|
|
def __init__(self,logits: Optional[torch.FloatTensor] = None, |
|
|
aux_loss: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
|
exp: Optional[torch.FloatTensor] = None, |
|
|
te: Optional[torch.FloatTensor] = None, |
|
|
zero_shot: Optional[torch.FloatTensor] = None, |
|
|
embeddings: Optional[torch.FloatTensor] = None, |
|
|
input_embedding: Optional[torch.FloatTensor] = None, |
|
|
input_twod_tokens: Optional[torch.FloatTensor] = None, |
|
|
hidden_states: Optional[torch.FloatTensor] = None, |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None, |
|
|
**kwargs): |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
logits=logits, |
|
|
past_key_values=past_key_values, |
|
|
hidden_states=hidden_states, |
|
|
attentions=attentions, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self.aux_loss = aux_loss |
|
|
self.exp = exp |
|
|
self.te = te |
|
|
self.zero_shot = zero_shot |
|
|
self.embeddings = embeddings |
|
|
self.input_embedding = input_embedding |
|
|
self.input_twod_tokens = input_twod_tokens |
|
|
|
|
|
|
|
|
self._sync_device() |
|
|
|
|
|
def _sync_device(self) -> None: |
|
|
"""同步所有张量到 logits 所在的设备""" |
|
|
if not hasattr(self, 'logits') or self.logits is None: |
|
|
return |
|
|
|
|
|
base_device = self.logits.device |
|
|
for field in ['aux_loss', 'te', 'embeddings', 'input_embedding', 'input_twod_tokens']: |
|
|
tensor = getattr(self, field) |
|
|
if isinstance(tensor, torch.Tensor) and tensor.device != base_device: |
|
|
setattr(self, field, tensor.to(base_device)) |
|
|
|
|
|
def __setattr__(self, name, value): |
|
|
"""重写属性设置,确保新张量自动同步设备""" |
|
|
super().__setattr__(name, value) |
|
|
if isinstance(value, torch.Tensor) and hasattr(self, 'logits') and self.logits is not None: |
|
|
if value.device != self.logits.device: |
|
|
super().__setattr__(name, value.to(self.logits.device)) |
|
|
class MiniMindLMForExp(MiniMindLM): |
|
|
def __init__(self, params: LMConfig = None, env_counts=2471): |
|
|
super().__init__(params) |
|
|
|
|
|
|
|
|
self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) |
|
|
self.exp_dropout = nn.Dropout(params.dropout) |
|
|
|
|
|
self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1) |
|
|
|
|
|
self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) |
|
|
self.te_head = ConvNetCodon(params.dim,params.dim//2,1) |
|
|
|
|
|
|
|
|
def forward(self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
src_exp_data: Optional[torch.Tensor] = None, |
|
|
env_ids: Optional[torch.Tensor] = None, |
|
|
twod_tokens: Optional[torch.Tensor] = None, |
|
|
src_feature: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False, |
|
|
**args): |
|
|
|
|
|
past_key_values = past_key_values or [None] * len(self.layers) |
|
|
|
|
|
|
|
|
input_ids = input_ids.to(torch.long) |
|
|
start_pos = args.get('start_pos', 0) |
|
|
twod_tokens = twod_tokens.to(torch.float32) |
|
|
|
|
|
h = self.dropout(self.tok_embeddings(input_ids)) |
|
|
seq_mask = input_ids == 1 |
|
|
seq_mask.unsqueeze_(-1) |
|
|
h = h.masked_fill_(seq_mask, 0) |
|
|
|
|
|
if src_exp_data is not None: |
|
|
src_exp_data = src_exp_data.to(torch.float32) |
|
|
src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) |
|
|
src_exp_data = src_exp_data.masked_fill_(seq_mask, 0) |
|
|
h+=src_exp_data |
|
|
h = h.masked_fill_(seq_mask, 0) |
|
|
|
|
|
if env_ids is not None: |
|
|
env_ids = env_ids.to(torch.long) |
|
|
env_ids = self.env_adapter(env_ids.unsqueeze(-1)) |
|
|
h+=env_ids |
|
|
h = h.masked_fill_(seq_mask, 0) |
|
|
|
|
|
|
|
|
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] |
|
|
past_kvs = [] |
|
|
for l, layer in enumerate(self.layers): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=twod_tokens, |
|
|
past_key_value=past_key_values[l], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill_(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
h = self.norm(h) |
|
|
h = h.masked_fill_(seq_mask, 0) |
|
|
logits = self.output(h) |
|
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) |
|
|
|
|
|
exp = self.exp_head(h) if hasattr(self,'exp_head') else None |
|
|
if exp is not None:exp = exp.masked_fill_(seq_mask, 0) |
|
|
h = h.masked_fill_(seq_mask, 0) |
|
|
|
|
|
te = self.te_head(h) |
|
|
|
|
|
if not h.requires_grad: |
|
|
|
|
|
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) |
|
|
|
|
|
|
|
|
count_h = torch.sum(~seq_mask, dim=(1, 2)) |
|
|
|
|
|
mean_h = sum_h / count_h |
|
|
|
|
|
|
|
|
mean_h[count_h == 0] = 0 |
|
|
|
|
|
|
|
|
zero_shot = mean_h.reshape(-1, 1) |
|
|
|
|
|
else: |
|
|
zero_shot = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.OUT.__setitem__('logits', logits) |
|
|
self.OUT.__setitem__('aux_loss', aux_loss) |
|
|
self.OUT.__setitem__('past_key_values', past_kvs) |
|
|
self.OUT.__setitem__('exp', exp) |
|
|
self.OUT.__setitem__('te', te) |
|
|
self.OUT.__setitem__('zero_shot', zero_shot) |
|
|
self.OUT.__setitem__('embeddings', h) |
|
|
|
|
|
return self.OUT |
|
|
|
|
|
|
|
|
class MiniMindLMForExp44(MiniMindLM): |
|
|
def __init__(self, params: LMConfig = None, env_counts=2471): |
|
|
super().__init__(params) |
|
|
|
|
|
|
|
|
self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) |
|
|
self.exp_dropout = nn.Dropout(params.dropout) |
|
|
|
|
|
self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1) |
|
|
|
|
|
self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) |
|
|
self.te_head = ConvNetCodon(params.dim,params.dim//2,1) |
|
|
|
|
|
|
|
|
def forward(self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
src_exp_data: Optional[torch.Tensor] = None, |
|
|
env_ids: Optional[torch.Tensor] = None, |
|
|
twod_tokens: Optional[torch.Tensor] = None, |
|
|
src_feature: Optional[torch.Tensor] = None, |
|
|
input_embedding=None, |
|
|
input_twod_tokens = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False, |
|
|
**args): |
|
|
|
|
|
past_key_values = past_key_values or [None] * len(self.layers) |
|
|
|
|
|
|
|
|
input_ids = input_ids.to(torch.long) |
|
|
start_pos = args.get('start_pos', 0) |
|
|
twod_tokens = twod_tokens.to(torch.float32) |
|
|
embedding = self.tok_embeddings(input_ids).to(torch.float32) |
|
|
|
|
|
if input_embedding is not None: |
|
|
embedding = input_embedding |
|
|
|
|
|
if input_twod_tokens is not None: |
|
|
twod_tokens = input_twod_tokens + 1e-7 |
|
|
h = self.dropout(embedding) |
|
|
seq_mask = input_ids == 1 |
|
|
seq_mask.unsqueeze_(-1) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] |
|
|
past_kvs = [] |
|
|
for l, layer in enumerate(self.layers[:4]): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=twod_tokens, |
|
|
past_key_value=past_key_values[l], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
if src_exp_data is not None: |
|
|
src_exp_data = src_exp_data.to(torch.float32) |
|
|
src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) |
|
|
src_exp_data = src_exp_data.masked_fill(seq_mask, 0) |
|
|
h+=src_exp_data |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
if env_ids is not None: |
|
|
env_ids = env_ids.to(torch.long) |
|
|
env_ids = self.env_adapter(env_ids.unsqueeze(-1)) |
|
|
h+=env_ids |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
for l, layer in enumerate(self.layers[4:]): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=twod_tokens, |
|
|
past_key_value=past_key_values[l + 4], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
h = self.norm(h) |
|
|
|
|
|
|
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
logits = self.output(h) |
|
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) |
|
|
|
|
|
exp = self.exp_head(h) if hasattr(self,'exp_head') else None |
|
|
if exp is not None:exp = exp.masked_fill_(seq_mask, 0) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
te = self.te_head(h) |
|
|
|
|
|
if not h.requires_grad: |
|
|
|
|
|
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) |
|
|
|
|
|
|
|
|
count_h = torch.sum(~seq_mask, dim=(1, 2)) |
|
|
|
|
|
mean_h = sum_h / count_h |
|
|
|
|
|
|
|
|
mean_h[count_h == 0] = 0 |
|
|
|
|
|
|
|
|
zero_shot = mean_h.reshape(-1, 1) |
|
|
|
|
|
else: |
|
|
zero_shot = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.OUT.__setitem__('logits', logits) |
|
|
self.OUT.__setitem__('aux_loss', aux_loss) |
|
|
self.OUT.__setitem__('past_key_values', past_kvs) |
|
|
self.OUT.__setitem__('exp', exp) |
|
|
self.OUT.__setitem__('te', te) |
|
|
self.OUT.__setitem__('zero_shot', zero_shot) |
|
|
self.OUT.__setitem__('embeddings', h) |
|
|
self.OUT.__setitem__('input_embedding', embedding) |
|
|
self.OUT.__setitem__('input_twod_tokens', twod_tokens) |
|
|
|
|
|
return self.OUT |
|
|
|
|
|
class MiniMindLMForExp44_4region(MiniMindLM): |
|
|
def __init__(self, params: LMConfig = None, env_counts=2471,break_position=604): |
|
|
super().__init__(params) |
|
|
|
|
|
|
|
|
self.exp_adapter = ExpAdapter(3, params.dim,padding_idx=0) |
|
|
self.exp_dropout = nn.Dropout(params.dropout) |
|
|
|
|
|
self.env_adapter = nn.Embedding(env_counts, params.dim,padding_idx=1) |
|
|
|
|
|
self.exp_head = NonLinearHead(params.dim, 3,'relu', hidden=params.dim//2) |
|
|
self.te_head = ConvNetCodon(params.dim,params.dim//2,1) |
|
|
self.break_position = break_position |
|
|
|
|
|
self.OUT = MyModelOutput() |
|
|
def forward(self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
src_exp_data: Optional[torch.Tensor] = None, |
|
|
env_ids: Optional[torch.Tensor] = None, |
|
|
twod_tokens: Optional[torch.Tensor] = None, |
|
|
src_feature: Optional[torch.Tensor] = None, |
|
|
input_embedding=None, |
|
|
input_twod_tokens = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False, |
|
|
**args): |
|
|
|
|
|
past_key_values = past_key_values or [None] * len(self.layers) |
|
|
|
|
|
|
|
|
input_ids = input_ids.to(torch.long) |
|
|
start_pos = args.get('start_pos', 0) |
|
|
twod_tokens = twod_tokens.to(torch.float32) |
|
|
embedding = self.tok_embeddings(input_ids).to(torch.float32) |
|
|
|
|
|
if input_embedding is not None: |
|
|
embedding = input_embedding |
|
|
|
|
|
if input_twod_tokens is not None: |
|
|
twod_tokens = input_twod_tokens + 1e-7 |
|
|
h = self.dropout(embedding) |
|
|
seq_mask = input_ids == 1 |
|
|
seq_mask.unsqueeze_(-1) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
gap = 3000 |
|
|
|
|
|
if input_ids.size(1)<self.break_position: |
|
|
start_pos = self.break_position-input_ids.size(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pos_cis = torch.concat([self.pos_cis[start_pos:self.break_position],self.pos_cis[self.break_position+gap:gap+input_ids.size(1)]],dim=0) |
|
|
|
|
|
assert input_ids.size(1) <= pos_cis.size(0)+gap,( |
|
|
f"Sequence length mismatch: input_ids has length {input_ids.size(1)},gap is {gap} " |
|
|
f"but pos_cis has length {self.pos_cis.size(0)}. " |
|
|
f"Please ensure that max_seq_len is set to at least {input_ids.size(1)+gap}." |
|
|
) |
|
|
past_kvs = [] |
|
|
|
|
|
l = 0 |
|
|
layer = self.layers[l] |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=twod_tokens, |
|
|
past_key_value=past_key_values[l], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
|
|
|
for l, layer in enumerate(self.layers[1:4]): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=None, |
|
|
past_key_value=past_key_values[l+1], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
if src_exp_data is not None: |
|
|
src_exp_data = src_exp_data.to(torch.float32) |
|
|
src_exp_data = self.exp_dropout(self.exp_adapter(src_exp_data)) |
|
|
src_exp_data = src_exp_data.masked_fill(seq_mask, 0) |
|
|
h+=src_exp_data |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
if env_ids is not None: |
|
|
env_ids = env_ids.to(torch.long) |
|
|
env_ids = self.env_adapter(env_ids.unsqueeze(-1)) |
|
|
h+=env_ids |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
for l, layer in enumerate(self.layers[4:]): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=None, |
|
|
past_key_value=past_key_values[l + 4], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
h = self.norm(h) |
|
|
|
|
|
|
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
logits = self.output(h) |
|
|
|
|
|
|
|
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) |
|
|
|
|
|
exp = self.exp_head(h) if hasattr(self,'exp_head') else None |
|
|
if exp is not None:exp = exp.masked_fill_(seq_mask, 0) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
te = self.te_head(h) |
|
|
|
|
|
if not h.requires_grad: |
|
|
|
|
|
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) |
|
|
|
|
|
|
|
|
count_h = torch.sum(~seq_mask, dim=(1, 2)) |
|
|
|
|
|
mean_h = sum_h / count_h |
|
|
|
|
|
|
|
|
mean_h[count_h == 0] = 0 |
|
|
|
|
|
|
|
|
zero_shot = mean_h.reshape(-1, 1) |
|
|
|
|
|
else: |
|
|
zero_shot = None |
|
|
|
|
|
|
|
|
|
|
|
self.OUT.__setitem__('logits', logits) |
|
|
self.OUT.__setitem__('aux_loss', aux_loss) |
|
|
self.OUT.__setitem__('past_key_values', past_kvs) |
|
|
self.OUT.__setitem__('exp', exp) |
|
|
self.OUT.__setitem__('te', te) |
|
|
self.OUT.__setitem__('zero_shot', zero_shot) |
|
|
self.OUT.__setitem__('embeddings', h) |
|
|
self.OUT.__setitem__('input_embedding', embedding) |
|
|
self.OUT.__setitem__('input_twod_tokens', twod_tokens) |
|
|
|
|
|
return self.OUT |
|
|
|
|
|
|
|
|
class ConditionalLoss(nn.Module): |
|
|
def __init__(self, use_te_loss=True, te_loss_weight=1.0): |
|
|
super().__init__() |
|
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
self.loss_mse = nn.MSELoss() |
|
|
self.te_loss_weight = te_loss_weight |
|
|
|
|
|
def forward(self, pred_nn:torch.Tensor=None, pred_te:torch.Tensor=None, |
|
|
targets_nn:torch.Tensor=None, targets_te:torch.Tensor=None, |
|
|
species_idx:torch.Tensor=None, truncated_idx:torch.Tensor=None,seq_mask:torch.Tensor=None): |
|
|
|
|
|
total_loss = 0 |
|
|
batch_size,length,vocab = pred_nn.size() |
|
|
|
|
|
|
|
|
unique_species = torch.unique(species_idx) |
|
|
species_losses = {} |
|
|
species_nn_losses = {} |
|
|
species_te_losses = {} |
|
|
|
|
|
for species in unique_species: |
|
|
species_mask = species_idx == species |
|
|
if species_mask.sum() > 0: |
|
|
|
|
|
species_nn_loss = self.loss_fn( |
|
|
torch.masked_select(pred_nn[species_mask].view(-1, vocab), seq_mask[species_mask].view(-1)), |
|
|
torch.masked_select(targets_nn[species_mask], seq_mask[species_mask].view(-1)) |
|
|
) |
|
|
|
|
|
species_nn_losses[f'species_{species.item()}'] = species_nn_loss |
|
|
|
|
|
|
|
|
species_te_loss = 0 |
|
|
if targets_te is not None: |
|
|
species_te_loss = self.loss_mse(pred_te[species_mask].view(-1), targets_te[species_mask]) |
|
|
species_te_losses[f'species_{species.item()}'] = species_te_loss |
|
|
|
|
|
|
|
|
species_loss = species_nn_loss + self.te_loss_weight * species_te_loss |
|
|
species_losses[f'species_{species.item()}'] = species_loss |
|
|
total_loss += species_loss |
|
|
|
|
|
|
|
|
unique_trunc = torch.unique(truncated_idx) |
|
|
trunc_losses = {} |
|
|
trunc_nn_losses = {} |
|
|
trunc_te_losses = {} |
|
|
|
|
|
for trunc_pos in unique_trunc: |
|
|
trunc_mask = (truncated_idx == trunc_pos) |
|
|
if trunc_mask.sum() > 0: |
|
|
|
|
|
trunc_nn_loss = self.loss_fn(pred_nn[trunc_mask].view(-1,vocab), targets_nn[trunc_mask].view(-1)) |
|
|
trunc_nn_losses[f'trunc_{trunc_pos.item()}'] = trunc_nn_loss |
|
|
|
|
|
|
|
|
trunc_te_loss = 0 |
|
|
if targets_te: |
|
|
trunc_te_loss = self.loss_mse(pred_te[trunc_mask].view(-1), targets_te[trunc_mask]) |
|
|
trunc_te_losses[f'trunc_{trunc_pos.item()}'] = trunc_te_loss |
|
|
|
|
|
|
|
|
trunc_loss = trunc_nn_loss + self.te_loss_weight * trunc_te_loss |
|
|
trunc_losses[f'trunc_{trunc_pos.item()}'] = trunc_loss |
|
|
total_loss += trunc_loss |
|
|
|
|
|
|
|
|
num_groups = len(unique_species) + len(unique_trunc) |
|
|
avg_loss = total_loss / num_groups if num_groups > 0 else torch.tensor(0.0) |
|
|
|
|
|
|
|
|
result = { |
|
|
'total_loss': avg_loss, |
|
|
'species_losses': species_losses, |
|
|
'trunc_losses': trunc_losses, |
|
|
'species_nn_losses': species_nn_losses, |
|
|
'trunc_nn_losses': trunc_nn_losses, |
|
|
} |
|
|
|
|
|
|
|
|
if self.use_te_loss: |
|
|
result.update({ |
|
|
'species_te_losses': species_te_losses, |
|
|
'trunc_te_losses': trunc_te_losses, |
|
|
'te_loss_weight': self.te_loss_weight |
|
|
}) |
|
|
|
|
|
return result |
|
|
class MiniMindLM_Maotao(MiniMindLM): |
|
|
def __init__(self, params: LMConfig = None): |
|
|
super().__init__(params) |
|
|
|
|
|
head_dim = params.dim // params.n_heads |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.aa_embedding_adapter = VocabEmbedding(params.aa_vocab_size, params.dim,padding_idx=0) |
|
|
self.species_feature_adapter = VocabEmbedding(params.species_size, params.dim) |
|
|
self.truncated_feature_adapter = VocabEmbedding(params.truncated_size, params.dim) |
|
|
self.continuous_features = nn.Linear(params.continuous_features_dim, params.dim) |
|
|
|
|
|
|
|
|
self.aa_attention = nn.MultiheadAttention(params.dim, num_heads= head_dim) |
|
|
self.feature_fusion = nn.Linear(params.dim*3, params.dim) |
|
|
|
|
|
|
|
|
self.te_head = ConvNetCodon(params.dim,params.dim//2,1) |
|
|
|
|
|
|
|
|
self.tok_embeddings.weight = nn.Parameter(torch.randn_like(self.tok_embeddings.weight)) |
|
|
|
|
|
self.OUT = MyModelOutput() |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
twod_tokens: Optional[torch.Tensor] = None, |
|
|
aa_idx: Optional[torch.Tensor] = None, |
|
|
continuous_features: Optional[torch.Tensor] = None, |
|
|
species_features: Optional[torch.Tensor] = None, |
|
|
truncated_features: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False, |
|
|
input_embedding=None, |
|
|
input_twod_tokens=None, |
|
|
|
|
|
|
|
|
**args): |
|
|
|
|
|
past_key_values = past_key_values or [None] * len(self.layers) |
|
|
|
|
|
|
|
|
|
|
|
if aa_idx is None: |
|
|
raise ValueError("aa_idx must be provided as the primary sequence input.") |
|
|
'''input |
|
|
aa |
|
|
nn |
|
|
mer3 |
|
|
''' |
|
|
|
|
|
'''input rna''' |
|
|
input_ids = input_ids.to(torch.long) |
|
|
start_pos = args.get('start_pos', 0) |
|
|
twod_tokens = twod_tokens.to(torch.float32) |
|
|
nn_embedding = self.tok_embeddings(input_ids).to(torch.float32) |
|
|
|
|
|
if input_embedding is not None: |
|
|
nn_embedding = input_embedding |
|
|
|
|
|
if input_twod_tokens is not None: |
|
|
twod_tokens = input_twod_tokens + 1e-7 |
|
|
|
|
|
h = self.dropout(nn_embedding) |
|
|
|
|
|
seq_mask = input_ids == 1 |
|
|
seq_mask.unsqueeze_(-1) |
|
|
|
|
|
h = h.masked_fill(seq_mask, 1) |
|
|
|
|
|
'''input aa''' |
|
|
aa_idx = torch.clamp(aa_idx-9,min=0) |
|
|
|
|
|
|
|
|
aa_embedding = self.aa_embedding_adapter(aa_idx.to(torch.long)).to(torch.float32) |
|
|
species_features = self.species_feature_adapter(species_features).to(torch.float32) |
|
|
truncated_features = self.truncated_feature_adapter(truncated_features).to(torch.float32) |
|
|
continuous_features = self.continuous_features(continuous_features).to(torch.float32) |
|
|
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] |
|
|
past_kvs = [] |
|
|
|
|
|
l = 0 |
|
|
layer = self.layers[l] |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=twod_tokens, |
|
|
past_key_value=past_key_values[l], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
|
|
|
|
|
|
''' |
|
|
模态融合 |
|
|
''' |
|
|
batch_size, seq_len, hidden_dim = h.shape |
|
|
frame_1 = h[:, 0::3, :] |
|
|
frame_2 = h[:, 1::3, :] |
|
|
frame_3 = h[:, 2::3, :] |
|
|
|
|
|
|
|
|
|
|
|
attended_from_aa, _ = self.aa_attention( |
|
|
query=frame_3, |
|
|
key=aa_embedding, |
|
|
value=aa_embedding |
|
|
) |
|
|
|
|
|
|
|
|
global_features = torch.cat([species_features, truncated_features, continuous_features], dim=1) |
|
|
global_features = self.feature_fusion(global_features) |
|
|
global_features = global_features.unsqueeze(1).expand(-1, frame_3.size(1), -1) |
|
|
|
|
|
|
|
|
|
|
|
new_h_reshaped = torch.stack([frame_1, frame_2, frame_3+ attended_from_aa + global_features], dim=2) |
|
|
h = new_h_reshaped.reshape(batch_size, -1, hidden_dim) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
for l, layer in enumerate(self.layers[1:4]): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=None, |
|
|
past_key_value=past_key_values[l+1], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
|
|
|
|
|
|
|
|
|
for l, layer in enumerate(self.layers[4:]): |
|
|
h, past_kv = layer( |
|
|
h, pos_cis, |
|
|
twod_tokens=None, |
|
|
past_key_value=past_key_values[l + 4], |
|
|
use_cache=use_cache |
|
|
) |
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
past_kvs.append(past_kv) |
|
|
h = self.norm(h) |
|
|
|
|
|
|
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
logits = self.output(h) |
|
|
|
|
|
|
|
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) |
|
|
|
|
|
h = h.masked_fill(seq_mask, 0) |
|
|
|
|
|
cai = self.te_head(h) |
|
|
zero_shot = self.get_zero_shot(seq_mask,h) if not h.requires_grad else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.OUT.__setitem__('logits', logits) |
|
|
self.OUT.__setitem__('aux_loss', aux_loss) |
|
|
self.OUT.__setitem__('past_key_values', past_kvs) |
|
|
self.OUT.__setitem__('te', cai) |
|
|
self.OUT.__setitem__('zero_shot', zero_shot) |
|
|
self.OUT.__setitem__('embeddings', h) |
|
|
self.OUT.__setitem__('input_embedding', nn_embedding) |
|
|
self.OUT.__setitem__('input_twod_tokens', twod_tokens) |
|
|
|
|
|
|
|
|
return self.OUT |
|
|
def get_zero_shot(self,seq_mask,h): |
|
|
|
|
|
|
|
|
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) |
|
|
|
|
|
|
|
|
count_h = torch.sum(~seq_mask, dim=(1, 2)) |
|
|
|
|
|
mean_h = sum_h / count_h |
|
|
|
|
|
|
|
|
mean_h[count_h == 0] = 0 |
|
|
|
|
|
|
|
|
zero_shot = mean_h.reshape(-1, 1) |
|
|
|
|
|
return zero_shot |
|
|
|
|
|
|
|
|
|