amirali1985 commited on
Commit
af89cf1
Β·
1 Parent(s): a746c39

sync dashboard to scratchpad; add figure

Browse files
Files changed (1) hide show
  1. app.py +89 -408
app.py CHANGED
@@ -21,7 +21,18 @@ MODEL_REPO = "thoughtworks/arithmetic-sorl"
21
  # LaTeX scratchpad content β€” edit here, copy from dashboard into Overleaf
22
  # ═══════════════════════════════════════════════════════════════════
23
 
24
- LATEX_ARITHMETIC_SETUP = r"""% ── Arithmetic case study: task setup + Quirke subtask definitions ──────────
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  \subsection{Case study: six-digit addition and subtraction}
27
  \label{sec:arithmetic}
@@ -47,13 +58,14 @@ each of the 23 active tokens concentrates on a narrow slice of the Quirke taxono
47
  This structure emerges from the info-gain loss alone, with no access to ground-truth subtask labels or carry state.
48
  Tokens are also causally necessary: knocking out all tokens collapses accuracy from 95.5\% to 0.1\%, confirming they carry the computation rather than merely annotating it.
49
 
50
- \textbf{Named tokens enable targeted intervention and better performance.}
51
  Because the routing codes are discrete and named, surgical model edits are possible that have no analog in standard transformers:
52
  swapping a single token at one answer position fixes wrong predictions
53
  at a 27-31\% rate on carry-heavy examples (cross-operation transplant:
54
  93.5\% vs.\ 75.5\% random baseline).
55
  Interpretability here is not merely post-hoc β€” it translates directly into the ability to correct the model.
56
- Correspondingly, \sorl{} outperforms \sft{} on 12 of 13 tested
 
57
  (architecture, data-size) configurations, and on \emph{all 13} on
58
  the hardest 6-deep carry cascades, with gains as large as $+50$\,pp
59
  (Table~\ref{tab:undersized-wins}).
@@ -62,10 +74,10 @@ The margin grows with cascade depth, consistent with explicit carry/borrow routi
62
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
63
  fonttitle=\bfseries\small, title={Finding \#1: \sorl{} increases accuracy on 6-digit arithmetic dramatically},
64
  left=5pt, right=5pt, top=4pt, bottom=4pt]
65
- \small \sorl{} increases accuracy on 6-digit arithmetic dramatically, winning by most on chard ascade splits.
66
  \end{tcolorbox}
67
 
68
- See Appendix \ref{app:arithmetic} further details on SORL interpretability, including a demonstration of auto-interp ~\citep{bills2023language_models_explain_neurons}, token specializations andpolysemantic tokens.
69
  """
70
 
71
  LATEX_FIGURE_EXAMPLE = r"""% fig_arithmetic_example.tex
@@ -254,31 +266,7 @@ LATEX_TABLE_UNDERSIZED = r"""% tab:undersized-wins β€” SoRL vs SFT on undersized
254
  \end{table}
