szxllm commited on
Commit
cd66851
·
verified ·
1 Parent(s): 4f4d205

Upload 20 files

Browse files
Files changed (20) hide show
  1. components.py +387 -0
  2. continual_learning.py +294 -0
  3. contrastive_learning.py +339 -0
  4. data_augmentation.py +366 -0
  5. data_config.py +292 -0
  6. data_loader.py +832 -0
  7. encoders.py +559 -0
  8. gradio1.py +228 -0
  9. grpo.py +630 -0
  10. infer.py +372 -0
  11. infer_sft.py +407 -0
  12. model.py +505 -0
  13. moe.py +460 -0
  14. multimodel_fusion.py +522 -0
  15. peft_.py +213 -0
  16. post.py +532 -0
  17. posttrain.py +554 -0
  18. pretrain.py +502 -0
  19. reward_model.py +189 -0
  20. transformer.py +335 -0
components.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Optional, Union
5
+ import math
6
+
7
+ class YARNScaling:
8
+ """
9
+ YARN (Yet Another RoPE extensioN) 缩放策略
10
+ 实现参考: https://arxiv.org/abs/2309.00071
11
+ """
12
+ @staticmethod
13
+ def compute_yarn_parameters(
14
+ original_max_len: int,
15
+ target_max_len: int=8192,
16
+ dim: int=128,
17
+ base: int = 10000,
18
+ beta_fast: int = 32,
19
+ beta_slow: int = 1,
20
+ alpha: float = 1.0,
21
+ device: Optional[torch.device] = None
22
+ ) -> Tuple[torch.Tensor, float]:
23
+ scale = float(target_max_len) / original_max_len
24
+ mscale = YARNScaling.compute_mscale(scale, alpha)
25
+
26
+ # 确保 dim 为 float 以进行除法运算
27
+ # RoPE 频率是成对的 (0, 2, ..., d-2)
28
+ freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
29
+
30
+ # 基础频率 (Original RoPE)
31
+ freq_extra = 1.0 / (base ** (freqs_idx / dim))
32
+
33
+ # 如果不需要缩放,直接返回基础频率
34
+ if scale <= 1.0:
35
+ return freq_extra, 1.0
36
+
37
+ # 插值频率 (Interpolated for extension)
38
+ freq_inter = 1.0 / (scale * base ** (freqs_idx / dim))
39
+
40
+ # 计算 YARN 阈值 (基于波长/索引)
41
+ # 对应 paper 中的 band constraints
42
+ # 这里的公式将频率索引 i 映射到阈值
43
+ def get_limit(beta):
44
+ return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base))
45
+
46
+ low = max(math.floor(get_limit(beta_fast)), 0)
47
+ high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1)
48
+
49
+ # indices: 0, 1, ..., dim/2 - 1
50
+ indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device)
51
+
52
+ inv_freq = freq_extra.clone()
53
+
54
+ # 1. 低频部分 (Long wavelengths, Indices > high): 使用插值频率
55
+ # 这些频率对应的波长已经超过了原始上下文长度,需要拉伸
56
+ mask_low_freq = indices > high
57
+ inv_freq[mask_low_freq] = freq_inter[mask_low_freq]
58
+
59
+ # 2. 高频部分 (Short wavelengths, Indices < low): 保持原频率 (freq_extra)
60
+ # 这些部分受旋转不变性保护,不需要插值
61
+
62
+ # 3. 中间部分: 线性平滑混合 (Ramp function)
63
+ mid_mask = (indices >= low) & (indices <= high)
64
+ if mid_mask.any():
65
+ # 避免除以 0
66
+ denom = max(high - low, 1)
67
+ t = (indices[mid_mask] - low) / denom
68
+ inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t
69
+
70
+ return inv_freq, float(mscale)
71
+
72
+ @staticmethod
73
+ def compute_mscale(scale: float, alpha: float = 1.0) -> float:
74
+ """计算注意力缩放因子 (Temperature scaling)"""
75
+ if scale <= 1.0:
76
+ return 1.0
77
+ # 0.1 * ln(scale) + 1.0 是经验公式,用于修正熵值
78
+ return 0.1 * math.log(scale) + 1.0
79
+
80
+ class YARNRotaryEmbedding(nn.Module):
81
+ """
82
+ 集成 YARN 的旋转位置编码
83
+ 修复了精度问题、缓存管理以及 position_ids 越界问题
84
+ """
85
+ def __init__(
86
+ self,
87
+ dim: int = 64,
88
+ max_seq_len: int = 8192,
89
+ original_max_len: int = 4096,
90
+ base: int = 10000,
91
+ scaling_factor: float = 1.0, # 预留接口,暂未使用,由 yarn 逻辑控制
92
+ beta_fast: int = 32,
93
+ beta_slow: int = 1,
94
+ alpha: float = 1.0,
95
+ rope_percentage: float = 1.0,
96
+ device: Optional[torch.device] = None
97
+ ):
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.max_seq_len = max_seq_len
101
+ self.original_max_len = original_max_len
102
+ self.base = base
103
+ self.alpha = alpha
104
+
105
+ # 计算实际应用 RoPE 的维度
106
+ self.rope_dim = int(dim * rope_percentage)
107
+ # 确保是偶数
108
+ if self.rope_dim % 2 != 0:
109
+ self.rope_dim -= 1
110
+
111
+ # 初始化频率 (Persistent state)
112
+ self._init_yarn_frequencies(device)
113
+
114
+ # 缓存 cos/sin (Transient state)
115
+ # persistent=False 意味着不会保存到 state_dict,减少 checkpoint 大小
116
+ self.register_buffer("cos_cached", None, persistent=False)
117
+ self.register_buffer("sin_cached", None, persistent=False)
118
+
119
+ def _init_yarn_frequencies(self, device: Optional[torch.device] = None):
120
+ """初始化 YARN 频率"""
121
+ inv_freq, mscale = YARNScaling.compute_yarn_parameters(
122
+ self.original_max_len,
123
+ self.max_seq_len,
124
+ self.rope_dim,
125
+ self.base,
126
+ beta_fast=32, # 这里通常使用默认值或传入参数,此处修正为使用硬编码默认值保持一致,或应改为 self.beta_fast
127
+ beta_slow=1,
128
+ alpha=self.alpha,
129
+ device=device
130
+ )
131
+ # 注册 buffer
132
+ self.register_buffer("inv_freq", inv_freq, persistent=True)
133
+ self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True)
134
+
135
+ def _compute_cos_sin_cache(
136
+ self,
137
+ needed_len: int,
138
+ device: torch.device,
139
+ dtype: torch.dtype
140
+ ):
141
+ """预计算 cos 和 sin 缓存,始终使用 float32 计算以保证精度"""
142
+ # 至少分配 max_seq_len,如果 needed_len 更大则扩展
143
+ alloc_len = max(needed_len, self.max_seq_len)
144
+
145
+ # 如果已有缓存且足够大且设备匹配,则不重新计算 (可选优化,这里选择简单逻辑:不够就重算)
146
+ if (self.cos_cached is not None and
147
+ self.cos_cached.shape[2] >= alloc_len and
148
+ self.cos_cached.device == device):
149
+ return
150
+
151
+ t = torch.arange(alloc_len, dtype=torch.float32, device=device)
152
+
153
+ # freqs: [alloc_len, dim // 2]
154
+ # outer product: t[i] * inv_freq[j]
155
+ freqs = torch.outer(t, self.inv_freq.to(device))
156
+
157
+ # 拼接以匹配 rotate_half 的逻辑: [theta_0, theta_1, ..., theta_0, theta_1, ...]
158
+ emb = torch.cat((freqs, freqs), dim=-1)
159
+
160
+ # 应用 mscale 并计算 cos/sin
161
+ # [alloc_len, rope_dim] -> [1, 1, alloc_len, rope_dim] 用于广播
162
+ cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
163
+ sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
164
+
165
+ self.cos_cached = cos_cached.to(dtype) # 缓存可以存为半精度以省显存,但计算时建议 float32
166
+ self.sin_cached = sin_cached.to(dtype)
167
+
168
+ @staticmethod
169
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
170
+ """
171
+ 旋转输入的后半部分
172
+ Input: [..., d] -> Split into x1, x2 -> Output [-x2, x1]
173
+ """
174
+ x1, x2 = x.chunk(2, dim=-1)
175
+ return torch.cat((-x2, x1), dim=-1)
176
+
177
+ def apply_rotary_pos_emb(
178
+ self,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ position_ids: Optional[torch.Tensor] = None
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
183
+ """应用 RoPE,包含精度修正和边界检查"""
184
+ bsz, num_heads, seq_len, head_dim = q.shape
185
+
186
+ # 1. 确定需要的缓存长度
187
+ if position_ids is not None:
188
+ # 必须覆盖 position_ids 中的最大索引
189
+ max_pos = position_ids.max().item() + 1
190
+ needed_len = max(max_pos, seq_len)
191
+ else:
192
+ needed_len = seq_len
193
+
194
+ # 2. 检查并更新缓存
195
+ if (self.cos_cached is None or
196
+ self.cos_cached.shape[2] < needed_len or
197
+ self.cos_cached.device != q.device):
198
+ self._compute_cos_sin_cache(needed_len, q.device, q.dtype)
199
+
200
+ # 3. 获取对应的 cos/sin
201
+ # cos_cached: [1, 1, alloc_len, dim]
202
+ if position_ids is not None:
203
+ # position_ids: [bs, seq_len]
204
+ # 选取对应的 pos embedding -> [bs, 1, seq_len, dim]
205
+ # 注意: cos_cached[0, 0] 形状为 [alloc_len, dim]
206
+ cos = self.cos_cached[0, 0][position_ids].unsqueeze(1)
207
+ sin = self.sin_cached[0, 0][position_ids].unsqueeze(1)
208
+ else:
209
+ # 默认假设从 0 开始
210
+ cos = self.cos_cached[:, :, :seq_len, :]
211
+ sin = self.sin_cached[:, :, :seq_len, :]
212
+
213
+ # 4. 处理部分 RoPE (如果 rope_dim < head_dim)
214
+ if self.rope_dim < head_dim:
215
+ q_rot = q[..., :self.rope_dim]
216
+ q_pass = q[..., self.rope_dim:]
217
+ k_rot = k[..., :self.rope_dim]
218
+ k_pass = k[..., self.rope_dim:]
219
+ else:
220
+ q_rot = q
221
+ k_rot = k
222
+ q_pass = None
223
+ k_pass = None
224
+
225
+ # 5. 执行旋转 (强制 float32 计算以避免精度溢出)
226
+ q_rot_float = q_rot.float()
227
+ k_rot_float = k_rot.float()
228
+ cos_float = cos.float()
229
+ sin_float = sin.float()
230
+
231
+ q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float)
232
+ k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float)
233
+
234
+ # 6. 转回原始类型
235
+ q_embed = q_embed.type_as(q)
236
+ k_embed = k_embed.type_as(k)
237
+
238
+ if q_pass is not None:
239
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
240
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
241
+
242
+ return q_embed, k_embed
243
+
244
+ def forward(
245
+ self,
246
+ q: torch.Tensor,
247
+ k: torch.Tensor,
248
+ position_ids: Optional[torch.Tensor] = None
249
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
250
+ return self.apply_rotary_pos_emb(q, k, position_ids)
251
+
252
+ def extra_repr(self) -> str:
253
+ return (f"dim={self.dim}, rope_dim={self.rope_dim}, "
254
+ f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, "
255
+ f"base={self.base}")
256
+
257
+ class RMSNorm(nn.Module):
258
+ """
259
+ Root Mean Square Layer Normalization
260
+ 包含 float32 强制转换以确保数值稳定性
261
+ """
262
+ def __init__(
263
+ self,
264
+ dim: int,
265
+ eps: float = 1e-6,
266
+ elementwise_affine: bool = True
267
+ ):
268
+ super().__init__()
269
+ self.eps = eps
270
+ self.elementwise_affine = elementwise_affine
271
+
272
+ if self.elementwise_affine:
273
+ self.weight = nn.Parameter(torch.ones(dim))
274
+ else:
275
+ self.register_parameter('weight', None)
276
+
277
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
278
+ # 始终在 float32 下计算 RMS,防止 FP16 下溢或溢出
279
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ # 1. 转换为 float32 进行统计量计算
283
+ output = self._norm(x.float())
284
+
285
+ # 2. 转回原始类型
286
+ output = output.type_as(x)
287
+
288
+ # 3. 应用权重 (如果存在)
289
+ if self.elementwise_affine and self.weight is not None:
290
+ output = output * self.weight
291
+
292
+ return output
293
+
294
+ class QKNorm(nn.Module):
295
+ """
296
+ Query-Key Normalization (ViT-22B / Scaling Transformer)
297
+ 用于稳定注意力矩阵的 logits
298
+ """
299
+ def __init__(self, dim: int, eps: float = 1e-6):
300
+ super().__init__()
301
+ self.query_norm = RMSNorm(dim, eps=eps)
302
+ self.key_norm = RMSNorm(dim, eps=eps)
303
+
304
+ def forward(
305
+ self,
306
+ q: torch.Tensor,
307
+ k: torch.Tensor
308
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
309
+ q = self.query_norm(q)
310
+ k = self.key_norm(k)
311
+ return q, k
312
+
313
+ class SwiGLU(nn.Module):
314
+ """
315
+ SwiGLU 激活前馈网络
316
+ 结构: Down(SiLU(Gate) * Up)
317
+ """
318
+ def __init__(
319
+ self,
320
+ dim: int,
321
+ hidden_dim: Optional[int] = None,
322
+ multiple_of: int = 256,
323
+ ffn_dim_multiplier: Optional[float] = None,
324
+ dropout: float = 0.0,
325
+ bias: bool = False
326
+ ):
327
+ super().__init__()
328
+
329
+ if hidden_dim is None:
330
+ if ffn_dim_multiplier is not None:
331
+ hidden_dim = int(dim * ffn_dim_multiplier)
332
+ else:
333
+ # 默认: 2/3 * 4 * dim = 8/3 * dim (LLaMA standard)
334
+ hidden_dim = int(2 * dim * 4 / 3)
335
+
336
+ # 确保 hidden_dim 是 multiple_of 的倍数 (通常为了 GPU 核心优化)
337
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
338
+
339
+ self.hidden_dim = hidden_dim
340
+
341
+ # W1: Gate, W3: Up, W2: Down (Standard LLaMA naming conventions)
342
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
343
+ self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
344
+ self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
345
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
346
+
347
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
348
+ # SwiGLU(x) = (SiLU(W1·x) ⊙ W3·x) · W2
349
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
350
+
351
+ class ParallelAttentionFFN(nn.Module):
352
+ """
353
+ 并行注意力与前馈网络 (PaLM / GPT-J 风格)
354
+ y = x + Attention(LN(x)) + MLP(LN(x))
355
+ """
356
+ def __init__(
357
+ self,
358
+ dim: int,
359
+ attn_module: nn.Module,
360
+ ffn_module: nn.Module,
361
+ norm_eps: float = 1e-6
362
+ ):
363
+ super().__init__()
364
+ # 注意: 某些架构(如 PaLM)可能共用一个 LayerNorm,
365
+ # 但这里为了灵活性保留两个独立的 Norm (如 CodeLlama 某些变体)
366
+ self.attn_norm = RMSNorm(dim, eps=norm_eps)
367
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
368
+ self.attn = attn_module
369
+ self.ffn = ffn_module
370
+
371
+ def forward(
372
+ self,
373
+ x: torch.Tensor,
374
+ **attn_kwargs
375
+ ) -> torch.Tensor:
376
+ # 并行计算:从同一个 x (normalize 后) 分叉
377
+ attn_input = self.attn_norm(x)
378
+ ffn_input = self.ffn_norm(x)
379
+
380
+ # 计算注意力
381
+ attn_out = self.attn(attn_input, **attn_kwargs)
382
+
383
+ # 计算 FFN (确保不传递 attn 特定的 kwargs)
384
+ ffn_out = self.ffn(ffn_input)
385
+
386
+ # 一次性残差连接
387
+ return x + attn_out + ffn_out
continual_learning.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 持续学习模块
3
+ 支持EWC和经验回放
4
+ 修复版本:适配 MultiModalDenseTransformer 和 data_loader.py
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader
11
+ from collections import deque
12
+ from typing import List, Dict, Any, Optional, Union
13
+ from tqdm import tqdm
14
+ from dataclasses import dataclass
15
+
16
+ # 假设 model.py 中已有定义,用于类型提示
17
+ # from model import MultiModalDenseTransformer
18
+
19
+ @dataclass
20
+ class ModalityConfig:
21
+ name: str
22
+ modality_id: int
23
+
24
+ class UnifiedMultiModalPreprocessor(nn.Module):
25
+ """
26
+ 统一多模态预处理器
27
+ 职责:仅负责将原始Batch数据格式化为 MultiModalDenseTransformer 接受的 'segments' 结构。
28
+ 不再包含编码器,编码工作交由模型自身完成,以确保 EWC 能够捕捉模型参数的梯度。
29
+ """
30
+ def __init__(self, model_dim: int = 2048):
31
+ super().__init__()
32
+ self.modality_configs = {
33
+ 'text': ModalityConfig('text', 0),
34
+ 'image': ModalityConfig('image', 1),
35
+ 'audio': ModalityConfig('audio', 2),
36
+ 'video': ModalityConfig('video', 3)
37
+ }
38
+
39
+ def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]:
40
+ """
41
+ 将特定模态的数据封装为 segment 格式
42
+ """
43
+ processed_segments = []
44
+ if modality_type not in self.modality_configs:
45
+ return processed_segments
46
+
47
+ config = self.modality_configs[modality_type]
48
+
49
+ # 确保数据是 Tensor 格式
50
+ if isinstance(batch_data, list):
51
+ # 过滤 None
52
+ valid_data = [x for x in batch_data if x is not None]
53
+ if not valid_data:
54
+ return []
55
+ # 假设 list 中全是 Tensor,且维度一致,进行堆叠
56
+ # 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W)
57
+ try:
58
+ data_tensor = torch.stack(valid_data)
59
+ except Exception as e:
60
+ print(f"Error stacking modality data: {e}")
61
+ return []
62
+ elif isinstance(batch_data, torch.Tensor):
63
+ data_tensor = batch_data
64
+ else:
65
+ return []
66
+
67
+ processed_segments.append({
68
+ 'type': modality_type,
69
+ 'data': data_tensor, # 保持原始数据 (如图片像素),模型内部会encode
70
+ 'modality_id': config.modality_id
71
+ })
72
+ return processed_segments
73
+
74
+
75
+ class ExperienceReplayBuffer:
76
+ """经验回放缓冲区 - 内存安全版"""
77
+ def __init__(self, max_size: int = 10000):
78
+ self.buffer = deque(maxlen=max_size)
79
+
80
+ def add(self, sample: Dict[str, Any]):
81
+ """
82
+ 添加样本到buffer
83
+ 关键修复:将数据移至 CPU 并 detach,防止显存泄漏
84
+ """
85
+ safe_sample = {}
86
+ for k, v in sample.items():
87
+ if isinstance(v, torch.Tensor):
88
+ safe_sample[k] = v.detach().cpu()
89
+ elif isinstance(v, list):
90
+ # 递归处理 list 中的 tensor
91
+ safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v]
92
+ else:
93
+ safe_sample[k] = v
94
+ self.buffer.append(safe_sample)
95
+
96
+ def sample(self, batch_size: int) -> List[Any]:
97
+ """从buffer中采样"""
98
+ if not self.buffer:
99
+ return []
100
+
101
+ indices = np.random.choice(
102
+ len(self.buffer),
103
+ min(len(self.buffer), batch_size),
104
+ replace=False
105
+ )
106
+ return [self.buffer[i] for i in indices]
107
+
108
+ def __len__(self):
109
+ return len(self.buffer)
110
+
111
+ def clear(self):
112
+ """清空buffer"""
113
+ self.buffer.clear()
114
+
115
+
116
+ class EWC:
117
+ """弹性权重固化 (Elastic Weight Consolidation)"""
118
+ def __init__(
119
+ self,
120
+ model: nn.Module,
121
+ dataloader: DataLoader,
122
+ preprocessor: UnifiedMultiModalPreprocessor,
123
+ importance: float = 1000.0
124
+ ):
125
+ self.model = model
126
+ self.preprocessor = preprocessor
127
+ self.importance = importance
128
+ self.device = next(model.parameters()).device
129
+
130
+ # 冻结当前参数作为参考
131
+ self.params = {
132
+ n: p.clone().detach()
133
+ for n, p in model.named_parameters()
134
+ if p.requires_grad
135
+ }
136
+
137
+ self.fisher = self._compute_fisher(dataloader)
138
+
139
+ def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]:
140
+ """计算Fisher信息矩阵 (使用 Empirical Fisher)"""
141
+ fisher = {
142
+ n: torch.zeros_like(p)
143
+ for n, p in self.model.named_parameters()
144
+ if p.requires_grad
145
+ }
146
+
147
+ self.model.eval()
148
+ num_samples = 0
149
+
150
+ # 使用 tqdm 稍微简化输出
151
+ pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False)
152
+ for batch in pbar:
153
+ if batch is None: continue
154
+
155
+ self.model.zero_grad()
156
+
157
+ # 1. 准备文本输入
158
+ instruction_ids = batch['instruction'].to(self.device)
159
+ response_ids = batch['response'].to(self.device)
160
+ # 拼接: [Instruction, Response]
161
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
162
+
163
+ # 2. 准备多模态输入结构
164
+ input_data = {'segments': []}
165
+
166
+ # 处理额外的模态数据 (如果有)
167
+ # 这里的 batch['modality_data'] 可能是 list (由 collate_fn_v2 生成)
168
+ raw_modality_data = batch.get('modality_data')
169
+ if raw_modality_data is not None:
170
+ # 尝试判断模态类型,如果 dataset 中没有明确指定,默认尝试 'image'
171
+ # 实际应用中建议 dataset 返回 'modality_type'
172
+ modality_type = batch.get('modality_type', 'image')
173
+ if isinstance(modality_type, list): modality_type = modality_type[0]
174
+
175
+ # Preprocessor 处理数据堆叠和格式化
176
+ mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type)
177
+ # 只有在数据有效时才传给 device
178
+ for seg in mod_segments:
179
+ seg['data'] = seg['data'].to(self.device)
180
+ input_data['segments'].append(seg)
181
+
182
+ # 添加文本 Segment
183
+ input_data['segments'].append({
184
+ 'type': 'text',
185
+ 'data': input_ids,
186
+ 'modality_id': 0
187
+ })
188
+
189
+ # 3. 前向传播
190
+ output = self.model(input_data)
191
+ logits = output['logits'] # (B, Seq_Len, Vocab)
192
+
193
+ # 4. 计算 Loss (Standard Causal LM Loss)
194
+ # Shift logits and labels
195
+ # input_ids: [I1, I2, R1, R2]
196
+ # labels: [I2, R1, R2, EOS]
197
+ shift_logits = logits[:, :-1, :].contiguous()
198
+ shift_labels = input_ids[:, 1:].contiguous()
199
+
200
+ # 创建 Mask: 只在 Response 部分计算梯度
201
+ # Instruction 长度
202
+ inst_len = instruction_ids.shape[1]
203
+ loss_mask = torch.ones_like(shift_labels, dtype=torch.float)
204
+ if inst_len > 1:
205
+ # 掩盖 Instruction 部分 (注意 shift 后的索引偏移)
206
+ loss_mask[:, :inst_len-1] = 0.0
207
+
208
+ # 计算逐个 Token 的 Loss
209
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
210
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
211
+
212
+ # 应用 Mask 并求平均
213
+ loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6)
214
+
215
+ # 5. 反向传播累积梯度平方
216
+ loss.backward()
217
+
218
+ for n, p in self.model.named_parameters():
219
+ if p.grad is not None and n in fisher:
220
+ fisher[n] += p.grad.detach() ** 2
221
+
222
+ num_samples += input_ids.size(0)
223
+
224
+ # 平均化
225
+ if num_samples > 0:
226
+ for n in fisher:
227
+ fisher[n] /= num_samples
228
+
229
+ self.model.train()
230
+ return fisher
231
+
232
+ def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
233
+ """计算EWC惩罚项"""
234
+ # 兼容性处理:如果传入了 model 参数,优先使用(通常 self.model 就是同一个)
235
+ target_model = model if model is not None else self.model
236
+
237
+ loss = torch.tensor(0.0, device=self.device)
238
+
239
+ for n, p in target_model.named_parameters():
240
+ if n in self.params and p.requires_grad:
241
+ if n in self.fisher:
242
+ loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
243
+
244
+ return self.importance * loss
245
+
246
+
247
+ class OnlineEWC(EWC):
248
+ """在线EWC - 支持持续更新Fisher矩阵"""
249
+ def __init__(
250
+ self,
251
+ model: nn.Module,
252
+ preprocessor: UnifiedMultiModalPreprocessor,
253
+ importance: float = 1000.0,
254
+ gamma: float = 0.9
255
+ ):
256
+ # 初始时不计算 Fisher,等待 update_fisher 调用
257
+ self.model = model
258
+ self.preprocessor = preprocessor
259
+ self.importance = importance
260
+ self.gamma = gamma
261
+ self.device = next(model.parameters()).device
262
+
263
+ self.params = {}
264
+ self.fisher = {}
265
+ self.task_count = 0
266
+
267
+ def update_fisher(self, dataloader: DataLoader):
268
+ """更新Fisher信息矩阵"""
269
+ print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...")
270
+ new_fisher = self._compute_fisher(dataloader)
271
+
272
+ if self.task_count == 0:
273
+ self.fisher = new_fisher
274
+ else:
275
+ for n in self.fisher:
276
+ if n in new_fisher:
277
+ # 移动平均更新 Fisher 信息
278
+ self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n]
279
+
280
+ # 更新参考参数为当前任务训练后的参数
281
+ self.params = {
282
+ n: p.clone().detach()
283
+ for n, p in self.model.named_parameters()
284
+ if p.requires_grad
285
+ }
286
+
287
+ self.task_count += 1
288
+ print(f"Online EWC regularizer updated.")
289
+
290
+ def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
291
+ """计算EWC惩罚项"""
292
+ if self.task_count == 0:
293
+ return torch.tensor(0.0, device=self.device)
294
+ return super().penalty(model)
contrastive_learning.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict, Optional, Tuple, Union, Literal, List
5
+ import math
6
+ import copy
7
+
8
+ class CLIPLoss(nn.Module):
9
+ """CLIP风格的对比学习损失"""
10
+ def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0):
11
+ super().__init__()
12
+ self.temperature = temperature
13
+ self.max_temperature = max_temperature
14
+ # 初始化 logit_scale
15
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature))
16
+
17
+ def forward(
18
+ self,
19
+ image_features: torch.Tensor,
20
+ text_features: torch.Tensor
21
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
22
+ """
23
+ Args:
24
+ image_features: [B, D]
25
+ text_features: [B, D]
26
+ """
27
+ # 归一化
28
+ image_features = F.normalize(image_features, dim=-1)
29
+ text_features = F.normalize(text_features, dim=-1)
30
+
31
+ # 限制 logit_scale 防止数值不稳定
32
+ logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature)
33
+
34
+ # 计算相似度矩阵 [B, B]
35
+ # 注意:在 DDP 环境下,这里计算的是局部 Batch 的 Loss。
36
+ # 完整的 DDP 实现需要 gather 所有 GPU 的 features。
37
+ logits_per_image = logit_scale * image_features @ text_features.T
38
+ logits_per_text = logits_per_image.T
39
+
40
+ # 标签: 对角线为正样本
41
+ batch_size = image_features.shape[0]
42
+ labels = torch.arange(batch_size, device=image_features.device)
43
+
44
+ # 双向交叉熵
45
+ loss_i2t = F.cross_entropy(logits_per_image, labels)
46
+ loss_t2i = F.cross_entropy(logits_per_text, labels)
47
+
48
+ total_loss = (loss_i2t + loss_t2i) / 2
49
+
50
+ return total_loss, loss_i2t, loss_t2i
51
+
52
+ class SigLIPLoss(nn.Module):
53
+ """
54
+ SigLIP损失 - 包含可学习的 Bias 和 Scale
55
+ Paper: Sigmoid Loss for Language Image Pre-Training
56
+ """
57
+ def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0):
58
+ super().__init__()
59
+ self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature)))
60
+ self.b = nn.Parameter(torch.tensor(init_bias))
61
+
62
+ def forward(
63
+ self,
64
+ image_features: torch.Tensor,
65
+ text_features: torch.Tensor
66
+ ) -> torch.Tensor:
67
+ """
68
+ 注意:SigLIP 的标准实现不需要 Gather 全局负样本即可收敛,
69
+ 但这里实现的是 dense pair loss。对于超大 Batch (如 8k+),
70
+ 构造 [B, B] 的 labels 矩阵会导致显存爆炸,生产环境建议使用 custom kernel 或 block chunking。
71
+ """
72
+ # 归一化
73
+ image_features = F.normalize(image_features, dim=-1)
74
+ text_features = F.normalize(text_features, dim=-1)
75
+
76
+ batch_size = image_features.shape[0]
77
+
78
+ # Logits = exp(t) * (x @ yT) + b
79
+ logits = image_features @ text_features.T * self.t_prime.exp() + self.b
80
+
81
+ # 构造标签: 对角线为1,其余为-1
82
+ labels = -torch.ones(batch_size, batch_size, device=image_features.device)
83
+ labels += 2 * torch.eye(batch_size, device=image_features.device)
84
+
85
+ # Sigmoid Loss: -log(sigmoid(label * logits))
86
+ # 当 label=1: -log(sigmoid(z))
87
+ # 当 label=-1: -log(sigmoid(-z)) = -log(1 - sigmoid(z))
88
+ # 这就是标准的 Binary Cross Entropy (Summed)
89
+
90
+ # SigLIP 论文中通常建议除以 batch_size (或正样本数量) 进行归一化
91
+ loss = -F.logsigmoid(labels * logits).sum() / batch_size
92
+
93
+ return loss
94
+
95
+ class InfoNCELoss(nn.Module):
96
+ """InfoNCE损失 - 支持显式负样本或 Batch 内负样本"""
97
+ def __init__(self, temperature: float = 0.07):
98
+ super().__init__()
99
+ self.temperature = temperature
100
+
101
+ def forward(
102
+ self,
103
+ query: torch.Tensor,
104
+ positive_key: torch.Tensor,
105
+ negative_keys: Optional[torch.Tensor] = None
106
+ ) -> torch.Tensor:
107
+ """
108
+ Args:
109
+ query: [B, D]
110
+ positive_key: [B, D]
111
+ negative_keys: [B, N, D] or None.
112
+ """
113
+ query = F.normalize(query, dim=-1)
114
+ positive_key = F.normalize(positive_key, dim=-1)
115
+
116
+ if negative_keys is not None:
117
+ # 显式负样本
118
+ # pos_sim: [B]
119
+ pos_sim = (query * positive_key).sum(dim=-1) / self.temperature
120
+
121
+ negative_keys = F.normalize(negative_keys, dim=-1)
122
+ # neg_sim: [B, N]
123
+ neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature
124
+
125
+ # [B, 1 + N]
126
+ logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
127
+ # 正样本在索引0
128
+ labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
129
+ else:
130
+ # Batch内负样本 (类似于 CLIP 的单向 Loss)
131
+ logits = query @ positive_key.T / self.temperature
132
+ labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device)
133
+
134
+ loss = F.cross_entropy(logits, labels)
135
+ return loss
136
+
137
+ class ProjectionHead(nn.Module):
138
+ """
139
+ 投影头:处理特征维度变换和形状适配
140
+ 针对 Transformer 输出 (Sequence) 提供了更精细的 Pooling 控制。
141
+ """
142
+ def __init__(
143
+ self,
144
+ input_dim: int,
145
+ embed_dim: int,
146
+ pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean',
147
+ exclude_first_token: bool = False
148
+ ):
149
+ super().__init__()
150
+ self.pooling_type = pooling_type
151
+ self.exclude_first_token = exclude_first_token
152
+
153
+ self.net = nn.Sequential(
154
+ nn.Linear(input_dim, embed_dim),
155
+ nn.GELU(),
156
+ nn.Linear(embed_dim, embed_dim)
157
+ )
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ # 适配 3D 张量 [B, Seq, D] -> [B, D]
161
+ if x.dim() == 3:
162
+ if self.pooling_type == 'cls':
163
+ # 假设索引0是CLS token (Standard ViT / BERT)
164
+ x = x[:, 0, :]
165
+
166
+ elif self.pooling_type == 'mean':
167
+ if self.exclude_first_token and x.shape[1] > 1:
168
+ # 对于 ViT,如果使用 mean pooling,通常需要排除 CLS token
169
+ x = x[:, 1:, :].mean(dim=1)
170
+ else:
171
+ x = x.mean(dim=1)
172
+
173
+ elif self.pooling_type == 'max':
174
+ if self.exclude_first_token and x.shape[1] > 1:
175
+ x = x[:, 1:, :].max(dim=1)[0]
176
+ else:
177
+ x = x.max(dim=1)[0]
178
+
179
+ elif self.pooling_type == 'none':
180
+ # 保留序列维度,适用于 Dense Prediction 或细粒度对比
181
+ # 此时输出为 [B, Seq, embed_dim]
182
+ pass
183
+
184
+ return self.net(x)
185
+
186
+ class MultiModalContrastiveLoss(nn.Module):
187
+ """多模态对比学习损失 - 支持动态模态和异构维度"""
188
+ def __init__(
189
+ self,
190
+ embed_dim: int = 512,
191
+ input_dims: Union[int, Dict[str, int]] = 2048,
192
+ temperature: float = 0.07,
193
+ loss_type: str = 'clip',
194
+ modality_config: Optional[Dict[str, str]] = None
195
+ ):
196
+ super().__init__()
197
+ self.embed_dim = embed_dim
198
+ self.loss_type = loss_type
199
+
200
+ if loss_type == 'clip':
201
+ self.loss_fn = CLIPLoss(temperature)
202
+ elif loss_type == 'siglip':
203
+ self.loss_fn = SigLIPLoss()
204
+ else:
205
+ self.loss_fn = InfoNCELoss(temperature)
206
+
207
+ self.projectors = nn.ModuleDict()
208
+
209
+ if modality_config is None:
210
+ # 默认常用模态配置
211
+ # 注意:ImprovedVisionTransformer 输出带 CLS,所以图像推荐用 'cls' 或带排除的 'mean'
212
+ modality_config = {
213
+ 'text': 'cls',
214
+ 'image': 'cls',
215
+ 'audio': 'mean', # AudioEncoder 的双流输出已经是 2D,但如果是纯 Transformer 输出则是 3D
216
+ 'video': 'mean' # VideoEncoder 输出通常是 [B, T, D]
217
+ }
218
+
219
+ self.modality_config = modality_config
220
+
221
+ # 初始化投影头
222
+ for mod_name, pool_type in modality_config.items():
223
+ dim = 0
224
+ if isinstance(input_dims, dict):
225
+ dim = input_dims.get(mod_name)
226
+ # 如果字典里没给这个模态的维度,跳过初始化,避免 crash
227
+ if dim is None:
228
+ continue
229
+ else:
230
+ dim = input_dims
231
+
232
+ # 特殊处理:如果是 'mean' 或 'max' 且是 image/text,可能需要排除 CLS
233
+ # 这里做一个启发式判断,用户也可以手动修改
234
+ exclude_first = False
235
+ if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']:
236
+ exclude_first = True
237
+
238
+ self.projectors[mod_name] = ProjectionHead(
239
+ input_dim=dim,
240
+ embed_dim=embed_dim,
241
+ pooling_type=pool_type,
242
+ exclude_first_token=exclude_first
243
+ )
244
+
245
+ def forward(
246
+ self,
247
+ features: Dict[str, torch.Tensor],
248
+ modality_pairs: Optional[List[Tuple[str, str]]] = None
249
+ ) -> Dict[str, torch.Tensor]:
250
+
251
+ # 自动生成对比对:将所有非Text模态与Text对比
252
+ if modality_pairs is None:
253
+ if 'text' in features:
254
+ modality_pairs = [
255
+ (mod, 'text') for mod in features.keys() if mod != 'text'
256
+ ]
257
+ else:
258
+ return {}
259
+
260
+ losses = {}
261
+
262
+ for mod_a, mod_b in modality_pairs:
263
+ if mod_a not in features or mod_b not in features:
264
+ continue
265
+
266
+ if mod_a not in self.projectors or mod_b not in self.projectors:
267
+ # 记录警告或跳过
268
+ continue
269
+
270
+ feat_a = self.projectors[mod_a](features[mod_a])
271
+ feat_b = self.projectors[mod_b](features[mod_b])
272
+
273
+ # 计算损失
274
+ loss_key = f'{mod_a}_{mod_b}_loss'
275
+
276
+ if self.loss_type == 'clip':
277
+ loss, _, _ = self.loss_fn(feat_a, feat_b)
278
+ else:
279
+ loss = self.loss_fn(feat_a, feat_b)
280
+
281
+ losses[loss_key] = loss
282
+
283
+ return losses
284
+
285
+ class MomentumEncoder(nn.Module):
286
+ """
287
+ 动量编码器 - 用于MoCo风格的对比学习
288
+ 支持参数和 Buffer (如 BatchNorm stats) 的动量更新
289
+ """
290
+ def __init__(self, encoder: nn.Module, momentum: float = 0.999):
291
+ super().__init__()
292
+ self.encoder = encoder
293
+ self.momentum_encoder = self._build_momentum_encoder(encoder)
294
+ self.momentum = momentum
295
+
296
+ def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module:
297
+ """构建动量编码器"""
298
+ momentum_encoder = copy.deepcopy(encoder)
299
+
300
+ # 冻结动量编码器参数
301
+ for param in momentum_encoder.parameters():
302
+ param.requires_grad = False
303
+
304
+ return momentum_encoder
305
+
306
+ @torch.no_grad()
307
+ def _update_momentum_encoder(self):
308
+ """更新动量编码器 (In-place update)"""
309
+ # 更新参数
310
+ for param_q, param_k in zip(
311
+ self.encoder.parameters(),
312
+ self.momentum_encoder.parameters()
313
+ ):
314
+ # EMA Update: k = m * k + (1 - m) * q
315
+ param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum)
316
+
317
+ # 更新 Buffers (如 BatchNorm running mean/var)
318
+ # 简单的策略是直接覆盖,或者同样使用 EMA。通常直接覆盖即可,
319
+ # 因为 Key Encoder 处于 Eval 模式,不追踪 batch stats。
320
+ for buffer_q, buffer_k in zip(
321
+ self.encoder.buffers(),
322
+ self.momentum_encoder.buffers()
323
+ ):
324
+ buffer_k.data.copy_(buffer_q.data)
325
+
326
+ def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor:
327
+ """
328
+ Args:
329
+ x: 输入数据
330
+ use_momentum: 如果为 True,使用动量编码器 (通常用于生成 Key/Target)
331
+ """
332
+ if use_momentum:
333
+ with torch.no_grad():
334
+ self._update_momentum_encoder()
335
+ # 动量编码器始终处于 eval 模式
336
+ self.momentum_encoder.eval()
337
+ return self.momentum_encoder(x)
338
+ else:
339
+ return self.encoder(x)
data_augmentation.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 数据增强模块
3
+ 针对不同模态的高级数据增强策略
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, List
9
+ import random
10
+ import math
11
+
12
+ class RandAugment(nn.Module):
13
+ """RandAugment for images"""
14
+ def __init__(self, n: int = 2, m: int = 10):
15
+ super().__init__()
16
+ self.n = n
17
+ self.m = m
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ """随机应用n个增强操作"""
21
+ # 确保输入是 [B, C, H, W],如果是 [C, H, W] 则增加维度
22
+ is_batched = x.ndim == 4
23
+ if not is_batched:
24
+ x = x.unsqueeze(0)
25
+
26
+ augmentations = [
27
+ self._auto_contrast,
28
+ self._equalize,
29
+ self._solarize,
30
+ self._color,
31
+ self._contrast,
32
+ self._brightness,
33
+ self._sharpness,
34
+ ]
35
+
36
+ # 这里的ops应该是每一轮随机选择,而不是固定
37
+ for _ in range(self.n):
38
+ aug = random.choice(augmentations)
39
+ x = aug(x)
40
+
41
+ if not is_batched:
42
+ x = x.squeeze(0)
43
+
44
+ return x
45
+
46
+ def _auto_contrast(self, x: torch.Tensor) -> torch.Tensor:
47
+ """自动对比度: 线性拉伸到 [0, 1]"""
48
+ # 针对每个样本分别计算 min/max
49
+ # x: [B, C, H, W]
50
+ B, C, H, W = x.shape
51
+ x_flat = x.view(B, C, -1)
52
+ min_val = x_flat.min(dim=2, keepdim=True)[0].view(B, C, 1, 1)
53
+ max_val = x_flat.max(dim=2, keepdim=True)[0].view(B, C, 1, 1)
54
+ return (x - min_val) / (max_val - min_val + 1e-8)
55
+
56
+ def _equalize(self, x: torch.Tensor) -> torch.Tensor:
57
+ """直方图均衡化 (简化版:基于每个通道的CDF)"""
58
+ # 这是一个计算密集型操作,PyTorch原生实现较复杂。
59
+ # 这里实现一个基于排序的简化版本,模拟均衡化效果
60
+ B, C, H, W = x.shape
61
+ # 将像素值缩放到 [0, 255] 离散化以便计算直方图
62
+ x_int = (x * 255).long().clamp(0, 255)
63
+
64
+ out = torch.zeros_like(x)
65
+
66
+ for b in range(B):
67
+ for c in range(C):
68
+ hist = torch.histc(x[b, c].float(), bins=256, min=0, max=1)
69
+ cdf = hist.cumsum(0)
70
+ cdf = cdf / cdf[-1] # 归一化
71
+ # 使用cdf作为查找表
72
+ out[b, c] = cdf[x_int[b, c]]
73
+
74
+ return out
75
+
76
+ def _solarize(self, x: torch.Tensor) -> torch.Tensor:
77
+ """曝光"""
78
+ threshold = random.uniform(0.3, 0.7)
79
+ return torch.where(x < threshold, x, 1.0 - x)
80
+
81
+ def _color(self, x: torch.Tensor) -> torch.Tensor:
82
+ """颜色增强 (饱和度)"""
83
+ factor = 1.0 + (random.random() - 0.5) * 0.4
84
+ # RGB转灰度简单近似: mean over channels
85
+ # x is [B, C, H, W], dim=1 is channels
86
+ mean = x.mean(dim=1, keepdim=True)
87
+ return torch.clamp(mean + factor * (x - mean), 0, 1)
88
+
89
+ def _contrast(self, x: torch.Tensor) -> torch.Tensor:
90
+ """对比度"""
91
+ factor = 1.0 + (random.random() - 0.5) * 0.4
92
+ # 计算整张图的均值,保留 Batch 维度
93
+ # view(B, -1) -> mean(1) -> view(B, 1, 1, 1)
94
+ mean = x.view(x.size(0), -1).mean(dim=1).view(-1, 1, 1, 1)
95
+ return torch.clamp(mean + factor * (x - mean), 0, 1)
96
+
97
+ def _brightness(self, x: torch.Tensor) -> torch.Tensor:
98
+ """亮度"""
99
+ factor = 1.0 + (random.random() - 0.5) * 0.4
100
+ return torch.clamp(x * factor, 0, 1)
101
+
102
+ def _sharpness(self, x: torch.Tensor) -> torch.Tensor:
103
+ """锐化: 通过混合原图和高斯模糊图实现"""
104
+ factor = 1.0 + (random.random() - 0.5) * 0.4
105
+ # 使用 AvgPool 模拟模糊
106
+ kernel_size = 3
107
+ pad = kernel_size // 2
108
+ blurred = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad)
109
+ # 锐化公式: Original + alpha * (Original - Blurred)
110
+ # 或者简单的混合: Blend(Original, Blurred, factor)
111
+ # 这里使用 PIL 风格的锐化:
112
+ # result = original * factor + blurred * (1 - factor)
113
+ # 但要注意 factor>1 时是锐化,factor<1 是模糊
114
+ # 更标准的锐化掩模: x + factor * (x - blurred)
115
+ return torch.clamp(x + (factor - 1.0) * (x - blurred), 0, 1)
116
+
117
+ class MixUp(nn.Module):
118
+ """MixUp数据增强"""
119
+ def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
120
+ super().__init__()
121
+ self.alpha = alpha
122
+ self.num_classes = num_classes
123
+
124
+ def forward(
125
+ self,
126
+ x: torch.Tensor,
127
+ y: Optional[torch.Tensor] = None
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
129
+
130
+ if self.alpha > 0:
131
+ lambda_ = random.betavariate(self.alpha, self.alpha)
132
+ else:
133
+ lambda_ = 1.0
134
+
135
+ batch_size = x.shape[0]
136
+ index = torch.randperm(batch_size, device=x.device)
137
+
138
+ mixed_x = lambda_ * x + (1 - lambda_) * x[index]
139
+
140
+ mixed_y = None
141
+ if y is not None:
142
+ # 处理标签混合
143
+ y_a = y
144
+ y_b = y[index]
145
+
146
+ # 检查标签是否需要 One-Hot 编码 (如果 y 是 long 类型且维度不对)
147
+ if y.dtype == torch.long or y.ndim == 1:
148
+ if self.num_classes is None:
149
+ # 如果未提供 num_classes,尝试推断 (可能有风险)
150
+ self.num_classes = int(y.max().item()) + 1
151
+
152
+ y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
153
+ y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
154
+
155
+ mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
156
+
157
+ return mixed_x, mixed_y, lambda_
158
+
159
+ class CutMix(nn.Module):
160
+ """CutMix数据增强"""
161
+ def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
162
+ super().__init__()
163
+ self.alpha = alpha
164
+ self.num_classes = num_classes
165
+
166
+ def _rand_bbox(
167
+ self,
168
+ size: Tuple[int, ...],
169
+ lambda_: float
170
+ ) -> Tuple[int, int, int, int]:
171
+ """生成随机bbox"""
172
+ W = size[-1] # 兼容 [B, C, H, W]
173
+ H = size[-2]
174
+ cut_rat = math.sqrt(1.0 - lambda_)
175
+ cut_w = int(W * cut_rat)
176
+ cut_h = int(H * cut_rat)
177
+
178
+ cx = random.randint(0, W)
179
+ cy = random.randint(0, H)
180
+
181
+ bbx1 = torch.tensor(cx - cut_w // 2, device='cpu').clamp(0, W).item()
182
+ bby1 = torch.tensor(cy - cut_h // 2, device='cpu').clamp(0, H).item()
183
+ bbx2 = torch.tensor(cx + cut_w // 2, device='cpu').clamp(0, W).item()
184
+ bby2 = torch.tensor(cy + cut_h // 2, device='cpu').clamp(0, H).item()
185
+
186
+ return int(bbx1), int(bby1), int(bbx2), int(bby2)
187
+
188
+ def forward(
189
+ self,
190
+ x: torch.Tensor,
191
+ y: Optional[torch.Tensor] = None
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
193
+
194
+ if self.alpha > 0:
195
+ lambda_ = random.betavariate(self.alpha, self.alpha)
196
+ else:
197
+ lambda_ = 1.0
198
+
199
+ batch_size = x.shape[0]
200
+ index = torch.randperm(batch_size, device=x.device)
201
+
202
+ bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lambda_)
203
+
204
+ # 克隆防止就地修改影响后续梯度计算 (虽然这里是输入数据处理,通常还好)
205
+ x = x.clone()
206
+ x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
207
+
208
+ # 调整lambda为精确的像素比例
209
+ # 注意: 原始代码中宽高的计算顺序可能有歧义,这里统一 H=size[-2], W=size[-1]
210
+ H, W = x.size()[-2], x.size()[-1]
211
+ lambda_ = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
212
+
213
+ mixed_y = None
214
+ if y is not None:
215
+ y_a = y
216
+ y_b = y[index]
217
+
218
+ if y.dtype == torch.long or y.ndim == 1:
219
+ if self.num_classes is None:
220
+ # 最好在初始化时传入 num_classes
221
+ self.num_classes = int(y.max().item()) + 1
222
+ y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
223
+ y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
224
+
225
+ mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
226
+
227
+ return x, mixed_y, lambda_
228
+
229
+ class SpecAugment(nn.Module):
230
+ """SpecAugment for audio spectrograms"""
231
+ def __init__(
232
+ self,
233
+ freq_mask_param: int = 27,
234
+ time_mask_param: int = 100,
235
+ num_freq_masks: int = 2,
236
+ num_time_masks: int = 2
237
+ ):
238
+ super().__init__()
239
+ self.freq_mask_param = freq_mask_param
240
+ self.time_mask_param = time_mask_param
241
+ self.num_freq_masks = num_freq_masks
242
+ self.num_time_masks = num_time_masks
243
+
244
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
245
+ """
246
+ Args:
247
+ spec: [B, F, T] or [B, C, F, T]
248
+ """
249
+ input_ndim = spec.ndim
250
+ if input_ndim == 3:
251
+ spec = spec.unsqueeze(1) # [B, 1, F, T]
252
+
253
+ B, C, F, T = spec.shape
254
+ spec = spec.clone()
255
+
256
+ # 频率遮罩
257
+ for _ in range(self.num_freq_masks):
258
+ # 确保 mask 不超过 F
259
+ f_param = min(self.freq_mask_param, F)
260
+ f = random.randint(0, f_param)
261
+ f0 = random.randint(0, max(0, F - f))
262
+ spec[:, :, f0:f0+f, :] = 0
263
+
264
+ # 时间遮罩
265
+ for _ in range(self.num_time_masks):
266
+ # 确保 mask 不超过 T
267
+ t_param = min(self.time_mask_param, T)
268
+ t = random.randint(0, t_param)
269
+ t0 = random.randint(0, max(0, T - t))
270
+ spec[:, :, :, t0:t0+t] = 0
271
+
272
+ if input_ndim == 3:
273
+ return spec.squeeze(1)
274
+ return spec
275
+
276
+ class TemporalMasking(nn.Module):
277
+ """视频的时序遮罩"""
278
+ def __init__(self, mask_ratio: float = 0.15):
279
+ super().__init__()
280
+ self.mask_ratio = mask_ratio
281
+
282
+ def forward(self, video: torch.Tensor) -> torch.Tensor:
283
+ """
284
+ Args:
285
+ video: [B, T, C, H, W]
286
+ """
287
+ B, T, C, H, W = video.shape
288
+ num_mask = int(T * self.mask_ratio)
289
+ if num_mask == 0:
290
+ return video
291
+
292
+ video = video.clone()
293
+
294
+ for b in range(B):
295
+ # 随机采样要遮罩的帧索引
296
+ mask_indices = torch.randperm(T)[:num_mask]
297
+ video[b, mask_indices] = 0
298
+
299
+ return video
300
+
301
+ class MultiModalAugmentation(nn.Module):
302
+ """统一的多模态数据增强"""
303
+ def __init__(
304
+ self,
305
+ image_aug: bool = True,
306
+ audio_aug: bool = True,
307
+ video_aug: bool = True,
308
+ use_mixup: bool = True,
309
+ use_cutmix: bool = True,
310
+ num_classes: Optional[int] = None
311
+ ):
312
+ super().__init__()
313
+ self.image_aug = RandAugment() if image_aug else None
314
+ self.audio_aug = SpecAugment() if audio_aug else None
315
+ self.video_aug = TemporalMasking() if video_aug else None
316
+
317
+ self.mixup = MixUp(num_classes=num_classes) if use_mixup else None
318
+ self.cutmix = CutMix(num_classes=num_classes) if use_cutmix else None
319
+
320
+ def forward(
321
+ self,
322
+ data: torch.Tensor,
323
+ modality: str,
324
+ labels: Optional[torch.Tensor] = None
325
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
326
+ """
327
+ Args:
328
+ data: 输入数据
329
+ modality: 模态类型 ('image', 'audio', 'video')
330
+ labels: 标签(可选)
331
+ """
332
+ # 1. 模态特定的增强 (Intra-sample augmentation)
333
+ if modality == 'image' and self.image_aug is not None:
334
+ data = self.image_aug(data)
335
+ elif modality == 'audio' and self.audio_aug is not None:
336
+ data = self.audio_aug(data)
337
+ elif modality == 'video' and self.video_aug is not None:
338
+ data = self.video_aug(data)
339
+
340
+ # 2. 混合增强 (Inter-sample augmentation)
341
+ if self.training and labels is not None:
342
+ # 随机选择 MixUp 或 CutMix,或者都不选
343
+ # 策略:如果有 CutMix 且是图片,50%概率 CutMix;否则看有没有 MixUp
344
+
345
+ apply_mixup = False
346
+ apply_cutmix = False
347
+
348
+ p = random.random()
349
+
350
+ # 简单的互斥逻辑:如果有CutMix且是图像,一半概率CutMix,一半概率MixUp(如果有)
351
+ if self.cutmix is not None and modality == 'image':
352
+ if p < 0.5:
353
+ apply_cutmix = True
354
+ elif self.mixup is not None:
355
+ apply_mixup = True
356
+ elif self.mixup is not None:
357
+ # 非图像或无CutMix,则只考虑MixUp
358
+ if p < 0.5: # 假设 50% 概率应用 MixUp
359
+ apply_mixup = True
360
+
361
+ if apply_cutmix:
362
+ data, labels, _ = self.cutmix(data, labels)
363
+ elif apply_mixup:
364
+ data, labels, _ = self.mixup(data, labels)
365
+
366
+ return data, labels
data_config.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_config.py
2
+ """
3
+ 预训练和后训练数据集配置
4
+ """
5
+
6
+ PRETRAIN_DATASETS = {
7
+ # 文本数据集
8
+ 'the_pile': {
9
+ 'type': 'text',
10
+ 'hf_path': 'EleutherAI/pile',
11
+ 'split': 'train',
12
+ 'streaming': True,
13
+ 'text_field': 'text',
14
+ 'weight': 1.0,
15
+ 'description': 'The Pile - 825GB diverse text corpus'
16
+ },
17
+ 'c4': {
18
+ 'type': 'text',
19
+ 'hf_path': 'allenai/c4',
20
+ 'config': 'en',
21
+ 'split': 'train',
22
+ 'streaming': True,
23
+ 'text_field': 'text',
24
+ 'weight': 0.5,
25
+ 'description': 'C4 - Colossal Clean Crawled Corpus'
26
+ },
27
+ 'wikipedia': {
28
+ 'type': 'text',
29
+ 'hf_path': 'HuggingFaceFW/fineweb-edu',
30
+ 'config': 'sample-10BT',
31
+ 'split': 'train',
32
+ 'streaming': True,
33
+ 'text_field': 'text',
34
+ 'weight': 0.3,
35
+ 'description': 'FineWeb Edu - High quality educational content'
36
+ },
37
+ 'bookcorpus': {
38
+ 'type': 'text',
39
+ 'hf_path': 'HuggingFaceTB/smollm-corpus',
40
+ 'config': 'cosmopedia-v2',
41
+ 'split': 'train',
42
+ 'streaming': True,
43
+ 'text_field': 'text',
44
+ 'weight': 0.2,
45
+ 'description': 'Synthetic textbooks and stories'
46
+ },
47
+ # 代码数据集
48
+ 'codeparrot': {
49
+ 'type': 'code',
50
+ 'hf_path': 'bigcode/the-stack-smol',
51
+ 'config': 'default',
52
+ 'split': 'train',
53
+ 'streaming': True,
54
+ 'text_field': 'content',
55
+ 'weight': 0.3,
56
+ 'description': 'The Stack Smol - code'
57
+ },
58
+ 'the_stack': {
59
+ 'type': 'code',
60
+ 'hf_path': 'bigcode/the-stack-dedup',
61
+ 'split': 'train',
62
+ 'streaming': True,
63
+ 'text_field': 'content',
64
+ 'weight': 0.2,
65
+ 'description': 'The Stack - deduplicated code'
66
+ },
67
+ # 多模态数据集
68
+ 'laion400m': {
69
+ 'type': 'image_text',
70
+ 'hf_path': 'laion/laion400m',
71
+ 'split': 'train',
72
+ 'streaming': True,
73
+ 'image_field': 'url',
74
+ 'text_field': 'caption',
75
+ 'weight': 0.4,
76
+ 'description': 'LAION-400M image-text pairs'
77
+ },
78
+ 'conceptual_captions': {
79
+ 'type': 'image_text',
80
+ 'hf_path': 'google-research-datasets/conceptual_captions',
81
+ 'split': 'train',
82
+ 'streaming': False,
83
+ 'image_field': 'image_url',
84
+ 'text_field': 'caption',
85
+ 'weight': 0.2,
86
+ 'description': 'Conceptual Captions 3M'
87
+ },
88
+ }
89
+
90
+ # 后训练数据集配置(instruction tuning + alignment)
91
+ POSTTRAIN_DATASETS = {
92
+ # Instruction Tuning数据集
93
+ 'flan_v2': {
94
+ 'type': 'instruction',
95
+ 'hf_path': 'Muennighoff/flan',
96
+ 'split': 'train',
97
+ 'streaming': True,
98
+ 'instruction_field': 'inputs',
99
+ 'response_field': 'targets',
100
+ 'weight': 1.0,
101
+ 'max_samples': 100000,
102
+ 'description': 'FLAN v2 collection'
103
+ },
104
+ 'alpaca': {
105
+ 'type': 'instruction',
106
+ 'hf_path': 'tatsu-lab/alpaca',
107
+ 'split': 'train',
108
+ 'streaming': False,
109
+ 'instruction_field': 'instruction',
110
+ 'input_field': 'input',
111
+ 'response_field': 'output',
112
+ 'weight': 0.5,
113
+ 'description': 'Stanford Alpaca 52K'
114
+ },
115
+ 'dolly': {
116
+ 'type': 'instruction',
117
+ 'hf_path': 'databricks/databricks-dolly-15k',
118
+ 'split': 'train',
119
+ 'streaming': False,
120
+ 'instruction_field': 'instruction',
121
+ 'context_field': 'context', # Dolly有context字段
122
+ 'response_field': 'response',
123
+ 'weight': 0.3,
124
+ 'description': 'Dolly 15K'
125
+ },
126
+ 'oasst1': {
127
+ 'type': 'conversation',
128
+ 'hf_path': 'OpenAssistant/oasst1',
129
+ 'split': 'train',
130
+ 'streaming': False,
131
+ 'weight': 0.4,
132
+ 'description': 'OpenAssistant Conversations',
133
+ # OASST1需要特殊处理,因为它是树形结构
134
+ # 可能需要自定义预处理
135
+ },
136
+ 'sharegpt': {
137
+ 'type': 'conversation',
138
+ 'hf_path': 'anon8231489123/ShareGPT_Vicuna_unfiltered',
139
+ 'split': 'train',
140
+ 'streaming': False,
141
+ 'weight': 0.3,
142
+ 'max_samples': 50000,
143
+ 'description': 'ShareGPT conversations'
144
+ },
145
+ # Code instruction数据集
146
+ 'code_alpaca': {
147
+ 'type': 'code_instruction',
148
+ 'hf_path': 'sahil2801/CodeAlpaca-20k',
149
+ 'split': 'train',
150
+ 'streaming': False,
151
+ 'instruction_field': 'instruction',
152
+ 'response_field': 'output',
153
+ 'weight': 0.3,
154
+ 'description': 'Code Alpaca 20K'
155
+ },
156
+ # 多模态instruction数据集
157
+ 'llava_instruct': {
158
+ 'type': 'multimodal_instruction',
159
+ 'hf_path': 'liuhaotian/LLaVA-Instruct-150K',
160
+ 'split': 'train',
161
+ 'streaming': False,
162
+ 'image_field': 'image',
163
+ 'instruction_field': 'conversations',
164
+ 'weight': 0.5,
165
+ 'description': 'LLaVA visual instruction tuning'
166
+ },
167
+ # Preference数据集 (用于RLHF)
168
+ 'hh_rlhf': {
169
+ 'type': 'preference',
170
+ 'hf_path': 'Anthropic/hh-rlhf',
171
+ 'split': 'train',
172
+ 'streaming': False,
173
+ 'chosen_field': 'chosen',
174
+ 'rejected_field': 'rejected',
175
+ 'weight': 1.0,
176
+ 'description': 'Anthropic HH-RLHF'
177
+ },
178
+ 'ultrafeedback': {
179
+ 'type': 'preference',
180
+ 'hf_path': 'openbmb/UltraFeedback',
181
+ 'split': 'train',
182
+ 'streaming': True,
183
+ 'chosen_field': 'chosen', # 添加字段配置
184
+ 'rejected_field': 'rejected',
185
+ 'weight': 0.5,
186
+ 'max_samples': 50000,
187
+ 'description': 'UltraFeedback preferences'
188
+ },
189
+ 'debug_water': {
190
+ 'type': 'instruction',
191
+ 'hf_path': 'json', # 使用 json 加载器
192
+ 'data_files': 'debug_water.json', # 指向刚才生成的文件
193
+ 'split': 'train',
194
+ 'streaming': False,
195
+ 'instruction_field': 'instruction',
196
+ 'response_field': 'output',
197
+ 'weight': 1.0,
198
+ 'description': 'Overfitting test for water'
199
+ },
200
+ }
201
+
202
+ # 轻量级测试数据集(用于快速验证)
203
+ TEST_DATASETS = {
204
+ 'tiny_shakespeare': {
205
+ 'type': 'text',
206
+ 'hf_path': 'tiny_shakespeare',
207
+ 'split': 'train',
208
+ 'streaming': False,
209
+ 'text_field': 'text',
210
+ 'weight': 1.0,
211
+ 'description': 'Tiny Shakespeare for testing'
212
+ },
213
+ 'gsm8k': {
214
+ 'type': 'instruction',
215
+ 'hf_path': 'gsm8k',
216
+ 'config': 'main',
217
+ 'split': 'train',
218
+ 'streaming': False,
219
+ 'instruction_field': 'question',
220
+ 'response_field': 'answer',
221
+ 'weight': 1.0,
222
+ 'description': 'GSM8K math problems'
223
+ },
224
+ }
225
+
226
+ # 数据集混合策略
227
+ PRETRAIN_MIX = {
228
+ 'default': {
229
+ 'datasets': ['c4', 'wikipedia', 'bookcorpus', 'codeparrot'],
230
+ 'weights': [0.5, 0.2, 0.2, 0.1],
231
+ 'description': 'Default pretrain mix'
232
+ },
233
+ 'code_heavy': {
234
+ 'datasets': ['c4', 'codeparrot', 'the_stack', 'wikipedia'],
235
+ 'weights': [0.3, 0.4, 0.2, 0.1],
236
+ 'description': 'Code-heavy mix'
237
+ },
238
+ 'multimodal': {
239
+ 'datasets': ['c4', 'wikipedia', 'laion400m', 'conceptual_captions'],
240
+ 'weights': [0.4, 0.2, 0.3, 0.1],
241
+ 'description': 'Multimodal mix'
242
+ },
243
+ 'text_only': {
244
+ 'datasets': ['c4', 'wikipedia', 'bookcorpus'],
245
+ 'weights': [0.5, 0.3, 0.2],
246
+ 'description': 'Text-only mix for testing'
247
+ },
248
+ }
249
+
250
+ POSTTRAIN_MIX = {
251
+ 'default': {
252
+ 'datasets': ['flan_v2', 'alpaca', 'dolly', 'oasst1'],
253
+ 'weights': [0.4, 0.3, 0.2, 0.1],
254
+ 'description': 'Default instruction tuning mix'
255
+ },
256
+ 'conversation': {
257
+ 'datasets': ['oasst1', 'sharegpt', 'alpaca'],
258
+ 'weights': [0.4, 0.4, 0.2],
259
+ 'description': 'Conversation-focused mix'
260
+ },
261
+ 'code_instruct': {
262
+ 'datasets': ['code_alpaca', 'alpaca', 'flan_v2'],
263
+ 'weights': [0.5, 0.3, 0.2],
264
+ 'description': 'Code instruction mix'
265
+ },
266
+ 'simple_instruct': {
267
+ 'datasets': ['alpaca', 'dolly'],
268
+ 'weights': [0.6, 0.4],
269
+ 'description': 'Simple instruction mix for testing'
270
+ },
271
+ 'debug_mix': {
272
+ 'datasets': ['debug_water'],
273
+ 'weights': [1.0],
274
+ 'description': 'Debug mix for overfitting'
275
+ },
276
+ }
277
+
278
+ # 下载和缓存配置
279
+ DATASET_CACHE_DIR = "./dataset_cache"
280
+ HF_CACHE_DIR = "./hf_cache"
281
+ MAX_RETRIES = 3
282
+ DOWNLOAD_TIMEOUT = 300
283
+
284
+ # 数据处理配置
285
+ PREPROCESSING_CONFIG = {
286
+ 'max_seq_length': 2048,
287
+ 'min_seq_length': 32,
288
+ 'num_workers': 4,
289
+ 'batch_size': 8,
290
+ 'shuffle_buffer_size': 10000,
291
+ 'seed': 42,
292
+ }
data_loader.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_loader.py
2
+ """
3
+ 改进的数据加载器 - 支持预训练和后训练数据集
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
8
+ from datasets import load_dataset, concatenate_datasets, interleave_datasets
9
+ from typing import Dict, List, Optional, Any, Union
10
+ import random
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import warnings
14
+ from PIL import Image
15
+ import requests
16
+ from io import BytesIO
17
+ from torchvision import transforms
18
+ import logging
19
+
20
+ # 设置日志
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ warnings.filterwarnings("ignore", category=UserWarning)
25
+
26
+ from data_config import (
27
+ PRETRAIN_DATASETS,
28
+ POSTTRAIN_DATASETS,
29
+ TEST_DATASETS,
30
+ PRETRAIN_MIX,
31
+ POSTTRAIN_MIX,
32
+ PREPROCESSING_CONFIG,
33
+ DATASET_CACHE_DIR,
34
+ HF_CACHE_DIR
35
+ )
36
+
37
+ # 图像变换
38
+ image_transform = transforms.Compose([
39
+ transforms.Resize((224, 224)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42
+ ])
43
+
44
+ class PreTrainDataset(IterableDataset):
45
+ """预训练数据集 - 支持流式和混合采样"""
46
+ def __init__(
47
+ self,
48
+ mix_name: str = 'default',
49
+ tokenizer=None,
50
+ max_length: int = 2048,
51
+ streaming: bool = True,
52
+ seed: int = 42,
53
+ max_samples: Optional[int] = None
54
+ ):
55
+ super().__init__()
56
+
57
+ if tokenizer is None:
58
+ raise ValueError("tokenizer cannot be None")
59
+
60
+ self.tokenizer = tokenizer
61
+ self.max_length = max_length
62
+ self.streaming = streaming
63
+ self.seed = seed
64
+ self.max_samples = max_samples
65
+ self.samples_generated = 0
66
+
67
+ # 获取混合配置
68
+ if mix_name not in PRETRAIN_MIX:
69
+ raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
70
+
71
+ mix_config = PRETRAIN_MIX[mix_name]
72
+ dataset_names = mix_config.get('datasets', [])
73
+ weights = mix_config.get('weights', [])
74
+
75
+ if not dataset_names:
76
+ raise ValueError(f"No datasets found in mix: {mix_name}")
77
+
78
+ logger.info(f"Loading pretrain mix: {mix_name}")
79
+ logger.info(f" Datasets: {dataset_names}")
80
+ logger.info(f" Weights: {weights}")
81
+
82
+ # 加载数据集
83
+ self.datasets = []
84
+ self.probabilities = []
85
+
86
+ for name, weight in zip(dataset_names, weights):
87
+ if name not in PRETRAIN_DATASETS:
88
+ logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping")
89
+ continue
90
+
91
+ config = PRETRAIN_DATASETS[name]
92
+ try:
93
+ ds = self._load_dataset(config)
94
+ if ds is not None:
95
+ self.datasets.append((name, ds, config))
96
+ self.probabilities.append(weight)
97
+ logger.info(f" Successfully loaded {name}")
98
+ except Exception as e:
99
+ logger.error(f"Error loading {name}: {e}")
100
+ continue
101
+
102
+ if not self.datasets:
103
+ raise ValueError("No datasets loaded successfully")
104
+
105
+ # 归一化概率
106
+ total = sum(self.probabilities)
107
+ self.probabilities = [p / total for p in self.probabilities]
108
+
109
+ logger.info(f"Successfully loaded {len(self.datasets)} datasets")
110
+
111
+ def _load_dataset(self, config: Dict):
112
+ """加载单个数据集"""
113
+ try:
114
+ load_kwargs = {
115
+ 'path': config['hf_path'],
116
+ 'split': config.get('split', 'train'),
117
+ 'streaming': config.get('streaming', self.streaming),
118
+ 'cache_dir': HF_CACHE_DIR,
119
+ }
120
+
121
+ # 添加config参数(如果存在)
122
+ if 'config' in config:
123
+ load_kwargs['name'] = config['config']
124
+
125
+ ds = load_dataset(**load_kwargs)
126
+ return ds
127
+ except Exception as e:
128
+ logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
129
+ return None
130
+
131
+ def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
132
+ """处理文本样本"""
133
+ try:
134
+ text_field = config.get('text_field', 'text')
135
+ text = sample.get(text_field, '')
136
+
137
+ if not text or not isinstance(text, str):
138
+ return None
139
+
140
+ text = text.strip()
141
+ if len(text) < 10:
142
+ return None
143
+
144
+ # Tokenize
145
+ encoding = self.tokenizer(
146
+ text,
147
+ max_length=self.max_length,
148
+ truncation=True,
149
+ padding='max_length',
150
+ return_tensors='pt'
151
+ )
152
+
153
+ return {
154
+ 'input_ids': encoding['input_ids'].squeeze(0),
155
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
156
+ 'type': 'text'
157
+ }
158
+ except Exception as e:
159
+ logger.debug(f"Error processing text sample: {e}")
160
+ return None
161
+
162
+ def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
163
+ """处理图像-文本样本"""
164
+ try:
165
+ text_field = config.get('text_field', 'caption')
166
+ image_field = config.get('image_field', 'image')
167
+
168
+ text = sample.get(text_field, '')
169
+ image = sample.get(image_field)
170
+
171
+ if not text or image is None:
172
+ return None
173
+
174
+ # 处理图像
175
+ if isinstance(image, str):
176
+ # URL - 添加超时和错误处理
177
+ try:
178
+ response = requests.get(image, timeout=5)
179
+ image = Image.open(BytesIO(response.content)).convert('RGB')
180
+ except Exception as img_error:
181
+ logger.debug(f"Failed to load image from URL: {img_error}")
182
+ return None
183
+ elif isinstance(image, Image.Image):
184
+ image = image.convert('RGB')
185
+ else:
186
+ return None
187
+
188
+ # 转换图像
189
+ image_tensor = image_transform(image)
190
+
191
+ # Tokenize文本
192
+ encoding = self.tokenizer(
193
+ text,
194
+ max_length=self.max_length,
195
+ truncation=True,
196
+ padding='max_length',
197
+ return_tensors='pt'
198
+ )
199
+
200
+ return {
201
+ 'input_ids': encoding['input_ids'].squeeze(0),
202
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
203
+ 'image': image_tensor,
204
+ 'type': 'image_text'
205
+ }
206
+ except Exception as e:
207
+ logger.debug(f"Error processing image-text sample: {e}")
208
+ return None
209
+
210
+ def __iter__(self):
211
+ """迭代器"""
212
+ worker_info = torch.utils.data.get_worker_info()
213
+ if worker_info is not None:
214
+ # 多worker时设置不同的随机种子
215
+ random.seed(self.seed + worker_info.id)
216
+ np.random.seed(self.seed + worker_info.id)
217
+ else:
218
+ random.seed(self.seed)
219
+ np.random.seed(self.seed)
220
+
221
+ # 创建数据集迭代器
222
+ iterators = [iter(ds) for _, ds, _ in self.datasets]
223
+ self.samples_generated = 0
224
+
225
+ while True:
226
+ # 检查是否达到最大样本数
227
+ if self.max_samples and self.samples_generated >= self.max_samples:
228
+ break
229
+
230
+ try:
231
+ # 根据概率选择数据集
232
+ idx = np.random.choice(len(self.datasets), p=self.probabilities)
233
+ name, _, config = self.datasets[idx]
234
+
235
+ # 从选中的数据集获取样本
236
+ sample = next(iterators[idx])
237
+
238
+ # 处理样本
239
+ processed = None
240
+ if config.get('type') in ['text', 'code']:
241
+ processed = self._process_text_sample(sample, config)
242
+ elif config.get('type') == 'image_text':
243
+ processed = self._process_image_text_sample(sample, config)
244
+ else:
245
+ logger.debug(f"Unknown type: {config.get('type')}")
246
+ continue
247
+
248
+ if processed is not None:
249
+ self.samples_generated += 1
250
+ yield processed
251
+
252
+ except StopIteration:
253
+ # 重新创建迭代器
254
+ try:
255
+ iterators[idx] = iter(self.datasets[idx][1])
256
+ except Exception as e:
257
+ logger.error(f"Failed to recreate iterator for dataset {idx}: {e}")
258
+ break
259
+ except Exception as e:
260
+ logger.debug(f"Error in iterator: {e}")
261
+ continue
262
+
263
+
264
+ class PostTrainDataset(Dataset):
265
+ """后训练数据集 - Instruction tuning和对话"""
266
+ def __init__(
267
+ self,
268
+ mix_name: str = 'default',
269
+ tokenizer=None,
270
+ max_length: int = 2048,
271
+ max_samples: Optional[int] = None,
272
+ split: str = 'train'
273
+ ):
274
+ super().__init__()
275
+
276
+ if tokenizer is None:
277
+ raise ValueError("tokenizer cannot be None")
278
+
279
+ self.tokenizer = tokenizer
280
+ self.max_length = max_length
281
+ self.split = split
282
+
283
+ # 获取混合配置
284
+ if mix_name not in POSTTRAIN_MIX:
285
+ raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
286
+
287
+ mix_config = POSTTRAIN_MIX[mix_name]
288
+ dataset_names = mix_config.get('datasets', [])
289
+ weights = mix_config.get('weights', [])
290
+
291
+ if not dataset_names:
292
+ raise ValueError(f"No datasets found in mix: {mix_name}")
293
+
294
+ logger.info(f"Loading posttrain mix: {mix_name}")
295
+ logger.info(f" Datasets: {dataset_names}")
296
+
297
+ # 加载和合并数据集
298
+ all_datasets = []
299
+
300
+ for name in dataset_names:
301
+ if name not in POSTTRAIN_DATASETS:
302
+ logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS")
303
+ continue
304
+
305
+ config = POSTTRAIN_DATASETS[name]
306
+ try:
307
+ load_kwargs = {
308
+ 'path': config['hf_path'],
309
+ 'split': split,
310
+ 'streaming': config.get('streaming', False),
311
+ 'cache_dir': HF_CACHE_DIR,
312
+ }
313
+ # [新增] 如果配置里有 data_files,就加进去
314
+ if 'data_files' in config:
315
+ load_kwargs['data_files'] = config['data_files']
316
+ # 添加config参数(如果存在)
317
+ if 'config' in config:
318
+ load_kwargs['name'] = config['config']
319
+
320
+ ds = load_dataset(**load_kwargs)
321
+
322
+ # 限制样本数
323
+ if config.get('max_samples'):
324
+ if hasattr(ds, 'take'):
325
+ ds = ds.take(config['max_samples'])
326
+ elif hasattr(ds, 'select'):
327
+ ds = ds.select(range(min(len(ds), config['max_samples'])))
328
+
329
+ # 添加数据集标识
330
+ def add_source(example):
331
+ example['_source'] = name
332
+ example['_config'] = config
333
+ return example
334
+
335
+ ds = ds.map(add_source)
336
+ all_datasets.append(ds)
337
+
338
+ ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming'
339
+ logger.info(f" Loaded {name}: {ds_len} samples")
340
+
341
+ except Exception as e:
342
+ logger.error(f"Error loading {name}: {e}")
343
+ continue
344
+
345
+ # 合并数据集
346
+ if not all_datasets:
347
+ raise ValueError("No datasets loaded successfully")
348
+
349
+ if len(all_datasets) == 1:
350
+ self.dataset = all_datasets[0]
351
+ else:
352
+ # 交织数据集
353
+ probabilities = [w / sum(weights[:len(all_datasets)])
354
+ for w in weights[:len(all_datasets)]]
355
+ self.dataset = interleave_datasets(
356
+ all_datasets,
357
+ probabilities=probabilities,
358
+ seed=42,
359
+ stopping_strategy='all_exhausted'
360
+ )
361
+
362
+ # 限制总样本数
363
+ if max_samples and hasattr(self.dataset, '__len__'):
364
+ actual_len = min(len(self.dataset), max_samples)
365
+ self.dataset = self.dataset.select(range(actual_len))
366
+
367
+ dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming'
368
+ logger.info(f"Total samples: {dataset_len}")
369
+
370
+ def _format_instruction(self, sample: Dict, config: Dict) -> str:
371
+ """格式化instruction"""
372
+ try:
373
+ data_type = config.get('type', 'instruction')
374
+
375
+ if data_type == 'instruction':
376
+ instruction_field = config.get('instruction_field', 'instruction')
377
+ input_field = config.get('input_field', 'input')
378
+ context_field = config.get('context_field', None)
379
+
380
+ instruction = sample.get(instruction_field, '')
381
+ input_text = sample.get(input_field, '')
382
+ context = sample.get(context_field, '') if context_field else ''
383
+
384
+ # 构建prompt
385
+ prompt_parts = [f"Instruction: {instruction}"]
386
+
387
+ if context:
388
+ prompt_parts.append(f"Context: {context}")
389
+
390
+ if input_text:
391
+ prompt_parts.append(f"Input: {input_text}")
392
+
393
+ prompt_parts.append("Response:")
394
+ return "\n".join(prompt_parts)
395
+
396
+ elif data_type == 'conversation':
397
+ # 处理对话格式 - 支持不同的对话格式
398
+ if 'conversations' in sample:
399
+ # LLaVA格式
400
+ conversations = sample['conversations']
401
+ if isinstance(conversations, list) and len(conversations) > 0:
402
+ dialogue = []
403
+ for conv in conversations[:-1]:
404
+ role = conv.get('from', 'user')
405
+ content = conv.get('value', '')
406
+ dialogue.append(f"{role}: {content}")
407
+ return "\n".join(dialogue) + "\nassistant:"
408
+
409
+ elif 'messages' in sample:
410
+ # 标准消息格式
411
+ messages = sample['messages']
412
+ if isinstance(messages, list) and len(messages) > 0:
413
+ dialogue = []
414
+ for msg in messages[:-1]:
415
+ role = msg.get('role', 'user')
416
+ content = msg.get('content', '')
417
+ dialogue.append(f"{role}: {content}")
418
+ return "\n".join(dialogue) + "\nassistant:"
419
+
420
+ # 如果没有标准格式,尝试使用text字段
421
+ return sample.get('text', '')
422
+
423
+ elif data_type == 'code_instruction':
424
+ # 代码instruction格式
425
+ instruction_field = config.get('instruction_field', 'instruction')
426
+ instruction = sample.get(instruction_field, '')
427
+ return f"### Instruction:\n{instruction}\n### Response:"
428
+
429
+ elif data_type == 'multimodal_instruction':
430
+ # 多模态instruction
431
+ instruction_field = config.get('instruction_field', 'conversations')
432
+ conversations = sample.get(instruction_field, [])
433
+ if isinstance(conversations, list) and len(conversations) > 0:
434
+ # 提取对话历史(除了最后一条回复)
435
+ dialogue = []
436
+ for conv in conversations[:-1]:
437
+ role = conv.get('from', 'user')
438
+ content = conv.get('value', '')
439
+ dialogue.append(f"{role}: {content}")
440
+ return "\n".join(dialogue) + "\nassistant:"
441
+ return ""
442
+
443
+ else:
444
+ return sample.get(config.get('instruction_field', 'text'), '')
445
+ except Exception as e:
446
+ logger.debug(f"Error formatting instruction: {e}")
447
+ return ""
448
+
449
+ def _get_response(self, sample: Dict, config: Dict) -> str:
450
+ """获取响应"""
451
+ try:
452
+ data_type = config.get('type', 'instruction')
453
+
454
+ if data_type == 'instruction' or data_type == 'code_instruction':
455
+ response_field = config.get('response_field', 'output')
456
+ return sample.get(response_field, '')
457
+
458
+ elif data_type == 'conversation':
459
+ # 从对话中提取最后一条assistant的回复
460
+ if 'conversations' in sample:
461
+ conversations = sample['conversations']
462
+ if isinstance(conversations, list) and len(conversations) > 0:
463
+ return conversations[-1].get('value', '')
464
+
465
+ elif 'messages' in sample:
466
+ messages = sample['messages']
467
+ if isinstance(messages, list) and len(messages) > 0:
468
+ return messages[-1].get('content', '')
469
+
470
+ return ""
471
+
472
+ elif data_type == 'multimodal_instruction':
473
+ instruction_field = config.get('instruction_field', 'conversations')
474
+ conversations = sample.get(instruction_field, [])
475
+ if isinstance(conversations, list) and len(conversations) > 0:
476
+ return conversations[-1].get('value', '')
477
+ return ""
478
+
479
+ else:
480
+ response_field = config.get('response_field', 'output')
481
+ return sample.get(response_field, '')
482
+ except Exception as e:
483
+ logger.debug(f"Error getting response: {e}")
484
+ return ""
485
+
486
+ def __len__(self):
487
+ return len(self.dataset) if hasattr(self.dataset, '__len__') else 0
488
+
489
+ def __getitem__(self, idx):
490
+ try:
491
+ sample = self.dataset[idx]
492
+
493
+ # 获取配置
494
+ if '_config' not in sample:
495
+ logger.warning(f"Sample at index {idx} missing _config")
496
+ return None
497
+
498
+ config = sample['_config']
499
+
500
+ # 格式化 instruction 和 response
501
+ instruction_text = self._format_instruction(sample, config)
502
+ response_text = self._get_response(sample, config)
503
+
504
+ if not instruction_text or not response_text:
505
+ return None
506
+
507
+ # 确保 pad_token_id 存在
508
+ pad_token_id = self.tokenizer.pad_token_id
509
+ if pad_token_id is None:
510
+ pad_token_id = self.tokenizer.eos_token_id
511
+
512
+ # =======================================================
513
+ # 1. 处理 Instruction (不需要 EOS,因为后面紧接 Response)
514
+ # =======================================================
515
+ instruction_max_len = self.max_length // 2
516
+
517
+ # Tokenize 不做 padding,手动处理
518
+ instruction_enc = self.tokenizer(
519
+ instruction_text,
520
+ truncation=True,
521
+ max_length=instruction_max_len,
522
+ add_special_tokens=False, # 手动控制特殊token
523
+ return_tensors='pt'
524
+ )
525
+ instr_ids = instruction_enc['input_ids'].squeeze(0)
526
+
527
+ # Instruction 手动 Padding
528
+ instr_len = instr_ids.size(0)
529
+ if instr_len < instruction_max_len:
530
+ # 左填充或者右填充皆可,通常 SFT 这里的 Instruction 是右填充
531
+ # padding_tensor = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
532
+ # instr_ids = torch.cat([instr_ids, padding_tensor])
533
+ # 为了保持代码与原逻辑一致,这里使用右填充至固定长度
534
+ padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
535
+ instr_ids = torch.cat([instr_ids, padding])
536
+
537
+ # Mask: 真实token为1,pad为0
538
+ instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)])
539
+ else:
540
+ instr_mask = torch.ones(instruction_max_len, dtype=torch.long)
541
+
542
+ # =======================================================
543
+ # 2. 处理 Response (【核心修复】:必须加 EOS)
544
+ # =======================================================
545
+ response_max_len = self.max_length // 2
546
+
547
+ # Tokenize: 预留1个位置给EOS
548
+ response_enc = self.tokenizer(
549
+ response_text,
550
+ truncation=True,
551
+ max_length=response_max_len - 1, # 关键:留一个位置给 EOS
552
+ add_special_tokens=False,
553
+ return_tensors='pt'
554
+ )
555
+ resp_ids = response_enc['input_ids'].squeeze(0)
556
+
557
+ # 【强制添加 EOS Token】
558
+ eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
559
+ resp_ids = torch.cat([resp_ids, eos_token])
560
+
561
+ # Response 手动 Padding
562
+ curr_resp_len = resp_ids.size(0)
563
+ if curr_resp_len < response_max_len:
564
+ padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long)
565
+ resp_ids = torch.cat([resp_ids, padding])
566
+
567
+ # Mask: 真实内容+EOS 为1,Pad 为0
568
+ resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)])
569
+ else:
570
+ resp_mask = torch.ones(response_max_len, dtype=torch.long)
571
+
572
+ # =======================================================
573
+ # 3. 组装结果
574
+ # =======================================================
575
+ result = {
576
+ 'instruction': instr_ids,
577
+ 'response': resp_ids,
578
+ 'instruction_mask': instr_mask,
579
+ 'response_mask': resp_mask,
580
+ 'task': sample.get('_source', 'unknown'),
581
+ 'modality_data': None
582
+ }
583
+
584
+ # 如果是多模态数据,添加图像
585
+ if config.get('type') == 'multimodal_instruction' and 'image' in sample:
586
+ try:
587
+ image = sample['image']
588
+ if isinstance(image, Image.Image):
589
+ image = image.convert('RGB')
590
+ image_tensor = image_transform(image)
591
+ result['modality_data'] = {'image': image_tensor}
592
+ except Exception as e:
593
+ logger.debug(f"Error processing image: {e}")
594
+
595
+ return result
596
+
597
+ except Exception as e:
598
+ logger.debug(f"Error getting item at index {idx}: {e}")
599
+ import traceback
600
+ traceback.print_exc()
601
+ return None
602
+
603
+
604
+ class PreferenceDataset(Dataset):
605
+ """偏好数据集 - 用于RLHF"""
606
+ def __init__(
607
+ self,
608
+ dataset_name: str = 'hh_rlhf',
609
+ tokenizer=None,
610
+ max_length: int = 1024,
611
+ max_samples: Optional[int] = None,
612
+ split: str = 'train'
613
+ ):
614
+ super().__init__()
615
+
616
+ if tokenizer is None:
617
+ raise ValueError("tokenizer cannot be None")
618
+
619
+ self.tokenizer = tokenizer
620
+ self.max_length = max_length
621
+
622
+ if dataset_name not in POSTTRAIN_DATASETS:
623
+ raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}")
624
+
625
+ config = POSTTRAIN_DATASETS[dataset_name]
626
+ if config.get('type') != 'preference':
627
+ raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})")
628
+
629
+ logger.info(f"Loading preference dataset: {dataset_name}")
630
+
631
+ load_kwargs = {
632
+ 'path': config['hf_path'],
633
+ 'split': split,
634
+ 'cache_dir': HF_CACHE_DIR,
635
+ }
636
+
637
+ # 添加config参数(如果存在)
638
+ if 'config' in config:
639
+ load_kwargs['name'] = config['config']
640
+
641
+ self.dataset = load_dataset(**load_kwargs)
642
+
643
+ self.chosen_field = config.get('chosen_field', 'chosen')
644
+ self.rejected_field = config.get('rejected_field', 'rejected')
645
+
646
+ if max_samples and len(self.dataset) > max_samples:
647
+ self.dataset = self.dataset.select(range(max_samples))
648
+
649
+ logger.info(f"Loaded {len(self.dataset)} preference pairs")
650
+
651
+ def __len__(self):
652
+ return len(self.dataset)
653
+
654
+ def __getitem__(self, idx):
655
+ try:
656
+ sample = self.dataset[idx]
657
+
658
+ chosen_text = sample.get(self.chosen_field, '')
659
+ rejected_text = sample.get(self.rejected_field, '')
660
+
661
+ if not chosen_text or not rejected_text:
662
+ return None
663
+
664
+ # Tokenize
665
+ chosen_enc = self.tokenizer(
666
+ chosen_text,
667
+ max_length=self.max_length,
668
+ truncation=True,
669
+ padding='max_length',
670
+ return_tensors='pt'
671
+ )
672
+
673
+ rejected_enc = self.tokenizer(
674
+ rejected_text,
675
+ max_length=self.max_length,
676
+ truncation=True,
677
+ padding='max_length',
678
+ return_tensors='pt'
679
+ )
680
+
681
+ return (
682
+ chosen_enc['input_ids'].squeeze(0),
683
+ rejected_enc['input_ids'].squeeze(0),
684
+ chosen_enc['attention_mask'].squeeze(0),
685
+ rejected_enc['attention_mask'].squeeze(0)
686
+ )
687
+
688
+ except Exception as e:
689
+ logger.debug(f"Error getting preference item at index {idx}: {e}")
690
+ return None
691
+
692
+
693
+ def collate_fn_v2(batch):
694
+ """改进的collate函数"""
695
+ # 过滤None
696
+ batch = [item for item in batch if item is not None]
697
+
698
+ if not batch:
699
+ logger.warning("Empty batch after filtering None values")
700
+ # 返回一个空的占位batch而不是None
701
+ return {
702
+ 'input_ids': torch.empty(0),
703
+ 'attention_mask': torch.empty(0)
704
+ }
705
+
706
+ # 检查是否是preference数据
707
+ if isinstance(batch[0], tuple):
708
+ if len(batch[0]) == 4: # 包含attention_mask
709
+ chosen = torch.stack([item[0] for item in batch])
710
+ rejected = torch.stack([item[1] for item in batch])
711
+ chosen_mask = torch.stack([item[2] for item in batch])
712
+ rejected_mask = torch.stack([item[3] for item in batch])
713
+ return {
714
+ 'chosen': chosen,
715
+ 'rejected': rejected,
716
+ 'chosen_mask': chosen_mask,
717
+ 'rejected_mask': rejected_mask
718
+ }
719
+ else: # 旧格式兼容
720
+ chosen = torch.stack([item[0] for item in batch])
721
+ rejected = torch.stack([item[1] for item in batch])
722
+ return {'chosen': chosen, 'rejected': rejected}
723
+
724
+ # 普通数据
725
+ keys = batch[0].keys()
726
+ collated = {}
727
+
728
+ for key in keys:
729
+ if key in ['instruction', 'response', 'instruction_mask',
730
+ 'response_mask', 'input_ids', 'attention_mask']:
731
+ tensors = [item[key] for item in batch if item.get(key) is not None]
732
+ if tensors:
733
+ collated[key] = torch.stack(tensors)
734
+ else:
735
+ collated[key] = None
736
+ elif key == 'modality_data':
737
+ # 处理多模态数据
738
+ modality_list = [item[key] for item in batch if item.get(key) is not None]
739
+ if modality_list and any(m is not None for m in modality_list):
740
+ # 收集图像
741
+ images = [m.get('image') for m in modality_list if m and 'image' in m]
742
+ if images:
743
+ collated[key] = {'image': torch.stack(images)}
744
+ else:
745
+ collated[key] = None
746
+ else:
747
+ collated[key] = None
748
+ else:
749
+ collated[key] = [item[key] for item in batch]
750
+
751
+ return collated
752
+
753
+
754
+ def create_pretrain_dataloader(
755
+ mix_name: str = 'default',
756
+ tokenizer=None,
757
+ batch_size: int = 8,
758
+ num_workers: int = 4,
759
+ max_length: int = 2048,
760
+ max_samples: Optional[int] = None
761
+ ):
762
+ """创建预训练数据加载器"""
763
+ dataset = PreTrainDataset(
764
+ mix_name=mix_name,
765
+ tokenizer=tokenizer,
766
+ max_length=max_length,
767
+ streaming=True,
768
+ max_samples=max_samples
769
+ )
770
+ return DataLoader(
771
+ dataset,
772
+ batch_size=batch_size,
773
+ num_workers=num_workers,
774
+ collate_fn=collate_fn_v2
775
+ )
776
+
777
+
778
+ def create_posttrain_dataloader(
779
+ mix_name: str = 'default',
780
+ tokenizer=None,
781
+ batch_size: int = 8,
782
+ num_workers: int = 4,
783
+ max_length: int = 2048,
784
+ max_samples: Optional[int] = None,
785
+ split: str = 'train',
786
+ shuffle: bool = True
787
+ ):
788
+ """创建后训练数据加载器"""
789
+ dataset = PostTrainDataset(
790
+ mix_name=mix_name,
791
+ tokenizer=tokenizer,
792
+ max_length=max_length,
793
+ max_samples=max_samples,
794
+ split=split
795
+ )
796
+ return DataLoader(
797
+ dataset,
798
+ batch_size=batch_size,
799
+ shuffle=shuffle,
800
+ num_workers=num_workers,
801
+ collate_fn=collate_fn_v2,
802
+ pin_memory=True,
803
+ drop_last=False # 保留最后一个batch
804
+ )
805
+
806
+
807
+ def create_preference_dataloader(
808
+ dataset_name: str = 'hh_rlhf',
809
+ tokenizer=None,
810
+ batch_size: int = 8,
811
+ num_workers: int = 4,
812
+ max_length: int = 1024,
813
+ max_samples: Optional[int] = None,
814
+ split: str = 'train',
815
+ shuffle: bool = True
816
+ ):
817
+ """创建偏好数据加载器"""
818
+ dataset = PreferenceDataset(
819
+ dataset_name=dataset_name,
820
+ tokenizer=tokenizer,
821
+ max_length=max_length,
822
+ max_samples=max_samples,
823
+ split=split
824
+ )
825
+ return DataLoader(
826
+ dataset,
827
+ batch_size=batch_size,
828
+ shuffle=shuffle,
829
+ num_workers=num_workers,
830
+ collate_fn=collate_fn_v2,
831
+ pin_memory=True
832
+ )
encoders.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 改进的多模态编码器 - SOTA级别(修复版)
3
+ 集成最新的视觉、音频、视频编码技术
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Tuple, Optional
9
+ from components import RMSNorm, SwiGLU
10
+ from transformer import OptimizedTransformerBlock
11
+ import math
12
+
13
+ class LayerScale(nn.Module):
14
+ """LayerScale - 改进训练稳定性"""
15
+ def __init__(self, dim: int, init_values: float = 1e-5):
16
+ super().__init__()
17
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return x * self.gamma
21
+
22
+ class StochasticDepth(nn.Module):
23
+ """随机深度 - Drop Path"""
24
+ def __init__(self, drop_prob: float = 0.0):
25
+ super().__init__()
26
+ self.drop_prob = drop_prob
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ if not self.training or self.drop_prob == 0.0:
30
+ return x
31
+
32
+ keep_prob = 1 - self.drop_prob
33
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
34
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
35
+ random_tensor.floor_()
36
+ return x.div(keep_prob) * random_tensor
37
+
38
+ class ImprovedPatchEmbedding(nn.Module):
39
+ """改进的图像分块嵌入 - 支持重叠patch和多尺度"""
40
+ def __init__(
41
+ self,
42
+ patch_size: int = 14,
43
+ in_channels: int = 3,
44
+ embed_dim: int = 2048,
45
+ overlap: int = 0
46
+ ):
47
+ super().__init__()
48
+ self.patch_size = patch_size
49
+ stride = patch_size - overlap
50
+ self.proj = nn.Conv2d(
51
+ in_channels,
52
+ embed_dim,
53
+ kernel_size=patch_size,
54
+ stride=stride,
55
+ padding=overlap // 2
56
+ )
57
+
58
+ self.norm = RMSNorm(embed_dim)
59
+
60
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
61
+ B, C, H, W = x.shape
62
+ x = self.proj(x)
63
+ grid_size = (x.shape[2], x.shape[3])
64
+ x = x.flatten(2).transpose(1, 2)
65
+ x = self.norm(x)
66
+ return x, grid_size
67
+
68
+ class ImprovedVisionBlock(nn.Module):
69
+ """改进的Vision Transformer Block"""
70
+ def __init__(
71
+ self,
72
+ dim: int,
73
+ n_heads: int,
74
+ dropout: float = 0.0,
75
+ drop_path: float = 0.0,
76
+ use_adapter: bool = False,
77
+ adapter_dim: int = 64,
78
+ use_layer_scale: bool = True,
79
+ layer_scale_init: float = 1e-5
80
+ ):
81
+ super().__init__()
82
+ self.norm1 = RMSNorm(dim)
83
+ self.attn = nn.MultiheadAttention(
84
+ dim, n_heads, dropout=dropout, batch_first=True
85
+ )
86
+
87
+ self.norm2 = RMSNorm(dim)
88
+ self.mlp = nn.Sequential(
89
+ nn.Linear(dim, dim * 4),
90
+ nn.GELU(),
91
+ nn.Dropout(dropout),
92
+ nn.Linear(dim * 4, dim),
93
+ nn.Dropout(dropout)
94
+ )
95
+
96
+ self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity()
97
+
98
+ if use_layer_scale:
99
+ self.ls1 = LayerScale(dim, layer_scale_init)
100
+ self.ls2 = LayerScale(dim, layer_scale_init)
101
+ else:
102
+ self.ls1 = nn.Identity()
103
+ self.ls2 = nn.Identity()
104
+
105
+ # 修复:使用简单的adapter实现,避免外部依赖
106
+ if use_adapter:
107
+ self.adapter = nn.Sequential(
108
+ nn.Linear(dim, adapter_dim),
109
+ nn.GELU(),
110
+ nn.Linear(adapter_dim, dim)
111
+ )
112
+ else:
113
+ self.adapter = None
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ # 注意力
117
+ normx = self.norm1(x)
118
+ attn_out, _ = self.attn(normx, normx, normx)
119
+ x = x + self.drop_path(self.ls1(attn_out))
120
+
121
+ # MLP
122
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
123
+
124
+ # Adapter
125
+ if self.adapter is not None:
126
+ x = x + self.adapter(x)
127
+
128
+ return x
129
+
130
+ class ImprovedVisionTransformer(nn.Module):
131
+ """
132
+ 改进的视觉Transformer
133
+ - LayerScale
134
+ - Stochastic Depth
135
+ - 改进的位置编码
136
+ - 可选的Register tokens
137
+ """
138
+ def __init__(
139
+ self,
140
+ img_size: int = 224,
141
+ patch_size: int = 14,
142
+ in_channels: int = 3,
143
+ embed_dim: int = 2048,
144
+ depth: int = 24,
145
+ n_heads: int = 16,
146
+ dropout: float = 0.0,
147
+ drop_path_rate: float = 0.1,
148
+ use_register_tokens: bool = True,
149
+ num_register_tokens: int = 4,
150
+ use_adapter: bool = False,
151
+ adapter_dim: int = 64,
152
+ use_layer_scale: bool = True,
153
+ layer_scale_init: float = 1e-5
154
+ ):
155
+ super().__init__()
156
+ self.patch_size = patch_size
157
+ self.embed_dim = embed_dim
158
+ self.use_register_tokens = use_register_tokens
159
+ self.num_register_tokens = num_register_tokens if use_register_tokens else 0
160
+
161
+ # Patch embedding
162
+ self.patch_embed = ImprovedPatchEmbedding(
163
+ patch_size, in_channels, embed_dim, overlap=0
164
+ )
165
+
166
+ self.pretrain_img_size = img_size
167
+ n_patches_pretrain = (img_size // patch_size) ** 2
168
+
169
+ # CLS token
170
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
171
+
172
+ # Register tokens (DINOv2启发)
173
+ if use_register_tokens:
174
+ self.register_tokens = nn.Parameter(
175
+ torch.zeros(1, num_register_tokens, embed_dim)
176
+ )
177
+
178
+ # 修复:位置编码总数 = 1(CLS) + n_patches + register_tokens
179
+ total_tokens = 1 + n_patches_pretrain + self.num_register_tokens
180
+ self.pos_embed = nn.Parameter(
181
+ torch.zeros(1, total_tokens, embed_dim)
182
+ )
183
+ self.pos_drop = nn.Dropout(dropout)
184
+
185
+ # Stochastic depth
186
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
187
+
188
+ # Transformer blocks
189
+ self.blocks = nn.ModuleList([
190
+ ImprovedVisionBlock(
191
+ embed_dim,
192
+ n_heads,
193
+ dropout,
194
+ drop_path=dpr[i],
195
+ use_adapter=use_adapter,
196
+ adapter_dim=adapter_dim,
197
+ use_layer_scale=use_layer_scale,
198
+ layer_scale_init=layer_scale_init
199
+ )
200
+ for i in range(depth)
201
+ ])
202
+
203
+ self.norm = RMSNorm(embed_dim)
204
+ self._init_weights()
205
+
206
+ def _init_weights(self):
207
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
208
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
209
+ if self.use_register_tokens:
210
+ nn.init.trunc_normal_(self.register_tokens, std=0.02)
211
+
212
+ self.apply(self._init_module_weights)
213
+
214
+ def _init_module_weights(self, m):
215
+ if isinstance(m, nn.Linear):
216
+ nn.init.trunc_normal_(m.weight, std=0.02)
217
+ if m.bias is not None:
218
+ nn.init.zeros_(m.bias)
219
+ elif isinstance(m, nn.Conv2d):
220
+ nn.init.trunc_normal_(m.weight, std=0.02)
221
+ if m.bias is not None:
222
+ nn.init.zeros_(m.bias)
223
+ elif isinstance(m, RMSNorm):
224
+ if hasattr(m, 'weight') and m.weight is not None:
225
+ nn.init.ones_(m.weight)
226
+
227
+ def _interpolate_pos_encoding(
228
+ self,
229
+ patch_tokens: torch.Tensor,
230
+ grid_size: Tuple[int, int]
231
+ ) -> torch.Tensor:
232
+ """
233
+ 修复:改进的位置编码插值
234
+ 只对patch位置编码进行插值,CLS和register token位置编码保持不变
235
+ """
236
+ pretrain_grid_h = self.pretrain_img_size // self.patch_size
237
+ pretrain_grid_w = pretrain_grid_h
238
+
239
+ # 如果尺寸匹配,直接返回原始位置编码
240
+ if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w:
241
+ return self.pos_embed
242
+
243
+ # 分离不同部分的位置编码
244
+ # pos_embed结构: [CLS(1), register_tokens(n), patches(H*W)]
245
+ num_extra_tokens = 1 + self.num_register_tokens
246
+ cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] # [1, 1+n, dim]
247
+ patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] # [1, H*W, dim]
248
+
249
+ # 2D插值patch位置编码
250
+ patch_pos_embed = patch_pos_embed.reshape(
251
+ 1, pretrain_grid_h, pretrain_grid_w, -1
252
+ ).permute(0, 3, 1, 2)
253
+
254
+ patch_pos_embed = F.interpolate(
255
+ patch_pos_embed,
256
+ size=grid_size,
257
+ mode='bicubic',
258
+ align_corners=False
259
+ )
260
+
261
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
262
+
263
+ # 拼接回去
264
+ return torch.cat([cls_register_pos, patch_pos_embed], dim=1)
265
+
266
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
267
+ B = x.shape[0]
268
+
269
+ # Patch embedding
270
+ x, grid_size = self.patch_embed(x)
271
+
272
+ # 添加CLS token
273
+ cls_tokens = self.cls_token.expand(B, -1, -1)
274
+
275
+ # 修复:正确组装tokens序列
276
+ if self.use_register_tokens:
277
+ register_tokens = self.register_tokens.expand(B, -1, -1)
278
+ # 顺序: [CLS, register_tokens, patches]
279
+ x = torch.cat([cls_tokens, register_tokens, x], dim=1)
280
+ else:
281
+ x = torch.cat([cls_tokens, x], dim=1)
282
+
283
+ # 位置编码(插值以适应不同尺寸)
284
+ pos_embed = self._interpolate_pos_encoding(x, grid_size)
285
+ x = self.pos_drop(x + pos_embed)
286
+
287
+ # Transformer blocks
288
+ for block in self.blocks:
289
+ x = block(x)
290
+
291
+ x = self.norm(x)
292
+
293
+ # 返回所有tokens(调用者可以选择使用CLS token或全局池化)
294
+ return x
295
+
296
+ class ImprovedAudioEncoder(nn.Module):
297
+ """
298
+ 改进的音频编码器
299
+ - 时序建模
300
+ - 频率建模
301
+ - 双流架构
302
+ """
303
+ def __init__(
304
+ self,
305
+ n_mels: int = 128,
306
+ target_length: int = 1024,
307
+ embed_dim: int = 2048,
308
+ depth: int = 12,
309
+ n_heads: int = 16,
310
+ patch_size: int = 16,
311
+ dropout: float = 0.1,
312
+ use_adapter: bool = False,
313
+ adapter_dim: int = 64,
314
+ use_dual_stream: bool = True
315
+ ):
316
+ super().__init__()
317
+ self.use_dual_stream = use_dual_stream
318
+ self.patch_size = patch_size
319
+
320
+ # 主编码器
321
+ self.patch_embed = nn.Conv2d(
322
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size
323
+ )
324
+
325
+ # 修复:计算实际的patch数量
326
+ self.n_patches_h = n_mels // patch_size
327
+ self.n_patches_w = target_length // patch_size
328
+ n_patches = self.n_patches_h * self.n_patches_w
329
+
330
+ self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim))
331
+ self.pos_drop = nn.Dropout(dropout)
332
+
333
+ # Transformer blocks
334
+ self.blocks = nn.ModuleList([
335
+ OptimizedTransformerBlock(
336
+ embed_dim, n_heads, None, None, dropout,
337
+ use_adapter=use_adapter, adapter_dim=adapter_dim
338
+ )
339
+ for _ in range(depth)
340
+ ])
341
+
342
+ # 双流:时间流和频率流
343
+ if use_dual_stream:
344
+ # 修复:使用正确的池化维度
345
+ self.temporal_pool = nn.AdaptiveAvgPool1d(1)
346
+ self.frequency_pool = nn.AdaptiveAvgPool1d(1)
347
+
348
+ self.temporal_proj = nn.Linear(embed_dim, embed_dim)
349
+ self.frequency_proj = nn.Linear(embed_dim, embed_dim)
350
+
351
+ self.fusion = nn.Linear(embed_dim * 2, embed_dim)
352
+
353
+ self.norm = RMSNorm(embed_dim)
354
+ self._init_weights()
355
+
356
+ def _init_weights(self):
357
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
358
+ self.apply(self._init_module_weights)
359
+
360
+ def _init_module_weights(self, m):
361
+ if isinstance(m, nn.Linear):
362
+ nn.init.trunc_normal_(m.weight, std=0.02)
363
+ if m.bias is not None:
364
+ nn.init.zeros_(m.bias)
365
+ elif isinstance(m, nn.Conv2d):
366
+ nn.init.trunc_normal_(m.weight, std=0.02)
367
+ if m.bias is not None:
368
+ nn.init.zeros_(m.bias)
369
+
370
+ def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
371
+ if mel_spec.ndim == 3:
372
+ mel_spec = mel_spec.unsqueeze(1)
373
+
374
+ # Patch embedding
375
+ x = self.patch_embed(mel_spec) # [B, C, H, W]
376
+ x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
377
+ x = self.pos_drop(x + self.pos_embed)
378
+
379
+ # Transformer encoding
380
+ for block in self.blocks:
381
+ x, _, _ = block(x)
382
+
383
+ x = self.norm(x)
384
+
385
+ # 修复:双流处理
386
+ if self.use_dual_stream:
387
+ B, N, C = x.shape
388
+
389
+ # 重塑为2D网格
390
+ x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w)
391
+
392
+ # 时间流:沿频率维度池化(保留时间)
393
+ temporal = x_2d.mean(dim=2) # [B, C, W]
394
+ temporal = self.temporal_pool(temporal).squeeze(-1) # [B, C]
395
+ temporal = self.temporal_proj(temporal).unsqueeze(1) # [B, 1, C]
396
+
397
+ # 频率流:沿时间维度池化(保留频率)
398
+ frequency = x_2d.mean(dim=3) # [B, C, H]
399
+ frequency = self.frequency_pool(frequency).squeeze(-1) # [B, C]
400
+ frequency = self.frequency_proj(frequency).unsqueeze(1) # [B, 1, C]
401
+
402
+ # 融合
403
+ x = self.fusion(torch.cat([temporal, frequency], dim=-1))
404
+ else:
405
+ # 简单全局平均池化
406
+ x = x.mean(dim=1, keepdim=True)
407
+
408
+ return x
409
+
410
+ class ImprovedVideoEncoder(nn.Module):
411
+ """
412
+ 改进的视频编码器
413
+ - 因果时序建模
414
+ - 时空分离注意力
415
+ - 可选的3D卷积
416
+ """
417
+ def __init__(
418
+ self,
419
+ img_size: int = 224,
420
+ patch_size: int = 14,
421
+ in_channels: int = 3,
422
+ embed_dim: int = 2048,
423
+ spatial_depth: int = 12,
424
+ temporal_depth: int = 4,
425
+ n_heads: int = 16,
426
+ num_frames: int = 16,
427
+ dropout: float = 0.1,
428
+ use_adapter: bool = False,
429
+ adapter_dim: int = 64,
430
+ use_3d_conv: bool = False
431
+ ):
432
+ super().__init__()
433
+ self.num_frames = num_frames
434
+ self.use_3d_conv = use_3d_conv
435
+ self.patch_size = patch_size
436
+ self.img_size = img_size
437
+
438
+ if use_3d_conv:
439
+ # 3D卷积处理时空信息
440
+ self.patch_embed = nn.Conv3d(
441
+ in_channels,
442
+ embed_dim,
443
+ kernel_size=(2, patch_size, patch_size),
444
+ stride=(2, patch_size, patch_size)
445
+ )
446
+ # 修复:计算3D卷积后的尺寸
447
+ self.n_temporal_patches = num_frames // 2
448
+ self.n_spatial_patches = (img_size // patch_size) ** 2
449
+ else:
450
+ # 2D卷积 + 时序建模
451
+ self.patch_embed = ImprovedPatchEmbedding(
452
+ patch_size, in_channels, embed_dim
453
+ )
454
+ self.n_spatial_patches = (img_size // patch_size) ** 2
455
+
456
+ # 空间位置编码
457
+ self.spatial_pos_embed = nn.Parameter(
458
+ torch.zeros(1, self.n_spatial_patches, embed_dim)
459
+ )
460
+ self.spatial_pos_drop = nn.Dropout(dropout)
461
+
462
+ # 空间编码器
463
+ self.spatial_blocks = nn.ModuleList([
464
+ OptimizedTransformerBlock(
465
+ embed_dim, n_heads, None, None, dropout,
466
+ use_adapter=use_adapter, adapter_dim=adapter_dim
467
+ )
468
+ for _ in range(spatial_depth)
469
+ ])
470
+
471
+ # 时间位置编码
472
+ if use_3d_conv:
473
+ self.temporal_pos_embed = nn.Parameter(
474
+ torch.zeros(1, self.n_temporal_patches, embed_dim)
475
+ )
476
+ else:
477
+ self.temporal_pos_embed = nn.Parameter(
478
+ torch.zeros(1, num_frames, embed_dim)
479
+ )
480
+ self.temporal_pos_drop = nn.Dropout(dropout)
481
+
482
+ # 时序编码器
483
+ self.temporal_blocks = nn.ModuleList([
484
+ OptimizedTransformerBlock(
485
+ embed_dim, n_heads, None, None, dropout,
486
+ use_adapter=use_adapter, adapter_dim=adapter_dim
487
+ )
488
+ for _ in range(temporal_depth)
489
+ ])
490
+
491
+ self.norm = RMSNorm(embed_dim)
492
+ self._init_weights()
493
+
494
+ def _init_weights(self):
495
+ nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02)
496
+ nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02)
497
+ self.apply(self._init_module_weights)
498
+
499
+ def _init_module_weights(self, m):
500
+ if isinstance(m, nn.Linear):
501
+ nn.init.trunc_normal_(m.weight, std=0.02)
502
+ if m.bias is not None:
503
+ nn.init.zeros_(m.bias)
504
+ elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
505
+ nn.init.trunc_normal_(m.weight, std=0.02)
506
+ if m.bias is not None:
507
+ nn.init.zeros_(m.bias)
508
+
509
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
510
+ B, T, C, H, W = x.shape
511
+
512
+ if self.use_3d_conv:
513
+ # 修复:3D卷积路径
514
+ x = x.transpose(1, 2) # [B, C, T, H, W]
515
+ x = self.patch_embed(x) # [B, embed_dim, T', H', W']
516
+
517
+ # 重塑: [B, D, T', H'*W'] -> [B, T', H'*W', D]
518
+ B, D, T_new, H_new, W_new = x.shape
519
+ x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) # [B, T', H'*W', D]
520
+
521
+ # 空间位置编码(每帧独立)
522
+ x = x + self.spatial_pos_embed.unsqueeze(1)
523
+
524
+ # 逐帧空间编码
525
+ x_flat = x.reshape(B * T_new, -1, D)
526
+ for block in self.spatial_blocks:
527
+ x_flat, _, _ = block(x_flat)
528
+
529
+ # 重塑回时序维度
530
+ x = x_flat.view(B, T_new, -1, D)
531
+
532
+ # 修复:时序聚合 - 使用平均池化而非取第一个token
533
+ x = x.mean(dim=2) # [B, T', D]
534
+
535
+ else:
536
+ # 2D卷积 + 分离时空建模
537
+ x_flat = x.view(B * T, C, H, W)
538
+ x_patched, grid_size = self.patch_embed(x_flat)
539
+
540
+ # 空间位置编码
541
+ x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed)
542
+
543
+ # 空间编码
544
+ for block in self.spatial_blocks:
545
+ x_patched, _, _ = block(x_patched)
546
+
547
+ # 修复:时序聚合 - 全局平均池化而非仅mean(dim=2)
548
+ _, N, D = x_patched.shape
549
+ x_spatial = x_patched.view(B, T, N, D)
550
+ x = x_spatial.mean(dim=2) # [B, T, D] - 对每帧的所有patch取平均
551
+
552
+ # 时序位置编码
553
+ x = self.temporal_pos_drop(x + self.temporal_pos_embed)
554
+
555
+ # 时序编码
556
+ for block in self.temporal_blocks:
557
+ x, _, _ = block(x)
558
+
559
+ return self.norm(x)
gradio1.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio 推理界面 - 多模态 Dense Transformer (适配 Qwen Tokenizer 版)
3
+
4
+ 用法:
5
+ pip install -r requirements.txt
6
+ # requirements.txt 至少包含:
7
+ # torch>=1.12, transformers, pillow, gradio
8
+ python app_gradio.py --checkpoint /path/to/final_model.pt --tokenizer Qwen/Qwen2.5-7B-Instruct --port 7860 --share False
9
+ """
10
+
11
+ import os
12
+ import argparse
13
+ from pathlib import Path
14
+ import json
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from PIL import Image
19
+ from transformers import AutoTokenizer
20
+
21
+ # UI
22
+ import gradio as gr
23
+
24
+ # 本项目代码引用(按你的工程结构调整)
25
+ from model import MultiModalDenseTransformer
26
+ from continual_learning import UnifiedMultiModalPreprocessor
27
+
28
+ # 设置国内镜像(如需要)
29
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
30
+
31
+ # ---- 与你原来保持一致的图像预处理 ----
32
+ from torchvision import transforms
33
+ image_transform = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
37
+ std=[0.229, 0.224, 0.225]),
38
+ ])
39
+
40
+ # -------- ModelInference 类(轻微改写) --------
41
+ class ModelInference:
42
+ def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
43
+ self.device = torch.device(device)
44
+ print(f"Using device: {self.device}")
45
+ print(f"Loading tokenizer: {tokenizer_name}...")
46
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True)
47
+ if self.tokenizer.pad_token is None:
48
+ self.tokenizer.pad_token = self.tokenizer.eos_token
49
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
50
+
51
+ if config_path and Path(config_path).exists():
52
+ with open(config_path, 'r') as f:
53
+ self.config = json.load(f)
54
+ else:
55
+ # 采用你原始脚本中的默认 config(可按需调整)
56
+ self.config = {
57
+ 'model_dim': 1536,
58
+ 'vocab_size': len(self.tokenizer),
59
+ 'n_layers': 12,
60
+ 'n_heads': 12,
61
+ 'n_kv_heads': 4,
62
+ 'head_dim': None,
63
+ 'max_seq_len': 512,
64
+ 'dropout': 0.0,
65
+ 'use_moe': False,
66
+ 'use_adapter': False,
67
+ 'use_lora': False,
68
+ 'rope_scaling_type': "yarn",
69
+ 'use_multimodal_fusion': False,
70
+ 'use_contrastive': False
71
+ }
72
+
73
+ # init model + preprocessor
74
+ print("Initializing model architecture...")
75
+ self.model = MultiModalDenseTransformer(**self.config)
76
+ self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim'])
77
+
78
+ print(f"Loading checkpoint from {checkpoint_path}...")
79
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
80
+ # 支持 checkpoint 包含 'model_state_dict' 的情况
81
+ state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint
82
+
83
+ new_state_dict = {}
84
+ for k, v in state_dict.items():
85
+ if k.startswith('module.'):
86
+ new_state_dict[k[7:]] = v
87
+ else:
88
+ new_state_dict[k] = v
89
+
90
+ missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
91
+ if missing:
92
+ print(f"Warning: Missing keys: {len(missing)}")
93
+ if unexpected:
94
+ print(f"Warning: Unexpected keys: {len(unexpected)}")
95
+
96
+ self.model.to(self.device)
97
+ self.preprocessor.to(self.device)
98
+ self.model.eval()
99
+ print("Model loaded successfully!")
100
+ print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
101
+
102
+ @torch.no_grad()
103
+ def generate_text(self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 10, top_p: float = 0.9, repetition_penalty: float = 1.2, image: Optional[Image.Image] = None) -> str:
104
+ formatted_prompt = f"Instruction: {prompt}\nResponse:"
105
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
106
+ input_ids = inputs['input_ids'].to(self.device)
107
+
108
+ input_data = {'segments': []}
109
+ if image is not None:
110
+ try:
111
+ if image.mode != 'RGB':
112
+ image = image.convert('RGB')
113
+ image_tensor = image_transform(image).unsqueeze(0).to(self.device)
114
+ mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
115
+ for seg in mod_segments:
116
+ input_data['segments'].append(seg)
117
+ except Exception as e:
118
+ print(f"Warning: Image processing skipped due to error: {e}")
119
+
120
+ input_data['segments'].append({
121
+ 'type': 'text',
122
+ 'data': input_ids,
123
+ 'modality_id': 0
124
+ })
125
+
126
+ try:
127
+ generated_ids = self.model.generate(
128
+ input_data,
129
+ max_new_tokens=max_new_tokens,
130
+ temperature=temperature,
131
+ top_k=top_k,
132
+ top_p=top_p,
133
+ repetition_penalty=repetition_penalty,
134
+ do_sample=True,
135
+ eos_token_id=self.tokenizer.eos_token_id,
136
+ pad_token_id=self.tokenizer.pad_token_id
137
+ )
138
+
139
+ full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
140
+ # 提取 Response 后的部分并做 stop 处理
141
+ if "Response:" in full_output:
142
+ answer = full_output.split("Response:")[-1].strip()
143
+ else:
144
+ answer = full_output
145
+
146
+ stop_words = ["Instruction", "Input", "###", "Response", "User:", "Assistant:", "\n\n"]
147
+ for sw in stop_words:
148
+ if sw in answer:
149
+ answer = answer.split(sw)[0].strip()
150
+
151
+ # 去掉可能的 echo
152
+ lines = answer.split('\n')
153
+ if len(lines) > 0 and prompt.lower() in lines[0].lower():
154
+ answer = "\n".join(lines[1:]).strip()
155
+ return answer
156
+ except Exception as e:
157
+ import traceback
158
+ traceback.print_exc()
159
+ return f"Error: {e}"
160
+
161
+ # -------- Gradio UI 部分 --------
162
+ def build_ui(model_instance):
163
+ with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css="""
164
+ .gradio-container { max-width: 900px; margin: auto; }
165
+ """) as demo:
166
+ gr.Markdown("## 🚀 多模态在线推理(文本 + 图片)")
167
+ with gr.Row():
168
+ with gr.Column(scale=3):
169
+ txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5)
170
+ img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)")
171
+ btn = gr.Button("生成 (Generate)")
172
+ with gr.Column(scale=2):
173
+ max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=1024, step=1, value=128)
174
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.01, value=0.7)
175
+ top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, step=1, value=40)
176
+ top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
177
+ rep_pen = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.01, value=1.1)
178
+ status = gr.Textbox(label="Status", value="Ready", interactive=False)
179
+ output = gr.Textbox(label="Output", lines=12, interactive=False)
180
+
181
+ def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v):
182
+ if not prompt or str(prompt).strip() == "":
183
+ return "", "请输入 Prompt", ""
184
+ status_msg = "Generating..."
185
+ # call model
186
+ out = model_instance.generate_text(prompt=prompt,
187
+ max_new_tokens=int(max_tokens_v),
188
+ temperature=float(temp_v),
189
+ top_k=int(topk_v),
190
+ top_p=float(topp_v),
191
+ repetition_penalty=float(rep_v),
192
+ image=image)
193
+ return out, "Done", ""
194
+
195
+ btn.click(fn=gr_generate, inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen], outputs=[output, status, gr.State()])
196
+
197
+ demo.launch(share=True)
198
+
199
+ return demo
200
+
201
+ # -------- CLI / main --------
202
+ def main():
203
+ parser = argparse.ArgumentParser()
204
+ parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
205
+ parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
206
+ parser.add_argument("--config", type=str, default=None)
207
+ parser.add_argument("--port", type=int, default=7860)
208
+ parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True)
209
+ args = parser.parse_args()
210
+
211
+ # 如果 default 的 final_model 不存在,尝试寻找最近 step
212
+ if not Path(args.checkpoint).exists():
213
+ possible = list(Path("checkpoints/pretrain").glob("step_*.pt"))
214
+ if possible:
215
+ args.checkpoint = str(possible[-1])
216
+ print(f"未找到 final_model.pt,使用最新 checkpoint: {args.checkpoint}")
217
+ else:
218
+ raise FileNotFoundError(f"找不到检查点: {args.checkpoint}")
219
+
220
+ global model_instance
221
+ model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config)
222
+
223
+ # 启动 Gradio(使用 share 参数决定是否创建公网链接)
224
+ demo = build_ui(model_instance)
225
+ demo.launch(server_port=args.port, share=args.share)
226
+
227
+ if __name__ == "__main__":
228
+ main()
grpo.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 改进的 GRPO (Group Relative Policy Optimization) 训练器
3
+ 修复了所有已知问题
4
+ """
5
+
6
+ import torch
7
+ import torch.optim as optim
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import DataLoader, TensorDataset
10
+ from typing import Dict, List, Tuple, Optional
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import gc
14
+ import logging
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class GRPOTrainer:
21
+ """
22
+ GRPO训练器 - Group Relative Policy Optimization
23
+ 参考 DeepSeekMath/DeepSeek-V3 策略
24
+
25
+ 主要修复:
26
+ 1. 修复了 generate() 返回格式问题
27
+ 2. 修复了 reward_model 输出处理
28
+ 3. 添加了完整的混合精度训练支持
29
+ 4. 改进了 KL 散度计算的数值稳定性
30
+ 5. 修复了 past_key_values 的使用逻辑
31
+ 6. 改进了内存管理和错误处理
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ actor_model,
37
+ reward_model,
38
+ ref_model,
39
+ tokenizer,
40
+ learning_rate: float = 1e-6,
41
+ kl_coef: float = 0.04,
42
+ group_size: int = 4,
43
+ clip_epsilon: float = 0.2,
44
+ max_grad_norm: float = 1.0,
45
+ grpo_epochs: int = 1,
46
+ update_batch_size: int = 4,
47
+ use_amp: bool = True,
48
+ value_clip: bool = False,
49
+ entropy_coef: float = 0.01,
50
+ advantage_normalization: str = 'group', # 'group', 'global', 'none'
51
+ kl_estimation_method: str = 'forward' # 'forward', 'reverse', 'symmetric'
52
+ ):
53
+ """
54
+ 初始化GRPO训练器
55
+
56
+ Args:
57
+ actor_model: 要训练的策略模型
58
+ reward_model: 奖励模型(冻结)
59
+ ref_model: 参考模型(冻结)
60
+ tokenizer: 分词器
61
+ learning_rate: 学习率
62
+ kl_coef: KL散度惩罚系数
63
+ group_size: 每个prompt生成的样本数
64
+ clip_epsilon: PPO clip范围
65
+ max_grad_norm: 梯度裁剪阈值
66
+ grpo_epochs: 每批经验的训练轮数
67
+ update_batch_size: 更新时的mini-batch大小
68
+ use_amp: 是否使用混合精度训练
69
+ value_clip: 是否对value进行clip(当前未使用value网络)
70
+ entropy_coef: 熵正则化系数
71
+ advantage_normalization: 优势函数归一化方式
72
+ kl_estimation_method: KL散度估计方法
73
+ """
74
+ self.actor = actor_model
75
+ self.reward_model = reward_model
76
+ self.ref_model = ref_model
77
+ self.tokenizer = tokenizer
78
+
79
+ self.kl_coef = kl_coef
80
+ self.group_size = group_size
81
+ self.clip_epsilon = clip_epsilon
82
+ self.max_grad_norm = max_grad_norm
83
+ self.grpo_epochs = grpo_epochs
84
+ self.update_batch_size = update_batch_size
85
+ self.use_amp = use_amp
86
+ self.entropy_coef = entropy_coef
87
+ self.advantage_normalization = advantage_normalization
88
+ self.kl_estimation_method = kl_estimation_method
89
+
90
+ self.device = next(actor_model.parameters()).device
91
+
92
+ # 冻结参考模型和奖励模型
93
+ self.ref_model.eval()
94
+ self.ref_model.requires_grad_(False)
95
+ self.reward_model.eval()
96
+ self.reward_model.requires_grad_(False)
97
+
98
+ # 优化器配置
99
+ self.optimizer = optim.AdamW(
100
+ filter(lambda p: p.requires_grad, actor_model.parameters()),
101
+ lr=learning_rate,
102
+ weight_decay=0.01,
103
+ betas=(0.9, 0.95),
104
+ eps=1e-8
105
+ )
106
+
107
+ # 混合精度训练 - 修复:添加 GradScaler
108
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
109
+
110
+ # 训练统计
111
+ self.training_stats = {
112
+ 'iterations': 0,
113
+ 'total_samples': 0,
114
+ 'avg_rewards': [],
115
+ 'avg_kl': [],
116
+ 'policy_losses': []
117
+ }
118
+
119
+ logger.info(f"GRPO Trainer initialized:")
120
+ logger.info(f" Group Size: {group_size}")
121
+ logger.info(f" KL Coef: {kl_coef}")
122
+ logger.info(f" Clip Epsilon: {clip_epsilon}")
123
+ logger.info(f" Learning Rate: {learning_rate}")
124
+ logger.info(f" Update Batch Size: {update_batch_size}")
125
+ logger.info(f" Mixed Precision: {use_amp}")
126
+ logger.info(f" KL Estimation: {kl_estimation_method}")
127
+
128
+ def _compute_kl_divergence(
129
+ self,
130
+ log_probs: torch.Tensor,
131
+ ref_log_probs: torch.Tensor,
132
+ mask: torch.Tensor
133
+ ) -> torch.Tensor:
134
+ """
135
+ 计算KL散度(改进数值稳定性)
136
+
137
+ Args:
138
+ log_probs: 当前策略的log概率
139
+ ref_log_probs: 参考策略的log概率
140
+ mask: 有效token的mask
141
+
142
+ Returns:
143
+ KL散度(标量)
144
+ """
145
+ if self.kl_estimation_method == 'forward':
146
+ # KL(π||π_ref) = Σ π * log(π/π_ref)
147
+ # ≈ Σ exp(log_π) * (log_π - log_π_ref)
148
+ # 为了数值稳定,使用 log_π - log_π_ref
149
+ kl = log_probs - ref_log_probs
150
+ elif self.kl_estimation_method == 'reverse':
151
+ # KL(π_ref||π) = Σ π_ref * log(π_ref/π)
152
+ kl = ref_log_probs - log_probs
153
+ else: # symmetric
154
+ # 对称KL散度
155
+ forward_kl = log_probs - ref_log_probs
156
+ reverse_kl = ref_log_probs - log_probs
157
+ kl = 0.5 * (forward_kl + reverse_kl)
158
+
159
+ # 应用mask并求和
160
+ kl_penalty = (kl * mask).sum(dim=-1)
161
+ return kl_penalty
162
+
163
+ @torch.no_grad()
164
+ def generate_experience(
165
+ self,
166
+ prompts_dataloader: DataLoader,
167
+ max_gen_len: int,
168
+ temperature: float = 1.0,
169
+ top_p: float = 0.9
170
+ ) -> Dict:
171
+ """
172
+ 生成经验数据:采样 -> 计算 LogProbs -> 计算 Rewards(含KL)
173
+
174
+ 修复:
175
+ 1. 正确处理 generate() 的返回值
176
+ 2. 修复 reward_model 的输出处理
177
+ 3. 改进数值稳定性
178
+ """
179
+ self.actor.eval()
180
+
181
+ all_sequences = []
182
+ all_log_probs = []
183
+ all_advantages = []
184
+ all_prompt_lens = []
185
+ all_rewards = []
186
+
187
+ logger.info("Generating experience...")
188
+
189
+ for prompts in tqdm(prompts_dataloader, desc="Generating Experience"):
190
+ try:
191
+ # 处理不同的输入格式
192
+ if isinstance(prompts, (list, tuple)):
193
+ prompts = prompts[0]
194
+
195
+ prompts = prompts.to(self.device)
196
+ batch_size = prompts.shape[0]
197
+
198
+ # 扩展prompts以生成group_size个样本
199
+ prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0)
200
+ prompt_len = prompts_repeated.shape[1]
201
+
202
+ input_data = {
203
+ 'segments': [{
204
+ 'type': 'text',
205
+ 'data': prompts_repeated,
206
+ 'modality_id': 0
207
+ }]
208
+ }
209
+
210
+ # 1. 采样生成(修复:generate只返回新生成的tokens)
211
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
212
+ response_ids = self.actor.generate(
213
+ input_data,
214
+ max_new_tokens=max_gen_len,
215
+ do_sample=True,
216
+ temperature=temperature,
217
+ top_p=top_p,
218
+ eos_token_id=self.tokenizer.eos_token_id,
219
+ pad_token_id=self.tokenizer.pad_token_id,
220
+ use_cache=True # 使用缓存加速生成
221
+ )
222
+
223
+ # 修复:拼接完整序列(prompt + response)
224
+ sequences = torch.cat([prompts_repeated, response_ids], dim=1)
225
+
226
+ # 检查序列长度
227
+ if sequences.shape[1] <= prompt_len:
228
+ logger.warning("Generated sequence too short, skipping batch")
229
+ continue
230
+
231
+ full_input_data = {
232
+ 'segments': [{
233
+ 'type': 'text',
234
+ 'data': sequences,
235
+ 'modality_id': 0
236
+ }]
237
+ }
238
+
239
+ # 2. 计算当前策略和参考策略的 LogProbs
240
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
241
+ actor_out = self.actor(full_input_data)
242
+ ref_out = self.ref_model(full_input_data)
243
+
244
+ logits = actor_out['logits'][:, :-1, :]
245
+ ref_logits = ref_out['logits'][:, :-1, :]
246
+ targets = sequences[:, 1:]
247
+
248
+ # 计算log probabilities(改进数值稳定性)
249
+ log_probs = F.log_softmax(logits, dim=-1)
250
+ ref_log_probs = F.log_softmax(ref_logits, dim=-1)
251
+
252
+ # 提取对应token的log概率
253
+ per_token_log_probs = torch.gather(
254
+ log_probs, -1, targets.unsqueeze(-1)
255
+ ).squeeze(-1)
256
+ per_token_ref_log_probs = torch.gather(
257
+ ref_log_probs, -1, targets.unsqueeze(-1)
258
+ ).squeeze(-1)
259
+
260
+ # 3. 计算 KL 散度 (只针对response部分)
261
+ response_mask = torch.arange(
262
+ sequences.size(1) - 1, device=self.device
263
+ ) >= (prompt_len - 1)
264
+ response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs)
265
+ response_mask = response_mask.float()
266
+
267
+ # 使用改进的KL计算
268
+ kl_penalty = self._compute_kl_divergence(
269
+ per_token_log_probs,
270
+ per_token_ref_log_probs,
271
+ response_mask
272
+ )
273
+
274
+ # 4. 计算环境奖励(修复:正确处理reward_model输出)
275
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
276
+ reward_output = self.reward_model(full_input_data)
277
+
278
+ # reward_model返回 (batch_size, seq_len),取最后一个位置的奖励
279
+ if reward_output.dim() == 2:
280
+ raw_rewards = reward_output[:, -1]
281
+ else:
282
+ raw_rewards = reward_output.squeeze(-1)
283
+
284
+ # 5. 组合总奖励: R_total = R_env - β * KL
285
+ total_rewards = raw_rewards - self.kl_coef * kl_penalty
286
+
287
+ # 6. 计算组内相对优势 (Group Relative Advantage)
288
+ rewards_grouped = total_rewards.view(batch_size, self.group_size)
289
+
290
+ if self.advantage_normalization == 'group':
291
+ # 组内标准化
292
+ mean_grouped = rewards_grouped.mean(dim=1, keepdim=True)
293
+ std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
294
+ advantages = (rewards_grouped - mean_grouped) / std_grouped
295
+ elif self.advantage_normalization == 'global':
296
+ # 全局标准化
297
+ advantages = (rewards_grouped - rewards_grouped.mean()) / (
298
+ rewards_grouped.std() + 1e-8
299
+ )
300
+ else: # 'none'
301
+ advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True)
302
+
303
+ advantages = advantages.view(-1)
304
+
305
+ # 保存数据
306
+ all_sequences.append(sequences.cpu())
307
+ all_log_probs.append(per_token_log_probs.detach().cpu())
308
+ all_advantages.append(advantages.detach().cpu())
309
+ all_prompt_lens.append(
310
+ torch.full((sequences.size(0),), prompt_len, dtype=torch.long)
311
+ )
312
+ all_rewards.append(total_rewards.detach().cpu())
313
+
314
+ # 清理中间变量
315
+ del logits, ref_logits, actor_out, ref_out
316
+ del log_probs, ref_log_probs, reward_output
317
+
318
+ except Exception as e:
319
+ logger.error(f"Error generating experience for batch: {e}")
320
+ import traceback
321
+ traceback.print_exc()
322
+ continue
323
+
324
+ finally:
325
+ torch.cuda.empty_cache()
326
+
327
+ if not all_sequences:
328
+ raise RuntimeError("No valid sequences generated")
329
+
330
+ # 合并所有数据
331
+ experience = {
332
+ 'sequences': torch.cat(all_sequences, dim=0),
333
+ 'log_probs': torch.cat(all_log_probs, dim=0),
334
+ 'advantages': torch.cat(all_advantages, dim=0),
335
+ 'prompt_lengths': torch.cat(all_prompt_lens, dim=0),
336
+ 'rewards': torch.cat(all_rewards, dim=0)
337
+ }
338
+
339
+ # 统计信息
340
+ logger.info(f"Generated {len(experience['sequences'])} sequences")
341
+ logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}")
342
+ logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}")
343
+ logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}")
344
+
345
+ return experience
346
+
347
+ def grpo_step(
348
+ self,
349
+ dataset: TensorDataset
350
+ ) -> Dict[str, float]:
351
+ """
352
+ 执行 GRPO 优化步骤
353
+
354
+ 修复:
355
+ 1. 使用 GradScaler 进行混合精度训练
356
+ 2. 改进损失计算
357
+ 3. 更好的统计信息收集
358
+ """
359
+ self.actor.train()
360
+
361
+ dataloader = DataLoader(
362
+ dataset,
363
+ batch_size=self.update_batch_size,
364
+ shuffle=True,
365
+ drop_last=False
366
+ )
367
+
368
+ epoch_stats = {
369
+ 'total_loss': 0.0,
370
+ 'policy_loss': 0.0,
371
+ 'entropy': 0.0,
372
+ 'approx_kl': 0.0,
373
+ 'clip_fraction': 0.0,
374
+ 'steps': 0
375
+ }
376
+
377
+ for batch_data in dataloader:
378
+ sequences, old_log_probs, advantages, prompt_lens = batch_data
379
+
380
+ sequences = sequences.to(self.device)
381
+ old_log_probs = old_log_probs.to(self.device)
382
+ advantages = advantages.to(self.device)
383
+
384
+ input_data = {
385
+ 'segments': [{
386
+ 'type': 'text',
387
+ 'data': sequences,
388
+ 'modality_id': 0
389
+ }]
390
+ }
391
+
392
+ # 修复:使用 GradScaler 进行混合精度训练
393
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
394
+ outputs = self.actor(input_data)
395
+ logits = outputs['logits'][:, :-1, :]
396
+
397
+ # 计算新的log probabilities
398
+ targets = sequences[:, 1:]
399
+ log_probs_dist = F.log_softmax(logits, dim=-1)
400
+ new_log_probs = torch.gather(
401
+ log_probs_dist, -1, targets.unsqueeze(-1)
402
+ ).squeeze(-1)
403
+
404
+ # 构建response mask
405
+ mask = torch.zeros_like(new_log_probs)
406
+ for i, pl in enumerate(prompt_lens):
407
+ mask[i, pl-1:] = 1.0
408
+
409
+ # 计算概率比率
410
+ ratio = torch.exp(new_log_probs - old_log_probs)
411
+
412
+ # 扩展advantages到序列维度
413
+ adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs)
414
+
415
+ # PPO clip损失
416
+ surr1 = ratio * adv_expanded
417
+ surr2 = torch.clamp(
418
+ ratio,
419
+ 1.0 - self.clip_epsilon,
420
+ 1.0 + self.clip_epsilon
421
+ ) * adv_expanded
422
+
423
+ # 策略损失(最小化负目标)
424
+ policy_loss = -torch.min(surr1, surr2)
425
+ policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8)
426
+
427
+ # 熵奖励(鼓励探索)
428
+ probs = F.softmax(logits, dim=-1)
429
+ entropy = -(probs * log_probs_dist).sum(dim=-1)
430
+ entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8)
431
+
432
+ # 总损失
433
+ loss = policy_loss - self.entropy_coef * entropy_bonus
434
+
435
+ # 统计信息
436
+ with torch.no_grad():
437
+ log_ratio = new_log_probs - old_log_probs
438
+ approx_kl = ((ratio - 1) - log_ratio) * mask
439
+ approx_kl = approx_kl.sum() / (mask.sum() + 1e-8)
440
+
441
+ clip_fraction = ((ratio > 1 + self.clip_epsilon) |
442
+ (ratio < 1 - self.clip_epsilon)).float()
443
+ clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8)
444
+
445
+ # 修复:使用 GradScaler 进行反向传播
446
+ self.optimizer.zero_grad()
447
+ self.scaler.scale(loss).backward()
448
+
449
+ # 梯度裁剪
450
+ self.scaler.unscale_(self.optimizer)
451
+ grad_norm = torch.nn.utils.clip_grad_norm_(
452
+ self.actor.parameters(),
453
+ self.max_grad_norm
454
+ )
455
+
456
+ self.scaler.step(self.optimizer)
457
+ self.scaler.update()
458
+
459
+ # 累积统计
460
+ epoch_stats['total_loss'] += loss.item()
461
+ epoch_stats['policy_loss'] += policy_loss.item()
462
+ epoch_stats['entropy'] += entropy_bonus.item()
463
+ epoch_stats['approx_kl'] += approx_kl.item()
464
+ epoch_stats['clip_fraction'] += clip_fraction.item()
465
+ epoch_stats['steps'] += 1
466
+
467
+ # 计算平均值
468
+ for key in epoch_stats:
469
+ if key != 'steps':
470
+ epoch_stats[key] /= max(epoch_stats['steps'], 1)
471
+
472
+ return epoch_stats
473
+
474
+ def train(
475
+ self,
476
+ prompt_dataloader: DataLoader,
477
+ num_iterations: int = 1,
478
+ max_gen_len: int = 50,
479
+ temperature: float = 1.0,
480
+ save_every: int = 5,
481
+ save_path: str = "checkpoints"
482
+ ):
483
+ """
484
+ 完整的GRPO训练循环
485
+
486
+ Args:
487
+ prompt_dataloader: 提供prompts的数据加载器
488
+ num_iterations: 训练迭代次数
489
+ max_gen_len: 最大生成长度
490
+ temperature: 采样温度
491
+ save_every: 每N次迭代保存一次checkpoint
492
+ save_path: checkpoint保存路径
493
+ """
494
+ logger.info(f"\n{'='*80}")
495
+ logger.info(f"Starting GRPO Training")
496
+ logger.info(f" Iterations: {num_iterations}")
497
+ logger.info(f" Max Gen Length: {max_gen_len}")
498
+ logger.info(f" Temperature: {temperature}")
499
+ logger.info(f"{'='*80}\n")
500
+
501
+ for iteration in range(num_iterations):
502
+ try:
503
+ # 1. 生成经验
504
+ experience = self.generate_experience(
505
+ prompt_dataloader,
506
+ max_gen_len,
507
+ temperature
508
+ )
509
+
510
+ dataset = TensorDataset(
511
+ experience['sequences'],
512
+ experience['log_probs'],
513
+ experience['advantages'],
514
+ experience['prompt_lengths']
515
+ )
516
+
517
+ # 2. 策略优化(多个epoch)
518
+ logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...")
519
+ all_epoch_stats = []
520
+
521
+ for epoch in range(self.grpo_epochs):
522
+ stats = self.grpo_step(dataset)
523
+ all_epoch_stats.append(stats)
524
+
525
+ logger.info(
526
+ f" Epoch {epoch+1}/{self.grpo_epochs} | "
527
+ f"Loss: {stats['total_loss']:.4f} | "
528
+ f"KL: {stats['approx_kl']:.4f} | "
529
+ f"Clip%: {stats['clip_fraction']*100:.1f}"
530
+ )
531
+
532
+ # 3. 汇总统计
533
+ avg_stats = {
534
+ key: np.mean([s[key] for s in all_epoch_stats])
535
+ for key in all_epoch_stats[0].keys()
536
+ }
537
+
538
+ self.training_stats['iterations'] += 1
539
+ self.training_stats['total_samples'] += len(experience['sequences'])
540
+ self.training_stats['avg_rewards'].append(
541
+ experience['rewards'].mean().item()
542
+ )
543
+ self.training_stats['avg_kl'].append(avg_stats['approx_kl'])
544
+ self.training_stats['policy_losses'].append(avg_stats['policy_loss'])
545
+
546
+ # 4. 打印进度
547
+ logger.info(f"\n{'='*80}")
548
+ logger.info(f"Iteration {iteration+1}/{num_iterations} Complete")
549
+ logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}")
550
+ logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}")
551
+ logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}")
552
+ logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}")
553
+ logger.info(f" Entropy: {avg_stats['entropy']:.4f}")
554
+ logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%")
555
+ logger.info(f"{'='*80}\n")
556
+
557
+ # 5. 保存checkpoint
558
+ if (iteration + 1) % save_every == 0:
559
+ self.save_checkpoint(
560
+ f"{save_path}/grpo_iter_{iteration+1}.pt"
561
+ )
562
+
563
+ # 6. 清理内存
564
+ del experience, dataset
565
+ gc.collect()
566
+ torch.cuda.empty_cache()
567
+
568
+ except Exception as e:
569
+ logger.error(f"Error in iteration {iteration+1}: {e}")
570
+ import traceback
571
+ traceback.print_exc()
572
+ continue
573
+
574
+ logger.info("GRPO Training Complete!")
575
+ self.print_training_summary()
576
+
577
+ def save_checkpoint(self, path: str):
578
+ """保存训练checkpoint"""
579
+ import os
580
+ os.makedirs(os.path.dirname(path), exist_ok=True)
581
+
582
+ checkpoint = {
583
+ 'actor_state_dict': self.actor.state_dict(),
584
+ 'optimizer_state_dict': self.optimizer.state_dict(),
585
+ 'scaler_state_dict': self.scaler.state_dict(), # 修复:保存scaler状态
586
+ 'training_stats': self.training_stats,
587
+ 'config': {
588
+ 'kl_coef': self.kl_coef,
589
+ 'group_size': self.group_size,
590
+ 'clip_epsilon': self.clip_epsilon,
591
+ }
592
+ }
593
+
594
+ torch.save(checkpoint, path)
595
+ logger.info(f"Checkpoint saved to {path}")
596
+
597
+ def load_checkpoint(self, path: str):
598
+ """加载训练checkpoint"""
599
+ checkpoint = torch.load(path, map_location=self.device)
600
+
601
+ self.actor.load_state_dict(checkpoint['actor_state_dict'])
602
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
603
+
604
+ # 修复:加载scaler状态
605
+ if 'scaler_state_dict' in checkpoint and self.use_amp:
606
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
607
+
608
+ self.training_stats = checkpoint['training_stats']
609
+
610
+ logger.info(f"Checkpoint loaded from {path}")
611
+
612
+ def print_training_summary(self):
613
+ """打印训练摘要"""
614
+ logger.info("\n" + "="*80)
615
+ logger.info("Training Summary")
616
+ logger.info("="*80)
617
+ logger.info(f"Total Iterations: {self.training_stats['iterations']}")
618
+ logger.info(f"Total Samples: {self.training_stats['total_samples']}")
619
+
620
+ if self.training_stats['avg_rewards']:
621
+ logger.info(
622
+ f"Final Avg Reward: "
623
+ f"{self.training_stats['avg_rewards'][-1]:.4f}"
624
+ )
625
+ logger.info(
626
+ f"Reward Improvement: "
627
+ f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}"
628
+ )
629
+
630
+ logger.info("="*80 + "\n")
infer.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from flask import Flask, render_template, request, jsonify
9
+ from transformers import AutoTokenizer
10
+ from PIL import Image
11
+ import json
12
+ import io
13
+ import base64
14
+ from pathlib import Path
15
+ from typing import Optional
16
+
17
+ # 确保引入路径正确,根据你之前的文件结构
18
+ from model import MultiModalDenseTransformer
19
+ # 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
20
+ # 如果你移动了它,请修改这里的导入路径
21
+ from continual_learning import UnifiedMultiModalPreprocessor
22
+ # 如果没有 image_transform,我们需要在这里定义或导入
23
+ from torchvision import transforms
24
+
25
+ # 设置国内镜像
26
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
27
+
28
+ # 定义图像预处理 (与 training 保持一致)
29
+ image_transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ class ModelInference:
36
+ """模型推理类"""
37
+
38
+ def __init__(
39
+ self,
40
+ checkpoint_path: str,
41
+ tokenizer_name: str,
42
+ config_path: Optional[str] = None,
43
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ ):
45
+ self.device = torch.device(device)
46
+ print(f"Using device: {self.device}")
47
+
48
+ # 1. 加载 Tokenizer (与预训练一致)
49
+ print(f"Loading tokenizer: {tokenizer_name}...")
50
+ try:
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ tokenizer_name,
53
+ use_fast=True,
54
+ trust_remote_code=True
55
+ )
56
+ if self.tokenizer.pad_token is None:
57
+ self.tokenizer.pad_token = self.tokenizer.eos_token
58
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
59
+ except Exception as e:
60
+ print(f"Error loading tokenizer: {e}")
61
+ raise e
62
+
63
+ # 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
64
+ if config_path and Path(config_path).exists():
65
+ with open(config_path, 'r') as f:
66
+ self.config = json.load(f)
67
+ else:
68
+ # [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
69
+ self.config = {
70
+ 'model_dim': 1536, # 预训练设置
71
+ 'vocab_size': len(self.tokenizer), # 自动适配 Qwen (约 151665)
72
+ 'n_layers': 12, # 预训练设置
73
+ 'n_heads': 12, # 预训练设置
74
+ 'n_kv_heads': 4, # 预训练设置
75
+ 'head_dim': None, # 自动计算
76
+ 'max_seq_len': 512, # 预训练设置
77
+ 'dropout': 0.0, # 推理时关闭 dropout
78
+ 'use_moe': False, # 预训练设置
79
+ 'use_adapter': False, # 预训练未开启 Adapter
80
+ 'use_lora': False, # 预训练未开启 LoRA
81
+ 'rope_scaling_type': "yarn" # 预训练设置
82
+ }
83
+
84
+ # 3. 初始化模型结构
85
+ print("Initializing model architecture...")
86
+ try:
87
+ self.model = MultiModalDenseTransformer(**self.config)
88
+ self.preprocessor = UnifiedMultiModalPreprocessor(
89
+ model_dim=self.config['model_dim']
90
+ )
91
+
92
+ # 4. 加载权重
93
+ print(f"Loading checkpoint from {checkpoint_path}...")
94
+ # weights_only=False 是为了支持加载完整的 checkpoint 字典
95
+ checkpoint = torch.load(
96
+ checkpoint_path,
97
+ map_location=self.device,
98
+ weights_only=False
99
+ )
100
+
101
+ # 提取 state_dict
102
+ if 'model_state_dict' in checkpoint:
103
+ print("Found 'model_state_dict' in checkpoint.")
104
+ state_dict = checkpoint['model_state_dict']
105
+ else:
106
+ state_dict = checkpoint
107
+
108
+ # 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
109
+ new_state_dict = {}
110
+ for k, v in state_dict.items():
111
+ if k.startswith('module.'):
112
+ new_state_dict[k[7:]] = v
113
+ else:
114
+ new_state_dict[k] = v
115
+
116
+ # 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
117
+ missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
118
+ if missing:
119
+ print(f"Warning: Missing keys: {len(missing)}")
120
+ if unexpected:
121
+ print(f"Warning: Unexpected keys: {len(unexpected)}")
122
+
123
+ self.model.to(self.device)
124
+ self.preprocessor.to(self.device)
125
+ self.model.eval()
126
+
127
+ print("Model loaded successfully!")
128
+ print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
129
+
130
+ except Exception as e:
131
+ print(f"Error initializing model: {e}")
132
+ raise e
133
+
134
+ @torch.no_grad()
135
+ def generate_text(
136
+ self,
137
+ prompt: str,
138
+ max_new_tokens: int = 128,
139
+ temperature: float = 0.7,
140
+ top_k: int = 40,
141
+ top_p: float = 0.9,
142
+ repetition_penalty: float = 1.1,
143
+ image: Optional[Image.Image] = None
144
+ ) -> str:
145
+ """生成文本"""
146
+
147
+ # 编码输入
148
+ inputs = self.tokenizer(prompt, return_tensors="pt")
149
+ input_ids = inputs['input_ids'].to(self.device)
150
+
151
+ # 构建 MultiModalDenseTransformer 需要的输入格式
152
+ input_data = {'segments': []}
153
+
154
+ # 处理图像
155
+ if image is not None:
156
+ if image.mode != 'RGB':
157
+ image = image.convert('RGB')
158
+ # 简单的图像处理
159
+ image_tensor = image_transform(image).unsqueeze(0).to(self.device)
160
+ # 这里假设预处理器能处理这种输入
161
+ try:
162
+ # process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
163
+ mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
164
+ # 将返回的 segment 列表合并到 input_data
165
+ for seg in mod_segments:
166
+ input_data['segments'].append(seg)
167
+ except Exception as e:
168
+ print(f"Warning: Image processing skipped due to error: {e}")
169
+
170
+ # 添加文本段
171
+ input_data['segments'].append({
172
+ 'type': 'text',
173
+ 'data': input_ids,
174
+ 'modality_id': 0
175
+ })
176
+
177
+ # 生成
178
+ try:
179
+ # 使用模型自带的 generate 方法
180
+ generated_ids = self.model.generate(
181
+ input_data,
182
+ max_new_tokens=max_new_tokens,
183
+ temperature=temperature,
184
+ top_k=top_k,
185
+ top_p=top_p,
186
+ repetition_penalty=repetition_penalty,
187
+ do_sample=True,
188
+ eos_token_id=self.tokenizer.eos_token_id,
189
+ pad_token_id=self.tokenizer.pad_token_id
190
+ )
191
+
192
+ # 解码
193
+ # 注意:生成的 ids 可能包含原始输入,或者只包含新生成的 token
194
+ # MultiModalDenseTransformer.generate 通常返回完整的序列
195
+ generated_text = self.tokenizer.decode(
196
+ generated_ids[0],
197
+ skip_special_tokens=True
198
+ )
199
+
200
+ # 如果包含 prompt,可以选择移除它只显示新内容
201
+ # if generated_text.startswith(prompt):
202
+ # generated_text = generated_text[len(prompt):]
203
+
204
+ return generated_text
205
+
206
+ except Exception as e:
207
+ print(f"Generation error: {e}")
208
+ import traceback
209
+ traceback.print_exc()
210
+ return f"Error: {str(e)}"
211
+
212
+
213
+ # 全局模型实例
214
+ model_instance = None
215
+ app = Flask(__name__)
216
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
217
+
218
+ @app.route('/')
219
+ def index():
220
+ display_config = model_instance.config.copy() if model_instance else {}
221
+ return render_template('index.html', config=display_config)
222
+
223
+ @app.route('/generate', methods=['POST'])
224
+ def generate():
225
+ try:
226
+ data = request.json
227
+ prompt = data.get('prompt', '')
228
+ if not prompt.strip():
229
+ return jsonify({'error': '请输入提示文本'}), 400
230
+
231
+ max_tokens = int(data.get('max_tokens', 100))
232
+ temperature = float(data.get('temperature', 0.7))
233
+ top_k = int(data.get('top_k', 40))
234
+ top_p = float(data.get('top_p', 0.9))
235
+ repetition_penalty = float(data.get('repetition_penalty', 1.1))
236
+
237
+ image = None
238
+ if 'image' in data and data['image']:
239
+ try:
240
+ image_data = base64.b64decode(data['image'].split(',')[1])
241
+ image = Image.open(io.BytesIO(image_data))
242
+ except Exception as e:
243
+ print(f"Image load error: {e}")
244
+
245
+ output = model_instance.generate_text(
246
+ prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
247
+ )
248
+ return jsonify({'output': output})
249
+
250
+ except Exception as e:
251
+ return jsonify({'error': str(e)}), 500
252
+
253
+ def create_html_template():
254
+ """写入HTML模板"""
255
+ html_content = '''
256
+ <!DOCTYPE html>
257
+ <html lang="zh-CN">
258
+ <head>
259
+ <meta charset="UTF-8">
260
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
261
+ <title>Model Inference</title>
262
+ <style>
263
+ body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
264
+ .container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
265
+ h1 { color: #1a73e8; text-align: center; }
266
+ textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
267
+ .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
268
+ button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
269
+ button:hover { background: #1557b0; }
270
+ button:disabled { background: #ccc; }
271
+ #output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
272
+ .loading { color: #666; font-style: italic; }
273
+ </style>
274
+ </head>
275
+ <body>
276
+ <div class="container">
277
+ <h1>🚀 模型在线推理</h1>
278
+
279
+ <div>
280
+ <label><strong>提示词 (Prompt):</strong></label>
281
+ <textarea id="prompt" placeholder="请输入你的问题..."></textarea>
282
+ </div>
283
+
284
+ <div class="controls">
285
+ <div>
286
+ <label>Max Tokens: <span id="maxTokensVal">128</span></label>
287
+ <input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
288
+ </div>
289
+ <div>
290
+ <label>Temperature: <span id="tempVal">0.7</span></label>
291
+ <input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
292
+ </div>
293
+ </div>
294
+
295
+ <button id="btn" onclick="generate()">生成 (Generate)</button>
296
+
297
+ <div id="output">结果将显示在这里...</div>
298
+ </div>
299
+
300
+ <script>
301
+ async function generate() {
302
+ const prompt = document.getElementById('prompt').value;
303
+ if(!prompt) return alert("请输入内容");
304
+
305
+ const btn = document.getElementById('btn');
306
+ const out = document.getElementById('output');
307
+
308
+ btn.disabled = true;
309
+ btn.innerText = "生成中...";
310
+ out.innerHTML = '<div class="loading">正在思考中...</div>';
311
+
312
+ try {
313
+ const res = await fetch('/generate', {
314
+ method: 'POST',
315
+ headers: {'Content-Type': 'application/json'},
316
+ body: JSON.stringify({
317
+ prompt: prompt,
318
+ max_tokens: parseInt(document.getElementById('maxTokens').value),
319
+ temperature: parseFloat(document.getElementById('temperature').value)
320
+ })
321
+ });
322
+ const data = await res.json();
323
+ if(data.error) out.innerText = "Error: " + data.error;
324
+ else out.innerText = data.output;
325
+ } catch(e) {
326
+ out.innerText = "请求失败: " + e;
327
+ } finally {
328
+ btn.disabled = false;
329
+ btn.innerText = "生成 (Generate)";
330
+ }
331
+ }
332
+ </script>
333
+ </body>
334
+ </html>
335
+ '''
336
+
337
+ Path('templates').mkdir(exist_ok=True)
338
+ with open('templates/index.html', 'w', encoding='utf-8') as f:
339
+ f.write(html_content)
340
+
341
+ def main():
342
+ import argparse
343
+ parser = argparse.ArgumentParser()
344
+ # 默认指向 pretrain 保存的 checkpoint 路径
345
+ parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt")
346
+ parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
347
+ parser.add_argument("--port", type=int, default=5001)
348
+ parser.add_argument("--host", type=str, default="0.0.0.0")
349
+ args = parser.parse_args()
350
+
351
+ if not Path(args.checkpoint).exists():
352
+ # 尝试找最近的 step checkpoint
353
+ steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
354
+ if steps:
355
+ print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
356
+ args.checkpoint = str(steps[-1])
357
+ else:
358
+ print(f"错误: 找不到检查点文件: {args.checkpoint}")
359
+ return
360
+
361
+ create_html_template()
362
+
363
+ global model_instance
364
+ model_instance = ModelInference(args.checkpoint, args.tokenizer)
365
+
366
+ print(f"\n服务已启动: http://{args.host}:{args.port}")
367
+ app.run(host=args.host, port=args.port,
368
+ debug=True, # 开启调试模式
369
+ use_reloader=False)
370
+
371
+ if __name__ == "__main__":
372
+ main()
infer_sft.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from flask import Flask, render_template, request, jsonify
9
+ from transformers import AutoTokenizer
10
+ from PIL import Image
11
+ import json
12
+ import io
13
+ import base64
14
+ from pathlib import Path
15
+ from typing import Optional
16
+
17
+ # 确保引入路径正确,根据你之前的文件结构
18
+ from model import MultiModalDenseTransformer
19
+ # 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
20
+ # 如果你移动了它,请修改这里的导入路径
21
+ from continual_learning import UnifiedMultiModalPreprocessor
22
+ # 如果没有 image_transform,我们需要在这里定义或导入
23
+ from torchvision import transforms
24
+
25
+ # 设置国内镜像
26
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
27
+
28
+ # 定义图像预处理 (与 training 保持一致)
29
+ image_transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ class ModelInference:
36
+ """模型推理类"""
37
+
38
+ def __init__(
39
+ self,
40
+ checkpoint_path: str,
41
+ tokenizer_name: str,
42
+ config_path: Optional[str] = None,
43
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ ):
45
+ self.device = torch.device(device)
46
+ print(f"Using device: {self.device}")
47
+
48
+ # 1. 加载 Tokenizer (与预训练一致)
49
+ print(f"Loading tokenizer: {tokenizer_name}...")
50
+ try:
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ tokenizer_name,
53
+ use_fast=True,
54
+ trust_remote_code=True
55
+ )
56
+ if self.tokenizer.pad_token is None:
57
+ self.tokenizer.pad_token = self.tokenizer.eos_token
58
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
59
+ except Exception as e:
60
+ print(f"Error loading tokenizer: {e}")
61
+ raise e
62
+
63
+ # 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
64
+ if config_path and Path(config_path).exists():
65
+ with open(config_path, 'r') as f:
66
+ self.config = json.load(f)
67
+ else:
68
+ # [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
69
+ self.config = {
70
+ 'model_dim': 1536, # 预训练设置
71
+ 'vocab_size': len(self.tokenizer), # 自动适配 Qwen (约 151665)
72
+ 'n_layers': 12, # 预训练设置
73
+ 'n_heads': 12, # 预训练设置
74
+ 'n_kv_heads': 4, # 预训练设置
75
+ 'head_dim': None, # 自动计算
76
+ 'max_seq_len': 512, # 预训练设置
77
+ 'dropout': 0.0, # 推理时关闭 dropout
78
+ 'use_moe': False, # 预训练设置
79
+ 'use_adapter': False, # 预训练未开启 Adapter
80
+ 'use_lora': False, # 预训练未开启 LoRA
81
+ 'rope_scaling_type': "yarn", # 预训练设置
82
+ 'use_multimodal_fusion': False,
83
+ 'use_contrastive': False
84
+ }
85
+
86
+ # 3. 初始化模型结构
87
+ print("Initializing model architecture...")
88
+ try:
89
+ self.model = MultiModalDenseTransformer(**self.config)
90
+ self.preprocessor = UnifiedMultiModalPreprocessor(
91
+ model_dim=self.config['model_dim']
92
+ )
93
+
94
+ # 4. 加载权重
95
+ print(f"Loading checkpoint from {checkpoint_path}...")
96
+ # weights_only=False 是为了支持加载完整的 checkpoint 字典
97
+ checkpoint = torch.load(
98
+ checkpoint_path,
99
+ map_location=self.device,
100
+ weights_only=False
101
+ )
102
+
103
+ # 提取 state_dict
104
+ if 'model_state_dict' in checkpoint:
105
+ print("Found 'model_state_dict' in checkpoint.")
106
+ state_dict = checkpoint['model_state_dict']
107
+ else:
108
+ state_dict = checkpoint
109
+
110
+ # 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
111
+ new_state_dict = {}
112
+ for k, v in state_dict.items():
113
+ if k.startswith('module.'):
114
+ new_state_dict[k[7:]] = v
115
+ else:
116
+ new_state_dict[k] = v
117
+
118
+ # 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
119
+ missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
120
+ if missing:
121
+ print(f"Warning: Missing keys: {len(missing)}")
122
+ if unexpected:
123
+ print(f"Warning: Unexpected keys: {len(unexpected)}")
124
+
125
+ self.model.to(self.device)
126
+ self.preprocessor.to(self.device)
127
+ self.model.eval()
128
+
129
+ print("Model loaded successfully!")
130
+ print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
131
+
132
+ except Exception as e:
133
+ print(f"Error initializing model: {e}")
134
+ raise e
135
+
136
+ @torch.no_grad()
137
+ def generate_text(
138
+ self,
139
+ prompt: str,
140
+ max_new_tokens: int = 128,
141
+ temperature: float = 0.7,
142
+ top_k: int = 10,
143
+ top_p: float = 0.9,
144
+ repetition_penalty: float = 1.2,
145
+ image: Optional[Image.Image] = None
146
+ ) -> str:
147
+ """生成文本"""
148
+ formatted_prompt = f"Instruction: {prompt}\nResponse:"
149
+ # 编码输入
150
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
151
+
152
+ # 编码输入
153
+ #inputs = self.tokenizer(prompt, return_tensors="pt")
154
+ input_ids = inputs['input_ids'].to(self.device)
155
+
156
+ # 构建 MultiModalDenseTransformer 需要的输入格式
157
+ input_data = {'segments': []}
158
+
159
+ # 处理图像
160
+ if image is not None:
161
+ if image.mode != 'RGB':
162
+ image = image.convert('RGB')
163
+ # 简单的图像处理
164
+ image_tensor = image_transform(image).unsqueeze(0).to(self.device)
165
+ # 这里假设预处理器能处理这种输入
166
+ try:
167
+ # process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
168
+ mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
169
+ # 将返回的 segment 列表合并到 input_data
170
+ for seg in mod_segments:
171
+ input_data['segments'].append(seg)
172
+ except Exception as e:
173
+ print(f"Warning: Image processing skipped due to error: {e}")
174
+
175
+ # 添加文本段
176
+ input_data['segments'].append({
177
+ 'type': 'text',
178
+ 'data': input_ids,
179
+ 'modality_id': 0
180
+ })
181
+
182
+ # 生成
183
+ try:
184
+ # 使用模型自带的 generate 方法
185
+ generated_ids = self.model.generate(
186
+ input_data,
187
+ max_new_tokens=max_new_tokens,
188
+ temperature=temperature,
189
+ top_k=top_k,
190
+ top_p=top_p,
191
+ repetition_penalty=repetition_penalty,
192
+ do_sample=True,
193
+ eos_token_id=self.tokenizer.eos_token_id,
194
+ pad_token_id=self.tokenizer.pad_token_id
195
+ )
196
+
197
+ # 3. 解码
198
+ full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
199
+ print(f"\n====== [DEBUG 原始输出] ======\n{full_output}\n==============================\n")
200
+ # 4. [关键修改] 截断逻辑 (Stop Logic)
201
+ # 提取 Response 之后的部分
202
+ if "Response:" in full_output:
203
+ answer = full_output.split("Response:")[-1].strip()
204
+ else:
205
+ answer = full_output
206
+
207
+ # 定义停止词列表 (根据你的图,模型喜欢生成 Instructions: 或 Ingredients:)
208
+ stop_words = [
209
+ "Instruction", "Input", "###", "Response",
210
+ "User:", "Assistant:", "\n\n" # 双换行通常意味着一段结束
211
+ ]
212
+
213
+ for stop_word in stop_words:
214
+ if stop_word in answer:
215
+ answer = answer.split(stop_word)[0].strip()
216
+
217
+ # 3. [新增] 强制去除首行重复 (解决 Echo 问题)
218
+ # 如果模型第一句就是重复 Prompt,去掉它
219
+ lines = answer.split('\n')
220
+ if len(lines) > 0 and prompt.lower() in lines[0].lower():
221
+ answer = "\n".join(lines[1:]).strip()
222
+ return answer
223
+
224
+ except Exception as e:
225
+ print(f"Generation error: {e}")
226
+ import traceback
227
+ traceback.print_exc()
228
+ return f"Error: {str(e)}"
229
+
230
+
231
+ # 全局模型实例
232
+ model_instance = None
233
+ app = Flask(__name__)
234
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
235
+
236
+ @app.route('/')
237
+ def index():
238
+ display_config = model_instance.config.copy() if model_instance else {}
239
+ return render_template('index.html', config=display_config)
240
+
241
+ @app.route('/generate', methods=['POST'])
242
+ def generate():
243
+ try:
244
+ data = request.json
245
+ prompt = data.get('prompt', '')
246
+ if not prompt.strip():
247
+ return jsonify({'error': '请输入提示文本'}), 400
248
+
249
+ max_tokens = int(data.get('max_tokens', 100))
250
+ temperature = float(data.get('temperature', 0.7))
251
+ top_k = int(data.get('top_k', 40))
252
+ top_p = float(data.get('top_p', 0.9))
253
+ repetition_penalty = float(data.get('repetition_penalty', 1.1))
254
+
255
+ image = None
256
+ if 'image' in data and data['image']:
257
+ try:
258
+ image_data = base64.b64decode(data['image'].split(',')[1])
259
+ image = Image.open(io.BytesIO(image_data))
260
+ except Exception as e:
261
+ print(f"Image load error: {e}")
262
+
263
+ output = model_instance.generate_text(
264
+ prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
265
+ )
266
+ return jsonify({'output': output})
267
+
268
+ except Exception as e:
269
+ return jsonify({'error': str(e)}), 500
270
+
271
+ def create_html_template():
272
+ """写入HTML模板"""
273
+ html_content = '''
274
+ <!DOCTYPE html>
275
+ <html lang="zh-CN">
276
+ <head>
277
+ <meta charset="UTF-8">
278
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
279
+ <title>Model Inference</title>
280
+ <style>
281
+ body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
282
+ .container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
283
+ h1 { color: #1a73e8; text-align: center; }
284
+ textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
285
+ .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
286
+ button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
287
+ button:hover { background: #1557b0; }
288
+ button:disabled { background: #ccc; }
289
+ #output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
290
+ .loading { color: #666; font-style: italic; }
291
+ </style>
292
+ </head>
293
+ <body>
294
+ <div class="container">
295
+ <h1>🚀 模型在线推理</h1>
296
+
297
+ <div>
298
+ <label><strong>提示词 (Prompt):</strong></label>
299
+ <textarea id="prompt" placeholder="请输入你的问题..."></textarea>
300
+ </div>
301
+
302
+ <div class="controls">
303
+ <div>
304
+ <label>Max Tokens: <span id="maxTokensVal">128</span></label>
305
+ <input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
306
+ </div>
307
+ <div>
308
+ <label>Temperature: <span id="tempVal">0.7</span></label>
309
+ <input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
310
+ </div>
311
+ </div>
312
+
313
+ <button id="btn" onclick="generate()">生成 (Generate)</button>
314
+
315
+ <div id="output">结果将显示在这里...</div>
316
+ </div>
317
+
318
+ <script>
319
+ async function generate() {
320
+ const prompt = document.getElementById('prompt').value;
321
+ if(!prompt) return alert("请输入内容");
322
+
323
+ const btn = document.getElementById('btn');
324
+ const out = document.getElementById('output');
325
+
326
+ btn.disabled = true;
327
+ btn.innerText = "生成中...";
328
+ out.innerHTML = '<div class="loading">正在思考中...</div>';
329
+
330
+ try {
331
+ const res = await fetch('/generate', {
332
+ method: 'POST',
333
+ headers: {'Content-Type': 'application/json'},
334
+ body: JSON.stringify({
335
+ prompt: prompt,
336
+ max_tokens: parseInt(document.getElementById('maxTokens').value),
337
+ temperature: parseFloat(document.getElementById('temperature').value)
338
+ })
339
+ });
340
+ const data = await res.json();
341
+ if(data.error) out.innerText = "Error: " + data.error;
342
+ else out.innerText = data.output;
343
+ } catch(e) {
344
+ out.innerText = "请求失败: " + e;
345
+ } finally {
346
+ btn.disabled = false;
347
+ btn.innerText = "生成 (Generate)";
348
+ }
349
+ }
350
+ </script>
351
+ </body>
352
+ </html>
353
+ '''
354
+
355
+ Path('templates').mkdir(exist_ok=True)
356
+ with open('templates/index.html', 'w', encoding='utf-8') as f:
357
+ f.write(html_content)
358
+
359
+ def main():
360
+ import argparse
361
+ parser = argparse.ArgumentParser()
362
+ # 默认指向 pretrain 保存的 checkpoint 路径
363
+ parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
364
+ parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
365
+ parser.add_argument("--port", type=int, default=5001)
366
+ parser.add_argument("--host", type=str, default="0.0.0.0")
367
+ args = parser.parse_args()
368
+
369
+ if not Path(args.checkpoint).exists():
370
+ # 尝试找最近的 step checkpoint
371
+ steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
372
+ if steps:
373
+ print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
374
+ args.checkpoint = str(steps[-1])
375
+ else:
376
+ print(f"错误: 找不到检查点文件: {args.checkpoint}")
377
+ return
378
+ # ----------------- 新增部分开始 -----------------
379
+ try:
380
+ from pyngrok import ngrok, conf
381
+
382
+ # 如果你在国内,ngrok 连接慢,可以配置 region='ap' (亚太) 或 'au' (澳洲)
383
+ # conf.get_default().region = "ap"
384
+
385
+ # 建立隧道,映射 5001 端口
386
+ public_url = ngrok.connect(args.port).public_url
387
+ print(f"\n========================================")
388
+ print(f"🎉 公网访问地址 (发给朋友): {public_url}")
389
+ print(f"========================================\n")
390
+ except ImportError:
391
+ print("未安装 pyngrok,无法自动生成公网链接。")
392
+ print("提示: pip install pyngrok")
393
+ except Exception as e:
394
+ print(f"Ngrok 启动失败: {e}")
395
+ # ----------------- 新增部分结束 -----------------
396
+ create_html_template()
397
+
398
+ global model_instance
399
+ model_instance = ModelInference(args.checkpoint, args.tokenizer)
400
+
401
+ print(f"\n服务已启动: http://{args.host}:{args.port}")
402
+ app.run(host=args.host, port=args.port,
403
+ debug=True, # 开启调试模式
404
+ use_reloader=False)
405
+
406
+ if __name__ == "__main__":
407
+ main()
model.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 改进的多模态Dense Transformer主模型
3
+ 整合所有SOTA改进
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import List, Dict, Optional, Tuple
9
+ import math
10
+ from components import RMSNorm
11
+ from transformer import OptimizedTransformerBlock
12
+ from multimodel_fusion import MultiModalFusionModule
13
+ from encoders import (
14
+ ImprovedVisionTransformer,
15
+ ImprovedAudioEncoder,
16
+ ImprovedVideoEncoder
17
+ )
18
+
19
+ class MultiModalDenseTransformer(nn.Module):
20
+ """
21
+ 改进的统一多模态Dense Transformer
22
+ 主要改进:
23
+ 1. 深度跨模态融合
24
+ 2. 模态特定的优化编码器
25
+ 3. 对比学习对齐
26
+ 4. 改进的位置编码和注意力机制
27
+ 5. 更好的训练稳定性
28
+ """
29
+ def __init__(
30
+ self,
31
+ model_dim: int = 2048,
32
+ vocab_size: int = 30000,
33
+ n_layers: int = 48,
34
+ n_heads: int = 32,
35
+ n_kv_heads: Optional[int] = None,
36
+ head_dim: Optional[int] = None,
37
+ max_seq_len: int = 8192,
38
+ dropout: float = 0.0,
39
+ attn_dropout: float = 0.0,
40
+
41
+ # MoE配置
42
+ use_moe: bool = False,
43
+ num_experts: int = 8,
44
+ moe_top_k: int = 2,
45
+ moe_layers: Optional[List[int]] = None,
46
+
47
+ # PEFT配置
48
+ use_adapter: bool = False,
49
+ adapter_dim: int = 64,
50
+ use_lora: bool = False,
51
+ lora_rank: int = 8,
52
+
53
+ # 训练配置
54
+ use_gradient_checkpointing: bool = False,
55
+ use_parallel_residual: bool = False,
56
+
57
+ # 位置编码
58
+ rope_scaling_factor: float = 1.0,
59
+ rope_scaling_type: str = "yarn",
60
+ sliding_window: Optional[int] = None,
61
+
62
+ # 规范化
63
+ norm_eps: float = 1e-6,
64
+ initializer_range: float = 0.02,
65
+ ffn_dim_multiplier: Optional[float] = None,
66
+ tie_word_embeddings: bool = True,
67
+
68
+ # 多模态配置
69
+ use_multimodal_fusion: bool = True,
70
+ fusion_layers: int = 4,
71
+ use_contrastive: bool = True,
72
+ vision_depth: int = 24,
73
+ audio_depth: int = 12,
74
+ video_spatial_depth: int = 12,
75
+ video_temporal_depth: int = 4
76
+ ):
77
+ super().__init__()
78
+
79
+ self.model_dim = model_dim
80
+ self.vocab_size = vocab_size
81
+ self.n_layers = n_layers
82
+ self.max_seq_len = max_seq_len
83
+ self.use_gradient_checkpointing = use_gradient_checkpointing
84
+ self.tie_word_embeddings = tie_word_embeddings
85
+ self.use_multimodal_fusion = use_multimodal_fusion
86
+
87
+ # Token embedding
88
+ self.token_embedding = nn.Embedding(vocab_size, model_dim)
89
+ self.modality_embedding = nn.Embedding(4, model_dim)
90
+ self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
91
+
92
+ # 改进的模态编码器
93
+ self.vision_encoder = ImprovedVisionTransformer(
94
+ embed_dim=model_dim,
95
+ depth=vision_depth,
96
+ n_heads=n_heads,
97
+ dropout=dropout,
98
+ use_adapter=use_adapter,
99
+ adapter_dim=adapter_dim
100
+ )
101
+
102
+ self.audio_encoder = ImprovedAudioEncoder(
103
+ embed_dim=model_dim,
104
+ depth=audio_depth,
105
+ n_heads=n_heads,
106
+ dropout=dropout,
107
+ use_adapter=use_adapter,
108
+ adapter_dim=adapter_dim
109
+ )
110
+
111
+ self.video_encoder = ImprovedVideoEncoder(
112
+ embed_dim=model_dim,
113
+ spatial_depth=video_spatial_depth,
114
+ temporal_depth=video_temporal_depth,
115
+ n_heads=n_heads,
116
+ dropout=dropout,
117
+ use_adapter=use_adapter,
118
+ adapter_dim=adapter_dim
119
+ )
120
+
121
+ # 多模态融合模块
122
+ if use_multimodal_fusion:
123
+ self.fusion_module = MultiModalFusionModule(
124
+ dim=model_dim,
125
+ num_fusion_layers=fusion_layers,
126
+ n_heads=n_heads,
127
+ dropout=dropout,
128
+ use_contrastive=use_contrastive
129
+ )
130
+
131
+ # Transformer layers
132
+ if moe_layers is None and use_moe:
133
+ moe_layers = list(range(n_layers // 2, n_layers))
134
+ elif moe_layers is None:
135
+ moe_layers = []
136
+
137
+ self.layers = nn.ModuleList([
138
+ OptimizedTransformerBlock(
139
+ dim=model_dim,
140
+ n_heads=n_heads,
141
+ n_kv_heads=n_kv_heads,
142
+ head_dim=head_dim,
143
+ dropout=dropout,
144
+ attn_dropout=attn_dropout,
145
+ use_moe=(use_moe and i in moe_layers),
146
+ num_experts=num_experts,
147
+ moe_top_k=moe_top_k,
148
+ use_adapter=use_adapter,
149
+ adapter_dim=adapter_dim,
150
+ use_lora=use_lora,
151
+ lora_rank=lora_rank,
152
+ use_parallel_residual=use_parallel_residual,
153
+ norm_eps=norm_eps,
154
+ sliding_window=sliding_window,
155
+ ffn_dim_multiplier=ffn_dim_multiplier,
156
+ layer_idx=i
157
+ )
158
+ for i in range(n_layers)
159
+ ])
160
+
161
+ self.norm = RMSNorm(model_dim, eps=norm_eps)
162
+ self.lm_head = nn.Linear(model_dim, vocab_size, bias=False)
163
+
164
+ if tie_word_embeddings:
165
+ self.lm_head.weight = self.token_embedding.weight
166
+
167
+ self.initializer_range = initializer_range
168
+ self.apply(self._init_weights)
169
+
170
+ if not tie_word_embeddings:
171
+ self._init_lm_head()
172
+
173
+ self.n_params = sum(p.numel() for p in self.parameters())
174
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
175
+
176
+ print(f"\n{'='*80}")
177
+ print(f"Improved Model Configuration:")
178
+ print(f" Model Dimension: {model_dim}")
179
+ print(f" Vocab Size: {vocab_size}")
180
+ print(f" Layers: {n_layers}")
181
+ print(f" Attention Heads: {n_heads}")
182
+ print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}")
183
+ print(f" Max Sequence Length: {max_seq_len}")
184
+ print(f" Multimodal Fusion: {use_multimodal_fusion}")
185
+ print(f" Contrastive Learning: {use_contrastive}")
186
+ print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})")
187
+ print(f" Total Parameters: {self.n_params / 1e9:.2f}B")
188
+ print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B")
189
+ print(f"{'='*80}\n")
190
+
191
+ def _init_weights(self, module):
192
+ """权重初始化"""
193
+ if isinstance(module, nn.Linear):
194
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
195
+ if module.bias is not None:
196
+ torch.nn.init.zeros_(module.bias)
197
+ elif isinstance(module, nn.Embedding):
198
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
199
+ if hasattr(module, 'padding_idx') and module.padding_idx is not None:
200
+ module.weight.data[module.padding_idx].zero_()
201
+
202
+ def _init_lm_head(self):
203
+ """初始化LM head"""
204
+ std = self.initializer_range / math.sqrt(2 * self.n_layers)
205
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
206
+
207
+ def _encode_modality(self, segment: Dict) -> torch.Tensor:
208
+ """编码单个模态"""
209
+ seg_type = segment['type']
210
+ seg_data = segment['data']
211
+
212
+ if seg_type == 'image':
213
+ return self.vision_encoder(seg_data)
214
+ elif seg_type == 'audio':
215
+ return self.audio_encoder(seg_data)
216
+ elif seg_type == 'video':
217
+ return self.video_encoder(seg_data)
218
+ elif seg_type == 'text':
219
+ return self.token_embedding(seg_data)
220
+ else:
221
+ return seg_data
222
+
223
+ def forward(
224
+ self,
225
+ input_data: Dict,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ position_ids: Optional[torch.Tensor] = None,
228
+ return_hidden: bool = False,
229
+ use_cache: bool = False,
230
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
231
+ output_attentions: bool = False,
232
+ output_hidden_states: bool = False,
233
+ compute_contrastive: bool = False
234
+ ) -> Dict:
235
+ """前向传播"""
236
+ device = self.token_embedding.weight.device
237
+
238
+ # 编码每个模态
239
+ encoded_segments = []
240
+ for segment in input_data.get('segments', []):
241
+ encoded = self._encode_modality(segment)
242
+
243
+ # 添加模态嵌入
244
+ modality_id = segment.get('modality_id', 0)
245
+ modality_embeds = self.modality_embedding(
246
+ torch.tensor([modality_id], device=device)
247
+ ).expand(encoded.shape[0], encoded.shape[1], -1)
248
+
249
+ encoded_segments.append({
250
+ 'type': segment['type'],
251
+ 'data': encoded + modality_embeds,
252
+ 'modality_id': modality_id
253
+ })
254
+
255
+ # 多模态融合
256
+ contrastive_losses = {}
257
+ if self.use_multimodal_fusion and len(encoded_segments) > 1:
258
+ fusion_output = self.fusion_module(
259
+ encoded_segments,
260
+ compute_contrastive=compute_contrastive
261
+ )
262
+ x = fusion_output['fused_features']
263
+ contrastive_losses = fusion_output.get('contrastive_losses', {})
264
+ else:
265
+ # 简单拼接
266
+ all_embeddings = [seg['data'] for seg in encoded_segments]
267
+ x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros(
268
+ 1, 1, self.model_dim, device=device
269
+ )
270
+
271
+ x = self.embed_dropout(x)
272
+ # 如果��有传入 position_ids,我们需要根据历史长度生成它
273
+ if position_ids is None:
274
+ if past_key_values is not None:
275
+ # 缓存的长度 (KV cache 的 shape 是 [B, H, SeqLen, D])
276
+ past_length = past_key_values[0][0].size(2)
277
+ # 当前输入的长度
278
+ seq_length = x.shape[1]
279
+ # 生成正确的位置索引: [past_length, past_length + 1, ...]
280
+ position_ids = torch.arange(
281
+ past_length, past_length + seq_length, dtype=torch.long, device=device
282
+ ).unsqueeze(0).expand(x.shape[0], -1)
283
+ else:
284
+ # 如果没有缓存,从 0 开始
285
+ seq_length = x.shape[1]
286
+ position_ids = torch.arange(
287
+ 0, seq_length, dtype=torch.long, device=device
288
+ ).unsqueeze(0).expand(x.shape[0], -1)
289
+ # Transformer层
290
+ present_key_values = [] if use_cache else None
291
+ all_hidden_states = [] if output_hidden_states else None
292
+ all_attentions = [] if output_attentions else None
293
+ moe_aux_loss = torch.tensor(0.0, device=device)
294
+
295
+ for idx, layer in enumerate(self.layers):
296
+ if output_hidden_states:
297
+ all_hidden_states.append(x)
298
+
299
+ past_kv = past_key_values[idx] if past_key_values is not None else None
300
+
301
+ if self.use_gradient_checkpointing and self.training:
302
+ def create_custom_forward(module):
303
+ def custom_forward(*inputs):
304
+ return module(
305
+ inputs[0],
306
+ attention_mask=inputs[1],
307
+ position_ids=inputs[2],
308
+ use_cache=False,
309
+ past_kv=None,
310
+ output_attentions=False
311
+ )
312
+ return custom_forward
313
+
314
+ import torch.utils.checkpoint as checkpoint
315
+ layer_outputs = checkpoint.checkpoint(
316
+ create_custom_forward(layer),
317
+ x,
318
+ attention_mask,
319
+ position_ids,
320
+ use_reentrant=False
321
+ )
322
+ x = layer_outputs[0]
323
+ present_kv = None
324
+ attn_weights = None
325
+ else:
326
+ layer_outputs = layer(
327
+ x,
328
+ attention_mask=attention_mask,
329
+ position_ids=position_ids,
330
+ use_cache=use_cache,
331
+ past_kv=past_kv,
332
+ output_attentions=output_attentions
333
+ )
334
+ x, present_kv, attn_weights = layer_outputs
335
+
336
+ if use_cache:
337
+ present_key_values.append(present_kv)
338
+
339
+ if output_attentions:
340
+ all_attentions.append(attn_weights)
341
+
342
+ if hasattr(layer, 'moe_aux_loss'):
343
+ moe_aux_loss += layer.moe_aux_loss
344
+
345
+ hidden_states = self.norm(x)
346
+ logits = self.lm_head(hidden_states)
347
+
348
+ if output_hidden_states:
349
+ all_hidden_states.append(hidden_states)
350
+
351
+ # 组装输出
352
+ outputs = {
353
+ 'logits': logits,
354
+ 'moe_aux_loss': moe_aux_loss,
355
+ 'contrastive_losses': contrastive_losses
356
+ }
357
+
358
+ if use_cache:
359
+ outputs['past_key_values'] = present_key_values
360
+
361
+ if output_hidden_states:
362
+ outputs['hidden_states'] = all_hidden_states
363
+
364
+ if output_attentions:
365
+ outputs['attentions'] = all_attentions
366
+
367
+ if return_hidden:
368
+ outputs['last_hidden_state'] = hidden_states
369
+
370
+ return outputs
371
+
372
+ @torch.no_grad()
373
+ def generate(
374
+ self,
375
+ input_data: Dict,
376
+ max_new_tokens: int = 100,
377
+ temperature: float = 1.0,
378
+ top_k: int = 50,
379
+ top_p: float = 0.9,
380
+ eos_token_id: int = 2,
381
+ pad_token_id: Optional[int] = None,
382
+ use_cache: bool = True,
383
+ repetition_penalty: float = 1.0,
384
+ length_penalty: float = 1.0,
385
+ min_length: int = 0,
386
+ do_sample: bool = True,
387
+ num_beams: int = 1
388
+ ) -> torch.Tensor:
389
+ """改进的生成方法"""
390
+ self.eval()
391
+ device = next(self.parameters()).device
392
+
393
+ if pad_token_id is None:
394
+ pad_token_id = eos_token_id
395
+
396
+ initial_text_tokens = input_data['segments'][0]['data'].to(device)
397
+ batch_size = initial_text_tokens.shape[0]
398
+
399
+ if 'attention_mask' in input_data:
400
+ attention_mask = input_data['attention_mask'].to(device)
401
+ else:
402
+ attention_mask = torch.ones_like(initial_text_tokens)
403
+ initial_seq_len = initial_text_tokens.shape[1]
404
+ position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device)
405
+
406
+ for i in range(batch_size):
407
+ non_pad_mask = attention_mask[i].bool()
408
+ if non_pad_mask.any():
409
+ positions = torch.cumsum(non_pad_mask.long(),dim=0) -1
410
+ position_ids[i]=positions * non_pad_mask.long()
411
+
412
+
413
+
414
+ generated_tokens = []
415
+ past_key_values = None
416
+ current_tokens = initial_text_tokens
417
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
418
+
419
+ for step in range(max_new_tokens):
420
+ current_input_data = {
421
+ 'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}]
422
+ }
423
+
424
+ if step > 0 and use_cache:
425
+ # 添加当前 token 的 mask (1)
426
+ new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device)
427
+ attention_mask = torch.cat([attention_mask, new_mask], dim=1)
428
+ current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0)
429
+ current_positions_ids=current_positions
430
+ else:
431
+ current_positions_ids=position_ids
432
+ outputs = self.forward(
433
+ current_input_data,
434
+ attention_mask=attention_mask, # <--- 传入 Mask
435
+ position_ids=current_positions_ids,
436
+ use_cache=use_cache,
437
+ past_key_values=past_key_values
438
+ )
439
+
440
+ logits = outputs['logits']
441
+ if use_cache:
442
+ past_key_values = outputs['past_key_values']
443
+
444
+ next_token_logits = logits[:, -1, :] / max(temperature, 1e-5)
445
+
446
+ # Repetition penalty
447
+ if repetition_penalty != 1.0 and len(generated_tokens) > 0:
448
+ prev_generated = torch.cat(generated_tokens, dim=1)
449
+ score = torch.gather(next_token_logits, 1, prev_generated)
450
+ score = torch.where(
451
+ score < 0,
452
+ score * repetition_penalty,
453
+ score / repetition_penalty
454
+ )
455
+ next_token_logits.scatter_(1, prev_generated, score)
456
+
457
+ # Min length constraint
458
+ if step < min_length:
459
+ next_token_logits[:, eos_token_id] = float('-inf')
460
+
461
+ # Sampling
462
+ if do_sample:
463
+ if top_k > 0:
464
+ top_k_vals, _ = torch.topk(next_token_logits, top_k)
465
+ min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1)
466
+ next_token_logits[next_token_logits < min_val_to_keep] = float('-inf')
467
+
468
+ if top_p < 1.0:
469
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
470
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
471
+ sorted_indices_to_remove = cumulative_probs > top_p
472
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
473
+ sorted_indices_to_remove[..., 0] = 0
474
+ indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
475
+ indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
476
+ next_token_logits[indices_to_remove] = float('-inf')
477
+
478
+ probs = F.softmax(next_token_logits, dim=-1)
479
+ next_token = torch.multinomial(probs, num_samples=1)
480
+ else:
481
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
482
+
483
+ # Apply unfinished mask
484
+ next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None])
485
+
486
+ generated_tokens.append(next_token)
487
+
488
+ if not use_cache:
489
+ initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1)
490
+ current_tokens = initial_text_tokens
491
+ else:
492
+ current_tokens = next_token
493
+
494
+ # Update unfinished sequences
495
+ unfinished_sequences = unfinished_sequences.mul(
496
+ (next_token.squeeze(-1) != eos_token_id).long()
497
+ )
498
+
499
+ if unfinished_sequences.max() == 0:
500
+ break
501
+
502
+ if not generated_tokens:
503
+ return torch.empty(batch_size, 0, dtype=torch.long, device=device)
504
+
505
+ return torch.cat(generated_tokens, dim=1)
moe.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 优化的混合专家系统 (Mixture of Experts)
3
+ 基于Mixtral、Switch Transformer、GLaM的最佳实践
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Tuple, Optional, List
9
+ import math
10
+
11
+ class Expert(nn.Module):
12
+ """
13
+ 单个专家网络
14
+ 使用SwiGLU激活函数以获得更好的性能
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ hidden_dim: int,
21
+ dropout: float = 0.0,
22
+ bias: bool = False
23
+ ):
24
+ super().__init__()
25
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
26
+ self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
27
+ self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
28
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
29
+
30
+ self._init_weights()
31
+
32
+ def _init_weights(self):
33
+ """改进的权重初始化"""
34
+ for module in [self.w1, self.w2, self.w3]:
35
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
36
+ if module.bias is not None:
37
+ nn.init.zeros_(module.bias)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ 前向传播
42
+ SwiGLU: (Swish(W1·x) ⊙ W3·x) W2
43
+ """
44
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
45
+
46
+ class TopKRouter(nn.Module):
47
+ """
48
+ Top-K路由器 - 改进版
49
+ 改进点:
50
+ 1. 专家容量管理
51
+ 2. 负载均衡
52
+ 3. 训练时的噪声注入
53
+ 4. Z-loss防止logits爆炸
54
+
55
+ 参考:
56
+ - Switch Transformer
57
+ - Mixtral 8x7B
58
+ - ST-MoE
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ dim: int,
64
+ num_experts: int,
65
+ top_k: int = 2,
66
+ capacity_factor: float = 1.25,
67
+ noise_std: float = 1.0,
68
+ use_expert_capacity: bool = True,
69
+ router_z_loss_coef: float = 0.001,
70
+ router_aux_loss_coef: float = 0.01
71
+ ):
72
+ super().__init__()
73
+ self.num_experts = num_experts
74
+ self.top_k = top_k
75
+ self.capacity_factor = capacity_factor
76
+ self.noise_std = noise_std
77
+ self.use_expert_capacity = use_expert_capacity
78
+ self.router_z_loss_coef = router_z_loss_coef
79
+ self.router_aux_loss_coef = router_aux_loss_coef
80
+
81
+ self.gate = nn.Linear(dim, num_experts, bias=False)
82
+
83
+ nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
84
+
85
+ def _compute_routing_weights(
86
+ self,
87
+ logits: torch.Tensor,
88
+ use_noise: bool = True
89
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ """
91
+ 计算路由权重
92
+
93
+ Args:
94
+ logits: 路由logits [batch*seq_len, num_experts]
95
+ use_noise: 是否添加噪声
96
+
97
+ Returns:
98
+ top_k_gates: Top-K门控值 [batch*seq_len, top_k]
99
+ top_k_indices: Top-K专家索引 [batch*seq_len, top_k]
100
+ """
101
+ if use_noise and self.training:
102
+ noise = torch.randn_like(logits) * self.noise_std
103
+ logits = logits + noise
104
+
105
+ top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
106
+
107
+ top_k_gates = F.softmax(top_k_logits, dim=-1)
108
+
109
+ return top_k_gates, top_k_indices
110
+
111
+ def _compute_auxiliary_loss(
112
+ self,
113
+ logits: torch.Tensor,
114
+ top_k_indices: torch.Tensor
115
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
116
+ """
117
+ 计算辅助损失
118
+
119
+ 包括:
120
+ 1. 负载均衡损失(确保专家被均匀使用)
121
+ 2. Z-loss(防止logits过大)
122
+
123
+ Args:
124
+ logits: 路由logits [batch*seq_len, num_experts]
125
+ top_k_indices: 选中的专家索引 [batch*seq_len, top_k]
126
+
127
+ Returns:
128
+ load_balance_loss: 负载均衡损失
129
+ z_loss: Z-loss
130
+ """
131
+ num_tokens = logits.shape[0]
132
+
133
+ router_probs = F.softmax(logits, dim=-1)
134
+
135
+ expert_probs = router_probs.mean(dim=0)
136
+
137
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
138
+ expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k)
139
+
140
+ load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq)
141
+
142
+ z_loss = torch.mean(logits ** 2)
143
+
144
+ return load_balance_loss, z_loss
145
+
146
+ def forward(
147
+ self,
148
+ x: torch.Tensor
149
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
150
+ """
151
+ 前向传播
152
+
153
+ Args:
154
+ x: 输入 [batch*seq_len, dim]
155
+
156
+ Returns:
157
+ top_k_gates: 门控权重 [batch*seq_len, top_k]
158
+ top_k_indices: 专家索引 [batch*seq_len, top_k]
159
+ auxiliary_loss: 辅助损失(标量)
160
+ """
161
+ logits = self.gate(x)
162
+
163
+ top_k_gates, top_k_indices = self._compute_routing_weights(
164
+ logits, use_noise=self.training
165
+ )
166
+
167
+ if self.training:
168
+ load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices)
169
+ auxiliary_loss = (
170
+ self.router_aux_loss_coef * load_balance_loss +
171
+ self.router_z_loss_coef * z_loss
172
+ )
173
+ else:
174
+ auxiliary_loss = torch.tensor(0.0, device=x.device)
175
+
176
+ return top_k_gates, top_k_indices, auxiliary_loss
177
+
178
+ class MixtureOfExperts(nn.Module):
179
+ """
180
+ 混合专家层 - 优化版
181
+ 改进点:
182
+ 1. 高效的token分发和聚合
183
+ 2. 专家容量管理
184
+ 3. 改进的负载均衡
185
+ 4. 支持专家并行
186
+
187
+ 参考:
188
+ - Mixtral 8x7B
189
+ - Switch Transformer
190
+ - GShard
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ dim: int,
196
+ num_experts: int = 8,
197
+ expert_hidden_dim: Optional[int] = None,
198
+ top_k: int = 2,
199
+ dropout: float = 0.0,
200
+ capacity_factor: float = 1.25,
201
+ use_expert_capacity: bool = True,
202
+ router_z_loss_coef: float = 0.001,
203
+ router_aux_loss_coef: float = 0.01,
204
+ noise_std: float = 1.0,
205
+ ffn_dim_multiplier: Optional[float] = None
206
+ ):
207
+ super().__init__()
208
+ self.num_experts = num_experts
209
+ self.top_k = top_k
210
+ self.capacity_factor = capacity_factor
211
+ self.use_expert_capacity = use_expert_capacity
212
+
213
+ if expert_hidden_dim is None:
214
+ if ffn_dim_multiplier is not None:
215
+ expert_hidden_dim = int(dim * ffn_dim_multiplier)
216
+ else:
217
+ expert_hidden_dim = int(2 * dim * 4 / 3)
218
+ expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
219
+
220
+ self.experts = nn.ModuleList([
221
+ Expert(dim, expert_hidden_dim, dropout, bias=False)
222
+ for _ in range(num_experts)
223
+ ])
224
+
225
+ self.router = TopKRouter(
226
+ dim=dim,
227
+ num_experts=num_experts,
228
+ top_k=top_k,
229
+ capacity_factor=capacity_factor,
230
+ noise_std=noise_std,
231
+ use_expert_capacity=use_expert_capacity,
232
+ router_z_loss_coef=router_z_loss_coef,
233
+ router_aux_loss_coef=router_aux_loss_coef
234
+ )
235
+
236
+ def _compute_expert_capacity(self, num_tokens: int) -> int:
237
+ """计算每个专家的容量"""
238
+ if not self.use_expert_capacity:
239
+ return num_tokens
240
+
241
+ capacity = int(
242
+ (num_tokens / self.num_experts) * self.capacity_factor * self.top_k
243
+ )
244
+ return max(capacity, 1)
245
+
246
+ def forward(
247
+ self,
248
+ x: torch.Tensor
249
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
250
+ """
251
+ 前向传播
252
+
253
+ Args:
254
+ x: 输入 [batch, seq_len, dim]
255
+
256
+ Returns:
257
+ output: 输出 [batch, seq_len, dim]
258
+ auxiliary_loss: 辅助损失
259
+ """
260
+ B, T, D = x.shape
261
+ num_tokens = B * T
262
+
263
+ x_flat = x.view(-1, D)
264
+
265
+ top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat)
266
+
267
+ output = torch.zeros_like(x_flat)
268
+
269
+ expert_capacity = self._compute_expert_capacity(num_tokens)
270
+
271
+ for expert_idx, expert in enumerate(self.experts):
272
+ expert_mask = (top_k_indices == expert_idx)
273
+
274
+ token_indices, topk_positions = torch.where(expert_mask)
275
+
276
+ if len(token_indices) == 0:
277
+ continue
278
+
279
+ if self.use_expert_capacity and len(token_indices) > expert_capacity:
280
+ perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity]
281
+ token_indices = token_indices[perm]
282
+ topk_positions = topk_positions[perm]
283
+
284
+ expert_input = x_flat[token_indices]
285
+ expert_gates = top_k_gates[token_indices, topk_positions]
286
+
287
+ expert_output = expert(expert_input)
288
+
289
+ expert_output = expert_output * expert_gates.unsqueeze(-1)
290
+
291
+ output.index_add_(0, token_indices, expert_output)
292
+
293
+ output = output.view(B, T, D)
294
+
295
+ return output, auxiliary_loss
296
+
297
+ class SparseDispatcher:
298
+ """
299
+ 稀疏分发器 - 用于高效的MoE计算
300
+ 管理tokens到专家的分配和聚合
301
+ 这是一个可选的辅助类,用于更高效的实现
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ num_experts: int,
307
+ gates: torch.Tensor,
308
+ expert_indices: torch.Tensor
309
+ ):
310
+ """
311
+ Args:
312
+ num_experts: 专家数量
313
+ gates: 门控权重 [batch_size, num_experts]
314
+ expert_indices: 专家索引 [batch_size]
315
+ """
316
+ self.num_experts = num_experts
317
+ self._gates = gates
318
+ self._expert_indices = expert_indices
319
+
320
+ self._expert_masks = []
321
+ for i in range(num_experts):
322
+ self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0])
323
+
324
+ def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
325
+ """
326
+ 将输入分发给各个专家
327
+
328
+ Args:
329
+ inp: 输入张量 [batch_size, dim]
330
+
331
+ Returns:
332
+ expert_inputs: 每个专家的输入列表
333
+ """
334
+ expert_inputs = []
335
+ for mask in self._expert_masks:
336
+ if len(mask) > 0:
337
+ expert_inputs.append(inp[mask])
338
+ else:
339
+ expert_inputs.append(
340
+ torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype)
341
+ )
342
+ return expert_inputs
343
+
344
+ def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor:
345
+ """
346
+ 组合专家输出
347
+
348
+ Args:
349
+ expert_outputs: 每个专家的输出列表
350
+
351
+ Returns:
352
+ output: 组合后的输出 [batch_size, dim]
353
+ """
354
+ output_shape = (self._gates.size(0), expert_outputs[0].size(-1))
355
+ output = torch.zeros(
356
+ output_shape,
357
+ device=self._gates.device,
358
+ dtype=expert_outputs[0].dtype
359
+ )
360
+
361
+ for expert_idx, expert_out in enumerate(expert_outputs):
362
+ mask = self._expert_masks[expert_idx]
363
+ if len(mask) > 0:
364
+ weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1)
365
+ output[mask] += weighted_output
366
+
367
+ return output
368
+
369
+ def expert_to_gates(self) -> List[torch.Tensor]:
370
+ """
371
+ 返回每个专家对应的门控权重
372
+
373
+ Returns:
374
+ gates_per_expert: 每个专家的门控权重列表
375
+ """
376
+ gates_per_expert = []
377
+ for expert_idx in range(self.num_experts):
378
+ mask = self._expert_masks[expert_idx]
379
+ if len(mask) > 0:
380
+ gates_per_expert.append(self._gates[mask, expert_idx])
381
+ else:
382
+ gates_per_expert.append(torch.empty(0, device=self._gates.device))
383
+ return gates_per_expert
384
+
385
+ class MoELayer(nn.Module):
386
+ """
387
+ MoE层的另一种实现方式
388
+ 使用SparseDispatcher进行更高效的计算
389
+ """
390
+ def __init__(
391
+ self,
392
+ dim: int,
393
+ num_experts: int = 8,
394
+ expert_hidden_dim: Optional[int] = None,
395
+ top_k: int = 2,
396
+ dropout: float = 0.0,
397
+ capacity_factor: float = 1.25
398
+ ):
399
+ super().__init__()
400
+ self.num_experts = num_experts
401
+ self.top_k = top_k
402
+
403
+ if expert_hidden_dim is None:
404
+ expert_hidden_dim = int(2 * dim * 4 / 3)
405
+ expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
406
+
407
+ self.experts = nn.ModuleList([
408
+ Expert(dim, expert_hidden_dim, dropout)
409
+ for _ in range(num_experts)
410
+ ])
411
+
412
+ self.gate = nn.Linear(dim, num_experts, bias=False)
413
+ nn.init.normal_(self.gate.weight, std=0.02)
414
+
415
+ self.capacity_factor = capacity_factor
416
+
417
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
418
+ """
419
+ 前向传播使用SparseDispatcher
420
+
421
+ Args:
422
+ x: 输入 [batch, seq_len, dim]
423
+
424
+ Returns:
425
+ output: 输出 [batch, seq_len, dim]
426
+ aux_loss: 辅助损失
427
+ """
428
+ B, T, D = x.shape
429
+ x_flat = x.view(-1, D)
430
+
431
+ gates = F.softmax(self.gate(x_flat), dim=-1)
432
+
433
+ top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1)
434
+ top_k_gates = F.softmax(top_k_gates, dim=-1)
435
+
436
+ expert_probs = gates.mean(dim=0)
437
+ expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1])
438
+ expert_counts = expert_counts / (B * T * self.top_k)
439
+ aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts)
440
+
441
+ output = torch.zeros_like(x_flat)
442
+
443
+ for expert_idx, expert in enumerate(self.experts):
444
+ expert_mask = (top_k_indices == expert_idx)
445
+ token_indices, topk_positions = torch.where(expert_mask)
446
+
447
+ if len(token_indices) == 0:
448
+ continue
449
+
450
+ expert_input = x_flat[token_indices]
451
+ expert_gates = top_k_gates[token_indices, topk_positions]
452
+
453
+ expert_output = expert(expert_input)
454
+ expert_output = expert_output * expert_gates.unsqueeze(-1)
455
+
456
+ output.index_add_(0, token_indices, expert_output)
457
+
458
+ output = output.view(B, T, D)
459
+
460
+ return output, aux_loss
multimodel_fusion.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 跨模态融合模块 - SOTA级别
3
+ 支持深度跨模态交互、对比学习、模态对齐
4
+ 修复版本:解决了所有接口不匹配和潜在bug
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+ from components import RMSNorm
11
+ from transformer import GroupedQueryAttention
12
+ import math
13
+ from contrastive_learning import MultiModalContrastiveLoss
14
+
15
+
16
+ class CrossModalAttention(nn.Module):
17
+ """跨模态注意力 - 允许不同模态之间的信息交互"""
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ n_heads: int = 16,
22
+ dropout: float = 0.1,
23
+ qkv_bias: bool = True
24
+ ):
25
+ super().__init__()
26
+ self.dim = dim
27
+ self.n_heads = n_heads
28
+ self.head_dim = dim // n_heads
29
+ self.scale = self.head_dim ** -0.5
30
+
31
+ assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
32
+
33
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
34
+ self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
35
+ self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
36
+ self.o_proj = nn.Linear(dim, dim)
37
+
38
+ self.attn_dropout = nn.Dropout(dropout)
39
+ self.resid_dropout = nn.Dropout(dropout)
40
+
41
+ self.norm_q = RMSNorm(dim)
42
+ self.norm_k = RMSNorm(dim)
43
+
44
+ def forward(
45
+ self,
46
+ query: torch.Tensor,
47
+ key: torch.Tensor,
48
+ value: torch.Tensor,
49
+ attention_mask: Optional[torch.Tensor] = None
50
+ ) -> torch.Tensor:
51
+ """
52
+ Args:
53
+ query: [B, T_q, D] - 查询模态
54
+ key: [B, T_k, D] - 键模态
55
+ value: [B, T_v, D] - 值模态 (通常与key相同)
56
+ """
57
+ B, T_q, D = query.shape
58
+ T_k = key.shape[1]
59
+
60
+ # 归一化
61
+ query = self.norm_q(query)
62
+ key = self.norm_k(key)
63
+
64
+ # 投影并重塑
65
+ q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
66
+ k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
67
+ v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
68
+
69
+ # 使用Flash Attention或手动实现
70
+ if hasattr(F, 'scaled_dot_product_attention'):
71
+ dropout_p = self.attn_dropout.p if self.training else 0.0
72
+ attn_output = F.scaled_dot_product_attention(
73
+ q, k, v,
74
+ attn_mask=attention_mask,
75
+ dropout_p=dropout_p,
76
+ is_causal=False
77
+ )
78
+ else:
79
+ attn_scores = (q @ k.transpose(-2, -1)) * self.scale
80
+ if attention_mask is not None:
81
+ attn_scores = attn_scores + attention_mask
82
+ attn_weights = F.softmax(attn_scores, dim=-1)
83
+ attn_weights = self.attn_dropout(attn_weights)
84
+ attn_output = attn_weights @ v
85
+
86
+ # 重塑并投影输出
87
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D)
88
+ output = self.resid_dropout(self.o_proj(attn_output))
89
+
90
+ return output
91
+
92
+
93
+ class ModalityProjector(nn.Module):
94
+ """模态投影器 - 将不同模态投影到统一空间"""
95
+ def __init__(
96
+ self,
97
+ input_dim: int,
98
+ output_dim: int,
99
+ hidden_dim: Optional[int] = None,
100
+ num_layers: int = 2,
101
+ use_layer_norm: bool = True
102
+ ):
103
+ super().__init__()
104
+ if hidden_dim is None:
105
+ hidden_dim = (input_dim + output_dim) // 2
106
+
107
+ layers = []
108
+ for i in range(num_layers):
109
+ if i == 0:
110
+ layers.append(nn.Linear(input_dim, hidden_dim))
111
+ elif i == num_layers - 1:
112
+ layers.append(nn.Linear(hidden_dim, output_dim))
113
+ else:
114
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
115
+
116
+ if i < num_layers - 1:
117
+ if use_layer_norm:
118
+ layers.append(RMSNorm(hidden_dim))
119
+ layers.append(nn.GELU())
120
+
121
+ self.projector = nn.Sequential(*layers)
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ return self.projector(x)
125
+
126
+
127
+ class ModalityAdapter(nn.Module):
128
+ """模态适配器 - 为每个模态学习特定的适配参数"""
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ bottleneck_dim: int = 64,
133
+ num_modalities: int = 4
134
+ ):
135
+ super().__init__()
136
+ self.adapters = nn.ModuleList([
137
+ nn.Sequential(
138
+ nn.Linear(dim, bottleneck_dim),
139
+ nn.GELU(),
140
+ nn.Linear(bottleneck_dim, dim)
141
+ )
142
+ for _ in range(num_modalities)
143
+ ])
144
+ # 初始化为零,确保开始时是恒等映射
145
+ for adapter in self.adapters:
146
+ nn.init.zeros_(adapter[-1].weight)
147
+ nn.init.zeros_(adapter[-1].bias)
148
+
149
+ def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor:
150
+ if modality_id >= len(self.adapters):
151
+ return x
152
+ return x + self.adapters[modality_id](x)
153
+
154
+
155
+ class CrossModalFusionLayer(nn.Module):
156
+ """跨模态融合层"""
157
+ def __init__(
158
+ self,
159
+ dim: int,
160
+ n_heads: int = 16,
161
+ dropout: float = 0.1,
162
+ use_adapter: bool = True,
163
+ adapter_dim: int = 64
164
+ ):
165
+ super().__init__()
166
+ self.dim = dim
167
+ self.use_adapter = use_adapter
168
+
169
+ # 自注意力
170
+ self.self_attn = GroupedQueryAttention(
171
+ dim=dim,
172
+ n_heads=n_heads,
173
+ dropout=dropout,
174
+ attn_dropout=dropout
175
+ )
176
+
177
+ # 跨模态注意力
178
+ self.cross_attn = CrossModalAttention(
179
+ dim=dim,
180
+ n_heads=n_heads,
181
+ dropout=dropout
182
+ )
183
+
184
+ # 前馈网络
185
+ self.ffn = nn.Sequential(
186
+ nn.Linear(dim, dim * 4),
187
+ nn.GELU(),
188
+ nn.Dropout(dropout),
189
+ nn.Linear(dim * 4, dim),
190
+ nn.Dropout(dropout)
191
+ )
192
+
193
+ # 归一化层
194
+ self.norm1 = RMSNorm(dim)
195
+ self.norm2 = RMSNorm(dim)
196
+ self.norm3 = RMSNorm(dim)
197
+
198
+ # 模态适配器
199
+ if use_adapter:
200
+ self.adapter = ModalityAdapter(dim, adapter_dim)
201
+ else:
202
+ self.adapter = None
203
+
204
+ def forward(
205
+ self,
206
+ x: torch.Tensor,
207
+ context: Optional[torch.Tensor] = None,
208
+ modality_id: Optional[int] = None,
209
+ attention_mask: Optional[torch.Tensor] = None
210
+ ) -> torch.Tensor:
211
+ """
212
+ Args:
213
+ x: 当前模态特征 [B, T, D]
214
+ context: 其他模态的上下文 [B, T_ctx, D]
215
+ modality_id: 模态ID(用于adapter)
216
+ attention_mask: 注意力掩码
217
+ """
218
+ # 自注意力 - 返回 (output, present_kv, attention_weights)
219
+ attn_out = self.self_attn(
220
+ self.norm1(x),
221
+ attention_mask=attention_mask
222
+ )[0] # 只取输出
223
+ x = x + attn_out
224
+
225
+ # 跨模态注意力(如果有上下文)
226
+ if context is not None:
227
+ cross_attn_out = self.cross_attn(
228
+ self.norm2(x),
229
+ context,
230
+ context,
231
+ attention_mask=None
232
+ )
233
+ x = x + cross_attn_out
234
+
235
+ # 前馈网络
236
+ x = x + self.ffn(self.norm3(x))
237
+
238
+ # 模态适配器
239
+ if self.use_adapter and modality_id is not None and self.adapter is not None:
240
+ x = self.adapter(x, modality_id)
241
+
242
+ return x
243
+
244
+
245
+ class PerceiverResampler(nn.Module):
246
+ """Perceiver Resampler - 压缩模态特征到固定数量的tokens"""
247
+ def __init__(
248
+ self,
249
+ dim: int,
250
+ depth: int = 6,
251
+ num_latents: int = 64,
252
+ n_heads: int = 16,
253
+ dropout: float = 0.0
254
+ ):
255
+ super().__init__()
256
+ self.num_latents = num_latents
257
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
258
+
259
+ self.layers = nn.ModuleList([
260
+ CrossModalFusionLayer(
261
+ dim=dim,
262
+ n_heads=n_heads,
263
+ dropout=dropout,
264
+ use_adapter=False
265
+ )
266
+ for _ in range(depth)
267
+ ])
268
+
269
+ self.norm = RMSNorm(dim)
270
+
271
+ # 初始化latents
272
+ nn.init.trunc_normal_(self.latents, std=0.02)
273
+
274
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
275
+ """
276
+ Args:
277
+ x: [B, T, D] - 输入特征
278
+ Returns:
279
+ [B, num_latents, D] - 压缩后的特征
280
+ """
281
+ B = x.shape[0]
282
+ latents = self.latents.unsqueeze(0).expand(B, -1, -1)
283
+
284
+ # 通过多层交叉注意力处理
285
+ for layer in self.layers:
286
+ latents = layer(latents, context=x)
287
+
288
+ return self.norm(latents)
289
+
290
+
291
+ class MultiModalFusionModule(nn.Module):
292
+ """多模态融合模块 - 整合所有融合策略"""
293
+ def __init__(
294
+ self,
295
+ dim: int = 2048,
296
+ num_fusion_layers: int = 4,
297
+ n_heads: int = 16,
298
+ dropout: float = 0.1,
299
+ use_perceiver: bool = True,
300
+ num_latents: int = 64,
301
+ use_contrastive: bool = True,
302
+ contrastive_loss_type: str = 'siglip',
303
+ contrastive_embed_dim: int = 512
304
+ ):
305
+ super().__init__()
306
+ self.dim = dim
307
+ self.use_perceiver = use_perceiver
308
+ self.use_contrastive = use_contrastive
309
+
310
+ # 模态投影器
311
+ self.modality_projectors = nn.ModuleDict({
312
+ 'image': ModalityProjector(dim, dim),
313
+ 'audio': ModalityProjector(dim, dim),
314
+ 'video': ModalityProjector(dim, dim),
315
+ 'text': ModalityProjector(dim, dim)
316
+ })
317
+
318
+ # 跨模态融合层
319
+ self.fusion_layers = nn.ModuleList([
320
+ CrossModalFusionLayer(
321
+ dim=dim,
322
+ n_heads=n_heads,
323
+ dropout=dropout,
324
+ use_adapter=True
325
+ )
326
+ for _ in range(num_fusion_layers)
327
+ ])
328
+
329
+ # Perceiver Resampler
330
+ if use_perceiver:
331
+ self.perceiver = PerceiverResampler(
332
+ dim=dim,
333
+ depth=4,
334
+ num_latents=num_latents,
335
+ n_heads=n_heads,
336
+ dropout=dropout
337
+ )
338
+
339
+ # 对比学习模块
340
+ if use_contrastive:
341
+ # 定义每个模态的输入维度和池化类型
342
+ modality_config = {
343
+ 'text': 'cls',
344
+ 'image': 'cls',
345
+ 'audio': 'mean',
346
+ 'video': 'mean'
347
+ }
348
+
349
+ input_dims = {k: dim for k in modality_config.keys()}
350
+
351
+ self.contrastive_module = MultiModalContrastiveLoss(
352
+ embed_dim=contrastive_embed_dim,
353
+ input_dims=input_dims,
354
+ temperature=0.07,
355
+ loss_type=contrastive_loss_type,
356
+ modality_config=modality_config
357
+ )
358
+
359
+ self.final_norm = RMSNorm(dim)
360
+
361
+ def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
362
+ """池化特征到单一向量 [B, T, D] -> [B, D]"""
363
+ if features.dim() == 3:
364
+ return features.mean(dim=1)
365
+ return features
366
+
367
+ def forward(
368
+ self,
369
+ segments: List[Dict],
370
+ compute_contrastive: bool = False
371
+ ) -> Dict:
372
+ """
373
+ Args:
374
+ segments: 列表,每个元素包含 {'type', 'data', 'modality_id'}
375
+ - type: str, 模态类型 ('image', 'audio', 'video', 'text')
376
+ - data: Tensor [B, T, D], 模态数据
377
+ - modality_id: int, 模态ID (0-3)
378
+ compute_contrastive: 是否计算对比学习损失
379
+
380
+ Returns:
381
+ Dict containing:
382
+ - fused_features: 融合后的特征序列
383
+ - modality_features: 各模态的特征字典
384
+ - contrastive_losses: 对比学习损失字典
385
+ """
386
+ # 分离不同模态
387
+ modality_features = {}
388
+ modality_ids = {}
389
+
390
+ for seg in segments:
391
+ mod_type = seg['type']
392
+ mod_data = seg['data']
393
+ mod_id = seg['modality_id']
394
+
395
+ # 检查数据维度
396
+ if mod_data.dim() != 3:
397
+ raise ValueError(
398
+ f"Expected 3D tensor [B, T, D] for modality {mod_type}, "
399
+ f"got shape {mod_data.shape}"
400
+ )
401
+
402
+ # 投影到统一空间
403
+ if mod_type in self.modality_projectors:
404
+ projected = self.modality_projectors[mod_type](mod_data)
405
+ else:
406
+ projected = mod_data
407
+
408
+ # 使用Perceiver压缩(可选,非text模态)
409
+ if self.use_perceiver and mod_type != 'text':
410
+ projected = self.perceiver(projected)
411
+
412
+ modality_features[mod_type] = projected
413
+ modality_ids[mod_type] = mod_id
414
+
415
+ # 跨模态融合
416
+ fused_features = {}
417
+
418
+ for mod_type, features in modality_features.items():
419
+ # 创建不包含当前模态的上下文
420
+ if len(modality_features) > 1:
421
+ other_features = torch.cat([
422
+ f for k, f in modality_features.items() if k != mod_type
423
+ ], dim=1)
424
+ else:
425
+ other_features = None
426
+
427
+ # 通过融合层
428
+ fused = features
429
+ for layer in self.fusion_layers:
430
+ fused = layer(
431
+ fused,
432
+ context=other_features,
433
+ modality_id=modality_ids[mod_type]
434
+ )
435
+
436
+ fused_features[mod_type] = self.final_norm(fused)
437
+
438
+ # 计算对比学习损失(如果需要)
439
+ contrastive_losses = {}
440
+ if compute_contrastive and self.use_contrastive:
441
+ # 准备特征字典 - 保持3D格式供投影头处理
442
+ pooled_features = fused_features # 不池化,让ProjectionHead处理
443
+
444
+ # 定义需要对比的模态对
445
+ modality_pairs = []
446
+ if 'text' in pooled_features:
447
+ for mod in pooled_features.keys():
448
+ if mod != 'text':
449
+ modality_pairs.append((mod, 'text'))
450
+
451
+ # 调用对比学习模块
452
+ if modality_pairs:
453
+ contrastive_losses = self.contrastive_module(
454
+ pooled_features,
455
+ modality_pairs=modality_pairs
456
+ )
457
+
458
+ # 拼接所有融合后的特征
459
+ fused_sequence = torch.cat(list(fused_features.values()), dim=1)
460
+
461
+ return {
462
+ 'fused_features': fused_sequence,
463
+ 'modality_features': fused_features,
464
+ 'contrastive_losses': contrastive_losses
465
+ }
466
+
467
+
468
+ class EarlyFusionModule(nn.Module):
469
+ """早期融合 - 在浅层就融合模态"""
470
+ def __init__(self, dim: int = 2048):
471
+ super().__init__()
472
+ self.fusion_proj = nn.Linear(dim, dim)
473
+ self.norm = RMSNorm(dim)
474
+
475
+ def forward(self, segments: List[Dict]) -> torch.Tensor:
476
+ """简单拼接所有模态"""
477
+ all_features = [seg['data'] for seg in segments]
478
+ fused = torch.cat(all_features, dim=1)
479
+ fused = self.fusion_proj(fused)
480
+ return self.norm(fused)
481
+
482
+
483
+ class LateFusionModule(nn.Module):
484
+ """晚期融合 - 在深层才融合模态"""
485
+ def __init__(
486
+ self,
487
+ dim: int = 2048,
488
+ num_modalities: int = 4,
489
+ fusion_method: str = 'concat' # 'concat', 'attention', 'average'
490
+ ):
491
+ super().__init__()
492
+ self.fusion_method = fusion_method
493
+
494
+ if fusion_method == 'concat':
495
+ self.fusion_proj = nn.Linear(dim * num_modalities, dim)
496
+ elif fusion_method == 'attention':
497
+ self.attention_weights = nn.Linear(dim, 1)
498
+
499
+ self.norm = RMSNorm(dim)
500
+
501
+ def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor:
502
+ """
503
+ Args:
504
+ modality_outputs: 每个模态独立处理后的输出列表 [B, T, D]
505
+ """
506
+ if self.fusion_method == 'concat':
507
+ # 拼接并投影
508
+ pooled = [x.mean(dim=1) for x in modality_outputs]
509
+ fused = torch.cat(pooled, dim=-1)
510
+ fused = self.fusion_proj(fused)
511
+
512
+ elif self.fusion_method == 'attention':
513
+ # 注意力加权
514
+ stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
515
+ weights = F.softmax(self.attention_weights(stacked), dim=1)
516
+ fused = (stacked * weights).sum(dim=1)
517
+
518
+ else: # average
519
+ stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
520
+ fused = stacked.mean(dim=1)
521
+
522
+ return self.norm(fused)
peft_.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 参数高效微调 (PEFT) 模块
3
+ 支持LoRA和Adapter
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import math
8
+
9
+ class LoRALayer(nn.Module):
10
+ """低秩适应层 (LoRA)"""
11
+ def __init__(
12
+ self,
13
+ in_features: int,
14
+ out_features: int,
15
+ rank: int = 8,
16
+ alpha: float = 16.0,
17
+ dropout: float = 0.0
18
+ ):
19
+ super().__init__()
20
+ self.rank = rank
21
+ self.alpha = alpha
22
+ self.scaling = alpha / rank
23
+
24
+ self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
25
+ self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
26
+
27
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
28
+
29
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
30
+ nn.init.zeros_(self.lora_B)
31
+
32
+ self.merged = False
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ """前向传播"""
36
+ result = x @ self.lora_A @ self.lora_B
37
+ result = self.dropout(result)
38
+ return result * self.scaling
39
+
40
+ class LinearWithLoRA(nn.Module):
41
+ """带LoRA的线性层"""
42
+ def __init__(
43
+ self,
44
+ in_features: int,
45
+ out_features: int,
46
+ bias: bool = True,
47
+ use_lora: bool = False,
48
+ lora_rank: int = 8,
49
+ lora_alpha: float = 16.0,
50
+ lora_dropout: float = 0.0
51
+ ):
52
+ super().__init__()
53
+ self.in_features = in_features
54
+ self.out_features = out_features
55
+ self.use_lora = use_lora
56
+
57
+ self.base_linear = nn.Linear(in_features, out_features, bias=bias)
58
+
59
+ if use_lora:
60
+ self.lora = LoRALayer(
61
+ in_features,
62
+ out_features,
63
+ lora_rank,
64
+ lora_alpha,
65
+ lora_dropout
66
+ )
67
+ self.merged = False
68
+ else:
69
+ self.lora = None
70
+ self.merged = False
71
+
72
+ def merge(self):
73
+ """将LoRA权重合并到基础权重中"""
74
+ if self.use_lora and not self.merged:
75
+ lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling
76
+ self.base_linear.weight.data += lora_weight.T
77
+ self.merged = True
78
+
79
+ def unmerge(self):
80
+ """取消合并LoRA权重"""
81
+ if self.use_lora and self.merged:
82
+ lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling
83
+ self.base_linear.weight.data -= lora_weight.T
84
+ self.merged = False
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ """前向传播"""
88
+ output = self.base_linear(x)
89
+
90
+ if self.use_lora and self.lora is not None and not self.merged:
91
+ output = output + self.lora(x)
92
+
93
+ return output
94
+
95
+ class AdapterLayer(nn.Module):
96
+ """Adapter层 - 轻量级微调"""
97
+ def __init__(
98
+ self,
99
+ dim: int,
100
+ bottleneck_dim: int = 64,
101
+ dropout: float = 0.1,
102
+ activation: str = 'gelu',
103
+ residual_scale: float = 1.0
104
+ ):
105
+ super().__init__()
106
+ self.residual_scale = residual_scale
107
+
108
+ self.down_proj = nn.Linear(dim, bottleneck_dim)
109
+
110
+ if activation == 'gelu':
111
+ self.activation = nn.GELU()
112
+ elif activation == 'relu':
113
+ self.activation = nn.ReLU()
114
+ elif activation == 'silu':
115
+ self.activation = nn.SiLU()
116
+ else:
117
+ self.activation = nn.GELU()
118
+
119
+ self.up_proj = nn.Linear(bottleneck_dim, dim)
120
+ self.dropout = nn.Dropout(dropout)
121
+
122
+ from components import RMSNorm
123
+ self.layer_norm = RMSNorm(dim)
124
+
125
+ self._init_weights()
126
+
127
+ def _init_weights(self):
128
+ """初始化权重"""
129
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
130
+ nn.init.zeros_(self.up_proj.weight)
131
+ if self.down_proj.bias is not None:
132
+ nn.init.zeros_(self.down_proj.bias)
133
+ if self.up_proj.bias is not None:
134
+ nn.init.zeros_(self.up_proj.bias)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ """前向传播"""
138
+ residual = x
139
+
140
+ x = self.layer_norm(x)
141
+ x = self.down_proj(x)
142
+ x = self.activation(x)
143
+ x = self.dropout(x)
144
+ x = self.up_proj(x)
145
+ x = self.dropout(x)
146
+
147
+ return residual + x * self.residual_scale
148
+
149
+ class PrefixTuning(nn.Module):
150
+ """Prefix Tuning"""
151
+ def __init__(
152
+ self,
153
+ num_layers: int,
154
+ num_tokens: int,
155
+ dim: int,
156
+ num_heads: int
157
+ ):
158
+ super().__init__()
159
+ self.num_layers = num_layers
160
+ self.num_tokens = num_tokens
161
+ self.dim = dim
162
+ self.num_heads = num_heads
163
+
164
+ head_dim = dim // num_heads
165
+ self.prefix = nn.Parameter(
166
+ torch.randn(num_layers, 2, num_tokens, num_heads, head_dim)
167
+ )
168
+
169
+ nn.init.normal_(self.prefix, std=0.02)
170
+
171
+ def forward(self, layer_idx: int, batch_size: int) -> torch.Tensor:
172
+ """获取指定层的prefix"""
173
+ prefix = self.prefix[layer_idx]
174
+ prefix = prefix.unsqueeze(1).expand(
175
+ 2, batch_size, self.num_heads, self.num_tokens, -1
176
+ )
177
+
178
+ return prefix
179
+
180
+ class PromptTuning(nn.Module):
181
+ """Prompt Tuning"""
182
+ def __init__(
183
+ self,
184
+ num_tokens: int,
185
+ dim: int,
186
+ init_from_vocab: bool = False,
187
+ vocab_embeddings: nn.Embedding = None
188
+ ):
189
+ super().__init__()
190
+ self.num_tokens = num_tokens
191
+ self.dim = dim
192
+
193
+ self.prompt_embeddings = nn.Parameter(torch.randn(num_tokens, dim))
194
+
195
+ if init_from_vocab and vocab_embeddings is not None:
196
+ indices = torch.randint(0, vocab_embeddings.num_embeddings, (num_tokens,))
197
+ self.prompt_embeddings.data = vocab_embeddings.weight[indices].clone()
198
+ else:
199
+ nn.init.normal_(self.prompt_embeddings, std=0.02)
200
+
201
+ def forward(self, batch_size: int) -> torch.Tensor:
202
+ """获取prompt embeddings"""
203
+ return self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
204
+
205
+ class IALayer(nn.Module):
206
+ """(IA)³层"""
207
+ def __init__(self, dim: int):
208
+ super().__init__()
209
+ self.scale = nn.Parameter(torch.ones(dim))
210
+
211
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
212
+ """应用缩放"""
213
+ return x * self.scale
post.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # posttrain.py
2
+ """
3
+ 后训练脚本 - Instruction tuning和对齐
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import AutoTokenizer
9
+ from pathlib import Path
10
+ import logging
11
+ from tqdm import tqdm
12
+ import json
13
+ from datetime import datetime
14
+ import copy
15
+ from model import MultiModalDenseTransformer
16
+
17
+ from data_loader import (
18
+ create_posttrain_dataloader,
19
+ create_preference_dataloader
20
+ )
21
+ from data_config import POSTTRAIN_MIX
22
+ from reward_model import RewardModel, RewardModelTrainer
23
+ from grpo import GRPOTrainer
24
+ from typing import Optional
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
32
+
33
+ class PostTrainer:
34
+ """后训练器 - Supervised Fine-Tuning"""
35
+ def __init__(
36
+ self,
37
+ model: MultiModalDenseTransformer,
38
+ tokenizer,
39
+ learning_rate: float = 1e-5,
40
+ weight_decay: float = 0.01,
41
+ num_epochs: int = 3,
42
+ gradient_accumulation_steps: int = 1,
43
+ max_grad_norm: float = 1.0,
44
+ log_interval: int = 10,
45
+ eval_interval: int = 500,
46
+ save_interval: int = 1000,
47
+ checkpoint_dir: str = "checkpoints/posttrain"
48
+ ):
49
+ self.model = model
50
+ self.tokenizer = tokenizer
51
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+
53
+ self.model.to(self.device)
54
+
55
+ # 优化器
56
+ self.optimizer = torch.optim.AdamW(
57
+ model.parameters(),
58
+ lr=learning_rate,
59
+ weight_decay=weight_decay,
60
+ betas=(0.9, 0.95),
61
+ eps=1e-8
62
+ )
63
+
64
+ # 混合精度
65
+ self.use_amp = torch.cuda.is_available()
66
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
67
+
68
+ # 训练参数
69
+ self.num_epochs = num_epochs
70
+ self.gradient_accumulation_steps = gradient_accumulation_steps
71
+ self.max_grad_norm = max_grad_norm
72
+ self.log_interval = log_interval
73
+ self.eval_interval = eval_interval
74
+ self.save_interval = save_interval
75
+
76
+ # Checkpoint管理
77
+ self.checkpoint_dir = Path(checkpoint_dir)
78
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
79
+
80
+ # 训练状态
81
+ self.global_step = 0
82
+ self.best_eval_loss = float('inf')
83
+
84
+ logger.info(f"PostTrainer initialized:")
85
+ logger.info(f" Device: {self.device}")
86
+ logger.info(f" Learning Rate: {learning_rate}")
87
+ logger.info(f" Num Epochs: {num_epochs}")
88
+ logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
89
+
90
+ def train_step(self, batch: dict) -> dict:
91
+ """单步训练"""
92
+ instruction_ids = batch['instruction'].to(self.device)
93
+ response_ids = batch['response'].to(self.device)
94
+
95
+ # 获取 DataLoader 返回的掩码
96
+ instruction_mask = batch['instruction_mask'].to(self.device)
97
+ response_mask = batch['response_mask'].to(self.device)
98
+ # 拼接输入
99
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
100
+ attention_mask = torch.cat([instruction_mask, response_mask], dim=1).float()
101
+ # 创建标签(只计算response部分的损失)
102
+ labels = input_ids.clone()
103
+ instr_len = instruction_ids.shape[1]
104
+ labels[:, :instr_len] = -100
105
+ labels[attention_mask == 0] = -100
106
+
107
+
108
+ # 准备输入数据
109
+ input_data = {
110
+ 'segments': [{
111
+ 'type': 'text',
112
+ 'data': input_ids,
113
+ 'modality_id': 0
114
+ }]
115
+ }
116
+
117
+ # 前向传播
118
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
119
+ outputs = self.model(input_data,attention_mask=attention_mask)
120
+ logits = outputs['logits']
121
+
122
+ # 计算损失
123
+ shift_logits = logits[:, :-1, :].contiguous()
124
+ shift_labels = labels[:, 1:].contiguous()
125
+
126
+ loss = F.cross_entropy(
127
+ shift_logits.view(-1, shift_logits.size(-1)),
128
+ shift_labels.view(-1),
129
+ ignore_index=-100
130
+ )
131
+ raw_loss = loss.item()
132
+ loss = loss / self.gradient_accumulation_steps
133
+
134
+ # 反向传播
135
+ self.scaler.scale(loss).backward()
136
+
137
+ return {
138
+ 'loss': raw_loss
139
+ }
140
+
141
+ def optimizer_step(self):
142
+ """优化器步骤"""
143
+ self.scaler.unscale_(self.optimizer)
144
+ grad_norm = torch.nn.utils.clip_grad_norm_(
145
+ self.model.parameters(),
146
+ self.max_grad_norm
147
+ )
148
+
149
+ self.scaler.step(self.optimizer)
150
+ self.scaler.update()
151
+ self.optimizer.zero_grad(set_to_none=True)
152
+ self.global_step += 1
153
+ return grad_norm.item()
154
+
155
+ @torch.no_grad()
156
+ def evaluate(self, dataloader, max_batches: int = 50) -> float:
157
+ """评估"""
158
+ self.model.eval()
159
+ total_loss = 0.0
160
+ num_batches = 0
161
+
162
+ for i, batch in enumerate(dataloader):
163
+ if i >= max_batches:
164
+ break
165
+
166
+ if batch is None:
167
+ continue
168
+
169
+ instruction_ids = batch['instruction'].to(self.device)
170
+ response_ids = batch['response'].to(self.device)
171
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
172
+
173
+ labels = input_ids.clone()
174
+ labels[:, :instruction_ids.shape[1]] = -100
175
+ labels[input_ids == self.tokenizer.pad_token_id] = -100
176
+
177
+ input_data = {
178
+ 'segments': [{
179
+ 'type': 'text',
180
+ 'data': input_ids,
181
+ 'modality_id': 0
182
+ }]
183
+ }
184
+
185
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
186
+ outputs = self.model(input_data)
187
+ logits = outputs['logits']
188
+
189
+ shift_logits = logits[:, :-1, :].contiguous()
190
+ shift_labels = labels[:, 1:].contiguous()
191
+
192
+ loss = F.cross_entropy(
193
+ shift_logits.view(-1, shift_logits.size(-1)),
194
+ shift_labels.view(-1),
195
+ ignore_index=-100
196
+ )
197
+
198
+ total_loss += loss.item()
199
+ num_batches += 1
200
+
201
+ self.model.train()
202
+ return total_loss / max(num_batches, 1)
203
+
204
+ def train(
205
+ self,
206
+ train_dataloader,
207
+ eval_dataloader=None,
208
+ resume_from: Optional[str] = None
209
+ ):
210
+ """训练循环"""
211
+ logger.info("\n" + "="*80)
212
+ logger.info("Starting Post-Training (SFT)")
213
+ logger.info("="*80 + "\n")
214
+
215
+ if resume_from:
216
+ self.load_checkpoint(resume_from)
217
+
218
+ self.model.train()
219
+
220
+ for epoch in range(self.num_epochs):
221
+ logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
222
+
223
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
224
+ running_loss = 0.0
225
+ step_in_accumulation = 0
226
+
227
+ for batch_idx, batch in enumerate(progress_bar):
228
+ if batch is None:
229
+ continue
230
+
231
+ # 训练步骤
232
+ stats = self.train_step(batch)
233
+ running_loss += stats['loss']
234
+ step_in_accumulation += 1
235
+
236
+ # 优化器更新
237
+ if step_in_accumulation == self.gradient_accumulation_steps:
238
+ grad_norm = self.optimizer_step()
239
+ step_in_accumulation = 0
240
+
241
+ # 更新进度条
242
+ progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
243
+
244
+ # 日志
245
+ if self.global_step % self.log_interval == 0:
246
+ avg_loss = running_loss / self.log_interval
247
+ logger.info(
248
+ f"Step {self.global_step} | "
249
+ f"Epoch {epoch+1} | "
250
+ f"Loss: {avg_loss:.4f}"
251
+ )
252
+ running_loss = 0.0
253
+
254
+ # 评估
255
+ if eval_dataloader and self.global_step % self.eval_interval == 0:
256
+ eval_loss = self.evaluate(eval_dataloader)
257
+ logger.info(f"Eval Loss: {eval_loss:.4f}")
258
+
259
+ if eval_loss < self.best_eval_loss:
260
+ self.best_eval_loss = eval_loss
261
+ self.save_checkpoint(
262
+ self.checkpoint_dir / "best_model.pt",
263
+ is_best=True
264
+ )
265
+
266
+ # 保存
267
+ if self.global_step % self.save_interval == 0:
268
+ self.save_checkpoint(
269
+ self.checkpoint_dir / f"step_{self.global_step}.pt"
270
+ )
271
+
272
+ # Epoch结束评估
273
+ if eval_dataloader:
274
+ eval_loss = self.evaluate(eval_dataloader)
275
+ logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
276
+
277
+ logger.info("\n" + "="*80)
278
+ logger.info("Post-Training Complete!")
279
+ logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
280
+ logger.info("="*80 + "\n")
281
+
282
+ self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
283
+
284
+ def save_checkpoint(self, path: Path, is_best: bool = False):
285
+ """保存checkpoint"""
286
+ checkpoint = {
287
+ 'model_state_dict': self.model.state_dict(),
288
+ 'optimizer_state_dict': self.optimizer.state_dict(),
289
+ 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
290
+ 'global_step': self.global_step,
291
+ 'best_eval_loss': self.best_eval_loss,
292
+ 'timestamp': datetime.now().isoformat()
293
+ }
294
+
295
+ torch.save(checkpoint, path)
296
+ logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
297
+
298
+ def load_checkpoint(self, path: str):
299
+ """加载checkpoint"""
300
+ checkpoint = torch.load(path, map_location=self.device)
301
+
302
+ self.model.load_state_dict(checkpoint['model_state_dict'])
303
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
304
+
305
+ if self.use_amp and checkpoint.get('scaler_state_dict'):
306
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
307
+
308
+ self.global_step = checkpoint['global_step']
309
+ self.best_eval_loss = checkpoint['best_eval_loss']
310
+
311
+ logger.info(f"Checkpoint loaded from {path}")
312
+
313
+ def main():
314
+ """主函数"""
315
+ # 配置
316
+ config = {
317
+ # 模型配置
318
+ 'model_dim': 1536,
319
+ 'vocab_size': 151665,
320
+ 'n_layers': 12,
321
+ 'n_heads': 12,
322
+ 'n_kv_heads': 4,
323
+ 'max_seq_len': 512,
324
+ 'dropout': 0.0,
325
+ 'use_moe': False,
326
+ # 训练配置
327
+ 'batch_size': 2,
328
+ 'gradient_accumulation_steps': 8,
329
+ 'learning_rate': 1e-4,
330
+ 'weight_decay': 0.01,
331
+ 'num_epochs': 1,
332
+ 'max_grad_norm': 1.0,
333
+
334
+ # 数据配置
335
+ 'data_mix': 'debug_mix',
336
+ 'max_samples_train': 1000,
337
+ 'max_samples_eval': 1000,
338
+ 'max_length': 512,
339
+ 'num_workers': 4,
340
+
341
+ # RLHF配置
342
+ 'do_rlhf': False,
343
+ 'preference_dataset': 'hh_rlhf',
344
+ 'grpo_iterations': 3,
345
+ 'grpo_kl_coef': 0.04,
346
+ 'grpo_group_size': 4,
347
+
348
+ # 路径
349
+ 'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
350
+ 'checkpoint_dir': 'checkpoints/posttrain',
351
+ 'log_interval': 50,
352
+ 'eval_interval': 500,
353
+ 'save_interval': 1000,
354
+ }
355
+
356
+ logger.info("Configuration:")
357
+ logger.info(json.dumps(config, indent=2))
358
+
359
+ # 初始化tokenizer
360
+ logger.info("\nInitializing tokenizer...")
361
+ tokenizer = AutoTokenizer.from_pretrained(
362
+ "Qwen/Qwen2.5-7B-Instruct",
363
+ use_fast=True,
364
+ trust_remote_code=True
365
+ )
366
+
367
+ if tokenizer.pad_token is None:
368
+ tokenizer.pad_token = tokenizer.eos_token
369
+ tokenizer.pad_token_id = tokenizer.eos_token_id
370
+
371
+ config['vocab_size'] = len(tokenizer)
372
+
373
+ # 初始化或加载模型
374
+ logger.info("\nInitializing model...")
375
+ model = MultiModalDenseTransformer(
376
+ model_dim=config['model_dim'],
377
+ vocab_size=config['vocab_size'],
378
+ n_layers=config['n_layers'],
379
+ n_heads=config['n_heads'],
380
+ n_kv_heads=config['n_kv_heads'],
381
+ max_seq_len=config['max_seq_len'],
382
+ dropout=config['dropout'],
383
+ use_moe=config['use_moe'],
384
+ use_gradient_checkpointing=False,
385
+ rope_scaling_type="yarn",
386
+ use_multimodal_fusion=False,
387
+ use_contrastive=False
388
+ )
389
+
390
+ # 加载预训练checkpoint(如果有)
391
+ if config['pretrain_checkpoint']:
392
+ logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
393
+ checkpoint = torch.load(config['pretrain_checkpoint'])
394
+ model.load_state_dict(checkpoint['model_state_dict'])
395
+
396
+ # ===== 阶段1: Supervised Fine-Tuning =====
397
+ logger.info("\n" + "="*80)
398
+ logger.info("PHASE 1: Supervised Fine-Tuning")
399
+ logger.info("="*80)
400
+
401
+ # 创建数据加载器
402
+ train_dataloader = create_posttrain_dataloader(
403
+ mix_name=config['data_mix'],
404
+ tokenizer=tokenizer,
405
+ batch_size=config['batch_size'],
406
+ num_workers=config['num_workers'],
407
+ max_length=config['max_length'],
408
+ max_samples=config['max_samples_train'],
409
+ split='train',
410
+ shuffle=True
411
+ )
412
+
413
+ eval_dataloader = create_posttrain_dataloader(
414
+ mix_name=config['data_mix'],
415
+ tokenizer=tokenizer,
416
+ batch_size=config['batch_size'] * 2,
417
+ num_workers=config['num_workers'],
418
+ max_length=config['max_length'],
419
+ max_samples=config['max_samples_eval'],
420
+ split='train', # 使用train的后部分作为验证
421
+ shuffle=False
422
+ )
423
+
424
+ # 创建训练器
425
+ trainer = PostTrainer(
426
+ model=model,
427
+ tokenizer=tokenizer,
428
+ learning_rate=config['learning_rate'],
429
+ weight_decay=config['weight_decay'],
430
+ num_epochs=config['num_epochs'],
431
+ gradient_accumulation_steps=config['gradient_accumulation_steps'],
432
+ max_grad_norm=config['max_grad_norm'],
433
+ log_interval=config['log_interval'],
434
+ eval_interval=config['eval_interval'],
435
+ save_interval=config['save_interval'],
436
+ checkpoint_dir=config['checkpoint_dir']
437
+ )
438
+
439
+ # 开始SFT训练
440
+ trainer.train(train_dataloader, eval_dataloader)
441
+
442
+ # ===== 阶段2: RLHF with GRPO =====
443
+ if config['do_rlhf']:
444
+ logger.info("\n" + "="*80)
445
+ logger.info("PHASE 2: RLHF with GRPO")
446
+ logger.info("="*80)
447
+
448
+ try:
449
+ # 训练奖励模型
450
+ logger.info("\nTraining Reward Model...")
451
+
452
+ reward_base_model = copy.deepcopy(model)
453
+ reward_model = RewardModel(reward_base_model, use_value_head=True)
454
+
455
+ preference_dataloader = create_preference_dataloader(
456
+ dataset_name=config['preference_dataset'],
457
+ tokenizer=tokenizer,
458
+ batch_size=config['batch_size'],
459
+ num_workers=config['num_workers'],
460
+ max_samples=5000,
461
+ split='train'
462
+ )
463
+
464
+ reward_trainer = RewardModelTrainer(
465
+ reward_model=reward_model,
466
+ learning_rate=1e-5
467
+ )
468
+
469
+ reward_trainer.train(preference_dataloader, num_epochs=1)
470
+
471
+ # GRPO训练
472
+ logger.info("\nStarting GRPO Training...")
473
+
474
+ ref_model = copy.deepcopy(model)
475
+ ref_model.eval()
476
+
477
+ grpo_trainer = GRPOTrainer(
478
+ actor_model=model,
479
+ reward_model=reward_model,
480
+ ref_model=ref_model,
481
+ tokenizer=tokenizer,
482
+ learning_rate=1e-6,
483
+ kl_coef=config['grpo_kl_coef'],
484
+ group_size=config['grpo_group_size'],
485
+ update_batch_size=2,
486
+ use_amp=True
487
+ )
488
+
489
+ # 准备prompts
490
+ prompt_dataloader = create_posttrain_dataloader(
491
+ mix_name=config['data_mix'],
492
+ tokenizer=tokenizer,
493
+ batch_size=4,
494
+ num_workers=2,
495
+ max_samples=1000,
496
+ split='train'
497
+ )
498
+
499
+ # 提取prompts
500
+ prompts = []
501
+ for batch in prompt_dataloader:
502
+ if batch and batch.get('instruction') is not None:
503
+ prompts.append(batch['instruction'])
504
+ if len(prompts) >= 200:
505
+ break
506
+
507
+ if prompts:
508
+ prompt_tensor = torch.cat(prompts[:200], dim=0)
509
+ from torch.utils.data import TensorDataset, DataLoader
510
+ prompt_loader = DataLoader(
511
+ TensorDataset(prompt_tensor),
512
+ batch_size=4
513
+ )
514
+
515
+ grpo_trainer.train(
516
+ prompt_loader,
517
+ num_iterations=config['grpo_iterations'],
518
+ max_gen_len=50,
519
+ save_path=config['checkpoint_dir'] + "/grpo"
520
+ )
521
+
522
+ except Exception as e:
523
+ logger.error(f"Error in RLHF: {e}")
524
+ import traceback
525
+ traceback.print_exc()
526
+
527
+ logger.info("\n" + "="*80)
528
+ logger.info("All Training Complete!")
529
+ logger.info("="*80)
530
+
531
+ if __name__ == "__main__":
532
+ main()
posttrain.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # posttrain.py
2
+ """
3
+ 后训练脚本 - Instruction tuning和对齐
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import AutoTokenizer
9
+ from pathlib import Path
10
+ import logging
11
+ from tqdm import tqdm
12
+ import json
13
+ from datetime import datetime
14
+ import copy
15
+ from model import MultiModalDenseTransformer
16
+
17
+ from data_loader import (
18
+ create_posttrain_dataloader,
19
+ create_preference_dataloader
20
+ )
21
+ from data_config import POSTTRAIN_MIX
22
+ from reward_model import RewardModel, RewardModelTrainer
23
+ from grpo import GRPOTrainer
24
+ from typing import Optional
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
32
+
33
+ class PostTrainer:
34
+ """后训练器 - Supervised Fine-Tuning"""
35
+ def __init__(
36
+ self,
37
+ model: MultiModalDenseTransformer,
38
+ tokenizer,
39
+ learning_rate: float = 1e-5,
40
+ weight_decay: float = 0.01,
41
+ num_epochs: int = 3,
42
+ gradient_accumulation_steps: int = 1,
43
+ max_grad_norm: float = 1.0,
44
+ log_interval: int = 10,
45
+ eval_interval: int = 500,
46
+ save_interval: int = 1000,
47
+ checkpoint_dir: str = "checkpoints/posttrain"
48
+ ):
49
+ self.model = model
50
+ self.tokenizer = tokenizer
51
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+
53
+ self.model.to(self.device)
54
+
55
+ # 优化器
56
+ self.optimizer = torch.optim.AdamW(
57
+ model.parameters(),
58
+ lr=learning_rate,
59
+ weight_decay=weight_decay,
60
+ betas=(0.9, 0.95),
61
+ eps=1e-8
62
+ )
63
+
64
+ # 混合精度
65
+ self.use_amp = torch.cuda.is_available()
66
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
67
+
68
+ # 训练参数
69
+ self.num_epochs = num_epochs
70
+ self.gradient_accumulation_steps = gradient_accumulation_steps
71
+ self.max_grad_norm = max_grad_norm
72
+ self.log_interval = log_interval
73
+ self.eval_interval = eval_interval
74
+ self.save_interval = save_interval
75
+
76
+ # Checkpoint管理
77
+ self.checkpoint_dir = Path(checkpoint_dir)
78
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
79
+
80
+ # 训练状态
81
+ self.global_step = 0
82
+ self.best_eval_loss = float('inf')
83
+
84
+ logger.info(f"PostTrainer initialized:")
85
+ logger.info(f" Device: {self.device}")
86
+ logger.info(f" Learning Rate: {learning_rate}")
87
+ logger.info(f" Num Epochs: {num_epochs}")
88
+ logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
89
+
90
+ def train_step(self, batch: dict) -> dict:
91
+ """单步训练"""
92
+ instruction_ids = batch['instruction'].to(self.device)
93
+ response_ids = batch['response'].to(self.device)
94
+
95
+ # 1. 获取 Mask (这是之前代码里漏掉的)
96
+ instruction_mask = batch['instruction_mask'].to(self.device)
97
+ response_mask = batch['response_mask'].to(self.device)
98
+
99
+ # 2. 拼接输入 ID 和 Mask
100
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
101
+ attention_mask = torch.cat([instruction_mask, response_mask], dim=1)
102
+
103
+ batch_size , seq_len = input_ids.shape
104
+ position_ids=torch.zeros_like(input_ids)
105
+
106
+ for i in range(batch_size):
107
+ non_pad_mask = attention_mask[i].bool()
108
+ if non_pad_mask.any():
109
+ positions=torch.cumsum(non_pad_mask.long(), dim=0) -1
110
+ position_ids[i] = positions * non_pad_mask.long()
111
+
112
+
113
+
114
+
115
+
116
+ # 3. 创建标签
117
+ labels = input_ids.clone()
118
+
119
+ # 屏蔽 Instruction 部分
120
+ instr_len = instruction_ids.shape[1]
121
+ labels[:, :instr_len] = -100
122
+
123
+ labels[attention_mask == 0] = -100
124
+
125
+
126
+ # 准备输入数据
127
+ input_data = {
128
+ 'segments': [{
129
+ 'type': 'text',
130
+ 'data': input_ids,
131
+ 'modality_id': 0
132
+ }]
133
+ }
134
+
135
+ # 前向传播
136
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
137
+ # === 核心修改点 2 ===
138
+ # 必须传入 attention_mask,否则 transformer 不知道哪里是 padding
139
+ outputs = self.model(input_data, attention_mask=attention_mask,
140
+ position_ids = position_ids)
141
+
142
+ logits = outputs['logits']
143
+
144
+ # 计算损失
145
+ shift_logits = logits[:, :-1, :].contiguous()
146
+ shift_labels = labels[:, 1:].contiguous()
147
+
148
+ loss = F.cross_entropy(
149
+ shift_logits.view(-1, shift_logits.size(-1)),
150
+ shift_labels.view(-1),
151
+ ignore_index=-100
152
+ )
153
+ raw_loss = loss.item()
154
+ loss = loss / self.gradient_accumulation_steps
155
+
156
+ # 反向传播
157
+ self.scaler.scale(loss).backward()
158
+
159
+ return {
160
+ 'loss': raw_loss
161
+ }
162
+
163
+ def optimizer_step(self):
164
+ """优化器步骤"""
165
+ self.scaler.unscale_(self.optimizer)
166
+ grad_norm = torch.nn.utils.clip_grad_norm_(
167
+ self.model.parameters(),
168
+ self.max_grad_norm
169
+ )
170
+
171
+ self.scaler.step(self.optimizer)
172
+ self.scaler.update()
173
+ self.optimizer.zero_grad(set_to_none=True)
174
+ self.global_step += 1
175
+ return grad_norm.item()
176
+
177
+ @torch.no_grad()
178
+ def evaluate(self, dataloader, max_batches: int = 50) -> float:
179
+ """评估"""
180
+ self.model.eval()
181
+ total_loss = 0.0
182
+ num_batches = 0
183
+
184
+ for i, batch in enumerate(dataloader):
185
+ if i >= max_batches:
186
+ break
187
+
188
+ if batch is None:
189
+ continue
190
+
191
+ instruction_ids = batch['instruction'].to(self.device)
192
+ response_ids = batch['response'].to(self.device)
193
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
194
+
195
+ labels = input_ids.clone()
196
+ labels[:, :instruction_ids.shape[1]] = -100
197
+ labels[input_ids == self.tokenizer.pad_token_id] = -100
198
+
199
+ input_data = {
200
+ 'segments': [{
201
+ 'type': 'text',
202
+ 'data': input_ids,
203
+ 'modality_id': 0
204
+ }]
205
+ }
206
+
207
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
208
+ outputs = self.model(input_data)
209
+ logits = outputs['logits']
210
+
211
+ shift_logits = logits[:, :-1, :].contiguous()
212
+ shift_labels = labels[:, 1:].contiguous()
213
+
214
+ loss = F.cross_entropy(
215
+ shift_logits.view(-1, shift_logits.size(-1)),
216
+ shift_labels.view(-1),
217
+ ignore_index=-100
218
+ )
219
+
220
+ total_loss += loss.item()
221
+ num_batches += 1
222
+
223
+ self.model.train()
224
+ return total_loss / max(num_batches, 1)
225
+
226
+ def train(
227
+ self,
228
+ train_dataloader,
229
+ eval_dataloader=None,
230
+ resume_from: Optional[str] = None
231
+ ):
232
+ """训练循环"""
233
+ logger.info("\n" + "="*80)
234
+ logger.info("Starting Post-Training (SFT)")
235
+ logger.info("="*80 + "\n")
236
+
237
+ if resume_from:
238
+ self.load_checkpoint(resume_from)
239
+
240
+ self.model.train()
241
+
242
+ for epoch in range(self.num_epochs):
243
+ logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
244
+
245
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
246
+ running_loss = 0.0
247
+ step_in_accumulation = 0
248
+
249
+ for batch_idx, batch in enumerate(progress_bar):
250
+ if batch is None:
251
+ continue
252
+
253
+ # 训练步骤
254
+ stats = self.train_step(batch)
255
+ running_loss += stats['loss']
256
+ step_in_accumulation += 1
257
+
258
+ # 优化器更新
259
+ if step_in_accumulation == self.gradient_accumulation_steps:
260
+ grad_norm = self.optimizer_step()
261
+ step_in_accumulation = 0
262
+
263
+ # 更新进度条
264
+ progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
265
+
266
+ # 日志
267
+ if self.global_step % self.log_interval == 0:
268
+ avg_loss = running_loss / self.log_interval
269
+ logger.info(
270
+ f"Step {self.global_step} | "
271
+ f"Epoch {epoch+1} | "
272
+ f"Loss: {avg_loss:.4f}"
273
+ )
274
+ running_loss = 0.0
275
+
276
+ # 评估
277
+ if eval_dataloader and self.global_step % self.eval_interval == 0:
278
+ eval_loss = self.evaluate(eval_dataloader)
279
+ logger.info(f"Eval Loss: {eval_loss:.4f}")
280
+
281
+ if eval_loss < self.best_eval_loss:
282
+ self.best_eval_loss = eval_loss
283
+ self.save_checkpoint(
284
+ self.checkpoint_dir / "best_model.pt",
285
+ is_best=True
286
+ )
287
+
288
+ # 保存
289
+ if self.global_step % self.save_interval == 0:
290
+ self.save_checkpoint(
291
+ self.checkpoint_dir / f"step_{self.global_step}.pt"
292
+ )
293
+
294
+ # Epoch结束评估
295
+ if eval_dataloader:
296
+ eval_loss = self.evaluate(eval_dataloader)
297
+ logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
298
+
299
+ logger.info("\n" + "="*80)
300
+ logger.info("Post-Training Complete!")
301
+ logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
302
+ logger.info("="*80 + "\n")
303
+
304
+ self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
305
+
306
+ def save_checkpoint(self, path: Path, is_best: bool = False):
307
+ """保存checkpoint"""
308
+ checkpoint = {
309
+ 'model_state_dict': self.model.state_dict(),
310
+ 'optimizer_state_dict': self.optimizer.state_dict(),
311
+ 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
312
+ 'global_step': self.global_step,
313
+ 'best_eval_loss': self.best_eval_loss,
314
+ 'timestamp': datetime.now().isoformat()
315
+ }
316
+
317
+ torch.save(checkpoint, path)
318
+ logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
319
+
320
+ def load_checkpoint(self, path: str):
321
+ """加载checkpoint"""
322
+ checkpoint = torch.load(path, map_location=self.device)
323
+
324
+ self.model.load_state_dict(checkpoint['model_state_dict'])
325
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
326
+
327
+ if self.use_amp and checkpoint.get('scaler_state_dict'):
328
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
329
+
330
+ self.global_step = checkpoint['global_step']
331
+ self.best_eval_loss = checkpoint['best_eval_loss']
332
+
333
+ logger.info(f"Checkpoint loaded from {path}")
334
+
335
+ def main():
336
+ """主函数"""
337
+ # 配置
338
+ config = {
339
+ # 模型配置
340
+ 'model_dim': 1536,
341
+ 'vocab_size': 151665,
342
+ 'n_layers': 12,
343
+ 'n_heads': 12,
344
+ 'n_kv_heads': 4,
345
+ 'max_seq_len': 512,
346
+ 'dropout': 0.0,
347
+ 'use_moe': False,
348
+ # 训练配置
349
+ 'batch_size': 2,
350
+ 'gradient_accumulation_steps': 8,
351
+ 'learning_rate': 1e-5,
352
+ 'weight_decay': 0.01,
353
+ 'num_epochs': 3,
354
+ 'max_grad_norm': 1.0,
355
+
356
+ # 数据配置
357
+ 'data_mix': 'simple_instruct',
358
+ 'max_samples_train': 20000,
359
+ 'max_samples_eval': 1000,
360
+ 'max_length': 512,
361
+ 'num_workers': 4,
362
+
363
+ # RLHF配置
364
+ 'do_rlhf': False,
365
+ 'preference_dataset': 'hh_rlhf',
366
+ 'grpo_iterations': 3,
367
+ 'grpo_kl_coef': 0.04,
368
+ 'grpo_group_size': 4,
369
+
370
+ # 路径
371
+ 'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
372
+ 'checkpoint_dir': 'checkpoints/posttrain',
373
+ 'log_interval': 50,
374
+ 'eval_interval': 500,
375
+ 'save_interval': 1000,
376
+ }
377
+
378
+ logger.info("Configuration:")
379
+ logger.info(json.dumps(config, indent=2))
380
+
381
+ # 初始化tokenizer
382
+ logger.info("\nInitializing tokenizer...")
383
+ tokenizer = AutoTokenizer.from_pretrained(
384
+ "Qwen/Qwen2.5-7B-Instruct",
385
+ use_fast=True,
386
+ trust_remote_code=True
387
+ )
388
+
389
+ if tokenizer.pad_token is None:
390
+ tokenizer.pad_token = tokenizer.eos_token
391
+ tokenizer.pad_token_id = tokenizer.eos_token_id
392
+
393
+ config['vocab_size'] = len(tokenizer)
394
+
395
+ # 初始化或加载模型
396
+ logger.info("\nInitializing model...")
397
+ model = MultiModalDenseTransformer(
398
+ model_dim=config['model_dim'],
399
+ vocab_size=config['vocab_size'],
400
+ n_layers=config['n_layers'],
401
+ n_heads=config['n_heads'],
402
+ n_kv_heads=config['n_kv_heads'],
403
+ max_seq_len=config['max_seq_len'],
404
+ dropout=config['dropout'],
405
+ use_moe=config['use_moe'],
406
+ use_gradient_checkpointing=False,
407
+ rope_scaling_type="yarn",
408
+ use_multimodal_fusion=False,
409
+ use_contrastive=False
410
+ )
411
+
412
+ # 加载预训练checkpoint(如果有)
413
+ if config['pretrain_checkpoint']:
414
+ logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
415
+ checkpoint = torch.load(config['pretrain_checkpoint'])
416
+ model.load_state_dict(checkpoint['model_state_dict'])
417
+
418
+ # ===== 阶段1: Supervised Fine-Tuning =====
419
+ logger.info("\n" + "="*80)
420
+ logger.info("PHASE 1: Supervised Fine-Tuning")
421
+ logger.info("="*80)
422
+
423
+ # 创建数据加载器
424
+ train_dataloader = create_posttrain_dataloader(
425
+ mix_name=config['data_mix'],
426
+ tokenizer=tokenizer,
427
+ batch_size=config['batch_size'],
428
+ num_workers=config['num_workers'],
429
+ max_length=config['max_length'],
430
+ max_samples=config['max_samples_train'],
431
+ split='train',
432
+ shuffle=True
433
+ )
434
+
435
+ eval_dataloader = create_posttrain_dataloader(
436
+ mix_name=config['data_mix'],
437
+ tokenizer=tokenizer,
438
+ batch_size=config['batch_size'] * 2,
439
+ num_workers=config['num_workers'],
440
+ max_length=config['max_length'],
441
+ max_samples=config['max_samples_eval'],
442
+ split='train', # 使用train的后部分作为验证
443
+ shuffle=False
444
+ )
445
+
446
+ # 创建训练器
447
+ trainer = PostTrainer(
448
+ model=model,
449
+ tokenizer=tokenizer,
450
+ learning_rate=config['learning_rate'],
451
+ weight_decay=config['weight_decay'],
452
+ num_epochs=config['num_epochs'],
453
+ gradient_accumulation_steps=config['gradient_accumulation_steps'],
454
+ max_grad_norm=config['max_grad_norm'],
455
+ log_interval=config['log_interval'],
456
+ eval_interval=config['eval_interval'],
457
+ save_interval=config['save_interval'],
458
+ checkpoint_dir=config['checkpoint_dir']
459
+ )
460
+
461
+ # 开始SFT训练
462
+ trainer.train(train_dataloader, eval_dataloader)
463
+
464
+ # ===== 阶段2: RLHF with GRPO =====
465
+ if config['do_rlhf']:
466
+ logger.info("\n" + "="*80)
467
+ logger.info("PHASE 2: RLHF with GRPO")
468
+ logger.info("="*80)
469
+
470
+ try:
471
+ # 训练奖励模型
472
+ logger.info("\nTraining Reward Model...")
473
+
474
+ reward_base_model = copy.deepcopy(model)
475
+ reward_model = RewardModel(reward_base_model, use_value_head=True)
476
+
477
+ preference_dataloader = create_preference_dataloader(
478
+ dataset_name=config['preference_dataset'],
479
+ tokenizer=tokenizer,
480
+ batch_size=config['batch_size'],
481
+ num_workers=config['num_workers'],
482
+ max_samples=5000,
483
+ split='train'
484
+ )
485
+
486
+ reward_trainer = RewardModelTrainer(
487
+ reward_model=reward_model,
488
+ learning_rate=1e-5
489
+ )
490
+
491
+ reward_trainer.train(preference_dataloader, num_epochs=1)
492
+
493
+ # GRPO训练
494
+ logger.info("\nStarting GRPO Training...")
495
+
496
+ ref_model = copy.deepcopy(model)
497
+ ref_model.eval()
498
+
499
+ grpo_trainer = GRPOTrainer(
500
+ actor_model=model,
501
+ reward_model=reward_model,
502
+ ref_model=ref_model,
503
+ tokenizer=tokenizer,
504
+ learning_rate=1e-6,
505
+ kl_coef=config['grpo_kl_coef'],
506
+ group_size=config['grpo_group_size'],
507
+ update_batch_size=2,
508
+ use_amp=True
509
+ )
510
+
511
+ # 准备prompts
512
+ prompt_dataloader = create_posttrain_dataloader(
513
+ mix_name=config['data_mix'],
514
+ tokenizer=tokenizer,
515
+ batch_size=4,
516
+ num_workers=2,
517
+ max_samples=1000,
518
+ split='train'
519
+ )
520
+
521
+ # 提取prompts
522
+ prompts = []
523
+ for batch in prompt_dataloader:
524
+ if batch and batch.get('instruction') is not None:
525
+ prompts.append(batch['instruction'])
526
+ if len(prompts) >= 200:
527
+ break
528
+
529
+ if prompts:
530
+ prompt_tensor = torch.cat(prompts[:200], dim=0)
531
+ from torch.utils.data import TensorDataset, DataLoader
532
+ prompt_loader = DataLoader(
533
+ TensorDataset(prompt_tensor),
534
+ batch_size=4
535
+ )
536
+
537
+ grpo_trainer.train(
538
+ prompt_loader,
539
+ num_iterations=config['grpo_iterations'],
540
+ max_gen_len=50,
541
+ save_path=config['checkpoint_dir'] + "/grpo"
542
+ )
543
+
544
+ except Exception as e:
545
+ logger.error(f"Error in RLHF: {e}")
546
+ import traceback
547
+ traceback.print_exc()
548
+
549
+ logger.info("\n" + "="*80)
550
+ logger.info("All Training Complete!")
551
+ logger.info("="*80)
552
+
553
+ if __name__ == "__main__":
554
+ main()
pretrain.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pretrain.py - 完全修复版本
2
+
3
+ import os
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer
7
+ from pathlib import Path
8
+ import logging
9
+ from tqdm import tqdm
10
+ import json
11
+ from datetime import datetime
12
+ from model import MultiModalDenseTransformer
13
+ from data_loader import create_pretrain_dataloader
14
+
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
21
+
22
+
23
+ class PreTrainer:
24
+ """预训练器 - 完全修复版"""
25
+ def __init__(
26
+ self,
27
+ model: MultiModalDenseTransformer,
28
+ tokenizer,
29
+ learning_rate: float = 3e-4,
30
+ weight_decay: float = 0.1,
31
+ warmup_steps: int = 1000,
32
+ max_steps: int = 100000,
33
+ gradient_accumulation_steps: int = 16,
34
+ max_grad_norm: float = 1.0,
35
+ log_interval: int = 10,
36
+ save_interval: int = 1000,
37
+ checkpoint_dir: str = "checkpoints/pretrain",
38
+ loss_log_file: str = "checkpoints/pretrain/train_loss.log"
39
+ ):
40
+ self.model = model
41
+ self.tokenizer = tokenizer
42
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+
44
+ self.model.to(self.device)
45
+
46
+ # 优化器配置 - 使用标准AdamW参数
47
+ self.optimizer = torch.optim.AdamW(
48
+ model.parameters(),
49
+ lr=learning_rate,
50
+ weight_decay=weight_decay,
51
+ betas=(0.9, 0.95),
52
+ eps=1e-8
53
+ )
54
+
55
+ # 🔧 修复:使用更简单的学习率调度器
56
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
57
+
58
+ # Warmup + Cosine Decay
59
+ self.warmup_steps = warmup_steps
60
+ self.max_lr = learning_rate
61
+ self.min_lr = learning_rate * 0.1
62
+ self.current_step = 0
63
+
64
+ # 混合精度
65
+ self.use_amp = torch.cuda.is_available()
66
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
67
+
68
+ # 训练参数
69
+ self.gradient_accumulation_steps = gradient_accumulation_steps
70
+ self.max_grad_norm = max_grad_norm
71
+ self.max_steps = max_steps
72
+ self.log_interval = log_interval
73
+ self.save_interval = save_interval
74
+
75
+ # Checkpoint管理
76
+ self.checkpoint_dir = Path(checkpoint_dir)
77
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
78
+
79
+ # 损失日志
80
+ self.loss_log_file = Path(loss_log_file)
81
+ self.loss_log_file.parent.mkdir(parents=True, exist_ok=True)
82
+
83
+ # 训练状态
84
+ self.global_step = 0
85
+ self.tokens_seen = 0
86
+ self.running_loss = 0.0
87
+ self.best_loss = float('inf')
88
+
89
+ logger.info(f"PreTrainer initialized:")
90
+ logger.info(f" Device: {self.device}")
91
+ logger.info(f" Learning Rate: {learning_rate}")
92
+ logger.info(f" Max Steps: {max_steps}")
93
+ logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
94
+ logger.info(f" Effective Batch Size: {gradient_accumulation_steps}")
95
+ logger.info(f" Mixed Precision: {self.use_amp}")
96
+
97
+ def _get_lr(self) -> float:
98
+ """手动计算学习率(Warmup + Cosine)"""
99
+ if self.current_step < self.warmup_steps:
100
+ # Linear warmup
101
+ return self.max_lr * (self.current_step / self.warmup_steps)
102
+ else:
103
+ # Cosine decay
104
+ progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
105
+ return self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)))
106
+
107
+ def _set_lr(self, lr: float):
108
+ """设置学习率"""
109
+ for param_group in self.optimizer.param_groups:
110
+ param_group['lr'] = lr
111
+
112
+ def train_step(self, batch: dict) -> dict:
113
+ """
114
+ 🔧 完全修复的训练步骤
115
+ 关键:不要在loss计算时除以gradient_accumulation_steps
116
+ """
117
+ input_ids = batch['input_ids'].to(self.device)
118
+ attention_mask = batch['attention_mask'].to(self.device)
119
+ batch_size, seq_len = input_ids.shape
120
+ position_ids= torch.zeros_like(input_ids)
121
+
122
+ for i in range(batch_size):
123
+ non_pad_mask = attention_mask[i].bool()
124
+ if non_pad_mask.any():
125
+ positions = torch.cumsum(non_pad_mask.long(), dim=0) -1
126
+ position_ids[i]=positions * non_pad_mask.long()
127
+
128
+
129
+
130
+ # 准备输入
131
+ input_data = {
132
+ 'segments': [{
133
+ 'type': 'text',
134
+ 'data': input_ids,
135
+ 'modality_id': 0
136
+ }]
137
+ }
138
+
139
+ # 前向传播
140
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
141
+ outputs = self.model(
142
+ input_data,
143
+ attention_mask=attention_mask,
144
+ position_ids=position_ids)
145
+ logits = outputs['logits']
146
+
147
+ # 计算损失(标准自回归)
148
+ shift_logits = logits[:, :-1, :].contiguous()
149
+ shift_labels = input_ids[:, 1:].contiguous()
150
+ shift_attention_mask = attention_mask[:, 1:].contiguous()
151
+
152
+ # 🔧 关键修复:直接计算平均loss,不要除以gradient_accumulation_steps
153
+ loss = F.cross_entropy(
154
+ shift_logits.view(-1, shift_logits.size(-1)),
155
+ shift_labels.view(-1),
156
+ reduction='none'
157
+ )
158
+
159
+ # 应用mask
160
+ loss = (loss * shift_attention_mask.view(-1)).sum() / (shift_attention_mask.sum() + 1e-8)
161
+
162
+ # 🔧 重要:为了数值稳定,在这里手动处理梯度累积
163
+ # 方法:缩放loss用于反向传播,但记录原始loss
164
+ loss_for_backward = loss / self.gradient_accumulation_steps
165
+
166
+ # 反向传播(使用缩放后的loss)
167
+ self.scaler.scale(loss_for_backward).backward()
168
+
169
+ # 🔧 关键修复:不在这里累积loss,改在optimizer_step时累积
170
+ # self.running_loss += loss.item() # ❌ 移除
171
+ self.tokens_seen += attention_mask.sum().item()
172
+
173
+ return {
174
+ 'loss': loss.item(), # 返回真实的、未缩放的loss
175
+ 'lr': self.optimizer.param_groups[0]['lr']
176
+ }
177
+
178
+ def optimizer_step(self):
179
+ """优化器步骤"""
180
+ # Unscale梯度
181
+ self.scaler.unscale_(self.optimizer)
182
+
183
+ # 梯度裁剪
184
+ grad_norm = torch.nn.utils.clip_grad_norm_(
185
+ self.model.parameters(),
186
+ self.max_grad_norm
187
+ )
188
+
189
+ # 更新参数
190
+ self.scaler.step(self.optimizer)
191
+ self.scaler.update()
192
+ self.optimizer.zero_grad(set_to_none=True)
193
+
194
+ # 更新学习率
195
+ self.current_step += 1
196
+ self.global_step += 1
197
+ lr = self._get_lr()
198
+ self._set_lr(lr)
199
+
200
+ return grad_norm.item()
201
+
202
+ def _write_loss_to_txt(self, step, avg_loss, lr, tokens_seen):
203
+ """写入损失日志"""
204
+ log_content = (
205
+ f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
206
+ f"Step: {step}/{self.max_steps}, "
207
+ f"Average Loss: {avg_loss:.4f}, "
208
+ f"Learning Rate: {lr:.2e}, "
209
+ f"Tokens Seen: {tokens_seen/1e9:.2f}B\n"
210
+ )
211
+ with open(self.loss_log_file, 'a', encoding='utf-8') as f:
212
+ f.write(log_content)
213
+
214
+ def train(self, dataloader, resume_from=None):
215
+ """训练循环"""
216
+ logger.info("\n" + "="*80)
217
+ logger.info("Starting Pre-Training (Fixed Version)")
218
+ logger.info("="*80 + "\n")
219
+
220
+ # 恢复训练
221
+ if resume_from:
222
+ self.load_checkpoint(resume_from)
223
+
224
+ # 初始化日志
225
+ if not self.loss_log_file.exists():
226
+ with open(self.loss_log_file, 'w', encoding='utf-8') as f:
227
+ f.write("🚀 Fixed Training Log (Real Loss Values)\n")
228
+ f.write("="*80 + "\n")
229
+
230
+ self.model.train()
231
+ progress_bar = tqdm(total=self.max_steps, initial=self.global_step)
232
+
233
+ step_in_accumulation = 0
234
+ accumulated_loss = 0.0 # 🔧 用于累积一个完整step的loss
235
+
236
+ batches_to_skip = self.global_step * self.gradient_accumulation_steps
237
+
238
+ logger.info(f"Current Global Step: {self.global_step}")
239
+ if batches_to_skip > 0:
240
+ logger.info(f"🔄 Resuming: Need to skip {batches_to_skip} batches to restore data state...")
241
+ logger.info("This might take a while depending on network/disk speed...")
242
+
243
+ # 创建迭代器
244
+ data_iterator = iter(dataloader)
245
+
246
+ # 1. 执行跳过逻辑
247
+ skipped = 0
248
+ if batches_to_skip > 0:
249
+ with tqdm(total=batches_to_skip, desc="Skipping trained batches", unit="batch") as skip_pbar:
250
+ while skipped < batches_to_skip:
251
+ try:
252
+ # 只取数据,不进模型,不计算梯度
253
+ _ = next(data_iterator)
254
+ skipped += 1
255
+ skip_pbar.update(1)
256
+ except StopIteration:
257
+ logger.error("Dataset exhausted during skipping! Check your dataset size or max_steps.")
258
+ return
259
+
260
+ logger.info("✅ Data fast-forward complete. Resuming training...")
261
+
262
+ # 2. 正式训练循环
263
+ try:
264
+ # 注意:这里不能再用 for batch in dataloader,因为迭代器已经被消费了一部分
265
+ # 我们继续使用上面创建的 data_iterator
266
+ while True:
267
+ try:
268
+ batch = next(data_iterator)
269
+ except StopIteration:
270
+ break # 数据耗尽
271
+
272
+ if batch is None or batch['input_ids'].size(0) == 0:
273
+ continue
274
+ #print("Sample input:", self.tokenizer.decode(batch['input_ids'][0][:50]))
275
+ # 训练步骤
276
+ stats = self.train_step(batch)
277
+ step_in_accumulation += 1
278
+ accumulated_loss += stats['loss'] # 🔧 累积当前micro-batch的loss
279
+
280
+ # 梯度累积完成,执行优化器更新
281
+ if step_in_accumulation >= self.gradient_accumulation_steps:
282
+ # 🔧 计算这个完整step的平均loss
283
+ avg_step_loss = accumulated_loss / self.gradient_accumulation_steps
284
+
285
+ grad_norm = self.optimizer_step()
286
+ stats['grad_norm'] = grad_norm
287
+ stats['loss'] = avg_step_loss # 🔧 更新为平均loss
288
+
289
+ # 🔧 累积到running_loss(用于日志记录)
290
+ self.running_loss += avg_step_loss
291
+
292
+ step_in_accumulation = 0
293
+ accumulated_loss = 0.0 # 🔧 重置累积器
294
+
295
+ # 更新进度条
296
+ progress_bar.update(1)
297
+ progress_bar.set_postfix({
298
+ 'loss': f"{stats['loss']:.4f}",
299
+ 'lr': f"{stats['lr']:.2e}",
300
+ 'tokens': f"{self.tokens_seen/1e9:.2f}B",
301
+ 'grad': f"{grad_norm:.2f}"
302
+ })
303
+
304
+ # 日志记录
305
+ if self.global_step % self.log_interval == 0:
306
+ avg_loss = self.running_loss / self.log_interval
307
+
308
+ logger.info(
309
+ f"Step {self.global_step}/{self.max_steps} | "
310
+ f"Loss: {avg_loss:.4f} | "
311
+ f"LR: {stats['lr']:.2e} | "
312
+ f"GradNorm: {grad_norm:.2f} | "
313
+ f"Tokens: {self.tokens_seen/1e9:.2f}B"
314
+ )
315
+
316
+ # 🔧 检测训练异常
317
+ if avg_loss > 10.0 and self.global_step > 100:
318
+ logger.warning(f"⚠️ Loss异常高 ({avg_loss:.2f}),可能存在问题!")
319
+
320
+ if avg_loss < self.best_loss:
321
+ self.best_loss = avg_loss
322
+ logger.info(f"✨ New best loss: {self.best_loss:.4f}")
323
+
324
+ self._write_loss_to_txt(
325
+ step=self.global_step,
326
+ avg_loss=avg_loss,
327
+ lr=stats['lr'],
328
+ tokens_seen=self.tokens_seen
329
+ )
330
+ self.running_loss = 0.0
331
+
332
+ # 保存checkpoint
333
+ if self.global_step % self.save_interval == 0:
334
+ self.save_checkpoint(
335
+ self.checkpoint_dir / f"step_{self.global_step}.pt"
336
+ )
337
+
338
+ # 完成训练
339
+ if self.global_step >= self.max_steps:
340
+ break
341
+
342
+ except KeyboardInterrupt:
343
+ logger.info("\n⚠️ Training interrupted by user")
344
+ self.save_checkpoint(
345
+ self.checkpoint_dir / f"interrupted_step_{self.global_step}.pt"
346
+ )
347
+
348
+ finally:
349
+ progress_bar.close()
350
+
351
+ logger.info("\n" + "="*80)
352
+ logger.info("Pre-Training Complete!")
353
+ logger.info(f" Total Steps: {self.global_step}")
354
+ logger.info(f" Total Tokens: {self.tokens_seen/1e9:.2f}B")
355
+ logger.info(f" Best Loss: {self.best_loss:.4f}")
356
+ logger.info("="*80 + "\n")
357
+
358
+ # 保存最终模型
359
+ self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
360
+
361
+ def save_checkpoint(self, path: Path):
362
+ """保存checkpoint"""
363
+ checkpoint = {
364
+ 'model_state_dict': self.model.state_dict(),
365
+ 'optimizer_state_dict': self.optimizer.state_dict(),
366
+ 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
367
+ 'global_step': self.global_step,
368
+ 'current_step': self.current_step,
369
+ 'tokens_seen': self.tokens_seen,
370
+ 'best_loss': self.best_loss,
371
+ 'timestamp': datetime.now().isoformat()
372
+ }
373
+
374
+ torch.save(checkpoint, path)
375
+ logger.info(f"💾 Checkpoint saved to {path}")
376
+
377
+ def load_checkpoint(self, path: str):
378
+ """加载checkpoint"""
379
+ checkpoint = torch.load(path, map_location=self.device, weights_only=True)
380
+
381
+ self.model.load_state_dict(checkpoint['model_state_dict'])
382
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
383
+
384
+ if self.use_amp and checkpoint.get('scaler_state_dict'):
385
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
386
+
387
+ self.global_step = checkpoint['global_step']
388
+ self.current_step = checkpoint.get('current_step', self.global_step)
389
+ self.tokens_seen = checkpoint['tokens_seen']
390
+ self.best_loss = checkpoint.get('best_loss', float('inf'))
391
+
392
+ logger.info(f"📂 Checkpoint loaded from {path}")
393
+ logger.info(f" Resuming from step {self.global_step}")
394
+ logger.info(f" Tokens seen: {self.tokens_seen/1e9:.2f}B")
395
+
396
+
397
+ def main():
398
+ """主函数"""
399
+ # 🔧 优化后的配置
400
+ config = {
401
+ # 模型配置
402
+ 'model_dim': 1536,
403
+ 'vocab_size': 151665,
404
+ 'n_layers': 12,
405
+ 'n_heads': 12,
406
+ 'n_kv_heads': 4,
407
+ 'max_seq_len': 512, # 🔧 减小以提升速度
408
+ 'dropout': 0.1,
409
+ 'use_moe': False,
410
+
411
+ # 🔧 训练配置(关键修复)
412
+ 'batch_size': 4, # 增加
413
+ 'gradient_accumulation_steps': 8, # 减少
414
+ 'learning_rate': 3e-4, # 标准值
415
+ 'weight_decay': 0.1,
416
+ 'warmup_steps': 500, # 更快warmup
417
+ 'max_steps': 10000,
418
+ 'max_grad_norm': 1.0,
419
+
420
+ # 数据配置
421
+ 'data_mix': 'text_only',
422
+ 'max_length': 512, # 🔧 与max_seq_len一致
423
+ 'num_workers': 2, # 🔧 减少避免网络问题
424
+
425
+ # 日志和保存
426
+ 'log_interval': 10,
427
+ 'save_interval': 500, # 🔧 更频繁保存
428
+ 'checkpoint_dir': 'checkpoints/pretrain_fixed',
429
+ 'loss_log_file': 'checkpoints/pretrain_fixed/train_loss.log'
430
+ }
431
+
432
+ logger.info("="*80)
433
+ logger.info("🔧 Fixed Configuration:")
434
+ logger.info(json.dumps(config, indent=2))
435
+ logger.info("="*80 + "\n")
436
+
437
+ # 初始化tokenizer
438
+ logger.info("Initializing tokenizer...")
439
+ tokenizer = AutoTokenizer.from_pretrained(
440
+ "Qwen/Qwen2.5-7B-Instruct",
441
+ use_fast=True,
442
+ trust_remote_code=True
443
+ )
444
+
445
+ if tokenizer.pad_token is None:
446
+ tokenizer.pad_token = tokenizer.eos_token
447
+ tokenizer.pad_token_id = tokenizer.eos_token_id
448
+
449
+ config['vocab_size'] = len(tokenizer)
450
+ logger.info(f"Vocab size: {config['vocab_size']}\n")
451
+
452
+ # 初始化模型
453
+ logger.info("Initializing model...")
454
+ model = MultiModalDenseTransformer(
455
+ model_dim=config['model_dim'],
456
+ vocab_size=config['vocab_size'],
457
+ n_layers=config['n_layers'],
458
+ n_heads=config['n_heads'],
459
+ n_kv_heads=config['n_kv_heads'],
460
+ max_seq_len=config['max_seq_len'],
461
+ dropout=config['dropout'],
462
+ use_moe=config['use_moe'],
463
+ use_gradient_checkpointing=True,
464
+ rope_scaling_type="yarn",
465
+ use_multimodal_fusion=False,
466
+ use_contrastive=False
467
+ )
468
+
469
+ # 创建数据加载器
470
+ logger.info(f"\nCreating dataloader (mix: {config['data_mix']})...")
471
+ dataloader = create_pretrain_dataloader(
472
+ mix_name=config['data_mix'],
473
+ tokenizer=tokenizer,
474
+ batch_size=config['batch_size'],
475
+ num_workers=config['num_workers'],
476
+ max_length=config['max_length']
477
+ )
478
+
479
+ # 创建训练器
480
+ trainer = PreTrainer(
481
+ model=model,
482
+ tokenizer=tokenizer,
483
+ learning_rate=config['learning_rate'],
484
+ weight_decay=config['weight_decay'],
485
+ warmup_steps=config['warmup_steps'],
486
+ max_steps=config['max_steps'],
487
+ gradient_accumulation_steps=config['gradient_accumulation_steps'],
488
+ max_grad_norm=config['max_grad_norm'],
489
+ log_interval=config['log_interval'],
490
+ save_interval=config['save_interval'],
491
+ checkpoint_dir=config['checkpoint_dir'],
492
+ loss_log_file=config['loss_log_file']
493
+ )
494
+
495
+ # 🔧 开始训练(从头开始,不要用旧的checkpoint)
496
+ logger.info("\n🚀 Starting fresh training with fixes...\n")
497
+ trainer.train(dataloader, resume_from="/root/step_6500.pt")
498
+ #trainer.train(dataloader)
499
+
500
+
501
+ if __name__ == "__main__":
502
+ main()
reward_model.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 奖励模型 - 用于RLHF
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ from collections import defaultdict
10
+ from typing import Dict, Tuple, Union, Optional
11
+ from tqdm import tqdm
12
+ from model import MultiModalDenseTransformer
13
+
14
+ class RewardModel(nn.Module):
15
+ """奖励模型 - 用于RLHF"""
16
+ def __init__(
17
+ self,
18
+ base_model: MultiModalDenseTransformer,
19
+ use_value_head: bool = True
20
+ ):
21
+ super().__init__()
22
+ self.base_model = base_model
23
+ self.use_value_head = use_value_head
24
+
25
+ self.reward_head = nn.Sequential(
26
+ nn.Linear(base_model.model_dim, base_model.model_dim // 2),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.1),
29
+ nn.Linear(base_model.model_dim // 2, 1)
30
+ )
31
+
32
+ if use_value_head:
33
+ self.value_head = nn.Sequential(
34
+ nn.Linear(base_model.model_dim, base_model.model_dim // 2),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.1),
37
+ nn.Linear(base_model.model_dim // 2, 1)
38
+ )
39
+
40
+ def forward(
41
+ self,
42
+ input_data: Dict,
43
+ return_values: bool = False
44
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
45
+ """前向传播"""
46
+ output = self.base_model(input_data, return_hidden=True)
47
+ hidden_states = output['last_hidden_state']
48
+
49
+ rewards = self.reward_head(hidden_states).squeeze(-1)
50
+
51
+ if return_values and self.use_value_head:
52
+ values = self.value_head(hidden_states).squeeze(-1)
53
+ return rewards, values
54
+
55
+ return rewards
56
+
57
+ class RewardModelTrainer:
58
+ """奖励模型训练器"""
59
+ def __init__(
60
+ self,
61
+ reward_model: RewardModel,
62
+ learning_rate: float = 1e-5,
63
+ margin: float = 0.0
64
+ ):
65
+ self.reward_model = reward_model
66
+ self.margin = margin
67
+
68
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+ self.reward_model.to(self.device)
70
+
71
+ for param in self.reward_model.base_model.parameters():
72
+ param.requires_grad = False
73
+
74
+ for layer in self.reward_model.base_model.layers[-2:]:
75
+ for param in layer.parameters():
76
+ param.requires_grad = True
77
+
78
+ trainable_params = list(self.reward_model.reward_head.parameters())
79
+ if self.reward_model.use_value_head:
80
+ trainable_params += list(self.reward_model.value_head.parameters())
81
+
82
+ self.optimizer = optim.AdamW(
83
+ filter(lambda p: p.requires_grad, self.reward_model.parameters()),
84
+ lr=learning_rate
85
+ )
86
+
87
+ def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict:
88
+ """单步训练"""
89
+ self.reward_model.train()
90
+ self.optimizer.zero_grad()
91
+
92
+ chosen_rewards = self.reward_model(chosen_batch)[:, -1]
93
+ rejected_rewards = self.reward_model(rejected_batch)[:, -1]
94
+
95
+ loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
96
+
97
+ loss.backward()
98
+ torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0)
99
+ self.optimizer.step()
100
+
101
+ accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
102
+
103
+ return {
104
+ 'loss': loss.item(),
105
+ 'accuracy': accuracy
106
+ }
107
+
108
+ def train(
109
+ self,
110
+ dataloader: DataLoader,
111
+ num_epochs: int = 1,
112
+ log_interval: int = 10
113
+ ):
114
+ """训练循环"""
115
+ print(f"Starting reward model training on {self.device}...")
116
+
117
+ for epoch in range(num_epochs):
118
+ total_stats = defaultdict(float)
119
+ num_steps = 0
120
+ progress_bar = tqdm(
121
+ dataloader,
122
+ desc=f"Reward Model Epoch {epoch+1}/{num_epochs}"
123
+ )
124
+
125
+ for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar):
126
+ chosen_batch = {
127
+ 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
128
+ }
129
+
130
+ rejected_batch = {
131
+ 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
132
+ }
133
+
134
+ stats = self.train_step(chosen_batch, rejected_batch)
135
+
136
+ for k, v in stats.items():
137
+ total_stats[k] += v
138
+ num_steps += 1
139
+
140
+ if (batch_idx + 1) % log_interval == 0:
141
+ avg_stats = {
142
+ k: v / num_steps
143
+ for k, v in total_stats.items()
144
+ }
145
+ progress_bar.set_postfix(avg_stats)
146
+ total_stats = defaultdict(float)
147
+
148
+ print("Reward model training complete!")
149
+
150
+ def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
151
+ """评估奖励模型"""
152
+ self.reward_model.eval()
153
+ total_stats = defaultdict(float)
154
+ num_batches = 0
155
+
156
+ with torch.no_grad():
157
+ for chosen_ids, rejected_ids in dataloader:
158
+ chosen_batch = {
159
+ 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
160
+ }
161
+
162
+ rejected_batch = {
163
+ 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
164
+ }
165
+
166
+ chosen_rewards = self.reward_model(chosen_batch)[:, -1]
167
+ rejected_rewards = self.reward_model(rejected_batch)[:, -1]
168
+
169
+ loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
170
+ accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
171
+
172
+ total_stats['loss'] += loss.item()
173
+ total_stats['accuracy'] += accuracy
174
+ num_batches += 1
175
+
176
+ return {k: v / num_batches for k, v in total_stats.items()}
177
+
178
+ def save_checkpoint(self, path: str):
179
+ """保存检查点"""
180
+ torch.save({
181
+ 'model_state_dict': self.reward_model.state_dict(),
182
+ 'optimizer_state_dict': self.optimizer.state_dict(),
183
+ }, path)
184
+
185
+ def load_checkpoint(self, path: str):
186
+ """加载检查点"""
187
+ checkpoint = torch.load(path, map_location=self.device)
188
+ self.reward_model.load_state_dict(checkpoint['model_state_dict'])
189
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
transformer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 优化的Transformer架构
3
+ 支持GQA/MQA、滑动窗口注意力、Flash Attention 2、YARN位置编码
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, List
9
+ import math
10
+ from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm
11
+ from peft_ import LinearWithLoRA, AdapterLayer
12
+ from moe import MixtureOfExperts
13
+
14
+ class GroupedQueryAttention(nn.Module):
15
+ """分组查询注意力 (GQA) - 优化版 with YARN"""
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ n_heads: int,
20
+ n_kv_heads: Optional[int] = None,
21
+ head_dim: Optional[int] = None,
22
+ dropout: float = 0.0,
23
+ attn_dropout: float = 0.0,
24
+ use_flash: bool = True,
25
+ qkv_bias: bool = False,
26
+ use_lora: bool = False,
27
+ lora_rank: int = 8,
28
+ max_seq_len: int = 8192,
29
+ rope_scaling_factor: float = 1.0,
30
+ rope_scaling_type: str = "yarn",
31
+ use_qk_norm: bool = False,
32
+ sliding_window: Optional[int] = None,
33
+ use_alibi: bool = False
34
+ ):
35
+ super().__init__()
36
+
37
+ self.dim = dim
38
+ self.n_heads = n_heads
39
+ self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
40
+
41
+ assert n_heads % self.n_kv_heads == 0, \
42
+ f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
43
+
44
+ self.n_rep = n_heads // self.n_kv_heads
45
+ self.head_dim = head_dim if head_dim is not None else dim // n_heads
46
+ self.scale = self.head_dim ** -0.5
47
+
48
+ self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
49
+ self.sliding_window = sliding_window
50
+
51
+ self.q_proj = LinearWithLoRA(
52
+ dim, n_heads * self.head_dim,
53
+ bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
54
+ )
55
+ self.k_proj = LinearWithLoRA(
56
+ dim, self.n_kv_heads * self.head_dim,
57
+ bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
58
+ )
59
+ self.v_proj = LinearWithLoRA(
60
+ dim, self.n_kv_heads * self.head_dim,
61
+ bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
62
+ )
63
+ self.o_proj = LinearWithLoRA(
64
+ n_heads * self.head_dim, dim,
65
+ bias=False, use_lora=use_lora, lora_rank=lora_rank
66
+ )
67
+
68
+ self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity()
69
+ self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
70
+
71
+ self.use_qk_norm = use_qk_norm
72
+ if use_qk_norm:
73
+ self.q_norm = QKNorm(self.head_dim)
74
+ self.k_norm = QKNorm(self.head_dim)
75
+
76
+ self.use_alibi = use_alibi
77
+ if use_alibi:
78
+ self.register_buffer(
79
+ "alibi_slopes",
80
+ self._get_alibi_slopes(n_heads),
81
+ persistent=False
82
+ )
83
+ else:
84
+ self.rotary_emb = YARNRotaryEmbedding(
85
+ self.head_dim,
86
+ max_seq_len=max_seq_len,
87
+ original_max_len=4096,
88
+ scaling_factor=rope_scaling_factor,
89
+ rope_percentage=1.0
90
+ )
91
+
92
+ def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor:
93
+ """计算ALiBi斜率"""
94
+ def get_slopes_power_of_2(n):
95
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
96
+ ratio = start
97
+ return [start * ratio ** i for i in range(n)]
98
+
99
+ if math.log2(n_heads).is_integer():
100
+ slopes = get_slopes_power_of_2(n_heads)
101
+ else:
102
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
103
+ slopes = get_slopes_power_of_2(closest_power_of_2)
104
+ extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2]
105
+ slopes.extend(extra_slopes[:n_heads - closest_power_of_2])
106
+
107
+ return torch.tensor(slopes).view(n_heads, 1, 1)
108
+
109
+ def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
110
+ """重复KV heads以匹配Q heads"""
111
+ if self.n_rep == 1:
112
+ return x
113
+
114
+ B, n_kv_heads, seq_len, head_dim = x.shape
115
+ return x[:, :, None, :, :].expand(
116
+ B, n_kv_heads, self.n_rep, seq_len, head_dim
117
+ ).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim)
118
+
119
+ def _apply_sliding_window_mask(
120
+ self,
121
+ attn_scores: torch.Tensor,
122
+ seq_len: int
123
+ ) -> torch.Tensor:
124
+ """应用滑动窗口mask"""
125
+ if self.sliding_window is None or seq_len <= self.sliding_window:
126
+ return attn_scores
127
+
128
+ mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool)
129
+ mask = torch.triu(mask, diagonal=-self.sliding_window + 1)
130
+ mask = torch.tril(mask, diagonal=0)
131
+
132
+ attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
133
+ return attn_scores
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ attention_mask: Optional[torch.Tensor] = None,
139
+ position_ids: Optional[torch.Tensor] = None,
140
+ use_cache: bool = False,
141
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
142
+ output_attentions: bool = False
143
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
144
+ """前向传播"""
145
+ B, T, C = x.shape
146
+
147
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
148
+ k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
149
+ v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
150
+
151
+ if self.use_qk_norm:
152
+ q_shape = q.shape
153
+ k_shape = k.shape
154
+ q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape)
155
+ k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape)
156
+
157
+ if not self.use_alibi:
158
+ q, k = self.rotary_emb(q, k, position_ids)
159
+
160
+ if past_kv is not None:
161
+ past_k, past_v = past_kv
162
+ k = torch.cat([past_k, k], dim=2)
163
+ v = torch.cat([past_v, v], dim=2)
164
+
165
+ present_kv = (k, v) if use_cache else None
166
+
167
+ k = self.repeat_kv(k)
168
+ v = self.repeat_kv(v)
169
+
170
+ seq_len_k = k.size(2)
171
+
172
+ if self.use_flash and not output_attentions and attention_mask is None:
173
+ dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0
174
+ attn_output = F.scaled_dot_product_attention(
175
+ q, k, v,
176
+ attn_mask=attention_mask,
177
+ dropout_p=dropout_p,
178
+ is_causal=True if attention_mask is None else False
179
+ )
180
+ attention_weights = None
181
+ else:
182
+ attn_scores = (q @ k.transpose(-2, -1)) * self.scale
183
+
184
+ if self.use_alibi:
185
+ position_bias = self.alibi_slopes.to(x.device) * torch.arange(
186
+ seq_len_k, device=x.device
187
+ ).view(1, 1, -1)
188
+ attn_scores = attn_scores + position_bias
189
+
190
+ if self.sliding_window is not None:
191
+ attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k)
192
+
193
+ if attention_mask is not None:
194
+ if attention_mask.dim() == 2:
195
+ attention_mask = attention_mask[:, None, None, :]
196
+ if attention_mask.dtype != torch.float:
197
+ # 假设传入的是 1(Keep)/0(Mask)
198
+ extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
199
+ else:
200
+ # 假设传入的已经是加性 mask (0/-inf)
201
+ extended_mask = attention_mask
202
+
203
+ attn_scores = attn_scores + extended_mask
204
+
205
+ is_causal = seq_len_k > 1
206
+ if is_causal:
207
+ causal_mask = torch.triu(
208
+ torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
209
+ diagonal=1
210
+ )
211
+ causal_mask = causal_mask[-q.shape[2]:, :]#还没懂
212
+ attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
213
+
214
+ attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
215
+ attention_weights = self.attn_dropout(attention_weights)
216
+
217
+ attn_output = attention_weights @ v
218
+
219
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
220
+ output = self.resid_dropout(self.o_proj(attn_output))
221
+
222
+ return output, present_kv, attention_weights if output_attentions else None
223
+
224
+ class OptimizedTransformerBlock(nn.Module):
225
+ """优化的Transformer块"""
226
+ def __init__(
227
+ self,
228
+ dim: int,
229
+ n_heads: int,
230
+ n_kv_heads: Optional[int] = None,
231
+ head_dim: Optional[int] = None,
232
+ dropout: float = 0.0,
233
+ attn_dropout: float = 0.0,
234
+ use_moe: bool = False,
235
+ num_experts: int = 8,
236
+ moe_top_k: int = 2,
237
+ use_adapter: bool = False,
238
+ adapter_dim: int = 64,
239
+ use_lora: bool = False,
240
+ lora_rank: int = 8,
241
+ use_parallel_residual: bool = False,
242
+ norm_eps: float = 1e-6,
243
+ sliding_window: Optional[int] = None,
244
+ ffn_dim_multiplier: Optional[float] = None,
245
+ layer_idx: int = 0
246
+ ):
247
+ super().__init__()
248
+ self.layer_idx = layer_idx
249
+ self.use_moe = use_moe
250
+ self.use_adapter = use_adapter
251
+ self.use_parallel_residual = use_parallel_residual
252
+
253
+ self.attention = GroupedQueryAttention(
254
+ dim=dim,
255
+ n_heads=n_heads,
256
+ n_kv_heads=n_kv_heads,
257
+ head_dim=head_dim,
258
+ dropout=dropout,
259
+ attn_dropout=attn_dropout,
260
+ use_lora=use_lora,
261
+ lora_rank=lora_rank,
262
+ sliding_window=sliding_window,
263
+ rope_scaling_type="yarn"
264
+ )
265
+
266
+ if use_moe:
267
+ self.ffn = MixtureOfExperts(
268
+ dim=dim,
269
+ num_experts=num_experts,
270
+ top_k=moe_top_k,
271
+ dropout=dropout,
272
+ ffn_dim_multiplier=ffn_dim_multiplier
273
+ )
274
+ else:
275
+ self.ffn = SwiGLU(
276
+ dim=dim,
277
+ dropout=dropout,
278
+ ffn_dim_multiplier=ffn_dim_multiplier
279
+ )
280
+
281
+ if use_adapter:
282
+ self.adapter = AdapterLayer(dim, adapter_dim, dropout)
283
+
284
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
285
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
286
+
287
+ self.moe_aux_loss = torch.tensor(0.0)
288
+
289
+ def forward(
290
+ self,
291
+ x: torch.Tensor,
292
+ attention_mask: Optional[torch.Tensor] = None,
293
+ position_ids: Optional[torch.Tensor] = None,
294
+ use_cache: bool = False,
295
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
296
+ output_attentions: bool = False
297
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
298
+ """前向传播"""
299
+
300
+ attn_out, present_kv, attn_weights = self.attention(
301
+ self.attention_norm(x),
302
+ attention_mask=attention_mask,
303
+ position_ids=position_ids,
304
+ use_cache=use_cache,
305
+ past_kv=past_kv,
306
+ output_attentions=output_attentions
307
+ )
308
+
309
+ if self.use_parallel_residual:
310
+ ffn_input = self.ffn_norm(x)
311
+
312
+ if self.use_moe:
313
+ ffn_out, aux_loss = self.ffn(ffn_input)
314
+ self.moe_aux_loss = aux_loss
315
+ else:
316
+ ffn_out = self.ffn(ffn_input)
317
+ self.moe_aux_loss = torch.tensor(0.0, device=x.device)
318
+
319
+ x = x + attn_out + ffn_out
320
+ else:
321
+ x = x + attn_out
322
+
323
+ if self.use_adapter:
324
+ x = self.adapter(x)
325
+
326
+ ffn_input = self.ffn_norm(x)
327
+ if self.use_moe:
328
+ ffn_out, aux_loss = self.ffn(ffn_input)
329
+ x = x + ffn_out
330
+ self.moe_aux_loss = aux_loss
331
+ else:
332
+ x = x + self.ffn(ffn_input)
333
+ self.moe_aux_loss = torch.tensor(0.0, device=x.device)
334
+
335
+ return x, present_kv, attn_weights