szxllm commited on
Commit
02d752c
·
verified ·
1 Parent(s): 5c5e75b

Update transformer.py

Browse files
Files changed (1) hide show
  1. transformer.py +329 -334
transformer.py CHANGED
@@ -1,335 +1,330 @@
1
- """
2
- 优化的Transformer架构
3
- 支持GQA/MQA、滑动窗口注意力、Flash Attention 2、YARN位置编码
4
- """
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from typing import Optional, Tuple, List
9
- import math
10
- from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm
11
- from peft_ import LinearWithLoRA, AdapterLayer
12
- from moe import MixtureOfExperts
13
-
14
- class GroupedQueryAttention(nn.Module):
15
- """分组查询注意力 (GQA) - 优化版 with YARN"""
16
- def __init__(
17
- self,
18
- dim: int,
19
- n_heads: int,
20
- n_kv_heads: Optional[int] = None,
21
- head_dim: Optional[int] = None,
22
- dropout: float = 0.0,
23
- attn_dropout: float = 0.0,
24
- use_flash: bool = True,
25
- qkv_bias: bool = False,
26
- use_lora: bool = False,
27
- lora_rank: int = 8,
28
- max_seq_len: int = 8192,
29
- rope_scaling_factor: float = 1.0,
30
- rope_scaling_type: str = "yarn",
31
- use_qk_norm: bool = False,
32
- sliding_window: Optional[int] = None,
33
- use_alibi: bool = False
34
- ):
35
- super().__init__()
36
-
37
- self.dim = dim
38
- self.n_heads = n_heads
39
- self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
40
-
41
- assert n_heads % self.n_kv_heads == 0, \
42
- f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
43
-
44
- self.n_rep = n_heads // self.n_kv_heads
45
- self.head_dim = head_dim if head_dim is not None else dim // n_heads
46
- self.scale = self.head_dim ** -0.5
47
-
48
- self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
49
- self.sliding_window = sliding_window
50
-
51
- self.q_proj = LinearWithLoRA(
52
- dim, n_heads * self.head_dim,
53
- bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
54
- )
55
- self.k_proj = LinearWithLoRA(
56
- dim, self.n_kv_heads * self.head_dim,
57
- bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
58
- )
59
- self.v_proj = LinearWithLoRA(
60
- dim, self.n_kv_heads * self.head_dim,
61
- bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
62
- )
63
- self.o_proj = LinearWithLoRA(
64
- n_heads * self.head_dim, dim,
65
- bias=False, use_lora=use_lora, lora_rank=lora_rank
66
- )
67
-
68
- self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity()
69
- self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
70
-
71
- self.use_qk_norm = use_qk_norm
72
- if use_qk_norm:
73
- self.q_norm = QKNorm(self.head_dim)
74
- self.k_norm = QKNorm(self.head_dim)
75
-
76
- self.use_alibi = use_alibi
77
- if use_alibi:
78
- self.register_buffer(
79
- "alibi_slopes",
80
- self._get_alibi_slopes(n_heads),
81
- persistent=False
82
- )
83
- else:
84
- self.rotary_emb = YARNRotaryEmbedding(
85
- self.head_dim,
86
- max_seq_len=max_seq_len,
87
- original_max_len=4096,
88
- scaling_factor=rope_scaling_factor,
89
- rope_percentage=1.0
90
- )
91
-
92
- def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor:
93
- """计算ALiBi斜率"""
94
- def get_slopes_power_of_2(n):
95
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
96
- ratio = start
97
- return [start * ratio ** i for i in range(n)]
98
-
99
- if math.log2(n_heads).is_integer():
100
- slopes = get_slopes_power_of_2(n_heads)
101
- else:
102
- closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
103
- slopes = get_slopes_power_of_2(closest_power_of_2)
104
- extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2]
105
- slopes.extend(extra_slopes[:n_heads - closest_power_of_2])
106
-
107
- return torch.tensor(slopes).view(n_heads, 1, 1)
108
-
109
- def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
110
- """重复KV heads以匹配Q heads"""
111
- if self.n_rep == 1:
112
- return x
113
-
114
- B, n_kv_heads, seq_len, head_dim = x.shape
115
- return x[:, :, None, :, :].expand(
116
- B, n_kv_heads, self.n_rep, seq_len, head_dim
117
- ).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim)
118
-
119
- def _apply_sliding_window_mask(
120
- self,
121
- attn_scores: torch.Tensor,
122
- seq_len: int
123
- ) -> torch.Tensor:
124
- """应用滑动窗口mask"""
125
- if self.sliding_window is None or seq_len <= self.sliding_window:
126
- return attn_scores
127
-
128
- mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool)
129
- mask = torch.triu(mask, diagonal=-self.sliding_window + 1)
130
- mask = torch.tril(mask, diagonal=0)
131
-
132
- attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
133
- return attn_scores
134
-
135
- def forward(
136
- self,
137
- x: torch.Tensor,
138
- attention_mask: Optional[torch.Tensor] = None,
139
- position_ids: Optional[torch.Tensor] = None,
140
- use_cache: bool = False,
141
- past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
142
- output_attentions: bool = False
143
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
144
- """前向传播"""
145
- B, T, C = x.shape
146
-
147
- q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
148
- k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
149
- v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
150
-
151
- if self.use_qk_norm:
152
- q_shape = q.shape
153
- k_shape = k.shape
154
- q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape)
155
- k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape)
156
-
157
- if not self.use_alibi:
158
- q, k = self.rotary_emb(q, k, position_ids)
159
-
160
- if past_kv is not None:
161
- past_k, past_v = past_kv
162
- k = torch.cat([past_k, k], dim=2)
163
- v = torch.cat([past_v, v], dim=2)
164
-
165
- present_kv = (k, v) if use_cache else None
166
-
167
- k = self.repeat_kv(k)
168
- v = self.repeat_kv(v)
169
-
170
- seq_len_k = k.size(2)
171
-
172
- if self.use_flash and not output_attentions and attention_mask is None:
173
- dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0
174
- attn_output = F.scaled_dot_product_attention(
175
- q, k, v,
176
- attn_mask=attention_mask,
177
- dropout_p=dropout_p,
178
- is_causal=True if attention_mask is None else False
179
- )
180
- attention_weights = None
181
- else:
182
- attn_scores = (q @ k.transpose(-2, -1)) * self.scale
183
-
184
- if self.use_alibi:
185
- position_bias = self.alibi_slopes.to(x.device) * torch.arange(
186
- seq_len_k, device=x.device
187
- ).view(1, 1, -1)
188
- attn_scores = attn_scores + position_bias
189
-
190
- if self.sliding_window is not None:
191
- attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k)
192
-
193
- if attention_mask is not None:
194
- if attention_mask.dim() == 2:
195
- attention_mask = attention_mask[:, None, None, :]
196
- if attention_mask.dtype != torch.float:
197
- # 假设传入的是 1(Keep)/0(Mask)
198
- extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
199
- else:
200
- # 假设传入的已经是加性 mask (0/-inf)
201
- extended_mask = attention_mask
202
-
203
- attn_scores = attn_scores + extended_mask
204
-
205
- is_causal = seq_len_k > 1
206
- if is_causal:
207
- causal_mask = torch.triu(
208
- torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
209
- diagonal=1
210
- )
211
- causal_mask = causal_mask[-q.shape[2]:, :]#还没懂
212
- attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
213
-
214
- attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
215
- attention_weights = self.attn_dropout(attention_weights)
216
-
217
- attn_output = attention_weights @ v
218
-
219
- attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
220
- output = self.resid_dropout(self.o_proj(attn_output))
221
-
222
- return output, present_kv, attention_weights if output_attentions else None
223
-
224
- class OptimizedTransformerBlock(nn.Module):
225
- """优化的Transformer块"""
226
- def __init__(
227
- self,
228
- dim: int,
229
- n_heads: int,
230
- n_kv_heads: Optional[int] = None,
231
- head_dim: Optional[int] = None,
232
- dropout: float = 0.0,
233
- attn_dropout: float = 0.0,
234
- use_moe: bool = False,
235
- num_experts: int = 8,
236
- moe_top_k: int = 2,
237
- use_adapter: bool = False,
238
- adapter_dim: int = 64,
239
- use_lora: bool = False,
240
- lora_rank: int = 8,
241
- use_parallel_residual: bool = False,
242
- norm_eps: float = 1e-6,
243
- sliding_window: Optional[int] = None,
244
- ffn_dim_multiplier: Optional[float] = None,
245
- layer_idx: int = 0
246
- ):
247
- super().__init__()
248
- self.layer_idx = layer_idx
249
- self.use_moe = use_moe
250
- self.use_adapter = use_adapter
251
- self.use_parallel_residual = use_parallel_residual
252
-
253
- self.attention = GroupedQueryAttention(
254
- dim=dim,
255
- n_heads=n_heads,
256
- n_kv_heads=n_kv_heads,
257
- head_dim=head_dim,
258
- dropout=dropout,
259
- attn_dropout=attn_dropout,
260
- use_lora=use_lora,
261
- lora_rank=lora_rank,
262
- sliding_window=sliding_window,
263
- rope_scaling_type="yarn"
264
- )
265
-
266
- if use_moe:
267
- self.ffn = MixtureOfExperts(
268
- dim=dim,
269
- num_experts=num_experts,
270
- top_k=moe_top_k,
271
- dropout=dropout,
272
- ffn_dim_multiplier=ffn_dim_multiplier
273
- )
274
- else:
275
- self.ffn = SwiGLU(
276
- dim=dim,
277
- dropout=dropout,
278
- ffn_dim_multiplier=ffn_dim_multiplier
279
- )
280
-
281
- if use_adapter:
282
- self.adapter = AdapterLayer(dim, adapter_dim, dropout)
283
-
284
- self.attention_norm = RMSNorm(dim, eps=norm_eps)
285
- self.ffn_norm = RMSNorm(dim, eps=norm_eps)
286
-
287
- self.moe_aux_loss = torch.tensor(0.0)
288
-
289
- def forward(
290
- self,
291
- x: torch.Tensor,
292
- attention_mask: Optional[torch.Tensor] = None,
293
- position_ids: Optional[torch.Tensor] = None,
294
- use_cache: bool = False,
295
- past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
296
- output_attentions: bool = False
297
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
298
- """前向传播"""
299
-
300
- attn_out, present_kv, attn_weights = self.attention(
301
- self.attention_norm(x),
302
- attention_mask=attention_mask,
303
- position_ids=position_ids,
304
- use_cache=use_cache,
305
- past_kv=past_kv,
306
- output_attentions=output_attentions
307
- )
308
-
309
- if self.use_parallel_residual:
310
- ffn_input = self.ffn_norm(x)
311
-
312
- if self.use_moe:
313
- ffn_out, aux_loss = self.ffn(ffn_input)
314
- self.moe_aux_loss = aux_loss
315
- else:
316
- ffn_out = self.ffn(ffn_input)
317
- self.moe_aux_loss = torch.tensor(0.0, device=x.device)
318
-
319
- x = x + attn_out + ffn_out
320
- else:
321
- x = x + attn_out
322
-
323
- if self.use_adapter:
324
- x = self.adapter(x)
325
-
326
- ffn_input = self.ffn_norm(x)
327
- if self.use_moe:
328
- ffn_out, aux_loss = self.ffn(ffn_input)
329
- x = x + ffn_out
330
- self.moe_aux_loss = aux_loss
331
- else:
332
- x = x + self.ffn(ffn_input)
333
- self.moe_aux_loss = torch.tensor(0.0, device=x.device)
334
-
335
  return x, present_kv, attn_weights
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, List
5
+ import math
6
+ from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm
7
+ from peft_ import LinearWithLoRA, AdapterLayer
8
+ from moe import MixtureOfExperts
9
+
10
+ class GroupedQueryAttention(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim: int,
14
+ n_heads: int,
15
+ n_kv_heads: Optional[int] = None,
16
+ head_dim: Optional[int] = None,
17
+ dropout: float = 0.0,
18
+ attn_dropout: float = 0.0,
19
+ use_flash: bool = True,
20
+ qkv_bias: bool = False,
21
+ use_lora: bool = False,
22
+ lora_rank: int = 8,
23
+ max_seq_len: int = 8192,
24
+ rope_scaling_factor: float = 1.0,
25
+ rope_scaling_type: str = "yarn",
26
+ use_qk_norm: bool = False,
27
+ sliding_window: Optional[int] = None,
28
+ use_alibi: bool = False
29
+ ):
30
+ super().__init__()
31
+
32
+ self.dim = dim
33
+ self.n_heads = n_heads
34
+ self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
35
+
36
+ assert n_heads % self.n_kv_heads == 0, \
37
+ f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
38
+
39
+ self.n_rep = n_heads // self.n_kv_heads
40
+ self.head_dim = head_dim if head_dim is not None else dim // n_heads
41
+ self.scale = self.head_dim ** -0.5
42
+
43
+ self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
44
+ self.sliding_window = sliding_window
45
+
46
+ self.q_proj = LinearWithLoRA(
47
+ dim, n_heads * self.head_dim,
48
+ bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
49
+ )
50
+ self.k_proj = LinearWithLoRA(
51
+ dim, self.n_kv_heads * self.head_dim,
52
+ bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
53
+ )
54
+ self.v_proj = LinearWithLoRA(
55
+ dim, self.n_kv_heads * self.head_dim,
56
+ bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
57
+ )
58
+ self.o_proj = LinearWithLoRA(
59
+ n_heads * self.head_dim, dim,
60
+ bias=False, use_lora=use_lora, lora_rank=lora_rank
61
+ )
62
+
63
+ self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity()
64
+ self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
65
+
66
+ self.use_qk_norm = use_qk_norm
67
+ if use_qk_norm:
68
+ self.q_norm = QKNorm(self.head_dim)
69
+ self.k_norm = QKNorm(self.head_dim)
70
+
71
+ self.use_alibi = use_alibi
72
+ if use_alibi:
73
+ self.register_buffer(
74
+ "alibi_slopes",
75
+ self._get_alibi_slopes(n_heads),
76
+ persistent=False
77
+ )
78
+ else:
79
+ self.rotary_emb = YARNRotaryEmbedding(
80
+ self.head_dim,
81
+ max_seq_len=max_seq_len,
82
+ original_max_len=4096,
83
+ scaling_factor=rope_scaling_factor,
84
+ rope_percentage=1.0
85
+ )
86
+
87
+ def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor:
88
+ """计算ALiBi斜率"""
89
+ def get_slopes_power_of_2(n):
90
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
91
+ ratio = start
92
+ return [start * ratio ** i for i in range(n)]
93
+
94
+ if math.log2(n_heads).is_integer():
95
+ slopes = get_slopes_power_of_2(n_heads)
96
+ else:
97
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
98
+ slopes = get_slopes_power_of_2(closest_power_of_2)
99
+ extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2]
100
+ slopes.extend(extra_slopes[:n_heads - closest_power_of_2])
101
+
102
+ return torch.tensor(slopes).view(n_heads, 1, 1)
103
+
104
+ def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
105
+ """重复KV heads以匹配Q heads"""
106
+ if self.n_rep == 1:
107
+ return x
108
+
109
+ B, n_kv_heads, seq_len, head_dim = x.shape
110
+ return x[:, :, None, :, :].expand(
111
+ B, n_kv_heads, self.n_rep, seq_len, head_dim
112
+ ).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim)
113
+
114
+ def _apply_sliding_window_mask(
115
+ self,
116
+ attn_scores: torch.Tensor,
117
+ seq_len: int
118
+ ) -> torch.Tensor:
119
+ """应用滑动窗口mask"""
120
+ if self.sliding_window is None or seq_len <= self.sliding_window:
121
+ return attn_scores
122
+
123
+ mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool)
124
+ mask = torch.triu(mask, diagonal=-self.sliding_window + 1)
125
+ mask = torch.tril(mask, diagonal=0)
126
+
127
+ attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
128
+ return attn_scores
129
+
130
+ def forward(
131
+ self,
132
+ x: torch.Tensor,
133
+ attention_mask: Optional[torch.Tensor] = None,
134
+ position_ids: Optional[torch.Tensor] = None,
135
+ use_cache: bool = False,
136
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
137
+ output_attentions: bool = False
138
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
139
+ """前向传播"""
140
+ B, T, C = x.shape
141
+
142
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
143
+ k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
144
+ v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
145
+
146
+ if self.use_qk_norm:
147
+ q_shape = q.shape
148
+ k_shape = k.shape
149
+ q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape)
150
+ k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape)
151
+
152
+ if not self.use_alibi:
153
+ q, k = self.rotary_emb(q, k, position_ids)
154
+
155
+ if past_kv is not None:
156
+ past_k, past_v = past_kv
157
+ k = torch.cat([past_k, k], dim=2)
158
+ v = torch.cat([past_v, v], dim=2)
159
+
160
+ present_kv = (k, v) if use_cache else None
161
+
162
+ k = self.repeat_kv(k)
163
+ v = self.repeat_kv(v)
164
+
165
+ seq_len_k = k.size(2)
166
+
167
+ if self.use_flash and not output_attentions and attention_mask is None:
168
+ dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0
169
+ attn_output = F.scaled_dot_product_attention(
170
+ q, k, v,
171
+ attn_mask=attention_mask,
172
+ dropout_p=dropout_p,
173
+ is_causal=True if attention_mask is None else False
174
+ )
175
+ attention_weights = None
176
+ else:
177
+ attn_scores = (q @ k.transpose(-2, -1)) * self.scale
178
+
179
+ if self.use_alibi:
180
+ position_bias = self.alibi_slopes.to(x.device) * torch.arange(
181
+ seq_len_k, device=x.device
182
+ ).view(1, 1, -1)
183
+ attn_scores = attn_scores + position_bias
184
+
185
+ if self.sliding_window is not None:
186
+ attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k)
187
+
188
+ if attention_mask is not None:
189
+ if attention_mask.dim() == 2:
190
+ attention_mask = attention_mask[:, None, None, :]
191
+ if attention_mask.dtype != torch.float:
192
+ # 假设传入的是 1(Keep)/0(Mask)
193
+ extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
194
+ else:
195
+ # 假设传入的已经是加性 mask (0/-inf)
196
+ extended_mask = attention_mask
197
+
198
+ attn_scores = attn_scores + extended_mask
199
+
200
+ is_causal = seq_len_k > 1
201
+ if is_causal:
202
+ causal_mask = torch.triu(
203
+ torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
204
+ diagonal=1
205
+ )
206
+ causal_mask = causal_mask[-q.shape[2]:, :]#还没懂
207
+ attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
208
+
209
+ attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
210
+ attention_weights = self.attn_dropout(attention_weights)
211
+
212
+ attn_output = attention_weights @ v
213
+
214
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
215
+ output = self.resid_dropout(self.o_proj(attn_output))
216
+
217
+ return output, present_kv, attention_weights if output_attentions else None
218
+
219
+ class OptimizedTransformerBlock(nn.Module):
220
+ """优化的Transformer块"""
221
+ def __init__(
222
+ self,
223
+ dim: int,
224
+ n_heads: int,
225
+ n_kv_heads: Optional[int] = None,
226
+ head_dim: Optional[int] = None,
227
+ dropout: float = 0.0,
228
+ attn_dropout: float = 0.0,
229
+ use_moe: bool = False,
230
+ num_experts: int = 8,
231
+ moe_top_k: int = 2,
232
+ use_adapter: bool = False,
233
+ adapter_dim: int = 64,
234
+ use_lora: bool = False,
235
+ lora_rank: int = 8,
236
+ use_parallel_residual: bool = False,
237
+ norm_eps: float = 1e-6,
238
+ sliding_window: Optional[int] = None,
239
+ ffn_dim_multiplier: Optional[float] = None,
240
+ layer_idx: int = 0
241
+ ):
242
+ super().__init__()
243
+ self.layer_idx = layer_idx
244
+ self.use_moe = use_moe
245
+ self.use_adapter = use_adapter
246
+ self.use_parallel_residual = use_parallel_residual
247
+
248
+ self.attention = GroupedQueryAttention(
249
+ dim=dim,
250
+ n_heads=n_heads,
251
+ n_kv_heads=n_kv_heads,
252
+ head_dim=head_dim,
253
+ dropout=dropout,
254
+ attn_dropout=attn_dropout,
255
+ use_lora=use_lora,
256
+ lora_rank=lora_rank,
257
+ sliding_window=sliding_window,
258
+ rope_scaling_type="yarn"
259
+ )
260
+
261
+ if use_moe:
262
+ self.ffn = MixtureOfExperts(
263
+ dim=dim,
264
+ num_experts=num_experts,
265
+ top_k=moe_top_k,
266
+ dropout=dropout,
267
+ ffn_dim_multiplier=ffn_dim_multiplier
268
+ )
269
+ else:
270
+ self.ffn = SwiGLU(
271
+ dim=dim,
272
+ dropout=dropout,
273
+ ffn_dim_multiplier=ffn_dim_multiplier
274
+ )
275
+
276
+ if use_adapter:
277
+ self.adapter = AdapterLayer(dim, adapter_dim, dropout)
278
+
279
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
280
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
281
+
282
+ self.moe_aux_loss = torch.tensor(0.0)
283
+
284
+ def forward(
285
+ self,
286
+ x: torch.Tensor,
287
+ attention_mask: Optional[torch.Tensor] = None,
288
+ position_ids: Optional[torch.Tensor] = None,
289
+ use_cache: bool = False,
290
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
291
+ output_attentions: bool = False
292
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
293
+ """前向传播"""
294
+
295
+ attn_out, present_kv, attn_weights = self.attention(
296
+ self.attention_norm(x),
297
+ attention_mask=attention_mask,
298
+ position_ids=position_ids,
299
+ use_cache=use_cache,
300
+ past_kv=past_kv,
301
+ output_attentions=output_attentions
302
+ )
303
+
304
+ if self.use_parallel_residual:
305
+ ffn_input = self.ffn_norm(x)
306
+
307
+ if self.use_moe:
308
+ ffn_out, aux_loss = self.ffn(ffn_input)
309
+ self.moe_aux_loss = aux_loss
310
+ else:
311
+ ffn_out = self.ffn(ffn_input)
312
+ self.moe_aux_loss = torch.tensor(0.0, device=x.device)
313
+
314
+ x = x + attn_out + ffn_out
315
+ else:
316
+ x = x + attn_out
317
+
318
+ if self.use_adapter:
319
+ x = self.adapter(x)
320
+
321
+ ffn_input = self.ffn_norm(x)
322
+ if self.use_moe:
323
+ ffn_out, aux_loss = self.ffn(ffn_input)
324
+ x = x + ffn_out
325
+ self.moe_aux_loss = aux_loss
326
+ else:
327
+ x = x + self.ffn(ffn_input)
328
+ self.moe_aux_loss = torch.tensor(0.0, device=x.device)
329
+
 
 
 
 
 
330
  return x, present_kv, attn_weights