|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from typing import Optional, Tuple
|
| from transformers import PreTrainedModel, PretrainedConfig
|
| from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
| class JiRackMoEConfig(PretrainedConfig):
|
| model_type = "jirack_moe_transformer"
|
| def __init__(
|
| self,
|
| vocab_size=128256,
|
| hidden_size=12288,
|
| num_hidden_layers=112,
|
| num_attention_heads=96,
|
| num_experts=16,
|
| num_experts_per_tok=2,
|
| intermediate_size=12288,
|
| max_position_embeddings=32768,
|
| rms_norm_eps=1e-5,
|
| author="Author: Konstantin Vladimirovich Grabko (CMS Manhattan) 2025",
|
| **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
| self.vocab_size = vocab_size
|
| self.hidden_size = hidden_size
|
| self.num_hidden_layers = num_hidden_layers
|
| self.num_attention_heads = num_attention_heads
|
| self.num_experts = num_experts
|
| self.num_experts_per_tok = num_experts_per_tok
|
| self.intermediate_size = intermediate_size
|
| self.max_position_embeddings = max_position_embeddings
|
| self.rms_norm_eps = rms_norm_eps
|
| self.author = author
|
|
|
|
|
|
|
| class JiRackBitLinear(nn.Linear):
|
| def __init__(self, in_features, out_features, bias=False):
|
| super().__init__(in_features, out_features, bias)
|
| nn.init.normal_(self.weight, std=0.02)
|
| def forward(self, x):
|
| w = self.weight
|
| gamma = w.abs().mean() + 1e-9
|
| w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
|
| w_final = w + (w_quant * gamma - w).detach()
|
| x_norm = x - x.mean(dim=-1, keepdim=True)
|
| x_quant = x_norm + (torch.clamp(x_norm, -1.2, 1.2) - x_norm).detach()
|
| return F.linear(x_quant, w_final, self.bias)
|
|
|
|
|
|
|
| class JiRackExpert(nn.Module):
|
| def __init__(self, config: JiRackMoEConfig):
|
| super().__init__()
|
| self.w1 = JiRackBitLinear(config.hidden_size, config.intermediate_size)
|
| self.w2 = JiRackBitLinear(config.intermediate_size, config.hidden_size)
|
| self.w3 = JiRackBitLinear(config.hidden_size, config.intermediate_size)
|
| def forward(self, x):
|
| return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
|
| class JiRackMoEGate(nn.Module):
|
| def __init__(self, config: JiRackMoEConfig):
|
| super().__init__()
|
| self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| self.top_k = config.num_experts_per_tok
|
| def forward(self, x):
|
| logits = self.gate(x)
|
| weights, indices = torch.topk(logits, self.top_k, dim=-1)
|
| weights = F.softmax(weights.float(), dim=-1).type_as(x)
|
| return weights, indices
|
|
|
|
|
|
|
| class SignatureLayer(nn.Module):
|
| def __init__(self, dim, author_name):
|
| super().__init__()
|
| self.gate = nn.Parameter(torch.ones(dim))
|
| seed = sum(ord(c) for c in author_name)
|
| torch.manual_seed(seed)
|
| self.signage_cms = nn.Parameter(torch.randn(dim, dim) * 0.001)
|
| def forward(self, x):
|
| sig = torch.tanh(F.linear(x, self.signage_cms))
|
| return x * torch.sigmoid(self.gate) + sig
|
|
|
|
|
|
|
| class JiRackMoEBlock(nn.Module):
|
| def __init__(self, config: JiRackMoEConfig):
|
| super().__init__()
|
| self.ln1 = nn.LayerNorm(config.hidden_size)
|
| self.attn = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
| self.ln2 = nn.LayerNorm(config.hidden_size)
|
|
|
|
|
| self.gate = JiRackMoEGate(config)
|
| self.experts = nn.ModuleList([JiRackExpert(config) for _ in range(config.num_experts)])
|
|
|
| self.signature = SignatureLayer(config.hidden_size, config.author)
|
|
|
| def forward(self, x):
|
|
|
| residual = x
|
| x = self.ln1(x)
|
| x, _ = self.attn(x, x, x, need_weights=False)
|
| x = residual + x
|
|
|
|
|
| residual = x
|
| x = self.ln2(x)
|
| weights, indices = self.gate(x)
|
|
|
| final_output = torch.zeros_like(x)
|
|
|
| for i, expert in enumerate(self.experts):
|
| mask = (indices == i).any(dim=-1)
|
| if mask.any():
|
| expert_output = expert(x[mask])
|
|
|
| w = weights[indices == i].unsqueeze(-1)
|
| final_output[mask] += expert_output * w
|
|
|
| x = residual + self.signature(final_output)
|
| return x
|
|
|
|
|
|
|
| class JiRackTernaryMoE405B(PreTrainedModel):
|
| config_class = JiRackMoEConfig
|
| def __init__(self, config: JiRackMoEConfig):
|
| super().__init__(config)
|
| self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
|
| self.blocks = nn.ModuleList([JiRackMoEBlock(config) for _ in range(config.num_hidden_layers)])
|
| self.ln_f = nn.LayerNorm(config.hidden_size)
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
| self.post_init()
|
|
|
| def forward(self, input_ids, labels=None):
|
| x = self.token_emb(input_ids)
|
| for block in self.blocks:
|
| if self.training:
|
| x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| else:
|
| x = block(x)
|
|
|
| logits = self.lm_head(self.ln_f(x))
|
| loss = None
|
| if labels is not None:
|
| loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits) |