Files changed (1) hide show
  1. README.md +273 -3
README.md CHANGED
@@ -1,3 +1,273 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - causal-lm
7
+ - pytorch
8
+ - custom-architecture
9
+ - mla
10
+ - swiglu
11
+ - chat
12
+ - instruction-following
13
+ pipeline_tag: text-generation
14
+ library_name: pytorch
15
+ datasets:
16
+ - HuggingFaceFW/fineweb
17
+ - HuggingFaceFW/fineweb-edu
18
+ - bigcode/starcoderdata
19
+ - HuggingFaceH4/ultrachat_200k
20
+ ---
21
+
22
+ # Zenyx-Chat
23
+
24
+ **Zenyx-Chat** is a 214M-parameter causal language model designed for conversational and instruction-following tasks. It is trained from scratch using a custom architecture featuring Multi-head Latent Attention (MLA) and SwiGLU feedforward layers, trained on a curated mix of web, code, and chat datasets.
25
+
26
+ > ⚠️ **Model is actively training.** Evaluation metrics will be added as training progresses.
27
+
28
+ ---
29
+
30
+ ## Model Details
31
+
32
+ | Property | Value |
33
+ |---|---|
34
+ | **Architecture** | Custom Decoder-only Transformer |
35
+ | **Parameters** | ~214M (base) |
36
+ | **Layers** | 16 |
37
+ | **Hidden Dimension** | 1024 |
38
+ | **Attention Heads** | 16 |
39
+ | **KV Latent Dimension** | 256 (MLA compression) |
40
+ | **MLP Type** | SwiGLU |
41
+ | **Positional Encoding** | RoPE (θ = 500,000) |
42
+ | **Context Length** | 2,048 tokens |
43
+ | **Vocabulary Size** | 32,768 |
44
+ | **Tokenizer** | `Arko007/zenyx-v2-tokenizer` |
45
+ | **Precision** | FP16 (trained), FP32 (inference) |
46
+ | **Framework** | PyTorch |
47
+
48
+ ---
49
+
50
+ ## Architecture
51
+
52
+ Zenyx-Chat is built on a custom transformer decoder with the following key design choices:
53
+
54
+ **Multi-head Latent Attention (MLA):** Instead of standard key-value projections, KV representations are compressed into a low-dimensional latent space (`KV_LATENT_DIM=256`) before being projected back to full dimension. This reduces the KV footprint during training while preserving expressiveness.
55
+
56
+ **SwiGLU FFN:** Each block uses a gated feedforward layer with the SiLU activation on the gate path and a separate up-projection, following the formulation from [PaLM](https://arxiv.org/abs/2204.02311). The hidden dimension is set to `int(2 × 1024 × 4/3) = 2730`.
57
+
58
+ **RMSNorm:** Pre-normalization is applied using RMSNorm before both the attention and feedforward sublayers, with no bias terms throughout the network.
59
+
60
+ **Weight Tying:** The token embedding matrix and the LM head share weights, reducing parameter count and improving training stability.
61
+
62
+ **Multi-Token Prediction (MTP):** During training, 2 auxiliary prediction heads supervise the model to predict 2 and 3 tokens ahead simultaneously, improving representation quality. These heads are not used during inference.
63
+
64
+ ---
65
+
66
+ ## Training
67
+
68
+ ### Data Mix
69
+
70
+ | Dataset | Proportion | Purpose |
71
+ |---|---|---|
72
+ | `HuggingFaceFW/fineweb-edu` (10BT sample) | 40% | High-quality educational web text |
73
+ | `HuggingFaceFW/fineweb` (350BT sample) | 25% | Broad general web text |
74
+ | `HuggingFaceH4/ultrachat_200k` | 20% | Multi-turn chat / instruction following |
75
+ | `bigcode/starcoderdata` (Python) | 15% | Python code |
76
+
77
+ ### Training Configuration
78
+
79
+ | Hyperparameter | Value |
80
+ |---|---|
81
+ | Max Steps | 50,000 |
82
+ | Sequence Length | 2,048 |
83
+ | Micro Batch Size | 4 |
84
+ | Gradient Accumulation | 8 |
85
+ | Effective Batch | 64 seqs / step (2 GPUs) |
86
+ | Learning Rate | 3e-4 |
87
+ | LR Schedule | Cosine with warmup |
88
+ | Warmup Steps | 2,000 |
89
+ | Weight Decay | 0.1 |
90
+ | Grad Clip | 1.0 |
91
+ | Optimizer | AdamW (β₁=0.9, β₂=0.999, ε=1e-6) |
92
+ | Precision | FP16 + GradScaler |
93
+ | Hardware | 2× NVIDIA T4 (16GB) |
94
+ | Gradient Checkpointing | Yes (per-layer) |
95
+
96
+ ---
97
+
98
+ ## Usage
99
+
100
+ ### Installation
101
+
102
+ ```bash
103
+ pip install torch transformers huggingface_hub
104
+ ```
105
+
106
+ ```python
107
+ # Inference Script
108
+
109
+ import torch
110
+ import torch.nn as nn
111
+ import torch.nn.functional as F
112
+ from transformers import PreTrainedTokenizerFast
113
+ from huggingface_hub import hf_hub_download
114
+ import math
115
+
116
+ # --- CONFIG ---
117
+ SEQ_LEN = 2048
118
+ D_MODEL = 1024
119
+ N_LAYERS = 16
120
+ N_HEADS = 16
121
+ KV_LATENT_DIM = 256
122
+ VOCAB_SIZE = 32768
123
+
124
+ # --- ARCHITECTURE ---
125
+ class RMSNorm(nn.Module):
126
+ def __init__(self, dim, eps=1e-6):
127
+ super().__init__()
128
+ self.eps = eps
129
+ self.weight = nn.Parameter(torch.ones(dim))
130
+ def forward(self, x):
131
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
132
+
133
+ def precompute_rope(dim, seq_len, theta=500000.0):
134
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
135
+ t = torch.arange(seq_len)
136
+ freqs = torch.outer(t, freqs).float()
137
+ return torch.polar(torch.ones_like(freqs), freqs)
138
+
139
+ def apply_rope(xq, xk, freqs_cis):
140
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
141
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
142
+ f = freqs_cis.unsqueeze(0).unsqueeze(2)
143
+ return (torch.view_as_real(xq_ * f).flatten(3).type_as(xq),
144
+ torch.view_as_real(xk_ * f).flatten(3).type_as(xk))
145
+
146
+ class MultiHeadLatentAttention(nn.Module):
147
+ def __init__(self):
148
+ super().__init__()
149
+ self.d_head = D_MODEL // N_HEADS
150
+ self.q_proj = nn.Linear(D_MODEL, D_MODEL, bias=False)
151
+ self.kv_down = nn.Linear(D_MODEL, KV_LATENT_DIM, bias=False)
152
+ self.kv_up_key = nn.Linear(KV_LATENT_DIM, D_MODEL, bias=False)
153
+ self.kv_up_val = nn.Linear(KV_LATENT_DIM, D_MODEL, bias=False)
154
+ self.o_proj = nn.Linear(D_MODEL, D_MODEL, bias=False)
155
+ def forward(self, x, freqs_cis):
156
+ B, T, C = x.size()
157
+ q = self.q_proj(x).view(B, T, N_HEADS, self.d_head)
158
+ kv = self.kv_down(x)
159
+ k = self.kv_up_key(kv).view(B, T, N_HEADS, self.d_head)
160
+ v = self.kv_up_val(kv).view(B, T, N_HEADS, self.d_head)
161
+ q, k = apply_rope(q, k, freqs_cis[:T])
162
+ q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
163
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
164
+ return self.o_proj(y.transpose(1,2).contiguous().view(B, T, C))
165
+
166
+ class SwiGLU(nn.Module):
167
+ def __init__(self):
168
+ super().__init__()
169
+ h = int(2 * D_MODEL * 4 / 3)
170
+ self.gate = nn.Linear(D_MODEL, h, bias=False)
171
+ self.up = nn.Linear(D_MODEL, h, bias=False)
172
+ self.down = nn.Linear(h, D_MODEL, bias=False)
173
+ def forward(self, x):
174
+ return self.down(F.silu(self.gate(x)) * self.up(x))
175
+
176
+ class TransformerBlock(nn.Module):
177
+ def __init__(self):
178
+ super().__init__()
179
+ self.ln_1 = RMSNorm(D_MODEL)
180
+ self.attn = MultiHeadLatentAttention()
181
+ self.ln_2 = RMSNorm(D_MODEL)
182
+ self.mlp = SwiGLU()
183
+ def forward(self, x, freqs_cis):
184
+ x = x + self.attn(self.ln_1(x), freqs_cis)
185
+ x = x + self.mlp(self.ln_2(x))
186
+ return x
187
+
188
+ class CustomLLM(nn.Module):
189
+ def __init__(self):
190
+ super().__init__()
191
+ self.token_emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
192
+ self.layers = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
193
+ self.ln_f = RMSNorm(D_MODEL)
194
+ self.lm_head = nn.Linear(D_MODEL, VOCAB_SIZE, bias=False)
195
+ self.mtp_heads = nn.ModuleList([
196
+ nn.Linear(D_MODEL, VOCAB_SIZE, bias=False) for _ in range(2)
197
+ ])
198
+ self.register_buffer("freqs_cis", precompute_rope(D_MODEL // N_HEADS, SEQ_LEN))
199
+ def forward(self, input_ids):
200
+ x = self.token_emb(input_ids)
201
+ for layer in self.layers:
202
+ x = layer(x, self.freqs_cis)
203
+ x = self.ln_f(x)
204
+ return self.lm_head(x)
205
+
206
+ # --- LOAD ---
207
+ device = "cuda" if torch.cuda.is_available() else "cpu"
208
+
209
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("Arko007/zenyx-v2-tokenizer")
210
+ weights_path = hf_hub_download(repo_id="koyelog/chatbotk", filename="pytorch_model.bin")
211
+ state_dict = torch.load(weights_path, map_location=device)
212
+
213
+ model = CustomLLM().to(device)
214
+ model.load_state_dict(state_dict["model"] if "model" in state_dict else state_dict)
215
+ model.eval()
216
+ print("Model loaded!")
217
+
218
+ # --- GENERATE ---
219
+ def generate(prompt, max_new_tokens=200, temperature=0.8, repetition_penalty=1.2):
220
+ input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
221
+ prompt_len = input_ids.shape
222
+ for _ in range(max_new_tokens):
223
+ with torch.no_grad():
224
+ logits = model(input_ids[:, -SEQ_LEN:])
225
+ logits = logits[:, -1, :] / temperature
226
+ for token_id in set(input_ids.tolist()):
227
+ logits[0, token_id] = (
228
+ logits[0, token_id] * repetition_penalty
229
+ if logits[0, token_id] < 0
230
+ else logits[0, token_id] / repetition_penalty
231
+ )
232
+ probs = F.softmax(logits, dim=-1)
233
+ next_token = torch.multinomial(probs, num_samples=1)
234
+ input_ids = torch.cat([input_ids, next_token], dim=1)
235
+ if next_token.item() == tokenizer.eos_token_id:
236
+ break
237
+ return tokenizer.decode(input_ids[0, prompt_len:].cpu().numpy())
238
+
239
+ print(generate("Hello, how are you?"))
240
+ ```
241
+ ## Generation Parameters
242
+
243
+ | Parameter | Default | Effect |
244
+ | ------------------ | ------- | ------------------------------------------ |
245
+ | temperature | 0.8 | Controls randomness. Lower = more focused. |
246
+ | repetition_penalty | 1.2 | Penalizes already-seen tokens. |
247
+ | max_new_tokens | 200 | Maximum tokens to generate. |
248
+
249
+ ## Limitations & Intended Use
250
+
251
+ - **Intended Use**: Research, experimentation, and educational exploration of custom LLM architectures. Not intended for production use or safety-critical applications.
252
+
253
+ - **Limitations**: This model is undertrained relative to production-grade LLMs. It may produce incoherent, factually incorrect, or biased outputs. Metrics will be added as training matures.
254
+
255
+ - **Not instruction-tuned via RLHF**: The chat capability comes purely from data mix (UltraChat), with no reinforcement learning from human feedback.
256
+
257
+ - **Language**: English only.
258
+
259
+ ## Citation
260
+
261
+ If you use this model or find the architecture useful, please cite:
262
+
263
+ ```bash
264
+ @misc{chatbotk-2026,
265
+ author = {koyelog},
266
+ title = {chatbotk: A Custom 281M Causal LM with MLA and SwiGLU},
267
+ year = {2026},
268
+ publisher = {Hugging Face},
269
+ url = {https://huggingface.co/koyelog/chatbotk}
270
+ }
271
+ ```
272
+ ## License
273
+ - **Apache 2.0**