szxllm commited on
Commit
c13bd26
·
verified ·
1 Parent(s): a64e147

Update components.py

Browse files
Files changed (1) hide show
  1. components.py +314 -386
components.py CHANGED
@@ -1,387 +1,315 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Tuple, Optional, Union
5
- import math
6
-
7
- class YARNScaling:
8
- """
9
- YARN (Yet Another RoPE extensioN) 缩放策略
10
- 实现参考: https://arxiv.org/abs/2309.00071
11
- """
12
- @staticmethod
13
- def compute_yarn_parameters(
14
- original_max_len: int,
15
- target_max_len: int=8192,
16
- dim: int=128,
17
- base: int = 10000,
18
- beta_fast: int = 32,
19
- beta_slow: int = 1,
20
- alpha: float = 1.0,
21
- device: Optional[torch.device] = None
22
- ) -> Tuple[torch.Tensor, float]:
23
- scale = float(target_max_len) / original_max_len
24
- mscale = YARNScaling.compute_mscale(scale, alpha)
25
-
26
- # 确保 dim 为 float 以进行除法运算
27
- # RoPE 频率是成对的 (0, 2, ..., d-2)
28
- freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
29
-
30
- # 基础频率 (Original RoPE)
31
- freq_extra = 1.0 / (base ** (freqs_idx / dim))
32
-
33
- # 如果不需要缩放,直接返回基础频率
34
- if scale <= 1.0:
35
- return freq_extra, 1.0
36
-
37
- # 插值频率 (Interpolated for extension)
38
- freq_inter = 1.0 / (scale * base ** (freqs_idx / dim))
39
-
40
- # 计算 YARN 阈值 (基于波长/索引)
41
- # 对应 paper 中的 band constraints
42
- # 这里的公式将频率索引 i 映射到阈值
43
- def get_limit(beta):
44
- return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base))
45
-
46
- low = max(math.floor(get_limit(beta_fast)), 0)
47
- high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1)
48
-
49
- # indices: 0, 1, ..., dim/2 - 1
50
- indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device)
51
-
52
- inv_freq = freq_extra.clone()
53
-
54
- # 1. 低频部分 (Long wavelengths, Indices > high): 使用插值频率
55
- # 这些频率对应的波长已经超过了原始上下文长度,需要拉伸
56
- mask_low_freq = indices > high
57
- inv_freq[mask_low_freq] = freq_inter[mask_low_freq]
58
-
59
- # 2. 高频部分 (Short wavelengths, Indices < low): 保持原频率 (freq_extra)
60
- # 这些部分受旋转不变性保护,不需要插值
61
-
62
- # 3. 中间部分: 线性平滑混合 (Ramp function)
63
- mid_mask = (indices >= low) & (indices <= high)
64
- if mid_mask.any():
65
- # 避免除以 0
66
- denom = max(high - low, 1)
67
- t = (indices[mid_mask] - low) / denom
68
- inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t
69
-
70
- return inv_freq, float(mscale)
71
-
72
- @staticmethod
73
- def compute_mscale(scale: float, alpha: float = 1.0) -> float:
74
- """计算注意力缩放因子 (Temperature scaling)"""
75
- if scale <= 1.0:
76
- return 1.0
77
- # 0.1 * ln(scale) + 1.0 是经验公式,用于修正熵值
78
- return 0.1 * math.log(scale) + 1.0
79
-
80
- class YARNRotaryEmbedding(nn.Module):
81
- """
82
- 集成 YARN 的旋转位置编码
83
- 修复了精度问题、缓存管理以及 position_ids 越界问题
84
- """
85
- def __init__(
86
- self,
87
- dim: int = 64,
88
- max_seq_len: int = 8192,
89
- original_max_len: int = 4096,
90
- base: int = 10000,
91
- scaling_factor: float = 1.0, # 预留接口,暂未使用,由 yarn 逻辑控制
92
- beta_fast: int = 32,
93
- beta_slow: int = 1,
94
- alpha: float = 1.0,
95
- rope_percentage: float = 1.0,
96
- device: Optional[torch.device] = None
97
- ):
98
- super().__init__()
99
- self.dim = dim
100
- self.max_seq_len = max_seq_len
101
- self.original_max_len = original_max_len
102
- self.base = base
103
- self.alpha = alpha
104
-
105
- # 计算实际应用 RoPE 的维度
106
- self.rope_dim = int(dim * rope_percentage)
107
- # 确保是偶数
108
- if self.rope_dim % 2 != 0:
109
- self.rope_dim -= 1
110
-
111
- # 初始化频率 (Persistent state)
112
- self._init_yarn_frequencies(device)
113
-
114
- # 缓存 cos/sin (Transient state)
115
- # persistent=False 意味着不会保存到 state_dict,减少 checkpoint 大小
116
- self.register_buffer("cos_cached", None, persistent=False)
117
- self.register_buffer("sin_cached", None, persistent=False)
118
-
119
- def _init_yarn_frequencies(self, device: Optional[torch.device] = None):
120
- """初始化 YARN 频率"""
121
- inv_freq, mscale = YARNScaling.compute_yarn_parameters(
122
- self.original_max_len,
123
- self.max_seq_len,
124
- self.rope_dim,
125
- self.base,
126
- beta_fast=32, # 这里通常使用默认值或传入参数,此处修正为使用硬编码默认值保持一致,或应改为 self.beta_fast
127
- beta_slow=1,
128
- alpha=self.alpha,
129
- device=device
130
- )
131
- # 注册 buffer
132
- self.register_buffer("inv_freq", inv_freq, persistent=True)
133
- self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True)
134
-
135
- def _compute_cos_sin_cache(
136
- self,
137
- needed_len: int,
138
- device: torch.device,
139
- dtype: torch.dtype
140
- ):
141
- """预计算 cos 和 sin 缓存,始终使用 float32 计算以保证精度"""
142
- # 至少分配 max_seq_len,如果 needed_len 更大则扩展
143
- alloc_len = max(needed_len, self.max_seq_len)
144
-
145
- # 如果已有缓存且足够大且设备匹配,则不重新计算 (可选优化,这里选择简单逻辑:不够就重算)
146
- if (self.cos_cached is not None and
147
- self.cos_cached.shape[2] >= alloc_len and
148
- self.cos_cached.device == device):
149
- return
150
-
151
- t = torch.arange(alloc_len, dtype=torch.float32, device=device)
152
-
153
- # freqs: [alloc_len, dim // 2]
154
- # outer product: t[i] * inv_freq[j]
155
- freqs = torch.outer(t, self.inv_freq.to(device))
156
-
157
- # 拼接以匹配 rotate_half 的逻辑: [theta_0, theta_1, ..., theta_0, theta_1, ...]
158
- emb = torch.cat((freqs, freqs), dim=-1)
159
-
160
- # 应用 mscale 并计算 cos/sin
161
- # [alloc_len, rope_dim] -> [1, 1, alloc_len, rope_dim] 用于广播
162
- cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
163
- sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
164
-
165
- self.cos_cached = cos_cached.to(dtype) # 缓存可以存为半精度以省显存,但计算时建议 float32
166
- self.sin_cached = sin_cached.to(dtype)
167
-
168
- @staticmethod
169
- def rotate_half(x: torch.Tensor) -> torch.Tensor:
170
- """
171
- 旋转输入的后半部分
172
- Input: [..., d] -> Split into x1, x2 -> Output [-x2, x1]
173
- """
174
- x1, x2 = x.chunk(2, dim=-1)
175
- return torch.cat((-x2, x1), dim=-1)
176
-
177
- def apply_rotary_pos_emb(
178
- self,
179
- q: torch.Tensor,
180
- k: torch.Tensor,
181
- position_ids: Optional[torch.Tensor] = None
182
- ) -> Tuple[torch.Tensor, torch.Tensor]:
183
- """应用 RoPE,包含精度修正和边界检查"""
184
- bsz, num_heads, seq_len, head_dim = q.shape
185
-
186
- # 1. 确定需要的缓存长度
187
- if position_ids is not None:
188
- # 必须覆盖 position_ids 中的最大索引
189
- max_pos = position_ids.max().item() + 1
190
- needed_len = max(max_pos, seq_len)
191
- else:
192
- needed_len = seq_len
193
-
194
- # 2. 检查并更新缓存
195
- if (self.cos_cached is None or
196
- self.cos_cached.shape[2] < needed_len or
197
- self.cos_cached.device != q.device):
198
- self._compute_cos_sin_cache(needed_len, q.device, q.dtype)
199
-
200
- # 3. 获取对应的 cos/sin
201
- # cos_cached: [1, 1, alloc_len, dim]
202
- if position_ids is not None:
203
- # position_ids: [bs, seq_len]
204
- # 选取对应的 pos embedding -> [bs, 1, seq_len, dim]
205
- # 注意: cos_cached[0, 0] 形状为 [alloc_len, dim]
206
- cos = self.cos_cached[0, 0][position_ids].unsqueeze(1)
207
- sin = self.sin_cached[0, 0][position_ids].unsqueeze(1)
208
- else:
209
- # 默认假设从 0 开始
210
- cos = self.cos_cached[:, :, :seq_len, :]
211
- sin = self.sin_cached[:, :, :seq_len, :]
212
-
213
- # 4. 处理部分 RoPE (如果 rope_dim < head_dim)
214
- if self.rope_dim < head_dim:
215
- q_rot = q[..., :self.rope_dim]
216
- q_pass = q[..., self.rope_dim:]
217
- k_rot = k[..., :self.rope_dim]
218
- k_pass = k[..., self.rope_dim:]
219
- else:
220
- q_rot = q
221
- k_rot = k
222
- q_pass = None
223
- k_pass = None
224
-
225
- # 5. 执行旋转 (强制 float32 计算以避免精度溢出)
226
- q_rot_float = q_rot.float()
227
- k_rot_float = k_rot.float()
228
- cos_float = cos.float()
229
- sin_float = sin.float()
230
-
231
- q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float)
232
- k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float)
233
-
234
- # 6. 转回原始类型
235
- q_embed = q_embed.type_as(q)
236
- k_embed = k_embed.type_as(k)
237
-
238
- if q_pass is not None:
239
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
240
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
241
-
242
- return q_embed, k_embed
243
-
244
- def forward(
245
- self,
246
- q: torch.Tensor,
247
- k: torch.Tensor,
248
- position_ids: Optional[torch.Tensor] = None
249
- ) -> Tuple[torch.Tensor, torch.Tensor]:
250
- return self.apply_rotary_pos_emb(q, k, position_ids)
251
-
252
- def extra_repr(self) -> str:
253
- return (f"dim={self.dim}, rope_dim={self.rope_dim}, "
254
- f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, "
255
- f"base={self.base}")
256
-
257
- class RMSNorm(nn.Module):
258
- """
259
- Root Mean Square Layer Normalization
260
- 包含 float32 强制转换以确保数值稳定性
261
- """
262
- def __init__(
263
- self,
264
- dim: int,
265
- eps: float = 1e-6,
266
- elementwise_affine: bool = True
267
- ):
268
- super().__init__()
269
- self.eps = eps
270
- self.elementwise_affine = elementwise_affine
271
-
272
- if self.elementwise_affine:
273
- self.weight = nn.Parameter(torch.ones(dim))
274
- else:
275
- self.register_parameter('weight', None)
276
-
277
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
278
- # 始终在 float32 下计算 RMS,防止 FP16 下溢或溢出
279
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
280
-
281
- def forward(self, x: torch.Tensor) -> torch.Tensor:
282
- # 1. 转换为 float32 进行统计量计算
283
- output = self._norm(x.float())
284
-
285
- # 2. 转回原始类型
286
- output = output.type_as(x)
287
-
288
- # 3. 应用权重 (如果存在)
289
- if self.elementwise_affine and self.weight is not None:
290
- output = output * self.weight
291
-
292
- return output
293
-
294
- class QKNorm(nn.Module):
295
- """
296
- Query-Key Normalization (ViT-22B / Scaling Transformer)
297
- 用于稳定注意力矩阵的 logits
298
- """
299
- def __init__(self, dim: int, eps: float = 1e-6):
300
- super().__init__()
301
- self.query_norm = RMSNorm(dim, eps=eps)
302
- self.key_norm = RMSNorm(dim, eps=eps)
303
-
304
- def forward(
305
- self,
306
- q: torch.Tensor,
307
- k: torch.Tensor
308
- ) -> Tuple[torch.Tensor, torch.Tensor]:
309
- q = self.query_norm(q)
310
- k = self.key_norm(k)
311
- return q, k
312
-
313
- class SwiGLU(nn.Module):
314
- """
315
- SwiGLU 激活前馈网络
316
- 结构: Down(SiLU(Gate) * Up)
317
- """
318
- def __init__(
319
- self,
320
- dim: int,
321
- hidden_dim: Optional[int] = None,
322
- multiple_of: int = 256,
323
- ffn_dim_multiplier: Optional[float] = None,
324
- dropout: float = 0.0,
325
- bias: bool = False
326
- ):
327
- super().__init__()
328
-
329
- if hidden_dim is None:
330
- if ffn_dim_multiplier is not None:
331
- hidden_dim = int(dim * ffn_dim_multiplier)
332
- else:
333
- # 默认: 2/3 * 4 * dim = 8/3 * dim (LLaMA standard)
334
- hidden_dim = int(2 * dim * 4 / 3)
335
-
336
- # 确保 hidden_dim 是 multiple_of 的倍数 (通常为了 GPU 核心优化)
337
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
338
-
339
- self.hidden_dim = hidden_dim
340
-
341
- # W1: Gate, W3: Up, W2: Down (Standard LLaMA naming conventions)
342
- self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
343
- self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
344
- self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
345
- self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
346
-
347
- def forward(self, x: torch.Tensor) -> torch.Tensor:
348
- # SwiGLU(x) = (SiLU(W1·x) ⊙ W3·x) · W2
349
- return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
350
-
351
- class ParallelAttentionFFN(nn.Module):
352
- """
353
- 并行注意力与前馈网络 (PaLM / GPT-J 风格)
354
- y = x + Attention(LN(x)) + MLP(LN(x))
355
- """
356
- def __init__(
357
- self,
358
- dim: int,
359
- attn_module: nn.Module,
360
- ffn_module: nn.Module,
361
- norm_eps: float = 1e-6
362
- ):
363
- super().__init__()
364
- # 注意: 某些架构(如 PaLM)可能共用一个 LayerNorm,
365
- # 但这里为了灵活性保留两个独立的 Norm (如 CodeLlama 某些变体)
366
- self.attn_norm = RMSNorm(dim, eps=norm_eps)
367
- self.ffn_norm = RMSNorm(dim, eps=norm_eps)
368
- self.attn = attn_module
369
- self.ffn = ffn_module
370
-
371
- def forward(
372
- self,
373
- x: torch.Tensor,
374
- **attn_kwargs
375
- ) -> torch.Tensor:
376
- # 并行计算:从同一个 x (normalize 后) 分叉
377
- attn_input = self.attn_norm(x)
378
- ffn_input = self.ffn_norm(x)
379
-
380
- # 计算注意力
381
- attn_out = self.attn(attn_input, **attn_kwargs)
382
-
383
- # 计算 FFN (确保不传递 attn 特定的 kwargs)
384
- ffn_out = self.ffn(ffn_input)
385
-
386
- # 一次性残差连接
387
  return x + attn_out + ffn_out
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Optional, Union
5
+ import math
6
+
7
+ class YARNScaling:
8
+ @staticmethod
9
+ def compute_yarn_parameters(
10
+ original_max_len: int,
11
+ target_max_len: int=8192,
12
+ dim: int=128,
13
+ base: int = 10000,
14
+ beta_fast: int = 32,
15
+ beta_slow: int = 1,
16
+ alpha: float = 1.0,
17
+ device: Optional[torch.device] = None
18
+ ) -> Tuple[torch.Tensor, float]:
19
+ scale = float(target_max_len) / original_max_len
20
+ mscale = YARNScaling.compute_mscale(scale, alpha)
21
+
22
+ # 确保 dim float 以进行除法运算
23
+ # RoPE 频率是成对的 (0, 2, ..., d-2)
24
+ freqs_idx = torch.arange(0, dim, 2, dtype=torch.float32, device=device)
25
+
26
+ # 基础频率 (Original RoPE)
27
+ freq_extra = 1.0 / (base ** (freqs_idx / dim))
28
+
29
+ # 如果不需要缩放,直接返回基础频率
30
+ if scale <= 1.0:
31
+ return freq_extra, 1.0
32
+
33
+ # 插值频率 (Interpolated for extension)
34
+ freq_inter = 1.0 / (scale * base ** (freqs_idx / dim))
35
+
36
+ def get_limit(beta):
37
+ return dim * math.log(original_max_len / (2 * math.pi * beta)) / (2 * math.log(base))
38
+
39
+ low = max(math.floor(get_limit(beta_fast)), 0)
40
+ high = min(math.ceil(get_limit(beta_slow)), dim // 2 - 1)
41
+
42
+ indices = torch.arange(0, dim // 2, dtype=torch.float32, device=device)
43
+
44
+ inv_freq = freq_extra.clone()
45
+
46
+ mask_low_freq = indices > high
47
+ inv_freq[mask_low_freq] = freq_inter[mask_low_freq]
48
+
49
+ mid_mask = (indices >= low) & (indices <= high)
50
+ if mid_mask.any():
51
+ # 避免除以 0
52
+ denom = max(high - low, 1)
53
+ t = (indices[mid_mask] - low) / denom
54
+ inv_freq[mid_mask] = freq_extra[mid_mask] * (1 - t) + freq_inter[mid_mask] * t
55
+
56
+ return inv_freq, float(mscale)
57
+
58
+ @staticmethod
59
+ def compute_mscale(scale: float, alpha: float = 1.0) -> float:
60
+ """计算注意力缩放因子 (Temperature scaling)"""
61
+ if scale <= 1.0:
62
+ return 1.0
63
+ return 0.1 * math.log(scale) + 1.0
64
+
65
+ class YARNRotaryEmbedding(nn.Module):
66
+ def __init__(
67
+ self,
68
+ dim: int = 64,
69
+ max_seq_len: int = 8192,
70
+ original_max_len: int = 4096,
71
+ base: int = 10000,
72
+ scaling_factor: float = 1.0,
73
+ beta_fast: int = 32,
74
+ beta_slow: int = 1,
75
+ alpha: float = 1.0,
76
+ rope_percentage: float = 1.0,
77
+ device: Optional[torch.device] = None
78
+ ):
79
+ super().__init__()
80
+ self.dim = dim
81
+ self.max_seq_len = max_seq_len
82
+ self.original_max_len = original_max_len
83
+ self.base = base
84
+ self.alpha = alpha
85
+
86
+ # 计算实际应用 RoPE 的维度
87
+ self.rope_dim = int(dim * rope_percentage)
88
+ # 确保是偶数
89
+ if self.rope_dim % 2 != 0:
90
+ self.rope_dim -= 1
91
+
92
+ # 初始化频率 (Persistent state)
93
+ self._init_yarn_frequencies(device)
94
+
95
+ # 缓存 cos/sin
96
+ self.register_buffer("cos_cached", None, persistent=False)
97
+ self.register_buffer("sin_cached", None, persistent=False)
98
+
99
+ def _init_yarn_frequencies(self, device: Optional[torch.device] = None):
100
+ inv_freq, mscale = YARNScaling.compute_yarn_parameters(
101
+ self.original_max_len,
102
+ self.max_seq_len,
103
+ self.rope_dim,
104
+ self.base,
105
+ beta_fast=32,
106
+ beta_slow=1,
107
+ alpha=self.alpha,
108
+ device=device
109
+ )
110
+ self.register_buffer("inv_freq", inv_freq, persistent=True)
111
+ self.register_buffer("mscale", torch.tensor(mscale, dtype=torch.float32, device=device), persistent=True)
112
+
113
+ def _compute_cos_sin_cache(
114
+ self,
115
+ needed_len: int,
116
+ device: torch.device,
117
+ dtype: torch.dtype
118
+ ):
119
+ alloc_len = max(needed_len, self.max_seq_len)
120
+
121
+ if (self.cos_cached is not None and
122
+ self.cos_cached.shape[2] >= alloc_len and
123
+ self.cos_cached.device == device):
124
+ return
125
+
126
+ t = torch.arange(alloc_len, dtype=torch.float32, device=device)
127
+ freqs = torch.outer(t, self.inv_freq.to(device))
128
+ emb = torch.cat((freqs, freqs), dim=-1)
129
+
130
+ cos_cached = (emb.cos() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
131
+ sin_cached = (emb.sin() * self.mscale).view(1, 1, alloc_len, self.rope_dim)
132
+
133
+ self.cos_cached = cos_cached.to(dtype)
134
+ self.sin_cached = sin_cached.to(dtype)
135
+
136
+ @staticmethod
137
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
138
+ x1, x2 = x.chunk(2, dim=-1)
139
+ return torch.cat((-x2, x1), dim=-1)
140
+
141
+ def apply_rotary_pos_emb(
142
+ self,
143
+ q: torch.Tensor,
144
+ k: torch.Tensor,
145
+ position_ids: Optional[torch.Tensor] = None
146
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ bsz, num_heads, seq_len, head_dim = q.shape
148
+
149
+ if position_ids is not None:
150
+ max_pos = position_ids.max().item() + 1
151
+ needed_len = max(max_pos, seq_len)
152
+ else:
153
+ needed_len = seq_len
154
+
155
+ if (self.cos_cached is None or
156
+ self.cos_cached.shape[2] < needed_len or
157
+ self.cos_cached.device != q.device):
158
+ self._compute_cos_sin_cache(needed_len, q.device, q.dtype)
159
+
160
+ if position_ids is not None:
161
+ cos = self.cos_cached[0, 0][position_ids].unsqueeze(1)
162
+ sin = self.sin_cached[0, 0][position_ids].unsqueeze(1)
163
+ else:
164
+ cos = self.cos_cached[:, :, :seq_len, :]
165
+ sin = self.sin_cached[:, :, :seq_len, :]
166
+
167
+ if self.rope_dim < head_dim:
168
+ q_rot = q[..., :self.rope_dim]
169
+ q_pass = q[..., self.rope_dim:]
170
+ k_rot = k[..., :self.rope_dim]
171
+ k_pass = k[..., self.rope_dim:]
172
+ else:
173
+ q_rot = q
174
+ k_rot = k
175
+ q_pass = None
176
+ k_pass = None
177
+
178
+ q_rot_float = q_rot.float()
179
+ k_rot_float = k_rot.float()
180
+ cos_float = cos.float()
181
+ sin_float = sin.float()
182
+
183
+ q_embed = (q_rot_float * cos_float) + (self.rotate_half(q_rot_float) * sin_float)
184
+ k_embed = (k_rot_float * cos_float) + (self.rotate_half(k_rot_float) * sin_float)
185
+
186
+ q_embed = q_embed.type_as(q)
187
+ k_embed = k_embed.type_as(k)
188
+
189
+ if q_pass is not None:
190
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
191
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
192
+
193
+ return q_embed, k_embed
194
+
195
+ def forward(
196
+ self,
197
+ q: torch.Tensor,
198
+ k: torch.Tensor,
199
+ position_ids: Optional[torch.Tensor] = None
200
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
201
+ return self.apply_rotary_pos_emb(q, k, position_ids)
202
+
203
+ def extra_repr(self) -> str:
204
+ return (f"dim={self.dim}, rope_dim={self.rope_dim}, "
205
+ f"max_seq_len={self.max_seq_len}, original_max_len={self.original_max_len}, "
206
+ f"base={self.base}")
207
+
208
+ class RMSNorm(nn.Module):
209
+ def __init__(
210
+ self,
211
+ dim: int,
212
+ eps: float = 1e-6,
213
+ elementwise_affine: bool = True
214
+ ):
215
+ super().__init__()
216
+ self.eps = eps
217
+ self.elementwise_affine = elementwise_affine
218
+
219
+ if self.elementwise_affine:
220
+ self.weight = nn.Parameter(torch.ones(dim))
221
+ else:
222
+ self.register_parameter('weight', None)
223
+
224
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
225
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
226
+
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ output = self._norm(x.float())
229
+ output = output.type_as(x)
230
+
231
+ if self.elementwise_affine and self.weight is not None:
232
+ output = output * self.weight
233
+
234
+ return output
235
+
236
+ class QKNorm(nn.Module):
237
+ def __init__(self, dim: int, eps: float = 1e-6):
238
+ super().__init__()
239
+ self.query_norm = RMSNorm(dim, eps=eps)
240
+ self.key_norm = RMSNorm(dim, eps=eps)
241
+
242
+ def forward(
243
+ self,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor
246
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
247
+ q = self.query_norm(q)
248
+ k = self.key_norm(k)
249
+ return q, k
250
+
251
+ class SwiGLU(nn.Module):
252
+ def __init__(
253
+ self,
254
+ dim: int,
255
+ hidden_dim: Optional[int] = None,
256
+ multiple_of: int = 256,
257
+ ffn_dim_multiplier: Optional[float] = None,
258
+ dropout: float = 0.0,
259
+ bias: bool = False
260
+ ):
261
+ super().__init__()
262
+
263
+ if hidden_dim is None:
264
+ if ffn_dim_multiplier is not None:
265
+ hidden_dim = int(dim * ffn_dim_multiplier)
266
+ else:
267
+ # 默认: 2/3 * 4 * dim = 8/3 * dim
268
+ hidden_dim = int(2 * dim * 4 / 3)
269
+
270
+ # 确保 hidden_dim 是 multiple_of 的倍数 (通常为了 GPU 核心优化)
271
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
272
+
273
+ self.hidden_dim = hidden_dim
274
+
275
+ # W1: Gate, W3: Up, W2: Down (Standard LLaMA naming conventions)
276
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
277
+ self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
278
+ self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
279
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ # SwiGLU(x) = (SiLU(W1·x) ⊙ W3·x) · W2
283
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
284
+
285
+ class ParallelAttentionFFN(nn.Module):
286
+ def __init__(
287
+ self,
288
+ dim: int,
289
+ attn_module: nn.Module,
290
+ ffn_module: nn.Module,
291
+ norm_eps: float = 1e-6
292
+ ):
293
+ super().__init__()
294
+ self.attn_norm = RMSNorm(dim, eps=norm_eps)
295
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
296
+ self.attn = attn_module
297
+ self.ffn = ffn_module
298
+
299
+ def forward(
300
+ self,
301
+ x: torch.Tensor,
302
+ **attn_kwargs
303
+ ) -> torch.Tensor:
304
+ # 并行计算:从同一个 x (normalize 后) 分叉
305
+ attn_input = self.attn_norm(x)
306
+ ffn_input = self.ffn_norm(x)
307
+
308
+ # 计算注意力
309
+ attn_out = self.attn(attn_input, **attn_kwargs)
310
+
311
+ # 计算 FFN (确保不传递 attn 特定的 kwargs)
312
+ ffn_out = self.ffn(ffn_input)
313
+
314
+ # 一次性残差连接
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  return x + attn_out + ffn_out