PeterPinetree commited on
Commit
931a76f
·
verified ·
1 Parent(s): 921ac50

Update app.py

Browse files

Auto-predict
A tiny AutoPredictWatcher() component debounces text_rx changes (~250 ms) and calls on_predict() automatically. No more Predict button.

Click anywhere on a row
The top-10 is now a clean, custom list of full-width “row buttons” (styled to look like rows). Clicking anywhere in the row calls append_token(tid).

Clearer neighborhood
The Plotly figure now has three layers:

Pale cloud (all tokens): rgba(56,189,248,0.15)

Darker neighbors (10–20): rgba(56,189,248,0.75)

Bright selected token: rgba(34,211,238,1.0)

Hover preview + neighbor chips
Each row has on_mouse_enter that calls preview_token(tid), which updates the plot and populates a right-side chip list with the 10–20 closest neighbors and their cosine %.

Files changed (1) hide show
  1. app.py +182 -137
app.py CHANGED
@@ -1,65 +1,78 @@
1
  # app.py
2
  import json
3
- import random
4
  from pathlib import Path
 
5
 
6
  import solara
7
- import pandas as pd
8
  import torch
9
  import torch.nn.functional as F
 
 
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
 
12
- # ---- Model (same as original Space) -----------------------------------------
13
  MODEL_ID = "Qwen/Qwen3-0.6B"
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
16
 
17
- # ---- App state ---------------------------------------------------------------
18
- text_rx = solara.reactive("twinkle, twinkle, little ")
19
- top10_rx = solara.reactive(pd.DataFrame(columns=["probs", "next token ID", "predicted next token"]))
20
- selected_token_id_rx = solara.reactive(None) # for neighborhood focus
21
- notice_rx = solara.reactive("Enter text to see predictions.")
22
- theme_css = solara.reactive("""
23
- <style>
24
- :root {
25
- --primary: #38bdf8; /* light blue */
26
- --bg: #ffffff; /* white */
27
- --text: #000000; /* black */
28
- --muted: #6b7280; /* gray-500 */
29
- --border: #e5e7eb; /* gray-200 */
30
  }
31
- body { background: var(--bg); color: var(--text); }
32
- h1, h2, h3 { color: var(--text); }
33
- table td, table th { border-color: var(--border) !important; }
34
- .solara-dataframe .MuiTableCell-root { font-size: 14px; }
35
- .btn-primary { background: var(--primary); color: #000; border: 1px solid var(--primary); padding: 6px 10px; border-radius: 8px; }
36
- .badge { display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; color:var(--text); }
37
- </style>
38
- """)
39
-
40
- # ---- Load embedding assets (your files) --------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ASSETS = Path("assets/embeddings")
42
  COORDS_PATH = ASSETS / "pca_top5k_coords.json"
43
- NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
44
 
45
  coords = {}
46
  neighbors = {}
47
  ids_set = set()
48
 
49
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
50
- with COORDS_PATH.open("r", encoding="utf-8") as f:
51
- coords = json.load(f) # {token_id: [x, y], ...}
52
- with NEIGH_PATH.open("r", encoding="utf-8") as f:
53
- neighbors = json.load(f) # {"neighbors": {token_id: [[nid, sim], ...]}}
54
  ids_set = set(map(int, coords.keys()))
55
  else:
56
- notice_rx.set("Embedding files not found. Add assets/embeddings/*.json to enable the map.")
57
 
58
- # ---- Helpers -----------------------------------------------------------------
59
  def predict_top10(prompt: str) -> pd.DataFrame:
60
  if not prompt:
61
- return pd.DataFrame(columns=["probs", "next token ID", "predicted next token"])
62
-
63
  tokens = tokenizer.encode(prompt, return_tensors="pt")
64
  out = model.generate(
65
  tokens,
@@ -67,148 +80,180 @@ def predict_top10(prompt: str) -> pd.DataFrame:
67
  output_scores=True,
68
  return_dict_in_generate=True,
69
  pad_token_id=tokenizer.eos_token_id,
70
- do_sample=False, # greedy (deterministic)
71
  temperature=0.0,
72
  top_k=1,
73
  top_p=1.0,
74
  )
