vigneshwar234's picture
Add live demo Space badge + section to model card
05b3a30 verified
metadata
language:
  - en
license: mit
tags:
  - pytorch
  - transformers
  - text-generation
  - language-model
  - graph-neural-network
  - sparse-attention
  - adaptive-depth
  - temporal-decay
  - mesh-attention
  - efficient-transformer
  - novel-architecture
  - causal-lm
  - research
  - preprint
  - mesh-transformer
  - dynamic-graph
  - early-exit
  - per-token-routing
library_name: pytorch
pipeline_tag: text-generation
datasets:
  - vigneshwar234/TMT-Benchmarks
metrics:
  - perplexity
doi: 10.5281/zenodo.20287390
extra_gated_prompt: |
  Paper DOI: https://doi.org/10.5281/zenodo.20287390
  Zenodo: https://zenodo.org/records/20287390
  GitHub: https://github.com/vignesh2027/TemporalMesh-Transformer
model-index:
  - name: TemporalMesh Transformer (TMT-Base)
    results:
      - task:
          type: text-generation
          name: Language Modelling
        dataset:
          type: wikitext
          name: WikiText-2
          config: wikitext-2-raw-v1
          split: validation
        metrics:
          - type: perplexity
            value: 29.4
            name: Validation Perplexity
            verified: false
      - task:
          type: text-generation
          name: Efficient Inference
        dataset:
          type: wikitext
          name: WikiText-2
          config: wikitext-2-raw-v1
          split: validation
        metrics:
          - type: perplexity
            value: 29.4
            name: Validation Perplexity
            verified: false
          - name: Relative Compute
            type: efficiency
            value: 0.48
            verified: false
          - name: Avg Exit Layer
            type: efficiency
            value: 5.5
            verified: false

TemporalMesh Transformer (TMT)

Dynamic Graph Attention · Temporal Semantic Decay · Per-Token Adaptive Depth Routing

DOI Space GitHub Paper PDF Dataset License: MIT Zenodo

Val. Perplexity: 29.4 · ~50% compute reduction · ~120M parameters · WikiText-2


Overview

The TemporalMesh Transformer (TMT) is a novel autoregressive language model architecture that breaks the three fundamental assumptions shared by every standard transformer:

Assumption Every Transformer Makes How TMT Breaks It
Every token attends to every other — O(S²) cost Mesh Attention: Dynamic kNN graph rebuilt each layer — O(S·k)
Attention topology is flat and fixed Mesh Graph: Topology changes every forward pass from token similarity
Every token uses identical compute (all N layers) Adaptive Depth: Easy tokens exit after 2 layers; hard tokens use all 12

No single prior paper combines all three. That unification is the TMT research contribution.


Architecture at a Glance

Input Tokens (B, S)
      │
      ▼
TokenEmbedding           ← Standard learned embedding × √d_model
      │
      ▼
TemporalPositionEncoder  ← RoPE + learned decay scalars per token
      │
      ▼
MeshBuilder              ← Cosine similarity → top-k graph  O(S·k)
      │
      ▼  [× 12 layers]
┌─────────────────────────────────────────────────────┐
│  MeshAttention     ← Attention over graph edges only │
│  DualStreamFFN     ← Syntax stream + Semantic stream │
│  ExitGate          ← Freeze token if confidence>0.85 │
│  MemoryAnchorCross ← Cross-attend 16 EMA anchors     │
│  → Rebuild graph from updated representations        │
└─────────────────────────────────────────────────────┘
      │
      ▼
LayerNorm + OutputProjection (weight-tied to embedding)
      │
      ▼
TMTOutput: logits · exit_masks · confidences · graph_edges · memory_state

The Five Innovations

1. Mesh Attention — Dynamic kNN Graph

At every layer, tokens are nodes. Edges are recomputed from cosine similarity of current representations — the graph is not fixed, it adapts to what the tokens mean right now.

sim(i,j) = Xᵢ · Xⱼ / (‖Xᵢ‖ · ‖Xⱼ‖)
N_k(i)   = top-k { j ≠ i : sim(i,j) }
Attention flows only along N_k edges  →  O(S·k) vs O(S²)

At S=1024, k=8: 128× fewer attention operations than standard transformers.

2. Temporal Decay Encoding

A learned per-head scalar multiplied into post-softmax attention weights. Semantically distant tokens are attenuated — not by position alone, but by learned semantic distance.

δ_h(i,j) = σ( W_decay_h · |t_i − t_j| )
ã_ij      = α_ij · δ_h(i,j)

Unlike ALiBi (additive to logits, fixed schedule), TMT decay is multiplicative, post-softmax, and fully learned.

3. Adaptive Depth Routing — Per-Token Early Exit

