File size: 11,653 Bytes
ae7984f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | ---
library_name: transformers
license: apache-2.0
language:
- en
tags:
- monoid
- causal-lm
- linear-attention
- state-space
- O(1)-inference
- reasoning
pipeline_tag: text-generation
model-index:
- name: Spartacus-1B-Instruct
results: []
---
# Spartacus-1B-Instruct β Causal Monoid Language Model
A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference β regardless of sequence length.
Fine-tuned for enhanced reasoning with structured chain-of-thought data.
## Monoid Attention β Internal Structure
```
MonoidAttention (per layer, per head)
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β β
β x_t β R^{2048} β
β β β
β βββ> q_proj ββ> RMSNorm ββ> q_t β R^{d} (query) β
β β β
β βββ> k_proj ββ> RMSNorm ββ> SiLU ββ> k_t β R^{d} (key, >= 0) β
β β β
β βββ> v_proj ββ> v_t β R^{d} (value) β
β β β
β βββ> decay_proj ββ> sigmoid ββ> alpha_t β (0,1) (decay gate) β
β β
β k_t (x) v_t β
β β ββββββββββββββββββββββββββββββββ β
β β β State Matrix S_t β R^{d x d} β β
β v β β β
β S_t = alpha_t * S_{t-1} + k_t (x) v_t β β
β β β "Compressed causal history" β β
β β ββββββββββββββββββββββββββββββββ β
β v β
β o_t = q_t . S_t ββ> o_proj ββ> output β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
```
## Monoid State Diagonal β O(1) Compression Contour
The state matrix `S_t` accumulates causal history along its diagonal. Each head maintains an independent `d x d` state that compresses ALL past tokens into a fixed footprint:
```
State Matrix S_t β R^{64 x 64} (one per head, 32 heads per layer)
k-dim -->
0 8 16 24 32 40 48 56 63
βββββ¬ββββ¬ββββ¬ββββ¬ββββ¬ββββ¬ββββ¬ββββ 0
β***β** β* β β β β β β v-dim
β***β** β* β. β β β β β |
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 8 |
β** β***β** β* β. β β β β v
β* β***β** β* β. β β β β
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 16
β* β** β***β** β* β. β β β
β. β* β***β** β* β. β β β
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 24
β β. β** β***β** β* β. β β
β β β* β***β** β* β. β β
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 32
β β β. β** β***β** β* β. β
β β β β* β***β** β* β. β
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 40
β β β β. β** β***β** β* β
β β β β β* β***β** β* β
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 48
β β β β β. β** β***β** β
β β β β β β* β***β** β
βββββΌββββΌββββΌββββΌββββΌββββΌββββΌββββ€ 56
β β β β β β. β** β***β
β β β β β β β* β***β
βββββ΄ββββ΄ββββ΄ββββ΄ββββ΄ββββ΄ββββ΄ββββ 63
Legend: *** = high activation (recent tokens, alpha^0 ~ alpha^2)
** = medium (alpha^3 ~ alpha^5)
* = fading (alpha^6 ~ alpha^10)
. = near-zero (alpha^11+, effectively forgotten)
= zero (never reached or fully decayed)
The diagonal band emerges because S_t = SUM_{i<=t} alpha^{t-i} * k_i (x) v_i.
Recent outer products dominate near the diagonal; older ones decay
exponentially via alpha, creating this characteristic contour.
```
## Key Properties
| Property | Transformer (Llama) | Spartacus (Monoid) |
|---|---|---|
| Inference time per token | O(T) -- scans full KV-cache | **O(1)** -- single state update |
| Inference memory per layer | O(T) -- stores all past K,V | **O(1)** -- fixed d x d state matrix |
| Sequence length extrapolation | Degrades beyond training length | **Unlimited** -- state size is constant |
| Causality | Imposed via attention mask | **Built into the recurrence** |
| Training complexity | O(T^2) | **O(T)** via parallel prefix scan |
## The Monoid Recurrence
Standard attention computes:
```
o_t = sum_{i<=t} softmax(q_t . k_i) v_i -- requires O(T) KV-cache
```
Monoid attention compresses the entire causal history into a **fixed-size state matrix** S_t per head:
```
S_t = alpha_t * S_{t-1} + k_t (x) v_t -- explicit causal recurrence
o_t = q_t . S_t -- state readout
```
where `alpha_t = sigmoid(decay_proj(x_t))` is a learned, content-dependent decay gate that controls how fast past information fades.
## Explicit Causal Modeling
Unlike Transformers where causality is a constraint imposed by masking, Spartacus makes causality a **first-class citizen**:
- The decay gate `alpha_t` explicitly controls per-head information retention at every timestep
- The model learns **when to forget** rather than encoding **where tokens are** (no positional encoding needed)
- No attention mask required -- causality is structural, not enforced
## Design Choices
- **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix `S` positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
- **Log-space decay**: Working in log-space `log(alpha)` avoids numerical underflow when `alpha^T -> 0` for long sequences
- **Learnable h0**: The initial state `S_0 = h0` is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
## Model Details
| Parameter | Value |
|---|---|
| Model | `NoesisLab/Spartacus-1B-Instruct` |
| Architecture | MonoidForCausalLM |
| Parameters | ~1.34B (tied embeddings) |
| Hidden size | 2048 |
| Intermediate size (MLP) | 8192 |
| Layers | 16 |
| Attention heads | 32 |
| Head dimension | 64 |
| State matrix per head | 64 x 64 = 4096 floats |
| Vocabulary | 128,256 (Llama-3.2 tokenizer) |
| Precision | bfloat16 |
## Benchmarks (0-shot)
| Task | Metric | Value | Stderr |
|---|---|---|---|
| ARC-Challenge | acc_norm | 0.3063 | Β±0.0135 |
| ARC-Easy | acc | 0.5518 | Β±0.0102 |
| HellaSwag | acc_norm | 0.4610 | Β±0.0050 |
| PIQA | acc_norm | 0.6915 | Β±0.0108 |
| WinoGrande | acc | 0.5225 | Β±0.0140 |
### Comparison with ~1B Baselines (acc_norm, 0-shot)
| Task | Spartacus-1B-Instruct | TinyLlama-1.1B | Llama 3.2-1B | Mamba-1.4B | RWKV-6-1.6B |
|---|---|---|---|---|---|
| ARC-C | **0.3063** | 0.3268 | ~0.359 | 0.284 | ~0.301 |
| ARC-E | **0.5518** | 0.5547 | ~0.752 | 0.512 | ~0.530 |
| HellaSwag | **0.4610** | 0.4670 | ~0.546 | 0.435 | ~0.450 |
| PIQA | **0.6915** | 0.7210 | ~0.740 | 0.655 | ~0.670 |
| WinoGrande | **0.5225** | 0.5040 | ~0.592 | 0.510 | ~0.515 |
> Spartacus achieves competitive performance with sub-quadratic models (Mamba, RWKV) while maintaining **O(1) inference time and memory per token**. Scores marked with ~ are approximate community-reported values.
## Training
### Stage 1: General SFT
- **Base weights**: Transferred from Llama-3.2-1B-Instruct (embeddings, MLP, norms)
- **Data**: Capybara + smol-smoltalk (general conversation)
- **Training**: Full-parameter SFT
### Stage 2: Reasoning Enhancement
- **Data mix**: 60% Qwen3-Short-Reasoning + 20% Capybara + 20% smol-smoltalk
- **Steps**: 2,000
- **Learning rate**: 2e-5 (cosine schedule, 50 warmup steps)
- **Batch size**: 8
- **Sequence length**: 2,048
- **Precision**: bfloat16
- **Optimizer**: AdamW (weight decay 0.01, max grad norm 1.0)
The reasoning data uses structured "Thought + Solution" format to strengthen chain-of-thought capabilities while the general data prevents catastrophic forgetting.
## Parallel Scan Implementation
The `monoid_scan_cuda.py` module provides a Triton JIT-compiled parallel prefix scan:
- **Forward**: Sequential scan along T, parallelized across B x H x D on GPU via Triton kernels
- **Backward**: Reverse-order adjoint scan computes gradients for both values and log-decay gates
- **Fallback**: Pure PyTorch sequential scan for CPU/MPS
- **Auto-dispatch**: CUDA -> Triton kernel, otherwise -> PyTorch fallback
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"NoesisLab/Spartacus-1B-Instruct",
trust_remote_code=True,
torch_dtype="bfloat16",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("NoesisLab/Spartacus-1B-Instruct")
messages = [{"role": "user", "content": "Hello!"}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## File Structure
```
MonoidForCausalLM.py # Model architecture (MonoidConfig, MonoidAttention, MonoidForCausalLM)
monoid_scan_cuda.py # Triton JIT parallel prefix scan + PyTorch fallback
model.safetensors # Model weights (bfloat16)
config.json # Model configuration
tokenizer.json # Llama-3.2 tokenizer
```
## Citation
```bibtex
@software{spartacus2025,
title={Spartacus: Causal Monoid Language Model with O(1) Inference},
author={NoesisLab},
year={2025},
url={https://huggingface.co/NoesisLab/Spartacus-1B-Instruct},
description={Replaces softmax attention with monoid state compression for constant-time, constant-memory autoregressive generation}
}
```
## License
Apache 2.0
|