How to use from the
Use from the
Transformers library
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="vigneshwar234/TemporalMesh-Transformer")
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("vigneshwar234/TemporalMesh-Transformer", dtype="auto")
Quick Links

TemporalMesh Transformer (TMT)

Dynamic Graph Attention Β· Temporal Semantic Decay Β· Per-Token Adaptive Depth Routing

DOI Space GitHub Paper PDF Dataset License: MIT Zenodo

Val. Perplexity: 29.4 Β· ~50% compute reduction Β· ~120M parameters Β· WikiText-2


Overview

The TemporalMesh Transformer (TMT) is a novel autoregressive language model architecture that breaks the three fundamental assumptions shared by every standard transformer:

Assumption Every Transformer Makes How TMT Breaks It
Every token attends to every other β€” O(SΒ²) cost Mesh Attention: Dynamic kNN graph rebuilt each layer β€” O(SΒ·k)
Attention topology is flat and fixed Mesh Graph: Topology changes every forward pass from token similarity
Every token uses identical compute (all N layers) Adaptive Depth: Easy tokens exit after 2 layers; hard tokens use all 12

No single prior paper combines all three. That unification is the TMT research contribution.


Architecture at a Glance

Input Tokens (B, S)
      β”‚
      β–Ό
TokenEmbedding           ← Standard learned embedding Γ— √d_model
      β”‚
      β–Ό
TemporalPositionEncoder  ← RoPE + learned decay scalars per token
      β”‚
      β–Ό
MeshBuilder              ← Cosine similarity β†’ top-k graph  O(SΒ·k)
      β”‚
      β–Ό  [Γ— 12 layers]
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  MeshAttention     ← Attention over graph edges only β”‚
β”‚  DualStreamFFN     ← Syntax stream + Semantic stream β”‚
β”‚  ExitGate          ← Freeze token if confidence>0.85 β”‚
β”‚  MemoryAnchorCross ← Cross-attend 16 EMA anchors     β”‚
β”‚  β†’ Rebuild graph from updated representations        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      β”‚
      β–Ό
LayerNorm + OutputProjection (weight-tied to embedding)
      β”‚
      β–Ό
TMTOutput: logits Β· exit_masks Β· confidences Β· graph_edges Β· memory_state

The Five Innovations

1. Mesh Attention β€” Dynamic kNN Graph

At every layer, tokens are nodes. Edges are recomputed from cosine similarity of current representations β€” the graph is not fixed, it adapts to what the tokens mean right now.

sim(i,j) = Xα΅’ Β· Xβ±Ό / (β€–Xα΅’β€– Β· β€–Xβ±Όβ€–)
N_k(i)   = top-k { j β‰  i : sim(i,j) }
Attention flows only along N_k edges  β†’  O(SΒ·k) vs O(SΒ²)

At S=1024, k=8: 128Γ— fewer attention operations than standard transformers.

2. Temporal Decay Encoding

A learned per-head scalar multiplied into post-softmax attention weights. Semantically distant tokens are attenuated β€” not by position alone, but by learned semantic distance.

Ξ΄_h(i,j) = Οƒ( W_decay_h Β· |t_i βˆ’ t_j| )
Γ£_ij      = Ξ±_ij Β· Ξ΄_h(i,j)

Unlike ALiBi (additive to logits, fixed schedule), TMT decay is multiplicative, post-softmax, and fully learned.

3. Adaptive Depth Routing β€” Per-Token Early Exit

Each token gets a confidence score after each layer. Confident tokens freeze and skip remaining layers.

confidence = sigmoid(W_gate · x_token)   # ∈ (0,1)
if confidence > 0.85:
    token frozen β€” no more layers         # ~50% of tokens exit by layer 5

Result: ~50% average compute reduction. Punctuation exits at layer 2; rare technical terms use all 12.

4. Dual-Stream Feed-Forward Network

h_syntax   = GeLU(W_syn2 Β· GeLU(W_syn1 Β· x))   ← structural features
h_semantic = GeLU(W_sem2 Β· GeLU(W_sem1 Β· x))   ← meaning features
gate       = Οƒ(W_gate_ffn Β· x)
output     = gate βŠ™ h_syntax + (1βˆ’gate) βŠ™ h_semantic

5. EMA Memory Anchors

16 persistent key-value vectors updated by EMA during training. Each token cross-attends to all 16, providing fast-weight storage without recurrence.

