MultiModal / model.py
szxllm's picture
Update model.py
d16a3f0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional, Tuple
import math
from components import RMSNorm
from transformer import OptimizedTransformerBlock
from multimodel_fusion import MultiModalFusionModule
from encoders import (
ImprovedVisionTransformer,
ImprovedAudioEncoder,
ImprovedVideoEncoder
)
class MultiModalDenseTransformer(nn.Module):
def __init__(
self,
model_dim: int = 2048,
vocab_size: int = 30000,
n_layers: int = 48,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
max_seq_len: int = 8192,
dropout: float = 0.0,
attn_dropout: float = 0.0,
# MoE配置
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
moe_layers: Optional[List[int]] = None,
# PEFT配置
use_adapter: bool = False,
adapter_dim: int = 64,
use_lora: bool = False,
lora_rank: int = 8,
# 训练配置
use_gradient_checkpointing: bool = False,
use_parallel_residual: bool = False,
# 位置编码
rope_scaling_factor: float = 1.0,
rope_scaling_type: str = "yarn",
sliding_window: Optional[int] = None,
# 规范化
norm_eps: float = 1e-6,
initializer_range: float = 0.02,
ffn_dim_multiplier: Optional[float] = None,
tie_word_embeddings: bool = True,
# 多模态配置
use_multimodal_fusion: bool = True,
fusion_layers: int = 4,
use_contrastive: bool = True,
vision_depth: int = 24,
audio_depth: int = 12,
video_spatial_depth: int = 12,
video_temporal_depth: int = 4
):
super().__init__()
self.model_dim = model_dim
self.vocab_size = vocab_size
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.use_gradient_checkpointing = use_gradient_checkpointing
self.tie_word_embeddings = tie_word_embeddings
self.use_multimodal_fusion = use_multimodal_fusion
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, model_dim)
self.modality_embedding = nn.Embedding(4, model_dim)
self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.vision_encoder = ImprovedVisionTransformer(
embed_dim=model_dim,
depth=vision_depth,
n_heads=n_heads,
dropout=dropout,
use_adapter=use_adapter,
adapter_dim=adapter_dim
)
self.audio_encoder = ImprovedAudioEncoder(
embed_dim=model_dim,
depth=audio_depth,
n_heads=n_heads,
dropout=dropout,
use_adapter=use_adapter,
adapter_dim=adapter_dim
)
self.video_encoder = ImprovedVideoEncoder(
embed_dim=model_dim,
spatial_depth=video_spatial_depth,
temporal_depth=video_temporal_depth,
n_heads=n_heads,
dropout=dropout,
use_adapter=use_adapter,
adapter_dim=adapter_dim
)
# 多模态融合模块
if use_multimodal_fusion:
self.fusion_module = MultiModalFusionModule(
dim=model_dim,
num_fusion_layers=fusion_layers,
n_heads=n_heads,
dropout=dropout,
use_contrastive=use_contrastive
)
if moe_layers is None and use_moe:
moe_layers = list(range(n_layers // 2, n_layers))
elif moe_layers is None:
moe_layers = []
self.layers = nn.ModuleList([
OptimizedTransformerBlock(
dim=model_dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
dropout=dropout,
attn_dropout=attn_dropout,
use_moe=(use_moe and i in moe_layers),
num_experts=num_experts,
moe_top_k=moe_top_k,
use_adapter=use_adapter,
adapter_dim=adapter_dim,
use_lora=use_lora,
lora_rank=lora_rank,
use_parallel_residual=use_parallel_residual,
norm_eps=norm_eps,
sliding_window=sliding_window,
ffn_dim_multiplier=ffn_dim_multiplier,
layer_idx=i
)
for i in range(n_layers)
])
self.norm = RMSNorm(model_dim, eps=norm_eps)
self.lm_head = nn.Linear(model_dim, vocab_size, bias=False)
if tie_word_embeddings:
self.lm_head.weight = self.token_embedding.weight
self.initializer_range = initializer_range
self.apply(self._init_weights)
if not tie_word_embeddings:
self._init_lm_head()
self.n_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"\n{'='*80}")
print(f"Improved Model Configuration:")
print(f" Model Dimension: {model_dim}")
print(f" Vocab Size: {vocab_size}")
print(f" Layers: {n_layers}")
print(f" Attention Heads: {n_heads}")
print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}")
print(f" Max Sequence Length: {max_seq_len}")
print(f" Multimodal Fusion: {use_multimodal_fusion}")
print(f" Contrastive Learning: {use_contrastive}")
print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})")
print(f" Total Parameters: {self.n_params / 1e9:.2f}B")
print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B")
print(f"{'='*80}\n")
def _init_weights(self, module):
"""权重初始化"""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
if hasattr(module, 'padding_idx') and module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _init_lm_head(self):
"""初始化LM head"""
std = self.initializer_range / math.sqrt(2 * self.n_layers)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
def _encode_modality(self, segment: Dict) -> torch.Tensor:
"""编码单个模态"""
seg_type = segment['type']
seg_data = segment['data']
if seg_type == 'image':
return self.vision_encoder(seg_data)
elif seg_type == 'audio':
return self.audio_encoder(seg_data)
elif seg_type == 'video':
return self.video_encoder(seg_data)
elif seg_type == 'text':
return self.token_embedding(seg_data)
else:
return seg_data
def forward(
self,
input_data: Dict,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_hidden: bool = False,
use_cache: bool = False,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
compute_contrastive: bool = False
) -> Dict:
"""前向传播"""
device = self.token_embedding.weight.device
# 编码每个模态
encoded_segments = []
for segment in input_data.get('segments', []):
encoded = self._encode_modality(segment)
# 添加模态嵌入
modality_id = segment.get('modality_id', 0)
modality_embeds = self.modality_embedding(
torch.tensor([modality_id], device=device)
).expand(encoded.shape[0], encoded.shape[1], -1)
encoded_segments.append({
'type': segment['type'],
'data': encoded + modality_embeds,
'modality_id': modality_id
})
# 多模态融合
contrastive_losses = {}
if self.use_multimodal_fusion and len(encoded_segments) > 1:
fusion_output = self.fusion_module(
encoded_segments,
compute_contrastive=compute_contrastive
)
x = fusion_output['fused_features']
contrastive_losses = fusion_output.get('contrastive_losses', {})
else:
# 简单拼接
all_embeddings = [seg['data'] for seg in encoded_segments]
x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros(
1, 1, self.model_dim, device=device
)
x = self.embed_dropout(x)
if position_ids is None:
if past_key_values is not None:
# 缓存的长度 (KV cache 的 shape 是 [B, H, SeqLen, D])
past_length = past_key_values[0][0].size(2)
# 当前输入的长度
seq_length = x.shape[1]
# 生成正确的位置索引: [past_length, past_length + 1, ...]
position_ids = torch.arange(
past_length, past_length + seq_length, dtype=torch.long, device=device
).unsqueeze(0).expand(x.shape[0], -1)
else:
# 如果没有缓存,从 0 开始
seq_length = x.shape[1]
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
).unsqueeze(0).expand(x.shape[0], -1)
# Transformer层
present_key_values = [] if use_cache else None
all_hidden_states = [] if output_hidden_states else None
all_attentions = [] if output_attentions else None
moe_aux_loss = torch.tensor(0.0, device=device)
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(x)
past_kv = past_key_values[idx] if past_key_values is not None else None
if self.use_gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(
inputs[0],
attention_mask=inputs[1],
position_ids=inputs[2],
use_cache=False,
past_kv=None,
output_attentions=False
)
return custom_forward
import torch.utils.checkpoint as checkpoint
layer_outputs = checkpoint.checkpoint(
create_custom_forward(layer),
x,
attention_mask,
position_ids,
use_reentrant=False
)
x = layer_outputs[0]
present_kv = None
attn_weights = None
else:
layer_outputs = layer(
x,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_kv=past_kv,
output_attentions=output_attentions
)
x, present_kv, attn_weights = layer_outputs
if use_cache:
present_key_values.append(present_kv)
if output_attentions:
all_attentions.append(attn_weights)
if hasattr(layer, 'moe_aux_loss'):
moe_aux_loss += layer.moe_aux_loss
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)
if output_hidden_states:
all_hidden_states.append(hidden_states)
# 组装输出
outputs = {
'logits': logits,
'moe_aux_loss': moe_aux_loss,
'contrastive_losses': contrastive_losses
}
if use_cache:
outputs['past_key_values'] = present_key_values
if output_hidden_states:
outputs['hidden_states'] = all_hidden_states
if output_attentions:
outputs['attentions'] = all_attentions
if return_hidden:
outputs['last_hidden_state'] = hidden_states
return outputs
@torch.no_grad()
def generate(
self,
input_data: Dict,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.9,
eos_token_id: int = 2,
pad_token_id: Optional[int] = None,
use_cache: bool = True,
repetition_penalty: float = 1.0,
length_penalty: float = 1.0,
min_length: int = 0,
do_sample: bool = True,
num_beams: int = 1
) -> torch.Tensor:
"""改进的生成方法"""
self.eval()
device = next(self.parameters()).device
if pad_token_id is None:
pad_token_id = eos_token_id
initial_text_tokens = input_data['segments'][0]['data'].to(device)
batch_size = initial_text_tokens.shape[0]
if 'attention_mask' in input_data:
attention_mask = input_data['attention_mask'].to(device)
else:
attention_mask = torch.ones_like(initial_text_tokens)
initial_seq_len = initial_text_tokens.shape[1]
position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device)
for i in range(batch_size):
non_pad_mask = attention_mask[i].bool()
if non_pad_mask.any():
positions = torch.cumsum(non_pad_mask.long(),dim=0) -1
position_ids[i]=positions * non_pad_mask.long()
generated_tokens = []
past_key_values = None
current_tokens = initial_text_tokens
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
for step in range(max_new_tokens):
current_input_data = {
'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}]
}
if step > 0 and use_cache:
# 添加当前 token 的 mask (1)
new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device)
attention_mask = torch.cat([attention_mask, new_mask], dim=1)
current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0)
current_positions_ids=current_positions
else:
current_positions_ids=position_ids
outputs = self.forward(
current_input_data,
attention_mask=attention_mask, # <--- 传入 Mask
position_ids=current_positions_ids,
use_cache=use_cache,
past_key_values=past_key_values
)
logits = outputs['logits']
if use_cache:
past_key_values = outputs['past_key_values']
next_token_logits = logits[:, -1, :] / max(temperature, 1e-5)
# Repetition penalty
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
prev_generated = torch.cat(generated_tokens, dim=1)
score = torch.gather(next_token_logits, 1, prev_generated)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty
)
next_token_logits.scatter_(1, prev_generated, score)
# Min length constraint
if step < min_length:
next_token_logits[:, eos_token_id] = float('-inf')
# Sampling
if do_sample:
if top_k > 0:
top_k_vals, _ = torch.topk(next_token_logits, top_k)
min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1)
next_token_logits[next_token_logits < min_val_to_keep] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Apply unfinished mask
next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None])
generated_tokens.append(next_token)
if not use_cache:
initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1)
current_tokens = initial_text_tokens
else:
current_tokens = next_token
# Update unfinished sequences
unfinished_sequences = unfinished_sequences.mul(
(next_token.squeeze(-1) != eos_token_id).long()
)
if unfinished_sequences.max() == 0:
break
if not generated_tokens:
return torch.empty(batch_size, 0, dtype=torch.long, device=device)
return torch.cat(generated_tokens, dim=1)