Spaces:
Runtime error
Runtime error
| # app.py | |
| import json | |
| from pathlib import Path | |
| import threading, time | |
| import solara | |
| import torch | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ------------------ Model ------------------ | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| # ------------------ Theme ------------------ | |
| theme_css = """ | |
| :root{ | |
| --primary:#38bdf8; /* light blue */ | |
| --bg:#ffffff; /* white */ | |
| --text:#0b0f14; /* near-black */ | |
| --muted:#6b7280; /* gray-500 */ | |
| --border:#e5e7eb; /* gray-200 */ | |
| --pale: rgba(56,189,248,0.15); /* base dots */ | |
| --mid: rgba(56,189,248,0.65); /* neighbors */ | |
| --bright: rgba(34,211,238,1.0); /* selected */ | |
| } | |
| body{ background:var(--bg); color:var(--text);} | |
| h1,h2,h3{ color:var(--text); } | |
| hr{ border-color:var(--border); } | |
| .btn{ background:var(--primary); color:#000; border:1px solid var(--primary); padding:6px 10px; border-radius:10px; } | |
| .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; } | |
| .rowbtn{ | |
| width:100%; text-align:left; padding:10px 12px; border-radius:12px; | |
| border:1px solid var(--border); background:#fff; | |
| } | |
| .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; } | |
| .rowgrid{ display:grid; grid-template-columns: 70px 120px 1fr; gap:10px; align-items:center; } | |
| .rowprob{ color:#111; font-weight:600; } | |
| .rowid{ color:var(--muted); font-variant-numeric: tabular-nums; } | |
| .rowtok{ color:#111; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; } | |
| .listheader{ color:var(--muted); font-size:12px; margin-bottom:6px;} | |
| """ | |
| # ------------------ App state ------------------ | |
| text_rx = solara.reactive("twinkle, twinkle, little ") | |
| preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"])) | |
| selected_token_id_rx = solara.reactive(None) | |
| neighbor_list_rx = solara.reactive([]) # list of (tok, sim) | |
| notice_rx = solara.reactive("Click a candidate (or hover to preview).") | |
| auto_running_rx = solara.reactive(True) # toggle if you ever want to disable auto-predict | |
| # ------------------ Embedding assets ------------------ | |
| ASSETS = Path("assets/embeddings") | |
| COORDS_PATH = ASSETS / "pca_top5k_coords.json" | |
| NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json" | |
| coords = {} | |
| neighbors = {} | |
| ids_set = set() | |
| if COORDS_PATH.exists() and NEIGH_PATH.exists(): | |
| coords = json.loads(COORDS_PATH.read_text("utf-8")) | |
| neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) | |
| ids_set = set(map(int, coords.keys())) | |
| else: | |
| notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.") | |
| # ------------------ Predict ------------------ | |
| def predict_top10(prompt: str) -> pd.DataFrame: | |
| if not prompt: | |
| return pd.DataFrame(columns=["probs", "id", "tok"]) | |
| tokens = tokenizer.encode(prompt, return_tensors="pt") | |
| out = model.generate( | |
| tokens, | |
| max_new_tokens=1, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=False, | |
| temperature=0.0, | |
| top_k=1, | |
| top_p=1.0, | |
| ) | |
| scores = torch.softmax(out.scores[0], dim=-1) # [1, vocab] | |
| topk = torch.topk(scores, 10) | |
| ids = [int(topk.indices[0, i]) for i in range(10)] | |
| probs = [float(topk.values[0, i]) for i in range(10)] | |
| toks = [tokenizer.decode([i]) for i in ids] | |
| df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks}) | |
| df["probs"] = df["probs"].map(lambda p: f"{p:.2%}") | |
| return df | |
| def on_predict(): | |
| df = predict_top10(text_rx.value) | |
| preds_rx.set(df) | |
| if len(df) > 0: | |
| tid = int(df.iloc[0]["id"]) | |
| preview_token(tid) # highlight top-1 by default | |
| # ------------------ Plotly figure ------------------ | |
| def base_scatter(): | |
| fig = go.Figure() | |
| if coords: | |
| xs, ys = zip(*[coords[k] for k in coords.keys()]) | |
| fig.add_trace(go.Scattergl( | |
| x=xs, y=ys, mode="markers", | |
| marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"), | |
| hoverinfo="skip", | |
| )) | |
| fig.update_layout( | |
| height=460, margin=dict(l=10, r=10, t=10, b=10), | |
| paper_bgcolor="white", plot_bgcolor="white", | |
| xaxis=dict(visible=False), yaxis=dict(visible=False), | |
| showlegend=False, | |
| ) | |
| return fig | |
| fig_rx = solara.reactive(base_scatter()) | |
| def get_neighbor_list(token_id: int, k: int = 18): | |
| if not ids_set or token_id not in ids_set: | |
| return [] | |
| raw = neighbors.get("neighbors", {}).get(str(token_id), []) | |
| return raw[:k] # [ [nid, sim], ... ] | |
| def highlight(token_id: int): | |
| """Return figure with neighbors + target highlighted and update neighbor chip list.""" | |
| fig = base_scatter() | |
| if not coords or token_id not in ids_set: | |
| neighbor_list_rx.set([]) | |
| return fig | |
| # neighbors | |
| nbrs = get_neighbor_list(token_id, k=20) | |
| if nbrs: | |
| nx = [coords[str(nid)][0] for nid, _ in nbrs] | |
| ny = [coords[str(nid)][1] for nid, _ in nbrs] | |
| fig.add_trace(go.Scattergl( | |
| x=nx, y=ny, mode="markers", | |
| marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"), | |
| hoverinfo="skip", | |
| )) | |
| # update chip list | |
| chips = [] | |
| for nid, sim in nbrs: | |
| chips.append((tokenizer.decode([int(nid)]), float(sim))) | |
| neighbor_list_rx.set(chips) | |
| else: | |
| neighbor_list_rx.set([]) | |
| # target | |
| tx, ty = coords[str(token_id)] | |
| fig.add_trace(go.Scattergl( | |
| x=[tx], y=[ty], mode="markers", | |
| marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)), | |
| hoverinfo="skip", | |
| )) | |
| return fig | |
| def preview_token(token_id: int): | |
| selected_token_id_rx.set(int(token_id)) | |
| fig_rx.set(highlight(int(token_id))) | |
| def append_token(token_id: int): | |
| decoded = tokenizer.decode([int(token_id)]) | |
| text_rx.set(text_rx.value + decoded) | |
| preview_token(int(token_id)) # highlight the one we appended | |
| on_predict() # refresh next top-10 for new text | |
| # ------------------ Auto-predict on typing (debounced) ------------------ | |
| def AutoPredictWatcher(): | |
| text = text_rx.value | |
| auto = auto_running_rx.value | |
| def effect(): | |
| if not auto: | |
| return | |
| cancelled = False | |
| snap = text | |
| def worker(): | |
| time.sleep(0.25) # debounce ~250ms | |
| if not cancelled and snap == text_rx.value: | |
| on_predict() | |
| threading.Thread(target=worker, daemon=True).start() | |
| def cleanup(): | |
| nonlocal cancelled | |
| cancelled = True | |
| return cleanup | |
| solara.use_effect(effect, [text, auto]) | |
| return solara.Text("", style={"display":"none"}) # nothing visible | |
| # ------------------ UI: custom clickable/hoverable list ------------------ | |
| def PredictionsList(): | |
| df = preds_rx.value | |
| with solara.Column(gap="6px", style={"maxWidth":"720px"}): | |
| solara.Markdown("### Prediction") | |
| solara.HTML(tag="div", unsafe_innerHTML='<div class="listheader"># probs token id predicted next token</div>') | |
| for i, row in df.iterrows(): | |
| tid = int(row["id"]); tok = row["tok"]; prob = row["probs"] | |
| label = f'<div class="rowgrid"><div class="rowprob">{i} {prob}</div><div class="rowid">{tid}</div><div class="rowtok">{tok}</div></div>' | |
| # full-width button acts like a row; hover previews, click appends | |
| solara.Button(None, | |
| on_click=lambda tid=tid: append_token(tid), | |
| on_mouse_enter=lambda tid=tid: preview_token(tid), | |
| html=label, | |
| classes=["rowbtn"], | |
| ) | |
| # ------------------ Page ------------------ | |
| def Page(): | |
| solara.Style(theme_css) | |
| with solara.Column(margin=12, gap="16px"): | |
| solara.Markdown("# Next-Token Predictor + Semantic Neighborhood") | |
| solara.Markdown( | |
| "Type text to see predictions update automatically. " | |
| "Click a candidate to append it and highlight its **semantic neighborhood**. " | |
| "Hover a candidate to preview its neighborhood." | |
| ) | |
| # Input row (no Predict button needed anymore) | |
| solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"}) | |
| solara.Markdown(f"*{notice_rx.value}*") | |
| with solara.Row(gap="24px", style={"align-items": "flex-start"}): | |
| with solara.Column(): | |
| PredictionsList() | |
| with solara.Column(): | |
| solara.Markdown("### Semantic Neighborhood") | |
| if not coords: | |
| solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.") | |
| else: | |
| solara.FigurePlotly(fig_rx.value) | |
| # Neighbor chips | |
| if neighbor_list_rx.value: | |
| solara.Markdown("**Nearest neighbors:**") | |
| with solara.Row(style={"flex-wrap": "wrap"}): | |
| for tok, sim in neighbor_list_rx.value: | |
| solara.HTML(tag="span", unsafe_innerHTML=f'<span class="badge">{tok} {(sim*100):.1f}%</span>') | |
| AutoPredictWatcher() # invisible component that triggers auto-predict | |
| # Seed initial predictions | |
| on_predict() | |
| Page() | |