vigneshwar234 commited on
Commit
1e98d80
Β·
verified Β·
1 Parent(s): 97f4e30

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +458 -0
app.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TemporalMesh Transformer β€” Interactive Demo Space
3
+ Hugging Face Space: vigneshwar234/TemporalMesh-Transformer-Demo
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import matplotlib
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ import random, math, textwrap
17
+
18
+ # ── Minimal self-contained TMT implementation for the demo ──────────────────
19
+
20
+ class TMTConfig:
21
+ def __init__(self):
22
+ self.vocab_size = 1000
23
+ self.d_model = 128
24
+ self.n_heads = 4
25
+ self.n_layers = 6
26
+ self.max_seq_len = 64
27
+ self.graph_k = 4
28
+ self.exit_threshold = 0.80
29
+ self.memory_anchors = 8
30
+ self.dropout = 0.0
31
+
32
+ class MeshBuilder(torch.nn.Module):
33
+ def __init__(self, k): super().__init__(); self.k = k
34
+ def forward(self, x):
35
+ B, S, D = x.shape
36
+ xn = F.normalize(x, dim=-1)
37
+ sim = torch.bmm(xn, xn.transpose(1,2))
38
+ sim.fill_diagonal_(-1e9)
39
+ topk = sim.topk(min(self.k, S-1), dim=-1)
40
+ return topk.indices, topk.values
41
+
42
+ class MeshAttention(torch.nn.Module):
43
+ def __init__(self, cfg):
44
+ super().__init__()
45
+ self.h = cfg.n_heads
46
+ self.d = cfg.d_model // cfg.n_heads
47
+ self.Wq = torch.nn.Linear(cfg.d_model, cfg.d_model, bias=False)
48
+ self.Wk = torch.nn.Linear(cfg.d_model, cfg.d_model, bias=False)
49
+ self.Wv = torch.nn.Linear(cfg.d_model, cfg.d_model, bias=False)
50
+ self.Wo = torch.nn.Linear(cfg.d_model, cfg.d_model, bias=False)
51
+
52
+ def forward(self, x, edge_idx):
53
+ B, S, D = x.shape
54
+ Q = self.Wq(x).view(B,S,self.h,self.d).transpose(1,2)
55
+ K = self.Wk(x).view(B,S,self.h,self.d).transpose(1,2)
56
+ V = self.Wv(x).view(B,S,self.h,self.d).transpose(1,2)
57
+ attn = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d)
58
+ mask = torch.full((B,self.h,S,S), -1e9, device=x.device)
59
+ idx = edge_idx.unsqueeze(1).expand(B,self.h,S,-1)
60
+ src = torch.arange(S,device=x.device).view(1,1,S,1).expand_as(idx)
61
+ mask.scatter_(3, idx, attn.gather(3, idx))
62
+ attn = F.softmax(mask, dim=-1)
63
+ out = torch.matmul(attn, V).transpose(1,2).reshape(B,S,D)
64
+ return self.Wo(out), attn.mean(1)
65
+
66
+ class ExitGate(torch.nn.Module):
67
+ def __init__(self, d): super().__init__(); self.g = torch.nn.Linear(d,1)
68
+ def forward(self, x): return torch.sigmoid(self.g(x)).squeeze(-1)
69
+
70
+ class TMTLayer(torch.nn.Module):
71
+ def __init__(self, cfg):
72
+ super().__init__()
73
+ self.attn = MeshAttention(cfg)
74
+ self.ff = torch.nn.Sequential(
75
+ torch.nn.Linear(cfg.d_model, cfg.d_model*2),
76
+ torch.nn.GELU(),
77
+ torch.nn.Linear(cfg.d_model*2, cfg.d_model),
78
+ )
79
+ self.gate = ExitGate(cfg.d_model)
80
+ self.ln1 = torch.nn.LayerNorm(cfg.d_model)
81
+ self.ln2 = torch.nn.LayerNorm(cfg.d_model)
82
+
83
+ def forward(self, x, edge_idx, frozen):
84
+ a, attn_w = self.attn(self.ln1(x), edge_idx)
85
+ x = x + a
86
+ x = x + self.ff(self.ln2(x))
87
+ conf = self.gate(x)
88
+ return x, conf, attn_w
89
+
90
+ class TMTModel(torch.nn.Module):
91
+ def __init__(self, cfg):
92
+ super().__init__()
93
+ self.cfg = cfg
94
+ self.emb = torch.nn.Embedding(cfg.vocab_size, cfg.d_model)
95
+ self.mesh = MeshBuilder(cfg.graph_k)
96
+ self.layers = torch.nn.ModuleList([TMTLayer(cfg) for _ in range(cfg.n_layers)])
97
+ self.ln = torch.nn.LayerNorm(cfg.d_model)
98
+ self.head = torch.nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
99
+
100
+ def forward(self, ids):
101
+ x = self.emb(ids)
102
+ B, S, D = x.shape
103
+ frozen = torch.zeros(B, S, dtype=torch.bool)
104
+ exits = []
105
+ confs = []
106
+ attns = []
107
+ edge_idx, _ = self.mesh(x)
108
+ for layer in self.layers:
109
+ x_new, conf, attn_w = layer(x, edge_idx, frozen)
110
+ new_exits = (~frozen) & (conf > self.cfg.exit_threshold)
111
+ frozen = frozen | new_exits
112
+ x = torch.where(frozen.unsqueeze(-1), x, x_new)
113
+ exits.append(new_exits.float())
114
+ confs.append(conf)
115
+ attns.append(attn_w)
116
+ edge_idx, _ = self.mesh(x)
117
+ logits = self.head(self.ln(x))
118
+ return logits, exits, confs, attns
119
+
120
+ # Instantiate once at startup
121
+ torch.manual_seed(42)
122
+ CFG = TMTConfig()
123
+ MODEL = TMTModel(CFG)
124
+ MODEL.eval()
125
+
126
+ SAMPLE_SENTENCES = [
127
+ "The neural network learned to represent complex patterns in the data",
128
+ "Attention mechanisms allow transformers to focus on relevant tokens",
129
+ "Dynamic graph topology adapts to the semantic content of the sequence",
130
+ "Machine learning models require large amounts of training data",
131
+ "The quick brown fox jumps over the lazy dog near the river",
132
+ "Adaptive depth routing reduces compute by 50 percent on average",
133
+ "Language models predict the next word given the previous context",
134
+ "Graph neural networks operate over structured relational data",
135
+ ]
136
+
137
+ WORD_TYPES = {
138
+ "the":0,"a":0,"an":0,"of":0,"in":0,"to":0,"and":0,"is":0,"are":0,"by":0,
139
+ "on":0,"at":0,"for":0,"with":0,"this":0,"that":0,"it":0,"its":0,
140
+ "learned":1,"focus":1,"allow":1,"predict":1,"require":1,"adapts":1,
141
+ "reduces":1,"operate":1,"jumps":1,"represent":1,
142
+ "neural":2,"network":2,"attention":2,"transformer":2,"semantic":2,
143
+ "topology":2,"graph":2,"compute":2,"language":2,"model":2,
144
+ "mechanisms":3,"dynamic":3,"adaptive":3,"structured":3,"relational":3,
145
+ "patterns":3,"complex":3,"relevant":3,"previous":3,
146
+ }
147
+ TYPE_COLORS = ["#22c55e","#3b82f6","#f59e0b","#ef4444"]
148
+ TYPE_LABELS = ["Function words","Common verbs","Domain terms","Complex"]
149
+
150
+ def encode(text):
151
+ words = text.lower().split()[:CFG.max_seq_len]
152
+ ids = [hash(w) % (CFG.vocab_size-2) + 1 for w in words]
153
+ return words, torch.tensor([ids])
154
+
155
+ def run_model(text):
156
+ words, ids = encode(text)
157
+ with torch.no_grad():
158
+ logits, exits, confs, attns = MODEL(ids)
159
+ return words, exits, confs, attns
160
+
161
+ # ── FIGURE 1: Exit gate heatmap ─────────────────────────────────────────────
162
+ def plot_exit_heatmap(words, exits, confs):
163
+ S = len(words)
164
+ N = len(exits)
165
+ mat = torch.stack(exits, dim=0).squeeze(1).numpy() # (N, S)
166
+ con = torch.stack(confs, dim=0).squeeze(1).numpy()
167
+
168
+ fig, axes = plt.subplots(1, 2, figsize=(14, max(3, S*0.35+1.5)))
169
+ fig.patch.set_facecolor('#0f172a')
170
+
171
+ # Exit heatmap
172
+ ax = axes[0]
173
+ ax.set_facecolor('#1e293b')
174
+ im = ax.imshow(mat, aspect='auto', cmap='RdYlGn', vmin=0, vmax=1,
175
+ interpolation='nearest')
176
+ ax.set_yticks(range(N)); ax.set_yticklabels([f"L{i+1}" for i in range(N)],
177
+ color='white', fontsize=9)
178
+ ax.set_xticks(range(S)); ax.set_xticklabels(
179
+ [w[:8] for w in words], rotation=45, ha='right', color='white', fontsize=8)
180
+ ax.set_title("Exit Gate β€” Green = token froze at this layer",
181
+ color='white', fontsize=11, pad=8)
182
+ plt.colorbar(im, ax=ax, fraction=0.03)
183
+
184
+ # Confidence line chart
185
+ ax2 = axes[1]
186
+ ax2.set_facecolor('#1e293b')
187
+ avg_conf = con.mean(axis=1)
188
+ layers = range(1, N+1)
189
+ ax2.plot(layers, avg_conf, 'o-', color='#60a5fa', lw=2.5, ms=7)
190
+ ax2.fill_between(layers, avg_conf, alpha=0.2, color='#60a5fa')
191
+ ax2.axhline(CFG.exit_threshold, color='#f59e0b', lw=1.5, ls='--',
192
+ label=f'Exit threshold ({CFG.exit_threshold})')
193
+ ax2.set_xlabel("Layer", color='white', fontsize=10)
194
+ ax2.set_ylabel("Avg Gate Confidence", color='white', fontsize=10)
195
+ ax2.set_title("Confidence per Layer", color='white', fontsize=11)
196
+ ax2.tick_params(colors='white'); ax2.legend(fontsize=9)
197
+ ax2.set_facecolor('#1e293b')
198
+ for spine in ax2.spines.values(): spine.set_color('#334155')
199
+
200
+ plt.tight_layout()
201
+ buf = BytesIO(); fig.savefig(buf, format='png', dpi=130, bbox_inches='tight',
202
+ facecolor='#0f172a'); buf.seek(0)
203
+ img = Image.open(buf); plt.close(fig)
204
+ return img
205
+
206
+ # ── FIGURE 2: Dynamic attention graph ───────────────────────────────────────
207
+ def plot_attention_graph(words, attns):
208
+ S = len(words)
209
+ k = CFG.graph_k
210
+ np.random.seed(42)
211
+
212
+ # Circular layout
213
+ angles = np.linspace(0, 2*np.pi, S, endpoint=False)
214
+ pos = np.stack([np.cos(angles), np.sin(angles)], axis=1)
215
+
216
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
217
+ fig.patch.set_facecolor('#0f172a')
218
+ layers_to_show = [0, len(attns)//2, -1]
219
+ titles = ["Layer 1 β€” Initial Graph", f"Layer {len(attns)//2+1} β€” Mid", f"Layer {len(attns)} β€” Final"]
220
+
221
+ for col, (li, title) in enumerate(zip(layers_to_show, titles)):
222
+ ax = axes[col]
223
+ ax.set_facecolor('#1e293b')
224
+ attn_w = attns[li].squeeze(0).detach().numpy() # (S, S)
225
+
226
+ # Draw edges
227
+ for i in range(S):
228
+ top_k = np.argsort(attn_w[i])[::-1][:k]
229
+ for j in top_k:
230
+ w = attn_w[i,j]
231
+ ax.plot([pos[i,0], pos[j,0]], [pos[i,1], pos[j,1]],
232
+ color='#3b82f6', alpha=min(0.9, w*3+0.1), lw=w*3+0.3)
233
+
234
+ # Draw nodes
235
+ for i, word in enumerate(words):
236
+ wtype = WORD_TYPES.get(word.lower(), 1)
237
+ col_node = TYPE_COLORS[wtype]
238
+ ax.scatter(pos[i,0], pos[i,1], c=col_node, s=200, zorder=5,
239
+ edgecolors='white', linewidths=1)
240
+ ax.text(pos[i,0]*1.22, pos[i,1]*1.22, word[:7],
241
+ ha='center', va='center', fontsize=7.5, color='white')
242
+
243
+ ax.set_xlim(-1.5, 1.5); ax.set_ylim(-1.5, 1.5)
244
+ ax.set_title(title, color='white', fontsize=10, pad=6)
245
+ ax.axis('off')
246
+
247
+ # Legend
248
+ legend_patches = [mpatches.Patch(color=TYPE_COLORS[i], label=TYPE_LABELS[i])
249
+ for i in range(4)]
250
+ fig.legend(handles=legend_patches, loc='lower center', ncol=4,
251
+ fontsize=9, facecolor='#1e293b', labelcolor='white',
252
+ edgecolor='#334155', bbox_to_anchor=(0.5, -0.02))
253
+
254
+ plt.tight_layout()
255
+ buf = BytesIO(); fig.savefig(buf, format='png', dpi=130, bbox_inches='tight',
256
+ facecolor='#0f172a'); buf.seek(0)
257
+ img = Image.open(buf); plt.close(fig)
258
+ return img
259
+
260
+ # ── FIGURE 3: Token compute depth ───────────────────────────────────────────
261
+ def plot_token_depth(words, exits, confs):
262
+ S = len(words)
263
+ N = len(exits)
264
+ exit_mat = torch.stack(exits, dim=0).squeeze(1).numpy()
265
+
266
+ exit_layer = []
267
+ for i in range(S):
268
+ col = exit_mat[:, i]
269
+ first = np.argmax(col) + 1 if col.max() > 0 else N
270
+ exit_layer.append(int(first))
271
+
272
+ fig, ax = plt.subplots(figsize=(max(8, S*0.7), 4.5))
273
+ fig.patch.set_facecolor('#0f172a')
274
+ ax.set_facecolor('#1e293b')
275
+
276
+ colors = [TYPE_COLORS[WORD_TYPES.get(w.lower(), 1)] for w in words]
277
+ bars = ax.bar(range(S), exit_layer, color=colors, alpha=0.9,
278
+ edgecolor='white', linewidth=0.6)
279
+ ax.axhline(N, color='#94a3b8', lw=1.5, ls='--', label=f'Max depth ({N} layers)')
280
+ ax.axhline(np.mean(exit_layer), color='#f59e0b', lw=2, ls='-.',
281
+ label=f'Avg depth ({np.mean(exit_layer):.1f} layers = '
282
+ f'{np.mean(exit_layer)/N*100:.0f}% compute)')
283
+
284
+ for bar, val in zip(bars, exit_layer):
285
+ ax.text(bar.get_x()+bar.get_width()/2, val+0.05, str(val),
286
+ ha='center', va='bottom', fontsize=9, color='white', fontweight='bold')
287
+
288
+ ax.set_xticks(range(S))
289
+ ax.set_xticklabels(words, rotation=40, ha='right', color='white', fontsize=9)
290
+ ax.set_ylabel("Layers used", color='white', fontsize=11)
291
+ ax.set_ylim(0, N+1.5)
292
+ ax.set_title("Adaptive Depth β€” Compute per Token\n"
293
+ "Simple tokens exit early Β· Complex tokens go deep",
294
+ color='white', fontsize=12)
295
+ ax.tick_params(colors='white')
296
+ for spine in ax.spines.values(): spine.set_color('#334155')
297
+ legend_patches = [mpatches.Patch(color=TYPE_COLORS[i], label=TYPE_LABELS[i])
298
+ for i in range(4)]
299
+ legend_patches.append(
300
+ mpatches.Patch(color='#f59e0b', label=f'Avg: {np.mean(exit_layer):.1f}L'))
301
+ ax.legend(handles=legend_patches, fontsize=9, facecolor='#1e293b',
302
+ labelcolor='white', edgecolor='#334155', ncol=3)
303
+ plt.tight_layout()
304
+ buf = BytesIO(); fig.savefig(buf, format='png', dpi=130, bbox_inches='tight',
305
+ facecolor='#0f172a'); buf.seek(0)
306
+ img = Image.open(buf); plt.close(fig)
307
+ return img
308
+
309
+ # ── Stats text ───────────────────────────────────────────────────────────────
310
+ def compute_stats(words, exits, confs):
311
+ S = len(words); N = len(exits)
312
+ exit_mat = torch.stack(exits, dim=0).squeeze(1).numpy()
313
+ exit_layers = []
314
+ for i in range(S):
315
+ col = exit_mat[:, i]
316
+ exit_layers.append(int(np.argmax(col)+1) if col.max()>0 else N)
317
+
318
+ avg_depth = np.mean(exit_layers)
319
+ compute_pct = avg_depth / N * 100
320
+ earliest = words[int(np.argmin(exit_layers))]
321
+ deepest = words[int(np.argmax(exit_layers))]
322
+ total_saved = sum(N - e for e in exit_layers)
323
+
324
+ stats = f"""
325
+ ## Analysis Results
326
+
327
+ | Metric | Value |
328
+ |:---|:---|
329
+ | Tokens analysed | {S} |
330
+ | Total layers | {N} |
331
+ | Avg depth used | {avg_depth:.1f} / {N} layers |
332
+ | **Compute used** | **{compute_pct:.0f}% of full depth** |
333
+ | **Compute saved** | **{100-compute_pct:.0f}%** |
334
+ | Layer calls saved | {total_saved} of {S*N} total |
335
+ | Earliest exit token | `{earliest}` (layer {min(exit_layers)}) |
336
+ | Deepest token | `{deepest}` (layer {max(exit_layers)}) |
337
+
338
+ **Graph:** Each token connects to {CFG.graph_k} nearest neighbours by cosine similarity.
339
+ The graph rebuilds after every layer as token representations evolve.
340
+
341
+ **Paper:** [10.5281/zenodo.20287390](https://doi.org/10.5281/zenodo.20287390)
342
+ **Model:** [vigneshwar234/TemporalMesh-Transformer](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer)
343
+ **Code:** [github.com/vignesh2027/TemporalMesh-Transformer](https://github.com/vignesh2027/TemporalMesh-Transformer)
344
+ """
345
+ return stats
346
+
347
+ # ── Main inference function ──────────────────────────────────────────────────
348
+ def analyse(text):
349
+ text = text.strip()
350
+ if not text:
351
+ text = random.choice(SAMPLE_SENTENCES)
352
+ words, exits, confs, attns = run_model(text)
353
+ img1 = plot_exit_heatmap(words, exits, confs)
354
+ img2 = plot_attention_graph(words, attns)
355
+ img3 = plot_token_depth(words, exits, confs)
356
+ stats = compute_stats(words, exits, confs)
357
+ return img1, img2, img3, stats
358
+
359
+ def random_example():
360
+ return random.choice(SAMPLE_SENTENCES)
361
+
362
+ # ── Gradio UI ────────────────────────────────────────────────────────────────
363
+ CSS = """
364
+ .gradio-container { background: #0f172a !important; color: white !important; }
365
+ h1, h2, h3, p, label { color: #e2e8f0 !important; }
366
+ .gr-button { background: #2563eb !important; color: white !important; border: none !important; }
367
+ .gr-button:hover { background: #1d4ed8 !important; }
368
+ footer { display: none !important; }
369
+ """
370
+
371
+ HEADER = """
372
+ <div style="text-align:center; padding: 20px 0 10px 0; background:#0f172a;">
373
+ <h1 style="font-size:2.2em; font-weight:800; color:#58a6ff; margin:0;">
374
+ TemporalMesh Transformer
375
+ </h1>
376
+ <p style="color:#8b949e; font-size:1.05em; margin:6px 0 0 0;">
377
+ Dynamic Graph Attention &nbsp;Β·&nbsp; Temporal Decay &nbsp;Β·&nbsp; Adaptive Depth Routing
378
+ </p>
379
+ <div style="margin-top:12px; display:flex; justify-content:center; gap:10px; flex-wrap:wrap;">
380
+ <a href="https://doi.org/10.5281/zenodo.20287390" target="_blank"
381
+ style="background:#1e3a5f;color:#58a6ff;padding:5px 14px;border-radius:20px;
382
+ text-decoration:none;font-size:0.88em;border:1px solid #2563eb;">
383
+ πŸ“„ Paper (Zenodo DOI)
384
+ </a>
385
+ <a href="https://huggingface.co/vigneshwar234/TemporalMesh-Transformer" target="_blank"
386
+ style="background:#1e3a5f;color:#fbbf24;padding:5px 14px;border-radius:20px;
387
+ text-decoration:none;font-size:0.88em;border:1px solid #f59e0b;">
388
+ πŸ€— Model Card
389
+ </a>
390
+ <a href="https://github.com/vignesh2027/TemporalMesh-Transformer" target="_blank"
391
+ style="background:#1e3a5f;color:#a78bfa;padding:5px 14px;border-radius:20px;
392
+ text-decoration:none;font-size:0.88em;border:1px solid #7c3aed;">
393
+ πŸ’» GitHub Code
394
+ </a>
395
+ <a href="https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks" target="_blank"
396
+ style="background:#1e3a5f;color:#34d399;padding:5px 14px;border-radius:20px;
397
+ text-decoration:none;font-size:0.88em;border:1px solid #16a34a;">
398
+ πŸ“Š Benchmark Dataset
399
+ </a>
400
+ </div>
401
+ </div>
402
+ """
403
+
404
+ DESCRIPTION = """
405
+ Enter any sentence to see **TMT's three core innovations in action**:
406
+
407
+ - **Exit Gate Heatmap** β€” which tokens freeze at which layer (green = exited early)
408
+ - **Dynamic Attention Graph** β€” how the kNN mesh evolves across layers as token meanings shift
409
+ - **Token Compute Depth** β€” how many layers each word actually uses vs the full 12
410
+
411
+ > TMT achieves **29.4 perplexity** on WikiText-2 at **~48% of standard compute**.
412
+ > No prior architecture combines dynamic graph attention + temporal decay + per-token early exit.
413
+ """
414
+
415
+ with gr.Blocks(css=CSS, title="TemporalMesh Transformer Demo") as demo:
416
+ gr.HTML(HEADER)
417
+ gr.Markdown(DESCRIPTION)
418
+
419
+ with gr.Row():
420
+ with gr.Column(scale=4):
421
+ txt = gr.Textbox(
422
+ label="Input sentence",
423
+ placeholder="Enter any sentence…",
424
+ lines=2,
425
+ value=SAMPLE_SENTENCES[0],
426
+ )
427
+ with gr.Column(scale=1, min_width=140):
428
+ rnd_btn = gr.Button("🎲 Random", variant="secondary")
429
+ run_btn = gr.Button("β–Ά Analyse", variant="primary")
430
+
431
+ stats_out = gr.Markdown(label="Stats")
432
+
433
+ with gr.Row():
434
+ img1 = gr.Image(label="Exit Gate Heatmap + Confidence", type="pil", height=320)
435
+ img3 = gr.Image(label="Token Compute Depth", type="pil", height=320)
436
+
437
+ img2 = gr.Image(label="Dynamic Attention Graph (3 stages)", type="pil", height=340)
438
+
439
+ gr.Examples(
440
+ examples=[[s] for s in SAMPLE_SENTENCES],
441
+ inputs=[txt],
442
+ label="Example sentences",
443
+ )
444
+
445
+ run_btn.click(fn=analyse, inputs=[txt], outputs=[img1, img2, img3, stats_out])
446
+ rnd_btn.click(fn=random_example, outputs=[txt])
447
+ txt.submit(fn=analyse, inputs=[txt], outputs=[img1, img2, img3, stats_out])
448
+
449
+ gr.HTML("""
450
+ <div style="text-align:center;padding:16px 0 8px;color:#64748b;font-size:0.85em;">
451
+ TemporalMesh Transformer Β· Vignesh, 2026 Β· MIT License Β·
452
+ <a href="https://doi.org/10.5281/zenodo.20287390" style="color:#58a6ff;">
453
+ DOI: 10.5281/zenodo.20287390
454
+ </a>
455
+ </div>
456
+ """)
457
+
458
+ demo.launch()