255
  """
256
 
257
- LATEX_APPENDIX = r"""% ═══════════════════════════════════════════════════════════════════════════
258
- % APPENDIX β€” Arithmetic case study (Findings #2–5 + training details)
259
- % ═══════════════════════════════════════════════════════════════════════════
260
- % Paste this entire block into your appendix .tex file.
261
- %
262
- % Required packages:
263
- % \usepackage{tcolorbox}
264
- % \usepackage{booktabs}
265
- % \usepackage{xcolor}
266
- % \usepackage{multirow}
267
- % \usepackage{enumitem} % for [nosep]
268
- %
269
- % Required macros (put in preamble if not already defined):
270
- % \providecommand{\sorl}{\textsc{DLR}}
271
- % \providecommand{\sft}{\textsc{SFT}}
272
- % ═══════════════════════════════════════════════════════════════════════════
273
-
274
- % ── Finding box macro (put in preamble) ───────────────────────────���──────────
275
- % \usepackage{tcolorbox}
276
- % \newcommand{\finding}[2]{%
277
- % \begin{tcolorbox}[colback=gray!6,colframe=gray!40,
278
- % fonttitle=\bfseries\small,title={Finding \##1},
279
- % left=5pt,right=5pt,top=4pt,bottom=4pt]\small#2\end{tcolorbox}}
280
-
281
- \section{Arithmetic case study: interpretability analysis}
282
  \label{app:arithmetic}
283
 
284
  \begin{table}[h]
@@ -310,13 +298,12 @@ LATEX_APPENDIX = r"""% ═══════════════════
310
  \label{tab:quirke-subtasks}
311
  \end{table}
312
 
313
- \textbf{Setup.}
314
  All interpretability analyses use model
315
  \texttt{add\_sub\_sorl\_v1\_abs30\_K1\_100K\_2L1H128d}
316
  (\texttt{2L/1H/128d}, 2 layers, 1 head, hidden size 128; trained on 100K examples),
317
  evaluated on 2{,}600 held-out problems across 26 splits.
318
- This model achieves 95.5\% accuracy with \sorl{} abstraction tokens
319
- and 0.1\% without β€” making it the clearest test-bed for causal analysis.
320
  Experimental code: \texttt{arithmetic/experiments/} (all results reproducible).
321
 
322
  % ─────────────────────────────────────────────────────────────────────────────
@@ -370,16 +357,13 @@ Table~\ref{tab:ablation-splits} shows per-split accuracy under each condition.
370
  \label{tab:ablation-splits}
371
  \end{table}
372
 
373
- \textbf{Commentary.}
374
- Knockout reduces accuracy to $\leq$2\% on every split, confirming that
375
- the model has offloaded computation into the routing tokens.
376
  Three patterns are notable:
377
 
378
  \begin{itemize}[nosep]
379
  \item \textbf{Shuffle $>$ Random on easy splits.}
380
- On addition S0 (no carry), shuffle yields 24\% vs.\ random's 28\%;
381
- the gap is small and reflects that no single position is critical β€”
382
- any token in any position is roughly as bad as another.
383
  \item \textbf{Shuffle $>$ Random on cascade splits (C3--C6).}
384
  On 4--6-deep carry cascades, shuffle (23--28\%) consistently
385
  outperforms random (13--19\%).
@@ -390,7 +374,7 @@ Three patterns are notable:
390
  cascade resolution impossible.
391
  \item \textbf{Borrow cascades are uniquely sensitive.}
392
  Sub-M4 (4-deep borrow) drops from 85\% baseline to 6\% under
393
- shuffle and 0\% under random β€” a 79\,pp collapse from shuffle
394
  alone. Sub-M5 (5-deep borrow) is hardest even at baseline (57\%),
395
  and all ablations reduce it to $\leq$3\%, showing that deep
396
  borrow cascades are the single hardest regime and that \sorl{}
@@ -398,32 +382,25 @@ Three patterns are notable:
398
  \end{itemize}
399
 
400
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
401
- fonttitle=\bfseries\small, title={Finding \#2: Abstraction tokens are causally necessary},
402
  left=5pt, right=5pt, top=4pt, bottom=4pt]
403
- \small \sorl{} abstraction tokens are causally necessary for correct computation.
 
 
404
  \end{tcolorbox}
405
 
406
  % ─────────────────────────────────────────────────────────────────────────────
407
- \subsection{Token--subtask heatmap}
408
  \label{app:heatmap}
409
 
410
- \begin{figure}[h]
411
- \centering
412
- \includegraphics[width=0.9\linewidth]{experiments/03_token_subtask_heatmap/fig_token_subtask.png}
413
- \caption{Token--subtask heatmap for \texttt{add\_sub\_sorl\_v1\_abs30\_K1\_100K\_2L1H128d}.
414
- Each cell shows $P(\text{subtask} \mid \text{token})$ over the held-out evaluation set.
415
- Rows are active tokens (23 of 30 appear in eval); columns are the 10 Quirke subtask labels.
416
- Tokens are sorted by dominant subtask.}
417
- \label{fig:heatmap}
418
- \end{figure}
419
 
420
- Of the 30 tokens in the codebook, 23 appear in the held-out evaluation set.
421
- Each active token concentrates on a narrow slice of the subtask space:
422
- the dominant subtask accounts for ${\geq}70\%$ of that token's occurrences
423
- in the majority of cases.
424
- Tokens are also \emph{position-locked}: each token appears predominantly
425
- at one or two answer positions ($d_0$--$d_6$), rarely crossing position
426
- boundaries.
427
  Representative examples are shown in Table~\ref{tab:token-profiles}:
428
  token \texttt{t21} fires 93\% of the time on US (sum-9 cascade, addition)
429
  with digit sum $\equiv 9 \pmod{10}$ in 95\% of cases;
@@ -449,40 +426,34 @@ token \texttt{t23} is the subtraction mirror (UD, 88\%, position $d_3$).
449
  \end{table}
450
 
451
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
452
- fonttitle=\bfseries\small, title={Finding \#3: Tokens spontaneously specialize by subtask and position},
453
  left=5pt, right=5pt, top=4pt, bottom=4pt]
454
- \small \sorl{} tokens spontaneously specialize by subtask and position without supervision.
 
 
455
  \end{tcolorbox}
456
 
457
  % ─────────────────────────────────────────────────────────────────────────────
458
  \subsection{Guided computation via token intervention}
459
  \label{app:guided}
460
 
461
- If \sorl{} tokens encode the \emph{computational route} rather than just
462
- the answer, swapping a token at a mispredicted position should
463
- \emph{fix} the prediction β€” without retraining, without accessing
464
- internal activations, and without modifying the weights.
465
 
466
  We test this with \emph{surgical swap}: for each wrong prediction,
467
  we try replacing the abstraction token at each answer position with
468
- every other token in the codebook (29 candidates $\times$ 5 positions = 145
469
- interventions per example) and measure how many wrong predictions become
470
- correct β€” and how many previously-correct predictions break.
471
-
472
- \textbf{Results.}
473
- At positions $d_0$--$d_2$ (the carry-heavy positions), a fixing swap exists
474
- for 27--31\% of mispredicted examples.
475
- The best single swap is replacing \texttt{t16} with \texttt{t25} at $d_1$:
476
- this fixes 10 wrong predictions while breaking only 5 correct ones
477
- (a 2:1 fix-to-break ratio), all on carry cascade splits (C4--C6).
478
- Position $d_3$ and $d_4$ are harder to fix surgically (fix rates 8\% and 2\%),
479
- consistent with those positions encoding longer-range carry state that a
480
- single-position swap cannot resolve.
481
 
482
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
483
- fonttitle=\bfseries\small, title={Finding \#4: Single token swaps fix mispredicted carry-heavy examples},
484
  left=5pt, right=5pt, top=4pt, bottom=4pt]
485
- \small Single token swaps fix mispredicted carry-heavy examples, enabling targeted model correction.
 
 
486
  \end{tcolorbox}
487
 
488
  % ───────────────��─────────────────────────────────────────────────────────────
@@ -490,15 +461,13 @@ single-position swap cannot resolve.
490
  \label{app:quirke-analogy}
491
 
492
  \citet{quirke_2024_addsub_preprint} identify, via PCA of internal
493
- residual-stream activations, a \emph{tri-state carry classifier} at each
494
- digit position $n$: the hidden state encodes which of three carry regimes
495
  applies β€”
496
  $\text{ST}_n = 0$ (digit sum $< 9$; carry cannot propagate),
497
  $\text{ST}_n = 1$ (digit sum $> 9$; carry will propagate), or
498
  $\text{ST}_n = U$ (digit sum $= 9$; carry state \emph{uncertain}, depends
499
  on lower positions).
500
- This trichotomy is the core circuit for addition; borrow cascades have
501
- an analogous structure for subtraction.
502
 
503
  \sorl{} recovers the same trichotomy \emph{without access to activation
504
  data or ground-truth circuit labels}, purely from the info-gain training
@@ -518,15 +487,21 @@ The correspondence is visible directly in the codebook:
518
  tokens concentrated on SA/MD subtasks.
519
  \end{itemize}
520
 
521
- Crucially, Quirke et al.\ required full activation-level mechanistic
522
- interpretability to discover this classifier; \sorl{} surfaces it as a
523
- readable token in the output sequence, accessible without any
524
- post-hoc analysis.
525
 
526
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
527
- fonttitle=\bfseries\small, title={Finding \#5: \sorl{} rediscovers known arithmetic circuits without supervision},
528
  left=5pt, right=5pt, top=4pt, bottom=4pt]
529
- \small \sorl{} rediscovers known arithmetic circuits without supervision or activation access.
 
 
 
 
 
 
 
 
 
530
  \end{tcolorbox}
531
 
532
  % ─────────────────────────────────────────────────────────────────────────────
@@ -535,8 +510,7 @@ post-hoc analysis.
535
 
536
  Not all codebook tokens are specialists. Table~\ref{tab:token-polysemanticity}
537
  contrasts the most specialist with the most polysemantic tokens.
538
- Token \texttt{t21} fires 94\% of the time on a single subtask (US) at a
539
- single position ($d_3$, addition only) β€” a true specialist.
540
  Token \texttt{t1}, by contrast, is the highest-frequency token
541
  ($n{=}6{,}359$, spanning all five answer positions) with no subtask
542
  exceeding 24\% β€” it acts as a general-purpose fallback, handling
@@ -561,66 +535,48 @@ whichever position and carry regime was not captured by a specialist token.
561
  \end{tabular}
562
  \caption{Specialist vs.\ polysemantic tokens.
563
  \textbf{Purity} = fraction of occurrences where the top subtask applies.
564
- \textbf{Positions} = number of distinct answer positions ($d_0$--$d_6$)
565
  where the token appears. \texttt{t1} is a high-frequency fallback used
566
  across all positions; \texttt{t21} is a pure sum-9 cascade detector.}
567
  \label{tab:token-polysemanticity}
568
  \end{table}
569
 
570
- Polysemanticity here mirrors the phenomenon described in neural network
571
- interpretability~\citep{elhage2022superposition}: a single token encodes
572
- multiple distinct roles, likely because the 30-token codebook has spare
573
- capacity at the overflow position ($d_0$) where carry state is most
574
- variable. The specialist tokens concentrate at mid-sequence positions
575
- ($d_2$--$d_4$) where carry propagation is most structured.
576
 
577
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
578
- fonttitle=\bfseries\small, title={Finding \#6: The codebook mixes specialist and polysemantic tokens},
579
  left=5pt, right=5pt, top=4pt, bottom=4pt]
580
- \small The codebook mixes highly specialist tokens with polysemantic fallback tokens.
 
 
581
  \end{tcolorbox}
582
 
583
  % ─────────────────────────────────────────────────────────────────────────────
584
  \subsection{Automated token interpretation}
585
  \label{app:autointerp}
586
 
587
- We implement a light version of the automated interpretation procedure
588
- of \citet{bills2023language}.
589
- For each active token, we collect the $N{=}10$ examples from the evaluation
590
- set where the model assigned it with highest softmax confidence,
591
  then ask \texttt{claude-haiku} to produce a one-sentence role description.
592
- The full procedure is in \texttt{experiments/11\_auto\_interp/run.py}.
593
 
594
- \begin{table}[h]
595
- \centering\small
596
- \begin{tabular}{clrp{5.5cm}}
597
- \toprule
598
- Token & Top subtask & Conf. & Auto-interpretation \\
599
- \midrule
600
- \texttt{t0} & UC (47\%) & 1.00 & Token t0 marks the tens digit position in addition problems, regardless of carry state or sum value. \\
601
- \texttt{t2} & UC (70\%) & 0.99 & Token t2 outputs the ones digit (0) when adding two numbers whose ones digits sum to 10 or more. \\
602
- \texttt{t1} & UC (30\%) & 0.99 & This token routes to the fourth digit position during addition when a carry from the previous position must be incorporated. \\
603
- \texttt{t3} & UC (44\%) & 0.94 & Token t3 routes to the hundreds position (d3) when processing carries from the tens column in addition. \\
604
- \texttt{t5} & MD (65\%) & 0.93 & Token t5 routes cases where the ones digit result is 0, spanning multiple subtasks and operations. \\
605
- \texttt{t8} & MD (26\%) & 0.91 & Token t8 activates when processing the tens digit (d2) across addition/subtraction with various carry/borrow states. \\
606
- \texttt{t10} & UB (41\%) & 0.88 & Token t10 routes subtraction problems requiring borrow propagation at mid-to-late digit positions. \\
607
- \texttt{t6} & UC (27\%) & 0.88 & Token t6 routes cases where the ones digit result is 0, regardless of operation or carry state. \\
608
- \bottomrule
609
- \end{tabular}
610
- \caption{Auto-interpretation of \sorl{} abstraction tokens (\`{a} la \citealt{bills2023language}).
611
- For each token, the 10 examples where the model assigned it
612
- with highest softmax confidence are shown to \texttt{claude-haiku},
613
- which produces a one-sentence role description.
614
- \textbf{Conf.}\ = mean softmax probability of the assigned token.}
615
- \label{tab:auto-interp}
616
- \end{table}
617
 
618
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
619
- fonttitle=\bfseries\small, title={Finding \#7: Automated interpretation produces human-readable token roles},
620
  left=5pt, right=5pt, top=4pt, bottom=4pt]
621
- \small Automated interpretation produces human-readable role descriptions that match ground-truth subtask labels.
 
 
 
 
 
622
  \end{tcolorbox}
623
-
624
  % ─────────────────────────────────────────────────────────────────────────────
625
  \subsection{Training and evaluation details}
626
  \label{app:training}
@@ -662,284 +618,9 @@ The full procedure is in \texttt{experiments/11\_auto\_interp/run.py}.
662
  \end{table}
663
 
664
  Evaluation uses fixed-length autoregressive decoding (no teacher forcing):
665
- the model generates answer digits $d_0 \to d_6$ using its own predictions,
666
- with abstraction tokens inserted via the \sorl{} search-then-recurse
667
- procedure (matching training). Accuracy is measured on 100 held-out
668
- examples per split (seed 42; \texttt{thoughtworks/arithmetic-sorl-data}).
669
  """
670
 
671
- HARD_SPLITS = ["add_C4", "add_C5", "add_C6", "sub_M4", "sub_M5"]
672
- ALL_SPLITS = [
673
- "add_S0", "add_S1", "add_S2", "add_S3", "add_S4", "add_S5", "add_S6", "add_random",
674
- "add_C3", "add_C4", "add_C5", "add_C6",
675
- "sub_M0", "sub_M1", "sub_M2", "sub_M3", "sub_M4", "sub_M5", "sub_random",
676
- "sub_B3", "sub_B4", "sub_B5",
677
- ]
678
-
679
-
680
- def fetch_all_models():
681
- """Pull VALID models from HF using the model catalog for filtering."""
682
- # Load model_catalog.json to filter by status (VALID/SUPERSEDED)
683
- catalog = None
684
- try:
685
- cat_local = hf_hub_download(MODEL_REPO, "model_catalog.json",
686
- local_dir="/tmp/hf_dash_cache",
687
- force_download=True)
688
- catalog = {e["name"]: e for e in json.load(open(cat_local))}
689
- except Exception:
690
- pass
691
-
692
- api = HfApi()
693
- all_files = api.list_repo_files(MODEL_REPO)
694
- config_files = sorted([f for f in all_files if f.endswith("train_config.json")
695
- and not f.startswith("interp_results/")])
696
- metrics_files = set(f for f in all_files if f.endswith("metrics.json"))
697
-
698
- models = []
699
- for cf in config_files:
700
- subfolder = cf.rsplit("/", 1)[0]
701
-
702
- # Skip non-VALID models if catalog is available
703
- if catalog is not None:
704
- cat_entry = catalog.get(subfolder, {})
705
- if cat_entry.get("status", "VALID") != "VALID":
706
- continue
707
-
708
- try:
709
- local = hf_hub_download(MODEL_REPO, cf, local_dir="/tmp/hf_dash_cache")
710
- config = json.load(open(local))
711
- metrics = {}
712
- mf = f"{subfolder}/metrics.json"
713
- if mf in metrics_files:
714
- try:
715
- ml = hf_hub_download(MODEL_REPO, mf, local_dir="/tmp/hf_dash_cache")
716
- metrics = json.load(open(ml))
717
- except Exception:
718
- pass
719
- models.append({
720
- "subfolder": subfolder,
721
- "config": config,
722
- "metrics": metrics,
723
- "enriched": not subfolder.startswith("non_enriched/"),
724
- })
725
- except Exception:
726
- pass
727
- return models
728
-
729
-
730
- def fmt_pct(v):
731
- if v is None:
732
- return "β€”"
733
- return f"{v:.0%}"
734
-
735
-
736
- def bold_winner(base_val, sorl_val):
737
- if base_val is None and sorl_val is None:
738
- return "β€”", "β€”"
739
- if base_val is None:
740
- return "β€”", f"**{sorl_val:.0%}**"
741
- if sorl_val is None:
742
- return f"**{base_val:.0%}**", "β€”"
743
- b = f"{base_val:.0%}"
744
- s = f"{sorl_val:.0%}"
745
- if sorl_val > base_val + 0.005:
746
- return b, f"**{s}**"
747
- elif base_val > sorl_val + 0.005:
748
- return f"**{b}**", s
749
- return b, s
750
-
751
-
752
- def get_split_acc(metrics, eval_key, split):
753
- ev = metrics.get(eval_key, {})
754
- s = ev.get("splits", {}).get(split, {})
755
- return s.get("full_accuracy") if s else None
756
-
757
-
758
- def build_comparison_table(models, arch_filter="All", enriched_only=True):
759
- filtered = [m for m in models if (not enriched_only or m["enriched"])]
760
- if arch_filter != "All":
761
- filtered = [m for m in filtered if
762
- f"{m['config'].get('n_layer')}L/{m['config'].get('n_head')}H/{m['config'].get('n_embd')}d" == arch_filter]
763
-
764
- baselines = {}
765
- sorls = {}
766
- for m in filtered:
767
- cfg = m["config"]
768
- ops = cfg.get("ops", "?")
769
- ds = cfg.get("dataset_size", 0)
770
- arch = f"{cfg.get('n_layer')}L/{cfg.get('n_head')}H/{cfg.get('n_embd')}d"
771
- key = (ops, ds, arch)
772
-
773
- if cfg.get("mode") == "baseline":
774
- baselines[key] = m
775
- elif cfg.get("mode") == "sorl":
776
- K = cfg.get("K", 0)
777
- vocab = cfg.get("abs_vocab", 0)
778
- sorl_key = (ops, ds, arch, K, vocab)
779
- sorls[sorl_key] = m
780
-
781
- rows = []
782
- all_keys = set()
783
- for k in baselines:
784
- all_keys.add(k)
785
- for ops, ds, arch, K, vocab in sorls:
786
- all_keys.add((ops, ds, arch))
787
-
788
- for ops, ds, arch in sorted(all_keys):
789
- base = baselines.get((ops, ds, arch))
790
- base_acc = base["config"].get("final_accuracy") if base else None
791
- matching_sorl = {(K, v): m for (o, d, a, K, v), m in sorls.items()
792
- if o == ops and d == ds and a == arch}
793
-
794
- if not matching_sorl and base is None:
795
- continue
796
-
797
- ds_label = f"{ds // 1000}K"
798
- raw_base_wandb = base["config"].get("wandb_url", "") if base else ""
799
- base_wandb = f"[wandb]({raw_base_wandb})" if raw_base_wandb else ""
800
-
801
- HF_BASE = "https://huggingface.co/thoughtworks/arithmetic-sorl/tree/main"
802
- base_hf = f"[model]({HF_BASE}/{base['subfolder']})" if base else ""
803
-
804
- if not matching_sorl:
805
- base_hard = {s: get_split_acc(base["metrics"], "sft_eval", s) for s in HARD_SPLITS} if base else {}
806
- row = {
807
- "Ops": ops, "Data": ds_label, "Arch": arch,
808
- "Baseline": fmt_pct(base_acc), "SoRL": "pending", "Config": "pending",
809
- "B_wandb": base_wandb, "S_wandb": "pending",
810
- "B_hf": base_hf, "S_hf": "pending",
811
- }
812
- for s in HARD_SPLITS:
813
- row[f"B_{s}"] = fmt_pct(base_hard.get(s))
814
- row[f"S_{s}"] = "pending"
815
- rows.append(row)
816
- else:
817
- for (K, vocab), sorl_m in sorted(matching_sorl.items()):
818
- sorl_cfg = sorl_m["config"]
819
- sorl_acc = sorl_cfg.get("final_accuracy")
820
- raw_sorl_wandb = sorl_cfg.get("wandb_url", "")
821
- sorl_wandb = f"[wandb]({raw_sorl_wandb})" if raw_sorl_wandb else ""
822
- eval_key = "sorl_eval"
823
-
824
- sorl_hf = f"[model]({HF_BASE}/{sorl_m['subfolder']})"
825
-
826
- b_str, s_str = bold_winner(base_acc, sorl_acc)
827
- row = {
828
- "Ops": ops, "Data": ds_label, "Arch": arch,
829
- "Baseline": b_str, "SoRL": s_str,
830
- "Config": f"K={K} v={vocab}",
831
- "B_wandb": base_wandb, "S_wandb": sorl_wandb,
832
- "B_hf": base_hf, "S_hf": sorl_hf,
833
- }
834
- for s in HARD_SPLITS:
835
- bv = get_split_acc(base["metrics"], "sft_eval", s) if base else None
836
- sv = get_split_acc(sorl_m["metrics"], eval_key, s)
837
- b_s, s_s = bold_winner(bv, sv)
838
- row[f"B_{s}"] = b_s
839
- row[f"S_{s}"] = s_s
840
- rows.append(row)
841
-
842
- return pd.DataFrame(rows)
843
-
844
-
845
- def build_detailed_splits(models, model_name):
846
- for m in models:
847
- name = m["subfolder"].removeprefix("non_enriched/")
848
- if model_name in name or name in model_name or name == model_name:
849
- cfg = m["config"]
850
- eval_key = "sorl_eval" if cfg.get("mode") == "sorl" else "sft_eval"
851
- ev = m["metrics"].get(eval_key, {})
852
- splits = ev.get("splits", {})
853
- rows = []
854
- for s in ALL_SPLITS:
855
- if s in splits:
856
- rows.append({
857
- "Split": s,
858
- "Accuracy": fmt_pct(splits[s].get("full_accuracy")),
859
- "N": splits[s].get("n_examples", 0),
860
- })
861
- return pd.DataFrame(rows) if rows else pd.DataFrame({"Split": ["No data"], "Accuracy": ["β€”"], "N": [0]})
862
- return pd.DataFrame({"Split": ["Model not found"], "Accuracy": ["β€”"], "N": [0]})
863
-
864
-
865
- def get_queue_status_text(n_models):
866
- try:
867
- path = hf_hub_download(MODEL_REPO, "queue_status.json", local_dir="/tmp/hf_dash_cache")
868
- with open(path) as f:
869
- qs = json.load(f)
870
-
871
- total = qs.get("total", n_models)
872
- done = qs.get("done", n_models)
873
- # Exclude killed jobs from failed count (exit_code -9 or 0s runtime)
874
- all_jobs = qs.get("jobs", [])
875
- real_failed = [j for j in all_jobs if j.get("status") == "failed"
876
- and j.get("exit_code") not in (-9, None) and j.get("elapsed", 999) > 10]
877
- failed = len(real_failed)
878
- running = qs.get("running", 0)
879
- pending = qs.get("pending", 0)
880
-
881
- # Adjust total to exclude killed jobs
882
- killed = len([j for j in all_jobs if j.get("status") == "failed"
883
- and (j.get("exit_code") == -9 or j.get("elapsed", 999) <= 10)])
884
- effective_total = total - killed
885
-
886
- pct = done / effective_total * 100 if effective_total else 0
887
- bar_len = 30
888
- filled = int(bar_len * done / effective_total) if effective_total else 0
889
- bar = "β–ˆ" * filled + "β–‘" * (bar_len - filled)
890
- status = "COMPLETE" if done >= effective_total else "RUNNING"
891
-
892
- lines = [
893
- f"### Queue: {done}/{effective_total} done ({pct:.0f}%) β€” {status}",
894
- f"`{bar}`",
895
- ]
896
- if running or pending or failed:
897
- parts = []
898
- if running:
899
- parts.append(f"🟒 Running: {running}")
900
- if pending:
901
- parts.append(f"⏳ Pending: {pending}")
902
- if failed:
903
- parts.append(f"❌ Failed: {failed}")
904
- lines.append(" | ".join(parts))
905
-
906
- return "\n".join(lines)
907
- except Exception:
908
- return f"**{n_models}** models on HF"
909
-
910
-
911
- def build_eval_info(models):
912
- n_per_split = "?"
913
- n_digits = 6
914
- splits = []
915
- total = "?"
916
- for m in models:
917
- metrics = m.get("metrics", {})
918
- for key in ("sft_eval", "sorl_eval"):
919
- cfg = metrics.get(key, {}).get("config", {})
920
- if cfg.get("n_per_split"):
921
- n_per_split = cfg["n_per_split"]
922
- n_digits = cfg.get("n_digits", 6)
923
- total = metrics[key].get("summary", {}).get("total_examples", "?")
924
- splits = list(metrics[key].get("splits", {}).keys())
925
- break
926
- if splits:
927
- break
928
-
929
- return f"""**Replication of [Quirke et al. 2024](https://arxiv.org/abs/2402.02619)** β€” \
930
- understanding addition and subtraction in transformers.
931
-
932
- We train tiny Qwen3 models (2L/3H/510d, ~8M transformer params) from scratch on \
933
- {n_digits}-digit arithmetic. SoRL v1 (info-gain loss) adds learnable "abstraction tokens" \
934
- every K positions.
935
-
936
- **Eval**: autoregressive (errors propagate, no teacher forcing). Fixed eval sets (seed=42, {n_per_split}/split, {total} total).
937
-
938
- [Paper](https://arxiv.org/abs/2402.02619) Β· \
939
- [Models](https://huggingface.co/thoughtworks/arithmetic-sorl) Β· \
940
- [Data](https://huggingface.co/datasets/thoughtworks/arithmetic-sorl-data) Β· \
941
- [Code](https://github.com/fangyuan-ksgk/mod_gpt/tree/amir/arithmetic)"""
942
-
943
 
944
  # ═══════════════════════════════════════════════════════════════════
945
  # App
 
21
  # LaTeX scratchpad content β€” edit here, copy from dashboard into Overleaf
22
  # ═══════════════════════════════════════════════════════════════════
23
 
24
+ LATEX_ARITHMETIC_SETUP = r"""% ── Arithmetic case study ──────────────────────────────────────────────────
25
+
26
+ \begin{figure}[t]
27
+ \centering
28
+ \includegraphics[width=\linewidth]{figures/fig_arithmetic_example.pdf}
29
+ \caption{Addition $959{,}271 + 040{,}756 = 1{,}000{,}027$ β€” a four-deep carry cascade.
30
+ Each answer-digit position is annotated with its Quirke subtask label (coloured box)
31
+ and the \sorl{} abstraction token assigned by the model (dashed purple box).
32
+ The carry chain ($d_1$--$d_5$) is highlighted; \sorl{} uses a consistent token
33
+ (\texttt{t2}) at cascade positions and distinct tokens elsewhere.}
34
+ \label{fig:arithmetic-example}
35
+ \end{figure}
36
 
37
  \subsection{Case study: six-digit addition and subtraction}
38
  \label{sec:arithmetic}
 
58
  This structure emerges from the info-gain loss alone, with no access to ground-truth subtask labels or carry state.
59
  Tokens are also causally necessary: knocking out all tokens collapses accuracy from 95.5\% to 0.1\%, confirming they carry the computation rather than merely annotating it.
60
 
61
+ \textbf{Named tokens better fine tuning performance and interventions.}
62
  Because the routing codes are discrete and named, surgical model edits are possible that have no analog in standard transformers:
63
  swapping a single token at one answer position fixes wrong predictions
64
  at a 27-31\% rate on carry-heavy examples (cross-operation transplant:
65
  93.5\% vs.\ 75.5\% random baseline).
66
  Interpretability here is not merely post-hoc β€” it translates directly into the ability to correct the model.
67
+
68
+ Correspondingly, finetuned \sorl{} outperforms \sft{} strongly on 12 of 13 tested
69
  (architecture, data-size) configurations, and on \emph{all 13} on
70
  the hardest 6-deep carry cascades, with gains as large as $+50$\,pp
71
  (Table~\ref{tab:undersized-wins}).
 
74
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
75
  fonttitle=\bfseries\small, title={Finding \#1: \sorl{} increases accuracy on 6-digit arithmetic dramatically},
76
  left=5pt, right=5pt, top=4pt, bottom=4pt]
77
+ \small \sorl{} increases accuracy on 6-digit arithmetic dramatically, winning stronger on hard cascade and borrow splits.
78
  \end{tcolorbox}
79
 
80
+ See Appendix \ref{app:arithmetic} further details on SORL interpretability, including a demonstration of auto-interp ~\citep{bills2023language_models_explain_neurons}, token specializations and polysemantic tokens.
81
  """
82
 
83
  LATEX_FIGURE_EXAMPLE = r"""% fig_arithmetic_example.tex
 
266
  \end{table}
267
  """
268
 
269
+ LATEX_APPENDIX = r"""\section{Arithmetic case study: interpretability analysis}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  \label{app:arithmetic}
271
 
272
  \begin{table}[h]
 
298
  \label{tab:quirke-subtasks}
299
  \end{table}
300
 
301
+ \paragraph{Setup.}
302
  All interpretability analyses use model
303
  \texttt{add\_sub\_sorl\_v1\_abs30\_K1\_100K\_2L1H128d}
304
  (\texttt{2L/1H/128d}, 2 layers, 1 head, hidden size 128; trained on 100K examples),
305
  evaluated on 2{,}600 held-out problems across 26 splits.
306
+ This model achieves 95.5\% accuracy with \sorl{} abstraction tokens and 0.1\% without; making it the clearest test-bed for causal analysis.
 
307
  Experimental code: \texttt{arithmetic/experiments/} (all results reproducible).
308
 
309
  % ─────────────────────────────────────────────────────────────────────────────
 
357
  \label{tab:ablation-splits}
358
  \end{table}
359
 
360
+ \paragraph{Commentary.}
361
+ Knockout reduces accuracy to $\leq$2\% on every split, confirming that the model has offloaded computation into the routing tokens.
 
362
  Three patterns are notable:
363
 
364
  \begin{itemize}[nosep]
365
  \item \textbf{Shuffle $>$ Random on easy splits.}
366
+ On addition S0 (no carry), shuffle yields 24\% vs.\ random's 28\%; the gap is small and reflects that no single position is critical - any token in any position is roughly as bad as another.
 
 
367
  \item \textbf{Shuffle $>$ Random on cascade splits (C3--C6).}
368
  On 4--6-deep carry cascades, shuffle (23--28\%) consistently
369
  outperforms random (13--19\%).
 
374
  cascade resolution impossible.
375
  \item \textbf{Borrow cascades are uniquely sensitive.}
376
  Sub-M4 (4-deep borrow) drops from 85\% baseline to 6\% under
377
+ shuffle and 0\% under random - a 79\,pp collapse from shuffle
378
  alone. Sub-M5 (5-deep borrow) is hardest even at baseline (57\%),
379
  and all ablations reduce it to $\leq$3\%, showing that deep
380
  borrow cascades are the single hardest regime and that \sorl{}
 
382
  \end{itemize}
383
 
384
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
385
+ fonttitle=\bfseries\small, title={Finding \#2},
386
  left=5pt, right=5pt, top=4pt, bottom=4pt]
387
+ \small
388
+ \sorl{} abstraction tokens are causally necessary: knockout collapses accuracy from 95.5\% to 0.1\% overall, and to $\leq$3\% on the hardest borrow-cascade splits (M4-M5).
389
+ Shuffle (identity-preserving, position-destroying) is more harmful than random on cascade splits - wrong-position tokens from the same structural family cause systematic carry errors, while random tokens cause broader incoherence.
390
  \end{tcolorbox}
391
 
392
  % ─────────────────────────────────────────────────────────────────────────────
393
+ \subsection{Token-subtask heatmap}
394
  \label{app:heatmap}
395
 
396
+ % [PLACEHOLDER: insert token-subtask heatmap figure here]
397
+ % Figure: P(subtask | token) heatmap for all active tokens Γ— 10 subtask labels.
398
+ % Generate with: python experiments/03_token_subtask_heatmap/run.py \
399
+ % -model add_sub_sorl_v1_abs30_K1_100K_2L1H128d
 
 
 
 
 
400
 
401
+ Of the 30 tokens in the codebook, 18 appear in the held-out evaluation set.
402
+ Each active token concentrates on a narrow slice of the subtask space: the dominant subtask accounts for ${\geq}70\%$ of that token's occurrences in the majority of cases.
403
+ Tokens are also \emph{position-locked}: each token appears predominantly at one or two answer positions ($d_0$-$d_6$), rarely crossing position boundaries.
 
 
 
 
404
  Representative examples are shown in Table~\ref{tab:token-profiles}:
405
  token \texttt{t21} fires 93\% of the time on US (sum-9 cascade, addition)
406
  with digit sum $\equiv 9 \pmod{10}$ in 95\% of cases;
 
426
  \end{table}
427
 
428
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
429
+ fonttitle=\bfseries\small, title={Finding \#3},
430
  left=5pt, right=5pt, top=4pt, bottom=4pt]
431
+ \small
432
+ \sorl{} spontaneously learns position-locked, subtask-specialised routing:
433
+ 18 of 30 codebook tokens are active; each concentrates on 1-2 of the 10 Quirke subtasks (purity ${\geq}70\%$ for most), and each is tied to one or two answer positions. The codebook partitions the arithmetic computation into an interpretable registry of specialist tokens.
434
  \end{tcolorbox}
435
 
436
  % ─────────────────────────────────────────────────────────────────────────────
437
  \subsection{Guided computation via token intervention}
438
  \label{app:guided}
439
 
440
+ If \sorl{} tokens encode the \emph{computational route} rather than just the answer, swapping a token at a mispredicted position should \emph{fix} the prediction β€” without retraining, without accessing internal activations, and without modifying the weights.
 
 
 
441
 
442
  We test this with \emph{surgical swap}: for each wrong prediction,
443
  we try replacing the abstraction token at each answer position with
444
+ every other token in the codebook (29 candidates $\times$ 5 positions = 145 interventions per example) and measure how many wrong predictions become correct β€” and how many previously-correct predictions break.
445
+
446
+ \paragraph{Results.}
447
+ At positions $d_0$-$d_2$ (the carry-heavy positions), a fixing swap exists for 27-31\% of mispredicted examples.
448
+ The best single swap is replacing \texttt{t16} with \texttt{t25} at $d_1$: this fixes 10 wrong predictions while breaking only 5 correct ones (a 2:1 fix-to-break ratio), all on carry cascade splits (C4-C6).
449
+ Position $d_3$ and $d_4$ are harder to fix surgically (fix rates 8\% and 2\%), consistent with those positions encoding longer-range carry state that a single-position swap cannot resolve.
 
 
 
 
 
 
 
450
 
451
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
452
+ fonttitle=\bfseries\small, title={Finding \#4},
453
  left=5pt, right=5pt, top=4pt, bottom=4pt]
454
+ \small
455
+ Token interventions enable \emph{guided computation}: replacing a single abstraction token in the sequence fixes a wrong prediction in 27-31\% of mispredicted examples at carry-heavy positions, with no weight updates and no access to internal activations.
456
+ This is only possible because the tokens are a human-readable interface to the model's routing decisions.
457
  \end{tcolorbox}
458
 
459
  % ───────────────��─────────────────────────────────────────────────────────────
 
461
  \label{app:quirke-analogy}
462
 
463
  \citet{quirke_2024_addsub_preprint} identify, via PCA of internal
464
+ residual-stream activations, a \emph{tri-state carry classifier} at each digit position $n$: the hidden state encodes which of three carry regimes
 
465
  applies β€”
466
  $\text{ST}_n = 0$ (digit sum $< 9$; carry cannot propagate),
467
  $\text{ST}_n = 1$ (digit sum $> 9$; carry will propagate), or
468
  $\text{ST}_n = U$ (digit sum $= 9$; carry state \emph{uncertain}, depends
469
  on lower positions).
470
+ This trichotomy is the core circuit for addition; borrow cascades have an analogous structure for subtraction.
 
471
 
472
  \sorl{} recovers the same trichotomy \emph{without access to activation
473
  data or ground-truth circuit labels}, purely from the info-gain training
 
487
  tokens concentrated on SA/MD subtasks.
488
  \end{itemize}
489
 
490
+ Crucially, Quirke et al.\ required full activation-level mechanistic interpretability to discover this classifier; \sorl{} surfaces it as a readable token in the output sequence, accessible without any post-hoc analysis.
 
 
 
491
 
492
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
493
+ fonttitle=\bfseries\small, title={Finding \#5},
494
  left=5pt, right=5pt, top=4pt, bottom=4pt]
495
+ \small
496
+ \sorl{} independently rediscovers the carry-state tri-classifier
497
+ ($\text{ST}_n \in \{0, U, 1\}$) identified by \citet{quirke_2024_addsub_preprint}
498
+ via internal circuit analysis β€” with no access to ground-truth circuit labels.
499
+ The three carry regimes map onto disjoint token clusters (e.g.\
500
+ \texttt{t21}/\texttt{t6} for sum-9 uncertain in addition;
501
+ \texttt{t23}/\texttt{t7} for borrow-uncertain in subtraction),
502
+ and an analogous structure appears for subtraction borrow cascades.
503
+ What Quirke et al.\ needed PCA of hidden activations to reveal,
504
+ \sorl{} externalises as a readable routing token.
505
  \end{tcolorbox}
506
 
507
  % ─────────────────────────────────────────────────────────────────────────────
 
510
 
511
  Not all codebook tokens are specialists. Table~\ref{tab:token-polysemanticity}
512
  contrasts the most specialist with the most polysemantic tokens.
513
+ Token \texttt{t21} fires 94\% of the time on a single subtask (US) at a single position ($d_3$, addition only) β€” a true specialist.
 
514
  Token \texttt{t1}, by contrast, is the highest-frequency token
515
  ($n{=}6{,}359$, spanning all five answer positions) with no subtask
516
  exceeding 24\% β€” it acts as a general-purpose fallback, handling
 
535
  \end{tabular}
536
  \caption{Specialist vs.\ polysemantic tokens.
537
  \textbf{Purity} = fraction of occurrences where the top subtask applies.
538
+ \textbf{Positions} = number of distinct answer positions ($d_0$-$d_6$)
539
  where the token appears. \texttt{t1} is a high-frequency fallback used
540
  across all positions; \texttt{t21} is a pure sum-9 cascade detector.}
541
  \label{tab:token-polysemanticity}
542
  \end{table}
543
 
544
+ Polysemanticity here mirrors the phenomenon described in neural network interpretability~\citep{elhage2022superposition}: a single token encodes multiple distinct roles, likely because the 30-token codebook has spare capacity at the overflow position ($d_0$) where carry state is most
545
+ variable. The specialist tokens concentrate at mid-sequence positions ($d_2$-$d_4$) where carry propagation is most structured.
 
 
 
 
546
 
547
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
548
+ fonttitle=\bfseries\small, title={Finding \#6},
549
  left=5pt, right=5pt, top=4pt, bottom=4pt]
550
+ \small
551
+ The 30-token codebook is not uniformly specialist: high-purity tokens (e.g.\ \texttt{t21}, 94\% US, single position) coexist with polysemantic fallback tokens (e.g.\ \texttt{t1}, top purity 24\%, all five positions).
552
+ Polysemanticity concentrates at overflow positions where carry state is most variable; specialist tokens dominate the structured mid-sequence carry-propagation positions.
553
  \end{tcolorbox}
554
 
555
  % ─────────────────────────────────────────────────────────────────────────────
556
  \subsection{Automated token interpretation}
557
  \label{app:autointerp}
558
 
559
+ We implement a light version of the automated interpretation procedure of \citet{bills2023language}.
560
+ For each active token, we collect the $N{=}10$ examples from the evaluation set where the model assigned it with highest softmax confidence,
 
 
561
  then ask \texttt{claude-haiku} to produce a one-sentence role description.
 
562
 
563
+ % [PLACEHOLDER β€” run experiments/11_auto_interp/run.py to generate table.tex]
564
+ % Then paste the output of table.tex here:
565
+ %
566
+ % \input{experiments/11_auto_interp/table.tex}
567
+ %
568
+ % Expected columns: Token | Top subtask (purity) | Mean conf. | Auto-interpretation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
  \begin{tcolorbox}[colback=gray!6, colframe=gray!40,
571
+ fonttitle=\bfseries\small, title={Finding \#7},
572
  left=5pt, right=5pt, top=4pt, bottom=4pt]
573
+ \small
574
+ Automated interpretation (\`{a} la \citealt{bills2023language}) applied to the top-10 highest-confidence examples per token produces human-readable role descriptions that match the ground-truth Quirke subtask labels, without access to those labels.
575
+ Specialist tokens receive crisp single-sentence descriptions
576
+ (e.g.\ ``detects sum-9 boundary in addition cascade'');
577
+ polysemantic tokens produce broader descriptions reflecting their
578
+ mixed roles.
579
  \end{tcolorbox}
 
580
  % ─────────────────────────────────────────────────────────────────────────────
581
  \subsection{Training and evaluation details}
582
  \label{app:training}
 
618
  \end{table}
619
 
620
  Evaluation uses fixed-length autoregressive decoding (no teacher forcing):
621
+ the model generates answer digits $d_0 \to d_6$ using its own predictions, with abstraction tokens inserted via the \sorl{} search-then-recurse procedure (matching training). Accuracy is measured on 100 held-out examples per split (seed 42; \texttt{thoughtworks/arithmetic-sorl-data}).
 
 
 
622
  """
623
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
  # ═══════════════════════════════════════════════════════════════════
626
  # App