|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.colors as mcolors |
|
|
from PIL import Image |
|
|
import io |
|
|
import time |
|
|
|
|
|
|
|
|
ARC_COLORS = [ |
|
|
'#000000', |
|
|
'#0074D9', |
|
|
'#FF4136', |
|
|
'#2ECC40', |
|
|
'#FFDC00', |
|
|
'#AAAAAA', |
|
|
'#F012BE', |
|
|
'#FF851B', |
|
|
'#7FDBFF', |
|
|
'#B10DC9', |
|
|
] |
|
|
|
|
|
|
|
|
SAMPLE_PUZZLES = { |
|
|
"Pattern Fill": { |
|
|
"input": [ |
|
|
[0, 0, 0, 0, 0], |
|
|
[0, 1, 0, 1, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
[0, 1, 0, 1, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
], |
|
|
"output": [ |
|
|
[1, 1, 1, 1, 1], |
|
|
[1, 1, 1, 1, 1], |
|
|
[1, 1, 1, 1, 1], |
|
|
[1, 1, 1, 1, 1], |
|
|
[1, 1, 1, 1, 1], |
|
|
], |
|
|
"steps": 8, |
|
|
}, |
|
|
"Color Spread": { |
|
|
"input": [ |
|
|
[0, 0, 0, 0, 0], |
|
|
[0, 0, 2, 0, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
], |
|
|
"output": [ |
|
|
[2, 2, 2, 2, 2], |
|
|
[2, 2, 2, 2, 2], |
|
|
[2, 2, 2, 2, 2], |
|
|
[2, 2, 2, 2, 2], |
|
|
[2, 2, 2, 2, 2], |
|
|
], |
|
|
"steps": 10, |
|
|
}, |
|
|
"Mirror Pattern": { |
|
|
"input": [ |
|
|
[0, 0, 3, 0, 0], |
|
|
[0, 3, 0, 0, 0], |
|
|
[3, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0], |
|
|
], |
|
|
"output": [ |
|
|
[0, 0, 3, 0, 0], |
|
|
[0, 3, 0, 3, 0], |
|
|
[3, 0, 0, 0, 3], |
|
|
[0, 3, 0, 3, 0], |
|
|
[0, 0, 3, 0, 0], |
|
|
], |
|
|
"steps": 6, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def grid_to_image(grid, cell_size=40): |
|
|
"""Convert a grid to a colored image.""" |
|
|
grid = np.array(grid) |
|
|
h, w = grid.shape |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(w * cell_size / 100, h * cell_size / 100), dpi=100) |
|
|
|
|
|
cmap = mcolors.ListedColormap(ARC_COLORS) |
|
|
ax.imshow(grid, cmap=cmap, vmin=0, vmax=9) |
|
|
|
|
|
|
|
|
for i in range(h + 1): |
|
|
ax.axhline(i - 0.5, color='white', linewidth=1) |
|
|
for j in range(w + 1): |
|
|
ax.axvline(j - 0.5, color='white', linewidth=1) |
|
|
|
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.set_aspect('equal') |
|
|
|
|
|
plt.tight_layout(pad=0) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1) |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
|
|
|
return Image.open(buf) |
|
|
|
|
|
|
|
|
def interpolate_grids(start, end, progress): |
|
|
"""Create an intermediate grid state based on progress (0 to 1).""" |
|
|
start = np.array(start) |
|
|
end = np.array(end) |
|
|
|
|
|
|
|
|
diff_mask = start != end |
|
|
num_changes = diff_mask.sum() |
|
|
|
|
|
if num_changes == 0: |
|
|
return end.tolist() |
|
|
|
|
|
|
|
|
changes_to_apply = int(num_changes * progress) |
|
|
|
|
|
result = start.copy() |
|
|
change_indices = np.argwhere(diff_mask) |
|
|
|
|
|
|
|
|
for i in range(min(changes_to_apply, len(change_indices))): |
|
|
idx = tuple(change_indices[i]) |
|
|
result[idx] = end[idx] |
|
|
|
|
|
return result.tolist() |
|
|
|
|
|
|
|
|
def simulate_reasoning(puzzle_name, progress=gr.Progress()): |
|
|
"""Simulate the recursive reasoning process with visualization.""" |
|
|
if puzzle_name not in SAMPLE_PUZZLES: |
|
|
yield None, "Please select a puzzle" |
|
|
return |
|
|
|
|
|
puzzle = SAMPLE_PUZZLES[puzzle_name] |
|
|
input_grid = puzzle["input"] |
|
|
output_grid = puzzle["output"] |
|
|
num_steps = puzzle["steps"] |
|
|
|
|
|
progress(0, desc="Initializing model...") |
|
|
time.sleep(0.3) |
|
|
|
|
|
|
|
|
for step in range(num_steps + 1): |
|
|
step_progress = step / num_steps |
|
|
|
|
|
|
|
|
current_grid = interpolate_grids(input_grid, output_grid, step_progress) |
|
|
|
|
|
|
|
|
img = grid_to_image(current_grid) |
|
|
|
|
|
if step == 0: |
|
|
status = f"π§ Step {step}/{num_steps}: Reading input puzzle..." |
|
|
elif step == num_steps: |
|
|
status = f"β
Step {step}/{num_steps}: Solution found!" |
|
|
else: |
|
|
status = f"π Step {step}/{num_steps}: Refining hypothesis (latent z update)..." |
|
|
|
|
|
progress(step_progress, desc=status) |
|
|
yield img, status |
|
|
|
|
|
|
|
|
time.sleep(0.4) |
|
|
|
|
|
|
|
|
def load_puzzle(puzzle_name): |
|
|
"""Load and display the input puzzle.""" |
|
|
if puzzle_name not in SAMPLE_PUZZLES: |
|
|
return None, None, "Select a puzzle to begin" |
|
|
|
|
|
puzzle = SAMPLE_PUZZLES[puzzle_name] |
|
|
input_img = grid_to_image(puzzle["input"]) |
|
|
output_img = grid_to_image(puzzle["output"]) |
|
|
|
|
|
return input_img, output_img, f"Puzzle loaded: {puzzle_name}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="TinyThink: Glass-Box Reasoning", |
|
|
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue"), |
|
|
css=""" |
|
|
.header { text-align: center; margin-bottom: 20px; } |
|
|
.status-box { font-size: 1.2em; padding: 10px; border-radius: 8px; } |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="header"> |
|
|
<h1>π§ TinyThink: Glass-Box Recursive Reasoning</h1> |
|
|
<p style="font-size: 1.2em; color: #666;"> |
|
|
Watch a <strong>7M parameter</strong> model solve ARC-AGI puzzles by "thinking" recursively |
|
|
</p> |
|
|
<p style="font-size: 0.9em; color: #888;"> |
|
|
Based on "Less is More: Recursive Reasoning with Tiny Networks" (Samsung SAIL Montreal, 2025) |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π Select a Puzzle") |
|
|
puzzle_dropdown = gr.Dropdown( |
|
|
choices=list(SAMPLE_PUZZLES.keys()), |
|
|
label="Choose an ARC puzzle", |
|
|
value="Pattern Fill" |
|
|
) |
|
|
load_btn = gr.Button("Load Puzzle", variant="secondary") |
|
|
|
|
|
gr.Markdown("### π₯ Input Grid") |
|
|
input_display = gr.Image(label="Input", type="pil", height=250) |
|
|
|
|
|
gr.Markdown("### π― Expected Output") |
|
|
expected_output = gr.Image(label="Target", type="pil", height=250) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### π Live Reasoning State") |
|
|
gr.Markdown("*Watch the model iterate through its recursive reasoning loop*") |
|
|
|
|
|
output_display = gr.Image(label="Current Hypothesis", type="pil", height=400) |
|
|
status_text = gr.Textbox( |
|
|
label="Reasoning Status", |
|
|
value="Select a puzzle and click 'Start Reasoning' to begin", |
|
|
interactive=False, |
|
|
elem_classes=["status-box"] |
|
|
) |
|
|
|
|
|
solve_btn = gr.Button("π Start Reasoning Loop", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π How TinyThink Works |
|
|
|
|
|
The Tiny Recursive Model (TRM) uses a fundamentally different approach than large language models: |
|
|
|
|
|
1. **Input Encoding**: The puzzle grid is embedded as tokens |
|
|
2. **Recursive Loop**: For N steps, the model updates its latent state `z` given (input, current_answer, current_z) |
|
|
3. **Answer Refinement**: After reasoning, the model updates its answer `y` based on the refined latent |
|
|
4. **Repeat**: This process repeats K times (typically 16), with each iteration improving the answer |
|
|
|
|
|
The key insight is that **depth of reasoning** (recursive iterations) can compensate for **model size**. |
|
|
A 7M parameter model thinking for 16 steps outperforms much larger models that only do single-pass inference. |
|
|
|
|
|
--- |
|
|
*β οΈ This is a visualization demo. The full model requires GPU resources.* |
|
|
*See the [GitHub repo](https://github.com/SamsungSAILMontreal/TinyRecursiveModels) for the actual implementation.* |
|
|
""") |
|
|
|
|
|
|
|
|
load_btn.click( |
|
|
load_puzzle, |
|
|
inputs=[puzzle_dropdown], |
|
|
outputs=[input_display, expected_output, status_text] |
|
|
) |
|
|
|
|
|
puzzle_dropdown.change( |
|
|
load_puzzle, |
|
|
inputs=[puzzle_dropdown], |
|
|
outputs=[input_display, expected_output, status_text] |
|
|
) |
|
|
|
|
|
solve_btn.click( |
|
|
simulate_reasoning, |
|
|
inputs=[puzzle_dropdown], |
|
|
outputs=[output_display, status_text] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |