MagistrTheOne commited on
Commit
509745d
·
verified ·
1 Parent(s): 33b6c6b

MURZIK-15B init weights

Browse files
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MurzikForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 1,
7
+ "dtype": "bfloat16",
8
+ "eos_token_id": 2,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 5120,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 20480,
14
+ "max_position_embeddings": 8192,
15
+ "model_type": "murzik",
16
+ "num_attention_heads": 40,
17
+ "num_hidden_layers": 32,
18
+ "num_key_value_heads": 8,
19
+ "pad_token_id": null,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_theta": 1000000.0,
22
+ "tie_word_embeddings": true,
23
+ "transformers_version": "5.6.0",
24
+ "use_cache": true,
25
+ "use_qk_norm": true,
26
+ "vocab_size": 128256,
27
+ "auto_map": {
28
+ "AutoConfig": "murzik.configuration_murzik.MurzikConfig",
29
+ "AutoModelForCausalLM": "murzik.modeling_murzik.MurzikForCausalLM"
30
+ }
31
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:682eabba623085c2c33a5b59251d2033c476dcd432261bfac2a363e8752e71a9
3
+ size 25473255848
murzik/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
2
+
3
+ from .configuration_murzik import MurzikConfig
4
+ from .configuration_murzik_moe import MurzikMoeConfig
5
+ from .modeling_murzik import MurzikForCausalLM
6
+ from .modeling_murzik_moe import MurzikMoeForCausalLM
7
+ from .tokenization_murzik import MurzikTokenizer
8
+
9
+ AutoConfig.register("murzik", MurzikConfig)
10
+ AutoConfig.register("murzik_moe", MurzikMoeConfig)
11
+ AutoModelForCausalLM.register(MurzikConfig, MurzikForCausalLM)
12
+ AutoModelForCausalLM.register(MurzikMoeConfig, MurzikMoeForCausalLM)
13
+ AutoTokenizer.register(MurzikConfig, slow_tokenizer_class=MurzikTokenizer)
14
+
15
+ __all__ = [
16
+ "MurzikConfig",
17
+ "MurzikMoeConfig",
18
+ "MurzikForCausalLM",
19
+ "MurzikMoeForCausalLM",
20
+ "MurzikTokenizer",
21
+ ]
murzik/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.06 kB). View file
 
murzik/__pycache__/configuration_murzik.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
murzik/__pycache__/configuration_murzik_moe.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
murzik/__pycache__/modeling_murzik.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
murzik/__pycache__/modeling_murzik_moe.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
murzik/__pycache__/tokenization_murzik.cpython-311.pyc ADDED
Binary file (5.77 kB). View file
 
