mmcarpi commited on
Commit
4e1b142
·
verified ·
1 Parent(s): f73cf40

Upload custom model with source code and tokenizer

Browse files
common.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class CastedLinear(nn.Linear):
8
+ def forward(self, x: torch.FloatTensor):
9
+ if self.weight.device.type == "meta":
10
+ return nn.functional.linear(x, self.weight)
11
+ return nn.functional.linear(x, self.weight.type_as(x))
12
+
13
+
14
+ class FeedForward(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embedding_dim: int,
18
+ hidden_dim: int,
19
+ device: torch.device,
20
+ dtype: torch.dtype | None = None,
21
+ ):
22
+ factory_kwargs = dict(device=device, dtype=dtype)
23
+ super().__init__()
24
+
25
+ self.fc1 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs)
26
+ self.fc2 = CastedLinear(embedding_dim, hidden_dim, bias=False, **factory_kwargs)
27
+ self.fc3 = CastedLinear(hidden_dim, embedding_dim, bias=False, **factory_kwargs)
28
+
29
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
30
+ x_fc1 = self.fc1(x)
31
+ x_fc2 = self.fc2(x)
32
+
33
+ x = nn.functional.silu(x_fc1) * x_fc2
34
+ x = self.fc3(x)
35
+ return x
36
+
37
+
38
+ class MoEFeedForward(nn.Module):
39
+ def __init__(
40
+ self,
41
+ embedding_dim: int,
42
+ hidden_dim: int,
43
+ num_experts_per_token: int,
44
+ num_experts: int,
45
+ device: torch.device,
46
+ dtype: torch.dtype | None = None,
47
+ ):
48
+ assert num_experts > 0, "num_experts should be greater than zero"
49
+ assert num_experts >= num_experts_per_token > 0, (
50
+ "num_experts_per_token should be greater than zero and less than or equal to num_experts"
51
+ )
52
+ super().__init__()
53
+ self.num_experts_per_token = num_experts_per_token
54
+ self.num_experts = num_experts
55
+ meta_device = torch.device("meta")
56
+
57
+ self.gate = CastedLinear(
58
+ embedding_dim, num_experts, bias=False, device=device, dtype=dtype
59
+ )
60
+ self.ff = nn.ModuleList(
61
+ [
62
+ FeedForward(
63
+ embedding_dim,
64
+ hidden_dim,
65
+ device=meta_device,
66
+ dtype=dtype,
67
+ )
68
+ for _ in range(num_experts)
69
+ ]
70
+ )
71
+
72
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
73
+ scores = self.gate(x)
74
+ topk_scores, topk_indices = torch.topk(
75
+ scores, self.num_experts_per_token, dim=-1
76
+ )
77
+ topk_probs = torch.softmax(topk_scores, dim=-1)
78
+
79
+ expert_outputs = []
80
+ for i in range(self.num_experts):
81
+ out = self.ff[i](x)
82
+ expert_outputs.append(out.unsqueeze(-2))
83
+ expert_outputs = torch.cat(expert_outputs, dim=-2)
84
+
85
+ gating_probs = torch.zeros_like(scores)
86
+ for i in range(self.num_experts_per_token):
87
+ indices = topk_indices[..., i : i + 1]
88
+ prob = topk_probs[..., i : i + 1]
89
+ gating_probs.scatter_(dim=-1, index=indices, src=prob)
90
+ gating_probs = gating_probs.unsqueeze(-1)
91
+ y = (gating_probs * expert_outputs).sum(dim=-2)
92
+ return y
93
+
94
+
95
+ class RMSNorm(nn.Module):
96
+ def __init__(
97
+ self,
98
+ embedding_dim: int,
99
+ eps: float = 1e-6,
100
+ bias: bool = False,
101
+ device: torch.device | None = None,
102
+ dtype: torch.dtype | None = None,
103
+ ):
104
+ factory_kwargs = dict(device=device, dtype=dtype)
105
+ super().__init__()
106
+ self.scale = nn.Parameter(torch.ones(embedding_dim, **factory_kwargs))
107
+ self.eps = eps
108
+ self.shift = (
109
+ nn.Parameter(torch.zeros(embedding_dim, **factory_kwargs)) if bias else None
110
+ )
111
+ self.dtype = dtype
112
+
113
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
114
+ input_dtype = x.dtype
115
+
116
+ variance = x.to(self.dtype).pow(2).mean(dim=-1, keepdim=True)
117
+ norm_x = x * torch.rsqrt(variance + self.eps)
118
+ norm_x = norm_x * self.scale
119
+
120
+ if self.shift is not None:
121
+ norm_x = norm_x + self.shift
122
+
123
+ return norm_x.to(input_dtype)
124
+
125
+
126
+ def compute_rope_params(
127
+ head_dim: int,
128
+ theta_base: int = 10_000,
129
+ context_length: int = 4096,
130
+ dtype: Optional[torch.dtype] = torch.float32,
131
+ device: Optional[torch.device] = None,
132
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
133
+ assert head_dim % 2 == 0, "Embedding dim (head_dim) must be even"
134
+
135
+ inv_freq = 1.0 / (
136
+ theta_base
137
+ ** (
138
+ torch.arange(0, head_dim, 2, dtype=dtype, device=device)[
139
+ : head_dim // 2
140
+ ].float()
141
+ / head_dim
142
+ )
143
+ )
144
+
145
+ positions = torch.arange(context_length, dtype=dtype, device=device)
146
+ angles = positions[:, None] * inv_freq[None, :]
147
+ angles = torch.cat([angles, angles], dim=1)
148
+
149
+ cos = torch.cos(angles)
150
+ sin = torch.sin(angles)
151
+ return cos, sin
152
+
153
+
154
+ def apply_rope(
155
+ x: torch.FloatTensor,
156
+ cos: torch.FloatTensor,
157
+ sin: torch.FloatTensor,
158
+ offset: int = 0,
159
+ ) -> torch.FloatTensor:
160
+ assert x.dim() == 4, "expected tensor of dimension 3 (B, NH, S, H)"
161
+ _, _, seq_len, head_dim = x.shape
162
+ assert head_dim % 2 == 0, "head_dim must be even"
163
+
164
+ x1 = x[..., : head_dim // 2]
165
+ x2 = x[..., : head_dim // 2 :]
166
+ cos = cos[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0)
167
+ sin = sin[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0)
168
+ rotated = torch.cat((-x2, x1), dim=-1)
169
+ x_rotated = (x * cos) + (rotated * sin)
170
+ x_rotated = x_rotated.type_as(x)
171
+
172
+ return x_rotated
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FlexQwenForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModel": "qwen.FlexQwen",
9
+ "AutoModelForCausalLM": "qwen.FlexQwenForCausalLM",
10
+ "AutoModelForSequenceClassification": "qwen.FlexQwenForSequenceClassification"
11
+ },
12
+ "cls_token_id": 1,
13
+ "context_length": 4096,
14
+ "embedding_dim": 1024,
15
+ "head_dim": 128,
16
+ "hidden_act": "silu",
17
+ "hidden_dim": 2048,
18
+ "hidden_size": 4096,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 22016,
21
+ "max_position_embeddings": 32768,
22
+ "max_window_layers": 28,
23
+ "model_type": "qwen3",
24
+ "moe_hidden_dim": 512,
25
+ "moe_num_experts": 0,
26
+ "moe_num_experts_per_token": -1,
27
+ "num_attention_heads": 8,
28
+ "num_hidden_layers": 32,
29
+ "num_key_value_heads": 32,
30
+ "num_kv_groups": 8,
31
+ "pad_token_id": 3,
32
+ "qk_norm": true,
33
+ "rms_norm_eps": 1e-06,
34
+ "rope_scaling": null,
35
+ "rope_theta": 10000,
36
+ "sliding_window": 4096,
37
+ "tie_word_embeddings": false,
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.51.3",
40
+ "use_cache": true,
41
+ "use_sliding_window": false,
42
+ "vocab_size": 64000
43
+ }
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 3,
4
+ "transformers_version": "4.51.3"
5
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e340dba542c92d7d93cbfd27702a8e3d188af47e21cfe39873ea91228061e223
3
+ size 1866802096
qwen.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import PreTrainedModel, Qwen3Config, GenerationMixin
7
+ from transformers.utils import ModelOutput
8
+ from transformers.modeling_outputs import (
9
+ SequenceClassifierOutput,
10
+ CausalLMOutputWithPast,
11
+ )
12
+
13
+ from .common import (
14
+ FeedForward,
15
+ MoEFeedForward,
16
+ RMSNorm,
17
+ compute_rope_params,
18
+ apply_rope,
19
+ CastedLinear,
20
+ )
21
+
22
+
23
+ class FlexQwenConfig(Qwen3Config):
24
+ def __init__(
25
+ self,
26
+ vocab_size: int = 64000,
27
+ embedding_dim: int = 1024,
28
+ hidden_dim: int = 2048,
29
+ num_attention_heads: int = 8,
30
+ num_kv_groups: int = 8,
31
+ head_dim: int = 128,
32
+ qk_norm: bool = True,
33
+ moe_num_experts: int = 0,
34
+ moe_num_experts_per_token: int = -1,
35
+ moe_hidden_dim: int = 512,
36
+ num_hidden_layers: int = 32,
37
+ context_length: int = 1024,
38
+ rms_norm_eps: float = 1e-6,
39
+ rope_theta: int = 10000,
40
+ initializer_range: float = 0.02,
41
+ cls_token_id: int = 1,
42
+ pad_token_id: int = 3,
43
+ tie_word_embeddings: bool = False,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(
47
+ cls_token_id=cls_token_id,
48
+ pad_token_id=pad_token_id,
49
+ tie_word_embeddings=tie_word_embeddings,
50
+ **kwargs,
51
+ )
52
+
53
+ # Vocab & Embeddings
54
+ self.vocab_size = vocab_size
55
+ self.embedding_dim = embedding_dim
56
+ self.hidden_dim = hidden_dim
57
+
58
+ # Attention Mechanism
59
+ self.num_attention_heads = num_attention_heads
60
+ self.num_kv_groups = num_kv_groups
61
+ self.head_dim = head_dim
62
+ self.qk_norm = qk_norm
63
+
64
+ # Feed-Forward & MoE
65
+ self.moe_num_experts = moe_num_experts
66
+ self.moe_num_experts_per_token = moe_num_experts_per_token
67
+ self.moe_hidden_dim = moe_hidden_dim
68
+
69
+ # General Architecture
70
+ self.num_hidden_layers = num_hidden_layers
71
+ self.context_length = context_length
72
+ self.rms_norm_eps = rms_norm_eps
73
+ self.rope_theta = rope_theta
74
+
75
+ # Initialization
76
+ self.initializer_range = initializer_range
77
+
78
+ # Standard HF Config params
79
+ self.tie_word_embeddings = tie_word_embeddings
80
+
81
+
82
+ class FlexQwenPreTrainedModel(PreTrainedModel):
83
+ config_class = FlexQwenConfig
84
+ _supports_cache_class = True
85
+
86
+ def _init_weights(self, module):
87
+ if isinstance(module, nn.Embedding):
88
+ module.weight.data.uniform_(
89
+ -self.config.initializer_range, self.config.initializer_range
90
+ )
91
+ # elif isinstance(module, CastedLinear):
92
+ # module.weight.data.uniform_()
93
+
94
+
95
+ class GroupedQueryAttention(nn.Module):
96
+ def __init__(
97
+ self,
98
+ in_features: int,
99
+ num_heads: int,
100
+ num_kv_groups: int,
101
+ head_dim: int | None = None,
102
+ qk_norm: bool = False,
103
+ rms_norm_eps: float = 1e-6,
104
+ device: torch.device | None = None,
105
+ dtype: torch.dtype | None = None,
106
+ ):
107
+ assert num_heads % num_kv_groups == 0, (
108
+ "num_heads must be divisible by num_kv_groups"
109
+ )
110
+ factory_kwargs = dict(device=device, dtype=dtype)
111
+ super().__init__()
112
+
113
+ self.num_heads = num_heads
114
+ self.num_kv_groups = num_kv_groups
115
+ self.group_size = num_heads // num_kv_groups
116
+
117
+ if head_dim is None:
118
+ assert in_features % num_heads == 0, (
119
+ "input_dim must be divisible by num_heads"
120
+ )
121
+ head_dim = in_features // num_heads
122
+
123
+ self.head_dim = head_dim
124
+ self.out_features = num_heads * head_dim
125
+
126
+ self.wq = CastedLinear(
127
+ in_features, self.out_features, bias=False, **factory_kwargs
128
+ )
129
+ self.wkv = CastedLinear(
130
+ in_features, 2 * num_kv_groups * head_dim, bias=False, **factory_kwargs
131
+ )
132
+
133
+ self.out_proj = CastedLinear(
134
+ self.out_features, in_features, bias=False, **factory_kwargs
135
+ )
136
+
137
+ if qk_norm:
138
+ self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs)
139
+ self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs)
140
+ else:
141
+ self.q_norm = self.k_norm = None
142
+
143
+ def forward(
144
+ self,
145
+ x: torch.FloatTensor,
146
+ cos: torch.FloatTensor,
147
+ sin: torch.FloatTensor,
148
+ attention_mask: Optional[torch.BoolTensor] = None,
149
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
150
+ cache_position: Optional[torch.LongTensor] = None,
151
+ ) -> tuple[torch.FloatTensor, tuple[torch.Tensor, torch.Tensor]]:
152
+ batch_size, num_tokens, _ = x.shape
153
+
154
+ query = self.wq(x)
155
+ key, value = self.wkv(x).chunk(2, dim=-1)
156
+
157
+ query = query.view(
158
+ batch_size, num_tokens, self.num_heads, self.head_dim
159
+ ).transpose(1, 2)
160
+
161
+ key = key.view(
162
+ batch_size, num_tokens, self.num_kv_groups, self.head_dim
163
+ ).transpose(1, 2)
164
+
165
+ value = value.view(
166
+ batch_size, num_tokens, self.num_kv_groups, self.head_dim
167
+ ).transpose(1, 2)
168
+
169
+ if self.q_norm:
170
+ query = self.q_norm(query)
171
+ if self.k_norm:
172
+ key = self.k_norm(key)
173
+
174
+ offset = 0
175
+ if cache_position is None:
176
+ kv_seq_len = key.shape[-2]
177
+ if past_key_value is not None:
178
+ kv_seq_len += past_key_value[0].shape[2]
179
+ offset = kv_seq_len - num_tokens
180
+ else:
181
+ offset = cache_position[0].item()
182
+
183
+ query = apply_rope(query, cos, sin, offset=offset)
184
+ key = apply_rope(key, cos, sin, offset=offset)
185
+
186
+ if past_key_value is not None:
187
+ past_key, past_value = past_key_value
188
+ key = torch.cat([past_key, key], dim=-2)
189
+ value = torch.cat([past_value, value], dim=-2)
190
+
191
+ present_key_value = (key, value)
192
+
193
+ attn_output = nn.functional.scaled_dot_product_attention(
194
+ query,
195
+ key,
196
+ value,
197
+ attn_mask=attention_mask,
198
+ dropout_p=0.0,
199
+ enable_gqa=True,
200
+ )
201
+ out = self.out_proj(
202
+ attn_output.transpose(1, 2).reshape(
203
+ batch_size, num_tokens, self.out_features
204
+ )
205
+ )
206
+ return out, present_key_value
207
+
208
+
209
+ class Transformer(nn.Module):
210
+ def __init__(
211
+ self,
212
+ embedding_dim: int,
213
+ hidden_dim: int,
214
+ num_heads: int,
215
+ head_dim: int,
216
+ num_kv_groups: int,
217
+ qk_norm: int = False,
218
+ moe_num_experts_per_token: int = 8,
219
+ moe_num_experts: int = 0,
220
+ moe_hidden_dim: int = 128,
221
+ rms_norm_eps: float = 1e-6,
222
+ device: torch.device | None = None,
223
+ dtype: torch.dtype | None = None,
224
+ ):
225
+ factory_kwargs = dict(device=device, dtype=dtype)
226
+ super().__init__()
227
+ self.attn = GroupedQueryAttention(
228
+ in_features=embedding_dim,
229
+ num_heads=num_heads,
230
+ head_dim=head_dim,
231
+ num_kv_groups=num_kv_groups,
232
+ qk_norm=qk_norm,
233
+ **factory_kwargs,
234
+ )
235
+
236
+ if moe_num_experts > 0:
237
+ self.ff = MoEFeedForward(
238
+ embedding_dim=embedding_dim,
239
+ hidden_dim=moe_hidden_dim,
240
+ num_experts_per_token=moe_num_experts_per_token,
241
+ num_experts=moe_num_experts,
242
+ **factory_kwargs,
243
+ )
244
+ else:
245
+ self.ff = FeedForward(
246
+ embedding_dim, hidden_dim=hidden_dim, **factory_kwargs
247
+ )
248
+ self.norm1 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs)
249
+ self.norm2 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs)
250
+
251
+ def forward(
252
+ self,
253
+ x: torch.FloatTensor,
254
+ cos: torch.FloatTensor,
255
+ sin: torch.FloatTensor,
256
+ attention_mask: Optional[torch.BoolTensor] = None,
257
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
258
+ cache_position: Optional[torch.LongTensor] = None,
259
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
260
+ residual = x
261
+ x = self.norm1(x)
262
+ x, present_key_value = self.attn(
263
+ x,
264
+ cos,
265
+ sin,
266
+ attention_mask=attention_mask,
267
+ past_key_value=past_key_value,
268
+ cache_position=cache_position,
269
+ )
270
+ x += residual
271
+
272
+ residual = x
273
+ x = self.norm2(x)
274
+ x = self.ff(x)
275
+ x += residual
276
+
277
+ return x, present_key_value
278
+
279
+
280
+ @dataclass
281
+ class FlexQwenOutputWithPast(ModelOutput):
282
+ last_hidden_state: torch.FloatTensor
283
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None
284
+
285
+
286
+ class FlexQwen(FlexQwenPreTrainedModel):
287
+ config_class = FlexQwenConfig
288
+
289
+ def __init__(
290
+ self,
291
+ config: FlexQwenConfig,
292
+ device: Optional[torch.device] = None,
293
+ dtype: Optional[torch.dtype] = None,
294
+ ):
295
+ super().__init__(config)
296
+
297
+ self.embed = nn.Embedding(
298
+ config.vocab_size,
299
+ config.embedding_dim,
300
+ padding_idx=config.pad_token_id,
301
+ device=device,
302
+ dtype=dtype,
303
+ )
304
+
305
+ self.transformer_blocks = nn.ModuleList(
306
+ [
307
+ Transformer(
308
+ embedding_dim=config.embedding_dim,
309
+ hidden_dim=config.hidden_dim,
310
+ num_heads=config.num_attention_heads,
311
+ head_dim=config.head_dim,
312
+ num_kv_groups=config.num_kv_groups,
313
+ qk_norm=config.qk_norm,
314
+ moe_num_experts_per_token=config.moe_num_experts_per_token,
315
+ moe_num_experts=config.moe_num_experts,
316
+ moe_hidden_dim=config.moe_hidden_dim,
317
+ rms_norm_eps=config.rms_norm_eps,
318
+ device=device,
319
+ dtype=dtype,
320
+ )
321
+ for _ in range(config.num_hidden_layers)
322
+ ]
323
+ )
324
+
325
+ self.final_norm = RMSNorm(
326
+ config.embedding_dim, eps=config.rms_norm_eps, device=device, dtype=dtype
327
+ )
328
+
329
+ cos, sin = compute_rope_params(
330
+ head_dim=config.head_dim,
331
+ theta_base=config.rope_theta,
332
+ context_length=config.context_length,
333
+ dtype=dtype,
334
+ device=device,
335
+ )
336
+
337
+ self.register_buffer("cos", cos, persistent=False)
338
+ self.register_buffer("sin", sin, persistent=False)
339
+ self.config = config
340
+ self.current_pos = 0
341
+
342
+ def forward(
343
+ self,
344
+ input_ids: Optional[torch.LongTensor] = None,
345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
346
+ attention_mask: Optional[torch.BoolTensor] = None,
347
+ past_key_values: Optional[tuple[torch.FloatTensor, torch.FloatTensor]] = None,
348
+ cache_position: Optional[torch.LongTensor] = None,
349
+ use_cache: Optional[bool] = None,
350
+ is_causal: bool = True,
351
+ return_dict: bool = True,
352
+ ) -> FlexQwenOutputWithPast:
353
+ if input_ids is not None and inputs_embeds is not None:
354
+ raise ValueError("Received both input_ids and input_embeds. Pass only one.")
355
+ if input_ids is None and inputs_embeds is None:
356
+ raise ValueError("Exactly one of input_ids, input_embds is required.")
357
+
358
+ if input_ids is not None:
359
+ if input_ids.dim() == 1:
360
+ input_ids = input_ids.unsqueeze(0)
361
+ x = self.embed(input_ids)
362
+ else:
363
+ x = inputs_embeds
364
+
365
+ seq_length = x.shape[1]
366
+ base_mask = torch.ones(
367
+ (seq_length, seq_length), dtype=torch.bool, device=x.device
368
+ )
369
+
370
+ if is_causal:
371
+ base_mask = torch.tril(base_mask)
372
+ else:
373
+ base_mask = ~base_mask
374
+
375
+ if attention_mask is not None:
376
+ padding_mask = (attention_mask == 0).unsqueeze(1).unsqueeze(2)
377
+ attention_mask = base_mask.unsqueeze(0).unsqueeze(1) | padding_mask
378
+ else:
379
+ attention_mask = base_mask.unsqueeze(0).unsqueeze(1)
380
+
381
+ next_kv_cache = [] if use_cache else None
382
+ for i, block in enumerate(self.transformer_blocks):
383
+ past_kv_cache_block = (
384
+ past_key_values[i]
385
+ if past_key_values is not None and len(past_key_values) > 0
386
+ else None
387
+ )
388
+ x, block_present_kv_cache = block(
389
+ x,
390
+ self.cos,
391
+ self.sin,
392
+ attention_mask=attention_mask,
393
+ past_key_value=past_kv_cache_block,
394
+ cache_position=cache_position,
395
+ )
396
+ if use_cache:
397
+ next_kv_cache.append(block_present_kv_cache)
398
+
399
+ x = self.final_norm(x)
400
+ output = FlexQwenOutputWithPast(
401
+ last_hidden_state=x,
402
+ past_key_values=tuple(next_kv_cache) if use_cache else None,
403
+ )
404
+
405
+ if not return_dict:
406
+ return output.to_tuple()
407
+
408
+ return output
409
+
410
+
411
+ class FlexQwenForCausalLM(FlexQwenPreTrainedModel, GenerationMixin):
412
+ config_class = FlexQwenConfig
413
+
414
+ def __init__(
415
+ self,
416
+ config: FlexQwenConfig,
417
+ device: Optional[torch.device] = None,
418
+ dtype: Optional[torch.dtype] = None,
419
+ **kwargs,
420
+ ):
421
+ super().__init__(config)
422
+ self.model = FlexQwen(config, device=device, dtype=dtype)
423
+ self.lm_head = CastedLinear(
424
+ config.embedding_dim,
425
+ config.vocab_size,
426
+ bias=False,
427
+ device=device,
428
+ dtype=dtype,
429
+ )
430
+
431
+ def forward(
432
+ self,
433
+ input_ids: Optional[torch.LongTensor] = None,
434
+ labels: Optional[torch.LongTensor] = None,
435
+ return_dict: bool = True,
436
+ use_cache: Optional[bool] = None,
437
+ **kwargs,
438
+ ) -> CausalLMOutputWithPast:
439
+ outputs: FlexQwenOutputWithPast = self.model(
440
+ input_ids=input_ids,
441
+ is_causal=True,
442
+ use_cache=use_cache,
443
+ return_dict=True,
444
+ **kwargs,
445
+ )
446
+
447
+ logits = self.lm_head(outputs.last_hidden_state).to(torch.float32)
448
+ loss = None
449
+ if labels is not None:
450
+ if labels.dim() == 1:
451
+ labels = labels.unsqueeze(0)
452
+ loss = nn.functional.cross_entropy(
453
+ logits.view(-1, logits.size(-1)),
454
+ labels.view(-1),
455
+ ignore_index=-100,
456
+ reduction="sum" if self.training else "mean",
457
+ )
458
+
459
+ output = CausalLMOutputWithPast(
460
+ logits=logits,
461
+ loss=loss,
462
+ past_key_values=outputs.past_key_values if use_cache else None,
463
+ )
464
+
465
+ if not return_dict:
466
+ return output.to_tuple()
467
+
468
+ return output
469
+
470
+
471
+ class FlexQwenForSequenceClassification(FlexQwenPreTrainedModel):
472
+ config_class = FlexQwenConfig
473
+
474
+ def __init__(
475
+ self,
476
+ config: FlexQwenConfig,
477
+ device: Optional[torch.device] = None,
478
+ dtype: Optional[torch.dtype] = None,
479
+ ):
480
+ super().__init__(config)
481
+ self.num_labels = config.num_labels
482
+ self.model = FlexQwen(config, device=device, dtype=dtype)
483
+ self.score = CastedLinear(config.embedding_dim, self.num_labels, bias=False)
484
+
485
+ def forward(
486
+ self,
487
+ input_ids: Optional[torch.LongTensor] = None,
488
+ attention_mask: Optional[torch.BoolTensor] = None,
489
+ labels: Optional[torch.LongTensor] = None,
490
+ return_dict: Optional[bool] = None,
491
+ **kwargs,
492
+ ) -> SequenceClassifierOutput:
493
+ return_dict = (
494
+ return_dict if return_dict is not None else self.config.use_return_dict
495
+ )
496
+
497
+ outputs: FlexQwenOutputWithPast = self.model(
498
+ input_ids=input_ids,
499
+ attention_mask=attention_mask,
500
+ return_dict=True,
501
+ **kwargs,
502
+ )
503
+
504
+ sequence_lengths = (
505
+ torch.eq(attention_mask, 1).int().argmax(-1)
506
+ if attention_mask is not None
507
+ else -1
508
+ )
509
+
510
+ hidden_states = outputs.last_hidden_state
511
+ pooled_states = hidden_states[
512
+ torch.arange(hidden_states.shape[0], device=hidden_states.device),
513
+ sequence_lengths,
514
+ ]
515
+ logits = self.score(pooled_states)
516
+
517
+ loss = None
518
+ if labels is not None:
519
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
520
+ loss = loss_fct(
521
+ logits.view(-1, self.num_labels),
522
+ labels.view(-1),
523
+ )
524
+
525
+ if not return_dict:
526
+ output = (logits,) + outputs[1:]
527
+ return (loss,) + output if loss is not None else output
528
+
529
+ return SequenceClassifierOutput(
530
+ loss=loss,
531
+ logits=logits,
532
+ )
533
+
534
+
535
+ # def check_grad(is_causal):
536
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
537
+ # config = FlexQwenConfig(vocab_size=2**10)
538
+ # model = FlexQwenForCausalLM(config=config, device=device)
539
+ # x = torch.randn(
540
+ # 1,
541
+ # config.context_length,
542
+ # config.embedding_dim,
543
+ # requires_grad=True,
544
+ # device=device,
545
+ # )
546
+ # output = model(inputs_embeds=x, attention_mask=None, is_causal=is_causal)
547
+ # output = output.logits
548
+ # t = config.context_length // 2
549
+ # loss = output[:, t, :].sum()
550
+ # loss.backward()
551
+ # grad_up_to_t = x.grad[:, : t + 1, :]
552
+ # has_grad_past = torch.all(grad_up_to_t != 0)
553
+ # grad_after_t = x.grad[:, t + 1 :, :]
554
+ # has_grad_future = torch.any(grad_after_t != 0)
555
+
556
+ # print(f"{is_causal=} {has_grad_past=} {has_grad_future=}")
557
+
558
+
559
+ # if __name__ == "__main__":
560
+ # device = torch.device("cuda:0")
561
+ # config = FlexQwenConfig(vocab_size=2**10)
562
+
563
+ # model_lm = FlexQwenForCausalLM(config=config, device=device)
564
+ # input_ids = torch.arange(
565
+ # start=0,
566
+ # end=config.context_length - 1,
567
+ # device=device,
568
+ # ).unsqueeze(0)
569
+ # labels_seq = torch.arange(
570
+ # start=1,
571
+ # end=config.context_length,
572
+ # device=device,
573
+ # ).unsqueeze(0)
574
+
575
+ # output_lm: FlexQwenOutputWithPast = model_lm(
576
+ # input_ids, labels=labels_seq, is_causal=True
577
+ # )
578
+ # print(f"LM Logits shape: {output_lm.logits.shape}")
579
+ # print(f"LM Loss: {output_lm.loss.item()}")
580
+
581
+ # config.num_labels = 3
582
+ # model_seq = FlexQwenForSequenceClassification(config=config, device=device)
583
+ # input_ids = torch.randint(0, config.vocab_size, (4, 16), device=device)
584
+ # attention_mask = torch.ones_like(input_ids)
585
+
586
+ # attention_mask[2, 10:] = 0
587
+ # labels_seq = torch.randint(0, config.num_labels, (4,), device=device)
588
+ # output_seq = model_seq(
589
+ # input_ids=input_ids, attention_mask=attention_mask, labels=labels_seq
590
+ # )
591
+
592
+ # print(f"Seq Logits shape: {output_seq.logits.shape}")
593
+ # print(f"Seq Loss: {output_seq.loss.item()}")
594
+
595
+ # peak_memory_allocated = torch.cuda.max_memory_allocated() // 1024 // 1024
596
+ # reserved_memory = torch.cuda.max_memory_reserved() // 1024 // 1024
597
+
598
+ # print(f"Peak memory allocated: {peak_memory_allocated} MB")
599
+ # print(f"Reserved memory: {reserved_memory} MB")
600
+ # check_grad(is_causal=True)
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[UNK]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[PAD]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "extra_special_tokens": {},
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 1000000000000000019884624838656,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "tokenizer_class": "PreTrainedTokenizer",
52
+ "unk_token": "[UNK]"
53
+ }