Each token gets a confidence score after each layer. Confident tokens freeze and skip remaining layers.

confidence = sigmoid(W_gate · x_token)   # ∈ (0,1)
if confidence > 0.85:
    token frozen — no more layers         # ~50% of tokens exit by layer 5

Result: ~50% average compute reduction. Punctuation exits at layer 2; rare technical terms use all 12.

4. Dual-Stream Feed-Forward Network

h_syntax   = GeLU(W_syn2 · GeLU(W_syn1 · x))   ← structural features
h_semantic = GeLU(W_sem2 · GeLU(W_sem1 · x))   ← meaning features
gate       = σ(W_gate_ffn · x)
output     = gate ⊙ h_syntax + (1−gate) ⊙ h_semantic

5. EMA Memory Anchors

16 persistent key-value vectors updated by EMA during training. Each token cross-attends to all 16, providing fast-weight storage without recurrence.

MemAttn(x)  = softmax(x·W_Q · K_mem^T / √d) · V_mem
k_m        ←  0.99 · k_m + 0.01 · mean(attending tokens)

Performance

WikiText-2 Benchmark (all models ~120M params, 10k steps)

Model Val PPL ↓ Avg Layers/Token Relative Compute
Vanilla Transformer 42.1 12.0 100%
+ Mesh Attention only 37.8 12.0 62%
+ Temporal Decay only 40.3 12.0 98%
+ Adaptive Depth only 39.6 5.8 51%
Mesh + Decay 34.2 12.0 61%
Mesh + Exit 35.1 5.7 50%
Full TMT (all 3) 29.4 5.5 48%

Compute Scaling

Sequence Length Standard Attn Ops TMT Mesh Ops (k=8) Reduction
128 16,384 1,024 16×
256 65,536 2,048 32×
512 262,144 4,096 64×
1024 1,048,576 8,192 128×
2048 4,194,304 16,384 256×

Exit Gate Distribution (TMT-Base, step 10k)

Token Type Example Avg Exit Layer Compute Used
Punctuation . , ! ? 2.1 / 12 17%
Articles/Determiners a the an 3.4 / 12 28%
Common Nouns dog city 5.8 / 12 48%
Technical Terms neural FFN 9.3 / 12 78%
Rare Words palimpsest 11.7 / 12 98%

🚀 Live Demo

Try TMT interactively — no install needed:

👉 huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo

Visualise exit gates, dynamic attention graphs, and per-token compute depth on any sentence you type.


Quick Start

Installation

git clone https://github.com/vignesh2027/TemporalMesh-Transformer.git
cd TemporalMesh-Transformer
python3 -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt

Forward Pass

import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel

cfg = TMTConfig(
    vocab_size=50258,
    d_model=512,
    n_heads=8,
    n_layers=12,
    graph_k=8,
    exit_threshold=0.85,
    memory_anchors=16,
    max_seq_len=256,
)

model = TMTModel(cfg)
model.eval()

input_ids = torch.randint(0, 50258, (1, 64))  # batch=1, seq_len=64

with torch.no_grad():
    output = model(input_ids)

print("Logits shape:    ", output.logits.shape)          # (1, 64, 50258)
print("Exit masks:      ", len(output.exit_masks))       # 12 — one per layer
print("Tokens per layer:", [m.sum().item() for m in output.exit_masks])
print("Memory state:    ", output.memory_state.shape)    # (16, 512)
print("Graph edges:     ", output.graph_edges[0].shape)  # (2, E)

Inspect Exit Behaviour

# Which tokens exited at which layer?
for layer_idx, mask in enumerate(output.exit_masks):
    n_exited = mask.sum().item()
    print(f"Layer {layer_idx+1:2d}: {n_exited} tokens exited")

# Confidence scores per token
for layer_idx, conf in enumerate(output.confidences):
    print(f"Layer {layer_idx+1:2d}: avg confidence = {conf.mean():.3f}")

Training (Quick CPU Run)

from tmt.model.config import TMTConfig
from tmt.training.trainer import TMTTrainer, TrainConfig
from tmt.data.dataset import load_text_dataset

cfg = TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4,
                graph_k=4, ffn_stream_dim=128, memory_anchors=8, max_seq_len=128)

loaders = load_text_dataset('wikitext-2', seq_len=128, batch_size=8)

trainer = TMTTrainer(
    cfg,
    TrainConfig(total_steps=500, warmup_steps=50, use_wandb=False, eval_every=100),
    loaders['train'], loaders.get('validation')
)
trainer.train()

Full GPU Training (Publication Quality)

