PeterPinetree commited on
Commit
bec627d
·
verified ·
1 Parent(s): 65a9c4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -84
app.py CHANGED
@@ -4,18 +4,21 @@ from pathlib import Path
4
  import threading, time
5
 
6
  import solara
7
- import torch
8
- import torch.nn.functional as F
9
  import pandas as pd
10
  import plotly.graph_objects as go
 
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
- # ------------------ Model ------------------
 
 
 
 
14
  MODEL_ID = "Qwen/Qwen3-0.6B"
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
17
 
18
- # ------------------ Theme ------------------
19
  theme_css = """
20
  :root{
21
  --primary:#38bdf8; /* light blue */
@@ -24,18 +27,15 @@ theme_css = """
24
  --muted:#6b7280; /* gray-500 */
25
  --border:#e5e7eb; /* gray-200 */
26
  }
27
- body{ background:var(--bg); color:var(--text);}
28
- h1,h2,h3{ color:var(--text); }
29
- hr{ border-color:var(--border); }
30
  .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
31
 
32
- /* Keep the predictions list above the plot so it can receive pointer events */
33
  .predictions-panel { position: relative; z-index: 5; }
34
- .plot-panel { position: relative; z-index: 1; }
35
-
36
- /* Safety: prevent any wide Plotly overlay from swallowing events on the left */
37
  .plot-panel .js-plotly-plot { position: relative; z-index: 1; }