75
- scores = F.softmax(out.scores[0], dim=-1) # [1, vocab]
76
- top_10 = torch.topk(scores, 10)
77
-
78
- df = pd.DataFrame()
79
- df["probs"] = top_10.values[0].detach().cpu().numpy()
80
- df["probs"] = [f"{p:.2%}" for p in df["probs"]]
81
- ids = [int(top_10.indices[0][i].detach().cpu().item()) for i in range(10)]
82
- df["next token ID"] = ids
83
- df["predicted next token"] = [tokenizer.decode([i]) for i in ids]
84
  return df
85
 
86
- def get_neighbor_list(token_id: int, k: int = 18):
87
- if not ids_set or token_id not in ids_set:
88
- return []
89
- raw = neighbors.get("neighbors", {}).get(str(token_id), [])
90
- # raw item is [nid, sim]; keep top k
91
- return raw[:k]
92
-
93
- # ---- Plot (Plotly scatter) ---------------------------------------------------
94
- # We generate a static "all points" scatter once, then reuse it with highlights.
95
- import plotly.graph_objects as go
96
 
 
97
  def base_scatter():
98
- if not coords:
99
- return go.Figure().update_layout(
100
- height=440, margin=dict(l=10, r=10, t=10, b=10),
101
- paper_bgcolor="white", plot_bgcolor="white",
102
- )
103
- # unpack coordinates
104
- xs, ys, tids = [], [], []
105
- for tid_str, pt in coords.items():
106
- xs.append(pt[0]); ys.append(pt[1]); tids.append(int(tid_str))
107
  fig = go.Figure()
108
- fig.add_trace(go.Scattergl(
109
- x=xs, y=ys, mode="markers",
110
- marker=dict(size=3, opacity=0.85),
111
- text=[f"id {t}" for t in tids],
112
- hoverinfo="skip", # keep hover minimal; we’ll show neighbors explicitly
113
- ))
 
114
  fig.update_layout(
115
- height=440, margin=dict(l=10, r=10, t=10, b=10),
116
  paper_bgcolor="white", plot_bgcolor="white",
117
  xaxis=dict(visible=False), yaxis=dict(visible=False),
 
118
  )
119
  return fig
120
 
121
- base_fig = base_scatter()
122
- fig_rx = solara.reactive(base_fig)
123
 
124
- def highlight(token_id: int):
125
- """Return a new figure with neighbors + target highlighted."""
126
- fig = base_fig.to_dict() # detach copy
127
- fig = go.Figure(fig)
 
128
 
 
 
 
129
  if not coords or token_id not in ids_set:
 
130
  return fig
131
 
132
- # Target
133
- tx, ty = coords[str(token_id)]
134
- fig.add_trace(go.Scattergl(
135
- x=[tx], y=[ty], mode="markers",
136
- marker=dict(size=8, line=dict(width=1), symbol="circle"),
137
- name="target",
138
- ))
139
-
140
- # Neighbors
141
- nbrs = get_neighbor_list(token_id)
142
  if nbrs:
143
  nx = [coords[str(nid)][0] for nid, _ in nbrs]
144
  ny = [coords[str(nid)][1] for nid, _ in nbrs]
145
  fig.add_trace(go.Scattergl(
146
  x=nx, y=ny, mode="markers",
147
- marker=dict(size=6, symbol="circle-open"),
148
- name="neighbors",
149
  ))