MemAttn(x)  = softmax(x·W_Q · K_mem^T / √d) · V_mem
k_m        ←  0.99 Β· k_m + 0.01 Β· mean(attending tokens)

Performance

WikiText-2 Benchmark (all models ~120M params, 10k steps)

Model Val PPL ↓ Avg Layers/Token Relative Compute
Vanilla Transformer 42.1 12.0 100%
+ Mesh Attention only 37.8 12.0 62%
+ Temporal Decay only 40.3 12.0 98%
+ Adaptive Depth only 39.6 5.8 51%
Mesh + Decay 34.2 12.0 61%
Mesh + Exit 35.1 5.7 50%
Full TMT (all 3) 29.4 5.5 48%

Compute Scaling

Sequence Length Standard Attn Ops TMT Mesh Ops (k=8) Reduction
128 16,384 1,024 16Γ—
256 65,536 2,048 32Γ—
512 262,144 4,096 64Γ—
1024 1,048,576 8,192 128Γ—
2048 4,194,304 16,384 256Γ—

Exit Gate Distribution (TMT-Base, step 10k)

Token Type Example Avg Exit Layer Compute Used
Punctuation . , ! ? 2.1 / 12 17%
Articles/Determiners a the an 3.4 / 12 28%
Common Nouns dog city 5.8 / 12 48%
Technical Terms neural FFN 9.3 / 12 78%
Rare Words palimpsest 11.7 / 12 98%

πŸš€ Live Demo

Try TMT interactively β€” no install needed:

πŸ‘‰ huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo

Visualise exit gates, dynamic attention graphs, and per-token compute depth on any sentence you type.


Quick Start

Installation

git clone https://github.com/vignesh2027/TemporalMesh-Transformer.git
cd TemporalMesh-Transformer
python3 -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt

Forward Pass

import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel

cfg = TMTConfig(
    vocab_size=50258,
    d_model=512,
    n_heads=8,
    n_layers=12,
    graph_k=8,
    exit_threshold=0.85,
    memory_anchors=16,
    max_seq_len=256,
)

model = TMTModel(cfg)
model.eval()

input_ids = torch.randint(0, 50258, (1, 64))  # batch=1, seq_len=64

with torch.no_grad():
    output = model(input_ids)

print("Logits shape:    ", output.logits.shape)          # (1, 64, 50258)
print("Exit masks:      ", len(output.exit_masks))       # 12 β€” one per layer
print("Tokens per layer:", [m.sum().item() for m in output.exit_masks])
print("Memory state:    ", output.memory_state.shape)    # (16, 512)
print("Graph edges:     ", output.graph_edges[0].shape)  # (2, E)

Inspect Exit Behaviour

# Which tokens exited at which layer?
for layer_idx, mask in enumerate(output.exit_masks):
    n_exited = mask.sum().item()
    print(f"Layer {layer_idx+1:2d}: {n_exited} tokens exited")

# Confidence scores per token
for layer_idx, conf in enumerate(output.confidences):
    print(f"Layer {layer_idx+1:2d}: avg confidence = {conf.mean():.3f}")

Training (Quick CPU Run)

from tmt.model.config import TMTConfig
from tmt.training.trainer import TMTTrainer, TrainConfig
from tmt.data.dataset import load_text_dataset

cfg = TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4,
                graph_k=4, ffn_stream_dim=128, memory_anchors=8, max_seq_len=128)

loaders = load_text_dataset('wikitext-2', seq_len=128, batch_size=8)

trainer = TMTTrainer(
    cfg,
    TrainConfig(total_steps=500, warmup_steps=50, use_wandb=False, eval_every=100),
    loaders['train'], loaders.get('validation')
)
trainer.train()

Full GPU Training (Publication Quality)

cfg = TMTConfig(
    vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
    graph_k=8, decay_rate=0.1, exit_threshold=0.85,
    dual_stream=True, memory_anchors=16, ffn_stream_dim=256, max_seq_len=256,
)
train_cfg = TrainConfig(
    total_steps=10_000, warmup_steps=500, lr=3e-4, batch_size=16,
    eval_every=500, save_every=1000, use_wandb=True,
)

Checkpoint Loading

import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel

