|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
pipeline_tag: text-generation |
|
|
tags: |
|
|
- mixture-of-attentions |
|
|
- distance-attention |
|
|
- metric-attention |
|
|
- mqa |
|
|
- hyperffn |
|
|
- router-gating |
|
|
datasets: |
|
|
- nvidia/Nemotron-Math-HumanReasoning |
|
|
- WeMake/Intelligent-Content-Understanding |
|
|
--- |
|
|
|
|
|
# MoAMetricLM‑100M — Mixture of Attentions (MoA) |
|
|
|
|
|
**A geometry‑aware Transformer that mixes several attention mechanisms and routes them with a metric‑based router.** |
|
|
- **Parameters:** ~185 M (≈ 100 M effective due to the mixture) |
|
|
- **Task:** Causal language modeling (decoder‑only) |
|
|
- **Library:** 🤗 Transformers |
|
|
- **KV cache:** Not yet implemented (generation recomputes the full context at every step) |
|
|
|
|
|
--- |
|
|
|
|
|
## Model card |
|
|
|
|
|
| **Model ID** | `reaperdoesntknow/MoA-100M` | |
|
|
|--------------|-------------------------------------| |
|
|
| **Architecture** | `moa_metric` (custom) | |
|
|
| **Tokenizer** | GPT‑2 (`gpt2`) – `pad_token` set to `eos_token` | |
|
|
| **Context length** | 2048 tokens | |
|
|
| **Training data** | 2 × ≈ 256 k tokens from the datasets listed above | |
|
|
| **Training compute** | CPU‑only (Intel), FP32 | |
|
|
| **Training hyper‑parameters** | LR = 5e‑4 (AdamW), batch = 4, seq ≤ 512, 500 k total tokens | |
|
|
| **Final loss** | ≈ 0.30 (train) | |
|
|
| **License** | Apache‑2.0 | |
|
|
| **Safety** | No alignment or safety fine‑tuning – outputs may be biased or inaccurate. | |
|
|
| **Intended use** | Research on geometry‑aware attention, structured sparsity, and mixture‑of‑attention models. | |
|
|
| **Limitations** | • No KV‑cache → slower generation. <br>• Small token budget → not a general‑purpose LM. <br>• No safety/alignment training. | |
|
|
| **Out‑of‑scope** | High‑stakes applications (medical, legal, etc.) without further evaluation. | |
|
|
|
|
|
--- |
|
|
|
|
|
## Overview |
|
|
|
|
|
MoA replaces the classic dot‑product attention with **metric‑based attention** and blends **four** distinct heads per Transformer block: |
|
|
|
|
|
| Head type | Description | |
|
|
|-----------|-------------| |
|
|
| **LocalConvHead** | Depthwise‑separable 1‑D convolution → captures short‑range context. | |
|
|
| **Metric Multi‑Head Attention (MetricMHAttention)** | Soft‑min over **L2 / cosine / diagonal‑Mahalanobis** distances: <br> \(\displaystyle \text{attn}_{h}(i,j) \propto \exp\!\big(-\alpha_h\|q_i-k_j\|^2\big)\) | |
|
|
| **Metric MQA** | Multi‑Query attention (shared K/V) in the same metric space – cheaper than full MHA. | |
|
|
| **ChannelMixHead** | Per‑token MLP that mixes channel dimensions (no positional mixing). | |
|
|
|
|
|
A **token‑wise router** decides, for each token, which head(s) to use and applies **feature‑gates** (FiLM‑style) and **router‑bias gates** for up/down‑scaling. |
|
|
|
|
|
The **FFN** is a **HyperFFN** – three parallel branches (SwiGLU MLP, separable‑conv, low‑rank) combined by a **branch router**. LayerScale and optional DropPath keep training stable. |
|
|
|
|
|
### Regularisation (optional) |
|
|
|
|
|
* **Triangle‑inequality (TI) penalty** on sampled triples to encourage true‑metric behaviour. |
|
|
* **Ball pruning** – each head learns an **origin** \(o_h\) and **radius** \(r_h\); keys outside the ball are masked, giving structured sparsity. |
|
|
|
|
|
--- |
|
|
|
|
|
## Architecture diagram (high‑level) |
|
|
|
|
|
``` |
|
|
Input → Embedding → (PreNorm) → Block₁ → … → Blockₙ → LM‑Head → Output |
|
|
│ |
|
|
├─ LocalConvHead |
|
|
├─ MetricMHAttention |
|
|
├─ MetricMQA |
|
|
└─ ChannelMixHead |
|
|
(router decides per‑token) |
|
|
|
|
|
Each Block also contains: |
|
|
→ HyperFFN (SwiGLU | Conv | Low‑rank) ← branch router |
|
|
→ LayerScale + DropPath |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Configuration (example) |
|
|
|
|
|
```json |
|
|
{ |
|
|
"model_type": "moa_metric", |
|
|
"vocab_size": 50257, |
|
|
"dim": 768, |
|
|
"num_layers": 12, |
|
|
"attn_heads": 8, |
|
|
"mqa_q_heads": 8, |
|
|
"mixer_hidden": 3072, |
|
|
"ffn_hidden": 3072, |
|
|
"metric": "l2", // "l2" | "cosine" | "maha_diag" |
|
|
"alpha_init": 1.0, |
|
|
"learn_alpha": true, |
|
|
"use_balls": true, |
|
|
"radius_init": 3.0, |
|
|
"learn_radius": true, |
|
|
"origin_init_scale": 0.0, |
|
|
"maha_init": 1.0, |
|
|
"ti_reg_weight": 0.0, |
|
|
"ti_reg_samples": 0, |
|
|
"router_hidden": 128, |
|
|
"router_dropout": 0.1, |
|
|
"router_temperature": 1.0, |
|
|
"attn_drop": 0.1, |
|
|
"proj_drop": 0.1, |
|
|
"drop_path": 0.0, |
|
|
"max_position_embeddings": 2048, |
|
|
"pad_token_id": 50256, |
|
|
"bos_token_id": 50256, |
|
|
"eos_token_id": 50256 |
|
|
} |
|
|
``` |
|
|
|
|
|
> **Tip:** If you use the GPT‑2 tokenizer, set `pad_token = eos_token` and make sure `vocab_size` matches the tokenizer (50257). |
|
|
|
|
|
--- |
|
|
|
|
|
## Quick‑start (inference) |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
>>> model_id = "reaperdoesntknow/MoA-100M" |
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
>>> tokenizer.pad_token = tokenizer.eos_token # needed for the GPT‑2 tokenizer |
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained(model_id) |
|
|
|
|
|
>>> prompt = "Explain metric‑based attention in simple terms:" |
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
>>> output_ids = model.generate( |
|
|
... **inputs, |
|
|
... max_new_tokens=128, |
|
|
... do_sample=False, # deterministic; set temperature>0 for sampling |
|
|
... pad_token_id=tokenizer.pad_token_id, |
|
|
... ) |
|
|
>>> print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
*Note:* Because KV‑cache is not implemented, generation time grows linearly with the total context length. |
|
|
|
|
|
--- |
|
|
|
|
|
## Training (custom loop sketch) |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling |
|
|
from torch.utils.data import DataLoader |
|
|
import torch, torch.nn.functional as F |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def collate_fn(examples): |
|
|
batch = tokenizer( |
|
|
[ex["text"] for ex in examples], |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt", |
|
|
) |
|
|
labels = batch["input_ids"].clone() |
|
|
labels[batch["attention_mask"] == 0] = -100 |
|
|
batch["labels"] = labels |
|
|
return batch |
|
|
|
|
|
# dataset = load_dataset(..., split="train") # must contain a 'text' field |
|
|
# loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("reaperdoesntknow/MoA-100M") |
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=5e-4, |
|
|
betas=(0.9, 0.95), |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
for batch in loader: |
|
|
out = model(**batch) |
|
|
out.loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.2) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Evaluation checklist |
|
|
|
|
|
* **Perplexity** on a held‑out split of the two training datasets. |
|
|
* **Ablation studies** (keep total token budget constant): |
|
|
* L2 vs. cosine vs. diagonal‑Mahalanobis distance. |
|
|
* With / without ball pruning. |
|
|
* With / without HyperFFN branch router. |
|
|
* With / without TI regulariser. |
|
|
* **Speed / memory** comparison against a vanilla GPT‑2‑size model (same `dim`/`layers`). |
|
|
|
|
|
--- |
|
|
|
|
|
## Efficiency notes |
|
|
|
|
|
| Feature | What it does | |
|
|
|---------|--------------| |
|
|
| **Ball pruning** | Masks keys that lie outside a learned radius → reduces the quadratic attention cost. | |
|
|
| **Metric MQA** | Shares K/V across heads → fewer projection matrices, lower FLOPs. | |
|
|
| **HyperFFN branch router** | Token‑wise top‑k routing means only the most useful branch is evaluated per token. | |
|
|
| **CPU tips** | Set `OMP_NUM_THREADS` / `MKL_NUM_THREADS` to the number of physical cores; use `torch.set_num_threads()` if needed. | |
|
|
|
|
|
Future roadmap: metric‑aware KV‑cache, kernelised distance approximations (e.g., Random Fourier Features), quantisation & mixed‑precision inference. |
|
|
|
|
|
--- |
|
|
|
|
|
## Safety, Bias & Risks |
|
|
|
|
|
* The model **has not been fine‑tuned for safety or alignment**. |
|
|
* Outputs may contain **biases, profanity, or factual errors**. |
|
|
* Do **not** deploy in high‑stakes contexts without additional evaluation, moderation, and possibly further fine‑tuning. |
|
|
|
|
|
--- |
|
|
|
|
|
## License |
|
|
|
|
|
Apache‑2.0 – see the `LICENSE` file in the repository. |
|
|
|
|
|
--- |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{moametriclm185m, |
|
|
title = {reaperdoesntknow/MoA-100M: A Geometry-Aware Mixture-of-Attentions Language Model}, |
|
|
author = {Colca, Roy Shawn and collaborators}, |
|
|
year = {2025}, |
|
|
url = {https://huggingface.co/reaperdoesntknow/MoA-100M} |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Changelog |
|
|
|
|
|
| Version | Date | Notes | |
|
|
|---------|------|-------| |
|
|
| **v0.2** | 2025‑09‑20 | 500 k‑token CPU run, GPT‑2 tokenizer, LR = 5e‑4, final loss ≈ 0.30. | |
|
|
| **v0.1** | 2025‑09‑20 | Initial public release: metric heads, MQA, ball pruning, HyperFFN, router & gates; HF‑compatible; no KV cache. | |
|
|
|
|
|
--- |
|
|
|
|
|
## Maintainers |
|
|
|
|
|
* **Author:** reaper (Convergent Intelligence LLC) |
|
|
* **Contact:** *Email* (convergentintelligencenyc@gmail.com)* |
|
|
|
|
|
|
|
|
--- |
|
|
|
|
|
## Special Remarks |
|
|
|
|
|
- This models still in an extremely experimental state. As are most of them, but im working on stabilizing this one for general inference. |
|
|
- I design create and train all of my models using my mathematical research and pure disgust for the dot product! |
|
|
- For those of you who actually read this and use my models, you make my day everytime I see another download, so thank you for being awesome! |
|
|
|
|
|
|