PeterPinetree commited on
Commit
f8ca812
·
verified ·
1 Parent(s): e858eb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -43
app.py CHANGED
@@ -23,35 +23,29 @@ theme_css = """
23
  --text:#0b0f14; /* near-black */
24
  --muted:#6b7280; /* gray-500 */
25
  --border:#e5e7eb; /* gray-200 */
26
- --pale: rgba(56,189,248,0.15); /* base dots */
27
- --mid: rgba(56,189,248,0.65); /* neighbors */
28
- --bright: rgba(34,211,238,1.0); /* selected */
29
  }
30
  body{ background:var(--bg); color:var(--text);}
31
  h1,h2,h3{ color:var(--text); }
32
  hr{ border-color:var(--border); }
33
- .btn{ background:var(--primary); color:#000; border:1px solid var(--primary); padding:6px 10px; border-radius:10px; }
34
  .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
 
35
  .rowbtn{
36
  width:100%; text-align:left; padding:10px 12px; border-radius:12px;
37
- border:1px solid var(--border); background:#fff;
 
 
38
  }
39
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; }
40
- .rowgrid{ display:grid; grid-template-columns: 70px 120px 1fr; gap:10px; align-items:center; }
41
- .rowprob{ color:#111; font-weight:600; }
42
- .rowid{ color:var(--muted); font-variant-numeric: tabular-nums; }
43
- .rowtok{ color:#111; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; }
44
- .listheader{ color:var(--muted); font-size:12px; margin-bottom:6px;}
45
  """
46
 
47
  # ------------------ App state ------------------
48
  text_rx = solara.reactive("twinkle, twinkle, little ")
49
  preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
50
  selected_token_id_rx = solara.reactive(None)
51
- neighbor_list_rx = solara.reactive([]) # list of (tok, sim)
52
-
53
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
54
- auto_running_rx = solara.reactive(True) # toggle if you ever want to disable auto-predict
 
55
 
56
  # ------------------ Embedding assets ------------------
57
  ASSETS = Path("assets/embeddings")
@@ -63,16 +57,28 @@ neighbors = {}
63
  ids_set = set()
64
 
65
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
66
- coords = json.loads(COORDS_PATH.read_text("utf-8"))
67
- neighbors = json.loads(NEIGH_PATH.read_text("utf-8"))
68
  ids_set = set(map(int, coords.keys()))
69
  else:
70
  notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
71
 
 
 
 
 
 
 
 
 
 
 
 
72
  # ------------------ Predict ------------------
73
  def predict_top10(prompt: str) -> pd.DataFrame:
74
  if not prompt:
75
  return pd.DataFrame(columns=["probs", "id", "tok"])
 
76
  tokens = tokenizer.encode(prompt, return_tensors="pt")
77
  out = model.generate(
78
  tokens,
@@ -98,8 +104,7 @@ def on_predict():
98
  df = predict_top10(text_rx.value)
99
  preds_rx.set(df)
100
  if len(df) > 0:
101
- tid = int(df.iloc[0]["id"])
102
- preview_token(tid) # highlight top-1 by default
103
 
104
  # ------------------ Plotly figure ------------------
105
  def base_scatter():
@@ -121,11 +126,11 @@ def base_scatter():
121
 
122
  fig_rx = solara.reactive(base_scatter())
123
 
124
- def get_neighbor_list(token_id: int, k: int = 18):
125
  if not ids_set or token_id not in ids_set:
126
  return []
127
  raw = neighbors.get("neighbors", {}).get(str(token_id), [])
128
- return raw[:k] # [ [nid, sim], ... ]
129
 
130
  def highlight(token_id: int):
131
  """Return figure with neighbors + target highlighted and update neighbor chip list."""
@@ -134,7 +139,6 @@ def highlight(token_id: int):
134
  neighbor_list_rx.set([])
135
  return fig
136
 
137
- # neighbors
138
  nbrs = get_neighbor_list(token_id, k=20)
139
  if nbrs:
140
  nx = [coords[str(nid)][0] for nid, _ in nbrs]
@@ -144,7 +148,6 @@ def highlight(token_id: int):
144
  marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"),
145
  hoverinfo="skip",
146
  ))
147
- # update chip list
148
  chips = []
149
  for nid, sim in nbrs:
150
  chips.append((tokenizer.decode([int(nid)]), float(sim)))
@@ -152,7 +155,6 @@ def highlight(token_id: int):
152
  else:
153
  neighbor_list_rx.set([])
154
 
155
- # target
156
  tx, ty = coords[str(token_id)]
