tl-hyungguk commited on
Commit
665a63f
·
verified ·
1 Parent(s): e3fd6f9

Upload TridaForDLM

Browse files
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TridaForDLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_trida.TridaConfig",
8
+ "AutoModel": "modeling_trida.TridaForDLM",
9
+ "AutoModelForCausalLM": "modeling_trida.TridaForDLM"
10
+ },
11
+ "bd_size": 32,
12
+ "bos_token_id": 0,
13
+ "dtype": "float32",
14
+ "eos_token_id": 128001,
15
+ "fuse_cross_entropy": false,
16
+ "head_dim": 128,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 4096,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 11008,
21
+ "layer_types": [
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention",
42
+ "full_attention",
43
+ "full_attention",
44
+ "full_attention",
45
+ "full_attention",
46
+ "full_attention",
47
+ "full_attention",
48
+ "full_attention",
49
+ "full_attention",
50
+ "full_attention",
51
+ "full_attention",
52
+ "full_attention",
53
+ "full_attention"
54
+ ],
55
+ "max_position_embeddings": 4096,
56
+ "max_window_layers": 28,
57
+ "model_type": "Trida",
58
+ "num_attention_heads": 32,
59
+ "num_hidden_layers": 32,
60
+ "num_key_value_heads": 32,
61
+ "rms_norm_eps": 1e-05,
62
+ "rope_scaling": null,
63
+ "rope_theta": 100000.0,
64
+ "sliding_window": null,
65
+ "tie_word_embeddings": false,
66
+ "transformers_version": "4.57.1",
67
+ "use_cache": false,
68
+ "use_sliding_window": false,
69
+ "vocab_size": 128256
70
+ }
configuration_trida.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trida model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ try:
5
+ from transformers.configuration_utils import layer_type_validation
6
+ except Exception:
7
+ def layer_type_validation(layer_types):
8
+ return
9
+ from transformers.modeling_rope_utils import rope_config_validation
10
+ from transformers.utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class TridaConfig(PretrainedConfig):
17
+
18
+ model_type = "Trida"
19
+ keys_to_ignore_at_inference = ["past_key_values"]
20
+
21
+ # Default tensor parallel plan for base model `Trida`
22
+ base_model_tp_plan = {
23
+ "layers.*.self_attn.q_proj": "colwise",
24
+ "layers.*.self_attn.k_proj": "colwise",
25
+ "layers.*.self_attn.v_proj": "colwise",
26
+ "layers.*.self_attn.o_proj": "rowwise",
27
+ "layers.*.mlp.gate_proj": "colwise",
28
+ "layers.*.mlp.up_proj": "colwise",
29
+ "layers.*.mlp.down_proj": "rowwise",
30
+ }
31
+ base_model_pp_plan = {
32
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
33
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
34
+ "norm": (["hidden_states"], ["hidden_states"]),
35
+ }
36
+
37
+ def __init__(
38
+ self,
39
+ vocab_size=151936,
40
+ hidden_size=4096,
41
+ intermediate_size=22016,
42
+ num_hidden_layers=32,
43
+ num_attention_heads=32,
44
+ num_key_value_heads=32,
45
+ hidden_act="silu",
46
+ max_position_embeddings=32768,
47
+ initializer_range=0.02,
48
+ rms_norm_eps=1e-6,
49
+ use_cache=True,
50
+ tie_word_embeddings=False,
51
+ rope_theta=10000.0,
52
+ rope_scaling=None,
53
+ use_sliding_window=False,
54
+ sliding_window=4096,
55
+ max_window_layers=28,
56
+ layer_types=None,
57
+ attention_dropout=0.0,
58
+ bd_size=32,
59
+ **kwargs,
60
+ ):
61
+ self.vocab_size = vocab_size
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.hidden_size = hidden_size
64
+ self.intermediate_size = intermediate_size
65
+ self.num_hidden_layers = num_hidden_layers
66
+ self.num_attention_heads = num_attention_heads
67
+ self.use_sliding_window = use_sliding_window
68
+ self.sliding_window = sliding_window if self.use_sliding_window else None
69
+ self.max_window_layers = max_window_layers
70
+
71
+ # for backward compatibility
72
+ if num_key_value_heads is None:
73
+ num_key_value_heads = num_attention_heads
74
+
75
+ self.num_key_value_heads = num_key_value_heads
76
+ self.hidden_act = hidden_act
77
+ self.initializer_range = initializer_range
78
+ self.rms_norm_eps = rms_norm_eps
79
+ self.use_cache = use_cache
80
+ self.rope_theta = rope_theta
81
+ self.rope_scaling = rope_scaling
82
+ self.attention_dropout = attention_dropout
83
+ self.bd_size = bd_size
84
+ # Validate the correctness of rotary position embeddings parameters
85
+ # BC: if there is a 'type' field, move it to 'rope_type'.
86
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
87
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
88
+ rope_config_validation(self)
89
+
90
+ self.layer_types = layer_types
91
+ if self.layer_types is None:
92
+ self.layer_types = [
93
+ "sliding_attention"
94
+ if self.sliding_window is not None and i >= self.max_window_layers
95
+ else "full_attention"
96
+ for i in range(self.num_hidden_layers)
97
+ ]
98
+ layer_type_validation(self.layer_types)
99
+
100
+ ##########################################################
101
+ self.head_dim = 128
102
+ ##########################################################
103
+
104
+ super().__init__(
105
+ tie_word_embeddings=tie_word_embeddings,
106
+ **kwargs,
107
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 128001,
5
+ "transformers_version": "4.57.1",
6
+ "use_cache": false
7
+ }
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9eb1b2fcdd5c5d59694dc8a0b2aab057c45a2630ca3b496ee23a4000ee9b3386
3
+ size 4978740968
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23746158af4e0daa9bdb19831fd3760b5a89d20e1547572723955b183532fb93
3
+ size 4857206848
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46d9be098ec0902122ac1c282bea7c3019d95146a67b36bfdf41f320eeb10676
3
+ size 4857206896
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:619313938b5d6191233be9d11219d87c887312649870ff6beb3acc54efa04194
3
+ size 4857206896
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f09180fd1b694278e63bb12ae4d26cd3ad2b7a28ba0640f00194f5aad879076
3
+ size 4857206896
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a295d83bf29357b935393c01450b8298d80cbfa9d3f810fdecdbd4410ed6a765
3
+ size 3598897776
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e9a3a3e44db08e60da2b4dc1c6903552e2376648207db0058cf4d76de80f4eb
3
+ size 2101346432
model.safetensors.index.json ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 7526944768,
4
+ "total_size": 30107779072
5
+ },
6
+ "weight_map": {
7
+ "lm_head.weight": "model-00007-of-00007.safetensors",
8
+ "model.embed_tokens.weight": "model-00001-of-00007.safetensors",
9
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00007.safetensors",
10
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
11
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
12
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
13
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
15
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
16
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
17
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
18
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00007.safetensors",
19
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
20
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
21
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
22
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
23
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
24
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
25
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
26
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
27
+ "model.layers.10.input_layernorm.weight": "model-00003-of-00007.safetensors",
28
+ "model.layers.10.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
29
+ "model.layers.10.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
30
+ "model.layers.10.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
31
+ "model.layers.10.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
32
+ "model.layers.10.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
33
+ "model.layers.10.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
34
+ "model.layers.10.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
35
+ "model.layers.10.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
36
+ "model.layers.11.input_layernorm.weight": "model-00003-of-00007.safetensors",
37
+ "model.layers.11.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
38
+ "model.layers.11.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
39
+ "model.layers.11.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
40
+ "model.layers.11.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
41
+ "model.layers.11.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
42
+ "model.layers.11.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
43
+ "model.layers.11.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
44
+ "model.layers.11.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
45
+ "model.layers.12.input_layernorm.weight": "model-00003-of-00007.safetensors",
46
+ "model.layers.12.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
47
+ "model.layers.12.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
48
+ "model.layers.12.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
49
+ "model.layers.12.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
50
+ "model.layers.12.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
51
+ "model.layers.12.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
52
+ "model.layers.12.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
53
+ "model.layers.12.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
54
+ "model.layers.13.input_layernorm.weight": "model-00003-of-00007.safetensors",
55
+ "model.layers.13.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
56
+ "model.layers.13.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
57
+ "model.layers.13.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
58
+ "model.layers.13.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
59
+ "model.layers.13.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
60
+ "model.layers.13.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
61
+ "model.layers.13.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
62
+ "model.layers.13.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
63
+ "model.layers.14.input_layernorm.weight": "model-00003-of-00007.safetensors",
64
+ "model.layers.14.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
65
+ "model.layers.14.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
66
+ "model.layers.14.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
67
+ "model.layers.14.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
68
+ "model.layers.14.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
69
+ "model.layers.14.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
70
+ "model.layers.14.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
71
+ "model.layers.14.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
72
+ "model.layers.15.input_layernorm.weight": "model-00004-of-00007.safetensors",
73
+ "model.layers.15.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
74
+ "model.layers.15.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
75
+ "model.layers.15.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
76
+ "model.layers.15.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
77
+ "model.layers.15.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
78
+ "model.layers.15.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
79
+ "model.layers.15.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
80
+ "model.layers.15.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
81
+ "model.layers.16.input_layernorm.weight": "model-00004-of-00007.safetensors",
82
+ "model.layers.16.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
83
+ "model.layers.16.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
84
+ "model.layers.16.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
85
+ "model.layers.16.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
86
+ "model.layers.16.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
87
+ "model.layers.16.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
88
+ "model.layers.16.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
89
+ "model.layers.16.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
90
+ "model.layers.17.input_layernorm.weight": "model-00004-of-00007.safetensors",
91
+ "model.layers.17.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
92
+ "model.layers.17.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
93
+ "model.layers.17.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
94
+ "model.layers.17.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
95
+ "model.layers.17.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
96
+ "model.layers.17.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
97
+ "model.layers.17.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
98
+ "model.layers.17.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
99
+ "model.layers.18.input_layernorm.weight": "model-00004-of-00007.safetensors",
100
+ "model.layers.18.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
101
+ "model.layers.18.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
102
+ "model.layers.18.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
103
+ "model.layers.18.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
104
+ "model.layers.18.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
105
+ "model.layers.18.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
106
+ "model.layers.18.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
107
+ "model.layers.18.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
108
+ "model.layers.19.input_layernorm.weight": "model-00004-of-00007.safetensors",
109
+ "model.layers.19.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
110
+ "model.layers.19.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
111
+ "model.layers.19.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
112
+ "model.layers.19.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
113
+ "model.layers.19.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
114
+ "model.layers.19.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
115
+ "model.layers.19.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
116
+ "model.layers.19.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
117
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00007.safetensors",
118
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
119
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
120
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
121
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
122
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
123
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
124
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
125
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
126
+ "model.layers.20.input_layernorm.weight": "model-00004-of-00007.safetensors",
127
+ "model.layers.20.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
128
+ "model.layers.20.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
129
+ "model.layers.20.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
130
+ "model.layers.20.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
131
+ "model.layers.20.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
132
+ "model.layers.20.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
133
+ "model.layers.20.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
134
+ "model.layers.20.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
135
+ "model.layers.21.input_layernorm.weight": "model-00005-of-00007.safetensors",
136
+ "model.layers.21.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
137
+ "model.layers.21.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
138
+ "model.layers.21.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
139
+ "model.layers.21.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
140
+ "model.layers.21.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
141
+ "model.layers.21.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
142
+ "model.layers.21.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
143
+ "model.layers.21.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
144
+ "model.layers.22.input_layernorm.weight": "model-00005-of-00007.safetensors",
145
+ "model.layers.22.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
146
+ "model.layers.22.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
147
+ "model.layers.22.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
148
+ "model.layers.22.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
149
+ "model.layers.22.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
150
+ "model.layers.22.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
151
+ "model.layers.22.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
152
+ "model.layers.22.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
153
+ "model.layers.23.input_layernorm.weight": "model-00005-of-00007.safetensors",
154
+ "model.layers.23.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
155
+ "model.layers.23.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
156
+ "model.layers.23.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
157
+ "model.layers.23.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
158
+ "model.layers.23.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
159
+ "model.layers.23.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
160
+ "model.layers.23.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
161
+ "model.layers.23.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
162
+ "model.layers.24.input_layernorm.weight": "model-00005-of-00007.safetensors",
163
+ "model.layers.24.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
164
+ "model.layers.24.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
165
+ "model.layers.24.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
166
+ "model.layers.24.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
167
+ "model.layers.24.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
168
+ "model.layers.24.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
169
+ "model.layers.24.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
170
+ "model.layers.24.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
171
+ "model.layers.25.input_layernorm.weight": "model-00005-of-00007.safetensors",
172
+ "model.layers.25.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
173
+ "model.layers.25.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
174
+ "model.layers.25.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
175
+ "model.layers.25.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
176
+ "model.layers.25.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
177
+ "model.layers.25.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
178
+ "model.layers.25.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
179
+ "model.layers.25.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
180
+ "model.layers.26.input_layernorm.weight": "model-00005-of-00007.safetensors",
181
+ "model.layers.26.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
182
+ "model.layers.26.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
183
+ "model.layers.26.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
184
+ "model.layers.26.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
185
+ "model.layers.26.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
186
+ "model.layers.26.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
187
+ "model.layers.26.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
188
+ "model.layers.26.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
189
+ "model.layers.27.input_layernorm.weight": "model-00006-of-00007.safetensors",
190
+ "model.layers.27.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
191
+ "model.layers.27.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
192
+ "model.layers.27.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
193
+ "model.layers.27.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
194
+ "model.layers.27.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
195
+ "model.layers.27.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
196
+ "model.layers.27.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
197
+ "model.layers.27.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
198
+ "model.layers.28.input_layernorm.weight": "model-00006-of-00007.safetensors",
199
+ "model.layers.28.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
200
+ "model.layers.28.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
201
+ "model.layers.28.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
202
+ "model.layers.28.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
203
+ "model.layers.28.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
204
+ "model.layers.28.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
205
+ "model.layers.28.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
206
+ "model.layers.28.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
207
+ "model.layers.29.input_layernorm.weight": "model-00006-of-00007.safetensors",
208
+ "model.layers.29.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
209
+ "model.layers.29.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
210
+ "model.layers.29.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
211
+ "model.layers.29.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
212
+ "model.layers.29.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
213
+ "model.layers.29.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
214
+ "model.layers.29.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
215
+ "model.layers.29.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
216
+ "model.layers.3.input_layernorm.weight": "model-00002-of-00007.safetensors",
217
+ "model.layers.3.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
218
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
219
+ "model.layers.3.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
220
+ "model.layers.3.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
221
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
222
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
223
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
224
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
225
+ "model.layers.30.input_layernorm.weight": "model-00006-of-00007.safetensors",
226
+ "model.layers.30.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
227
+ "model.layers.30.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
228
+ "model.layers.30.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
229
+ "model.layers.30.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
230
+ "model.layers.30.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
231
+ "model.layers.30.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
232
+ "model.layers.30.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
233
+ "model.layers.30.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
234
+ "model.layers.31.input_layernorm.weight": "model-00006-of-00007.safetensors",
235
+ "model.layers.31.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
236
+ "model.layers.31.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
237
+ "model.layers.31.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
238
+ "model.layers.31.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
239
+ "model.layers.31.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
240
+ "model.layers.31.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
241
+ "model.layers.31.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
242
+ "model.layers.31.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
243
+ "model.layers.4.input_layernorm.weight": "model-00002-of-00007.safetensors",
244
+ "model.layers.4.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
245
+ "model.layers.4.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
246
+ "model.layers.4.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
247
+ "model.layers.4.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
248
+ "model.layers.4.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
249
+ "model.layers.4.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
250
+ "model.layers.4.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
251
+ "model.layers.4.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
252
+ "model.layers.5.input_layernorm.weight": "model-00002-of-00007.safetensors",
253
+ "model.layers.5.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
254
+ "model.layers.5.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
255
+ "model.layers.5.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
256
+ "model.layers.5.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
257
+ "model.layers.5.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
258
+ "model.layers.5.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
259
+ "model.layers.5.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
260
+ "model.layers.5.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
261
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00007.safetensors",
262
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
263
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
264
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
265
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
266
+ "model.layers.6.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
267
+ "model.layers.6.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
268
+ "model.layers.6.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
269
+ "model.layers.6.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
270
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00007.safetensors",
271
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
272
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
273
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
274
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
275
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
276
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
277
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
278
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
279
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00007.safetensors",
280
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
281
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
282
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
283
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
284
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
285
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
286
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
287
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
288
+ "model.layers.9.input_layernorm.weight": "model-00003-of-00007.safetensors",
289
+ "model.layers.9.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
290
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
291
+ "model.layers.9.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
292
+ "model.layers.9.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
293
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
294
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
295
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
296
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
297
+ "model.norm.weight": "model-00006-of-00007.safetensors"
298
+ }
299
+ }
modeling_trida.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from functools import partial
8
+
9
+ from transformers.activations import ACT2FN
10
+ from transformers.cache_utils import Cache, DynamicCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.integrations import use_kernel_forward_from_hub
13
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
14
+ from transformers.modeling_layers import GradientCheckpointingLayer
15
+ from transformers.modeling_outputs import (
16
+ BaseModelOutputWithPast,
17
+ CausalLMOutputWithPast,
18
+ )
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import auto_docstring, can_return_tuple, logging
23
+ from .configuration_trida import TridaConfig
24
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
25
+ from einops import rearrange, repeat
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class CausalLMOutputWithPastAndBlockCache(CausalLMOutputWithPast):
32
+ block_past_key_values: Optional[Cache] = None
33
+
34
+ @dataclass
35
+ class BaseModelOutputWithPastAndBlockCache(BaseModelOutputWithPast):
36
+ block_past_key_values: Optional[Cache] = None
37
+
38
+
39
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
40
+ def fused_flex_attention(q, k, v, mask=None):
41
+ return flex_attention(q, k, v, block_mask=mask, enable_gqa=True)
42
+
43
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
44
+ """
45
+ Constructs the specialized block diffusion attention mask for training
46
+ composed of three masks:
47
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
48
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
49
+ - **Block Causal Mask (M_BC)**: Attention to update x0
50
+
51
+ Args:
52
+ b, h: Batch and head indices (ignored for mask logic).
53
+ q_idx, kv_idx: Query and Key indices.
54
+ seq_len: Total sequence length.
55
+ block_size: Defines the block structure.
56
+
57
+ Returns:
58
+ A boolean attention mask.
59
+ """
60
+ # Indicate whether token belongs to xt or x0
61
+ x0_flag_q = (q_idx >= n)
62
+ x0_flag_kv = (kv_idx >= n)
63
+
64
+ # Compute block indices
65
+ block_q = torch.where(x0_flag_q == 1,
66
+ (q_idx - n) // block_size,
67
+ q_idx // block_size)
68
+ block_kv = torch.where(x0_flag_kv == 1,
69
+ (kv_idx - n) // block_size,
70
+ kv_idx // block_size)
71
+
72
+ # **1. Block Diagonal Mask (M_BD) **
73
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
74
+
75
+ # **2. Offset Block-Causal Mask (M_OBC) **
76
+ offset_block_causal = (
77
+ (block_q > block_kv)
78
+ & (x0_flag_kv == 1)
79
+ & (x0_flag_q == 0)
80
+ )
81
+
82
+ # **3. Block-Causal Mask (M_BC) **
83
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
84
+
85
+ # **4. Combine Masks **
86
+ return block_diagonal | offset_block_causal | block_causal
87
+
88
+ def eval_block_diff_mask(q_idx, kv_idx, block_size=None):
89
+ # Compute block indices
90
+ block_q = q_idx // block_size
91
+ block_kv = kv_idx // block_size
92
+
93
+ return block_q >= block_kv
94
+
95
+ class TridaMLP(nn.Module):
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.config = config
99
+ self.hidden_size = config.hidden_size
100
+ self.intermediate_size = config.intermediate_size
101
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
102
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
103
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
104
+ self.act_fn = ACT2FN[config.hidden_act]
105
+
106
+ def forward(self, x):
107
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
108
+ return down_proj
109
+
110
+
111
+ def rotate_half(x):
112
+ """Rotates half the hidden dims of the input."""
113
+ x1 = x[..., : x.shape[-1] // 2]
114
+ x2 = x[..., x.shape[-1] // 2 :]
115
+ return torch.cat((-x2, x1), dim=-1)
116
+
117
+
118
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
119
+ """Applies Rotary Position Embedding to the query and key tensors.
120
+
121
+ Args:
122
+ q (`torch.Tensor`): The query tensor.
123
+ k (`torch.Tensor`): The key tensor.
124
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
125
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
126
+ position_ids (`torch.Tensor`, *optional*):
127
+ Deprecated and unused.
128
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
129
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
130
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
131
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
132
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
133
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
134
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
135
+ Returns:
136
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
137
+ """
138
+ cos = cos.unsqueeze(unsqueeze_dim)
139
+ sin = sin.unsqueeze(unsqueeze_dim)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
146
+ """
147
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
148
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
149
+ """
150
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
151
+ if n_rep == 1:
152
+ return hidden_states
153
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
154
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
155
+
156
+
157
+ class TridaAttention(nn.Module):
158
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
159
+
160
+ def __init__(self, config: TridaConfig, layer_idx: int):
161
+ super().__init__()
162
+ self.config = config
163
+ self.layer_idx = layer_idx
164
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
165
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
166
+ self.scaling = self.head_dim**-0.5
167
+ self.attention_dropout = config.attention_dropout
168
+ self.is_causal = True
169
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
170
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
171
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
172
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
173
+ self.sliding_window = None
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
179
+ attention_mask: Optional[torch.Tensor],
180
+ past_key_value: Optional[Cache] = None,
181
+ cache_position: Optional[torch.LongTensor] = None,
182
+ update_past_key_values: Optional[bool] = False,
183
+ block_past_key_values: Optional[Cache] = None,
184
+ replace_position: Optional[int] = None,
185
+ **kwargs: Unpack[FlashAttentionKwargs],
186
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
187
+ input_shape = hidden_states.shape[:-1]
188
+ hidden_shape = (*input_shape, -1, self.head_dim)
189
+
190
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
191
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
192
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
193
+
194
+ cos, sin = position_embeddings
195
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
196
+ if self.training:
197
+ #split q into two parts
198
+ q_1 = query_states[:,:,:query_states.shape[2]//2]
199
+ q_2 = query_states[:,:,query_states.shape[2]//2:]
200
+ #split k into two parts
201
+ k_1 = key_states[:,:,:key_states.shape[2]//2]
202
+ k_2 = key_states[:,:,key_states.shape[2]//2:]
203
+ q_1, k_1 = apply_rotary_pos_emb(q_1, k_1, cos, sin)
204
+ q_2, k_2 = apply_rotary_pos_emb(q_2, k_2, cos, sin)
205
+ query_states = torch.cat((q_1, q_2), dim=-2)
206
+ key_states = torch.cat((k_1, k_2), dim=-2)
207
+ else:
208
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
209
+
210
+ if block_past_key_values is not None:
211
+ if len(block_past_key_values) <= self.layer_idx:
212
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213
+ key_states, value_states = block_past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
214
+ else:
215
+ block_cache_key_states = block_past_key_values[self.layer_idx][0]
216
+ block_cache_value_states = block_past_key_values[self.layer_idx][1]
217
+
218
+ block_cache_key_states[:, :, replace_position:replace_position+key_states.shape[2]] = key_states
219
+ block_cache_value_states[:, :, replace_position:replace_position+value_states.shape[2]] = value_states
220
+ key_states = block_cache_key_states
221
+ value_states = block_cache_value_states
222
+
223
+ if past_key_value is not None:
224
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
225
+ if update_past_key_values:
226
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
227
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
228
+ elif len(past_key_value) > self.layer_idx:
229
+ key_states = torch.cat((past_key_value[self.layer_idx][0], key_states), dim=-2)
230
+ value_states = torch.cat((past_key_value[self.layer_idx][1], value_states), dim=-2)
231
+
232
+ if self.training:
233
+ attn_output = fused_flex_attention(query_states, key_states, value_states, mask=attention_mask)
234
+ attn_output = attn_output.transpose(1, 2).contiguous()
235
+ else:
236
+ attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
237
+
238
+ attn_output, attn_weights = attention_interface(
239
+ self,
240
+ query_states,
241
+ key_states,
242
+ value_states,
243
+ attention_mask,
244
+ is_causal=False,
245
+ dropout=0.0 if not self.training else self.attention_dropout,
246
+ scaling=self.scaling,
247
+ sliding_window=self.sliding_window, # main diff with Llama
248
+ **kwargs,
249
+ )
250
+
251
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
252
+ attn_output = self.o_proj(attn_output)
253
+ return attn_output
254
+
255
+ @use_kernel_forward_from_hub("RMSNorm")
256
+ class TridaRMSNorm(nn.Module):
257
+ def __init__(self, hidden_size, eps=1e-6):
258
+ """
259
+ TridaRMSNorm is equivalent to T5LayerNorm
260
+ """
261
+ super().__init__()
262
+ self.weight = nn.Parameter(torch.ones(hidden_size))
263
+ self.variance_epsilon = eps
264
+
265
+ def forward(self, hidden_states):
266
+ input_dtype = hidden_states.dtype
267
+ hidden_states = hidden_states.to(torch.float32)
268
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
269
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
270
+ return self.weight * hidden_states.to(input_dtype)
271
+
272
+ def extra_repr(self):
273
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
274
+
275
+
276
+ class TridaDecoderLayer(GradientCheckpointingLayer):
277
+ def __init__(self, config: TridaConfig, layer_idx: int):
278
+ super().__init__()
279
+ self.hidden_size = config.hidden_size
280
+
281
+ self.self_attn = TridaAttention(config=config, layer_idx=layer_idx)
282
+
283
+ self.mlp = TridaMLP(config)
284
+ self.input_layernorm = TridaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
285
+ self.post_attention_layernorm = TridaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+ self.attention_type = config.layer_types[layer_idx]
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ attention_mask: Optional[torch.Tensor] = None,
292
+ position_ids: Optional[torch.LongTensor] = None,
293
+ past_key_value: Optional[Cache] = None,
294
+ use_cache: Optional[bool] = False,
295
+ cache_position: Optional[torch.LongTensor] = None,
296
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
297
+ update_past_key_values: Optional[bool] = False,
298
+ use_block_cache: Optional[bool] = False,
299
+ block_past_key_values: Optional[Cache] = None,
300
+ replace_position: Optional[int] = None,
301
+ **kwargs
302
+ ) -> tuple[torch.Tensor]:
303
+ residual = hidden_states
304
+ hidden_states = self.input_layernorm(hidden_states)
305
+ # Self Attention
306
+ hidden_states = self.self_attn(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ position_ids=position_ids,
310
+ past_key_value=past_key_value,
311
+ use_cache=use_cache,
312
+ cache_position=cache_position,
313
+ position_embeddings=position_embeddings,
314
+ update_past_key_values=update_past_key_values,
315
+ use_block_cache=use_block_cache,
316
+ block_past_key_values=block_past_key_values,
317
+ replace_position=replace_position,
318
+ **kwargs,
319
+ )
320
+ hidden_states = residual + hidden_states
321
+
322
+ # Fully Connected
323
+ residual = hidden_states
324
+ hidden_states = self.post_attention_layernorm(hidden_states)
325
+ hidden_states = self.mlp(hidden_states)
326
+ hidden_states = residual + hidden_states
327
+ return hidden_states
328
+
329
+
330
+
331
+ class TridaPreTrainedModel(PreTrainedModel):
332
+ config_class = TridaConfig
333
+ base_model_prefix = "model"
334
+ supports_gradient_checkpointing = True
335
+ _no_split_modules = ["TridaDecoderLayer"]
336
+ _skip_keys_device_placement = ["past_key_values"]
337
+ _supports_flash_attn_2 = True
338
+ _supports_sdpa = True
339
+ _supports_flex_attn = True
340
+ _supports_cache_class = True
341
+ _supports_quantized_cache = True
342
+ _supports_static_cache = True
343
+ _supports_attention_backend = True
344
+ _can_record_outputs = {
345
+ "hidden_states": TridaDecoderLayer,
346
+ "attentions": TridaAttention,
347
+ }
348
+
349
+ def _init_weights(self, module):
350
+ std = self.config.initializer_range
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+ elif isinstance(module, nn.Embedding):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+ if module.padding_idx is not None:
358
+ module.weight.data[module.padding_idx].zero_()
359
+ elif isinstance(module, TridaRMSNorm):
360
+ module.weight.data.fill_(1.0)
361
+
362
+
363
+ class TridaRotaryEmbedding(nn.Module):
364
+ def __init__(self, config: TridaConfig, device=None):
365
+ super().__init__()
366
+ # BC: "rope_type" was originally "type"
367
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
368
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
369
+ else:
370
+ self.rope_type = "default"
371
+ self.max_seq_len_cached = config.max_position_embeddings
372
+ self.original_max_seq_len = config.max_position_embeddings
373
+
374
+ self.config = config
375
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
376
+
377
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
378
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
379
+ self.original_inv_freq = self.inv_freq
380
+
381
+ @torch.no_grad()
382
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
383
+ def forward(self, x, position_ids):
384
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
385
+ position_ids_expanded = position_ids[:, None, :].float()
386
+
387
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
388
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
389
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
390
+ emb = torch.cat((freqs, freqs), dim=-1)
391
+ cos = emb.cos() * self.attention_scaling
392
+ sin = emb.sin() * self.attention_scaling
393
+
394
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
395
+
396
+
397
+
398
+ class TridaModel(TridaPreTrainedModel):
399
+ def __init__(self, config: TridaConfig):
400
+ super().__init__(config)
401
+ self.padding_idx = config.pad_token_id
402
+ self.vocab_size = config.vocab_size
403
+ self.bd_size = config.bd_size
404
+
405
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
406
+ self.layers = nn.ModuleList(
407
+ [TridaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
408
+ )
409
+ self.norm = TridaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
410
+ self.rotary_emb = TridaRotaryEmbedding(config=config)
411
+ self.gradient_checkpointing = True
412
+
413
+ # Initialize weights and apply final processing
414
+ self.post_init()
415
+
416
+ def get_input_embeddings(self):
417
+ return self.embed_tokens
418
+
419
+ def set_input_embeddings(self, value):
420
+ self.embed_tokens = value
421
+
422
+
423
+ def eval_mask(self, seqlen, block_size, cache_seq_len):
424
+ q_indices = torch.arange(seqlen) + cache_seq_len
425
+ k_indices = torch.arange(seqlen + cache_seq_len)
426
+ mask = eval_block_diff_mask(
427
+ q_idx=q_indices[:, None],
428
+ kv_idx=k_indices[None, :],
429
+ block_size=block_size
430
+ )
431
+ return mask
432
+
433
+ def gen_mask(self, seqlen, block_size, B, H):
434
+ mask = create_block_mask(
435
+ partial(block_diff_mask, block_size=block_size, n=seqlen),
436
+ B=B, H=H, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
437
+
438
+ return mask
439
+
440
+ def forward(
441
+ self,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ labels: Optional[torch.LongTensor] = None,
444
+ attention_mask: Optional[torch.Tensor] = None,
445
+ position_ids: Optional[torch.LongTensor] = None,
446
+ past_key_values: Optional[Cache] = None,
447
+ inputs_embeds: Optional[torch.FloatTensor] = None,
448
+ use_cache: Optional[bool] = None,
449
+ cache_position: Optional[torch.LongTensor] = None,
450
+ update_past_key_values: Optional[bool] = False,
451
+ block_size: Optional[int] = 32,
452
+ use_block_cache: Optional[bool] = False,
453
+ block_past_key_values: Optional[Cache] = None,
454
+ replace_position: Optional[int] = None,
455
+ **kwargs
456
+ ) -> BaseModelOutputWithPast:
457
+ if (input_ids is None) ^ (inputs_embeds is not None):
458
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
459
+
460
+ if inputs_embeds is None:
461
+ inputs_embeds = self.embed_tokens(input_ids)
462
+
463
+ if use_cache and past_key_values is None:
464
+ past_key_values = DynamicCache()
465
+
466
+ if use_block_cache and block_past_key_values is None:
467
+ block_past_key_values = DynamicCache()
468
+
469
+ if cache_position is None:
470
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
471
+ if self.training:
472
+ cache_position = torch.arange(
473
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]//2, device=inputs_embeds.device
474
+ )
475
+ else:
476
+ if use_block_cache:
477
+ block_start_position = past_seen_tokens+replace_position if replace_position is not None else past_seen_tokens
478
+ cache_position = torch.arange(
479
+ block_start_position, block_start_position + inputs_embeds.shape[1], device=inputs_embeds.device
480
+ )
481
+ else:
482
+ cache_position = torch.arange(
483
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] if not self.training else inputs_embeds.shape[1]//2, device=inputs_embeds.device
484
+ )
485
+
486
+ if position_ids is None:
487
+ position_ids = cache_position.unsqueeze(0)
488
+
489
+ if self.training:
490
+ attention_mask = self.gen_mask(labels.shape[1], self.bd_size, labels.shape[0], self.config.num_attention_heads).to(device=inputs_embeds.device)
491
+ else:
492
+ if use_block_cache and block_past_key_values.get_seq_length() != 0:
493
+ attention_mask = None
494
+ else:
495
+ attention_mask = self.eval_mask(input_ids.shape[1], block_size, past_key_values.get_seq_length() if past_key_values is not None else 0).to(device=inputs_embeds.device)
496
+
497
+ hidden_states = inputs_embeds
498
+
499
+ # create position embeddings to be shared across the decoder layers
500
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
501
+
502
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
503
+ hidden_states = decoder_layer(
504
+ hidden_states,
505
+ attention_mask=attention_mask,
506
+ position_ids=position_ids,
507
+ past_key_value=past_key_values,
508
+ use_cache=use_cache,
509
+ cache_position=cache_position,
510
+ position_embeddings=position_embeddings,
511
+ update_past_key_values=update_past_key_values,
512
+ use_block_cache=use_block_cache,
513
+ block_past_key_values=block_past_key_values,
514
+ replace_position=replace_position,
515
+ **kwargs,
516
+ )
517
+
518
+ hidden_states = self.norm(hidden_states)
519
+ return BaseModelOutputWithPastAndBlockCache(
520
+ last_hidden_state=hidden_states,
521
+ past_key_values=past_key_values if use_cache else None,
522
+ block_past_key_values=block_past_key_values if use_block_cache else None,
523
+ )
524
+
525
+
526
+ class TridaForDLM(TridaPreTrainedModel, GenerationMixin):
527
+ _tied_weights_keys = ["lm_head.weight"]
528
+ _tp_plan = {"lm_head": "colwise_rep"}
529
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
530
+
531
+ def __init__(self, config):
532
+ super().__init__(config)
533
+ self.model = TridaModel(config)
534
+ self.vocab_size = config.vocab_size
535
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
536
+
537
+ # Initialize weights and apply final processing
538
+ self.post_init()
539
+
540
+ def get_input_embeddings(self):
541
+ return self.model.embed_tokens
542
+
543
+ def set_input_embeddings(self, value):
544
+ self.model.embed_tokens = value
545
+
546
+ def get_output_embeddings(self):
547
+ return self.lm_head
548
+
549
+ def set_output_embeddings(self, new_embeddings):
550
+ self.lm_head = new_embeddings
551
+
552
+ def set_decoder(self, decoder):
553
+ self.model = decoder
554
+
555
+ def get_decoder(self):
556
+ return self.model
557
+
558
+ @can_return_tuple
559
+ def forward(
560
+ self,
561
+ input_ids: Optional[torch.LongTensor] = None,
562
+ attention_mask: Optional[torch.Tensor] = None,
563
+ position_ids: Optional[torch.LongTensor] = None,
564
+ past_key_values: Optional[Cache] = None,
565
+ inputs_embeds: Optional[torch.FloatTensor] = None,
566
+ labels: Optional[torch.LongTensor] = None,
567
+ use_cache: Optional[bool] = None,
568
+ cache_position: Optional[torch.LongTensor] = None,
569
+ logits_to_keep: Union[int, torch.Tensor] = 0,
570
+ update_past_key_values: Optional[bool] = False,
571
+ block_size: Optional[int] = 32,
572
+ use_block_cache: Optional[bool] = False,
573
+ block_past_key_values: Optional[Cache] = None,
574
+ replace_position: Optional[int] = None,
575
+ mask_id: Optional[int] = 128012,
576
+ **kwargs
577
+ ) -> CausalLMOutputWithPastAndBlockCache:
578
+
579
+ if self.training:
580
+ original_labels = labels.clone()
581
+ original_input_ids = input_ids.clone()
582
+
583
+ noisy_input_ids = input_ids.clone()
584
+
585
+ input_ids = input_ids.reshape(input_ids.shape[0] * input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
586
+ b, l = input_ids.shape
587
+ t = torch.rand((b,), device=input_ids.device)
588
+ eps=1e-3
589
+ p_mask = (1 - eps) * t + eps
590
+ p_mask = p_mask[:, None].repeat(1, l)
591
+
592
+ mask_indices = torch.rand((b, l), device=input_ids.device) < p_mask
593
+ x_t = torch.where(mask_indices, mask_id, input_ids).reshape(labels.shape)
594
+ noisy_input_ids[labels != -100] = x_t[labels != -100]
595
+ mask = (noisy_input_ids != mask_id)
596
+ labels[mask] = -100
597
+ input_ids = torch.cat([noisy_input_ids, input_ids.reshape(labels.shape)], dim=1)
598
+
599
+ complementary_noisy_input_ids = original_input_ids.clone()
600
+ complementary_labels = original_labels.clone()
601
+
602
+ complementary_input_ids = original_input_ids.reshape(original_input_ids.shape[0] * original_input_ids.shape[1] // self.model.bd_size, self.model.bd_size)
603
+
604
+ complementary_mask_indices = ~mask_indices
605
+ complementary_x_t = torch.where(complementary_mask_indices, mask_id, complementary_input_ids).reshape(labels.shape)
606
+ complementary_noisy_input_ids[complementary_labels != -100] = complementary_x_t[complementary_labels != -100]
607
+ complementary_mask = (complementary_noisy_input_ids != mask_id)
608
+ complementary_labels[complementary_mask] = -100
609
+ complementary_input_ids = torch.cat([complementary_noisy_input_ids, complementary_input_ids.reshape(complementary_labels.shape)], dim=1)
610
+
611
+ input_ids = torch.cat([input_ids, complementary_input_ids], dim=0)
612
+ labels = torch.cat([labels, complementary_labels], dim=0)
613
+
614
+ outputs: BaseModelOutputWithPastAndBlockCache = self.model(
615
+ input_ids=input_ids,
616
+ labels=labels,
617
+ attention_mask=attention_mask,
618
+ position_ids=position_ids,
619
+ past_key_values=past_key_values,
620
+ inputs_embeds=inputs_embeds,
621
+ use_cache=use_cache,
622
+ cache_position=cache_position,
623
+ update_past_key_values=update_past_key_values,
624
+ block_size=block_size,
625
+ use_block_cache=use_block_cache,
626
+ block_past_key_values=block_past_key_values,
627
+ replace_position=replace_position,
628
+ **kwargs,
629
+ )
630
+
631
+ hidden_states = outputs.last_hidden_state
632
+ if self.training:
633
+ hidden_states = hidden_states[:, :hidden_states.shape[1]//2, :]
634
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
635
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
636
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
637
+
638
+ loss = None
639
+ if labels is not None:
640
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
641
+
642
+ return CausalLMOutputWithPastAndBlockCache(
643
+ loss=loss,
644
+ logits=logits,
645
+ past_key_values=outputs.past_key_values,
646
+ hidden_states=outputs.hidden_states,
647
+ attentions=outputs.attentions,
648
+ block_past_key_values=outputs.block_past_key_values,
649
+ )
650
+
651
+ @torch.no_grad()
652
+ def generate(
653
+ self,
654
+ input_ids,
655
+ max_new_tokens,
656
+ mask_id=128012,
657
+ threshold=1,
658
+ small_block_size=8,
659
+ block_size=32,
660
+ stop_token=128001,
661
+ stopping_criteria=None,
662
+ top_p=0.95,
663
+ temperature=0,
664
+ use_block_cache=False,
665
+ **kwargs
666
+ ):
667
+ num_blocks = max_new_tokens // block_size
668
+ original_input_length = input_ids.shape[1]
669
+
670
+ if input_ids.shape[1] > block_size:
671
+ output = self.forward(input_ids=input_ids[:, :(input_ids.shape[1] // block_size * block_size)], use_cache=True, update_past_key_values=True, block_size=block_size)
672
+ logits, past_key_values = output.logits, output.past_key_values
673
+ if input_ids.shape[1] % block_size == 0:
674
+ next_token = logits[:, -1:, :].argmax(dim=-1)
675
+ input_ids = torch.cat([input_ids, next_token], dim=1)
676
+ else:
677
+ past_key_values = None
678
+
679
+ num_small_blocks = block_size // small_block_size
680
+
681
+ for block_idx in range(num_blocks):
682
+ if stop_token in input_ids[:, original_input_length:]:
683
+ break
684
+ prompt_length = input_ids.shape[1]
685
+ # Initialize x_init with mask_id
686
+ x_init = mask_id * torch.ones((input_ids.shape[0], block_size-prompt_length%block_size), device=self.device, dtype=torch.long)
687
+ x_init = torch.cat([input_ids, x_init], dim=1)
688
+
689
+ x_t = x_init.clone()
690
+ block_past_key_values = None
691
+ while True:
692
+ if stop_token in x_t[:, prompt_length:]:
693
+ stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
694
+ if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
695
+ break
696
+ mask_idx = (x_t[:, -block_size:] == mask_id)
697
+ # Decode a complete block, update cache, and generate the next token
698
+ if mask_idx.sum() == 0:
699
+ output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=True, block_size=block_size)
700
+ logits, past_key_values = output.logits, output.past_key_values
701
+ next_token = logits[:, -1:, :].argmax(dim=-1)
702
+ x_t = torch.cat([x_t, next_token], dim=1)
703
+ break
704
+ for small_block_idx in range(num_small_blocks):
705
+ small_block_start_idx = small_block_idx * small_block_size
706
+ small_block_end_idx = small_block_start_idx + small_block_size
707
+
708
+ start = -block_size + small_block_start_idx
709
+ end = None if block_size == small_block_end_idx else -block_size + small_block_end_idx
710
+ while True:
711
+ mask_idx = (x_t[:, -block_size:] == mask_id)
712
+ if mask_idx[:, start:end].sum() == 0:
713
+ break
714
+ if stop_token in x_t[:, prompt_length:]:
715
+ stop_token_idx = (x_t[:, prompt_length:] == stop_token).nonzero()[0][1]
716
+ if (x_t[:, prompt_length:prompt_length+stop_token_idx] == mask_id).sum() == 0:
717
+ break
718
+
719
+ if use_block_cache:
720
+ if block_past_key_values is None or (x_t[:, -block_size+small_block_start_idx] == mask_id).any():
721
+ output = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True)
722
+ logits, block_past_key_values = output.logits, output.block_past_key_values
723
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
724
+ logits = logits[:, start:end]
725
+ else:
726
+ logits = self.forward(input_ids=x_t[:,start:end], use_cache=True, past_key_values=past_key_values, update_past_key_values=False, use_block_cache=True, block_past_key_values=block_past_key_values, replace_position=small_block_start_idx).logits
727
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
728
+ else:
729
+ logits = self.forward(input_ids=x_t[:, -block_size:], use_cache=True, past_key_values=past_key_values, update_past_key_values=False).logits
730
+ logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
731
+ logits = logits[:, start:end]
732
+
733
+
734
+ x_1, p_1t = self.sample_with_top_p(logits, top_p=top_p, temperature=temperature)
735
+ # Select tokens with probability greater than threshold from p_1t
736
+ x1_p = torch.squeeze(torch.gather(p_1t, dim=-1, index=torch.unsqueeze(x_1, -1)), -1)
737
+ x1_p = torch.where(mask_idx[:, start:end], x1_p, -torch.inf)
738
+
739
+ unmask_idx = (x1_p > threshold)
740
+ max_prob_idx = x1_p.argmax(dim=-1)
741
+ unmask_idx[torch.arange(x_1.shape[0]), max_prob_idx] = True
742
+ unmask_idx = unmask_idx & mask_idx[:, start:end]
743
+
744
+ x_t[:, start:end][unmask_idx] = x_1[unmask_idx]
745
+
746
+ input_ids = x_t
747
+ # Truncate stop_token
748
+ if stop_token in input_ids[:, original_input_length:]:
749
+ stop_token_idx = (input_ids[:, original_input_length:] == stop_token).nonzero()[0][1]
750
+ input_ids = input_ids[:, :stop_token_idx+original_input_length+1]
751
+ return input_ids
752
+
753
+ def sample_with_top_p(self, logits, top_p=0.95, temperature=1.0):
754
+ # Calculate probabilities
755
+ if temperature > 0:
756
+ scaled_logits = logits / temperature
757
+ else:
758
+ p_1t = torch.softmax(logits, dim=-1)
759
+ x_1 = p_1t.argmax(dim=-1)
760
+ return x_1, p_1t
761
+
762
+ probs = F.softmax(scaled_logits, dim=-1)
763
+
764
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
765
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
766
+
767
+ sorted_indices_to_remove = cumulative_probs > top_p
768
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
769
+ sorted_indices_to_remove[..., 0] = 0
770
+
771
+ indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(
772
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
773
+ )
774
+
775
+ probs[indices_to_remove] = 0
776
+
777
+ # Renormalize so that the probabilities of remaining tokens sum to 1
778
+ # Add a small epsilon value to prevent division by zero
779
+ probs_sum = torch.sum(probs, dim=-1, keepdim=True)
780
+ normalized_probs = probs / probs_sum
781
+
782
+ p_1t = normalized_probs
783
+ x_1 = torch.multinomial(p_1t[0], num_samples=1).unsqueeze(0).squeeze(-1)
784
+
785
+ return x_1, p_1t