cfg = TMTConfig(
    vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
    graph_k=8, decay_rate=0.1, exit_threshold=0.85,
    dual_stream=True, memory_anchors=16, ffn_stream_dim=256, max_seq_len=256,
)
train_cfg = TrainConfig(
    total_steps=10_000, warmup_steps=500, lr=3e-4, batch_size=16,
    eval_every=500, save_every=1000, use_wandb=True,
)

Checkpoint Loading

import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel

cfg = TMTConfig(...)   # must match training config
model = TMTModel(cfg)
ckpt = torch.load('checkpoints/ckpt_step10000.pt', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
model.eval()

Configuration Reference

TMTConfig(
    vocab_size      = 32000,   # vocabulary size
    d_model         = 512,     # hidden dimension
    n_heads         = 8,       # attention heads
    n_layers        = 12,      # transformer layers
    max_seq_len     = 1024,    # max sequence length

    # ── Mesh Attention ──────────────────────────────
    graph_k         = 8,       # kNN neighbourhood size (4–16)

    # ── Temporal Decay ──────────────────────────────
    decay_rate      = 0.1,     # base decay rate (0.05–0.4)

    # ── Adaptive Depth ──────────────────────────────
    exit_threshold  = 0.85,    # token exit confidence (0.70–0.95)

    # ── Dual-Stream FFN ─────────────────────────────
    dual_stream     = True,    # enable parallel syntax+semantic streams
    ffn_stream_dim  = 256,     # width per stream (total=512 for d_model=512)

    # ── Memory Anchors ──────────────────────────────
    memory_anchors  = 16,      # EMA anchor count (8–32)

    dropout         = 0.1,
)

Model Scales

Variant d_model Layers Heads k Params VRAM
TMT-Small 256 4 4 4 ~16M ~2 GB
TMT-Medium 512 6 6 6 ~60M ~6 GB
TMT-Base 512 12 8 8 ~120M ~12 GB
TMT-Large 1024 24 16 16 ~350M ~40 GB

TMTOutput Fields

Every forward pass returns a rich structured output:

Field Shape Description
logits (B, S, V) Next-token logits — use for loss/generation
exit_masks list[(B, S) bool] True where token exited at that layer
confidences list[(B, S) float] Gate confidence per token per layer
graph_edges (edge_index, weights) Live sparse graph from final layer
memory_state (M, D) Final EMA memory anchor state
decay_scalars (B, S, D) Temporal decay weights applied

Test Dataset

The companion dataset vigneshwar234/TMT-Benchmarks contains:

  • complexity_test — 1,000 sequences annotated by token complexity category
  • length_scaling — sequences from S=32 to S=1024 for throughput benchmarking
  • ablation_reference — canonical perplexity reference values for all 8 ablation configs
  • exit_gate_reference — expected exit layer distributions per token type
  • edge_case_inputs — boundary inputs for robustness testing (empty, max-length, all-same)
from datasets import load_dataset
ds = load_dataset("vigneshwar234/TMT-Benchmarks", "complexity_test")
print(ds['test'][0])
# {'input_ids': [...], 'token_types': [...], 'expected_exit_layers': [...], 'text': '...'}

Figures

Figure Description
fig_architecture.png Full TMT architecture block diagram
fig_graph.png Dynamic graph evolution across 3 layers
fig_decay.png Temporal decay function curves + RoPE comparison
fig_exit.png Exit gate distribution by layer and token type
fig_training.png Training loss + validation perplexity curves
fig_ablation.png Ablation bar chart + Pareto frontier
fig_complexity.png O(S²) vs O(S·k) operation count + memory

Citation

@misc{tmt2026,
  title        = {TemporalMesh Transformer: Dynamic Graph Attention with
                  Temporal Decay and Adaptive Depth Routing},
  author       = {Vignesh},
  year         = {2026},
  doi          = {10.5281/zenodo.20287390},
  url          = {https://doi.org/10.5281/zenodo.20287390},
  publisher    = {Zenodo},
  note         = {Preprint. Novel architecture combining mesh attention, temporal
                  decay encoding, and per-token adaptive depth routing.
                  Code: https://github.com/vignesh2027/TemporalMesh-Transformer}
}

Related Work

Paper Relation to TMT
Vaswani et al. 2017 — Attention Is All You Need Base architecture
Su et al. 2021 — RoFormer (RoPE) TMT extends RoPE with learned decay
Elbayad et al. 2020 — Depth-Adaptive Transformer TMT generalises to generation
Graves 2016 — Adaptive Computation Time Transformer-native equivalent
Zaheer et al. 2020 — BigBird Fixed sparse patterns vs TMT's dynamic graph
Shi et al. 2021 — Graph Transformer Static graph vs TMT's rebuilt-per-layer graph

License

MIT — free to use, modify, and build upon. Citation appreciated for published work.