| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import List, Dict, Optional, Tuple |
| | import math |
| | from components import RMSNorm |
| | from transformer import OptimizedTransformerBlock |
| | from multimodel_fusion import MultiModalFusionModule |
| | from encoders import ( |
| | ImprovedVisionTransformer, |
| | ImprovedAudioEncoder, |
| | ImprovedVideoEncoder |
| | ) |
| |
|
| | class MultiModalDenseTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | model_dim: int = 2048, |
| | vocab_size: int = 30000, |
| | n_layers: int = 48, |
| | n_heads: int = 32, |
| | n_kv_heads: Optional[int] = None, |
| | head_dim: Optional[int] = None, |
| | max_seq_len: int = 8192, |
| | dropout: float = 0.0, |
| | attn_dropout: float = 0.0, |
| | |
| | |
| | use_moe: bool = False, |
| | num_experts: int = 8, |
| | moe_top_k: int = 2, |
| | moe_layers: Optional[List[int]] = None, |
| | |
| | |
| | use_adapter: bool = False, |
| | adapter_dim: int = 64, |
| | use_lora: bool = False, |
| | lora_rank: int = 8, |
| | |
| | |
| | use_gradient_checkpointing: bool = False, |
| | use_parallel_residual: bool = False, |
| | |
| | |
| | rope_scaling_factor: float = 1.0, |
| | rope_scaling_type: str = "yarn", |
| | sliding_window: Optional[int] = None, |
| | |
| | |
| | norm_eps: float = 1e-6, |
| | initializer_range: float = 0.02, |
| | ffn_dim_multiplier: Optional[float] = None, |
| | tie_word_embeddings: bool = True, |
| | |
| | |
| | use_multimodal_fusion: bool = True, |
| | fusion_layers: int = 4, |
| | use_contrastive: bool = True, |
| | vision_depth: int = 24, |
| | audio_depth: int = 12, |
| | video_spatial_depth: int = 12, |
| | video_temporal_depth: int = 4 |
| | ): |
| | super().__init__() |
| | |
| | self.model_dim = model_dim |
| | self.vocab_size = vocab_size |
| | self.n_layers = n_layers |
| | self.max_seq_len = max_seq_len |
| | self.use_gradient_checkpointing = use_gradient_checkpointing |
| | self.tie_word_embeddings = tie_word_embeddings |
| | self.use_multimodal_fusion = use_multimodal_fusion |
| | |
| | |
| | self.token_embedding = nn.Embedding(vocab_size, model_dim) |
| | self.modality_embedding = nn.Embedding(4, model_dim) |
| | self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
| | |
| | self.vision_encoder = ImprovedVisionTransformer( |
| | embed_dim=model_dim, |
| | depth=vision_depth, |
| | n_heads=n_heads, |
| | dropout=dropout, |
| | use_adapter=use_adapter, |
| | adapter_dim=adapter_dim |
| | ) |
| | |
| | self.audio_encoder = ImprovedAudioEncoder( |
| | embed_dim=model_dim, |
| | depth=audio_depth, |
| | n_heads=n_heads, |
| | dropout=dropout, |
| | use_adapter=use_adapter, |
| | adapter_dim=adapter_dim |
| | ) |
| | |
| | self.video_encoder = ImprovedVideoEncoder( |
| | embed_dim=model_dim, |
| | spatial_depth=video_spatial_depth, |
| | temporal_depth=video_temporal_depth, |
| | n_heads=n_heads, |
| | dropout=dropout, |
| | use_adapter=use_adapter, |
| | adapter_dim=adapter_dim |
| | ) |
| | |
| | |
| | if use_multimodal_fusion: |
| | self.fusion_module = MultiModalFusionModule( |
| | dim=model_dim, |
| | num_fusion_layers=fusion_layers, |
| | n_heads=n_heads, |
| | dropout=dropout, |
| | use_contrastive=use_contrastive |
| | ) |
| | |
| | if moe_layers is None and use_moe: |
| | moe_layers = list(range(n_layers // 2, n_layers)) |
| | elif moe_layers is None: |
| | moe_layers = [] |
| | |
| | self.layers = nn.ModuleList([ |
| | OptimizedTransformerBlock( |
| | dim=model_dim, |
| | n_heads=n_heads, |
| | n_kv_heads=n_kv_heads, |
| | head_dim=head_dim, |
| | dropout=dropout, |
| | attn_dropout=attn_dropout, |
| | use_moe=(use_moe and i in moe_layers), |
| | num_experts=num_experts, |
| | moe_top_k=moe_top_k, |
| | use_adapter=use_adapter, |
| | adapter_dim=adapter_dim, |
| | use_lora=use_lora, |
| | lora_rank=lora_rank, |
| | use_parallel_residual=use_parallel_residual, |
| | norm_eps=norm_eps, |
| | sliding_window=sliding_window, |
| | ffn_dim_multiplier=ffn_dim_multiplier, |
| | layer_idx=i |
| | ) |
| | for i in range(n_layers) |
| | ]) |
| | |
| | self.norm = RMSNorm(model_dim, eps=norm_eps) |
| | self.lm_head = nn.Linear(model_dim, vocab_size, bias=False) |
| | |
| | if tie_word_embeddings: |
| | self.lm_head.weight = self.token_embedding.weight |
| | |
| | self.initializer_range = initializer_range |
| | self.apply(self._init_weights) |
| | |
| | if not tie_word_embeddings: |
| | self._init_lm_head() |
| | |
| | self.n_params = sum(p.numel() for p in self.parameters()) |
| | trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| | |
| | print(f"\n{'='*80}") |
| | print(f"Improved Model Configuration:") |
| | print(f" Model Dimension: {model_dim}") |
| | print(f" Vocab Size: {vocab_size}") |
| | print(f" Layers: {n_layers}") |
| | print(f" Attention Heads: {n_heads}") |
| | print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}") |
| | print(f" Max Sequence Length: {max_seq_len}") |
| | print(f" Multimodal Fusion: {use_multimodal_fusion}") |
| | print(f" Contrastive Learning: {use_contrastive}") |
| | print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})") |
| | print(f" Total Parameters: {self.n_params / 1e9:.2f}B") |
| | print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B") |
| | print(f"{'='*80}\n") |
| |
|
| | def _init_weights(self, module): |
| | """权重初始化""" |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) |
| | if hasattr(module, 'padding_idx') and module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| |
|
| | def _init_lm_head(self): |
| | """初始化LM head""" |
| | std = self.initializer_range / math.sqrt(2 * self.n_layers) |
| | torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std) |
| |
|
| | def _encode_modality(self, segment: Dict) -> torch.Tensor: |
| | """编码单个模态""" |
| | seg_type = segment['type'] |
| | seg_data = segment['data'] |
| | |
| | if seg_type == 'image': |
| | return self.vision_encoder(seg_data) |
| | elif seg_type == 'audio': |
| | return self.audio_encoder(seg_data) |
| | elif seg_type == 'video': |
| | return self.video_encoder(seg_data) |
| | elif seg_type == 'text': |
| | return self.token_embedding(seg_data) |
| | else: |
| | return seg_data |
| |
|
| | def forward( |
| | self, |
| | input_data: Dict, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | return_hidden: bool = False, |
| | use_cache: bool = False, |
| | past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | compute_contrastive: bool = False |
| | ) -> Dict: |
| | """前向传播""" |
| | device = self.token_embedding.weight.device |
| | |
| | |
| | encoded_segments = [] |
| | for segment in input_data.get('segments', []): |
| | encoded = self._encode_modality(segment) |
| | |
| | |
| | modality_id = segment.get('modality_id', 0) |
| | modality_embeds = self.modality_embedding( |
| | torch.tensor([modality_id], device=device) |
| | ).expand(encoded.shape[0], encoded.shape[1], -1) |
| | |
| | encoded_segments.append({ |
| | 'type': segment['type'], |
| | 'data': encoded + modality_embeds, |
| | 'modality_id': modality_id |
| | }) |
| | |
| | |
| | contrastive_losses = {} |
| | if self.use_multimodal_fusion and len(encoded_segments) > 1: |
| | fusion_output = self.fusion_module( |
| | encoded_segments, |
| | compute_contrastive=compute_contrastive |
| | ) |
| | x = fusion_output['fused_features'] |
| | contrastive_losses = fusion_output.get('contrastive_losses', {}) |
| | else: |
| | |
| | all_embeddings = [seg['data'] for seg in encoded_segments] |
| | x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros( |
| | 1, 1, self.model_dim, device=device |
| | ) |
| | |
| | x = self.embed_dropout(x) |
| | if position_ids is None: |
| | if past_key_values is not None: |
| | |
| | past_length = past_key_values[0][0].size(2) |
| | |
| | seq_length = x.shape[1] |
| | |
| | position_ids = torch.arange( |
| | past_length, past_length + seq_length, dtype=torch.long, device=device |
| | ).unsqueeze(0).expand(x.shape[0], -1) |
| | else: |
| | |
| | seq_length = x.shape[1] |
| | position_ids = torch.arange( |
| | 0, seq_length, dtype=torch.long, device=device |
| | ).unsqueeze(0).expand(x.shape[0], -1) |
| | |
| | present_key_values = [] if use_cache else None |
| | all_hidden_states = [] if output_hidden_states else None |
| | all_attentions = [] if output_attentions else None |
| | moe_aux_loss = torch.tensor(0.0, device=device) |
| | |
| | for idx, layer in enumerate(self.layers): |
| | if output_hidden_states: |
| | all_hidden_states.append(x) |
| | |
| | past_kv = past_key_values[idx] if past_key_values is not None else None |
| | |
| | if self.use_gradient_checkpointing and self.training: |
| | def create_custom_forward(module): |
| | def custom_forward(*inputs): |
| | return module( |
| | inputs[0], |
| | attention_mask=inputs[1], |
| | position_ids=inputs[2], |
| | use_cache=False, |
| | past_kv=None, |
| | output_attentions=False |
| | ) |
| | return custom_forward |
| | |
| | import torch.utils.checkpoint as checkpoint |
| | layer_outputs = checkpoint.checkpoint( |
| | create_custom_forward(layer), |
| | x, |
| | attention_mask, |
| | position_ids, |
| | use_reentrant=False |
| | ) |
| | x = layer_outputs[0] |
| | present_kv = None |
| | attn_weights = None |
| | else: |
| | layer_outputs = layer( |
| | x, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | use_cache=use_cache, |
| | past_kv=past_kv, |
| | output_attentions=output_attentions |
| | ) |
| | x, present_kv, attn_weights = layer_outputs |
| | |
| | if use_cache: |
| | present_key_values.append(present_kv) |
| | |
| | if output_attentions: |
| | all_attentions.append(attn_weights) |
| | |
| | if hasattr(layer, 'moe_aux_loss'): |
| | moe_aux_loss += layer.moe_aux_loss |
| | |
| | hidden_states = self.norm(x) |
| | logits = self.lm_head(hidden_states) |
| | |
| | if output_hidden_states: |
| | all_hidden_states.append(hidden_states) |
| | |
| | |
| | outputs = { |
| | 'logits': logits, |
| | 'moe_aux_loss': moe_aux_loss, |
| | 'contrastive_losses': contrastive_losses |
| | } |
| | |
| | if use_cache: |
| | outputs['past_key_values'] = present_key_values |
| | |
| | if output_hidden_states: |
| | outputs['hidden_states'] = all_hidden_states |
| | |
| | if output_attentions: |
| | outputs['attentions'] = all_attentions |
| | |
| | if return_hidden: |
| | outputs['last_hidden_state'] = hidden_states |
| | |
| | return outputs |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_data: Dict, |
| | max_new_tokens: int = 100, |
| | temperature: float = 1.0, |
| | top_k: int = 50, |
| | top_p: float = 0.9, |
| | eos_token_id: int = 2, |
| | pad_token_id: Optional[int] = None, |
| | use_cache: bool = True, |
| | repetition_penalty: float = 1.0, |
| | length_penalty: float = 1.0, |
| | min_length: int = 0, |
| | do_sample: bool = True, |
| | num_beams: int = 1 |
| | ) -> torch.Tensor: |
| | """改进的生成方法""" |
| | self.eval() |
| | device = next(self.parameters()).device |
| | |
| | if pad_token_id is None: |
| | pad_token_id = eos_token_id |
| | |
| | initial_text_tokens = input_data['segments'][0]['data'].to(device) |
| | batch_size = initial_text_tokens.shape[0] |
| | |
| | if 'attention_mask' in input_data: |
| | attention_mask = input_data['attention_mask'].to(device) |
| | else: |
| | attention_mask = torch.ones_like(initial_text_tokens) |
| | initial_seq_len = initial_text_tokens.shape[1] |
| | position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device) |
| | |
| | for i in range(batch_size): |
| | non_pad_mask = attention_mask[i].bool() |
| | if non_pad_mask.any(): |
| | positions = torch.cumsum(non_pad_mask.long(),dim=0) -1 |
| | position_ids[i]=positions * non_pad_mask.long() |
| | |
| |
|
| |
|
| | generated_tokens = [] |
| | past_key_values = None |
| | current_tokens = initial_text_tokens |
| | unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) |
| | |
| | for step in range(max_new_tokens): |
| | current_input_data = { |
| | 'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}] |
| | } |
| | |
| | if step > 0 and use_cache: |
| | |
| | new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device) |
| | attention_mask = torch.cat([attention_mask, new_mask], dim=1) |
| | current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0) |
| | current_positions_ids=current_positions |
| | else: |
| | current_positions_ids=position_ids |
| | outputs = self.forward( |
| | current_input_data, |
| | attention_mask=attention_mask, |
| | position_ids=current_positions_ids, |
| | use_cache=use_cache, |
| | past_key_values=past_key_values |
| | ) |
| | |
| | logits = outputs['logits'] |
| | if use_cache: |
| | past_key_values = outputs['past_key_values'] |
| | |
| | next_token_logits = logits[:, -1, :] / max(temperature, 1e-5) |
| | |
| | |
| | if repetition_penalty != 1.0 and len(generated_tokens) > 0: |
| | prev_generated = torch.cat(generated_tokens, dim=1) |
| | score = torch.gather(next_token_logits, 1, prev_generated) |
| | score = torch.where( |
| | score < 0, |
| | score * repetition_penalty, |
| | score / repetition_penalty |
| | ) |
| | next_token_logits.scatter_(1, prev_generated, score) |
| | |
| | |
| | if step < min_length: |
| | next_token_logits[:, eos_token_id] = float('-inf') |
| | |
| | |
| | if do_sample: |
| | if top_k > 0: |
| | top_k_vals, _ = torch.topk(next_token_logits, top_k) |
| | min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1) |
| | next_token_logits[next_token_logits < min_val_to_keep] = float('-inf') |
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool) |
| | indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove) |
| | next_token_logits[indices_to_remove] = float('-inf') |
| | |
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | else: |
| | next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| | |
| | |
| | next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None]) |
| | |
| | generated_tokens.append(next_token) |
| | |
| | if not use_cache: |
| | initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1) |
| | current_tokens = initial_text_tokens |
| | else: |
| | current_tokens = next_token |
| | |
| | |
| | unfinished_sequences = unfinished_sequences.mul( |
| | (next_token.squeeze(-1) != eos_token_id).long() |
| | ) |
| | |
| | if unfinished_sequences.max() == 0: |
| | break |
| | |
| | if not generated_tokens: |
| | return torch.empty(batch_size, 0, dtype=torch.long, device=device) |
| | |
| | return torch.cat(generated_tokens, dim=1) |