Upload 6 files
Browse files- config (2).json +25 -0
- configuration_ultrabase.py +38 -0
- generation_config (2).json +8 -0
- model (2).safetensors +3 -0
- modeling_ultrabase.py +187 -0
- tokenizer (1).json +0 -0
config (2).json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"UltraBaseForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_ultrabase.UltraBaseConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_ultrabase.UltraBaseForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"bos_token_id": 0,
|
| 10 |
+
"bypass_rate": 0.375,
|
| 11 |
+
"d_ff": 256,
|
| 12 |
+
"d_model": 256,
|
| 13 |
+
"dtype": "float32",
|
| 14 |
+
"eos_token_id": 0,
|
| 15 |
+
"head_dim": 16,
|
| 16 |
+
"latent_dim": 64,
|
| 17 |
+
"model_type": "ultrabase",
|
| 18 |
+
"n_heads": 12,
|
| 19 |
+
"n_layers": 16,
|
| 20 |
+
"num_private_experts": 6,
|
| 21 |
+
"num_shared_experts": 1,
|
| 22 |
+
"tie_word_embeddings": true,
|
| 23 |
+
"transformers_version": "5.12.1",
|
| 24 |
+
"vocab_size": 49152
|
| 25 |
+
}
|
configuration_ultrabase.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class UltraBaseConfig(PretrainedConfig):
|
| 4 |
+
model_type = "ultrabase"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
vocab_size=49152,
|
| 9 |
+
d_model=256,
|
| 10 |
+
n_layers=16,
|
| 11 |
+
n_heads=12,
|
| 12 |
+
latent_dim=64,
|
| 13 |
+
head_dim=16,
|
| 14 |
+
bypass_rate=0.375,
|
| 15 |
+
num_private_experts=6,
|
| 16 |
+
num_shared_experts=1,
|
| 17 |
+
d_ff=256,
|
| 18 |
+
bos_token_id=0,
|
| 19 |
+
eos_token_id=0,
|
| 20 |
+
tie_word_embeddings=True,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__(
|
| 24 |
+
bos_token_id=bos_token_id,
|
| 25 |
+
eos_token_id=eos_token_id,
|
| 26 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 27 |
+
**kwargs
|
| 28 |
+
)
|
| 29 |
+
self.vocab_size = vocab_size
|
| 30 |
+
self.d_model = d_model
|
| 31 |
+
self.n_layers = n_layers
|
| 32 |
+
self.n_heads = n_heads
|
| 33 |
+
self.latent_dim = latent_dim
|
| 34 |
+
self.head_dim = head_dim
|
| 35 |
+
self.bypass_rate = bypass_rate
|
| 36 |
+
self.num_private_experts = num_private_experts
|
| 37 |
+
self.num_shared_experts = num_shared_experts
|
| 38 |
+
self.d_ff = d_ff
|
generation_config (2).json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 0,
|
| 5 |
+
"output_attentions": false,
|
| 6 |
+
"output_hidden_states": false,
|
| 7 |
+
"transformers_version": "5.12.1"
|
| 8 |
+
}
|
model (2).safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1df47d01ce4454a7b669390da1574b4b9a2602fef8a826c3899b8f2cc448ae0a
|
| 3 |
+
size 168486496
|
modeling_ultrabase.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from transformers import PreTrainedModel
|
| 6 |
+
from transformers.generation import GenerationMixin
|
| 7 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 8 |
+
from configuration_ultrabase import UltraBaseConfig
|
| 9 |
+
|
| 10 |
+
class RMSNorm(nn.Module):
|
| 11 |
+
def __init__(self, dim, eps=1e-6):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.eps = eps
|
| 14 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 18 |
+
return x * torch.rsqrt(variance + self.eps) * self.weight
|
| 19 |
+
|
| 20 |
+
class MLA(nn.Module):
|
| 21 |
+
def __init__(self, config):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.n_heads = config.n_heads
|
| 24 |
+
self.head_dim = config.head_dim
|
| 25 |
+
self.latent_dim = config.latent_dim
|
| 26 |
+
self.d_model = config.d_model
|
| 27 |
+
|
| 28 |
+
self.kv_down_proj = nn.Linear(config.d_model, config.latent_dim, bias=False)
|
| 29 |
+
self.kv_up_proj_k = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False)
|
| 30 |
+
self.kv_up_proj_v = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False)
|
| 31 |
+
|
| 32 |
+
self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False)
|
| 33 |
+
self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
B, S, C = x.shape
|
| 37 |
+
q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
|
| 38 |
+
|
| 39 |
+
latent_kv = self.kv_down_proj(x)
|
| 40 |
+
k = self.kv_up_proj_k(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
|
| 41 |
+
v = self.kv_up_proj_v(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
|
| 42 |
+
|
| 43 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 44 |
+
|
| 45 |
+
mask = torch.full((S, S), float("-inf"), device=x.device)
|
| 46 |
+
mask = torch.triu(mask, diagonal=1)
|
| 47 |
+
attn_scores = attn_scores + mask.unsqueeze(0).unsqueeze(1)
|
| 48 |
+
|
| 49 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 50 |
+
context = torch.matmul(attn_weights, v)
|
| 51 |
+
context = context.transpose(1, 2).contiguous().view(B, S, -1)
|
| 52 |
+
|
| 53 |
+
return self.o_proj(context)
|
| 54 |
+
|
| 55 |
+
class Expert(nn.Module):
|
| 56 |
+
def __init__(self, d_model, d_ff):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.w1 = nn.Linear(d_model, d_ff, bias=False)
|
| 59 |
+
self.w2 = nn.Linear(d_ff, d_model, bias=False)
|
| 60 |
+
self.act = nn.SiLU()
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
return self.w2(self.act(self.w1(x)))
|
| 64 |
+
|
| 65 |
+
class SSPMoE(nn.Module):
|
| 66 |
+
def __init__(self, config):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.num_private = config.num_private_experts
|
| 69 |
+
self.shared_expert = Expert(config.d_model, config.d_ff)
|
| 70 |
+
self.private_experts = nn.ModuleList([
|
| 71 |
+
Expert(config.d_model, config.d_ff) for _ in range(self.num_private)
|
| 72 |
+
])
|
| 73 |
+
self.router = nn.Linear(config.d_model, self.num_private, bias=False)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
shared_out = self.shared_expert(x)
|
| 77 |
+
|
| 78 |
+
router_logits = self.router(x)
|
| 79 |
+
routing_weights = F.softmax(router_logits, dim=-1)
|
| 80 |
+
top1_weights, top1_indices = torch.topk(routing_weights, k=1, dim=-1)
|
| 81 |
+
|
| 82 |
+
B, S, C = x.shape
|
| 83 |
+
flat_x = x.view(-1, C)
|
| 84 |
+
flat_indices = top1_indices.view(-1)
|
| 85 |
+
flat_weights = top1_weights.view(-1, 1)
|
| 86 |
+
|
| 87 |
+
private_out = torch.zeros_like(flat_x)
|
| 88 |
+
for i in range(self.num_private):
|
| 89 |
+
mask = (flat_indices == i)
|
| 90 |
+
if mask.any():
|
| 91 |
+
expert_in = flat_x[mask]
|
| 92 |
+
expert_out = self.private_experts[i](expert_in)
|
| 93 |
+
private_out[mask] = expert_out * flat_weights[mask]
|
| 94 |
+
|
| 95 |
+
private_out = private_out.view(B, S, C)
|
| 96 |
+
return shared_out + private_out
|
| 97 |
+
|
| 98 |
+
class DecoderLayer(nn.Module):
|
| 99 |
+
def __init__(self, config):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.active_rate = 1.0 - config.bypass_rate
|
| 102 |
+
self.mod_router = nn.Linear(config.d_model, 1, bias=False)
|
| 103 |
+
|
| 104 |
+
self.pre_rmsnorm = RMSNorm(config.d_model)
|
| 105 |
+
self.mla_block = MLA(config)
|
| 106 |
+
self.ssp_moe_layer = SSPMoE(config)
|
| 107 |
+
self.post_rmsnorm = RMSNorm(config.d_model)
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
B, S, C = x.shape
|
| 111 |
+
if S < 2:
|
| 112 |
+
h = self.pre_rmsnorm(x)
|
| 113 |
+
h = h + self.mla_block(h)
|
| 114 |
+
h = h + self.ssp_moe_layer(h)
|
| 115 |
+
return self.post_rmsnorm(h)
|
| 116 |
+
|
| 117 |
+
router_logits = self.mod_router(x).squeeze(-1)
|
| 118 |
+
k = int(S * self.active_rate)
|
| 119 |
+
k = max(1, min(k, S))
|
| 120 |
+
|
| 121 |
+
_, topk_indices = torch.topk(router_logits, k, dim=-1)
|
| 122 |
+
out = x.clone()
|
| 123 |
+
|
| 124 |
+
for b in range(B):
|
| 125 |
+
active_idx = topk_indices[b]
|
| 126 |
+
x_active = x[b, active_idx, :].unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
h = self.pre_rmsnorm(x_active)
|
| 129 |
+
h = h + self.mla_block(h)
|
| 130 |
+
h = h + self.ssp_moe_layer(h)
|
| 131 |
+
h = self.post_rmsnorm(h)
|
| 132 |
+
|
| 133 |
+
out[b, active_idx, :] = h.squeeze(0)
|
| 134 |
+
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
class UltraBasePreTrainedModel(PreTrainedModel):
|
| 138 |
+
config_class = UltraBaseConfig
|
| 139 |
+
base_model_prefix = "model"
|
| 140 |
+
supports_gradient_checkpointing = True
|
| 141 |
+
|
| 142 |
+
def _init_weights(self, module):
|
| 143 |
+
if isinstance(module, nn.Linear):
|
| 144 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 145 |
+
if module.bias is not None:
|
| 146 |
+
torch.nn.init.zeros_(module.bias)
|
| 147 |
+
elif isinstance(module, nn.Embedding):
|
| 148 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 149 |
+
|
| 150 |
+
class UltraBaseForCausalLM(PreTrainedModel, GenerationMixin):
|
| 151 |
+
def __init__(self, config):
|
| 152 |
+
super().__init__(config)
|
| 153 |
+
self.embed = nn.Embedding(config.vocab_size, config.d_model)
|
| 154 |
+
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)])
|
| 155 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 156 |
+
|
| 157 |
+
self.post_init()
|
| 158 |
+
|
| 159 |
+
def get_input_embeddings(self):
|
| 160 |
+
return self.embed
|
| 161 |
+
|
| 162 |
+
def set_input_embeddings(self, value):
|
| 163 |
+
self.embed = value
|
| 164 |
+
|
| 165 |
+
def get_output_embeddings(self):
|
| 166 |
+
return self.lm_head
|
| 167 |
+
|
| 168 |
+
def set_output_embeddings(self, new_embeddings):
|
| 169 |
+
self.lm_head = new_embeddings
|
| 170 |
+
|
| 171 |
+
def forward(self, input_ids, labels=None, **kwargs):
|
| 172 |
+
x = self.embed(input_ids)
|
| 173 |
+
for layer in self.layers:
|
| 174 |
+
x = layer(x)
|
| 175 |
+
logits = self.lm_head(x)
|
| 176 |
+
|
| 177 |
+
loss = None
|
| 178 |
+
if labels is not None:
|
| 179 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 180 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 181 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 182 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 183 |
+
|
| 184 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits)
|
| 185 |
+
|
| 186 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 187 |
+
return {"input_ids": input_ids}
|
tokenizer (1).json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|