chen-yingfa commited on
Commit
67967d1
·
verified ·
1 Parent(s): 46d7890

Delete gdn.py

Browse files
Files changed (1) hide show
  1. gdn.py +0 -403
gdn.py DELETED
@@ -1,403 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
-
4
- from __future__ import annotations
5
-
6
- import math
7
- from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
-
9
- import torch
10
- from torch import Tensor, nn
11
- from einops import rearrange, repeat
12
- from torch.nn import functional as F
13
-
14
- from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
- from fla.modules.l2norm import l2_norm
16
- from fla.ops.gated_delta_rule import (
17
- chunk_gated_delta_rule,
18
- fused_recurrent_gated_delta_rule,
19
- )
20
- from .configuration_hybrid import HybridConfig
21
- from .modeling_qwen3 import Qwen3Attention, apply_rotary_pos_emb
22
-
23
- if TYPE_CHECKING:
24
- from transformers.processing_utils import Unpack
25
-
26
- from fla.models.utils import Cache
27
-
28
-
29
- def elu_p1(x):
30
- return (F.elu(x, 1., False) + 1.).to(x)
31
-
32
-
33
- def sum_norm(x):
34
- return (x / x.sum(-1, keepdim=True)).to(x)
35
-
36
- # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
37
-
38
-
39
- class GatedDeltaNet(nn.Module):
40
- """
41
- The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
42
-
43
- Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
44
- Parameter alloation when use_gate=True:
45
- - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
46
- - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
47
- - Others are ignorably small.
48
- - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
49
- NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
50
-
51
- Parameter allocation when use_gate=False:
52
- - 1 * hidden_size * hidden_size for the q_proj and k_proj each
53
- - 2 * hidden_size * hidden_size for the v_proj and o_proj each
54
- - Others are ignorably small.
55
- - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
56
-
57
- Args:
58
- hidden_size (int, Optional):
59
- The hidden size of the input. Default: 2048.
60
- expand_v (float, Optional):
61
- The expansion ratio for the value dim. Default: 2.0.
62
- head_dim (int, Optional):
63
- The dimension of each head. Default: 256.
64
- num_heads (int, Optional):
65
- The number of heads. Default: 4.
66
- mode (str, Optional):
67
- Which Gated DeltaNet kernel to use.
68
- Currently available: `chunk` and `fused_recurrent`.
69
- Default: `chunk`.
70
- use_beta (bool, Optional):
71
- Whether to use beta. Default: `True`.
72
- use_gate (bool, Optional):
73
- Whether to use output gate. Default: `True`.
74
- use_short_conv (bool, Optional):
75
- Whether to use short convolutions. Default: `True`.
76
- conv_size (int, Optional):
77
- The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
78
- conv_bias (bool, Optional):
79
- Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
80
- layer_idx (int, Optional):
81
- The index of the layer. Default: None.
82
- norm_eps (float, Optional):
83
- The epsilon value for the normalization layer. Default: 1e-5.
84
- """
85
-
86
- def __init__(
87
- self,
88
- layer_idx: Optional[int] = None,
89
- hidden_size: int = 2048,
90
- expand_v: float = 2,
91
- # head_dim: int = 256,
92
- key_dim: int = 128,
93
- val_dim: int = 128,
94
- num_heads: int = 32,
95
- num_kv_heads: int = 8,
96
- mode: str = 'chunk',
97
- use_gate: bool = True,
98
- use_short_conv: bool = True,
99
- conv_size: int = 4,
100
- conv_bias: bool = False,
101
- norm_eps: float = 1e-5,
102
- activation: Optional[str] = None,
103
- qk_norm: bool = False,
104
- use_rope: bool = False,
105
- **kwargs,
106
- ):
107
- super().__init__()
108
-
109
- self.mode = mode
110
-
111
- self.hidden_size = hidden_size
112
- self.expand_v = expand_v
113
-
114
- self.use_gate = use_gate
115
- self.use_short_conv = use_short_conv
116
- self.conv_size = conv_size
117
- self.conv_bias = conv_bias
118
-
119
- # self.head_dim = head_dim
120
- self.key_dim = key_dim
121
- self.val_dim = val_dim
122
- self.num_heads = num_heads
123
- self.num_kv_heads = num_kv_heads
124
-
125
- self.k_dim = self.num_kv_heads * key_dim
126
- self.v_dim = self.num_kv_heads * val_dim
127
- self.q_dim = self.num_heads * key_dim
128
- self.layer_idx = layer_idx
129
- self.activation = activation
130
- self.qk_norm = qk_norm
131
- self.use_rope = use_rope
132
- self.silu = nn.SiLU()
133
-
134
- assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
135
-
136
- if self.qk_norm:
137
- self.q_norm = RMSNorm(key_dim, eps=norm_eps)
138
- self.k_norm = RMSNorm(key_dim, eps=norm_eps)
139
- self.q_proj = nn.Linear(hidden_size, self.q_dim, bias=False)
140
- self.k_proj = nn.Linear(hidden_size, self.k_dim, bias=False)
141
- self.v_proj = nn.Linear(hidden_size, self.v_dim, bias=False)
142
- self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
143
- self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
144
- A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
145
- A_log = torch.log(A)
146
- self.A_log = nn.Parameter(A_log)
147
- self.A_log._no_weight_decay = True
148
- # self.D = nn.Parameter(torch.ones(self.num_heads))
149
- # self.D._no_weight_decay = True
150
- # hard coded for now
151
- dt_min = 0.001
152
- dt_max = 0.1
153
- dt_init_floor = 1e-4
154
- dt = torch.exp(
155
- torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
156
- + math.log(dt_min)
157
- )
158
- dt = torch.clamp(dt, min=dt_init_floor)
159
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
160
- inv_dt = dt + torch.log(-torch.expm1(-dt))
161
- self.dt_bias = nn.Parameter(inv_dt)
162
- # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
163
- # name.endswith("bias") in param_grouping.py
164
- self.dt_bias._no_weight_decay = True
165
-
166
- if use_short_conv:
167
- self.conv_size = conv_size
168
- self.q_conv1d = ShortConvolution(
169
- hidden_size=self.key_dim,
170
- kernel_size=conv_size,
171
- activation='silu',
172
- use_fast_conv1d=False,
173
- )
174
- self.k_conv1d = ShortConvolution(
175
- hidden_size=self.key_dim,
176
- kernel_size=conv_size,
177
- activation='silu',
178
- use_fast_conv1d=False,
179
- )
180
- self.v_conv1d = ShortConvolution(
181
- hidden_size=self.v_dim,
182
- kernel_size=conv_size,
183
- activation='silu',
184
- use_fast_conv1d=False,
185
- )
186
- # else:
187
- # raise UserWarning(
188
- # "ShortConvolution is crucial to the performance. "
189
- # "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
190
- # )
191
- if use_gate:
192
- self.g_proj = nn.Linear(hidden_size, self.num_heads * self.val_dim, bias=False)
193
- self.o_norm = FusedRMSNormSwishGate(self.val_dim, eps=norm_eps)
194
- else:
195
- self.o_norm = RMSNorm(self.val_dim, eps=norm_eps)
196
- self.o_proj = nn.Linear(self.num_heads * self.val_dim, hidden_size, bias=False)
197
- self.apply(self._initialize_weights)
198
-
199
- def _initialize_weights(self, module: nn.Module):
200
- if getattr(module, "_is_hf_initialized", False):
201
- return
202
- if isinstance(module, nn.Linear):
203
- nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
204
- if module.bias is not None:
205
- nn.init.zeros_(module.bias)
206
- module._is_hf_initialized = True
207
-
208
- def forward(
209
- self,
210
- hidden_states: torch.Tensor,
211
- attention_mask: Optional[torch.Tensor] = None,
212
- past_key_values: Optional[Cache] = None,
213
- use_cache: Optional[bool] = False,
214
- output_attentions: Optional[bool] = False,
215
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
216
- **kwargs,
217
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
218
- attention_mask = None
219
- if attention_mask is not None:
220
- assert len(attention_mask.shape) == 2, (
221
- "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
222
- "for padding purposes (0 indicating padding). "
223
- "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
224
- )
225
-
226
- mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
227
- if self.training:
228
- assert mode == 'chunk', "Only chunk mode is supported in training."
229
-
230
- last_state = None
231
- if past_key_values is not None and len(past_key_values) > self.layer_idx:
232
- last_state = past_key_values[self.layer_idx]
233
-
234
- if self.use_short_conv:
235
- conv_state_q, conv_state_k, conv_state_v = None, None, None
236
- if last_state is not None:
237
- conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
238
- conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
239
- q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
240
- mask=conv_mask,
241
- cache=conv_state_q,
242
- output_final_state=use_cache)
243
- k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
244
- mask=conv_mask,
245
- cache=conv_state_k,
246
- output_final_state=use_cache)
247
- v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
248
- mask=conv_mask,
249
- cache=conv_state_v,
250
- output_final_state=use_cache)
251
- else:
252
- q = self.q_proj(hidden_states)
253
- k = self.k_proj(hidden_states)
254
- v = self.v_proj(hidden_states)
255
- if self.activation is not None:
256
- q = self.silu(q)
257
- k = self.silu(k)
258
- v = self.silu(v)
259
-
260
- q = rearrange(q, 'b t (h d) -> b t h d', d=self.key_dim)
261
- k = rearrange(k, 'b t (h d) -> b t h d', d=self.key_dim)
262
- v = rearrange(v, 'b t (h d) -> b t h d', d=self.val_dim)
263
-
264
- if self.qk_norm:
265
- q = self.q_norm(q)
266
- k = self.k_norm(k)
267
-
268
- if self.use_rope:
269
- assert position_embeddings is not None
270
- cos, sin = position_embeddings
271
- q, k = q.transpose(1, 2), k.transpose(1, 2)
272
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
273
- q, k = q.transpose(1, 2), k.transpose(1, 2)
274
-
275
- q = l2_norm(q)
276
- k = l2_norm(k)
277
- # Allow negative eigenvalues
278
- beta = self.b_proj(hidden_states).sigmoid() * 2
279
- g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
280
-
281
- # Handle grouped-query, maybe we should untie the weights to go back to MHA?
282
- if self.num_kv_heads < self.num_heads:
283
- group_size = self.num_heads // self.num_kv_heads
284
- k = repeat(k, 'b t h d -> b t (h g) d', g=group_size) # (B, T, nh, dh)
285
- v = repeat(v, 'b t h d -> b t (h g) d', g=group_size) # (B, T, nh, dh)
286
-
287
- # dealing with padding
288
- if attention_mask is not None:
289
- beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
290
- g = g.mul(attention_mask[:, -g.shape[-2]:, None])
291
-
292
- recurrent_state = last_state['recurrent_state'] if last_state is not None else None
293
- # offsets = kwargs.get('offsets', None)
294
- if mode == 'chunk':
295
- o, recurrent_state = chunk_gated_delta_rule(
296
- q=q,
297
- k=k,
298
- v=v,
299
- g=g,
300
- beta=beta,
301
- initial_state=recurrent_state,
302
- output_final_state=use_cache,
303
- # offsets=offsets,
304
- # head_first=False
305
- )
306
- elif mode == 'fused_recurrent':
307
- o, recurrent_state = fused_recurrent_gated_delta_rule(
308
- q=q,
309
- k=k,
310
- v=v,
311
- g=g,
312
- beta=beta,
313
- initial_state=recurrent_state,
314
- output_final_state=use_cache,
315
- # offsets=offsets,
316
- # head_first=False
317
- )
318
- if past_key_values is not None:
319
- past_key_values.update(
320
- recurrent_state=recurrent_state,
321
- conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
322
- layer_idx=self.layer_idx,
323
- offset=q.shape[2]
324
- )
325
-
326
- if self.use_gate:
327
- g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
328
- o = self.o_norm(o, g)
329
- else:
330
- o = self.o_norm(o)
331
- o = rearrange(o, 'b t h d -> b t (h d)')
332
- o = self.o_proj(o)
333
-
334
- return o, None, past_key_values
335
-
336
-
337
-
338
- def build_gdn_with_attn(
339
- attn_layer: Qwen3Attention,
340
- layer_idx: int,
341
- config: HybridConfig,
342
- ) -> nn.Module:
343
- """
344
- Initialize a Gated DeltaNet block using the parameters of a Qwen3Attention layer.
345
- We instantiate the GDN block such that the QKVO projections have the same shape,
346
- then copy the weights from the Qwen3Attention layer.
347
- """
348
-
349
- gdn_block = GatedDeltaNet(
350
- hidden_size=config.hidden_size,
351
- layer_idx=layer_idx,
352
- expand_v=1.0,
353
- num_heads=config.gdn_nh,
354
- num_kv_heads=config.gdn_nkv,
355
- key_dim=config.head_dim,
356
- val_dim=config.head_dim,
357
- use_short_conv=config.gdn_use_short_conv,
358
- use_gate=config.gdn_use_gate,
359
- norm_eps=config.rms_norm_eps,
360
- activation=config.gdn_activation,
361
- qk_norm=config.gdn_use_qk_norm,
362
- use_rope=config.gdn_use_rope,
363
- )
364
-
365
- q_proj: nn.Linear = attn_layer.q_proj
366
- k_proj: nn.Linear = attn_layer.k_proj
367
- v_proj: nn.Linear = attn_layer.v_proj
368
- o_proj: nn.Linear = attn_layer.o_proj
369
- # Note that the `.weight.shape` for a projection from d1 to d2 is (d2, d1)
370
- wq: Tensor = q_proj.weight # (nh * dh, d)
371
- wk: Tensor = k_proj.weight # (nkv * dh, d)
372
- wv: Tensor = v_proj.weight # (nkv * dh, d)
373
- wo: Tensor = o_proj.weight # (d, nh * dh)
374
-
375
- if config.expand_kv_proj:
376
- wk = wk.reshape(-1, config.head_dim, config.hidden_size)
377
- wv = wv.reshape(-1, config.head_dim, config.hidden_size)
378
- assert wk.shape[1] == wv.shape[1], wk.shape[1] == config.num_key_value_heads
379
-
380
- # Repeat KV projections to convert it to MHA
381
- target_kv_size = config.lightning_nkv * config.lightning_head_dim
382
- orig_kv_size = config.num_key_value_heads * config.head_dim
383
- expand_size = target_kv_size // orig_kv_size
384
- wk = wk.repeat_interleave(expand_size, dim=0)
385
- wv = wv.repeat_interleave(expand_size, dim=0)
386
-
387
- wk = wk.reshape(-1, config.hidden_size)
388
- wv = wv.reshape(-1, config.hidden_size)
389
-
390
- # ==== Create target module ====
391
- gdn_block.q_proj.weight.data.copy_(wq)
392
- gdn_block.k_proj.weight.data.copy_(wk)
393
- gdn_block.v_proj.weight.data.copy_(wv)
394
- gdn_block.o_proj.weight.data.copy_(wo)
395
-
396
- if hasattr(gdn_block, 'q_norm') and hasattr(attn_layer, 'q_norm'):
397
- gdn_block.q_norm.weight.data.copy_(attn_layer.q_norm.weight.data.clone())
398
-
399
- if hasattr(gdn_block, 'k_norm') and hasattr(attn_layer, 'k_norm'):
400
- gdn_block.k_norm.weight.data.copy_(attn_layer.k_norm.weight.data.clone())
401
-
402
-
403
- return gdn_block