Upload 20 files
Browse files- components.py +387 -0
- continual_learning.py +294 -0
- contrastive_learning.py +339 -0
- data_augmentation.py +366 -0
- data_config.py +292 -0
- data_loader.py +832 -0
- encoders.py +559 -0
- gradio1.py +228 -0
- grpo.py +630 -0
- infer.py +372 -0
- infer_sft.py +407 -0
- model.py +505 -0
- moe.py +460 -0
- multimodel_fusion.py +522 -0
- peft_.py +213 -0
- post.py +532 -0
- posttrain.py +554 -0
- pretrain.py +502 -0
- reward_model.py +189 -0
- transformer.py +335 -0
components.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Tuple, Optional, Union
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class YARNScaling:
|
| 8 |
+
"""
|
| 9 |
+
YARN (Yet Another RoPE extensioN) 缩放策略
|
| 10 |
+
实现参考: https://arxiv.org/abs/2309.00071
|
| 11 |
+
"""
|
| 12 |
+
@staticmethod
|
| 13 |
+
def compute_yarn_parameters(
|
| 14 |
+
original_max_len: int,
|
| 15 |
+
target_max_len: int=8192,
|
| 16 |
+
dim: int=128,
|
| 17 |
+
base: int = 10000,
|
| 18 |
+
beta_fast: int = 32,
|
| 19 |
+
beta_slow: int = 1,
|
| 20 |
+
alpha: float = 1.0,
|
| 21 |
+
device: Optional[torch.device] = None
|
| 22 |
+
) -> Tuple[torch.Tensor, float]:
|
| 23 |
+
scale = float(target_max_len) / original_max_len
|
| 24 |
+
mscale = YARNScaling.compute_mscale(scale, alpha)
|
| 25 |
+
|
| 26 |
+
# 确保 dim 为 float 以进行除法运算
|
| 27 |
+
# RoPE 频率是成对的 (0, 2, ..., d-2)
|
| 28 |
+
freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
|
| 29 |
+
|
| 30 |
+
# 基础频率 (Original RoPE)
|
| 31 |
+
freq_extra = 1.0 / (base ** (freqs_idx / dim))
|
| 32 |
+
|
| 33 |
+
# 如果不需要缩放,直接返回基础频率
|
| 34 |
+
if scale <= 1.0:
|
| 35 |
+
return freq_extra, 1.0
|
| 36 |
+
|
| 37 |
+
# 插值频率 (Interpolated for extension)
|
| 38 |
+
freq_inter = 1.0 / (scale * base ** (freqs_idx / dim))
|
| 39 |
+
|
| 40 |
+
# 计算 YARN 阈值 (基于波长/索引)
|
| 41 |
+
# 对应 paper 中的 band constraints
|
| 42 |
+
# 这里的公式将频率索引 i 映射到阈值
|
| 43 |
+
def get_limit(beta):
|
| 44 |
+
return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base))
|
| 45 |
+
|
| 46 |
+
low = max(math.floor(get_limit(beta_fast)), 0)
|
| 47 |
+
high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1)
|
| 48 |
+
|
| 49 |
+
# indices: 0, 1, ..., dim/2 - 1
|
| 50 |
+
indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device)
|
| 51 |
+
|
| 52 |
+
inv_freq = freq_extra.clone()
|
| 53 |
+
|
| 54 |
+
# 1. 低频部分 (Long wavelengths, Indices > high): 使用插值频率
|
| 55 |
+
# 这些频率对应的波长已经超过了原始上下文长度,需要拉伸
|
| 56 |
+
mask_low_freq = indices > high
|
| 57 |
+
inv_freq[mask_low_freq] = freq_inter[mask_low_freq]
|
| 58 |
+
|
| 59 |
+
# 2. 高频部分 (Short wavelengths, Indices < low): 保持原频率 (freq_extra)
|
| 60 |
+
# 这些部分受旋转不变性保护,不需要插值
|
| 61 |
+
|
| 62 |
+
# 3. 中间部分: 线性平滑混合 (Ramp function)
|
| 63 |
+
mid_mask = (indices >= low) & (indices <= high)
|
| 64 |
+
if mid_mask.any():
|
| 65 |
+
# 避免除以 0
|
| 66 |
+
denom = max(high - low, 1)
|
| 67 |
+
t = (indices[mid_mask] - low) / denom
|
| 68 |
+
inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t
|
| 69 |
+
|
| 70 |
+
return inv_freq, float(mscale)
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def compute_mscale(scale: float, alpha: float = 1.0) -> float:
|
| 74 |
+
"""计算注意力缩放因子 (Temperature scaling)"""
|
| 75 |
+
if scale <= 1.0:
|
| 76 |
+
return 1.0
|
| 77 |
+
# 0.1 * ln(scale) + 1.0 是经验公式,用于修正熵值
|
| 78 |
+
return 0.1 * math.log(scale) + 1.0
|
| 79 |
+
|
| 80 |
+
class YARNRotaryEmbedding(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
集成 YARN 的旋转位置编码
|
| 83 |
+
修复了精度问题、缓存管理以及 position_ids 越界问题
|
| 84 |
+
"""
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int = 64,
|
| 88 |
+
max_seq_len: int = 8192,
|
| 89 |
+
original_max_len: int = 4096,
|
| 90 |
+
base: int = 10000,
|
| 91 |
+
scaling_factor: float = 1.0, # 预留接口,暂未使用,由 yarn 逻辑控制
|
| 92 |
+
beta_fast: int = 32,
|
| 93 |
+
beta_slow: int = 1,
|
| 94 |
+
alpha: float = 1.0,
|
| 95 |
+
rope_percentage: float = 1.0,
|
| 96 |
+
device: Optional[torch.device] = None
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.dim = dim
|
| 100 |
+
self.max_seq_len = max_seq_len
|
| 101 |
+
self.original_max_len = original_max_len
|
| 102 |
+
self.base = base
|
| 103 |
+
self.alpha = alpha
|
| 104 |
+
|
| 105 |
+
# 计算实际应用 RoPE 的维度
|
| 106 |
+
self.rope_dim = int(dim * rope_percentage)
|
| 107 |
+
# 确保是偶数
|
| 108 |
+
if self.rope_dim % 2 != 0:
|
| 109 |
+
self.rope_dim -= 1
|
| 110 |
+
|
| 111 |
+
# 初始化频率 (Persistent state)
|
| 112 |
+
self._init_yarn_frequencies(device)
|
| 113 |
+
|
| 114 |
+
# 缓存 cos/sin (Transient state)
|
| 115 |
+
# persistent=False 意味着不会保存到 state_dict,减少 checkpoint 大小
|
| 116 |
+
self.register_buffer("cos_cached", None, persistent=False)
|
| 117 |
+
self.register_buffer("sin_cached", None, persistent=False)
|
| 118 |
+
|
| 119 |
+
def _init_yarn_frequencies(self, device: Optional[torch.device] = None):
|
| 120 |
+
"""初始化 YARN 频率"""
|
| 121 |
+
inv_freq, mscale = YARNScaling.compute_yarn_parameters(
|
| 122 |
+
self.original_max_len,
|
| 123 |
+
self.max_seq_len,
|
| 124 |
+
self.rope_dim,
|
| 125 |
+
self.base,
|
| 126 |
+
beta_fast=32, # 这里通常使用默认值或传入参数,此处修正为使用硬编码默认值保持一致,或应改为 self.beta_fast
|
| 127 |
+
beta_slow=1,
|
| 128 |
+
alpha=self.alpha,
|
| 129 |
+
device=device
|
| 130 |
+
)
|
| 131 |
+
# 注册 buffer
|
| 132 |
+
self.register_buffer("inv_freq", inv_freq, persistent=True)
|
| 133 |
+
self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True)
|
| 134 |
+
|
| 135 |
+
def _compute_cos_sin_cache(
|
| 136 |
+
self,
|
| 137 |
+
needed_len: int,
|
| 138 |
+
device: torch.device,
|
| 139 |
+
dtype: torch.dtype
|
| 140 |
+
):
|
| 141 |
+
"""预计算 cos 和 sin 缓存,始终使用 float32 计算以保证精度"""
|
| 142 |
+
# 至少分配 max_seq_len,如果 needed_len 更大则扩展
|
| 143 |
+
alloc_len = max(needed_len, self.max_seq_len)
|
| 144 |
+
|
| 145 |
+
# 如果已有缓存且足够大且设备匹配,则不重新计算 (可选优化,这里选择简单逻辑:不够就重算)
|
| 146 |
+
if (self.cos_cached is not None and
|
| 147 |
+
self.cos_cached.shape[2] >= alloc_len and
|
| 148 |
+
self.cos_cached.device == device):
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
t = torch.arange(alloc_len, dtype=torch.float32, device=device)
|
| 152 |
+
|
| 153 |
+
# freqs: [alloc_len, dim // 2]
|
| 154 |
+
# outer product: t[i] * inv_freq[j]
|
| 155 |
+
freqs = torch.outer(t, self.inv_freq.to(device))
|
| 156 |
+
|
| 157 |
+
# 拼接以匹配 rotate_half 的逻辑: [theta_0, theta_1, ..., theta_0, theta_1, ...]
|
| 158 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 159 |
+
|
| 160 |
+
# 应用 mscale 并计算 cos/sin
|
| 161 |
+
# [alloc_len, rope_dim] -> [1, 1, alloc_len, rope_dim] 用于广播
|
| 162 |
+
cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
|
| 163 |
+
sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
|
| 164 |
+
|
| 165 |
+
self.cos_cached = cos_cached.to(dtype) # 缓存可以存为半精度以省显存,但计算时建议 float32
|
| 166 |
+
self.sin_cached = sin_cached.to(dtype)
|
| 167 |
+
|
| 168 |
+
@staticmethod
|
| 169 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
旋转输入的后半部分
|
| 172 |
+
Input: [..., d] -> Split into x1, x2 -> Output [-x2, x1]
|
| 173 |
+
"""
|
| 174 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 175 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 176 |
+
|
| 177 |
+
def apply_rotary_pos_emb(
|
| 178 |
+
self,
|
| 179 |
+
q: torch.Tensor,
|
| 180 |
+
k: torch.Tensor,
|
| 181 |
+
position_ids: Optional[torch.Tensor] = None
|
| 182 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 183 |
+
"""应用 RoPE,包含精度修正和边界检查"""
|
| 184 |
+
bsz, num_heads, seq_len, head_dim = q.shape
|
| 185 |
+
|
| 186 |
+
# 1. 确定需要的缓存长度
|
| 187 |
+
if position_ids is not None:
|
| 188 |
+
# 必须覆盖 position_ids 中的最大索引
|
| 189 |
+
max_pos = position_ids.max().item() + 1
|
| 190 |
+
needed_len = max(max_pos, seq_len)
|
| 191 |
+
else:
|
| 192 |
+
needed_len = seq_len
|
| 193 |
+
|
| 194 |
+
# 2. 检查并更新缓存
|
| 195 |
+
if (self.cos_cached is None or
|
| 196 |
+
self.cos_cached.shape[2] < needed_len or
|
| 197 |
+
self.cos_cached.device != q.device):
|
| 198 |
+
self._compute_cos_sin_cache(needed_len, q.device, q.dtype)
|
| 199 |
+
|
| 200 |
+
# 3. 获取对应的 cos/sin
|
| 201 |
+
# cos_cached: [1, 1, alloc_len, dim]
|
| 202 |
+
if position_ids is not None:
|
| 203 |
+
# position_ids: [bs, seq_len]
|
| 204 |
+
# 选取对应的 pos embedding -> [bs, 1, seq_len, dim]
|
| 205 |
+
# 注意: cos_cached[0, 0] 形状为 [alloc_len, dim]
|
| 206 |
+
cos = self.cos_cached[0, 0][position_ids].unsqueeze(1)
|
| 207 |
+
sin = self.sin_cached[0, 0][position_ids].unsqueeze(1)
|
| 208 |
+
else:
|
| 209 |
+
# 默认假设从 0 开始
|
| 210 |
+
cos = self.cos_cached[:, :, :seq_len, :]
|
| 211 |
+
sin = self.sin_cached[:, :, :seq_len, :]
|
| 212 |
+
|
| 213 |
+
# 4. 处理部分 RoPE (如果 rope_dim < head_dim)
|
| 214 |
+
if self.rope_dim < head_dim:
|
| 215 |
+
q_rot = q[..., :self.rope_dim]
|
| 216 |
+
q_pass = q[..., self.rope_dim:]
|
| 217 |
+
k_rot = k[..., :self.rope_dim]
|
| 218 |
+
k_pass = k[..., self.rope_dim:]
|
| 219 |
+
else:
|
| 220 |
+
q_rot = q
|
| 221 |
+
k_rot = k
|
| 222 |
+
q_pass = None
|
| 223 |
+
k_pass = None
|
| 224 |
+
|
| 225 |
+
# 5. 执行旋转 (强制 float32 计算以避免精度溢出)
|
| 226 |
+
q_rot_float = q_rot.float()
|
| 227 |
+
k_rot_float = k_rot.float()
|
| 228 |
+
cos_float = cos.float()
|
| 229 |
+
sin_float = sin.float()
|
| 230 |
+
|
| 231 |
+
q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float)
|
| 232 |
+
k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float)
|
| 233 |
+
|
| 234 |
+
# 6. 转回原始类型
|
| 235 |
+
q_embed = q_embed.type_as(q)
|
| 236 |
+
k_embed = k_embed.type_as(k)
|
| 237 |
+
|
| 238 |
+
if q_pass is not None:
|
| 239 |
+
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
| 240 |
+
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
| 241 |
+
|
| 242 |
+
return q_embed, k_embed
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
q: torch.Tensor,
|
| 247 |
+
k: torch.Tensor,
|
| 248 |
+
position_ids: Optional[torch.Tensor] = None
|
| 249 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 250 |
+
return self.apply_rotary_pos_emb(q, k, position_ids)
|
| 251 |
+
|
| 252 |
+
def extra_repr(self) -> str:
|
| 253 |
+
return (f"dim={self.dim}, rope_dim={self.rope_dim}, "
|
| 254 |
+
f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, "
|
| 255 |
+
f"base={self.base}")
|
| 256 |
+
|
| 257 |
+
class RMSNorm(nn.Module):
|
| 258 |
+
"""
|
| 259 |
+
Root Mean Square Layer Normalization
|
| 260 |
+
包含 float32 强制转换以确保数值稳定性
|
| 261 |
+
"""
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
dim: int,
|
| 265 |
+
eps: float = 1e-6,
|
| 266 |
+
elementwise_affine: bool = True
|
| 267 |
+
):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.eps = eps
|
| 270 |
+
self.elementwise_affine = elementwise_affine
|
| 271 |
+
|
| 272 |
+
if self.elementwise_affine:
|
| 273 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 274 |
+
else:
|
| 275 |
+
self.register_parameter('weight', None)
|
| 276 |
+
|
| 277 |
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
| 278 |
+
# 始终在 float32 下计算 RMS,防止 FP16 下溢或溢出
|
| 279 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 280 |
+
|
| 281 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
# 1. 转换为 float32 进行统计量计算
|
| 283 |
+
output = self._norm(x.float())
|
| 284 |
+
|
| 285 |
+
# 2. 转回原始类型
|
| 286 |
+
output = output.type_as(x)
|
| 287 |
+
|
| 288 |
+
# 3. 应用权重 (如果存在)
|
| 289 |
+
if self.elementwise_affine and self.weight is not None:
|
| 290 |
+
output = output * self.weight
|
| 291 |
+
|
| 292 |
+
return output
|
| 293 |
+
|
| 294 |
+
class QKNorm(nn.Module):
|
| 295 |
+
"""
|
| 296 |
+
Query-Key Normalization (ViT-22B / Scaling Transformer)
|
| 297 |
+
用于稳定注意力矩阵的 logits
|
| 298 |
+
"""
|
| 299 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.query_norm = RMSNorm(dim, eps=eps)
|
| 302 |
+
self.key_norm = RMSNorm(dim, eps=eps)
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
q: torch.Tensor,
|
| 307 |
+
k: torch.Tensor
|
| 308 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 309 |
+
q = self.query_norm(q)
|
| 310 |
+
k = self.key_norm(k)
|
| 311 |
+
return q, k
|
| 312 |
+
|
| 313 |
+
class SwiGLU(nn.Module):
|
| 314 |
+
"""
|
| 315 |
+
SwiGLU 激活前馈网络
|
| 316 |
+
结构: Down(SiLU(Gate) * Up)
|
| 317 |
+
"""
|
| 318 |
+
def __init__(
|
| 319 |
+
self,
|
| 320 |
+
dim: int,
|
| 321 |
+
hidden_dim: Optional[int] = None,
|
| 322 |
+
multiple_of: int = 256,
|
| 323 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 324 |
+
dropout: float = 0.0,
|
| 325 |
+
bias: bool = False
|
| 326 |
+
):
|
| 327 |
+
super().__init__()
|
| 328 |
+
|
| 329 |
+
if hidden_dim is None:
|
| 330 |
+
if ffn_dim_multiplier is not None:
|
| 331 |
+
hidden_dim = int(dim * ffn_dim_multiplier)
|
| 332 |
+
else:
|
| 333 |
+
# 默认: 2/3 * 4 * dim = 8/3 * dim (LLaMA standard)
|
| 334 |
+
hidden_dim = int(2 * dim * 4 / 3)
|
| 335 |
+
|
| 336 |
+
# 确保 hidden_dim 是 multiple_of 的倍数 (通常为了 GPU 核心优化)
|
| 337 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 338 |
+
|
| 339 |
+
self.hidden_dim = hidden_dim
|
| 340 |
+
|
| 341 |
+
# W1: Gate, W3: Up, W2: Down (Standard LLaMA naming conventions)
|
| 342 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
|
| 343 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 344 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
|
| 345 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 346 |
+
|
| 347 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 348 |
+
# SwiGLU(x) = (SiLU(W1·x) ⊙ W3·x) · W2
|
| 349 |
+
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
| 350 |
+
|
| 351 |
+
class ParallelAttentionFFN(nn.Module):
|
| 352 |
+
"""
|
| 353 |
+
并行注意力与前馈网络 (PaLM / GPT-J 风格)
|
| 354 |
+
y = x + Attention(LN(x)) + MLP(LN(x))
|
| 355 |
+
"""
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
dim: int,
|
| 359 |
+
attn_module: nn.Module,
|
| 360 |
+
ffn_module: nn.Module,
|
| 361 |
+
norm_eps: float = 1e-6
|
| 362 |
+
):
|
| 363 |
+
super().__init__()
|
| 364 |
+
# 注意: 某些架构(如 PaLM)可能共用一个 LayerNorm,
|
| 365 |
+
# 但这里为了灵活性保留两个独立的 Norm (如 CodeLlama 某些变体)
|
| 366 |
+
self.attn_norm = RMSNorm(dim, eps=norm_eps)
|
| 367 |
+
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
| 368 |
+
self.attn = attn_module
|
| 369 |
+
self.ffn = ffn_module
|
| 370 |
+
|
| 371 |
+
def forward(
|
| 372 |
+
self,
|
| 373 |
+
x: torch.Tensor,
|
| 374 |
+
**attn_kwargs
|
| 375 |
+
) -> torch.Tensor:
|
| 376 |
+
# 并行计算:从同一个 x (normalize 后) 分叉
|
| 377 |
+
attn_input = self.attn_norm(x)
|
| 378 |
+
ffn_input = self.ffn_norm(x)
|
| 379 |
+
|
| 380 |
+
# 计算注意力
|
| 381 |
+
attn_out = self.attn(attn_input, **attn_kwargs)
|
| 382 |
+
|
| 383 |
+
# 计算 FFN (确保不传递 attn 特定的 kwargs)
|
| 384 |
+
ffn_out = self.ffn(ffn_input)
|
| 385 |
+
|
| 386 |
+
# 一次性残差连接
|
| 387 |
+
return x + attn_out + ffn_out
|
continual_learning.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
持续学习模块
|
| 3 |
+
支持EWC和经验回放
|
| 4 |
+
修复版本:适配 MultiModalDenseTransformer 和 data_loader.py
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from collections import deque
|
| 12 |
+
from typing import List, Dict, Any, Optional, Union
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
# 假设 model.py 中已有定义,用于类型提示
|
| 17 |
+
# from model import MultiModalDenseTransformer
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ModalityConfig:
|
| 21 |
+
name: str
|
| 22 |
+
modality_id: int
|
| 23 |
+
|
| 24 |
+
class UnifiedMultiModalPreprocessor(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
统一多模态预处理器
|
| 27 |
+
职责:仅负责将原始Batch数据格式化为 MultiModalDenseTransformer 接受的 'segments' 结构。
|
| 28 |
+
不再包含编码器,编码工作交由模型自身完成,以确保 EWC 能够捕捉模型参数的梯度。
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, model_dim: int = 2048):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.modality_configs = {
|
| 33 |
+
'text': ModalityConfig('text', 0),
|
| 34 |
+
'image': ModalityConfig('image', 1),
|
| 35 |
+
'audio': ModalityConfig('audio', 2),
|
| 36 |
+
'video': ModalityConfig('video', 3)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def process_batch(self, batch_data: Union[torch.Tensor, List[Any]], modality_type: str) -> List[Dict]:
|
| 40 |
+
"""
|
| 41 |
+
将特定模态的数据封装为 segment 格式
|
| 42 |
+
"""
|
| 43 |
+
processed_segments = []
|
| 44 |
+
if modality_type not in self.modality_configs:
|
| 45 |
+
return processed_segments
|
| 46 |
+
|
| 47 |
+
config = self.modality_configs[modality_type]
|
| 48 |
+
|
| 49 |
+
# 确保数据是 Tensor 格式
|
| 50 |
+
if isinstance(batch_data, list):
|
| 51 |
+
# 过滤 None
|
| 52 |
+
valid_data = [x for x in batch_data if x is not None]
|
| 53 |
+
if not valid_data:
|
| 54 |
+
return []
|
| 55 |
+
# 假设 list 中全是 Tensor,且维度一致,进行堆叠
|
| 56 |
+
# 如果是 list of tensor (B, C, H, W) -> stack -> (B, C, H, W)
|
| 57 |
+
try:
|
| 58 |
+
data_tensor = torch.stack(valid_data)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error stacking modality data: {e}")
|
| 61 |
+
return []
|
| 62 |
+
elif isinstance(batch_data, torch.Tensor):
|
| 63 |
+
data_tensor = batch_data
|
| 64 |
+
else:
|
| 65 |
+
return []
|
| 66 |
+
|
| 67 |
+
processed_segments.append({
|
| 68 |
+
'type': modality_type,
|
| 69 |
+
'data': data_tensor, # 保持原始数据 (如图片像素),模型内部会encode
|
| 70 |
+
'modality_id': config.modality_id
|
| 71 |
+
})
|
| 72 |
+
return processed_segments
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ExperienceReplayBuffer:
|
| 76 |
+
"""经验回放缓冲区 - 内存安全版"""
|
| 77 |
+
def __init__(self, max_size: int = 10000):
|
| 78 |
+
self.buffer = deque(maxlen=max_size)
|
| 79 |
+
|
| 80 |
+
def add(self, sample: Dict[str, Any]):
|
| 81 |
+
"""
|
| 82 |
+
添加样本到buffer
|
| 83 |
+
关键修复:将数据移至 CPU 并 detach,防止显存泄漏
|
| 84 |
+
"""
|
| 85 |
+
safe_sample = {}
|
| 86 |
+
for k, v in sample.items():
|
| 87 |
+
if isinstance(v, torch.Tensor):
|
| 88 |
+
safe_sample[k] = v.detach().cpu()
|
| 89 |
+
elif isinstance(v, list):
|
| 90 |
+
# 递归处理 list 中的 tensor
|
| 91 |
+
safe_sample[k] = [x.detach().cpu() if isinstance(x, torch.Tensor) else x for x in v]
|
| 92 |
+
else:
|
| 93 |
+
safe_sample[k] = v
|
| 94 |
+
self.buffer.append(safe_sample)
|
| 95 |
+
|
| 96 |
+
def sample(self, batch_size: int) -> List[Any]:
|
| 97 |
+
"""从buffer中采样"""
|
| 98 |
+
if not self.buffer:
|
| 99 |
+
return []
|
| 100 |
+
|
| 101 |
+
indices = np.random.choice(
|
| 102 |
+
len(self.buffer),
|
| 103 |
+
min(len(self.buffer), batch_size),
|
| 104 |
+
replace=False
|
| 105 |
+
)
|
| 106 |
+
return [self.buffer[i] for i in indices]
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.buffer)
|
| 110 |
+
|
| 111 |
+
def clear(self):
|
| 112 |
+
"""清空buffer"""
|
| 113 |
+
self.buffer.clear()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class EWC:
|
| 117 |
+
"""弹性权重固化 (Elastic Weight Consolidation)"""
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
model: nn.Module,
|
| 121 |
+
dataloader: DataLoader,
|
| 122 |
+
preprocessor: UnifiedMultiModalPreprocessor,
|
| 123 |
+
importance: float = 1000.0
|
| 124 |
+
):
|
| 125 |
+
self.model = model
|
| 126 |
+
self.preprocessor = preprocessor
|
| 127 |
+
self.importance = importance
|
| 128 |
+
self.device = next(model.parameters()).device
|
| 129 |
+
|
| 130 |
+
# 冻结当前参数作为参考
|
| 131 |
+
self.params = {
|
| 132 |
+
n: p.clone().detach()
|
| 133 |
+
for n, p in model.named_parameters()
|
| 134 |
+
if p.requires_grad
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
self.fisher = self._compute_fisher(dataloader)
|
| 138 |
+
|
| 139 |
+
def _compute_fisher(self, dataloader: DataLoader) -> Dict[str, torch.Tensor]:
|
| 140 |
+
"""计算Fisher信息矩阵 (使用 Empirical Fisher)"""
|
| 141 |
+
fisher = {
|
| 142 |
+
n: torch.zeros_like(p)
|
| 143 |
+
for n, p in self.model.named_parameters()
|
| 144 |
+
if p.requires_grad
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
self.model.eval()
|
| 148 |
+
num_samples = 0
|
| 149 |
+
|
| 150 |
+
# 使用 tqdm 稍微简化输出
|
| 151 |
+
pbar = tqdm(dataloader, desc="Computing Fisher Matrix", leave=False)
|
| 152 |
+
for batch in pbar:
|
| 153 |
+
if batch is None: continue
|
| 154 |
+
|
| 155 |
+
self.model.zero_grad()
|
| 156 |
+
|
| 157 |
+
# 1. 准备文本输入
|
| 158 |
+
instruction_ids = batch['instruction'].to(self.device)
|
| 159 |
+
response_ids = batch['response'].to(self.device)
|
| 160 |
+
# 拼接: [Instruction, Response]
|
| 161 |
+
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
|
| 162 |
+
|
| 163 |
+
# 2. 准备多模态输入结构
|
| 164 |
+
input_data = {'segments': []}
|
| 165 |
+
|
| 166 |
+
# 处理额外的模态数据 (如果有)
|
| 167 |
+
# 这里的 batch['modality_data'] 可能是 list (由 collate_fn_v2 生成)
|
| 168 |
+
raw_modality_data = batch.get('modality_data')
|
| 169 |
+
if raw_modality_data is not None:
|
| 170 |
+
# 尝试判断模态类型,如果 dataset 中没有明确指定,默认尝试 'image'
|
| 171 |
+
# 实际应用中建议 dataset 返回 'modality_type'
|
| 172 |
+
modality_type = batch.get('modality_type', 'image')
|
| 173 |
+
if isinstance(modality_type, list): modality_type = modality_type[0]
|
| 174 |
+
|
| 175 |
+
# Preprocessor 处理数据堆叠和格式化
|
| 176 |
+
mod_segments = self.preprocessor.process_batch(raw_modality_data, modality_type)
|
| 177 |
+
# 只有在数据有效时才传给 device
|
| 178 |
+
for seg in mod_segments:
|
| 179 |
+
seg['data'] = seg['data'].to(self.device)
|
| 180 |
+
input_data['segments'].append(seg)
|
| 181 |
+
|
| 182 |
+
# 添加文本 Segment
|
| 183 |
+
input_data['segments'].append({
|
| 184 |
+
'type': 'text',
|
| 185 |
+
'data': input_ids,
|
| 186 |
+
'modality_id': 0
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
# 3. 前向传播
|
| 190 |
+
output = self.model(input_data)
|
| 191 |
+
logits = output['logits'] # (B, Seq_Len, Vocab)
|
| 192 |
+
|
| 193 |
+
# 4. 计算 Loss (Standard Causal LM Loss)
|
| 194 |
+
# Shift logits and labels
|
| 195 |
+
# input_ids: [I1, I2, R1, R2]
|
| 196 |
+
# labels: [I2, R1, R2, EOS]
|
| 197 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 198 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
| 199 |
+
|
| 200 |
+
# 创建 Mask: 只在 Response 部分计算梯度
|
| 201 |
+
# Instruction 长度
|
| 202 |
+
inst_len = instruction_ids.shape[1]
|
| 203 |
+
loss_mask = torch.ones_like(shift_labels, dtype=torch.float)
|
| 204 |
+
if inst_len > 1:
|
| 205 |
+
# 掩盖 Instruction 部分 (注意 shift 后的索引偏移)
|
| 206 |
+
loss_mask[:, :inst_len-1] = 0.0
|
| 207 |
+
|
| 208 |
+
# 计算逐个 Token 的 Loss
|
| 209 |
+
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
| 210 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 211 |
+
|
| 212 |
+
# 应用 Mask 并求平均
|
| 213 |
+
loss = (loss * loss_mask.view(-1)).sum() / (loss_mask.sum() + 1e-6)
|
| 214 |
+
|
| 215 |
+
# 5. 反向传播累积梯度平方
|
| 216 |
+
loss.backward()
|
| 217 |
+
|
| 218 |
+
for n, p in self.model.named_parameters():
|
| 219 |
+
if p.grad is not None and n in fisher:
|
| 220 |
+
fisher[n] += p.grad.detach() ** 2
|
| 221 |
+
|
| 222 |
+
num_samples += input_ids.size(0)
|
| 223 |
+
|
| 224 |
+
# 平均化
|
| 225 |
+
if num_samples > 0:
|
| 226 |
+
for n in fisher:
|
| 227 |
+
fisher[n] /= num_samples
|
| 228 |
+
|
| 229 |
+
self.model.train()
|
| 230 |
+
return fisher
|
| 231 |
+
|
| 232 |
+
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
|
| 233 |
+
"""计算EWC惩罚项"""
|
| 234 |
+
# 兼容性处理:如果传入了 model 参数,优先使用(通常 self.model 就是同一个)
|
| 235 |
+
target_model = model if model is not None else self.model
|
| 236 |
+
|
| 237 |
+
loss = torch.tensor(0.0, device=self.device)
|
| 238 |
+
|
| 239 |
+
for n, p in target_model.named_parameters():
|
| 240 |
+
if n in self.params and p.requires_grad:
|
| 241 |
+
if n in self.fisher:
|
| 242 |
+
loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
|
| 243 |
+
|
| 244 |
+
return self.importance * loss
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class OnlineEWC(EWC):
|
| 248 |
+
"""在线EWC - 支持持续更新Fisher矩阵"""
|
| 249 |
+
def __init__(
|
| 250 |
+
self,
|
| 251 |
+
model: nn.Module,
|
| 252 |
+
preprocessor: UnifiedMultiModalPreprocessor,
|
| 253 |
+
importance: float = 1000.0,
|
| 254 |
+
gamma: float = 0.9
|
| 255 |
+
):
|
| 256 |
+
# 初始时不计算 Fisher,等待 update_fisher 调用
|
| 257 |
+
self.model = model
|
| 258 |
+
self.preprocessor = preprocessor
|
| 259 |
+
self.importance = importance
|
| 260 |
+
self.gamma = gamma
|
| 261 |
+
self.device = next(model.parameters()).device
|
| 262 |
+
|
| 263 |
+
self.params = {}
|
| 264 |
+
self.fisher = {}
|
| 265 |
+
self.task_count = 0
|
| 266 |
+
|
| 267 |
+
def update_fisher(self, dataloader: DataLoader):
|
| 268 |
+
"""更新Fisher信息矩阵"""
|
| 269 |
+
print(f"Updating Online EWC Fisher Matrix (Task {self.task_count + 1})...")
|
| 270 |
+
new_fisher = self._compute_fisher(dataloader)
|
| 271 |
+
|
| 272 |
+
if self.task_count == 0:
|
| 273 |
+
self.fisher = new_fisher
|
| 274 |
+
else:
|
| 275 |
+
for n in self.fisher:
|
| 276 |
+
if n in new_fisher:
|
| 277 |
+
# 移动平均更新 Fisher 信息
|
| 278 |
+
self.fisher[n] = self.gamma * self.fisher[n] + new_fisher[n]
|
| 279 |
+
|
| 280 |
+
# 更新参考参数为当前任务训练后的参数
|
| 281 |
+
self.params = {
|
| 282 |
+
n: p.clone().detach()
|
| 283 |
+
for n, p in self.model.named_parameters()
|
| 284 |
+
if p.requires_grad
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
self.task_count += 1
|
| 288 |
+
print(f"Online EWC regularizer updated.")
|
| 289 |
+
|
| 290 |
+
def penalty(self, model: Optional[nn.Module] = None) -> torch.Tensor:
|
| 291 |
+
"""计算EWC惩罚项"""
|
| 292 |
+
if self.task_count == 0:
|
| 293 |
+
return torch.tensor(0.0, device=self.device)
|
| 294 |
+
return super().penalty(model)
|
contrastive_learning.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Dict, Optional, Tuple, Union, Literal, List
|
| 5 |
+
import math
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
class CLIPLoss(nn.Module):
|
| 9 |
+
"""CLIP风格的对比学习损失"""
|
| 10 |
+
def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
self.max_temperature = max_temperature
|
| 14 |
+
# 初始化 logit_scale
|
| 15 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature))
|
| 16 |
+
|
| 17 |
+
def forward(
|
| 18 |
+
self,
|
| 19 |
+
image_features: torch.Tensor,
|
| 20 |
+
text_features: torch.Tensor
|
| 21 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
image_features: [B, D]
|
| 25 |
+
text_features: [B, D]
|
| 26 |
+
"""
|
| 27 |
+
# 归一化
|
| 28 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 29 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 30 |
+
|
| 31 |
+
# 限制 logit_scale 防止数值不稳定
|
| 32 |
+
logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature)
|
| 33 |
+
|
| 34 |
+
# 计算相似度矩阵 [B, B]
|
| 35 |
+
# 注意:在 DDP 环境下,这里计算的是局部 Batch 的 Loss。
|
| 36 |
+
# 完整的 DDP 实现需要 gather 所有 GPU 的 features。
|
| 37 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 38 |
+
logits_per_text = logits_per_image.T
|
| 39 |
+
|
| 40 |
+
# 标签: 对角线为正样本
|
| 41 |
+
batch_size = image_features.shape[0]
|
| 42 |
+
labels = torch.arange(batch_size, device=image_features.device)
|
| 43 |
+
|
| 44 |
+
# 双向交叉熵
|
| 45 |
+
loss_i2t = F.cross_entropy(logits_per_image, labels)
|
| 46 |
+
loss_t2i = F.cross_entropy(logits_per_text, labels)
|
| 47 |
+
|
| 48 |
+
total_loss = (loss_i2t + loss_t2i) / 2
|
| 49 |
+
|
| 50 |
+
return total_loss, loss_i2t, loss_t2i
|
| 51 |
+
|
| 52 |
+
class SigLIPLoss(nn.Module):
|
| 53 |
+
"""
|
| 54 |
+
SigLIP损失 - 包含可学习的 Bias 和 Scale
|
| 55 |
+
Paper: Sigmoid Loss for Language Image Pre-Training
|
| 56 |
+
"""
|
| 57 |
+
def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature)))
|
| 60 |
+
self.b = nn.Parameter(torch.tensor(init_bias))
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
image_features: torch.Tensor,
|
| 65 |
+
text_features: torch.Tensor
|
| 66 |
+
) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
注意:SigLIP 的标准实现不需要 Gather 全局负样本即可收敛,
|
| 69 |
+
但这里实现的是 dense pair loss。对于超大 Batch (如 8k+),
|
| 70 |
+
构造 [B, B] 的 labels 矩阵会导致显存爆炸,生产环境建议使用 custom kernel 或 block chunking。
|
| 71 |
+
"""
|
| 72 |
+
# 归一化
|
| 73 |
+
image_features = F.normalize(image_features, dim=-1)
|
| 74 |
+
text_features = F.normalize(text_features, dim=-1)
|
| 75 |
+
|
| 76 |
+
batch_size = image_features.shape[0]
|
| 77 |
+
|
| 78 |
+
# Logits = exp(t) * (x @ yT) + b
|
| 79 |
+
logits = image_features @ text_features.T * self.t_prime.exp() + self.b
|
| 80 |
+
|
| 81 |
+
# 构造标签: 对角线为1,其余为-1
|
| 82 |
+
labels = -torch.ones(batch_size, batch_size, device=image_features.device)
|
| 83 |
+
labels += 2 * torch.eye(batch_size, device=image_features.device)
|
| 84 |
+
|
| 85 |
+
# Sigmoid Loss: -log(sigmoid(label * logits))
|
| 86 |
+
# 当 label=1: -log(sigmoid(z))
|
| 87 |
+
# 当 label=-1: -log(sigmoid(-z)) = -log(1 - sigmoid(z))
|
| 88 |
+
# 这就是标准的 Binary Cross Entropy (Summed)
|
| 89 |
+
|
| 90 |
+
# SigLIP 论文中通常建议除以 batch_size (或正样本数量) 进行归一化
|
| 91 |
+
loss = -F.logsigmoid(labels * logits).sum() / batch_size
|
| 92 |
+
|
| 93 |
+
return loss
|
| 94 |
+
|
| 95 |
+
class InfoNCELoss(nn.Module):
|
| 96 |
+
"""InfoNCE损失 - 支持显式负样本或 Batch 内负样本"""
|
| 97 |
+
def __init__(self, temperature: float = 0.07):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.temperature = temperature
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
query: torch.Tensor,
|
| 104 |
+
positive_key: torch.Tensor,
|
| 105 |
+
negative_keys: Optional[torch.Tensor] = None
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
query: [B, D]
|
| 110 |
+
positive_key: [B, D]
|
| 111 |
+
negative_keys: [B, N, D] or None.
|
| 112 |
+
"""
|
| 113 |
+
query = F.normalize(query, dim=-1)
|
| 114 |
+
positive_key = F.normalize(positive_key, dim=-1)
|
| 115 |
+
|
| 116 |
+
if negative_keys is not None:
|
| 117 |
+
# 显式负样本
|
| 118 |
+
# pos_sim: [B]
|
| 119 |
+
pos_sim = (query * positive_key).sum(dim=-1) / self.temperature
|
| 120 |
+
|
| 121 |
+
negative_keys = F.normalize(negative_keys, dim=-1)
|
| 122 |
+
# neg_sim: [B, N]
|
| 123 |
+
neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature
|
| 124 |
+
|
| 125 |
+
# [B, 1 + N]
|
| 126 |
+
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
|
| 127 |
+
# 正样本在索引0
|
| 128 |
+
labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
|
| 129 |
+
else:
|
| 130 |
+
# Batch内负样本 (类似于 CLIP 的单向 Loss)
|
| 131 |
+
logits = query @ positive_key.T / self.temperature
|
| 132 |
+
labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device)
|
| 133 |
+
|
| 134 |
+
loss = F.cross_entropy(logits, labels)
|
| 135 |
+
return loss
|
| 136 |
+
|
| 137 |
+
class ProjectionHead(nn.Module):
|
| 138 |
+
"""
|
| 139 |
+
投影头:处理特征维度变换和形状适配
|
| 140 |
+
针对 Transformer 输出 (Sequence) 提供了更精细的 Pooling 控制。
|
| 141 |
+
"""
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
input_dim: int,
|
| 145 |
+
embed_dim: int,
|
| 146 |
+
pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean',
|
| 147 |
+
exclude_first_token: bool = False
|
| 148 |
+
):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.pooling_type = pooling_type
|
| 151 |
+
self.exclude_first_token = exclude_first_token
|
| 152 |
+
|
| 153 |
+
self.net = nn.Sequential(
|
| 154 |
+
nn.Linear(input_dim, embed_dim),
|
| 155 |
+
nn.GELU(),
|
| 156 |
+
nn.Linear(embed_dim, embed_dim)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
# 适配 3D 张量 [B, Seq, D] -> [B, D]
|
| 161 |
+
if x.dim() == 3:
|
| 162 |
+
if self.pooling_type == 'cls':
|
| 163 |
+
# 假设索引0是CLS token (Standard ViT / BERT)
|
| 164 |
+
x = x[:, 0, :]
|
| 165 |
+
|
| 166 |
+
elif self.pooling_type == 'mean':
|
| 167 |
+
if self.exclude_first_token and x.shape[1] > 1:
|
| 168 |
+
# 对于 ViT,如果使用 mean pooling,通常需要排除 CLS token
|
| 169 |
+
x = x[:, 1:, :].mean(dim=1)
|
| 170 |
+
else:
|
| 171 |
+
x = x.mean(dim=1)
|
| 172 |
+
|
| 173 |
+
elif self.pooling_type == 'max':
|
| 174 |
+
if self.exclude_first_token and x.shape[1] > 1:
|
| 175 |
+
x = x[:, 1:, :].max(dim=1)[0]
|
| 176 |
+
else:
|
| 177 |
+
x = x.max(dim=1)[0]
|
| 178 |
+
|
| 179 |
+
elif self.pooling_type == 'none':
|
| 180 |
+
# 保留序列维度,适用于 Dense Prediction 或细粒度对比
|
| 181 |
+
# 此时输出为 [B, Seq, embed_dim]
|
| 182 |
+
pass
|
| 183 |
+
|
| 184 |
+
return self.net(x)
|
| 185 |
+
|
| 186 |
+
class MultiModalContrastiveLoss(nn.Module):
|
| 187 |
+
"""多模态对比学习损失 - 支持动态模态和异构维度"""
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
embed_dim: int = 512,
|
| 191 |
+
input_dims: Union[int, Dict[str, int]] = 2048,
|
| 192 |
+
temperature: float = 0.07,
|
| 193 |
+
loss_type: str = 'clip',
|
| 194 |
+
modality_config: Optional[Dict[str, str]] = None
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.embed_dim = embed_dim
|
| 198 |
+
self.loss_type = loss_type
|
| 199 |
+
|
| 200 |
+
if loss_type == 'clip':
|
| 201 |
+
self.loss_fn = CLIPLoss(temperature)
|
| 202 |
+
elif loss_type == 'siglip':
|
| 203 |
+
self.loss_fn = SigLIPLoss()
|
| 204 |
+
else:
|
| 205 |
+
self.loss_fn = InfoNCELoss(temperature)
|
| 206 |
+
|
| 207 |
+
self.projectors = nn.ModuleDict()
|
| 208 |
+
|
| 209 |
+
if modality_config is None:
|
| 210 |
+
# 默认常用模态配置
|
| 211 |
+
# 注意:ImprovedVisionTransformer 输出带 CLS,所以图像推荐用 'cls' 或带排除的 'mean'
|
| 212 |
+
modality_config = {
|
| 213 |
+
'text': 'cls',
|
| 214 |
+
'image': 'cls',
|
| 215 |
+
'audio': 'mean', # AudioEncoder 的双流输出已经是 2D,但如果是纯 Transformer 输出则是 3D
|
| 216 |
+
'video': 'mean' # VideoEncoder 输出通常是 [B, T, D]
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
self.modality_config = modality_config
|
| 220 |
+
|
| 221 |
+
# 初始化投影头
|
| 222 |
+
for mod_name, pool_type in modality_config.items():
|
| 223 |
+
dim = 0
|
| 224 |
+
if isinstance(input_dims, dict):
|
| 225 |
+
dim = input_dims.get(mod_name)
|
| 226 |
+
# 如果字典里没给这个模态的维度,跳过初始化,避免 crash
|
| 227 |
+
if dim is None:
|
| 228 |
+
continue
|
| 229 |
+
else:
|
| 230 |
+
dim = input_dims
|
| 231 |
+
|
| 232 |
+
# 特殊处理:如果是 'mean' 或 'max' 且是 image/text,可能需要排除 CLS
|
| 233 |
+
# 这里做一个启发式判断,用户也可以手动修改
|
| 234 |
+
exclude_first = False
|
| 235 |
+
if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']:
|
| 236 |
+
exclude_first = True
|
| 237 |
+
|
| 238 |
+
self.projectors[mod_name] = ProjectionHead(
|
| 239 |
+
input_dim=dim,
|
| 240 |
+
embed_dim=embed_dim,
|
| 241 |
+
pooling_type=pool_type,
|
| 242 |
+
exclude_first_token=exclude_first
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
features: Dict[str, torch.Tensor],
|
| 248 |
+
modality_pairs: Optional[List[Tuple[str, str]]] = None
|
| 249 |
+
) -> Dict[str, torch.Tensor]:
|
| 250 |
+
|
| 251 |
+
# 自动生成对比对:将所有非Text模态与Text对比
|
| 252 |
+
if modality_pairs is None:
|
| 253 |
+
if 'text' in features:
|
| 254 |
+
modality_pairs = [
|
| 255 |
+
(mod, 'text') for mod in features.keys() if mod != 'text'
|
| 256 |
+
]
|
| 257 |
+
else:
|
| 258 |
+
return {}
|
| 259 |
+
|
| 260 |
+
losses = {}
|
| 261 |
+
|
| 262 |
+
for mod_a, mod_b in modality_pairs:
|
| 263 |
+
if mod_a not in features or mod_b not in features:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
if mod_a not in self.projectors or mod_b not in self.projectors:
|
| 267 |
+
# 记录警告或跳过
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
feat_a = self.projectors[mod_a](features[mod_a])
|
| 271 |
+
feat_b = self.projectors[mod_b](features[mod_b])
|
| 272 |
+
|
| 273 |
+
# 计算损失
|
| 274 |
+
loss_key = f'{mod_a}_{mod_b}_loss'
|
| 275 |
+
|
| 276 |
+
if self.loss_type == 'clip':
|
| 277 |
+
loss, _, _ = self.loss_fn(feat_a, feat_b)
|
| 278 |
+
else:
|
| 279 |
+
loss = self.loss_fn(feat_a, feat_b)
|
| 280 |
+
|
| 281 |
+
losses[loss_key] = loss
|
| 282 |
+
|
| 283 |
+
return losses
|
| 284 |
+
|
| 285 |
+
class MomentumEncoder(nn.Module):
|
| 286 |
+
"""
|
| 287 |
+
动量编码器 - 用于MoCo风格的对比学习
|
| 288 |
+
支持参数和 Buffer (如 BatchNorm stats) 的动量更新
|
| 289 |
+
"""
|
| 290 |
+
def __init__(self, encoder: nn.Module, momentum: float = 0.999):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.encoder = encoder
|
| 293 |
+
self.momentum_encoder = self._build_momentum_encoder(encoder)
|
| 294 |
+
self.momentum = momentum
|
| 295 |
+
|
| 296 |
+
def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module:
|
| 297 |
+
"""构建动量编码器"""
|
| 298 |
+
momentum_encoder = copy.deepcopy(encoder)
|
| 299 |
+
|
| 300 |
+
# 冻结动量编码器参数
|
| 301 |
+
for param in momentum_encoder.parameters():
|
| 302 |
+
param.requires_grad = False
|
| 303 |
+
|
| 304 |
+
return momentum_encoder
|
| 305 |
+
|
| 306 |
+
@torch.no_grad()
|
| 307 |
+
def _update_momentum_encoder(self):
|
| 308 |
+
"""更新动量编码器 (In-place update)"""
|
| 309 |
+
# 更新参数
|
| 310 |
+
for param_q, param_k in zip(
|
| 311 |
+
self.encoder.parameters(),
|
| 312 |
+
self.momentum_encoder.parameters()
|
| 313 |
+
):
|
| 314 |
+
# EMA Update: k = m * k + (1 - m) * q
|
| 315 |
+
param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum)
|
| 316 |
+
|
| 317 |
+
# 更新 Buffers (如 BatchNorm running mean/var)
|
| 318 |
+
# 简单的策略是直接覆盖,或者同样使用 EMA。通常直接覆盖即可,
|
| 319 |
+
# 因为 Key Encoder 处于 Eval 模式,不追踪 batch stats。
|
| 320 |
+
for buffer_q, buffer_k in zip(
|
| 321 |
+
self.encoder.buffers(),
|
| 322 |
+
self.momentum_encoder.buffers()
|
| 323 |
+
):
|
| 324 |
+
buffer_k.data.copy_(buffer_q.data)
|
| 325 |
+
|
| 326 |
+
def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor:
|
| 327 |
+
"""
|
| 328 |
+
Args:
|
| 329 |
+
x: 输入数据
|
| 330 |
+
use_momentum: 如果为 True,使用动量编码器 (通常用于生成 Key/Target)
|
| 331 |
+
"""
|
| 332 |
+
if use_momentum:
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
self._update_momentum_encoder()
|
| 335 |
+
# 动量编码器始终处于 eval 模式
|
| 336 |
+
self.momentum_encoder.eval()
|
| 337 |
+
return self.momentum_encoder(x)
|
| 338 |
+
else:
|
| 339 |
+
return self.encoder(x)
|
data_augmentation.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
数据增强模块
|
| 3 |
+
针对不同模态的高级数据增强策略
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Optional, Tuple, List
|
| 9 |
+
import random
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
class RandAugment(nn.Module):
|
| 13 |
+
"""RandAugment for images"""
|
| 14 |
+
def __init__(self, n: int = 2, m: int = 10):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.n = n
|
| 17 |
+
self.m = m
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""随机应用n个增强操作"""
|
| 21 |
+
# 确保输入是 [B, C, H, W],如果是 [C, H, W] 则增加维度
|
| 22 |
+
is_batched = x.ndim == 4
|
| 23 |
+
if not is_batched:
|
| 24 |
+
x = x.unsqueeze(0)
|
| 25 |
+
|
| 26 |
+
augmentations = [
|
| 27 |
+
self._auto_contrast,
|
| 28 |
+
self._equalize,
|
| 29 |
+
self._solarize,
|
| 30 |
+
self._color,
|
| 31 |
+
self._contrast,
|
| 32 |
+
self._brightness,
|
| 33 |
+
self._sharpness,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# 这里的ops应该是每一轮随机选择,而不是固定
|
| 37 |
+
for _ in range(self.n):
|
| 38 |
+
aug = random.choice(augmentations)
|
| 39 |
+
x = aug(x)
|
| 40 |
+
|
| 41 |
+
if not is_batched:
|
| 42 |
+
x = x.squeeze(0)
|
| 43 |
+
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
def _auto_contrast(self, x: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""自动对比度: 线性拉伸到 [0, 1]"""
|
| 48 |
+
# 针对每个样本分别计算 min/max
|
| 49 |
+
# x: [B, C, H, W]
|
| 50 |
+
B, C, H, W = x.shape
|
| 51 |
+
x_flat = x.view(B, C, -1)
|
| 52 |
+
min_val = x_flat.min(dim=2, keepdim=True)[0].view(B, C, 1, 1)
|
| 53 |
+
max_val = x_flat.max(dim=2, keepdim=True)[0].view(B, C, 1, 1)
|
| 54 |
+
return (x - min_val) / (max_val - min_val + 1e-8)
|
| 55 |
+
|
| 56 |
+
def _equalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
"""直方图均衡化 (简化版:基于每个通道的CDF)"""
|
| 58 |
+
# 这是一个计算密集型操作,PyTorch原生实现较复杂。
|
| 59 |
+
# 这里实现一个基于排序的简化版本,模拟均衡化效果
|
| 60 |
+
B, C, H, W = x.shape
|
| 61 |
+
# 将像素值缩放到 [0, 255] 离散化以便计算直方图
|
| 62 |
+
x_int = (x * 255).long().clamp(0, 255)
|
| 63 |
+
|
| 64 |
+
out = torch.zeros_like(x)
|
| 65 |
+
|
| 66 |
+
for b in range(B):
|
| 67 |
+
for c in range(C):
|
| 68 |
+
hist = torch.histc(x[b, c].float(), bins=256, min=0, max=1)
|
| 69 |
+
cdf = hist.cumsum(0)
|
| 70 |
+
cdf = cdf / cdf[-1] # 归一化
|
| 71 |
+
# 使用cdf作为查找表
|
| 72 |
+
out[b, c] = cdf[x_int[b, c]]
|
| 73 |
+
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
def _solarize(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
"""曝光"""
|
| 78 |
+
threshold = random.uniform(0.3, 0.7)
|
| 79 |
+
return torch.where(x < threshold, x, 1.0 - x)
|
| 80 |
+
|
| 81 |
+
def _color(self, x: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
"""颜色增强 (饱和度)"""
|
| 83 |
+
factor = 1.0 + (random.random() - 0.5) * 0.4
|
| 84 |
+
# RGB转灰度简单近似: mean over channels
|
| 85 |
+
# x is [B, C, H, W], dim=1 is channels
|
| 86 |
+
mean = x.mean(dim=1, keepdim=True)
|
| 87 |
+
return torch.clamp(mean + factor * (x - mean), 0, 1)
|
| 88 |
+
|
| 89 |
+
def _contrast(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
"""对比度"""
|
| 91 |
+
factor = 1.0 + (random.random() - 0.5) * 0.4
|
| 92 |
+
# 计算整张图的均值,保留 Batch 维度
|
| 93 |
+
# view(B, -1) -> mean(1) -> view(B, 1, 1, 1)
|
| 94 |
+
mean = x.view(x.size(0), -1).mean(dim=1).view(-1, 1, 1, 1)
|
| 95 |
+
return torch.clamp(mean + factor * (x - mean), 0, 1)
|
| 96 |
+
|
| 97 |
+
def _brightness(self, x: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
"""亮度"""
|
| 99 |
+
factor = 1.0 + (random.random() - 0.5) * 0.4
|
| 100 |
+
return torch.clamp(x * factor, 0, 1)
|
| 101 |
+
|
| 102 |
+
def _sharpness(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""锐化: 通过混合原图和高斯模糊图实现"""
|
| 104 |
+
factor = 1.0 + (random.random() - 0.5) * 0.4
|
| 105 |
+
# 使用 AvgPool 模拟模糊
|
| 106 |
+
kernel_size = 3
|
| 107 |
+
pad = kernel_size // 2
|
| 108 |
+
blurred = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad)
|
| 109 |
+
# 锐化公式: Original + alpha * (Original - Blurred)
|
| 110 |
+
# 或者简单的混合: Blend(Original, Blurred, factor)
|
| 111 |
+
# 这里使用 PIL 风格的锐化:
|
| 112 |
+
# result = original * factor + blurred * (1 - factor)
|
| 113 |
+
# 但要注意 factor>1 时是锐化,factor<1 是模糊
|
| 114 |
+
# 更标准的锐化掩模: x + factor * (x - blurred)
|
| 115 |
+
return torch.clamp(x + (factor - 1.0) * (x - blurred), 0, 1)
|
| 116 |
+
|
| 117 |
+
class MixUp(nn.Module):
|
| 118 |
+
"""MixUp数据增强"""
|
| 119 |
+
def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.alpha = alpha
|
| 122 |
+
self.num_classes = num_classes
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
x: torch.Tensor,
|
| 127 |
+
y: Optional[torch.Tensor] = None
|
| 128 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
|
| 129 |
+
|
| 130 |
+
if self.alpha > 0:
|
| 131 |
+
lambda_ = random.betavariate(self.alpha, self.alpha)
|
| 132 |
+
else:
|
| 133 |
+
lambda_ = 1.0
|
| 134 |
+
|
| 135 |
+
batch_size = x.shape[0]
|
| 136 |
+
index = torch.randperm(batch_size, device=x.device)
|
| 137 |
+
|
| 138 |
+
mixed_x = lambda_ * x + (1 - lambda_) * x[index]
|
| 139 |
+
|
| 140 |
+
mixed_y = None
|
| 141 |
+
if y is not None:
|
| 142 |
+
# 处理标签混合
|
| 143 |
+
y_a = y
|
| 144 |
+
y_b = y[index]
|
| 145 |
+
|
| 146 |
+
# 检查标签是否需要 One-Hot 编码 (如果 y 是 long 类型且维度不对)
|
| 147 |
+
if y.dtype == torch.long or y.ndim == 1:
|
| 148 |
+
if self.num_classes is None:
|
| 149 |
+
# 如果未提供 num_classes,尝试推断 (可能有风险)
|
| 150 |
+
self.num_classes = int(y.max().item()) + 1
|
| 151 |
+
|
| 152 |
+
y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
|
| 153 |
+
y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
|
| 154 |
+
|
| 155 |
+
mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
|
| 156 |
+
|
| 157 |
+
return mixed_x, mixed_y, lambda_
|
| 158 |
+
|
| 159 |
+
class CutMix(nn.Module):
|
| 160 |
+
"""CutMix数据增强"""
|
| 161 |
+
def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.alpha = alpha
|
| 164 |
+
self.num_classes = num_classes
|
| 165 |
+
|
| 166 |
+
def _rand_bbox(
|
| 167 |
+
self,
|
| 168 |
+
size: Tuple[int, ...],
|
| 169 |
+
lambda_: float
|
| 170 |
+
) -> Tuple[int, int, int, int]:
|
| 171 |
+
"""生成随机bbox"""
|
| 172 |
+
W = size[-1] # 兼容 [B, C, H, W]
|
| 173 |
+
H = size[-2]
|
| 174 |
+
cut_rat = math.sqrt(1.0 - lambda_)
|
| 175 |
+
cut_w = int(W * cut_rat)
|
| 176 |
+
cut_h = int(H * cut_rat)
|
| 177 |
+
|
| 178 |
+
cx = random.randint(0, W)
|
| 179 |
+
cy = random.randint(0, H)
|
| 180 |
+
|
| 181 |
+
bbx1 = torch.tensor(cx - cut_w // 2, device='cpu').clamp(0, W).item()
|
| 182 |
+
bby1 = torch.tensor(cy - cut_h // 2, device='cpu').clamp(0, H).item()
|
| 183 |
+
bbx2 = torch.tensor(cx + cut_w // 2, device='cpu').clamp(0, W).item()
|
| 184 |
+
bby2 = torch.tensor(cy + cut_h // 2, device='cpu').clamp(0, H).item()
|
| 185 |
+
|
| 186 |
+
return int(bbx1), int(bby1), int(bbx2), int(bby2)
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self,
|
| 190 |
+
x: torch.Tensor,
|
| 191 |
+
y: Optional[torch.Tensor] = None
|
| 192 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
|
| 193 |
+
|
| 194 |
+
if self.alpha > 0:
|
| 195 |
+
lambda_ = random.betavariate(self.alpha, self.alpha)
|
| 196 |
+
else:
|
| 197 |
+
lambda_ = 1.0
|
| 198 |
+
|
| 199 |
+
batch_size = x.shape[0]
|
| 200 |
+
index = torch.randperm(batch_size, device=x.device)
|
| 201 |
+
|
| 202 |
+
bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lambda_)
|
| 203 |
+
|
| 204 |
+
# 克隆防止就地修改影响后续梯度计算 (虽然这里是输入数据处理,通常还好)
|
| 205 |
+
x = x.clone()
|
| 206 |
+
x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
|
| 207 |
+
|
| 208 |
+
# 调整lambda为精确的像素比例
|
| 209 |
+
# 注意: 原始代码中宽高的计算顺序可能有歧义,这里统一 H=size[-2], W=size[-1]
|
| 210 |
+
H, W = x.size()[-2], x.size()[-1]
|
| 211 |
+
lambda_ = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
|
| 212 |
+
|
| 213 |
+
mixed_y = None
|
| 214 |
+
if y is not None:
|
| 215 |
+
y_a = y
|
| 216 |
+
y_b = y[index]
|
| 217 |
+
|
| 218 |
+
if y.dtype == torch.long or y.ndim == 1:
|
| 219 |
+
if self.num_classes is None:
|
| 220 |
+
# 最好在初始化时传入 num_classes
|
| 221 |
+
self.num_classes = int(y.max().item()) + 1
|
| 222 |
+
y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
|
| 223 |
+
y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
|
| 224 |
+
|
| 225 |
+
mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
|
| 226 |
+
|
| 227 |
+
return x, mixed_y, lambda_
|
| 228 |
+
|
| 229 |
+
class SpecAugment(nn.Module):
|
| 230 |
+
"""SpecAugment for audio spectrograms"""
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
freq_mask_param: int = 27,
|
| 234 |
+
time_mask_param: int = 100,
|
| 235 |
+
num_freq_masks: int = 2,
|
| 236 |
+
num_time_masks: int = 2
|
| 237 |
+
):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.freq_mask_param = freq_mask_param
|
| 240 |
+
self.time_mask_param = time_mask_param
|
| 241 |
+
self.num_freq_masks = num_freq_masks
|
| 242 |
+
self.num_time_masks = num_time_masks
|
| 243 |
+
|
| 244 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 245 |
+
"""
|
| 246 |
+
Args:
|
| 247 |
+
spec: [B, F, T] or [B, C, F, T]
|
| 248 |
+
"""
|
| 249 |
+
input_ndim = spec.ndim
|
| 250 |
+
if input_ndim == 3:
|
| 251 |
+
spec = spec.unsqueeze(1) # [B, 1, F, T]
|
| 252 |
+
|
| 253 |
+
B, C, F, T = spec.shape
|
| 254 |
+
spec = spec.clone()
|
| 255 |
+
|
| 256 |
+
# 频率遮罩
|
| 257 |
+
for _ in range(self.num_freq_masks):
|
| 258 |
+
# 确保 mask 不超过 F
|
| 259 |
+
f_param = min(self.freq_mask_param, F)
|
| 260 |
+
f = random.randint(0, f_param)
|
| 261 |
+
f0 = random.randint(0, max(0, F - f))
|
| 262 |
+
spec[:, :, f0:f0+f, :] = 0
|
| 263 |
+
|
| 264 |
+
# 时间遮罩
|
| 265 |
+
for _ in range(self.num_time_masks):
|
| 266 |
+
# 确保 mask 不超过 T
|
| 267 |
+
t_param = min(self.time_mask_param, T)
|
| 268 |
+
t = random.randint(0, t_param)
|
| 269 |
+
t0 = random.randint(0, max(0, T - t))
|
| 270 |
+
spec[:, :, :, t0:t0+t] = 0
|
| 271 |
+
|
| 272 |
+
if input_ndim == 3:
|
| 273 |
+
return spec.squeeze(1)
|
| 274 |
+
return spec
|
| 275 |
+
|
| 276 |
+
class TemporalMasking(nn.Module):
|
| 277 |
+
"""视频的时序遮罩"""
|
| 278 |
+
def __init__(self, mask_ratio: float = 0.15):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.mask_ratio = mask_ratio
|
| 281 |
+
|
| 282 |
+
def forward(self, video: torch.Tensor) -> torch.Tensor:
|
| 283 |
+
"""
|
| 284 |
+
Args:
|
| 285 |
+
video: [B, T, C, H, W]
|
| 286 |
+
"""
|
| 287 |
+
B, T, C, H, W = video.shape
|
| 288 |
+
num_mask = int(T * self.mask_ratio)
|
| 289 |
+
if num_mask == 0:
|
| 290 |
+
return video
|
| 291 |
+
|
| 292 |
+
video = video.clone()
|
| 293 |
+
|
| 294 |
+
for b in range(B):
|
| 295 |
+
# 随机采样要遮罩的帧索引
|
| 296 |
+
mask_indices = torch.randperm(T)[:num_mask]
|
| 297 |
+
video[b, mask_indices] = 0
|
| 298 |
+
|
| 299 |
+
return video
|
| 300 |
+
|
| 301 |
+
class MultiModalAugmentation(nn.Module):
|
| 302 |
+
"""统一的多模态数据增强"""
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
image_aug: bool = True,
|
| 306 |
+
audio_aug: bool = True,
|
| 307 |
+
video_aug: bool = True,
|
| 308 |
+
use_mixup: bool = True,
|
| 309 |
+
use_cutmix: bool = True,
|
| 310 |
+
num_classes: Optional[int] = None
|
| 311 |
+
):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.image_aug = RandAugment() if image_aug else None
|
| 314 |
+
self.audio_aug = SpecAugment() if audio_aug else None
|
| 315 |
+
self.video_aug = TemporalMasking() if video_aug else None
|
| 316 |
+
|
| 317 |
+
self.mixup = MixUp(num_classes=num_classes) if use_mixup else None
|
| 318 |
+
self.cutmix = CutMix(num_classes=num_classes) if use_cutmix else None
|
| 319 |
+
|
| 320 |
+
def forward(
|
| 321 |
+
self,
|
| 322 |
+
data: torch.Tensor,
|
| 323 |
+
modality: str,
|
| 324 |
+
labels: Optional[torch.Tensor] = None
|
| 325 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 326 |
+
"""
|
| 327 |
+
Args:
|
| 328 |
+
data: 输入数据
|
| 329 |
+
modality: 模态类型 ('image', 'audio', 'video')
|
| 330 |
+
labels: 标签(可选)
|
| 331 |
+
"""
|
| 332 |
+
# 1. 模态特定的增强 (Intra-sample augmentation)
|
| 333 |
+
if modality == 'image' and self.image_aug is not None:
|
| 334 |
+
data = self.image_aug(data)
|
| 335 |
+
elif modality == 'audio' and self.audio_aug is not None:
|
| 336 |
+
data = self.audio_aug(data)
|
| 337 |
+
elif modality == 'video' and self.video_aug is not None:
|
| 338 |
+
data = self.video_aug(data)
|
| 339 |
+
|
| 340 |
+
# 2. 混合增强 (Inter-sample augmentation)
|
| 341 |
+
if self.training and labels is not None:
|
| 342 |
+
# 随机选择 MixUp 或 CutMix,或者都不选
|
| 343 |
+
# 策略:如果有 CutMix 且是图片,50%概率 CutMix;否则看有没有 MixUp
|
| 344 |
+
|
| 345 |
+
apply_mixup = False
|
| 346 |
+
apply_cutmix = False
|
| 347 |
+
|
| 348 |
+
p = random.random()
|
| 349 |
+
|
| 350 |
+
# 简单的互斥逻辑:如果有CutMix且是图像,一半概率CutMix,一半概率MixUp(如果有)
|
| 351 |
+
if self.cutmix is not None and modality == 'image':
|
| 352 |
+
if p < 0.5:
|
| 353 |
+
apply_cutmix = True
|
| 354 |
+
elif self.mixup is not None:
|
| 355 |
+
apply_mixup = True
|
| 356 |
+
elif self.mixup is not None:
|
| 357 |
+
# 非图像或无CutMix,则只考虑MixUp
|
| 358 |
+
if p < 0.5: # 假设 50% 概率应用 MixUp
|
| 359 |
+
apply_mixup = True
|
| 360 |
+
|
| 361 |
+
if apply_cutmix:
|
| 362 |
+
data, labels, _ = self.cutmix(data, labels)
|
| 363 |
+
elif apply_mixup:
|
| 364 |
+
data, labels, _ = self.mixup(data, labels)
|
| 365 |
+
|
| 366 |
+
return data, labels
|
data_config.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_config.py
|
| 2 |
+
"""
|
| 3 |
+
预训练和后训练数据集配置
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
PRETRAIN_DATASETS = {
|
| 7 |
+
# 文本数据集
|
| 8 |
+
'the_pile': {
|
| 9 |
+
'type': 'text',
|
| 10 |
+
'hf_path': 'EleutherAI/pile',
|
| 11 |
+
'split': 'train',
|
| 12 |
+
'streaming': True,
|
| 13 |
+
'text_field': 'text',
|
| 14 |
+
'weight': 1.0,
|
| 15 |
+
'description': 'The Pile - 825GB diverse text corpus'
|
| 16 |
+
},
|
| 17 |
+
'c4': {
|
| 18 |
+
'type': 'text',
|
| 19 |
+
'hf_path': 'allenai/c4',
|
| 20 |
+
'config': 'en',
|
| 21 |
+
'split': 'train',
|
| 22 |
+
'streaming': True,
|
| 23 |
+
'text_field': 'text',
|
| 24 |
+
'weight': 0.5,
|
| 25 |
+
'description': 'C4 - Colossal Clean Crawled Corpus'
|
| 26 |
+
},
|
| 27 |
+
'wikipedia': {
|
| 28 |
+
'type': 'text',
|
| 29 |
+
'hf_path': 'HuggingFaceFW/fineweb-edu',
|
| 30 |
+
'config': 'sample-10BT',
|
| 31 |
+
'split': 'train',
|
| 32 |
+
'streaming': True,
|
| 33 |
+
'text_field': 'text',
|
| 34 |
+
'weight': 0.3,
|
| 35 |
+
'description': 'FineWeb Edu - High quality educational content'
|
| 36 |
+
},
|
| 37 |
+
'bookcorpus': {
|
| 38 |
+
'type': 'text',
|
| 39 |
+
'hf_path': 'HuggingFaceTB/smollm-corpus',
|
| 40 |
+
'config': 'cosmopedia-v2',
|
| 41 |
+
'split': 'train',
|
| 42 |
+
'streaming': True,
|
| 43 |
+
'text_field': 'text',
|
| 44 |
+
'weight': 0.2,
|
| 45 |
+
'description': 'Synthetic textbooks and stories'
|
| 46 |
+
},
|
| 47 |
+
# 代码数据集
|
| 48 |
+
'codeparrot': {
|
| 49 |
+
'type': 'code',
|
| 50 |
+
'hf_path': 'bigcode/the-stack-smol',
|
| 51 |
+
'config': 'default',
|
| 52 |
+
'split': 'train',
|
| 53 |
+
'streaming': True,
|
| 54 |
+
'text_field': 'content',
|
| 55 |
+
'weight': 0.3,
|
| 56 |
+
'description': 'The Stack Smol - code'
|
| 57 |
+
},
|
| 58 |
+
'the_stack': {
|
| 59 |
+
'type': 'code',
|
| 60 |
+
'hf_path': 'bigcode/the-stack-dedup',
|
| 61 |
+
'split': 'train',
|
| 62 |
+
'streaming': True,
|
| 63 |
+
'text_field': 'content',
|
| 64 |
+
'weight': 0.2,
|
| 65 |
+
'description': 'The Stack - deduplicated code'
|
| 66 |
+
},
|
| 67 |
+
# 多模态数据集
|
| 68 |
+
'laion400m': {
|
| 69 |
+
'type': 'image_text',
|
| 70 |
+
'hf_path': 'laion/laion400m',
|
| 71 |
+
'split': 'train',
|
| 72 |
+
'streaming': True,
|
| 73 |
+
'image_field': 'url',
|
| 74 |
+
'text_field': 'caption',
|
| 75 |
+
'weight': 0.4,
|
| 76 |
+
'description': 'LAION-400M image-text pairs'
|
| 77 |
+
},
|
| 78 |
+
'conceptual_captions': {
|
| 79 |
+
'type': 'image_text',
|
| 80 |
+
'hf_path': 'google-research-datasets/conceptual_captions',
|
| 81 |
+
'split': 'train',
|
| 82 |
+
'streaming': False,
|
| 83 |
+
'image_field': 'image_url',
|
| 84 |
+
'text_field': 'caption',
|
| 85 |
+
'weight': 0.2,
|
| 86 |
+
'description': 'Conceptual Captions 3M'
|
| 87 |
+
},
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# 后训练数据集配置(instruction tuning + alignment)
|
| 91 |
+
POSTTRAIN_DATASETS = {
|
| 92 |
+
# Instruction Tuning数据集
|
| 93 |
+
'flan_v2': {
|
| 94 |
+
'type': 'instruction',
|
| 95 |
+
'hf_path': 'Muennighoff/flan',
|
| 96 |
+
'split': 'train',
|
| 97 |
+
'streaming': True,
|
| 98 |
+
'instruction_field': 'inputs',
|
| 99 |
+
'response_field': 'targets',
|
| 100 |
+
'weight': 1.0,
|
| 101 |
+
'max_samples': 100000,
|
| 102 |
+
'description': 'FLAN v2 collection'
|
| 103 |
+
},
|
| 104 |
+
'alpaca': {
|
| 105 |
+
'type': 'instruction',
|
| 106 |
+
'hf_path': 'tatsu-lab/alpaca',
|
| 107 |
+
'split': 'train',
|
| 108 |
+
'streaming': False,
|
| 109 |
+
'instruction_field': 'instruction',
|
| 110 |
+
'input_field': 'input',
|
| 111 |
+
'response_field': 'output',
|
| 112 |
+
'weight': 0.5,
|
| 113 |
+
'description': 'Stanford Alpaca 52K'
|
| 114 |
+
},
|
| 115 |
+
'dolly': {
|
| 116 |
+
'type': 'instruction',
|
| 117 |
+
'hf_path': 'databricks/databricks-dolly-15k',
|
| 118 |
+
'split': 'train',
|
| 119 |
+
'streaming': False,
|
| 120 |
+
'instruction_field': 'instruction',
|
| 121 |
+
'context_field': 'context', # Dolly有context字段
|
| 122 |
+
'response_field': 'response',
|
| 123 |
+
'weight': 0.3,
|
| 124 |
+
'description': 'Dolly 15K'
|
| 125 |
+
},
|
| 126 |
+
'oasst1': {
|
| 127 |
+
'type': 'conversation',
|
| 128 |
+
'hf_path': 'OpenAssistant/oasst1',
|
| 129 |
+
'split': 'train',
|
| 130 |
+
'streaming': False,
|
| 131 |
+
'weight': 0.4,
|
| 132 |
+
'description': 'OpenAssistant Conversations',
|
| 133 |
+
# OASST1需要特殊处理,因为它是树形结构
|
| 134 |
+
# 可能需要自定义预处理
|
| 135 |
+
},
|
| 136 |
+
'sharegpt': {
|
| 137 |
+
'type': 'conversation',
|
| 138 |
+
'hf_path': 'anon8231489123/ShareGPT_Vicuna_unfiltered',
|
| 139 |
+
'split': 'train',
|
| 140 |
+
'streaming': False,
|
| 141 |
+
'weight': 0.3,
|
| 142 |
+
'max_samples': 50000,
|
| 143 |
+
'description': 'ShareGPT conversations'
|
| 144 |
+
},
|
| 145 |
+
# Code instruction数据集
|
| 146 |
+
'code_alpaca': {
|
| 147 |
+
'type': 'code_instruction',
|
| 148 |
+
'hf_path': 'sahil2801/CodeAlpaca-20k',
|
| 149 |
+
'split': 'train',
|
| 150 |
+
'streaming': False,
|
| 151 |
+
'instruction_field': 'instruction',
|
| 152 |
+
'response_field': 'output',
|
| 153 |
+
'weight': 0.3,
|
| 154 |
+
'description': 'Code Alpaca 20K'
|
| 155 |
+
},
|
| 156 |
+
# 多模态instruction数据集
|
| 157 |
+
'llava_instruct': {
|
| 158 |
+
'type': 'multimodal_instruction',
|
| 159 |
+
'hf_path': 'liuhaotian/LLaVA-Instruct-150K',
|
| 160 |
+
'split': 'train',
|
| 161 |
+
'streaming': False,
|
| 162 |
+
'image_field': 'image',
|
| 163 |
+
'instruction_field': 'conversations',
|
| 164 |
+
'weight': 0.5,
|
| 165 |
+
'description': 'LLaVA visual instruction tuning'
|
| 166 |
+
},
|
| 167 |
+
# Preference数据集 (用于RLHF)
|
| 168 |
+
'hh_rlhf': {
|
| 169 |
+
'type': 'preference',
|
| 170 |
+
'hf_path': 'Anthropic/hh-rlhf',
|
| 171 |
+
'split': 'train',
|
| 172 |
+
'streaming': False,
|
| 173 |
+
'chosen_field': 'chosen',
|
| 174 |
+
'rejected_field': 'rejected',
|
| 175 |
+
'weight': 1.0,
|
| 176 |
+
'description': 'Anthropic HH-RLHF'
|
| 177 |
+
},
|
| 178 |
+
'ultrafeedback': {
|
| 179 |
+
'type': 'preference',
|
| 180 |
+
'hf_path': 'openbmb/UltraFeedback',
|
| 181 |
+
'split': 'train',
|
| 182 |
+
'streaming': True,
|
| 183 |
+
'chosen_field': 'chosen', # 添加字段配置
|
| 184 |
+
'rejected_field': 'rejected',
|
| 185 |
+
'weight': 0.5,
|
| 186 |
+
'max_samples': 50000,
|
| 187 |
+
'description': 'UltraFeedback preferences'
|
| 188 |
+
},
|
| 189 |
+
'debug_water': {
|
| 190 |
+
'type': 'instruction',
|
| 191 |
+
'hf_path': 'json', # 使用 json 加载器
|
| 192 |
+
'data_files': 'debug_water.json', # 指向刚才生成的文件
|
| 193 |
+
'split': 'train',
|
| 194 |
+
'streaming': False,
|
| 195 |
+
'instruction_field': 'instruction',
|
| 196 |
+
'response_field': 'output',
|
| 197 |
+
'weight': 1.0,
|
| 198 |
+
'description': 'Overfitting test for water'
|
| 199 |
+
},
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
# 轻量级测试数据集(用于快速验证)
|
| 203 |
+
TEST_DATASETS = {
|
| 204 |
+
'tiny_shakespeare': {
|
| 205 |
+
'type': 'text',
|
| 206 |
+
'hf_path': 'tiny_shakespeare',
|
| 207 |
+
'split': 'train',
|
| 208 |
+
'streaming': False,
|
| 209 |
+
'text_field': 'text',
|
| 210 |
+
'weight': 1.0,
|
| 211 |
+
'description': 'Tiny Shakespeare for testing'
|
| 212 |
+
},
|
| 213 |
+
'gsm8k': {
|
| 214 |
+
'type': 'instruction',
|
| 215 |
+
'hf_path': 'gsm8k',
|
| 216 |
+
'config': 'main',
|
| 217 |
+
'split': 'train',
|
| 218 |
+
'streaming': False,
|
| 219 |
+
'instruction_field': 'question',
|
| 220 |
+
'response_field': 'answer',
|
| 221 |
+
'weight': 1.0,
|
| 222 |
+
'description': 'GSM8K math problems'
|
| 223 |
+
},
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# 数据集混合策略
|
| 227 |
+
PRETRAIN_MIX = {
|
| 228 |
+
'default': {
|
| 229 |
+
'datasets': ['c4', 'wikipedia', 'bookcorpus', 'codeparrot'],
|
| 230 |
+
'weights': [0.5, 0.2, 0.2, 0.1],
|
| 231 |
+
'description': 'Default pretrain mix'
|
| 232 |
+
},
|
| 233 |
+
'code_heavy': {
|
| 234 |
+
'datasets': ['c4', 'codeparrot', 'the_stack', 'wikipedia'],
|
| 235 |
+
'weights': [0.3, 0.4, 0.2, 0.1],
|
| 236 |
+
'description': 'Code-heavy mix'
|
| 237 |
+
},
|
| 238 |
+
'multimodal': {
|
| 239 |
+
'datasets': ['c4', 'wikipedia', 'laion400m', 'conceptual_captions'],
|
| 240 |
+
'weights': [0.4, 0.2, 0.3, 0.1],
|
| 241 |
+
'description': 'Multimodal mix'
|
| 242 |
+
},
|
| 243 |
+
'text_only': {
|
| 244 |
+
'datasets': ['c4', 'wikipedia', 'bookcorpus'],
|
| 245 |
+
'weights': [0.5, 0.3, 0.2],
|
| 246 |
+
'description': 'Text-only mix for testing'
|
| 247 |
+
},
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
POSTTRAIN_MIX = {
|
| 251 |
+
'default': {
|
| 252 |
+
'datasets': ['flan_v2', 'alpaca', 'dolly', 'oasst1'],
|
| 253 |
+
'weights': [0.4, 0.3, 0.2, 0.1],
|
| 254 |
+
'description': 'Default instruction tuning mix'
|
| 255 |
+
},
|
| 256 |
+
'conversation': {
|
| 257 |
+
'datasets': ['oasst1', 'sharegpt', 'alpaca'],
|
| 258 |
+
'weights': [0.4, 0.4, 0.2],
|
| 259 |
+
'description': 'Conversation-focused mix'
|
| 260 |
+
},
|
| 261 |
+
'code_instruct': {
|
| 262 |
+
'datasets': ['code_alpaca', 'alpaca', 'flan_v2'],
|
| 263 |
+
'weights': [0.5, 0.3, 0.2],
|
| 264 |
+
'description': 'Code instruction mix'
|
| 265 |
+
},
|
| 266 |
+
'simple_instruct': {
|
| 267 |
+
'datasets': ['alpaca', 'dolly'],
|
| 268 |
+
'weights': [0.6, 0.4],
|
| 269 |
+
'description': 'Simple instruction mix for testing'
|
| 270 |
+
},
|
| 271 |
+
'debug_mix': {
|
| 272 |
+
'datasets': ['debug_water'],
|
| 273 |
+
'weights': [1.0],
|
| 274 |
+
'description': 'Debug mix for overfitting'
|
| 275 |
+
},
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
# 下载和缓存配置
|
| 279 |
+
DATASET_CACHE_DIR = "./dataset_cache"
|
| 280 |
+
HF_CACHE_DIR = "./hf_cache"
|
| 281 |
+
MAX_RETRIES = 3
|
| 282 |
+
DOWNLOAD_TIMEOUT = 300
|
| 283 |
+
|
| 284 |
+
# 数据处理配置
|
| 285 |
+
PREPROCESSING_CONFIG = {
|
| 286 |
+
'max_seq_length': 2048,
|
| 287 |
+
'min_seq_length': 32,
|
| 288 |
+
'num_workers': 4,
|
| 289 |
+
'batch_size': 8,
|
| 290 |
+
'shuffle_buffer_size': 10000,
|
| 291 |
+
'seed': 42,
|
| 292 |
+
}
|
data_loader.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_loader.py
|
| 2 |
+
"""
|
| 3 |
+
改进的数据加载器 - 支持预训练和后训练数据集
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 8 |
+
from datasets import load_dataset, concatenate_datasets, interleave_datasets
|
| 9 |
+
from typing import Dict, List, Optional, Any, Union
|
| 10 |
+
import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import warnings
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import requests
|
| 16 |
+
from io import BytesIO
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
# 设置日志
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 25 |
+
|
| 26 |
+
from data_config import (
|
| 27 |
+
PRETRAIN_DATASETS,
|
| 28 |
+
POSTTRAIN_DATASETS,
|
| 29 |
+
TEST_DATASETS,
|
| 30 |
+
PRETRAIN_MIX,
|
| 31 |
+
POSTTRAIN_MIX,
|
| 32 |
+
PREPROCESSING_CONFIG,
|
| 33 |
+
DATASET_CACHE_DIR,
|
| 34 |
+
HF_CACHE_DIR
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# 图像变换
|
| 38 |
+
image_transform = transforms.Compose([
|
| 39 |
+
transforms.Resize((224, 224)),
|
| 40 |
+
transforms.ToTensor(),
|
| 41 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
class PreTrainDataset(IterableDataset):
|
| 45 |
+
"""预训练数据集 - 支持流式和混合采样"""
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
mix_name: str = 'default',
|
| 49 |
+
tokenizer=None,
|
| 50 |
+
max_length: int = 2048,
|
| 51 |
+
streaming: bool = True,
|
| 52 |
+
seed: int = 42,
|
| 53 |
+
max_samples: Optional[int] = None
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
if tokenizer is None:
|
| 58 |
+
raise ValueError("tokenizer cannot be None")
|
| 59 |
+
|
| 60 |
+
self.tokenizer = tokenizer
|
| 61 |
+
self.max_length = max_length
|
| 62 |
+
self.streaming = streaming
|
| 63 |
+
self.seed = seed
|
| 64 |
+
self.max_samples = max_samples
|
| 65 |
+
self.samples_generated = 0
|
| 66 |
+
|
| 67 |
+
# 获取混合配置
|
| 68 |
+
if mix_name not in PRETRAIN_MIX:
|
| 69 |
+
raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
|
| 70 |
+
|
| 71 |
+
mix_config = PRETRAIN_MIX[mix_name]
|
| 72 |
+
dataset_names = mix_config.get('datasets', [])
|
| 73 |
+
weights = mix_config.get('weights', [])
|
| 74 |
+
|
| 75 |
+
if not dataset_names:
|
| 76 |
+
raise ValueError(f"No datasets found in mix: {mix_name}")
|
| 77 |
+
|
| 78 |
+
logger.info(f"Loading pretrain mix: {mix_name}")
|
| 79 |
+
logger.info(f" Datasets: {dataset_names}")
|
| 80 |
+
logger.info(f" Weights: {weights}")
|
| 81 |
+
|
| 82 |
+
# 加载数据集
|
| 83 |
+
self.datasets = []
|
| 84 |
+
self.probabilities = []
|
| 85 |
+
|
| 86 |
+
for name, weight in zip(dataset_names, weights):
|
| 87 |
+
if name not in PRETRAIN_DATASETS:
|
| 88 |
+
logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping")
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
config = PRETRAIN_DATASETS[name]
|
| 92 |
+
try:
|
| 93 |
+
ds = self._load_dataset(config)
|
| 94 |
+
if ds is not None:
|
| 95 |
+
self.datasets.append((name, ds, config))
|
| 96 |
+
self.probabilities.append(weight)
|
| 97 |
+
logger.info(f" Successfully loaded {name}")
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Error loading {name}: {e}")
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
if not self.datasets:
|
| 103 |
+
raise ValueError("No datasets loaded successfully")
|
| 104 |
+
|
| 105 |
+
# 归一化概率
|
| 106 |
+
total = sum(self.probabilities)
|
| 107 |
+
self.probabilities = [p / total for p in self.probabilities]
|
| 108 |
+
|
| 109 |
+
logger.info(f"Successfully loaded {len(self.datasets)} datasets")
|
| 110 |
+
|
| 111 |
+
def _load_dataset(self, config: Dict):
|
| 112 |
+
"""加载单个数据集"""
|
| 113 |
+
try:
|
| 114 |
+
load_kwargs = {
|
| 115 |
+
'path': config['hf_path'],
|
| 116 |
+
'split': config.get('split', 'train'),
|
| 117 |
+
'streaming': config.get('streaming', self.streaming),
|
| 118 |
+
'cache_dir': HF_CACHE_DIR,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# 添加config参数(如果存在)
|
| 122 |
+
if 'config' in config:
|
| 123 |
+
load_kwargs['name'] = config['config']
|
| 124 |
+
|
| 125 |
+
ds = load_dataset(**load_kwargs)
|
| 126 |
+
return ds
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
|
| 132 |
+
"""处理文本样本"""
|
| 133 |
+
try:
|
| 134 |
+
text_field = config.get('text_field', 'text')
|
| 135 |
+
text = sample.get(text_field, '')
|
| 136 |
+
|
| 137 |
+
if not text or not isinstance(text, str):
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
text = text.strip()
|
| 141 |
+
if len(text) < 10:
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
# Tokenize
|
| 145 |
+
encoding = self.tokenizer(
|
| 146 |
+
text,
|
| 147 |
+
max_length=self.max_length,
|
| 148 |
+
truncation=True,
|
| 149 |
+
padding='max_length',
|
| 150 |
+
return_tensors='pt'
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
'input_ids': encoding['input_ids'].squeeze(0),
|
| 155 |
+
'attention_mask': encoding['attention_mask'].squeeze(0),
|
| 156 |
+
'type': 'text'
|
| 157 |
+
}
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.debug(f"Error processing text sample: {e}")
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
|
| 163 |
+
"""处理图像-文本样本"""
|
| 164 |
+
try:
|
| 165 |
+
text_field = config.get('text_field', 'caption')
|
| 166 |
+
image_field = config.get('image_field', 'image')
|
| 167 |
+
|
| 168 |
+
text = sample.get(text_field, '')
|
| 169 |
+
image = sample.get(image_field)
|
| 170 |
+
|
| 171 |
+
if not text or image is None:
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
# 处理图像
|
| 175 |
+
if isinstance(image, str):
|
| 176 |
+
# URL - 添加超时和错误处理
|
| 177 |
+
try:
|
| 178 |
+
response = requests.get(image, timeout=5)
|
| 179 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
| 180 |
+
except Exception as img_error:
|
| 181 |
+
logger.debug(f"Failed to load image from URL: {img_error}")
|
| 182 |
+
return None
|
| 183 |
+
elif isinstance(image, Image.Image):
|
| 184 |
+
image = image.convert('RGB')
|
| 185 |
+
else:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
# 转换图像
|
| 189 |
+
image_tensor = image_transform(image)
|
| 190 |
+
|
| 191 |
+
# Tokenize文本
|
| 192 |
+
encoding = self.tokenizer(
|
| 193 |
+
text,
|
| 194 |
+
max_length=self.max_length,
|
| 195 |
+
truncation=True,
|
| 196 |
+
padding='max_length',
|
| 197 |
+
return_tensors='pt'
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
'input_ids': encoding['input_ids'].squeeze(0),
|
| 202 |
+
'attention_mask': encoding['attention_mask'].squeeze(0),
|
| 203 |
+
'image': image_tensor,
|
| 204 |
+
'type': 'image_text'
|
| 205 |
+
}
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.debug(f"Error processing image-text sample: {e}")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
def __iter__(self):
|
| 211 |
+
"""迭代器"""
|
| 212 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 213 |
+
if worker_info is not None:
|
| 214 |
+
# 多worker时设置不同的随机种子
|
| 215 |
+
random.seed(self.seed + worker_info.id)
|
| 216 |
+
np.random.seed(self.seed + worker_info.id)
|
| 217 |
+
else:
|
| 218 |
+
random.seed(self.seed)
|
| 219 |
+
np.random.seed(self.seed)
|
| 220 |
+
|
| 221 |
+
# 创建数据集迭代器
|
| 222 |
+
iterators = [iter(ds) for _, ds, _ in self.datasets]
|
| 223 |
+
self.samples_generated = 0
|
| 224 |
+
|
| 225 |
+
while True:
|
| 226 |
+
# 检查是否达到最大样本数
|
| 227 |
+
if self.max_samples and self.samples_generated >= self.max_samples:
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
# 根据概率选择数据集
|
| 232 |
+
idx = np.random.choice(len(self.datasets), p=self.probabilities)
|
| 233 |
+
name, _, config = self.datasets[idx]
|
| 234 |
+
|
| 235 |
+
# 从选中的数据集获取样本
|
| 236 |
+
sample = next(iterators[idx])
|
| 237 |
+
|
| 238 |
+
# 处理样本
|
| 239 |
+
processed = None
|
| 240 |
+
if config.get('type') in ['text', 'code']:
|
| 241 |
+
processed = self._process_text_sample(sample, config)
|
| 242 |
+
elif config.get('type') == 'image_text':
|
| 243 |
+
processed = self._process_image_text_sample(sample, config)
|
| 244 |
+
else:
|
| 245 |
+
logger.debug(f"Unknown type: {config.get('type')}")
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
if processed is not None:
|
| 249 |
+
self.samples_generated += 1
|
| 250 |
+
yield processed
|
| 251 |
+
|
| 252 |
+
except StopIteration:
|
| 253 |
+
# 重新创建迭代器
|
| 254 |
+
try:
|
| 255 |
+
iterators[idx] = iter(self.datasets[idx][1])
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.error(f"Failed to recreate iterator for dataset {idx}: {e}")
|
| 258 |
+
break
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.debug(f"Error in iterator: {e}")
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class PostTrainDataset(Dataset):
|
| 265 |
+
"""后训练数据集 - Instruction tuning和对话"""
|
| 266 |
+
def __init__(
|
| 267 |
+
self,
|
| 268 |
+
mix_name: str = 'default',
|
| 269 |
+
tokenizer=None,
|
| 270 |
+
max_length: int = 2048,
|
| 271 |
+
max_samples: Optional[int] = None,
|
| 272 |
+
split: str = 'train'
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
|
| 276 |
+
if tokenizer is None:
|
| 277 |
+
raise ValueError("tokenizer cannot be None")
|
| 278 |
+
|
| 279 |
+
self.tokenizer = tokenizer
|
| 280 |
+
self.max_length = max_length
|
| 281 |
+
self.split = split
|
| 282 |
+
|
| 283 |
+
# 获取混合配置
|
| 284 |
+
if mix_name not in POSTTRAIN_MIX:
|
| 285 |
+
raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
|
| 286 |
+
|
| 287 |
+
mix_config = POSTTRAIN_MIX[mix_name]
|
| 288 |
+
dataset_names = mix_config.get('datasets', [])
|
| 289 |
+
weights = mix_config.get('weights', [])
|
| 290 |
+
|
| 291 |
+
if not dataset_names:
|
| 292 |
+
raise ValueError(f"No datasets found in mix: {mix_name}")
|
| 293 |
+
|
| 294 |
+
logger.info(f"Loading posttrain mix: {mix_name}")
|
| 295 |
+
logger.info(f" Datasets: {dataset_names}")
|
| 296 |
+
|
| 297 |
+
# 加载和合并数据集
|
| 298 |
+
all_datasets = []
|
| 299 |
+
|
| 300 |
+
for name in dataset_names:
|
| 301 |
+
if name not in POSTTRAIN_DATASETS:
|
| 302 |
+
logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS")
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
config = POSTTRAIN_DATASETS[name]
|
| 306 |
+
try:
|
| 307 |
+
load_kwargs = {
|
| 308 |
+
'path': config['hf_path'],
|
| 309 |
+
'split': split,
|
| 310 |
+
'streaming': config.get('streaming', False),
|
| 311 |
+
'cache_dir': HF_CACHE_DIR,
|
| 312 |
+
}
|
| 313 |
+
# [新增] 如果配置里有 data_files,就加进去
|
| 314 |
+
if 'data_files' in config:
|
| 315 |
+
load_kwargs['data_files'] = config['data_files']
|
| 316 |
+
# 添加config参数(如果存在)
|
| 317 |
+
if 'config' in config:
|
| 318 |
+
load_kwargs['name'] = config['config']
|
| 319 |
+
|
| 320 |
+
ds = load_dataset(**load_kwargs)
|
| 321 |
+
|
| 322 |
+
# 限制样本数
|
| 323 |
+
if config.get('max_samples'):
|
| 324 |
+
if hasattr(ds, 'take'):
|
| 325 |
+
ds = ds.take(config['max_samples'])
|
| 326 |
+
elif hasattr(ds, 'select'):
|
| 327 |
+
ds = ds.select(range(min(len(ds), config['max_samples'])))
|
| 328 |
+
|
| 329 |
+
# 添加数据集标识
|
| 330 |
+
def add_source(example):
|
| 331 |
+
example['_source'] = name
|
| 332 |
+
example['_config'] = config
|
| 333 |
+
return example
|
| 334 |
+
|
| 335 |
+
ds = ds.map(add_source)
|
| 336 |
+
all_datasets.append(ds)
|
| 337 |
+
|
| 338 |
+
ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming'
|
| 339 |
+
logger.info(f" Loaded {name}: {ds_len} samples")
|
| 340 |
+
|
| 341 |
+
except Exception as e:
|
| 342 |
+
logger.error(f"Error loading {name}: {e}")
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
# 合并数据集
|
| 346 |
+
if not all_datasets:
|
| 347 |
+
raise ValueError("No datasets loaded successfully")
|
| 348 |
+
|
| 349 |
+
if len(all_datasets) == 1:
|
| 350 |
+
self.dataset = all_datasets[0]
|
| 351 |
+
else:
|
| 352 |
+
# 交织数据集
|
| 353 |
+
probabilities = [w / sum(weights[:len(all_datasets)])
|
| 354 |
+
for w in weights[:len(all_datasets)]]
|
| 355 |
+
self.dataset = interleave_datasets(
|
| 356 |
+
all_datasets,
|
| 357 |
+
probabilities=probabilities,
|
| 358 |
+
seed=42,
|
| 359 |
+
stopping_strategy='all_exhausted'
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# 限制总样本数
|
| 363 |
+
if max_samples and hasattr(self.dataset, '__len__'):
|
| 364 |
+
actual_len = min(len(self.dataset), max_samples)
|
| 365 |
+
self.dataset = self.dataset.select(range(actual_len))
|
| 366 |
+
|
| 367 |
+
dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming'
|
| 368 |
+
logger.info(f"Total samples: {dataset_len}")
|
| 369 |
+
|
| 370 |
+
def _format_instruction(self, sample: Dict, config: Dict) -> str:
|
| 371 |
+
"""格式化instruction"""
|
| 372 |
+
try:
|
| 373 |
+
data_type = config.get('type', 'instruction')
|
| 374 |
+
|
| 375 |
+
if data_type == 'instruction':
|
| 376 |
+
instruction_field = config.get('instruction_field', 'instruction')
|
| 377 |
+
input_field = config.get('input_field', 'input')
|
| 378 |
+
context_field = config.get('context_field', None)
|
| 379 |
+
|
| 380 |
+
instruction = sample.get(instruction_field, '')
|
| 381 |
+
input_text = sample.get(input_field, '')
|
| 382 |
+
context = sample.get(context_field, '') if context_field else ''
|
| 383 |
+
|
| 384 |
+
# 构建prompt
|
| 385 |
+
prompt_parts = [f"Instruction: {instruction}"]
|
| 386 |
+
|
| 387 |
+
if context:
|
| 388 |
+
prompt_parts.append(f"Context: {context}")
|
| 389 |
+
|
| 390 |
+
if input_text:
|
| 391 |
+
prompt_parts.append(f"Input: {input_text}")
|
| 392 |
+
|
| 393 |
+
prompt_parts.append("Response:")
|
| 394 |
+
return "\n".join(prompt_parts)
|
| 395 |
+
|
| 396 |
+
elif data_type == 'conversation':
|
| 397 |
+
# 处理对话格式 - 支持不同的对话格式
|
| 398 |
+
if 'conversations' in sample:
|
| 399 |
+
# LLaVA格式
|
| 400 |
+
conversations = sample['conversations']
|
| 401 |
+
if isinstance(conversations, list) and len(conversations) > 0:
|
| 402 |
+
dialogue = []
|
| 403 |
+
for conv in conversations[:-1]:
|
| 404 |
+
role = conv.get('from', 'user')
|
| 405 |
+
content = conv.get('value', '')
|
| 406 |
+
dialogue.append(f"{role}: {content}")
|
| 407 |
+
return "\n".join(dialogue) + "\nassistant:"
|
| 408 |
+
|
| 409 |
+
elif 'messages' in sample:
|
| 410 |
+
# 标准消息格式
|
| 411 |
+
messages = sample['messages']
|
| 412 |
+
if isinstance(messages, list) and len(messages) > 0:
|
| 413 |
+
dialogue = []
|
| 414 |
+
for msg in messages[:-1]:
|
| 415 |
+
role = msg.get('role', 'user')
|
| 416 |
+
content = msg.get('content', '')
|
| 417 |
+
dialogue.append(f"{role}: {content}")
|
| 418 |
+
return "\n".join(dialogue) + "\nassistant:"
|
| 419 |
+
|
| 420 |
+
# 如果没有标准格式,尝试使用text字段
|
| 421 |
+
return sample.get('text', '')
|
| 422 |
+
|
| 423 |
+
elif data_type == 'code_instruction':
|
| 424 |
+
# 代码instruction格式
|
| 425 |
+
instruction_field = config.get('instruction_field', 'instruction')
|
| 426 |
+
instruction = sample.get(instruction_field, '')
|
| 427 |
+
return f"### Instruction:\n{instruction}\n### Response:"
|
| 428 |
+
|
| 429 |
+
elif data_type == 'multimodal_instruction':
|
| 430 |
+
# 多模态instruction
|
| 431 |
+
instruction_field = config.get('instruction_field', 'conversations')
|
| 432 |
+
conversations = sample.get(instruction_field, [])
|
| 433 |
+
if isinstance(conversations, list) and len(conversations) > 0:
|
| 434 |
+
# 提取对话历史(除了最后一条回复)
|
| 435 |
+
dialogue = []
|
| 436 |
+
for conv in conversations[:-1]:
|
| 437 |
+
role = conv.get('from', 'user')
|
| 438 |
+
content = conv.get('value', '')
|
| 439 |
+
dialogue.append(f"{role}: {content}")
|
| 440 |
+
return "\n".join(dialogue) + "\nassistant:"
|
| 441 |
+
return ""
|
| 442 |
+
|
| 443 |
+
else:
|
| 444 |
+
return sample.get(config.get('instruction_field', 'text'), '')
|
| 445 |
+
except Exception as e:
|
| 446 |
+
logger.debug(f"Error formatting instruction: {e}")
|
| 447 |
+
return ""
|
| 448 |
+
|
| 449 |
+
def _get_response(self, sample: Dict, config: Dict) -> str:
|
| 450 |
+
"""获取响应"""
|
| 451 |
+
try:
|
| 452 |
+
data_type = config.get('type', 'instruction')
|
| 453 |
+
|
| 454 |
+
if data_type == 'instruction' or data_type == 'code_instruction':
|
| 455 |
+
response_field = config.get('response_field', 'output')
|
| 456 |
+
return sample.get(response_field, '')
|
| 457 |
+
|
| 458 |
+
elif data_type == 'conversation':
|
| 459 |
+
# 从对话中提取最后一条assistant的回复
|
| 460 |
+
if 'conversations' in sample:
|
| 461 |
+
conversations = sample['conversations']
|
| 462 |
+
if isinstance(conversations, list) and len(conversations) > 0:
|
| 463 |
+
return conversations[-1].get('value', '')
|
| 464 |
+
|
| 465 |
+
elif 'messages' in sample:
|
| 466 |
+
messages = sample['messages']
|
| 467 |
+
if isinstance(messages, list) and len(messages) > 0:
|
| 468 |
+
return messages[-1].get('content', '')
|
| 469 |
+
|
| 470 |
+
return ""
|
| 471 |
+
|
| 472 |
+
elif data_type == 'multimodal_instruction':
|
| 473 |
+
instruction_field = config.get('instruction_field', 'conversations')
|
| 474 |
+
conversations = sample.get(instruction_field, [])
|
| 475 |
+
if isinstance(conversations, list) and len(conversations) > 0:
|
| 476 |
+
return conversations[-1].get('value', '')
|
| 477 |
+
return ""
|
| 478 |
+
|
| 479 |
+
else:
|
| 480 |
+
response_field = config.get('response_field', 'output')
|
| 481 |
+
return sample.get(response_field, '')
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logger.debug(f"Error getting response: {e}")
|
| 484 |
+
return ""
|
| 485 |
+
|
| 486 |
+
def __len__(self):
|
| 487 |
+
return len(self.dataset) if hasattr(self.dataset, '__len__') else 0
|
| 488 |
+
|
| 489 |
+
def __getitem__(self, idx):
|
| 490 |
+
try:
|
| 491 |
+
sample = self.dataset[idx]
|
| 492 |
+
|
| 493 |
+
# 获取配置
|
| 494 |
+
if '_config' not in sample:
|
| 495 |
+
logger.warning(f"Sample at index {idx} missing _config")
|
| 496 |
+
return None
|
| 497 |
+
|
| 498 |
+
config = sample['_config']
|
| 499 |
+
|
| 500 |
+
# 格式化 instruction 和 response
|
| 501 |
+
instruction_text = self._format_instruction(sample, config)
|
| 502 |
+
response_text = self._get_response(sample, config)
|
| 503 |
+
|
| 504 |
+
if not instruction_text or not response_text:
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
# 确保 pad_token_id 存在
|
| 508 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 509 |
+
if pad_token_id is None:
|
| 510 |
+
pad_token_id = self.tokenizer.eos_token_id
|
| 511 |
+
|
| 512 |
+
# =======================================================
|
| 513 |
+
# 1. 处理 Instruction (不需要 EOS,因为后面紧接 Response)
|
| 514 |
+
# =======================================================
|
| 515 |
+
instruction_max_len = self.max_length // 2
|
| 516 |
+
|
| 517 |
+
# Tokenize 不做 padding,手动处理
|
| 518 |
+
instruction_enc = self.tokenizer(
|
| 519 |
+
instruction_text,
|
| 520 |
+
truncation=True,
|
| 521 |
+
max_length=instruction_max_len,
|
| 522 |
+
add_special_tokens=False, # 手动控制特殊token
|
| 523 |
+
return_tensors='pt'
|
| 524 |
+
)
|
| 525 |
+
instr_ids = instruction_enc['input_ids'].squeeze(0)
|
| 526 |
+
|
| 527 |
+
# Instruction 手动 Padding
|
| 528 |
+
instr_len = instr_ids.size(0)
|
| 529 |
+
if instr_len < instruction_max_len:
|
| 530 |
+
# 左填充或者右填充皆可,通常 SFT 这里的 Instruction 是右填充
|
| 531 |
+
# padding_tensor = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
|
| 532 |
+
# instr_ids = torch.cat([instr_ids, padding_tensor])
|
| 533 |
+
# 为了保持代码与原逻辑一致,这里使用右填充至固定长度
|
| 534 |
+
padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
|
| 535 |
+
instr_ids = torch.cat([instr_ids, padding])
|
| 536 |
+
|
| 537 |
+
# Mask: 真实token为1,pad为0
|
| 538 |
+
instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)])
|
| 539 |
+
else:
|
| 540 |
+
instr_mask = torch.ones(instruction_max_len, dtype=torch.long)
|
| 541 |
+
|
| 542 |
+
# =======================================================
|
| 543 |
+
# 2. 处理 Response (【核心修复】:必须加 EOS)
|
| 544 |
+
# =======================================================
|
| 545 |
+
response_max_len = self.max_length // 2
|
| 546 |
+
|
| 547 |
+
# Tokenize: 预留1个位置给EOS
|
| 548 |
+
response_enc = self.tokenizer(
|
| 549 |
+
response_text,
|
| 550 |
+
truncation=True,
|
| 551 |
+
max_length=response_max_len - 1, # 关键:留一个位置给 EOS
|
| 552 |
+
add_special_tokens=False,
|
| 553 |
+
return_tensors='pt'
|
| 554 |
+
)
|
| 555 |
+
resp_ids = response_enc['input_ids'].squeeze(0)
|
| 556 |
+
|
| 557 |
+
# 【强制添加 EOS Token】
|
| 558 |
+
eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
|
| 559 |
+
resp_ids = torch.cat([resp_ids, eos_token])
|
| 560 |
+
|
| 561 |
+
# Response 手动 Padding
|
| 562 |
+
curr_resp_len = resp_ids.size(0)
|
| 563 |
+
if curr_resp_len < response_max_len:
|
| 564 |
+
padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long)
|
| 565 |
+
resp_ids = torch.cat([resp_ids, padding])
|
| 566 |
+
|
| 567 |
+
# Mask: 真实内容+EOS 为1,Pad 为0
|
| 568 |
+
resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)])
|
| 569 |
+
else:
|
| 570 |
+
resp_mask = torch.ones(response_max_len, dtype=torch.long)
|
| 571 |
+
|
| 572 |
+
# =======================================================
|
| 573 |
+
# 3. 组装结果
|
| 574 |
+
# =======================================================
|
| 575 |
+
result = {
|
| 576 |
+
'instruction': instr_ids,
|
| 577 |
+
'response': resp_ids,
|
| 578 |
+
'instruction_mask': instr_mask,
|
| 579 |
+
'response_mask': resp_mask,
|
| 580 |
+
'task': sample.get('_source', 'unknown'),
|
| 581 |
+
'modality_data': None
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
# 如果是多模态数据,添加图像
|
| 585 |
+
if config.get('type') == 'multimodal_instruction' and 'image' in sample:
|
| 586 |
+
try:
|
| 587 |
+
image = sample['image']
|
| 588 |
+
if isinstance(image, Image.Image):
|
| 589 |
+
image = image.convert('RGB')
|
| 590 |
+
image_tensor = image_transform(image)
|
| 591 |
+
result['modality_data'] = {'image': image_tensor}
|
| 592 |
+
except Exception as e:
|
| 593 |
+
logger.debug(f"Error processing image: {e}")
|
| 594 |
+
|
| 595 |
+
return result
|
| 596 |
+
|
| 597 |
+
except Exception as e:
|
| 598 |
+
logger.debug(f"Error getting item at index {idx}: {e}")
|
| 599 |
+
import traceback
|
| 600 |
+
traceback.print_exc()
|
| 601 |
+
return None
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class PreferenceDataset(Dataset):
|
| 605 |
+
"""偏好数据集 - 用于RLHF"""
|
| 606 |
+
def __init__(
|
| 607 |
+
self,
|
| 608 |
+
dataset_name: str = 'hh_rlhf',
|
| 609 |
+
tokenizer=None,
|
| 610 |
+
max_length: int = 1024,
|
| 611 |
+
max_samples: Optional[int] = None,
|
| 612 |
+
split: str = 'train'
|
| 613 |
+
):
|
| 614 |
+
super().__init__()
|
| 615 |
+
|
| 616 |
+
if tokenizer is None:
|
| 617 |
+
raise ValueError("tokenizer cannot be None")
|
| 618 |
+
|
| 619 |
+
self.tokenizer = tokenizer
|
| 620 |
+
self.max_length = max_length
|
| 621 |
+
|
| 622 |
+
if dataset_name not in POSTTRAIN_DATASETS:
|
| 623 |
+
raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}")
|
| 624 |
+
|
| 625 |
+
config = POSTTRAIN_DATASETS[dataset_name]
|
| 626 |
+
if config.get('type') != 'preference':
|
| 627 |
+
raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})")
|
| 628 |
+
|
| 629 |
+
logger.info(f"Loading preference dataset: {dataset_name}")
|
| 630 |
+
|
| 631 |
+
load_kwargs = {
|
| 632 |
+
'path': config['hf_path'],
|
| 633 |
+
'split': split,
|
| 634 |
+
'cache_dir': HF_CACHE_DIR,
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
# 添加config参数(如果存在)
|
| 638 |
+
if 'config' in config:
|
| 639 |
+
load_kwargs['name'] = config['config']
|
| 640 |
+
|
| 641 |
+
self.dataset = load_dataset(**load_kwargs)
|
| 642 |
+
|
| 643 |
+
self.chosen_field = config.get('chosen_field', 'chosen')
|
| 644 |
+
self.rejected_field = config.get('rejected_field', 'rejected')
|
| 645 |
+
|
| 646 |
+
if max_samples and len(self.dataset) > max_samples:
|
| 647 |
+
self.dataset = self.dataset.select(range(max_samples))
|
| 648 |
+
|
| 649 |
+
logger.info(f"Loaded {len(self.dataset)} preference pairs")
|
| 650 |
+
|
| 651 |
+
def __len__(self):
|
| 652 |
+
return len(self.dataset)
|
| 653 |
+
|
| 654 |
+
def __getitem__(self, idx):
|
| 655 |
+
try:
|
| 656 |
+
sample = self.dataset[idx]
|
| 657 |
+
|
| 658 |
+
chosen_text = sample.get(self.chosen_field, '')
|
| 659 |
+
rejected_text = sample.get(self.rejected_field, '')
|
| 660 |
+
|
| 661 |
+
if not chosen_text or not rejected_text:
|
| 662 |
+
return None
|
| 663 |
+
|
| 664 |
+
# Tokenize
|
| 665 |
+
chosen_enc = self.tokenizer(
|
| 666 |
+
chosen_text,
|
| 667 |
+
max_length=self.max_length,
|
| 668 |
+
truncation=True,
|
| 669 |
+
padding='max_length',
|
| 670 |
+
return_tensors='pt'
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
rejected_enc = self.tokenizer(
|
| 674 |
+
rejected_text,
|
| 675 |
+
max_length=self.max_length,
|
| 676 |
+
truncation=True,
|
| 677 |
+
padding='max_length',
|
| 678 |
+
return_tensors='pt'
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
return (
|
| 682 |
+
chosen_enc['input_ids'].squeeze(0),
|
| 683 |
+
rejected_enc['input_ids'].squeeze(0),
|
| 684 |
+
chosen_enc['attention_mask'].squeeze(0),
|
| 685 |
+
rejected_enc['attention_mask'].squeeze(0)
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
except Exception as e:
|
| 689 |
+
logger.debug(f"Error getting preference item at index {idx}: {e}")
|
| 690 |
+
return None
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def collate_fn_v2(batch):
|
| 694 |
+
"""改进的collate函数"""
|
| 695 |
+
# 过滤None
|
| 696 |
+
batch = [item for item in batch if item is not None]
|
| 697 |
+
|
| 698 |
+
if not batch:
|
| 699 |
+
logger.warning("Empty batch after filtering None values")
|
| 700 |
+
# 返回一个空的占位batch而不是None
|
| 701 |
+
return {
|
| 702 |
+
'input_ids': torch.empty(0),
|
| 703 |
+
'attention_mask': torch.empty(0)
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
# 检查是否是preference数据
|
| 707 |
+
if isinstance(batch[0], tuple):
|
| 708 |
+
if len(batch[0]) == 4: # 包含attention_mask
|
| 709 |
+
chosen = torch.stack([item[0] for item in batch])
|
| 710 |
+
rejected = torch.stack([item[1] for item in batch])
|
| 711 |
+
chosen_mask = torch.stack([item[2] for item in batch])
|
| 712 |
+
rejected_mask = torch.stack([item[3] for item in batch])
|
| 713 |
+
return {
|
| 714 |
+
'chosen': chosen,
|
| 715 |
+
'rejected': rejected,
|
| 716 |
+
'chosen_mask': chosen_mask,
|
| 717 |
+
'rejected_mask': rejected_mask
|
| 718 |
+
}
|
| 719 |
+
else: # 旧格式兼容
|
| 720 |
+
chosen = torch.stack([item[0] for item in batch])
|
| 721 |
+
rejected = torch.stack([item[1] for item in batch])
|
| 722 |
+
return {'chosen': chosen, 'rejected': rejected}
|
| 723 |
+
|
| 724 |
+
# 普通数据
|
| 725 |
+
keys = batch[0].keys()
|
| 726 |
+
collated = {}
|
| 727 |
+
|
| 728 |
+
for key in keys:
|
| 729 |
+
if key in ['instruction', 'response', 'instruction_mask',
|
| 730 |
+
'response_mask', 'input_ids', 'attention_mask']:
|
| 731 |
+
tensors = [item[key] for item in batch if item.get(key) is not None]
|
| 732 |
+
if tensors:
|
| 733 |
+
collated[key] = torch.stack(tensors)
|
| 734 |
+
else:
|
| 735 |
+
collated[key] = None
|
| 736 |
+
elif key == 'modality_data':
|
| 737 |
+
# 处理多模态数据
|
| 738 |
+
modality_list = [item[key] for item in batch if item.get(key) is not None]
|
| 739 |
+
if modality_list and any(m is not None for m in modality_list):
|
| 740 |
+
# 收集图像
|
| 741 |
+
images = [m.get('image') for m in modality_list if m and 'image' in m]
|
| 742 |
+
if images:
|
| 743 |
+
collated[key] = {'image': torch.stack(images)}
|
| 744 |
+
else:
|
| 745 |
+
collated[key] = None
|
| 746 |
+
else:
|
| 747 |
+
collated[key] = None
|
| 748 |
+
else:
|
| 749 |
+
collated[key] = [item[key] for item in batch]
|
| 750 |
+
|
| 751 |
+
return collated
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def create_pretrain_dataloader(
|
| 755 |
+
mix_name: str = 'default',
|
| 756 |
+
tokenizer=None,
|
| 757 |
+
batch_size: int = 8,
|
| 758 |
+
num_workers: int = 4,
|
| 759 |
+
max_length: int = 2048,
|
| 760 |
+
max_samples: Optional[int] = None
|
| 761 |
+
):
|
| 762 |
+
"""创建预训练数据加载器"""
|
| 763 |
+
dataset = PreTrainDataset(
|
| 764 |
+
mix_name=mix_name,
|
| 765 |
+
tokenizer=tokenizer,
|
| 766 |
+
max_length=max_length,
|
| 767 |
+
streaming=True,
|
| 768 |
+
max_samples=max_samples
|
| 769 |
+
)
|
| 770 |
+
return DataLoader(
|
| 771 |
+
dataset,
|
| 772 |
+
batch_size=batch_size,
|
| 773 |
+
num_workers=num_workers,
|
| 774 |
+
collate_fn=collate_fn_v2
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def create_posttrain_dataloader(
|
| 779 |
+
mix_name: str = 'default',
|
| 780 |
+
tokenizer=None,
|
| 781 |
+
batch_size: int = 8,
|
| 782 |
+
num_workers: int = 4,
|
| 783 |
+
max_length: int = 2048,
|
| 784 |
+
max_samples: Optional[int] = None,
|
| 785 |
+
split: str = 'train',
|
| 786 |
+
shuffle: bool = True
|
| 787 |
+
):
|
| 788 |
+
"""创建后训练数据加载器"""
|
| 789 |
+
dataset = PostTrainDataset(
|
| 790 |
+
mix_name=mix_name,
|
| 791 |
+
tokenizer=tokenizer,
|
| 792 |
+
max_length=max_length,
|
| 793 |
+
max_samples=max_samples,
|
| 794 |
+
split=split
|
| 795 |
+
)
|
| 796 |
+
return DataLoader(
|
| 797 |
+
dataset,
|
| 798 |
+
batch_size=batch_size,
|
| 799 |
+
shuffle=shuffle,
|
| 800 |
+
num_workers=num_workers,
|
| 801 |
+
collate_fn=collate_fn_v2,
|
| 802 |
+
pin_memory=True,
|
| 803 |
+
drop_last=False # 保留最后一个batch
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def create_preference_dataloader(
|
| 808 |
+
dataset_name: str = 'hh_rlhf',
|
| 809 |
+
tokenizer=None,
|
| 810 |
+
batch_size: int = 8,
|
| 811 |
+
num_workers: int = 4,
|
| 812 |
+
max_length: int = 1024,
|
| 813 |
+
max_samples: Optional[int] = None,
|
| 814 |
+
split: str = 'train',
|
| 815 |
+
shuffle: bool = True
|
| 816 |
+
):
|
| 817 |
+
"""创建偏好数据加载器"""
|
| 818 |
+
dataset = PreferenceDataset(
|
| 819 |
+
dataset_name=dataset_name,
|
| 820 |
+
tokenizer=tokenizer,
|
| 821 |
+
max_length=max_length,
|
| 822 |
+
max_samples=max_samples,
|
| 823 |
+
split=split
|
| 824 |
+
)
|
| 825 |
+
return DataLoader(
|
| 826 |
+
dataset,
|
| 827 |
+
batch_size=batch_size,
|
| 828 |
+
shuffle=shuffle,
|
| 829 |
+
num_workers=num_workers,
|
| 830 |
+
collate_fn=collate_fn_v2,
|
| 831 |
+
pin_memory=True
|
| 832 |
+
)
|
encoders.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
改进的多模态编码器 - SOTA级别(修复版)
|
| 3 |
+
集成最新的视觉、音频、视频编码技术
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Tuple, Optional
|
| 9 |
+
from components import RMSNorm, SwiGLU
|
| 10 |
+
from transformer import OptimizedTransformerBlock
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
class LayerScale(nn.Module):
|
| 14 |
+
"""LayerScale - 改进训练稳定性"""
|
| 15 |
+
def __init__(self, dim: int, init_values: float = 1e-5):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 18 |
+
|
| 19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
return x * self.gamma
|
| 21 |
+
|
| 22 |
+
class StochasticDepth(nn.Module):
|
| 23 |
+
"""随机深度 - Drop Path"""
|
| 24 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.drop_prob = drop_prob
|
| 27 |
+
|
| 28 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
if not self.training or self.drop_prob == 0.0:
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
keep_prob = 1 - self.drop_prob
|
| 33 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 34 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 35 |
+
random_tensor.floor_()
|
| 36 |
+
return x.div(keep_prob) * random_tensor
|
| 37 |
+
|
| 38 |
+
class ImprovedPatchEmbedding(nn.Module):
|
| 39 |
+
"""改进的图像分块嵌入 - 支持重叠patch和多尺度"""
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
patch_size: int = 14,
|
| 43 |
+
in_channels: int = 3,
|
| 44 |
+
embed_dim: int = 2048,
|
| 45 |
+
overlap: int = 0
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.patch_size = patch_size
|
| 49 |
+
stride = patch_size - overlap
|
| 50 |
+
self.proj = nn.Conv2d(
|
| 51 |
+
in_channels,
|
| 52 |
+
embed_dim,
|
| 53 |
+
kernel_size=patch_size,
|
| 54 |
+
stride=stride,
|
| 55 |
+
padding=overlap // 2
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.norm = RMSNorm(embed_dim)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 61 |
+
B, C, H, W = x.shape
|
| 62 |
+
x = self.proj(x)
|
| 63 |
+
grid_size = (x.shape[2], x.shape[3])
|
| 64 |
+
x = x.flatten(2).transpose(1, 2)
|
| 65 |
+
x = self.norm(x)
|
| 66 |
+
return x, grid_size
|
| 67 |
+
|
| 68 |
+
class ImprovedVisionBlock(nn.Module):
|
| 69 |
+
"""改进的Vision Transformer Block"""
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
dim: int,
|
| 73 |
+
n_heads: int,
|
| 74 |
+
dropout: float = 0.0,
|
| 75 |
+
drop_path: float = 0.0,
|
| 76 |
+
use_adapter: bool = False,
|
| 77 |
+
adapter_dim: int = 64,
|
| 78 |
+
use_layer_scale: bool = True,
|
| 79 |
+
layer_scale_init: float = 1e-5
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.norm1 = RMSNorm(dim)
|
| 83 |
+
self.attn = nn.MultiheadAttention(
|
| 84 |
+
dim, n_heads, dropout=dropout, batch_first=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.norm2 = RMSNorm(dim)
|
| 88 |
+
self.mlp = nn.Sequential(
|
| 89 |
+
nn.Linear(dim, dim * 4),
|
| 90 |
+
nn.GELU(),
|
| 91 |
+
nn.Dropout(dropout),
|
| 92 |
+
nn.Linear(dim * 4, dim),
|
| 93 |
+
nn.Dropout(dropout)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity()
|
| 97 |
+
|
| 98 |
+
if use_layer_scale:
|
| 99 |
+
self.ls1 = LayerScale(dim, layer_scale_init)
|
| 100 |
+
self.ls2 = LayerScale(dim, layer_scale_init)
|
| 101 |
+
else:
|
| 102 |
+
self.ls1 = nn.Identity()
|
| 103 |
+
self.ls2 = nn.Identity()
|
| 104 |
+
|
| 105 |
+
# 修复:使用简单的adapter实现,避免外部依赖
|
| 106 |
+
if use_adapter:
|
| 107 |
+
self.adapter = nn.Sequential(
|
| 108 |
+
nn.Linear(dim, adapter_dim),
|
| 109 |
+
nn.GELU(),
|
| 110 |
+
nn.Linear(adapter_dim, dim)
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
self.adapter = None
|
| 114 |
+
|
| 115 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
# 注意力
|
| 117 |
+
normx = self.norm1(x)
|
| 118 |
+
attn_out, _ = self.attn(normx, normx, normx)
|
| 119 |
+
x = x + self.drop_path(self.ls1(attn_out))
|
| 120 |
+
|
| 121 |
+
# MLP
|
| 122 |
+
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
|
| 123 |
+
|
| 124 |
+
# Adapter
|
| 125 |
+
if self.adapter is not None:
|
| 126 |
+
x = x + self.adapter(x)
|
| 127 |
+
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
class ImprovedVisionTransformer(nn.Module):
|
| 131 |
+
"""
|
| 132 |
+
改进的视觉Transformer
|
| 133 |
+
- LayerScale
|
| 134 |
+
- Stochastic Depth
|
| 135 |
+
- 改进的位置编码
|
| 136 |
+
- 可选的Register tokens
|
| 137 |
+
"""
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
img_size: int = 224,
|
| 141 |
+
patch_size: int = 14,
|
| 142 |
+
in_channels: int = 3,
|
| 143 |
+
embed_dim: int = 2048,
|
| 144 |
+
depth: int = 24,
|
| 145 |
+
n_heads: int = 16,
|
| 146 |
+
dropout: float = 0.0,
|
| 147 |
+
drop_path_rate: float = 0.1,
|
| 148 |
+
use_register_tokens: bool = True,
|
| 149 |
+
num_register_tokens: int = 4,
|
| 150 |
+
use_adapter: bool = False,
|
| 151 |
+
adapter_dim: int = 64,
|
| 152 |
+
use_layer_scale: bool = True,
|
| 153 |
+
layer_scale_init: float = 1e-5
|
| 154 |
+
):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.patch_size = patch_size
|
| 157 |
+
self.embed_dim = embed_dim
|
| 158 |
+
self.use_register_tokens = use_register_tokens
|
| 159 |
+
self.num_register_tokens = num_register_tokens if use_register_tokens else 0
|
| 160 |
+
|
| 161 |
+
# Patch embedding
|
| 162 |
+
self.patch_embed = ImprovedPatchEmbedding(
|
| 163 |
+
patch_size, in_channels, embed_dim, overlap=0
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
self.pretrain_img_size = img_size
|
| 167 |
+
n_patches_pretrain = (img_size // patch_size) ** 2
|
| 168 |
+
|
| 169 |
+
# CLS token
|
| 170 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 171 |
+
|
| 172 |
+
# Register tokens (DINOv2启发)
|
| 173 |
+
if use_register_tokens:
|
| 174 |
+
self.register_tokens = nn.Parameter(
|
| 175 |
+
torch.zeros(1, num_register_tokens, embed_dim)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 修复:位置编码总数 = 1(CLS) + n_patches + register_tokens
|
| 179 |
+
total_tokens = 1 + n_patches_pretrain + self.num_register_tokens
|
| 180 |
+
self.pos_embed = nn.Parameter(
|
| 181 |
+
torch.zeros(1, total_tokens, embed_dim)
|
| 182 |
+
)
|
| 183 |
+
self.pos_drop = nn.Dropout(dropout)
|
| 184 |
+
|
| 185 |
+
# Stochastic depth
|
| 186 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
| 187 |
+
|
| 188 |
+
# Transformer blocks
|
| 189 |
+
self.blocks = nn.ModuleList([
|
| 190 |
+
ImprovedVisionBlock(
|
| 191 |
+
embed_dim,
|
| 192 |
+
n_heads,
|
| 193 |
+
dropout,
|
| 194 |
+
drop_path=dpr[i],
|
| 195 |
+
use_adapter=use_adapter,
|
| 196 |
+
adapter_dim=adapter_dim,
|
| 197 |
+
use_layer_scale=use_layer_scale,
|
| 198 |
+
layer_scale_init=layer_scale_init
|
| 199 |
+
)
|
| 200 |
+
for i in range(depth)
|
| 201 |
+
])
|
| 202 |
+
|
| 203 |
+
self.norm = RMSNorm(embed_dim)
|
| 204 |
+
self._init_weights()
|
| 205 |
+
|
| 206 |
+
def _init_weights(self):
|
| 207 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
| 208 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 209 |
+
if self.use_register_tokens:
|
| 210 |
+
nn.init.trunc_normal_(self.register_tokens, std=0.02)
|
| 211 |
+
|
| 212 |
+
self.apply(self._init_module_weights)
|
| 213 |
+
|
| 214 |
+
def _init_module_weights(self, m):
|
| 215 |
+
if isinstance(m, nn.Linear):
|
| 216 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 217 |
+
if m.bias is not None:
|
| 218 |
+
nn.init.zeros_(m.bias)
|
| 219 |
+
elif isinstance(m, nn.Conv2d):
|
| 220 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 221 |
+
if m.bias is not None:
|
| 222 |
+
nn.init.zeros_(m.bias)
|
| 223 |
+
elif isinstance(m, RMSNorm):
|
| 224 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
| 225 |
+
nn.init.ones_(m.weight)
|
| 226 |
+
|
| 227 |
+
def _interpolate_pos_encoding(
|
| 228 |
+
self,
|
| 229 |
+
patch_tokens: torch.Tensor,
|
| 230 |
+
grid_size: Tuple[int, int]
|
| 231 |
+
) -> torch.Tensor:
|
| 232 |
+
"""
|
| 233 |
+
修复:改进的位置编码插值
|
| 234 |
+
只对patch位置编码进行插值,CLS和register token位置编码保持不变
|
| 235 |
+
"""
|
| 236 |
+
pretrain_grid_h = self.pretrain_img_size // self.patch_size
|
| 237 |
+
pretrain_grid_w = pretrain_grid_h
|
| 238 |
+
|
| 239 |
+
# 如果尺寸匹配,直接返回原始位置编码
|
| 240 |
+
if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w:
|
| 241 |
+
return self.pos_embed
|
| 242 |
+
|
| 243 |
+
# 分离不同部分的位置编码
|
| 244 |
+
# pos_embed结构: [CLS(1), register_tokens(n), patches(H*W)]
|
| 245 |
+
num_extra_tokens = 1 + self.num_register_tokens
|
| 246 |
+
cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] # [1, 1+n, dim]
|
| 247 |
+
patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] # [1, H*W, dim]
|
| 248 |
+
|
| 249 |
+
# 2D插值patch位置编码
|
| 250 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
| 251 |
+
1, pretrain_grid_h, pretrain_grid_w, -1
|
| 252 |
+
).permute(0, 3, 1, 2)
|
| 253 |
+
|
| 254 |
+
patch_pos_embed = F.interpolate(
|
| 255 |
+
patch_pos_embed,
|
| 256 |
+
size=grid_size,
|
| 257 |
+
mode='bicubic',
|
| 258 |
+
align_corners=False
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
|
| 262 |
+
|
| 263 |
+
# 拼接回去
|
| 264 |
+
return torch.cat([cls_register_pos, patch_pos_embed], dim=1)
|
| 265 |
+
|
| 266 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 267 |
+
B = x.shape[0]
|
| 268 |
+
|
| 269 |
+
# Patch embedding
|
| 270 |
+
x, grid_size = self.patch_embed(x)
|
| 271 |
+
|
| 272 |
+
# 添加CLS token
|
| 273 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 274 |
+
|
| 275 |
+
# 修复:正确组装tokens序列
|
| 276 |
+
if self.use_register_tokens:
|
| 277 |
+
register_tokens = self.register_tokens.expand(B, -1, -1)
|
| 278 |
+
# 顺序: [CLS, register_tokens, patches]
|
| 279 |
+
x = torch.cat([cls_tokens, register_tokens, x], dim=1)
|
| 280 |
+
else:
|
| 281 |
+
x = torch.cat([cls_tokens, x], dim=1)
|
| 282 |
+
|
| 283 |
+
# 位置编码(插值以适应不同尺寸)
|
| 284 |
+
pos_embed = self._interpolate_pos_encoding(x, grid_size)
|
| 285 |
+
x = self.pos_drop(x + pos_embed)
|
| 286 |
+
|
| 287 |
+
# Transformer blocks
|
| 288 |
+
for block in self.blocks:
|
| 289 |
+
x = block(x)
|
| 290 |
+
|
| 291 |
+
x = self.norm(x)
|
| 292 |
+
|
| 293 |
+
# 返回所有tokens(调用者可以选择使用CLS token或全局池化)
|
| 294 |
+
return x
|
| 295 |
+
|
| 296 |
+
class ImprovedAudioEncoder(nn.Module):
|
| 297 |
+
"""
|
| 298 |
+
改进的音频编码器
|
| 299 |
+
- 时序建模
|
| 300 |
+
- 频率建模
|
| 301 |
+
- 双流架构
|
| 302 |
+
"""
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
n_mels: int = 128,
|
| 306 |
+
target_length: int = 1024,
|
| 307 |
+
embed_dim: int = 2048,
|
| 308 |
+
depth: int = 12,
|
| 309 |
+
n_heads: int = 16,
|
| 310 |
+
patch_size: int = 16,
|
| 311 |
+
dropout: float = 0.1,
|
| 312 |
+
use_adapter: bool = False,
|
| 313 |
+
adapter_dim: int = 64,
|
| 314 |
+
use_dual_stream: bool = True
|
| 315 |
+
):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.use_dual_stream = use_dual_stream
|
| 318 |
+
self.patch_size = patch_size
|
| 319 |
+
|
| 320 |
+
# 主编码器
|
| 321 |
+
self.patch_embed = nn.Conv2d(
|
| 322 |
+
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# 修复:计算实际的patch数量
|
| 326 |
+
self.n_patches_h = n_mels // patch_size
|
| 327 |
+
self.n_patches_w = target_length // patch_size
|
| 328 |
+
n_patches = self.n_patches_h * self.n_patches_w
|
| 329 |
+
|
| 330 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim))
|
| 331 |
+
self.pos_drop = nn.Dropout(dropout)
|
| 332 |
+
|
| 333 |
+
# Transformer blocks
|
| 334 |
+
self.blocks = nn.ModuleList([
|
| 335 |
+
OptimizedTransformerBlock(
|
| 336 |
+
embed_dim, n_heads, None, None, dropout,
|
| 337 |
+
use_adapter=use_adapter, adapter_dim=adapter_dim
|
| 338 |
+
)
|
| 339 |
+
for _ in range(depth)
|
| 340 |
+
])
|
| 341 |
+
|
| 342 |
+
# 双流:时间流和频率流
|
| 343 |
+
if use_dual_stream:
|
| 344 |
+
# 修复:使用正确的池化维度
|
| 345 |
+
self.temporal_pool = nn.AdaptiveAvgPool1d(1)
|
| 346 |
+
self.frequency_pool = nn.AdaptiveAvgPool1d(1)
|
| 347 |
+
|
| 348 |
+
self.temporal_proj = nn.Linear(embed_dim, embed_dim)
|
| 349 |
+
self.frequency_proj = nn.Linear(embed_dim, embed_dim)
|
| 350 |
+
|
| 351 |
+
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
|
| 352 |
+
|
| 353 |
+
self.norm = RMSNorm(embed_dim)
|
| 354 |
+
self._init_weights()
|
| 355 |
+
|
| 356 |
+
def _init_weights(self):
|
| 357 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 358 |
+
self.apply(self._init_module_weights)
|
| 359 |
+
|
| 360 |
+
def _init_module_weights(self, m):
|
| 361 |
+
if isinstance(m, nn.Linear):
|
| 362 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 363 |
+
if m.bias is not None:
|
| 364 |
+
nn.init.zeros_(m.bias)
|
| 365 |
+
elif isinstance(m, nn.Conv2d):
|
| 366 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 367 |
+
if m.bias is not None:
|
| 368 |
+
nn.init.zeros_(m.bias)
|
| 369 |
+
|
| 370 |
+
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
| 371 |
+
if mel_spec.ndim == 3:
|
| 372 |
+
mel_spec = mel_spec.unsqueeze(1)
|
| 373 |
+
|
| 374 |
+
# Patch embedding
|
| 375 |
+
x = self.patch_embed(mel_spec) # [B, C, H, W]
|
| 376 |
+
x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
|
| 377 |
+
x = self.pos_drop(x + self.pos_embed)
|
| 378 |
+
|
| 379 |
+
# Transformer encoding
|
| 380 |
+
for block in self.blocks:
|
| 381 |
+
x, _, _ = block(x)
|
| 382 |
+
|
| 383 |
+
x = self.norm(x)
|
| 384 |
+
|
| 385 |
+
# 修复:双流处理
|
| 386 |
+
if self.use_dual_stream:
|
| 387 |
+
B, N, C = x.shape
|
| 388 |
+
|
| 389 |
+
# 重塑为2D网格
|
| 390 |
+
x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w)
|
| 391 |
+
|
| 392 |
+
# 时间流:沿频率维度池化(保留时间)
|
| 393 |
+
temporal = x_2d.mean(dim=2) # [B, C, W]
|
| 394 |
+
temporal = self.temporal_pool(temporal).squeeze(-1) # [B, C]
|
| 395 |
+
temporal = self.temporal_proj(temporal).unsqueeze(1) # [B, 1, C]
|
| 396 |
+
|
| 397 |
+
# 频率流:沿时间维度池化(保留频率)
|
| 398 |
+
frequency = x_2d.mean(dim=3) # [B, C, H]
|
| 399 |
+
frequency = self.frequency_pool(frequency).squeeze(-1) # [B, C]
|
| 400 |
+
frequency = self.frequency_proj(frequency).unsqueeze(1) # [B, 1, C]
|
| 401 |
+
|
| 402 |
+
# 融合
|
| 403 |
+
x = self.fusion(torch.cat([temporal, frequency], dim=-1))
|
| 404 |
+
else:
|
| 405 |
+
# 简单全局平均池化
|
| 406 |
+
x = x.mean(dim=1, keepdim=True)
|
| 407 |
+
|
| 408 |
+
return x
|
| 409 |
+
|
| 410 |
+
class ImprovedVideoEncoder(nn.Module):
|
| 411 |
+
"""
|
| 412 |
+
改进的视频编码器
|
| 413 |
+
- 因果时序建模
|
| 414 |
+
- 时空分离注意力
|
| 415 |
+
- 可选的3D卷积
|
| 416 |
+
"""
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
img_size: int = 224,
|
| 420 |
+
patch_size: int = 14,
|
| 421 |
+
in_channels: int = 3,
|
| 422 |
+
embed_dim: int = 2048,
|
| 423 |
+
spatial_depth: int = 12,
|
| 424 |
+
temporal_depth: int = 4,
|
| 425 |
+
n_heads: int = 16,
|
| 426 |
+
num_frames: int = 16,
|
| 427 |
+
dropout: float = 0.1,
|
| 428 |
+
use_adapter: bool = False,
|
| 429 |
+
adapter_dim: int = 64,
|
| 430 |
+
use_3d_conv: bool = False
|
| 431 |
+
):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.num_frames = num_frames
|
| 434 |
+
self.use_3d_conv = use_3d_conv
|
| 435 |
+
self.patch_size = patch_size
|
| 436 |
+
self.img_size = img_size
|
| 437 |
+
|
| 438 |
+
if use_3d_conv:
|
| 439 |
+
# 3D卷积处理时空信息
|
| 440 |
+
self.patch_embed = nn.Conv3d(
|
| 441 |
+
in_channels,
|
| 442 |
+
embed_dim,
|
| 443 |
+
kernel_size=(2, patch_size, patch_size),
|
| 444 |
+
stride=(2, patch_size, patch_size)
|
| 445 |
+
)
|
| 446 |
+
# 修复:计算3D卷积后的尺寸
|
| 447 |
+
self.n_temporal_patches = num_frames // 2
|
| 448 |
+
self.n_spatial_patches = (img_size // patch_size) ** 2
|
| 449 |
+
else:
|
| 450 |
+
# 2D卷积 + 时序建模
|
| 451 |
+
self.patch_embed = ImprovedPatchEmbedding(
|
| 452 |
+
patch_size, in_channels, embed_dim
|
| 453 |
+
)
|
| 454 |
+
self.n_spatial_patches = (img_size // patch_size) ** 2
|
| 455 |
+
|
| 456 |
+
# 空间位置编码
|
| 457 |
+
self.spatial_pos_embed = nn.Parameter(
|
| 458 |
+
torch.zeros(1, self.n_spatial_patches, embed_dim)
|
| 459 |
+
)
|
| 460 |
+
self.spatial_pos_drop = nn.Dropout(dropout)
|
| 461 |
+
|
| 462 |
+
# 空间编码器
|
| 463 |
+
self.spatial_blocks = nn.ModuleList([
|
| 464 |
+
OptimizedTransformerBlock(
|
| 465 |
+
embed_dim, n_heads, None, None, dropout,
|
| 466 |
+
use_adapter=use_adapter, adapter_dim=adapter_dim
|
| 467 |
+
)
|
| 468 |
+
for _ in range(spatial_depth)
|
| 469 |
+
])
|
| 470 |
+
|
| 471 |
+
# 时间位置编码
|
| 472 |
+
if use_3d_conv:
|
| 473 |
+
self.temporal_pos_embed = nn.Parameter(
|
| 474 |
+
torch.zeros(1, self.n_temporal_patches, embed_dim)
|
| 475 |
+
)
|
| 476 |
+
else:
|
| 477 |
+
self.temporal_pos_embed = nn.Parameter(
|
| 478 |
+
torch.zeros(1, num_frames, embed_dim)
|
| 479 |
+
)
|
| 480 |
+
self.temporal_pos_drop = nn.Dropout(dropout)
|
| 481 |
+
|
| 482 |
+
# 时序编码器
|
| 483 |
+
self.temporal_blocks = nn.ModuleList([
|
| 484 |
+
OptimizedTransformerBlock(
|
| 485 |
+
embed_dim, n_heads, None, None, dropout,
|
| 486 |
+
use_adapter=use_adapter, adapter_dim=adapter_dim
|
| 487 |
+
)
|
| 488 |
+
for _ in range(temporal_depth)
|
| 489 |
+
])
|
| 490 |
+
|
| 491 |
+
self.norm = RMSNorm(embed_dim)
|
| 492 |
+
self._init_weights()
|
| 493 |
+
|
| 494 |
+
def _init_weights(self):
|
| 495 |
+
nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02)
|
| 496 |
+
nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02)
|
| 497 |
+
self.apply(self._init_module_weights)
|
| 498 |
+
|
| 499 |
+
def _init_module_weights(self, m):
|
| 500 |
+
if isinstance(m, nn.Linear):
|
| 501 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 502 |
+
if m.bias is not None:
|
| 503 |
+
nn.init.zeros_(m.bias)
|
| 504 |
+
elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
|
| 505 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 506 |
+
if m.bias is not None:
|
| 507 |
+
nn.init.zeros_(m.bias)
|
| 508 |
+
|
| 509 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 510 |
+
B, T, C, H, W = x.shape
|
| 511 |
+
|
| 512 |
+
if self.use_3d_conv:
|
| 513 |
+
# 修复:3D卷积路径
|
| 514 |
+
x = x.transpose(1, 2) # [B, C, T, H, W]
|
| 515 |
+
x = self.patch_embed(x) # [B, embed_dim, T', H', W']
|
| 516 |
+
|
| 517 |
+
# 重塑: [B, D, T', H'*W'] -> [B, T', H'*W', D]
|
| 518 |
+
B, D, T_new, H_new, W_new = x.shape
|
| 519 |
+
x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) # [B, T', H'*W', D]
|
| 520 |
+
|
| 521 |
+
# 空间位置编码(每帧独立)
|
| 522 |
+
x = x + self.spatial_pos_embed.unsqueeze(1)
|
| 523 |
+
|
| 524 |
+
# 逐帧空间编码
|
| 525 |
+
x_flat = x.reshape(B * T_new, -1, D)
|
| 526 |
+
for block in self.spatial_blocks:
|
| 527 |
+
x_flat, _, _ = block(x_flat)
|
| 528 |
+
|
| 529 |
+
# 重塑回时序维度
|
| 530 |
+
x = x_flat.view(B, T_new, -1, D)
|
| 531 |
+
|
| 532 |
+
# 修复:时序聚合 - 使用平均池化而非取第一个token
|
| 533 |
+
x = x.mean(dim=2) # [B, T', D]
|
| 534 |
+
|
| 535 |
+
else:
|
| 536 |
+
# 2D卷积 + 分离时空建模
|
| 537 |
+
x_flat = x.view(B * T, C, H, W)
|
| 538 |
+
x_patched, grid_size = self.patch_embed(x_flat)
|
| 539 |
+
|
| 540 |
+
# 空间位置编码
|
| 541 |
+
x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed)
|
| 542 |
+
|
| 543 |
+
# 空间编码
|
| 544 |
+
for block in self.spatial_blocks:
|
| 545 |
+
x_patched, _, _ = block(x_patched)
|
| 546 |
+
|
| 547 |
+
# 修复:时序聚合 - 全局平均池化而非仅mean(dim=2)
|
| 548 |
+
_, N, D = x_patched.shape
|
| 549 |
+
x_spatial = x_patched.view(B, T, N, D)
|
| 550 |
+
x = x_spatial.mean(dim=2) # [B, T, D] - 对每帧的所有patch取平均
|
| 551 |
+
|
| 552 |
+
# 时序位置编码
|
| 553 |
+
x = self.temporal_pos_drop(x + self.temporal_pos_embed)
|
| 554 |
+
|
| 555 |
+
# 时序编码
|
| 556 |
+
for block in self.temporal_blocks:
|
| 557 |
+
x, _, _ = block(x)
|
| 558 |
+
|
| 559 |
+
return self.norm(x)
|
gradio1.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio 推理界面 - 多模态 Dense Transformer (适配 Qwen Tokenizer 版)
|
| 3 |
+
|
| 4 |
+
用法:
|
| 5 |
+
pip install -r requirements.txt
|
| 6 |
+
# requirements.txt 至少包含:
|
| 7 |
+
# torch>=1.12, transformers, pillow, gradio
|
| 8 |
+
python app_gradio.py --checkpoint /path/to/final_model.pt --tokenizer Qwen/Qwen2.5-7B-Instruct --port 7860 --share False
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import json
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
|
| 21 |
+
# UI
|
| 22 |
+
import gradio as gr
|
| 23 |
+
|
| 24 |
+
# 本项目代码引用(按你的工程结构调整)
|
| 25 |
+
from model import MultiModalDenseTransformer
|
| 26 |
+
from continual_learning import UnifiedMultiModalPreprocessor
|
| 27 |
+
|
| 28 |
+
# 设置国内镜像(如需要)
|
| 29 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 30 |
+
|
| 31 |
+
# ---- 与你原来保持一致的图像预处理 ----
|
| 32 |
+
from torchvision import transforms
|
| 33 |
+
image_transform = transforms.Compose([
|
| 34 |
+
transforms.Resize((224, 224)),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 37 |
+
std=[0.229, 0.224, 0.225]),
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
# -------- ModelInference 类(轻微改写) --------
|
| 41 |
+
class ModelInference:
|
| 42 |
+
def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
| 43 |
+
self.device = torch.device(device)
|
| 44 |
+
print(f"Using device: {self.device}")
|
| 45 |
+
print(f"Loading tokenizer: {tokenizer_name}...")
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True)
|
| 47 |
+
if self.tokenizer.pad_token is None:
|
| 48 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 49 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 50 |
+
|
| 51 |
+
if config_path and Path(config_path).exists():
|
| 52 |
+
with open(config_path, 'r') as f:
|
| 53 |
+
self.config = json.load(f)
|
| 54 |
+
else:
|
| 55 |
+
# 采用你原始脚本中的默认 config(可按需调整)
|
| 56 |
+
self.config = {
|
| 57 |
+
'model_dim': 1536,
|
| 58 |
+
'vocab_size': len(self.tokenizer),
|
| 59 |
+
'n_layers': 12,
|
| 60 |
+
'n_heads': 12,
|
| 61 |
+
'n_kv_heads': 4,
|
| 62 |
+
'head_dim': None,
|
| 63 |
+
'max_seq_len': 512,
|
| 64 |
+
'dropout': 0.0,
|
| 65 |
+
'use_moe': False,
|
| 66 |
+
'use_adapter': False,
|
| 67 |
+
'use_lora': False,
|
| 68 |
+
'rope_scaling_type': "yarn",
|
| 69 |
+
'use_multimodal_fusion': False,
|
| 70 |
+
'use_contrastive': False
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
# init model + preprocessor
|
| 74 |
+
print("Initializing model architecture...")
|
| 75 |
+
self.model = MultiModalDenseTransformer(**self.config)
|
| 76 |
+
self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim'])
|
| 77 |
+
|
| 78 |
+
print(f"Loading checkpoint from {checkpoint_path}...")
|
| 79 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 80 |
+
# 支持 checkpoint 包含 'model_state_dict' 的情况
|
| 81 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint
|
| 82 |
+
|
| 83 |
+
new_state_dict = {}
|
| 84 |
+
for k, v in state_dict.items():
|
| 85 |
+
if k.startswith('module.'):
|
| 86 |
+
new_state_dict[k[7:]] = v
|
| 87 |
+
else:
|
| 88 |
+
new_state_dict[k] = v
|
| 89 |
+
|
| 90 |
+
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
|
| 91 |
+
if missing:
|
| 92 |
+
print(f"Warning: Missing keys: {len(missing)}")
|
| 93 |
+
if unexpected:
|
| 94 |
+
print(f"Warning: Unexpected keys: {len(unexpected)}")
|
| 95 |
+
|
| 96 |
+
self.model.to(self.device)
|
| 97 |
+
self.preprocessor.to(self.device)
|
| 98 |
+
self.model.eval()
|
| 99 |
+
print("Model loaded successfully!")
|
| 100 |
+
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def generate_text(self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 10, top_p: float = 0.9, repetition_penalty: float = 1.2, image: Optional[Image.Image] = None) -> str:
|
| 104 |
+
formatted_prompt = f"Instruction: {prompt}\nResponse:"
|
| 105 |
+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
|
| 106 |
+
input_ids = inputs['input_ids'].to(self.device)
|
| 107 |
+
|
| 108 |
+
input_data = {'segments': []}
|
| 109 |
+
if image is not None:
|
| 110 |
+
try:
|
| 111 |
+
if image.mode != 'RGB':
|
| 112 |
+
image = image.convert('RGB')
|
| 113 |
+
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
|
| 114 |
+
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
|
| 115 |
+
for seg in mod_segments:
|
| 116 |
+
input_data['segments'].append(seg)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Warning: Image processing skipped due to error: {e}")
|
| 119 |
+
|
| 120 |
+
input_data['segments'].append({
|
| 121 |
+
'type': 'text',
|
| 122 |
+
'data': input_ids,
|
| 123 |
+
'modality_id': 0
|
| 124 |
+
})
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
generated_ids = self.model.generate(
|
| 128 |
+
input_data,
|
| 129 |
+
max_new_tokens=max_new_tokens,
|
| 130 |
+
temperature=temperature,
|
| 131 |
+
top_k=top_k,
|
| 132 |
+
top_p=top_p,
|
| 133 |
+
repetition_penalty=repetition_penalty,
|
| 134 |
+
do_sample=True,
|
| 135 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 136 |
+
pad_token_id=self.tokenizer.pad_token_id
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 140 |
+
# 提取 Response 后的部分并做 stop 处理
|
| 141 |
+
if "Response:" in full_output:
|
| 142 |
+
answer = full_output.split("Response:")[-1].strip()
|
| 143 |
+
else:
|
| 144 |
+
answer = full_output
|
| 145 |
+
|
| 146 |
+
stop_words = ["Instruction", "Input", "###", "Response", "User:", "Assistant:", "\n\n"]
|
| 147 |
+
for sw in stop_words:
|
| 148 |
+
if sw in answer:
|
| 149 |
+
answer = answer.split(sw)[0].strip()
|
| 150 |
+
|
| 151 |
+
# 去掉可能的 echo
|
| 152 |
+
lines = answer.split('\n')
|
| 153 |
+
if len(lines) > 0 and prompt.lower() in lines[0].lower():
|
| 154 |
+
answer = "\n".join(lines[1:]).strip()
|
| 155 |
+
return answer
|
| 156 |
+
except Exception as e:
|
| 157 |
+
import traceback
|
| 158 |
+
traceback.print_exc()
|
| 159 |
+
return f"Error: {e}"
|
| 160 |
+
|
| 161 |
+
# -------- Gradio UI 部分 --------
|
| 162 |
+
def build_ui(model_instance):
|
| 163 |
+
with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css="""
|
| 164 |
+
.gradio-container { max-width: 900px; margin: auto; }
|
| 165 |
+
""") as demo:
|
| 166 |
+
gr.Markdown("## 🚀 多模态在线推理(文本 + 图片)")
|
| 167 |
+
with gr.Row():
|
| 168 |
+
with gr.Column(scale=3):
|
| 169 |
+
txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5)
|
| 170 |
+
img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)")
|
| 171 |
+
btn = gr.Button("生成 (Generate)")
|
| 172 |
+
with gr.Column(scale=2):
|
| 173 |
+
max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=1024, step=1, value=128)
|
| 174 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.01, value=0.7)
|
| 175 |
+
top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, step=1, value=40)
|
| 176 |
+
top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
|
| 177 |
+
rep_pen = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.01, value=1.1)
|
| 178 |
+
status = gr.Textbox(label="Status", value="Ready", interactive=False)
|
| 179 |
+
output = gr.Textbox(label="Output", lines=12, interactive=False)
|
| 180 |
+
|
| 181 |
+
def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v):
|
| 182 |
+
if not prompt or str(prompt).strip() == "":
|
| 183 |
+
return "", "请输入 Prompt", ""
|
| 184 |
+
status_msg = "Generating..."
|
| 185 |
+
# call model
|
| 186 |
+
out = model_instance.generate_text(prompt=prompt,
|
| 187 |
+
max_new_tokens=int(max_tokens_v),
|
| 188 |
+
temperature=float(temp_v),
|
| 189 |
+
top_k=int(topk_v),
|
| 190 |
+
top_p=float(topp_v),
|
| 191 |
+
repetition_penalty=float(rep_v),
|
| 192 |
+
image=image)
|
| 193 |
+
return out, "Done", ""
|
| 194 |
+
|
| 195 |
+
btn.click(fn=gr_generate, inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen], outputs=[output, status, gr.State()])
|
| 196 |
+
|
| 197 |
+
demo.launch(share=True)
|
| 198 |
+
|
| 199 |
+
return demo
|
| 200 |
+
|
| 201 |
+
# -------- CLI / main --------
|
| 202 |
+
def main():
|
| 203 |
+
parser = argparse.ArgumentParser()
|
| 204 |
+
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
|
| 205 |
+
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
|
| 206 |
+
parser.add_argument("--config", type=str, default=None)
|
| 207 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 208 |
+
parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True)
|
| 209 |
+
args = parser.parse_args()
|
| 210 |
+
|
| 211 |
+
# 如果 default 的 final_model 不存在,尝试寻找最近 step
|
| 212 |
+
if not Path(args.checkpoint).exists():
|
| 213 |
+
possible = list(Path("checkpoints/pretrain").glob("step_*.pt"))
|
| 214 |
+
if possible:
|
| 215 |
+
args.checkpoint = str(possible[-1])
|
| 216 |
+
print(f"未找到 final_model.pt,使用最新 checkpoint: {args.checkpoint}")
|
| 217 |
+
else:
|
| 218 |
+
raise FileNotFoundError(f"找不到检查点: {args.checkpoint}")
|
| 219 |
+
|
| 220 |
+
global model_instance
|
| 221 |
+
model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config)
|
| 222 |
+
|
| 223 |
+
# 启动 Gradio(使用 share 参数决定是否创建公网链接)
|
| 224 |
+
demo = build_ui(model_instance)
|
| 225 |
+
demo.launch(server_port=args.port, share=args.share)
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
main()
|
grpo.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
改进的 GRPO (Group Relative Policy Optimization) 训练器
|
| 3 |
+
修复了所有已知问题
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 10 |
+
from typing import Dict, List, Tuple, Optional
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import numpy as np
|
| 13 |
+
import gc
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GRPOTrainer:
|
| 21 |
+
"""
|
| 22 |
+
GRPO训练器 - Group Relative Policy Optimization
|
| 23 |
+
参考 DeepSeekMath/DeepSeek-V3 策略
|
| 24 |
+
|
| 25 |
+
主要修复:
|
| 26 |
+
1. 修复了 generate() 返回格式问题
|
| 27 |
+
2. 修复了 reward_model 输出处理
|
| 28 |
+
3. 添加了完整的混合精度训练支持
|
| 29 |
+
4. 改进了 KL 散度计算的数值稳定性
|
| 30 |
+
5. 修复了 past_key_values 的使用逻辑
|
| 31 |
+
6. 改进了内存管理和错误处理
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
actor_model,
|
| 37 |
+
reward_model,
|
| 38 |
+
ref_model,
|
| 39 |
+
tokenizer,
|
| 40 |
+
learning_rate: float = 1e-6,
|
| 41 |
+
kl_coef: float = 0.04,
|
| 42 |
+
group_size: int = 4,
|
| 43 |
+
clip_epsilon: float = 0.2,
|
| 44 |
+
max_grad_norm: float = 1.0,
|
| 45 |
+
grpo_epochs: int = 1,
|
| 46 |
+
update_batch_size: int = 4,
|
| 47 |
+
use_amp: bool = True,
|
| 48 |
+
value_clip: bool = False,
|
| 49 |
+
entropy_coef: float = 0.01,
|
| 50 |
+
advantage_normalization: str = 'group', # 'group', 'global', 'none'
|
| 51 |
+
kl_estimation_method: str = 'forward' # 'forward', 'reverse', 'symmetric'
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
初始化GRPO训练器
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
actor_model: 要训练的策略模型
|
| 58 |
+
reward_model: 奖励模型(冻结)
|
| 59 |
+
ref_model: 参考模型(冻结)
|
| 60 |
+
tokenizer: 分词器
|
| 61 |
+
learning_rate: 学习率
|
| 62 |
+
kl_coef: KL散度惩罚系数
|
| 63 |
+
group_size: 每个prompt生成的样本数
|
| 64 |
+
clip_epsilon: PPO clip范围
|
| 65 |
+
max_grad_norm: 梯度裁剪阈值
|
| 66 |
+
grpo_epochs: 每批经验的训练轮数
|
| 67 |
+
update_batch_size: 更新时的mini-batch大小
|
| 68 |
+
use_amp: 是否使用混合精度训练
|
| 69 |
+
value_clip: 是否对value进行clip(当前未使用value网络)
|
| 70 |
+
entropy_coef: 熵正则化系数
|
| 71 |
+
advantage_normalization: 优势函数归一化方式
|
| 72 |
+
kl_estimation_method: KL散度估计方法
|
| 73 |
+
"""
|
| 74 |
+
self.actor = actor_model
|
| 75 |
+
self.reward_model = reward_model
|
| 76 |
+
self.ref_model = ref_model
|
| 77 |
+
self.tokenizer = tokenizer
|
| 78 |
+
|
| 79 |
+
self.kl_coef = kl_coef
|
| 80 |
+
self.group_size = group_size
|
| 81 |
+
self.clip_epsilon = clip_epsilon
|
| 82 |
+
self.max_grad_norm = max_grad_norm
|
| 83 |
+
self.grpo_epochs = grpo_epochs
|
| 84 |
+
self.update_batch_size = update_batch_size
|
| 85 |
+
self.use_amp = use_amp
|
| 86 |
+
self.entropy_coef = entropy_coef
|
| 87 |
+
self.advantage_normalization = advantage_normalization
|
| 88 |
+
self.kl_estimation_method = kl_estimation_method
|
| 89 |
+
|
| 90 |
+
self.device = next(actor_model.parameters()).device
|
| 91 |
+
|
| 92 |
+
# 冻结参考模型和奖励模型
|
| 93 |
+
self.ref_model.eval()
|
| 94 |
+
self.ref_model.requires_grad_(False)
|
| 95 |
+
self.reward_model.eval()
|
| 96 |
+
self.reward_model.requires_grad_(False)
|
| 97 |
+
|
| 98 |
+
# 优化器配置
|
| 99 |
+
self.optimizer = optim.AdamW(
|
| 100 |
+
filter(lambda p: p.requires_grad, actor_model.parameters()),
|
| 101 |
+
lr=learning_rate,
|
| 102 |
+
weight_decay=0.01,
|
| 103 |
+
betas=(0.9, 0.95),
|
| 104 |
+
eps=1e-8
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# 混合精度训练 - 修复:添加 GradScaler
|
| 108 |
+
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
| 109 |
+
|
| 110 |
+
# 训练统计
|
| 111 |
+
self.training_stats = {
|
| 112 |
+
'iterations': 0,
|
| 113 |
+
'total_samples': 0,
|
| 114 |
+
'avg_rewards': [],
|
| 115 |
+
'avg_kl': [],
|
| 116 |
+
'policy_losses': []
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
logger.info(f"GRPO Trainer initialized:")
|
| 120 |
+
logger.info(f" Group Size: {group_size}")
|
| 121 |
+
logger.info(f" KL Coef: {kl_coef}")
|
| 122 |
+
logger.info(f" Clip Epsilon: {clip_epsilon}")
|
| 123 |
+
logger.info(f" Learning Rate: {learning_rate}")
|
| 124 |
+
logger.info(f" Update Batch Size: {update_batch_size}")
|
| 125 |
+
logger.info(f" Mixed Precision: {use_amp}")
|
| 126 |
+
logger.info(f" KL Estimation: {kl_estimation_method}")
|
| 127 |
+
|
| 128 |
+
def _compute_kl_divergence(
|
| 129 |
+
self,
|
| 130 |
+
log_probs: torch.Tensor,
|
| 131 |
+
ref_log_probs: torch.Tensor,
|
| 132 |
+
mask: torch.Tensor
|
| 133 |
+
) -> torch.Tensor:
|
| 134 |
+
"""
|
| 135 |
+
计算KL散度(改进数值稳定性)
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
log_probs: 当前策略的log概率
|
| 139 |
+
ref_log_probs: 参考策略的log概率
|
| 140 |
+
mask: 有效token的mask
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
KL散度(标量)
|
| 144 |
+
"""
|
| 145 |
+
if self.kl_estimation_method == 'forward':
|
| 146 |
+
# KL(π||π_ref) = Σ π * log(π/π_ref)
|
| 147 |
+
# ≈ Σ exp(log_π) * (log_π - log_π_ref)
|
| 148 |
+
# 为了数值稳定,使用 log_π - log_π_ref
|
| 149 |
+
kl = log_probs - ref_log_probs
|
| 150 |
+
elif self.kl_estimation_method == 'reverse':
|
| 151 |
+
# KL(π_ref||π) = Σ π_ref * log(π_ref/π)
|
| 152 |
+
kl = ref_log_probs - log_probs
|
| 153 |
+
else: # symmetric
|
| 154 |
+
# 对称KL散度
|
| 155 |
+
forward_kl = log_probs - ref_log_probs
|
| 156 |
+
reverse_kl = ref_log_probs - log_probs
|
| 157 |
+
kl = 0.5 * (forward_kl + reverse_kl)
|
| 158 |
+
|
| 159 |
+
# 应用mask并求和
|
| 160 |
+
kl_penalty = (kl * mask).sum(dim=-1)
|
| 161 |
+
return kl_penalty
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
def generate_experience(
|
| 165 |
+
self,
|
| 166 |
+
prompts_dataloader: DataLoader,
|
| 167 |
+
max_gen_len: int,
|
| 168 |
+
temperature: float = 1.0,
|
| 169 |
+
top_p: float = 0.9
|
| 170 |
+
) -> Dict:
|
| 171 |
+
"""
|
| 172 |
+
生成经验数据:采样 -> 计算 LogProbs -> 计算 Rewards(含KL)
|
| 173 |
+
|
| 174 |
+
修复:
|
| 175 |
+
1. 正确处理 generate() 的返回值
|
| 176 |
+
2. 修复 reward_model 的输出处理
|
| 177 |
+
3. 改进数值稳定性
|
| 178 |
+
"""
|
| 179 |
+
self.actor.eval()
|
| 180 |
+
|
| 181 |
+
all_sequences = []
|
| 182 |
+
all_log_probs = []
|
| 183 |
+
all_advantages = []
|
| 184 |
+
all_prompt_lens = []
|
| 185 |
+
all_rewards = []
|
| 186 |
+
|
| 187 |
+
logger.info("Generating experience...")
|
| 188 |
+
|
| 189 |
+
for prompts in tqdm(prompts_dataloader, desc="Generating Experience"):
|
| 190 |
+
try:
|
| 191 |
+
# 处理不同的输入格式
|
| 192 |
+
if isinstance(prompts, (list, tuple)):
|
| 193 |
+
prompts = prompts[0]
|
| 194 |
+
|
| 195 |
+
prompts = prompts.to(self.device)
|
| 196 |
+
batch_size = prompts.shape[0]
|
| 197 |
+
|
| 198 |
+
# 扩展prompts以生成group_size个样本
|
| 199 |
+
prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0)
|
| 200 |
+
prompt_len = prompts_repeated.shape[1]
|
| 201 |
+
|
| 202 |
+
input_data = {
|
| 203 |
+
'segments': [{
|
| 204 |
+
'type': 'text',
|
| 205 |
+
'data': prompts_repeated,
|
| 206 |
+
'modality_id': 0
|
| 207 |
+
}]
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# 1. 采样生成(修复:generate只返回新生成的tokens)
|
| 211 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 212 |
+
response_ids = self.actor.generate(
|
| 213 |
+
input_data,
|
| 214 |
+
max_new_tokens=max_gen_len,
|
| 215 |
+
do_sample=True,
|
| 216 |
+
temperature=temperature,
|
| 217 |
+
top_p=top_p,
|
| 218 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 219 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 220 |
+
use_cache=True # 使用缓存加速生成
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# 修复:拼接完整序列(prompt + response)
|
| 224 |
+
sequences = torch.cat([prompts_repeated, response_ids], dim=1)
|
| 225 |
+
|
| 226 |
+
# 检查序列长度
|
| 227 |
+
if sequences.shape[1] <= prompt_len:
|
| 228 |
+
logger.warning("Generated sequence too short, skipping batch")
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
full_input_data = {
|
| 232 |
+
'segments': [{
|
| 233 |
+
'type': 'text',
|
| 234 |
+
'data': sequences,
|
| 235 |
+
'modality_id': 0
|
| 236 |
+
}]
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# 2. 计算当前策略和参考策略的 LogProbs
|
| 240 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 241 |
+
actor_out = self.actor(full_input_data)
|
| 242 |
+
ref_out = self.ref_model(full_input_data)
|
| 243 |
+
|
| 244 |
+
logits = actor_out['logits'][:, :-1, :]
|
| 245 |
+
ref_logits = ref_out['logits'][:, :-1, :]
|
| 246 |
+
targets = sequences[:, 1:]
|
| 247 |
+
|
| 248 |
+
# 计算log probabilities(改进数值稳定性)
|
| 249 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 250 |
+
ref_log_probs = F.log_softmax(ref_logits, dim=-1)
|
| 251 |
+
|
| 252 |
+
# 提取对应token的log概率
|
| 253 |
+
per_token_log_probs = torch.gather(
|
| 254 |
+
log_probs, -1, targets.unsqueeze(-1)
|
| 255 |
+
).squeeze(-1)
|
| 256 |
+
per_token_ref_log_probs = torch.gather(
|
| 257 |
+
ref_log_probs, -1, targets.unsqueeze(-1)
|
| 258 |
+
).squeeze(-1)
|
| 259 |
+
|
| 260 |
+
# 3. 计算 KL 散度 (只针对response部分)
|
| 261 |
+
response_mask = torch.arange(
|
| 262 |
+
sequences.size(1) - 1, device=self.device
|
| 263 |
+
) >= (prompt_len - 1)
|
| 264 |
+
response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs)
|
| 265 |
+
response_mask = response_mask.float()
|
| 266 |
+
|
| 267 |
+
# 使用改进的KL计算
|
| 268 |
+
kl_penalty = self._compute_kl_divergence(
|
| 269 |
+
per_token_log_probs,
|
| 270 |
+
per_token_ref_log_probs,
|
| 271 |
+
response_mask
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# 4. 计算环境奖励(修复:正确处理reward_model输出)
|
| 275 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 276 |
+
reward_output = self.reward_model(full_input_data)
|
| 277 |
+
|
| 278 |
+
# reward_model返回 (batch_size, seq_len),取最后一个位置的奖励
|
| 279 |
+
if reward_output.dim() == 2:
|
| 280 |
+
raw_rewards = reward_output[:, -1]
|
| 281 |
+
else:
|
| 282 |
+
raw_rewards = reward_output.squeeze(-1)
|
| 283 |
+
|
| 284 |
+
# 5. 组合总奖励: R_total = R_env - β * KL
|
| 285 |
+
total_rewards = raw_rewards - self.kl_coef * kl_penalty
|
| 286 |
+
|
| 287 |
+
# 6. 计算组内相对优势 (Group Relative Advantage)
|
| 288 |
+
rewards_grouped = total_rewards.view(batch_size, self.group_size)
|
| 289 |
+
|
| 290 |
+
if self.advantage_normalization == 'group':
|
| 291 |
+
# 组内标准化
|
| 292 |
+
mean_grouped = rewards_grouped.mean(dim=1, keepdim=True)
|
| 293 |
+
std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
|
| 294 |
+
advantages = (rewards_grouped - mean_grouped) / std_grouped
|
| 295 |
+
elif self.advantage_normalization == 'global':
|
| 296 |
+
# 全局标准化
|
| 297 |
+
advantages = (rewards_grouped - rewards_grouped.mean()) / (
|
| 298 |
+
rewards_grouped.std() + 1e-8
|
| 299 |
+
)
|
| 300 |
+
else: # 'none'
|
| 301 |
+
advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True)
|
| 302 |
+
|
| 303 |
+
advantages = advantages.view(-1)
|
| 304 |
+
|
| 305 |
+
# 保存数据
|
| 306 |
+
all_sequences.append(sequences.cpu())
|
| 307 |
+
all_log_probs.append(per_token_log_probs.detach().cpu())
|
| 308 |
+
all_advantages.append(advantages.detach().cpu())
|
| 309 |
+
all_prompt_lens.append(
|
| 310 |
+
torch.full((sequences.size(0),), prompt_len, dtype=torch.long)
|
| 311 |
+
)
|
| 312 |
+
all_rewards.append(total_rewards.detach().cpu())
|
| 313 |
+
|
| 314 |
+
# 清理中间变量
|
| 315 |
+
del logits, ref_logits, actor_out, ref_out
|
| 316 |
+
del log_probs, ref_log_probs, reward_output
|
| 317 |
+
|
| 318 |
+
except Exception as e:
|
| 319 |
+
logger.error(f"Error generating experience for batch: {e}")
|
| 320 |
+
import traceback
|
| 321 |
+
traceback.print_exc()
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
finally:
|
| 325 |
+
torch.cuda.empty_cache()
|
| 326 |
+
|
| 327 |
+
if not all_sequences:
|
| 328 |
+
raise RuntimeError("No valid sequences generated")
|
| 329 |
+
|
| 330 |
+
# 合并所有数据
|
| 331 |
+
experience = {
|
| 332 |
+
'sequences': torch.cat(all_sequences, dim=0),
|
| 333 |
+
'log_probs': torch.cat(all_log_probs, dim=0),
|
| 334 |
+
'advantages': torch.cat(all_advantages, dim=0),
|
| 335 |
+
'prompt_lengths': torch.cat(all_prompt_lens, dim=0),
|
| 336 |
+
'rewards': torch.cat(all_rewards, dim=0)
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# 统计信息
|
| 340 |
+
logger.info(f"Generated {len(experience['sequences'])} sequences")
|
| 341 |
+
logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}")
|
| 342 |
+
logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}")
|
| 343 |
+
logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}")
|
| 344 |
+
|
| 345 |
+
return experience
|
| 346 |
+
|
| 347 |
+
def grpo_step(
|
| 348 |
+
self,
|
| 349 |
+
dataset: TensorDataset
|
| 350 |
+
) -> Dict[str, float]:
|
| 351 |
+
"""
|
| 352 |
+
执行 GRPO 优化步骤
|
| 353 |
+
|
| 354 |
+
修复:
|
| 355 |
+
1. 使用 GradScaler 进行混合精度训练
|
| 356 |
+
2. 改进损失计算
|
| 357 |
+
3. 更好的统计信息收集
|
| 358 |
+
"""
|
| 359 |
+
self.actor.train()
|
| 360 |
+
|
| 361 |
+
dataloader = DataLoader(
|
| 362 |
+
dataset,
|
| 363 |
+
batch_size=self.update_batch_size,
|
| 364 |
+
shuffle=True,
|
| 365 |
+
drop_last=False
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
epoch_stats = {
|
| 369 |
+
'total_loss': 0.0,
|
| 370 |
+
'policy_loss': 0.0,
|
| 371 |
+
'entropy': 0.0,
|
| 372 |
+
'approx_kl': 0.0,
|
| 373 |
+
'clip_fraction': 0.0,
|
| 374 |
+
'steps': 0
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
for batch_data in dataloader:
|
| 378 |
+
sequences, old_log_probs, advantages, prompt_lens = batch_data
|
| 379 |
+
|
| 380 |
+
sequences = sequences.to(self.device)
|
| 381 |
+
old_log_probs = old_log_probs.to(self.device)
|
| 382 |
+
advantages = advantages.to(self.device)
|
| 383 |
+
|
| 384 |
+
input_data = {
|
| 385 |
+
'segments': [{
|
| 386 |
+
'type': 'text',
|
| 387 |
+
'data': sequences,
|
| 388 |
+
'modality_id': 0
|
| 389 |
+
}]
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
# 修复:使用 GradScaler 进行混合精度训练
|
| 393 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 394 |
+
outputs = self.actor(input_data)
|
| 395 |
+
logits = outputs['logits'][:, :-1, :]
|
| 396 |
+
|
| 397 |
+
# 计算新的log probabilities
|
| 398 |
+
targets = sequences[:, 1:]
|
| 399 |
+
log_probs_dist = F.log_softmax(logits, dim=-1)
|
| 400 |
+
new_log_probs = torch.gather(
|
| 401 |
+
log_probs_dist, -1, targets.unsqueeze(-1)
|
| 402 |
+
).squeeze(-1)
|
| 403 |
+
|
| 404 |
+
# 构建response mask
|
| 405 |
+
mask = torch.zeros_like(new_log_probs)
|
| 406 |
+
for i, pl in enumerate(prompt_lens):
|
| 407 |
+
mask[i, pl-1:] = 1.0
|
| 408 |
+
|
| 409 |
+
# 计算概率比率
|
| 410 |
+
ratio = torch.exp(new_log_probs - old_log_probs)
|
| 411 |
+
|
| 412 |
+
# 扩展advantages到序列维度
|
| 413 |
+
adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs)
|
| 414 |
+
|
| 415 |
+
# PPO clip损失
|
| 416 |
+
surr1 = ratio * adv_expanded
|
| 417 |
+
surr2 = torch.clamp(
|
| 418 |
+
ratio,
|
| 419 |
+
1.0 - self.clip_epsilon,
|
| 420 |
+
1.0 + self.clip_epsilon
|
| 421 |
+
) * adv_expanded
|
| 422 |
+
|
| 423 |
+
# 策略损失(最小化负目标)
|
| 424 |
+
policy_loss = -torch.min(surr1, surr2)
|
| 425 |
+
policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8)
|
| 426 |
+
|
| 427 |
+
# 熵奖励(鼓励探索)
|
| 428 |
+
probs = F.softmax(logits, dim=-1)
|
| 429 |
+
entropy = -(probs * log_probs_dist).sum(dim=-1)
|
| 430 |
+
entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8)
|
| 431 |
+
|
| 432 |
+
# 总损失
|
| 433 |
+
loss = policy_loss - self.entropy_coef * entropy_bonus
|
| 434 |
+
|
| 435 |
+
# 统计信息
|
| 436 |
+
with torch.no_grad():
|
| 437 |
+
log_ratio = new_log_probs - old_log_probs
|
| 438 |
+
approx_kl = ((ratio - 1) - log_ratio) * mask
|
| 439 |
+
approx_kl = approx_kl.sum() / (mask.sum() + 1e-8)
|
| 440 |
+
|
| 441 |
+
clip_fraction = ((ratio > 1 + self.clip_epsilon) |
|
| 442 |
+
(ratio < 1 - self.clip_epsilon)).float()
|
| 443 |
+
clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8)
|
| 444 |
+
|
| 445 |
+
# 修复:使用 GradScaler 进行反向传播
|
| 446 |
+
self.optimizer.zero_grad()
|
| 447 |
+
self.scaler.scale(loss).backward()
|
| 448 |
+
|
| 449 |
+
# 梯度裁剪
|
| 450 |
+
self.scaler.unscale_(self.optimizer)
|
| 451 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 452 |
+
self.actor.parameters(),
|
| 453 |
+
self.max_grad_norm
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
self.scaler.step(self.optimizer)
|
| 457 |
+
self.scaler.update()
|
| 458 |
+
|
| 459 |
+
# 累积统计
|
| 460 |
+
epoch_stats['total_loss'] += loss.item()
|
| 461 |
+
epoch_stats['policy_loss'] += policy_loss.item()
|
| 462 |
+
epoch_stats['entropy'] += entropy_bonus.item()
|
| 463 |
+
epoch_stats['approx_kl'] += approx_kl.item()
|
| 464 |
+
epoch_stats['clip_fraction'] += clip_fraction.item()
|
| 465 |
+
epoch_stats['steps'] += 1
|
| 466 |
+
|
| 467 |
+
# 计算平均值
|
| 468 |
+
for key in epoch_stats:
|
| 469 |
+
if key != 'steps':
|
| 470 |
+
epoch_stats[key] /= max(epoch_stats['steps'], 1)
|
| 471 |
+
|
| 472 |
+
return epoch_stats
|
| 473 |
+
|
| 474 |
+
def train(
|
| 475 |
+
self,
|
| 476 |
+
prompt_dataloader: DataLoader,
|
| 477 |
+
num_iterations: int = 1,
|
| 478 |
+
max_gen_len: int = 50,
|
| 479 |
+
temperature: float = 1.0,
|
| 480 |
+
save_every: int = 5,
|
| 481 |
+
save_path: str = "checkpoints"
|
| 482 |
+
):
|
| 483 |
+
"""
|
| 484 |
+
完整的GRPO训练循环
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
prompt_dataloader: 提供prompts的数据加载器
|
| 488 |
+
num_iterations: 训练迭代次数
|
| 489 |
+
max_gen_len: 最大生成长度
|
| 490 |
+
temperature: 采样温度
|
| 491 |
+
save_every: 每N次迭代保存一次checkpoint
|
| 492 |
+
save_path: checkpoint保存路径
|
| 493 |
+
"""
|
| 494 |
+
logger.info(f"\n{'='*80}")
|
| 495 |
+
logger.info(f"Starting GRPO Training")
|
| 496 |
+
logger.info(f" Iterations: {num_iterations}")
|
| 497 |
+
logger.info(f" Max Gen Length: {max_gen_len}")
|
| 498 |
+
logger.info(f" Temperature: {temperature}")
|
| 499 |
+
logger.info(f"{'='*80}\n")
|
| 500 |
+
|
| 501 |
+
for iteration in range(num_iterations):
|
| 502 |
+
try:
|
| 503 |
+
# 1. 生成经验
|
| 504 |
+
experience = self.generate_experience(
|
| 505 |
+
prompt_dataloader,
|
| 506 |
+
max_gen_len,
|
| 507 |
+
temperature
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
dataset = TensorDataset(
|
| 511 |
+
experience['sequences'],
|
| 512 |
+
experience['log_probs'],
|
| 513 |
+
experience['advantages'],
|
| 514 |
+
experience['prompt_lengths']
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# 2. 策略优化(多个epoch)
|
| 518 |
+
logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...")
|
| 519 |
+
all_epoch_stats = []
|
| 520 |
+
|
| 521 |
+
for epoch in range(self.grpo_epochs):
|
| 522 |
+
stats = self.grpo_step(dataset)
|
| 523 |
+
all_epoch_stats.append(stats)
|
| 524 |
+
|
| 525 |
+
logger.info(
|
| 526 |
+
f" Epoch {epoch+1}/{self.grpo_epochs} | "
|
| 527 |
+
f"Loss: {stats['total_loss']:.4f} | "
|
| 528 |
+
f"KL: {stats['approx_kl']:.4f} | "
|
| 529 |
+
f"Clip%: {stats['clip_fraction']*100:.1f}"
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# 3. 汇总统计
|
| 533 |
+
avg_stats = {
|
| 534 |
+
key: np.mean([s[key] for s in all_epoch_stats])
|
| 535 |
+
for key in all_epoch_stats[0].keys()
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
self.training_stats['iterations'] += 1
|
| 539 |
+
self.training_stats['total_samples'] += len(experience['sequences'])
|
| 540 |
+
self.training_stats['avg_rewards'].append(
|
| 541 |
+
experience['rewards'].mean().item()
|
| 542 |
+
)
|
| 543 |
+
self.training_stats['avg_kl'].append(avg_stats['approx_kl'])
|
| 544 |
+
self.training_stats['policy_losses'].append(avg_stats['policy_loss'])
|
| 545 |
+
|
| 546 |
+
# 4. 打印进度
|
| 547 |
+
logger.info(f"\n{'='*80}")
|
| 548 |
+
logger.info(f"Iteration {iteration+1}/{num_iterations} Complete")
|
| 549 |
+
logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}")
|
| 550 |
+
logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}")
|
| 551 |
+
logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}")
|
| 552 |
+
logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}")
|
| 553 |
+
logger.info(f" Entropy: {avg_stats['entropy']:.4f}")
|
| 554 |
+
logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%")
|
| 555 |
+
logger.info(f"{'='*80}\n")
|
| 556 |
+
|
| 557 |
+
# 5. 保存checkpoint
|
| 558 |
+
if (iteration + 1) % save_every == 0:
|
| 559 |
+
self.save_checkpoint(
|
| 560 |
+
f"{save_path}/grpo_iter_{iteration+1}.pt"
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# 6. 清理内存
|
| 564 |
+
del experience, dataset
|
| 565 |
+
gc.collect()
|
| 566 |
+
torch.cuda.empty_cache()
|
| 567 |
+
|
| 568 |
+
except Exception as e:
|
| 569 |
+
logger.error(f"Error in iteration {iteration+1}: {e}")
|
| 570 |
+
import traceback
|
| 571 |
+
traceback.print_exc()
|
| 572 |
+
continue
|
| 573 |
+
|
| 574 |
+
logger.info("GRPO Training Complete!")
|
| 575 |
+
self.print_training_summary()
|
| 576 |
+
|
| 577 |
+
def save_checkpoint(self, path: str):
|
| 578 |
+
"""保存训练checkpoint"""
|
| 579 |
+
import os
|
| 580 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 581 |
+
|
| 582 |
+
checkpoint = {
|
| 583 |
+
'actor_state_dict': self.actor.state_dict(),
|
| 584 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 585 |
+
'scaler_state_dict': self.scaler.state_dict(), # 修复:保存scaler状态
|
| 586 |
+
'training_stats': self.training_stats,
|
| 587 |
+
'config': {
|
| 588 |
+
'kl_coef': self.kl_coef,
|
| 589 |
+
'group_size': self.group_size,
|
| 590 |
+
'clip_epsilon': self.clip_epsilon,
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
torch.save(checkpoint, path)
|
| 595 |
+
logger.info(f"Checkpoint saved to {path}")
|
| 596 |
+
|
| 597 |
+
def load_checkpoint(self, path: str):
|
| 598 |
+
"""加载训练checkpoint"""
|
| 599 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 600 |
+
|
| 601 |
+
self.actor.load_state_dict(checkpoint['actor_state_dict'])
|
| 602 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 603 |
+
|
| 604 |
+
# 修复:加载scaler状态
|
| 605 |
+
if 'scaler_state_dict' in checkpoint and self.use_amp:
|
| 606 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 607 |
+
|
| 608 |
+
self.training_stats = checkpoint['training_stats']
|
| 609 |
+
|
| 610 |
+
logger.info(f"Checkpoint loaded from {path}")
|
| 611 |
+
|
| 612 |
+
def print_training_summary(self):
|
| 613 |
+
"""打印训练摘要"""
|
| 614 |
+
logger.info("\n" + "="*80)
|
| 615 |
+
logger.info("Training Summary")
|
| 616 |
+
logger.info("="*80)
|
| 617 |
+
logger.info(f"Total Iterations: {self.training_stats['iterations']}")
|
| 618 |
+
logger.info(f"Total Samples: {self.training_stats['total_samples']}")
|
| 619 |
+
|
| 620 |
+
if self.training_stats['avg_rewards']:
|
| 621 |
+
logger.info(
|
| 622 |
+
f"Final Avg Reward: "
|
| 623 |
+
f"{self.training_stats['avg_rewards'][-1]:.4f}"
|
| 624 |
+
)
|
| 625 |
+
logger.info(
|
| 626 |
+
f"Reward Improvement: "
|
| 627 |
+
f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}"
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
logger.info("="*80 + "\n")
|
infer.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from flask import Flask, render_template, request, jsonify
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import json
|
| 12 |
+
import io
|
| 13 |
+
import base64
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
# 确保引入路径正确,根据你之前的文件结构
|
| 18 |
+
from model import MultiModalDenseTransformer
|
| 19 |
+
# 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
|
| 20 |
+
# 如果你移动了它,请修改这里的导入路径
|
| 21 |
+
from continual_learning import UnifiedMultiModalPreprocessor
|
| 22 |
+
# 如果没有 image_transform,我们需要在这里定义或导入
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
|
| 25 |
+
# 设置国内镜像
|
| 26 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 27 |
+
|
| 28 |
+
# 定义图像预处理 (与 training 保持一致)
|
| 29 |
+
image_transform = transforms.Compose([
|
| 30 |
+
transforms.Resize((224, 224)),
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 33 |
+
])
|
| 34 |
+
|
| 35 |
+
class ModelInference:
|
| 36 |
+
"""模型推理类"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
checkpoint_path: str,
|
| 41 |
+
tokenizer_name: str,
|
| 42 |
+
config_path: Optional[str] = None,
|
| 43 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
):
|
| 45 |
+
self.device = torch.device(device)
|
| 46 |
+
print(f"Using device: {self.device}")
|
| 47 |
+
|
| 48 |
+
# 1. 加载 Tokenizer (与预训练一致)
|
| 49 |
+
print(f"Loading tokenizer: {tokenizer_name}...")
|
| 50 |
+
try:
|
| 51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 52 |
+
tokenizer_name,
|
| 53 |
+
use_fast=True,
|
| 54 |
+
trust_remote_code=True
|
| 55 |
+
)
|
| 56 |
+
if self.tokenizer.pad_token is None:
|
| 57 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 58 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading tokenizer: {e}")
|
| 61 |
+
raise e
|
| 62 |
+
|
| 63 |
+
# 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
|
| 64 |
+
if config_path and Path(config_path).exists():
|
| 65 |
+
with open(config_path, 'r') as f:
|
| 66 |
+
self.config = json.load(f)
|
| 67 |
+
else:
|
| 68 |
+
# [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
|
| 69 |
+
self.config = {
|
| 70 |
+
'model_dim': 1536, # 预训练设置
|
| 71 |
+
'vocab_size': len(self.tokenizer), # 自动适配 Qwen (约 151665)
|
| 72 |
+
'n_layers': 12, # 预训练设置
|
| 73 |
+
'n_heads': 12, # 预训练设置
|
| 74 |
+
'n_kv_heads': 4, # 预训练设置
|
| 75 |
+
'head_dim': None, # 自动计算
|
| 76 |
+
'max_seq_len': 512, # 预训练设置
|
| 77 |
+
'dropout': 0.0, # 推理时关闭 dropout
|
| 78 |
+
'use_moe': False, # 预训练设置
|
| 79 |
+
'use_adapter': False, # 预训练未开启 Adapter
|
| 80 |
+
'use_lora': False, # 预训练未开启 LoRA
|
| 81 |
+
'rope_scaling_type': "yarn" # 预训练设置
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# 3. 初始化模型结构
|
| 85 |
+
print("Initializing model architecture...")
|
| 86 |
+
try:
|
| 87 |
+
self.model = MultiModalDenseTransformer(**self.config)
|
| 88 |
+
self.preprocessor = UnifiedMultiModalPreprocessor(
|
| 89 |
+
model_dim=self.config['model_dim']
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# 4. 加载权重
|
| 93 |
+
print(f"Loading checkpoint from {checkpoint_path}...")
|
| 94 |
+
# weights_only=False 是为了支持加载完整的 checkpoint 字典
|
| 95 |
+
checkpoint = torch.load(
|
| 96 |
+
checkpoint_path,
|
| 97 |
+
map_location=self.device,
|
| 98 |
+
weights_only=False
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# 提取 state_dict
|
| 102 |
+
if 'model_state_dict' in checkpoint:
|
| 103 |
+
print("Found 'model_state_dict' in checkpoint.")
|
| 104 |
+
state_dict = checkpoint['model_state_dict']
|
| 105 |
+
else:
|
| 106 |
+
state_dict = checkpoint
|
| 107 |
+
|
| 108 |
+
# 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
|
| 109 |
+
new_state_dict = {}
|
| 110 |
+
for k, v in state_dict.items():
|
| 111 |
+
if k.startswith('module.'):
|
| 112 |
+
new_state_dict[k[7:]] = v
|
| 113 |
+
else:
|
| 114 |
+
new_state_dict[k] = v
|
| 115 |
+
|
| 116 |
+
# 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
|
| 117 |
+
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
|
| 118 |
+
if missing:
|
| 119 |
+
print(f"Warning: Missing keys: {len(missing)}")
|
| 120 |
+
if unexpected:
|
| 121 |
+
print(f"Warning: Unexpected keys: {len(unexpected)}")
|
| 122 |
+
|
| 123 |
+
self.model.to(self.device)
|
| 124 |
+
self.preprocessor.to(self.device)
|
| 125 |
+
self.model.eval()
|
| 126 |
+
|
| 127 |
+
print("Model loaded successfully!")
|
| 128 |
+
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error initializing model: {e}")
|
| 132 |
+
raise e
|
| 133 |
+
|
| 134 |
+
@torch.no_grad()
|
| 135 |
+
def generate_text(
|
| 136 |
+
self,
|
| 137 |
+
prompt: str,
|
| 138 |
+
max_new_tokens: int = 128,
|
| 139 |
+
temperature: float = 0.7,
|
| 140 |
+
top_k: int = 40,
|
| 141 |
+
top_p: float = 0.9,
|
| 142 |
+
repetition_penalty: float = 1.1,
|
| 143 |
+
image: Optional[Image.Image] = None
|
| 144 |
+
) -> str:
|
| 145 |
+
"""生成文本"""
|
| 146 |
+
|
| 147 |
+
# 编码输入
|
| 148 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 149 |
+
input_ids = inputs['input_ids'].to(self.device)
|
| 150 |
+
|
| 151 |
+
# 构建 MultiModalDenseTransformer 需要的输入格式
|
| 152 |
+
input_data = {'segments': []}
|
| 153 |
+
|
| 154 |
+
# 处理图像
|
| 155 |
+
if image is not None:
|
| 156 |
+
if image.mode != 'RGB':
|
| 157 |
+
image = image.convert('RGB')
|
| 158 |
+
# 简单的图像处理
|
| 159 |
+
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
|
| 160 |
+
# 这里假设预处理器能处理这种输入
|
| 161 |
+
try:
|
| 162 |
+
# process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
|
| 163 |
+
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
|
| 164 |
+
# 将返回的 segment 列表合并到 input_data
|
| 165 |
+
for seg in mod_segments:
|
| 166 |
+
input_data['segments'].append(seg)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Warning: Image processing skipped due to error: {e}")
|
| 169 |
+
|
| 170 |
+
# 添加文本段
|
| 171 |
+
input_data['segments'].append({
|
| 172 |
+
'type': 'text',
|
| 173 |
+
'data': input_ids,
|
| 174 |
+
'modality_id': 0
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
# 生成
|
| 178 |
+
try:
|
| 179 |
+
# 使用模型自带的 generate 方法
|
| 180 |
+
generated_ids = self.model.generate(
|
| 181 |
+
input_data,
|
| 182 |
+
max_new_tokens=max_new_tokens,
|
| 183 |
+
temperature=temperature,
|
| 184 |
+
top_k=top_k,
|
| 185 |
+
top_p=top_p,
|
| 186 |
+
repetition_penalty=repetition_penalty,
|
| 187 |
+
do_sample=True,
|
| 188 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 189 |
+
pad_token_id=self.tokenizer.pad_token_id
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# 解码
|
| 193 |
+
# 注意:生成的 ids 可能包含原始输入,或者只包含新生成的 token
|
| 194 |
+
# MultiModalDenseTransformer.generate 通常返回完整的序列
|
| 195 |
+
generated_text = self.tokenizer.decode(
|
| 196 |
+
generated_ids[0],
|
| 197 |
+
skip_special_tokens=True
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# 如果包含 prompt,可以选择移除它只显示新内容
|
| 201 |
+
# if generated_text.startswith(prompt):
|
| 202 |
+
# generated_text = generated_text[len(prompt):]
|
| 203 |
+
|
| 204 |
+
return generated_text
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print(f"Generation error: {e}")
|
| 208 |
+
import traceback
|
| 209 |
+
traceback.print_exc()
|
| 210 |
+
return f"Error: {str(e)}"
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# 全局模型实例
|
| 214 |
+
model_instance = None
|
| 215 |
+
app = Flask(__name__)
|
| 216 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
| 217 |
+
|
| 218 |
+
@app.route('/')
|
| 219 |
+
def index():
|
| 220 |
+
display_config = model_instance.config.copy() if model_instance else {}
|
| 221 |
+
return render_template('index.html', config=display_config)
|
| 222 |
+
|
| 223 |
+
@app.route('/generate', methods=['POST'])
|
| 224 |
+
def generate():
|
| 225 |
+
try:
|
| 226 |
+
data = request.json
|
| 227 |
+
prompt = data.get('prompt', '')
|
| 228 |
+
if not prompt.strip():
|
| 229 |
+
return jsonify({'error': '请输入提示文本'}), 400
|
| 230 |
+
|
| 231 |
+
max_tokens = int(data.get('max_tokens', 100))
|
| 232 |
+
temperature = float(data.get('temperature', 0.7))
|
| 233 |
+
top_k = int(data.get('top_k', 40))
|
| 234 |
+
top_p = float(data.get('top_p', 0.9))
|
| 235 |
+
repetition_penalty = float(data.get('repetition_penalty', 1.1))
|
| 236 |
+
|
| 237 |
+
image = None
|
| 238 |
+
if 'image' in data and data['image']:
|
| 239 |
+
try:
|
| 240 |
+
image_data = base64.b64decode(data['image'].split(',')[1])
|
| 241 |
+
image = Image.open(io.BytesIO(image_data))
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Image load error: {e}")
|
| 244 |
+
|
| 245 |
+
output = model_instance.generate_text(
|
| 246 |
+
prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
|
| 247 |
+
)
|
| 248 |
+
return jsonify({'output': output})
|
| 249 |
+
|
| 250 |
+
except Exception as e:
|
| 251 |
+
return jsonify({'error': str(e)}), 500
|
| 252 |
+
|
| 253 |
+
def create_html_template():
|
| 254 |
+
"""写入HTML模板"""
|
| 255 |
+
html_content = '''
|
| 256 |
+
<!DOCTYPE html>
|
| 257 |
+
<html lang="zh-CN">
|
| 258 |
+
<head>
|
| 259 |
+
<meta charset="UTF-8">
|
| 260 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 261 |
+
<title>Model Inference</title>
|
| 262 |
+
<style>
|
| 263 |
+
body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
|
| 264 |
+
.container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
|
| 265 |
+
h1 { color: #1a73e8; text-align: center; }
|
| 266 |
+
textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
|
| 267 |
+
.controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
|
| 268 |
+
button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
|
| 269 |
+
button:hover { background: #1557b0; }
|
| 270 |
+
button:disabled { background: #ccc; }
|
| 271 |
+
#output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
|
| 272 |
+
.loading { color: #666; font-style: italic; }
|
| 273 |
+
</style>
|
| 274 |
+
</head>
|
| 275 |
+
<body>
|
| 276 |
+
<div class="container">
|
| 277 |
+
<h1>🚀 模型在线推理</h1>
|
| 278 |
+
|
| 279 |
+
<div>
|
| 280 |
+
<label><strong>提示词 (Prompt):</strong></label>
|
| 281 |
+
<textarea id="prompt" placeholder="请输入你的问题..."></textarea>
|
| 282 |
+
</div>
|
| 283 |
+
|
| 284 |
+
<div class="controls">
|
| 285 |
+
<div>
|
| 286 |
+
<label>Max Tokens: <span id="maxTokensVal">128</span></label>
|
| 287 |
+
<input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
|
| 288 |
+
</div>
|
| 289 |
+
<div>
|
| 290 |
+
<label>Temperature: <span id="tempVal">0.7</span></label>
|
| 291 |
+
<input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
|
| 292 |
+
</div>
|
| 293 |
+
</div>
|
| 294 |
+
|
| 295 |
+
<button id="btn" onclick="generate()">生成 (Generate)</button>
|
| 296 |
+
|
| 297 |
+
<div id="output">结果将显示在这里...</div>
|
| 298 |
+
</div>
|
| 299 |
+
|
| 300 |
+
<script>
|
| 301 |
+
async function generate() {
|
| 302 |
+
const prompt = document.getElementById('prompt').value;
|
| 303 |
+
if(!prompt) return alert("请输入内容");
|
| 304 |
+
|
| 305 |
+
const btn = document.getElementById('btn');
|
| 306 |
+
const out = document.getElementById('output');
|
| 307 |
+
|
| 308 |
+
btn.disabled = true;
|
| 309 |
+
btn.innerText = "生成中...";
|
| 310 |
+
out.innerHTML = '<div class="loading">正在思考中...</div>';
|
| 311 |
+
|
| 312 |
+
try {
|
| 313 |
+
const res = await fetch('/generate', {
|
| 314 |
+
method: 'POST',
|
| 315 |
+
headers: {'Content-Type': 'application/json'},
|
| 316 |
+
body: JSON.stringify({
|
| 317 |
+
prompt: prompt,
|
| 318 |
+
max_tokens: parseInt(document.getElementById('maxTokens').value),
|
| 319 |
+
temperature: parseFloat(document.getElementById('temperature').value)
|
| 320 |
+
})
|
| 321 |
+
});
|
| 322 |
+
const data = await res.json();
|
| 323 |
+
if(data.error) out.innerText = "Error: " + data.error;
|
| 324 |
+
else out.innerText = data.output;
|
| 325 |
+
} catch(e) {
|
| 326 |
+
out.innerText = "请求失败: " + e;
|
| 327 |
+
} finally {
|
| 328 |
+
btn.disabled = false;
|
| 329 |
+
btn.innerText = "生成 (Generate)";
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
</script>
|
| 333 |
+
</body>
|
| 334 |
+
</html>
|
| 335 |
+
'''
|
| 336 |
+
|
| 337 |
+
Path('templates').mkdir(exist_ok=True)
|
| 338 |
+
with open('templates/index.html', 'w', encoding='utf-8') as f:
|
| 339 |
+
f.write(html_content)
|
| 340 |
+
|
| 341 |
+
def main():
|
| 342 |
+
import argparse
|
| 343 |
+
parser = argparse.ArgumentParser()
|
| 344 |
+
# 默认指向 pretrain 保存的 checkpoint 路径
|
| 345 |
+
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt")
|
| 346 |
+
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
|
| 347 |
+
parser.add_argument("--port", type=int, default=5001)
|
| 348 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 349 |
+
args = parser.parse_args()
|
| 350 |
+
|
| 351 |
+
if not Path(args.checkpoint).exists():
|
| 352 |
+
# 尝试找最近的 step checkpoint
|
| 353 |
+
steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
|
| 354 |
+
if steps:
|
| 355 |
+
print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
|
| 356 |
+
args.checkpoint = str(steps[-1])
|
| 357 |
+
else:
|
| 358 |
+
print(f"错误: 找不到检查点文件: {args.checkpoint}")
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
create_html_template()
|
| 362 |
+
|
| 363 |
+
global model_instance
|
| 364 |
+
model_instance = ModelInference(args.checkpoint, args.tokenizer)
|
| 365 |
+
|
| 366 |
+
print(f"\n服务已启动: http://{args.host}:{args.port}")
|
| 367 |
+
app.run(host=args.host, port=args.port,
|
| 368 |
+
debug=True, # 开启调试模式
|
| 369 |
+
use_reloader=False)
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
main()
|
infer_sft.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from flask import Flask, render_template, request, jsonify
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import json
|
| 12 |
+
import io
|
| 13 |
+
import base64
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
# 确保引入路径正确,根据你之前的文件结构
|
| 18 |
+
from model import MultiModalDenseTransformer
|
| 19 |
+
# 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
|
| 20 |
+
# 如果你移动了它,请修改这里的导入路径
|
| 21 |
+
from continual_learning import UnifiedMultiModalPreprocessor
|
| 22 |
+
# 如果没有 image_transform,我们需要在这里定义或导入
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
|
| 25 |
+
# 设置国内镜像
|
| 26 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 27 |
+
|
| 28 |
+
# 定义图像预处理 (与 training 保持一致)
|
| 29 |
+
image_transform = transforms.Compose([
|
| 30 |
+
transforms.Resize((224, 224)),
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 33 |
+
])
|
| 34 |
+
|
| 35 |
+
class ModelInference:
|
| 36 |
+
"""模型推理类"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
checkpoint_path: str,
|
| 41 |
+
tokenizer_name: str,
|
| 42 |
+
config_path: Optional[str] = None,
|
| 43 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
):
|
| 45 |
+
self.device = torch.device(device)
|
| 46 |
+
print(f"Using device: {self.device}")
|
| 47 |
+
|
| 48 |
+
# 1. 加载 Tokenizer (与预训练一致)
|
| 49 |
+
print(f"Loading tokenizer: {tokenizer_name}...")
|
| 50 |
+
try:
|
| 51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 52 |
+
tokenizer_name,
|
| 53 |
+
use_fast=True,
|
| 54 |
+
trust_remote_code=True
|
| 55 |
+
)
|
| 56 |
+
if self.tokenizer.pad_token is None:
|
| 57 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 58 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading tokenizer: {e}")
|
| 61 |
+
raise e
|
| 62 |
+
|
| 63 |
+
# 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
|
| 64 |
+
if config_path and Path(config_path).exists():
|
| 65 |
+
with open(config_path, 'r') as f:
|
| 66 |
+
self.config = json.load(f)
|
| 67 |
+
else:
|
| 68 |
+
# [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
|
| 69 |
+
self.config = {
|
| 70 |
+
'model_dim': 1536, # 预训练设置
|
| 71 |
+
'vocab_size': len(self.tokenizer), # 自动适配 Qwen (约 151665)
|
| 72 |
+
'n_layers': 12, # 预训练设置
|
| 73 |
+
'n_heads': 12, # 预训练设置
|
| 74 |
+
'n_kv_heads': 4, # 预训练设置
|
| 75 |
+
'head_dim': None, # 自动计算
|
| 76 |
+
'max_seq_len': 512, # 预训练设置
|
| 77 |
+
'dropout': 0.0, # 推理时关闭 dropout
|
| 78 |
+
'use_moe': False, # 预训练设置
|
| 79 |
+
'use_adapter': False, # 预训练未开启 Adapter
|
| 80 |
+
'use_lora': False, # 预训练未开启 LoRA
|
| 81 |
+
'rope_scaling_type': "yarn", # 预训练设置
|
| 82 |
+
'use_multimodal_fusion': False,
|
| 83 |
+
'use_contrastive': False
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# 3. 初始化模型结构
|
| 87 |
+
print("Initializing model architecture...")
|
| 88 |
+
try:
|
| 89 |
+
self.model = MultiModalDenseTransformer(**self.config)
|
| 90 |
+
self.preprocessor = UnifiedMultiModalPreprocessor(
|
| 91 |
+
model_dim=self.config['model_dim']
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 4. 加载权重
|
| 95 |
+
print(f"Loading checkpoint from {checkpoint_path}...")
|
| 96 |
+
# weights_only=False 是为了支持加载完整的 checkpoint 字典
|
| 97 |
+
checkpoint = torch.load(
|
| 98 |
+
checkpoint_path,
|
| 99 |
+
map_location=self.device,
|
| 100 |
+
weights_only=False
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 提取 state_dict
|
| 104 |
+
if 'model_state_dict' in checkpoint:
|
| 105 |
+
print("Found 'model_state_dict' in checkpoint.")
|
| 106 |
+
state_dict = checkpoint['model_state_dict']
|
| 107 |
+
else:
|
| 108 |
+
state_dict = checkpoint
|
| 109 |
+
|
| 110 |
+
# 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
|
| 111 |
+
new_state_dict = {}
|
| 112 |
+
for k, v in state_dict.items():
|
| 113 |
+
if k.startswith('module.'):
|
| 114 |
+
new_state_dict[k[7:]] = v
|
| 115 |
+
else:
|
| 116 |
+
new_state_dict[k] = v
|
| 117 |
+
|
| 118 |
+
# 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
|
| 119 |
+
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
|
| 120 |
+
if missing:
|
| 121 |
+
print(f"Warning: Missing keys: {len(missing)}")
|
| 122 |
+
if unexpected:
|
| 123 |
+
print(f"Warning: Unexpected keys: {len(unexpected)}")
|
| 124 |
+
|
| 125 |
+
self.model.to(self.device)
|
| 126 |
+
self.preprocessor.to(self.device)
|
| 127 |
+
self.model.eval()
|
| 128 |
+
|
| 129 |
+
print("Model loaded successfully!")
|
| 130 |
+
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error initializing model: {e}")
|
| 134 |
+
raise e
|
| 135 |
+
|
| 136 |
+
@torch.no_grad()
|
| 137 |
+
def generate_text(
|
| 138 |
+
self,
|
| 139 |
+
prompt: str,
|
| 140 |
+
max_new_tokens: int = 128,
|
| 141 |
+
temperature: float = 0.7,
|
| 142 |
+
top_k: int = 10,
|
| 143 |
+
top_p: float = 0.9,
|
| 144 |
+
repetition_penalty: float = 1.2,
|
| 145 |
+
image: Optional[Image.Image] = None
|
| 146 |
+
) -> str:
|
| 147 |
+
"""生成文本"""
|
| 148 |
+
formatted_prompt = f"Instruction: {prompt}\nResponse:"
|
| 149 |
+
# 编码输入
|
| 150 |
+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
|
| 151 |
+
|
| 152 |
+
# 编码输入
|
| 153 |
+
#inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 154 |
+
input_ids = inputs['input_ids'].to(self.device)
|
| 155 |
+
|
| 156 |
+
# 构建 MultiModalDenseTransformer 需要的输入格式
|
| 157 |
+
input_data = {'segments': []}
|
| 158 |
+
|
| 159 |
+
# 处理图像
|
| 160 |
+
if image is not None:
|
| 161 |
+
if image.mode != 'RGB':
|
| 162 |
+
image = image.convert('RGB')
|
| 163 |
+
# 简单的图像处理
|
| 164 |
+
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
|
| 165 |
+
# 这里假设预处理器能处理这种输入
|
| 166 |
+
try:
|
| 167 |
+
# process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
|
| 168 |
+
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
|
| 169 |
+
# 将返回的 segment 列表合并到 input_data
|
| 170 |
+
for seg in mod_segments:
|
| 171 |
+
input_data['segments'].append(seg)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"Warning: Image processing skipped due to error: {e}")
|
| 174 |
+
|
| 175 |
+
# 添加文本段
|
| 176 |
+
input_data['segments'].append({
|
| 177 |
+
'type': 'text',
|
| 178 |
+
'data': input_ids,
|
| 179 |
+
'modality_id': 0
|
| 180 |
+
})
|
| 181 |
+
|
| 182 |
+
# 生成
|
| 183 |
+
try:
|
| 184 |
+
# 使用模型自带的 generate 方法
|
| 185 |
+
generated_ids = self.model.generate(
|
| 186 |
+
input_data,
|
| 187 |
+
max_new_tokens=max_new_tokens,
|
| 188 |
+
temperature=temperature,
|
| 189 |
+
top_k=top_k,
|
| 190 |
+
top_p=top_p,
|
| 191 |
+
repetition_penalty=repetition_penalty,
|
| 192 |
+
do_sample=True,
|
| 193 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 194 |
+
pad_token_id=self.tokenizer.pad_token_id
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# 3. 解码
|
| 198 |
+
full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 199 |
+
print(f"\n====== [DEBUG 原始输出] ======\n{full_output}\n==============================\n")
|
| 200 |
+
# 4. [关键修改] 截断逻辑 (Stop Logic)
|
| 201 |
+
# 提取 Response 之后的部分
|
| 202 |
+
if "Response:" in full_output:
|
| 203 |
+
answer = full_output.split("Response:")[-1].strip()
|
| 204 |
+
else:
|
| 205 |
+
answer = full_output
|
| 206 |
+
|
| 207 |
+
# 定义停止词列表 (根据你的图,模型喜欢生成 Instructions: 或 Ingredients:)
|
| 208 |
+
stop_words = [
|
| 209 |
+
"Instruction", "Input", "###", "Response",
|
| 210 |
+
"User:", "Assistant:", "\n\n" # 双换行通常意味着一段结束
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
for stop_word in stop_words:
|
| 214 |
+
if stop_word in answer:
|
| 215 |
+
answer = answer.split(stop_word)[0].strip()
|
| 216 |
+
|
| 217 |
+
# 3. [新增] 强制去除首行重复 (解决 Echo 问题)
|
| 218 |
+
# 如果模型第一句就是重复 Prompt,去掉它
|
| 219 |
+
lines = answer.split('\n')
|
| 220 |
+
if len(lines) > 0 and prompt.lower() in lines[0].lower():
|
| 221 |
+
answer = "\n".join(lines[1:]).strip()
|
| 222 |
+
return answer
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print(f"Generation error: {e}")
|
| 226 |
+
import traceback
|
| 227 |
+
traceback.print_exc()
|
| 228 |
+
return f"Error: {str(e)}"
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# 全局模型实例
|
| 232 |
+
model_instance = None
|
| 233 |
+
app = Flask(__name__)
|
| 234 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
| 235 |
+
|
| 236 |
+
@app.route('/')
|
| 237 |
+
def index():
|
| 238 |
+
display_config = model_instance.config.copy() if model_instance else {}
|
| 239 |
+
return render_template('index.html', config=display_config)
|
| 240 |
+
|
| 241 |
+
@app.route('/generate', methods=['POST'])
|
| 242 |
+
def generate():
|
| 243 |
+
try:
|
| 244 |
+
data = request.json
|
| 245 |
+
prompt = data.get('prompt', '')
|
| 246 |
+
if not prompt.strip():
|
| 247 |
+
return jsonify({'error': '请输入提示文本'}), 400
|
| 248 |
+
|
| 249 |
+
max_tokens = int(data.get('max_tokens', 100))
|
| 250 |
+
temperature = float(data.get('temperature', 0.7))
|
| 251 |
+
top_k = int(data.get('top_k', 40))
|
| 252 |
+
top_p = float(data.get('top_p', 0.9))
|
| 253 |
+
repetition_penalty = float(data.get('repetition_penalty', 1.1))
|
| 254 |
+
|
| 255 |
+
image = None
|
| 256 |
+
if 'image' in data and data['image']:
|
| 257 |
+
try:
|
| 258 |
+
image_data = base64.b64decode(data['image'].split(',')[1])
|
| 259 |
+
image = Image.open(io.BytesIO(image_data))
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"Image load error: {e}")
|
| 262 |
+
|
| 263 |
+
output = model_instance.generate_text(
|
| 264 |
+
prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
|
| 265 |
+
)
|
| 266 |
+
return jsonify({'output': output})
|
| 267 |
+
|
| 268 |
+
except Exception as e:
|
| 269 |
+
return jsonify({'error': str(e)}), 500
|
| 270 |
+
|
| 271 |
+
def create_html_template():
|
| 272 |
+
"""写入HTML模板"""
|
| 273 |
+
html_content = '''
|
| 274 |
+
<!DOCTYPE html>
|
| 275 |
+
<html lang="zh-CN">
|
| 276 |
+
<head>
|
| 277 |
+
<meta charset="UTF-8">
|
| 278 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 279 |
+
<title>Model Inference</title>
|
| 280 |
+
<style>
|
| 281 |
+
body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
|
| 282 |
+
.container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
|
| 283 |
+
h1 { color: #1a73e8; text-align: center; }
|
| 284 |
+
textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
|
| 285 |
+
.controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
|
| 286 |
+
button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
|
| 287 |
+
button:hover { background: #1557b0; }
|
| 288 |
+
button:disabled { background: #ccc; }
|
| 289 |
+
#output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
|
| 290 |
+
.loading { color: #666; font-style: italic; }
|
| 291 |
+
</style>
|
| 292 |
+
</head>
|
| 293 |
+
<body>
|
| 294 |
+
<div class="container">
|
| 295 |
+
<h1>🚀 模型在线推理</h1>
|
| 296 |
+
|
| 297 |
+
<div>
|
| 298 |
+
<label><strong>提示词 (Prompt):</strong></label>
|
| 299 |
+
<textarea id="prompt" placeholder="请输入你的问题..."></textarea>
|
| 300 |
+
</div>
|
| 301 |
+
|
| 302 |
+
<div class="controls">
|
| 303 |
+
<div>
|
| 304 |
+
<label>Max Tokens: <span id="maxTokensVal">128</span></label>
|
| 305 |
+
<input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
|
| 306 |
+
</div>
|
| 307 |
+
<div>
|
| 308 |
+
<label>Temperature: <span id="tempVal">0.7</span></label>
|
| 309 |
+
<input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
|
| 310 |
+
</div>
|
| 311 |
+
</div>
|
| 312 |
+
|
| 313 |
+
<button id="btn" onclick="generate()">生成 (Generate)</button>
|
| 314 |
+
|
| 315 |
+
<div id="output">结果将显示在这里...</div>
|
| 316 |
+
</div>
|
| 317 |
+
|
| 318 |
+
<script>
|
| 319 |
+
async function generate() {
|
| 320 |
+
const prompt = document.getElementById('prompt').value;
|
| 321 |
+
if(!prompt) return alert("请输入内容");
|
| 322 |
+
|
| 323 |
+
const btn = document.getElementById('btn');
|
| 324 |
+
const out = document.getElementById('output');
|
| 325 |
+
|
| 326 |
+
btn.disabled = true;
|
| 327 |
+
btn.innerText = "生成中...";
|
| 328 |
+
out.innerHTML = '<div class="loading">正在思考中...</div>';
|
| 329 |
+
|
| 330 |
+
try {
|
| 331 |
+
const res = await fetch('/generate', {
|
| 332 |
+
method: 'POST',
|
| 333 |
+
headers: {'Content-Type': 'application/json'},
|
| 334 |
+
body: JSON.stringify({
|
| 335 |
+
prompt: prompt,
|
| 336 |
+
max_tokens: parseInt(document.getElementById('maxTokens').value),
|
| 337 |
+
temperature: parseFloat(document.getElementById('temperature').value)
|
| 338 |
+
})
|
| 339 |
+
});
|
| 340 |
+
const data = await res.json();
|
| 341 |
+
if(data.error) out.innerText = "Error: " + data.error;
|
| 342 |
+
else out.innerText = data.output;
|
| 343 |
+
} catch(e) {
|
| 344 |
+
out.innerText = "请求失败: " + e;
|
| 345 |
+
} finally {
|
| 346 |
+
btn.disabled = false;
|
| 347 |
+
btn.innerText = "生成 (Generate)";
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
</script>
|
| 351 |
+
</body>
|
| 352 |
+
</html>
|
| 353 |
+
'''
|
| 354 |
+
|
| 355 |
+
Path('templates').mkdir(exist_ok=True)
|
| 356 |
+
with open('templates/index.html', 'w', encoding='utf-8') as f:
|
| 357 |
+
f.write(html_content)
|
| 358 |
+
|
| 359 |
+
def main():
|
| 360 |
+
import argparse
|
| 361 |
+
parser = argparse.ArgumentParser()
|
| 362 |
+
# 默认指向 pretrain 保存的 checkpoint 路径
|
| 363 |
+
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
|
| 364 |
+
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
|
| 365 |
+
parser.add_argument("--port", type=int, default=5001)
|
| 366 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 367 |
+
args = parser.parse_args()
|
| 368 |
+
|
| 369 |
+
if not Path(args.checkpoint).exists():
|
| 370 |
+
# 尝试找最近的 step checkpoint
|
| 371 |
+
steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
|
| 372 |
+
if steps:
|
| 373 |
+
print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
|
| 374 |
+
args.checkpoint = str(steps[-1])
|
| 375 |
+
else:
|
| 376 |
+
print(f"错误: 找不到检查点文件: {args.checkpoint}")
|
| 377 |
+
return
|
| 378 |
+
# ----------------- 新增部分开始 -----------------
|
| 379 |
+
try:
|
| 380 |
+
from pyngrok import ngrok, conf
|
| 381 |
+
|
| 382 |
+
# 如果你在国内,ngrok 连接慢,可以配置 region='ap' (亚太) 或 'au' (澳洲)
|
| 383 |
+
# conf.get_default().region = "ap"
|
| 384 |
+
|
| 385 |
+
# 建立隧道,映射 5001 端口
|
| 386 |
+
public_url = ngrok.connect(args.port).public_url
|
| 387 |
+
print(f"\n========================================")
|
| 388 |
+
print(f"🎉 公网访问地址 (发给朋友): {public_url}")
|
| 389 |
+
print(f"========================================\n")
|
| 390 |
+
except ImportError:
|
| 391 |
+
print("未安装 pyngrok,无法自动生成公网链接。")
|
| 392 |
+
print("提示: pip install pyngrok")
|
| 393 |
+
except Exception as e:
|
| 394 |
+
print(f"Ngrok 启动失败: {e}")
|
| 395 |
+
# ----------------- 新增部分结束 -----------------
|
| 396 |
+
create_html_template()
|
| 397 |
+
|
| 398 |
+
global model_instance
|
| 399 |
+
model_instance = ModelInference(args.checkpoint, args.tokenizer)
|
| 400 |
+
|
| 401 |
+
print(f"\n服务已启动: http://{args.host}:{args.port}")
|
| 402 |
+
app.run(host=args.host, port=args.port,
|
| 403 |
+
debug=True, # 开启调试模式
|
| 404 |
+
use_reloader=False)
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
改进的多模态Dense Transformer主模型
|
| 3 |
+
整合所有SOTA改进
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import List, Dict, Optional, Tuple
|
| 9 |
+
import math
|
| 10 |
+
from components import RMSNorm
|
| 11 |
+
from transformer import OptimizedTransformerBlock
|
| 12 |
+
from multimodel_fusion import MultiModalFusionModule
|
| 13 |
+
from encoders import (
|
| 14 |
+
ImprovedVisionTransformer,
|
| 15 |
+
ImprovedAudioEncoder,
|
| 16 |
+
ImprovedVideoEncoder
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
class MultiModalDenseTransformer(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
改进的统一多模态Dense Transformer
|
| 22 |
+
主要改进:
|
| 23 |
+
1. 深度跨模态融合
|
| 24 |
+
2. 模态特定的优化编码器
|
| 25 |
+
3. 对比学习对齐
|
| 26 |
+
4. 改进的位置编码和注意力机制
|
| 27 |
+
5. 更好的训练稳定性
|
| 28 |
+
"""
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
model_dim: int = 2048,
|
| 32 |
+
vocab_size: int = 30000,
|
| 33 |
+
n_layers: int = 48,
|
| 34 |
+
n_heads: int = 32,
|
| 35 |
+
n_kv_heads: Optional[int] = None,
|
| 36 |
+
head_dim: Optional[int] = None,
|
| 37 |
+
max_seq_len: int = 8192,
|
| 38 |
+
dropout: float = 0.0,
|
| 39 |
+
attn_dropout: float = 0.0,
|
| 40 |
+
|
| 41 |
+
# MoE配置
|
| 42 |
+
use_moe: bool = False,
|
| 43 |
+
num_experts: int = 8,
|
| 44 |
+
moe_top_k: int = 2,
|
| 45 |
+
moe_layers: Optional[List[int]] = None,
|
| 46 |
+
|
| 47 |
+
# PEFT配置
|
| 48 |
+
use_adapter: bool = False,
|
| 49 |
+
adapter_dim: int = 64,
|
| 50 |
+
use_lora: bool = False,
|
| 51 |
+
lora_rank: int = 8,
|
| 52 |
+
|
| 53 |
+
# 训练配置
|
| 54 |
+
use_gradient_checkpointing: bool = False,
|
| 55 |
+
use_parallel_residual: bool = False,
|
| 56 |
+
|
| 57 |
+
# 位置编码
|
| 58 |
+
rope_scaling_factor: float = 1.0,
|
| 59 |
+
rope_scaling_type: str = "yarn",
|
| 60 |
+
sliding_window: Optional[int] = None,
|
| 61 |
+
|
| 62 |
+
# 规范化
|
| 63 |
+
norm_eps: float = 1e-6,
|
| 64 |
+
initializer_range: float = 0.02,
|
| 65 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 66 |
+
tie_word_embeddings: bool = True,
|
| 67 |
+
|
| 68 |
+
# 多模态配置
|
| 69 |
+
use_multimodal_fusion: bool = True,
|
| 70 |
+
fusion_layers: int = 4,
|
| 71 |
+
use_contrastive: bool = True,
|
| 72 |
+
vision_depth: int = 24,
|
| 73 |
+
audio_depth: int = 12,
|
| 74 |
+
video_spatial_depth: int = 12,
|
| 75 |
+
video_temporal_depth: int = 4
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.model_dim = model_dim
|
| 80 |
+
self.vocab_size = vocab_size
|
| 81 |
+
self.n_layers = n_layers
|
| 82 |
+
self.max_seq_len = max_seq_len
|
| 83 |
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
| 84 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 85 |
+
self.use_multimodal_fusion = use_multimodal_fusion
|
| 86 |
+
|
| 87 |
+
# Token embedding
|
| 88 |
+
self.token_embedding = nn.Embedding(vocab_size, model_dim)
|
| 89 |
+
self.modality_embedding = nn.Embedding(4, model_dim)
|
| 90 |
+
self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 91 |
+
|
| 92 |
+
# 改进的模态编码器
|
| 93 |
+
self.vision_encoder = ImprovedVisionTransformer(
|
| 94 |
+
embed_dim=model_dim,
|
| 95 |
+
depth=vision_depth,
|
| 96 |
+
n_heads=n_heads,
|
| 97 |
+
dropout=dropout,
|
| 98 |
+
use_adapter=use_adapter,
|
| 99 |
+
adapter_dim=adapter_dim
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.audio_encoder = ImprovedAudioEncoder(
|
| 103 |
+
embed_dim=model_dim,
|
| 104 |
+
depth=audio_depth,
|
| 105 |
+
n_heads=n_heads,
|
| 106 |
+
dropout=dropout,
|
| 107 |
+
use_adapter=use_adapter,
|
| 108 |
+
adapter_dim=adapter_dim
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.video_encoder = ImprovedVideoEncoder(
|
| 112 |
+
embed_dim=model_dim,
|
| 113 |
+
spatial_depth=video_spatial_depth,
|
| 114 |
+
temporal_depth=video_temporal_depth,
|
| 115 |
+
n_heads=n_heads,
|
| 116 |
+
dropout=dropout,
|
| 117 |
+
use_adapter=use_adapter,
|
| 118 |
+
adapter_dim=adapter_dim
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# 多模态融合模块
|
| 122 |
+
if use_multimodal_fusion:
|
| 123 |
+
self.fusion_module = MultiModalFusionModule(
|
| 124 |
+
dim=model_dim,
|
| 125 |
+
num_fusion_layers=fusion_layers,
|
| 126 |
+
n_heads=n_heads,
|
| 127 |
+
dropout=dropout,
|
| 128 |
+
use_contrastive=use_contrastive
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Transformer layers
|
| 132 |
+
if moe_layers is None and use_moe:
|
| 133 |
+
moe_layers = list(range(n_layers // 2, n_layers))
|
| 134 |
+
elif moe_layers is None:
|
| 135 |
+
moe_layers = []
|
| 136 |
+
|
| 137 |
+
self.layers = nn.ModuleList([
|
| 138 |
+
OptimizedTransformerBlock(
|
| 139 |
+
dim=model_dim,
|
| 140 |
+
n_heads=n_heads,
|
| 141 |
+
n_kv_heads=n_kv_heads,
|
| 142 |
+
head_dim=head_dim,
|
| 143 |
+
dropout=dropout,
|
| 144 |
+
attn_dropout=attn_dropout,
|
| 145 |
+
use_moe=(use_moe and i in moe_layers),
|
| 146 |
+
num_experts=num_experts,
|
| 147 |
+
moe_top_k=moe_top_k,
|
| 148 |
+
use_adapter=use_adapter,
|
| 149 |
+
adapter_dim=adapter_dim,
|
| 150 |
+
use_lora=use_lora,
|
| 151 |
+
lora_rank=lora_rank,
|
| 152 |
+
use_parallel_residual=use_parallel_residual,
|
| 153 |
+
norm_eps=norm_eps,
|
| 154 |
+
sliding_window=sliding_window,
|
| 155 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 156 |
+
layer_idx=i
|
| 157 |
+
)
|
| 158 |
+
for i in range(n_layers)
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
self.norm = RMSNorm(model_dim, eps=norm_eps)
|
| 162 |
+
self.lm_head = nn.Linear(model_dim, vocab_size, bias=False)
|
| 163 |
+
|
| 164 |
+
if tie_word_embeddings:
|
| 165 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 166 |
+
|
| 167 |
+
self.initializer_range = initializer_range
|
| 168 |
+
self.apply(self._init_weights)
|
| 169 |
+
|
| 170 |
+
if not tie_word_embeddings:
|
| 171 |
+
self._init_lm_head()
|
| 172 |
+
|
| 173 |
+
self.n_params = sum(p.numel() for p in self.parameters())
|
| 174 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 175 |
+
|
| 176 |
+
print(f"\n{'='*80}")
|
| 177 |
+
print(f"Improved Model Configuration:")
|
| 178 |
+
print(f" Model Dimension: {model_dim}")
|
| 179 |
+
print(f" Vocab Size: {vocab_size}")
|
| 180 |
+
print(f" Layers: {n_layers}")
|
| 181 |
+
print(f" Attention Heads: {n_heads}")
|
| 182 |
+
print(f" KV Heads: {n_kv_heads if n_kv_heads else n_heads}")
|
| 183 |
+
print(f" Max Sequence Length: {max_seq_len}")
|
| 184 |
+
print(f" Multimodal Fusion: {use_multimodal_fusion}")
|
| 185 |
+
print(f" Contrastive Learning: {use_contrastive}")
|
| 186 |
+
print(f" MoE: {use_moe} (Experts: {num_experts}, Top-K: {moe_top_k})")
|
| 187 |
+
print(f" Total Parameters: {self.n_params / 1e9:.2f}B")
|
| 188 |
+
print(f" Trainable Parameters: {trainable_params / 1e9:.2f}B")
|
| 189 |
+
print(f"{'='*80}\n")
|
| 190 |
+
|
| 191 |
+
def _init_weights(self, module):
|
| 192 |
+
"""权重初始化"""
|
| 193 |
+
if isinstance(module, nn.Linear):
|
| 194 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
|
| 195 |
+
if module.bias is not None:
|
| 196 |
+
torch.nn.init.zeros_(module.bias)
|
| 197 |
+
elif isinstance(module, nn.Embedding):
|
| 198 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range)
|
| 199 |
+
if hasattr(module, 'padding_idx') and module.padding_idx is not None:
|
| 200 |
+
module.weight.data[module.padding_idx].zero_()
|
| 201 |
+
|
| 202 |
+
def _init_lm_head(self):
|
| 203 |
+
"""初始化LM head"""
|
| 204 |
+
std = self.initializer_range / math.sqrt(2 * self.n_layers)
|
| 205 |
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=std)
|
| 206 |
+
|
| 207 |
+
def _encode_modality(self, segment: Dict) -> torch.Tensor:
|
| 208 |
+
"""编码单个模态"""
|
| 209 |
+
seg_type = segment['type']
|
| 210 |
+
seg_data = segment['data']
|
| 211 |
+
|
| 212 |
+
if seg_type == 'image':
|
| 213 |
+
return self.vision_encoder(seg_data)
|
| 214 |
+
elif seg_type == 'audio':
|
| 215 |
+
return self.audio_encoder(seg_data)
|
| 216 |
+
elif seg_type == 'video':
|
| 217 |
+
return self.video_encoder(seg_data)
|
| 218 |
+
elif seg_type == 'text':
|
| 219 |
+
return self.token_embedding(seg_data)
|
| 220 |
+
else:
|
| 221 |
+
return seg_data
|
| 222 |
+
|
| 223 |
+
def forward(
|
| 224 |
+
self,
|
| 225 |
+
input_data: Dict,
|
| 226 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 227 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 228 |
+
return_hidden: bool = False,
|
| 229 |
+
use_cache: bool = False,
|
| 230 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 231 |
+
output_attentions: bool = False,
|
| 232 |
+
output_hidden_states: bool = False,
|
| 233 |
+
compute_contrastive: bool = False
|
| 234 |
+
) -> Dict:
|
| 235 |
+
"""前向传播"""
|
| 236 |
+
device = self.token_embedding.weight.device
|
| 237 |
+
|
| 238 |
+
# 编码每个模态
|
| 239 |
+
encoded_segments = []
|
| 240 |
+
for segment in input_data.get('segments', []):
|
| 241 |
+
encoded = self._encode_modality(segment)
|
| 242 |
+
|
| 243 |
+
# 添加模态嵌入
|
| 244 |
+
modality_id = segment.get('modality_id', 0)
|
| 245 |
+
modality_embeds = self.modality_embedding(
|
| 246 |
+
torch.tensor([modality_id], device=device)
|
| 247 |
+
).expand(encoded.shape[0], encoded.shape[1], -1)
|
| 248 |
+
|
| 249 |
+
encoded_segments.append({
|
| 250 |
+
'type': segment['type'],
|
| 251 |
+
'data': encoded + modality_embeds,
|
| 252 |
+
'modality_id': modality_id
|
| 253 |
+
})
|
| 254 |
+
|
| 255 |
+
# 多模态融合
|
| 256 |
+
contrastive_losses = {}
|
| 257 |
+
if self.use_multimodal_fusion and len(encoded_segments) > 1:
|
| 258 |
+
fusion_output = self.fusion_module(
|
| 259 |
+
encoded_segments,
|
| 260 |
+
compute_contrastive=compute_contrastive
|
| 261 |
+
)
|
| 262 |
+
x = fusion_output['fused_features']
|
| 263 |
+
contrastive_losses = fusion_output.get('contrastive_losses', {})
|
| 264 |
+
else:
|
| 265 |
+
# 简单拼接
|
| 266 |
+
all_embeddings = [seg['data'] for seg in encoded_segments]
|
| 267 |
+
x = torch.cat(all_embeddings, dim=1) if all_embeddings else torch.zeros(
|
| 268 |
+
1, 1, self.model_dim, device=device
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
x = self.embed_dropout(x)
|
| 272 |
+
# 如果��有传入 position_ids,我们需要根据历史长度生成它
|
| 273 |
+
if position_ids is None:
|
| 274 |
+
if past_key_values is not None:
|
| 275 |
+
# 缓存的长度 (KV cache 的 shape 是 [B, H, SeqLen, D])
|
| 276 |
+
past_length = past_key_values[0][0].size(2)
|
| 277 |
+
# 当前输入的长度
|
| 278 |
+
seq_length = x.shape[1]
|
| 279 |
+
# 生成正确的位置索引: [past_length, past_length + 1, ...]
|
| 280 |
+
position_ids = torch.arange(
|
| 281 |
+
past_length, past_length + seq_length, dtype=torch.long, device=device
|
| 282 |
+
).unsqueeze(0).expand(x.shape[0], -1)
|
| 283 |
+
else:
|
| 284 |
+
# 如果没有缓存,从 0 开始
|
| 285 |
+
seq_length = x.shape[1]
|
| 286 |
+
position_ids = torch.arange(
|
| 287 |
+
0, seq_length, dtype=torch.long, device=device
|
| 288 |
+
).unsqueeze(0).expand(x.shape[0], -1)
|
| 289 |
+
# Transformer层
|
| 290 |
+
present_key_values = [] if use_cache else None
|
| 291 |
+
all_hidden_states = [] if output_hidden_states else None
|
| 292 |
+
all_attentions = [] if output_attentions else None
|
| 293 |
+
moe_aux_loss = torch.tensor(0.0, device=device)
|
| 294 |
+
|
| 295 |
+
for idx, layer in enumerate(self.layers):
|
| 296 |
+
if output_hidden_states:
|
| 297 |
+
all_hidden_states.append(x)
|
| 298 |
+
|
| 299 |
+
past_kv = past_key_values[idx] if past_key_values is not None else None
|
| 300 |
+
|
| 301 |
+
if self.use_gradient_checkpointing and self.training:
|
| 302 |
+
def create_custom_forward(module):
|
| 303 |
+
def custom_forward(*inputs):
|
| 304 |
+
return module(
|
| 305 |
+
inputs[0],
|
| 306 |
+
attention_mask=inputs[1],
|
| 307 |
+
position_ids=inputs[2],
|
| 308 |
+
use_cache=False,
|
| 309 |
+
past_kv=None,
|
| 310 |
+
output_attentions=False
|
| 311 |
+
)
|
| 312 |
+
return custom_forward
|
| 313 |
+
|
| 314 |
+
import torch.utils.checkpoint as checkpoint
|
| 315 |
+
layer_outputs = checkpoint.checkpoint(
|
| 316 |
+
create_custom_forward(layer),
|
| 317 |
+
x,
|
| 318 |
+
attention_mask,
|
| 319 |
+
position_ids,
|
| 320 |
+
use_reentrant=False
|
| 321 |
+
)
|
| 322 |
+
x = layer_outputs[0]
|
| 323 |
+
present_kv = None
|
| 324 |
+
attn_weights = None
|
| 325 |
+
else:
|
| 326 |
+
layer_outputs = layer(
|
| 327 |
+
x,
|
| 328 |
+
attention_mask=attention_mask,
|
| 329 |
+
position_ids=position_ids,
|
| 330 |
+
use_cache=use_cache,
|
| 331 |
+
past_kv=past_kv,
|
| 332 |
+
output_attentions=output_attentions
|
| 333 |
+
)
|
| 334 |
+
x, present_kv, attn_weights = layer_outputs
|
| 335 |
+
|
| 336 |
+
if use_cache:
|
| 337 |
+
present_key_values.append(present_kv)
|
| 338 |
+
|
| 339 |
+
if output_attentions:
|
| 340 |
+
all_attentions.append(attn_weights)
|
| 341 |
+
|
| 342 |
+
if hasattr(layer, 'moe_aux_loss'):
|
| 343 |
+
moe_aux_loss += layer.moe_aux_loss
|
| 344 |
+
|
| 345 |
+
hidden_states = self.norm(x)
|
| 346 |
+
logits = self.lm_head(hidden_states)
|
| 347 |
+
|
| 348 |
+
if output_hidden_states:
|
| 349 |
+
all_hidden_states.append(hidden_states)
|
| 350 |
+
|
| 351 |
+
# 组装输出
|
| 352 |
+
outputs = {
|
| 353 |
+
'logits': logits,
|
| 354 |
+
'moe_aux_loss': moe_aux_loss,
|
| 355 |
+
'contrastive_losses': contrastive_losses
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
if use_cache:
|
| 359 |
+
outputs['past_key_values'] = present_key_values
|
| 360 |
+
|
| 361 |
+
if output_hidden_states:
|
| 362 |
+
outputs['hidden_states'] = all_hidden_states
|
| 363 |
+
|
| 364 |
+
if output_attentions:
|
| 365 |
+
outputs['attentions'] = all_attentions
|
| 366 |
+
|
| 367 |
+
if return_hidden:
|
| 368 |
+
outputs['last_hidden_state'] = hidden_states
|
| 369 |
+
|
| 370 |
+
return outputs
|
| 371 |
+
|
| 372 |
+
@torch.no_grad()
|
| 373 |
+
def generate(
|
| 374 |
+
self,
|
| 375 |
+
input_data: Dict,
|
| 376 |
+
max_new_tokens: int = 100,
|
| 377 |
+
temperature: float = 1.0,
|
| 378 |
+
top_k: int = 50,
|
| 379 |
+
top_p: float = 0.9,
|
| 380 |
+
eos_token_id: int = 2,
|
| 381 |
+
pad_token_id: Optional[int] = None,
|
| 382 |
+
use_cache: bool = True,
|
| 383 |
+
repetition_penalty: float = 1.0,
|
| 384 |
+
length_penalty: float = 1.0,
|
| 385 |
+
min_length: int = 0,
|
| 386 |
+
do_sample: bool = True,
|
| 387 |
+
num_beams: int = 1
|
| 388 |
+
) -> torch.Tensor:
|
| 389 |
+
"""改进的生成方法"""
|
| 390 |
+
self.eval()
|
| 391 |
+
device = next(self.parameters()).device
|
| 392 |
+
|
| 393 |
+
if pad_token_id is None:
|
| 394 |
+
pad_token_id = eos_token_id
|
| 395 |
+
|
| 396 |
+
initial_text_tokens = input_data['segments'][0]['data'].to(device)
|
| 397 |
+
batch_size = initial_text_tokens.shape[0]
|
| 398 |
+
|
| 399 |
+
if 'attention_mask' in input_data:
|
| 400 |
+
attention_mask = input_data['attention_mask'].to(device)
|
| 401 |
+
else:
|
| 402 |
+
attention_mask = torch.ones_like(initial_text_tokens)
|
| 403 |
+
initial_seq_len = initial_text_tokens.shape[1]
|
| 404 |
+
position_ids = torch.zeros((batch_size,initial_seq_len),dtype=torch.long,device=device)
|
| 405 |
+
|
| 406 |
+
for i in range(batch_size):
|
| 407 |
+
non_pad_mask = attention_mask[i].bool()
|
| 408 |
+
if non_pad_mask.any():
|
| 409 |
+
positions = torch.cumsum(non_pad_mask.long(),dim=0) -1
|
| 410 |
+
position_ids[i]=positions * non_pad_mask.long()
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
generated_tokens = []
|
| 415 |
+
past_key_values = None
|
| 416 |
+
current_tokens = initial_text_tokens
|
| 417 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
| 418 |
+
|
| 419 |
+
for step in range(max_new_tokens):
|
| 420 |
+
current_input_data = {
|
| 421 |
+
'segments': [{'type': 'text', 'data': current_tokens, 'modality_id': 0}]
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
if step > 0 and use_cache:
|
| 425 |
+
# 添加当前 token 的 mask (1)
|
| 426 |
+
new_mask = torch.ones(batch_size,1,dtype=torch.long,device=device)
|
| 427 |
+
attention_mask = torch.cat([attention_mask, new_mask], dim=1)
|
| 428 |
+
current_positions = (attention_mask.sum(dim=1 , keepdim=True) -1).clamp(min=0)
|
| 429 |
+
current_positions_ids=current_positions
|
| 430 |
+
else:
|
| 431 |
+
current_positions_ids=position_ids
|
| 432 |
+
outputs = self.forward(
|
| 433 |
+
current_input_data,
|
| 434 |
+
attention_mask=attention_mask, # <--- 传入 Mask
|
| 435 |
+
position_ids=current_positions_ids,
|
| 436 |
+
use_cache=use_cache,
|
| 437 |
+
past_key_values=past_key_values
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
logits = outputs['logits']
|
| 441 |
+
if use_cache:
|
| 442 |
+
past_key_values = outputs['past_key_values']
|
| 443 |
+
|
| 444 |
+
next_token_logits = logits[:, -1, :] / max(temperature, 1e-5)
|
| 445 |
+
|
| 446 |
+
# Repetition penalty
|
| 447 |
+
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
|
| 448 |
+
prev_generated = torch.cat(generated_tokens, dim=1)
|
| 449 |
+
score = torch.gather(next_token_logits, 1, prev_generated)
|
| 450 |
+
score = torch.where(
|
| 451 |
+
score < 0,
|
| 452 |
+
score * repetition_penalty,
|
| 453 |
+
score / repetition_penalty
|
| 454 |
+
)
|
| 455 |
+
next_token_logits.scatter_(1, prev_generated, score)
|
| 456 |
+
|
| 457 |
+
# Min length constraint
|
| 458 |
+
if step < min_length:
|
| 459 |
+
next_token_logits[:, eos_token_id] = float('-inf')
|
| 460 |
+
|
| 461 |
+
# Sampling
|
| 462 |
+
if do_sample:
|
| 463 |
+
if top_k > 0:
|
| 464 |
+
top_k_vals, _ = torch.topk(next_token_logits, top_k)
|
| 465 |
+
min_val_to_keep = top_k_vals[:, -1].unsqueeze(-1)
|
| 466 |
+
next_token_logits[next_token_logits < min_val_to_keep] = float('-inf')
|
| 467 |
+
|
| 468 |
+
if top_p < 1.0:
|
| 469 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 470 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 471 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 472 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 473 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 474 |
+
indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool)
|
| 475 |
+
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
| 476 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 477 |
+
|
| 478 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 479 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 480 |
+
else:
|
| 481 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 482 |
+
|
| 483 |
+
# Apply unfinished mask
|
| 484 |
+
next_token = next_token * unfinished_sequences[:, None] + pad_token_id * (1 - unfinished_sequences[:, None])
|
| 485 |
+
|
| 486 |
+
generated_tokens.append(next_token)
|
| 487 |
+
|
| 488 |
+
if not use_cache:
|
| 489 |
+
initial_text_tokens = torch.cat([initial_text_tokens, next_token], dim=1)
|
| 490 |
+
current_tokens = initial_text_tokens
|
| 491 |
+
else:
|
| 492 |
+
current_tokens = next_token
|
| 493 |
+
|
| 494 |
+
# Update unfinished sequences
|
| 495 |
+
unfinished_sequences = unfinished_sequences.mul(
|
| 496 |
+
(next_token.squeeze(-1) != eos_token_id).long()
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if unfinished_sequences.max() == 0:
|
| 500 |
+
break
|
| 501 |
+
|
| 502 |
+
if not generated_tokens:
|
| 503 |
+
return torch.empty(batch_size, 0, dtype=torch.long, device=device)
|
| 504 |
+
|
| 505 |
+
return torch.cat(generated_tokens, dim=1)
|
moe.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
优化的混合专家系统 (Mixture of Experts)
|
| 3 |
+
基于Mixtral、Switch Transformer、GLaM的最佳实践
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Tuple, Optional, List
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
class Expert(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
单个专家网络
|
| 14 |
+
使用SwiGLU激活函数以获得更好的性能
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
hidden_dim: int,
|
| 21 |
+
dropout: float = 0.0,
|
| 22 |
+
bias: bool = False
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
|
| 26 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
|
| 28 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 29 |
+
|
| 30 |
+
self._init_weights()
|
| 31 |
+
|
| 32 |
+
def _init_weights(self):
|
| 33 |
+
"""改进的权重初始化"""
|
| 34 |
+
for module in [self.w1, self.w2, self.w3]:
|
| 35 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 36 |
+
if module.bias is not None:
|
| 37 |
+
nn.init.zeros_(module.bias)
|
| 38 |
+
|
| 39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""
|
| 41 |
+
前向传播
|
| 42 |
+
SwiGLU: (Swish(W1·x) ⊙ W3·x) W2
|
| 43 |
+
"""
|
| 44 |
+
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
| 45 |
+
|
| 46 |
+
class TopKRouter(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Top-K路由器 - 改进版
|
| 49 |
+
改进点:
|
| 50 |
+
1. 专家容量管理
|
| 51 |
+
2. 负载均衡
|
| 52 |
+
3. 训练时的噪声注入
|
| 53 |
+
4. Z-loss防止logits爆炸
|
| 54 |
+
|
| 55 |
+
参考:
|
| 56 |
+
- Switch Transformer
|
| 57 |
+
- Mixtral 8x7B
|
| 58 |
+
- ST-MoE
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
dim: int,
|
| 64 |
+
num_experts: int,
|
| 65 |
+
top_k: int = 2,
|
| 66 |
+
capacity_factor: float = 1.25,
|
| 67 |
+
noise_std: float = 1.0,
|
| 68 |
+
use_expert_capacity: bool = True,
|
| 69 |
+
router_z_loss_coef: float = 0.001,
|
| 70 |
+
router_aux_loss_coef: float = 0.01
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.num_experts = num_experts
|
| 74 |
+
self.top_k = top_k
|
| 75 |
+
self.capacity_factor = capacity_factor
|
| 76 |
+
self.noise_std = noise_std
|
| 77 |
+
self.use_expert_capacity = use_expert_capacity
|
| 78 |
+
self.router_z_loss_coef = router_z_loss_coef
|
| 79 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 80 |
+
|
| 81 |
+
self.gate = nn.Linear(dim, num_experts, bias=False)
|
| 82 |
+
|
| 83 |
+
nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
|
| 84 |
+
|
| 85 |
+
def _compute_routing_weights(
|
| 86 |
+
self,
|
| 87 |
+
logits: torch.Tensor,
|
| 88 |
+
use_noise: bool = True
|
| 89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
"""
|
| 91 |
+
计算路由权重
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
logits: 路由logits [batch*seq_len, num_experts]
|
| 95 |
+
use_noise: 是否添加噪声
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
top_k_gates: Top-K门控值 [batch*seq_len, top_k]
|
| 99 |
+
top_k_indices: Top-K专家索引 [batch*seq_len, top_k]
|
| 100 |
+
"""
|
| 101 |
+
if use_noise and self.training:
|
| 102 |
+
noise = torch.randn_like(logits) * self.noise_std
|
| 103 |
+
logits = logits + noise
|
| 104 |
+
|
| 105 |
+
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
|
| 106 |
+
|
| 107 |
+
top_k_gates = F.softmax(top_k_logits, dim=-1)
|
| 108 |
+
|
| 109 |
+
return top_k_gates, top_k_indices
|
| 110 |
+
|
| 111 |
+
def _compute_auxiliary_loss(
|
| 112 |
+
self,
|
| 113 |
+
logits: torch.Tensor,
|
| 114 |
+
top_k_indices: torch.Tensor
|
| 115 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 116 |
+
"""
|
| 117 |
+
计算辅助损失
|
| 118 |
+
|
| 119 |
+
包括:
|
| 120 |
+
1. 负载均衡损失(确保专家被均匀使用)
|
| 121 |
+
2. Z-loss(防止logits过大)
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
logits: 路由logits [batch*seq_len, num_experts]
|
| 125 |
+
top_k_indices: 选中的专家索引 [batch*seq_len, top_k]
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
load_balance_loss: 负载均衡损失
|
| 129 |
+
z_loss: Z-loss
|
| 130 |
+
"""
|
| 131 |
+
num_tokens = logits.shape[0]
|
| 132 |
+
|
| 133 |
+
router_probs = F.softmax(logits, dim=-1)
|
| 134 |
+
|
| 135 |
+
expert_probs = router_probs.mean(dim=0)
|
| 136 |
+
|
| 137 |
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
| 138 |
+
expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k)
|
| 139 |
+
|
| 140 |
+
load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq)
|
| 141 |
+
|
| 142 |
+
z_loss = torch.mean(logits ** 2)
|
| 143 |
+
|
| 144 |
+
return load_balance_loss, z_loss
|
| 145 |
+
|
| 146 |
+
def forward(
|
| 147 |
+
self,
|
| 148 |
+
x: torch.Tensor
|
| 149 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 150 |
+
"""
|
| 151 |
+
前向传播
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
x: 输入 [batch*seq_len, dim]
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
top_k_gates: 门控权重 [batch*seq_len, top_k]
|
| 158 |
+
top_k_indices: 专家索引 [batch*seq_len, top_k]
|
| 159 |
+
auxiliary_loss: 辅助损失(标量)
|
| 160 |
+
"""
|
| 161 |
+
logits = self.gate(x)
|
| 162 |
+
|
| 163 |
+
top_k_gates, top_k_indices = self._compute_routing_weights(
|
| 164 |
+
logits, use_noise=self.training
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if self.training:
|
| 168 |
+
load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices)
|
| 169 |
+
auxiliary_loss = (
|
| 170 |
+
self.router_aux_loss_coef * load_balance_loss +
|
| 171 |
+
self.router_z_loss_coef * z_loss
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
auxiliary_loss = torch.tensor(0.0, device=x.device)
|
| 175 |
+
|
| 176 |
+
return top_k_gates, top_k_indices, auxiliary_loss
|
| 177 |
+
|
| 178 |
+
class MixtureOfExperts(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
混合专家层 - 优化版
|
| 181 |
+
改进点:
|
| 182 |
+
1. 高效的token分发和聚合
|
| 183 |
+
2. 专家容量管理
|
| 184 |
+
3. 改进的负载均衡
|
| 185 |
+
4. 支持专家并行
|
| 186 |
+
|
| 187 |
+
参考:
|
| 188 |
+
- Mixtral 8x7B
|
| 189 |
+
- Switch Transformer
|
| 190 |
+
- GShard
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
dim: int,
|
| 196 |
+
num_experts: int = 8,
|
| 197 |
+
expert_hidden_dim: Optional[int] = None,
|
| 198 |
+
top_k: int = 2,
|
| 199 |
+
dropout: float = 0.0,
|
| 200 |
+
capacity_factor: float = 1.25,
|
| 201 |
+
use_expert_capacity: bool = True,
|
| 202 |
+
router_z_loss_coef: float = 0.001,
|
| 203 |
+
router_aux_loss_coef: float = 0.01,
|
| 204 |
+
noise_std: float = 1.0,
|
| 205 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 206 |
+
):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.num_experts = num_experts
|
| 209 |
+
self.top_k = top_k
|
| 210 |
+
self.capacity_factor = capacity_factor
|
| 211 |
+
self.use_expert_capacity = use_expert_capacity
|
| 212 |
+
|
| 213 |
+
if expert_hidden_dim is None:
|
| 214 |
+
if ffn_dim_multiplier is not None:
|
| 215 |
+
expert_hidden_dim = int(dim * ffn_dim_multiplier)
|
| 216 |
+
else:
|
| 217 |
+
expert_hidden_dim = int(2 * dim * 4 / 3)
|
| 218 |
+
expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
|
| 219 |
+
|
| 220 |
+
self.experts = nn.ModuleList([
|
| 221 |
+
Expert(dim, expert_hidden_dim, dropout, bias=False)
|
| 222 |
+
for _ in range(num_experts)
|
| 223 |
+
])
|
| 224 |
+
|
| 225 |
+
self.router = TopKRouter(
|
| 226 |
+
dim=dim,
|
| 227 |
+
num_experts=num_experts,
|
| 228 |
+
top_k=top_k,
|
| 229 |
+
capacity_factor=capacity_factor,
|
| 230 |
+
noise_std=noise_std,
|
| 231 |
+
use_expert_capacity=use_expert_capacity,
|
| 232 |
+
router_z_loss_coef=router_z_loss_coef,
|
| 233 |
+
router_aux_loss_coef=router_aux_loss_coef
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def _compute_expert_capacity(self, num_tokens: int) -> int:
|
| 237 |
+
"""计算每个专家的容量"""
|
| 238 |
+
if not self.use_expert_capacity:
|
| 239 |
+
return num_tokens
|
| 240 |
+
|
| 241 |
+
capacity = int(
|
| 242 |
+
(num_tokens / self.num_experts) * self.capacity_factor * self.top_k
|
| 243 |
+
)
|
| 244 |
+
return max(capacity, 1)
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
x: torch.Tensor
|
| 249 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 250 |
+
"""
|
| 251 |
+
前向传播
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
x: 输入 [batch, seq_len, dim]
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
output: 输出 [batch, seq_len, dim]
|
| 258 |
+
auxiliary_loss: 辅助损失
|
| 259 |
+
"""
|
| 260 |
+
B, T, D = x.shape
|
| 261 |
+
num_tokens = B * T
|
| 262 |
+
|
| 263 |
+
x_flat = x.view(-1, D)
|
| 264 |
+
|
| 265 |
+
top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat)
|
| 266 |
+
|
| 267 |
+
output = torch.zeros_like(x_flat)
|
| 268 |
+
|
| 269 |
+
expert_capacity = self._compute_expert_capacity(num_tokens)
|
| 270 |
+
|
| 271 |
+
for expert_idx, expert in enumerate(self.experts):
|
| 272 |
+
expert_mask = (top_k_indices == expert_idx)
|
| 273 |
+
|
| 274 |
+
token_indices, topk_positions = torch.where(expert_mask)
|
| 275 |
+
|
| 276 |
+
if len(token_indices) == 0:
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
if self.use_expert_capacity and len(token_indices) > expert_capacity:
|
| 280 |
+
perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity]
|
| 281 |
+
token_indices = token_indices[perm]
|
| 282 |
+
topk_positions = topk_positions[perm]
|
| 283 |
+
|
| 284 |
+
expert_input = x_flat[token_indices]
|
| 285 |
+
expert_gates = top_k_gates[token_indices, topk_positions]
|
| 286 |
+
|
| 287 |
+
expert_output = expert(expert_input)
|
| 288 |
+
|
| 289 |
+
expert_output = expert_output * expert_gates.unsqueeze(-1)
|
| 290 |
+
|
| 291 |
+
output.index_add_(0, token_indices, expert_output)
|
| 292 |
+
|
| 293 |
+
output = output.view(B, T, D)
|
| 294 |
+
|
| 295 |
+
return output, auxiliary_loss
|
| 296 |
+
|
| 297 |
+
class SparseDispatcher:
|
| 298 |
+
"""
|
| 299 |
+
稀疏分发器 - 用于高效的MoE计算
|
| 300 |
+
管理tokens到专家的分配和聚合
|
| 301 |
+
这是一个可选的辅助类,用于更高效的实现
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
num_experts: int,
|
| 307 |
+
gates: torch.Tensor,
|
| 308 |
+
expert_indices: torch.Tensor
|
| 309 |
+
):
|
| 310 |
+
"""
|
| 311 |
+
Args:
|
| 312 |
+
num_experts: 专家数量
|
| 313 |
+
gates: 门控权重 [batch_size, num_experts]
|
| 314 |
+
expert_indices: 专家索引 [batch_size]
|
| 315 |
+
"""
|
| 316 |
+
self.num_experts = num_experts
|
| 317 |
+
self._gates = gates
|
| 318 |
+
self._expert_indices = expert_indices
|
| 319 |
+
|
| 320 |
+
self._expert_masks = []
|
| 321 |
+
for i in range(num_experts):
|
| 322 |
+
self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0])
|
| 323 |
+
|
| 324 |
+
def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
|
| 325 |
+
"""
|
| 326 |
+
将输入分发给各个专家
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
inp: 输入张量 [batch_size, dim]
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
expert_inputs: 每个专家的输入列表
|
| 333 |
+
"""
|
| 334 |
+
expert_inputs = []
|
| 335 |
+
for mask in self._expert_masks:
|
| 336 |
+
if len(mask) > 0:
|
| 337 |
+
expert_inputs.append(inp[mask])
|
| 338 |
+
else:
|
| 339 |
+
expert_inputs.append(
|
| 340 |
+
torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype)
|
| 341 |
+
)
|
| 342 |
+
return expert_inputs
|
| 343 |
+
|
| 344 |
+
def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor:
|
| 345 |
+
"""
|
| 346 |
+
组合专家输出
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
expert_outputs: 每个专家的输出列表
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
output: 组合后的输出 [batch_size, dim]
|
| 353 |
+
"""
|
| 354 |
+
output_shape = (self._gates.size(0), expert_outputs[0].size(-1))
|
| 355 |
+
output = torch.zeros(
|
| 356 |
+
output_shape,
|
| 357 |
+
device=self._gates.device,
|
| 358 |
+
dtype=expert_outputs[0].dtype
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
for expert_idx, expert_out in enumerate(expert_outputs):
|
| 362 |
+
mask = self._expert_masks[expert_idx]
|
| 363 |
+
if len(mask) > 0:
|
| 364 |
+
weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1)
|
| 365 |
+
output[mask] += weighted_output
|
| 366 |
+
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def expert_to_gates(self) -> List[torch.Tensor]:
|
| 370 |
+
"""
|
| 371 |
+
返回每个专家对应的门控权重
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
gates_per_expert: 每个专家的门控权重列表
|
| 375 |
+
"""
|
| 376 |
+
gates_per_expert = []
|
| 377 |
+
for expert_idx in range(self.num_experts):
|
| 378 |
+
mask = self._expert_masks[expert_idx]
|
| 379 |
+
if len(mask) > 0:
|
| 380 |
+
gates_per_expert.append(self._gates[mask, expert_idx])
|
| 381 |
+
else:
|
| 382 |
+
gates_per_expert.append(torch.empty(0, device=self._gates.device))
|
| 383 |
+
return gates_per_expert
|
| 384 |
+
|
| 385 |
+
class MoELayer(nn.Module):
|
| 386 |
+
"""
|
| 387 |
+
MoE层的另一种实现方式
|
| 388 |
+
使用SparseDispatcher进行更高效的计算
|
| 389 |
+
"""
|
| 390 |
+
def __init__(
|
| 391 |
+
self,
|
| 392 |
+
dim: int,
|
| 393 |
+
num_experts: int = 8,
|
| 394 |
+
expert_hidden_dim: Optional[int] = None,
|
| 395 |
+
top_k: int = 2,
|
| 396 |
+
dropout: float = 0.0,
|
| 397 |
+
capacity_factor: float = 1.25
|
| 398 |
+
):
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.num_experts = num_experts
|
| 401 |
+
self.top_k = top_k
|
| 402 |
+
|
| 403 |
+
if expert_hidden_dim is None:
|
| 404 |
+
expert_hidden_dim = int(2 * dim * 4 / 3)
|
| 405 |
+
expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
|
| 406 |
+
|
| 407 |
+
self.experts = nn.ModuleList([
|
| 408 |
+
Expert(dim, expert_hidden_dim, dropout)
|
| 409 |
+
for _ in range(num_experts)
|
| 410 |
+
])
|
| 411 |
+
|
| 412 |
+
self.gate = nn.Linear(dim, num_experts, bias=False)
|
| 413 |
+
nn.init.normal_(self.gate.weight, std=0.02)
|
| 414 |
+
|
| 415 |
+
self.capacity_factor = capacity_factor
|
| 416 |
+
|
| 417 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 418 |
+
"""
|
| 419 |
+
前向传播使用SparseDispatcher
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
x: 输入 [batch, seq_len, dim]
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
output: 输出 [batch, seq_len, dim]
|
| 426 |
+
aux_loss: 辅助损失
|
| 427 |
+
"""
|
| 428 |
+
B, T, D = x.shape
|
| 429 |
+
x_flat = x.view(-1, D)
|
| 430 |
+
|
| 431 |
+
gates = F.softmax(self.gate(x_flat), dim=-1)
|
| 432 |
+
|
| 433 |
+
top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1)
|
| 434 |
+
top_k_gates = F.softmax(top_k_gates, dim=-1)
|
| 435 |
+
|
| 436 |
+
expert_probs = gates.mean(dim=0)
|
| 437 |
+
expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1])
|
| 438 |
+
expert_counts = expert_counts / (B * T * self.top_k)
|
| 439 |
+
aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts)
|
| 440 |
+
|
| 441 |
+
output = torch.zeros_like(x_flat)
|
| 442 |
+
|
| 443 |
+
for expert_idx, expert in enumerate(self.experts):
|
| 444 |
+
expert_mask = (top_k_indices == expert_idx)
|
| 445 |
+
token_indices, topk_positions = torch.where(expert_mask)
|
| 446 |
+
|
| 447 |
+
if len(token_indices) == 0:
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
expert_input = x_flat[token_indices]
|
| 451 |
+
expert_gates = top_k_gates[token_indices, topk_positions]
|
| 452 |
+
|
| 453 |
+
expert_output = expert(expert_input)
|
| 454 |
+
expert_output = expert_output * expert_gates.unsqueeze(-1)
|
| 455 |
+
|
| 456 |
+
output.index_add_(0, token_indices, expert_output)
|
| 457 |
+
|
| 458 |
+
output = output.view(B, T, D)
|
| 459 |
+
|
| 460 |
+
return output, aux_loss
|
multimodel_fusion.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
跨模态融合模块 - SOTA级别
|
| 3 |
+
支持深度跨模态交互、对比学习、模态对齐
|
| 4 |
+
修复版本:解决了所有接口不匹配和潜在bug
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
from components import RMSNorm
|
| 11 |
+
from transformer import GroupedQueryAttention
|
| 12 |
+
import math
|
| 13 |
+
from contrastive_learning import MultiModalContrastiveLoss
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CrossModalAttention(nn.Module):
|
| 17 |
+
"""跨模态注意力 - 允许不同模态之间的信息交互"""
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim: int,
|
| 21 |
+
n_heads: int = 16,
|
| 22 |
+
dropout: float = 0.1,
|
| 23 |
+
qkv_bias: bool = True
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.dim = dim
|
| 27 |
+
self.n_heads = n_heads
|
| 28 |
+
self.head_dim = dim // n_heads
|
| 29 |
+
self.scale = self.head_dim ** -0.5
|
| 30 |
+
|
| 31 |
+
assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
|
| 32 |
+
|
| 33 |
+
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 34 |
+
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 35 |
+
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| 36 |
+
self.o_proj = nn.Linear(dim, dim)
|
| 37 |
+
|
| 38 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 39 |
+
self.resid_dropout = nn.Dropout(dropout)
|
| 40 |
+
|
| 41 |
+
self.norm_q = RMSNorm(dim)
|
| 42 |
+
self.norm_k = RMSNorm(dim)
|
| 43 |
+
|
| 44 |
+
def forward(
|
| 45 |
+
self,
|
| 46 |
+
query: torch.Tensor,
|
| 47 |
+
key: torch.Tensor,
|
| 48 |
+
value: torch.Tensor,
|
| 49 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
query: [B, T_q, D] - 查询模态
|
| 54 |
+
key: [B, T_k, D] - 键模态
|
| 55 |
+
value: [B, T_v, D] - 值模态 (通常与key相同)
|
| 56 |
+
"""
|
| 57 |
+
B, T_q, D = query.shape
|
| 58 |
+
T_k = key.shape[1]
|
| 59 |
+
|
| 60 |
+
# 归一化
|
| 61 |
+
query = self.norm_q(query)
|
| 62 |
+
key = self.norm_k(key)
|
| 63 |
+
|
| 64 |
+
# 投影并重塑
|
| 65 |
+
q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
|
| 66 |
+
k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
|
| 67 |
+
v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
|
| 68 |
+
|
| 69 |
+
# 使用Flash Attention或手动实现
|
| 70 |
+
if hasattr(F, 'scaled_dot_product_attention'):
|
| 71 |
+
dropout_p = self.attn_dropout.p if self.training else 0.0
|
| 72 |
+
attn_output = F.scaled_dot_product_attention(
|
| 73 |
+
q, k, v,
|
| 74 |
+
attn_mask=attention_mask,
|
| 75 |
+
dropout_p=dropout_p,
|
| 76 |
+
is_causal=False
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
attn_scores = (q @ k.transpose(-2, -1)) * self.scale
|
| 80 |
+
if attention_mask is not None:
|
| 81 |
+
attn_scores = attn_scores + attention_mask
|
| 82 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 83 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 84 |
+
attn_output = attn_weights @ v
|
| 85 |
+
|
| 86 |
+
# 重塑并投影输出
|
| 87 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D)
|
| 88 |
+
output = self.resid_dropout(self.o_proj(attn_output))
|
| 89 |
+
|
| 90 |
+
return output
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ModalityProjector(nn.Module):
|
| 94 |
+
"""模态投影器 - 将不同模态投影到统一空间"""
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
input_dim: int,
|
| 98 |
+
output_dim: int,
|
| 99 |
+
hidden_dim: Optional[int] = None,
|
| 100 |
+
num_layers: int = 2,
|
| 101 |
+
use_layer_norm: bool = True
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
if hidden_dim is None:
|
| 105 |
+
hidden_dim = (input_dim + output_dim) // 2
|
| 106 |
+
|
| 107 |
+
layers = []
|
| 108 |
+
for i in range(num_layers):
|
| 109 |
+
if i == 0:
|
| 110 |
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
| 111 |
+
elif i == num_layers - 1:
|
| 112 |
+
layers.append(nn.Linear(hidden_dim, output_dim))
|
| 113 |
+
else:
|
| 114 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
| 115 |
+
|
| 116 |
+
if i < num_layers - 1:
|
| 117 |
+
if use_layer_norm:
|
| 118 |
+
layers.append(RMSNorm(hidden_dim))
|
| 119 |
+
layers.append(nn.GELU())
|
| 120 |
+
|
| 121 |
+
self.projector = nn.Sequential(*layers)
|
| 122 |
+
|
| 123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
return self.projector(x)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ModalityAdapter(nn.Module):
|
| 128 |
+
"""模态适配器 - 为每个模态学习特定的适配参数"""
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
dim: int,
|
| 132 |
+
bottleneck_dim: int = 64,
|
| 133 |
+
num_modalities: int = 4
|
| 134 |
+
):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.adapters = nn.ModuleList([
|
| 137 |
+
nn.Sequential(
|
| 138 |
+
nn.Linear(dim, bottleneck_dim),
|
| 139 |
+
nn.GELU(),
|
| 140 |
+
nn.Linear(bottleneck_dim, dim)
|
| 141 |
+
)
|
| 142 |
+
for _ in range(num_modalities)
|
| 143 |
+
])
|
| 144 |
+
# 初始化为零,确保开始时是恒等映射
|
| 145 |
+
for adapter in self.adapters:
|
| 146 |
+
nn.init.zeros_(adapter[-1].weight)
|
| 147 |
+
nn.init.zeros_(adapter[-1].bias)
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor:
|
| 150 |
+
if modality_id >= len(self.adapters):
|
| 151 |
+
return x
|
| 152 |
+
return x + self.adapters[modality_id](x)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class CrossModalFusionLayer(nn.Module):
|
| 156 |
+
"""跨模态融合层"""
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
dim: int,
|
| 160 |
+
n_heads: int = 16,
|
| 161 |
+
dropout: float = 0.1,
|
| 162 |
+
use_adapter: bool = True,
|
| 163 |
+
adapter_dim: int = 64
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.dim = dim
|
| 167 |
+
self.use_adapter = use_adapter
|
| 168 |
+
|
| 169 |
+
# 自注意力
|
| 170 |
+
self.self_attn = GroupedQueryAttention(
|
| 171 |
+
dim=dim,
|
| 172 |
+
n_heads=n_heads,
|
| 173 |
+
dropout=dropout,
|
| 174 |
+
attn_dropout=dropout
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# 跨模态注意力
|
| 178 |
+
self.cross_attn = CrossModalAttention(
|
| 179 |
+
dim=dim,
|
| 180 |
+
n_heads=n_heads,
|
| 181 |
+
dropout=dropout
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# 前馈网络
|
| 185 |
+
self.ffn = nn.Sequential(
|
| 186 |
+
nn.Linear(dim, dim * 4),
|
| 187 |
+
nn.GELU(),
|
| 188 |
+
nn.Dropout(dropout),
|
| 189 |
+
nn.Linear(dim * 4, dim),
|
| 190 |
+
nn.Dropout(dropout)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# 归一化层
|
| 194 |
+
self.norm1 = RMSNorm(dim)
|
| 195 |
+
self.norm2 = RMSNorm(dim)
|
| 196 |
+
self.norm3 = RMSNorm(dim)
|
| 197 |
+
|
| 198 |
+
# 模态适配器
|
| 199 |
+
if use_adapter:
|
| 200 |
+
self.adapter = ModalityAdapter(dim, adapter_dim)
|
| 201 |
+
else:
|
| 202 |
+
self.adapter = None
|
| 203 |
+
|
| 204 |
+
def forward(
|
| 205 |
+
self,
|
| 206 |
+
x: torch.Tensor,
|
| 207 |
+
context: Optional[torch.Tensor] = None,
|
| 208 |
+
modality_id: Optional[int] = None,
|
| 209 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 210 |
+
) -> torch.Tensor:
|
| 211 |
+
"""
|
| 212 |
+
Args:
|
| 213 |
+
x: 当前模态特征 [B, T, D]
|
| 214 |
+
context: 其他模态的上下文 [B, T_ctx, D]
|
| 215 |
+
modality_id: 模态ID(用于adapter)
|
| 216 |
+
attention_mask: 注意力掩码
|
| 217 |
+
"""
|
| 218 |
+
# 自注意力 - 返回 (output, present_kv, attention_weights)
|
| 219 |
+
attn_out = self.self_attn(
|
| 220 |
+
self.norm1(x),
|
| 221 |
+
attention_mask=attention_mask
|
| 222 |
+
)[0] # 只取输出
|
| 223 |
+
x = x + attn_out
|
| 224 |
+
|
| 225 |
+
# 跨模态注意力(如果有上下文)
|
| 226 |
+
if context is not None:
|
| 227 |
+
cross_attn_out = self.cross_attn(
|
| 228 |
+
self.norm2(x),
|
| 229 |
+
context,
|
| 230 |
+
context,
|
| 231 |
+
attention_mask=None
|
| 232 |
+
)
|
| 233 |
+
x = x + cross_attn_out
|
| 234 |
+
|
| 235 |
+
# 前馈网络
|
| 236 |
+
x = x + self.ffn(self.norm3(x))
|
| 237 |
+
|
| 238 |
+
# 模态适配器
|
| 239 |
+
if self.use_adapter and modality_id is not None and self.adapter is not None:
|
| 240 |
+
x = self.adapter(x, modality_id)
|
| 241 |
+
|
| 242 |
+
return x
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class PerceiverResampler(nn.Module):
|
| 246 |
+
"""Perceiver Resampler - 压缩模态特征到固定数量的tokens"""
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
dim: int,
|
| 250 |
+
depth: int = 6,
|
| 251 |
+
num_latents: int = 64,
|
| 252 |
+
n_heads: int = 16,
|
| 253 |
+
dropout: float = 0.0
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.num_latents = num_latents
|
| 257 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 258 |
+
|
| 259 |
+
self.layers = nn.ModuleList([
|
| 260 |
+
CrossModalFusionLayer(
|
| 261 |
+
dim=dim,
|
| 262 |
+
n_heads=n_heads,
|
| 263 |
+
dropout=dropout,
|
| 264 |
+
use_adapter=False
|
| 265 |
+
)
|
| 266 |
+
for _ in range(depth)
|
| 267 |
+
])
|
| 268 |
+
|
| 269 |
+
self.norm = RMSNorm(dim)
|
| 270 |
+
|
| 271 |
+
# 初始化latents
|
| 272 |
+
nn.init.trunc_normal_(self.latents, std=0.02)
|
| 273 |
+
|
| 274 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 275 |
+
"""
|
| 276 |
+
Args:
|
| 277 |
+
x: [B, T, D] - 输入特征
|
| 278 |
+
Returns:
|
| 279 |
+
[B, num_latents, D] - 压缩后的特征
|
| 280 |
+
"""
|
| 281 |
+
B = x.shape[0]
|
| 282 |
+
latents = self.latents.unsqueeze(0).expand(B, -1, -1)
|
| 283 |
+
|
| 284 |
+
# 通过多层交叉注意力处理
|
| 285 |
+
for layer in self.layers:
|
| 286 |
+
latents = layer(latents, context=x)
|
| 287 |
+
|
| 288 |
+
return self.norm(latents)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class MultiModalFusionModule(nn.Module):
|
| 292 |
+
"""多模态融合模块 - 整合所有融合策略"""
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
dim: int = 2048,
|
| 296 |
+
num_fusion_layers: int = 4,
|
| 297 |
+
n_heads: int = 16,
|
| 298 |
+
dropout: float = 0.1,
|
| 299 |
+
use_perceiver: bool = True,
|
| 300 |
+
num_latents: int = 64,
|
| 301 |
+
use_contrastive: bool = True,
|
| 302 |
+
contrastive_loss_type: str = 'siglip',
|
| 303 |
+
contrastive_embed_dim: int = 512
|
| 304 |
+
):
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.dim = dim
|
| 307 |
+
self.use_perceiver = use_perceiver
|
| 308 |
+
self.use_contrastive = use_contrastive
|
| 309 |
+
|
| 310 |
+
# 模态投影器
|
| 311 |
+
self.modality_projectors = nn.ModuleDict({
|
| 312 |
+
'image': ModalityProjector(dim, dim),
|
| 313 |
+
'audio': ModalityProjector(dim, dim),
|
| 314 |
+
'video': ModalityProjector(dim, dim),
|
| 315 |
+
'text': ModalityProjector(dim, dim)
|
| 316 |
+
})
|
| 317 |
+
|
| 318 |
+
# 跨模态融合层
|
| 319 |
+
self.fusion_layers = nn.ModuleList([
|
| 320 |
+
CrossModalFusionLayer(
|
| 321 |
+
dim=dim,
|
| 322 |
+
n_heads=n_heads,
|
| 323 |
+
dropout=dropout,
|
| 324 |
+
use_adapter=True
|
| 325 |
+
)
|
| 326 |
+
for _ in range(num_fusion_layers)
|
| 327 |
+
])
|
| 328 |
+
|
| 329 |
+
# Perceiver Resampler
|
| 330 |
+
if use_perceiver:
|
| 331 |
+
self.perceiver = PerceiverResampler(
|
| 332 |
+
dim=dim,
|
| 333 |
+
depth=4,
|
| 334 |
+
num_latents=num_latents,
|
| 335 |
+
n_heads=n_heads,
|
| 336 |
+
dropout=dropout
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# 对比学习模块
|
| 340 |
+
if use_contrastive:
|
| 341 |
+
# 定义每个模态的输入维度和池化类型
|
| 342 |
+
modality_config = {
|
| 343 |
+
'text': 'cls',
|
| 344 |
+
'image': 'cls',
|
| 345 |
+
'audio': 'mean',
|
| 346 |
+
'video': 'mean'
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
input_dims = {k: dim for k in modality_config.keys()}
|
| 350 |
+
|
| 351 |
+
self.contrastive_module = MultiModalContrastiveLoss(
|
| 352 |
+
embed_dim=contrastive_embed_dim,
|
| 353 |
+
input_dims=input_dims,
|
| 354 |
+
temperature=0.07,
|
| 355 |
+
loss_type=contrastive_loss_type,
|
| 356 |
+
modality_config=modality_config
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.final_norm = RMSNorm(dim)
|
| 360 |
+
|
| 361 |
+
def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
|
| 362 |
+
"""池化特征到单一向量 [B, T, D] -> [B, D]"""
|
| 363 |
+
if features.dim() == 3:
|
| 364 |
+
return features.mean(dim=1)
|
| 365 |
+
return features
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
segments: List[Dict],
|
| 370 |
+
compute_contrastive: bool = False
|
| 371 |
+
) -> Dict:
|
| 372 |
+
"""
|
| 373 |
+
Args:
|
| 374 |
+
segments: 列表,每个元素包含 {'type', 'data', 'modality_id'}
|
| 375 |
+
- type: str, 模态类型 ('image', 'audio', 'video', 'text')
|
| 376 |
+
- data: Tensor [B, T, D], 模态数据
|
| 377 |
+
- modality_id: int, 模态ID (0-3)
|
| 378 |
+
compute_contrastive: 是否计算对比学习损失
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
Dict containing:
|
| 382 |
+
- fused_features: 融合后的特征序列
|
| 383 |
+
- modality_features: 各模态的特征字典
|
| 384 |
+
- contrastive_losses: 对比学习损失字典
|
| 385 |
+
"""
|
| 386 |
+
# 分离不同模态
|
| 387 |
+
modality_features = {}
|
| 388 |
+
modality_ids = {}
|
| 389 |
+
|
| 390 |
+
for seg in segments:
|
| 391 |
+
mod_type = seg['type']
|
| 392 |
+
mod_data = seg['data']
|
| 393 |
+
mod_id = seg['modality_id']
|
| 394 |
+
|
| 395 |
+
# 检查数据维度
|
| 396 |
+
if mod_data.dim() != 3:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
f"Expected 3D tensor [B, T, D] for modality {mod_type}, "
|
| 399 |
+
f"got shape {mod_data.shape}"
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# 投影到统一空间
|
| 403 |
+
if mod_type in self.modality_projectors:
|
| 404 |
+
projected = self.modality_projectors[mod_type](mod_data)
|
| 405 |
+
else:
|
| 406 |
+
projected = mod_data
|
| 407 |
+
|
| 408 |
+
# 使用Perceiver压缩(可选,非text模态)
|
| 409 |
+
if self.use_perceiver and mod_type != 'text':
|
| 410 |
+
projected = self.perceiver(projected)
|
| 411 |
+
|
| 412 |
+
modality_features[mod_type] = projected
|
| 413 |
+
modality_ids[mod_type] = mod_id
|
| 414 |
+
|
| 415 |
+
# 跨模态融合
|
| 416 |
+
fused_features = {}
|
| 417 |
+
|
| 418 |
+
for mod_type, features in modality_features.items():
|
| 419 |
+
# 创建不包含当前模态的上下文
|
| 420 |
+
if len(modality_features) > 1:
|
| 421 |
+
other_features = torch.cat([
|
| 422 |
+
f for k, f in modality_features.items() if k != mod_type
|
| 423 |
+
], dim=1)
|
| 424 |
+
else:
|
| 425 |
+
other_features = None
|
| 426 |
+
|
| 427 |
+
# 通过融合层
|
| 428 |
+
fused = features
|
| 429 |
+
for layer in self.fusion_layers:
|
| 430 |
+
fused = layer(
|
| 431 |
+
fused,
|
| 432 |
+
context=other_features,
|
| 433 |
+
modality_id=modality_ids[mod_type]
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
fused_features[mod_type] = self.final_norm(fused)
|
| 437 |
+
|
| 438 |
+
# 计算对比学习损失(如果需要)
|
| 439 |
+
contrastive_losses = {}
|
| 440 |
+
if compute_contrastive and self.use_contrastive:
|
| 441 |
+
# 准备特征字典 - 保持3D格式供投影头处理
|
| 442 |
+
pooled_features = fused_features # 不池化,让ProjectionHead处理
|
| 443 |
+
|
| 444 |
+
# 定义需要对比的模态对
|
| 445 |
+
modality_pairs = []
|
| 446 |
+
if 'text' in pooled_features:
|
| 447 |
+
for mod in pooled_features.keys():
|
| 448 |
+
if mod != 'text':
|
| 449 |
+
modality_pairs.append((mod, 'text'))
|
| 450 |
+
|
| 451 |
+
# 调用对比学习模块
|
| 452 |
+
if modality_pairs:
|
| 453 |
+
contrastive_losses = self.contrastive_module(
|
| 454 |
+
pooled_features,
|
| 455 |
+
modality_pairs=modality_pairs
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# 拼接所有融合后的特征
|
| 459 |
+
fused_sequence = torch.cat(list(fused_features.values()), dim=1)
|
| 460 |
+
|
| 461 |
+
return {
|
| 462 |
+
'fused_features': fused_sequence,
|
| 463 |
+
'modality_features': fused_features,
|
| 464 |
+
'contrastive_losses': contrastive_losses
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class EarlyFusionModule(nn.Module):
|
| 469 |
+
"""早期融合 - 在浅层就融合模态"""
|
| 470 |
+
def __init__(self, dim: int = 2048):
|
| 471 |
+
super().__init__()
|
| 472 |
+
self.fusion_proj = nn.Linear(dim, dim)
|
| 473 |
+
self.norm = RMSNorm(dim)
|
| 474 |
+
|
| 475 |
+
def forward(self, segments: List[Dict]) -> torch.Tensor:
|
| 476 |
+
"""简单拼接所有模态"""
|
| 477 |
+
all_features = [seg['data'] for seg in segments]
|
| 478 |
+
fused = torch.cat(all_features, dim=1)
|
| 479 |
+
fused = self.fusion_proj(fused)
|
| 480 |
+
return self.norm(fused)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class LateFusionModule(nn.Module):
|
| 484 |
+
"""晚期融合 - 在深层才融合模态"""
|
| 485 |
+
def __init__(
|
| 486 |
+
self,
|
| 487 |
+
dim: int = 2048,
|
| 488 |
+
num_modalities: int = 4,
|
| 489 |
+
fusion_method: str = 'concat' # 'concat', 'attention', 'average'
|
| 490 |
+
):
|
| 491 |
+
super().__init__()
|
| 492 |
+
self.fusion_method = fusion_method
|
| 493 |
+
|
| 494 |
+
if fusion_method == 'concat':
|
| 495 |
+
self.fusion_proj = nn.Linear(dim * num_modalities, dim)
|
| 496 |
+
elif fusion_method == 'attention':
|
| 497 |
+
self.attention_weights = nn.Linear(dim, 1)
|
| 498 |
+
|
| 499 |
+
self.norm = RMSNorm(dim)
|
| 500 |
+
|
| 501 |
+
def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor:
|
| 502 |
+
"""
|
| 503 |
+
Args:
|
| 504 |
+
modality_outputs: 每个模态独立处理后的输出列表 [B, T, D]
|
| 505 |
+
"""
|
| 506 |
+
if self.fusion_method == 'concat':
|
| 507 |
+
# 拼接并投影
|
| 508 |
+
pooled = [x.mean(dim=1) for x in modality_outputs]
|
| 509 |
+
fused = torch.cat(pooled, dim=-1)
|
| 510 |
+
fused = self.fusion_proj(fused)
|
| 511 |
+
|
| 512 |
+
elif self.fusion_method == 'attention':
|
| 513 |
+
# 注意力加权
|
| 514 |
+
stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
|
| 515 |
+
weights = F.softmax(self.attention_weights(stacked), dim=1)
|
| 516 |
+
fused = (stacked * weights).sum(dim=1)
|
| 517 |
+
|
| 518 |
+
else: # average
|
| 519 |
+
stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
|
| 520 |
+
fused = stacked.mean(dim=1)
|
| 521 |
+
|
| 522 |
+
return self.norm(fused)
|
peft_.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
参数高效微调 (PEFT) 模块
|
| 3 |
+
支持LoRA和Adapter
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
class LoRALayer(nn.Module):
|
| 10 |
+
"""低秩适应层 (LoRA)"""
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_features: int,
|
| 14 |
+
out_features: int,
|
| 15 |
+
rank: int = 8,
|
| 16 |
+
alpha: float = 16.0,
|
| 17 |
+
dropout: float = 0.0
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.rank = rank
|
| 21 |
+
self.alpha = alpha
|
| 22 |
+
self.scaling = alpha / rank
|
| 23 |
+
|
| 24 |
+
self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
|
| 25 |
+
self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
|
| 26 |
+
|
| 27 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 28 |
+
|
| 29 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 30 |
+
nn.init.zeros_(self.lora_B)
|
| 31 |
+
|
| 32 |
+
self.merged = False
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
"""前向传播"""
|
| 36 |
+
result = x @ self.lora_A @ self.lora_B
|
| 37 |
+
result = self.dropout(result)
|
| 38 |
+
return result * self.scaling
|
| 39 |
+
|
| 40 |
+
class LinearWithLoRA(nn.Module):
|
| 41 |
+
"""带LoRA的线性层"""
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
in_features: int,
|
| 45 |
+
out_features: int,
|
| 46 |
+
bias: bool = True,
|
| 47 |
+
use_lora: bool = False,
|
| 48 |
+
lora_rank: int = 8,
|
| 49 |
+
lora_alpha: float = 16.0,
|
| 50 |
+
lora_dropout: float = 0.0
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.in_features = in_features
|
| 54 |
+
self.out_features = out_features
|
| 55 |
+
self.use_lora = use_lora
|
| 56 |
+
|
| 57 |
+
self.base_linear = nn.Linear(in_features, out_features, bias=bias)
|
| 58 |
+
|
| 59 |
+
if use_lora:
|
| 60 |
+
self.lora = LoRALayer(
|
| 61 |
+
in_features,
|
| 62 |
+
out_features,
|
| 63 |
+
lora_rank,
|
| 64 |
+
lora_alpha,
|
| 65 |
+
lora_dropout
|
| 66 |
+
)
|
| 67 |
+
self.merged = False
|
| 68 |
+
else:
|
| 69 |
+
self.lora = None
|
| 70 |
+
self.merged = False
|
| 71 |
+
|
| 72 |
+
def merge(self):
|
| 73 |
+
"""将LoRA权重合并到基础权重中"""
|
| 74 |
+
if self.use_lora and not self.merged:
|
| 75 |
+
lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling
|
| 76 |
+
self.base_linear.weight.data += lora_weight.T
|
| 77 |
+
self.merged = True
|
| 78 |
+
|
| 79 |
+
def unmerge(self):
|
| 80 |
+
"""取消合并LoRA权重"""
|
| 81 |
+
if self.use_lora and self.merged:
|
| 82 |
+
lora_weight = (self.lora.lora_A @ self.lora.lora_B) * self.lora.scaling
|
| 83 |
+
self.base_linear.weight.data -= lora_weight.T
|
| 84 |
+
self.merged = False
|
| 85 |
+
|
| 86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
"""前向传播"""
|
| 88 |
+
output = self.base_linear(x)
|
| 89 |
+
|
| 90 |
+
if self.use_lora and self.lora is not None and not self.merged:
|
| 91 |
+
output = output + self.lora(x)
|
| 92 |
+
|
| 93 |
+
return output
|
| 94 |
+
|
| 95 |
+
class AdapterLayer(nn.Module):
|
| 96 |
+
"""Adapter层 - 轻量级微调"""
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
dim: int,
|
| 100 |
+
bottleneck_dim: int = 64,
|
| 101 |
+
dropout: float = 0.1,
|
| 102 |
+
activation: str = 'gelu',
|
| 103 |
+
residual_scale: float = 1.0
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.residual_scale = residual_scale
|
| 107 |
+
|
| 108 |
+
self.down_proj = nn.Linear(dim, bottleneck_dim)
|
| 109 |
+
|
| 110 |
+
if activation == 'gelu':
|
| 111 |
+
self.activation = nn.GELU()
|
| 112 |
+
elif activation == 'relu':
|
| 113 |
+
self.activation = nn.ReLU()
|
| 114 |
+
elif activation == 'silu':
|
| 115 |
+
self.activation = nn.SiLU()
|
| 116 |
+
else:
|
| 117 |
+
self.activation = nn.GELU()
|
| 118 |
+
|
| 119 |
+
self.up_proj = nn.Linear(bottleneck_dim, dim)
|
| 120 |
+
self.dropout = nn.Dropout(dropout)
|
| 121 |
+
|
| 122 |
+
from components import RMSNorm
|
| 123 |
+
self.layer_norm = RMSNorm(dim)
|
| 124 |
+
|
| 125 |
+
self._init_weights()
|
| 126 |
+
|
| 127 |
+
def _init_weights(self):
|
| 128 |
+
"""初始化权重"""
|
| 129 |
+
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
|
| 130 |
+
nn.init.zeros_(self.up_proj.weight)
|
| 131 |
+
if self.down_proj.bias is not None:
|
| 132 |
+
nn.init.zeros_(self.down_proj.bias)
|
| 133 |
+
if self.up_proj.bias is not None:
|
| 134 |
+
nn.init.zeros_(self.up_proj.bias)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
"""前向传播"""
|
| 138 |
+
residual = x
|
| 139 |
+
|
| 140 |
+
x = self.layer_norm(x)
|
| 141 |
+
x = self.down_proj(x)
|
| 142 |
+
x = self.activation(x)
|
| 143 |
+
x = self.dropout(x)
|
| 144 |
+
x = self.up_proj(x)
|
| 145 |
+
x = self.dropout(x)
|
| 146 |
+
|
| 147 |
+
return residual + x * self.residual_scale
|
| 148 |
+
|
| 149 |
+
class PrefixTuning(nn.Module):
|
| 150 |
+
"""Prefix Tuning"""
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
num_layers: int,
|
| 154 |
+
num_tokens: int,
|
| 155 |
+
dim: int,
|
| 156 |
+
num_heads: int
|
| 157 |
+
):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.num_layers = num_layers
|
| 160 |
+
self.num_tokens = num_tokens
|
| 161 |
+
self.dim = dim
|
| 162 |
+
self.num_heads = num_heads
|
| 163 |
+
|
| 164 |
+
head_dim = dim // num_heads
|
| 165 |
+
self.prefix = nn.Parameter(
|
| 166 |
+
torch.randn(num_layers, 2, num_tokens, num_heads, head_dim)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
nn.init.normal_(self.prefix, std=0.02)
|
| 170 |
+
|
| 171 |
+
def forward(self, layer_idx: int, batch_size: int) -> torch.Tensor:
|
| 172 |
+
"""获取指定层的prefix"""
|
| 173 |
+
prefix = self.prefix[layer_idx]
|
| 174 |
+
prefix = prefix.unsqueeze(1).expand(
|
| 175 |
+
2, batch_size, self.num_heads, self.num_tokens, -1
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return prefix
|
| 179 |
+
|
| 180 |
+
class PromptTuning(nn.Module):
|
| 181 |
+
"""Prompt Tuning"""
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
num_tokens: int,
|
| 185 |
+
dim: int,
|
| 186 |
+
init_from_vocab: bool = False,
|
| 187 |
+
vocab_embeddings: nn.Embedding = None
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.num_tokens = num_tokens
|
| 191 |
+
self.dim = dim
|
| 192 |
+
|
| 193 |
+
self.prompt_embeddings = nn.Parameter(torch.randn(num_tokens, dim))
|
| 194 |
+
|
| 195 |
+
if init_from_vocab and vocab_embeddings is not None:
|
| 196 |
+
indices = torch.randint(0, vocab_embeddings.num_embeddings, (num_tokens,))
|
| 197 |
+
self.prompt_embeddings.data = vocab_embeddings.weight[indices].clone()
|
| 198 |
+
else:
|
| 199 |
+
nn.init.normal_(self.prompt_embeddings, std=0.02)
|
| 200 |
+
|
| 201 |
+
def forward(self, batch_size: int) -> torch.Tensor:
|
| 202 |
+
"""获取prompt embeddings"""
|
| 203 |
+
return self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
|
| 204 |
+
|
| 205 |
+
class IALayer(nn.Module):
|
| 206 |
+
"""(IA)³层"""
|
| 207 |
+
def __init__(self, dim: int):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 210 |
+
|
| 211 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 212 |
+
"""应用缩放"""
|
| 213 |
+
return x * self.scale
|
post.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# posttrain.py
|
| 2 |
+
"""
|
| 3 |
+
后训练脚本 - Instruction tuning和对齐
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import logging
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import json
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import copy
|
| 15 |
+
from model import MultiModalDenseTransformer
|
| 16 |
+
|
| 17 |
+
from data_loader import (
|
| 18 |
+
create_posttrain_dataloader,
|
| 19 |
+
create_preference_dataloader
|
| 20 |
+
)
|
| 21 |
+
from data_config import POSTTRAIN_MIX
|
| 22 |
+
from reward_model import RewardModel, RewardModelTrainer
|
| 23 |
+
from grpo import GRPOTrainer
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 32 |
+
|
| 33 |
+
class PostTrainer:
|
| 34 |
+
"""后训练器 - Supervised Fine-Tuning"""
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
model: MultiModalDenseTransformer,
|
| 38 |
+
tokenizer,
|
| 39 |
+
learning_rate: float = 1e-5,
|
| 40 |
+
weight_decay: float = 0.01,
|
| 41 |
+
num_epochs: int = 3,
|
| 42 |
+
gradient_accumulation_steps: int = 1,
|
| 43 |
+
max_grad_norm: float = 1.0,
|
| 44 |
+
log_interval: int = 10,
|
| 45 |
+
eval_interval: int = 500,
|
| 46 |
+
save_interval: int = 1000,
|
| 47 |
+
checkpoint_dir: str = "checkpoints/posttrain"
|
| 48 |
+
):
|
| 49 |
+
self.model = model
|
| 50 |
+
self.tokenizer = tokenizer
|
| 51 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
+
|
| 53 |
+
self.model.to(self.device)
|
| 54 |
+
|
| 55 |
+
# 优化器
|
| 56 |
+
self.optimizer = torch.optim.AdamW(
|
| 57 |
+
model.parameters(),
|
| 58 |
+
lr=learning_rate,
|
| 59 |
+
weight_decay=weight_decay,
|
| 60 |
+
betas=(0.9, 0.95),
|
| 61 |
+
eps=1e-8
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# 混合精度
|
| 65 |
+
self.use_amp = torch.cuda.is_available()
|
| 66 |
+
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
| 67 |
+
|
| 68 |
+
# 训练参数
|
| 69 |
+
self.num_epochs = num_epochs
|
| 70 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 71 |
+
self.max_grad_norm = max_grad_norm
|
| 72 |
+
self.log_interval = log_interval
|
| 73 |
+
self.eval_interval = eval_interval
|
| 74 |
+
self.save_interval = save_interval
|
| 75 |
+
|
| 76 |
+
# Checkpoint管理
|
| 77 |
+
self.checkpoint_dir = Path(checkpoint_dir)
|
| 78 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
# 训练状态
|
| 81 |
+
self.global_step = 0
|
| 82 |
+
self.best_eval_loss = float('inf')
|
| 83 |
+
|
| 84 |
+
logger.info(f"PostTrainer initialized:")
|
| 85 |
+
logger.info(f" Device: {self.device}")
|
| 86 |
+
logger.info(f" Learning Rate: {learning_rate}")
|
| 87 |
+
logger.info(f" Num Epochs: {num_epochs}")
|
| 88 |
+
logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
|
| 89 |
+
|
| 90 |
+
def train_step(self, batch: dict) -> dict:
|
| 91 |
+
"""单步训练"""
|
| 92 |
+
instruction_ids = batch['instruction'].to(self.device)
|
| 93 |
+
response_ids = batch['response'].to(self.device)
|
| 94 |
+
|
| 95 |
+
# 获取 DataLoader 返回的掩码
|
| 96 |
+
instruction_mask = batch['instruction_mask'].to(self.device)
|
| 97 |
+
response_mask = batch['response_mask'].to(self.device)
|
| 98 |
+
# 拼接输入
|
| 99 |
+
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
|
| 100 |
+
attention_mask = torch.cat([instruction_mask, response_mask], dim=1).float()
|
| 101 |
+
# 创建标签(只计算response部分的损失)
|
| 102 |
+
labels = input_ids.clone()
|
| 103 |
+
instr_len = instruction_ids.shape[1]
|
| 104 |
+
labels[:, :instr_len] = -100
|
| 105 |
+
labels[attention_mask == 0] = -100
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# 准备输入数据
|
| 109 |
+
input_data = {
|
| 110 |
+
'segments': [{
|
| 111 |
+
'type': 'text',
|
| 112 |
+
'data': input_ids,
|
| 113 |
+
'modality_id': 0
|
| 114 |
+
}]
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# 前向传播
|
| 118 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 119 |
+
outputs = self.model(input_data,attention_mask=attention_mask)
|
| 120 |
+
logits = outputs['logits']
|
| 121 |
+
|
| 122 |
+
# 计算损失
|
| 123 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 124 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 125 |
+
|
| 126 |
+
loss = F.cross_entropy(
|
| 127 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 128 |
+
shift_labels.view(-1),
|
| 129 |
+
ignore_index=-100
|
| 130 |
+
)
|
| 131 |
+
raw_loss = loss.item()
|
| 132 |
+
loss = loss / self.gradient_accumulation_steps
|
| 133 |
+
|
| 134 |
+
# 反向传播
|
| 135 |
+
self.scaler.scale(loss).backward()
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
'loss': raw_loss
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
def optimizer_step(self):
|
| 142 |
+
"""优化器步骤"""
|
| 143 |
+
self.scaler.unscale_(self.optimizer)
|
| 144 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 145 |
+
self.model.parameters(),
|
| 146 |
+
self.max_grad_norm
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.scaler.step(self.optimizer)
|
| 150 |
+
self.scaler.update()
|
| 151 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 152 |
+
self.global_step += 1
|
| 153 |
+
return grad_norm.item()
|
| 154 |
+
|
| 155 |
+
@torch.no_grad()
|
| 156 |
+
def evaluate(self, dataloader, max_batches: int = 50) -> float:
|
| 157 |
+
"""评估"""
|
| 158 |
+
self.model.eval()
|
| 159 |
+
total_loss = 0.0
|
| 160 |
+
num_batches = 0
|
| 161 |
+
|
| 162 |
+
for i, batch in enumerate(dataloader):
|
| 163 |
+
if i >= max_batches:
|
| 164 |
+
break
|
| 165 |
+
|
| 166 |
+
if batch is None:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
instruction_ids = batch['instruction'].to(self.device)
|
| 170 |
+
response_ids = batch['response'].to(self.device)
|
| 171 |
+
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
|
| 172 |
+
|
| 173 |
+
labels = input_ids.clone()
|
| 174 |
+
labels[:, :instruction_ids.shape[1]] = -100
|
| 175 |
+
labels[input_ids == self.tokenizer.pad_token_id] = -100
|
| 176 |
+
|
| 177 |
+
input_data = {
|
| 178 |
+
'segments': [{
|
| 179 |
+
'type': 'text',
|
| 180 |
+
'data': input_ids,
|
| 181 |
+
'modality_id': 0
|
| 182 |
+
}]
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 186 |
+
outputs = self.model(input_data)
|
| 187 |
+
logits = outputs['logits']
|
| 188 |
+
|
| 189 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 190 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 191 |
+
|
| 192 |
+
loss = F.cross_entropy(
|
| 193 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 194 |
+
shift_labels.view(-1),
|
| 195 |
+
ignore_index=-100
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
total_loss += loss.item()
|
| 199 |
+
num_batches += 1
|
| 200 |
+
|
| 201 |
+
self.model.train()
|
| 202 |
+
return total_loss / max(num_batches, 1)
|
| 203 |
+
|
| 204 |
+
def train(
|
| 205 |
+
self,
|
| 206 |
+
train_dataloader,
|
| 207 |
+
eval_dataloader=None,
|
| 208 |
+
resume_from: Optional[str] = None
|
| 209 |
+
):
|
| 210 |
+
"""训练循环"""
|
| 211 |
+
logger.info("\n" + "="*80)
|
| 212 |
+
logger.info("Starting Post-Training (SFT)")
|
| 213 |
+
logger.info("="*80 + "\n")
|
| 214 |
+
|
| 215 |
+
if resume_from:
|
| 216 |
+
self.load_checkpoint(resume_from)
|
| 217 |
+
|
| 218 |
+
self.model.train()
|
| 219 |
+
|
| 220 |
+
for epoch in range(self.num_epochs):
|
| 221 |
+
logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
|
| 222 |
+
|
| 223 |
+
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
|
| 224 |
+
running_loss = 0.0
|
| 225 |
+
step_in_accumulation = 0
|
| 226 |
+
|
| 227 |
+
for batch_idx, batch in enumerate(progress_bar):
|
| 228 |
+
if batch is None:
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
# 训练步骤
|
| 232 |
+
stats = self.train_step(batch)
|
| 233 |
+
running_loss += stats['loss']
|
| 234 |
+
step_in_accumulation += 1
|
| 235 |
+
|
| 236 |
+
# 优化器更新
|
| 237 |
+
if step_in_accumulation == self.gradient_accumulation_steps:
|
| 238 |
+
grad_norm = self.optimizer_step()
|
| 239 |
+
step_in_accumulation = 0
|
| 240 |
+
|
| 241 |
+
# 更新进度条
|
| 242 |
+
progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
|
| 243 |
+
|
| 244 |
+
# 日志
|
| 245 |
+
if self.global_step % self.log_interval == 0:
|
| 246 |
+
avg_loss = running_loss / self.log_interval
|
| 247 |
+
logger.info(
|
| 248 |
+
f"Step {self.global_step} | "
|
| 249 |
+
f"Epoch {epoch+1} | "
|
| 250 |
+
f"Loss: {avg_loss:.4f}"
|
| 251 |
+
)
|
| 252 |
+
running_loss = 0.0
|
| 253 |
+
|
| 254 |
+
# 评估
|
| 255 |
+
if eval_dataloader and self.global_step % self.eval_interval == 0:
|
| 256 |
+
eval_loss = self.evaluate(eval_dataloader)
|
| 257 |
+
logger.info(f"Eval Loss: {eval_loss:.4f}")
|
| 258 |
+
|
| 259 |
+
if eval_loss < self.best_eval_loss:
|
| 260 |
+
self.best_eval_loss = eval_loss
|
| 261 |
+
self.save_checkpoint(
|
| 262 |
+
self.checkpoint_dir / "best_model.pt",
|
| 263 |
+
is_best=True
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# 保存
|
| 267 |
+
if self.global_step % self.save_interval == 0:
|
| 268 |
+
self.save_checkpoint(
|
| 269 |
+
self.checkpoint_dir / f"step_{self.global_step}.pt"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Epoch结束评估
|
| 273 |
+
if eval_dataloader:
|
| 274 |
+
eval_loss = self.evaluate(eval_dataloader)
|
| 275 |
+
logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
|
| 276 |
+
|
| 277 |
+
logger.info("\n" + "="*80)
|
| 278 |
+
logger.info("Post-Training Complete!")
|
| 279 |
+
logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
|
| 280 |
+
logger.info("="*80 + "\n")
|
| 281 |
+
|
| 282 |
+
self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
|
| 283 |
+
|
| 284 |
+
def save_checkpoint(self, path: Path, is_best: bool = False):
|
| 285 |
+
"""保存checkpoint"""
|
| 286 |
+
checkpoint = {
|
| 287 |
+
'model_state_dict': self.model.state_dict(),
|
| 288 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 289 |
+
'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
|
| 290 |
+
'global_step': self.global_step,
|
| 291 |
+
'best_eval_loss': self.best_eval_loss,
|
| 292 |
+
'timestamp': datetime.now().isoformat()
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
torch.save(checkpoint, path)
|
| 296 |
+
logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
|
| 297 |
+
|
| 298 |
+
def load_checkpoint(self, path: str):
|
| 299 |
+
"""加载checkpoint"""
|
| 300 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 301 |
+
|
| 302 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 303 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 304 |
+
|
| 305 |
+
if self.use_amp and checkpoint.get('scaler_state_dict'):
|
| 306 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 307 |
+
|
| 308 |
+
self.global_step = checkpoint['global_step']
|
| 309 |
+
self.best_eval_loss = checkpoint['best_eval_loss']
|
| 310 |
+
|
| 311 |
+
logger.info(f"Checkpoint loaded from {path}")
|
| 312 |
+
|
| 313 |
+
def main():
|
| 314 |
+
"""主函数"""
|
| 315 |
+
# 配置
|
| 316 |
+
config = {
|
| 317 |
+
# 模型配置
|
| 318 |
+
'model_dim': 1536,
|
| 319 |
+
'vocab_size': 151665,
|
| 320 |
+
'n_layers': 12,
|
| 321 |
+
'n_heads': 12,
|
| 322 |
+
'n_kv_heads': 4,
|
| 323 |
+
'max_seq_len': 512,
|
| 324 |
+
'dropout': 0.0,
|
| 325 |
+
'use_moe': False,
|
| 326 |
+
# 训练配置
|
| 327 |
+
'batch_size': 2,
|
| 328 |
+
'gradient_accumulation_steps': 8,
|
| 329 |
+
'learning_rate': 1e-4,
|
| 330 |
+
'weight_decay': 0.01,
|
| 331 |
+
'num_epochs': 1,
|
| 332 |
+
'max_grad_norm': 1.0,
|
| 333 |
+
|
| 334 |
+
# 数据配置
|
| 335 |
+
'data_mix': 'debug_mix',
|
| 336 |
+
'max_samples_train': 1000,
|
| 337 |
+
'max_samples_eval': 1000,
|
| 338 |
+
'max_length': 512,
|
| 339 |
+
'num_workers': 4,
|
| 340 |
+
|
| 341 |
+
# RLHF配置
|
| 342 |
+
'do_rlhf': False,
|
| 343 |
+
'preference_dataset': 'hh_rlhf',
|
| 344 |
+
'grpo_iterations': 3,
|
| 345 |
+
'grpo_kl_coef': 0.04,
|
| 346 |
+
'grpo_group_size': 4,
|
| 347 |
+
|
| 348 |
+
# 路径
|
| 349 |
+
'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
|
| 350 |
+
'checkpoint_dir': 'checkpoints/posttrain',
|
| 351 |
+
'log_interval': 50,
|
| 352 |
+
'eval_interval': 500,
|
| 353 |
+
'save_interval': 1000,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
logger.info("Configuration:")
|
| 357 |
+
logger.info(json.dumps(config, indent=2))
|
| 358 |
+
|
| 359 |
+
# 初始化tokenizer
|
| 360 |
+
logger.info("\nInitializing tokenizer...")
|
| 361 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 362 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 363 |
+
use_fast=True,
|
| 364 |
+
trust_remote_code=True
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if tokenizer.pad_token is None:
|
| 368 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 369 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 370 |
+
|
| 371 |
+
config['vocab_size'] = len(tokenizer)
|
| 372 |
+
|
| 373 |
+
# 初始化或加载模型
|
| 374 |
+
logger.info("\nInitializing model...")
|
| 375 |
+
model = MultiModalDenseTransformer(
|
| 376 |
+
model_dim=config['model_dim'],
|
| 377 |
+
vocab_size=config['vocab_size'],
|
| 378 |
+
n_layers=config['n_layers'],
|
| 379 |
+
n_heads=config['n_heads'],
|
| 380 |
+
n_kv_heads=config['n_kv_heads'],
|
| 381 |
+
max_seq_len=config['max_seq_len'],
|
| 382 |
+
dropout=config['dropout'],
|
| 383 |
+
use_moe=config['use_moe'],
|
| 384 |
+
use_gradient_checkpointing=False,
|
| 385 |
+
rope_scaling_type="yarn",
|
| 386 |
+
use_multimodal_fusion=False,
|
| 387 |
+
use_contrastive=False
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# 加载预训练checkpoint(如果有)
|
| 391 |
+
if config['pretrain_checkpoint']:
|
| 392 |
+
logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
|
| 393 |
+
checkpoint = torch.load(config['pretrain_checkpoint'])
|
| 394 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 395 |
+
|
| 396 |
+
# ===== 阶段1: Supervised Fine-Tuning =====
|
| 397 |
+
logger.info("\n" + "="*80)
|
| 398 |
+
logger.info("PHASE 1: Supervised Fine-Tuning")
|
| 399 |
+
logger.info("="*80)
|
| 400 |
+
|
| 401 |
+
# 创建数据加载器
|
| 402 |
+
train_dataloader = create_posttrain_dataloader(
|
| 403 |
+
mix_name=config['data_mix'],
|
| 404 |
+
tokenizer=tokenizer,
|
| 405 |
+
batch_size=config['batch_size'],
|
| 406 |
+
num_workers=config['num_workers'],
|
| 407 |
+
max_length=config['max_length'],
|
| 408 |
+
max_samples=config['max_samples_train'],
|
| 409 |
+
split='train',
|
| 410 |
+
shuffle=True
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
eval_dataloader = create_posttrain_dataloader(
|
| 414 |
+
mix_name=config['data_mix'],
|
| 415 |
+
tokenizer=tokenizer,
|
| 416 |
+
batch_size=config['batch_size'] * 2,
|
| 417 |
+
num_workers=config['num_workers'],
|
| 418 |
+
max_length=config['max_length'],
|
| 419 |
+
max_samples=config['max_samples_eval'],
|
| 420 |
+
split='train', # 使用train的后部分作为验证
|
| 421 |
+
shuffle=False
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# 创建训练器
|
| 425 |
+
trainer = PostTrainer(
|
| 426 |
+
model=model,
|
| 427 |
+
tokenizer=tokenizer,
|
| 428 |
+
learning_rate=config['learning_rate'],
|
| 429 |
+
weight_decay=config['weight_decay'],
|
| 430 |
+
num_epochs=config['num_epochs'],
|
| 431 |
+
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
| 432 |
+
max_grad_norm=config['max_grad_norm'],
|
| 433 |
+
log_interval=config['log_interval'],
|
| 434 |
+
eval_interval=config['eval_interval'],
|
| 435 |
+
save_interval=config['save_interval'],
|
| 436 |
+
checkpoint_dir=config['checkpoint_dir']
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# 开始SFT训练
|
| 440 |
+
trainer.train(train_dataloader, eval_dataloader)
|
| 441 |
+
|
| 442 |
+
# ===== 阶段2: RLHF with GRPO =====
|
| 443 |
+
if config['do_rlhf']:
|
| 444 |
+
logger.info("\n" + "="*80)
|
| 445 |
+
logger.info("PHASE 2: RLHF with GRPO")
|
| 446 |
+
logger.info("="*80)
|
| 447 |
+
|
| 448 |
+
try:
|
| 449 |
+
# 训练奖励模型
|
| 450 |
+
logger.info("\nTraining Reward Model...")
|
| 451 |
+
|
| 452 |
+
reward_base_model = copy.deepcopy(model)
|
| 453 |
+
reward_model = RewardModel(reward_base_model, use_value_head=True)
|
| 454 |
+
|
| 455 |
+
preference_dataloader = create_preference_dataloader(
|
| 456 |
+
dataset_name=config['preference_dataset'],
|
| 457 |
+
tokenizer=tokenizer,
|
| 458 |
+
batch_size=config['batch_size'],
|
| 459 |
+
num_workers=config['num_workers'],
|
| 460 |
+
max_samples=5000,
|
| 461 |
+
split='train'
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
reward_trainer = RewardModelTrainer(
|
| 465 |
+
reward_model=reward_model,
|
| 466 |
+
learning_rate=1e-5
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
reward_trainer.train(preference_dataloader, num_epochs=1)
|
| 470 |
+
|
| 471 |
+
# GRPO训练
|
| 472 |
+
logger.info("\nStarting GRPO Training...")
|
| 473 |
+
|
| 474 |
+
ref_model = copy.deepcopy(model)
|
| 475 |
+
ref_model.eval()
|
| 476 |
+
|
| 477 |
+
grpo_trainer = GRPOTrainer(
|
| 478 |
+
actor_model=model,
|
| 479 |
+
reward_model=reward_model,
|
| 480 |
+
ref_model=ref_model,
|
| 481 |
+
tokenizer=tokenizer,
|
| 482 |
+
learning_rate=1e-6,
|
| 483 |
+
kl_coef=config['grpo_kl_coef'],
|
| 484 |
+
group_size=config['grpo_group_size'],
|
| 485 |
+
update_batch_size=2,
|
| 486 |
+
use_amp=True
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# 准备prompts
|
| 490 |
+
prompt_dataloader = create_posttrain_dataloader(
|
| 491 |
+
mix_name=config['data_mix'],
|
| 492 |
+
tokenizer=tokenizer,
|
| 493 |
+
batch_size=4,
|
| 494 |
+
num_workers=2,
|
| 495 |
+
max_samples=1000,
|
| 496 |
+
split='train'
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# 提取prompts
|
| 500 |
+
prompts = []
|
| 501 |
+
for batch in prompt_dataloader:
|
| 502 |
+
if batch and batch.get('instruction') is not None:
|
| 503 |
+
prompts.append(batch['instruction'])
|
| 504 |
+
if len(prompts) >= 200:
|
| 505 |
+
break
|
| 506 |
+
|
| 507 |
+
if prompts:
|
| 508 |
+
prompt_tensor = torch.cat(prompts[:200], dim=0)
|
| 509 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 510 |
+
prompt_loader = DataLoader(
|
| 511 |
+
TensorDataset(prompt_tensor),
|
| 512 |
+
batch_size=4
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
grpo_trainer.train(
|
| 516 |
+
prompt_loader,
|
| 517 |
+
num_iterations=config['grpo_iterations'],
|
| 518 |
+
max_gen_len=50,
|
| 519 |
+
save_path=config['checkpoint_dir'] + "/grpo"
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
except Exception as e:
|
| 523 |
+
logger.error(f"Error in RLHF: {e}")
|
| 524 |
+
import traceback
|
| 525 |
+
traceback.print_exc()
|
| 526 |
+
|
| 527 |
+
logger.info("\n" + "="*80)
|
| 528 |
+
logger.info("All Training Complete!")
|
| 529 |
+
logger.info("="*80)
|
| 530 |
+
|
| 531 |
+
if __name__ == "__main__":
|
| 532 |
+
main()
|
posttrain.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# posttrain.py
|
| 2 |
+
"""
|
| 3 |
+
后训练脚本 - Instruction tuning和对齐
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import logging
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import json
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import copy
|
| 15 |
+
from model import MultiModalDenseTransformer
|
| 16 |
+
|
| 17 |
+
from data_loader import (
|
| 18 |
+
create_posttrain_dataloader,
|
| 19 |
+
create_preference_dataloader
|
| 20 |
+
)
|
| 21 |
+
from data_config import POSTTRAIN_MIX
|
| 22 |
+
from reward_model import RewardModel, RewardModelTrainer
|
| 23 |
+
from grpo import GRPOTrainer
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 32 |
+
|
| 33 |
+
class PostTrainer:
|
| 34 |
+
"""后训练器 - Supervised Fine-Tuning"""
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
model: MultiModalDenseTransformer,
|
| 38 |
+
tokenizer,
|
| 39 |
+
learning_rate: float = 1e-5,
|
| 40 |
+
weight_decay: float = 0.01,
|
| 41 |
+
num_epochs: int = 3,
|
| 42 |
+
gradient_accumulation_steps: int = 1,
|
| 43 |
+
max_grad_norm: float = 1.0,
|
| 44 |
+
log_interval: int = 10,
|
| 45 |
+
eval_interval: int = 500,
|
| 46 |
+
save_interval: int = 1000,
|
| 47 |
+
checkpoint_dir: str = "checkpoints/posttrain"
|
| 48 |
+
):
|
| 49 |
+
self.model = model
|
| 50 |
+
self.tokenizer = tokenizer
|
| 51 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
+
|
| 53 |
+
self.model.to(self.device)
|
| 54 |
+
|
| 55 |
+
# 优化器
|
| 56 |
+
self.optimizer = torch.optim.AdamW(
|
| 57 |
+
model.parameters(),
|
| 58 |
+
lr=learning_rate,
|
| 59 |
+
weight_decay=weight_decay,
|
| 60 |
+
betas=(0.9, 0.95),
|
| 61 |
+
eps=1e-8
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# 混合精度
|
| 65 |
+
self.use_amp = torch.cuda.is_available()
|
| 66 |
+
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
| 67 |
+
|
| 68 |
+
# 训练参数
|
| 69 |
+
self.num_epochs = num_epochs
|
| 70 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 71 |
+
self.max_grad_norm = max_grad_norm
|
| 72 |
+
self.log_interval = log_interval
|
| 73 |
+
self.eval_interval = eval_interval
|
| 74 |
+
self.save_interval = save_interval
|
| 75 |
+
|
| 76 |
+
# Checkpoint管理
|
| 77 |
+
self.checkpoint_dir = Path(checkpoint_dir)
|
| 78 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
# 训练状态
|
| 81 |
+
self.global_step = 0
|
| 82 |
+
self.best_eval_loss = float('inf')
|
| 83 |
+
|
| 84 |
+
logger.info(f"PostTrainer initialized:")
|
| 85 |
+
logger.info(f" Device: {self.device}")
|
| 86 |
+
logger.info(f" Learning Rate: {learning_rate}")
|
| 87 |
+
logger.info(f" Num Epochs: {num_epochs}")
|
| 88 |
+
logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
|
| 89 |
+
|
| 90 |
+
def train_step(self, batch: dict) -> dict:
|
| 91 |
+
"""单步训练"""
|
| 92 |
+
instruction_ids = batch['instruction'].to(self.device)
|
| 93 |
+
response_ids = batch['response'].to(self.device)
|
| 94 |
+
|
| 95 |
+
# 1. 获取 Mask (这是之前代码里漏掉的)
|
| 96 |
+
instruction_mask = batch['instruction_mask'].to(self.device)
|
| 97 |
+
response_mask = batch['response_mask'].to(self.device)
|
| 98 |
+
|
| 99 |
+
# 2. 拼接输入 ID 和 Mask
|
| 100 |
+
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
|
| 101 |
+
attention_mask = torch.cat([instruction_mask, response_mask], dim=1)
|
| 102 |
+
|
| 103 |
+
batch_size , seq_len = input_ids.shape
|
| 104 |
+
position_ids=torch.zeros_like(input_ids)
|
| 105 |
+
|
| 106 |
+
for i in range(batch_size):
|
| 107 |
+
non_pad_mask = attention_mask[i].bool()
|
| 108 |
+
if non_pad_mask.any():
|
| 109 |
+
positions=torch.cumsum(non_pad_mask.long(), dim=0) -1
|
| 110 |
+
position_ids[i] = positions * non_pad_mask.long()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# 3. 创建标签
|
| 117 |
+
labels = input_ids.clone()
|
| 118 |
+
|
| 119 |
+
# 屏蔽 Instruction 部分
|
| 120 |
+
instr_len = instruction_ids.shape[1]
|
| 121 |
+
labels[:, :instr_len] = -100
|
| 122 |
+
|
| 123 |
+
labels[attention_mask == 0] = -100
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# 准备输入数据
|
| 127 |
+
input_data = {
|
| 128 |
+
'segments': [{
|
| 129 |
+
'type': 'text',
|
| 130 |
+
'data': input_ids,
|
| 131 |
+
'modality_id': 0
|
| 132 |
+
}]
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# 前向传播
|
| 136 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 137 |
+
# === 核心修改点 2 ===
|
| 138 |
+
# 必须传入 attention_mask,否则 transformer 不知道哪里是 padding
|
| 139 |
+
outputs = self.model(input_data, attention_mask=attention_mask,
|
| 140 |
+
position_ids = position_ids)
|
| 141 |
+
|
| 142 |
+
logits = outputs['logits']
|
| 143 |
+
|
| 144 |
+
# 计算损失
|
| 145 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 146 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 147 |
+
|
| 148 |
+
loss = F.cross_entropy(
|
| 149 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 150 |
+
shift_labels.view(-1),
|
| 151 |
+
ignore_index=-100
|
| 152 |
+
)
|
| 153 |
+
raw_loss = loss.item()
|
| 154 |
+
loss = loss / self.gradient_accumulation_steps
|
| 155 |
+
|
| 156 |
+
# 反向传播
|
| 157 |
+
self.scaler.scale(loss).backward()
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
'loss': raw_loss
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def optimizer_step(self):
|
| 164 |
+
"""优化器步骤"""
|
| 165 |
+
self.scaler.unscale_(self.optimizer)
|
| 166 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 167 |
+
self.model.parameters(),
|
| 168 |
+
self.max_grad_norm
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.scaler.step(self.optimizer)
|
| 172 |
+
self.scaler.update()
|
| 173 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 174 |
+
self.global_step += 1
|
| 175 |
+
return grad_norm.item()
|
| 176 |
+
|
| 177 |
+
@torch.no_grad()
|
| 178 |
+
def evaluate(self, dataloader, max_batches: int = 50) -> float:
|
| 179 |
+
"""评估"""
|
| 180 |
+
self.model.eval()
|
| 181 |
+
total_loss = 0.0
|
| 182 |
+
num_batches = 0
|
| 183 |
+
|
| 184 |
+
for i, batch in enumerate(dataloader):
|
| 185 |
+
if i >= max_batches:
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
if batch is None:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
instruction_ids = batch['instruction'].to(self.device)
|
| 192 |
+
response_ids = batch['response'].to(self.device)
|
| 193 |
+
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
|
| 194 |
+
|
| 195 |
+
labels = input_ids.clone()
|
| 196 |
+
labels[:, :instruction_ids.shape[1]] = -100
|
| 197 |
+
labels[input_ids == self.tokenizer.pad_token_id] = -100
|
| 198 |
+
|
| 199 |
+
input_data = {
|
| 200 |
+
'segments': [{
|
| 201 |
+
'type': 'text',
|
| 202 |
+
'data': input_ids,
|
| 203 |
+
'modality_id': 0
|
| 204 |
+
}]
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 208 |
+
outputs = self.model(input_data)
|
| 209 |
+
logits = outputs['logits']
|
| 210 |
+
|
| 211 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 212 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 213 |
+
|
| 214 |
+
loss = F.cross_entropy(
|
| 215 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 216 |
+
shift_labels.view(-1),
|
| 217 |
+
ignore_index=-100
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
total_loss += loss.item()
|
| 221 |
+
num_batches += 1
|
| 222 |
+
|
| 223 |
+
self.model.train()
|
| 224 |
+
return total_loss / max(num_batches, 1)
|
| 225 |
+
|
| 226 |
+
def train(
|
| 227 |
+
self,
|
| 228 |
+
train_dataloader,
|
| 229 |
+
eval_dataloader=None,
|
| 230 |
+
resume_from: Optional[str] = None
|
| 231 |
+
):
|
| 232 |
+
"""训练循环"""
|
| 233 |
+
logger.info("\n" + "="*80)
|
| 234 |
+
logger.info("Starting Post-Training (SFT)")
|
| 235 |
+
logger.info("="*80 + "\n")
|
| 236 |
+
|
| 237 |
+
if resume_from:
|
| 238 |
+
self.load_checkpoint(resume_from)
|
| 239 |
+
|
| 240 |
+
self.model.train()
|
| 241 |
+
|
| 242 |
+
for epoch in range(self.num_epochs):
|
| 243 |
+
logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
|
| 244 |
+
|
| 245 |
+
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
|
| 246 |
+
running_loss = 0.0
|
| 247 |
+
step_in_accumulation = 0
|
| 248 |
+
|
| 249 |
+
for batch_idx, batch in enumerate(progress_bar):
|
| 250 |
+
if batch is None:
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
# 训练步骤
|
| 254 |
+
stats = self.train_step(batch)
|
| 255 |
+
running_loss += stats['loss']
|
| 256 |
+
step_in_accumulation += 1
|
| 257 |
+
|
| 258 |
+
# 优化器更新
|
| 259 |
+
if step_in_accumulation == self.gradient_accumulation_steps:
|
| 260 |
+
grad_norm = self.optimizer_step()
|
| 261 |
+
step_in_accumulation = 0
|
| 262 |
+
|
| 263 |
+
# 更新进度条
|
| 264 |
+
progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
|
| 265 |
+
|
| 266 |
+
# 日志
|
| 267 |
+
if self.global_step % self.log_interval == 0:
|
| 268 |
+
avg_loss = running_loss / self.log_interval
|
| 269 |
+
logger.info(
|
| 270 |
+
f"Step {self.global_step} | "
|
| 271 |
+
f"Epoch {epoch+1} | "
|
| 272 |
+
f"Loss: {avg_loss:.4f}"
|
| 273 |
+
)
|
| 274 |
+
running_loss = 0.0
|
| 275 |
+
|
| 276 |
+
# 评估
|
| 277 |
+
if eval_dataloader and self.global_step % self.eval_interval == 0:
|
| 278 |
+
eval_loss = self.evaluate(eval_dataloader)
|
| 279 |
+
logger.info(f"Eval Loss: {eval_loss:.4f}")
|
| 280 |
+
|
| 281 |
+
if eval_loss < self.best_eval_loss:
|
| 282 |
+
self.best_eval_loss = eval_loss
|
| 283 |
+
self.save_checkpoint(
|
| 284 |
+
self.checkpoint_dir / "best_model.pt",
|
| 285 |
+
is_best=True
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# 保存
|
| 289 |
+
if self.global_step % self.save_interval == 0:
|
| 290 |
+
self.save_checkpoint(
|
| 291 |
+
self.checkpoint_dir / f"step_{self.global_step}.pt"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Epoch结束评估
|
| 295 |
+
if eval_dataloader:
|
| 296 |
+
eval_loss = self.evaluate(eval_dataloader)
|
| 297 |
+
logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
|
| 298 |
+
|
| 299 |
+
logger.info("\n" + "="*80)
|
| 300 |
+
logger.info("Post-Training Complete!")
|
| 301 |
+
logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
|
| 302 |
+
logger.info("="*80 + "\n")
|
| 303 |
+
|
| 304 |
+
self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
|
| 305 |
+
|
| 306 |
+
def save_checkpoint(self, path: Path, is_best: bool = False):
|
| 307 |
+
"""保存checkpoint"""
|
| 308 |
+
checkpoint = {
|
| 309 |
+
'model_state_dict': self.model.state_dict(),
|
| 310 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 311 |
+
'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
|
| 312 |
+
'global_step': self.global_step,
|
| 313 |
+
'best_eval_loss': self.best_eval_loss,
|
| 314 |
+
'timestamp': datetime.now().isoformat()
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
torch.save(checkpoint, path)
|
| 318 |
+
logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
|
| 319 |
+
|
| 320 |
+
def load_checkpoint(self, path: str):
|
| 321 |
+
"""加载checkpoint"""
|
| 322 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 323 |
+
|
| 324 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 325 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 326 |
+
|
| 327 |
+
if self.use_amp and checkpoint.get('scaler_state_dict'):
|
| 328 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 329 |
+
|
| 330 |
+
self.global_step = checkpoint['global_step']
|
| 331 |
+
self.best_eval_loss = checkpoint['best_eval_loss']
|
| 332 |
+
|
| 333 |
+
logger.info(f"Checkpoint loaded from {path}")
|
| 334 |
+
|
| 335 |
+
def main():
|
| 336 |
+
"""主函数"""
|
| 337 |
+
# 配置
|
| 338 |
+
config = {
|
| 339 |
+
# 模型配置
|
| 340 |
+
'model_dim': 1536,
|
| 341 |
+
'vocab_size': 151665,
|
| 342 |
+
'n_layers': 12,
|
| 343 |
+
'n_heads': 12,
|
| 344 |
+
'n_kv_heads': 4,
|
| 345 |
+
'max_seq_len': 512,
|
| 346 |
+
'dropout': 0.0,
|
| 347 |
+
'use_moe': False,
|
| 348 |
+
# 训练配置
|
| 349 |
+
'batch_size': 2,
|
| 350 |
+
'gradient_accumulation_steps': 8,
|
| 351 |
+
'learning_rate': 1e-5,
|
| 352 |
+
'weight_decay': 0.01,
|
| 353 |
+
'num_epochs': 3,
|
| 354 |
+
'max_grad_norm': 1.0,
|
| 355 |
+
|
| 356 |
+
# 数据配置
|
| 357 |
+
'data_mix': 'simple_instruct',
|
| 358 |
+
'max_samples_train': 20000,
|
| 359 |
+
'max_samples_eval': 1000,
|
| 360 |
+
'max_length': 512,
|
| 361 |
+
'num_workers': 4,
|
| 362 |
+
|
| 363 |
+
# RLHF配置
|
| 364 |
+
'do_rlhf': False,
|
| 365 |
+
'preference_dataset': 'hh_rlhf',
|
| 366 |
+
'grpo_iterations': 3,
|
| 367 |
+
'grpo_kl_coef': 0.04,
|
| 368 |
+
'grpo_group_size': 4,
|
| 369 |
+
|
| 370 |
+
# 路径
|
| 371 |
+
'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
|
| 372 |
+
'checkpoint_dir': 'checkpoints/posttrain',
|
| 373 |
+
'log_interval': 50,
|
| 374 |
+
'eval_interval': 500,
|
| 375 |
+
'save_interval': 1000,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
logger.info("Configuration:")
|
| 379 |
+
logger.info(json.dumps(config, indent=2))
|
| 380 |
+
|
| 381 |
+
# 初始化tokenizer
|
| 382 |
+
logger.info("\nInitializing tokenizer...")
|
| 383 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 384 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 385 |
+
use_fast=True,
|
| 386 |
+
trust_remote_code=True
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if tokenizer.pad_token is None:
|
| 390 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 391 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 392 |
+
|
| 393 |
+
config['vocab_size'] = len(tokenizer)
|
| 394 |
+
|
| 395 |
+
# 初始化或加载模型
|
| 396 |
+
logger.info("\nInitializing model...")
|
| 397 |
+
model = MultiModalDenseTransformer(
|
| 398 |
+
model_dim=config['model_dim'],
|
| 399 |
+
vocab_size=config['vocab_size'],
|
| 400 |
+
n_layers=config['n_layers'],
|
| 401 |
+
n_heads=config['n_heads'],
|
| 402 |
+
n_kv_heads=config['n_kv_heads'],
|
| 403 |
+
max_seq_len=config['max_seq_len'],
|
| 404 |
+
dropout=config['dropout'],
|
| 405 |
+
use_moe=config['use_moe'],
|
| 406 |
+
use_gradient_checkpointing=False,
|
| 407 |
+
rope_scaling_type="yarn",
|
| 408 |
+
use_multimodal_fusion=False,
|
| 409 |
+
use_contrastive=False
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# 加载预训练checkpoint(如果有)
|
| 413 |
+
if config['pretrain_checkpoint']:
|
| 414 |
+
logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
|
| 415 |
+
checkpoint = torch.load(config['pretrain_checkpoint'])
|
| 416 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 417 |
+
|
| 418 |
+
# ===== 阶段1: Supervised Fine-Tuning =====
|
| 419 |
+
logger.info("\n" + "="*80)
|
| 420 |
+
logger.info("PHASE 1: Supervised Fine-Tuning")
|
| 421 |
+
logger.info("="*80)
|
| 422 |
+
|
| 423 |
+
# 创建数据加载器
|
| 424 |
+
train_dataloader = create_posttrain_dataloader(
|
| 425 |
+
mix_name=config['data_mix'],
|
| 426 |
+
tokenizer=tokenizer,
|
| 427 |
+
batch_size=config['batch_size'],
|
| 428 |
+
num_workers=config['num_workers'],
|
| 429 |
+
max_length=config['max_length'],
|
| 430 |
+
max_samples=config['max_samples_train'],
|
| 431 |
+
split='train',
|
| 432 |
+
shuffle=True
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
eval_dataloader = create_posttrain_dataloader(
|
| 436 |
+
mix_name=config['data_mix'],
|
| 437 |
+
tokenizer=tokenizer,
|
| 438 |
+
batch_size=config['batch_size'] * 2,
|
| 439 |
+
num_workers=config['num_workers'],
|
| 440 |
+
max_length=config['max_length'],
|
| 441 |
+
max_samples=config['max_samples_eval'],
|
| 442 |
+
split='train', # 使用train的后部分作为验证
|
| 443 |
+
shuffle=False
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# 创建训练器
|
| 447 |
+
trainer = PostTrainer(
|
| 448 |
+
model=model,
|
| 449 |
+
tokenizer=tokenizer,
|
| 450 |
+
learning_rate=config['learning_rate'],
|
| 451 |
+
weight_decay=config['weight_decay'],
|
| 452 |
+
num_epochs=config['num_epochs'],
|
| 453 |
+
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
| 454 |
+
max_grad_norm=config['max_grad_norm'],
|
| 455 |
+
log_interval=config['log_interval'],
|
| 456 |
+
eval_interval=config['eval_interval'],
|
| 457 |
+
save_interval=config['save_interval'],
|
| 458 |
+
checkpoint_dir=config['checkpoint_dir']
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# 开始SFT训练
|
| 462 |
+
trainer.train(train_dataloader, eval_dataloader)
|
| 463 |
+
|
| 464 |
+
# ===== 阶段2: RLHF with GRPO =====
|
| 465 |
+
if config['do_rlhf']:
|
| 466 |
+
logger.info("\n" + "="*80)
|
| 467 |
+
logger.info("PHASE 2: RLHF with GRPO")
|
| 468 |
+
logger.info("="*80)
|
| 469 |
+
|
| 470 |
+
try:
|
| 471 |
+
# 训练奖励模型
|
| 472 |
+
logger.info("\nTraining Reward Model...")
|
| 473 |
+
|
| 474 |
+
reward_base_model = copy.deepcopy(model)
|
| 475 |
+
reward_model = RewardModel(reward_base_model, use_value_head=True)
|
| 476 |
+
|
| 477 |
+
preference_dataloader = create_preference_dataloader(
|
| 478 |
+
dataset_name=config['preference_dataset'],
|
| 479 |
+
tokenizer=tokenizer,
|
| 480 |
+
batch_size=config['batch_size'],
|
| 481 |
+
num_workers=config['num_workers'],
|
| 482 |
+
max_samples=5000,
|
| 483 |
+
split='train'
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
reward_trainer = RewardModelTrainer(
|
| 487 |
+
reward_model=reward_model,
|
| 488 |
+
learning_rate=1e-5
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
reward_trainer.train(preference_dataloader, num_epochs=1)
|
| 492 |
+
|
| 493 |
+
# GRPO训练
|
| 494 |
+
logger.info("\nStarting GRPO Training...")
|
| 495 |
+
|
| 496 |
+
ref_model = copy.deepcopy(model)
|
| 497 |
+
ref_model.eval()
|
| 498 |
+
|
| 499 |
+
grpo_trainer = GRPOTrainer(
|
| 500 |
+
actor_model=model,
|
| 501 |
+
reward_model=reward_model,
|
| 502 |
+
ref_model=ref_model,
|
| 503 |
+
tokenizer=tokenizer,
|
| 504 |
+
learning_rate=1e-6,
|
| 505 |
+
kl_coef=config['grpo_kl_coef'],
|
| 506 |
+
group_size=config['grpo_group_size'],
|
| 507 |
+
update_batch_size=2,
|
| 508 |
+
use_amp=True
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# 准备prompts
|
| 512 |
+
prompt_dataloader = create_posttrain_dataloader(
|
| 513 |
+
mix_name=config['data_mix'],
|
| 514 |
+
tokenizer=tokenizer,
|
| 515 |
+
batch_size=4,
|
| 516 |
+
num_workers=2,
|
| 517 |
+
max_samples=1000,
|
| 518 |
+
split='train'
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# 提取prompts
|
| 522 |
+
prompts = []
|
| 523 |
+
for batch in prompt_dataloader:
|
| 524 |
+
if batch and batch.get('instruction') is not None:
|
| 525 |
+
prompts.append(batch['instruction'])
|
| 526 |
+
if len(prompts) >= 200:
|
| 527 |
+
break
|
| 528 |
+
|
| 529 |
+
if prompts:
|
| 530 |
+
prompt_tensor = torch.cat(prompts[:200], dim=0)
|
| 531 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 532 |
+
prompt_loader = DataLoader(
|
| 533 |
+
TensorDataset(prompt_tensor),
|
| 534 |
+
batch_size=4
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
grpo_trainer.train(
|
| 538 |
+
prompt_loader,
|
| 539 |
+
num_iterations=config['grpo_iterations'],
|
| 540 |
+
max_gen_len=50,
|
| 541 |
+
save_path=config['checkpoint_dir'] + "/grpo"
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
except Exception as e:
|
| 545 |
+
logger.error(f"Error in RLHF: {e}")
|
| 546 |
+
import traceback
|
| 547 |
+
traceback.print_exc()
|
| 548 |
+
|
| 549 |
+
logger.info("\n" + "="*80)
|
| 550 |
+
logger.info("All Training Complete!")
|
| 551 |
+
logger.info("="*80)
|
| 552 |
+
|
| 553 |
+
if __name__ == "__main__":
|
| 554 |
+
main()
|
pretrain.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pretrain.py - 完全修复版本
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import logging
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import json
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from model import MultiModalDenseTransformer
|
| 13 |
+
from data_loader import create_pretrain_dataloader
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=logging.INFO,
|
| 17 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 18 |
+
)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PreTrainer:
|
| 24 |
+
"""预训练器 - 完全修复版"""
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
model: MultiModalDenseTransformer,
|
| 28 |
+
tokenizer,
|
| 29 |
+
learning_rate: float = 3e-4,
|
| 30 |
+
weight_decay: float = 0.1,
|
| 31 |
+
warmup_steps: int = 1000,
|
| 32 |
+
max_steps: int = 100000,
|
| 33 |
+
gradient_accumulation_steps: int = 16,
|
| 34 |
+
max_grad_norm: float = 1.0,
|
| 35 |
+
log_interval: int = 10,
|
| 36 |
+
save_interval: int = 1000,
|
| 37 |
+
checkpoint_dir: str = "checkpoints/pretrain",
|
| 38 |
+
loss_log_file: str = "checkpoints/pretrain/train_loss.log"
|
| 39 |
+
):
|
| 40 |
+
self.model = model
|
| 41 |
+
self.tokenizer = tokenizer
|
| 42 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 43 |
+
|
| 44 |
+
self.model.to(self.device)
|
| 45 |
+
|
| 46 |
+
# 优化器配置 - 使用标准AdamW参数
|
| 47 |
+
self.optimizer = torch.optim.AdamW(
|
| 48 |
+
model.parameters(),
|
| 49 |
+
lr=learning_rate,
|
| 50 |
+
weight_decay=weight_decay,
|
| 51 |
+
betas=(0.9, 0.95),
|
| 52 |
+
eps=1e-8
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# 🔧 修复:使用更简单的学习率调度器
|
| 56 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
| 57 |
+
|
| 58 |
+
# Warmup + Cosine Decay
|
| 59 |
+
self.warmup_steps = warmup_steps
|
| 60 |
+
self.max_lr = learning_rate
|
| 61 |
+
self.min_lr = learning_rate * 0.1
|
| 62 |
+
self.current_step = 0
|
| 63 |
+
|
| 64 |
+
# 混合精度
|
| 65 |
+
self.use_amp = torch.cuda.is_available()
|
| 66 |
+
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
|
| 67 |
+
|
| 68 |
+
# 训练参数
|
| 69 |
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
| 70 |
+
self.max_grad_norm = max_grad_norm
|
| 71 |
+
self.max_steps = max_steps
|
| 72 |
+
self.log_interval = log_interval
|
| 73 |
+
self.save_interval = save_interval
|
| 74 |
+
|
| 75 |
+
# Checkpoint管理
|
| 76 |
+
self.checkpoint_dir = Path(checkpoint_dir)
|
| 77 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
# 损失日志
|
| 80 |
+
self.loss_log_file = Path(loss_log_file)
|
| 81 |
+
self.loss_log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 82 |
+
|
| 83 |
+
# 训练状态
|
| 84 |
+
self.global_step = 0
|
| 85 |
+
self.tokens_seen = 0
|
| 86 |
+
self.running_loss = 0.0
|
| 87 |
+
self.best_loss = float('inf')
|
| 88 |
+
|
| 89 |
+
logger.info(f"PreTrainer initialized:")
|
| 90 |
+
logger.info(f" Device: {self.device}")
|
| 91 |
+
logger.info(f" Learning Rate: {learning_rate}")
|
| 92 |
+
logger.info(f" Max Steps: {max_steps}")
|
| 93 |
+
logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
|
| 94 |
+
logger.info(f" Effective Batch Size: {gradient_accumulation_steps}")
|
| 95 |
+
logger.info(f" Mixed Precision: {self.use_amp}")
|
| 96 |
+
|
| 97 |
+
def _get_lr(self) -> float:
|
| 98 |
+
"""手动计算学习率(Warmup + Cosine)"""
|
| 99 |
+
if self.current_step < self.warmup_steps:
|
| 100 |
+
# Linear warmup
|
| 101 |
+
return self.max_lr * (self.current_step / self.warmup_steps)
|
| 102 |
+
else:
|
| 103 |
+
# Cosine decay
|
| 104 |
+
progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
|
| 105 |
+
return self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)))
|
| 106 |
+
|
| 107 |
+
def _set_lr(self, lr: float):
|
| 108 |
+
"""设置学习率"""
|
| 109 |
+
for param_group in self.optimizer.param_groups:
|
| 110 |
+
param_group['lr'] = lr
|
| 111 |
+
|
| 112 |
+
def train_step(self, batch: dict) -> dict:
|
| 113 |
+
"""
|
| 114 |
+
🔧 完全修复的训练步骤
|
| 115 |
+
关键:不要在loss计算时除以gradient_accumulation_steps
|
| 116 |
+
"""
|
| 117 |
+
input_ids = batch['input_ids'].to(self.device)
|
| 118 |
+
attention_mask = batch['attention_mask'].to(self.device)
|
| 119 |
+
batch_size, seq_len = input_ids.shape
|
| 120 |
+
position_ids= torch.zeros_like(input_ids)
|
| 121 |
+
|
| 122 |
+
for i in range(batch_size):
|
| 123 |
+
non_pad_mask = attention_mask[i].bool()
|
| 124 |
+
if non_pad_mask.any():
|
| 125 |
+
positions = torch.cumsum(non_pad_mask.long(), dim=0) -1
|
| 126 |
+
position_ids[i]=positions * non_pad_mask.long()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# 准备输入
|
| 131 |
+
input_data = {
|
| 132 |
+
'segments': [{
|
| 133 |
+
'type': 'text',
|
| 134 |
+
'data': input_ids,
|
| 135 |
+
'modality_id': 0
|
| 136 |
+
}]
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
# 前向传播
|
| 140 |
+
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
| 141 |
+
outputs = self.model(
|
| 142 |
+
input_data,
|
| 143 |
+
attention_mask=attention_mask,
|
| 144 |
+
position_ids=position_ids)
|
| 145 |
+
logits = outputs['logits']
|
| 146 |
+
|
| 147 |
+
# 计算损失(标准自回归)
|
| 148 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 149 |
+
shift_labels = input_ids[:, 1:].contiguous()
|
| 150 |
+
shift_attention_mask = attention_mask[:, 1:].contiguous()
|
| 151 |
+
|
| 152 |
+
# 🔧 关键修复:直接计算平均loss,不要除以gradient_accumulation_steps
|
| 153 |
+
loss = F.cross_entropy(
|
| 154 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 155 |
+
shift_labels.view(-1),
|
| 156 |
+
reduction='none'
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# 应用mask
|
| 160 |
+
loss = (loss * shift_attention_mask.view(-1)).sum() / (shift_attention_mask.sum() + 1e-8)
|
| 161 |
+
|
| 162 |
+
# 🔧 重要:为了数值稳定,在这里手动处理梯度累积
|
| 163 |
+
# 方法:缩放loss用于反向传播,但记录原始loss
|
| 164 |
+
loss_for_backward = loss / self.gradient_accumulation_steps
|
| 165 |
+
|
| 166 |
+
# 反向传播(使用缩放后的loss)
|
| 167 |
+
self.scaler.scale(loss_for_backward).backward()
|
| 168 |
+
|
| 169 |
+
# 🔧 关键修复:不在这里累积loss,改在optimizer_step时累积
|
| 170 |
+
# self.running_loss += loss.item() # ❌ 移除
|
| 171 |
+
self.tokens_seen += attention_mask.sum().item()
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
'loss': loss.item(), # 返回真实的、未缩放的loss
|
| 175 |
+
'lr': self.optimizer.param_groups[0]['lr']
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
def optimizer_step(self):
|
| 179 |
+
"""优化器步骤"""
|
| 180 |
+
# Unscale梯度
|
| 181 |
+
self.scaler.unscale_(self.optimizer)
|
| 182 |
+
|
| 183 |
+
# 梯度裁剪
|
| 184 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 185 |
+
self.model.parameters(),
|
| 186 |
+
self.max_grad_norm
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# 更新参数
|
| 190 |
+
self.scaler.step(self.optimizer)
|
| 191 |
+
self.scaler.update()
|
| 192 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 193 |
+
|
| 194 |
+
# 更新学习率
|
| 195 |
+
self.current_step += 1
|
| 196 |
+
self.global_step += 1
|
| 197 |
+
lr = self._get_lr()
|
| 198 |
+
self._set_lr(lr)
|
| 199 |
+
|
| 200 |
+
return grad_norm.item()
|
| 201 |
+
|
| 202 |
+
def _write_loss_to_txt(self, step, avg_loss, lr, tokens_seen):
|
| 203 |
+
"""写入损失日志"""
|
| 204 |
+
log_content = (
|
| 205 |
+
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
| 206 |
+
f"Step: {step}/{self.max_steps}, "
|
| 207 |
+
f"Average Loss: {avg_loss:.4f}, "
|
| 208 |
+
f"Learning Rate: {lr:.2e}, "
|
| 209 |
+
f"Tokens Seen: {tokens_seen/1e9:.2f}B\n"
|
| 210 |
+
)
|
| 211 |
+
with open(self.loss_log_file, 'a', encoding='utf-8') as f:
|
| 212 |
+
f.write(log_content)
|
| 213 |
+
|
| 214 |
+
def train(self, dataloader, resume_from=None):
|
| 215 |
+
"""训练循环"""
|
| 216 |
+
logger.info("\n" + "="*80)
|
| 217 |
+
logger.info("Starting Pre-Training (Fixed Version)")
|
| 218 |
+
logger.info("="*80 + "\n")
|
| 219 |
+
|
| 220 |
+
# 恢复训练
|
| 221 |
+
if resume_from:
|
| 222 |
+
self.load_checkpoint(resume_from)
|
| 223 |
+
|
| 224 |
+
# 初始化日志
|
| 225 |
+
if not self.loss_log_file.exists():
|
| 226 |
+
with open(self.loss_log_file, 'w', encoding='utf-8') as f:
|
| 227 |
+
f.write("🚀 Fixed Training Log (Real Loss Values)\n")
|
| 228 |
+
f.write("="*80 + "\n")
|
| 229 |
+
|
| 230 |
+
self.model.train()
|
| 231 |
+
progress_bar = tqdm(total=self.max_steps, initial=self.global_step)
|
| 232 |
+
|
| 233 |
+
step_in_accumulation = 0
|
| 234 |
+
accumulated_loss = 0.0 # 🔧 用于累积一个完整step的loss
|
| 235 |
+
|
| 236 |
+
batches_to_skip = self.global_step * self.gradient_accumulation_steps
|
| 237 |
+
|
| 238 |
+
logger.info(f"Current Global Step: {self.global_step}")
|
| 239 |
+
if batches_to_skip > 0:
|
| 240 |
+
logger.info(f"🔄 Resuming: Need to skip {batches_to_skip} batches to restore data state...")
|
| 241 |
+
logger.info("This might take a while depending on network/disk speed...")
|
| 242 |
+
|
| 243 |
+
# 创建迭代器
|
| 244 |
+
data_iterator = iter(dataloader)
|
| 245 |
+
|
| 246 |
+
# 1. 执行跳过逻辑
|
| 247 |
+
skipped = 0
|
| 248 |
+
if batches_to_skip > 0:
|
| 249 |
+
with tqdm(total=batches_to_skip, desc="Skipping trained batches", unit="batch") as skip_pbar:
|
| 250 |
+
while skipped < batches_to_skip:
|
| 251 |
+
try:
|
| 252 |
+
# 只取数据,不进模型,不计算梯度
|
| 253 |
+
_ = next(data_iterator)
|
| 254 |
+
skipped += 1
|
| 255 |
+
skip_pbar.update(1)
|
| 256 |
+
except StopIteration:
|
| 257 |
+
logger.error("Dataset exhausted during skipping! Check your dataset size or max_steps.")
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
logger.info("✅ Data fast-forward complete. Resuming training...")
|
| 261 |
+
|
| 262 |
+
# 2. 正式训练循环
|
| 263 |
+
try:
|
| 264 |
+
# 注意:这里不能再用 for batch in dataloader,因为迭代器已经被消费了一部分
|
| 265 |
+
# 我们继续使用上面创建的 data_iterator
|
| 266 |
+
while True:
|
| 267 |
+
try:
|
| 268 |
+
batch = next(data_iterator)
|
| 269 |
+
except StopIteration:
|
| 270 |
+
break # 数据耗尽
|
| 271 |
+
|
| 272 |
+
if batch is None or batch['input_ids'].size(0) == 0:
|
| 273 |
+
continue
|
| 274 |
+
#print("Sample input:", self.tokenizer.decode(batch['input_ids'][0][:50]))
|
| 275 |
+
# 训练步骤
|
| 276 |
+
stats = self.train_step(batch)
|
| 277 |
+
step_in_accumulation += 1
|
| 278 |
+
accumulated_loss += stats['loss'] # 🔧 累积当前micro-batch的loss
|
| 279 |
+
|
| 280 |
+
# 梯度累积完成,执行优化器更新
|
| 281 |
+
if step_in_accumulation >= self.gradient_accumulation_steps:
|
| 282 |
+
# 🔧 计算这个完整step的平均loss
|
| 283 |
+
avg_step_loss = accumulated_loss / self.gradient_accumulation_steps
|
| 284 |
+
|
| 285 |
+
grad_norm = self.optimizer_step()
|
| 286 |
+
stats['grad_norm'] = grad_norm
|
| 287 |
+
stats['loss'] = avg_step_loss # 🔧 更新为平均loss
|
| 288 |
+
|
| 289 |
+
# 🔧 累积到running_loss(用于日志记录)
|
| 290 |
+
self.running_loss += avg_step_loss
|
| 291 |
+
|
| 292 |
+
step_in_accumulation = 0
|
| 293 |
+
accumulated_loss = 0.0 # 🔧 重置累积器
|
| 294 |
+
|
| 295 |
+
# 更新进度条
|
| 296 |
+
progress_bar.update(1)
|
| 297 |
+
progress_bar.set_postfix({
|
| 298 |
+
'loss': f"{stats['loss']:.4f}",
|
| 299 |
+
'lr': f"{stats['lr']:.2e}",
|
| 300 |
+
'tokens': f"{self.tokens_seen/1e9:.2f}B",
|
| 301 |
+
'grad': f"{grad_norm:.2f}"
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
# 日志记录
|
| 305 |
+
if self.global_step % self.log_interval == 0:
|
| 306 |
+
avg_loss = self.running_loss / self.log_interval
|
| 307 |
+
|
| 308 |
+
logger.info(
|
| 309 |
+
f"Step {self.global_step}/{self.max_steps} | "
|
| 310 |
+
f"Loss: {avg_loss:.4f} | "
|
| 311 |
+
f"LR: {stats['lr']:.2e} | "
|
| 312 |
+
f"GradNorm: {grad_norm:.2f} | "
|
| 313 |
+
f"Tokens: {self.tokens_seen/1e9:.2f}B"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# 🔧 检测训练异常
|
| 317 |
+
if avg_loss > 10.0 and self.global_step > 100:
|
| 318 |
+
logger.warning(f"⚠️ Loss异常高 ({avg_loss:.2f}),可能存在问题!")
|
| 319 |
+
|
| 320 |
+
if avg_loss < self.best_loss:
|
| 321 |
+
self.best_loss = avg_loss
|
| 322 |
+
logger.info(f"✨ New best loss: {self.best_loss:.4f}")
|
| 323 |
+
|
| 324 |
+
self._write_loss_to_txt(
|
| 325 |
+
step=self.global_step,
|
| 326 |
+
avg_loss=avg_loss,
|
| 327 |
+
lr=stats['lr'],
|
| 328 |
+
tokens_seen=self.tokens_seen
|
| 329 |
+
)
|
| 330 |
+
self.running_loss = 0.0
|
| 331 |
+
|
| 332 |
+
# 保存checkpoint
|
| 333 |
+
if self.global_step % self.save_interval == 0:
|
| 334 |
+
self.save_checkpoint(
|
| 335 |
+
self.checkpoint_dir / f"step_{self.global_step}.pt"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 完成训练
|
| 339 |
+
if self.global_step >= self.max_steps:
|
| 340 |
+
break
|
| 341 |
+
|
| 342 |
+
except KeyboardInterrupt:
|
| 343 |
+
logger.info("\n⚠️ Training interrupted by user")
|
| 344 |
+
self.save_checkpoint(
|
| 345 |
+
self.checkpoint_dir / f"interrupted_step_{self.global_step}.pt"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
finally:
|
| 349 |
+
progress_bar.close()
|
| 350 |
+
|
| 351 |
+
logger.info("\n" + "="*80)
|
| 352 |
+
logger.info("Pre-Training Complete!")
|
| 353 |
+
logger.info(f" Total Steps: {self.global_step}")
|
| 354 |
+
logger.info(f" Total Tokens: {self.tokens_seen/1e9:.2f}B")
|
| 355 |
+
logger.info(f" Best Loss: {self.best_loss:.4f}")
|
| 356 |
+
logger.info("="*80 + "\n")
|
| 357 |
+
|
| 358 |
+
# 保存最终模型
|
| 359 |
+
self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
|
| 360 |
+
|
| 361 |
+
def save_checkpoint(self, path: Path):
|
| 362 |
+
"""保存checkpoint"""
|
| 363 |
+
checkpoint = {
|
| 364 |
+
'model_state_dict': self.model.state_dict(),
|
| 365 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 366 |
+
'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
|
| 367 |
+
'global_step': self.global_step,
|
| 368 |
+
'current_step': self.current_step,
|
| 369 |
+
'tokens_seen': self.tokens_seen,
|
| 370 |
+
'best_loss': self.best_loss,
|
| 371 |
+
'timestamp': datetime.now().isoformat()
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
torch.save(checkpoint, path)
|
| 375 |
+
logger.info(f"💾 Checkpoint saved to {path}")
|
| 376 |
+
|
| 377 |
+
def load_checkpoint(self, path: str):
|
| 378 |
+
"""加载checkpoint"""
|
| 379 |
+
checkpoint = torch.load(path, map_location=self.device, weights_only=True)
|
| 380 |
+
|
| 381 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 382 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 383 |
+
|
| 384 |
+
if self.use_amp and checkpoint.get('scaler_state_dict'):
|
| 385 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 386 |
+
|
| 387 |
+
self.global_step = checkpoint['global_step']
|
| 388 |
+
self.current_step = checkpoint.get('current_step', self.global_step)
|
| 389 |
+
self.tokens_seen = checkpoint['tokens_seen']
|
| 390 |
+
self.best_loss = checkpoint.get('best_loss', float('inf'))
|
| 391 |
+
|
| 392 |
+
logger.info(f"📂 Checkpoint loaded from {path}")
|
| 393 |
+
logger.info(f" Resuming from step {self.global_step}")
|
| 394 |
+
logger.info(f" Tokens seen: {self.tokens_seen/1e9:.2f}B")
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def main():
|
| 398 |
+
"""主函数"""
|
| 399 |
+
# 🔧 优化后的配置
|
| 400 |
+
config = {
|
| 401 |
+
# 模型配置
|
| 402 |
+
'model_dim': 1536,
|
| 403 |
+
'vocab_size': 151665,
|
| 404 |
+
'n_layers': 12,
|
| 405 |
+
'n_heads': 12,
|
| 406 |
+
'n_kv_heads': 4,
|
| 407 |
+
'max_seq_len': 512, # 🔧 减小以提升速度
|
| 408 |
+
'dropout': 0.1,
|
| 409 |
+
'use_moe': False,
|
| 410 |
+
|
| 411 |
+
# 🔧 训练配置(关键修复)
|
| 412 |
+
'batch_size': 4, # 增加
|
| 413 |
+
'gradient_accumulation_steps': 8, # 减少
|
| 414 |
+
'learning_rate': 3e-4, # 标准值
|
| 415 |
+
'weight_decay': 0.1,
|
| 416 |
+
'warmup_steps': 500, # 更快warmup
|
| 417 |
+
'max_steps': 10000,
|
| 418 |
+
'max_grad_norm': 1.0,
|
| 419 |
+
|
| 420 |
+
# 数据配置
|
| 421 |
+
'data_mix': 'text_only',
|
| 422 |
+
'max_length': 512, # 🔧 与max_seq_len一致
|
| 423 |
+
'num_workers': 2, # 🔧 减少避免网络问题
|
| 424 |
+
|
| 425 |
+
# 日志和保存
|
| 426 |
+
'log_interval': 10,
|
| 427 |
+
'save_interval': 500, # 🔧 更频繁保存
|
| 428 |
+
'checkpoint_dir': 'checkpoints/pretrain_fixed',
|
| 429 |
+
'loss_log_file': 'checkpoints/pretrain_fixed/train_loss.log'
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
logger.info("="*80)
|
| 433 |
+
logger.info("🔧 Fixed Configuration:")
|
| 434 |
+
logger.info(json.dumps(config, indent=2))
|
| 435 |
+
logger.info("="*80 + "\n")
|
| 436 |
+
|
| 437 |
+
# 初始化tokenizer
|
| 438 |
+
logger.info("Initializing tokenizer...")
|
| 439 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 440 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 441 |
+
use_fast=True,
|
| 442 |
+
trust_remote_code=True
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if tokenizer.pad_token is None:
|
| 446 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 447 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 448 |
+
|
| 449 |
+
config['vocab_size'] = len(tokenizer)
|
| 450 |
+
logger.info(f"Vocab size: {config['vocab_size']}\n")
|
| 451 |
+
|
| 452 |
+
# 初始化模型
|
| 453 |
+
logger.info("Initializing model...")
|
| 454 |
+
model = MultiModalDenseTransformer(
|
| 455 |
+
model_dim=config['model_dim'],
|
| 456 |
+
vocab_size=config['vocab_size'],
|
| 457 |
+
n_layers=config['n_layers'],
|
| 458 |
+
n_heads=config['n_heads'],
|
| 459 |
+
n_kv_heads=config['n_kv_heads'],
|
| 460 |
+
max_seq_len=config['max_seq_len'],
|
| 461 |
+
dropout=config['dropout'],
|
| 462 |
+
use_moe=config['use_moe'],
|
| 463 |
+
use_gradient_checkpointing=True,
|
| 464 |
+
rope_scaling_type="yarn",
|
| 465 |
+
use_multimodal_fusion=False,
|
| 466 |
+
use_contrastive=False
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# 创建数据加载器
|
| 470 |
+
logger.info(f"\nCreating dataloader (mix: {config['data_mix']})...")
|
| 471 |
+
dataloader = create_pretrain_dataloader(
|
| 472 |
+
mix_name=config['data_mix'],
|
| 473 |
+
tokenizer=tokenizer,
|
| 474 |
+
batch_size=config['batch_size'],
|
| 475 |
+
num_workers=config['num_workers'],
|
| 476 |
+
max_length=config['max_length']
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# 创建训练器
|
| 480 |
+
trainer = PreTrainer(
|
| 481 |
+
model=model,
|
| 482 |
+
tokenizer=tokenizer,
|
| 483 |
+
learning_rate=config['learning_rate'],
|
| 484 |
+
weight_decay=config['weight_decay'],
|
| 485 |
+
warmup_steps=config['warmup_steps'],
|
| 486 |
+
max_steps=config['max_steps'],
|
| 487 |
+
gradient_accumulation_steps=config['gradient_accumulation_steps'],
|
| 488 |
+
max_grad_norm=config['max_grad_norm'],
|
| 489 |
+
log_interval=config['log_interval'],
|
| 490 |
+
save_interval=config['save_interval'],
|
| 491 |
+
checkpoint_dir=config['checkpoint_dir'],
|
| 492 |
+
loss_log_file=config['loss_log_file']
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# 🔧 开始训练(从头开始,不要用旧的checkpoint)
|
| 496 |
+
logger.info("\n🚀 Starting fresh training with fixes...\n")
|
| 497 |
+
trainer.train(dataloader, resume_from="/root/step_6500.pt")
|
| 498 |
+
#trainer.train(dataloader)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
if __name__ == "__main__":
|
| 502 |
+
main()
|
reward_model.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
奖励模型 - 用于RLHF
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from typing import Dict, Tuple, Union, Optional
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from model import MultiModalDenseTransformer
|
| 13 |
+
|
| 14 |
+
class RewardModel(nn.Module):
|
| 15 |
+
"""奖励模型 - 用于RLHF"""
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
base_model: MultiModalDenseTransformer,
|
| 19 |
+
use_value_head: bool = True
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.base_model = base_model
|
| 23 |
+
self.use_value_head = use_value_head
|
| 24 |
+
|
| 25 |
+
self.reward_head = nn.Sequential(
|
| 26 |
+
nn.Linear(base_model.model_dim, base_model.model_dim // 2),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
nn.Dropout(0.1),
|
| 29 |
+
nn.Linear(base_model.model_dim // 2, 1)
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if use_value_head:
|
| 33 |
+
self.value_head = nn.Sequential(
|
| 34 |
+
nn.Linear(base_model.model_dim, base_model.model_dim // 2),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
nn.Dropout(0.1),
|
| 37 |
+
nn.Linear(base_model.model_dim // 2, 1)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(
|
| 41 |
+
self,
|
| 42 |
+
input_data: Dict,
|
| 43 |
+
return_values: bool = False
|
| 44 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 45 |
+
"""前向传播"""
|
| 46 |
+
output = self.base_model(input_data, return_hidden=True)
|
| 47 |
+
hidden_states = output['last_hidden_state']
|
| 48 |
+
|
| 49 |
+
rewards = self.reward_head(hidden_states).squeeze(-1)
|
| 50 |
+
|
| 51 |
+
if return_values and self.use_value_head:
|
| 52 |
+
values = self.value_head(hidden_states).squeeze(-1)
|
| 53 |
+
return rewards, values
|
| 54 |
+
|
| 55 |
+
return rewards
|
| 56 |
+
|
| 57 |
+
class RewardModelTrainer:
|
| 58 |
+
"""奖励模型训练器"""
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
reward_model: RewardModel,
|
| 62 |
+
learning_rate: float = 1e-5,
|
| 63 |
+
margin: float = 0.0
|
| 64 |
+
):
|
| 65 |
+
self.reward_model = reward_model
|
| 66 |
+
self.margin = margin
|
| 67 |
+
|
| 68 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 69 |
+
self.reward_model.to(self.device)
|
| 70 |
+
|
| 71 |
+
for param in self.reward_model.base_model.parameters():
|
| 72 |
+
param.requires_grad = False
|
| 73 |
+
|
| 74 |
+
for layer in self.reward_model.base_model.layers[-2:]:
|
| 75 |
+
for param in layer.parameters():
|
| 76 |
+
param.requires_grad = True
|
| 77 |
+
|
| 78 |
+
trainable_params = list(self.reward_model.reward_head.parameters())
|
| 79 |
+
if self.reward_model.use_value_head:
|
| 80 |
+
trainable_params += list(self.reward_model.value_head.parameters())
|
| 81 |
+
|
| 82 |
+
self.optimizer = optim.AdamW(
|
| 83 |
+
filter(lambda p: p.requires_grad, self.reward_model.parameters()),
|
| 84 |
+
lr=learning_rate
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict:
|
| 88 |
+
"""单步训练"""
|
| 89 |
+
self.reward_model.train()
|
| 90 |
+
self.optimizer.zero_grad()
|
| 91 |
+
|
| 92 |
+
chosen_rewards = self.reward_model(chosen_batch)[:, -1]
|
| 93 |
+
rejected_rewards = self.reward_model(rejected_batch)[:, -1]
|
| 94 |
+
|
| 95 |
+
loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
|
| 96 |
+
|
| 97 |
+
loss.backward()
|
| 98 |
+
torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0)
|
| 99 |
+
self.optimizer.step()
|
| 100 |
+
|
| 101 |
+
accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
'loss': loss.item(),
|
| 105 |
+
'accuracy': accuracy
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def train(
|
| 109 |
+
self,
|
| 110 |
+
dataloader: DataLoader,
|
| 111 |
+
num_epochs: int = 1,
|
| 112 |
+
log_interval: int = 10
|
| 113 |
+
):
|
| 114 |
+
"""训练循环"""
|
| 115 |
+
print(f"Starting reward model training on {self.device}...")
|
| 116 |
+
|
| 117 |
+
for epoch in range(num_epochs):
|
| 118 |
+
total_stats = defaultdict(float)
|
| 119 |
+
num_steps = 0
|
| 120 |
+
progress_bar = tqdm(
|
| 121 |
+
dataloader,
|
| 122 |
+
desc=f"Reward Model Epoch {epoch+1}/{num_epochs}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar):
|
| 126 |
+
chosen_batch = {
|
| 127 |
+
'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
rejected_batch = {
|
| 131 |
+
'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
stats = self.train_step(chosen_batch, rejected_batch)
|
| 135 |
+
|
| 136 |
+
for k, v in stats.items():
|
| 137 |
+
total_stats[k] += v
|
| 138 |
+
num_steps += 1
|
| 139 |
+
|
| 140 |
+
if (batch_idx + 1) % log_interval == 0:
|
| 141 |
+
avg_stats = {
|
| 142 |
+
k: v / num_steps
|
| 143 |
+
for k, v in total_stats.items()
|
| 144 |
+
}
|
| 145 |
+
progress_bar.set_postfix(avg_stats)
|
| 146 |
+
total_stats = defaultdict(float)
|
| 147 |
+
|
| 148 |
+
print("Reward model training complete!")
|
| 149 |
+
|
| 150 |
+
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
| 151 |
+
"""评估奖励模型"""
|
| 152 |
+
self.reward_model.eval()
|
| 153 |
+
total_stats = defaultdict(float)
|
| 154 |
+
num_batches = 0
|
| 155 |
+
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
for chosen_ids, rejected_ids in dataloader:
|
| 158 |
+
chosen_batch = {
|
| 159 |
+
'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
rejected_batch = {
|
| 163 |
+
'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
chosen_rewards = self.reward_model(chosen_batch)[:, -1]
|
| 167 |
+
rejected_rewards = self.reward_model(rejected_batch)[:, -1]
|
| 168 |
+
|
| 169 |
+
loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
|
| 170 |
+
accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
|
| 171 |
+
|
| 172 |
+
total_stats['loss'] += loss.item()
|
| 173 |
+
total_stats['accuracy'] += accuracy
|
| 174 |
+
num_batches += 1
|
| 175 |
+
|
| 176 |
+
return {k: v / num_batches for k, v in total_stats.items()}
|
| 177 |
+
|
| 178 |
+
def save_checkpoint(self, path: str):
|
| 179 |
+
"""保存检查点"""
|
| 180 |
+
torch.save({
|
| 181 |
+
'model_state_dict': self.reward_model.state_dict(),
|
| 182 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 183 |
+
}, path)
|
| 184 |
+
|
| 185 |
+
def load_checkpoint(self, path: str):
|
| 186 |
+
"""加载检查点"""
|
| 187 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 188 |
+
self.reward_model.load_state_dict(checkpoint['model_state_dict'])
|
| 189 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
transformer.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
优化的Transformer架构
|
| 3 |
+
支持GQA/MQA、滑动窗口注意力、Flash Attention 2、YARN位置编码
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from typing import Optional, Tuple, List
|
| 9 |
+
import math
|
| 10 |
+
from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm
|
| 11 |
+
from peft_ import LinearWithLoRA, AdapterLayer
|
| 12 |
+
from moe import MixtureOfExperts
|
| 13 |
+
|
| 14 |
+
class GroupedQueryAttention(nn.Module):
|
| 15 |
+
"""分组查询注意力 (GQA) - 优化版 with YARN"""
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
n_heads: int,
|
| 20 |
+
n_kv_heads: Optional[int] = None,
|
| 21 |
+
head_dim: Optional[int] = None,
|
| 22 |
+
dropout: float = 0.0,
|
| 23 |
+
attn_dropout: float = 0.0,
|
| 24 |
+
use_flash: bool = True,
|
| 25 |
+
qkv_bias: bool = False,
|
| 26 |
+
use_lora: bool = False,
|
| 27 |
+
lora_rank: int = 8,
|
| 28 |
+
max_seq_len: int = 8192,
|
| 29 |
+
rope_scaling_factor: float = 1.0,
|
| 30 |
+
rope_scaling_type: str = "yarn",
|
| 31 |
+
use_qk_norm: bool = False,
|
| 32 |
+
sliding_window: Optional[int] = None,
|
| 33 |
+
use_alibi: bool = False
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.dim = dim
|
| 38 |
+
self.n_heads = n_heads
|
| 39 |
+
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
|
| 40 |
+
|
| 41 |
+
assert n_heads % self.n_kv_heads == 0, \
|
| 42 |
+
f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
| 43 |
+
|
| 44 |
+
self.n_rep = n_heads // self.n_kv_heads
|
| 45 |
+
self.head_dim = head_dim if head_dim is not None else dim // n_heads
|
| 46 |
+
self.scale = self.head_dim ** -0.5
|
| 47 |
+
|
| 48 |
+
self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
|
| 49 |
+
self.sliding_window = sliding_window
|
| 50 |
+
|
| 51 |
+
self.q_proj = LinearWithLoRA(
|
| 52 |
+
dim, n_heads * self.head_dim,
|
| 53 |
+
bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
|
| 54 |
+
)
|
| 55 |
+
self.k_proj = LinearWithLoRA(
|
| 56 |
+
dim, self.n_kv_heads * self.head_dim,
|
| 57 |
+
bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
|
| 58 |
+
)
|
| 59 |
+
self.v_proj = LinearWithLoRA(
|
| 60 |
+
dim, self.n_kv_heads * self.head_dim,
|
| 61 |
+
bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
|
| 62 |
+
)
|
| 63 |
+
self.o_proj = LinearWithLoRA(
|
| 64 |
+
n_heads * self.head_dim, dim,
|
| 65 |
+
bias=False, use_lora=use_lora, lora_rank=lora_rank
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity()
|
| 69 |
+
self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 70 |
+
|
| 71 |
+
self.use_qk_norm = use_qk_norm
|
| 72 |
+
if use_qk_norm:
|
| 73 |
+
self.q_norm = QKNorm(self.head_dim)
|
| 74 |
+
self.k_norm = QKNorm(self.head_dim)
|
| 75 |
+
|
| 76 |
+
self.use_alibi = use_alibi
|
| 77 |
+
if use_alibi:
|
| 78 |
+
self.register_buffer(
|
| 79 |
+
"alibi_slopes",
|
| 80 |
+
self._get_alibi_slopes(n_heads),
|
| 81 |
+
persistent=False
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
self.rotary_emb = YARNRotaryEmbedding(
|
| 85 |
+
self.head_dim,
|
| 86 |
+
max_seq_len=max_seq_len,
|
| 87 |
+
original_max_len=4096,
|
| 88 |
+
scaling_factor=rope_scaling_factor,
|
| 89 |
+
rope_percentage=1.0
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor:
|
| 93 |
+
"""计算ALiBi斜率"""
|
| 94 |
+
def get_slopes_power_of_2(n):
|
| 95 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
| 96 |
+
ratio = start
|
| 97 |
+
return [start * ratio ** i for i in range(n)]
|
| 98 |
+
|
| 99 |
+
if math.log2(n_heads).is_integer():
|
| 100 |
+
slopes = get_slopes_power_of_2(n_heads)
|
| 101 |
+
else:
|
| 102 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
| 103 |
+
slopes = get_slopes_power_of_2(closest_power_of_2)
|
| 104 |
+
extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2]
|
| 105 |
+
slopes.extend(extra_slopes[:n_heads - closest_power_of_2])
|
| 106 |
+
|
| 107 |
+
return torch.tensor(slopes).view(n_heads, 1, 1)
|
| 108 |
+
|
| 109 |
+
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""重复KV heads以匹配Q heads"""
|
| 111 |
+
if self.n_rep == 1:
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
B, n_kv_heads, seq_len, head_dim = x.shape
|
| 115 |
+
return x[:, :, None, :, :].expand(
|
| 116 |
+
B, n_kv_heads, self.n_rep, seq_len, head_dim
|
| 117 |
+
).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim)
|
| 118 |
+
|
| 119 |
+
def _apply_sliding_window_mask(
|
| 120 |
+
self,
|
| 121 |
+
attn_scores: torch.Tensor,
|
| 122 |
+
seq_len: int
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
"""应用滑动窗口mask"""
|
| 125 |
+
if self.sliding_window is None or seq_len <= self.sliding_window:
|
| 126 |
+
return attn_scores
|
| 127 |
+
|
| 128 |
+
mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool)
|
| 129 |
+
mask = torch.triu(mask, diagonal=-self.sliding_window + 1)
|
| 130 |
+
mask = torch.tril(mask, diagonal=0)
|
| 131 |
+
|
| 132 |
+
attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
|
| 133 |
+
return attn_scores
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
x: torch.Tensor,
|
| 138 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 139 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 140 |
+
use_cache: bool = False,
|
| 141 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 142 |
+
output_attentions: bool = False
|
| 143 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
|
| 144 |
+
"""前向传播"""
|
| 145 |
+
B, T, C = x.shape
|
| 146 |
+
|
| 147 |
+
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 148 |
+
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 149 |
+
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
if self.use_qk_norm:
|
| 152 |
+
q_shape = q.shape
|
| 153 |
+
k_shape = k.shape
|
| 154 |
+
q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape)
|
| 155 |
+
k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape)
|
| 156 |
+
|
| 157 |
+
if not self.use_alibi:
|
| 158 |
+
q, k = self.rotary_emb(q, k, position_ids)
|
| 159 |
+
|
| 160 |
+
if past_kv is not None:
|
| 161 |
+
past_k, past_v = past_kv
|
| 162 |
+
k = torch.cat([past_k, k], dim=2)
|
| 163 |
+
v = torch.cat([past_v, v], dim=2)
|
| 164 |
+
|
| 165 |
+
present_kv = (k, v) if use_cache else None
|
| 166 |
+
|
| 167 |
+
k = self.repeat_kv(k)
|
| 168 |
+
v = self.repeat_kv(v)
|
| 169 |
+
|
| 170 |
+
seq_len_k = k.size(2)
|
| 171 |
+
|
| 172 |
+
if self.use_flash and not output_attentions and attention_mask is None:
|
| 173 |
+
dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0
|
| 174 |
+
attn_output = F.scaled_dot_product_attention(
|
| 175 |
+
q, k, v,
|
| 176 |
+
attn_mask=attention_mask,
|
| 177 |
+
dropout_p=dropout_p,
|
| 178 |
+
is_causal=True if attention_mask is None else False
|
| 179 |
+
)
|
| 180 |
+
attention_weights = None
|
| 181 |
+
else:
|
| 182 |
+
attn_scores = (q @ k.transpose(-2, -1)) * self.scale
|
| 183 |
+
|
| 184 |
+
if self.use_alibi:
|
| 185 |
+
position_bias = self.alibi_slopes.to(x.device) * torch.arange(
|
| 186 |
+
seq_len_k, device=x.device
|
| 187 |
+
).view(1, 1, -1)
|
| 188 |
+
attn_scores = attn_scores + position_bias
|
| 189 |
+
|
| 190 |
+
if self.sliding_window is not None:
|
| 191 |
+
attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k)
|
| 192 |
+
|
| 193 |
+
if attention_mask is not None:
|
| 194 |
+
if attention_mask.dim() == 2:
|
| 195 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 196 |
+
if attention_mask.dtype != torch.float:
|
| 197 |
+
# 假设传入的是 1(Keep)/0(Mask)
|
| 198 |
+
extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
|
| 199 |
+
else:
|
| 200 |
+
# 假设传入的已经是加性 mask (0/-inf)
|
| 201 |
+
extended_mask = attention_mask
|
| 202 |
+
|
| 203 |
+
attn_scores = attn_scores + extended_mask
|
| 204 |
+
|
| 205 |
+
is_causal = seq_len_k > 1
|
| 206 |
+
if is_causal:
|
| 207 |
+
causal_mask = torch.triu(
|
| 208 |
+
torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
|
| 209 |
+
diagonal=1
|
| 210 |
+
)
|
| 211 |
+
causal_mask = causal_mask[-q.shape[2]:, :]#还没懂
|
| 212 |
+
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
|
| 213 |
+
|
| 214 |
+
attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 215 |
+
attention_weights = self.attn_dropout(attention_weights)
|
| 216 |
+
|
| 217 |
+
attn_output = attention_weights @ v
|
| 218 |
+
|
| 219 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
|
| 220 |
+
output = self.resid_dropout(self.o_proj(attn_output))
|
| 221 |
+
|
| 222 |
+
return output, present_kv, attention_weights if output_attentions else None
|
| 223 |
+
|
| 224 |
+
class OptimizedTransformerBlock(nn.Module):
|
| 225 |
+
"""优化的Transformer块"""
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
dim: int,
|
| 229 |
+
n_heads: int,
|
| 230 |
+
n_kv_heads: Optional[int] = None,
|
| 231 |
+
head_dim: Optional[int] = None,
|
| 232 |
+
dropout: float = 0.0,
|
| 233 |
+
attn_dropout: float = 0.0,
|
| 234 |
+
use_moe: bool = False,
|
| 235 |
+
num_experts: int = 8,
|
| 236 |
+
moe_top_k: int = 2,
|
| 237 |
+
use_adapter: bool = False,
|
| 238 |
+
adapter_dim: int = 64,
|
| 239 |
+
use_lora: bool = False,
|
| 240 |
+
lora_rank: int = 8,
|
| 241 |
+
use_parallel_residual: bool = False,
|
| 242 |
+
norm_eps: float = 1e-6,
|
| 243 |
+
sliding_window: Optional[int] = None,
|
| 244 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 245 |
+
layer_idx: int = 0
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.layer_idx = layer_idx
|
| 249 |
+
self.use_moe = use_moe
|
| 250 |
+
self.use_adapter = use_adapter
|
| 251 |
+
self.use_parallel_residual = use_parallel_residual
|
| 252 |
+
|
| 253 |
+
self.attention = GroupedQueryAttention(
|
| 254 |
+
dim=dim,
|
| 255 |
+
n_heads=n_heads,
|
| 256 |
+
n_kv_heads=n_kv_heads,
|
| 257 |
+
head_dim=head_dim,
|
| 258 |
+
dropout=dropout,
|
| 259 |
+
attn_dropout=attn_dropout,
|
| 260 |
+
use_lora=use_lora,
|
| 261 |
+
lora_rank=lora_rank,
|
| 262 |
+
sliding_window=sliding_window,
|
| 263 |
+
rope_scaling_type="yarn"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
if use_moe:
|
| 267 |
+
self.ffn = MixtureOfExperts(
|
| 268 |
+
dim=dim,
|
| 269 |
+
num_experts=num_experts,
|
| 270 |
+
top_k=moe_top_k,
|
| 271 |
+
dropout=dropout,
|
| 272 |
+
ffn_dim_multiplier=ffn_dim_multiplier
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
self.ffn = SwiGLU(
|
| 276 |
+
dim=dim,
|
| 277 |
+
dropout=dropout,
|
| 278 |
+
ffn_dim_multiplier=ffn_dim_multiplier
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if use_adapter:
|
| 282 |
+
self.adapter = AdapterLayer(dim, adapter_dim, dropout)
|
| 283 |
+
|
| 284 |
+
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
| 285 |
+
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
| 286 |
+
|
| 287 |
+
self.moe_aux_loss = torch.tensor(0.0)
|
| 288 |
+
|
| 289 |
+
def forward(
|
| 290 |
+
self,
|
| 291 |
+
x: torch.Tensor,
|
| 292 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 293 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 294 |
+
use_cache: bool = False,
|
| 295 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 296 |
+
output_attentions: bool = False
|
| 297 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
|
| 298 |
+
"""前向传播"""
|
| 299 |
+
|
| 300 |
+
attn_out, present_kv, attn_weights = self.attention(
|
| 301 |
+
self.attention_norm(x),
|
| 302 |
+
attention_mask=attention_mask,
|
| 303 |
+
position_ids=position_ids,
|
| 304 |
+
use_cache=use_cache,
|
| 305 |
+
past_kv=past_kv,
|
| 306 |
+
output_attentions=output_attentions
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if self.use_parallel_residual:
|
| 310 |
+
ffn_input = self.ffn_norm(x)
|
| 311 |
+
|
| 312 |
+
if self.use_moe:
|
| 313 |
+
ffn_out, aux_loss = self.ffn(ffn_input)
|
| 314 |
+
self.moe_aux_loss = aux_loss
|
| 315 |
+
else:
|
| 316 |
+
ffn_out = self.ffn(ffn_input)
|
| 317 |
+
self.moe_aux_loss = torch.tensor(0.0, device=x.device)
|
| 318 |
+
|
| 319 |
+
x = x + attn_out + ffn_out
|
| 320 |
+
else:
|
| 321 |
+
x = x + attn_out
|
| 322 |
+
|
| 323 |
+
if self.use_adapter:
|
| 324 |
+
x = self.adapter(x)
|
| 325 |
+
|
| 326 |
+
ffn_input = self.ffn_norm(x)
|
| 327 |
+
if self.use_moe:
|
| 328 |
+
ffn_out, aux_loss = self.ffn(ffn_input)
|
| 329 |
+
x = x + ffn_out
|
| 330 |
+
self.moe_aux_loss = aux_loss
|
| 331 |
+
else:
|
| 332 |
+
x = x + self.ffn(ffn_input)
|
| 333 |
+
self.moe_aux_loss = torch.tensor(0.0, device=x.device)
|
| 334 |
+
|
| 335 |
+
return x, present_kv, attn_weights
|