File size: 9,308 Bytes
377cc15
c22b85c
7a75b28
c22b85c
377cc15
c22b85c
 
7a75b28
 
 
 
 
 
a9467db
7a75b28
 
377cc15
 
a61b9ff
377cc15
a61b9ff
 
 
 
 
a9467db
a61b9ff
377cc15
a61b9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377cc15
a61b9ff
377cc15
c22b85c
377cc15
a61b9ff
377cc15
a61b9ff
 
 
 
 
 
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
 
a9467db
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
 
 
 
 
 
 
 
 
 
 
 
 
a9467db
a61b9ff
377cc15
c22b85c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9467db
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a9467db
a61b9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
 
c22b85c
 
377cc15
a61b9ff
 
 
 
 
 
 
 
 
 
 
c22b85c
 
 
 
377cc15
a61b9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9467db
377cc15
a61b9ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377cc15
a61b9ff
 
 
 
 
 
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
 
 
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
377cc15
a61b9ff
c22b85c
a61b9ff
 
 
 
c22b85c
a61b9ff
 
 
377cc15
a61b9ff
377cc15
a61b9ff
 
 
 
377cc15
a61b9ff
 
 
 
 
 
 
 
 
377cc15
a61b9ff
a9467db
a61b9ff
 
 
377cc15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
---
language:
- en
license: apache-2.0
library_name: transformers
pipeline_tag: text-generation
tags:
- mixture-of-attentions
- distance-attention
- metric-attention
- mqa
- hyperffn
- router-gating
datasets:
- nvidia/Nemotron-Math-HumanReasoning
- WeMake/Intelligent-Content-Understanding
---

# MoAMetricLM‑100M — Mixture of Attentions (MoA)

**A geometry‑aware Transformer that mixes several attention mechanisms and routes them with a metric‑based router.**  
- **Parameters:** ~185 M (≈ 100 M effective due to the mixture)  
- **Task:** Causal language modeling (decoder‑only)  
- **Library:** 🤗 Transformers  
- **KV cache:** Not yet implemented (generation recomputes the full context at every step)  

---

## Model card

| **Model ID** | `reaperdoesntknow/MoA-100M` |
|--------------|-------------------------------------|
| **Architecture** | `moa_metric` (custom) |
| **Tokenizer** | GPT‑2 (`gpt2`) – `pad_token` set to `eos_token` |
| **Context length** | 2048 tokens |
| **Training data** | 2 × ≈ 256 k tokens from the datasets listed above |
| **Training compute** | CPU‑only (Intel), FP32 |
| **Training hyper‑parameters** | LR = 5e‑4 (AdamW), batch = 4, seq ≤ 512, 500 k total tokens |
| **Final loss** | ≈ 0.30 (train) |
| **License** | Apache‑2.0 |
| **Safety** | No alignment or safety fine‑tuning – outputs may be biased or inaccurate. |
| **Intended use** | Research on geometry‑aware attention, structured sparsity, and mixture‑of‑attention models. |
| **Limitations** | • No KV‑cache → slower generation. <br>• Small token budget → not a general‑purpose LM. <br>• No safety/alignment training. |
| **Out‑of‑scope** | High‑stakes applications (medical, legal, etc.) without further evaluation. |

---

## Overview

MoA replaces the classic dot‑product attention with **metric‑based attention** and blends **four** distinct heads per Transformer block:

| Head type | Description |
|-----------|-------------|
| **LocalConvHead** | Depthwise‑separable 1‑D convolution → captures short‑range context. |
| **Metric Multi‑Head Attention (MetricMHAttention)** | Soft‑min over **L2 / cosine / diagonal‑Mahalanobis** distances: <br> \(\displaystyle \text{attn}_{h}(i,j) \propto \exp\!\big(-\alpha_h\|q_i-k_j\|^2\big)\) |
| **Metric MQA** | Multi‑Query attention (shared K/V) in the same metric space – cheaper than full MHA. |
| **ChannelMixHead** | Per‑token MLP that mixes channel dimensions (no positional mixing). |

A **token‑wise router** decides, for each token, which head(s) to use and applies **feature‑gates** (FiLM‑style) and **router‑bias gates** for up/down‑scaling.

The **FFN** is a **HyperFFN** – three parallel branches (SwiGLU MLP, separable‑conv, low‑rank) combined by a **branch router**. LayerScale and optional DropPath keep training stable.

### Regularisation (optional)

* **Triangle‑inequality (TI) penalty** on sampled triples to encourage true‑metric behaviour.  
* **Ball pruning** – each head learns an **origin** \(o_h\) and **radius** \(r_h\); keys outside the ball are masked, giving structured sparsity.

---

## Architecture diagram (high‑level)

```
Input → Embedding → (PreNorm) → Block₁ → … → Blockₙ → LM‑Head → Output

                     ├─ LocalConvHead
                     ├─ MetricMHAttention
                     ├─ MetricMQA
                     └─ ChannelMixHead
                     (router decides per‑token)

Each Block also contains:
  → HyperFFN (SwiGLU | Conv | Low‑rank)  ← branch router
  → LayerScale + DropPath
```

---

## Configuration (example)

