Upload FlaxTPULlamaForCausalLM
Browse files- config.json +3 -2
- configuration_tpu_llama.py +3 -0
- modelling_flax_tpu_llama.py +34 -5
config.json
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
{
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"TPULlamaForCausalLM"
|
| 4 |
],
|
|
@@ -36,7 +37,7 @@
|
|
| 36 |
"torch_dtype": "float32",
|
| 37 |
"transformers_version": "4.52.3",
|
| 38 |
"use_cache": true,
|
|
|
|
| 39 |
"use_sliding_window": false,
|
| 40 |
-
"vocab_size": 151936
|
| 41 |
-
"add_qk_norm": true
|
| 42 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"add_qk_norm": false,
|
| 3 |
"architectures": [
|
| 4 |
"TPULlamaForCausalLM"
|
| 5 |
],
|
|
|
|
| 37 |
"torch_dtype": "float32",
|
| 38 |
"transformers_version": "4.52.3",
|
| 39 |
"use_cache": true,
|
| 40 |
+
"use_qk_norm": true,
|
| 41 |
"use_sliding_window": false,
|
| 42 |
+
"vocab_size": 151936
|
|
|
|
| 43 |
}
|
configuration_tpu_llama.py
CHANGED
|
@@ -147,6 +147,7 @@ class TPULlamaConfig(PretrainedConfig):
|
|
| 147 |
attention_dropout=0.0,
|
| 148 |
mlp_bias=False,
|
| 149 |
head_dim=None,
|
|
|
|
| 150 |
expand_input_ids=False, # Transformers-native PyTorch generation support
|
| 151 |
expand_input_ids_maxlen=None,
|
| 152 |
expand_input_ids_vocab_size=None,
|
|
@@ -183,6 +184,8 @@ class TPULlamaConfig(PretrainedConfig):
|
|
| 183 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 184 |
rope_config_validation(self)
|
| 185 |
|
|
|
|
|
|
|
| 186 |
self.expand_input_ids = expand_input_ids
|
| 187 |
self.expand_input_ids_maxlen = expand_input_ids_maxlen
|
| 188 |
self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
|
|
|
|
| 147 |
attention_dropout=0.0,
|
| 148 |
mlp_bias=False,
|
| 149 |
head_dim=None,
|
| 150 |
+
add_qk_norm=False, # Qwen3 compatibility
|
| 151 |
expand_input_ids=False, # Transformers-native PyTorch generation support
|
| 152 |
expand_input_ids_maxlen=None,
|
| 153 |
expand_input_ids_vocab_size=None,
|
|
|
|
| 184 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 185 |
rope_config_validation(self)
|
| 186 |
|
| 187 |
+
self.add_qk_norm = add_qk_norm # Qwen3 compatibility
|
| 188 |
+
|
| 189 |
self.expand_input_ids = expand_input_ids
|
| 190 |
self.expand_input_ids_maxlen = expand_input_ids_maxlen
|
| 191 |
self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
|
modelling_flax_tpu_llama.py
CHANGED
|
@@ -273,10 +273,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
| 273 |
class FlaxTPULlamaRMSNorm(nn.Module):
|
| 274 |
config: TPULlamaConfig
|
| 275 |
dtype: jnp.dtype = jnp.float32
|
|
|
|
| 276 |
|
| 277 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
self.epsilon = self.config.rms_norm_eps
|
| 279 |
-
self.weight = self.param("weight", lambda _, shape: jnp.ones(shape),
|
| 280 |
|
| 281 |
def __call__(self, hidden_states):
|
| 282 |
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
|
|
@@ -350,6 +356,11 @@ class FlaxTPULlamaAttention(nn.Module):
|
|
| 350 |
self.k_proj = dense(self.num_key_value_heads * self.head_dim)
|
| 351 |
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
|
| 352 |
self.o_proj = dense(self.embed_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
self.causal_mask = make_causal_mask(
|
| 354 |
jnp.ones(
|
| 355 |
(1, getattr(config, "max_length", config.max_position_embeddings)),
|
|
@@ -357,7 +368,6 @@ class FlaxTPULlamaAttention(nn.Module):
|
|
| 357 |
),
|
| 358 |
dtype="bool",
|
| 359 |
)
|
| 360 |
-
self.rotary_emb = FlaxTPULlamaRotaryEmbedding(config, dtype=self.dtype)
|
| 361 |
|
| 362 |
def _split_heads(self, hidden_states, num_heads):
|
| 363 |
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
|
@@ -401,6 +411,7 @@ class FlaxTPULlamaAttention(nn.Module):
|
|
| 401 |
def __call__(
|
| 402 |
self,
|
| 403 |
hidden_states,
|
|
|
|
| 404 |
attention_mask,
|
| 405 |
position_ids,
|
| 406 |
deterministic: bool = True,
|
|
@@ -415,9 +426,19 @@ class FlaxTPULlamaAttention(nn.Module):
|
|
| 415 |
key = self._split_heads(raw_key, self.num_key_value_heads)
|
| 416 |
value = self._split_heads(raw_value, self.num_key_value_heads)
|
| 417 |
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
query_length, key_length = query.shape[1], key.shape[1]
|
| 422 |
|
| 423 |
if self.has_variable("cache", "cached_key"):
|
|
@@ -519,6 +540,7 @@ class FlaxTPULlamaFlashAttention(FlaxTPULlamaAttention):
|
|
| 519 |
def __call__(
|
| 520 |
self,
|
| 521 |
hidden_states,
|
|
|
|
| 522 |
attention_mask,
|
| 523 |
position_ids,
|
| 524 |
deterministic: bool = True,
|
|
@@ -533,7 +555,7 @@ class FlaxTPULlamaFlashAttention(FlaxTPULlamaAttention):
|
|
| 533 |
key = self._split_heads(raw_key, self.num_key_value_heads)
|
| 534 |
value = self._split_heads(raw_value, self.num_key_value_heads)
|
| 535 |
|
| 536 |
-
cos, sin =
|
| 537 |
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 538 |
|
| 539 |
query_length, key_length = query.shape[1], key.shape[1]
|
|
@@ -647,6 +669,7 @@ class FlaxTPULlamaDecoderLayer(nn.Module):
|
|
| 647 |
def __call__(
|
| 648 |
self,
|
| 649 |
hidden_states,
|
|
|
|
| 650 |
attention_mask=None,
|
| 651 |
position_ids=None,
|
| 652 |
deterministic: bool = True,
|
|
@@ -660,6 +683,7 @@ class FlaxTPULlamaDecoderLayer(nn.Module):
|
|
| 660 |
hidden_states = self.input_layernorm(hidden_states)
|
| 661 |
outputs = self.self_attn(
|
| 662 |
hidden_states,
|
|
|
|
| 663 |
attention_mask=attention_mask,
|
| 664 |
position_ids=position_ids,
|
| 665 |
deterministic=deterministic,
|
|
@@ -865,8 +889,10 @@ class FlaxTPULlamaLayerCollection(nn.Module):
|
|
| 865 |
gradient_checkpointing: bool = False
|
| 866 |
|
| 867 |
def setup(self):
|
|
|
|
|
|
|
| 868 |
if self.gradient_checkpointing:
|
| 869 |
-
FlaxTPULlamaDecoderCheckpointLayer = remat(FlaxTPULlamaDecoderLayer, static_argnums=(
|
| 870 |
self.blocks = [
|
| 871 |
FlaxTPULlamaDecoderCheckpointLayer(self.config, dtype=self.dtype, name=str(i))
|
| 872 |
for i in range(self.config.num_hidden_layers)
|
|
@@ -891,6 +917,8 @@ class FlaxTPULlamaLayerCollection(nn.Module):
|
|
| 891 |
all_attentions = () if output_attentions else None
|
| 892 |
all_hidden_states = [(), ()] if output_hidden_states else None
|
| 893 |
|
|
|
|
|
|
|
| 894 |
if output_hidden_states:
|
| 895 |
all_hidden_states[0] += (hidden_states,)
|
| 896 |
all_hidden_states[1] += (hidden_states,)
|
|
@@ -898,6 +926,7 @@ class FlaxTPULlamaLayerCollection(nn.Module):
|
|
| 898 |
for block_idx, block in enumerate(self.blocks):
|
| 899 |
layer_outputs = block(
|
| 900 |
hidden_states,
|
|
|
|
| 901 |
attention_mask,
|
| 902 |
position_ids,
|
| 903 |
deterministic,
|
|
|
|
| 273 |
class FlaxTPULlamaRMSNorm(nn.Module):
|
| 274 |
config: TPULlamaConfig
|
| 275 |
dtype: jnp.dtype = jnp.float32
|
| 276 |
+
override_dim: int = None
|
| 277 |
|
| 278 |
def setup(self):
|
| 279 |
+
if self.override_dim is not None:
|
| 280 |
+
dim = self.override_dim
|
| 281 |
+
else:
|
| 282 |
+
dim = self.config.hidden_size
|
| 283 |
+
|
| 284 |
self.epsilon = self.config.rms_norm_eps
|
| 285 |
+
self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), dim)
|
| 286 |
|
| 287 |
def __call__(self, hidden_states):
|
| 288 |
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
|
|
|
|
| 356 |
self.k_proj = dense(self.num_key_value_heads * self.head_dim)
|
| 357 |
self.v_proj = dense(self.num_key_value_heads * self.head_dim)
|
| 358 |
self.o_proj = dense(self.embed_dim)
|
| 359 |
+
|
| 360 |
+
if self.config.add_qk_norm:
|
| 361 |
+
self.q_norm = FlaxTPULlamaRMSNorm(self.config, dtype=self.dtype, override_dim=self.head_dim)
|
| 362 |
+
self.k_norm = FlaxTPULlamaRMSNorm(self.config, dtype=self.dtype, override_dim=self.head_dim)
|
| 363 |
+
|
| 364 |
self.causal_mask = make_causal_mask(
|
| 365 |
jnp.ones(
|
| 366 |
(1, getattr(config, "max_length", config.max_position_embeddings)),
|
|
|
|
| 368 |
),
|
| 369 |
dtype="bool",
|
| 370 |
)
|
|
|
|
| 371 |
|
| 372 |
def _split_heads(self, hidden_states, num_heads):
|
| 373 |
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
|
|
|
| 411 |
def __call__(
|
| 412 |
self,
|
| 413 |
hidden_states,
|
| 414 |
+
position_embeddings,
|
| 415 |
attention_mask,
|
| 416 |
position_ids,
|
| 417 |
deterministic: bool = True,
|
|
|
|
| 426 |
key = self._split_heads(raw_key, self.num_key_value_heads)
|
| 427 |
value = self._split_heads(raw_value, self.num_key_value_heads)
|
| 428 |
|
| 429 |
+
if self.config.add_qk_norm:
|
| 430 |
+
query = self.q_norm(query)
|
| 431 |
+
key = self.k_norm(key)
|
| 432 |
+
|
| 433 |
+
print(query.sum(), key.sum(), value.sum())
|
| 434 |
+
|
| 435 |
+
cos, sin = position_embeddings
|
| 436 |
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 437 |
|
| 438 |
+
print(query.sum(), key.sum())
|
| 439 |
+
print()
|
| 440 |
+
print()
|
| 441 |
+
|
| 442 |
query_length, key_length = query.shape[1], key.shape[1]
|
| 443 |
|
| 444 |
if self.has_variable("cache", "cached_key"):
|
|
|
|
| 540 |
def __call__(
|
| 541 |
self,
|
| 542 |
hidden_states,
|
| 543 |
+
position_embeddings,
|
| 544 |
attention_mask,
|
| 545 |
position_ids,
|
| 546 |
deterministic: bool = True,
|
|
|
|
| 555 |
key = self._split_heads(raw_key, self.num_key_value_heads)
|
| 556 |
value = self._split_heads(raw_value, self.num_key_value_heads)
|
| 557 |
|
| 558 |
+
cos, sin = position_embeddings
|
| 559 |
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 560 |
|
| 561 |
query_length, key_length = query.shape[1], key.shape[1]
|
|
|
|
| 669 |
def __call__(
|
| 670 |
self,
|
| 671 |
hidden_states,
|
| 672 |
+
position_embeddings,
|
| 673 |
attention_mask=None,
|
| 674 |
position_ids=None,
|
| 675 |
deterministic: bool = True,
|
|
|
|
| 683 |
hidden_states = self.input_layernorm(hidden_states)
|
| 684 |
outputs = self.self_attn(
|
| 685 |
hidden_states,
|
| 686 |
+
position_embeddings,
|
| 687 |
attention_mask=attention_mask,
|
| 688 |
position_ids=position_ids,
|
| 689 |
deterministic=deterministic,
|
|
|
|
| 889 |
gradient_checkpointing: bool = False
|
| 890 |
|
| 891 |
def setup(self):
|
| 892 |
+
self.rotary_emb = FlaxTPULlamaRotaryEmbedding(self.config, dtype=self.dtype)
|
| 893 |
+
|
| 894 |
if self.gradient_checkpointing:
|
| 895 |
+
FlaxTPULlamaDecoderCheckpointLayer = remat(FlaxTPULlamaDecoderLayer, static_argnums=(4, 5, 6))
|
| 896 |
self.blocks = [
|
| 897 |
FlaxTPULlamaDecoderCheckpointLayer(self.config, dtype=self.dtype, name=str(i))
|
| 898 |
for i in range(self.config.num_hidden_layers)
|
|
|
|
| 917 |
all_attentions = () if output_attentions else None
|
| 918 |
all_hidden_states = [(), ()] if output_hidden_states else None
|
| 919 |
|
| 920 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 921 |
+
|
| 922 |
if output_hidden_states:
|
| 923 |
all_hidden_states[0] += (hidden_states,)
|
| 924 |
all_hidden_states[1] += (hidden_states,)
|
|
|
|
| 926 |
for block_idx, block in enumerate(self.blocks):
|
| 927 |
layer_outputs = block(
|
| 928 |
hidden_states,
|
| 929 |
+
position_embeddings,
|
| 930 |
attention_mask,
|
| 931 |
position_ids,
|
| 932 |
deterministic,
|