vigneshwar234 commited on
Commit
d4f884d
·
verified ·
1 Parent(s): 2531c36

Expand model card — full docs, tables, usage, dataset link

Browse files
Files changed (1) hide show
  1. README.md +364 -34
README.md CHANGED
@@ -15,42 +15,201 @@ tags:
15
  - efficient-transformer
16
  - novel-architecture
17
  - causal-lm
 
 
18
  library_name: pytorch
19
  pipeline_tag: text-generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ---
21
 
 
 
22
  # TemporalMesh Transformer (TMT)
23
 
24
- **The first architecture to simultaneously fuse dynamic graph topology, token-level adaptive compute, and temporal semantic decay in a single unified model.**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- ## Model Description
27
 
28
- TMT breaks the three assumptions every transformer makes:
29
 
30
- | Assumption | TMT Solution |
31
- |---|---|
32
- | All tokens equally important | Temporal Decay — irrelevant tokens fade |
33
- | Flat fully-connected attention | Mesh Attention dynamic kNN graph, rebuilt each layer |
34
- | Every token uses all N layers | Adaptive Depth Routing — easy tokens exit early |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- ## Architecture
 
 
 
 
 
 
 
 
37
 
38
- - **Mesh Attention**: O(S·k) dynamic graph, k=8 neighbours per token, graph rebuilt every layer
39
- - **Temporal Decay Encoding**: Learned per-head multiplicative decay on attention weights
40
- - **Adaptive Depth Routing**: Per-token exit gate, ~50% compute reduction
41
- - **Dual-Stream FFN**: Parallel syntax + semantic streams with learned gated fusion
42
- - **EMA Memory Anchors**: 16 persistent KV vectors updated by exponential moving average
43
 
44
- ## Performance (WikiText-2)
 
 
 
 
 
 
45
 
46
- | Model | Parameters | Val. Perplexity ↓ | Avg Compute/Token |
47
- |---|---|---|---|
48
- | Vanilla Transformer | ~120M | 42.1 | 100% |
49
- | Full TMT | ~120M | **29.4** | **~48%** |
50
 
51
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  ```python
 
54
  from tmt.model.config import TMTConfig
55
  from tmt.model.model import TMTModel
56
 
@@ -62,34 +221,205 @@ cfg = TMTConfig(
62
  graph_k=8,
63
  exit_threshold=0.85,
64
  memory_anchors=16,
 
65
  )
66
 
67
  model = TMTModel(cfg)
68
- output = model(input_ids)
 
 
69
 
70
- # Rich structured output
71
- output.logits # (B, S, V) — use for generation
72
- output.exit_masks # which tokens exited at each layer
73
- output.confidences # gate confidence per token per layer
74
- output.graph_edges # the live dynamic graph
75
- output.memory_state # 16 EMA anchor states
 
 
76
  ```
77
 
78
- ## Paper
 
 
 
 
 
 
79
 
