benjamin commited on
Commit
d2bb4be
·
verified ·
1 Parent(s): 66763b8

Upload FlaxTPULlamaForCausalLM

Browse files
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "TPULlamaForCausalLM"
4
  ],
@@ -36,7 +37,7 @@
36
  "torch_dtype": "float32",
37
  "transformers_version": "4.52.3",
38
  "use_cache": true,
 
39
  "use_sliding_window": false,
40
- "vocab_size": 151936,
41
- "add_qk_norm": true
42
  }
 
1
  {
2
+ "add_qk_norm": false,
3
  "architectures": [
4
  "TPULlamaForCausalLM"
5
  ],
 
37
  "torch_dtype": "float32",
38
  "transformers_version": "4.52.3",
39
  "use_cache": true,
40
+ "use_qk_norm": true,
41
  "use_sliding_window": false,
42
+ "vocab_size": 151936
 
43
  }
configuration_tpu_llama.py CHANGED
@@ -147,6 +147,7 @@ class TPULlamaConfig(PretrainedConfig):
147
  attention_dropout=0.0,
148
  mlp_bias=False,
149
  head_dim=None,
 
150
  expand_input_ids=False, # Transformers-native PyTorch generation support
151
  expand_input_ids_maxlen=None,
152
  expand_input_ids_vocab_size=None,
@@ -183,6 +184,8 @@ class TPULlamaConfig(PretrainedConfig):
183
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
184
  rope_config_validation(self)
185
 
 
 
186
  self.expand_input_ids = expand_input_ids
187
  self.expand_input_ids_maxlen = expand_input_ids_maxlen
188
  self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
 
147
  attention_dropout=0.0,
148
  mlp_bias=False,
149
  head_dim=None,
150
+ add_qk_norm=False, # Qwen3 compatibility
151
  expand_input_ids=False, # Transformers-native PyTorch generation support
152
  expand_input_ids_maxlen=None,
153
  expand_input_ids_vocab_size=None,
 
184
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
185
  rope_config_validation(self)
186
 
187
+ self.add_qk_norm = add_qk_norm # Qwen3 compatibility
188
+
189
  self.expand_input_ids = expand_input_ids
190
  self.expand_input_ids_maxlen = expand_input_ids_maxlen
191
  self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
modelling_flax_tpu_llama.py CHANGED
@@ -273,10 +273,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
273
  class FlaxTPULlamaRMSNorm(nn.Module):
274
  config: TPULlamaConfig
275
  dtype: jnp.dtype = jnp.float32
 
276
 
277
  def setup(self):
 
 
 
 
 
278
  self.epsilon = self.config.rms_norm_eps
279
- self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
280
 
281
  def __call__(self, hidden_states):
282
  variance = jnp.asarray(hidden_states, dtype=jnp.float32)
@@ -350,6 +356,11 @@ class FlaxTPULlamaAttention(nn.Module):
350
  self.k_proj = dense(self.num_key_value_heads * self.head_dim)
351
  self.v_proj = dense(self.num_key_value_heads * self.head_dim)
352
  self.o_proj = dense(self.embed_dim)
 
 
 
 
 
353
  self.causal_mask = make_causal_mask(
354
  jnp.ones(
355
  (1, getattr(config, "max_length", config.max_position_embeddings)),
@@ -357,7 +368,6 @@ class FlaxTPULlamaAttention(nn.Module):
357
  ),
358
  dtype="bool",
359
  )
360
- self.rotary_emb = FlaxTPULlamaRotaryEmbedding(config, dtype=self.dtype)
361
 
362
  def _split_heads(self, hidden_states, num_heads):
363
  return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
