PeterPinetree commited on
Commit
775aa53
·
verified ·
1 Parent(s): f8ca812

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -27
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # app.py
 
2
  import json
3
  from pathlib import Path
4
  import threading, time
@@ -30,10 +31,12 @@ hr{ border-color:var(--border); }
30
  .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
31
 
32
  .rowbtn{
33
- width:100%; text-align:left; padding:10px 12px; border-radius:12px;
34
  border:1px solid var(--border); background:#fff; color:var(--text);
 
 
35
  font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
36
- letter-spacing: .2px;
37
  }
38
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; }
39
  """
@@ -42,7 +45,7 @@ hr{ border-color:var(--border); }
42
  text_rx = solara.reactive("twinkle, twinkle, little ")
43
  preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
44
  selected_token_id_rx = solara.reactive(None)
45
- neighbor_list_rx = solara.reactive([]) # [(tok, sim), ...]
46
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
47
  auto_running_rx = solara.reactive(True)
48
  last_hovered_id_rx = solara.reactive(None)
@@ -64,15 +67,21 @@ else:
64
  notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
65
 
66
  # ------------------ Helpers ------------------
67
- def show_token(tok: str) -> str:
68
- """Make invisible whitespace visible but readable in UI."""
69
- if tok == "":
70
- return ""
71
- return tok.replace(" ", "")
72
-
73
- def fmt_row(idx: int, prob: str, tid: int, tok: str) -> str:
74
- tok_disp = show_token(tok)
75
- return f"{idx:>2} {prob:>6} {tid:<6} {tok_disp}"
 
 
 
 
 
 
76
 
77
  # ------------------ Predict ------------------
78
  def predict_top10(prompt: str) -> pd.DataFrame:
@@ -95,7 +104,7 @@ def predict_top10(prompt: str) -> pd.DataFrame:
95
  topk = torch.topk(scores, 10)
96
  ids = [int(topk.indices[0, i]) for i in range(10)]
97
  probs = [float(topk.values[0, i]) for i in range(10)]
98
- toks = [tokenizer.decode([i]) for i in ids]
99
  df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
100
  df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
101
  return df
@@ -113,7 +122,7 @@ def base_scatter():
113
  xs, ys = zip(*[coords[k] for k in coords.keys()])
114
  fig.add_trace(go.Scattergl(
115
  x=xs, y=ys, mode="markers",
116
- marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"),
117
  hoverinfo="skip",
118
  ))
119
  fig.update_layout(
@@ -145,12 +154,12 @@ def highlight(token_id: int):
145
  ny = [coords[str(nid)][1] for nid, _ in nbrs]
146
  fig.add_trace(go.Scattergl(
147
  x=nx, y=ny, mode="markers",
148
- marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"),
149
  hoverinfo="skip",
150
  ))
151
  chips = []
152
  for nid, sim in nbrs:
153
- chips.append((tokenizer.decode([int(nid)]), float(sim)))
154
  neighbor_list_rx.set(chips)
155
  else:
156
  neighbor_list_rx.set([])
@@ -158,7 +167,7 @@ def highlight(token_id: int):
158
  tx, ty = coords[str(token_id)]
159
  fig.add_trace(go.Scattergl(
160
  x=[tx], y=[ty], mode="markers",
161
- marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)),
162
  hoverinfo="skip",
163
  ))
164
  return fig
@@ -172,10 +181,11 @@ def preview_token(token_id: int):
172
  fig_rx.set(highlight(token_id))
173
 
174
  def append_token(token_id: int):
 
175
  decoded = tokenizer.decode([int(token_id)])
176
  text_rx.set(text_rx.value + decoded)
177
- preview_token(int(token_id)) # highlight the one we appended
178
- on_predict() # refresh next top-10 for new text
179
 
180
  # ------------------ Auto-predict on typing (debounced) ------------------
181
  @solara.component
@@ -190,7 +200,7 @@ def AutoPredictWatcher():
190
  snap = text
191
 
192
  def worker():
193
- time.sleep(0.25) # debounce ~250ms
194
  if not cancelled and snap == text_rx.value:
195
  on_predict()
196
 
@@ -211,20 +221,25 @@ def PredictionsList():
211
  with solara.Column(gap="6px", style={"maxWidth": "720px"}):
212
  solara.Markdown("### Prediction")
213
  solara.Text(
214
- " # probs token predicted next token",
215
  style={
216
  "color": "var(--muted)",
217
  "fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace',
218
  },
219
  )
220
  for i, row in df.iterrows():
221
- tid = int(row["id"]); tok = row["tok"]; prob = row["probs"]
222
- label = fmt_row(i, prob, tid, tok)
 
 
223
  solara.Button(
224
  label,
225
  on_click=lambda tid=tid: append_token(tid),
226
- on_mouse_enter=lambda tid=tid: preview_token(tid),
 
 
227
  classes=["rowbtn"],
 
228
  )
229
 
230
  # ------------------ Page ------------------
@@ -240,7 +255,7 @@ def Page():
240
  "Hover a candidate to preview its neighborhood."
241
  )
242
 
243
- # Input (auto-predict is handled by watcher)
244
  solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
245
  solara.Markdown(f"*{notice_rx.value}*")
246
 
@@ -262,11 +277,11 @@ def Page():
262
  for tok, sim in neighbor_list_rx.value:
263
  solara.HTML(
264
  tag="span",
265
- unsafe_innerHTML=f'<span class="badge">{show_token(tok)} &nbsp; {(sim*100):.1f}%</span>',
266
  )
267
 
268
  AutoPredictWatcher()
269
 
270
  # Seed initial predictions and mount
271
  on_predict()
272
- Page()
 
1
  # app.py
2
+ # app.py
3
  import json
4
  from pathlib import Path
5
  import threading, time
 
31
  .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
32
 
33
  .rowbtn{
34
+ width:100%; padding:10px 12px; border-radius:12px;
35
  border:1px solid var(--border); background:#fff; color:var(--text);
36
+ display:flex; justify-content:flex-start; align-items:center;
37
+ text-align:left;
38
  font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
39
+ letter-spacing:.2px;
40
  }
41
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; }
42
  """
 
