Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
| 13 |
# for robust hover/click from the browser
|
| 14 |
import anywidget
|
| 15 |
import traitlets as t
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
# ---------- Model ----------
|
|
@@ -83,7 +84,7 @@ selected_token_id_rx = solara.reactive(None)
|
|
| 83 |
neighbor_list_rx = solara.reactive([])
|
| 84 |
last_hovered_id_rx = solara.reactive(None)
|
| 85 |
auto_running_rx = solara.reactive(True)
|
| 86 |
-
|
| 87 |
|
| 88 |
# ---------- Embedding assets ----------
|
| 89 |
ASSETS = Path("assets/embeddings")
|
|
@@ -178,11 +179,20 @@ def get_neighbor_list(token_id: int, k: int = 20):
|
|
| 178 |
|
| 179 |
def highlight(token_id: int):
|
| 180 |
fig = base_scatter()
|
|
|
|
|
|
|
| 181 |
if not coords or token_id not in ids_set:
|
| 182 |
neighbor_list_rx.set([])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
return fig
|
| 184 |
|
|
|
|
|
|
|
| 185 |
nbrs = get_neighbor_list(token_id, k=20)
|
|
|
|
| 186 |
if nbrs:
|
| 187 |
nx = [coords[str(nid)][0] for nid,_ in nbrs]
|
| 188 |
ny = [coords[str(nid)][1] for nid,_ in nbrs]
|
|
@@ -328,7 +338,17 @@ def PredictionsList():
|
|
| 328 |
tid = int(row["id"])
|
| 329 |
prob = row["probs"] # already a formatted string like "4.12%"
|
| 330 |
tok_disp = display_token_from_id(tid)
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
w = HoverList()
|
| 334 |
w.items = items
|
|
@@ -382,6 +402,8 @@ def Page():
|
|
| 382 |
for tok, sim in neighbor_list_rx.value:
|
| 383 |
solara.HTML(tag="span",
|
| 384 |
unsafe_innerHTML=f'<span class="badge">{tok} {(sim*100):.1f}%</span>')
|
|
|
|
|
|
|
| 385 |
|
| 386 |
AutoPredictWatcher()
|
| 387 |
|
|
|
|
| 13 |
# for robust hover/click from the browser
|
| 14 |
import anywidget
|
| 15 |
import traitlets as t
|
| 16 |
+
import html # for escaping token text in the HTML label
|
| 17 |
|
| 18 |
|
| 19 |
# ---------- Model ----------
|
|
|
|
| 84 |
neighbor_list_rx = solara.reactive([])
|
| 85 |
last_hovered_id_rx = solara.reactive(None)
|
| 86 |
auto_running_rx = solara.reactive(True)
|
| 87 |
+
neigh_msg_rx = solara.reactive("") # message shown when no neighborhood is available
|
| 88 |
|
| 89 |
# ---------- Embedding assets ----------
|
| 90 |
ASSETS = Path("assets/embeddings")
|
|
|
|
| 179 |
|
| 180 |
def highlight(token_id: int):
|
| 181 |
fig = base_scatter()
|
| 182 |
+
|
| 183 |
+
# Not in map (or missing map) → clear chips and show message
|
| 184 |
if not coords or token_id not in ids_set:
|
| 185 |
neighbor_list_rx.set([])
|
| 186 |
+
if not coords:
|
| 187 |
+
neigh_msg_rx.set("Embedding map unavailable – add `assets/embeddings/*.json`.")
|
| 188 |
+
else:
|
| 189 |
+
neigh_msg_rx.set("Neighborhood unavailable for this token (not in the top-5k set).")
|
| 190 |
return fig
|
| 191 |
|
| 192 |
+
# In map → clear message and draw neighbors/target
|
| 193 |
+
neigh_msg_rx.set("")
|
| 194 |
nbrs = get_neighbor_list(token_id, k=20)
|
| 195 |
+
|
| 196 |
if nbrs:
|
| 197 |
nx = [coords[str(nid)][0] for nid,_ in nbrs]
|
| 198 |
ny = [coords[str(nid)][1] for nid,_ in nbrs]
|
|
|
|
| 338 |
tid = int(row["id"])
|
| 339 |
prob = row["probs"] # already a formatted string like "4.12%"
|
| 340 |
tok_disp = display_token_from_id(tid)
|
| 341 |
+
tok_safe = html.escape(tok_disp) # protect the HTML label
|
| 342 |
+
|
| 343 |
+
label_html = (
|
| 344 |
+
f'<div class="rowbtn-grid">'
|
| 345 |
+
f' <span class="c0">{i}</span>'
|
| 346 |
+
f' <span class="c1">{prob}</span>'
|
| 347 |
+
f' <span class="c2">{tid}</span>'
|
| 348 |
+
f' <span class="c3">{tok_safe}</span>'
|
| 349 |
+
f'</div>'
|
| 350 |
+
)
|
| 351 |
+
items.append({"tid": tid, "label_html": label_html}) # <-- note label_html
|
| 352 |
|
| 353 |
w = HoverList()
|
| 354 |
w.items = items
|
|
|
|
| 402 |
for tok, sim in neighbor_list_rx.value:
|
| 403 |
solara.HTML(tag="span",
|
| 404 |
unsafe_innerHTML=f'<span class="badge">{tok} {(sim*100):.1f}%</span>')
|
| 405 |
+
elif neigh_msg_rx.value:
|
| 406 |
+
solara.Text(neigh_msg_rx.value, style={"color":"var(--muted)"})
|
| 407 |
|
| 408 |
AutoPredictWatcher()
|
| 409 |
|