157
  fig.add_trace(go.Scattergl(
158
  x=[tx], y=[ty], mode="markers",
@@ -162,14 +164,18 @@ def highlight(token_id: int):
162
  return fig
163
 
164
  def preview_token(token_id: int):
165
- selected_token_id_rx.set(int(token_id))
166
- fig_rx.set(highlight(int(token_id)))
 
 
 
 
167
 
168
  def append_token(token_id: int):
169
  decoded = tokenizer.decode([int(token_id)])
170
  text_rx.set(text_rx.value + decoded)
171
  preview_token(int(token_id)) # highlight the one we appended
172
- on_predict() # refresh next top-10 for new text
173
 
174
  # ------------------ Auto-predict on typing (debounced) ------------------
175
  @solara.component
@@ -196,25 +202,30 @@ def AutoPredictWatcher():
196
  return cleanup
197
 
198
  solara.use_effect(effect, [text, auto])
199
- return solara.Text("", style={"display":"none"}) # nothing visible
200
 
201
  # ------------------ UI: custom clickable/hoverable list ------------------
202
  @solara.component
203
  def PredictionsList():
204
  df = preds_rx.value
205
- with solara.Column(gap="6px", style={"maxWidth":"720px"}):
206
  solara.Markdown("### Prediction")
207
- solara.HTML(tag="div", unsafe_innerHTML='<div class="listheader"># &nbsp;&nbsp; probs &nbsp;&nbsp; token id &nbsp;&nbsp; predicted next token</div>')
 
 
 
 
 
 
208
  for i, row in df.iterrows():
209
  tid = int(row["id"]); tok = row["tok"]; prob = row["probs"]
210
- label = f'<div class="rowgrid"><div class="rowprob">{i}&nbsp;&nbsp;{prob}</div><div class="rowid">{tid}</div><div class="rowtok">{tok}</div></div>'
211
- # full-width button acts like a row; hover previews, click appends
212
- solara.Button(None,
213
- on_click=lambda tid=tid: append_token(tid),
214
- on_mouse_enter=lambda tid=tid: preview_token(tid),
215
- html=label,
216
- classes=["rowbtn"],
217
- )
218
 
219
  # ------------------ Page ------------------
220
  @solara.component
@@ -229,11 +240,11 @@ def Page():
229
  "Hover a candidate to preview its neighborhood."
230
  )
231
 
232
- # Input row (no Predict button needed anymore)
233
- solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"})
234
-
235
  solara.Markdown(f"*{notice_rx.value}*")
236
 
 
237
  with solara.Row(gap="24px", style={"align-items": "flex-start"}):
238
  with solara.Column():
239
  PredictionsList()
@@ -245,15 +256,17 @@ def Page():
245
  else:
246
  solara.FigurePlotly(fig_rx.value)
247
 
248
- # Neighbor chips
249
  if neighbor_list_rx.value:
250
  solara.Markdown("**Nearest neighbors:**")
251
  with solara.Row(style={"flex-wrap": "wrap"}):
252
  for tok, sim in neighbor_list_rx.value:
253
- solara.HTML(tag="span", unsafe_innerHTML=f'<span class="badge">{tok} &nbsp; {(sim*100):.1f}%</span>')
 
 
 
254
 
255
- AutoPredictWatcher() # invisible component that triggers auto-predict
256
 
257
- # Seed initial predictions
258
  on_predict()
259
  Page()
 
23
  --text:#0b0f14; /* near-black */
24
  --muted:#6b7280; /* gray-500 */
25
  --border:#e5e7eb; /* gray-200 */
 
 
 
26
  }
27
  body{ background:var(--bg); color:var(--text);}
28
  h1,h2,h3{ color:var(--text); }
29
  hr{ border-color:var(--border); }
 
30
  .badge{ display:inline-block; padding:2px 8px; border:1px solid var(--border); border-radius:999px; margin:2px; }
31
+
32
  .rowbtn{
33
  width:100%; text-align:left; padding:10px 12px; border-radius:12px;
34
+ border:1px solid var(--border); background:#fff; color:var(--text);
35
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
36
+ letter-spacing: .2px;
37
  }
38
  .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; cursor:pointer; }
 
 
 
 
 
