| import gradio as gr |
| import numpy as np |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| from io import BytesIO |
| from PIL import Image |
| import math, random |
|
|
| |
|
|
| RNG = np.random.default_rng(42) |
|
|
| def _softmax(x, axis=-1): |
| e = np.exp(x - x.max(axis=axis, keepdims=True)) |
| return e / e.sum(axis=axis, keepdims=True) |
|
|
| def _normalize(x): |
| n = np.linalg.norm(x, axis=-1, keepdims=True) |
| return x / np.maximum(n, 1e-8) |
|
|
| class NumpyTMT: |
| """Minimal TMT forward pass in pure NumPy.""" |
| V, D, H, L, K = 1000, 64, 4, 6, 4 |
| THRESH = 0.70 |
|
|
| def __init__(self): |
| rng = np.random.default_rng(42) |
| D, H, L, V = self.D, self.H, self.L, self.V |
| self.emb = rng.standard_normal((V, D)).astype(np.float32) * 0.02 |
| self.Wq = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.Wk = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.Wv = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.Wo = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.W1 = [rng.standard_normal((D, D*2)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.W2 = [rng.standard_normal((D*2, D)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.Wg = [rng.standard_normal((D, 1)).astype(np.float32) * 0.02 for _ in range(L)] |
| self.decay = rng.uniform(0.7, 1.0, (L, H)).astype(np.float32) |
|
|
| def _mesh(self, x): |
| xn = _normalize(x) |
| sim = xn @ xn.T |
| np.fill_diagonal(sim, -1e9) |
| k = min(self.K, x.shape[0]-1) |
| return np.argsort(-sim, axis=-1)[:, :k], np.sort(-sim, axis=-1)[:, :k] * -1 |
|
|
| def _layer(self, x, l): |
| S, D = x.shape |
| H = self.H |
| dk = D // H |
| idx, _ = self._mesh(x) |
|
|
| Q = (x @ self.Wq[l]).reshape(S, H, dk).transpose(1,0,2) |
| K = (x @ self.Wk[l]).reshape(S, H, dk).transpose(1,0,2) |
| V = (x @ self.Wv[l]).reshape(S, H, dk).transpose(1,0,2) |
|
|
| sc = np.matmul(Q, K.transpose(0,2,1)) / math.sqrt(dk) |
| |
| for h in range(H): |
| dist = np.abs(np.arange(S)[:,None] - np.arange(S)[None,:]) |
| sc[h] = sc[h] * (self.decay[l,h] ** dist) |
|
|
| |
| mask = np.full_like(sc, -1e9) |
| for h in range(H): |
| mask[h, np.arange(S)[:,None], idx] = sc[h, np.arange(S)[:,None], idx] |
| attn = _softmax(mask, axis=-1) |
|
|
| out = np.matmul(attn, V) |
| out = out.transpose(1,0,2).reshape(S, D) |
| out = out @ self.Wo[l] |
|
|
| x = x + out |
| ff = np.maximum(0, x @ self.W1[l]) @ self.W2[l] |
| x = x + ff |
| conf = 1 / (1 + np.exp(-(x @ self.Wg[l]).squeeze(-1))) |
| return x, conf, attn.mean(0) |
|
|
| def forward(self, ids): |
| S = len(ids) |
| x = self.emb[ids].copy() |
| frozen = np.zeros(S, dtype=bool) |
| exits, confs, attns = [], [], [] |
| for l in range(self.L): |
| xn, cf, aw = self._layer(x, l) |
| ne = (~frozen) & (cf > self.THRESH) |
| frozen = frozen | ne |
| x = np.where(frozen[:,None], x, xn) |
| exits.append(ne.astype(float)) |
| confs.append(cf.copy()) |
| attns.append(aw) |
| return exits, confs, attns |
|
|
| MODEL = NumpyTMT() |
|
|
| |
|
|
| BG, CARD, GRID = "#0f172a", "#1e293b", "#334155" |
| TC = ["#22c55e", "#3b82f6", "#f59e0b", "#ef4444"] |
| TL = ["Function words", "Common verbs", "Domain terms", "Complex/rare"] |
| WTYPES = { |
| "the":0,"a":0,"an":0,"of":0,"in":0,"to":0,"and":0,"is":0,"are":0,"by":0, |
| "on":0,"for":0,"with":0,"this":0,"that":0,"it":0,"at":0,"or":0,"not":0, |
| "learned":1,"allow":1,"predict":1,"require":1,"adapts":1,"reduces":1, |
| "operate":1,"jumps":1,"represent":1,"trained":1,"uses":1,"focus":1,"make":1, |
| "neural":2,"network":2,"attention":2,"transformer":2,"semantic":2, |
| "topology":2,"graph":2,"compute":2,"language":2,"model":2,"token":2,"deep":2, |
| "dynamic":3,"adaptive":3,"complex":3,"patterns":3,"architecture":3, |
| "routing":3,"decay":3,"mechanisms":3,"structured":3,"relevant":3,"rare":3, |
| } |
|
|
| SAMPLES = [ |
| "The neural network learned to represent complex patterns in the data", |
| "Attention mechanisms allow transformers to focus on relevant tokens", |
| "Dynamic graph topology adapts to the semantic content of the sequence", |
| "Adaptive depth routing reduces compute by fifty percent on average", |
| "Language models predict the next word given the previous context", |
| "Graph neural networks operate over structured relational data", |
| "The quick brown fox jumps over the lazy dog near the river", |
| "Complex technical terminology requires deeper processing than simple words", |
| ] |
|
|
| def tokenize(text): |
| words = text.strip().split()[:32] |
| if len(words) < 3: |
| words = (words + ["the", "model", "runs"])[:3] |
| ids = [abs(hash(w.lower())) % (NumpyTMT.V - 2) + 1 for w in words] |
| return words, ids |
|
|
| def fig2arr(fig): |
| buf = BytesIO() |
| fig.savefig(buf, format="png", dpi=100, bbox_inches="tight", facecolor=BG) |
| buf.seek(0) |
| arr = np.array(Image.open(buf).convert("RGB")) |
| plt.close(fig); buf.close() |
| return arr |
|
|
| |
|
|
| def plot_heatmap(words, exits, confs): |
| S = len(words); N = len(exits) |
| mat = np.stack(exits, 0) |
| con = np.stack(confs, 0) |
|
|
| fig, (ax, ax2) = plt.subplots(1, 2, figsize=(13, max(3.5, S*0.32+2))) |
| fig.patch.set_facecolor(BG) |
|
|
| ax.set_facecolor(CARD) |
| im = ax.imshow(mat, aspect="auto", cmap="RdYlGn", vmin=0, vmax=1) |
| ax.set_yticks(range(N)); ax.set_yticklabels([f"L{i+1}" for i in range(N)], color="white", fontsize=9) |
| ax.set_xticks(range(S)); ax.set_xticklabels([w[:9] for w in words], rotation=45, ha="right", color="white", fontsize=8) |
| ax.set_title("Exit Gate β green = token frozen (compute saved)", color="white", fontsize=11, pad=8) |
| plt.colorbar(im, ax=ax, fraction=0.03) |
| for sp in ax.spines.values(): sp.set_color(GRID) |
|
|
| ax2.set_facecolor(CARD) |
| avg = con.mean(axis=1); ls = list(range(1, N+1)) |
| ax2.plot(ls, avg, "o-", color="#60a5fa", lw=2.5, ms=7) |
| ax2.fill_between(ls, avg, alpha=0.2, color="#60a5fa") |
| ax2.axhline(NumpyTMT.THRESH, color="#f59e0b", lw=1.5, ls="--", label=f"Threshold {NumpyTMT.THRESH}") |
| ax2.set_xlabel("Layer", color="white", fontsize=10) |
| ax2.set_ylabel("Avg Confidence", color="white", fontsize=10) |
| ax2.set_title("Gate Confidence per Layer", color="white", fontsize=11) |
| ax2.tick_params(colors="white") |
| ax2.legend(fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID) |
| for sp in ax2.spines.values(): sp.set_color(GRID) |
| plt.tight_layout() |
| return fig2arr(fig) |
|
|
|
|
| def plot_graph(words, attns): |
| S = len(words) |
| angles = np.linspace(0, 2*np.pi, S, endpoint=False) |
| pos = np.stack([np.cos(angles), np.sin(angles)], 1) |
| n = min(3, len(attns)) |
| idxs = [0] if n==1 else ([0,-1] if n==2 else [0, len(attns)//2, -1]) |
| titles = (["Layer 1"] if n==1 |
| else ["Layer 1 β Initial", f"Layer {len(attns)} β Final"] if n==2 |
| else ["Layer 1 β Initial", f"Layer {len(attns)//2+1} β Mid", |
| f"Layer {len(attns)} β Final"]) |
|
|
| fig, axes = plt.subplots(1, n, figsize=(5*n, 5)) |
| if n == 1: axes = [axes] |
| fig.patch.set_facecolor(BG) |
|
|
| for ax, li, title in zip(axes, idxs, titles): |
| ax.set_facecolor(CARD) |
| aw = attns[li] |
| k = min(NumpyTMT.K, S-1) |
| for i in range(S): |
| for j in np.argsort(aw[i])[::-1][:k]: |
| w = float(aw[i, j]) |
| ax.plot([pos[i,0], pos[j,0]], [pos[i,1], pos[j,1]], |
| color="#3b82f6", alpha=min(0.9, w*6+0.1), lw=max(0.4, w*4)) |
| for i, word in enumerate(words): |
| c = TC[WTYPES.get(word.lower(), 1)] |
| ax.scatter(pos[i,0], pos[i,1], c=c, s=220, zorder=5, edgecolors="white", linewidths=1.2) |
| ax.text(pos[i,0]*1.3, pos[i,1]*1.3, word[:8], |
| ha="center", va="center", fontsize=7.5, color="white") |
| ax.set_xlim(-1.7, 1.7); ax.set_ylim(-1.7, 1.7) |
| ax.set_title(title, color="white", fontsize=10, pad=6); ax.axis("off") |
|
|
| fig.legend(handles=[mpatches.Patch(color=TC[i], label=TL[i]) for i in range(4)], |
| loc="lower center", ncol=4, fontsize=9, |
| facecolor=CARD, labelcolor="white", edgecolor=GRID, bbox_to_anchor=(0.5, -0.05)) |
| plt.tight_layout() |
| return fig2arr(fig) |
|
|
|
|
| def plot_depth(words, exits): |
| S = len(words); N = len(exits) |
| em = np.stack(exits, 0) |
| el = [int(np.argmax(em[:, i])+1) if em[:, i].max() > 0 else N for i in range(S)] |
|
|
| fig, ax = plt.subplots(figsize=(max(8, S*0.8), 5)) |
| fig.patch.set_facecolor(BG); ax.set_facecolor(CARD) |
|
|
| bars = ax.bar(range(S), el, |
| color=[TC[WTYPES.get(w.lower(), 1)] for w in words], |
| alpha=0.9, edgecolor="white", linewidth=0.6) |
| avg = float(np.mean(el)) |
| ax.axhline(N, color="#94a3b8", lw=1.5, ls="--", label=f"Max ({N}L)") |
| ax.axhline(avg, color="#f59e0b", lw=2, ls="-.", label=f"Avg {avg:.1f}L = {avg/N*100:.0f}% compute") |
| for bar, val in zip(bars, el): |
| ax.text(bar.get_x()+bar.get_width()/2, val+0.07, str(val), |
| ha="center", va="bottom", fontsize=9, color="white", fontweight="bold") |
| ax.set_xticks(range(S)); ax.set_xticklabels(words, rotation=40, ha="right", color="white", fontsize=9) |
| ax.set_ylabel("Layers used", color="white", fontsize=11); ax.set_ylim(0, N+2) |
| ax.set_title("Per-Token Compute Depth β Simple exits early, Complex goes deep", |
| color="white", fontsize=12) |
| ax.tick_params(colors="white") |
| for sp in ax.spines.values(): sp.set_color(GRID) |
| patches = [mpatches.Patch(color=TC[i], label=TL[i]) for i in range(4)] |
| patches.append(mpatches.Patch(color="#f59e0b", label=f"Avg {avg:.1f}L")) |
| ax.legend(handles=patches, fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID, ncol=3) |
| plt.tight_layout() |
| return fig2arr(fig) |
|
|
| |
|
|
| def analyse(sentence): |
| if not sentence or not sentence.strip(): |
| sentence = random.choice(SAMPLES) |
| words, ids = tokenize(sentence) |
| exits, confs, attns = MODEL.forward(ids) |
|
|
| img1 = plot_heatmap(words, exits, confs) |
| img2 = plot_graph(words, attns) |
| img3 = plot_depth(words, exits) |
|
|
| S = len(words); N = len(exits) |
| em = np.stack(exits, 0) |
| el = [int(np.argmax(em[:, i])+1) if em[:, i].max() > 0 else N for i in range(S)] |
| avg = float(np.mean(el)) |
|
|
| md = ( |
| "### Analysis Results\n\n" |
| "| Metric | Value |\n" |
| "|:--|:--|\n" |
| f"| Tokens | {S} |\n" |
| f"| Avg depth | {avg:.1f} / {N} layers |\n" |
| f"| Compute used | **{avg/N*100:.0f}%** |\n" |
| f"| Compute saved | **{(1-avg/N)*100:.0f}%** |\n" |
| f"| Earliest exit | `{words[int(np.argmin(el))]}` β layer {min(el)} |\n" |
| f"| Deepest token | `{words[int(np.argmax(el))]}` β layer {max(el)} |\n\n" |
| "π [Paper](https://doi.org/10.5281/zenodo.20287390) Β· " |
| "π€ [Model](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer) Β· " |
| "π» [Code](https://github.com/vignesh2027/TemporalMesh-Transformer)" |
| ) |
| return img1, img2, img3, md |
|
|
|
|
| demo = gr.Interface( |
| fn=analyse, |
| inputs=gr.Textbox( |
| label="Enter a sentence", |
| placeholder="e.g. The neural network learned to represent complex patterns", |
| lines=2, |
| value=SAMPLES[0], |
| ), |
| outputs=[ |
| gr.Image(label="Exit Gate Heatmap + Confidence", type="numpy"), |
| gr.Image(label="Dynamic Attention Graph", type="numpy"), |
| gr.Image(label="Per-Token Compute Depth", type="numpy"), |
| gr.Markdown(label="Stats"), |
| ], |
| title="πΈοΈ TemporalMesh Transformer Demo", |
| description=( |
| "Visualise **dynamic graph attention**, **temporal decay**, and " |
| "**per-token adaptive depth routing** on any sentence.\n\n" |
| "π [Paper (Zenodo)](https://doi.org/10.5281/zenodo.20287390) Β· " |
| "π€ [Model Card](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer) Β· " |
| "π» [GitHub](https://github.com/vignesh2027/TemporalMesh-Transformer) Β· " |
| "π [Dataset](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)" |
| ), |
| examples=[ |
| ["The neural network learned to represent complex patterns in the data"], |
| ["Attention mechanisms allow transformers to focus on relevant tokens"], |
| ["Dynamic graph topology adapts to the semantic content of the sequence"], |
| ["Adaptive depth routing reduces compute by fifty percent on average"], |
| ["Language models predict the next word given the previous context"], |
| ["Graph neural networks operate over structured relational data"], |
| ], |
| cache_examples=False, |
| flagging_mode="never", |
| ) |
|
|
| demo.launch() |
|
|