import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Dict, Optional, Tuple import math from components import RMSNorm from transformer import OptimizedTransformerBlock from multimodel_fusion import MultiModalFusionModule from encoders import ( ImprovedVisionTransformer, ImprovedAudioEncoder, ImprovedVideoEncoder ) class MultiModalDenseTransformer(nn.Module): def __init__( self, model_dim: int = 2048, vocab_size: int = 30000, n_layers: int = 48, n_heads: int = 32, n_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, max_seq_len: int = 8192, dropout: float = 0.0, attn_dropout: float = 0.0, # MoE配置 use_moe: bool = False, num_experts: int = 8, moe_top_k: int = 2, moe_layers: Optional[List[int]] = None, # PEFT配置 use_adapter: bool = False, adapter_dim: int = 64, use_lora: bool = False, lora_rank: int = 8, # 训练配置 use_gradient_checkpointing: bool = False, use_parallel_residual: bool = False, # 位置编码 rope_scaling_factor: float = 1.0, rope_scaling_type: str = "yarn", sliding_window: Optional[int] = None, # 规范化 norm_eps: float = 1e-6, initializer_range: float = 0.02, ffn_dim_multiplier: Optional[float] = None, tie_word_embeddings: bool = True, # 多模态配置 use_multimodal_fusion: bool = True, fusion_layers: int = 4, use_contrastive: bool = True, vision_depth: int = 24, audio_depth: int = 12, video_spatial_depth: int = 12, video_temporal_depth: int = 4 ): super().__init__() self.model_dim = model_dim self.vocab_size = vocab_size self.n_layers = n_layers self.max_seq_len = max_seq_len self.use_gradient_checkpointing = use_gradient_checkpointing self.tie_word_embeddings = tie_word_embeddings self.use_multimodal_fusion = use_multimodal_fusion # Token embedding self.token_embedding = nn.Embedding(vocab_size, model_dim) self.modality_embedding = nn.Embedding(4, model_dim) self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.vision_encoder = ImprovedVisionTransformer( embed_dim=model_dim, depth=vision_depth, n_heads=n_heads, dropout=dropout, use_adapter=use_adapter, adapter_dim=adapter_dim ) self.audio_encoder = ImprovedAudioEncoder( embed_dim=model_dim, depth=audio_depth, n_heads=n_heads, dropout=dropout, use_adapter=use_adapter, adapter_dim=adapter_dim ) self.video_encoder = ImprovedVideoEncoder( embed_dim=model_dim, spatial_depth=video_spatial_depth, temporal_depth=video_temporal_depth, n_heads=n_heads, dropout=dropout, use_adapter=use_adapter, adapter_dim=adapter_dim ) # 多模态融合模块 if use_multimodal_fusion: self.fusion_module = MultiModalFusionModule( dim=model_dim, num_fusion_layers=fusion_layers, n_heads=n_heads, dropout=dropout, use_contrastive=use_contrastive ) if moe_layers is None and use_moe: moe_layers = list(range(n_layers // 2, n_layers)) elif moe_layers is None: moe_layers = [] self.layers = nn.ModuleList([ OptimizedTransformerBlock( dim=model_dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, dropout=dropout, attn_dropout=attn_dropout, use_moe=(use_moe and i in moe_layers), num_experts=num_experts, moe_top_k=moe_top_k, use_adapter=use_adapter, adapter_dim=adapter_dim, use_lora=use_lora, lora_rank=lora_rank, use_parallel_residual=use_parallel_residual, norm_eps=norm_eps, sliding_window=sliding_window, ffn_dim_multiplier=ffn_dim_multiplier, layer_idx=i ) for i in range(n_layers) ]) self.norm = RMSNorm(model_dim, eps=norm_eps) self.lm_head = nn.Linear(model_dim, vocab_size, bias=False) if tie_word_embeddings: self.lm_head.weight = self.token_embedding.weight self.initializer_range = initializer_range self.apply(self._init_weights) if not tie_word_embeddings: self._init_lm_head() self.n_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"\n{'='*80}") print(f"Improved Model Configuration:") print(f" Model Dimension: {model_dim}") print(f" Vocab Size: {vocab_size}") print(f" Layers: {n_layers}") print(f" Attention Heads: {n_heads}") print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}") print(f" Max Sequence Length: {max_seq_len}") print(f" Multimodal Fusion: {use_multimodal_fusion}") print(f" Contrastive Learning: {use_contrastive}") print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})") print(f" Total Parameters: {self.n_params / 1e9:.2f}B") print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B") print(f"{'='*80}\n") def _init_weights(self, module): """权重初始化""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) if hasattr(module, 'padding_idx') and module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _init_lm_head(self): """初始化LM head""" std = self.initializer_range / math.sqrt(2 * self.n_layers) torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std) def _encode_modality(self, segment: Dict) -> torch.Tensor: """编码单个模态""" seg_type = segment['type'] seg_data = segment['data'] if seg_type == 'image': return self.vision_encoder(seg_data) elif seg_type == 'audio': return self.audio_encoder(seg_data) elif seg_type == 'video': return self.video_encoder(seg_data) elif seg_type == 'text': return self.token_embedding(seg_data) else: return seg_data def forward( self, input_data: Dict, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, return_hidden: bool = False, use_cache: bool = False, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, compute_contrastive: bool = False ) -> Dict: """前向传播""" device = self.token_embedding.weight.device # 编码每个模态 encoded_segments = [] for segment in input_data.get('segments', []): encoded = self._encode_modality(segment) # 添加模态嵌入 modality_id = segment.get('modality_id', 0) modality_embeds = self.modality_embedding( torch.tensor([modality_id], device=device) ).expand(encoded.shape[0], encoded.shape[1], -1) encoded_segments.append({ 'type': segment['type'], 'data': encoded + modality_embeds, 'modality_id': modality_id }) # 多模态融合 contrastive_losses = {} if self.use_multimodal_fusion and len(encoded_segments) > 1: fusion_output = self.fusion_module( encoded_segments, compute_contrastive=compute_contrastive ) x = fusion_output['fused_features'] contrastive_losses = fusion_output.get('contrastive_losses', {}) else: # 简单拼接 all_embeddings = [seg['data'] for seg in encoded_segments] x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros( 1, 1, self.model_dim, device=device ) x = self.embed_dropout(x) if position_ids is None: if past_key_values is not None: # 缓存的长度 (KV cache 的 shape 是 [B, H, SeqLen, D]) past_length = past_key_values[0][0].size(2) # 当前输入的长度 seq_length = x.shape[1] # 生成正确的位置索引: [past_length, past_length + 1, ...] position_ids = torch.arange( past_length, past_length + seq_length, dtype=torch.long, device=device ).unsqueeze(0).expand(x.shape[0], -1) else: # 如果没有缓存,从 0 开始 seq_length = x.shape[1] position_ids = torch.arange( 0, seq_length, dtype=torch.long, device=device ).unsqueeze(0).expand(x.shape[0], -1) # Transformer层 present_key_values = [] if use_cache else None all_hidden_states = [] if output_hidden_states else None all_attentions = [] if output_attentions else None moe_aux_loss = torch.tensor(0.0, device=device) for idx, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states.append(x) past_kv = past_key_values[idx] if past_key_values is not None else None if self.use_gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module( inputs[0], attention_mask=inputs[1], position_ids=inputs[2], use_cache=False, past_kv=None, output_attentions=False ) return custom_forward import torch.utils.checkpoint as checkpoint layer_outputs = checkpoint.checkpoint( create_custom_forward(layer), x, attention_mask, position_ids, use_reentrant=False ) x = layer_outputs[0] present_kv = None attn_weights = None else: layer_outputs = layer( x, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, past_kv=past_kv, output_attentions=output_attentions ) x, present_kv, attn_weights = layer_outputs if use_cache: present_key_values.append(present_kv) if output_attentions: all_attentions.append(attn_weights) if hasattr(layer, 'moe_aux_loss'): moe_aux_loss += layer.moe_aux_loss hidden_states = self.norm(x) logits = self.lm_head(hidden_states) if output_hidden_states: all_hidden_states.append(hidden_states) # 组装输出 outputs = { 'logits': logits, 'moe_aux_loss': moe_aux_loss, 'contrastive_losses': contrastive_losses } if use_cache: outputs['past_key_values'] = present_key_values if output_hidden_states: outputs['hidden_states'] = all_hidden_states if output_attentions: outputs['attentions'] = all_attentions if return_hidden: outputs['last_hidden_state'] = hidden_states return outputs @torch.no_grad() def generate( self, input_data: Dict, max_new_tokens: int = 100, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.9, eos_token_id: int = 2, pad_token_id: Optional[int] = None, use_cache: bool = True, repetition_penalty: float = 1.0, length_penalty: float = 1.0, min_length: int = 0, do_sample: bool = True, num_beams: int = 1 ) -> torch.Tensor: """改进的生成方法""" self.eval() device = next(self.parameters()).device if pad_token_id is None: pad_token_id = eos_token_id initial_text_tokens = input_data['segments'][0]['data'].to(device) batch_size = initial_text_tokens.shape[0] if 'attention_mask' in input_data: attention_mask = input_data['attention_mask'].to(device) else: attention_mask = torch.ones_like(initial_text_tokens) initial_seq_len = initial_text_tokens.shape[1] position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device) for i in range(batch_size): non_pad_mask = attention_mask[i].bool() if non_pad_mask.any(): positions = torch.cumsum(non_pad_mask.long(),dim=0) -1 position_ids[i]=positions * non_pad_mask.long() generated_tokens = [] past_key_values = None current_tokens = initial_text_tokens unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) for step in range(max_new_tokens): current_input_data = { 'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}] } if step > 0 and use_cache: # 添加当前 token 的 mask (1) new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device) attention_mask = torch.cat([attention_mask, new_mask], dim=1) current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0) current_positions_ids=current_positions else: current_positions_ids=position_ids outputs = self.forward( current_input_data, attention_mask=attention_mask, # <--- 传入 Mask position_ids=current_positions_ids, use_cache=use_cache, past_key_values=past_key_values ) logits = outputs['logits'] if use_cache: past_key_values = outputs['past_key_values'] next_token_logits = logits[:, -1, :] / max(temperature, 1e-5) # Repetition penalty if repetition_penalty != 1.0 and len(generated_tokens) > 0: prev_generated = torch.cat(generated_tokens, dim=1) score = torch.gather(next_token_logits, 1, prev_generated) score = torch.where( score < 0, score * repetition_penalty, score / repetition_penalty ) next_token_logits.scatter_(1, prev_generated, score) # Min length constraint if step < min_length: next_token_logits[:, eos_token_id] = float('-inf') # Sampling if do_sample: if top_k > 0: top_k_vals, _ = torch.topk(next_token_logits, top_k) min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1) next_token_logits[next_token_logits < min_val_to_keep] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool) indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Apply unfinished mask next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None]) generated_tokens.append(next_token) if not use_cache: initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1) current_tokens = initial_text_tokens else: current_tokens = next_token # Update unfinished sequences unfinished_sequences = unfinished_sequences.mul( (next_token.squeeze(-1) != eos_token_id).long() ) if unfinished_sequences.max() == 0: break if not generated_tokens: return torch.empty(batch_size, 0, dtype=torch.long, device=device) return torch.cat(generated_tokens, dim=1)