150
- fig.update_layout(showlegend=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  return fig
152
 
153
- # ---- UI actions --------------------------------------------------------------
154
- def on_append_cell(column, row_index):
155
- # append chosen next token to the text input
156
- df = top10_rx.value
157
- if row_index < len(df):
158
- token_id = int(df.iloc[row_index]["next token ID"])
159
- decoded = tokenizer.decode([token_id])
160
- text_rx.set(text_rx.value + decoded)
161
- selected_token_id_rx.set(token_id)
162
- # Update plot
163
- fig_rx.set(highlight(token_id))
164
 
165
- cell_actions = [solara.CellAction(icon="mdi-plus", name="Append & highlight", on_click=on_append_cell)]
 
 
 
 
166
 
167
- def on_predict():
168
- df = predict_top10(text_rx.value)
169
- top10_rx.set(df)
170
- notice_rx.set("Click a candidate to append it and highlight its neighborhood.")
171
- # also set selected to the top-1 for convenience
172
- if len(df) > 0:
173
- tid = int(df.iloc[0]["next token ID"])
174
- selected_token_id_rx.set(tid)
175
- fig_rx.set(highlight(tid))
176
-
177
- def on_show_neighborhood():
178
- # take last token in the prompt (if any), otherwise do nothing
179
- ids = tokenizer.encode(text_rx.value)
180
- if ids:
181
- token_id = int(ids[-1])
182
- selected_token_id_rx.set(token_id)
183
- fig_rx.set(highlight(token_id))
184
-
185
- # ---- Page --------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  @solara.component
187
  def Page():
188
- solara.Style(theme_css.value)
189
 
190
  with solara.Column(margin=12, gap="16px"):
191
  solara.Markdown("# Next-Token Predictor + Semantic Neighborhood")
192
  solara.Markdown(
193
- "Type text, then **Predict** to see the next-token distribution. "
194
- "Click a candidate to append it and highlight its **semantic neighborhood**."
 
195
  )
196
- with solara.Row(gap="8px"):
197
- solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
198
- solara.Button("Predict", on_click=on_predict, classes=["btn-primary"])
199
- solara.Button("Show neighborhood of last token", on_click=on_show_neighborhood)
200
 
201
  solara.Markdown(f"*{notice_rx.value}*")
202
 
203
- # Top-10 table
204
- solara.Markdown("### Prediction")
205
- solara.DataFrame(top10_rx.value, items_per_page=10, cell_actions=cell_actions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- # Neighborhood panel
208
- solara.Markdown("### Semantic Neighborhood")
209
- if not coords:
210
- solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
211
- else:
212
- solara.FigurePlotly(fig_rx.value)
213
 
 
 
214
  Page()
 
1
  # app.py
2
  import json
 
3
  from pathlib import Path
4
+ import threading, time
5
 
6
  import solara
 
7
  import torch
8
  import torch.nn.functional as F
9
+ import pandas as pd
10
+ import plotly.graph_objects as go
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
+ # ------------------ Model ------------------
14
  MODEL_ID = "Qwen/Qwen3-0.6B"
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
17
 
18
+ # ------------------ Theme ------------------
19
+ theme_css = """
20
+ :root{
21
+ --primary:#38bdf8; /* light blue */
22
+ --bg:#ffffff; /* white */
23
+ --text:#0b0f14; /* near-black */
24
+ --muted:#6b7280; /* gray-500 */
25
+ --border:#e5e7eb; /* gray-200 */
26
+ --pale: rgba(56,189,248,0.15); /* base dots */
27
+ --mid: rgba(56,189,248,0.65); /* neighbors */
28
+ --bright: rgba(34,211,238,1.0); /* selected */
 
 
29
  }
30
+ body{ background:var(--bg); color:var(--text);}
31
+ h1,h2,h3{ color:var(--text); }
32
+ hr{ border-color:var(--border); }
33
+ .btn{ background:var(--primary); color:#000; border:1px solid var(--primary); padding:6px 10px; border-radius:10px; }
34
+ .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
35
+ .rowbtn{
36
+ width:100%; text-align:left; padding:10px 12px; border-radius:12px;
37
+ border:1px solid var(--border); background:#fff;
38
+ }
39
+ .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; }
40
+ .rowgrid{ display:grid; grid-template-columns: 70px 120px 1fr; gap:10px; align-items:center; }
41
+ .rowprob{ color:#111; font-weight:600; }
42
+ .rowid{ color:var(--muted); font-variant-numeric: tabular-nums; }
43
+ .rowtok{ color:#111; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; }
44
+ .listheader{ color:var(--muted); font-size:12px; margin-bottom:6px;}
45
+ """
46
+
47
+ # ------------------ App state ------------------
48
+ text_rx = solara.reactive("twinkle, twinkle, little ")
49
+ preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
50
+ selected_token_id_rx = solara.reactive(None)
51
+ neighbor_list_rx = solara.reactive([]) # list of (tok, sim)
52
+
53
+ notice_rx = solara.reactive("Click a candidate (or hover to preview).")
54
+ auto_running_rx = solara.reactive(True) # toggle if you ever want to disable auto-predict
55
+
56
+ # ------------------ Embedding assets ------------------
57
  ASSETS = Path("assets/embeddings")
58
  COORDS_PATH = ASSETS / "pca_top5k_coords.json"
59
+ NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
60
 
61
  coords = {}
62
  neighbors = {}
63
  ids_set = set()
64
 
65
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
66
+ coords = json.loads(COORDS_PATH.read_text("utf-8"))
67
+ neighbors = json.loads(NEIGH_PATH.read_text("utf-8"))
 
 
68
  ids_set = set(map(int, coords.keys()))
69
  else:
70
+ notice_rx.set("Embedding files not found add assets/embeddings/*.json to enable the map.")
71
 
72
+ # ------------------ Predict ------------------
73
  def predict_top10(prompt: str) -> pd.DataFrame:
74
  if not prompt:
75
+ return pd.DataFrame(columns=["probs", "id", "tok"])
 
76
  tokens = tokenizer.encode(prompt, return_tensors="pt")
77
  out = model.generate(
78
  tokens,
 
80
  output_scores=True,
81
  return_dict_in_generate=True,
82
  pad_token_id=tokenizer.eos_token_id,
83
+ do_sample=False,
84
  temperature=0.0,
85
  top_k=1,
86
  top_p=1.0,
87
  )
88
+ scores = torch.softmax(out.scores[0], dim=-1) # [1, vocab]
89
+ topk = torch.topk(scores, 10)
90
+ ids = [int(topk.indices[0, i]) for i in range(10)]
91
+ probs = [float(topk.values[0, i]) for i in range(10)]
92
+ toks = [tokenizer.decode([i]) for i in ids]
93
+ df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
94
+ df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
 
 
95
  return df
96
 
97
+ def on_predict():
98
+ df = predict_top10(text_rx.value)
99
+ preds_rx.set(df)
100
+ if len(df) > 0:
101
+ tid = int(df.iloc[0]["id"])
102
+ preview_token(tid) # highlight top-1 by default
 
 
 
 
103
 
104
+ # ------------------ Plotly figure ------------------
105
  def base_scatter():
 
 
 
 
 
 
 
 
 
106
  fig = go.Figure()
107
+ if coords:
108
+ xs, ys = zip(*[coords[k] for k in coords.keys()])
109
+ fig.add_trace(go.Scattergl(
110
+ x=xs, y=ys, mode="markers",
111
+ marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"),
112
+ hoverinfo="skip",
113
+ ))
114
  fig.update_layout(
115
+ height=460, margin=dict(l=10, r=10, t=10, b=10),
116
  paper_bgcolor="white", plot_bgcolor="white",
117
  xaxis=dict(visible=False), yaxis=dict(visible=False),
118
+ showlegend=False,
119
  )
120
  return fig
121
 
122
+ fig_rx = solara.reactive(base_scatter())
 
123
 
124
+ def get_neighbor_list(token_id: int, k: int = 18):
125
+ if not ids_set or token_id not in ids_set:
126
+ return []
127
+ raw = neighbors.get("neighbors", {}).get(str(token_id), [])
128
+ return raw[:k] # [ [nid, sim], ... ]
129
 
130
+ def highlight(token_id: int):
131
+ """Return figure with neighbors + target highlighted and update neighbor chip list."""
132
+ fig = base_scatter()
133
  if not coords or token_id not in ids_set:
134
+ neighbor_list_rx.set([])
135
  return fig
136
 
137
+ # neighbors
138
+ nbrs = get_neighbor_list(token_id, k=20)
 
 
 
 
 
 
 
 
139
  if nbrs:
140
  nx = [coords[str(nid)][0] for nid, _ in nbrs]
141
  ny = [coords[str(nid)][1] for nid, _ in nbrs]
142
  fig.add_trace(go.Scattergl(
143
  x=nx, y=ny, mode="markers",
144
+ marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"),
145
+ hoverinfo="skip",
146
  ))
147
+ # update chip list
148
+ chips = []
149
+ for nid, sim in nbrs:
150
+ chips.append((tokenizer.decode([int(nid)]), float(sim)))
151
+ neighbor_list_rx.set(chips)
152
+ else:
153
+ neighbor_list_rx.set([])
154
+
155
+ # target
156
+ tx, ty = coords[str(token_id)]
157
+ fig.add_trace(go.Scattergl(
158
+ x=[tx], y=[ty], mode="markers",
159
+ marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)),
160
+ hoverinfo="skip",
161
+ ))
162
  return fig
