vigneshwar234's picture
Remove PyTorch: pure numpy TMT β€” eliminates startup crash
44bc5b6 verified
import gradio as gr
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from io import BytesIO
from PIL import Image
import math, random
# ── Pure-NumPy TMT simulation (no PyTorch dependency) ────────────────────────
RNG = np.random.default_rng(42)
def _softmax(x, axis=-1):
e = np.exp(x - x.max(axis=axis, keepdims=True))
return e / e.sum(axis=axis, keepdims=True)
def _normalize(x):
n = np.linalg.norm(x, axis=-1, keepdims=True)
return x / np.maximum(n, 1e-8)
class NumpyTMT:
"""Minimal TMT forward pass in pure NumPy."""
V, D, H, L, K = 1000, 64, 4, 6, 4
THRESH = 0.70
def __init__(self):
rng = np.random.default_rng(42)
D, H, L, V = self.D, self.H, self.L, self.V
self.emb = rng.standard_normal((V, D)).astype(np.float32) * 0.02
self.Wq = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)]
self.Wk = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)]
self.Wv = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)]
self.Wo = [rng.standard_normal((D, D)).astype(np.float32) * 0.02 for _ in range(L)]
self.W1 = [rng.standard_normal((D, D*2)).astype(np.float32) * 0.02 for _ in range(L)]
self.W2 = [rng.standard_normal((D*2, D)).astype(np.float32) * 0.02 for _ in range(L)]
self.Wg = [rng.standard_normal((D, 1)).astype(np.float32) * 0.02 for _ in range(L)]
self.decay = rng.uniform(0.7, 1.0, (L, H)).astype(np.float32)
def _mesh(self, x):
xn = _normalize(x)
sim = xn @ xn.T
np.fill_diagonal(sim, -1e9)
k = min(self.K, x.shape[0]-1)
return np.argsort(-sim, axis=-1)[:, :k], np.sort(-sim, axis=-1)[:, :k] * -1
def _layer(self, x, l):
S, D = x.shape
H = self.H
dk = D // H
idx, _ = self._mesh(x)
Q = (x @ self.Wq[l]).reshape(S, H, dk).transpose(1,0,2) # H,S,dk
K = (x @ self.Wk[l]).reshape(S, H, dk).transpose(1,0,2)
V = (x @ self.Wv[l]).reshape(S, H, dk).transpose(1,0,2)
sc = np.matmul(Q, K.transpose(0,2,1)) / math.sqrt(dk) # H,S,S
# apply temporal decay per head
for h in range(H):
dist = np.abs(np.arange(S)[:,None] - np.arange(S)[None,:])
sc[h] = sc[h] * (self.decay[l,h] ** dist)
# mesh masking: keep only top-k neighbours
mask = np.full_like(sc, -1e9)
for h in range(H):
mask[h, np.arange(S)[:,None], idx] = sc[h, np.arange(S)[:,None], idx]
attn = _softmax(mask, axis=-1) # H,S,S
out = np.matmul(attn, V) # H,S,dk
out = out.transpose(1,0,2).reshape(S, D)
out = out @ self.Wo[l]
x = x + out
ff = np.maximum(0, x @ self.W1[l]) @ self.W2[l]
x = x + ff
conf = 1 / (1 + np.exp(-(x @ self.Wg[l]).squeeze(-1)))
return x, conf, attn.mean(0) # attn: S,S
def forward(self, ids):
S = len(ids)
x = self.emb[ids].copy()
frozen = np.zeros(S, dtype=bool)
exits, confs, attns = [], [], []
for l in range(self.L):
xn, cf, aw = self._layer(x, l)
ne = (~frozen) & (cf > self.THRESH)
frozen = frozen | ne
x = np.where(frozen[:,None], x, xn)
exits.append(ne.astype(float))
confs.append(cf.copy())
attns.append(aw)
return exits, confs, attns
MODEL = NumpyTMT()
# ── Colour helpers ─────────────────────────────────────────────────────────────
BG, CARD, GRID = "#0f172a", "#1e293b", "#334155"
TC = ["#22c55e", "#3b82f6", "#f59e0b", "#ef4444"]
TL = ["Function words", "Common verbs", "Domain terms", "Complex/rare"]
WTYPES = {
"the":0,"a":0,"an":0,"of":0,"in":0,"to":0,"and":0,"is":0,"are":0,"by":0,
"on":0,"for":0,"with":0,"this":0,"that":0,"it":0,"at":0,"or":0,"not":0,
"learned":1,"allow":1,"predict":1,"require":1,"adapts":1,"reduces":1,
"operate":1,"jumps":1,"represent":1,"trained":1,"uses":1,"focus":1,"make":1,
"neural":2,"network":2,"attention":2,"transformer":2,"semantic":2,
"topology":2,"graph":2,"compute":2,"language":2,"model":2,"token":2,"deep":2,
"dynamic":3,"adaptive":3,"complex":3,"patterns":3,"architecture":3,
"routing":3,"decay":3,"mechanisms":3,"structured":3,"relevant":3,"rare":3,
}
SAMPLES = [
"The neural network learned to represent complex patterns in the data",
"Attention mechanisms allow transformers to focus on relevant tokens",
"Dynamic graph topology adapts to the semantic content of the sequence",
"Adaptive depth routing reduces compute by fifty percent on average",
"Language models predict the next word given the previous context",
"Graph neural networks operate over structured relational data",
"The quick brown fox jumps over the lazy dog near the river",
"Complex technical terminology requires deeper processing than simple words",
]
def tokenize(text):
words = text.strip().split()[:32]
if len(words) < 3:
words = (words + ["the", "model", "runs"])[:3]
ids = [abs(hash(w.lower())) % (NumpyTMT.V - 2) + 1 for w in words]
return words, ids
def fig2arr(fig):
buf = BytesIO()
fig.savefig(buf, format="png", dpi=100, bbox_inches="tight", facecolor=BG)
buf.seek(0)
arr = np.array(Image.open(buf).convert("RGB"))
plt.close(fig); buf.close()
return arr
# ── Visualisations ─────────────────────────────────────────────────────────────
def plot_heatmap(words, exits, confs):
S = len(words); N = len(exits)
mat = np.stack(exits, 0) # NΓ—S
con = np.stack(confs, 0) # NΓ—S
fig, (ax, ax2) = plt.subplots(1, 2, figsize=(13, max(3.5, S*0.32+2)))
fig.patch.set_facecolor(BG)
ax.set_facecolor(CARD)
im = ax.imshow(mat, aspect="auto", cmap="RdYlGn", vmin=0, vmax=1)
ax.set_yticks(range(N)); ax.set_yticklabels([f"L{i+1}" for i in range(N)], color="white", fontsize=9)
ax.set_xticks(range(S)); ax.set_xticklabels([w[:9] for w in words], rotation=45, ha="right", color="white", fontsize=8)
ax.set_title("Exit Gate β€” green = token frozen (compute saved)", color="white", fontsize=11, pad=8)
plt.colorbar(im, ax=ax, fraction=0.03)
for sp in ax.spines.values(): sp.set_color(GRID)
ax2.set_facecolor(CARD)
avg = con.mean(axis=1); ls = list(range(1, N+1))
ax2.plot(ls, avg, "o-", color="#60a5fa", lw=2.5, ms=7)
ax2.fill_between(ls, avg, alpha=0.2, color="#60a5fa")
ax2.axhline(NumpyTMT.THRESH, color="#f59e0b", lw=1.5, ls="--", label=f"Threshold {NumpyTMT.THRESH}")
ax2.set_xlabel("Layer", color="white", fontsize=10)
ax2.set_ylabel("Avg Confidence", color="white", fontsize=10)
ax2.set_title("Gate Confidence per Layer", color="white", fontsize=11)
ax2.tick_params(colors="white")
ax2.legend(fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID)
for sp in ax2.spines.values(): sp.set_color(GRID)
plt.tight_layout()
return fig2arr(fig)
def plot_graph(words, attns):
S = len(words)
angles = np.linspace(0, 2*np.pi, S, endpoint=False)
pos = np.stack([np.cos(angles), np.sin(angles)], 1)
n = min(3, len(attns))
idxs = [0] if n==1 else ([0,-1] if n==2 else [0, len(attns)//2, -1])
titles = (["Layer 1"] if n==1
else ["Layer 1 β€” Initial", f"Layer {len(attns)} β€” Final"] if n==2
else ["Layer 1 β€” Initial", f"Layer {len(attns)//2+1} β€” Mid",
f"Layer {len(attns)} β€” Final"])
fig, axes = plt.subplots(1, n, figsize=(5*n, 5))
if n == 1: axes = [axes]
fig.patch.set_facecolor(BG)
for ax, li, title in zip(axes, idxs, titles):
ax.set_facecolor(CARD)
aw = attns[li] # SΓ—S
k = min(NumpyTMT.K, S-1)
for i in range(S):
for j in np.argsort(aw[i])[::-1][:k]:
w = float(aw[i, j])
ax.plot([pos[i,0], pos[j,0]], [pos[i,1], pos[j,1]],
color="#3b82f6", alpha=min(0.9, w*6+0.1), lw=max(0.4, w*4))
for i, word in enumerate(words):
c = TC[WTYPES.get(word.lower(), 1)]
ax.scatter(pos[i,0], pos[i,1], c=c, s=220, zorder=5, edgecolors="white", linewidths=1.2)
ax.text(pos[i,0]*1.3, pos[i,1]*1.3, word[:8],
ha="center", va="center", fontsize=7.5, color="white")
ax.set_xlim(-1.7, 1.7); ax.set_ylim(-1.7, 1.7)
ax.set_title(title, color="white", fontsize=10, pad=6); ax.axis("off")
fig.legend(handles=[mpatches.Patch(color=TC[i], label=TL[i]) for i in range(4)],
loc="lower center", ncol=4, fontsize=9,
facecolor=CARD, labelcolor="white", edgecolor=GRID, bbox_to_anchor=(0.5, -0.05))
plt.tight_layout()
return fig2arr(fig)
def plot_depth(words, exits):
S = len(words); N = len(exits)
em = np.stack(exits, 0) # NΓ—S
el = [int(np.argmax(em[:, i])+1) if em[:, i].max() > 0 else N for i in range(S)]
fig, ax = plt.subplots(figsize=(max(8, S*0.8), 5))
fig.patch.set_facecolor(BG); ax.set_facecolor(CARD)
bars = ax.bar(range(S), el,
color=[TC[WTYPES.get(w.lower(), 1)] for w in words],
alpha=0.9, edgecolor="white", linewidth=0.6)
avg = float(np.mean(el))
ax.axhline(N, color="#94a3b8", lw=1.5, ls="--", label=f"Max ({N}L)")
ax.axhline(avg, color="#f59e0b", lw=2, ls="-.", label=f"Avg {avg:.1f}L = {avg/N*100:.0f}% compute")
for bar, val in zip(bars, el):
ax.text(bar.get_x()+bar.get_width()/2, val+0.07, str(val),
ha="center", va="bottom", fontsize=9, color="white", fontweight="bold")
ax.set_xticks(range(S)); ax.set_xticklabels(words, rotation=40, ha="right", color="white", fontsize=9)
ax.set_ylabel("Layers used", color="white", fontsize=11); ax.set_ylim(0, N+2)
ax.set_title("Per-Token Compute Depth β€” Simple exits early, Complex goes deep",
color="white", fontsize=12)
ax.tick_params(colors="white")
for sp in ax.spines.values(): sp.set_color(GRID)
patches = [mpatches.Patch(color=TC[i], label=TL[i]) for i in range(4)]
patches.append(mpatches.Patch(color="#f59e0b", label=f"Avg {avg:.1f}L"))
ax.legend(handles=patches, fontsize=9, facecolor=CARD, labelcolor="white", edgecolor=GRID, ncol=3)
plt.tight_layout()
return fig2arr(fig)
# ── Gradio entry point ────────────────────────────────────────────────────────
def analyse(sentence):
if not sentence or not sentence.strip():
sentence = random.choice(SAMPLES)
words, ids = tokenize(sentence)
exits, confs, attns = MODEL.forward(ids)
img1 = plot_heatmap(words, exits, confs)
img2 = plot_graph(words, attns)
img3 = plot_depth(words, exits)
S = len(words); N = len(exits)
em = np.stack(exits, 0)
el = [int(np.argmax(em[:, i])+1) if em[:, i].max() > 0 else N for i in range(S)]
avg = float(np.mean(el))
md = (
"### Analysis Results\n\n"
"| Metric | Value |\n"
"|:--|:--|\n"
f"| Tokens | {S} |\n"
f"| Avg depth | {avg:.1f} / {N} layers |\n"
f"| Compute used | **{avg/N*100:.0f}%** |\n"
f"| Compute saved | **{(1-avg/N)*100:.0f}%** |\n"
f"| Earliest exit | `{words[int(np.argmin(el))]}` β†’ layer {min(el)} |\n"
f"| Deepest token | `{words[int(np.argmax(el))]}` β†’ layer {max(el)} |\n\n"
"πŸ“„ [Paper](https://doi.org/10.5281/zenodo.20287390) Β· "
"πŸ€— [Model](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer) Β· "
"πŸ’» [Code](https://github.com/vignesh2027/TemporalMesh-Transformer)"
)
return img1, img2, img3, md
demo = gr.Interface(
fn=analyse,
inputs=gr.Textbox(
label="Enter a sentence",
placeholder="e.g. The neural network learned to represent complex patterns",
lines=2,
value=SAMPLES[0],
),
outputs=[
gr.Image(label="Exit Gate Heatmap + Confidence", type="numpy"),
gr.Image(label="Dynamic Attention Graph", type="numpy"),
gr.Image(label="Per-Token Compute Depth", type="numpy"),
gr.Markdown(label="Stats"),
],
title="πŸ•ΈοΈ TemporalMesh Transformer Demo",
description=(
"Visualise **dynamic graph attention**, **temporal decay**, and "
"**per-token adaptive depth routing** on any sentence.\n\n"
"πŸ“„ [Paper (Zenodo)](https://doi.org/10.5281/zenodo.20287390) Β· "
"πŸ€— [Model Card](https://huggingface.co/vigneshwar234/TemporalMesh-Transformer) Β· "
"πŸ’» [GitHub](https://github.com/vignesh2027/TemporalMesh-Transformer) Β· "
"πŸ“Š [Dataset](https://huggingface.co/datasets/vigneshwar234/TMT-Benchmarks)"
),
examples=[
["The neural network learned to represent complex patterns in the data"],
["Attention mechanisms allow transformers to focus on relevant tokens"],
["Dynamic graph topology adapts to the semantic content of the sequence"],
["Adaptive depth routing reduces compute by fifty percent on average"],
["Language models predict the next word given the previous context"],
["Graph neural networks operate over structured relational data"],
],
cache_examples=False,
flagging_mode="never",
)
demo.launch()