|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
import math |
|
|
from components import RMSNorm |
|
|
from transformer import OptimizedTransformerBlock |
|
|
from multimodel_fusion import MultiModalFusionModule |
|
|
from encoders import ( |
|
|
ImprovedVisionTransformer, |
|
|
ImprovedAudioEncoder, |
|
|
ImprovedVideoEncoder |
|
|
) |
|
|
|
|
|
class MultiModalDenseTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_dim: int = 2048, |
|
|
vocab_size: int = 30000, |
|
|
n_layers: int = 48, |
|
|
n_heads: int = 32, |
|
|
n_kv_heads: Optional[int] = None, |
|
|
head_dim: Optional[int] = None, |
|
|
max_seq_len: int = 8192, |
|
|
dropout: float = 0.0, |
|
|
attn_dropout: float = 0.0, |
|
|
|
|
|
|
|
|
use_moe: bool = False, |
|
|
num_experts: int = 8, |
|
|
moe_top_k: int = 2, |
|
|
moe_layers: Optional[List[int]] = None, |
|
|
|
|
|
|
|
|
use_adapter: bool = False, |
|
|
adapter_dim: int = 64, |
|
|
use_lora: bool = False, |
|
|
lora_rank: int = 8, |
|
|
|
|
|
|
|
|
use_gradient_checkpointing: bool = False, |
|
|
use_parallel_residual: bool = False, |
|
|
|
|
|
|
|
|
rope_scaling_factor: float = 1.0, |
|
|
rope_scaling_type: str = "yarn", |
|
|
sliding_window: Optional[int] = None, |
|
|
|
|
|
|
|
|
norm_eps: float = 1e-6, |
|
|
initializer_range: float = 0.02, |
|
|
ffn_dim_multiplier: Optional[float] = None, |
|
|
tie_word_embeddings: bool = True, |
|
|
|
|
|
|
|
|
use_multimodal_fusion: bool = True, |
|
|
fusion_layers: int = 4, |
|
|
use_contrastive: bool = True, |
|
|
vision_depth: int = 24, |
|
|
audio_depth: int = 12, |
|
|
video_spatial_depth: int = 12, |
|
|
video_temporal_depth: int = 4 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.model_dim = model_dim |
|
|
self.vocab_size = vocab_size |
|
|
self.n_layers = n_layers |
|
|
self.max_seq_len = max_seq_len |
|
|
self.use_gradient_checkpointing = use_gradient_checkpointing |
|
|
self.tie_word_embeddings = tie_word_embeddings |
|
|
self.use_multimodal_fusion = use_multimodal_fusion |
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, model_dim) |
|
|
self.modality_embedding = nn.Embedding(4, model_dim) |
|
|
self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
self.vision_encoder = ImprovedVisionTransformer( |
|
|
embed_dim=model_dim, |
|
|
depth=vision_depth, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
use_adapter=use_adapter, |
|
|
adapter_dim=adapter_dim |
|
|
) |
|
|
|
|
|
self.audio_encoder = ImprovedAudioEncoder( |
|
|
embed_dim=model_dim, |
|
|
depth=audio_depth, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
use_adapter=use_adapter, |
|
|
adapter_dim=adapter_dim |
|
|
) |
|
|
|
|
|
self.video_encoder = ImprovedVideoEncoder( |
|
|
embed_dim=model_dim, |
|
|
spatial_depth=video_spatial_depth, |
|
|
temporal_depth=video_temporal_depth, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
use_adapter=use_adapter, |
|
|
adapter_dim=adapter_dim |
|
|
) |
|
|
|
|
|
|
|
|
if use_multimodal_fusion: |
|
|
self.fusion_module = MultiModalFusionModule( |
|
|
dim=model_dim, |
|
|
num_fusion_layers=fusion_layers, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
use_contrastive=use_contrastive |
|
|
) |
|
|
|
|
|
if moe_layers is None and use_moe: |
|
|
moe_layers = list(range(n_layers // 2, n_layers)) |
|
|
elif moe_layers is None: |
|
|
moe_layers = [] |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
OptimizedTransformerBlock( |
|
|
dim=model_dim, |
|
|
n_heads=n_heads, |
|
|
n_kv_heads=n_kv_heads, |
|
|
head_dim=head_dim, |
|
|
dropout=dropout, |
|
|
attn_dropout=attn_dropout, |
|
|
use_moe=(use_moe and i in moe_layers), |
|
|
num_experts=num_experts, |
|
|
moe_top_k=moe_top_k, |
|
|
use_adapter=use_adapter, |
|
|
adapter_dim=adapter_dim, |
|
|
use_lora=use_lora, |
|
|
lora_rank=lora_rank, |
|
|
use_parallel_residual=use_parallel_residual, |
|
|
norm_eps=norm_eps, |
|
|
sliding_window=sliding_window, |
|
|
ffn_dim_multiplier=ffn_dim_multiplier, |
|
|
layer_idx=i |
|
|
) |
|
|
for i in range(n_layers) |
|
|
]) |
|
|
|
|
|
self.norm = RMSNorm(model_dim, eps=norm_eps) |
|
|
self.lm_head = nn.Linear(model_dim, vocab_size, bias=False) |
|
|
|
|
|
if tie_word_embeddings: |
|
|
self.lm_head.weight = self.token_embedding.weight |
|
|
|
|
|
self.initializer_range = initializer_range |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
if not tie_word_embeddings: |
|
|
self._init_lm_head() |
|
|
|
|
|
self.n_params = sum(p.numel() for p in self.parameters()) |
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"Improved Model Configuration:") |
|
|
print(f" Model Dimension: {model_dim}") |
|
|
print(f" Vocab Size: {vocab_size}") |
|
|
print(f" Layers: {n_layers}") |
|
|
print(f" Attention Heads: {n_heads}") |
|
|
print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}") |
|
|
print(f" Max Sequence Length: {max_seq_len}") |
|
|
print(f" Multimodal Fusion: {use_multimodal_fusion}") |
|
|
print(f" Contrastive Learning: {use_contrastive}") |
|
|
print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})") |
|
|
print(f" Total Parameters: {self.n_params / 1e9:.2f}B") |
|
|
print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""权重初始化""" |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) |
|
|
if hasattr(module, 'padding_idx') and module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def _init_lm_head(self): |
|
|
"""初始化LM head""" |
|
|
std = self.initializer_range / math.sqrt(2 * self.n_layers) |
|
|
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std) |
|
|
|
|
|
def _encode_modality(self, segment: Dict) -> torch.Tensor: |
|
|
"""编码单个模态""" |
|
|
seg_type = segment['type'] |
|
|
seg_data = segment['data'] |
|
|
|
|
|
if seg_type == 'image': |
|
|
return self.vision_encoder(seg_data) |
|
|
elif seg_type == 'audio': |
|
|
return self.audio_encoder(seg_data) |
|
|
elif seg_type == 'video': |
|
|
return self.video_encoder(seg_data) |
|
|
elif seg_type == 'text': |
|
|
return self.token_embedding(seg_data) |
|
|
else: |
|
|
return seg_data |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_data: Dict, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
return_hidden: bool = False, |
|
|
use_cache: bool = False, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
compute_contrastive: bool = False |
|
|
) -> Dict: |
|
|
"""前向传播""" |
|
|
device = self.token_embedding.weight.device |
|
|
|
|
|
|
|
|
encoded_segments = [] |
|
|
for segment in input_data.get('segments', []): |
|
|
encoded = self._encode_modality(segment) |
|
|
|
|
|
|
|
|
modality_id = segment.get('modality_id', 0) |
|
|
modality_embeds = self.modality_embedding( |
|
|
torch.tensor([modality_id], device=device) |
|
|
).expand(encoded.shape[0], encoded.shape[1], -1) |
|
|
|
|
|
encoded_segments.append({ |
|
|
'type': segment['type'], |
|
|
'data': encoded + modality_embeds, |
|
|
'modality_id': modality_id |
|
|
}) |
|
|
|
|
|
|
|
|
contrastive_losses = {} |
|
|
if self.use_multimodal_fusion and len(encoded_segments) > 1: |
|
|
fusion_output = self.fusion_module( |
|
|
encoded_segments, |
|
|
compute_contrastive=compute_contrastive |
|
|
) |
|
|
x = fusion_output['fused_features'] |
|
|
contrastive_losses = fusion_output.get('contrastive_losses', {}) |
|
|
else: |
|
|
|
|
|
all_embeddings = [seg['data'] for seg in encoded_segments] |
|
|
x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros( |
|
|
1, 1, self.model_dim, device=device |
|
|
) |
|
|
|
|
|
x = self.embed_dropout(x) |
|
|
if position_ids is None: |
|
|
if past_key_values is not None: |
|
|
|
|
|
past_length = past_key_values[0][0].size(2) |
|
|
|
|
|
seq_length = x.shape[1] |
|
|
|
|
|
position_ids = torch.arange( |
|
|
past_length, past_length + seq_length, dtype=torch.long, device=device |
|
|
).unsqueeze(0).expand(x.shape[0], -1) |
|
|
else: |
|
|
|
|
|
seq_length = x.shape[1] |
|
|
position_ids = torch.arange( |
|
|
0, seq_length, dtype=torch.long, device=device |
|
|
).unsqueeze(0).expand(x.shape[0], -1) |
|
|
|
|
|
present_key_values = [] if use_cache else None |
|
|
all_hidden_states = [] if output_hidden_states else None |
|
|
all_attentions = [] if output_attentions else None |
|
|
moe_aux_loss = torch.tensor(0.0, device=device) |
|
|
|
|
|
for idx, layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states.append(x) |
|
|
|
|
|
past_kv = past_key_values[idx] if past_key_values is not None else None |
|
|
|
|
|
if self.use_gradient_checkpointing and self.training: |
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module( |
|
|
inputs[0], |
|
|
attention_mask=inputs[1], |
|
|
position_ids=inputs[2], |
|
|
use_cache=False, |
|
|
past_kv=None, |
|
|
output_attentions=False |
|
|
) |
|
|
return custom_forward |
|
|
|
|
|
import torch.utils.checkpoint as checkpoint |
|
|
layer_outputs = checkpoint.checkpoint( |
|
|
create_custom_forward(layer), |
|
|
x, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
use_reentrant=False |
|
|
) |
|
|
x = layer_outputs[0] |
|
|
present_kv = None |
|
|
attn_weights = None |
|
|
else: |
|
|
layer_outputs = layer( |
|
|
x, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
use_cache=use_cache, |
|
|
past_kv=past_kv, |
|
|
output_attentions=output_attentions |
|
|
) |
|
|
x, present_kv, attn_weights = layer_outputs |
|
|
|
|
|
if use_cache: |
|
|
present_key_values.append(present_kv) |
|
|
|
|
|
if output_attentions: |
|
|
all_attentions.append(attn_weights) |
|
|
|
|
|
if hasattr(layer, 'moe_aux_loss'): |
|
|
moe_aux_loss += layer.moe_aux_loss |
|
|
|
|
|
hidden_states = self.norm(x) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states.append(hidden_states) |
|
|
|
|
|
|
|
|
outputs = { |
|
|
'logits': logits, |
|
|
'moe_aux_loss': moe_aux_loss, |
|
|
'contrastive_losses': contrastive_losses |
|
|
} |
|
|
|
|
|
if use_cache: |
|
|
outputs['past_key_values'] = present_key_values |
|
|
|
|
|
if output_hidden_states: |
|
|
outputs['hidden_states'] = all_hidden_states |
|
|
|
|
|
if output_attentions: |
|
|
outputs['attentions'] = all_attentions |
|
|
|
|
|
if return_hidden: |
|
|
outputs['last_hidden_state'] = hidden_states |
|
|
|
|
|
return outputs |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_data: Dict, |
|
|
max_new_tokens: int = 100, |
|
|
temperature: float = 1.0, |
|
|
top_k: int = 50, |
|
|
top_p: float = 0.9, |
|
|
eos_token_id: int = 2, |
|
|
pad_token_id: Optional[int] = None, |
|
|
use_cache: bool = True, |
|
|
repetition_penalty: float = 1.0, |
|
|
length_penalty: float = 1.0, |
|
|
min_length: int = 0, |
|
|
do_sample: bool = True, |
|
|
num_beams: int = 1 |
|
|
) -> torch.Tensor: |
|
|
"""改进的生成方法""" |
|
|
self.eval() |
|
|
device = next(self.parameters()).device |
|
|
|
|
|
if pad_token_id is None: |
|
|
pad_token_id = eos_token_id |
|
|
|
|
|
initial_text_tokens = input_data['segments'][0]['data'].to(device) |
|
|
batch_size = initial_text_tokens.shape[0] |
|
|
|
|
|
if 'attention_mask' in input_data: |
|
|
attention_mask = input_data['attention_mask'].to(device) |
|
|
else: |
|
|
attention_mask = torch.ones_like(initial_text_tokens) |
|
|
initial_seq_len = initial_text_tokens.shape[1] |
|
|
position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device) |
|
|
|
|
|
for i in range(batch_size): |
|
|
non_pad_mask = attention_mask[i].bool() |
|
|
if non_pad_mask.any(): |
|
|
positions = torch.cumsum(non_pad_mask.long(),dim=0) -1 |
|
|
position_ids[i]=positions * non_pad_mask.long() |
|
|
|
|
|
|
|
|
|
|
|
generated_tokens = [] |
|
|
past_key_values = None |
|
|
current_tokens = initial_text_tokens |
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) |
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
current_input_data = { |
|
|
'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}] |
|
|
} |
|
|
|
|
|
if step > 0 and use_cache: |
|
|
|
|
|
new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device) |
|
|
attention_mask = torch.cat([attention_mask, new_mask], dim=1) |
|
|
current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0) |
|
|
current_positions_ids=current_positions |
|
|
else: |
|
|
current_positions_ids=position_ids |
|
|
outputs = self.forward( |
|
|
current_input_data, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=current_positions_ids, |
|
|
use_cache=use_cache, |
|
|
past_key_values=past_key_values |
|
|
) |
|
|
|
|
|
logits = outputs['logits'] |
|
|
if use_cache: |
|
|
past_key_values = outputs['past_key_values'] |
|
|
|
|
|
next_token_logits = logits[:, -1, :] / max(temperature, 1e-5) |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0 and len(generated_tokens) > 0: |
|
|
prev_generated = torch.cat(generated_tokens, dim=1) |
|
|
score = torch.gather(next_token_logits, 1, prev_generated) |
|
|
score = torch.where( |
|
|
score < 0, |
|
|
score * repetition_penalty, |
|
|
score / repetition_penalty |
|
|
) |
|
|
next_token_logits.scatter_(1, prev_generated, score) |
|
|
|
|
|
|
|
|
if step < min_length: |
|
|
next_token_logits[:, eos_token_id] = float('-inf') |
|
|
|
|
|
|
|
|
if do_sample: |
|
|
if top_k > 0: |
|
|
top_k_vals, _ = torch.topk(next_token_logits, top_k) |
|
|
min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1) |
|
|
next_token_logits[next_token_logits < min_val_to_keep] = float('-inf') |
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool) |
|
|
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None]) |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
|
|
|
if not use_cache: |
|
|
initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1) |
|
|
current_tokens = initial_text_tokens |
|
|
else: |
|
|
current_tokens = next_token |
|
|
|
|
|
|
|
|
unfinished_sequences = unfinished_sequences.mul( |
|
|
(next_token.squeeze(-1) != eos_token_id).long() |
|
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0: |
|
|
break |
|
|
|
|
|
if not generated_tokens: |
|
|
return torch.empty(batch_size, 0, dtype=torch.long, device=device) |
|
|
|
|
|
return torch.cat(generated_tokens, dim=1) |