45
  text_rx = solara.reactive("twinkle, twinkle, little ")
46
  preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
47
  selected_token_id_rx = solara.reactive(None)
48
+ neighbor_list_rx = solara.reactive([]) # [(tok_display, sim), ...]
49
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
50
  auto_running_rx = solara.reactive(True)
51
  last_hovered_id_rx = solara.reactive(None)
 
67
  notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
68
 
69
  # ------------------ Helpers ------------------
70
+ def display_token_from_id(tid: int) -> str:
71
+ """Readable label for a single token id (no leading tokenizer markers)."""
72
+ toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True)
73
+ t = toks[0] if toks else ""
74
+ for lead in ("", "Ġ"): # common leading markers
75
+ if t.startswith(lead):
76
+ t = t[len(lead):]
77
+ t = t.replace("\n", "↵")
78
+ if t.strip() == "":
79
+ return "␠" # visible space marker for pure whitespace
80
+ return t
81
+
82
+ def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str:
83
+ # left-justified simple columns
84
+ return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}"
85
 
86
  # ------------------ Predict ------------------
87
  def predict_top10(prompt: str) -> pd.DataFrame:
 
104
  topk = torch.topk(scores, 10)
105
  ids = [int(topk.indices[0, i]) for i in range(10)]
106
  probs = [float(topk.values[0, i]) for i in range(10)]
107
+ toks = [tokenizer.decode([i]) for i in ids] # used for append; display uses display_token_from_id
108
  df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks})
109
  df["probs"] = df["probs"].map(lambda p: f"{p:.2%}")
110
  return df
 
122
  xs, ys = zip(*[coords[k] for k in coords.keys()])
123
  fig.add_trace(go.Scattergl(
124
  x=xs, y=ys, mode="markers",
125
+ marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"), # pale cloud
126
  hoverinfo="skip",
127
  ))
128
  fig.update_layout(
 
154
  ny = [coords[str(nid)][1] for nid, _ in nbrs]
155
  fig.add_trace(go.Scattergl(
156
  x=nx, y=ny, mode="markers",
157
+ marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"), # darker neighbors
158
  hoverinfo="skip",
159
  ))
160
  chips = []
161
  for nid, sim in nbrs:
162
+ chips.append((display_token_from_id(int(nid)), float(sim)))
163
  neighbor_list_rx.set(chips)
164
  else:
165
  neighbor_list_rx.set([])
 
167
  tx, ty = coords[str(token_id)]
168
  fig.add_trace(go.Scattergl(
169
  x=[tx], y=[ty], mode="markers",
170
+ marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)), # bright target
171
  hoverinfo="skip",
172
  ))
173
  return fig
 
181
  fig_rx.set(highlight(token_id))
182
 
183
  def append_token(token_id: int):
184
+ # keep decode() here so spacing stays correct in the prompt
185
  decoded = tokenizer.decode([int(token_id)])
186
  text_rx.set(text_rx.value + decoded)
187
+ preview_token(int(token_id)) # highlight appended token
188
+ on_predict() # refresh predictions
189
 
190
  # ------------------ Auto-predict on typing (debounced) ------------------
191
  @solara.component
 
200
  snap = text
201
 
202
  def worker():
203
+ time.sleep(0.25) # ~250ms debounce
204
  if not cancelled and snap == text_rx.value:
205
  on_predict()
206
 
 
221
  with solara.Column(gap="6px", style={"maxWidth": "720px"}):
222
  solara.Markdown("### Prediction")
223
  solara.Text(
224
+ " # probs token predicted next token",
225
  style={
226
  "color": "var(--muted)",
227
  "fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace',
228
  },
229
  )
230
  for i, row in df.iterrows():
231
+ tid = int(row["id"])
232
+ prob = row["probs"]
233
+ tok_disp = display_token_from_id(tid) # clean label
234
+ label = fmt_row(i, prob, tid, tok_disp)
235
  solara.Button(
236
  label,
237
  on_click=lambda tid=tid: append_token(tid),
238
+ on_mouse_enter=lambda tid=tid: preview_token(tid), # newer solara
239
+ on_mouse_over=lambda tid=tid: preview_token(tid), # older solara
240
+ on_focus=lambda tid=tid: preview_token(tid), # keyboard
241
  classes=["rowbtn"],
242
+ style={"justifyContent": "flex-start"},
243
  )
244
 
245
  # ------------------ Page ------------------
 
255
  "Hover a candidate to preview its neighborhood."
256
  )
257
 
258
+ # Input (auto-predict handled by watcher)
259
  solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
260
  solara.Markdown(f"*{notice_rx.value}*")
261
 
 
277
  for tok, sim in neighbor_list_rx.value:
278
  solara.HTML(
279
  tag="span",
280
+ unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>',
281
  )
282
 
283
  AutoPredictWatcher()
284
 
285
  # Seed initial predictions and mount
286
  on_predict()
287
+ Page()