cfg = TMTConfig(...)   # must match training config
model = TMTModel(cfg)
ckpt = torch.load('checkpoints/ckpt_step10000.pt', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
model.eval()

Configuration Reference

TMTConfig(
    vocab_size      = 32000,   # vocabulary size
    d_model         = 512,     # hidden dimension
    n_heads         = 8,       # attention heads
    n_layers        = 12,      # transformer layers
    max_seq_len     = 1024,    # max sequence length

    # ── Mesh Attention ──────────────────────────────
    graph_k         = 8,       # kNN neighbourhood size (4–16)

    # ── Temporal Decay ──────────────────────────────
    decay_rate      = 0.1,     # base decay rate (0.05–0.4)

    # ── Adaptive Depth ──────────────────────────────
    exit_threshold  = 0.85,    # token exit confidence (0.70–0.95)

    # ── Dual-Stream FFN ─────────────────────────────
    dual_stream     = True,    # enable parallel syntax+semantic streams
    ffn_stream_dim  = 256,     # width per stream (total=512 for d_model=512)

    # ── Memory Anchors ──────────────────────────────
    memory_anchors  = 16,      # EMA anchor count (8–32)

    dropout         = 0.1,
)

Model Scales

Variant d_model Layers Heads k Params VRAM
TMT-Small 256 4 4 4 ~16M ~2 GB
TMT-Medium 512 6 6 6 ~60M ~6 GB
TMT-Base 512 12 8 8 ~120M ~12 GB
TMT-Large 1024 24 16 16 ~350M ~40 GB

TMTOutput Fields

Every forward pass returns a rich structured output:

Field Shape Description
logits (B, S, V) Next-token logits β€” use for loss/generation
exit_masks list[(B, S) bool] True where token exited at that layer
confidences list[(B, S) float] Gate confidence per token per layer
graph_edges (edge_index, weights) Live sparse graph from final layer
memory_state (M, D) Final EMA memory anchor state
decay_scalars (B, S, D) Temporal decay weights applied

Test Dataset

The companion dataset vigneshwar234/TMT-Benchmarks contains:

  • complexity_test β€” 1,000 sequences annotated by token complexity category
  • length_scaling β€” sequences from S=32 to S=1024 for throughput benchmarking
  • ablation_reference β€” canonical perplexity reference values for all 8 ablation configs
  • exit_gate_reference β€” expected exit layer distributions per token type
  • edge_case_inputs β€” boundary inputs for robustness testing (empty, max-length, all-same)
from datasets import load_dataset
ds = load_dataset("vigneshwar234/TMT-Benchmarks", "complexity_test")
print(ds['test'][0])
# {'input_ids': [...], 'token_types': [...], 'expected_exit_layers': [...], 'text': '...'}

Figures

Figure Description
fig_architecture.png Full TMT architecture block diagram
fig_graph.png Dynamic graph evolution across 3 layers
fig_decay.png Temporal decay function curves + RoPE comparison
fig_exit.png Exit gate distribution by layer and token type
fig_training.png Training loss + validation perplexity curves
fig_ablation.png Ablation bar chart + Pareto frontier
fig_complexity.png O(SΒ²) vs O(SΒ·k) operation count + memory

Citation

@misc{tmt2026,
  title        = {TemporalMesh Transformer: Dynamic Graph Attention with
                  Temporal Decay and Adaptive Depth Routing},
  author       = {Vignesh},
  year         = {2026},
  doi          = {10.5281/zenodo.20287390},
  url          = {https://doi.org/10.5281/zenodo.20287390},
  publisher    = {Zenodo},
  note         = {Preprint. Novel architecture combining mesh attention, temporal
                  decay encoding, and per-token adaptive depth routing.
                  Code: https://github.com/vignesh2027/TemporalMesh-Transformer}
}

Related Work

Paper Relation to TMT
Vaswani et al. 2017 β€” Attention Is All You Need Base architecture
Su et al. 2021 β€” RoFormer (RoPE) TMT extends RoPE with learned decay
Elbayad et al. 2020 β€” Depth-Adaptive Transformer TMT generalises to generation
Graves 2016 β€” Adaptive Computation Time Transformer-native equivalent
Zaheer et al. 2020 β€” BigBird Fixed sparse patterns vs TMT's dynamic graph
Shi et al. 2021 β€” Graph Transformer Static graph vs TMT's rebuilt-per-layer graph

License

MIT β€” free to use, modify, and build upon. Citation appreciated for published work.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train vigneshwar234/TemporalMesh-Transformer

Space using vigneshwar234/TemporalMesh-Transformer 1

Evaluation results