38
 
 
39
  .rowbtn{
40
  width:100%; padding:10px 12px; border-radius:12px;
41
  border:1px solid var(--border); background:#fff; color:var(--text);
@@ -47,16 +47,16 @@ hr{ border-color:var(--border); }
47
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; }
48
  """
49
 
50
- # ------------------ App state ------------------
51
  text_rx = solara.reactive("twinkle, twinkle, little ")
52
- preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
53
- selected_token_id_rx = solara.reactive(None) # currently highlighted token id
54
- neighbor_list_rx = solara.reactive([]) # [(tok_display, sim), ...]
 
55
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
56
  auto_running_rx = solara.reactive(True)
57
- last_hovered_id_rx = solara.reactive(None)
58
 
59
- # ------------------ Embedding assets ------------------
60
  ASSETS = Path("assets/embeddings")
61
  COORDS_PATH = ASSETS / "pca_top5k_coords.json"
62
  NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
@@ -66,34 +66,31 @@ neighbors = {}
66
  ids_set = set()
67
 
68
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
69
- coords = json.loads(COORDS_PATH.read_text("utf-8")) # { "tid": [x,y], ... }
70
- neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) # { "neighbors": { "tid": [[nid,sim], ...] } }
71
  ids_set = set(map(int, coords.keys()))
72
  else:
73
- notice_rx.set("Embedding files not found — add assets/embeddings/*.json`.")
74
 
75
- # ------------------ Helpers ------------------
76
  def display_token_from_id(tid: int) -> str:
77
- """Readable label for a single token id (no leading tokenizer markers)."""
78
  toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True)
79
  t = toks[0] if toks else ""
80
  for lead in ("▁", "Ġ"):
81
  if t.startswith(lead):
82
  t = t[len(lead):]
83
- t = t.replace("\n", "↵")
84
  if t.strip() == "":
85
- return "␠" # visible space marker for pure whitespace
86
  return t
87
 
88
  def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str:
89
- # left-justified simple columns
90
  return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}"
91
 
92
- # ------------------ Predict ------------------
93
  def predict_top10(prompt: str) -> pd.DataFrame:
94
  if not prompt:
95
- return pd.DataFrame(columns=["probs", "id", "tok"])
96
-
97
  tokens = tokenizer.encode(prompt, return_tensors="pt")
98
  out = model.generate(
99
  tokens,
@@ -101,43 +98,41 @@ def predict_top10(prompt: str) -> pd.DataFrame:
101
  output_scores=True,
102
  return_dict_in_generate=True,
103
  pad_token_id=tokenizer.eos_token_id,
104
- do_sample=False,
105
- temperature=0.0,
106
- top_k=1,
107
- top_p=1.0,
108
  )
109
- scores = torch.softmax(out.scores[0], dim=-1) # [1, vocab]
110
  topk = torch.topk(scores, 10)
111
  ids = [int(topk.indices[0, i]) for i in range(10)]
112
  probs = [float(topk.values[0, i]) for i in range(10)]
113
- toks = [tokenizer.decode([i]) for i in ids] # for append; display uses display_token_from_id
114
  df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
115
  df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
116
  return df
117
 
118
  def on_predict():
119
- """Update predictions. Keep current highlight unless there is none yet."""
120
  df = predict_top10(text_rx.value)
121
  preds_rx.set(df)
122
  if len(df) == 0:
123
  return
124
  if selected_token_id_rx.value is None:
125
- preview_token(int(df.iloc[0]["id"]))
126
  else:
 
127
  fig_rx.set(highlight(int(selected_token_id_rx.value)))
128
 
129
- # ------------------ Plotly figure ------------------
130
  def base_scatter():
131
  fig = go.Figure()
132
  if coords:
133
  xs, ys = zip(*[coords[k] for k in coords.keys()])
134
  fig.add_trace(go.Scattergl(
135
  x=xs, y=ys, mode="markers",
136
- marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"), # pale cloud
137
  hoverinfo="skip",
138
  ))
139
  fig.update_layout(
140
- height=460, margin=dict(l=10, r=10, t=10, b=10),
141
  paper_bgcolor="white", plot_bgcolor="white",
142
  xaxis=dict(visible=False), yaxis=dict(visible=False),
143
  showlegend=False,
@@ -153,7 +148,6 @@ def get_neighbor_list(token_id: int, k: int = 20):
153
  return raw[:k]
154
 
155
  def highlight(token_id: int):
156
- """Return figure with neighbors + target highlighted and update neighbor chip list."""
157
  fig = base_scatter()
158
  if not coords or token_id not in ids_set:
159
  neighbor_list_rx.set([])
@@ -161,16 +155,14 @@ def highlight(token_id: int):
161
 
162
  nbrs = get_neighbor_list(token_id, k=20)
163
  if nbrs:
164
- nx = [coords[str(nid)][0] for nid, _ in nbrs]
165
- ny = [coords[str(nid)][1] for nid, _ in nbrs]
166
  fig.add_trace(go.Scattergl(
167
  x=nx, y=ny, mode="markers",
168
- marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"), # darker neighbors
169
  hoverinfo="skip",
170
  ))
171
- chips = []
172
- for nid, sim in nbrs:
173
- chips.append((display_token_from_id(int(nid)), float(sim)))
174
  neighbor_list_rx.set(chips)
175
  else:
176
  neighbor_list_rx.set([])
@@ -178,16 +170,15 @@ def highlight(token_id: int):
178
  tx, ty = coords[str(token_id)]
179
  fig.add_trace(go.Scattergl(
180
  x=[tx], y=[ty], mode="markers",
181
- marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)), # bright target
182
  hoverinfo="skip",
183
  ))
184
  return fig
185
 
186
  def preview_token(token_id: int):
187
- print("preview ->", token_id) # TEMP: check Logs
188
- token_id = int(token_id)
189
- # TEMP DEBUG: verify hover fires in Space logs
190
  print("preview ->", token_id)
 
191
  if last_hovered_id_rx.value == token_id:
192
  return
193
  last_hovered_id_rx.set(token_id)
@@ -195,14 +186,14 @@ def preview_token(token_id: int):
195
  fig_rx.set(highlight(token_id))
196
 
197
  def append_token(token_id: int):
198
- # keep decode() here so spacing stays correct in the prompt
199
- print("append ->", token_id) # TEMP: check Logs
200
  decoded = tokenizer.decode([int(token_id)])
201
  text_rx.set(text_rx.value + decoded)
202
- preview_token(int(token_id)) # keep highlight on clicked token
203
- on_predict() # refresh predictions, preserve selection
204
 
205
- # ------------------ Auto-predict on typing (debounced) ------------------
206
  @solara.component
207
  def AutoPredictWatcher():
208
  text = text_rx.value
@@ -215,7 +206,7 @@ def AutoPredictWatcher():
215
  snap = text
216
 
217
  def worker():
218
- time.sleep(0.25) # ~250ms debounce
219
  if not cancelled and snap == text_rx.value:
220
  on_predict()
221
 
@@ -229,41 +220,39 @@ def AutoPredictWatcher():
229
  solara.use_effect(effect, [text, auto])
230
  return solara.Text("", style={"display": "none"})
231
 
232
- # ------------------ UI: rows as Div (hover + click here) ------------------
233
  @solara.component
234
  def PredictionsList():
235
  df = preds_rx.value
236
- with solara.Column(gap="6px", style={"maxWidth": "720px"}):
237
  solara.Markdown("### Prediction")
238
  solara.Text(
239
  " # probs token predicted next token",
240
  style={
241
- "color": "var(--muted)",
242
- "fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace',
243
  },
244
  )
245
  for i, row in df.iterrows():
246
- tid = int(row["id"])
247
- prob = row["probs"]
248
  tok_disp = display_token_from_id(tid)
249
  label = fmt_row(i, prob, tid, tok_disp)
250
 
251
- # One Div per row: both hover and click handlers live here.
252
  with solara.Div(
253
  classes=["rowbtn"],
254
- style={"justifyContent": "flex-start", "width": "100%"},
255
- attributes={"tabindex": "0", "role": "button"},
256
- on_click=lambda e=None, tid=tid: append_token(tid), # click to append
257
- on_mouse_enter=lambda e=None, tid=tid: preview_token(tid), # hover preview
258
- on_mouse_over=lambda e=None, tid=tid: preview_token(tid),
259
- on_mouse_move=lambda e=None, tid=tid: preview_token(tid),
260
- on_pointer_enter=lambda e=None, tid=tid: preview_token(tid),
261
- on_focus=lambda e=None, tid=tid: preview_token(tid),
262
  ):
263
  solara.Text(label)
264
 
265
-
266
- # ------------------ Page ------------------
267
  @solara.component
268
  def Page():
269
  solara.Style(theme_css)
@@ -276,12 +265,10 @@ def Page():
276
  "Hover a candidate to preview its neighborhood."
277
  )
278
 
279
- # Input (auto-predict handled by watcher)
280
- solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
281
  solara.Markdown(f"*{notice_rx.value}*")
282
 
283
- # Two columns
284
- with solara.Row(gap="24px", style={"align-items": "flex-start"}):
285
  with solara.Column(classes=["predictions-panel"]):
286
  PredictionsList()
287
 
@@ -294,15 +281,13 @@ def Page():
294
 
295
  if neighbor_list_rx.value:
296
  solara.Markdown("**Nearest neighbors:**")
297
- with solara.Row(style={"flex-wrap": "wrap"}):
298
  for tok, sim in neighbor_list_rx.value:
299
- solara.HTML(
300
- tag="span",
301
- unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>',
302
- )
303
 
304
  AutoPredictWatcher()
305
 
306
- # Seed initial predictions and mount
307
  on_predict()
308
- Page()
 
4
  import threading, time
5
 
6
  import solara
 
 
7
  import pandas as pd
8
  import plotly.graph_objects as go
9
+ import torch
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
 
12
+ # ---------- versions (shows up in Space logs) ----------
13
+ import plotly
14
+ print("VERSIONS:", "solara", solara.__version__, "plotly", plotly.__version__, "torch", torch.__version__)
15
+
16
+ # ---------- Model ----------
17
  MODEL_ID = "Qwen/Qwen3-0.6B"
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
20
 
21
+ # ---------- Theme & layout fixes ----------
22
  theme_css = """