39
  """
40
 
41
  # ------------------ App state ------------------
42
  text_rx = solara.reactive("twinkle, twinkle, little ")
43
  preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
44
  selected_token_id_rx = solara.reactive(None)
45
+ neighbor_list_rx = solara.reactive([]) # [(tok, sim), ...]
 
46
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
47
+ auto_running_rx = solara.reactive(True)
48
+ last_hovered_id_rx = solara.reactive(None)
49
 
50
  # ------------------ Embedding assets ------------------
51
  ASSETS = Path("assets/embeddings")
 
57
  ids_set = set()
58
 
59
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
60
+ coords = json.loads(COORDS_PATH.read_text("utf-8")) # { "tid": [x,y], ... }
61
+ neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) # { "neighbors": { "tid": [[nid,sim], ...] } }
62
  ids_set = set(map(int, coords.keys()))
63
  else:
64
  notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
65
 
66
+ # ------------------ Helpers ------------------
67
+ def show_token(tok: str) -> str:
68
+ """Make invisible whitespace visible but readable in UI."""
69
+ if tok == "":
70
+ return "⟂"
71
+ return tok.replace(" ", "␠")
72
+
73
+ def fmt_row(idx: int, prob: str, tid: int, tok: str) -> str:
74
+ tok_disp = show_token(tok)
75
+ return f"{idx:>2} {prob:>6} {tid:<6} {tok_disp}"
76
+
77
  # ------------------ Predict ------------------
78
  def predict_top10(prompt: str) -> pd.DataFrame:
79
  if not prompt:
80
  return pd.DataFrame(columns=["probs", "id", "tok"])
81
+
82
  tokens = tokenizer.encode(prompt, return_tensors="pt")
83
  out = model.generate(
84
  tokens,
 
104
  df = predict_top10(text_rx.value)
105
  preds_rx.set(df)
106
  if len(df) > 0:
107
+ preview_token(int(df.iloc[0]["id"])) # highlight top-1 by default
 
108
 
109
  # ------------------ Plotly figure ------------------
110
  def base_scatter():
 
126
 
127
  fig_rx = solara.reactive(base_scatter())
128
 
129
+ def get_neighbor_list(token_id: int, k: int = 20):
130
  if not ids_set or token_id not in ids_set:
131
  return []
132
  raw = neighbors.get("neighbors", {}).get(str(token_id), [])
133
+ return raw[:k]
134
 
135
  def highlight(token_id: int):
136
  """Return figure with neighbors + target highlighted and update neighbor chip list."""
 
139
  neighbor_list_rx.set([])
140
  return fig
141
 
 
142
  nbrs = get_neighbor_list(token_id, k=20)
143
  if nbrs:
144
  nx = [coords[str(nid)][0] for nid, _ in nbrs]
 
148
  marker=dict(size=6, color="rgba(56,189,248,0.75)", symbol="circle"),
149
  hoverinfo="skip",
150
  ))
 
151
  chips = []
152
  for nid, sim in nbrs:
153
  chips.append((tokenizer.decode([int(nid)]), float(sim)))
 
155
  else:
156
  neighbor_list_rx.set([])
157
 
 
158
  tx, ty = coords[str(token_id)]
159
  fig.add_trace(go.Scattergl(
160
  x=[tx], y=[ty], mode="markers",
 
164
  return fig
165
 
166
  def preview_token(token_id: int):
167
+ token_id = int(token_id)
168
+ if last_hovered_id_rx.value == token_id:
169
+ return
170
+ last_hovered_id_rx.set(token_id)
171
+ selected_token_id_rx.set(token_id)
172
+ fig_rx.set(highlight(token_id))
173
 
174
  def append_token(token_id: int):
175
  decoded = tokenizer.decode([int(token_id)])
176
  text_rx.set(text_rx.value + decoded)
177
  preview_token(int(token_id)) # highlight the one we appended
178
+ on_predict() # refresh next top-10 for new text
179
 
180
  # ------------------ Auto-predict on typing (debounced) ------------------
181
  @solara.component
 
202
  return cleanup
203
 
204
  solara.use_effect(effect, [text, auto])
205
+ return solara.Text("", style={"display": "none"})
206
 
207
  # ------------------ UI: custom clickable/hoverable list ------------------
208
  @solara.component
209
  def PredictionsList():
210
  df = preds_rx.value
211
+ with solara.Column(gap="6px", style={"maxWidth": "720px"}):
212
  solara.Markdown("### Prediction")
213
+ solara.Text(
214
+ " # probs token predicted next token",
215
+ style={
216
+ "color": "var(--muted)",
217
+ "fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace',
218
+ },
219
+ )
220
  for i, row in df.iterrows():
221
  tid = int(row["id"]); tok = row["tok"]; prob = row["probs"]
222
+ label = fmt_row(i, prob, tid, tok)
223
+ solara.Button(
224
+ label,
225
+ on_click=lambda tid=tid: append_token(tid),
226
+ on_mouse_enter=lambda tid=tid: preview_token(tid),
227
+ classes=["rowbtn"],
228
+ )
 
229
 
230
  # ------------------ Page ------------------
231
  @solara.component
 
240
  "Hover a candidate to preview its neighborhood."
241
  )
242
 
243
+ # Input (auto-predict is handled by watcher)
244
+ solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth": "520px"})
 
245
  solara.Markdown(f"*{notice_rx.value}*")
246
 
247
+ # Two columns
248
  with solara.Row(gap="24px", style={"align-items": "flex-start"}):
249
  with solara.Column():
250
  PredictionsList()
 
256
  else:
257
  solara.FigurePlotly(fig_rx.value)
258
 
 
259
  if neighbor_list_rx.value:
260
  solara.Markdown("**Nearest neighbors:**")
261
  with solara.Row(style={"flex-wrap": "wrap"}):
262
  for tok, sim in neighbor_list_rx.value:
263
+ solara.HTML(
264
+ tag="span",
265
+ unsafe_innerHTML=f'<span class="badge">{show_token(tok)} &nbsp; {(sim*100):.1f}%</span>',
266
+ )
267
 
268
+ AutoPredictWatcher()
269
 
270
+ # Seed initial predictions and mount
271
  on_predict()
272
  Page()