163
 
164
+ def preview_token(token_id: int):
165
+ selected_token_id_rx.set(int(token_id))
166
+ fig_rx.set(highlight(int(token_id)))
 
 
 
 
 
 
 
 
167
 
168
+ def append_token(token_id: int):
169
+ decoded = tokenizer.decode([int(token_id)])
170
+ text_rx.set(text_rx.value + decoded)
171
+ preview_token(int(token_id)) # highlight the one we appended
172
+ on_predict() # refresh next top-10 for new text
173
 
174
+ # ------------------ Auto-predict on typing (debounced) ------------------
175
+ @solara.component
176
+ def AutoPredictWatcher():
177
+ text = text_rx.value
178
+ auto = auto_running_rx.value
179
+
180
+ def effect():
181
+ if not auto:
182
+ return
183
+ cancelled = False
184
+ snap = text
185
+
186
+ def worker():
187
+ time.sleep(0.25) # debounce ~250ms
188
+ if not cancelled and snap == text_rx.value:
189
+ on_predict()
190
+
191
+ threading.Thread(target=worker, daemon=True).start()
192
+
193
+ def cleanup():
194
+ nonlocal cancelled
195
+ cancelled = True
196
+ return cleanup
197
+
198
+ solara.use_effect(effect, [text, auto])
199
+ return solara.Text("", style={"display":"none"}) # nothing visible
200
+
201
+ # ------------------ UI: custom clickable/hoverable list ------------------
202
+ @solara.component
203
+ def PredictionsList():
204
+ df = preds_rx.value
205
+ with solara.Column(gap="6px", style={"maxWidth":"720px"}):
206
+ solara.Markdown("### Prediction")
207
+ solara.HTML(tag="div", unsafe_innerHTML='<div class="listheader"># &nbsp;&nbsp; probs &nbsp;&nbsp; token id &nbsp;&nbsp; predicted next token</div>')
208
+ for i, row in df.iterrows():
209
+ tid = int(row["id"]); tok = row["tok"]; prob = row["probs"]
210
+ label = f'<div class="rowgrid"><div class="rowprob">{i}&nbsp;&nbsp;{prob}</div><div class="rowid">{tid}</div><div class="rowtok">{tok}</div></div>'
211
+ # full-width button acts like a row; hover previews, click appends
212
+ solara.Button(None,
213
+ on_click=lambda tid=tid: append_token(tid),
214
+ on_mouse_enter=lambda tid=tid: preview_token(tid),
215
+ html=label,
216
+ classes=["rowbtn"],
217
+ )
218
+
219
+ # ------------------ Page ------------------
220
  @solara.component