@@ -401,6 +411,7 @@ class FlaxTPULlamaAttention(nn.Module):
401
  def __call__(
402
  self,
403
  hidden_states,
 
404
  attention_mask,
405
  position_ids,
406
  deterministic: bool = True,
@@ -415,9 +426,19 @@ class FlaxTPULlamaAttention(nn.Module):
415
  key = self._split_heads(raw_key, self.num_key_value_heads)
416
  value = self._split_heads(raw_value, self.num_key_value_heads)
417
 
418
- cos, sin = self.rotary_emb(value, position_ids)
 
 
 
 
 
 
419
  query, key = apply_rotary_pos_emb(query, key, cos, sin)
420
 
 
 
 
 
421
  query_length, key_length = query.shape[1], key.shape[1]
422
 
423
  if self.has_variable("cache", "cached_key"):
@@ -519,6 +540,7 @@ class FlaxTPULlamaFlashAttention(FlaxTPULlamaAttention):
519
  def __call__(
520
  self,
521
  hidden_states,
 
522
  attention_mask,
523
  position_ids,
524
  deterministic: bool = True,
@@ -533,7 +555,7 @@ class FlaxTPULlamaFlashAttention(FlaxTPULlamaAttention):
533
  key = self._split_heads(raw_key, self.num_key_value_heads)
534
  value = self._split_heads(raw_value, self.num_key_value_heads)
535
 
536
- cos, sin = self.rotary_emb(value, position_ids)
537
  query, key = apply_rotary_pos_emb(query, key, cos, sin)
538
 
539
  query_length, key_length = query.shape[1], key.shape[1]
@@ -647,6 +669,7 @@ class FlaxTPULlamaDecoderLayer(nn.Module):
647
  def __call__(
648
  self,
649
  hidden_states,
 
650
  attention_mask=None,
651
  position_ids=None,
652
  deterministic: bool = True,
@@ -660,6 +683,7 @@ class FlaxTPULlamaDecoderLayer(nn.Module):
660
  hidden_states = self.input_layernorm(hidden_states)
661
  outputs = self.self_attn(
662
  hidden_states,
 
663
  attention_mask=attention_mask,
664
  position_ids=position_ids,
665
  deterministic=deterministic,
@@ -865,8 +889,10 @@ class FlaxTPULlamaLayerCollection(nn.Module):
865
  gradient_checkpointing: bool = False
866
 
867
  def setup(self):
 
 
868
  if self.gradient_checkpointing:
869
- FlaxTPULlamaDecoderCheckpointLayer = remat(FlaxTPULlamaDecoderLayer, static_argnums=(3, 4, 5))
870
  self.blocks = [
871
  FlaxTPULlamaDecoderCheckpointLayer(self.config, dtype=self.dtype, name=str(i))
872
  for i in range(self.config.num_hidden_layers)
@@ -891,6 +917,8 @@ class FlaxTPULlamaLayerCollection(nn.Module):
891
  all_attentions = () if output_attentions else None
892
  all_hidden_states = [(), ()] if output_hidden_states else None
893
 
 
 
894
  if output_hidden_states:
895
  all_hidden_states[0] += (hidden_states,)
896
  all_hidden_states[1] += (hidden_states,)
@@ -898,6 +926,7 @@ class FlaxTPULlamaLayerCollection(nn.Module):
898
  for block_idx, block in enumerate(self.blocks):
899
  layer_outputs = block(
900
  hidden_states,
 
901
  attention_mask,
902
  position_ids,
903
  deterministic,
 
273
  class FlaxTPULlamaRMSNorm(nn.Module):
274
  config: TPULlamaConfig
275
  dtype: jnp.dtype = jnp.float32
276
+ override_dim: int = None
277
 
278
  def setup(self):
279
+ if self.override_dim is not None:
280
+ dim = self.override_dim
281
+ else:
282
+ dim = self.config.hidden_size
283
+
284
  self.epsilon = self.config.rms_norm_eps
285
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), dim)
286
 
287
  def __call__(self, hidden_states):
288
  variance = jnp.asarray(hidden_states, dtype=jnp.float32)
 
356
  self.k_proj = dense(self.num_key_value_heads * self.head_dim)
357
  self.v_proj = dense(self.num_key_value_heads * self.head_dim)
358
  self.o_proj = dense(self.embed_dim)
359
+
360
+ if self.config.add_qk_norm:
361
+ self.q_norm = FlaxTPULlamaRMSNorm(self.config, dtype=self.dtype, override_dim=self.head_dim)
362
+ self.k_norm = FlaxTPULlamaRMSNorm(self.config, dtype=self.dtype, override_dim=self.head_dim)
363
+
364
  self.causal_mask = make_causal_mask(
365
  jnp.ones(
366
  (1, getattr(config, "max_length", config.max_position_embeddings)),
 
368
  ),
369
  dtype="bool",
370
  )
 
371
 
372
  def _split_heads(self, hidden_states, num_heads):
373
  return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
 
411
  def __call__(
412
  self,
413
  hidden_states,
414
+ position_embeddings,
415
  attention_mask,
416
  position_ids,
417
  deterministic: bool = True,
 
426
  key = self._split_heads(raw_key, self.num_key_value_heads)
427
  value = self._split_heads(raw_value, self.num_key_value_heads)
428
 
429
+ if self.config.add_qk_norm:
430
+ query = self.q_norm(query)
431
+ key = self.k_norm(key)
432
+
433
+ print(query.sum(), key.sum(), value.sum())
434
+
435
+ cos, sin = position_embeddings
436
  query, key = apply_rotary_pos_emb(query, key, cos, sin)
437
 
438
+ print(query.sum(), key.sum())
439
+ print()
440
+ print()
441
+
442
  query_length, key_length = query.shape[1], key.shape[1]
443
 
444
  if self.has_variable("cache", "cached_key"):
 
540
  def __call__(
541
  self,
542
  hidden_states,
543
+ position_embeddings,
544
  attention_mask,
545
  position_ids,
546
  deterministic: bool = True,
 
555
  key = self._split_heads(raw_key, self.num_key_value_heads)
556
  value = self._split_heads(raw_value, self.num_key_value_heads)
557
 
558
+ cos, sin = position_embeddings
559
  query, key = apply_rotary_pos_emb(query, key, cos, sin)
560
 
561
  query_length, key_length = query.shape[1], key.shape[1]
 
669
  def __call__(
670
  self,
671
  hidden_states,
672
+ position_embeddings,
673
  attention_mask=None,
674
  position_ids=None,
675
  deterministic: bool = True,
 
683
  hidden_states = self.input_layernorm(hidden_states)
684
  outputs = self.self_attn(
685
  hidden_states,
686
+ position_embeddings,
687
  attention_mask=attention_mask,
688
  position_ids=position_ids,
689
  deterministic=deterministic,
 
889
  gradient_checkpointing: bool = False
890
 
891
  def setup(self):
892
+ self.rotary_emb = FlaxTPULlamaRotaryEmbedding(self.config, dtype=self.dtype)
893
+
894
  if self.gradient_checkpointing:
895
+ FlaxTPULlamaDecoderCheckpointLayer = remat(FlaxTPULlamaDecoderLayer, static_argnums=(4, 5, 6))
896
  self.blocks = [
897
  FlaxTPULlamaDecoderCheckpointLayer(self.config, dtype=self.dtype, name=str(i))
898
  for i in range(self.config.num_hidden_layers)
 
917
  all_attentions = () if output_attentions else None
918
  all_hidden_states = [(), ()] if output_hidden_states else None
919
 
920
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
921
+
922
  if output_hidden_states:
923
  all_hidden_states[0] += (hidden_states,)
924
  all_hidden_states[1] += (hidden_states,)
 
926
  for block_idx, block in enumerate(self.blocks):
927
  layer_outputs = block(
928
  hidden_states,
929
+ position_embeddings,
930
  attention_mask,
931
  position_ids,
932
  deterministic,