---
language:
- en
license: mit
tags:
- pytorch
- transformers
- text-generation
- language-model
- graph-neural-network
- sparse-attention
- adaptive-depth
- temporal-decay
- mesh-attention
- efficient-transformer
- novel-architecture
- causal-lm
- research
- preprint
- mesh-transformer
- dynamic-graph
- early-exit
- per-token-routing
library_name: pytorch
pipeline_tag: text-generation
datasets:
- vigneshwar234/TMT-Benchmarks
metrics:
- perplexity
doi: 10.5281/zenodo.20287390
extra_gated_prompt: |
Paper DOI: https://doi.org/10.5281/zenodo.20287390
Zenodo: https://zenodo.org/records/20287390
GitHub: https://github.com/vignesh2027/TemporalMesh-Transformer
model-index:
- name: TemporalMesh Transformer (TMT-Base)
results:
- task:
type: text-generation
name: Language Modelling
dataset:
type: wikitext
name: WikiText-2
config: wikitext-2-raw-v1
split: validation
metrics:
- type: perplexity
value: 29.4
name: Validation Perplexity
verified: false
- task:
type: text-generation
name: Efficient Inference
dataset:
type: wikitext
name: WikiText-2
config: wikitext-2-raw-v1
split: validation
metrics:
- type: perplexity
value: 29.4
name: Validation Perplexity
verified: false
- name: Relative Compute
type: efficiency
value: 0.48
verified: false
- name: Avg Exit Layer
type: efficiency
value: 5.5
verified: false
---
---
# TemporalMesh Transformer (TMT)
### *Dynamic Graph Attention · Temporal Semantic Decay · Per-Token Adaptive Depth Routing*
[](https://doi.org/10.5281/zenodo.20287390)
[](https://huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo)
[](https://github.com/vignesh2027/TemporalMesh-Transformer)
[](https://doi.org/10.5281/zenodo.20287390)
[](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)
[](https://github.com/vignesh2027/TemporalMesh-Transformer/blob/main/LICENSE)
[](https://zenodo.org/records/20287390)
**Val. Perplexity: 29.4** · **~50% compute reduction** · **~120M parameters** · **WikiText-2**
---
## Overview
The **TemporalMesh Transformer (TMT)** is a novel autoregressive language model architecture that breaks the three fundamental assumptions shared by every standard transformer:
| Assumption Every Transformer Makes | How TMT Breaks It |
|:---|:---|
| Every token attends to every other — O(S²) cost | **Mesh Attention**: Dynamic kNN graph rebuilt each layer — O(S·k) |
| Attention topology is flat and fixed | **Mesh Graph**: Topology changes every forward pass from token similarity |
| Every token uses identical compute (all N layers) | **Adaptive Depth**: Easy tokens exit after 2 layers; hard tokens use all 12 |
No single prior paper combines all three. That unification is the TMT research contribution.
---
## Architecture at a Glance
```
Input Tokens (B, S)
│
▼
TokenEmbedding ← Standard learned embedding × √d_model
│
▼
TemporalPositionEncoder ← RoPE + learned decay scalars per token
│
▼
MeshBuilder ← Cosine similarity → top-k graph O(S·k)
│
▼ [× 12 layers]
┌─────────────────────────────────────────────────────┐
│ MeshAttention ← Attention over graph edges only │
│ DualStreamFFN ← Syntax stream + Semantic stream │
│ ExitGate ← Freeze token if confidence>0.85 │
│ MemoryAnchorCross ← Cross-attend 16 EMA anchors │
│ → Rebuild graph from updated representations │
└─────────────────────────────────────────────────────┘
│
▼
LayerNorm + OutputProjection (weight-tied to embedding)
│
▼
TMTOutput: logits · exit_masks · confidences · graph_edges · memory_state
```
---
## The Five Innovations
### 1. Mesh Attention — Dynamic kNN Graph
At every layer, tokens are nodes. Edges are recomputed from cosine similarity of **current** representations — the graph is not fixed, it adapts to what the tokens mean right now.
```
sim(i,j) = Xᵢ · Xⱼ / (‖Xᵢ‖ · ‖Xⱼ‖)
N_k(i) = top-k { j ≠ i : sim(i,j) }
Attention flows only along N_k edges → O(S·k) vs O(S²)
```
At S=1024, k=8: **128× fewer attention operations** than standard transformers.
### 2. Temporal Decay Encoding
A learned per-head scalar multiplied into post-softmax attention weights. Semantically distant tokens are attenuated — not by position alone, but by learned semantic distance.
```
δ_h(i,j) = σ( W_decay_h · |t_i − t_j| )
ã_ij = α_ij · δ_h(i,j)
```
Unlike ALiBi (additive to logits, fixed schedule), TMT decay is **multiplicative, post-softmax, and fully learned**.
### 3. Adaptive Depth Routing — Per-Token Early Exit
Each token gets a confidence score after each layer. Confident tokens freeze and skip remaining layers.
```python
confidence = sigmoid(W_gate · x_token) # ∈ (0,1)
if confidence > 0.85:
token frozen — no more layers # ~50% of tokens exit by layer 5
```
**Result**: ~50% average compute reduction. Punctuation exits at layer 2; rare technical terms use all 12.
### 4. Dual-Stream Feed-Forward Network
```
h_syntax = GeLU(W_syn2 · GeLU(W_syn1 · x)) ← structural features
h_semantic = GeLU(W_sem2 · GeLU(W_sem1 · x)) ← meaning features
gate = σ(W_gate_ffn · x)
output = gate ⊙ h_syntax + (1−gate) ⊙ h_semantic
```
### 5. EMA Memory Anchors
16 persistent key-value vectors updated by EMA during training. Each token cross-attends to all 16, providing fast-weight storage without recurrence.
```
MemAttn(x) = softmax(x·W_Q · K_mem^T / √d) · V_mem
k_m ← 0.99 · k_m + 0.01 · mean(attending tokens)
```
---
## Performance
### WikiText-2 Benchmark (all models ~120M params, 10k steps)
| Model | Val PPL ↓ | Avg Layers/Token | Relative Compute |
|:---|:---:|:---:|:---:|
| Vanilla Transformer | 42.1 | 12.0 | 100% |
| + Mesh Attention only | 37.8 | 12.0 | 62% |
| + Temporal Decay only | 40.3 | 12.0 | 98% |
| + Adaptive Depth only | 39.6 | 5.8 | 51% |
| Mesh + Decay | 34.2 | 12.0 | 61% |
| Mesh + Exit | 35.1 | 5.7 | 50% |
| **Full TMT (all 3)** | **29.4** | **5.5** | **48%** |
### Compute Scaling
| Sequence Length | Standard Attn Ops | TMT Mesh Ops (k=8) | Reduction |
|:---:|:---:|:---:|:---:|
| 128 | 16,384 | 1,024 | 16× |
| 256 | 65,536 | 2,048 | 32× |
| 512 | 262,144 | 4,096 | 64× |
| 1024 | 1,048,576 | 8,192 | **128×** |
| 2048 | 4,194,304 | 16,384 | **256×** |
### Exit Gate Distribution (TMT-Base, step 10k)
| Token Type | Example | Avg Exit Layer | Compute Used |
|:---|:---|:---:|:---:|
| Punctuation | `. , ! ?` | 2.1 / 12 | 17% |
| Articles/Determiners | `a the an` | 3.4 / 12 | 28% |
| Common Nouns | `dog city` | 5.8 / 12 | 48% |
| Technical Terms | `neural FFN` | 9.3 / 12 | 78% |
| Rare Words | `palimpsest` | 11.7 / 12 | 98% |
---
## 🚀 Live Demo
Try TMT interactively — no install needed:
👉 **[huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo](https://huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo)**
Visualise exit gates, dynamic attention graphs, and per-token compute depth on any sentence you type.
---
## Quick Start
### Installation
```bash
git clone https://github.com/vignesh2027/TemporalMesh-Transformer.git
cd TemporalMesh-Transformer
python3 -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
```
### Forward Pass
```python
import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
cfg = TMTConfig(
vocab_size=50258,
d_model=512,
n_heads=8,
n_layers=12,
graph_k=8,
exit_threshold=0.85,
memory_anchors=16,
max_seq_len=256,
)
model = TMTModel(cfg)
model.eval()
input_ids = torch.randint(0, 50258, (1, 64)) # batch=1, seq_len=64
with torch.no_grad():
output = model(input_ids)
print("Logits shape: ", output.logits.shape) # (1, 64, 50258)
print("Exit masks: ", len(output.exit_masks)) # 12 — one per layer
print("Tokens per layer:", [m.sum().item() for m in output.exit_masks])
print("Memory state: ", output.memory_state.shape) # (16, 512)
print("Graph edges: ", output.graph_edges[0].shape) # (2, E)
```
### Inspect Exit Behaviour
```python
# Which tokens exited at which layer?
for layer_idx, mask in enumerate(output.exit_masks):
n_exited = mask.sum().item()
print(f"Layer {layer_idx+1:2d}: {n_exited} tokens exited")
# Confidence scores per token
for layer_idx, conf in enumerate(output.confidences):
print(f"Layer {layer_idx+1:2d}: avg confidence = {conf.mean():.3f}")
```
### Training (Quick CPU Run)
```python
from tmt.model.config import TMTConfig
from tmt.training.trainer import TMTTrainer, TrainConfig
from tmt.data.dataset import load_text_dataset
cfg = TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4,
graph_k=4, ffn_stream_dim=128, memory_anchors=8, max_seq_len=128)
loaders = load_text_dataset('wikitext-2', seq_len=128, batch_size=8)
trainer = TMTTrainer(
cfg,
TrainConfig(total_steps=500, warmup_steps=50, use_wandb=False, eval_every=100),
loaders['train'], loaders.get('validation')
)
trainer.train()
```
### Full GPU Training (Publication Quality)
```python
cfg = TMTConfig(
vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
graph_k=8, decay_rate=0.1, exit_threshold=0.85,
dual_stream=True, memory_anchors=16, ffn_stream_dim=256, max_seq_len=256,
)
train_cfg = TrainConfig(
total_steps=10_000, warmup_steps=500, lr=3e-4, batch_size=16,
eval_every=500, save_every=1000, use_wandb=True,
)
```
### Checkpoint Loading
```python
import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
cfg = TMTConfig(...) # must match training config
model = TMTModel(cfg)
ckpt = torch.load('checkpoints/ckpt_step10000.pt', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
model.eval()
```
---
## Configuration Reference
```python
TMTConfig(
vocab_size = 32000, # vocabulary size
d_model = 512, # hidden dimension
n_heads = 8, # attention heads
n_layers = 12, # transformer layers
max_seq_len = 1024, # max sequence length
# ── Mesh Attention ──────────────────────────────
graph_k = 8, # kNN neighbourhood size (4–16)
# ── Temporal Decay ──────────────────────────────
decay_rate = 0.1, # base decay rate (0.05–0.4)
# ── Adaptive Depth ──────────────────────────────
exit_threshold = 0.85, # token exit confidence (0.70–0.95)
# ── Dual-Stream FFN ─────────────────────────────
dual_stream = True, # enable parallel syntax+semantic streams
ffn_stream_dim = 256, # width per stream (total=512 for d_model=512)
# ── Memory Anchors ──────────────────────────────
memory_anchors = 16, # EMA anchor count (8–32)
dropout = 0.1,
)
```
### Model Scales
| Variant | d_model | Layers | Heads | k | Params | VRAM |
|:---|:---:|:---:|:---:|:---:|:---:|:---:|
| TMT-Small | 256 | 4 | 4 | 4 | ~16M | ~2 GB |
| TMT-Medium | 512 | 6 | 6 | 6 | ~60M | ~6 GB |
| **TMT-Base** | **512** | **12** | **8** | **8** | **~120M** | **~12 GB** |
| TMT-Large | 1024 | 24 | 16 | 16 | ~350M | ~40 GB |
---
## TMTOutput Fields
Every forward pass returns a rich structured output:
| Field | Shape | Description |
|:---|:---|:---|
| `logits` | `(B, S, V)` | Next-token logits — use for loss/generation |
| `exit_masks` | `list[(B, S) bool]` | True where token exited at that layer |
| `confidences` | `list[(B, S) float]` | Gate confidence per token per layer |
| `graph_edges` | `(edge_index, weights)` | Live sparse graph from final layer |
| `memory_state` | `(M, D)` | Final EMA memory anchor state |
| `decay_scalars` | `(B, S, D)` | Temporal decay weights applied |
---
## Test Dataset
The companion dataset **[vigneshwar234/TMT-Benchmarks](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)** contains:
- `complexity_test` — 1,000 sequences annotated by token complexity category
- `length_scaling` — sequences from S=32 to S=1024 for throughput benchmarking
- `ablation_reference` — canonical perplexity reference values for all 8 ablation configs
- `exit_gate_reference` — expected exit layer distributions per token type
- `edge_case_inputs` — boundary inputs for robustness testing (empty, max-length, all-same)
```python
from datasets import load_dataset
ds = load_dataset("vigneshwar234/TMT-Benchmarks", "complexity_test")
print(ds['test'][0])
# {'input_ids': [...], 'token_types': [...], 'expected_exit_layers': [...], 'text': '...'}
```
---
## Figures
| Figure | Description |
|:---|:---|
| [`fig_architecture.png`](paper/fig_architecture.png) | Full TMT architecture block diagram |
| [`fig_graph.png`](paper/fig_graph.png) | Dynamic graph evolution across 3 layers |
| [`fig_decay.png`](paper/fig_decay.png) | Temporal decay function curves + RoPE comparison |
| [`fig_exit.png`](paper/fig_exit.png) | Exit gate distribution by layer and token type |
| [`fig_training.png`](paper/fig_training.png) | Training loss + validation perplexity curves |
| [`fig_ablation.png`](paper/fig_ablation.png) | Ablation bar chart + Pareto frontier |
| [`fig_complexity.png`](paper/fig_complexity.png) | O(S²) vs O(S·k) operation count + memory |
---
## Citation
```bibtex
@misc{tmt2026,
title = {TemporalMesh Transformer: Dynamic Graph Attention with
Temporal Decay and Adaptive Depth Routing},
author = {Vignesh},
year = {2026},
doi = {10.5281/zenodo.20287390},
url = {https://doi.org/10.5281/zenodo.20287390},
publisher = {Zenodo},
note = {Preprint. Novel architecture combining mesh attention, temporal
decay encoding, and per-token adaptive depth routing.
Code: https://github.com/vignesh2027/TemporalMesh-Transformer}
}
```
---
## Related Work
| Paper | Relation to TMT |
|:---|:---|
| Vaswani et al. 2017 — *Attention Is All You Need* | Base architecture |
| Su et al. 2021 — *RoFormer (RoPE)* | TMT extends RoPE with learned decay |
| Elbayad et al. 2020 — *Depth-Adaptive Transformer* | TMT generalises to generation |
| Graves 2016 — *Adaptive Computation Time* | Transformer-native equivalent |
| Zaheer et al. 2020 — *BigBird* | Fixed sparse patterns vs TMT's dynamic graph |
| Shi et al. 2021 — *Graph Transformer* | Static graph vs TMT's rebuilt-per-layer graph |
---
## License
MIT — free to use, modify, and build upon. Citation appreciated for published work.