221
  def Page():
222
+ solara.Style(theme_css)
223
 
224
  with solara.Column(margin=12, gap="16px"):
225
  solara.Markdown("# Next-Token Predictor + Semantic Neighborhood")
226
  solara.Markdown(
227
+ "Type text to see predictions update automatically. "
228
+ "Click a candidate to append it and highlight its **semantic neighborhood**. "
229
+ "Hover a candidate to preview its neighborhood."
230
  )
231
+
232
+ # Input row (no Predict button needed anymore)
233
+ solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
 
234
 
235
  solara.Markdown(f"*{notice_rx.value}*")
236
 
237
+ with solara.Row(align_items="start", gap="24px"):
238
+ with solara.Column():
239
+ PredictionsList()
240
+
241
+ with solara.Column():
242
+ solara.Markdown("### Semantic Neighborhood")
243
+ if not coords:
244
+ solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
245
+ else:
246
+ solara.FigurePlotly(fig_rx.value)
247
+
248
+ # Neighbor chips
249
+ if neighbor_list_rx.value:
250
+ solara.Markdown("**Nearest neighbors:**")
251
+ with solara.Row(wrap=True):
252
+ for tok, sim in neighbor_list_rx.value:
253
+ solara.HTML(tag="span", unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>')
254
 
255
+ AutoPredictWatcher() # invisible component that triggers auto-predict
 
 
 
 
 
256
 
257
+ # Seed initial predictions
258
+ on_predict()
259
  Page()