| """ |
| SoRL Arithmetic Dashboard β baseline vs SoRL side-by-side comparison. |
| |
| Deployed as HF Space. Reads from thoughtworks/arithmetic-sorl model repo. |
| """ |
| import json |
| import gradio as gr |
| from gen_summary import build_summary_markdown |
| import pandas as pd |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import io |
| import numpy as np |
| from PIL import Image |
| from huggingface_hub import HfApi, hf_hub_download |
|
|
| MODEL_REPO = "thoughtworks/arithmetic-sorl" |
|
|
| |
| |
| |
|
|
| LATEX_ARITHMETIC_SETUP = r"""\subsection{Case study: six-digit addition and subtraction} |
| \label{sec:arithmetic} |
| |
| Six-digit addition and subtraction provides a setting where the internal |
| reasoning structure is known: \citet{quirke_2024_addsub_preprint} identify |
| ten mutually exclusive subtask types at each answer-digit position - carry generation, carry use, sum-9 boundary detection, borrow cascades, and so on - that a transformer must implement to solve the task |
| (see Table~\ref{tab:quirke-subtasks} for details). |
| This makes arithmetic an ideal testbed for one application of \sorl{}: that abstraction tokens may \emph{externalize} reasoning steps, making them directly observable and intervenable without any activation-level tooling. |
| |
| \sorl{} inserts one abstraction token per answer-digit position ($K{=}1$, |
| codebook size $|\mathcal{A}|{=}30$), so each routing decision is a named, discrete symbol emitted at generation time. |
| Figure~\ref{fig:arithmetic-example} shows this concretely: |
| for $959{,}271 + 040{,}756 = 1{,}000{,}027$ (a four-deep carry cascade), |
| \sorl{} assigns \texttt{t2} at every cascade position (UC/US), |
| \texttt{t6} at the sum-9 boundary (SS), and distinct tokens at the trivial positions (SC/SA) - the carry structure is readable off the token sequence with no probing or patching required. |
| Full training and architecture details are in Appendix~\ref{app:training}. |
| |
| \textbf{Abstraction tokens recover known circuits without supervision.} |
| Analysis of \texttt{2L/1H/128d} shows that \sorl{}'s codebook |
| spontaneously partitions into subtask-specialist tokens: |
| each of the 23 active tokens concentrates on a narrow slice of the Quirke taxonomy - dominant subtask accounts for ${\geq}70\%$ of occurrences for the majority of tokens - and is locked to one or two answer positions. |
| This structure emerges from the info-gain loss alone, with no access to ground-truth subtask labels or carry state. |
| Tokens are also causally necessary: knocking out all tokens collapses accuracy from 95.5\% to 0.1\%, confirming they carry the computation rather than merely annotating it. |
| |
| \textbf{Named tokens better fine tuning performance and interventions.} |
| Because the routing codes are discrete and named, surgical model edits are possible that have no analog in standard transformers: |
| swapping a single token at one answer position fixes wrong predictions |
| at a 27-31\% rate on carry-heavy examples (cross-operation transplant: |
| 93.5\% vs.\ 75.5\% random baseline). |
| Interpretability here is not merely post-hoc β it translates directly into the ability to correct the model. |
| |
| Correspondingly, finetuned \sorl{} outperforms \sft{} strongly on 12 of 13 tested |
| (architecture, data-size) configurations, and on \emph{all 13} on |
| the hardest 6-deep carry cascades, with gains as large as $+50$\,pp |
| (Table~\ref{tab:undersized-wins}). |
| The margin grows with cascade depth, consistent with explicit carry/borrow routing being the mechanism behind the gain. |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#1: \sorl{} increases accuracy on 6-digit arithmetic dramatically}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small \sorl{} increases accuracy on 6-digit arithmetic dramatically, winning stronger on hard cascade and borrow splits. |
| \end{tcolorbox} |
| |
| See Appendix \ref{app:arithmetic} further details on SORL interpretability, including a demonstration of auto-interp ~\citep{bills2023language_models_explain_neurons}, token specializations and polysemantic tokens. |
| """ |
|
|
|
|
|
|
| LATEX_APPENDIX = r"""\section{Arithmetic case study: interpretability analysis} |
| \label{app:arithmetic} |
| |
| Full performance results across all architectures and data sizes are in |
| Table~\ref{tab:undersized-wins} (\S\ref{app:performance}). |
| |
| \begin{table}[h] |
| \centering\small |
| \setlength{\tabcolsep}{6pt} |
| \begin{tabular}{llp{7.8cm}} |
| \toprule |
| & Label & Condition at digit position $n$ \\ |
| \midrule |
| \multirow{5}{*}{\rotatebox[origin=c]{90}{Addition\;}} |
| & \textbf{SA} & $d_1{+}d_2 \leq 8$;\; no carry in or out \\ |
| & \textbf{SC} & $d_1{+}d_2 \geq 10$;\; generates a carry \\ |
| & \textbf{SS} & $d_1{+}d_2 = 9$;\; carry state \emph{uncertain} (cascade boundary) \\ |
| & \textbf{UC} & carry arrives from position $n{-}1$;\; answer digit depends on it \\ |
| & \textbf{US} & carry propagates through a run of SS positions (sum-of-9 cascade) \\ |
| \midrule |
| \multirow{5}{*}{\rotatebox[origin=c]{90}{Subtraction\;}} |
| & \textbf{MD} & $d_1 \geq d_2$;\; no borrow \\ |
| & \textbf{MB} & $d_1 < d_2$;\; generates a borrow \\ |
| & \textbf{ME} & $d_1 = d_2$;\; borrow state \emph{uncertain} \\ |
| & \textbf{UB} & borrow arrives from position $n{-}1$ \\ |
| & \textbf{UD} & borrow propagates through a run of ME positions \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Per-digit subtask labels for six-digit addition and |
| subtraction~\citep{quirke_2024_addsub_preprint}. |
| Cascades (US, UD) require tracking carry/borrow state across |
| multiple positions and are the hardest splits.} |
| \label{tab:quirke-subtasks} |
| \end{table} |
| |
| \paragraph{Setup.} |
| All interpretability analyses use a 2-layer, 1-head, 128-dimensional transformer |
| trained on 100K examples, evaluated on 2{,}600 held-out problems across 26 splits. |
| This model achieves 95.5\% accuracy with \sorl{} abstraction tokens and 0.1\% without; making it the clearest test-bed for causal analysis. |
| All results are reproducible from the released code. |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Performance: \sorl{} vs.\ \sft{} on undersized architectures} |
| \label{app:performance} |
| |
| \begin{table}[t] |
| \centering\small |
| \begin{tabular}{llrrrr} |
| \toprule |
| Architecture & Data & Baseline & SoRL & Gap & C6 gap \\ |
| \midrule |
| \texttt{1L/2H/256d} & 10K & 10\% & \textbf{19\%} & \textcolor{green!50!black}{\textbf{+9\%}} & \textcolor{green!50!black}{\textbf{+18\%}} \\ |
| & 25K & 32\% & 26\% & $-7\%$ & \textcolor{green!50!black}{\textbf{+10\%}} \\ |
| & 50K & 44\% & \textbf{65\%} & \textcolor{green!50!black}{\textbf{+21\%}} & \textcolor{green!50!black}{\textbf{+34\%}} \\ |
| & 100K & 49\% & \textbf{65\%} & \textcolor{green!50!black}{\textbf{+16\%}} & \textcolor{green!50!black}{\textbf{+31\%}} \\ |
| \midrule |
| \texttt{1L/3H/510d} & 10K & 36\% & \textbf{52\%} & \textcolor{green!50!black}{\textbf{+16\%}} & \textcolor{green!50!black}{\textbf{+30\%}} \\ |
| & 25K & 46\% & \textbf{60\%} & \textcolor{green!50!black}{\textbf{+14\%}} & \textcolor{green!50!black}{\textbf{+22\%}} \\ |
| & 50K & 53\% & \textbf{72\%} & \textcolor{green!50!black}{\textbf{+19\%}} & \textcolor{green!50!black}{\textbf{+38\%}} \\ |
| & 100K & 67\% & \textbf{83\%} & \textcolor{green!50!black}{\textbf{+16\%}} & \textcolor{green!50!black}{\textbf{+26\%}} \\ |
| \midrule |
| \texttt{2L/1H/128d} & 10K & 16\% & \textbf{36\%} & \textcolor{green!50!black}{\textbf{+21\%}} & \textcolor{green!50!black}{\textbf{+39\%}} \\ |
| & 25K & 40\% & \textbf{55\%} & \textcolor{green!50!black}{\textbf{+15\%}} & \textcolor{green!50!black}{\textbf{+23\%}} \\ |
| & 50K & 59\% & \textbf{87\%} & \textcolor{green!50!black}{\textbf{+28\%}} & \textcolor{green!50!black}{\textbf{+50\%}} \\ |
| & 75K & 75\% & \textbf{87\%} & \textcolor{green!50!black}{\textbf{+12\%}} & \textcolor{green!50!black}{\textbf{+5\%}} \\ |
| & 100K & 73\% & \textbf{95\%} & \textcolor{green!50!black}{\textbf{+22\%}} & \textcolor{green!50!black}{\textbf{+33\%}} \\ |
| \bottomrule |
| \end{tabular} |
| \caption{\sorl{} ($K{=}1$, $|\mathcal{A}|{=}30$) vs.\ \sft{} baseline on |
| undersized architectures across data sizes. |
| \textbf{Gap} = overall accuracy gain; \textbf{C6 gap} = gain on |
| 6-deep carry cascades (the hardest split). |
| \sorl{} wins in \textbf{12 of 13} (architecture, data-size) pairs; |
| the single exception is \texttt{1L/2H/256d} at 25K, where the model |
| is undertrained (accuracy still rising at epoch 20). |
| \sorl{} wins on C6 in \textbf{all 13} configurations.} |
| \label{tab:undersized-wins} |
| \end{table} |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Causal ablations} |
| \label{app:causal} |
| |
| To confirm that \sorl{} tokens are causally necessary (not merely correlated |
| with correct outputs), we run three intervention conditions on model |
| \texttt{2L/1H/128d} (100K), evaluated across 2{,}600 held-out problems: |
| |
| \begin{itemize}[nosep] |
| \item \textbf{Shuffle}: randomly permute all abstraction tokens within |
| each sequence (token identities preserved; positional assignment destroyed). |
| \item \textbf{Random}: replace each token with a draw uniform over the |
| 30-token codebook (identity and position both destroyed). |
| \item \textbf{Knockout}: replace every token with a fixed \texttt{[UNK]} |
| embedding (strongest intervention; removes all information). |
| \end{itemize} |
| |
| Table~\ref{tab:ablation-splits} shows per-split accuracy under each condition. |
| |
| \begin{table}[h] |
| \centering\small |
| \begin{tabular}{llrrrr} |
| \toprule |
| Family & Split & Baseline & Shuffle & Random & Knockout \\ |
| \midrule |
| \multirow{4}{*}{\textit{Addition (easy)}} & S0 (no carry) & 100\% & 24\% & 28\% & 0\% \\ |
| & S1 & 100\% & 17\% & 9\% & 0\% \\ |
| & S2 & 100\% & 22\% & 10\% & 0\% \\ |
| & random & 100\% & 26\% & 8\% & 0\% \\ |
| \midrule |
| \multirow{4}{*}{\textit{Addition cascade (hard)}} & C3 (3-deep) & 96\% & 28\% & 14\% & 0\% \\ |
| & C4 (4-deep) & 99\% & 25\% & 13\% & 0\% \\ |
| & C5 (5-deep) & 99\% & 23\% & 19\% & 0\% \\ |
| & C6 (6-deep) & 97\% & 27\% & 15\% & 0\% \\ |
| \midrule |
| \multirow{1}{*}{\textit{Subtraction (easy)}} & random & 100\% & 46\% & 12\% & 0\% \\ |
| \midrule |
| \multirow{3}{*}{\textit{Subtraction cascade (hard)}} & M3 (3-deep borrow) & 100\% & 22\% & 1\% & 0\% \\ |
| & M4 (4-deep borrow) & 85\% & 6\% & 0\% & 0\% \\ |
| & M5 (5-deep borrow) & 57\% & 3\% & 0\% & 2\% \\ |
| \midrule |
| \multicolumn{2}{l}{\textbf{Overall}} & \textbf{95.5\%} & 26.6\% & 12.3\% & 0.1\% \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Per-split causal ablation (\texttt{2L/1H/128d}, 100K training examples). |
| \textbf{Shuffle} preserves token identity but destroys positional assignment; |
| \textbf{Random} destroys both; \textbf{Knockout} removes all token information. |
| Generated by \texttt{paper/results/result\_ablation\_splits/run.py}.} |
| \label{tab:ablation-splits} |
| \end{table} |
| |
| \paragraph{Commentary.} |
| Knockout reduces accuracy to $\leq$2\% on every split, confirming that the model has offloaded computation into the routing tokens. |
| Three patterns are notable: |
| |
| \begin{itemize}[nosep] |
| \item \textbf{Shuffle $>$ Random on easy splits.} |
| On addition S0 (no carry), shuffle yields 24\% vs.\ random's 28\%; the gap is small and reflects that no single position is critical - any token in any position is roughly as bad as another. |
| \item \textbf{Shuffle $>$ Random on cascade splits (C3--C6).} |
| On 4--6-deep carry cascades, shuffle (23--28\%) consistently |
| outperforms random (13--19\%). |
| When tokens are shuffled, a cascade position receives a token |
| from another cascade position β a wrong token but from the |
| ``right family'', producing a systematic one-off carry error. |
| Random tokens provide no structural signal at all, making |
| cascade resolution impossible. |
| \item \textbf{Borrow cascades are uniquely sensitive.} |
| Sub-M4 (4-deep borrow) drops from 85\% baseline to 6\% under |
| shuffle and 0\% under random - a 79\,pp collapse from shuffle |
| alone. Sub-M5 (5-deep borrow) is hardest even at baseline (57\%), |
| and all ablations reduce it to $\leq$3\%, showing that deep |
| borrow cascades are the single hardest regime and that \sorl{} |
| tokens are essential for solving them. |
| \end{itemize} |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#2: Abstraction tokens are causally necessary}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small Knockout collapses accuracy from 95.5\% to 0.1\%; shuffle hurts more than random on cascades. |
| \end{tcolorbox} |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Token-subtask heatmap} |
| \label{app:heatmap} |
| |
| \begin{figure}[h] |
| \centering |
| \includegraphics[width=0.95\linewidth]{Styles/figures/fig_token_subtask.pdf} |
| \caption{Token--subtask specialization heatmap (2-layer, 1-head, 128d; 100K examples). |
| Each cell shows $P(\text{subtask} \mid \text{token})$ over 2{,}600 held-out problems. |
| Rows are the 23 active tokens; columns are the 10 Quirke subtask labels. |
| Most tokens concentrate on 1--2 subtasks: \texttt{t20} fires 84\% on UD |
| (borrow cascade, subtraction) and \texttt{t19} fires 82\% on UD, together |
| forming a dedicated borrow-cascade detector pair. |
| \texttt{t3} specializes in SA (59\%, simple addition with no carry), |
| while \texttt{t6} concentrates on MD (64\%, simple subtraction). |
| Tokens \texttt{t1} and \texttt{t2} are more polysemantic, spreading |
| across UC, US, and MD β acting as general-purpose fallbacks.} |
| \label{fig:token-subtask} |
| \end{figure} |
| |
| Of the 30 tokens in the codebook, 23 appear in the held-out evaluation set. |
| Each active token concentrates on a narrow slice of the subtask space: the dominant subtask accounts for ${\geq}70\%$ of that token's occurrences in the majority of cases. |
| Tokens are also \emph{position-locked}: each token appears predominantly at one or two answer positions ($d_0$-$d_6$), rarely crossing position boundaries. |
| Representative examples are shown in Table~\ref{tab:token-profiles}: |
| token \texttt{t21} fires 93\% of the time on US (sum-9 cascade, addition) |
| with digit sum $\equiv 9 \pmod{10}$ in 95\% of cases; |
| token \texttt{t23} is the subtraction mirror (UD, 88\%, position $d_3$). |
| |
| \begin{table}[h] |
| \centering\small |
| \begin{tabular}{llrrll} |
| \toprule |
| Token & Pos & $n$ & Top subtask & Purity & Op \\ |
| \midrule |
| \texttt{t21} & $d_3$ & 719 & US & 93\% & add (94\%) \\ |
| \texttt{t23} & $d_3$ & 687 & UD & 88\% & sub (93\%) \\ |
| \texttt{t6} & $d_2$ & 1438 & US & 52\% & add (79\%), sum${\equiv}9$: 78\% \\ |
| \texttt{t7} & $d_2$ & 1026 & UB/UD & 90\% & sub (100\%) \\ |
| \texttt{t14} & $d_4$ & 1242 & UD & 73\% & sub (81\%) \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Selected token profiles. Each token specialises in a specific |
| subtask at a specific answer position. ``Purity'' = fraction of |
| occurrences where the top subtask applies.} |
| \label{tab:token-profiles} |
| \end{table} |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#3: Tokens spontaneously specialize by subtask and position}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small 23 of 30 tokens are active; each locks to 1--2 Quirke subtasks (${\geq}70\%$ purity) and 1--2 answer positions. |
| \end{tcolorbox} |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Guided computation via token intervention} |
| \label{app:guided} |
| |
| If \sorl{} tokens encode the \emph{computational route} rather than just the answer, swapping a token at a mispredicted position should \emph{fix} the prediction β without retraining, without accessing internal activations, and without modifying the weights. |
| |
| We test this with \emph{surgical swap}: for each wrong prediction, |
| we try replacing the abstraction token at each answer position with |
| every other token in the codebook (29 candidates $\times$ 5 positions = 145 interventions per example) and measure how many wrong predictions become correct β and how many previously-correct predictions break. |
| |
| \paragraph{Results.} |
| At positions $d_0$-$d_2$ (the carry-heavy positions), a fixing swap exists for 27-31\% of mispredicted examples. |
| The best single swap is replacing \texttt{t16} with \texttt{t25} at $d_1$: this fixes 10 wrong predictions while breaking only 5 correct ones (a 2:1 fix-to-break ratio), all on carry cascade splits (C4-C6). |
| Position $d_3$ and $d_4$ are harder to fix surgically (fix rates 8\% and 2\%), consistent with those positions encoding longer-range carry state that a single-position swap cannot resolve. |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#4: Single token swaps fix mispredicted carry-heavy examples}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small Replacing one abstraction token fixes 27--31\% of wrong predictions at carry-heavy positions, with no weight updates. |
| \end{tcolorbox} |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{\sorl{} tokens recover circuits from Quirke et al.} |
| \label{app:quirke-analogy} |
| |
| \citet{quirke_2024_addsub_preprint} identify, via PCA of internal |
| residual-stream activations, a \emph{tri-state carry classifier} at each digit position $n$: the hidden state encodes which of three carry regimes |
| applies β |
| $\text{ST}_n = 0$ (digit sum $< 9$; carry cannot propagate), |
| $\text{ST}_n = 1$ (digit sum $> 9$; carry will propagate), or |
| $\text{ST}_n = U$ (digit sum $= 9$; carry state \emph{uncertain}, depends |
| on lower positions). |
| This trichotomy is the core circuit for addition; borrow cascades have an analogous structure for subtraction. |
| |
| \sorl{} recovers the same trichotomy \emph{without access to activation |
| data or ground-truth circuit labels}, purely from the info-gain training |
| signal. |
| The correspondence is visible directly in the codebook: |
| |
| \begin{itemize}[nosep] |
| \item $\text{ST}_n = U$ (sum-9 uncertain, addition) $\longleftrightarrow$ |
| \texttt{t21} at $d_3$ (US=93\%, sum${\equiv}9$: 95\%); |
| \texttt{t6} at $d_2$ (US=52\%, sum${\equiv}9$: 78\%). |
| \item $\text{ST}_n = U$ (borrow-uncertain, subtraction) $\longleftrightarrow$ |
| \texttt{t23} at $d_3$ (UD=88\%, sub=93\%); |
| \texttt{t7} at $d_2$ (UB/UD=90\%, sub=100\%). |
| \item $\text{ST}_n = 1$ (carry generated) $\longleftrightarrow$ |
| tokens concentrated on SC/MB subtasks at each position. |
| \item $\text{ST}_n = 0$ (simple digit, no carry) $\longleftrightarrow$ |
| tokens concentrated on SA/MD subtasks. |
| \end{itemize} |
| |
| Crucially, Quirke et al.\ required full activation-level mechanistic interpretability to discover this classifier; \sorl{} surfaces it as a readable token in the output sequence, accessible without any post-hoc analysis. |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#5: \sorl{} rediscovers Quirke's circuits without supervision}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small \sorl{} recovers the carry-state tri-classifier ($\{0,U,1\}$) unsupervised β externalizing what \citet{quirke_2024_addsub_preprint} needed PCA to reveal. |
| \end{tcolorbox} |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Polysemantic tokens} |
| \label{app:polysemantic} |
| |
| Not all codebook tokens are specialists. Table~\ref{tab:token-polysemanticity} |
| contrasts the most specialist with the most polysemantic tokens. |
| Token \texttt{t21} fires 94\% of the time on a single subtask (US) at a single position ($d_3$, addition only) β a true specialist. |
| Token \texttt{t1}, by contrast, is the highest-frequency token |
| ($n{=}6{,}359$, spanning all five answer positions) with no subtask |
| exceeding 24\% β it acts as a general-purpose fallback, handling |
| whichever position and carry regime was not captured by a specialist token. |
| |
| \begin{table}[h] |
| \centering\small |
| \begin{tabular}{clrrr} |
| \toprule |
| Token & Top subtask & Purity & $n$ & Positions \\ |
| \midrule |
| \multicolumn{5}{l}{\textit{Specialist (high purity)}} \\ |
| \texttt{t21} & US (cascade, add) & 94\% & 719 & 1 \\ |
| \texttt{t23} & UD (cascade, sub) & 88\% & 687 & 3 \\ |
| \texttt{t14} & UD & 74\% & 1242 & 3 \\ |
| \midrule |
| \multicolumn{5}{l}{\textit{Polysemantic (low purity)}} \\ |
| \texttt{t5} & UB & 21\% & 1377 & 4 \\ |
| \texttt{t1} & MD & 24\% & 6359 & 5 \\ |
| \texttt{t20} & UB & 24\% & 283 & 3 \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Specialist vs.\ polysemantic tokens. |
| \textbf{Purity} = fraction of occurrences where the top subtask applies. |
| \textbf{Positions} = number of distinct answer positions ($d_0$-$d_6$) |
| where the token appears. \texttt{t1} is a high-frequency fallback used |
| across all positions; \texttt{t21} is a pure sum-9 cascade detector.} |
| \label{tab:token-polysemanticity} |
| \end{table} |
| |
| Polysemanticity here mirrors the phenomenon described in neural network interpretability~\citep{elhage2022superposition}: a single token encodes multiple distinct roles, likely because the 30-token codebook has spare capacity at the overflow position ($d_0$) where carry state is most |
| variable. The specialist tokens concentrate at mid-sequence positions ($d_2$-$d_4$) where carry propagation is most structured. |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#6: The codebook mixes specialist and polysemantic tokens}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small Specialist tokens (\texttt{t21}: 94\% purity) coexist with polysemantic fallbacks (\texttt{t1}: 24\%, five positions). |
| \end{tcolorbox} |
| |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Automated token interpretation} |
| \label{app:autointerp} |
| |
| We implement a light version of the automated interpretation procedure of \citet{bills2023language}. |
| For each active token, we collect the $N{=}10$ examples from the evaluation set where the model assigned it with highest softmax confidence, |
| then ask \texttt{claude-haiku} to produce a one-sentence role description. |
| Table~\ref{tab:auto-interp} shows results for the 8 highest-confidence tokens. |
| |
| \begin{table}[ht] |
| \centering\small |
| \begin{tabular}{clrp{5.5cm}} |
| \toprule |
| Token & Top subtask & Conf. & Auto-interpretation \\ |
| \midrule |
| \texttt{t0} & UC (47\%) & 1.00 & Token t0 marks the tens digit position in addition problems, regardless of carry state or sum value. \\ |
| \texttt{t2} & UC (70\%) & 0.99 & Token t2 outputs the ones digit (0) when adding two numbers whose ones digits sum to 10 or more. \\ |
| \texttt{t1} & UC (30\%) & 0.99 & This token routes to the fourth digit position during addition when a carry from the previous position must be incorporated. \\ |
| \texttt{t3} & UC (44\%) & 0.94 & Token t3 routes to the hundreds position (d3) when processing carries from the tens column in addition. \\ |
| \texttt{t5} & MD (65\%) & 0.93 & Token t5 routes cases where the ones digit result is 0, spanning multiple subtasks and operations. \\ |
| \texttt{t8} & MD (26\%) & 0.91 & Token t8 activates when processing the tens digit (d2) across addition/subtraction with various carry states. \\ |
| \texttt{t10} & UB (41\%) & 0.88 & Token t10 routes subtraction problems requiring borrow propagation at mid-to-late digit positions. \\ |
| \texttt{t6} & UC (27\%) & 0.88 & Token t6 routes cases where the ones digit result is 0, regardless of operation or carry state. \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Automated interpretation of the 8 highest-confidence \sorl{} abstraction tokens |
| (\`{a} la \citealt{bills2023language}). |
| For each token, the 10 examples with highest softmax confidence are shown to an LLM, |
| which produces a one-sentence role description. |
| \textbf{Conf.} = mean softmax probability of the assigned token. |
| High-confidence specialists receive crisp, position- and operation-specific descriptions; |
| polysemantic tokens (not shown) produce broader descriptions.} |
| \label{tab:auto-interp} |
| \end{table} |
| |
| \begin{tcolorbox}[colback=gray!6, colframe=gray!40, |
| fonttitle=\bfseries\small, title={Finding \#7: Automated interpretation matches ground-truth subtask labels}, |
| left=5pt, right=5pt, top=4pt, bottom=4pt] |
| \small High-confidence tokens get crisp role descriptions; polysemantic tokens get appropriately vague ones β without accessing ground-truth labels. |
| \end{tcolorbox} |
| % βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| \subsection{Training and evaluation details} |
| \label{app:training} |
| |
| \begin{table}[h] |
| \centering\small |
| \begin{tabular}{lrrrrr} |
| \toprule |
| Architecture & Layers & Heads & Hidden & FFN & Params \\ |
| \midrule |
| \texttt{1L/2H/256d} & 1 & 2 & 256 & 1024 & ${\sim}$0.3M \\ |
| \texttt{1L/3H/510d} & 1 & 3 & 510 & 2040 & ${\sim}$2.0M \\ |
| \texttt{2L/1H/128d} & 2 & 1 & 128 & 512 & ${\sim}$0.1M \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Undersized architectures (Table~\ref{tab:undersized-wins}). |
| All use pre-norm, GeLU, Qwen3 tokenizer.} |
| \end{table} |
| |
| \begin{table}[h] |
| \centering\small |
| \begin{tabular}{ll} |
| \toprule |
| Hyperparameter & Value \\ |
| \midrule |
| Optimizer & AdamW \\ |
| Learning rate & $8\times10^{-5}$ \\ |
| $(\beta_1,\,\beta_2)$ & $(0.9,\;0.999)$ \\ |
| Weight decay & $0.01$ \\ |
| LR schedule & Linear warmup (3\%) then constant \\ |
| Batch size & 64 \\ |
| Epochs & 20 \\ |
| \sorl{} codebook & $|\mathcal{A}|{=}30$, $K{=}1$ \\ |
| $\alpha_{\text{info-gain}},\,\alpha_{\text{abs}},\,\alpha_{\text{zipf}}$ |
| & $10.0,\;0.1,\;1.0$ \\ |
| \bottomrule |
| \end{tabular} |
| \caption{Shared training hyperparameters.} |
| \end{table} |
| |
| Evaluation uses fixed-length autoregressive decoding (no teacher forcing): |
| the model generates answer digits $d_0 \to d_6$ using its own predictions, with abstraction tokens inserted via the \sorl{} search-then-recurse procedure (matching training). Accuracy is measured on 100 held-out examples per split (seed 42; \texttt{thoughtworks/arithmetic-sorl-data}). |
| |
| """ |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="SoRL Arithmetic Dashboard") as app: |
| gr.Markdown("# SoRL Arithmetic Dashboard") |
| gr.Markdown("[Models](https://huggingface.co/thoughtworks/arithmetic-sorl) Β· " |
| "[Datasets](https://huggingface.co/datasets/thoughtworks/arithmetic-sorl-data) Β· " |
| "[WandB](https://wandb.ai/nlp_and_interpretability/sorl-arithmetic) Β· " |
| "[Code](https://github.com/fangyuan-ksgk/mod_gpt/tree/amir/arithmetic) Β· " |
| "[Quirke et al. 2024](https://arxiv.org/abs/2402.02619)") |
|
|
| models_state = gr.State([]) |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("Models"): |
| with gr.Accordion("What is SoRL?", open=False): |
| gr.Markdown(""" |
| **Self-Organized Reinforcement Learning (SoRL)** augments a transformer with learned *abstraction tokens* β |
| a small auxiliary vocabulary (e.g. 30 tokens) inserted at regular intervals (every K positions) into the sequence. |
| |
| ``` |
| Standard SFT: 1 2 3 4 5 6 + 6 5 4 3 2 1 = 1 8 7 9 7 7 |
| |
| SoRL (K=4): 1 2 3 [a] 4 5 6 [a] + 6 5 [a] 4 3 2 [a] 1 = 1 [a] 8 7 9 [a] 7 7 |
| ``` |
| |
| **Training**: insert placeholders β search for best abstraction values β train with info-gain loss |
| (abstractions must reduce prediction uncertainty). **Eval**: autoregressive, errors propagate. |
| """) |
|
|
| with gr.Row(): |
| arch_filter = gr.Dropdown(["All", "2L/3H/510d", "1L/3H/510d", "1L/2H/256d", "2L/1H/128d"], |
| value="All", label="Architecture") |
| refresh_btn = gr.Button("Refresh from HF", variant="primary") |
|
|
| summary_text = gr.Markdown("Click Refresh to load.") |
| queue_status = gr.Markdown("") |
|
|
| summary_md = gr.Markdown(build_summary_markdown()) |
| main_table = gr.Dataframe( |
| headers=["Ops", "Data", "Arch", "Baseline", "SoRL", "Config", "B_hf", "S_hf", "B_wandb", "S_wandb"], |
| datatype=["str", "str", "str", "markdown", "markdown", "str", "markdown", "markdown", "markdown", "markdown"], |
| interactive=False, |
| ) |
|
|
| with gr.Accordion("Hard Split Comparison (C4, C5, C6, M4, M5)", open=False): |
| gr.Markdown("Left = Baseline, Right = SoRL. **Bold** = winner. C-splits = hot carry chains (varied answers).") |
| hard_table = gr.Dataframe( |
| headers=["Ops", "Data", "Arch", "Config", |
| "B_add_C4", "S_add_C4", "B_add_C5", "S_add_C5", |
| "B_add_C6", "S_add_C6", "B_sub_M4", "S_sub_M4", |
| "B_sub_M5", "S_sub_M5"], |
| datatype=["str", "str", "str", "str"] + ["markdown"] * 10, |
| interactive=False, |
| ) |
|
|
| with gr.Accordion("Data Efficiency & Undersized Models", open=False): |
| gr.Image("static_figures/fig_data_efficiency.png") |
| gr.Markdown("At 10K, SoRL K=1 abs30 reaches **96.7%** vs baseline **76.6%** (+20pp). By 50K both hit 100%.") |
| gr.Image("static_figures/fig_undersized.png") |
| gr.Markdown("Undersized 2L/1H/128d: baseline 50% β SoRL **85%** (+35pp). Abstraction tokens compensate for limited capacity.") |
|
|
| with gr.Accordion("Per-Split Detail", open=False): |
| model_selector = gr.Dropdown(label="Model", choices=[], allow_custom_value=True) |
| detail_btn = gr.Button("Show splits") |
| detail_table = gr.Dataframe(headers=["Split", "Accuracy", "N"], interactive=False) |
|
|
| |
| with gr.TabItem("Interpretability"): |
| gr.Markdown("""## SoRL tokens externalize arithmetic circuits |
| |
| ### Background: how multi-digit arithmetic works |
| |
| Adding two 6-digit numbers like `345678 + 657893` requires tracking **carries** β when |
| a column sums to 10 or more, you carry 1 to the next column. A **carry cascade** happens |
| when carries chain through multiple consecutive columns (e.g., `999 + 1 = 1000`). |
| |
| We evaluate on **C-splits** β problems grouped by how many consecutive columns produce carries, |
| with varied (non-zero) answer digits: |
| |
| | Split | Meaning | Example | Why it's hard | |
| |-------|---------|---------|---------------| |
| | **C1** | 1 carry | `345678 + 100921 = 446599` | Single carry β easy | |
| | **C2** | 2 consecutive carries | `503847 + 297162 = 801009` | Carry propagates once | |
| | **C3** | 3 consecutive carries | `145232 + 957868 = 1103100` | Must track 3-step cascade | |
| | **C4** | 4 consecutive carries | `780149 + 819959 = 1600108` | Longer cascade chain | |
| | **C5** | 5 consecutive carries | `553777 + 847927 = 1401704` | Nearly full cascade | |
| | **C6** | 6 carries (max) | `503847 + 996167 = 1500014` | Every column cascades | |
| |
| [Quirke et al. (2024)](https://arxiv.org/abs/2402.02619) showed that transformers learn |
| these carry/borrow circuits internally, but they're hidden in activations β discoverable only |
| through PCA, probing, or ablation at the activation level. |
| |
| ### Quirke's subtask definitions |
| |
| At each digit position, the model must compute one of these operations |
| ([Quirke et al. Β§3.2-3.3](https://arxiv.org/abs/2402.02619)): |
| |
| **Addition:** |
| | Subtask | Meaning | Quirke eq. | |
| |---------|---------|-----------| |
| | **SA** | Simple Add: `(dβ + dβ) mod 10` | β | |
| | **SC** | Sum Carry: `dβ + dβ β₯ 10`, produces a local carry | β | |
| | **SS** | Sum-of-9: `dβ + dβ = 9`, carry state is **uncertain** | eq. 2: STn = U | |
| | **UC** | Use Carry: this digit's answer depends on carry from right | eq. 2: STn = 1 | |
| | **US** | Use Sum-9 cascade: carry propagates through a chain of sum-9 digits | eq. 4-6 | |
| |
| **Subtraction** (same structure with borrows replacing carries): |
| | Subtask | Meaning | Quirke eq. | |
| |---------|---------|-----------| |
| | **MD** | Base Diff: `(dβ - dβ) mod 10` | β | |
| | **MB** | Make Borrow: `dβ < dβ`, produces a local borrow | eq. 7: MBn = 1 | |
| | **ME** | Equal digits: `dβ = dβ`, borrow state is **uncertain** | eq. 7: MBn = U | |
| | **UB** | Use Borrow: answer depends on borrow from right | β | |
| | **UD** | Cascade borrow: borrow propagates through equal-digit chain | β | |
| |
| **SoRL makes these circuits directly observable as tokens.** We show that: |
| 1. More abstraction vocabulary β richer representations |
| 2. Different tokens map to different arithmetic operations |
| 3. Token identity is causally necessary for correct answers |
| 4. These tokens correspond to Quirke's carry/borrow circuits |
| """) |
|
|
| gr.Markdown("""### 1. More vocabulary β higher accuracy and richer representations |
| |
| Increasing the abstraction vocabulary from 10 to 30 tokens improves accuracy, especially |
| on undersized models where capacity is limited. The 2L/1H/128d model gains +16pp |
| going from abs10 to abs30. The standard model saturates at ~100% regardless of vocab size. |
| """) |
| gr.Image("static_figures/fig_vocab_scaling.png") |
|
|
| gr.Markdown("""With more tokens available, the model uses more of them and distributes usage |
| more evenly. abs10 collapses to 5 tokens; abs30 uses 18 with higher entropy. |
| """) |
| gr.Image("static_figures/fig_diversity.png") |
|
|
| gr.Markdown("""### 2. Tokens map to Quirke's subtasks |
| |
| Within addition alone, tokens form a **carry computation spectrum** β sorted by how often |
| they appear at positions with an active carry. Tokens at the top handle "no carry" positions |
| (SA, SS). Tokens at the bottom are pure carry tokens (UC=87-100%). The model has learned |
| a graded vocabulary from "nothing happening" to "full carry cascade." |
| """) |
| gr.Image("static_figures/fig_addition_hierarchy.png") |
|
|
| gr.Markdown("""The full heatmap across both addition and subtraction:""") |
| gr.Image("static_figures/fig_token_subtask.png") |
|
|
| gr.Markdown("""### 3. Three token vignettes |
| |
| Each vignette introduces one token, explains its role across many examples, then |
| walks through a concrete problem showing it in action. |
| |
| --- |
| |
| **Vignette 1: Token t2 β the carry cascade propagator** |
| |
| Across 2,550 occurrences, t2 appears primarily at **UC (Use Carry, 31%)** and **UB (Use Borrow, 14%)** |
| positions β places where the answer digit depends on a carry or borrow propagating from the right. |
| It appears at multiple answer positions (d1=28%, d3=22%, d0=17%), always where cascade state |
| must be tracked. It's the model's way of saying: *"this digit depends on what happened to the right."* |
| |
| Here it is in a 5-carry cascade (`959271 + 040756 = 1000027`): |
| |
| ``` |
| Position: d0 d1 d2 d3 d4 d5 d6 |
| Answer: 1 0 0 0 0 2 7 |
| Subtask: UC US US US US SC SA |
| Token: t2 t2 t6 t2 t1 β β |
| β β β |
| t2 marks every position that must propagate the carry cascade |
| ``` |
| |
| d5-d6 are easy (SC/SA) β no abstraction needed. d1-d4 are all US (sum-of-9): the carry |
| cascades left through four consecutive uncertain positions. t2 appears at d0, d1, and d3 β |
| every position that needs to resolve the cascade. |
| |
| --- |
| |
| **Vignette 2: Token t7 β the borrow cascade specialist** |
| |
| t7 appears 1,026 times and is **100% subtraction** β never appears in addition problems. |
| It splits between **UB (Use Borrow, 47%)** and **UD (cascade borrow, 43%)**, and is locked |
| to position d2 (95%). When the model assigns t7, it means: *"borrow is propagating through |
| this position in a subtraction."* |
| |
| Here it is in a 4-borrow cascade (`698401 - 128406 = 0569995`): |
| |
| ``` |
| Position: d0 d1 d2 d3 d4 d5 d6 |
| Answer: 0 5 6 9 9 9 5 |
| Subtask: MD MD UB UD UD UD MB |
| Token: t1 t5 t7 t7 t1 β β |
| β β |
| t7 marks the borrow cascade positions |
| ``` |
| |
| d6 triggers the borrow (MB). The borrow cascades left through d3-d5 (UD = equal digits, |
| borrow uncertain). t7 appears at d2-d3 where the cascade must be resolved. Note that t7 |
| never appears in the addition vignette above β the model has learned operation-specific tokens. |
| |
| --- |
| |
| **Vignette 3: Token t3 β the simple addition marker** |
| |
| t3 appears 2,084 times, primarily at **UB (37%)** and **US (27%)** positions, but critically |
| at positions where no carry computation is needed. In easy problems (S0), it marks the |
| "nothing interesting here" positions. |
| |
| Here it is in a no-carry addition (`417080 + 531003 = 0948083`): |
| |
| ``` |
| Position: d0 d1 d2 d3 d4 d5 d6 |
| Answer: 0 9 4 8 0 8 3 |
| Subtask: SA SS SA SA SA SA SA |
| Token: t6 t15 t3 t8 t3 β β |
| β β |
| t3 at simple-add positions (no carry) |
| ``` |
| |
| Every position is SA (simple add) β no cascades at all. t3 appears at d2 and d4 (both SA). |
| Other tokens (t6, t15, t8) handle positions with specific digit-sum values, but none of the |
| carry-cascade tokens (t2, t7) appear anywhere. |
| |
| --- |
| |
| **The key insight:** the same token IDs recur consistently across different problems. |
| t2 = carry cascade. t7 = borrow cascade (subtraction only). t3 = no cascade needed. |
| The model has learned a **vocabulary for arithmetic reasoning** that maps directly to |
| Quirke's circuit definitions β without any supervision about carry logic. |
| |
| ### 4. Surgical intervention: the right token fixes hard-case failures |
| |
| *All experiments below use the 1L/3H/510d model with K=1 abs30, trained on 100K examples. |
| This model gets C3-C6 accuracy ~70% β enough correct examples for comparison, enough errors to fix.* |
| |
| **The two carry tokens: t9 vs t21** |
| |
| The model learned two tokens that both appear at carry-related positions, but with different specializations: |
| |
| | | **t9** (n=573) | **t21** (n=719) | |
| |---|---|---| |
| | Position | d2-d4 (spread) | d3 only (100%) | |
| | Sum = 9 | 56% | **95%** | |
| | Carry rate | 22% | 3% | |
| | Difficulty | **easy** (S0=23%, S2=28%) | **hard** (S5=37%, S6=37%) | |
| | Role | "shallow carry, maybe sum-9" | "deep cascade, definitely sum-9" | |
| |
| t9 is a **shallow/ambiguous** token β it appears in easy problems where carry state doesn't matter much. |
| t21 is a **deep cascade specialist** β it appears specifically in 5-6 carry cascades where |
| every digit's sum is exactly 9 (Quirke's uncertain U state, eq. 2). |
| |
| **The failure mode: using t9 at hard-cascade positions** |
| |
| When the model encounters a hard cascade (C5/C6), it sometimes assigns t9 (the shallow token) |
| instead of t21 (the cascade specialist). This is like using the wrong circuit β the model |
| treats a deep cascade as a shallow one and gets the carry propagation wrong. |
| |
| **The fix: globally replacing t9 with t21** |
| |
| We test what happens when we force the model to always use t21 (the cascade specialist) |
| instead of t9: |
| |
| ``` |
| Normal All t9βt21 Effect |
| C3-C6 (hard carries): 70% 93% +23pp β hard cases dramatically improve |
| S5 (5 cascades): 27% 92% +65pp β nearly fixes the hardest split |
| S0 (no carries): 99% 74% -25pp β easy cases get worse |
| ``` |
| |
| Forcing t21 everywhere is like telling the model "always assume deep cascade" β it fixes |
| hard problems (+65pp on S5!) but hurts easy ones where no cascade exists. The model needs |
| to *correctly choose* between t9 and t21 based on the input, and its main failure mode on |
| hard cases is choosing too conservatively (t9 instead of t21). |
| |
| **Individual transplants confirm this:** |
| |
| ``` |
| 014560 + 125450 = 0140010 |
| Wrong: 0139010 β t9 at d2 (shallow token at carry position) |
| Fixed: 0130010 β transplant correct token β d2 fixed β |
| |
| 332200 + 868010 = 1200210 |
| Wrong: 1190210 β t9 at d1 (shallow token at carry position) |
| Fixed: 1100210 β transplant correct token β d1 fixed β |
| ``` |
| |
| Each transplant fixes the specific digit where the wrong carry token was assigned. |
| This proves **token identity causally determines carry computation** β the wrong |
| token = wrong carry state = wrong answer digit. |
| """) |
|
|
| gr.Markdown("""### 3. Tokens spread across digit positions |
| |
| Unlike K=4 (where tokens are locked to fixed positions), K=1 tokens serve multiple |
| digit positions. This means token identity encodes **what computation to perform**, |
| not just **where in the sequence we are**. |
| """) |
| gr.Image("static_figures/fig_token_positions.png") |
|
|
| gr.Markdown("""### 4. Causal verification: token identity is essential for cascades |
| |
| Three interventions test whether tokens carry real information: |
| - **Knockout**: remove all abstraction tokens β **0% accuracy** (total dependence) |
| - **Shuffle**: randomly permute token IDs within the sequence β accuracy drops |
| - **Random**: replace tokens with random IDs β accuracy drops further |
| |
| The key finding: **the harder the cascade, the more token identity matters.** |
| For S5-S6 (5-6 consecutive carries), shuffling drops accuracy by 56-66 percentage points. |
| This directly parallels Quirke's finding that deeper cascades require more complex |
| internal circuits β SoRL externalizes these circuits as token sequences that must be correct. |
| """) |
| gr.Image("static_figures/fig_causal_ablation.png") |
|
|
| gr.Markdown("""### 5. Correspondence with Quirke's tri-state carry classifier |
| |
| Quirke's eq. 2 defines a **tri-state carry classifier** STn for each digit position: |
| - **STn = 0**: digit sum β€ 8, definitely no carry |
| - **STn = 1**: digit sum β₯ 10, definitely carry |
| - **STn = U**: digit sum = 9, carry depends on cascade from right (uncertain) |
| |
| In our K=4 abs30 model (where tokens are position-locked, forcing sharper specialization): |
| - **Token t3**: maps to SA (simple add) with 0% carry β Quirke's **STn = 0** |
| - **Token t6**: maps to UC with input sum mod 10 = 9 in 92% of cases β Quirke's **STn = U** |
| - **Tokens t8, t9**: map to UC with carry = 100% β Quirke's **STn = 1** |
| |
| The model independently discovered the tri-state classifier from the info-gain loss alone, |
| with no supervision about carry logic. This is the same circuit Quirke found via PCA of |
| hidden activations β but here it is readable directly from the token sequence. |
| |
| *Analysis: K=1 abs30 and K=4 abs30, 2L/3H/510d, 100K training examples, 4400 eval examples.* |
| """) |
|
|
| |
| with gr.TabItem("LaTeX"): |
| gr.Markdown("### LaTeX Scratchpad β copy sections directly into Overleaf") |
| gr.Markdown("Two blocks: **main body** (setup + Finding #1 + forward ref) and **appendix** (Findings #2β5, all tables, training details). Copy each into the appropriate place in your paper.") |
|
|
| gr.Markdown("#### Β§ Main body β Β§Arithmetic setup (paste into paper body)") |
| gr.Code(value=LATEX_ARITHMETIC_SETUP, label="arithmetic_setup.tex", |
| language=None, interactive=False) |
|
|
| gr.Markdown("#### Β§ Appendix β full analysis (paste into appendix)") |
| gr.Markdown("Includes: per-split ablation table, token profiles, guided computation, Quirke analogy, training details. Requires `tcolorbox`, `booktabs`, `xcolor`, `multirow`, `enumitem`.") |
| gr.Code(value=LATEX_APPENDIX, label="appendix_arithmetic.tex", |
| language=None, interactive=False) |
|
|
| |
| with gr.TabItem("About"): |
| eval_info_md = gr.Markdown("") |
|
|
| gr.Markdown("""### Using the models |
| |
| All models are on [HuggingFace](https://huggingface.co/thoughtworks/arithmetic-sorl). |
| Code is on the [`amir/arithmetic`](https://github.com/fangyuan-ksgk/mod_gpt/tree/amir/arithmetic) branch. |
| |
| ```python |
| import torch |
| from arithmetic.data.hub import load_model |
| from arithmetic.training.evaluate import ArithmeticEvaluator |
| from transformers import AutoTokenizer |
| |
| # Load model + tokenizer |
| model, config, metrics = load_model("add_sub_sorl_v1_abs30_K1_100K", device="cuda") |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") |
| |
| # Run full evaluation with per-split accuracy |
| evaluator = ArithmeticEvaluator(model, tokenizer, device="cuda") |
| results = evaluator.run(ops="add_sub", K=1, n_per_split=100) # K=None for baseline |
| evaluator.print_table(results) |
| ``` |
| |
| To inspect abstraction tokens on a single example: |
| |
| ```python |
| from arithmetic.training.train import QWEN3_TOKEN_MAP, QWEN3_INV_MAP |
| from sorl.sorl_trainer import infer_insert_mask, insert_tokens_with_padding, expand_prompt_len |
| |
| base_v = model.vocab_sizes[0].item() |
| |
| # Encode: 123456+654321= |
| prompt = [1,2,3,4,5,6, 10, 6,5,4,3,2,1, 12] |
| qwen_ids = torch.tensor([QWEN3_TOKEN_MAP[t] for t in prompt], device="cuda") |
| |
| # Pad to full 21 tokens (14 prompt + 7 dummy answer), insert abstractions, recurse |
| seq = torch.cat([qwen_ids, torch.zeros(7, dtype=torch.long, device="cuda")]) |
| ids = seq.unsqueeze(0) |
| im = infer_insert_mask(ids, K=1, attention_mask=torch.ones_like(ids)) |
| ep = expand_prompt_len(torch.tensor([14], device="cuda"), im) |
| ed, ea = insert_tokens_with_padding(ids, torch.ones_like(ids), im, model.vocab_sizes[0], 151643) |
| |
| data, ppt, logits = model.recursion(ed, ea, max_iterations=2, |
| memory_span_abs=1792, memory_span_traj=1792, temperature=0.0, prompt_len=ep) |
| |
| # Separate trajectory vs abstraction tokens |
| is_abs = data[0] >= base_v |
| abstractions = data[0][is_abs] - base_v # abstraction token IDs (0-indexed) |
| print(f"Abstraction tokens: {abstractions.tolist()}") |
| # Each abstraction token encodes carry/borrow state at that position |
| ``` |
| |
| Token IDs: `0-9` = digits, `10` = `+`, `11` = `-`, `12` = `=`. |
| Abstraction tokens are integers from 1 to `abs_vocab` (0 is the placeholder before recursion). |
| """) |
|
|
|
|
| |
| |
| |
|
|
| def on_refresh(arch): |
| models = fetch_all_models() |
| df = build_comparison_table(models, arch_filter=arch, enriched_only=False) |
|
|
| n_models = len(models) |
| summary = f"**{n_models}** models | Arch filter: {arch}" |
| q_status = get_queue_status_text(n_models) |
| eval_info = build_eval_info(models) |
| new_summary_md = build_summary_markdown() |
|
|
| main_cols = ["Ops", "Data", "Arch", "Baseline", "SoRL", "Config", "B_hf", "S_hf", "B_wandb", "S_wandb"] |
| main_df = df[main_cols] if all(c in df.columns for c in main_cols) else pd.DataFrame() |
|
|
| hard_cols = [c for c in df.columns |
| if (c.startswith("B_") or c.startswith("S_")) and "wandb" not in c] |
| hard_base = ["Ops", "Data", "Arch", "Config"] if "Arch" in df.columns else ["Ops", "Data", "Config"] |
| hard_df = df[hard_base + hard_cols] if len(df) > 0 else pd.DataFrame() |
|
|
| model_names = sorted([m["subfolder"].removeprefix("non_enriched/") for m in models]) |
| model_dd_update = gr.update(choices=model_names, value=model_names[0] if model_names else "") |
|
|
| return models, summary, q_status, main_df, hard_df, eval_info, model_dd_update, new_summary_md |
|
|
| def on_detail(models, name): |
| return build_detailed_splits(models, name.strip() if name else "") |
|
|
| all_outputs = [models_state, summary_text, queue_status, main_table, hard_table, eval_info_md, model_selector, summary_md] |
|
|
| refresh_btn.click(on_refresh, inputs=[arch_filter], outputs=all_outputs) |
| arch_filter.change(on_refresh, inputs=[arch_filter], outputs=all_outputs) |
| detail_btn.click(on_detail, inputs=[models_state, model_selector], outputs=[detail_table]) |
|
|
| timer = gr.Timer(300) |
| timer.tick(on_refresh, inputs=[arch_filter], outputs=all_outputs) |
| app.load(on_refresh, inputs=[arch_filter], outputs=all_outputs) |
|
|
|
|
| if __name__ == "__main__": |
| app.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, theme=gr.themes.Soft()) |
|
|