56m commited on
Commit
2144393
·
verified ·
1 Parent(s): 7212fb5

Upload 6 files

Browse files
config (2).json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UltraBaseForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_ultrabase.UltraBaseConfig",
7
+ "AutoModelForCausalLM": "modeling_ultrabase.UltraBaseForCausalLM"
8
+ },
9
+ "bos_token_id": 0,
10
+ "bypass_rate": 0.375,
11
+ "d_ff": 256,
12
+ "d_model": 256,
13
+ "dtype": "float32",
14
+ "eos_token_id": 0,
15
+ "head_dim": 16,
16
+ "latent_dim": 64,
17
+ "model_type": "ultrabase",
18
+ "n_heads": 12,
19
+ "n_layers": 16,
20
+ "num_private_experts": 6,
21
+ "num_shared_experts": 1,
22
+ "tie_word_embeddings": true,
23
+ "transformers_version": "5.12.1",
24
+ "vocab_size": 49152
25
+ }
configuration_ultrabase.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class UltraBaseConfig(PretrainedConfig):
4
+ model_type = "ultrabase"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=49152,
9
+ d_model=256,
10
+ n_layers=16,
11
+ n_heads=12,
12
+ latent_dim=64,
13
+ head_dim=16,
14
+ bypass_rate=0.375,
15
+ num_private_experts=6,
16
+ num_shared_experts=1,
17
+ d_ff=256,
18
+ bos_token_id=0,
19
+ eos_token_id=0,
20
+ tie_word_embeddings=True,
21
+ **kwargs
22
+ ):
23
+ super().__init__(
24
+ bos_token_id=bos_token_id,
25
+ eos_token_id=eos_token_id,
26
+ tie_word_embeddings=tie_word_embeddings,
27
+ **kwargs
28
+ )
29
+ self.vocab_size = vocab_size
30
+ self.d_model = d_model
31
+ self.n_layers = n_layers
32
+ self.n_heads = n_heads
33
+ self.latent_dim = latent_dim
34
+ self.head_dim = head_dim
35
+ self.bypass_rate = bypass_rate
36
+ self.num_private_experts = num_private_experts
37
+ self.num_shared_experts = num_shared_experts
38
+ self.d_ff = d_ff
generation_config (2).json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "output_attentions": false,
6
+ "output_hidden_states": false,
7
+ "transformers_version": "5.12.1"
8
+ }
model (2).safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1df47d01ce4454a7b669390da1574b4b9a2602fef8a826c3899b8f2cc448ae0a
3
+ size 168486496
modeling_ultrabase.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from transformers import PreTrainedModel
6
+ from transformers.generation import GenerationMixin
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from configuration_ultrabase import UltraBaseConfig
9
+
10
+ class RMSNorm(nn.Module):
11
+ def __init__(self, dim, eps=1e-6):
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(dim))
15
+
16
+ def forward(self, x):
17
+ variance = x.pow(2).mean(-1, keepdim=True)
18
+ return x * torch.rsqrt(variance + self.eps) * self.weight
19
+
20
+ class MLA(nn.Module):
21
+ def __init__(self, config):
22
+ super().__init__()
23
+ self.n_heads = config.n_heads
24
+ self.head_dim = config.head_dim
25
+ self.latent_dim = config.latent_dim
26
+ self.d_model = config.d_model
27
+
28
+ self.kv_down_proj = nn.Linear(config.d_model, config.latent_dim, bias=False)
29
+ self.kv_up_proj_k = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False)
30
+ self.kv_up_proj_v = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False)
31
+
32
+ self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False)
33
+ self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
34
+
35
+ def forward(self, x):
36
+ B, S, C = x.shape
37
+ q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
38
+
39
+ latent_kv = self.kv_down_proj(x)
40
+ k = self.kv_up_proj_k(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
41
+ v = self.kv_up_proj_v(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
42
+
43
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
44
+
45
+ mask = torch.full((S, S), float("-inf"), device=x.device)
46
+ mask = torch.triu(mask, diagonal=1)
47
+ attn_scores = attn_scores + mask.unsqueeze(0).unsqueeze(1)
48
+
49
+ attn_weights = F.softmax(attn_scores, dim=-1)
50
+ context = torch.matmul(attn_weights, v)
51
+ context = context.transpose(1, 2).contiguous().view(B, S, -1)
52
+
53
+ return self.o_proj(context)
54
+
55
+ class Expert(nn.Module):
56
+ def __init__(self, d_model, d_ff):
57
+ super().__init__()
58
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
59
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
60
+ self.act = nn.SiLU()
61
+
62
+ def forward(self, x):
63
+ return self.w2(self.act(self.w1(x)))
64
+
65
+ class SSPMoE(nn.Module):
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.num_private = config.num_private_experts
69
+ self.shared_expert = Expert(config.d_model, config.d_ff)
70
+ self.private_experts = nn.ModuleList([
71
+ Expert(config.d_model, config.d_ff) for _ in range(self.num_private)
72
+ ])
73
+ self.router = nn.Linear(config.d_model, self.num_private, bias=False)
74
+
75
+ def forward(self, x):
76
+ shared_out = self.shared_expert(x)
77
+
78
+ router_logits = self.router(x)
79
+ routing_weights = F.softmax(router_logits, dim=-1)
80
+ top1_weights, top1_indices = torch.topk(routing_weights, k=1, dim=-1)
81
+
82
+ B, S, C = x.shape
83
+ flat_x = x.view(-1, C)
84
+ flat_indices = top1_indices.view(-1)
85
+ flat_weights = top1_weights.view(-1, 1)
86
+
87
+ private_out = torch.zeros_like(flat_x)
88
+ for i in range(self.num_private):
89
+ mask = (flat_indices == i)
90
+ if mask.any():
91
+ expert_in = flat_x[mask]
92
+ expert_out = self.private_experts[i](expert_in)
93
+ private_out[mask] = expert_out * flat_weights[mask]
94
+
95
+ private_out = private_out.view(B, S, C)
96
+ return shared_out + private_out
97
+
98
+ class DecoderLayer(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.active_rate = 1.0 - config.bypass_rate
102
+ self.mod_router = nn.Linear(config.d_model, 1, bias=False)
103
+
104
+ self.pre_rmsnorm = RMSNorm(config.d_model)
105
+ self.mla_block = MLA(config)
106
+ self.ssp_moe_layer = SSPMoE(config)
107
+ self.post_rmsnorm = RMSNorm(config.d_model)
108
+
109
+ def forward(self, x):
110
+ B, S, C = x.shape
111
+ if S < 2:
112
+ h = self.pre_rmsnorm(x)
113
+ h = h + self.mla_block(h)
114
+ h = h + self.ssp_moe_layer(h)
115
+ return self.post_rmsnorm(h)
116
+
117
+ router_logits = self.mod_router(x).squeeze(-1)
118
+ k = int(S * self.active_rate)
119
+ k = max(1, min(k, S))
120
+
121
+ _, topk_indices = torch.topk(router_logits, k, dim=-1)
122
+ out = x.clone()
123
+
124
+ for b in range(B):
125
+ active_idx = topk_indices[b]
126
+ x_active = x[b, active_idx, :].unsqueeze(0)
127
+
128
+ h = self.pre_rmsnorm(x_active)
129
+ h = h + self.mla_block(h)
130
+ h = h + self.ssp_moe_layer(h)
131
+ h = self.post_rmsnorm(h)
132
+
133
+ out[b, active_idx, :] = h.squeeze(0)
134
+
135
+ return out
136
+
137
+ class UltraBasePreTrainedModel(PreTrainedModel):
138
+ config_class = UltraBaseConfig
139
+ base_model_prefix = "model"
140
+ supports_gradient_checkpointing = True
141
+
142
+ def _init_weights(self, module):
143
+ if isinstance(module, nn.Linear):
144
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
145
+ if module.bias is not None:
146
+ torch.nn.init.zeros_(module.bias)
147
+ elif isinstance(module, nn.Embedding):
148
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
149
+
150
+ class UltraBaseForCausalLM(PreTrainedModel, GenerationMixin):
151
+ def __init__(self, config):
152
+ super().__init__(config)
153
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
154
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)])
155
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
156
+
157
+ self.post_init()
158
+
159
+ def get_input_embeddings(self):
160
+ return self.embed
161
+
162
+ def set_input_embeddings(self, value):
163
+ self.embed = value
164
+
165
+ def get_output_embeddings(self):
166
+ return self.lm_head
167
+
168
+ def set_output_embeddings(self, new_embeddings):
169
+ self.lm_head = new_embeddings
170
+
171
+ def forward(self, input_ids, labels=None, **kwargs):
172
+ x = self.embed(input_ids)
173
+ for layer in self.layers:
174
+ x = layer(x)
175
+ logits = self.lm_head(x)
176
+
177
+ loss = None
178
+ if labels is not None:
179
+ shift_logits = logits[..., :-1, :].contiguous()
180
+ shift_labels = labels[..., 1:].contiguous()
181
+ loss_fct = nn.CrossEntropyLoss()
182
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
183
+
184
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
185
+
186
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
187
+ return {"input_ids": input_ids}
tokenizer (1).json ADDED
The diff for this file is too large to render. See raw diff