szxllm commited on
Commit
d16a3f0
·
verified ·
1 Parent(s): 9c85325

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +488 -504
model.py CHANGED
@@ -1,505 +1,489 @@
1
- """
2
- 改进的多模态Dense Transformer主模型
3
- 整合所有SOTA改进
4
- """
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from typing import List, Dict, Optional, Tuple
9
- import math
10
- from components import RMSNorm
11
- from transformer import OptimizedTransformerBlock
12
- from multimodel_fusion import MultiModalFusionModule
13
- from encoders import (
14
- ImprovedVisionTransformer,
15
- ImprovedAudioEncoder,
16
- ImprovedVideoEncoder
17
- )
18
-
19
- class MultiModalDenseTransformer(nn.Module):
20
- """
21
- 改进的统一多模态Dense Transformer
22
- 主要改进:
23
- 1. 深度跨模态融合
24
- 2. 模态特定的优化编码器
25
- 3. 对比学习对齐
26
- 4. 改进的位置编码和注意力机制
27
- 5. 更好的训练稳定性
28
- """
29
- def __init__(
30
- self,
31
- model_dim: int = 2048,
32
- vocab_size: int = 30000,
33
- n_layers: int = 48,
34
- n_heads: int = 32,
35
- n_kv_heads: Optional[int] = None,
36
- head_dim: Optional[int] = None,
37
- max_seq_len: int = 8192,
38
- dropout: float = 0.0,
39
- attn_dropout: float = 0.0,
40
-
41
- # MoE配置
42
- use_moe: bool = False,
43
- num_experts: int = 8,
44
- moe_top_k: int = 2,
45
- moe_layers: Optional[List[int]] = None,
46
-
47
- # PEFT配置
48
- use_adapter: bool = False,
49
- adapter_dim: int = 64,
50
- use_lora: bool = False,
51
- lora_rank: int = 8,
52
-
53
- # 训练配置
54
- use_gradient_checkpointing: bool = False,
55
- use_parallel_residual: bool = False,
56
-
57
- # 位置编码
58
- rope_scaling_factor: float = 1.0,
59
- rope_scaling_type: str = "yarn",
60
- sliding_window: Optional[int] = None,
61
-
62
- # 规范化
63
- norm_eps: float = 1e-6,
64
- initializer_range: float = 0.02,
65
- ffn_dim_multiplier: Optional[float] = None,
66
- tie_word_embeddings: bool = True,
67
-
68
- # 多模态配置
69
- use_multimodal_fusion: bool = True,
70
- fusion_layers: int = 4,
71
- use_contrastive: bool = True,
72
- vision_depth: int = 24,
73
- audio_depth: int = 12,
74
- video_spatial_depth: int = 12,
75
- video_temporal_depth: int = 4
76
- ):
77
- super().__init__()
78
-
79
- self.model_dim = model_dim
80
- self.vocab_size = vocab_size
81
- self.n_layers = n_layers
82
- self.max_seq_len = max_seq_len
83
- self.use_gradient_checkpointing = use_gradient_checkpointing
84
- self.tie_word_embeddings = tie_word_embeddings
85
- self.use_multimodal_fusion = use_multimodal_fusion
86
-
87
- # Token embedding
88
- self.token_embedding = nn.Embedding(vocab_size, model_dim)
89
- self.modality_embedding = nn.Embedding(4, model_dim)
90
- self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
91
-
92
- # 改进的模态编码器
93
- self.vision_encoder = ImprovedVisionTransformer(
94
- embed_dim=model_dim,
95
- depth=vision_depth,
96
- n_heads=n_heads,
97
- dropout=dropout,
98
- use_adapter=use_adapter,
99
- adapter_dim=adapter_dim
100
- )
101
-
102
- self.audio_encoder = ImprovedAudioEncoder(
103
- embed_dim=model_dim,
104
- depth=audio_depth,
105
- n_heads=n_heads,
106
- dropout=dropout,
107
- use_adapter=use_adapter,
108
- adapter_dim=adapter_dim
109
- )
110
-
111
- self.video_encoder = ImprovedVideoEncoder(
112
- embed_dim=model_dim,
113
- spatial_depth=video_spatial_depth,
114
- temporal_depth=video_temporal_depth,
115
- n_heads=n_heads,
116
- dropout=dropout,
117
- use_adapter=use_adapter,
118
- adapter_dim=adapter_dim
119
- )
120
-
121
- # 多模态融合模块
122
- if use_multimodal_fusion:
123
- self.fusion_module = MultiModalFusionModule(
124
- dim=model_dim,
125
- num_fusion_layers=fusion_layers,
126
- n_heads=n_heads,
127
- dropout=dropout,
128
- use_contrastive=use_contrastive
129
- )
130
-
131
- # Transformer layers
132
- if moe_layers is None and use_moe:
133
- moe_layers = list(range(n_layers // 2, n_layers))
134
- elif moe_layers is None:
135
- moe_layers = []
136
-
137
- self.layers = nn.ModuleList([
138
- OptimizedTransformerBlock(
139
- dim=model_dim,
140
- n_heads=n_heads,
141
- n_kv_heads=n_kv_heads,
142
- head_dim=head_dim,
143
- dropout=dropout,
144
- attn_dropout=attn_dropout,
145
- use_moe=(use_moe and i in moe_layers),
146
- num_experts=num_experts,
147
- moe_top_k=moe_top_k,
148
- use_adapter=use_adapter,
149
- adapter_dim=adapter_dim,
150
- use_lora=use_lora,
151
- lora_rank=lora_rank,
152
- use_parallel_residual=use_parallel_residual,
153
- norm_eps=norm_eps,
154
- sliding_window=sliding_window,
155
- ffn_dim_multiplier=ffn_dim_multiplier,
156
- layer_idx=i
157
- )
158
- for i in range(n_layers)
159
- ])
160
-
161
- self.norm = RMSNorm(model_dim, eps=norm_eps)
162
- self.lm_head = nn.Linear(model_dim, vocab_size, bias=False)
163
-
164
- if tie_word_embeddings:
165
- self.lm_head.weight = self.token_embedding.weight
166
-
167
- self.initializer_range = initializer_range
168
- self.apply(self._init_weights)
169
-
170
- if not tie_word_embeddings:
171
- self._init_lm_head()
172
-
173
- self.n_params = sum(p.numel() for p in self.parameters())
174
- trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
175
-
176
- print(f"\n{'='*80}")
177
- print(f"Improved Model Configuration:")
178
- print(f" Model Dimension: {model_dim}")
179
- print(f" Vocab Size: {vocab_size}")
180
- print(f" Layers: {n_layers}")
181
- print(f" Attention Heads: {n_heads}")
182
- print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}")
183
- print(f" Max Sequence Length: {max_seq_len}")
184
- print(f" Multimodal Fusion: {use_multimodal_fusion}")
185
- print(f" Contrastive Learning: {use_contrastive}")
186
- print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})")
187
- print(f" Total Parameters: {self.n_params / 1e9:.2f}B")
188
- print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B")
189
- print(f"{'='*80}\n")
190
-
191
- def _init_weights(self, module):
192
- """权重初始化"""
193
- if isinstance(module, nn.Linear):
194
- torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
195
- if module.bias is not None:
196
- torch.nn.init.zeros_(module.bias)
197
- elif isinstance(module, nn.Embedding):
198
- torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
199
- if hasattr(module, 'padding_idx') and module.padding_idx is not None:
200
- module.weight.data[module.padding_idx].zero_()
201
-
202
- def _init_lm_head(self):
203
- """初始化LM head"""
204
- std = self.initializer_range / math.sqrt(2 * self.n_layers)
205
- torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
206
-
207
- def _encode_modality(self, segment: Dict) -> torch.Tensor:
208
- """编码单个模态"""
209
- seg_type = segment['type']
210
- seg_data = segment['data']
211
-
212
- if seg_type == 'image':
213
- return self.vision_encoder(seg_data)
214
- elif seg_type == 'audio':
215
- return self.audio_encoder(seg_data)
216
- elif seg_type == 'video':
217
- return self.video_encoder(seg_data)
218
- elif seg_type == 'text':
219
- return self.token_embedding(seg_data)
220
- else:
221
- return seg_data
222
-
223
- def forward(
224
- self,
225
- input_data: Dict,
226
- attention_mask: Optional[torch.Tensor] = None,
227
- position_ids: Optional[torch.Tensor] = None,
228
- return_hidden: bool = False,
229
- use_cache: bool = False,
230
- past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
231
- output_attentions: bool = False,
232
- output_hidden_states: bool = False,
233
- compute_contrastive: bool = False
234
- ) -> Dict:
235
- """前向传播"""
236
- device = self.token_embedding.weight.device
237
-
238
- # 编码每个模态
239
- encoded_segments = []
240
- for segment in input_data.get('segments', []):
241
- encoded = self._encode_modality(segment)
242
-
243
- # 添加模态嵌入
244
- modality_id = segment.get('modality_id', 0)
245
- modality_embeds = self.modality_embedding(
246
- torch.tensor([modality_id], device=device)
247
- ).expand(encoded.shape[0], encoded.shape[1], -1)
248
-
249
- encoded_segments.append({
250
- 'type': segment['type'],
251
- 'data': encoded + modality_embeds,
252
- 'modality_id': modality_id
253
- })
254
-
255
- # 多模态融合
256
- contrastive_losses = {}
257
- if self.use_multimodal_fusion and len(encoded_segments) > 1:
258
- fusion_output = self.fusion_module(
259
- encoded_segments,
260
- compute_contrastive=compute_contrastive
261
- )
262
- x = fusion_output['fused_features']
263
- contrastive_losses = fusion_output.get('contrastive_losses', {})
264
- else:
265
- # 简单拼接
266
- all_embeddings = [seg['data'] for seg in encoded_segments]
267
- x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros(
268
- 1, 1, self.model_dim, device=device
269
- )
270
-
271
- x = self.embed_dropout(x)
272
- # 如果没有传入 position_ids,我们需要根据历史长度生成它
273
- if position_ids is None:
274
- if past_key_values is not None:
275
- # 缓存的长度 (KV cache shape 是 [B, H, SeqLen, D])
276
- past_length = past_key_values[0][0].size(2)
277
- # 当前输入的长度
278
- seq_length = x.shape[1]
279
- # 生成正确的位置索引: [past_length, past_length + 1, ...]
280
- position_ids = torch.arange(
281
- past_length, past_length + seq_length, dtype=torch.long, device=device
282
- ).unsqueeze(0).expand(x.shape[0], -1)
283
- else:
284
- # 如果没有缓存,从 0 开始
285
- seq_length = x.shape[1]
286
- position_ids = torch.arange(
287
- 0, seq_length, dtype=torch.long, device=device
288
- ).unsqueeze(0).expand(x.shape[0], -1)
289
- # Transformer层
290
- present_key_values = [] if use_cache else None
291
- all_hidden_states = [] if output_hidden_states else None
292
- all_attentions = [] if output_attentions else None
293
- moe_aux_loss = torch.tensor(0.0, device=device)
294
-
295
- for idx, layer in enumerate(self.layers):
296
- if output_hidden_states:
297
- all_hidden_states.append(x)
298
-
299
- past_kv = past_key_values[idx] if past_key_values is not None else None
300
-
301
- if self.use_gradient_checkpointing and self.training:
302
- def create_custom_forward(module):
303
- def custom_forward(*inputs):
304
- return module(
305
- inputs[0],
306
- attention_mask=inputs[1],
307
- position_ids=inputs[2],
308
- use_cache=False,
309
- past_kv=None,
310
- output_attentions=False
311
- )
312
- return custom_forward
313
-
314
- import torch.utils.checkpoint as checkpoint
315
- layer_outputs = checkpoint.checkpoint(
316
- create_custom_forward(layer),
317
- x,
318
- attention_mask,
319
- position_ids,
320
- use_reentrant=False
321
- )
322
- x = layer_outputs[0]
323
- present_kv = None
324
- attn_weights = None
325
- else:
326
- layer_outputs = layer(
327
- x,
328
- attention_mask=attention_mask,
329
- position_ids=position_ids,
330
- use_cache=use_cache,
331
- past_kv=past_kv,
332
- output_attentions=output_attentions
333
- )
334
- x, present_kv, attn_weights = layer_outputs
335
-
336
- if use_cache:
337
- present_key_values.append(present_kv)
338
-
339
- if output_attentions:
340
- all_attentions.append(attn_weights)
341
-
342
- if hasattr(layer, 'moe_aux_loss'):
343
- moe_aux_loss += layer.moe_aux_loss
344
-
345
- hidden_states = self.norm(x)
346
- logits = self.lm_head(hidden_states)
347
-
348
- if output_hidden_states:
349
- all_hidden_states.append(hidden_states)
350
-
351
- # 组装输出
352
- outputs = {
353
- 'logits': logits,
354
- 'moe_aux_loss': moe_aux_loss,
355
- 'contrastive_losses': contrastive_losses
356
- }
357
-
358
- if use_cache:
359
- outputs['past_key_values'] = present_key_values
360
-
361
- if output_hidden_states:
362
- outputs['hidden_states'] = all_hidden_states
363
-
364
- if output_attentions:
365
- outputs['attentions'] = all_attentions
366
-
367
- if return_hidden:
368
- outputs['last_hidden_state'] = hidden_states
369
-
370
- return outputs
371
-
372
- @torch.no_grad()
373
- def generate(
374
- self,
375
- input_data: Dict,
376
- max_new_tokens: int = 100,
377
- temperature: float = 1.0,
378
- top_k: int = 50,
379
- top_p: float = 0.9,
380
- eos_token_id: int = 2,
381
- pad_token_id: Optional[int] = None,
382
- use_cache: bool = True,
383
- repetition_penalty: float = 1.0,
384
- length_penalty: float = 1.0,
385
- min_length: int = 0,
386
- do_sample: bool = True,
387
- num_beams: int = 1
388
- ) -> torch.Tensor:
389
- """改进的生成方法"""
390
- self.eval()
391
- device = next(self.parameters()).device
392
-
393
- if pad_token_id is None:
394
- pad_token_id = eos_token_id
395
-
396
- initial_text_tokens = input_data['segments'][0]['data'].to(device)
397
- batch_size = initial_text_tokens.shape[0]
398
-
399
- if 'attention_mask' in input_data:
400
- attention_mask = input_data['attention_mask'].to(device)
401
- else:
402
- attention_mask = torch.ones_like(initial_text_tokens)
403
- initial_seq_len = initial_text_tokens.shape[1]
404
- position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device)
405
-
406
- for i in range(batch_size):
407
- non_pad_mask = attention_mask[i].bool()
408
- if non_pad_mask.any():
409
- positions = torch.cumsum(non_pad_mask.long(),dim=0) -1
410
- position_ids[i]=positions * non_pad_mask.long()
411
-
412
-
413
-
414
- generated_tokens = []
415
- past_key_values = None
416
- current_tokens = initial_text_tokens
417
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
418
-
419
- for step in range(max_new_tokens):
420
- current_input_data = {
421
- 'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}]
422
- }
423
-
424
- if step > 0 and use_cache:
425
- # 添加当前 token 的 mask (1)
426
- new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device)
427
- attention_mask = torch.cat([attention_mask, new_mask], dim=1)
428
- current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0)
429
- current_positions_ids=current_positions
430
- else:
431
- current_positions_ids=position_ids
432
- outputs = self.forward(
433
- current_input_data,
434
- attention_mask=attention_mask, # <--- 传入 Mask
435
- position_ids=current_positions_ids,
436
- use_cache=use_cache,
437
- past_key_values=past_key_values
438
- )
439
-
440
- logits = outputs['logits']
441
- if use_cache:
442
- past_key_values = outputs['past_key_values']
443
-
444
- next_token_logits = logits[:, -1, :] / max(temperature, 1e-5)
445
-
446
- # Repetition penalty
447
- if repetition_penalty != 1.0 and len(generated_tokens) > 0:
448
- prev_generated = torch.cat(generated_tokens, dim=1)
449
- score = torch.gather(next_token_logits, 1, prev_generated)
450
- score = torch.where(
451
- score < 0,
452
- score * repetition_penalty,
453
- score / repetition_penalty
454
- )
455
- next_token_logits.scatter_(1, prev_generated, score)
456
-
457
- # Min length constraint
458
- if step < min_length:
459
- next_token_logits[:, eos_token_id] = float('-inf')
460
-
461
- # Sampling
462
- if do_sample:
463
- if top_k > 0:
464
- top_k_vals, _ = torch.topk(next_token_logits, top_k)
465
- min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1)
466
- next_token_logits[next_token_logits < min_val_to_keep] = float('-inf')
467
-
468
- if top_p < 1.0:
469
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
470
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
471
- sorted_indices_to_remove = cumulative_probs > top_p
472
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
473
- sorted_indices_to_remove[..., 0] = 0
474
- indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
475
- indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
476
- next_token_logits[indices_to_remove] = float('-inf')
477
-
478
- probs = F.softmax(next_token_logits, dim=-1)
479
- next_token = torch.multinomial(probs, num_samples=1)
480
- else:
481
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
482
-
483
- # Apply unfinished mask
484
- next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None])
485
-
486
- generated_tokens.append(next_token)
487
-
488
- if not use_cache:
489
- initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1)
490
- current_tokens = initial_text_tokens
491
- else:
492
- current_tokens = next_token
493
-
494
- # Update unfinished sequences
495
- unfinished_sequences = unfinished_sequences.mul(
496
- (next_token.squeeze(-1) != eos_token_id).long()
497
- )
498
-
499
- if unfinished_sequences.max() == 0:
500
- break
501
-
502
- if not generated_tokens:
503
- return torch.empty(batch_size, 0, dtype=torch.long, device=device)
504
-
505
  return torch.cat(generated_tokens, dim=1)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Dict, Optional, Tuple
