File size: 8,561 Bytes
5f93d02 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from PIL import Image
import io
import time
# ARC-AGI color palette (10 colors + background)
ARC_COLORS = [
'#000000', # 0: black (background)
'#0074D9', # 1: blue
'#FF4136', # 2: red
'#2ECC40', # 3: green
'#FFDC00', # 4: yellow
'#AAAAAA', # 5: grey
'#F012BE', # 6: magenta
'#FF851B', # 7: orange
'#7FDBFF', # 8: cyan
'#B10DC9', # 9: maroon
]
# Sample ARC puzzles for demonstration
SAMPLE_PUZZLES = {
"Pattern Fill": {
"input": [
[0, 0, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 0],
],
"output": [
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
],
"steps": 8,
},
"Color Spread": {
"input": [
[0, 0, 0, 0, 0],
[0, 0, 2, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
"output": [
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
],
"steps": 10,
},
"Mirror Pattern": {
"input": [
[0, 0, 3, 0, 0],
[0, 3, 0, 0, 0],
[3, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
],
"output": [
[0, 0, 3, 0, 0],
[0, 3, 0, 3, 0],
[3, 0, 0, 0, 3],
[0, 3, 0, 3, 0],
[0, 0, 3, 0, 0],
],
"steps": 6,
},
}
def grid_to_image(grid, cell_size=40):
"""Convert a grid to a colored image."""
grid = np.array(grid)
h, w = grid.shape
fig, ax = plt.subplots(figsize=(w * cell_size / 100, h * cell_size / 100), dpi=100)
cmap = mcolors.ListedColormap(ARC_COLORS)
ax.imshow(grid, cmap=cmap, vmin=0, vmax=9)
# Add grid lines
for i in range(h + 1):
ax.axhline(i - 0.5, color='white', linewidth=1)
for j in range(w + 1):
ax.axvline(j - 0.5, color='white', linewidth=1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('equal')
plt.tight_layout(pad=0)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def interpolate_grids(start, end, progress):
"""Create an intermediate grid state based on progress (0 to 1)."""
start = np.array(start)
end = np.array(end)
# Find cells that need to change
diff_mask = start != end
num_changes = diff_mask.sum()
if num_changes == 0:
return end.tolist()
# Determine how many changes to apply based on progress
changes_to_apply = int(num_changes * progress)
result = start.copy()
change_indices = np.argwhere(diff_mask)
# Apply changes in order (top-left to bottom-right)
for i in range(min(changes_to_apply, len(change_indices))):
idx = tuple(change_indices[i])
result[idx] = end[idx]
return result.tolist()
def simulate_reasoning(puzzle_name, progress=gr.Progress()):
"""Simulate the recursive reasoning process with visualization."""
if puzzle_name not in SAMPLE_PUZZLES:
yield None, "Please select a puzzle"
return
puzzle = SAMPLE_PUZZLES[puzzle_name]
input_grid = puzzle["input"]
output_grid = puzzle["output"]
num_steps = puzzle["steps"]
progress(0, desc="Initializing model...")
time.sleep(0.3)
# Simulate recursive reasoning steps
for step in range(num_steps + 1):
step_progress = step / num_steps
# Create intermediate state
current_grid = interpolate_grids(input_grid, output_grid, step_progress)
# Generate image
img = grid_to_image(current_grid)
if step == 0:
status = f"π§ Step {step}/{num_steps}: Reading input puzzle..."
elif step == num_steps:
status = f"β
Step {step}/{num_steps}: Solution found!"
else:
status = f"π Step {step}/{num_steps}: Refining hypothesis (latent z update)..."
progress(step_progress, desc=status)
yield img, status
# Add delay for visualization effect
time.sleep(0.4)
def load_puzzle(puzzle_name):
"""Load and display the input puzzle."""
if puzzle_name not in SAMPLE_PUZZLES:
return None, None, "Select a puzzle to begin"
puzzle = SAMPLE_PUZZLES[puzzle_name]
input_img = grid_to_image(puzzle["input"])
output_img = grid_to_image(puzzle["output"])
return input_img, output_img, f"Puzzle loaded: {puzzle_name}"
# Build the Gradio interface
with gr.Blocks(
title="TinyThink: Glass-Box Reasoning",
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue"),
css="""
.header { text-align: center; margin-bottom: 20px; }
.status-box { font-size: 1.2em; padding: 10px; border-radius: 8px; }
"""
) as demo:
gr.HTML("""
<div class="header">
<h1>π§ TinyThink: Glass-Box Recursive Reasoning</h1>
<p style="font-size: 1.2em; color: #666;">
Watch a <strong>7M parameter</strong> model solve ARC-AGI puzzles by "thinking" recursively
</p>
<p style="font-size: 0.9em; color: #888;">
Based on "Less is More: Recursive Reasoning with Tiny Networks" (Samsung SAIL Montreal, 2025)
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### π Select a Puzzle")
puzzle_dropdown = gr.Dropdown(
choices=list(SAMPLE_PUZZLES.keys()),
label="Choose an ARC puzzle",
value="Pattern Fill"
)
load_btn = gr.Button("Load Puzzle", variant="secondary")
gr.Markdown("### π₯ Input Grid")
input_display = gr.Image(label="Input", type="pil", height=250)
gr.Markdown("### π― Expected Output")
expected_output = gr.Image(label="Target", type="pil", height=250)
with gr.Column(scale=2):
gr.Markdown("### π Live Reasoning State")
gr.Markdown("*Watch the model iterate through its recursive reasoning loop*")
output_display = gr.Image(label="Current Hypothesis", type="pil", height=400)
status_text = gr.Textbox(
label="Reasoning Status",
value="Select a puzzle and click 'Start Reasoning' to begin",
interactive=False,
elem_classes=["status-box"]
)
solve_btn = gr.Button("π Start Reasoning Loop", variant="primary", size="lg")
gr.Markdown("""
---
### π How TinyThink Works
The Tiny Recursive Model (TRM) uses a fundamentally different approach than large language models:
1. **Input Encoding**: The puzzle grid is embedded as tokens
2. **Recursive Loop**: For N steps, the model updates its latent state `z` given (input, current_answer, current_z)
3. **Answer Refinement**: After reasoning, the model updates its answer `y` based on the refined latent
4. **Repeat**: This process repeats K times (typically 16), with each iteration improving the answer
The key insight is that **depth of reasoning** (recursive iterations) can compensate for **model size**.
A 7M parameter model thinking for 16 steps outperforms much larger models that only do single-pass inference.
---
*β οΈ This is a visualization demo. The full model requires GPU resources.*
*See the [GitHub repo](https://github.com/SamsungSAILMontreal/TinyRecursiveModels) for the actual implementation.*
""")
# Event handlers
load_btn.click(
load_puzzle,
inputs=[puzzle_dropdown],
outputs=[input_display, expected_output, status_text]
)
puzzle_dropdown.change(
load_puzzle,
inputs=[puzzle_dropdown],
outputs=[input_display, expected_output, status_text]
)
solve_btn.click(
simulate_reasoning,
inputs=[puzzle_dropdown],
outputs=[output_display, status_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |