PeterPinetree commited on
Commit
44777fd
·
verified ·
1 Parent(s): c9b807e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -108
app.py CHANGED
@@ -1,25 +1,27 @@
1
  # app.py
2
  import json
 
 
3
  from pathlib import Path
4
- import threading, time
5
- import anywidget
6
- import traitlets as t
7
  import solara
8
  import pandas as pd
9
  import plotly.graph_objects as go
10
  import torch
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
- # ---------- versions (shows up in Space logs) ----------
14
- import plotly
15
- print("VERSIONS:", "solara", solara.__version__, "plotly", plotly.__version__, "torch", torch.__version__)
 
16
 
17
  # ---------- Model ----------
18
- MODEL_ID = "Qwen/Qwen3-0.6B"
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
21
 
22
- # ---------- Theme & layout fixes ----------
 
23
  theme_css = """
24
  :root{
25
  --primary:#38bdf8; /* light blue */
@@ -29,30 +31,18 @@ theme_css = """
29
  --border:#e5e7eb; /* gray-200 */
30
  }
31
  body{ background:var(--bg); color:var(--text); }
32
- .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
33
- /* Highlight hovered prediction token */
34
- .badge:hover {
35
- background: var(--primary);
36
- color: white;
37
- border-color: var(--primary);
38
- cursor: pointer;
39
- transition: all 0.2s ease;
40
- }
41
 
42
- /* Optional: style an active hovered token */
43
- .badge.hovered {
44
- background: var(--primary);
45
- color: white;
46
- border-color: var(--primary);
47
  }
48
 
 
 
 
 
49
 
50
- /* Make sure the prediction list can receive pointer events even if Plotly expands */
51
- .predictions-panel { position: relative; z-index: 5; }
52
- .plot-panel { position: relative; z-index: 1; }
53
- .plot-panel .js-plotly-plot { position: relative; z-index: 1; }
54
-
55
- /* Row style */
56
  .rowbtn{
57
  width:100%; padding:10px 12px; border-radius:12px;
58
  border:1px solid var(--border); background:#fff; color:var(--text);
@@ -64,15 +54,17 @@ body{ background:var(--bg); color:var(--text); }
64
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; }
65
  """
66
 
67
- # ---------- App state ----------
 
68
  text_rx = solara.reactive("twinkle, twinkle, little ")
69
- preds_rx = solara.reactive(pd.DataFrame(columns=["probs","id","tok"]))
70
  selected_token_id_rx = solara.reactive(None)
71
  neighbor_list_rx = solara.reactive([])
72
  last_hovered_id_rx = solara.reactive(None)
73
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
74
  auto_running_rx = solara.reactive(True)
75
 
 
76
  # ---------- Embedding assets ----------
77
  ASSETS = Path("assets/embeddings")
78
  COORDS_PATH = ASSETS / "pca_top5k_coords.json"
@@ -89,6 +81,7 @@ if COORDS_PATH.exists() and NEIGH_PATH.exists():
89
  else:
90
  notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
91
 
 
92
  # ---------- Helpers ----------
93
  def display_token_from_id(tid: int) -> str:
94
  toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True)
@@ -97,48 +90,47 @@ def display_token_from_id(tid: int) -> str:
97
  if t.startswith(lead):
98
  t = t[len(lead):]
99
  t = t.replace("\n","↵")
100
- if t.strip() == "":
101
- return "␠"
102
- return t
103
 
104
  def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str:
 
105
  return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}"
106
 
107
- # ---------- Predict ----------
 
108
  def predict_top10(prompt: str) -> pd.DataFrame:
109
  if not prompt:
110
- return pd.DataFrame(columns=["probs","id","tok"])
111
- tokens = tokenizer.encode(prompt, return_tensors="pt")
112
  out = model.generate(
113
- tokens,
114
  max_new_tokens=1,
115
  output_scores=True,
116
  return_dict_in_generate=True,
117
  pad_token_id=tokenizer.eos_token_id,
118
- do_sample=False, temperature=0.0, top_k=1, top_p=1.0,
119
  )
120
  scores = torch.softmax(out.scores[0], dim=-1)
121
  topk = torch.topk(scores, 10)
122
  ids = [int(topk.indices[0, i]) for i in range(10)]
123
  probs = [float(topk.values[0, i]) for i in range(10)]
124
- toks = [tokenizer.decode([i]) for i in ids] # used for append only
125
  df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
126
  df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
127
  return df
128
 
129
  def on_predict():
130
- """Update predictions; keep current highlight unless none yet."""
131
  df = predict_top10(text_rx.value)
132
  preds_rx.set(df)
133
  if len(df) == 0:
134
  return
135
  if selected_token_id_rx.value is None:
136
- preview_token(int(df.iloc[0]["id"])) # first time only
137
  else:
138
- # keep the user's last selection/hover
139
- fig_rx.set(highlight(int(selected_token_id_rx.value)))
140
 
141
- # ---------- Plot ----------
142
  def base_scatter():
143
  fig = go.Figure()
144
  if coords:
@@ -192,9 +184,9 @@ def highlight(token_id: int):
192
  ))
193
  return fig
194
 
 
195
  def preview_token(token_id: int):
196
- # DEBUG: confirm events reach Python
197
- print("preview ->", token_id)
198
  token_id = int(token_id)
199
  if last_hovered_id_rx.value == token_id:
200
  return
@@ -203,14 +195,14 @@ def preview_token(token_id: int):
203
  fig_rx.set(highlight(token_id))
204
 
205
  def append_token(token_id: int):
206
- # DEBUG
207
- print("append ->", token_id)
208
  decoded = tokenizer.decode([int(token_id)])
209
  text_rx.set(text_rx.value + decoded)
210
  preview_token(int(token_id))
211
  on_predict()
212
 
213
- # ---------- Auto-predict (debounced) ----------
 
214
  @solara.component
215
  def AutoPredictWatcher():
216
  text = text_rx.value
@@ -237,29 +229,30 @@ def AutoPredictWatcher():
237
  solara.use_effect(effect, [text, auto])
238
  return solara.Text("", style={"display": "none"})
239
 
 
 
240
  class HoverList(anywidget.AnyWidget):
241
  """
242
- Renders the prediction rows in the browser and streams hover/click
243
- events back to Python via synced traitlets.
244
  """
245
- # Browser code: builds the list and wires events
246
  _esm = """
247
  export function render({ model, el }) {
248
- const make = () => {
249
  const items = model.get('items') || [];
250
  el.innerHTML = "";
251
  const wrap = document.createElement('div');
252
  wrap.style.display = 'flex';
253
  wrap.style.flexDirection = 'column';
 
254
  items.forEach(({tid, label}) => {
255
  const btn = document.createElement('button');
256
  btn.textContent = label;
257
- btn.className = 'rowbtn'; // your existing CSS
258
  btn.setAttribute('type', 'button');
259
  btn.setAttribute('role', 'button');
260
  btn.setAttribute('tabindex', '0');
261
 
262
- // hover → preview
263
  const preview = () => {
264
  model.set('hovered_id', tid);
265
  model.save_changes();
@@ -269,7 +262,6 @@ class HoverList(anywidget.AnyWidget):
269
  btn.addEventListener('mousemove', preview);
270
  btn.addEventListener('focus', preview);
271
 
272
- // click → append
273
  btn.addEventListener('click', () => {
274
  model.set('clicked_id', tid);
275
  model.save_changes();
@@ -277,26 +269,23 @@ class HoverList(anywidget.AnyWidget):
277
 
278
  wrap.appendChild(btn);
279
  });
 
280
  el.appendChild(wrap);
281
  };
282
 
283
- // initial render
284
- make();
285
-
286
- // re-render when items change
287
- model.on('change:items', make);
288
  }
289
  """
290
- # Data flowing between JS and Python
291
- items = t.List(trait=t.Dict()).tag(sync=True) # [{tid:int, label:str}, ...]
292
- hovered_id = t.Int(allow_none=True).tag(sync=True)
293
- clicked_id = t.Int(allow_none=True).tag(sync=True)
294
 
295
 
296
- # ---------- Predictions list ----------
297
  @solara.component
298
  def PredictionsList():
299
- df = preds_rx.value # your DataFrame with columns: probs, id, tok
300
  with solara.Column(gap="6px", style={"maxWidth": "720px"}):
301
  solara.Markdown("### Prediction")
302
  solara.Text(
@@ -307,40 +296,33 @@ def PredictionsList():
307
  },
308
  )
309
 
 
 
310
  for i, row in df.iterrows():
311
  tid = int(row["id"])
312
- prob = row["probs"] # already formatted like "3.21%"
313
  tok_disp = display_token_from_id(tid)
314
- row_label = fmt_row(i, prob, tid, tok_disp)
315
-
316
- # Wrapper DIV handles hover reliably
317
- with solara.Div(
318
- classes=["rowbtn"], # styling on wrapper
319
- style={"justifyContent": "flex-start", "width": "100%"},
320
- attributes={"tabindex": "0", "role": "button"},
321
- # --- HOVER = preview neighborhood ---
322
- on_mouse_enter=lambda *args, tid=tid: preview_token(tid),
323
- on_mouse_over=lambda *args, tid=tid: preview_token(tid),
324
- on_mouse_move=lambda *args, tid=tid: preview_token(tid),
325
- on_pointer_enter=lambda *args, tid=tid: preview_token(tid),
326
- on_pointer_move=lambda *args, tid=tid: preview_token(tid),
327
- on_focus=lambda *args, tid=tid: preview_token(tid), # keyboard
328
- ):
329
- # Inner BUTTON handles click-to-append (and also binds hover for extra safety)
330
- solara.Button(
331
- row_label,
332
- classes=[], # keep wrapper styled; button unstyled
333
- style={"justifyContent": "flex-start", "width": "100%"},
334
- # --- CLICK = append token to text ---
335
- on_click=lambda *args, tid=tid: append_token(tid),
336
- # redundant hover hooks (helps on some builds)
337
- on_mouse_enter=lambda *args, tid=tid: preview_token(tid),
338
- on_mouse_over=lambda *args, tid=tid: preview_token(tid),
339
- on_mouse_move=lambda *args, tid=tid: preview_token(tid),
340
- on_pointer_enter=lambda *args, tid=tid: preview_token(tid),
341
- on_pointer_move=lambda *args, tid=tid: preview_token(tid),
342
- on_focus=lambda *args, tid=tid: preview_token(tid),
343
- )
344
 
345
  # ---------- Page ----------
346
  @solara.component
@@ -354,34 +336,31 @@ def Page():
354
  "Click a candidate to append it and highlight its **semantic neighborhood**. "
355
  "Hover a candidate to preview its neighborhood."
356
  )
357
-
358
  solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
359
  solara.Markdown(f"*{notice_rx.value}*")
360
 
361
  with solara.Row(classes=["app-row"]):
362
- # Left column: predictions list (fixed width, sits above plot for events)
363
  with solara.Column(classes=["predictions-panel"]):
364
  PredictionsList()
365
-
366
- # Right column: plot + neighbor chips
367
  with solara.Column(classes=["plot-panel"]):
368
  solara.Markdown("### Semantic Neighborhood")
369
  if not coords:
370
  solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
371
  else:
372
  solara.FigurePlotly(fig_rx.value)
373
-
374
  if neighbor_list_rx.value:
375
  solara.Markdown("**Nearest neighbors:**")
376
- with solara.Row(style={"flex-wrap": "wrap"}):
377
  for tok, sim in neighbor_list_rx.value:
378
- solara.HTML(
379
- tag="span",
380
- unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>',
381
- )
382
 
383
  AutoPredictWatcher()
384
 
 
385
  # ---------- Kickoff ----------
386
  on_predict()
387
  Page()
 
1
  # app.py
2
  import json
3
+ import threading
4
+ import time
5
  from pathlib import Path
6
+
 
 
7
  import solara
8
  import pandas as pd
9
  import plotly.graph_objects as go
10
  import torch
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
+ # for robust hover/click from the browser
14
+ import anywidget
15
+ import traitlets as t
16
+
17
 
18
  # ---------- Model ----------
19
+ MODEL_ID = "Qwen/Qwen3-0.6B" # same as the working HF Space
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
22
 
23
+
24
+ # ---------- Theme & layout (light blue / white / black accents) ----------
25
  theme_css = """
26
  :root{
27
  --primary:#38bdf8; /* light blue */
 
31
  --border:#e5e7eb; /* gray-200 */
32
  }
33
  body{ background:var(--bg); color:var(--text); }
 
 
 
 
 
 
 
 
 
34
 
35
+ .badge{
36
+ display:inline-block; padding:2px 8px; border:1px solid var(--border);
37
+ border-radius:999px; margin:2px;
 
 
38
  }
39
 
40
+ /* Two-column layout with clear stacking (predictions above plot for events) */
41
+ .app-row { display:flex; align-items:flex-start; gap:24px; }
42
+ .predictions-panel { flex:0 0 360px; position:relative; z-index:10; }
43
+ .plot-panel { flex:1 1 auto; position:relative; z-index:1; overflow:hidden; }
44
 
45
+ /* Prediction rows (styled on a button or wrapper div) */
 
 
 
 
 
46
  .rowbtn{
47
  width:100%; padding:10px 12px; border-radius:12px;
48
  border:1px solid var(--border); background:#fff; color:var(--text);
 
54
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; }
55
  """
56
 
57
+
58
+ # ---------- Reactive state ----------
59
  text_rx = solara.reactive("twinkle, twinkle, little ")
60
+ preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
61
  selected_token_id_rx = solara.reactive(None)
62
  neighbor_list_rx = solara.reactive([])
63
  last_hovered_id_rx = solara.reactive(None)
64
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
65
  auto_running_rx = solara.reactive(True)
66
 
67
+
68
  # ---------- Embedding assets ----------
69
  ASSETS = Path("assets/embeddings")
70
  COORDS_PATH = ASSETS / "pca_top5k_coords.json"
 
81
  else:
82
  notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
83
 
84
+
85
  # ---------- Helpers ----------
86
  def display_token_from_id(tid: int) -> str:
87
  toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True)
 
90
  if t.startswith(lead):
91
  t = t[len(lead):]
92
  t = t.replace("\n","↵")
93
+ return t if t.strip() else ""
 
 
94
 
95
  def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str:
96
+ # columns: index, probability, token id, token text
97
  return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}"
98
 
99
+
100
+ # ---------- Prediction ----------
101
  def predict_top10(prompt: str) -> pd.DataFrame:
102
  if not prompt:
103
+ return pd.DataFrame(columns=["probs", "id", "tok"])
104
+ tokens = tokenizer(prompt, return_tensors="pt", padding=False)
105
  out = model.generate(
106
+ **tokens,
107
  max_new_tokens=1,
108
  output_scores=True,
109
  return_dict_in_generate=True,
110
  pad_token_id=tokenizer.eos_token_id,
111
+ do_sample=False, # greedy; temp/top_k are ignored (by design)
112
  )
113
  scores = torch.softmax(out.scores[0], dim=-1)
114
  topk = torch.topk(scores, 10)
115
  ids = [int(topk.indices[0, i]) for i in range(10)]
116
  probs = [float(topk.values[0, i]) for i in range(10)]
117
+ toks = [tokenizer.decode([i]) for i in ids] # for append
118
  df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
119
  df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
120
  return df
121
 
122
  def on_predict():
 
123
  df = predict_top10(text_rx.value)
124
  preds_rx.set(df)
125
  if len(df) == 0:
126
  return
127
  if selected_token_id_rx.value is None:
128
+ preview_token(int(df.iloc[0]["id"])) # only first time
129
  else:
130
+ fig_rx.set(highlight(int(selected_token_id_rx.value))) # preserve selection
131
+
132
 
133
+ # ---------- Plot / neighborhood ----------
134
  def base_scatter():
135
  fig = go.Figure()
136
  if coords:
 
184
  ))
185
  return fig
186
 
187
+
188
  def preview_token(token_id: int):
189
+ # print("preview ->", token_id) # enable for debugging in Space logs
 
190
  token_id = int(token_id)
191
  if last_hovered_id_rx.value == token_id:
192
  return
 
195
  fig_rx.set(highlight(token_id))
196
 
197
  def append_token(token_id: int):
198
+ # print("append ->", token_id)
 
199
  decoded = tokenizer.decode([int(token_id)])
200
  text_rx.set(text_rx.value + decoded)
201
  preview_token(int(token_id))
202
  on_predict()
203
 
204
+
205
+ # ---------- Debounced auto-predict ----------
206
  @solara.component
207
  def AutoPredictWatcher():
208
  text = text_rx.value
 
229
  solara.use_effect(effect, [text, auto])
230
  return solara.Text("", style={"display": "none"})
231
 
232
+
233
+ # ---------- Hover-enabled list (browser) ----------
234
  class HoverList(anywidget.AnyWidget):
235
  """
236
+ Renders the prediction rows in the browser and streams hover/click events
237
+ back to Python via synced traitlets.
238
  """
 
239
  _esm = """
240
  export function render({ model, el }) {
241
+ const renderList = () => {
242
  const items = model.get('items') || [];
243
  el.innerHTML = "";
244
  const wrap = document.createElement('div');
245
  wrap.style.display = 'flex';
246
  wrap.style.flexDirection = 'column';
247
+
248
  items.forEach(({tid, label}) => {
249
  const btn = document.createElement('button');
250
  btn.textContent = label;
251
+ btn.className = 'rowbtn';
252
  btn.setAttribute('type', 'button');
253
  btn.setAttribute('role', 'button');
254
  btn.setAttribute('tabindex', '0');
255
 
 
256
  const preview = () => {
257
  model.set('hovered_id', tid);
258
  model.save_changes();
 
262
  btn.addEventListener('mousemove', preview);
263
  btn.addEventListener('focus', preview);
264
 
 
265
  btn.addEventListener('click', () => {
266
  model.set('clicked_id', tid);
267
  model.save_changes();
 
269
 
270
  wrap.appendChild(btn);
271
  });
272
+
273
  el.appendChild(wrap);
274
  };
275
 
276
+ renderList();
277
+ model.on('change:items', renderList);
 
 
 
278
  }
279
  """
280
+ items = t.List(trait=t.Dict()).tag(sync=True) # [{tid:int, label:str}, ...]
281
+ hovered_id = t.Int(allow_none=True).tag(sync=True)
282
+ clicked_id = t.Int(allow_none=True).tag(sync=True)
 
283
 
284
 
285
+ # ---------- Predictions list (uses HoverList) ----------
286
  @solara.component
287
  def PredictionsList():
288
+ df = preds_rx.value
289
  with solara.Column(gap="6px", style={"maxWidth": "720px"}):
290
  solara.Markdown("### Prediction")
291
  solara.Text(
 
296
  },
297
  )
298
 
299
+ # Build items for the browser widget
300
+ items = []
301
  for i, row in df.iterrows():
302
  tid = int(row["id"])
303
+ prob = row["probs"] # already a formatted string like "4.12%"
304
  tok_disp = display_token_from_id(tid)
305
+ items.append({"tid": tid, "label": fmt_row(i, prob, tid, tok_disp)})
306
+
307
+ w = HoverList()
308
+ w.items = items
309
+
310
+ # Hover preview (updates plot + neighbor chips)
311
+ def _on_hover(change):
312
+ tid = change["new"]
313
+ if tid is not None:
314
+ preview_token(int(tid))
315
+ w.observe(_on_hover, names="hovered_id")
316
+
317
+ # Click append
318
+ def _on_click(change):
319
+ tid = change["new"]
320
+ if tid is not None:
321
+ append_token(int(tid))
322
+ w.observe(_on_click, names="clicked_id")
323
+
324
+ solara.display(w)
325
+
 
 
 
 
 
 
 
 
 
326
 
327
  # ---------- Page ----------
328
  @solara.component
 
336
  "Click a candidate to append it and highlight its **semantic neighborhood**. "
337
  "Hover a candidate to preview its neighborhood."
338
  )
339
+
340
  solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
341
  solara.Markdown(f"*{notice_rx.value}*")
342
 
343
  with solara.Row(classes=["app-row"]):
 
344
  with solara.Column(classes=["predictions-panel"]):
345
  PredictionsList()
346
+
 
347
  with solara.Column(classes=["plot-panel"]):
348
  solara.Markdown("### Semantic Neighborhood")
349
  if not coords:
350
  solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.")
351
  else:
352
  solara.FigurePlotly(fig_rx.value)
353
+
354
  if neighbor_list_rx.value:
355
  solara.Markdown("**Nearest neighbors:**")
356
+ with solara.Row(style={"flex-wrap":"wrap"}):
357
  for tok, sim in neighbor_list_rx.value:
358
+ solara.HTML(tag="span",
359
+ unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>')
 
 
360
 
361
  AutoPredictWatcher()
362
 
363
+
364
  # ---------- Kickoff ----------
365
  on_predict()
366
  Page()