amirali1985 commited on
Commit
7bf43a3
·
verified ·
1 Parent(s): c30f61c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +58 -22
app.py CHANGED
@@ -288,34 +288,20 @@ activation-level probing or SAEs needed. This is what we test on
288
  detail_btn = gr.Button("Show splits")
289
  detail_table = gr.Dataframe(headers=["Split", "Accuracy", "N"], interactive=False)
290
 
291
- with gr.Accordion("Eval Sets Info", open=False):
292
- gr.Markdown("""
293
- **Fixed eval sets** (seed=42, cached, deterministic — all models evaluated on identical examples):
294
-
295
- | Split Type | Splits | Examples | Description |
296
- |-----------|--------|----------|-------------|
297
- | Quirke cascades (add) | S0–S6 | 250 each | Carry cascade depth 0–6 |
298
- | Quirke cascades (sub) | M0–M5 | 250 each | Borrow cascade depth 0–5 (M6 impossible for 6-digit) |
299
- | Hot carry chains | C3–C6 | 250 each | Varied answer digits (not just 0s) |
300
- | Hot borrow chains | B3–B5 | 250 each | Varied answer digits (not just 9s) |
301
- | Random | add_random, sub_random | 1000 each | Uniform random |
302
-
303
- **Total**: 1400 examples (add_sub), 750 examples (add-only)
304
- """)
305
 
306
  def get_queue_status_text(n_models):
307
  """Show live queue status from HF-uploaded queue_status.json."""
308
- EXPECTED = 90
309
-
310
- # Try to read live queue status
311
  try:
312
  path = hf_hub_download(MODEL_REPO, "queue_status.json",
313
  local_dir="/tmp/hf_dash_cache")
314
  with open(path) as f:
315
  qs = json.load(f)
316
 
317
- total = qs.get("total", EXPECTED)
318
- done = qs.get("done", 0)
319
  failed = qs.get("failed", 0)
320
  running = qs.get("running", 0)
321
  pending = qs.get("pending", 0)
@@ -365,6 +351,55 @@ activation-level probing or SAEs needed. This is what we test on
365
  f"`{bar}`"
366
  )
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  def on_refresh(arch):
369
  models = fetch_all_models()
370
  df = build_comparison_table(models, arch_filter=arch, enriched_only=False)
@@ -372,6 +407,7 @@ activation-level probing or SAEs needed. This is what we test on
372
  n_models = len(models)
373
  summary = f"**{n_models}** models | Arch filter: {arch}"
374
  q_status = get_queue_status_text(n_models)
 
375
 
376
  main_cols = ["Ops", "Data", "Arch", "Baseline", "SoRL", "Config", "B_wandb", "S_wandb"]
377
 
@@ -380,7 +416,7 @@ activation-level probing or SAEs needed. This is what we test on
380
  if (c.startswith("B_") or c.startswith("S_")) and "wandb" not in c]
381
  hard_df = df[["Ops", "Data", "Config"] + hard_cols] if len(df) > 0 else pd.DataFrame()
382
 
383
- return models, summary, q_status, main_df, hard_df
384
 
385
  def on_detail(models, name):
386
  return build_detailed_splits(models, name.strip())
@@ -388,13 +424,13 @@ activation-level probing or SAEs needed. This is what we test on
388
  refresh_btn.click(
389
  on_refresh,
390
  inputs=[arch_filter],
391
- outputs=[models_state, summary_text, queue_status, main_table, hard_table],
392
  )
393
 
394
  arch_filter.change(
395
  on_refresh,
396
  inputs=[arch_filter],
397
- outputs=[models_state, summary_text, queue_status, main_table, hard_table],
398
  )
399
 
400
  detail_btn.click(on_detail, inputs=[models_state, model_selector], outputs=[detail_table])
 
288
  detail_btn = gr.Button("Show splits")
289
  detail_table = gr.Dataframe(headers=["Split", "Accuracy", "N"], interactive=False)
290
 
291
+ with gr.Accordion("About This Study", open=False):
292
+ eval_info_md = gr.Markdown("")
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  def get_queue_status_text(n_models):
295
  """Show live queue status from HF-uploaded queue_status.json."""
296
+ # Try to read live queue status (pushed by upload_status daemon)
 
 
297
  try:
298
  path = hf_hub_download(MODEL_REPO, "queue_status.json",
299
  local_dir="/tmp/hf_dash_cache")
300
  with open(path) as f:
301
  qs = json.load(f)
302
 
303
+ total = qs.get("total", n_models) # fall back to HF model count
304
+ done = qs.get("done", n_models)
305
  failed = qs.get("failed", 0)
306
  running = qs.get("running", 0)
