license: apache-2.0
language:
- en
tags:
- causal-lm
- pytorch
- custom-architecture
- mla
- swiglu
- chat
- instruction-following
pipeline_tag: text-generation
library_name: pytorch
datasets:
- HuggingFaceFW/fineweb
- HuggingFaceFW/fineweb-edu
- bigcode/starcoderdata
- HuggingFaceH4/ultrachat_200k
Zenyx-Chat
Zenyx-Chat is a 214M-parameter causal language model designed for conversational and instruction-following tasks. It is trained from scratch using a custom architecture featuring Multi-head Latent Attention (MLA) and SwiGLU feedforward layers, trained on a curated mix of web, code, and chat datasets.
⚠️ Model is actively training. Evaluation metrics will be added as training progresses.
Model Details
| Property | Value |
|---|---|
| Architecture | Custom Decoder-only Transformer |
| Parameters | ~214M (base) |
| Layers | 16 |
| Hidden Dimension | 1024 |
| Attention Heads | 16 |
| KV Latent Dimension | 256 (MLA compression) |
| MLP Type | SwiGLU |
| Positional Encoding | RoPE (θ = 500,000) |
| Context Length | 2,048 tokens |
| Vocabulary Size | 32,768 |
| Tokenizer | Arko007/zenyx-v2-tokenizer |
| Precision | FP16 (trained), FP32 (inference) |
| Framework | PyTorch |
Architecture
Zenyx-Chat is built on a custom transformer decoder with the following key design choices:
Multi-head Latent Attention (MLA): Instead of standard key-value projections, KV representations are compressed into a low-dimensional latent space (KV_LATENT_DIM=256) before being projected back to full dimension. This reduces the KV footprint during training while preserving expressiveness.
SwiGLU FFN: Each block uses a gated feedforward layer with the SiLU activation on the gate path and a separate up-projection, following the formulation from PaLM. The hidden dimension is set to int(2 × 1024 × 4/3) = 2730.
RMSNorm: Pre-normalization is applied using RMSNorm before both the attention and feedforward sublayers, with no bias terms throughout the network.
Weight Tying: The token embedding matrix and the LM head share weights, reducing parameter count and improving training stability.
Multi-Token Prediction (MTP): During training, 2 auxiliary prediction heads supervise the model to predict 2 and 3 tokens ahead simultaneously, improving representation quality. These heads are not used during inference.
Training
Data Mix
| Dataset | Proportion | Purpose |
|---|---|---|
HuggingFaceFW/fineweb-edu (10BT sample) |
40% | High-quality educational web text |
HuggingFaceFW/fineweb (350BT sample) |
25% | Broad general web text |
HuggingFaceH4/ultrachat_200k |
20% | Multi-turn chat / instruction following |
bigcode/starcoderdata (Python) |
15% | Python code |
Training Configuration
| Hyperparameter | Value |
|---|---|
| Max Steps | 50,000 |
| Sequence Length | 2,048 |
| Micro Batch Size | 4 |
| Gradient Accumulation | 8 |
| Effective Batch | 64 seqs / step (2 GPUs) |
| Learning Rate | 3e-4 |
| LR Schedule | Cosine with warmup |
| Warmup Steps | 2,000 |
| Weight Decay | 0.1 |
| Grad Clip | 1.0 |
| Optimizer | AdamW (β₁=0.9, β₂=0.999, ε=1e-6) |
| Precision | FP16 + GradScaler |
| Hardware | 2× NVIDIA T4 (16GB) |
| Gradient Checkpointing | Yes (per-layer) |
Usage
Installation
pip install torch transformers huggingface_hub
# Inference Script
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast
from huggingface_hub import hf_hub_download
import math
# --- CONFIG ---
SEQ_LEN = 2048
D_MODEL = 1024
N_LAYERS = 16
N_HEADS = 16
KV_LATENT_DIM = 256
VOCAB_SIZE = 32768
# --- ARCHITECTURE ---
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def precompute_rope(dim, seq_len, theta=500000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(seq_len)
freqs = torch.outer(t, freqs).float()
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rope(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
f = freqs_cis.unsqueeze(0).unsqueeze(2)
return (torch.view_as_real(xq_ * f).flatten(3).type_as(xq),
torch.view_as_real(xk_ * f).flatten(3).type_as(xk))
class MultiHeadLatentAttention(nn.Module):
def __init__(self):
super().__init__()
self.d_head = D_MODEL // N_HEADS
self.q_proj = nn.Linear(D_MODEL, D_MODEL, bias=False)
self.kv_down = nn.Linear(D_MODEL, KV_LATENT_DIM, bias=False)
self.kv_up_key = nn.Linear(KV_LATENT_DIM, D_MODEL, bias=False)
self.kv_up_val = nn.Linear(KV_LATENT_DIM, D_MODEL, bias=False)
self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=False)
def forward(self, x, freqs_cis):
B, T, C = x.size()
q = self.q_proj(x).view(B, T, N_HEADS, self.d_head)
kv = self.kv_down(x)
k = self.kv_up_key(kv).view(B, T, N_HEADS, self.d_head)
v = self.kv_up_val(kv).view(B, T, N_HEADS, self.d_head)
q, k = apply_rope(q, k, freqs_cis[:T])
q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.o_proj(y.transpose(1,2).contiguous().view(B, T, C))
class SwiGLU(nn.Module):
def __init__(self):
super().__init__()
h = int(2 * D_MODEL * 4 / 3)
self.gate = nn.Linear(D_MODEL, h, bias=False)
self.up = nn.Linear(D_MODEL, h, bias=False)
self.down = nn.Linear(h, D_MODEL, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.ln_1 = RMSNorm(D_MODEL)
self.attn = MultiHeadLatentAttention()
self.ln_2 = RMSNorm(D_MODEL)
self.mlp = SwiGLU()
def forward(self, x, freqs_cis):
x = x + self.attn(self.ln_1(x), freqs_cis)
x = x + self.mlp(self.ln_2(x))
return x
class CustomLLM(nn.Module):
def __init__(self):
super().__init__()
self.token_emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
self.layers = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
self.ln_f = RMSNorm(D_MODEL)
self.lm_head = nn.Linear(D_MODEL, VOCAB_SIZE, bias=False)
self.mtp_heads = nn.ModuleList([
nn.Linear(D_MODEL, VOCAB_SIZE, bias=False) for _ in range(2)
])
self.register_buffer("freqs_cis", precompute_rope(D_MODEL // N_HEADS, SEQ_LEN))
def forward(self, input_ids):
x = self.token_emb(input_ids)
for layer in self.layers:
x = layer(x, self.freqs_cis)
x = self.ln_f(x)
return self.lm_head(x)
# --- LOAD ---
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = PreTrainedTokenizerFast.from_pretrained("Arko007/zenyx-v2-tokenizer")
weights_path = hf_hub_download(repo_id="koyelog/chatbotk", filename="pytorch_model.bin")
state_dict = torch.load(weights_path, map_location=device)
model = CustomLLM().to(device)
model.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
model.eval()
print("Model loaded!")
# --- GENERATE ---
def generate(prompt, max_new_tokens=200, temperature=0.8, repetition_penalty=1.2):
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
prompt_len = input_ids.shape
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(input_ids[:, -SEQ_LEN:])
logits = logits[:, -1, :] / temperature
for token_id in set(input_ids.tolist()):
logits[0, token_id] = (
logits[0, token_id] * repetition_penalty
if logits[0, token_id] < 0
else logits[0, token_id] / repetition_penalty
)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0, prompt_len:].cpu().numpy())
print(generate("Hello, how are you?"))
Generation Parameters
| Parameter | Default | Effect |
|---|---|---|
| temperature | 0.8 | Controls randomness. Lower = more focused. |
| repetition_penalty | 1.2 | Penalizes already-seen tokens. |
| max_new_tokens | 200 | Maximum tokens to generate. |
Limitations & Intended Use
Intended Use: Research, experimentation, and educational exploration of custom LLM architectures. Not intended for production use or safety-critical applications.
Limitations: This model is undertrained relative to production-grade LLMs. It may produce incoherent, factually incorrect, or biased outputs. Metrics will be added as training matures.
Not instruction-tuned via RLHF: The chat capability comes purely from data mix (UltraChat), with no reinforcement learning from human feedback.
Language: English only.
Citation
If you use this model or find the architecture useful, please cite:
@misc{chatbotk-2026,
author = {koyelog},
title = {chatbotk: A Custom 281M Causal LM with MLA and SwiGLU},
year = {2026},
publisher = {Hugging Face},
url = {https://huggingface.co/koyelog/chatbotk}
}
License
- Apache 2.0