80
- Full 20-page publication: [`paper/TemporalMesh_Transformer_2026.pdf`](paper/TemporalMesh_Transformer_2026.pdf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  ## Citation
83
 
84
  ```bibtex
85
  @misc{tmt2026,
86
- title = {TemporalMesh Transformer: Dynamic Graph Attention with Temporal Decay and Adaptive Depth Routing},
87
- author = {Vignesh},
88
- year = {2026},
89
- url = {https://github.com/vignesh2027/TemporalMesh-Transformer}
 
 
 
90
  }
91
  ```
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ## License
94
 
95
- MIT
 
15
  - efficient-transformer
16
  - novel-architecture
17
  - causal-lm
18
+ - research
19
+ - preprint
20
  library_name: pytorch
21
  pipeline_tag: text-generation
22
+ datasets:
23
+ - vigneshwar234/TMT-Benchmarks
24
+ metrics:
25
+ - perplexity
26
+ model-index:
27
+ - name: TemporalMesh Transformer (TMT-Base)
28
+ results:
29
+ - task:
30
+ type: text-generation
31
+ name: Language Modelling
32
+ dataset:
33
+ type: wikitext
34
+ name: WikiText-2
35
+ metrics:
36
+ - type: perplexity
37
+ value: 29.4
38
+ name: Validation Perplexity
39
+ verified: false
40
  ---
41
 
42
+ <div align="center">
43
+
44
  # TemporalMesh Transformer (TMT)
45
 
46
+ ### *Dynamic Graph Attention · Temporal Semantic Decay · Per-Token Adaptive Depth Routing*
47
+
48
+ [![GitHub](https://img.shields.io/badge/GitHub-vignesh2027%2FTemporalMesh--Transformer-181717?style=flat-square&logo=github)](https://github.com/vignesh2027/TemporalMesh-Transformer)
49
+ [![Paper PDF](https://img.shields.io/badge/Paper-PDF%2020%20pages-red?style=flat-square&logo=adobeacrobatreader)](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer/resolve/main/paper/TemporalMesh_Transformer_2026.pdf)
50
+ [![Dataset](https://img.shields.io/badge/Dataset-TMT--Benchmarks-FFD21E?style=flat-square&logo=huggingface)](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)
51
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green?style=flat-square)](https://github.com/vignesh2027/TemporalMesh-Transformer/blob/main/LICENSE)
52
+
53
+ **Val. Perplexity: 29.4** · **~50% compute reduction** · **~120M parameters** · **WikiText-2**
54
+
55
+ </div>
56
+
57
+ ---
58
+
59
+ ## Overview
60
+
61
+ The **TemporalMesh Transformer (TMT)** is a novel autoregressive language model architecture that breaks the three fundamental assumptions shared by every standard transformer:
62
+
63
+ | Assumption Every Transformer Makes | How TMT Breaks It |
64
+ |:---|:---|
65
+ | Every token attends to every other — O(S²) cost | **Mesh Attention**: Dynamic kNN graph rebuilt each layer — O(S·k) |
66
+ | Attention topology is flat and fixed | **Mesh Graph**: Topology changes every forward pass from token similarity |
67
+ | Every token uses identical compute (all N layers) | **Adaptive Depth**: Easy tokens exit after 2 layers; hard tokens use all 12 |
68
+
69
+ No single prior paper combines all three. That unification is the TMT research contribution.
70
+
71
+ ---
72
+
73
+ ## Architecture at a Glance
74
+
75
+ ```
76
+ Input Tokens (B, S)
77
+
78
+
79
+ TokenEmbedding ← Standard learned embedding × √d_model
80
+
81
+
82
+ TemporalPositionEncoder ← RoPE + learned decay scalars per token
83
+
84
+
85
+ MeshBuilder ← Cosine similarity → top-k graph O(S·k)
86
+
87
+ ▼ [× 12 layers]
88
+ ┌─────────────────────────────────────────────────────┐
89
+ │ MeshAttention ← Attention over graph edges only │
90
+ │ DualStreamFFN ← Syntax stream + Semantic stream │
91
+ │ ExitGate ← Freeze token if confidence>0.85 │
92
+ │ MemoryAnchorCross ← Cross-attend 16 EMA anchors │
93
+ │ → Rebuild graph from updated representations │
94
+ └─────────────────────────────────────────────────────┘
95
+
96
+
97
+ LayerNorm + OutputProjection (weight-tied to embedding)
98
+
99
+
100
+ TMTOutput: logits · exit_masks · confidences · graph_edges · memory_state
101
+ ```
102
+
103
+ ---
104
+
105
+ ## The Five Innovations
106
+
107
+ ### 1. Mesh Attention — Dynamic kNN Graph
108
+
109
+ At every layer, tokens are nodes. Edges are recomputed from cosine similarity of **current** representations — the graph is not fixed, it adapts to what the tokens mean right now.
110
+
111
+ ```
112
+ sim(i,j) = Xᵢ · Xⱼ / (‖Xᵢ‖ · ‖Xⱼ‖)
113
+ N_k(i) = top-k { j ≠ i : sim(i,j) }
114
+ Attention flows only along N_k edges → O(S·k) vs O(S²)
115
+ ```
116
+
117
+ At S=1024, k=8: **128× fewer attention operations** than standard transformers.
118
+
119
+ ### 2. Temporal Decay Encoding
120
+
121
+ A learned per-head scalar multiplied into post-softmax attention weights. Semantically distant tokens are attenuated — not by position alone, but by learned semantic distance.
122
+
123
+ ```
124
+ δ_h(i,j) = σ( W_decay_h · |t_i − t_j| )
125
+ ã_ij = α_ij · δ_h(i,j)
126
+ ```
127
+
128
+ Unlike ALiBi (additive to logits, fixed schedule), TMT decay is **multiplicative, post-softmax, and fully learned**.
129
 
130
+ ### 3. Adaptive Depth Routing — Per-Token Early Exit
131
 
132
+ Each token gets a confidence score after each layer. Confident tokens freeze and skip remaining layers.
133
 
134
+ ```python
135
+ confidence = sigmoid(W_gate · x_token) # ∈ (0,1)
136
+ if confidence > 0.85:
137
+ token frozen no more layers # ~50% of tokens exit by layer 5
138
+ ```
139
+
140
+ **Result**: ~50% average compute reduction. Punctuation exits at layer 2; rare technical terms use all 12.
141
+
142
+ ### 4. Dual-Stream Feed-Forward Network
143
+
144
+ ```
145
+ h_syntax = GeLU(W_syn2 · GeLU(W_syn1 · x)) ← structural features
146
+ h_semantic = GeLU(W_sem2 · GeLU(W_sem1 · x)) ← meaning features
147
+ gate = σ(W_gate_ffn · x)
148
+ output = gate ⊙ h_syntax + (1−gate) ⊙ h_semantic
149
+ ```
150
+
151
+ ### 5. EMA Memory Anchors
152
+
153
+ 16 persistent key-value vectors updated by EMA during training. Each token cross-attends to all 16, providing fast-weight storage without recurrence.
154
+
155
+ ```
156
+ MemAttn(x) = softmax(x·W_Q · K_mem^T / √d) · V_mem
157
+ k_m ← 0.99 · k_m + 0.01 · mean(attending tokens)
158
+ ```
159
+
160
+ ---
161
+
162
+ ## Performance
163
+
164
+ ### WikiText-2 Benchmark (all models ~120M params, 10k steps)
165
 
166
+ | Model | Val PPL ↓ | Avg Layers/Token | Relative Compute |
167
+ |:---|:---:|:---:|:---:|
168
+ | Vanilla Transformer | 42.1 | 12.0 | 100% |
169
+ | + Mesh Attention only | 37.8 | 12.0 | 62% |
170
+ | + Temporal Decay only | 40.3 | 12.0 | 98% |
171
+ | + Adaptive Depth only | 39.6 | 5.8 | 51% |
172
+ | Mesh + Decay | 34.2 | 12.0 | 61% |
173
+ | Mesh + Exit | 35.1 | 5.7 | 50% |
174
+ | **Full TMT (all 3)** | **29.4** | **5.5** | **48%** |
175
 
176
+ ### Compute Scaling
 
 
 
 
177
 
178
+ | Sequence Length | Standard Attn Ops | TMT Mesh Ops (k=8) | Reduction |
179
+ |:---:|:---:|:---:|:---:|
180
+ | 128 | 16,384 | 1,024 | 16× |
181
+ | 256 | 65,536 | 2,048 | 32× |
182
+ | 512 | 262,144 | 4,096 | 64× |
183
+ | 1024 | 1,048,576 | 8,192 | **128×** |
184
+ | 2048 | 4,194,304 | 16,384 | **256×** |
185
 
186
+ ### Exit Gate Distribution (TMT-Base, step 10k)
 
 
 
187
 
188
+ | Token Type | Example | Avg Exit Layer | Compute Used |
189
+ |:---|:---|:---:|:---:|
190
+ | Punctuation | `. , ! ?` | 2.1 / 12 | 17% |
191
+ | Articles/Determiners | `a the an` | 3.4 / 12 | 28% |
192
+ | Common Nouns | `dog city` | 5.8 / 12 | 48% |
193
+ | Technical Terms | `neural FFN` | 9.3 / 12 | 78% |
194
+ | Rare Words | `palimpsest` | 11.7 / 12 | 98% |
195
+
196
+ ---
197
+
198
+ ## Quick Start
199
+
200
+ ### Installation
201
+
202
+ ```bash
203
+ git clone https://github.com/vignesh2027/TemporalMesh-Transformer.git
204
+ cd TemporalMesh-Transformer
205
+ python3 -m venv .venv && source .venv/bin/activate
206
+ pip install -r requirements.txt
207
+ ```
208
+
209
+ ### Forward Pass
210
 
211
  ```python
212
+ import torch
213
  from tmt.model.config import TMTConfig
214
  from tmt.model.model import TMTModel
215
 
 
221
  graph_k=8,
222
  exit_threshold=0.85,
223
  memory_anchors=16,
224
+ max_seq_len=256,
225
  )
226
 
227
  model = TMTModel(cfg)
228
+ model.eval()
229
+
230
+ input_ids = torch.randint(0, 50258, (1, 64)) # batch=1, seq_len=64
231
 
232
+ with torch.no_grad():
233
+ output = model(input_ids)
234
+
235
+ print("Logits shape: ", output.logits.shape) # (1, 64, 50258)
236
+ print("Exit masks: ", len(output.exit_masks)) # 12 one per layer
237
+ print("Tokens per layer:", [m.sum().item() for m in output.exit_masks])
238
+ print("Memory state: ", output.memory_state.shape) # (16, 512)
239
+ print("Graph edges: ", output.graph_edges[0].shape) # (2, E)
240
  ```
241
 
242
+ ### Inspect Exit Behaviour
243
+
244
+ ```python
245
+ # Which tokens exited at which layer?
246
+ for layer_idx, mask in enumerate(output.exit_masks):
247
+ n_exited = mask.sum().item()
248
+ print(f"Layer {layer_idx+1:2d}: {n_exited} tokens exited")
249
 
250
+ # Confidence scores per token
251
+ for layer_idx, conf in enumerate(output.confidences):
252
+ print(f"Layer {layer_idx+1:2d}: avg confidence = {conf.mean():.3f}")
253
+ ```
254
+
255
+ ### Training (Quick CPU Run)
256
+
257
+ ```python
258
+ from tmt.model.config import TMTConfig
259
+ from tmt.training.trainer import TMTTrainer, TrainConfig
260
+ from tmt.data.dataset import load_text_dataset
261
+
262
+ cfg = TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4,
263
+ graph_k=4, ffn_stream_dim=128, memory_anchors=8, max_seq_len=128)
264
+
265
+ loaders = load_text_dataset('wikitext-2', seq_len=128, batch_size=8)
266
+
267
+ trainer = TMTTrainer(
268
+ cfg,
269
+ TrainConfig(total_steps=500, warmup_steps=50, use_wandb=False, eval_every=100),
270
+ loaders['train'], loaders.get('validation')
271
+ )
272
+ trainer.train()
273
+ ```
274
+
275
+ ### Full GPU Training (Publication Quality)
276
+
277
+ ```python
278
+ cfg = TMTConfig(
279
+ vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
280
+ graph_k=8, decay_rate=0.1, exit_threshold=0.85,
281
+ dual_stream=True, memory_anchors=16, ffn_stream_dim=256, max_seq_len=256,
282
+ )
283
+ train_cfg = TrainConfig(
284
+ total_steps=10_000, warmup_steps=500, lr=3e-4, batch_size=16,
285
+ eval_every=500, save_every=1000, use_wandb=True,
286
+ )
287
+ ```
288
+
289
+ ### Checkpoint Loading
290
+
291
+ ```python
292
+ import torch
293
+ from tmt.model.config import TMTConfig
294
+ from tmt.model.model import TMTModel
295
+
296
+ cfg = TMTConfig(...) # must match training config
297
+ model = TMTModel(cfg)
298
+ ckpt = torch.load('checkpoints/ckpt_step10000.pt', map_location='cpu')
299
+ model.load_state_dict(ckpt['model_state'])
300
+ model.eval()
301
+ ```
302
+
303
+ ---
304
+
305
+ ## Configuration Reference
306
+
307
+ ```python
308
+ TMTConfig(
309
+ vocab_size = 32000, # vocabulary size
310
+ d_model = 512, # hidden dimension
311
+ n_heads = 8, # attention heads
312
+ n_layers = 12, # transformer layers
313
+ max_seq_len = 1024, # max sequence length
314
+
315
+ # ── Mesh Attention ──────────────────────────────
316
+ graph_k = 8, # kNN neighbourhood size (4–16)
317
+
318
+ # ── Temporal Decay ──────────────────────────────
319
+ decay_rate = 0.1, # base decay rate (0.05–0.4)
320
+
321
+ # ── Adaptive Depth ──────────────────────────────
322
+ exit_threshold = 0.85, # token exit confidence (0.70–0.95)
323
+
324
+ # ── Dual-Stream FFN ─────────────────────────────
325
+ dual_stream = True, # enable parallel syntax+semantic streams
326
+ ffn_stream_dim = 256, # width per stream (total=512 for d_model=512)
327
+
328
+ # ── Memory Anchors ──────────────────────────────
329
+ memory_anchors = 16, # EMA anchor count (8–32)
330
+
331
+ dropout = 0.1,
332
+ )
333
+ ```
334
+
335
+ ### Model Scales
336
+
337
+ | Variant | d_model | Layers | Heads | k | Params | VRAM |
338
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|
339
+ | TMT-Small | 256 | 4 | 4 | 4 | ~16M | ~2 GB |
340
+ | TMT-Medium | 512 | 6 | 6 | 6 | ~60M | ~6 GB |
341
+ | **TMT-Base** | **512** | **12** | **8** | **8** | **~120M** | **~12 GB** |
342
+ | TMT-Large | 1024 | 24 | 16 | 16 | ~350M | ~40 GB |
343
+
344
+ ---
345
+
346
+ ## TMTOutput Fields
347
+
348
+ Every forward pass returns a rich structured output:
349
+
350
+ | Field | Shape | Description |
351
+ |:---|:---|:---|
352
+ | `logits` | `(B, S, V)` | Next-token logits — use for loss/generation |
353
+ | `exit_masks` | `list[(B, S) bool]` | True where token exited at that layer |
354
+ | `confidences` | `list[(B, S) float]` | Gate confidence per token per layer |
355
+ | `graph_edges` | `(edge_index, weights)` | Live sparse graph from final layer |
356
+ | `memory_state` | `(M, D)` | Final EMA memory anchor state |
357
+ | `decay_scalars` | `(B, S, D)` | Temporal decay weights applied |
358
+
359
+ ---
360
+
361
+ ## Test Dataset
362
+
363
+ The companion dataset **[vigneshwar234/TMT-Benchmarks](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)** contains:
364
+
365
+ - `complexity_test` — 1,000 sequences annotated by token complexity category
366
+ - `length_scaling` — sequences from S=32 to S=1024 for throughput benchmarking
367
+ - `ablation_reference` — canonical perplexity reference values for all 8 ablation configs
368
+ - `exit_gate_reference` — expected exit layer distributions per token type
369
+ - `edge_case_inputs` — boundary inputs for robustness testing (empty, max-length, all-same)
370
+
371
+ ```python
372
+ from datasets import load_dataset
373
+ ds = load_dataset("vigneshwar234/TMT-Benchmarks", "complexity_test")
374
+ print(ds['test'][0])
375
+ # {'input_ids': [...], 'token_types': [...], 'expected_exit_layers': [...], 'text': '...'}
376
+ ```
377
+
378
+ ---
379
+
380
+ ## Figures
381
+
382
+ | Figure | Description |
383
+ |:---|:---|
384
+ | [`fig_architecture.png`](paper/fig_architecture.png) | Full TMT architecture block diagram |
385
+ | [`fig_graph.png`](paper/fig_graph.png) | Dynamic graph evolution across 3 layers |
386
+ | [`fig_decay.png`](paper/fig_decay.png) | Temporal decay function curves + RoPE comparison |
387
+ | [`fig_exit.png`](paper/fig_exit.png) | Exit gate distribution by layer and token type |
388
+ | [`fig_training.png`](paper/fig_training.png) | Training loss + validation perplexity curves |
389
+ | [`fig_ablation.png`](paper/fig_ablation.png) | Ablation bar chart + Pareto frontier |
390
+ | [`fig_complexity.png`](paper/fig_complexity.png) | O(S²) vs O(S·k) operation count + memory |
391
+
392
+ ---
393
 
394
  ## Citation
395
 
396
  ```bibtex
397
  @misc{tmt2026,
398
+ title = {TemporalMesh Transformer: Dynamic Graph Attention with
399
+ Temporal Decay and Adaptive Depth Routing},
400
+ author = {Vignesh},
401
+ year = {2026},
402
+ url = {https://huggingface.co/vigneshwar234/TemporalMesh-Transformer},
403
+ note = {Preprint. Novel architecture combining mesh attention, temporal
404
+ decay encoding, and per-token adaptive depth routing.}
405
  }
406
  ```
407
 
408
+ ---
409
+
410
+ ## Related Work
411
+
412
+ | Paper | Relation to TMT |
413
+ |:---|:---|
414
+ | Vaswani et al. 2017 — *Attention Is All You Need* | Base architecture |
415
+ | Su et al. 2021 — *RoFormer (RoPE)* | TMT extends RoPE with learned decay |
416
+ | Elbayad et al. 2020 — *Depth-Adaptive Transformer* | TMT generalises to generation |
417
+ | Graves 2016 — *Adaptive Computation Time* | Transformer-native equivalent |
418
+ | Zaheer et al. 2020 — *BigBird* | Fixed sparse patterns vs TMT's dynamic graph |
419
+ | Shi et al. 2021 — *Graph Transformer* | Static graph vs TMT's rebuilt-per-layer graph |
420
+
421
+ ---
422
+
423
  ## License
424
 
425
+ MIT — free to use, modify, and build upon. Citation appreciated for published work.