307
  pending = qs.get("pending", 0)
 
351
  f"`{bar}`"
352
  )
353
 
354
+ def build_eval_info(models):
355
+ """Build eval set description from actual model metadata."""
356
+ # Try to get eval config from first available model
357
+ n_per_split = "?"
358
+ n_digits = 6
359
+ splits = []
360
+ total = "?"
361
+ for m in models:
362
+ metrics = m.get("metrics", {})
363
+ for key in ("sft_eval", "sorl_eval"):
364
+ cfg = metrics.get(key, {}).get("config", {})
365
+ if cfg.get("n_per_split"):
366
+ n_per_split = cfg["n_per_split"]
367
+ n_digits = cfg.get("n_digits", 6)
368
+ total = metrics[key].get("summary", {}).get("total_examples", "?")
369
+ splits = list(metrics[key].get("splits", {}).keys())
370
+ break
371
+ if splits:
372
+ break
373
+
374
+ n_add_cascade = len([s for s in splits if s.startswith("add_S")])
375
+ n_sub_cascade = len([s for s in splits if s.startswith("sub_M")])
376
+ n_hot_carry = len([s for s in splits if s.startswith("add_C")])
377
+ n_hot_borrow = len([s for s in splits if s.startswith("sub_B")])
378
+
379
+ return f"""**Replication of [Quirke et al. 2024](https://arxiv.org/abs/2402.02619)** — \
380
+ understanding addition and subtraction in transformers.
381
+
382
+ We train tiny Qwen3 models (2L/3H/510d, ~8M transformer params) from scratch on \
383
+ {n_digits}-digit arithmetic. SoRL v1 (info-gain loss) adds learnable "abstraction tokens" \
384
+ every K positions. The hypothesis: SoRL externalizes carry/borrow circuits that Quirke \
385
+ found via activation-level analysis as explicit, interpretable tokens.
386
+
387
+ **Eval**: autoregressive (errors propagate, no teacher forcing). Fixed eval sets (seed=42, cached).
388
+
389
+ | Split Type | Splits | Examples | Description |
390
+ |-----------|--------|----------|-------------|
391
+ | Carry cascades | S0–S{n_add_cascade - 1} | {n_per_split} each | Carry cascade depth (Quirke §3.2) |
392
+ | Borrow cascades | M0–M{n_sub_cascade - 1} | {n_per_split} each | Borrow cascade depth (Quirke §3.3) |
393
+ | Hot carry chains | C3–C{2 + n_hot_carry} | {n_per_split} each | Cascades with varied answer digits |
394
+ | Hot borrow chains | B3–B{2 + n_hot_borrow} | {n_per_split} each | Borrow cascades with varied digits |
395
+ | Random | add\\_random, sub\\_random | {n_per_split * 4} each | Uniform random |
396
+
397
+ **Total**: {total} examples per eval. \
398
+ [Paper](https://arxiv.org/abs/2402.02619) · \
399
+ [Models](https://huggingface.co/thoughtworks/arithmetic-sorl) · \
400
+ [Data](https://huggingface.co/datasets/thoughtworks/arithmetic-sorl-data) · \
401
+ [Code](https://github.com/thoughtworks/mod_gpt)"""
402
+
403
  def on_refresh(arch):
404
  models = fetch_all_models()
405
  df = build_comparison_table(models, arch_filter=arch, enriched_only=False)
 
407
  n_models = len(models)
408
  summary = f"**{n_models}** models | Arch filter: {arch}"
409
  q_status = get_queue_status_text(n_models)
410
+ eval_info = build_eval_info(models)
411
 
412
  main_cols = ["Ops", "Data", "Arch", "Baseline", "SoRL", "Config", "B_wandb", "S_wandb"]
413
 
 
416
  if (c.startswith("B_") or c.startswith("S_")) and "wandb" not in c]
417
  hard_df = df[["Ops", "Data", "Config"] + hard_cols] if len(df) > 0 else pd.DataFrame()
418
 
419
+ return models, summary, q_status, main_df, hard_df, eval_info
420
 
421
  def on_detail(models, name):
422
  return build_detailed_splits(models, name.strip())
 
424
  refresh_btn.click(
425
  on_refresh,
426
  inputs=[arch_filter],
427
+ outputs=[models_state, summary_text, queue_status, main_table, hard_table, eval_info_md],
428
  )
429
 
430
  arch_filter.change(
431
  on_refresh,
432
  inputs=[arch_filter],
433
+ outputs=[models_state, summary_text, queue_status, main_table, hard_table, eval_info_md],
434
  )
435
 
436
  detail_btn.click(on_detail, inputs=[models_state, model_selector], outputs=[detail_table])