import gradio as gr import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from io import BytesIO from PIL import Image import math, random # ── Pure-NumPy TMT simulation (no PyTorch dependency) ──────────────────────── RNG = np.random.default_rng(42) def _softmax(x, axis=-1): e = np.exp(x - x.max(axis=axis, keepdims=True)) return e / e.sum(axis=axis, keepdims=True) def _normalize(x): n = np.linalg.norm(x, axis=-1, keepdims=True) return x / np.maximum(n, 1e-8) class NumpyTMT: """Minimal TMT forward pass in pure NumPy.""" V, D, H, L, K = 1000, 64, 4, 6, 4 THRESH = 0.70 def __init__(self): rng = np.random.default_rng(42) D, H, L, V = self.D, self.H, self.L, self.V self.emb = rng.standard_normal((V, D)).astype(np.float32) * 0.02 self.Wq = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] self.Wk = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] self.Wv = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] self.Wo = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] self.W1 = [rng.standard_normal((D, D*2)).astype(np.float32) * 0.02 for _ in range(L)] self.W2 = [rng.standard_normal((D*2, D)).astype(np.float32) * 0.02 for _ in range(L)] self.Wg = [rng.standard_normal((D, 1)).astype(np.float32) * 0.02 for _ in range(L)] self.decay = rng.uniform(0.7, 1.0, (L, H)).astype(np.float32) def _mesh(self, x): xn = _normalize(x) sim = xn @ xn.T np.fill_diagonal(sim, -1e9) k = min(self.K, x.shape[0]-1) return np.argsort(-sim, axis=-1)[:, :k], np.sort(-sim, axis=-1)[:, :k] * -1 def _layer(self, x, l): S, D = x.shape H = self.H dk = D // H idx, _ = self._mesh(x) Q = (x @ self.Wq[l]).reshape(S, H, dk).transpose(1,0,2) # H,S,dk K = (x @ self.Wk[l]).reshape(S, H, dk).transpose(1,0,2) V = (x @ self.Wv[l]).reshape(S, H, dk).transpose(1,0,2) sc = np.matmul(Q, K.transpose(0,2,1)) / math.sqrt(dk) # H,S,S # apply temporal decay per head for h in range(H): dist = np.abs(np.arange(S)[:,None] - np.arange(S)[None,:]) sc[h] = sc[h] * (self.decay[l,h] ** dist) # mesh masking: keep only top-k neighbours mask = np.full_like(sc, -1e9) for h in range(H): mask[h, np.arange(S)[:,None], idx] = sc[h, np.arange(S)[:,None], idx] attn = _softmax(mask, axis=-1) # H,S,S out = np.matmul(attn, V) # H,S,dk out = out.transpose(1,0,2).reshape(S, D) out = out @ self.Wo[l] x = x + out ff = np.maximum(0, x @ self.W1[l]) @ self.W2[l] x = x + ff conf = 1 / (1 + np.exp(-(x @ self.Wg[l]).squeeze(-1))) return x, conf, attn.mean(0) # attn: S,S def forward(self, ids): S = len(ids) x = self.emb[ids].copy() frozen = np.zeros(S, dtype=bool) exits, confs, attns = [], [], [] for l in range(self.L): xn, cf, aw = self._layer(x, l) ne = (~frozen) & (cf > self.THRESH) frozen = frozen | ne x = np.where(frozen[:,None], x, xn) exits.append(ne.astype(float)) confs.append(cf.copy()) attns.append(aw) return exits, confs, attns MODEL = NumpyTMT() # ── Colour helpers ───────────────────────────────────────────────────────────── BG, CARD, GRID = "#0f172a", "#1e293b", "#334155" TC = ["#22c55e", "#3b82f6", "#f59e0b", "#ef4444"] TL = ["Function words", "Common verbs", "Domain terms", "Complex/rare"] WTYPES = { "the":0,"a":0,"an":0,"of":0,"in":0,"to":0,"and":0,"is":0,"are":0,"by":0, "on":0,"for":0,"with":0,"this":0,"that":0,"it":0,"at":0,"or":0,"not":0, "learned":1,"allow":1,"predict":1,"require":1,"adapts":1,"reduces":1, "operate":1,"jumps":1,"represent":1,"trained":1,"uses":1,"focus":1,"make":1, "neural":2,"network":2,"attention":2,"transformer":2,"semantic":2, "topology":2,"graph":2,"compute":2,"language":2,"model":2,"token":2,"deep":2, "dynamic":3,"adaptive":3,"complex":3,"patterns":3,"architecture":3, "routing":3,"decay":3,"mechanisms":3,"structured":3,"relevant":3,"rare":3, } SAMPLES = [ "The neural network learned to represent complex patterns in the data", "Attention mechanisms allow transformers to focus on relevant tokens", "Dynamic graph topology adapts to the semantic content of the sequence", "Adaptive depth routing reduces compute by fifty percent on average", "Language models predict the next word given the previous context", "Graph neural networks operate over structured relational data", "The quick brown fox jumps over the lazy dog near the river", "Complex technical terminology requires deeper processing than simple words", ] def tokenize(text): words = text.strip().split()[:32] if len(words) < 3: words = (words + ["the", "model", "runs"])[:3] ids = [abs(hash(w.lower())) % (NumpyTMT.V - 2) + 1 for w in words] return words, ids def fig2arr(fig): buf = BytesIO() fig.savefig(buf, format="png", dpi=100, bbox_inches="tight", facecolor=BG) buf.seek(0) arr = np.array(Image.open(buf).convert("RGB")) plt.close(fig); buf.close() return arr # ── Visualisations ───────────────────────────────────────────────────────────── def plot_heatmap(words, exits, confs): S = len(words); N = len(exits) mat = np.stack(exits, 0) # N×S con = np.stack(confs, 0) # N×S fig, (ax, ax2) = plt.subplots(1, 2, figsize=(13, max(3.5, S*0.32+2))) fig.patch.set_facecolor(BG) ax.set_facecolor(CARD) im = ax.imshow(mat, aspect="auto", cmap="RdYlGn", vmin=0, vmax=1) ax.set_yticks(range(N)); ax.set_yticklabels([f"L{i+1}" for i in range(N)], color="white", fontsize=9) ax.set_xticks(range(S)); ax.set_xticklabels([w[:9] for w in words], rotation=45, ha="right", color="white", fontsize=8) ax.set_title("Exit Gate — green = token frozen (compute saved)", color="white", fontsize=11, pad=8) plt.colorbar(im, ax=ax, fraction=0.03) for sp in ax.spines.values(): sp.set_color(GRID) ax2.set_facecolor(CARD) avg = con.mean(axis=1); ls = list(range(1, N+1)) ax2.plot(ls, avg, "o-", color="#60a5fa", lw=2.5, ms=7) ax2.fill_between(ls, avg, alpha=0.2, color="#60a5fa") ax2.axhline(NumpyTMT.THRESH, color="#f59e0b", lw=1.5, ls="--", label=f"Threshold {NumpyTMT.THRESH}") ax2.set_xlabel("Layer", color="white", fontsize=10) ax2.set_ylabel("Avg Confidence", color="white", fontsize=10) ax2.set_title("Gate Confidence per Layer", color="white", fontsize=11) ax2.tick_params(colors="white") ax2.legend(fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID) for sp in ax2.spines.values(): sp.set_color(GRID) plt.tight_layout() return fig2arr(fig) def plot_graph(words, attns): S = len(words) angles = np.linspace(0, 2*np.pi, S, endpoint=False) pos = np.stack([np.cos(angles), np.sin(angles)], 1) n = min(3, len(attns)) idxs = [0] if n==1 else ([0,-1] if n==2 else [0, len(attns)//2, -1]) titles = (["Layer 1"] if n==1 else ["Layer 1 — Initial", f"Layer {len(attns)} — Final"] if n==2 else ["Layer 1 — Initial", f"Layer {len(attns)//2+1} — Mid", f"Layer {len(attns)} — Final"]) fig, axes = plt.subplots(1, n, figsize=(5*n, 5)) if n == 1: axes = [axes] fig.patch.set_facecolor(BG) for ax, li, title in zip(axes, idxs, titles): ax.set_facecolor(CARD) aw = attns[li] # S×S k = min(NumpyTMT.K, S-1) for i in range(S): for j in np.argsort(aw[i])[::-1][:k]: w = float(aw[i, j]) ax.plot([pos[i,0], pos[j,0]], [pos[i,1], pos[j,1]], color="#3b82f6", alpha=min(0.9, w*6+0.1), lw=max(0.4, w*4)) for i, word in enumerate(words): c = TC[WTYPES.get(word.lower(), 1)] ax.scatter(pos[i,0], pos[i,1], c=c, s=220, zorder=5, edgecolors="white", linewidths=1.2) ax.text(pos[i,0]*1.3, pos[i,1]*1.3, word[:8], ha="center", va="center", fontsize=7.5, color="white") ax.set_xlim(-1.7, 1.7); ax.set_ylim(-1.7, 1.7) ax.set_title(title, color="white", fontsize=10, pad=6); ax.axis("off") fig.legend(handles=[mpatches.Patch(color=TC[i], label=TL[i]) for i in range(4)], loc="lower center", ncol=4, fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID, bbox_to_anchor=(0.5, -0.05)) plt.tight_layout() return fig2arr(fig) def plot_depth(words, exits): S = len(words); N = len(exits) em = np.stack(exits, 0) # N×S el = [int(np.argmax(em[:, i])+1) if em[:, i].max() > 0 else N for i in range(S)] fig, ax = plt.subplots(figsize=(max(8, S*0.8), 5)) fig.patch.set_facecolor(BG); ax.set_facecolor(CARD) bars = ax.bar(range(S), el, color=[TC[WTYPES.get(w.lower(), 1)] for w in words], alpha=0.9, edgecolor="white", linewidth=0.6) avg = float(np.mean(el)) ax.axhline(N, color="#94a3b8", lw=1.5, ls="--", label=f"Max ({N}L)") ax.axhline(avg, color="#f59e0b", lw=2, ls="-.", label=f"Avg {avg:.1f}L = {avg/N*100:.0f}% compute") for bar, val in zip(bars, el): ax.text(bar.get_x()+bar.get_width()/2, val+0.07, str(val), ha="center", va="bottom", fontsize=9, color="white", fontweight="bold") ax.set_xticks(range(S)); ax.set_xticklabels(words, rotation=40, ha="right", color="white", fontsize=9) ax.set_ylabel("Layers used", color="white", fontsize=11); ax.set_ylim(0, N+2) ax.set_title("Per-Token Compute Depth — Simple exits early, Complex goes deep", color="white", fontsize=12) ax.tick_params(colors="white") for sp in ax.spines.values(): sp.set_color(GRID) patches = [mpatches.Patch(color=TC[i], label=TL[i]) for i in range(4)] patches.append(mpatches.Patch(color="#f59e0b", label=f"Avg {avg:.1f}L")) ax.legend(handles=patches, fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID, ncol=3) plt.tight_layout() return fig2arr(fig) # ── Gradio entry point ──────────────────────────────────────────────────────── def analyse(sentence): if not sentence or not sentence.strip(): sentence = random.choice(SAMPLES) words, ids = tokenize(sentence) exits, confs, attns = MODEL.forward(ids) img1 = plot_heatmap(words, exits, confs) img2 = plot_graph(words, attns) img3 = plot_depth(words, exits) S = len(words); N = len(exits) em = np.stack(exits, 0) el = [int(np.argmax(em[:, i])+1) if em[:, i].max() > 0 else N for i in range(S)] avg = float(np.mean(el)) md = ( "### Analysis Results\n\n" "| Metric | Value |\n" "|:--|:--|\n" f"| Tokens | {S} |\n" f"| Avg depth | {avg:.1f} / {N} layers |\n" f"| Compute used | **{avg/N*100:.0f}%** |\n" f"| Compute saved | **{(1-avg/N)*100:.0f}%** |\n" f"| Earliest exit | `{words[int(np.argmin(el))]}` → layer {min(el)} |\n" f"| Deepest token | `{words[int(np.argmax(el))]}` → layer {max(el)} |\n\n" "📄 [Paper](https://doi.org/10.5281/zenodo.20287390) · " "🤗 [Model](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer) · " "💻 [Code](https://github.com/vignesh2027/TemporalMesh-Transformer)" ) return img1, img2, img3, md demo = gr.Interface( fn=analyse, inputs=gr.Textbox( label="Enter a sentence", placeholder="e.g. The neural network learned to represent complex patterns", lines=2, value=SAMPLES[0], ), outputs=[ gr.Image(label="Exit Gate Heatmap + Confidence", type="numpy"), gr.Image(label="Dynamic Attention Graph", type="numpy"), gr.Image(label="Per-Token Compute Depth", type="numpy"), gr.Markdown(label="Stats"), ], title="🕸️ TemporalMesh Transformer Demo", description=( "Visualise **dynamic graph attention**, **temporal decay**, and " "**per-token adaptive depth routing** on any sentence.\n\n" "📄 [Paper (Zenodo)](https://doi.org/10.5281/zenodo.20287390) · " "🤗 [Model Card](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer) · " "💻 [GitHub](https://github.com/vignesh2027/TemporalMesh-Transformer) · " "📊 [Dataset](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)" ), examples=[ ["The neural network learned to represent complex patterns in the data"], ["Attention mechanisms allow transformers to focus on relevant tokens"], ["Dynamic graph topology adapts to the semantic content of the sequence"], ["Adaptive depth routing reduces compute by fifty percent on average"], ["Language models predict the next word given the previous context"], ["Graph neural networks operate over structured relational data"], ], cache_examples=False, flagging_mode="never", ) demo.launch()