PeterPinetree's picture
Update app.py
e858eb0 verified
raw
history blame
9.62 kB
# app.py
import json
from pathlib import Path
import threading, time
import solara
import torch
import torch.nn.functional as F
import pandas as pd
import plotly.graph_objects as go
from transformers import AutoTokenizer, AutoModelForCausalLM
# ------------------ Model ------------------
MODEL_ID = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# ------------------ Theme ------------------
theme_css = """
:root{
--primary:#38bdf8; /* light blue */
--bg:#ffffff; /* white */
--text:#0b0f14; /* near-black */
--muted:#6b7280; /* gray-500 */
--border:#e5e7eb; /* gray-200 */
--pale: rgba(56,189,248,0.15); /* base dots */
--mid: rgba(56,189,248,0.65); /* neighbors */
--bright: rgba(34,211,238,1.0); /* selected */
}
body{ background:var(--bg); color:var(--text);}
h1,h2,h3{ color:var(--text); }
hr{ border-color:var(--border); }
.btn{ background:var(--primary); color:#000; border:1px solid var(--primary); padding:6px 10px; border-radius:10px; }
.badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
.rowbtn{
width:100%; text-align:left; padding:10px 12px; border-radius:12px;
border:1px solid var(--border); background:#fff;
}
.rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; }
.rowgrid{ display:grid; grid-template-columns: 70px 120px 1fr; gap:10px; align-items:center; }
.rowprob{ color:#111; font-weight:600; }
.rowid{ color:var(--muted); font-variant-numeric: tabular-nums; }
.rowtok{ color:#111; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; }
.listheader{ color:var(--muted); font-size:12px; margin-bottom:6px;}
"""
# ------------------ App state ------------------
text_rx = solara.reactive("twinkle, twinkle, little ")
preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
selected_token_id_rx = solara.reactive(None)
neighbor_list_rx = solara.reactive([]) # list of (tok, sim)
notice_rx = solara.reactive("Click a candidate (or hover to preview).")
auto_running_rx = solara.reactive(True) # toggle if you ever want to disable auto-predict
# ------------------ Embedding assets ------------------
ASSETS = Path("assets/embeddings")
COORDS_PATH = ASSETS / "pca_top5k_coords.json"
NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
coords = {}
neighbors = {}
ids_set = set()
if COORDS_PATH.exists() and NEIGH_PATH.exists():
coords = json.loads(COORDS_PATH.read_text("utf-8"))
neighbors = json.loads(NEIGH_PATH.read_text("utf-8"))
ids_set = set(map(int, coords.keys()))
else:
notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
# ------------------ Predict ------------------
def predict_top10(prompt: str) -> pd.DataFrame:
if not prompt:
return pd.DataFrame(columns=["probs", "id", "tok"])
tokens = tokenizer.encode(prompt, return_tensors="pt")
out = model.generate(
tokens,
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
temperature=0.0,
top_k=1,
top_p=1.0,
)
scores = torch.softmax(out.scores[0], dim=-1) # [1, vocab]
topk = torch.topk(scores, 10)
ids = [int(topk.indices[0, i]) for i in range(10)]
probs = [float(topk.values[0, i]) for i in range(10)]
toks = [tokenizer.decode([i]) for i in ids]
df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
return df
def on_predict():
df = predict_top10(text_rx.value)
preds_rx.set(df)
if len(df) > 0:
tid = int(df.iloc[0]["id"])
preview_token(tid) # highlight top-1 by default
# ------------------ Plotly figure ------------------
def base_scatter():
fig = go.Figure()
if coords:
xs, ys = zip(*[coords[k] for k in coords.keys()])
fig.add_trace(go.Scattergl(
x=xs, y=ys, mode="markers",
marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"),
hoverinfo="skip",
))
fig.update_layout(
height=460, margin=dict(l=10, r=10, t=10, b=10),
paper_bgcolor="white", plot_bgcolor="white",
xaxis=dict(visible=False), yaxis=dict(visible=False),
showlegend=False,
)
return fig
fig_rx = solara.reactive(base_scatter())
def get_neighbor_list(token_id: int, k: int = 18):
if not ids_set or token_id not in ids_set:
return []
raw = neighbors.get("neighbors", {}).get(str(token_id), [])
return raw[:k] # [ [nid, sim], ... ]
def highlight(token_id: int):
"""Return figure with neighbors + target highlighted and update neighbor chip list."""
fig = base_scatter()
if not coords or token_id not in ids_set:
neighbor_list_rx.set([])
return fig
# neighbors
nbrs = get_neighbor_list(token_id, k=20)
if nbrs:
nx = [coords[str(nid)][0] for nid, _ in nbrs]
ny = [coords[str(nid)][1] for nid, _ in nbrs]
fig.add_trace(go.Scattergl(
x=nx, y=ny, mode="markers",
marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"),
hoverinfo="skip",
))
# update chip list
chips = []
for nid, sim in nbrs:
chips.append((tokenizer.decode([int(nid)]), float(sim)))
neighbor_list_rx.set(chips)
else:
neighbor_list_rx.set([])
# target
tx, ty = coords[str(token_id)]
fig.add_trace(go.Scattergl(
x=[tx], y=[ty], mode="markers",
marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)),
hoverinfo="skip",
))
return fig
def preview_token(token_id: int):
selected_token_id_rx.set(int(token_id))
fig_rx.set(highlight(int(token_id)))
def append_token(token_id: int):
decoded = tokenizer.decode([int(token_id)])
text_rx.set(text_rx.value + decoded)
preview_token(int(token_id)) # highlight the one we appended
on_predict() # refresh next top-10 for new text
# ------------------ Auto-predict on typing (debounced) ------------------
@solara.component
def AutoPredictWatcher():
text = text_rx.value
auto = auto_running_rx.value
def effect():
if not auto:
return
cancelled = False
snap = text
def worker():
time.sleep(0.25) # debounce ~250ms
if not cancelled and snap == text_rx.value:
on_predict()
threading.Thread(target=worker, daemon=True).start()
def cleanup():
nonlocal cancelled
cancelled = True
return cleanup
solara.use_effect(effect, [text, auto])
return solara.Text("", style={"display":"none"}) # nothing visible
# ------------------ UI: custom clickable/hoverable list ------------------
@solara.component
def PredictionsList():
df = preds_rx.value
with solara.Column(gap="6px", style={"maxWidth":"720px"}):
solara.Markdown("### Prediction")
solara.HTML(tag="div", unsafe_innerHTML='<div class="listheader"># &nbsp;&nbsp; probs &nbsp;&nbsp; token id &nbsp;&nbsp; predicted next token</div>')
for i, row in df.iterrows():
tid = int(row["id"]); tok = row["tok"]; prob = row["probs"]
label = f'<div class="rowgrid"><div class="rowprob">{i}&nbsp;&nbsp;{prob}</div><div class="rowid">{tid}</div><div class="rowtok">{tok}</div></div>'
# full-width button acts like a row; hover previews, click appends
solara.Button(None,
on_click=lambda tid=tid: append_token(tid),
on_mouse_enter=lambda tid=tid: preview_token(tid),
html=label,
classes=["rowbtn"],
)
# ------------------ Page ------------------
@solara.component
def Page():
solara.Style(theme_css)
with solara.Column(margin=12, gap="16px"):
solara.Markdown("# Next-Token Predictor + Semantic Neighborhood")
solara.Markdown(
"Type text to see predictions update automatically. "
"Click a candidate to append it and highlight its **semantic neighborhood**. "
"Hover a candidate to preview its neighborhood."
)
# Input row (no Predict button needed anymore)
solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
solara.Markdown(f"*{notice_rx.value}*")
with solara.Row(gap="24px", style={"align-items": "flex-start"}):
with solara.Column():
PredictionsList()
with solara.Column():
solara.Markdown("### Semantic Neighborhood")
if not coords:
solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
else:
solara.FigurePlotly(fig_rx.value)
# Neighbor chips
if neighbor_list_rx.value:
solara.Markdown("**Nearest neighbors:**")
with solara.Row(style={"flex-wrap": "wrap"}):
for tok, sim in neighbor_list_rx.value:
solara.HTML(tag="span", unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>')
AutoPredictWatcher() # invisible component that triggers auto-predict
# Seed initial predictions
on_predict()
Page()