```json
{
  "model_type": "moa_metric",
  "vocab_size": 50257,
  "dim": 768,
  "num_layers": 12,
  "attn_heads": 8,
  "mqa_q_heads": 8,
  "mixer_hidden": 3072,
  "ffn_hidden": 3072,
  "metric": "l2",                     // "l2" | "cosine" | "maha_diag"
  "alpha_init": 1.0,
  "learn_alpha": true,
  "use_balls": true,
  "radius_init": 3.0,
  "learn_radius": true,
  "origin_init_scale": 0.0,
  "maha_init": 1.0,
  "ti_reg_weight": 0.0,
  "ti_reg_samples": 0,
  "router_hidden": 128,
  "router_dropout": 0.1,
  "router_temperature": 1.0,
  "attn_drop": 0.1,
  "proj_drop": 0.1,
  "drop_path": 0.0,
  "max_position_embeddings": 2048,
  "pad_token_id": 50256,
  "bos_token_id": 50256,
  "eos_token_id": 50256
}
```

> **Tip:** If you use the GPT‑2 tokenizer, set `pad_token = eos_token` and make sure `vocab_size` matches the tokenizer (50257).

---

## Quick‑start (inference)

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> model_id = "reaperdoesntknow/MoA-100M"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> tokenizer.pad_token = tokenizer.eos_token   # needed for the GPT‑2 tokenizer

>>> model = AutoModelForCausalLM.from_pretrained(model_id)

>>> prompt = "Explain metric‑based attention in simple terms:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> output_ids = model.generate(
...     **inputs,
...     max_new_tokens=128,
...     do_sample=False,          # deterministic; set temperature>0 for sampling
...     pad_token_id=tokenizer.pad_token_id,
... )
>>> print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
```

*Note:* Because KV‑cache is not implemented, generation time grows linearly with the total context length.

---

## Training (custom loop sketch)

```python
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
import torch, torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def collate_fn(examples):
    batch = tokenizer(
        [ex["text"] for ex in examples],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )
    labels = batch["input_ids"].clone()
    labels[batch["attention_mask"] == 0] = -100
    batch["labels"] = labels
    return batch

# dataset = load_dataset(..., split="train")  # must contain a 'text' field
# loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model = AutoModelForCausalLM.from_pretrained("reaperdoesntknow/MoA-100M")
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=5e-4,
    betas=(0.9, 0.95),
    weight_decay=0.01,
)

for batch in loader:
    out = model(**batch)
    out.loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.2)
    optimizer.step()
    optimizer.zero_grad()
```

---

## Evaluation checklist

* **Perplexity** on a held‑out split of the two training datasets.  
* **Ablation studies** (keep total token budget constant):
  * L2 vs. cosine vs. diagonal‑Mahalanobis distance.
  * With / without ball pruning.
  * With / without HyperFFN branch router.
  * With / without TI regulariser.
* **Speed / memory** comparison against a vanilla GPT‑2‑size model (same `dim`/`layers`).  

---

## Efficiency notes

| Feature | What it does |
|---------|--------------|
| **Ball pruning** | Masks keys that lie outside a learned radius → reduces the quadratic attention cost. |
| **Metric MQA** | Shares K/V across heads → fewer projection matrices, lower FLOPs. |
| **HyperFFN branch router** | Token‑wise top‑k routing means only the most useful branch is evaluated per token. |
| **CPU tips** | Set `OMP_NUM_THREADS` / `MKL_NUM_THREADS` to the number of physical cores; use `torch.set_num_threads()` if needed. |

Future roadmap: metric‑aware KV‑cache, kernelised distance approximations (e.g., Random Fourier Features), quantisation & mixed‑precision inference.

---

## Safety, Bias & Risks

* The model **has not been fine‑tuned for safety or alignment**.  
* Outputs may contain **biases, profanity, or factual errors**.  
* Do **not** deploy in high‑stakes contexts without additional evaluation, moderation, and possibly further fine‑tuning.

---

## License

Apache‑2.0 – see the `LICENSE` file in the repository.

---

## Citation

```bibtex
@misc{moametriclm185m,
  title   = {reaperdoesntknow/MoA-100M: A Geometry-Aware Mixture-of-Attentions Language Model},
  author  = {Colca, Roy Shawn and collaborators},
  year    = {2025},
  url     = {https://huggingface.co/reaperdoesntknow/MoA-100M}
}
```

---

## Changelog

| Version | Date | Notes |
|---------|------|-------|
| **v0.2** | 2025‑09‑20 | 500 k‑token CPU run, GPT‑2 tokenizer, LR = 5e‑4, final loss ≈ 0.30. |
| **v0.1** | 2025‑09‑20 | Initial public release: metric heads, MQA, ball pruning, HyperFFN, router & gates; HF‑compatible; no KV cache. |

---

## Maintainers

* **Author:** reaper (Convergent Intelligence LLC)  
* **Contact:** *Email* (convergentintelligencenyc@gmail.com)*  


---

## Special Remarks

- This models still in an extremely experimental state. As are most of them, but im working on stabilizing this one for general inference.
- I design create and train all of my models using my mathematical research and pure disgust for the dot product!
- For those of you who actually read this and use my models, you make my day everytime I see another download, so thank you for being awesome!