File size: 9,308 Bytes
377cc15 c22b85c 7a75b28 c22b85c 377cc15 c22b85c 7a75b28 a9467db 7a75b28 377cc15 a61b9ff 377cc15 a61b9ff a9467db a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 c22b85c 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff a9467db a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff a9467db a61b9ff 377cc15 c22b85c a9467db 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a9467db a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff c22b85c 377cc15 a61b9ff c22b85c 377cc15 a61b9ff a9467db 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff c22b85c a61b9ff c22b85c a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff 377cc15 a61b9ff a9467db a61b9ff 377cc15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
---
language:
- en
license: apache-2.0
library_name: transformers
pipeline_tag: text-generation
tags:
- mixture-of-attentions
- distance-attention
- metric-attention
- mqa
- hyperffn
- router-gating
datasets:
- nvidia/Nemotron-Math-HumanReasoning
- WeMake/Intelligent-Content-Understanding
---
# MoAMetricLM‑100M — Mixture of Attentions (MoA)
**A geometry‑aware Transformer that mixes several attention mechanisms and routes them with a metric‑based router.**
- **Parameters:** ~185 M (≈ 100 M effective due to the mixture)
- **Task:** Causal language modeling (decoder‑only)
- **Library:** 🤗 Transformers
- **KV cache:** Not yet implemented (generation recomputes the full context at every step)
---
## Model card
| **Model ID** | `reaperdoesntknow/MoA-100M` |
|--------------|-------------------------------------|
| **Architecture** | `moa_metric` (custom) |
| **Tokenizer** | GPT‑2 (`gpt2`) – `pad_token` set to `eos_token` |
| **Context length** | 2048 tokens |
| **Training data** | 2 × ≈ 256 k tokens from the datasets listed above |
| **Training compute** | CPU‑only (Intel), FP32 |
| **Training hyper‑parameters** | LR = 5e‑4 (AdamW), batch = 4, seq ≤ 512, 500 k total tokens |
| **Final loss** | ≈ 0.30 (train) |
| **License** | Apache‑2.0 |
| **Safety** | No alignment or safety fine‑tuning – outputs may be biased or inaccurate. |
| **Intended use** | Research on geometry‑aware attention, structured sparsity, and mixture‑of‑attention models. |
| **Limitations** | • No KV‑cache → slower generation. <br>• Small token budget → not a general‑purpose LM. <br>• No safety/alignment training. |
| **Out‑of‑scope** | High‑stakes applications (medical, legal, etc.) without further evaluation. |
---
## Overview
MoA replaces the classic dot‑product attention with **metric‑based attention** and blends **four** distinct heads per Transformer block:
| Head type | Description |
|-----------|-------------|
| **LocalConvHead** | Depthwise‑separable 1‑D convolution → captures short‑range context. |
| **Metric Multi‑Head Attention (MetricMHAttention)** | Soft‑min over **L2 / cosine / diagonal‑Mahalanobis** distances: <br> \(\displaystyle \text{attn}_{h}(i,j) \propto \exp\!\big(-\alpha_h\|q_i-k_j\|^2\big)\) |
| **Metric MQA** | Multi‑Query attention (shared K/V) in the same metric space – cheaper than full MHA. |
| **ChannelMixHead** | Per‑token MLP that mixes channel dimensions (no positional mixing). |
A **token‑wise router** decides, for each token, which head(s) to use and applies **feature‑gates** (FiLM‑style) and **router‑bias gates** for up/down‑scaling.
The **FFN** is a **HyperFFN** – three parallel branches (SwiGLU MLP, separable‑conv, low‑rank) combined by a **branch router**. LayerScale and optional DropPath keep training stable.
### Regularisation (optional)
* **Triangle‑inequality (TI) penalty** on sampled triples to encourage true‑metric behaviour.
* **Ball pruning** – each head learns an **origin** \(o_h\) and **radius** \(r_h\); keys outside the ball are masked, giving structured sparsity.
---
## Architecture diagram (high‑level)
```
Input → Embedding → (PreNorm) → Block₁ → … → Blockₙ → LM‑Head → Output
│
├─ LocalConvHead
├─ MetricMHAttention
├─ MetricMQA
└─ ChannelMixHead
(router decides per‑token)
Each Block also contains:
→ HyperFFN (SwiGLU | Conv | Low‑rank) ← branch router
→ LayerScale + DropPath
```
---
## Configuration (example)
```json
{
"model_type": "moa_metric",
"vocab_size": 50257,
"dim": 768,
"num_layers": 12,
"attn_heads": 8,
"mqa_q_heads": 8,
"mixer_hidden": 3072,
"ffn_hidden": 3072,
"metric": "l2", // "l2" | "cosine" | "maha_diag"
"alpha_init": 1.0,
"learn_alpha": true,
"use_balls": true,
"radius_init": 3.0,
"learn_radius": true,
"origin_init_scale": 0.0,
"maha_init": 1.0,
"ti_reg_weight": 0.0,
"ti_reg_samples": 0,
"router_hidden": 128,
"router_dropout": 0.1,
"router_temperature": 1.0,
"attn_drop": 0.1,
"proj_drop": 0.1,
"drop_path": 0.0,
"max_position_embeddings": 2048,
"pad_token_id": 50256,
"bos_token_id": 50256,
"eos_token_id": 50256
}
```
> **Tip:** If you use the GPT‑2 tokenizer, set `pad_token = eos_token` and make sure `vocab_size` matches the tokenizer (50257).
---
## Quick‑start (inference)
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model_id = "reaperdoesntknow/MoA-100M"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> tokenizer.pad_token = tokenizer.eos_token # needed for the GPT‑2 tokenizer
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
>>> prompt = "Explain metric‑based attention in simple terms:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> output_ids = model.generate(
... **inputs,
... max_new_tokens=128,
... do_sample=False, # deterministic; set temperature>0 for sampling
... pad_token_id=tokenizer.pad_token_id,
... )
>>> print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
```
*Note:* Because KV‑cache is not implemented, generation time grows linearly with the total context length.
---
## Training (custom loop sketch)
```python
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
import torch, torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def collate_fn(examples):
batch = tokenizer(
[ex["text"] for ex in examples],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
labels = batch["input_ids"].clone()
labels[batch["attention_mask"] == 0] = -100
batch["labels"] = labels
return batch
# dataset = load_dataset(..., split="train") # must contain a 'text' field
# loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
model = AutoModelForCausalLM.from_pretrained("reaperdoesntknow/MoA-100M")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=5e-4,
betas=(0.9, 0.95),
weight_decay=0.01,
)
for batch in loader:
out = model(**batch)
out.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.2)
optimizer.step()
optimizer.zero_grad()
```
---
## Evaluation checklist
* **Perplexity** on a held‑out split of the two training datasets.
* **Ablation studies** (keep total token budget constant):
* L2 vs. cosine vs. diagonal‑Mahalanobis distance.
* With / without ball pruning.
* With / without HyperFFN branch router.
* With / without TI regulariser.
* **Speed / memory** comparison against a vanilla GPT‑2‑size model (same `dim`/`layers`).
---
## Efficiency notes
| Feature | What it does |
|---------|--------------|
| **Ball pruning** | Masks keys that lie outside a learned radius → reduces the quadratic attention cost. |
| **Metric MQA** | Shares K/V across heads → fewer projection matrices, lower FLOPs. |
| **HyperFFN branch router** | Token‑wise top‑k routing means only the most useful branch is evaluated per token. |
| **CPU tips** | Set `OMP_NUM_THREADS` / `MKL_NUM_THREADS` to the number of physical cores; use `torch.set_num_threads()` if needed. |
Future roadmap: metric‑aware KV‑cache, kernelised distance approximations (e.g., Random Fourier Features), quantisation & mixed‑precision inference.
---
## Safety, Bias & Risks
* The model **has not been fine‑tuned for safety or alignment**.
* Outputs may contain **biases, profanity, or factual errors**.
* Do **not** deploy in high‑stakes contexts without additional evaluation, moderation, and possibly further fine‑tuning.
---
## License
Apache‑2.0 – see the `LICENSE` file in the repository.
---
## Citation
```bibtex
@misc{moametriclm185m,
title = {reaperdoesntknow/MoA-100M: A Geometry-Aware Mixture-of-Attentions Language Model},
author = {Colca, Roy Shawn and collaborators},
year = {2025},
url = {https://huggingface.co/reaperdoesntknow/MoA-100M}
}
```
---
## Changelog
| Version | Date | Notes |
|---------|------|-------|
| **v0.2** | 2025‑09‑20 | 500 k‑token CPU run, GPT‑2 tokenizer, LR = 5e‑4, final loss ≈ 0.30. |
| **v0.1** | 2025‑09‑20 | Initial public release: metric heads, MQA, ball pruning, HyperFFN, router & gates; HF‑compatible; no KV cache. |
---
## Maintainers
* **Author:** reaper (Convergent Intelligence LLC)
* **Contact:** *Email* (convergentintelligencenyc@gmail.com)*
---
## Special Remarks
- This models still in an extremely experimental state. As are most of them, but im working on stabilizing this one for general inference.
- I design create and train all of my models using my mathematical research and pure disgust for the dot product!
- For those of you who actually read this and use my models, you make my day everytime I see another download, so thank you for being awesome!
|