chatbotk / README.md
Arko007's picture
Update README.md
fc5b71b verified
|
raw
history blame
10.1 kB
metadata
license: apache-2.0
language:
  - en
tags:
  - causal-lm
  - pytorch
  - custom-architecture
  - mla
  - swiglu
  - chat
  - instruction-following
pipeline_tag: text-generation
library_name: pytorch
datasets:
  - HuggingFaceFW/fineweb
  - HuggingFaceFW/fineweb-edu
  - bigcode/starcoderdata
  - HuggingFaceH4/ultrachat_200k

Zenyx-Chat

Zenyx-Chat is a 214M-parameter causal language model designed for conversational and instruction-following tasks. It is trained from scratch using a custom architecture featuring Multi-head Latent Attention (MLA) and SwiGLU feedforward layers, trained on a curated mix of web, code, and chat datasets.

⚠️ Model is actively training. Evaluation metrics will be added as training progresses.


Model Details

Property Value
Architecture Custom Decoder-only Transformer
Parameters ~214M (base)
Layers 16
Hidden Dimension 1024
Attention Heads 16
KV Latent Dimension 256 (MLA compression)
MLP Type SwiGLU
Positional Encoding RoPE (θ = 500,000)
Context Length 2,048 tokens
Vocabulary Size 32,768
Tokenizer Arko007/zenyx-v2-tokenizer
Precision FP16 (trained), FP32 (inference)
Framework PyTorch

Architecture

Zenyx-Chat is built on a custom transformer decoder with the following key design choices:

Multi-head Latent Attention (MLA): Instead of standard key-value projections, KV representations are compressed into a low-dimensional latent space (KV_LATENT_DIM=256) before being projected back to full dimension. This reduces the KV footprint during training while preserving expressiveness.

SwiGLU FFN: Each block uses a gated feedforward layer with the SiLU activation on the gate path and a separate up-projection, following the formulation from PaLM. The hidden dimension is set to int(2 × 1024 × 4/3) = 2730.

RMSNorm: Pre-normalization is applied using RMSNorm before both the attention and feedforward sublayers, with no bias terms throughout the network.

Weight Tying: The token embedding matrix and the LM head share weights, reducing parameter count and improving training stability.

Multi-Token Prediction (MTP): During training, 2 auxiliary prediction heads supervise the model to predict 2 and 3 tokens ahead simultaneously, improving representation quality. These heads are not used during inference.


Training

Data Mix

Dataset Proportion Purpose
HuggingFaceFW/fineweb-edu (10BT sample) 40% High-quality educational web text
HuggingFaceFW/fineweb (350BT sample) 25% Broad general web text
HuggingFaceH4/ultrachat_200k 20% Multi-turn chat / instruction following
bigcode/starcoderdata (Python) 15% Python code

Training Configuration

Hyperparameter Value
Max Steps 50,000
Sequence Length 2,048
Micro Batch Size 4
Gradient Accumulation 8
Effective Batch 64 seqs / step (2 GPUs)
Learning Rate 3e-4
LR Schedule Cosine with warmup
Warmup Steps 2,000
Weight Decay 0.1
Grad Clip 1.0
Optimizer AdamW (β₁=0.9, β₂=0.999, ε=1e-6)
Precision FP16 + GradScaler
Hardware 2× NVIDIA T4 (16GB)
Gradient Checkpointing Yes (per-layer)

Usage

Installation

pip install torch transformers huggingface_hub
# Inference Script

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast
from huggingface_hub import hf_hub_download
import math

# --- CONFIG ---
SEQ_LEN       = 2048
D_MODEL       = 1024
N_LAYERS      = 16
N_HEADS       = 16
KV_LATENT_DIM = 256
VOCAB_SIZE    = 32768

# --- ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps    = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def precompute_rope(dim, seq_len, theta=500000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t     = torch.arange(seq_len)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rope(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    f   = freqs_cis.unsqueeze(0).unsqueeze(2)
    return (torch.view_as_real(xq_ * f).flatten(3).type_as(xq),
            torch.view_as_real(xk_ * f).flatten(3).type_as(xk))

class MultiHeadLatentAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_head    = D_MODEL // N_HEADS
        self.q_proj    = nn.Linear(D_MODEL, D_MODEL,       bias=False)
        self.kv_down   = nn.Linear(D_MODEL, KV_LATENT_DIM, bias=False)
        self.kv_up_key = nn.Linear(KV_LATENT_DIM, D_MODEL, bias=False)
        self.kv_up_val = nn.Linear(KV_LATENT_DIM, D_MODEL, bias=False)
        self.o_proj    = nn.Linear(D_MODEL, D_MODEL,       bias=False)
    def forward(self, x, freqs_cis):
        B, T, C = x.size()
        q  = self.q_proj(x).view(B, T, N_HEADS, self.d_head)
        kv = self.kv_down(x)
        k  = self.kv_up_key(kv).view(B, T, N_HEADS, self.d_head)
        v  = self.kv_up_val(kv).view(B, T, N_HEADS, self.d_head)
        q, k = apply_rope(q, k, freqs_cis[:T])
        q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.o_proj(y.transpose(1,2).contiguous().view(B, T, C))

class SwiGLU(nn.Module):
    def __init__(self):
        super().__init__()
        h         = int(2 * D_MODEL * 4 / 3)
        self.gate = nn.Linear(D_MODEL, h, bias=False)
        self.up   = nn.Linear(D_MODEL, h, bias=False)
        self.down = nn.Linear(h, D_MODEL, bias=False)
    def forward(self, x):
        return self.down(F.silu(self.gate(x)) * self.up(x))

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln_1 = RMSNorm(D_MODEL)
        self.attn = MultiHeadLatentAttention()
        self.ln_2 = RMSNorm(D_MODEL)
        self.mlp  = SwiGLU()
    def forward(self, x, freqs_cis):
        x = x + self.attn(self.ln_1(x), freqs_cis)
        x = x + self.mlp(self.ln_2(x))
        return x

class CustomLLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.layers    = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
        self.ln_f      = RMSNorm(D_MODEL)
        self.lm_head   = nn.Linear(D_MODEL, VOCAB_SIZE, bias=False)
        self.mtp_heads = nn.ModuleList([
            nn.Linear(D_MODEL, VOCAB_SIZE, bias=False) for _ in range(2)
        ])
        self.register_buffer("freqs_cis", precompute_rope(D_MODEL // N_HEADS, SEQ_LEN))
    def forward(self, input_ids):
        x = self.token_emb(input_ids)
        for layer in self.layers:
            x = layer(x, self.freqs_cis)
        x = self.ln_f(x)
        return self.lm_head(x)

# --- LOAD ---
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer    = PreTrainedTokenizerFast.from_pretrained("Arko007/zenyx-v2-tokenizer")
weights_path = hf_hub_download(repo_id="koyelog/chatbotk", filename="pytorch_model.bin")
state_dict   = torch.load(weights_path, map_location=device)

model = CustomLLM().to(device)
model.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
model.eval()
print("Model loaded!")

# --- GENERATE ---
def generate(prompt, max_new_tokens=200, temperature=0.8, repetition_penalty=1.2):
    input_ids  = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
    prompt_len = input_ids.shape
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(input_ids[:, -SEQ_LEN:])
            logits = logits[:, -1, :] / temperature
            for token_id in set(input_ids.tolist()):
                logits[0, token_id] = (
                    logits[0, token_id] * repetition_penalty
                    if logits[0, token_id] < 0
                    else logits[0, token_id] / repetition_penalty
                )
            probs      = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids  = torch.cat([input_ids, next_token], dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break
    return tokenizer.decode(input_ids[0, prompt_len:].cpu().numpy())

print(generate("Hello, how are you?"))

Generation Parameters

Parameter Default Effect
temperature 0.8 Controls randomness. Lower = more focused.
repetition_penalty 1.2 Penalizes already-seen tokens.
max_new_tokens 200 Maximum tokens to generate.

Limitations & Intended Use

  • Intended Use: Research, experimentation, and educational exploration of custom LLM architectures. Not intended for production use or safety-critical applications.

  • Limitations: This model is undertrained relative to production-grade LLMs. It may produce incoherent, factually incorrect, or biased outputs. Metrics will be added as training matures.

  • Not instruction-tuned via RLHF: The chat capability comes purely from data mix (UltraChat), with no reinforcement learning from human feedback.

  • Language: English only.

Citation

If you use this model or find the architecture useful, please cite:

@misc{chatbotk-2026,
  author    = {koyelog},
  title     = {chatbotk: A Custom 281M Causal LM with MLA and SwiGLU},
  year      = {2026},
  publisher = {Hugging Face},
  url       = {https://huggingface.co/koyelog/chatbotk}
}

License

  • Apache 2.0