--- language: - en license: mit tags: - pytorch - transformers - text-generation - language-model - graph-neural-network - sparse-attention - adaptive-depth - temporal-decay - mesh-attention - efficient-transformer - novel-architecture - causal-lm - research - preprint - mesh-transformer - dynamic-graph - early-exit - per-token-routing library_name: pytorch pipeline_tag: text-generation datasets: - vigneshwar234/TMT-Benchmarks metrics: - perplexity doi: 10.5281/zenodo.20287390 extra_gated_prompt: | Paper DOI: https://doi.org/10.5281/zenodo.20287390 Zenodo: https://zenodo.org/records/20287390 GitHub: https://github.com/vignesh2027/TemporalMesh-Transformer model-index: - name: TemporalMesh Transformer (TMT-Base) results: - task: type: text-generation name: Language Modelling dataset: type: wikitext name: WikiText-2 config: wikitext-2-raw-v1 split: validation metrics: - type: perplexity value: 29.4 name: Validation Perplexity verified: false - task: type: text-generation name: Efficient Inference dataset: type: wikitext name: WikiText-2 config: wikitext-2-raw-v1 split: validation metrics: - type: perplexity value: 29.4 name: Validation Perplexity verified: false - name: Relative Compute type: efficiency value: 0.48 verified: false - name: Avg Exit Layer type: efficiency value: 5.5 verified: false --- ---
# TemporalMesh Transformer (TMT) ### *Dynamic Graph Attention · Temporal Semantic Decay · Per-Token Adaptive Depth Routing* [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.20287390.svg)](https://doi.org/10.5281/zenodo.20287390) [![Space](https://img.shields.io/badge/🤗%20Space-Live%20Demo-orange?style=flat-square)](https://huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo) [![GitHub](https://img.shields.io/badge/GitHub-vignesh2027%2FTemporalMesh--Transformer-181717?style=flat-square&logo=github)](https://github.com/vignesh2027/TemporalMesh-Transformer) [![Paper PDF](https://img.shields.io/badge/Paper-PDF%2020%20pages-red?style=flat-square&logo=adobeacrobatreader)](https://doi.org/10.5281/zenodo.20287390) [![Dataset](https://img.shields.io/badge/Dataset-TMT--Benchmarks-FFD21E?style=flat-square&logo=huggingface)](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks) [![License: MIT](https://img.shields.io/badge/License-MIT-green?style=flat-square)](https://github.com/vignesh2027/TemporalMesh-Transformer/blob/main/LICENSE) [![Zenodo](https://img.shields.io/badge/Zenodo-Published-blue?style=flat-square&logo=zenodo)](https://zenodo.org/records/20287390) **Val. Perplexity: 29.4** · **~50% compute reduction** · **~120M parameters** · **WikiText-2**
--- ## Overview The **TemporalMesh Transformer (TMT)** is a novel autoregressive language model architecture that breaks the three fundamental assumptions shared by every standard transformer: | Assumption Every Transformer Makes | How TMT Breaks It | |:---|:---| | Every token attends to every other — O(S²) cost | **Mesh Attention**: Dynamic kNN graph rebuilt each layer — O(S·k) | | Attention topology is flat and fixed | **Mesh Graph**: Topology changes every forward pass from token similarity | | Every token uses identical compute (all N layers) | **Adaptive Depth**: Easy tokens exit after 2 layers; hard tokens use all 12 | No single prior paper combines all three. That unification is the TMT research contribution. --- ## Architecture at a Glance ``` Input Tokens (B, S) │ ▼ TokenEmbedding ← Standard learned embedding × √d_model │ ▼ TemporalPositionEncoder ← RoPE + learned decay scalars per token │ ▼ MeshBuilder ← Cosine similarity → top-k graph O(S·k) │ ▼ [× 12 layers] ┌─────────────────────────────────────────────────────┐ │ MeshAttention ← Attention over graph edges only │ │ DualStreamFFN ← Syntax stream + Semantic stream │ │ ExitGate ← Freeze token if confidence>0.85 │ │ MemoryAnchorCross ← Cross-attend 16 EMA anchors │ │ → Rebuild graph from updated representations │ └─────────────────────────────────────────────────────┘ │ ▼ LayerNorm + OutputProjection (weight-tied to embedding) │ ▼ TMTOutput: logits · exit_masks · confidences · graph_edges · memory_state ``` --- ## The Five Innovations ### 1. Mesh Attention — Dynamic kNN Graph At every layer, tokens are nodes. Edges are recomputed from cosine similarity of **current** representations — the graph is not fixed, it adapts to what the tokens mean right now. ``` sim(i,j) = Xᵢ · Xⱼ / (‖Xᵢ‖ · ‖Xⱼ‖) N_k(i) = top-k { j ≠ i : sim(i,j) } Attention flows only along N_k edges → O(S·k) vs O(S²) ``` At S=1024, k=8: **128× fewer attention operations** than standard transformers. ### 2. Temporal Decay Encoding A learned per-head scalar multiplied into post-softmax attention weights. Semantically distant tokens are attenuated — not by position alone, but by learned semantic distance. ``` δ_h(i,j) = σ( W_decay_h · |t_i − t_j| ) ã_ij = α_ij · δ_h(i,j) ``` Unlike ALiBi (additive to logits, fixed schedule), TMT decay is **multiplicative, post-softmax, and fully learned**. ### 3. Adaptive Depth Routing — Per-Token Early Exit Each token gets a confidence score after each layer. Confident tokens freeze and skip remaining layers. ```python confidence = sigmoid(W_gate · x_token) # ∈ (0,1) if confidence > 0.85: token frozen — no more layers # ~50% of tokens exit by layer 5 ``` **Result**: ~50% average compute reduction. Punctuation exits at layer 2; rare technical terms use all 12. ### 4. Dual-Stream Feed-Forward Network ``` h_syntax = GeLU(W_syn2 · GeLU(W_syn1 · x)) ← structural features h_semantic = GeLU(W_sem2 · GeLU(W_sem1 · x)) ← meaning features gate = σ(W_gate_ffn · x) output = gate ⊙ h_syntax + (1−gate) ⊙ h_semantic ``` ### 5. EMA Memory Anchors 16 persistent key-value vectors updated by EMA during training. Each token cross-attends to all 16, providing fast-weight storage without recurrence. ``` MemAttn(x) = softmax(x·W_Q · K_mem^T / √d) · V_mem k_m ← 0.99 · k_m + 0.01 · mean(attending tokens) ``` --- ## Performance ### WikiText-2 Benchmark (all models ~120M params, 10k steps) | Model | Val PPL ↓ | Avg Layers/Token | Relative Compute | |:---|:---:|:---:|:---:| | Vanilla Transformer | 42.1 | 12.0 | 100% | | + Mesh Attention only | 37.8 | 12.0 | 62% | | + Temporal Decay only | 40.3 | 12.0 | 98% | | + Adaptive Depth only | 39.6 | 5.8 | 51% | | Mesh + Decay | 34.2 | 12.0 | 61% | | Mesh + Exit | 35.1 | 5.7 | 50% | | **Full TMT (all 3)** | **29.4** | **5.5** | **48%** | ### Compute Scaling | Sequence Length | Standard Attn Ops | TMT Mesh Ops (k=8) | Reduction | |:---:|:---:|:---:|:---:| | 128 | 16,384 | 1,024 | 16× | | 256 | 65,536 | 2,048 | 32× | | 512 | 262,144 | 4,096 | 64× | | 1024 | 1,048,576 | 8,192 | **128×** | | 2048 | 4,194,304 | 16,384 | **256×** | ### Exit Gate Distribution (TMT-Base, step 10k) | Token Type | Example | Avg Exit Layer | Compute Used | |:---|:---|:---:|:---:| | Punctuation | `. , ! ?` | 2.1 / 12 | 17% | | Articles/Determiners | `a the an` | 3.4 / 12 | 28% | | Common Nouns | `dog city` | 5.8 / 12 | 48% | | Technical Terms | `neural FFN` | 9.3 / 12 | 78% | | Rare Words | `palimpsest` | 11.7 / 12 | 98% | --- ## 🚀 Live Demo Try TMT interactively — no install needed: 👉 **[huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo](https://huggingface.co/spaces/vigneshwar234/TemporalMesh-Transformer-Demo)** Visualise exit gates, dynamic attention graphs, and per-token compute depth on any sentence you type. --- ## Quick Start ### Installation ```bash git clone https://github.com/vignesh2027/TemporalMesh-Transformer.git cd TemporalMesh-Transformer python3 -m venv .venv && source .venv/bin/activate pip install -r requirements.txt ``` ### Forward Pass ```python import torch from tmt.model.config import TMTConfig from tmt.model.model import TMTModel cfg = TMTConfig( vocab_size=50258, d_model=512, n_heads=8, n_layers=12, graph_k=8, exit_threshold=0.85, memory_anchors=16, max_seq_len=256, ) model = TMTModel(cfg) model.eval() input_ids = torch.randint(0, 50258, (1, 64)) # batch=1, seq_len=64 with torch.no_grad(): output = model(input_ids) print("Logits shape: ", output.logits.shape) # (1, 64, 50258) print("Exit masks: ", len(output.exit_masks)) # 12 — one per layer print("Tokens per layer:", [m.sum().item() for m in output.exit_masks]) print("Memory state: ", output.memory_state.shape) # (16, 512) print("Graph edges: ", output.graph_edges[0].shape) # (2, E) ``` ### Inspect Exit Behaviour ```python # Which tokens exited at which layer? for layer_idx, mask in enumerate(output.exit_masks): n_exited = mask.sum().item() print(f"Layer {layer_idx+1:2d}: {n_exited} tokens exited") # Confidence scores per token for layer_idx, conf in enumerate(output.confidences): print(f"Layer {layer_idx+1:2d}: avg confidence = {conf.mean():.3f}") ``` ### Training (Quick CPU Run) ```python from tmt.model.config import TMTConfig from tmt.training.trainer import TMTTrainer, TrainConfig from tmt.data.dataset import load_text_dataset cfg = TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4, graph_k=4, ffn_stream_dim=128, memory_anchors=8, max_seq_len=128) loaders = load_text_dataset('wikitext-2', seq_len=128, batch_size=8) trainer = TMTTrainer( cfg, TrainConfig(total_steps=500, warmup_steps=50, use_wandb=False, eval_every=100), loaders['train'], loaders.get('validation') ) trainer.train() ``` ### Full GPU Training (Publication Quality) ```python cfg = TMTConfig( vocab_size=50258, d_model=512, n_heads=8, n_layers=12, graph_k=8, decay_rate=0.1, exit_threshold=0.85, dual_stream=True, memory_anchors=16, ffn_stream_dim=256, max_seq_len=256, ) train_cfg = TrainConfig( total_steps=10_000, warmup_steps=500, lr=3e-4, batch_size=16, eval_every=500, save_every=1000, use_wandb=True, ) ``` ### Checkpoint Loading ```python import torch from tmt.model.config import TMTConfig from tmt.model.model import TMTModel cfg = TMTConfig(...) # must match training config model = TMTModel(cfg) ckpt = torch.load('checkpoints/ckpt_step10000.pt', map_location='cpu') model.load_state_dict(ckpt['model_state']) model.eval() ``` --- ## Configuration Reference ```python TMTConfig( vocab_size = 32000, # vocabulary size d_model = 512, # hidden dimension n_heads = 8, # attention heads n_layers = 12, # transformer layers max_seq_len = 1024, # max sequence length # ── Mesh Attention ────────────────────────────── graph_k = 8, # kNN neighbourhood size (4–16) # ── Temporal Decay ────────────────────────────── decay_rate = 0.1, # base decay rate (0.05–0.4) # ── Adaptive Depth ────────────────────────────── exit_threshold = 0.85, # token exit confidence (0.70–0.95) # ── Dual-Stream FFN ───────────────────────────── dual_stream = True, # enable parallel syntax+semantic streams ffn_stream_dim = 256, # width per stream (total=512 for d_model=512) # ── Memory Anchors ────────────────────────────── memory_anchors = 16, # EMA anchor count (8–32) dropout = 0.1, ) ``` ### Model Scales | Variant | d_model | Layers | Heads | k | Params | VRAM | |:---|:---:|:---:|:---:|:---:|:---:|:---:| | TMT-Small | 256 | 4 | 4 | 4 | ~16M | ~2 GB | | TMT-Medium | 512 | 6 | 6 | 6 | ~60M | ~6 GB | | **TMT-Base** | **512** | **12** | **8** | **8** | **~120M** | **~12 GB** | | TMT-Large | 1024 | 24 | 16 | 16 | ~350M | ~40 GB | --- ## TMTOutput Fields Every forward pass returns a rich structured output: | Field | Shape | Description | |:---|:---|:---| | `logits` | `(B, S, V)` | Next-token logits — use for loss/generation | | `exit_masks` | `list[(B, S) bool]` | True where token exited at that layer | | `confidences` | `list[(B, S) float]` | Gate confidence per token per layer | | `graph_edges` | `(edge_index, weights)` | Live sparse graph from final layer | | `memory_state` | `(M, D)` | Final EMA memory anchor state | | `decay_scalars` | `(B, S, D)` | Temporal decay weights applied | --- ## Test Dataset The companion dataset **[vigneshwar234/TMT-Benchmarks](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)** contains: - `complexity_test` — 1,000 sequences annotated by token complexity category - `length_scaling` — sequences from S=32 to S=1024 for throughput benchmarking - `ablation_reference` — canonical perplexity reference values for all 8 ablation configs - `exit_gate_reference` — expected exit layer distributions per token type - `edge_case_inputs` — boundary inputs for robustness testing (empty, max-length, all-same) ```python from datasets import load_dataset ds = load_dataset("vigneshwar234/TMT-Benchmarks", "complexity_test") print(ds['test'][0]) # {'input_ids': [...], 'token_types': [...], 'expected_exit_layers': [...], 'text': '...'} ``` --- ## Figures | Figure | Description | |:---|:---| | [`fig_architecture.png`](paper/fig_architecture.png) | Full TMT architecture block diagram | | [`fig_graph.png`](paper/fig_graph.png) | Dynamic graph evolution across 3 layers | | [`fig_decay.png`](paper/fig_decay.png) | Temporal decay function curves + RoPE comparison | | [`fig_exit.png`](paper/fig_exit.png) | Exit gate distribution by layer and token type | | [`fig_training.png`](paper/fig_training.png) | Training loss + validation perplexity curves | | [`fig_ablation.png`](paper/fig_ablation.png) | Ablation bar chart + Pareto frontier | | [`fig_complexity.png`](paper/fig_complexity.png) | O(S²) vs O(S·k) operation count + memory | --- ## Citation ```bibtex @misc{tmt2026, title = {TemporalMesh Transformer: Dynamic Graph Attention with Temporal Decay and Adaptive Depth Routing}, author = {Vignesh}, year = {2026}, doi = {10.5281/zenodo.20287390}, url = {https://doi.org/10.5281/zenodo.20287390}, publisher = {Zenodo}, note = {Preprint. Novel architecture combining mesh attention, temporal decay encoding, and per-token adaptive depth routing. Code: https://github.com/vignesh2027/TemporalMesh-Transformer} } ``` --- ## Related Work | Paper | Relation to TMT | |:---|:---| | Vaswani et al. 2017 — *Attention Is All You Need* | Base architecture | | Su et al. 2021 — *RoFormer (RoPE)* | TMT extends RoPE with learned decay | | Elbayad et al. 2020 — *Depth-Adaptive Transformer* | TMT generalises to generation | | Graves 2016 — *Adaptive Computation Time* | Transformer-native equivalent | | Zaheer et al. 2020 — *BigBird* | Fixed sparse patterns vs TMT's dynamic graph | | Shi et al. 2021 — *Graph Transformer* | Static graph vs TMT's rebuilt-per-layer graph | --- ## License MIT — free to use, modify, and build upon. Citation appreciated for published work.