5
+ import math
6
+ from components import RMSNorm
7
+ from transformer import OptimizedTransformerBlock
8
+ from multimodel_fusion import MultiModalFusionModule
9
+ from encoders import (
10
+ ImprovedVisionTransformer,
11
+ ImprovedAudioEncoder,
12
+ ImprovedVideoEncoder
13
+ )
14
+
15
+ class MultiModalDenseTransformer(nn.Module):
16
+ def __init__(
17
+ self,
18
+ model_dim: int = 2048,
19
+ vocab_size: int = 30000,
20
+ n_layers: int = 48,
21
+ n_heads: int = 32,
22
+ n_kv_heads: Optional[int] = None,
23
+ head_dim: Optional[int] = None,
24
+ max_seq_len: int = 8192,
25
+ dropout: float = 0.0,
26
+ attn_dropout: float = 0.0,
27
+
28
+ # MoE配置
29
+ use_moe: bool = False,
30
+ num_experts: int = 8,
31
+ moe_top_k: int = 2,
32
+ moe_layers: Optional[List[int]] = None,
33
+
34
+ # PEFT配置
35
+ use_adapter: bool = False,
36
+ adapter_dim: int = 64,
37
+ use_lora: bool = False,
38
+ lora_rank: int = 8,
39
+
40
+ # 训练配置
41
+ use_gradient_checkpointing: bool = False,
42
+ use_parallel_residual: bool = False,
43
+
44
+ # 位置编码
45
+ rope_scaling_factor: float = 1.0,
46
+ rope_scaling_type: str = "yarn",
47
+ sliding_window: Optional[int] = None,
48
+
49
+ # 规范化
50
+ norm_eps: float = 1e-6,
51
+ initializer_range: float = 0.02,
52
+ ffn_dim_multiplier: Optional[float] = None,
53
+ tie_word_embeddings: bool = True,
54
+
55
+ # 多模态配置
56
+ use_multimodal_fusion: bool = True,
57
+ fusion_layers: int = 4,
58
+ use_contrastive: bool = True,
59
+ vision_depth: int = 24,
60
+ audio_depth: int = 12,
61
+ video_spatial_depth: int = 12,
62
+ video_temporal_depth: int = 4
63
+ ):
64
+ super().__init__()
65
+
66
+ self.model_dim = model_dim
67
+ self.vocab_size = vocab_size
68
+ self.n_layers = n_layers
69
+ self.max_seq_len = max_seq_len
70
+ self.use_gradient_checkpointing = use_gradient_checkpointing
71
+ self.tie_word_embeddings = tie_word_embeddings
72
+ self.use_multimodal_fusion = use_multimodal_fusion
73
+
74
+ # Token embedding
75
+ self.token_embedding = nn.Embedding(vocab_size, model_dim)
76
+ self.modality_embedding = nn.Embedding(4, model_dim)
77
+ self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
78
+
79
+ self.vision_encoder = ImprovedVisionTransformer(
80
+ embed_dim=model_dim,
81
+ depth=vision_depth,
82
+ n_heads=n_heads,
83
+ dropout=dropout,
84
+ use_adapter=use_adapter,
85
+ adapter_dim=adapter_dim
86
+ )
87
+
88
+ self.audio_encoder = ImprovedAudioEncoder(
89
+ embed_dim=model_dim,
90
+ depth=audio_depth,
91
+ n_heads=n_heads,
92
+ dropout=dropout,
93
+ use_adapter=use_adapter,
94
+ adapter_dim=adapter_dim
95
+ )
96
+
97
+ self.video_encoder = ImprovedVideoEncoder(
98
+ embed_dim=model_dim,
99
+ spatial_depth=video_spatial_depth,
100
+ temporal_depth=video_temporal_depth,
101
+ n_heads=n_heads,
102
+ dropout=dropout,
103
+ use_adapter=use_adapter,
104
+ adapter_dim=adapter_dim
105
+ )
106
+
107
+ # 多模态融合模块
108
+ if use_multimodal_fusion:
109
+ self.fusion_module = MultiModalFusionModule(
110
+ dim=model_dim,
111
+ num_fusion_layers=fusion_layers,
112
+ n_heads=n_heads,
113
+ dropout=dropout,
114
+ use_contrastive=use_contrastive
115
+ )
116
+
117
+ if moe_layers is None and use_moe:
118
+ moe_layers = list(range(n_layers // 2, n_layers))
119
+ elif moe_layers is None:
120
+ moe_layers = []
121
+
122
+ self.layers = nn.ModuleList([
123
+ OptimizedTransformerBlock(
124
+ dim=model_dim,
125
+ n_heads=n_heads,
126
+ n_kv_heads=n_kv_heads,
127
+ head_dim=head_dim,
128
+ dropout=dropout,
129
+ attn_dropout=attn_dropout,
130
+ use_moe=(use_moe and i in moe_layers),
131
+ num_experts=num_experts,
132
+ moe_top_k=moe_top_k,
133
+ use_adapter=use_adapter,
134
+ adapter_dim=adapter_dim,
135
+ use_lora=use_lora,
136
+ lora_rank=lora_rank,
137
+ use_parallel_residual=use_parallel_residual,
138
+ norm_eps=norm_eps,
139
+ sliding_window=sliding_window,
140
+ ffn_dim_multiplier=ffn_dim_multiplier,
141
+ layer_idx=i
142
+ )
143
+ for i in range(n_layers)
144
+ ])
145
+
146
+ self.norm = RMSNorm(model_dim, eps=norm_eps)
147
+ self.lm_head = nn.Linear(model_dim, vocab_size, bias=False)
148
+
149
+ if tie_word_embeddings:
150
+ self.lm_head.weight = self.token_embedding.weight
151
+
152
+ self.initializer_range = initializer_range
153
+ self.apply(self._init_weights)
154
+
155
+ if not tie_word_embeddings:
156
+ self._init_lm_head()
157
+
158
+ self.n_params = sum(p.numel() for p in self.parameters())
159
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
160
+
161
+ print(f"\n{'='*80}")
162
+ print(f"Improved Model Configuration:")
163
+ print(f" Model Dimension: {model_dim}")
164
+ print(f" Vocab Size: {vocab_size}")
165
+ print(f" Layers: {n_layers}")
166
+ print(f" Attention Heads: {n_heads}")
167
+ print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}")
168
+ print(f" Max Sequence Length: {max_seq_len}")
169
+ print(f" Multimodal Fusion: {use_multimodal_fusion}")
170
+ print(f" Contrastive Learning: {use_contrastive}")
171
+ print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})")
172
+ print(f" Total Parameters: {self.n_params / 1e9:.2f}B")
173
+ print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B")
174
+ print(f"{'='*80}\n")
175
+
176
+ def _init_weights(self, module):
177
+ """权重初始化"""
178
+ if isinstance(module, nn.Linear):
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
180
+ if module.bias is not None:
181
+ torch.nn.init.zeros_(module.bias)
182
+ elif isinstance(module, nn.Embedding):
183
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
184
+ if hasattr(module, 'padding_idx') and module.padding_idx is not None:
185
+ module.weight.data[module.padding_idx].zero_()
186
+
187
+ def _init_lm_head(self):
188
+ """初始化LM head"""
189
+ std = self.initializer_range / math.sqrt(2 * self.n_layers)
190
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
191
+
192
+ def _encode_modality(self, segment: Dict) -> torch.Tensor:
193
+ """编码单个模态"""
194
+ seg_type = segment['type']
195
+ seg_data = segment['data']
196
+
197
+ if seg_type == 'image':
198
+ return self.vision_encoder(seg_data)
199
+ elif seg_type == 'audio':
200
+ return self.audio_encoder(seg_data)
201
+ elif seg_type == 'video':
202
+ return self.video_encoder(seg_data)
203
+ elif seg_type == 'text':
204
+ return self.token_embedding(seg_data)
205
+ else:
206
+ return seg_data
207
+
208
+ def forward(
209
+ self,
210
+ input_data: Dict,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ position_ids: Optional[torch.Tensor] = None,
213
+ return_hidden: bool = False,
214
+ use_cache: bool = False,
215
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
216
+ output_attentions: bool = False,
217
+ output_hidden_states: bool = False,
218
+ compute_contrastive: bool = False
219
+ ) -> Dict:
220
+ """前向传播"""
221
+ device = self.token_embedding.weight.device
222
+
223
+ # 编码每个模态
224
+ encoded_segments = []
225
+ for segment in input_data.get('segments', []):
226
+ encoded = self._encode_modality(segment)
227
+
228
+ # 添加模态嵌入
229
+ modality_id = segment.get('modality_id', 0)
230
+ modality_embeds = self.modality_embedding(
231
+ torch.tensor([modality_id], device=device)
232
+ ).expand(encoded.shape[0], encoded.shape[1], -1)
233
+
234
+ encoded_segments.append({
235
+ 'type': segment['type'],
236
+ 'data': encoded + modality_embeds,
237
+ 'modality_id': modality_id
238
+ })
239
+
240
+ # 多模态融合
241
+ contrastive_losses = {}
242
+ if self.use_multimodal_fusion and len(encoded_segments) > 1:
243
+ fusion_output = self.fusion_module(
244
+ encoded_segments,
245
+ compute_contrastive=compute_contrastive
246
+ )
247
+ x = fusion_output['fused_features']
248
+ contrastive_losses = fusion_output.get('contrastive_losses', {})
249
+ else:
250
+ # 简单拼接
251
+ all_embeddings = [seg['data'] for seg in encoded_segments]
252
+ x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros(
253
+ 1, 1, self.model_dim, device=device
254
+ )
255
+
256
+ x = self.embed_dropout(x)
257
+ if position_ids is None:
258
+ if past_key_values is not None:
259
+ # 缓存的长度 (KV cache 的 shape 是 [B, H, SeqLen, D])
260
+ past_length = past_key_values[0][0].size(2)
261
+ # 当前输入的长度
262
+ seq_length = x.shape[1]
263
+ # 生成正确的位置索引: [past_length, past_length + 1, ...]
264
+ position_ids = torch.arange(
265
+ past_length, past_length + seq_length, dtype=torch.long, device=device
266
+ ).unsqueeze(0).expand(x.shape[0], -1)
267
+ else:
268
+ # 如果没有缓存,从 0 开始
269
+ seq_length = x.shape[1]
270
+ position_ids = torch.arange(
271
+ 0, seq_length, dtype=torch.long, device=device
272
+ ).unsqueeze(0).expand(x.shape[0], -1)
273
+ # Transformer层
274
+ present_key_values = [] if use_cache else None
275
+ all_hidden_states = [] if output_hidden_states else None
276
+ all_attentions = [] if output_attentions else None
277
+ moe_aux_loss = torch.tensor(0.0, device=device)
278
+
279
+ for idx, layer in enumerate(self.layers):
280
+ if output_hidden_states:
281
+ all_hidden_states.append(x)
282
+
283
+ past_kv = past_key_values[idx] if past_key_values is not None else None
284
+
285
+ if self.use_gradient_checkpointing and self.training:
286
+ def create_custom_forward(module):
287
+ def custom_forward(*inputs):
288
+ return module(
289
+ inputs[0],
290
+ attention_mask=inputs[1],
291
+ position_ids=inputs[2],
292
+ use_cache=False,
293
+ past_kv=None,
294
+ output_attentions=False
295
+ )
296
+ return custom_forward
297
+
298
+ import torch.utils.checkpoint as checkpoint
299
+ layer_outputs = checkpoint.checkpoint(
300
+ create_custom_forward(layer),
301
+ x,
302
+ attention_mask,
303
+ position_ids,
304
+ use_reentrant=False
305
+ )
306
+ x = layer_outputs[0]
307
+ present_kv = None
308
+ attn_weights = None
309
+ else:
310
+ layer_outputs = layer(
311
+ x,
312
+ attention_mask=attention_mask,
313
+ position_ids=position_ids,
314
+ use_cache=use_cache,
315
+ past_kv=past_kv,
316
+ output_attentions=output_attentions
317
+ )
318
+ x, present_kv, attn_weights = layer_outputs
319
+
320
+ if use_cache:
321
+ present_key_values.append(present_kv)
322
+
323
+ if output_attentions:
324
+ all_attentions.append(attn_weights)
325
+
326
+ if hasattr(layer, 'moe_aux_loss'):
327
+ moe_aux_loss += layer.moe_aux_loss
328
+
329
+ hidden_states = self.norm(x)
330
+ logits = self.lm_head(hidden_states)
331
+
332
+ if output_hidden_states:
333
+ all_hidden_states.append(hidden_states)
334
+
335
+ # 组装输出
336
+ outputs = {
337
+ 'logits': logits,
338
+ 'moe_aux_loss': moe_aux_loss,
339
+ 'contrastive_losses': contrastive_losses
340
+ }
341
+
342
+ if use_cache:
343
+ outputs['past_key_values'] = present_key_values
344
+
345
+ if output_hidden_states:
346
+ outputs['hidden_states'] = all_hidden_states
347
+
348
+ if output_attentions:
349
+ outputs['attentions'] = all_attentions
350
+
351
+ if return_hidden:
352
+ outputs['last_hidden_state'] = hidden_states
353
+
354
+ return outputs
355
+
356
+ @torch.no_grad()
357
+ def generate(
358
+ self,
359
+ input_data: Dict,
360
+ max_new_tokens: int = 100,
361
+ temperature: float = 1.0,
362
+ top_k: int = 50,
363
+ top_p: float = 0.9,
364
+ eos_token_id: int = 2,
365
+ pad_token_id: Optional[int] = None,
366
+ use_cache: bool = True,
367
+ repetition_penalty: float = 1.0,
368
+ length_penalty: float = 1.0,
369
+ min_length: int = 0,
370
+ do_sample: bool = True,
371
+ num_beams: int = 1
372
+ ) -> torch.Tensor:
373
+ """改进的生成方法"""
374
+ self.eval()
375
+ device = next(self.parameters()).device
376
+
377
+ if pad_token_id is None:
378
+ pad_token_id = eos_token_id
379
+
380
+ initial_text_tokens = input_data['segments'][0]['data'].to(device)
381
+ batch_size = initial_text_tokens.shape[0]
382
+
383
+ if 'attention_mask' in input_data:
384
+ attention_mask = input_data['attention_mask'].to(device)
385
+ else:
386
+ attention_mask = torch.ones_like(initial_text_tokens)
387
+ initial_seq_len = initial_text_tokens.shape[1]
388
+ position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device)
389
+
390
+ for i in range(batch_size):
391
+ non_pad_mask = attention_mask[i].bool()
392
+ if non_pad_mask.any():
393
+ positions = torch.cumsum(non_pad_mask.long(),dim=0) -1
394
+ position_ids[i]=positions * non_pad_mask.long()
395
+
396
+
397
+
398
+ generated_tokens = []
399
+ past_key_values = None
400
+ current_tokens = initial_text_tokens
401
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
402
+
403
+ for step in range(max_new_tokens):
404
+ current_input_data = {
405
+ 'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}]
406
+ }
407
+
408
+ if step > 0 and use_cache:
409
+ # 添加当前 token 的 mask (1)
410
+ new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device)
411
+ attention_mask = torch.cat([attention_mask, new_mask], dim=1)
412
+ current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0)
413
+ current_positions_ids=current_positions
414
+ else:
415
+ current_positions_ids=position_ids
416
+ outputs = self.forward(
417
+ current_input_data,
418
+ attention_mask=attention_mask, # <--- 传入 Mask
419
+ position_ids=current_positions_ids,
420
+ use_cache=use_cache,
421
+ past_key_values=past_key_values
422
+ )
423
+
424
+ logits = outputs['logits']
425
+ if use_cache:
426
+ past_key_values = outputs['past_key_values']
427
+
428
+ next_token_logits = logits[:, -1, :] / max(temperature, 1e-5)
429
+
430
+ # Repetition penalty
431
+ if repetition_penalty != 1.0 and len(generated_tokens) > 0:
432
+ prev_generated = torch.cat(generated_tokens, dim=1)
433
+ score = torch.gather(next_token_logits, 1, prev_generated)
434
+ score = torch.where(
435
+ score < 0,
436
+ score * repetition_penalty,
437
+ score / repetition_penalty
438
+ )
439
+ next_token_logits.scatter_(1, prev_generated, score)
440
+
441
+ # Min length constraint
442
+ if step < min_length:
443
+ next_token_logits[:, eos_token_id] = float('-inf')
444
+
445
+ # Sampling
446
+ if do_sample:
447
+ if top_k > 0:
448
+ top_k_vals, _ = torch.topk(next_token_logits, top_k)
449
+ min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1)
450
+ next_token_logits[next_token_logits < min_val_to_keep] = float('-inf')
451
+
452
+ if top_p < 1.0:
453
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
454
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
455
+ sorted_indices_to_remove = cumulative_probs > top_p
456
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
457
+ sorted_indices_to_remove[..., 0] = 0
458
+ indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
459
+ indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
460
+ next_token_logits[indices_to_remove] = float('-inf')
461
+
462
+ probs = F.softmax(next_token_logits, dim=-1)
463
+ next_token = torch.multinomial(probs, num_samples=1)
464
+ else:
465
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
466
+
467
+ # Apply unfinished mask
468
+ next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None])
469
+
470
+ generated_tokens.append(next_token)
471
+
472
+ if not use_cache:
473
+ initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1)
474
+ current_tokens = initial_text_tokens
475
+ else:
476
+ current_tokens = next_token
477
+
478
+ # Update unfinished sequences
479
+ unfinished_sequences = unfinished_sequences.mul(
480
+ (next_token.squeeze(-1) != eos_token_id).long()
481
+ )
482
+
483
+ if unfinished_sequences.max() == 0:
484
+ break
485
+
486
+ if not generated_tokens:
487
+ return torch.empty(batch_size, 0, dtype=torch.long, device=device)
488
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  return torch.cat(generated_tokens, dim=1)