File size: 18,820 Bytes
d16a3f0 cd66851 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional, Tuple
import math
from components import RMSNorm
from transformer import OptimizedTransformerBlock
from multimodel_fusion import MultiModalFusionModule
from encoders import (
ImprovedVisionTransformer,
ImprovedAudioEncoder,
ImprovedVideoEncoder
)
class MultiModalDenseTransformer(nn.Module):
def __init__(
self,
model_dim: int = 2048,
vocab_size: int = 30000,
n_layers: int = 48,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
max_seq_len: int = 8192,
dropout: float = 0.0,
attn_dropout: float = 0.0,
# MoE配置
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
moe_layers: Optional[List[int]] = None,
# PEFT配置
use_adapter: bool = False,
adapter_dim: int = 64,
use_lora: bool = False,
lora_rank: int = 8,
# 训练配置
use_gradient_checkpointing: bool = False,
use_parallel_residual: bool = False,
# 位置编码
rope_scaling_factor: float = 1.0,
rope_scaling_type: str = "yarn",
sliding_window: Optional[int] = None,
# 规范化
norm_eps: float = 1e-6,
initializer_range: float = 0.02,
ffn_dim_multiplier: Optional[float] = None,
tie_word_embeddings: bool = True,
# 多模态配置
use_multimodal_fusion: bool = True,
fusion_layers: int = 4,
use_contrastive: bool = True,
vision_depth: int = 24,
audio_depth: int = 12,
video_spatial_depth: int = 12,
video_temporal_depth: int = 4
):
super().__init__()
self.model_dim = model_dim
self.vocab_size = vocab_size
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.use_gradient_checkpointing = use_gradient_checkpointing
self.tie_word_embeddings = tie_word_embeddings
self.use_multimodal_fusion = use_multimodal_fusion
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, model_dim)
self.modality_embedding = nn.Embedding(4, model_dim)
self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.vision_encoder = ImprovedVisionTransformer(
embed_dim=model_dim,
depth=vision_depth,
n_heads=n_heads,
dropout=dropout,
use_adapter=use_adapter,
adapter_dim=adapter_dim
)
self.audio_encoder = ImprovedAudioEncoder(
embed_dim=model_dim,
depth=audio_depth,
n_heads=n_heads,
dropout=dropout,
use_adapter=use_adapter,
adapter_dim=adapter_dim
)
self.video_encoder = ImprovedVideoEncoder(
embed_dim=model_dim,
spatial_depth=video_spatial_depth,
temporal_depth=video_temporal_depth,
n_heads=n_heads,
dropout=dropout,
use_adapter=use_adapter,
adapter_dim=adapter_dim
)
# 多模态融合模块
if use_multimodal_fusion:
self.fusion_module = MultiModalFusionModule(
dim=model_dim,
num_fusion_layers=fusion_layers,
n_heads=n_heads,
dropout=dropout,
use_contrastive=use_contrastive
)
if moe_layers is None and use_moe:
moe_layers = list(range(n_layers // 2, n_layers))
elif moe_layers is None:
moe_layers = []
self.layers = nn.ModuleList([
OptimizedTransformerBlock(
dim=model_dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
dropout=dropout,
attn_dropout=attn_dropout,
use_moe=(use_moe and i in moe_layers),
num_experts=num_experts,
moe_top_k=moe_top_k,
use_adapter=use_adapter,
adapter_dim=adapter_dim,
use_lora=use_lora,
lora_rank=lora_rank,
use_parallel_residual=use_parallel_residual,
norm_eps=norm_eps,
sliding_window=sliding_window,
ffn_dim_multiplier=ffn_dim_multiplier,
layer_idx=i
)
for i in range(n_layers)
])
self.norm = RMSNorm(model_dim, eps=norm_eps)
self.lm_head = nn.Linear(model_dim, vocab_size, bias=False)
if tie_word_embeddings:
self.lm_head.weight = self.token_embedding.weight
self.initializer_range = initializer_range
self.apply(self._init_weights)
if not tie_word_embeddings:
self._init_lm_head()
self.n_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"\n{'='*80}")
print(f"Improved Model Configuration:")
print(f" Model Dimension: {model_dim}")
print(f" Vocab Size: {vocab_size}")
print(f" Layers: {n_layers}")
print(f" Attention Heads: {n_heads}")
print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}")
print(f" Max Sequence Length: {max_seq_len}")
print(f" Multimodal Fusion: {use_multimodal_fusion}")
print(f" Contrastive Learning: {use_contrastive}")
print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})")
print(f" Total Parameters: {self.n_params / 1e9:.2f}B")
print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B")
print(f"{'='*80}\n")
def _init_weights(self, module):
"""权重初始化"""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
if hasattr(module, 'padding_idx') and module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _init_lm_head(self):
"""初始化LM head"""
std = self.initializer_range / math.sqrt(2 * self.n_layers)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
def _encode_modality(self, segment: Dict) -> torch.Tensor:
"""编码单个模态"""
seg_type = segment['type']
seg_data = segment['data']
if seg_type == 'image':
return self.vision_encoder(seg_data)
elif seg_type == 'audio':
return self.audio_encoder(seg_data)
elif seg_type == 'video':
return self.video_encoder(seg_data)
elif seg_type == 'text':
return self.token_embedding(seg_data)
else:
return seg_data
def forward(
self,
input_data: Dict,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
return_hidden: bool = False,
use_cache: bool = False,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
compute_contrastive: bool = False
) -> Dict:
"""前向传播"""
device = self.token_embedding.weight.device
# 编码每个模态
encoded_segments = []
for segment in input_data.get('segments', []):
encoded = self._encode_modality(segment)
# 添加模态嵌入
modality_id = segment.get('modality_id', 0)
modality_embeds = self.modality_embedding(
torch.tensor([modality_id], device=device)
).expand(encoded.shape[0], encoded.shape[1], -1)
encoded_segments.append({
'type': segment['type'],
'data': encoded + modality_embeds,
'modality_id': modality_id
})
# 多模态融合
contrastive_losses = {}
if self.use_multimodal_fusion and len(encoded_segments) > 1:
fusion_output = self.fusion_module(
encoded_segments,
compute_contrastive=compute_contrastive
)
x = fusion_output['fused_features']
contrastive_losses = fusion_output.get('contrastive_losses', {})
else:
# 简单拼接
all_embeddings = [seg['data'] for seg in encoded_segments]
x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros(
1, 1, self.model_dim, device=device
)
x = self.embed_dropout(x)
if position_ids is None:
if past_key_values is not None:
# 缓存的长度 (KV cache 的 shape 是 [B, H, SeqLen, D])
past_length = past_key_values[0][0].size(2)
# 当前输入的长度
seq_length = x.shape[1]
# 生成正确的位置索引: [past_length, past_length + 1, ...]
position_ids = torch.arange(
past_length, past_length + seq_length, dtype=torch.long, device=device
).unsqueeze(0).expand(x.shape[0], -1)
else:
# 如果没有缓存,从 0 开始
seq_length = x.shape[1]
position_ids = torch.arange(
0, seq_length, dtype=torch.long, device=device
).unsqueeze(0).expand(x.shape[0], -1)
# Transformer层
present_key_values = [] if use_cache else None
all_hidden_states = [] if output_hidden_states else None
all_attentions = [] if output_attentions else None
moe_aux_loss = torch.tensor(0.0, device=device)
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(x)
past_kv = past_key_values[idx] if past_key_values is not None else None
if self.use_gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(
inputs[0],
attention_mask=inputs[1],
position_ids=inputs[2],
use_cache=False,
past_kv=None,
output_attentions=False
)
return custom_forward
import torch.utils.checkpoint as checkpoint
layer_outputs = checkpoint.checkpoint(
create_custom_forward(layer),
x,
attention_mask,
position_ids,
use_reentrant=False
)
x = layer_outputs[0]
present_kv = None
attn_weights = None
else:
layer_outputs = layer(
x,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_kv=past_kv,
output_attentions=output_attentions
)
x, present_kv, attn_weights = layer_outputs
if use_cache:
present_key_values.append(present_kv)
if output_attentions:
all_attentions.append(attn_weights)
if hasattr(layer, 'moe_aux_loss'):
moe_aux_loss += layer.moe_aux_loss
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)
if output_hidden_states:
all_hidden_states.append(hidden_states)
# 组装输出
outputs = {
'logits': logits,
'moe_aux_loss': moe_aux_loss,
'contrastive_losses': contrastive_losses
}
if use_cache:
outputs['past_key_values'] = present_key_values
if output_hidden_states:
outputs['hidden_states'] = all_hidden_states
if output_attentions:
outputs['attentions'] = all_attentions
if return_hidden:
outputs['last_hidden_state'] = hidden_states
return outputs
@torch.no_grad()
def generate(
self,
input_data: Dict,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.9,
eos_token_id: int = 2,
pad_token_id: Optional[int] = None,
use_cache: bool = True,
repetition_penalty: float = 1.0,
length_penalty: float = 1.0,
min_length: int = 0,
do_sample: bool = True,
num_beams: int = 1
) -> torch.Tensor:
"""改进的生成方法"""
self.eval()
device = next(self.parameters()).device
if pad_token_id is None:
pad_token_id = eos_token_id
initial_text_tokens = input_data['segments'][0]['data'].to(device)
batch_size = initial_text_tokens.shape[0]
if 'attention_mask' in input_data:
attention_mask = input_data['attention_mask'].to(device)
else:
attention_mask = torch.ones_like(initial_text_tokens)
initial_seq_len = initial_text_tokens.shape[1]
position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device)
for i in range(batch_size):
non_pad_mask = attention_mask[i].bool()
if non_pad_mask.any():
positions = torch.cumsum(non_pad_mask.long(),dim=0) -1
position_ids[i]=positions * non_pad_mask.long()
generated_tokens = []
past_key_values = None
current_tokens = initial_text_tokens
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
for step in range(max_new_tokens):
current_input_data = {
'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}]
}
if step > 0 and use_cache:
# 添加当前 token 的 mask (1)
new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device)
attention_mask = torch.cat([attention_mask, new_mask], dim=1)
current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0)
current_positions_ids=current_positions
else:
current_positions_ids=position_ids
outputs = self.forward(
current_input_data,
attention_mask=attention_mask, # <--- 传入 Mask
position_ids=current_positions_ids,
use_cache=use_cache,
past_key_values=past_key_values
)
logits = outputs['logits']
if use_cache:
past_key_values = outputs['past_key_values']
next_token_logits = logits[:, -1, :] / max(temperature, 1e-5)
# Repetition penalty
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
prev_generated = torch.cat(generated_tokens, dim=1)
score = torch.gather(next_token_logits, 1, prev_generated)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty
)
next_token_logits.scatter_(1, prev_generated, score)
# Min length constraint
if step < min_length:
next_token_logits[:, eos_token_id] = float('-inf')
# Sampling
if do_sample:
if top_k > 0:
top_k_vals, _ = torch.topk(next_token_logits, top_k)
min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1)
next_token_logits[next_token_logits < min_val_to_keep] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Apply unfinished mask
next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None])
generated_tokens.append(next_token)
if not use_cache:
initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1)
current_tokens = initial_text_tokens
else:
current_tokens = next_token
# Update unfinished sequences
unfinished_sequences = unfinished_sequences.mul(
(next_token.squeeze(-1) != eos_token_id).long()
)
if unfinished_sequences.max() == 0:
break
if not generated_tokens:
return torch.empty(batch_size, 0, dtype=torch.long, device=device)
return torch.cat(generated_tokens, dim=1) |