TinyThink / app.py
raayraay's picture
Create app.py
5f93d02 verified
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from PIL import Image
import io
import time
# ARC-AGI color palette (10 colors + background)
ARC_COLORS = [
'#000000', # 0: black (background)
'#0074D9', # 1: blue
'#FF4136', # 2: red
'#2ECC40', # 3: green
'#FFDC00', # 4: yellow
'#AAAAAA', # 5: grey
'#F012BE', # 6: magenta
'#FF851B', # 7: orange
'#7FDBFF', # 8: cyan
'#B10DC9', # 9: maroon
]
# Sample ARC puzzles for demonstration
SAMPLE_PUZZLES = {
"Pattern Fill": {
"input": [
[0, 0, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 0],
],
"output": [
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
],
"steps": 8,
},
"Color Spread": {
"input": [
[0, 0, 0, 0, 0],
[0, 0, 2, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
"output": [
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
],
"steps": 10,
},
"Mirror Pattern": {
"input": [
[0, 0, 3, 0, 0],
[0, 3, 0, 0, 0],
[3, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
"output": [
[0, 0, 3, 0, 0],
[0, 3, 0, 3, 0],
[3, 0, 0, 0, 3],
[0, 3, 0, 3, 0],
[0, 0, 3, 0, 0],
],
"steps": 6,
},
}
def grid_to_image(grid, cell_size=40):
"""Convert a grid to a colored image."""
grid = np.array(grid)
h, w = grid.shape
fig, ax = plt.subplots(figsize=(w * cell_size / 100, h * cell_size / 100), dpi=100)
cmap = mcolors.ListedColormap(ARC_COLORS)
ax.imshow(grid, cmap=cmap, vmin=0, vmax=9)
# Add grid lines
for i in range(h + 1):
ax.axhline(i - 0.5, color='white', linewidth=1)
for j in range(w + 1):
ax.axvline(j - 0.5, color='white', linewidth=1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')
plt.tight_layout(pad=0)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def interpolate_grids(start, end, progress):
"""Create an intermediate grid state based on progress (0 to 1)."""
start = np.array(start)
end = np.array(end)
# Find cells that need to change
diff_mask = start != end
num_changes = diff_mask.sum()
if num_changes == 0:
return end.tolist()
# Determine how many changes to apply based on progress
changes_to_apply = int(num_changes * progress)
result = start.copy()
change_indices = np.argwhere(diff_mask)
# Apply changes in order (top-left to bottom-right)
for i in range(min(changes_to_apply, len(change_indices))):
idx = tuple(change_indices[i])
result[idx] = end[idx]
return result.tolist()
def simulate_reasoning(puzzle_name, progress=gr.Progress()):
"""Simulate the recursive reasoning process with visualization."""
if puzzle_name not in SAMPLE_PUZZLES:
yield None, "Please select a puzzle"
return
puzzle = SAMPLE_PUZZLES[puzzle_name]
input_grid = puzzle["input"]
output_grid = puzzle["output"]
num_steps = puzzle["steps"]
progress(0, desc="Initializing model...")
time.sleep(0.3)
# Simulate recursive reasoning steps
for step in range(num_steps + 1):
step_progress = step / num_steps
# Create intermediate state
current_grid = interpolate_grids(input_grid, output_grid, step_progress)
# Generate image
img = grid_to_image(current_grid)
if step == 0:
status = f"🧠 Step {step}/{num_steps}: Reading input puzzle..."
elif step == num_steps:
status = f"βœ… Step {step}/{num_steps}: Solution found!"
else:
status = f"πŸ”„ Step {step}/{num_steps}: Refining hypothesis (latent z update)..."
progress(step_progress, desc=status)
yield img, status
# Add delay for visualization effect
time.sleep(0.4)
def load_puzzle(puzzle_name):
"""Load and display the input puzzle."""
if puzzle_name not in SAMPLE_PUZZLES:
return None, None, "Select a puzzle to begin"
puzzle = SAMPLE_PUZZLES[puzzle_name]
input_img = grid_to_image(puzzle["input"])
output_img = grid_to_image(puzzle["output"])
return input_img, output_img, f"Puzzle loaded: {puzzle_name}"
# Build the Gradio interface
with gr.Blocks(
title="TinyThink: Glass-Box Reasoning",
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue"),
css="""
.header { text-align: center; margin-bottom: 20px; }
.status-box { font-size: 1.2em; padding: 10px; border-radius: 8px; }
"""
) as demo:
gr.HTML("""
<div class="header">
<h1>🧠 TinyThink: Glass-Box Recursive Reasoning</h1>
<p style="font-size: 1.2em; color: #666;">
Watch a <strong>7M parameter</strong> model solve ARC-AGI puzzles by "thinking" recursively
</p>
<p style="font-size: 0.9em; color: #888;">
Based on "Less is More: Recursive Reasoning with Tiny Networks" (Samsung SAIL Montreal, 2025)
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“‹ Select a Puzzle")
puzzle_dropdown = gr.Dropdown(
choices=list(SAMPLE_PUZZLES.keys()),
label="Choose an ARC puzzle",
value="Pattern Fill"
)
load_btn = gr.Button("Load Puzzle", variant="secondary")
gr.Markdown("### πŸ“₯ Input Grid")
input_display = gr.Image(label="Input", type="pil", height=250)
gr.Markdown("### 🎯 Expected Output")
expected_output = gr.Image(label="Target", type="pil", height=250)
with gr.Column(scale=2):
gr.Markdown("### πŸ”„ Live Reasoning State")
gr.Markdown("*Watch the model iterate through its recursive reasoning loop*")
output_display = gr.Image(label="Current Hypothesis", type="pil", height=400)
status_text = gr.Textbox(
label="Reasoning Status",
value="Select a puzzle and click 'Start Reasoning' to begin",
interactive=False,
elem_classes=["status-box"]
)
solve_btn = gr.Button("πŸš€ Start Reasoning Loop", variant="primary", size="lg")
gr.Markdown("""
---
### πŸ“– How TinyThink Works
The Tiny Recursive Model (TRM) uses a fundamentally different approach than large language models:
1. **Input Encoding**: The puzzle grid is embedded as tokens
2. **Recursive Loop**: For N steps, the model updates its latent state `z` given (input, current_answer, current_z)
3. **Answer Refinement**: After reasoning, the model updates its answer `y` based on the refined latent
4. **Repeat**: This process repeats K times (typically 16), with each iteration improving the answer
The key insight is that **depth of reasoning** (recursive iterations) can compensate for **model size**.
A 7M parameter model thinking for 16 steps outperforms much larger models that only do single-pass inference.
---
*⚠️ This is a visualization demo. The full model requires GPU resources.*
*See the [GitHub repo](https://github.com/SamsungSAILMontreal/TinyRecursiveModels) for the actual implementation.*
""")
# Event handlers
load_btn.click(
load_puzzle,
inputs=[puzzle_dropdown],
outputs=[input_display, expected_output, status_text]
)
puzzle_dropdown.change(
load_puzzle,
inputs=[puzzle_dropdown],
outputs=[input_display, expected_output, status_text]
)
solve_btn.click(
simulate_reasoning,
inputs=[puzzle_dropdown],
outputs=[output_display, status_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()