murzik/configuration_murzik.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Murzik dense config (pilot 1B/15B)."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class MurzikConfig(PretrainedConfig):
7
+ model_type = "murzik"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size: int = 128256,
12
+ hidden_size: int = 1536,
13
+ intermediate_size: int = 6144,
14
+ num_hidden_layers: int = 24,
15
+ num_attention_heads: int = 16,
16
+ num_key_value_heads: int = 4,
17
+ head_dim: int | None = None,
18
+ hidden_act: str = "silu",
19
+ max_position_embeddings: int = 32768,
20
+ initializer_range: float = 0.02,
21
+ rms_norm_eps: float = 1e-6,
22
+ use_cache: bool = True,
23
+ tie_word_embeddings: bool = True,
24
+ rope_theta: float = 1_000_000.0,
25
+ attention_dropout: float = 0.0,
26
+ use_qk_norm: bool = True,
27
+ pad_token_id: int | None = None,
28
+ bos_token_id: int = 1,
29
+ eos_token_id: int = 2,
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.hidden_size = hidden_size
34
+ self.intermediate_size = intermediate_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.head_dim = head_dim or hidden_size // num_attention_heads
39
+ self.hidden_act = hidden_act
40
+ self.max_position_embeddings = max_position_embeddings
41
+ self.initializer_range = initializer_range
42
+ self.rms_norm_eps = rms_norm_eps
43
+ self.use_cache = use_cache
44
+ self.rope_theta = rope_theta
45
+ self.attention_dropout = attention_dropout
46
+ self.use_qk_norm = use_qk_norm
47
+ super().__init__(
48
+ pad_token_id=pad_token_id,
49
+ bos_token_id=bos_token_id,
50
+ eos_token_id=eos_token_id,
51
+ tie_word_embeddings=tie_word_embeddings,
52
+ **kwargs,
53
+ )
murzik/configuration_murzik_moe.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MurzikMoE config (32B/64B/100B)."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class MurzikMoeConfig(PretrainedConfig):
7
+ model_type = "murzik_moe"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size: int = 128256,
12
+ hidden_size: int = 2560,
13
+ intermediate_size: int = 9728,
14
+ num_hidden_layers: int = 40,
15
+ num_attention_heads: int = 32,
16
+ num_key_value_heads: int = 8,
17
+ head_dim: int | None = None,
18
+ hidden_act: str = "silu",
19
+ max_position_embeddings: int = 32768,
20
+ initializer_range: float = 0.006,
21
+ rms_norm_eps: float = 1e-6,
22
+ use_cache: bool = True,
23
+ tie_word_embeddings: bool = True,
24
+ rope_theta: float = 1_000_000.0,
25
+ attention_dropout: float = 0.0,
26
+ use_qk_norm: bool = True,
27
+ decoder_sparse_step: int = 1,
28
+ moe_intermediate_size: int = 2432,
29
+ num_experts: int = 96,
30
+ num_experts_per_tok: int = 6,
31
+ num_shared_experts: int = 2,
32
+ first_k_dense_replace: int = 2,
33
+ router_aux_loss_coef: float = 0.001,
34
+ expert_bias_update_speed: float = 0.001,
35
+ pad_token_id: int | None = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ **kwargs,
39
+ ):
40
+ self.vocab_size = vocab_size
41
+ self.hidden_size = hidden_size
42
+ self.intermediate_size = intermediate_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.num_key_value_heads = num_key_value_heads
46
+ self.head_dim = head_dim or hidden_size // num_attention_heads
47
+ self.hidden_act = hidden_act
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.initializer_range = initializer_range
50
+ self.rms_norm_eps = rms_norm_eps
51
+ self.use_cache = use_cache
52
+ self.rope_theta = rope_theta
53
+ self.attention_dropout = attention_dropout
54
+ self.use_qk_norm = use_qk_norm
55
+ self.decoder_sparse_step = decoder_sparse_step
56
+ self.moe_intermediate_size = moe_intermediate_size
57
+ self.num_experts = num_experts
58
+ self.num_experts_per_tok = num_experts_per_tok
59
+ self.num_shared_experts = num_shared_experts
60
+ self.first_k_dense_replace = first_k_dense_replace
61
+ self.router_aux_loss_coef = router_aux_loss_coef
62
+ self.expert_bias_update_speed = expert_bias_update_speed
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ tie_word_embeddings=tie_word_embeddings,
68
+ **kwargs,
69
+ )
murzik/modeling_murzik.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Murzik dense decoder (pilot). GQA + RoPE + SwiGLU + RMSNorm."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from transformers import PreTrainedModel
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ from transformers.utils import logging
14
+
15
+ from .configuration_murzik import MurzikConfig
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class MurzikRMSNorm(nn.Module):
21
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(hidden_size))
24
+ self.variance_epsilon = eps
25
+
26
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
27
+ input_dtype = hidden_states.dtype
28
+ hidden_states = hidden_states.to(torch.float32)
29
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
30
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
31
+ return self.weight * hidden_states.to(input_dtype)
32
+
33
+
34
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
35
+ x1, x2 = x.chunk(2, dim=-1)
36
+ return torch.cat((-x2, x1), dim=-1)
37
+
38
+
39
+ def apply_rotary_pos_emb(q, k, cos, sin):
40
+ q_embed = (q * cos) + (rotate_half(q) * sin)
41
+ k_embed = (k * cos) + (rotate_half(k) * sin)
42
+ return q_embed, k_embed
43
+
44
+
45
+ class MurzikRotaryEmbedding(nn.Module):
46
+ def __init__(self, dim: int, max_position_embeddings: int, base: float, device=None):
47
+ super().__init__()
48
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
49
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
50
+ self.max_seq_len_cached = max_position_embeddings
51
+ t = torch.arange(max_position_embeddings, device=device, dtype=torch.int64).type_as(self.inv_freq)
52
+ freqs = torch.outer(t, self.inv_freq)
53
+ emb = torch.cat((freqs, freqs), dim=-1)
54
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
55
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
56
+
57
+ def forward(self, x: torch.Tensor, seq_len: int):
58
+ return (
59
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
60
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
61
+ )
62
+
63
+
64
+ class MurzikMLP(nn.Module):
65
+ def __init__(self, config: MurzikConfig):
66
+ super().__init__()
67
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
68
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
69
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
73
+
74
+
75
+ class MurzikAttention(nn.Module):
76
+ def __init__(self, config: MurzikConfig, layer_idx: int):
77
+ super().__init__()
78
+ self.layer_idx = layer_idx
79
+ self.hidden_size = config.hidden_size
80
+ self.num_heads = config.num_attention_heads
81
+ self.num_kv_heads = config.num_key_value_heads
82
+ self.head_dim = config.head_dim
83
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
84
+
85
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
86
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
87
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
88
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
89
+ self.q_norm = MurzikRMSNorm(self.head_dim, eps=config.rms_norm_eps) if config.use_qk_norm else None
90
+ self.k_norm = MurzikRMSNorm(self.head_dim, eps=config.rms_norm_eps) if config.use_qk_norm else None
91
+ self.dropout = nn.Dropout(config.attention_dropout)
92
+
93
+ def forward(
94
+ self,
95
+ hidden_states: torch.Tensor,
96
+ attention_mask: Optional[torch.Tensor],
97
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
98
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
99
+ use_cache: bool = False,
100
+ ):
101
+ bsz, q_len, _ = hidden_states.size()
102
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
103
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
104
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
105
+
106
+ if self.q_norm is not None:
107
+ q = self.q_norm(q)
108
+ if self.k_norm is not None:
109
+ k = self.k_norm(k)
110
+
111
+ cos, sin = position_embeddings
112
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
113
+
114
+ if past_key_value is not None:
115
+ k = torch.cat([past_key_value[0], k], dim=2)
116
+ v = torch.cat([past_key_value[1], v], dim=2)
117
+ past = (k, v) if use_cache else None
118
+
119
+ k = k.repeat_interleave(self.num_kv_groups, dim=1)
120
+ v = v.repeat_interleave(self.num_kv_groups, dim=1)
121
+
122
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
123
+ if attention_mask is not None:
124
+ attn_weights = attn_weights + attention_mask
125
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
126
+ attn_weights = self.dropout(attn_weights)
127
+ attn_output = torch.matmul(attn_weights, v)
128
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
129
+ return self.o_proj(attn_output), past
130
+
131
+
132
+ class MurzikDecoderLayer(nn.Module):
133
+ def __init__(self, config: MurzikConfig, layer_idx: int):
134
+ super().__init__()
135
+ self.self_attn = MurzikAttention(config, layer_idx)
136
+ self.mlp = MurzikMLP(config)
137
+ self.input_layernorm = MurzikRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
138
+ self.post_attention_layernorm = MurzikRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
139
+
140
+ def forward(self, hidden_states, attention_mask, position_embeddings, past_key_value=None, use_cache=False):
141
+ residual = hidden_states
142
+ hidden_states = self.input_layernorm(hidden_states)
143
+ hidden_states, present = self.self_attn(
144
+ hidden_states, attention_mask, position_embeddings, past_key_value, use_cache
145
+ )
146
+ hidden_states = residual + hidden_states
147
+ residual = hidden_states
148
+ hidden_states = self.post_attention_layernorm(hidden_states)
149
+ hidden_states = residual + self.mlp(hidden_states)
150
+ return hidden_states, present
151
+
152
+
153
+ class MurzikPreTrainedModel(PreTrainedModel):
154
+ config_class = MurzikConfig
155
+ base_model_prefix = "model"
156
+ supports_gradient_checkpointing = True
157
+ _no_split_modules = ["MurzikDecoderLayer"]
158
+
159
+ def _init_weights(self, module):
160
+ std = self.config.initializer_range
161
+ if isinstance(module, nn.Linear):
162
+ module.weight.data.normal_(mean=0.0, std=std)
163
+ if module.bias is not None:
164
+ module.bias.data.zero_()
165
+ elif isinstance(module, nn.Embedding):
166
+ module.weight.data.normal_(mean=0.0, std=std)
167
+
168
+
169
+ class MurzikModel(MurzikPreTrainedModel):
170
+ def __init__(self, config: MurzikConfig):
171
+ super().__init__(config)
172
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
173
+ self.layers = nn.ModuleList([MurzikDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
174
+ self.norm = MurzikRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
175
+ self.rotary_emb = MurzikRotaryEmbedding(
176
+ config.head_dim, config.max_position_embeddings, config.rope_theta
177
+ )
178
+ self.gradient_checkpointing = False
179
+ self.post_init()
180
+
181
+ def forward(
182
+ self,
183
+ input_ids: torch.LongTensor,
184
+ attention_mask: Optional[torch.Tensor] = None,
185
+ past_key_values: Optional[list] = None,
186
+ use_cache: bool = False,
187
+ **kwargs,
188
+ ):
189
+ bsz, seq_len = input_ids.shape
190
+ hidden_states = self.embed_tokens(input_ids)
191
+ cos, sin = self.rotary_emb(hidden_states, seq_len)
192
+ position_embeddings = (cos, sin)
193
+
194
+ if attention_mask is None:
195
+ attention_mask = torch.triu(
196
+ torch.full((seq_len, seq_len), float("-inf"), device=input_ids.device),
197
+ diagonal=1,
198
+ ).unsqueeze(0).unsqueeze(0)
199
+ else:
200
+ attention_mask = attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
201
+ attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
202
+
203
+ presents = [] if use_cache else None
204
+ for idx, layer in enumerate(self.layers):
205
+ past = past_key_values[idx] if past_key_values is not None else None
206
+ if self.gradient_checkpointing and self.training:
207
+ hidden_states, present = self._checkpoint_layer(
208
+ layer, hidden_states, attention_mask, position_embeddings, past, use_cache
209
+ )
210
+ else:
211
+ hidden_states, present = layer(
212
+ hidden_states, attention_mask, position_embeddings, past, use_cache
213
+ )
214
+ if use_cache:
215
+ presents.append(present)
216
+
217
+ hidden_states = self.norm(hidden_states)
218
+ return hidden_states, presents
219
+
220
+ def _checkpoint_layer(self, layer, hidden_states, attention_mask, position_embeddings, past, use_cache):
221
+ def custom_forward(hs):
222
+ out, pr = layer(hs, attention_mask, position_embeddings, past, use_cache)
223
+ return out, pr
224
+
225
+ return torch.utils.checkpoint.checkpoint(custom_forward, hidden_states, use_reentrant=False)
226
+
227
+
228
+ class MurzikForCausalLM(MurzikPreTrainedModel):
229
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
230
+
231
+ def __init__(self, config: MurzikConfig):
232
+ super().__init__(config)
233
+ self.model = MurzikModel(config)
234
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
235
+ self.post_init()
236
+
237
+ def get_input_embeddings(self):
238
+ return self.model.embed_tokens
239
+
240
+ def set_input_embeddings(self, value):
241
+ self.model.embed_tokens = value
242
+
243
+ def get_output_embeddings(self):
244
+ return self.lm_head
245
+
246
+ def set_output_embeddings(self, new_embeddings):
247
+ self.lm_head = new_embeddings
248
+
249
+ def forward(
250
+ self,
251
+ input_ids: torch.LongTensor,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ labels: Optional[torch.LongTensor] = None,
254
+ past_key_values: Optional[list] = None,
255
+ use_cache: bool = False,
256
+ **kwargs,
257
+ ) -> CausalLMOutputWithPast:
258
+ hidden_states, past_key_values = self.model(
259
+ input_ids=input_ids,
260
+ attention_mask=attention_mask,
261
+ past_key_values=past_key_values,
262
+ use_cache=use_cache,
263
+ )
264
+ logits = self.lm_head(hidden_states)
265
+
266
+ loss = None
267
+ if labels is not None:
268
+ shift_logits = logits[..., :-1, :].contiguous()
269
+ shift_labels = labels[..., 1:].contiguous()
270
+ loss = F.cross_entropy(
271
+ shift_logits.view(-1, shift_logits.size(-1)),
272
+ shift_labels.view(-1),
273
+ ignore_index=-100,
274
+ )
275
+
276
+ return CausalLMOutputWithPast(
277
+ loss=loss,
278
+ logits=logits,
279
+ past_key_values=past_key_values,
280
+ )
murzik/modeling_murzik_moe.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MurzikMoE — sparse MoE FFN on top of Murzik decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from .configuration_murzik_moe import MurzikMoeConfig
13
+ from .modeling_murzik import (
14
+ MurzikAttention,
15
+ MurzikMLP,
16
+ MurzikPreTrainedModel,
17
+ MurzikRMSNorm,
18
+ MurzikRotaryEmbedding,
19
+ )
20
+
21
+
22
+ class MurzikMoeMLP(nn.Module):
23
+ """Single expert SwiGLU block."""
24
+
25
+ def __init__(self, config: MurzikMoeConfig, intermediate_size: int):
26
+ super().__init__()
27
+ self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
28
+ self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
29
+ self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
33
+
34
+
35
+ class MurzikSparseMoeBlock(nn.Module):
36
+ def __init__(self, config: MurzikMoeConfig):
37
+ super().__init__()
38
+ self.num_experts = config.num_experts
39
+ self.top_k = config.num_experts_per_tok
40
+ self.hidden_size = config.hidden_size
41
+
42
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
43
+ self.experts = nn.ModuleList(
44
+ [MurzikMoeMLP(config, config.moe_intermediate_size) for _ in range(config.num_experts)]
45
+ )
46
+ self.shared_experts = nn.ModuleList(
47
+ [MurzikMoeMLP(config, config.moe_intermediate_size) for _ in range(config.num_shared_experts)]
48
+ )
49
+ self.register_buffer("expert_bias", torch.zeros(config.num_experts), persistent=True)
50
+ self.router_aux_loss_coef = config.router_aux_loss_coef
51
+
52
+ def forward(self, hidden_states: torch.Tensor):
53
+ batch_size, seq_len, hidden_dim = hidden_states.shape
54
+ flat = hidden_states.view(-1, hidden_dim)
55
+ router_logits = self.gate(flat)
56
+ routing_weights = F.softmax(router_logits + self.expert_bias, dim=-1, dtype=torch.float32)
57
+ routing_weights, selected = torch.topk(routing_weights, self.top_k, dim=-1)
58
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
59
+ routing_weights = routing_weights.to(flat.dtype)
60
+
61
+ out = torch.zeros_like(flat)
62
+ for expert_idx, expert in enumerate(self.experts):
63
+ mask = (selected == expert_idx).any(dim=-1)
64
+ if not mask.any():
65
+ continue
66
+ idx = mask.nonzero(as_tuple=True)[0]
67
+ expert_input = flat[idx]
68
+ expert_out = expert(expert_input)
69
+ weight = (selected[idx] == expert_idx).float() * routing_weights[idx]
70
+ weight = weight.sum(dim=-1, keepdim=True)
71
+ out[idx] += expert_out * weight
72
+
73
+ for shared in self.shared_experts:
74
+ out += shared(flat)
75
+
76
+ aux_loss = self._aux_loss(router_logits, selected)
77
+ return out.view(batch_size, seq_len, hidden_dim), aux_loss
78
+
79
+ def _aux_loss(self, router_logits: torch.Tensor, selected: torch.Tensor) -> torch.Tensor:
80
+ probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
81
+ one_hot = F.one_hot(selected, num_classes=self.num_experts).float().sum(dim=1)
82
+ load = one_hot.mean(dim=0)
83
+ balance = probs.mean(dim=0)
84
+ aux = self.num_experts * (load * balance).sum()
85
+ return aux * self.router_aux_loss_coef
86
+
87
+
88
+ class MurzikMoeDecoderLayer(nn.Module):
89
+ def __init__(self, config: MurzikMoeConfig, layer_idx: int):
90
+ super().__init__()
91
+ self.layer_idx = layer_idx
92
+ self.self_attn = MurzikAttention(config, layer_idx)
93
+ self.input_layernorm = MurzikRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
94
+ self.post_attention_layernorm = MurzikRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
95
+
96
+ use_moe = layer_idx >= config.first_k_dense_replace
97
+ if use_moe:
98
+ self.mlp = MurzikSparseMoeBlock(config)
99
+ self.is_moe = True
100
+ else:
101
+ self.mlp = MurzikMLP(config)
102
+ self.is_moe = False
103
+
104
+ def forward(self, hidden_states, attention_mask, position_embeddings, past_key_value=None, use_cache=False):
105
+ residual = hidden_states
106
+ hidden_states = self.input_layernorm(hidden_states)
107
+ hidden_states, present = self.self_attn(
108
+ hidden_states, attention_mask, position_embeddings, past_key_value, use_cache
109
+ )
110
+ hidden_states = residual + hidden_states
111
+
112
+ residual = hidden_states
113
+ hidden_states = self.post_attention_layernorm(hidden_states)
114
+ aux_loss = None
115
+ if self.is_moe:
116
+ hidden_states, aux_loss = self.mlp(hidden_states)
117
+ else:
118
+ hidden_states = self.mlp(hidden_states)
119
+ hidden_states = residual + hidden_states
120
+ return hidden_states, present, aux_loss
121
+
122
+
123
+ class MurzikMoePreTrainedModel(MurzikPreTrainedModel):
124
+ config_class = MurzikMoeConfig
125
+ _no_split_modules = ["MurzikMoeDecoderLayer"]
126
+
127
+
128
+ class MurzikMoeModel(MurzikMoePreTrainedModel):
129
+ def __init__(self, config: MurzikMoeConfig):
130
+ super().__init__(config)
131
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
132
+ self.layers = nn.ModuleList(
133
+ [MurzikMoeDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
134
+ )
135
+ self.norm = MurzikRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
136
+ self.rotary_emb = MurzikRotaryEmbedding(
137
+ config.head_dim, config.max_position_embeddings, config.rope_theta
138
+ )
139
+ self.gradient_checkpointing = False
140
+ self.post_init()
141
+
142
+ def forward(
143
+ self,
144
+ input_ids: torch.LongTensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[list] = None,
147
+ use_cache: bool = False,
148
+ **kwargs,
149
+ ):
150
+ bsz, seq_len = input_ids.shape
151
+ hidden_states = self.embed_tokens(input_ids)
152
+ cos, sin = self.rotary_emb(hidden_states, seq_len)
153
+ position_embeddings = (cos, sin)
154
+
155
+ if attention_mask is None:
156
+ attention_mask = torch.triu(
157
+ torch.full((seq_len, seq_len), float("-inf"), device=input_ids.device),
158
+ diagonal=1,
159
+ ).unsqueeze(0).unsqueeze(0)
160
+ else:
161
+ attention_mask = attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
162
+ attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
163
+
164
+ presents = [] if use_cache else None
165
+ aux_loss = torch.tensor(0.0, device=input_ids.device)
166
+ for idx, layer in enumerate(self.layers):
167
+ past = past_key_values[idx] if past_key_values is not None else None
168
+ hidden_states, present, layer_aux = layer(
169
+ hidden_states, attention_mask, position_embeddings, past, use_cache
170
+ )
171
+ if layer_aux is not None:
172
+ aux_loss = aux_loss + layer_aux
173
+ if use_cache:
174
+ presents.append(present)
175
+
176
+ hidden_states = self.norm(hidden_states)
177
+ return hidden_states, presents, aux_loss
178
+
179
+
180
+ class MurzikMoeForCausalLM(MurzikMoePreTrainedModel):
181
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
182
+
183
+ def __init__(self, config: MurzikMoeConfig):
184
+ super().__init__(config)
185
+ self.model = MurzikMoeModel(config)
186
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
187
+ self.post_init()
188
+
189
+ def get_input_embeddings(self):
190
+ return self.model.embed_tokens
191
+
192
+ def set_input_embeddings(self, value):
193
+ self.model.embed_tokens = value
194
+
195
+ def get_output_embeddings(self):
196
+ return self.lm_head
197
+
198
+ def set_output_embeddings(self, new_embeddings):
199
+ self.lm_head = new_embeddings
200
+
201
+ def forward(
202
+ self,
203
+ input_ids: torch.LongTensor,
204
+ attention_mask: Optional[torch.Tensor] = None,
205
+ labels: Optional[torch.LongTensor] = None,
206
+ past_key_values: Optional[list] = None,
207
+ use_cache: bool = False,
208
+ **kwargs,
209
+ ) -> CausalLMOutputWithPast:
210
+ hidden_states, past_key_values, aux_loss = self.model(
211
+ input_ids=input_ids,
212
+ attention_mask=attention_mask,
213
+ past_key_values=past_key_values,
214
+ use_cache=use_cache,
215
+ )
216
+ logits = self.lm_head(hidden_states)
217
+
218
+ loss = None
219
+ if labels is not None:
220
+ shift_logits = logits[..., :-1, :].contiguous()
221
+ shift_labels = labels[..., 1:].contiguous()
222
+ loss = F.cross_entropy(
223
+ shift_logits.view(-1, shift_logits.size(-1)),
224
+ shift_labels.view(-1),
225
+ ignore_index=-100,
226
+ )
227
+ if aux_loss is not None:
228
+ loss = loss + aux_loss
229
+
230
+ return CausalLMOutputWithPast(
231
+ loss=loss,
232
+ logits=logits,
233
+ past_key_values=past_key_values,
234
+ )
murzik/tokenization_murzik.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Murzik tokenizer — SentencePiece wrapper for Hugging Face."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import sentencepiece as spm
7
+ from transformers import PreTrainedTokenizer
8
+
9
+ # Special tokens (must match SFT template)
10
+ SPECIAL_TOKENS = {
11
+ "pad_token": "<|pad|>",
12
+ "bos_token": "<|murzik|>",
13
+ "eos_token": "<|end|>",
14
+ "unk_token": "<|unk|>",
15
+ "additional_special_tokens": [
16
+ "<|user|>",
17
+ "<|assistant|>",
18
+ "<|system|>",
19
+ ],
20
+ }
21
+
22
+
23
+ class MurzikTokenizer(PreTrainedTokenizer):
24
+ model_input_names = ["input_ids", "attention_mask"]
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_file: str,
29
+ bos_token: str = SPECIAL_TOKENS["bos_token"],
30
+ eos_token: str = SPECIAL_TOKENS["eos_token"],
31
+ pad_token: str = SPECIAL_TOKENS["pad_token"],
32
+ unk_token: str = SPECIAL_TOKENS["unk_token"],
33
+ **kwargs,
34
+ ):
35
+ self.vocab_file = vocab_file
36
+ self.sp_model = spm.SentencePieceProcessor()
37
+ if vocab_file and Path(vocab_file).exists():
38
+ self.sp_model.Load(vocab_file)
39
+ super().__init__(
40
+ bos_token=bos_token,
41
+ eos_token=eos_token,
42
+ pad_token=pad_token,
43
+ unk_token=unk_token,
44
+ **kwargs,
45
+ )
46
+
47
+ @property
48
+ def vocab_size(self) -> int:
49
+ return self.sp_model.get_piece_size()
50
+
51
+ def get_vocab(self):
52
+ return {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
53
+
54
+ def _tokenize(self, text: str) -> list[str]:
55
+ return self.sp_model.encode(text, out_type=str)
56
+
57
+ def _convert_token_to_id(self, token: str) -> int:
58
+ return self.sp_model.piece_to_id(token)
59
+
60
+ def _convert_id_to_token(self, index: int) -> str:
61
+ return self.sp_model.id_to_piece(index)
62
+
63
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
64
+ return self.sp_model.decode(tokens)
65
+
66
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
67
+ if token_ids_1 is None:
68
+ return token_ids_0
69
+ return token_ids_0 + token_ids_1
70
+
71
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
72
+ if already_has_special_tokens:
73
+ return super().get_special_tokens_mask(
74
+ token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
75
+ )
76
+ if token_ids_1 is not None:
77
+ return ([0] * len(token_ids_0)) + ([1] + [0] * (len(token_ids_1) - 1))
78
+ return [0] * len(token_ids_0)
79
+
80
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
81
+ if token_ids_1 is None:
82
+ return len(token_ids_0) * [0]
83
+ return [0] * (len(token_ids_0) + len(token_ids_1))
84
+
85
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
86
+ out = Path(save_directory) / f"{filename_prefix or ''}murzik.model"
87
+ if self.vocab_file:
88
+ import shutil
89
+ shutil.copy(self.vocab_file, out)
90
+ return (str(out),)
tokenizer.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<|pad|>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<|murzik|>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<|end|>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "<|unk|>",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 4,
44
+ "content": "<|user|>",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ {
52
+ "id": 5,
53
+ "content": "<|assistant|>",
54
+ "single_word": false,
55
+ "lstrip": false,
56
+ "rstrip": false,
57
+ "normalized": false,
58
+ "special": true
59
+ },
60
+ {
61
+ "id": 6,
62
+ "content": "<|system|>",
63
+ "single_word": false,
64
+ "lstrip": false,
65
+ "rstrip": false,
66
+ "normalized": false,
67
+ "special": true
68
+ }
69
+ ],
70
+ "normalizer": null,
71
+ "pre_tokenizer": {
72
+ "type": "ByteLevel",
73
+ "add_prefix_space": true,
74
+ "trim_offsets": true,
75
+ "use_regex": true
76
+ },
77
+ "post_processor": {
78
+ "type": "TemplateProcessing",
79
+ "single": [
80
+ {
81
+ "Sequence": {
82
+ "id": "A",
83
+ "type_id": 0
84
+ }
85
+ }
86
+ ],
87
+ "pair": [
88
+ {
89
+ "Sequence": {
90
+ "id": "A",
91
+ "type_id": 0
92
+ }
93
+ },
94
+ {
95
+ "Sequence": {
96
+ "id": "B",
97
+ "type_id": 1
98
+ }
99
+ }
100
+ ],
101
+ "special_tokens": {}
102
+ },
103
+ "decoder": null,
104
+ "model": {
105
+ "type": "BPE",
106
+ "dropout": null,
107
+ "unk_token": null,
108
+ "continuing_subword_prefix": null,
109
+ "end_of_word_suffix": null,
110
+ "fuse_unk": false,
111
+ "byte_fallback": false,
112
+ "ignore_merges": false,
113
+ "vocab": {},
114
+ "merges": []
115
+ }
116
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|murzik|>",
4
+ "eos_token": "<|end|>",
5
+ "model_max_length": 8192,
6
+ "pad_token": "<|pad|>",
7
+ "tokenizer_class": "TokenizersBackend",
8
+ "unk_token": "<|unk|>"
9
+ }