PeterPinetree commited on
Commit
08c4a63
·
verified ·
1 Parent(s): 0e43a26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -43,8 +43,8 @@ hr{ border-color:var(--border); }
43
  # ------------------ App state ------------------
44
  text_rx = solara.reactive("twinkle, twinkle, little ")
45
  preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
46
- selected_token_id_rx = solara.reactive(None)
47
- neighbor_list_rx = solara.reactive([]) # [(tok_display, sim), ...]
48
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
49
  auto_running_rx = solara.reactive(True)
50
  last_hovered_id_rx = solara.reactive(None)
@@ -59,11 +59,11 @@ neighbors = {}
59
  ids_set = set()
60
 
61
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
62
- coords = json.loads(COORDS_PATH.read_text("utf-8")) # { "tid": [x,y], ... }
63
- neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) # { "neighbors": { "tid": [[nid,sim], ...] } }
64
  ids_set = set(map(int, coords.keys()))
65
  else:
66
- notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.")
67
 
68
  # ------------------ Helpers ------------------
69
  def display_token_from_id(tid: int) -> str:
@@ -109,10 +109,17 @@ def predict_top10(prompt: str) -> pd.DataFrame:
109
  return df
110
 
111
  def on_predict():
 
112
  df = predict_top10(text_rx.value)
113
  preds_rx.set(df)
114
- if len(df) > 0:
115
- preview_token(int(df.iloc[0]["id"])) # highlight top-1 by default
 
 
 
 
 
 
116
 
117
  # ------------------ Plotly figure ------------------
118
  def base_scatter():
@@ -183,8 +190,9 @@ def append_token(token_id: int):
183
  # keep decode() here so spacing stays correct in the prompt
184
  decoded = tokenizer.decode([int(token_id)])
185
  text_rx.set(text_rx.value + decoded)
186
- preview_token(int(token_id)) # highlight appended token
187
- on_predict() # refresh predictions
 
188
 
189
  # ------------------ Auto-predict on typing (debounced) ------------------
190
  @solara.component
@@ -232,18 +240,21 @@ def PredictionsList():
232
  tok_disp = display_token_from_id(tid)
233
  label = fmt_row(i, prob, tid, tok_disp)
234
 
235
- # Use Div so hover works across Solara versions; note (event, ...) signatures
 
236
  with solara.Div(
237
- classes=["rowbtn"],
238
- style={"justifyContent": "flex-start"},
239
  attributes={"tabindex": "0", "role": "button"},
240
- on_click=lambda e=None, tid=tid: append_token(tid),
241
  on_mouse_enter=lambda e=None, tid=tid: preview_token(tid),
242
  on_mouse_over=lambda e=None, tid=tid: preview_token(tid),
243
  on_mouse_move=lambda e=None, tid=tid: preview_token(tid),
244
  on_focus=lambda e=None, tid=tid: preview_token(tid),
245
  ):
246
- solara.Text(label)
 
 
 
 
 
247
 
248
  # ------------------ Page ------------------
249
  @solara.component
 
43
  # ------------------ App state ------------------
44
  text_rx = solara.reactive("twinkle, twinkle, little ")
45
  preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"]))
46
+ selected_token_id_rx = solara.reactive(None) # the token we’re currently highlighting
47
+ neighbor_list_rx = solara.reactive([]) # [(tok_display, sim), ...]
48
  notice_rx = solara.reactive("Click a candidate (or hover to preview).")
49
  auto_running_rx = solara.reactive(True)
50
  last_hovered_id_rx = solara.reactive(None)
 
59
  ids_set = set()
60
 
61
  if COORDS_PATH.exists() and NEIGH_PATH.exists():
62
+ coords = json.loads(COORDS_PATH.read_text("utf-8")) # { "tid": [x,y], ... }
63
+ neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) # { "neighbors": { "tid": [[nid,sim], ...] } }
64
  ids_set = set(map(int, coords.keys()))
65
  else:
66
+ notice_rx.set("Embedding files not found — add assets/embeddings/*.json`.")
67
 
68
  # ------------------ Helpers ------------------
69
  def display_token_from_id(tid: int) -> str:
 
109
  return df
110
 
111
  def on_predict():
112
+ """Update predictions. Keep current highlight unless there is none yet."""
113
  df = predict_top10(text_rx.value)
114
  preds_rx.set(df)
115
+ if len(df) == 0:
116
+ return
117
+ # Only auto-select the top-1 if we don't have a selection yet.
118
+ if selected_token_id_rx.value is None:
119
+ preview_token(int(df.iloc[0]["id"]))
120
+ else:
121
+ # Keep highlighting whatever the user last hovered/clicked.
122
+ fig_rx.set(highlight(int(selected_token_id_rx.value)))
123
 
124
  # ------------------ Plotly figure ------------------
125
  def base_scatter():
 
190
  # keep decode() here so spacing stays correct in the prompt
191
  decoded = tokenizer.decode([int(token_id)])
192
  text_rx.set(text_rx.value + decoded)
193
+ # lock highlight to the clicked token and keep it after re-predict
194
+ preview_token(int(token_id))
195
+ on_predict() # refresh next tokens but do NOT override highlight anymore
196
 
197
  # ------------------ Auto-predict on typing (debounced) ------------------
198
  @solara.component
 
240
  tok_disp = display_token_from_id(tid)
241
  label = fmt_row(i, prob, tid, tok_disp)
242
 
243
+ # Wrap a Button (click) inside a Div (hover). Div events pass (event, ...),
244
+ # Button click is reliable across Solara versions.
245
  with solara.Div(
 
 
246
  attributes={"tabindex": "0", "role": "button"},
 
247
  on_mouse_enter=lambda e=None, tid=tid: preview_token(tid),
248
  on_mouse_over=lambda e=None, tid=tid: preview_token(tid),
249
  on_mouse_move=lambda e=None, tid=tid: preview_token(tid),
250
  on_focus=lambda e=None, tid=tid: preview_token(tid),
251
  ):
252
+ solara.Button(
253
+ label,
254
+ on_click=lambda tid=tid: append_token(tid),
255
+ classes=["rowbtn"],
256
+ style={"justifyContent": "flex-start", "width": "100%"},
257
+ )
258
 
259
  # ------------------ Page ------------------
260
  @solara.component