23
  :root{
24
  --primary:#38bdf8; /* light blue */
 
27
  --muted:#6b7280; /* gray-500 */
28
  --border:#e5e7eb; /* gray-200 */
29
  }
30
+ body{ background:var(--bg); color:var(--text); }
 
 
31
  .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
32
 
33
+ /* Make sure the prediction list can receive pointer events even if Plotly expands */
34
  .predictions-panel { position: relative; z-index: 5; }
35
+ .plot-panel { position: relative; z-index: 1; }
 
 
36
  .plot-panel .js-plotly-plot { position: relative; z-index: 1; }
37
 
38
+ /* Row style */
39
  .rowbtn{
40
  width:100%; padding:10px 12px; border-radius:12px;
41
  border:1px solid var(--border); background:#fff; color:var(--text);
 
47
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; }
48
  """
49
 
50
+ # ---------- App state ----------
51
  text_rx = solara.reactive("twinkle, twinkle, little ")
52
+ preds_rx = solara.reactive(pd.DataFrame(columns=["probs","id","tok"]))
53
+ selected_token_id_rx = solara.reactive(None)
54
+ neighbor_list_rx = solara.reactive([])
55
+ last_hovered_id_rx = solara.reactive(None)
56
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
57
  auto_running_rx = solara.reactive(True)
 
58
 
59
+ # ---------- Embedding assets ----------
60
  ASSETS = Path("assets/embeddings")
61
  COORDS_PATH = ASSETS / "pca_top5k_coords.json"
62
  NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json"
 
66
  ids_set = set()
67
 
68
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
69
+ coords = json.loads(COORDS_PATH.read_text("utf-8"))
70
+ neighbors = json.loads(NEIGH_PATH.read_text("utf-8"))
71
  ids_set = set(map(int, coords.keys()))
72
  else:
73
+ notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
74
 
75
+ # ---------- Helpers ----------
76
  def display_token_from_id(tid: int) -> str:
 
77
  toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True)
78
  t = toks[0] if toks else ""
79
  for lead in ("▁", "Ġ"):
80
  if t.startswith(lead):
81
  t = t[len(lead):]
82
+ t = t.replace("\n","↵")
83
  if t.strip() == "":
84
+ return "␠"
85
  return t
86
 
87
  def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str:
 
88
  return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}"
89
 
90
+ # ---------- Predict ----------
91
  def predict_top10(prompt: str) -> pd.DataFrame:
92
  if not prompt:
93
+ return pd.DataFrame(columns=["probs","id","tok"])
 
94
  tokens = tokenizer.encode(prompt, return_tensors="pt")
95
  out = model.generate(
96
  tokens,
 
98
  output_scores=True,
99
  return_dict_in_generate=True,
100
  pad_token_id=tokenizer.eos_token_id,
101
+ do_sample=False, temperature=0.0, top_k=1, top_p=1.0,
 
 
 
102
  )
103
+ scores = torch.softmax(out.scores[0], dim=-1)
104
  topk = torch.topk(scores, 10)
105
  ids = [int(topk.indices[0, i]) for i in range(10)]
106
  probs = [float(topk.values[0, i]) for i in range(10)]
107
+ toks = [tokenizer.decode([i]) for i in ids] # used for append only
108
  df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
109
  df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
110
  return df
111
 
112
  def on_predict():
113
+ """Update predictions; keep current highlight unless none yet."""
114
  df = predict_top10(text_rx.value)
115
  preds_rx.set(df)
116
  if len(df) == 0:
117
  return
118
  if selected_token_id_rx.value is None:
119
+ preview_token(int(df.iloc[0]["id"])) # first time only
120
  else:
121
+ # keep the user's last selection/hover
122
  fig_rx.set(highlight(int(selected_token_id_rx.value)))
123
 
124
+ # ---------- Plot ----------
125
  def base_scatter():
126
  fig = go.Figure()
127
  if coords:
128
  xs, ys = zip(*[coords[k] for k in coords.keys()])
129
  fig.add_trace(go.Scattergl(
130
  x=xs, y=ys, mode="markers",
131
+ marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"),
132
  hoverinfo="skip",
133
  ))
134
  fig.update_layout(
135
+ height=460, margin=dict(l=10,r=10,t=10,b=10),
136
  paper_bgcolor="white", plot_bgcolor="white",
137
  xaxis=dict(visible=False), yaxis=dict(visible=False),
138
  showlegend=False,
 
148
  return raw[:k]
149
 
150
  def highlight(token_id: int):
 
151
  fig = base_scatter()
152
  if not coords or token_id not in ids_set:
153
  neighbor_list_rx.set([])
 
155
 
156
  nbrs = get_neighbor_list(token_id, k=20)
157
  if nbrs:
158
+ nx = [coords[str(nid)][0] for nid,_ in nbrs]
159
+ ny = [coords[str(nid)][1] for nid,_ in nbrs]
160
  fig.add_trace(go.Scattergl(
161
  x=nx, y=ny, mode="markers",
162
+ marker=dict(size=6, color="rgba(56,189,248,0.75)"),
163
  hoverinfo="skip",
164
  ))
165
+ chips = [(display_token_from_id(int(nid)), float(sim)) for nid,sim in nbrs]
 
 
166
  neighbor_list_rx.set(chips)
167
  else:
168
  neighbor_list_rx.set([])
 
170
  tx, ty = coords[str(token_id)]
171
  fig.add_trace(go.Scattergl(
172
  x=[tx], y=[ty], mode="markers",
173
+ marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)),
174
  hoverinfo="skip",
175
  ))
176
  return fig
177
 
178
  def preview_token(token_id: int):
179
+ # DEBUG: confirm events reach Python
 
 
180
  print("preview ->", token_id)
181
+ token_id = int(token_id)
182
  if last_hovered_id_rx.value == token_id:
183
  return
184
  last_hovered_id_rx.set(token_id)
 
186
  fig_rx.set(highlight(token_id))
187
 
188
  def append_token(token_id: int):
189
+ # DEBUG
190
+ print("append ->", token_id)
191
  decoded = tokenizer.decode([int(token_id)])
192
  text_rx.set(text_rx.value + decoded)
193
+ preview_token(int(token_id))
194
+ on_predict()
195
 
196
+ # ---------- Auto-predict (debounced) ----------
197
  @solara.component
198
  def AutoPredictWatcher():
199
  text = text_rx.value
 
206
  snap = text
207
 
208
  def worker():
209
+ time.sleep(0.25)
210
  if not cancelled and snap == text_rx.value:
211
  on_predict()
212
 
 
220
  solara.use_effect(effect, [text, auto])
221
  return solara.Text("", style={"display": "none"})
222
 
223
+ # ---------- Predictions list ----------
224
  @solara.component
225
  def PredictionsList():
226
  df = preds_rx.value
227
+ with solara.Column(gap="6px", style={"maxWidth":"720px"}):
228
  solara.Markdown("### Prediction")
229
  solara.Text(
230
  " # probs token predicted next token",
231
  style={
232
+ "color":"var(--muted)",
233
+ "fontFamily":'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace',
234
  },
235
  )
236
  for i, row in df.iterrows():
237
+ tid = int(row["id"]); prob = row["probs"]
 
238
  tok_disp = display_token_from_id(tid)
239
  label = fmt_row(i, prob, tid, tok_disp)
240
 
241
+ # Use Div so pointer events are reliable; accept *args to handle any signature
242
  with solara.Div(
243
  classes=["rowbtn"],
244
+ style={"justifyContent":"flex-start","width":"100%"},
245
+ attributes={"tabindex":"0","role":"button"},
246
+ on_click=lambda *args, tid=tid: append_token(tid),
247
+ on_mouse_enter=lambda *args, tid=tid: preview_token(tid),
248
+ on_mouse_over=lambda *args, tid=tid: preview_token(tid),
249
+ on_mouse_move=lambda *args, tid=tid: preview_token(tid),
250
+ on_pointer_enter=lambda *args, tid=tid: preview_token(tid),
251
+ on_focus=lambda *args, tid=tid: preview_token(tid),
252
  ):
253
  solara.Text(label)
254
 
255
+ # ---------- Page ----------
 
256
  @solara.component
257
  def Page():
258
  solara.Style(theme_css)
 
265
  "Hover a candidate to preview its neighborhood."
266
  )
267
 
268
+ solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
 
269
  solara.Markdown(f"*{notice_rx.value}*")
270
 
271
+ with solara.Row(gap="24px", style={"align-items":"flex-start"}):
 
272
  with solara.Column(classes=["predictions-panel"]):
273
  PredictionsList()
274
 
 
281
 
282
  if neighbor_list_rx.value:
283
  solara.Markdown("**Nearest neighbors:**")
284
+ with solara.Row(style={"flex-wrap":"wrap"}):
285
  for tok, sim in neighbor_list_rx.value:
286
+ solara.HTML(tag="span",
287
+ unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>')
 
 
288
 
289
  AutoPredictWatcher()
290
 
291
+ # ---------- Kickoff ----------
292
  on_predict()
293
+ Page()