mmcarpi commited on
Commit
9993140
·
verified ·
1 Parent(s): 4e9d33c

Add flexqwen.py to root for trust_remote_code

Browse files
Files changed (1) hide show
  1. flexqwen.py +674 -0
flexqwen.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
8
+ from transformers.cache_utils import Cache, DynamicCache
9
+ from transformers.utils import ModelOutput
10
+ from transformers.modeling_outputs import (
11
+ SequenceClassifierOutput,
12
+ CausalLMOutputWithPast,
13
+ )
14
+
15
+ from .common import (
16
+ FeedForward,
17
+ MoEFeedForward,
18
+ RMSNorm,
19
+ compute_rope_params,
20
+ apply_rope,
21
+ )
22
+
23
+
24
+ class FlexQwenConfig(PretrainedConfig):
25
+ model_type = "flexqwen"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size: int = 64000,
30
+ embedding_dim: int = 1024,
31
+ hidden_dim: int = 2048,
32
+ num_attention_heads: int = 8,
33
+ num_kv_groups: int = 8,
34
+ head_dim: int = 128,
35
+ qk_norm: bool = True,
36
+ moe_num_experts: int = 0,
37
+ moe_num_experts_per_token: int = -1,
38
+ moe_hidden_dim: int = 512,
39
+ num_hidden_layers: int = 32,
40
+ max_position_embeddings: int = 1024,
41
+ rms_norm_eps: float = 1e-6,
42
+ rope_theta: int = 10000,
43
+ initializer_range: float = 0.02,
44
+ cls_token_id: int = 1,
45
+ pad_token_id: int = 3,
46
+ tie_word_embeddings: bool = True,
47
+ dropout_rate: float = 0.0,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ cls_token_id=cls_token_id,
52
+ pad_token_id=pad_token_id,
53
+ tie_word_embeddings=tie_word_embeddings,
54
+ **kwargs,
55
+ )
56
+
57
+ # Vocab & Embeddings
58
+ self.vocab_size = vocab_size
59
+ self.embedding_dim = embedding_dim
60
+ self.hidden_dim = hidden_dim
61
+
62
+ # Attention Mechanism
63
+ self.num_attention_heads = num_attention_heads
64
+ self.num_kv_groups = num_kv_groups
65
+ self.head_dim = head_dim
66
+ self.qk_norm = qk_norm
67
+
68
+ # Feed-Forward & MoE
69
+ self.moe_num_experts = moe_num_experts
70
+ self.moe_num_experts_per_token = moe_num_experts_per_token
71
+ self.moe_hidden_dim = moe_hidden_dim
72
+
73
+ # General Architecture
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.rms_norm_eps = rms_norm_eps
77
+ self.rope_theta = rope_theta
78
+
79
+ # Initialization
80
+ self.initializer_range = initializer_range
81
+
82
+ # Standard HF Config params
83
+ self.tie_word_embeddings = tie_word_embeddings
84
+
85
+ self.dropout_rate = dropout_rate
86
+
87
+
88
+ # pyrefly: ignore
89
+ class FlexQwenPreTrainedModel(PreTrainedModel):
90
+ config_class = FlexQwenConfig
91
+ base_model_prefix = "model"
92
+ _supports_cache_class = True
93
+
94
+ def _init_weights(self, module):
95
+ if isinstance(module, nn.Embedding):
96
+ module.weight.data.uniform_(
97
+ -self.config.initializer_range, self.config.initializer_range
98
+ )
99
+ elif isinstance(module, nn.Linear):
100
+ module.weight.data.uniform_(
101
+ -self.config.initializer_range, self.config.initializer_range
102
+ )
103
+ if module.bias is not None:
104
+ module.bias.data.zero_()
105
+
106
+
107
+ class GroupedQueryAttention(nn.Module):
108
+ def __init__(
109
+ self,
110
+ in_features: int,
111
+ num_heads: int,
112
+ num_kv_groups: int,
113
+ head_dim: int | None = None,
114
+ qk_norm: int = False,
115
+ rms_norm_eps: float = 1e-6,
116
+ device: torch.device | None = None,
117
+ dtype: torch.dtype | None = None,
118
+ layer_idx: int = 0,
119
+ ):
120
+ assert num_heads % num_kv_groups == 0, (
121
+ "num_heads must be divisible by num_kv_groups"
122
+ )
123
+ factory_kwargs = dict(device=device, dtype=dtype)
124
+ super().__init__()
125
+
126
+ self.num_heads = num_heads
127
+ self.num_kv_groups = num_kv_groups
128
+ self.group_size = num_heads // num_kv_groups
129
+
130
+ if head_dim is None:
131
+ assert in_features % num_heads == 0, (
132
+ "input_dim must be divisible by num_heads"
133
+ )
134
+ head_dim = in_features // num_heads
135
+
136
+ self.head_dim = head_dim
137
+ self.out_features = num_heads * head_dim
138
+
139
+ self.wq = nn.Linear(
140
+ in_features, self.out_features, bias=False, **factory_kwargs
141
+ )
142
+ self.wkv = nn.Linear(
143
+ in_features, 2 * num_kv_groups * head_dim, bias=False, **factory_kwargs
144
+ )
145
+
146
+ self.out_proj = nn.Linear(
147
+ self.out_features, in_features, bias=False, **factory_kwargs
148
+ )
149
+
150
+ self.qk_norm = qk_norm
151
+ if self.qk_norm:
152
+ self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs)
153
+ self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs)
154
+
155
+ self.layer_idx = layer_idx
156
+
157
+ def forward(
158
+ self,
159
+ x: torch.FloatTensor,
160
+ cos: torch.FloatTensor,
161
+ sin: torch.FloatTensor,
162
+ attention_mask: Optional[torch.BoolTensor] = None,
163
+ past_key_value: Optional[Cache] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ ) -> tuple[torch.FloatTensor, Optional[Cache]]:
166
+ batch_size, num_tokens, _ = x.shape
167
+
168
+ query = self.wq(x)
169
+ key, value = self.wkv(x).chunk(2, dim=-1)
170
+
171
+ query = query.view(
172
+ batch_size, num_tokens, self.num_heads, self.head_dim
173
+ ).transpose(1, 2)
174
+
175
+ key = key.view(
176
+ batch_size, num_tokens, self.num_kv_groups, self.head_dim
177
+ ).transpose(1, 2)
178
+
179
+ value = value.view(
180
+ batch_size, num_tokens, self.num_kv_groups, self.head_dim
181
+ ).transpose(1, 2)
182
+
183
+ if self.qk_norm:
184
+ query = self.q_norm(query)
185
+ key = self.k_norm(key)
186
+
187
+ if cache_position is None:
188
+ offset = (
189
+ past_key_value.get_seq_length(self.layer_idx)
190
+ if past_key_value is not None
191
+ else 0
192
+ )
193
+ else:
194
+ offset = int(cache_position[0].item())
195
+
196
+ query = apply_rope(query, cos, sin, offset=offset)
197
+ key = apply_rope(key, cos, sin, offset=offset)
198
+
199
+ if past_key_value is not None:
200
+ cache_kwargs = {"cache_position": cache_position}
201
+ key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)
202
+
203
+ attn_output = nn.functional.scaled_dot_product_attention(
204
+ query,
205
+ key,
206
+ value,
207
+ attn_mask=attention_mask,
208
+ dropout_p=0.0,
209
+ enable_gqa=True,
210
+ )
211
+ out = self.out_proj(
212
+ attn_output.transpose(1, 2).reshape(
213
+ batch_size, num_tokens, self.out_features
214
+ )
215
+ )
216
+ return out, past_key_value
217
+
218
+
219
+ class Transformer(nn.Module):
220
+ def __init__(
221
+ self,
222
+ embedding_dim: int,
223
+ hidden_dim: int,
224
+ num_heads: int,
225
+ head_dim: int,
226
+ num_kv_groups: int,
227
+ qk_norm: int = False,
228
+ moe_num_experts_per_token: int = 8,
229
+ moe_num_experts: int = 0,
230
+ moe_hidden_dim: int = 128,
231
+ rms_norm_eps: float = 1e-6,
232
+ device: torch.device | None = None,
233
+ dtype: torch.dtype | None = None,
234
+ layer_idx: int = 0,
235
+ ):
236
+ factory_kwargs = dict(device=device, dtype=dtype)
237
+ super().__init__()
238
+ self.attn = GroupedQueryAttention(
239
+ in_features=embedding_dim,
240
+ num_heads=num_heads,
241
+ head_dim=head_dim,
242
+ num_kv_groups=num_kv_groups,
243
+ qk_norm=qk_norm,
244
+ layer_idx=layer_idx,
245
+ **factory_kwargs,
246
+ )
247
+
248
+ if moe_num_experts > 0:
249
+ self.ff: MoEFeedForward | FeedForward = MoEFeedForward(
250
+ embedding_dim=embedding_dim,
251
+ hidden_dim=moe_hidden_dim,
252
+ num_experts_per_token=moe_num_experts_per_token,
253
+ num_experts=moe_num_experts,
254
+ device=device,
255
+ dtype=dtype,
256
+ )
257
+ else:
258
+ self.ff = FeedForward(
259
+ embedding_dim, hidden_dim=hidden_dim, **factory_kwargs
260
+ )
261
+ self.norm1 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs)
262
+ self.norm2 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs)
263
+
264
+ def forward(
265
+ self,
266
+ x: torch.FloatTensor,
267
+ cos: torch.FloatTensor,
268
+ sin: torch.FloatTensor,
269
+ attention_mask: Optional[torch.BoolTensor] = None,
270
+ past_key_value: Optional[Cache] = None,
271
+ cache_position: Optional[torch.LongTensor] = None,
272
+ ) -> tuple[torch.FloatTensor, Optional[Cache]]:
273
+ residual = x
274
+ x = self.norm1(x)
275
+ x, past_key_value = self.attn(
276
+ x,
277
+ cos,
278
+ sin,
279
+ attention_mask=attention_mask,
280
+ past_key_value=past_key_value,
281
+ cache_position=cache_position,
282
+ )
283
+ x += residual
284
+
285
+ residual = x
286
+ x = self.norm2(x)
287
+ x = self.ff(x)
288
+ x += residual
289
+
290
+ return x, past_key_value
291
+
292
+
293
+ @dataclass
294
+ class FlexQwenOutputWithPast(ModelOutput):
295
+ last_hidden_states: tuple[torch.FloatTensor]
296
+ attentions: Optional[tuple[torch.FloatTensor]] = None
297
+ past_key_values: Optional[Cache] = None
298
+
299
+
300
+ class FlexQwen(FlexQwenPreTrainedModel):
301
+ config_class = FlexQwenConfig
302
+
303
+ def __init__(
304
+ self,
305
+ config: FlexQwenConfig,
306
+ device: Optional[torch.device] = None,
307
+ dtype: Optional[torch.dtype] = None,
308
+ ):
309
+ super().__init__(config)
310
+
311
+ self.embed = nn.Embedding(
312
+ config.vocab_size,
313
+ config.embedding_dim,
314
+ padding_idx=config.pad_token_id,
315
+ device=device,
316
+ dtype=dtype,
317
+ )
318
+
319
+ self.transformer_blocks = nn.ModuleList(
320
+ [
321
+ Transformer(
322
+ embedding_dim=config.embedding_dim,
323
+ hidden_dim=config.hidden_dim,
324
+ num_heads=config.num_attention_heads,
325
+ head_dim=config.head_dim,
326
+ num_kv_groups=config.num_kv_groups,
327
+ qk_norm=config.qk_norm,
328
+ moe_num_experts_per_token=config.moe_num_experts_per_token,
329
+ moe_num_experts=config.moe_num_experts,
330
+ moe_hidden_dim=config.moe_hidden_dim,
331
+ rms_norm_eps=config.rms_norm_eps,
332
+ device=device,
333
+ dtype=dtype,
334
+ layer_idx=i,
335
+ )
336
+ for i in range(config.num_hidden_layers)
337
+ ]
338
+ )
339
+
340
+ self.final_norm = RMSNorm(
341
+ config.embedding_dim, eps=config.rms_norm_eps, device=device, dtype=dtype
342
+ )
343
+
344
+ cos, sin = compute_rope_params(
345
+ head_dim=config.head_dim,
346
+ theta_base=config.rope_theta,
347
+ max_position_embeddings=config.max_position_embeddings,
348
+ dtype=dtype,
349
+ device=device,
350
+ )
351
+
352
+ self.register_buffer("cos", cos, persistent=True)
353
+ self.register_buffer("sin", sin, persistent=True)
354
+ self.config = config
355
+
356
+ self.post_init()
357
+
358
+ def forward(
359
+ self,
360
+ input_ids: Optional[torch.Tensor] = None,
361
+ inputs_embeds: Optional[torch.Tensor] = None,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ past_key_values: Optional[Cache] = None,
364
+ cache_position: Optional[torch.LongTensor] = None,
365
+ use_cache: Optional[int] = None,
366
+ is_causal: bool = True,
367
+ return_dict: bool = True,
368
+ **kwargs,
369
+ ) -> FlexQwenOutputWithPast | tuple:
370
+ if input_ids is not None and inputs_embeds is not None:
371
+ raise ValueError("Received both input_ids and input_embeds. Pass only one.")
372
+ if input_ids is None and inputs_embeds is None:
373
+ raise ValueError("Exactly one of input_ids, input_embds is required.")
374
+
375
+ if input_ids is not None:
376
+ if input_ids.dim() == 1:
377
+ input_ids = input_ids.unsqueeze(0)
378
+ x = self.embed(input_ids)
379
+ else:
380
+ x = inputs_embeds
381
+
382
+ assert x is not None
383
+
384
+ q_len = x.shape[1]
385
+ kv_len = q_len
386
+
387
+ # If we have a cache, the total key/value length is past_len + current_len
388
+ if past_key_values is not None:
389
+ kv_len += past_key_values.get_seq_length()
390
+
391
+ base_mask = torch.ones((q_len, kv_len), dtype=torch.bool, device=x.device)
392
+
393
+ if is_causal and q_len > 1:
394
+ # Shift the tril to account for past tokens
395
+ base_mask = torch.tril(base_mask, diagonal=kv_len - q_len)
396
+
397
+ if attention_mask is not None:
398
+ # Padding mask is usually (Batch, kv_len)
399
+ padding_mask = (attention_mask == 1).unsqueeze(1).unsqueeze(2)
400
+ attention_mask = base_mask.unsqueeze(0).unsqueeze(1) & padding_mask
401
+ else:
402
+ attention_mask = base_mask.unsqueeze(0).unsqueeze(1)
403
+
404
+ if use_cache and past_key_values is None:
405
+ past_key_values = DynamicCache()
406
+
407
+ for block in self.transformer_blocks:
408
+ x, past_key_values = block(
409
+ x,
410
+ self.cos,
411
+ self.sin,
412
+ attention_mask=attention_mask,
413
+ past_key_value=past_key_values,
414
+ cache_position=cache_position,
415
+ )
416
+
417
+ x = self.final_norm(x)
418
+
419
+ output = FlexQwenOutputWithPast(
420
+ last_hidden_states=(x,),
421
+ past_key_values=past_key_values if use_cache else None,
422
+ )
423
+
424
+ if not return_dict:
425
+ return output.to_tuple()
426
+
427
+ return output
428
+
429
+
430
+ class FlexQwenForCausalLM(FlexQwenPreTrainedModel, GenerationMixin):
431
+ config_class = FlexQwenConfig
432
+ _tied_weights_keys = {"lm_head.weight": "model.embed.weight"}
433
+
434
+ def __init__(
435
+ self,
436
+ config: FlexQwenConfig,
437
+ device: Optional[torch.device] = None,
438
+ dtype: Optional[torch.dtype] = None,
439
+ **kwargs,
440
+ ):
441
+ super().__init__(config)
442
+ self.model = FlexQwen(config, device=device, dtype=dtype)
443
+ self.lm_head = nn.Linear(
444
+ config.embedding_dim,
445
+ config.vocab_size,
446
+ bias=False,
447
+ device=device,
448
+ dtype=dtype,
449
+ )
450
+
451
+ self.post_init()
452
+
453
+ def get_input_embeddings(self):
454
+ return self.model.embed
455
+
456
+ def set_input_embeddings(self, value):
457
+ self.model.embed = value
458
+
459
+ def get_output_embeddings(self):
460
+ return self.lm_head
461
+
462
+ def set_output_embeddings(self, new_embeddings):
463
+ self.lm_head = new_embeddings
464
+
465
+ def tie_weights(
466
+ self, missing_keys: set[str] | None = None, recompute_mapping: bool = True
467
+ ) -> None:
468
+ super().tie_weights(
469
+ missing_keys=missing_keys, recompute_mapping=recompute_mapping
470
+ )
471
+
472
+ if getattr(self.config, "tie_word_embeddings", False):
473
+ self.lm_head.weight = self.model.embed.weight
474
+ print("Weights tied anyway, do not worry, be happy =)")
475
+
476
+ def forward(
477
+ self,
478
+ input_ids: Optional[torch.LongTensor] = None,
479
+ attention_mask: Optional[torch.BoolTensor] = None,
480
+ labels: Optional[torch.Tensor] = None,
481
+ return_dict: Optional[bool] = None,
482
+ use_cache: Optional[bool] = None,
483
+ is_causal=True,
484
+ **kwargs,
485
+ ) -> CausalLMOutputWithPast | tuple:
486
+ return_dict = (
487
+ return_dict if return_dict is not None else self.config.use_return_dict
488
+ )
489
+
490
+ outputs: FlexQwenOutputWithPast = self.model(
491
+ input_ids=input_ids,
492
+ attention_mask=attention_mask,
493
+ use_cache=use_cache,
494
+ return_dict=True,
495
+ is_causal=is_causal,
496
+ **kwargs,
497
+ )
498
+
499
+ logits = self.lm_head(outputs.last_hidden_states[-1])
500
+ loss = None
501
+ if labels is not None:
502
+ if labels.dim() == 1:
503
+ labels = labels.unsqueeze(0)
504
+ loss = nn.functional.cross_entropy(
505
+ logits.view(-1, logits.size(-1)),
506
+ labels.view(-1),
507
+ ignore_index=-100,
508
+ reduction="mean",
509
+ )
510
+
511
+ output = CausalLMOutputWithPast(
512
+ logits=logits,
513
+ # pyrefly: ignore
514
+ loss=loss,
515
+ # TODO: Implement this properly
516
+ # pyrefly: ignore
517
+ past_key_values=outputs.past_key_values if use_cache else None,
518
+ )
519
+
520
+ if not return_dict:
521
+ return output.to_tuple()
522
+
523
+ return output
524
+
525
+ def prepare_inputs_for_generation(
526
+ self,
527
+ input_ids: torch.LongTensor,
528
+ next_sequence_length: Optional[int] = None,
529
+ past_key_values: Optional[Cache] = None,
530
+ attention_mask: Optional[torch.LongTensor] = None,
531
+ inputs_embeds: Optional[torch.FloatTensor] = None,
532
+ cache_position: Optional[torch.LongTensor] = None,
533
+ is_first_iteration: Optional[bool] = False,
534
+ **kwargs,
535
+ ) -> dict:
536
+ if past_key_values is not None:
537
+ if not is_first_iteration:
538
+ input_ids = input_ids[:, -1:] # pyrefly: ignore
539
+
540
+ if inputs_embeds is not None and past_key_values is None:
541
+ model_inputs = {"inputs_embeds": inputs_embeds}
542
+ else:
543
+ model_inputs = {"input_ids": input_ids}
544
+
545
+ # pyrefly: ignore
546
+ model_inputs.update(
547
+ {
548
+ "past_key_values": past_key_values,
549
+ "use_cache": kwargs.get("use_cache", True),
550
+ "attention_mask": attention_mask,
551
+ "cache_position": cache_position,
552
+ "is_causal": True,
553
+ }
554
+ )
555
+ return model_inputs
556
+
557
+
558
+ class FlexQwenForSequenceClassification(FlexQwenPreTrainedModel):
559
+ config_class = FlexQwenConfig
560
+
561
+ def __init__(
562
+ self,
563
+ config: FlexQwenConfig,
564
+ device: Optional[torch.device] = None,
565
+ dtype: Optional[torch.dtype] = None,
566
+ ):
567
+ super().__init__(config)
568
+ self.num_labels = config.num_labels
569
+ self.model = FlexQwen(config, device=device, dtype=dtype)
570
+ self.dropout = nn.Dropout(p=config.dropout_rate)
571
+ self.score = nn.Linear(
572
+ config.embedding_dim,
573
+ self.num_labels,
574
+ bias=True,
575
+ device=device,
576
+ dtype=dtype,
577
+ )
578
+ self.loss_fct = nn.CrossEntropyLoss() if config.num_labels > 1 else nn.MSELoss()
579
+
580
+ self.post_init()
581
+
582
+ def forward(
583
+ self,
584
+ input_ids: torch.LongTensor,
585
+ # Fix when attention mask is None
586
+ attention_mask: Optional[torch.BoolTensor] = None,
587
+ labels: Optional[torch.LongTensor] = None,
588
+ return_dict: Optional[int] = None,
589
+ is_causal=True,
590
+ **kwargs,
591
+ ) -> SequenceClassifierOutput | tuple:
592
+ return_dict = (
593
+ return_dict if return_dict is not None else self.config.use_return_dict
594
+ )
595
+
596
+ # pyrefly: ignore
597
+ outputs: FlexQwenOutputWithPast = self.model(
598
+ input_ids=input_ids,
599
+ attention_mask=attention_mask,
600
+ return_dict=True,
601
+ is_causal=is_causal,
602
+ **kwargs,
603
+ )
604
+
605
+ hidden_states = outputs.last_hidden_states[-1]
606
+
607
+ if is_causal:
608
+ if attention_mask is None:
609
+ pooled_states = hidden_states[:, -1]
610
+ else:
611
+ sequence_lengths = attention_mask.sum(dim=1) - 1
612
+ pooled_states = hidden_states[
613
+ torch.arange(hidden_states.shape[0], device=hidden_states.device),
614
+ sequence_lengths,
615
+ ]
616
+ else:
617
+ if attention_mask is None:
618
+ pooled_states = hidden_states.mean(dim=1)
619
+ else:
620
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.size())
621
+ masked_hidden_states = torch.where(mask.bool(), hidden_states, 0.0)
622
+ num_valid_tokens = (
623
+ attention_mask.sum(dim=1).unsqueeze(-1).clamp(min=1e-9)
624
+ )
625
+ pooled_states = masked_hidden_states.sum(dim=1) / num_valid_tokens
626
+
627
+ logits = self.score(self.dropout(pooled_states))
628
+
629
+ loss = None
630
+ if labels is not None:
631
+ if self.num_labels == 1:
632
+ loss = self.loss_fct(logits.squeeze(), labels.squeeze())
633
+ else:
634
+ loss = self.loss_fct(
635
+ logits.view(-1, self.num_labels),
636
+ labels.view(-1),
637
+ )
638
+
639
+ if not return_dict:
640
+ output = (logits,) + (outputs.last_hidden_states, outputs.attentions)
641
+ return (loss,) + output if loss is not None else output
642
+
643
+ return SequenceClassifierOutput(
644
+ loss=loss,
645
+ logits=logits,
646
+ hidden_states=outputs.last_hidden_states,
647
+ attentions=outputs.attentions,
648
+ )
649
+
650
+
651
+ def load_model(
652
+ checkpoint_dir: str | Path, device: str | torch.device = "cpu"
653
+ ) -> FlexQwenForCausalLM:
654
+ checkpoint_dir = Path(checkpoint_dir)
655
+
656
+ from transformers import AutoConfig
657
+ from safetensors.torch import load_file
658
+
659
+ AutoConfig.register("flexqwen", FlexQwenConfig)
660
+
661
+ config = AutoConfig.from_pretrained(checkpoint_dir)
662
+ model = FlexQwenForCausalLM(config) # pyrefly: ignore
663
+
664
+ safetensors_path = checkpoint_dir / "model.safetensors"
665
+ if not safetensors_path.exists():
666
+ raise FileNotFoundError(f"Could not find {safetensors_path}.")
667
+
668
+ disk_dict = load_file(safetensors_path)
669
+
670
+ model.load_state_dict(disk_dict, strict=False)
671
+
672
+ model.tie_weights()
673
+
674
+ return model.to(device)