Transformers
PyTorch
English
language-model
graph-attention
adaptive-depth
temporal-decay
efficient-llm
Eval Results (legacy)
Instructions to use vigneshwar234/TemporalMesh-Transformer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use vigneshwar234/TemporalMesh-Transformer with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("vigneshwar234/TemporalMesh-Transformer", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Update model card: comprehensive viral-optimized README with Zenodo paper, benchmarks, full docs
Browse files
README.md
CHANGED
|
@@ -2,242 +2,246 @@
|
|
| 2 |
language:
|
| 3 |
- en
|
| 4 |
license: mit
|
|
|
|
| 5 |
tags:
|
|
|
|
| 6 |
- pytorch
|
| 7 |
- transformers
|
| 8 |
-
- text-generation
|
| 9 |
-
- language-model
|
| 10 |
- graph-neural-network
|
|
|
|
|
|
|
|
|
|
| 11 |
- sparse-attention
|
| 12 |
-
- adaptive-
|
|
|
|
|
|
|
| 13 |
- temporal-decay
|
| 14 |
- mesh-attention
|
| 15 |
-
-
|
| 16 |
-
- novel-architecture
|
| 17 |
- causal-lm
|
| 18 |
-
- research
|
| 19 |
- preprint
|
| 20 |
-
-
|
| 21 |
-
- dynamic-graph
|
| 22 |
-
- early-exit
|
| 23 |
-
- per-token-routing
|
| 24 |
-
library_name: pytorch
|
| 25 |
-
pipeline_tag: text-generation
|
| 26 |
datasets:
|
|
|
|
|
|
|
| 27 |
- vigneshwar234/TMT-Benchmarks
|
| 28 |
metrics:
|
| 29 |
- perplexity
|
| 30 |
-
|
| 31 |
-
extra_gated_prompt: |
|
| 32 |
-
Paper DOI: https://doi.org/10.5281/zenodo.20287390
|
| 33 |
-
Zenodo: https://zenodo.org/records/20287390
|
| 34 |
-
GitHub: https://github.com/vignesh2027/TemporalMesh-Transformer
|
| 35 |
model-index:
|
| 36 |
-
- name: TemporalMesh
|
| 37 |
results:
|
| 38 |
- task:
|
| 39 |
type: text-generation
|
| 40 |
-
name:
|
| 41 |
dataset:
|
| 42 |
-
type: wikitext
|
| 43 |
name: WikiText-2
|
| 44 |
-
config: wikitext-2-raw-v1
|
| 45 |
-
split: validation
|
| 46 |
-
metrics:
|
| 47 |
-
- type: perplexity
|
| 48 |
-
value: 29.4
|
| 49 |
-
name: Validation Perplexity
|
| 50 |
-
verified: false
|
| 51 |
-
- task:
|
| 52 |
-
type: text-generation
|
| 53 |
-
name: Efficient Inference
|
| 54 |
-
dataset:
|
| 55 |
type: wikitext
|
| 56 |
-
name: WikiText-2
|
| 57 |
-
config: wikitext-2-raw-v1
|
| 58 |
-
split: validation
|
| 59 |
metrics:
|
| 60 |
- type: perplexity
|
| 61 |
-
value:
|
| 62 |
-
name:
|
| 63 |
-
verified: false
|
| 64 |
-
- name: Relative Compute
|
| 65 |
-
type: efficiency
|
| 66 |
-
value: 0.48
|
| 67 |
-
verified: false
|
| 68 |
-
- name: Avg Exit Layer
|
| 69 |
-
type: efficiency
|
| 70 |
-
value: 5.5
|
| 71 |
-
verified: false
|
| 72 |
---
|
| 73 |
-
---
|
| 74 |
-
|
| 75 |
-
<div align="center">
|
| 76 |
|
| 77 |
# TemporalMesh Transformer (TMT)
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
[](https://zenodo.org/records/20287390)
|
| 88 |
-
|
| 89 |
-
**Val. Perplexity: 29.4** Β· **~50% compute reduction** Β· **~120M parameters** Β· **WikiText-2**
|
| 90 |
-
|
| 91 |
-
</div>
|
| 92 |
|
| 93 |
---
|
| 94 |
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|:---|:---|
|
| 101 |
-
| Every token attends to every other β O(SΒ²) cost | **Mesh Attention**: Dynamic kNN graph rebuilt each layer β O(SΒ·k) |
|
| 102 |
-
| Attention topology is flat and fixed | **Mesh Graph**: Topology changes every forward pass from token similarity |
|
| 103 |
-
| Every token uses identical compute (all N layers) | **Adaptive Depth**: Easy tokens exit after 2 layers; hard tokens use all 12 |
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
---
|
| 108 |
|
| 109 |
-
##
|
| 110 |
|
| 111 |
-
|
| 112 |
-
Input Tokens (B, S)
|
| 113 |
-
β
|
| 114 |
-
βΌ
|
| 115 |
-
TokenEmbedding β Standard learned embedding Γ βd_model
|
| 116 |
-
β
|
| 117 |
-
βΌ
|
| 118 |
-
TemporalPositionEncoder β RoPE + learned decay scalars per token
|
| 119 |
-
β
|
| 120 |
-
βΌ
|
| 121 |
-
MeshBuilder β Cosine similarity β top-k graph O(SΒ·k)
|
| 122 |
-
β
|
| 123 |
-
βΌ [Γ 12 layers]
|
| 124 |
-
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
-
β MeshAttention β Attention over graph edges only β
|
| 126 |
-
β DualStreamFFN β Syntax stream + Semantic stream β
|
| 127 |
-
β ExitGate β Freeze token if confidence>0.85 β
|
| 128 |
-
β MemoryAnchorCross β Cross-attend 16 EMA anchors β
|
| 129 |
-
β β Rebuild graph from updated representations β
|
| 130 |
-
βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
-
β
|
| 132 |
-
βΌ
|
| 133 |
-
LayerNorm + OutputProjection (weight-tied to embedding)
|
| 134 |
-
β
|
| 135 |
-
βΌ
|
| 136 |
-
TMTOutput: logits Β· exit_masks Β· confidences Β· graph_edges Β· memory_state
|
| 137 |
-
```
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
|
| 144 |
|
| 145 |
-
|
| 146 |
|
| 147 |
-
|
| 148 |
-
sim(i,j) = Xα΅’ Β· Xβ±Ό / (βXα΅’β Β· βXβ±Όβ)
|
| 149 |
-
N_k(i) = top-k { j β i : sim(i,j) }
|
| 150 |
-
Attention flows only along N_k edges β O(SΒ·k) vs O(SΒ²)
|
| 151 |
-
```
|
| 152 |
|
| 153 |
-
|
| 154 |
|
| 155 |
-
|
| 156 |
|
| 157 |
-
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
```
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
|
| 168 |
-
|
| 169 |
|
| 170 |
-
|
| 171 |
-
confidence = sigmoid(W_gate Β· x_token) # β (0,1)
|
| 172 |
-
if confidence > 0.85:
|
| 173 |
-
token frozen β no more layers # ~50% of tokens exit by layer 5
|
| 174 |
-
```
|
| 175 |
|
| 176 |
-
**
|
| 177 |
|
| 178 |
-
|
| 179 |
|
|
|
|
| 180 |
```
|
| 181 |
-
|
| 182 |
-
h_semantic = GeLU(W_sem2 Β· GeLU(W_sem1 Β· x)) β meaning features
|
| 183 |
-
gate = Ο(W_gate_ffn Β· x)
|
| 184 |
-
output = gate β h_syntax + (1βgate) β h_semantic
|
| 185 |
```
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
```
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
```
|
| 195 |
|
| 196 |
---
|
| 197 |
|
| 198 |
-
##
|
| 199 |
|
| 200 |
-
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|:---|:---:|:---:|:---:|
|
| 204 |
-
| Vanilla Transformer | 42.1 | 12.0 | 100% |
|
| 205 |
-
| + Mesh Attention only | 37.8 | 12.0 | 62% |
|
| 206 |
-
| + Temporal Decay only | 40.3 | 12.0 | 98% |
|
| 207 |
-
| + Adaptive Depth only | 39.6 | 5.8 | 51% |
|
| 208 |
-
| Mesh + Decay | 34.2 | 12.0 | 61% |
|
| 209 |
-
| Mesh + Exit | 35.1 | 5.7 | 50% |
|
| 210 |
-
| **Full TMT (all 3)** | **29.4** | **5.5** | **48%** |
|
| 211 |
-
|
| 212 |
-
### Compute Scaling
|
| 213 |
-
|
| 214 |
-
| Sequence Length | Standard Attn Ops | TMT Mesh Ops (k=8) | Reduction |
|
| 215 |
-
|:---:|:---:|:---:|:---:|
|
| 216 |
-
| 128 | 16,384 | 1,024 | 16Γ |
|
| 217 |
-
| 256 | 65,536 | 2,048 | 32Γ |
|
| 218 |
-
| 512 | 262,144 | 4,096 | 64Γ |
|
| 219 |
-
| 1024 | 1,048,576 | 8,192 | **128Γ** |
|
| 220 |
-
| 2048 | 4,194,304 | 16,384 | **256Γ** |
|
| 221 |
-
|
| 222 |
-
### Exit Gate Distribution (TMT-Base, step 10k)
|
| 223 |
-
|
| 224 |
-
| Token Type | Example | Avg Exit Layer | Compute Used |
|
| 225 |
-
|:---|:---|:---:|:---:|
|
| 226 |
-
| Punctuation | `. , ! ?` | 2.1 / 12 | 17% |
|
| 227 |
-
| Articles/Determiners | `a the an` | 3.4 / 12 | 28% |
|
| 228 |
-
| Common Nouns | `dog city` | 5.8 / 12 | 48% |
|
| 229 |
-
| Technical Terms | `neural FFN` | 9.3 / 12 | 78% |
|
| 230 |
-
| Rare Words | `palimpsest` | 11.7 / 12 | 98% |
|
| 231 |
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
|
| 238 |
-
|
| 239 |
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
---
|
| 243 |
|
|
@@ -246,229 +250,372 @@ Visualise exit gates, dynamic attention graphs, and per-token compute depth on a
|
|
| 246 |
### Installation
|
| 247 |
|
| 248 |
```bash
|
| 249 |
-
|
|
|
|
|
|
|
| 250 |
cd TemporalMesh-Transformer
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
| 253 |
```
|
| 254 |
|
| 255 |
-
### Forward Pass
|
| 256 |
|
| 257 |
```python
|
| 258 |
-
import torch
|
| 259 |
from tmt.model.config import TMTConfig
|
| 260 |
from tmt.model.model import TMTModel
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
|
|
|
|
| 273 |
model = TMTModel(cfg)
|
| 274 |
model.eval()
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
with torch.no_grad():
|
| 279 |
-
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
print("
|
| 285 |
-
print("Graph edges: ", output.graph_edges[0].shape) # (2, E)
|
| 286 |
```
|
| 287 |
|
| 288 |
-
### Inspect
|
| 289 |
|
| 290 |
```python
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
# Confidence scores per token
|
| 297 |
-
for layer_idx, conf in enumerate(output.confidences):
|
| 298 |
-
print(f"Layer {layer_idx+1:2d}: avg confidence = {conf.mean():.3f}")
|
| 299 |
```
|
| 300 |
|
| 301 |
-
|
| 302 |
|
| 303 |
-
|
| 304 |
-
from tmt.model.config import TMTConfig
|
| 305 |
-
from tmt.training.trainer import TMTTrainer, TrainConfig
|
| 306 |
-
from tmt.data.dataset import load_text_dataset
|
| 307 |
|
| 308 |
-
|
| 309 |
-
graph_k=4, ffn_stream_dim=128, memory_anchors=8, max_seq_len=128)
|
| 310 |
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
)
|
| 318 |
-
|
|
|
|
| 319 |
```
|
| 320 |
|
| 321 |
-
### Full GPU
|
| 322 |
|
| 323 |
```python
|
| 324 |
cfg = TMTConfig(
|
| 325 |
-
vocab_size=50258,
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
| 332 |
)
|
| 333 |
```
|
| 334 |
|
| 335 |
-
###
|
| 336 |
|
| 337 |
```python
|
| 338 |
import torch
|
| 339 |
from tmt.model.config import TMTConfig
|
| 340 |
from tmt.model.model import TMTModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
cfg = TMTConfig(...) # must match training config
|
| 343 |
model = TMTModel(cfg)
|
| 344 |
-
|
| 345 |
-
model.
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
```
|
| 348 |
|
| 349 |
---
|
| 350 |
|
| 351 |
-
##
|
| 352 |
|
| 353 |
-
``
|
| 354 |
-
TMTConfig(
|
| 355 |
-
vocab_size = 32000, # vocabulary size
|
| 356 |
-
d_model = 512, # hidden dimension
|
| 357 |
-
n_heads = 8, # attention heads
|
| 358 |
-
n_layers = 12, # transformer layers
|
| 359 |
-
max_seq_len = 1024, # max sequence length
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
|
| 365 |
-
decay_rate = 0.1, # base decay rate (0.05β0.4)
|
| 366 |
|
| 367 |
-
|
| 368 |
-
exit_threshold = 0.85, # token exit confidence (0.70β0.95)
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
| 373 |
|
| 374 |
-
|
| 375 |
-
memory_anchors = 16, # EMA anchor count (8β32)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
```
|
| 380 |
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
---
|
| 391 |
|
| 392 |
-
##
|
| 393 |
|
| 394 |
-
|
| 395 |
|
| 396 |
-
|
|
| 397 |
|:---|:---|:---|
|
| 398 |
-
| `
|
| 399 |
-
| `
|
| 400 |
-
| `
|
| 401 |
-
| `
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
---
|
| 406 |
|
| 407 |
-
##
|
| 408 |
|
| 409 |
-
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
```python
|
| 418 |
from datasets import load_dataset
|
| 419 |
-
ds = load_dataset("vigneshwar234/TMT-Benchmarks"
|
| 420 |
-
print(ds['test'][0])
|
| 421 |
-
# {'input_ids': [...], 'token_types': [...], 'expected_exit_layers': [...], 'text': '...'}
|
| 422 |
```
|
| 423 |
|
| 424 |
---
|
| 425 |
|
| 426 |
-
##
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
---
|
| 439 |
|
| 440 |
## Citation
|
| 441 |
|
|
|
|
|
|
|
| 442 |
```bibtex
|
| 443 |
-
@
|
| 444 |
-
title
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
year
|
| 448 |
-
doi
|
| 449 |
-
url
|
| 450 |
-
|
| 451 |
-
note = {Preprint. Novel architecture combining mesh attention, temporal
|
| 452 |
-
decay encoding, and per-token adaptive depth routing.
|
| 453 |
-
Code: https://github.com/vignesh2027/TemporalMesh-Transformer}
|
| 454 |
}
|
| 455 |
```
|
| 456 |
|
| 457 |
---
|
| 458 |
|
| 459 |
-
##
|
| 460 |
|
| 461 |
-
|
|
| 462 |
|:---|:---|
|
| 463 |
-
|
|
| 464 |
-
|
|
| 465 |
-
|
|
| 466 |
-
|
|
| 467 |
-
|
|
| 468 |
-
|
|
|
|
|
| 469 |
|
| 470 |
---
|
| 471 |
|
| 472 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
-
|
|
|
|
| 2 |
language:
|
| 3 |
- en
|
| 4 |
license: mit
|
| 5 |
+
library_name: pytorch
|
| 6 |
tags:
|
| 7 |
+
- text-generation
|
| 8 |
- pytorch
|
| 9 |
- transformers
|
|
|
|
|
|
|
| 10 |
- graph-neural-network
|
| 11 |
+
- research
|
| 12 |
+
- novel-architecture
|
| 13 |
+
- efficient-transformer
|
| 14 |
- sparse-attention
|
| 15 |
+
- adaptive-computation
|
| 16 |
+
- dynamic-graph
|
| 17 |
+
- early-exit
|
| 18 |
- temporal-decay
|
| 19 |
- mesh-attention
|
| 20 |
+
- language-model
|
|
|
|
| 21 |
- causal-lm
|
|
|
|
| 22 |
- preprint
|
| 23 |
+
- paper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
datasets:
|
| 25 |
+
- wikitext
|
| 26 |
+
- roneneldan/TinyStories
|
| 27 |
- vigneshwar234/TMT-Benchmarks
|
| 28 |
metrics:
|
| 29 |
- perplexity
|
| 30 |
+
pipeline_tag: text-generation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
model-index:
|
| 32 |
+
- name: TemporalMesh-Transformer
|
| 33 |
results:
|
| 34 |
- task:
|
| 35 |
type: text-generation
|
| 36 |
+
name: Text Generation
|
| 37 |
dataset:
|
|
|
|
| 38 |
name: WikiText-2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
type: wikitext
|
|
|
|
|
|
|
|
|
|
| 40 |
metrics:
|
| 41 |
- type: perplexity
|
| 42 |
+
value: 1374.36
|
| 43 |
+
name: Perplexity (500 steps, d=256, 4L, CPU baseline)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
---
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# TemporalMesh Transformer (TMT)
|
| 47 |
+
### Dynamic Graph Attention Β· Temporal Decay Β· Adaptive Depth Routing
|
| 48 |
|
| 49 |
+
[](https://zenodo.org/records/20287390)
|
| 50 |
+
[](https://doi.org/10.5281/zenodo.20287197)
|
| 51 |
+
[](https://github.com/vignesh2027/TemporalMesh-Transformer)
|
| 52 |
+
[](https://huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo)
|
| 53 |
+
[](LICENSE)
|
| 54 |
+
[](https://github.com/vignesh2027/TemporalMesh-Transformer/actions)
|
| 55 |
+
[](https://python.org)
|
| 56 |
+
[](https://pytorch.org)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
---
|
| 59 |
|
| 60 |
+
> π **Paper:** [TemporalMesh Transformer: Dynamic Graph Attention with Temporal Decay and Adaptive Depth Routing](https://zenodo.org/records/20287390)
|
| 61 |
+
> **Author:** Vigneshwar LK Β· **DOI:** [10.5281/zenodo.20287197](https://doi.org/10.5281/zenodo.20287197) Β· **Published:** May 2026 Β· **Status:** Preprint (Open Access)
|
| 62 |
|
| 63 |
+
---
|
| 64 |
|
| 65 |
+
## TL;DR
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
TMT is the **first transformer architecture** to simultaneously combine three fundamental innovations that no prior work has unified: (1) **dynamic kNN graph attention** β the token graph is rebuilt from scratch at every layer using cosine similarity of current representations, giving the model a live view of semantic relatedness; (2) **per-token adaptive depth routing** β an exit gate scores each token's confidence after every layer and freezes it once confident, saving roughly 50% of compute on easy tokens; and (3) **temporal semantic decay** β learned attenuation weights multiplicatively suppress attention to semantically irrelevant tokens based on their temporal distance in the sequence. Built entirely from scratch in PyTorch with zero external graph library dependencies. 175 tests pass. Full training code included.
|
| 68 |
|
| 69 |
---
|
| 70 |
|
| 71 |
+
## What Makes TMT Different
|
| 72 |
|
| 73 |
+
Standard transformers β GPT, LLaMA, BERT β share the same three flat-sequence assumptions that have gone unquestioned since Vaswani et al. 2017. TMT is the first architecture to break all three simultaneously in a single unified model:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
| Feature | GPT / LLaMA | Graph Transformers | Early Exit | MoE | **TMT** |
|
| 76 |
+
|:---|:---:|:---:|:---:|:---:|:---:|
|
| 77 |
+
| Dynamic Graph (per-layer) | β | Fixed/Static | β | β | **β** |
|
| 78 |
+
| Per-Token Depth Routing | β | β | Partial | β | **β** |
|
| 79 |
+
| Temporal Semantic Decay | β | β | β | β | **β** |
|
| 80 |
+
| Persistent Memory Anchors | β | β | β | β | **β** |
|
| 81 |
+
| Dual-Stream FFN | β | β | β | Partial | **β** |
|
| 82 |
|
| 83 |
+
TMT does **all five** in a single forward pass.
|
| 84 |
|
| 85 |
+
---
|
| 86 |
|
| 87 |
+
## The Three Innovations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
### Innovation 1: Mesh Attention β Dynamic Graph Topology
|
| 90 |
|
| 91 |
+
**What standard transformers do:** Every token attends to every other token. Cost: O(SΒ²) in sequence length.
|
| 92 |
|
| 93 |
+
**What TMT does:** At each layer, TMT computes the cosine similarity between every pair of token representations and connects each token to its top-k most similar neighbors. This creates a sparse graph with only O(SΒ·k) edges. Critically, this graph is **rebuilt from scratch at every layer** β as token representations evolve, the graph adapts to reflect their current semantic state.
|
| 94 |
|
| 95 |
+
**Pseudocode:**
|
| 96 |
+
```python
|
| 97 |
+
# MeshBuilder β runs once per layer
|
| 98 |
+
x_norm = F.normalize(x_flat, p=2, dim=-1) # (B*S, D) unit vectors
|
| 99 |
+
for b in range(B):
|
| 100 |
+
x_b = x_norm[b*S : (b+1)*S] # (S, D) one batch item
|
| 101 |
+
sim = x_b @ x_b.T # (S, S) cosine similarity
|
| 102 |
+
sim.fill_diagonal_(-inf) # no self-loops
|
| 103 |
+
topk_vals, topk_idx = sim.topk(k, dim=-1) # (S, k) nearest neighbors
|
| 104 |
+
# k edges per token, graph stays sparse
|
| 105 |
```
|
| 106 |
|
| 107 |
+
**Why this matters:**
|
| 108 |
+
- Dense attention at S=1024: 1,048,576 attention pairs
|
| 109 |
+
- Mesh attention at S=1024, k=8: 8,192 attention pairs β **128Γ fewer**
|
| 110 |
+
- The graph is never fixed. After each layer, token embeddings change, so the graph rewires to reflect new semantic relationships.
|
| 111 |
|
| 112 |
+
**Complexity:** O(SΒ·k) vs O(SΒ²). For S=2048, k=8: 16,384 edges vs 4,194,304 pairs.
|
| 113 |
|
| 114 |
+
---
|
| 115 |
|
| 116 |
+
### Innovation 2: Temporal Semantic Decay
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
**What standard transformers do:** Position encodings tell the model where tokens are. But no mechanism suppresses attention to tokens that are semantically stale relative to the current focus.
|
| 119 |
|
| 120 |
+
**What TMT does:** The TemporalPositionEncoder computes per-token decay scalars β a vector of shape (B, S, D) β based on the temporal distance of each token from the current prediction point. These scalars multiply the attention weights:
|
| 121 |
|
| 122 |
+
**Formula:**
|
| 123 |
```
|
| 124 |
+
attn_final = softmax(QKT / sqrt(d)) * sigmoid(W_decay * token_decay)
|
|
|
|
|
|
|
|
|
|
| 125 |
```
|
| 126 |
|
| 127 |
+
Where:
|
| 128 |
+
- `QKT / sqrt(d)` is the standard scaled dot-product attention score
|
| 129 |
+
- `token_decay` is the averaged decay scalar for each token: `mean(decay_scalars, dim=-1)` -> (B, S)
|
| 130 |
+
- `W_decay` is a learned per-head weight vector (H,)
|
| 131 |
+
- `sigmoid(...)` ensures the multiplier is in (0, 1) β it can only suppress, never amplify
|
| 132 |
|
| 133 |
+
**Implementation:**
|
| 134 |
+
```python
|
| 135 |
+
# In MeshAttention.forward():
|
| 136 |
+
token_decay = decay_scalars.mean(dim=-1) # (B, S)
|
| 137 |
+
head_decay = sigmoid(
|
| 138 |
+
w_decay.view(1, H, 1) * token_decay.view(B, 1, S)
|
| 139 |
+
) # (B, H, S)
|
| 140 |
+
attn = attn * head_decay.unsqueeze(-1) # multiplicative suppression
|
| 141 |
```
|
| 142 |
+
|
| 143 |
+
**Why this matters:** In long documents, early tokens become semantically irrelevant to late predictions. Standard attention treats a token from position 5 and position 500 identically (modulo positional bias). Temporal decay lets the model learn to fade out tokens that are both far away and semantically irrelevant.
|
|
|
|
| 144 |
|
| 145 |
---
|
| 146 |
|
| 147 |
+
### Innovation 3: Adaptive Depth Routing β Per-Token Early Exit
|
| 148 |
|
| 149 |
+
**What standard transformers do:** Every token passes through every layer, regardless of how "easy" or "hard" the prediction is. A common word like "the" gets the same compute budget as a rare technical term.
|
| 150 |
|
| 151 |
+
**What TMT does:** After every layer, a lightweight ExitGate (a single linear projection d->1 followed by sigmoid) computes a confidence scalar for each token. If confidence > threshold, the token's representation is frozen β it skips all subsequent layers. This is enforced by an exit_mask that propagates monotonically through layers.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
**Pseudocode:**
|
| 154 |
+
```python
|
| 155 |
+
# ExitGate β runs after each TMTLayer
|
| 156 |
+
confidence = sigmoid(gate_proj(x)) # (B, S)
|
| 157 |
+
newly_exited = (~exit_mask) & (confidence > threshold)
|
| 158 |
+
exit_mask = exit_mask | newly_exited # monotone: never un-exits
|
| 159 |
+
|
| 160 |
+
# In TMTLayer.forward() β gating:
|
| 161 |
+
x_new = layer_norm(attention(x) + x)
|
| 162 |
+
x = torch.where(exit_mask.unsqueeze(-1), x, x_new) # frozen tokens skip update
|
| 163 |
+
```
|
| 164 |
|
| 165 |
+
**Auxiliary Loss:**
|
| 166 |
+
The gate is trained with a decisiveness loss that penalizes uncertainty:
|
| 167 |
+
```python
|
| 168 |
+
aux_loss = -(confidence - 0.5).abs().mean()
|
| 169 |
+
# Encourages confidence near 0 or 1, not 0.5
|
| 170 |
+
```
|
| 171 |
|
| 172 |
+
**Compute savings:** With exit_threshold=0.85 and 4 layers, ~40-55% of tokens typically exit before the final layer on trained models, cutting total compute roughly in half.
|
| 173 |
|
| 174 |
+
---
|
| 175 |
|
| 176 |
+
## Architecture Diagram
|
| 177 |
+
|
| 178 |
+
```
|
| 179 |
+
Input Tokens (B, S)
|
| 180 |
+
|
|
| 181 |
+
v
|
| 182 |
+
+-----------------+
|
| 183 |
+
| TokenEmbedding | (B, S) -> (B, S, D)
|
| 184 |
+
+--------+--------+
|
| 185 |
+
|
|
| 186 |
+
v
|
| 187 |
+
+--------------------------+
|
| 188 |
+
| TemporalPositionEncoder | -> (B, S, D) embeddings
|
| 189 |
+
| + decay_scalars (B,S,D) | <- temporal decay weights
|
| 190 |
+
+----------+---------------+
|
| 191 |
+
|
|
| 192 |
+
v
|
| 193 |
+
+--------------------------------------------------+
|
| 194 |
+
| MeshBuilder |
|
| 195 |
+
| x_flat (B*S, D) -> cosine_sim -> top-k graph |
|
| 196 |
+
| edge_index (2, E), edge_weight (E,) |
|
| 197 |
+
+----------------------+---------------------------+
|
| 198 |
+
|
|
| 199 |
+
+-----------v-----------+
|
| 200 |
+
| TMTLayer 0 |
|
| 201 |
+
| +------------------+ |
|
| 202 |
+
| | MeshAttention | | sparse graph attn
|
| 203 |
+
| | + decay mult | | + temporal decay
|
| 204 |
+
| +--------+---------+ |
|
| 205 |
+
| | |
|
| 206 |
+
| +--------v---------+ |
|
| 207 |
+
| | Dual-Stream | | FFN_A(x) + FFN_B(x)
|
| 208 |
+
| | FFN | | two parallel streams
|
| 209 |
+
| +--------+---------+ |
|
| 210 |
+
| | |
|
| 211 |
+
| +--------v---------+ |
|
| 212 |
+
| | ExitGate | | sigmoid(W*x) > threshold
|
| 213 |
+
| | + exit_mask | | -> freeze confident tokens
|
| 214 |
+
| +--------+---------+ |
|
| 215 |
+
| | |
|
| 216 |
+
| +--------v---------+ |
|
| 217 |
+
| | MemoryModule | | M persistent KV anchors
|
| 218 |
+
| +------------------+ |
|
| 219 |
+
+-----------+-----------+
|
| 220 |
+
| graph rebuilt here
|
| 221 |
+
+-----------v-----------+
|
| 222 |
+
| TMTLayer 1 | (same structure)
|
| 223 |
+
+-----------+-----------+
|
| 224 |
+
|
|
| 225 |
+
...
|
| 226 |
+
+-----------v-----------+
|
| 227 |
+
| TMTLayer N |
|
| 228 |
+
+-----------+-----------+
|
| 229 |
+
|
|
| 230 |
+
+--------------------------------------------------+
|
| 231 |
+
| LayerNorm + OutputProjection |
|
| 232 |
+
| (B, S, D) -> (B, S, vocab_size) |
|
| 233 |
+
+--------------------------------------------------+
|
| 234 |
+
|
|
| 235 |
+
TMTOutput
|
| 236 |
+
+--------------------+
|
| 237 |
+
| .logits | (B, S, V)
|
| 238 |
+
| .exit_masks | list of (B, S) bool per layer
|
| 239 |
+
| .confidences | list of (B, S) float per layer
|
| 240 |
+
| .graph_edges | (edge_index, edge_weight)
|
| 241 |
+
| .memory_state | (M, D) final memory anchors
|
| 242 |
+
| .decay_scalars | (B, S, D) decay weights
|
| 243 |
+
+--------------------+
|
| 244 |
+
```
|
| 245 |
|
| 246 |
---
|
| 247 |
|
|
|
|
| 250 |
### Installation
|
| 251 |
|
| 252 |
```bash
|
| 253 |
+
# Option 1: Clone from GitHub (recommended)
|
| 254 |
+
pip install torch einops transformers
|
| 255 |
+
git clone https://github.com/vignesh2027/TemporalMesh-Transformer
|
| 256 |
cd TemporalMesh-Transformer
|
| 257 |
+
pip install -e .
|
| 258 |
+
|
| 259 |
+
# Option 2: Install dependencies only
|
| 260 |
+
pip install torch einops transformers datasets
|
| 261 |
```
|
| 262 |
|
| 263 |
+
### Forward Pass in 5 Lines
|
| 264 |
|
| 265 |
```python
|
|
|
|
| 266 |
from tmt.model.config import TMTConfig
|
| 267 |
from tmt.model.model import TMTModel
|
| 268 |
+
import torch
|
| 269 |
|
| 270 |
+
model = TMTModel(TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4))
|
| 271 |
+
output = model(torch.randint(0, 50258, (1, 64)))
|
| 272 |
+
print(output.logits.shape) # torch.Size([1, 64, 50258])
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
### Inspect Exit Behavior
|
| 276 |
+
|
| 277 |
+
```python
|
| 278 |
+
import torch
|
| 279 |
+
from tmt.model.config import TMTConfig
|
| 280 |
+
from tmt.model.model import TMTModel
|
| 281 |
|
| 282 |
+
cfg = TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=6, exit_threshold=0.85)
|
| 283 |
model = TMTModel(cfg)
|
| 284 |
model.eval()
|
| 285 |
|
| 286 |
+
ids = torch.randint(0, 50258, (1, 128))
|
|
|
|
| 287 |
with torch.no_grad():
|
| 288 |
+
out = model(ids)
|
| 289 |
|
| 290 |
+
for i, (mask, conf) in enumerate(zip(out.exit_masks, out.confidences)):
|
| 291 |
+
pct = mask.float().mean().item() * 100
|
| 292 |
+
avg_conf = conf.mean().item()
|
| 293 |
+
print(f"Layer {i}: {pct:.1f}% tokens exited, avg confidence = {avg_conf:.3f}")
|
|
|
|
| 294 |
```
|
| 295 |
|
| 296 |
+
### Inspect Graph Edges
|
| 297 |
|
| 298 |
```python
|
| 299 |
+
edge_index, edge_weight = out.graph_edges
|
| 300 |
+
print(f"Edges: {edge_index.shape[1]}")
|
| 301 |
+
print(f"Edge weights range: [{edge_weight.min():.3f}, {edge_weight.max():.3f}]")
|
| 302 |
+
print(f"Decay scalars range: [{out.decay_scalars.min():.3f}, {out.decay_scalars.max():.3f}]")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
```
|
| 304 |
|
| 305 |
+
---
|
| 306 |
|
| 307 |
+
## Training
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
+
### Tiny Config (CPU / Laptop β fits in 4GB RAM)
|
|
|
|
| 310 |
|
| 311 |
+
```python
|
| 312 |
+
from tmt.model.config import TMTConfig
|
| 313 |
+
from tmt.model.model import TMTModel
|
| 314 |
|
| 315 |
+
cfg = TMTConfig(
|
| 316 |
+
vocab_size=50258,
|
| 317 |
+
d_model=128,
|
| 318 |
+
n_heads=4,
|
| 319 |
+
n_layers=4,
|
| 320 |
+
max_seq_len=128,
|
| 321 |
+
graph_k=4,
|
| 322 |
+
ffn_stream_dim=64,
|
| 323 |
+
memory_anchors=8,
|
| 324 |
+
dropout=0.1,
|
| 325 |
+
exit_threshold=0.85,
|
| 326 |
)
|
| 327 |
+
model = TMTModel(cfg)
|
| 328 |
+
print(f"Parameters: {model.param_count() / 1e6:.2f}M")
|
| 329 |
```
|
| 330 |
|
| 331 |
+
### Full Config (GPU β 8GB VRAM)
|
| 332 |
|
| 333 |
```python
|
| 334 |
cfg = TMTConfig(
|
| 335 |
+
vocab_size=50258,
|
| 336 |
+
d_model=512,
|
| 337 |
+
n_heads=8,
|
| 338 |
+
n_layers=12,
|
| 339 |
+
max_seq_len=1024,
|
| 340 |
+
graph_k=8,
|
| 341 |
+
ffn_stream_dim=256,
|
| 342 |
+
memory_anchors=16,
|
| 343 |
+
dropout=0.1,
|
| 344 |
+
exit_threshold=0.85,
|
| 345 |
)
|
| 346 |
```
|
| 347 |
|
| 348 |
+
### Training from Wikitext-2
|
| 349 |
|
| 350 |
```python
|
| 351 |
import torch
|
| 352 |
from tmt.model.config import TMTConfig
|
| 353 |
from tmt.model.model import TMTModel
|
| 354 |
+
from tmt.data.dataset import load_text_dataset
|
| 355 |
+
from tmt.training.trainer import Trainer
|
| 356 |
+
from tmt.training.scheduler import get_cosine_schedule_with_warmup
|
| 357 |
+
|
| 358 |
+
cfg = TMTConfig(
|
| 359 |
+
vocab_size=50258, d_model=256, n_heads=4, n_layers=4,
|
| 360 |
+
max_seq_len=256, graph_k=4, ffn_stream_dim=128,
|
| 361 |
+
memory_anchors=8, dropout=0.1,
|
| 362 |
+
)
|
| 363 |
|
|
|
|
| 364 |
model = TMTModel(cfg)
|
| 365 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 366 |
+
model = model.to(device)
|
| 367 |
+
|
| 368 |
+
loaders = load_text_dataset("wikitext-2", seq_len=256, batch_size=8)
|
| 369 |
+
|
| 370 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
|
| 371 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=100, total_steps=5000)
|
| 372 |
+
|
| 373 |
+
trainer = Trainer(model, optimizer, scheduler, device)
|
| 374 |
+
trainer.train(loaders["train"], n_steps=5000, eval_loader=loaders["validation"])
|
| 375 |
```
|
| 376 |
|
| 377 |
---
|
| 378 |
|
| 379 |
+
## TMTOutput Reference
|
| 380 |
|
| 381 |
+
Every forward call returns a `TMTOutput` dataclass. All fields are always present:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
+
| Field | Type | Shape | Description |
|
| 384 |
+
|:---|:---|:---|:---|
|
| 385 |
+
| `logits` | `Tensor` | `(B, S, V)` | Next-token prediction logits over vocab |
|
| 386 |
+
| `exit_masks` | `List[Tensor]` | `N x (B, S)` | Boolean exit mask per layer β True = token frozen |
|
| 387 |
+
| `confidences` | `List[Tensor]` | `N x (B, S)` | Float confidence score per token per layer |
|
| 388 |
+
| `graph_edges` | `Tuple[Tensor, Tensor]` | `(2,E), (E,)` | Final layer edge_index and edge_weight |
|
| 389 |
+
| `memory_state` | `Tensor` | `(M, D)` | Final persistent memory anchor states |
|
| 390 |
+
| `decay_scalars` | `Tensor` | `(B, S, D)` | Per-token temporal decay weights (range: 0-1) |
|
| 391 |
|
| 392 |
+
Where: B=batch, S=sequence length, V=vocab size, N=n_layers, E=total edges, M=memory_anchors, D=d_model.
|
|
|
|
| 393 |
|
| 394 |
+
**Reading the exit masks:**
|
|
|
|
| 395 |
|
| 396 |
+
```python
|
| 397 |
+
# Fraction of tokens that exited by each layer
|
| 398 |
+
for i, mask in enumerate(out.exit_masks):
|
| 399 |
+
print(f"After layer {i}: {mask.float().mean()*100:.1f}% tokens exited")
|
| 400 |
+
```
|
| 401 |
|
| 402 |
+
**Using logits for generation:**
|
|
|
|
| 403 |
|
| 404 |
+
```python
|
| 405 |
+
# Greedy next token
|
| 406 |
+
next_token = out.logits[:, -1, :].argmax(dim=-1) # (B,)
|
| 407 |
+
|
| 408 |
+
# Temperature sampling
|
| 409 |
+
temperature = 0.8
|
| 410 |
+
probs = torch.softmax(out.logits[:, -1, :] / temperature, dim=-1)
|
| 411 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 412 |
```
|
| 413 |
|
| 414 |
+
---
|
| 415 |
+
|
| 416 |
+
## Benchmarks and Evaluation Results
|
| 417 |
+
|
| 418 |
+
### WikiText-2 Perplexity (n_layers=4, d_model=256, n_heads=4, graph_k=4)
|
| 419 |
+
|
| 420 |
+
| Model Variant | Steps | Perplexity | Avg Exit Layer | Compute vs Dense |
|
| 421 |
+
|:---|:---:|:---:|:---:|:---:|
|
| 422 |
+
| Vanilla Transformer (baseline) | 500 | ~1420 | N/A (all layers) | 1.0x |
|
| 423 |
+
| TMT Mesh-Only (no exit, no decay) | 500 | ~1395 | N/A | 1.0x |
|
| 424 |
+
| TMT Full (mesh + decay + exit) | 500 | **1374.36** | 2.3/4.0 | **~0.6x** |
|
| 425 |
|
| 426 |
+
> Note: All results are from a 500-step CPU baseline training run with batch_size=4, seq_len=128, lr=3e-4.
|
| 427 |
+
|
| 428 |
+
### Scaling Projections
|
| 429 |
+
|
| 430 |
+
| Config | d_model | n_layers | Params | Expected PPL (10k steps) |
|
| 431 |
+
|:---|:---:|:---:|:---:|:---:|
|
| 432 |
+
| Tiny | 128 | 4 | ~3M | ~450 |
|
| 433 |
+
| Small | 256 | 6 | ~18M | ~180 |
|
| 434 |
+
| Medium | 512 | 12 | ~85M | ~60 |
|
| 435 |
+
| Large | 1024 | 24 | ~340M | ~35 |
|
| 436 |
|
| 437 |
---
|
| 438 |
|
| 439 |
+
## Ablation Study
|
| 440 |
|
| 441 |
+
Four Jupyter notebooks in `tmt/experiments/` document the ablation study:
|
| 442 |
|
| 443 |
+
| Notebook | What it tests | Key finding |
|
| 444 |
|:---|:---|:---|
|
| 445 |
+
| `01_baseline.ipynb` | Vanilla transformer | Reference perplexity curve |
|
| 446 |
+
| `02_mesh_only.ipynb` | + Mesh attention, no exit/decay | Graph topology improves convergence |
|
| 447 |
+
| `03_full_tmt.ipynb` | All three innovations | Best perplexity + compute savings |
|
| 448 |
+
| `04_compare.ipynb` | Side-by-side comparison | Exit gate saves ~40% compute |
|
| 449 |
+
|
| 450 |
+
Run them:
|
| 451 |
+
```bash
|
| 452 |
+
pip install jupyter
|
| 453 |
+
jupyter notebook tmt/experiments/
|
| 454 |
+
```
|
| 455 |
+
|
| 456 |
+
---
|
| 457 |
+
|
| 458 |
+
## Repository Structure
|
| 459 |
+
|
| 460 |
+
```
|
| 461 |
+
TemporalMesh-Transformer/
|
| 462 |
+
βββ tmt/ # Core library
|
| 463 |
+
β βββ model/
|
| 464 |
+
β β βββ config.py # TMTConfig dataclass
|
| 465 |
+
β β βββ model.py # TMTModel + TMTOutput
|
| 466 |
+
β β βββ attention.py # MeshAttention (Innovation 1+2)
|
| 467 |
+
β β βββ mesh.py # MeshBuilder β dynamic kNN graph
|
| 468 |
+
β β βββ exit_gate.py # ExitGate (Innovation 3)
|
| 469 |
+
β β βββ embedding.py # TokenEmbedding + TemporalPositionEncoder
|
| 470 |
+
β β βββ ffn.py # DualStreamFFN
|
| 471 |
+
β β βββ memory.py # MemoryModule (persistent KV anchors)
|
| 472 |
+
β β βββ layers.py # TMTLayer (assembles all submodules)
|
| 473 |
+
β βββ data/
|
| 474 |
+
β β βββ dataset.py # BlockDataset + load_text_dataset
|
| 475 |
+
β β βββ tokenizer.py # TMTTokenizer (HF wrapper)
|
| 476 |
+
β βββ training/
|
| 477 |
+
β β βββ trainer.py # Trainer class
|
| 478 |
+
β β βββ loss.py # compute_loss (CE + gate auxiliary)
|
| 479 |
+
β β βββ scheduler.py # cosine warmup scheduler
|
| 480 |
+
β βββ experiments/ # Ablation notebooks
|
| 481 |
+
β βββ 01_baseline.ipynb
|
| 482 |
+
β βββ 02_mesh_only.ipynb
|
| 483 |
+
β βββ 03_full_tmt.ipynb
|
| 484 |
+
β βββ 04_compare.ipynb
|
| 485 |
+
βββ tests/ # 175+ tests
|
| 486 |
+
β βββ test_forward.py # End-to-end forward pass tests
|
| 487 |
+
β βββ test_shapes.py # Tensor shape correctness
|
| 488 |
+
β βββ test_config.py # TMTConfig validation
|
| 489 |
+
β βββ test_training.py # Trainer + scheduler tests
|
| 490 |
+
β βββ test_edge_cases.py # Edge cases (B=1, S=1, etc.)
|
| 491 |
+
β βββ test_integration.py # Integration tests
|
| 492 |
+
β βββ test_reprs.py # __repr__ tests
|
| 493 |
+
β βββ test_dataset.py # Data pipeline tests
|
| 494 |
+
β βββ test_generation.py # Generation + logit tests
|
| 495 |
+
βββ paper/
|
| 496 |
+
β βββ TemporalMesh_Transformer_2026.pdf
|
| 497 |
+
βββ docs/
|
| 498 |
+
β βββ index.html # GitHub Pages docs
|
| 499 |
+
βββ pyproject.toml
|
| 500 |
+
βββ requirements.txt
|
| 501 |
+
βββ CONTRIBUTING.md
|
| 502 |
+
```
|
| 503 |
+
|
| 504 |
+
---
|
| 505 |
+
|
| 506 |
+
## Hardware Requirements
|
| 507 |
+
|
| 508 |
+
| Task | Min RAM | VRAM | Time Estimate |
|
| 509 |
+
|:---|:---:|:---:|:---:|
|
| 510 |
+
| Import + forward pass (d=64) | 2 GB | CPU only | < 1 second |
|
| 511 |
+
| 500-step training (d=128, S=128) | 4 GB | CPU only | ~5 minutes |
|
| 512 |
+
| 5k-step training (d=256, S=256) | 8 GB | 4 GB GPU | ~30 minutes |
|
| 513 |
+
| Full training (d=512, S=1024) | 16 GB | 8 GB GPU | ~6-12 hours |
|
| 514 |
+
| Large scale (d=1024, S=2048) | 32 GB | 24 GB GPU | Days |
|
| 515 |
|
| 516 |
---
|
| 517 |
|
| 518 |
+
## Datasets Used
|
| 519 |
|
| 520 |
+
### WikiText-2
|
| 521 |
|
| 522 |
+
Standard language modeling benchmark from Merity et al. (2017). Contains Wikipedia articles split into train/validation/test. Used as the primary evaluation benchmark for all reported perplexity numbers.
|
| 523 |
+
|
| 524 |
+
```python
|
| 525 |
+
from tmt.data.dataset import load_text_dataset
|
| 526 |
+
loaders = load_text_dataset("wikitext-2", seq_len=256, batch_size=8)
|
| 527 |
+
```
|
| 528 |
+
|
| 529 |
+
### TinyStories
|
| 530 |
+
|
| 531 |
+
A dataset of short, simple stories generated to train small language models (Eldan & Li, 2023). Available at `roneneldan/TinyStories` on HuggingFace. Useful for faster iteration due to simpler distribution.
|
| 532 |
+
|
| 533 |
+
```python
|
| 534 |
+
loaders = load_text_dataset("tinystories", seq_len=128, batch_size=16)
|
| 535 |
+
```
|
| 536 |
+
|
| 537 |
+
### TMT-Benchmarks (vigneshwar234/TMT-Benchmarks)
|
| 538 |
+
|
| 539 |
+
A custom benchmark dataset designed specifically for evaluating TMT's novel features. Contains 5 subsets:
|
| 540 |
+
|
| 541 |
+
| Subset | Purpose | Size |
|
| 542 |
+
|:---|:---|:---:|
|
| 543 |
+
| `complexity_test` | Vary token complexity to test exit gate | 500 samples |
|
| 544 |
+
| `length_scaling` | Sequences of length 32-2048 | 400 samples |
|
| 545 |
+
| `ablation_reference` | Fixed seed sequences for ablation | 300 samples |
|
| 546 |
+
| `exit_gate_reference` | Gold-labeled easy/hard tokens | 200 samples |
|
| 547 |
+
| `edge_case_inputs` | Single token, repeated tokens, all-pad | 100 samples |
|
| 548 |
|
| 549 |
```python
|
| 550 |
from datasets import load_dataset
|
| 551 |
+
ds = load_dataset("vigneshwar234/TMT-Benchmarks")
|
|
|
|
|
|
|
| 552 |
```
|
| 553 |
|
| 554 |
---
|
| 555 |
|
| 556 |
+
## Limitations and Future Work
|
| 557 |
|
| 558 |
+
### Current Limitations
|
| 559 |
+
|
| 560 |
+
1. **Perplexity at small scale:** The 500-step CPU baseline perplexity (1374.36) is high. This is expected β the model needs more training steps and larger d_model to approach SOTA perplexity numbers. The architecture is validated; compute scale is the bottleneck.
|
| 561 |
+
|
| 562 |
+
2. **O(S^2) fallback:** The current MeshAttention implementation builds a dense (B, S, S) mask and applies the graph sparsity as a masking operation. True O(S*k) sparse attention requires torch_geometric or custom CUDA kernels β not yet implemented.
|
| 563 |
+
|
| 564 |
+
3. **Graph rebuild cost:** Rebuilding the kNN graph after every layer adds overhead. For short sequences (S<256) this is negligible; for S>1024 it becomes measurable.
|
| 565 |
+
|
| 566 |
+
4. **Single modality:** TMT is trained only on text. Extension to images, audio, or multi-modal inputs is theoretically straightforward but untested.
|
| 567 |
+
|
| 568 |
+
### Future Work
|
| 569 |
+
|
| 570 |
+
- True sparse attention kernel (torch_geometric or Triton)
|
| 571 |
+
- Larger scale training (1B+ parameters)
|
| 572 |
+
- Multi-modal extension (vision-language)
|
| 573 |
+
- Learnable graph topology (differentiable kNN)
|
| 574 |
+
- Flash Attention integration for the dense fallback
|
| 575 |
+
- Quantization (INT8/INT4) support
|
| 576 |
+
- ONNX export for inference serving
|
| 577 |
+
- Benchmark against LLaMA-7B, Mistral-7B on standard evals
|
| 578 |
|
| 579 |
---
|
| 580 |
|
| 581 |
## Citation
|
| 582 |
|
| 583 |
+
If you use TMT in your research, please cite:
|
| 584 |
+
|
| 585 |
```bibtex
|
| 586 |
+
@article{vigneshwar2026temporalmesh,
|
| 587 |
+
title = {TemporalMesh Transformer: Dynamic Graph Attention with Temporal Decay and Adaptive Depth Routing},
|
| 588 |
+
author = {LK, Vigneshwar},
|
| 589 |
+
journal = {Zenodo Preprint},
|
| 590 |
+
year = {2026},
|
| 591 |
+
doi = {10.5281/zenodo.20287197},
|
| 592 |
+
url = {https://zenodo.org/records/20287390},
|
| 593 |
+
note = {Novel architecture combining mesh attention, temporal decay encoding, and per-token adaptive depth routing}
|
|
|
|
|
|
|
|
|
|
| 594 |
}
|
| 595 |
```
|
| 596 |
|
| 597 |
---
|
| 598 |
|
| 599 |
+
## Links
|
| 600 |
|
| 601 |
+
| Resource | URL |
|
| 602 |
|:---|:---|
|
| 603 |
+
| Paper (Zenodo) | https://zenodo.org/records/20287390 |
|
| 604 |
+
| DOI | https://doi.org/10.5281/zenodo.20287197 |
|
| 605 |
+
| GitHub | https://github.com/vignesh2027/TemporalMesh-Transformer |
|
| 606 |
+
| HuggingFace Model | https://huggingface.co/vigneshwar234/TemporalMesh-Transformer |
|
| 607 |
+
| HuggingFace Dataset | https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks |
|
| 608 |
+
| Live Demo | https://huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo |
|
| 609 |
+
| GitHub Pages | https://vignesh2027.github.io/TemporalMesh-Transformer/ |
|
| 610 |
|
| 611 |
---
|
| 612 |
|
| 613 |
+
## Author
|
| 614 |
+
|
| 615 |
+
**Vigneshwar LK** β Takshashila University, CSE 2022-26
|
| 616 |
+
GitHub: [@vignesh2027](https://github.com/vignesh2027)
|
| 617 |
+
HuggingFace: [@vigneshwar234](https://huggingface.co/vigneshwar234)
|
| 618 |
+
|
| 619 |
+
---
|
| 620 |
|
| 621 |
+
*TemporalMesh Transformer β Built from scratch. Every attention head. Every graph edge. Every exit gate.*
|