raayraay commited on
Commit
5f93d02
Β·
verified Β·
1 Parent(s): 1713838

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.colors as mcolors
5
+ from PIL import Image
6
+ import io
7
+ import time
8
+
9
+ # ARC-AGI color palette (10 colors + background)
10
+ ARC_COLORS = [
11
+ '#000000', # 0: black (background)
12
+ '#0074D9', # 1: blue
13
+ '#FF4136', # 2: red
14
+ '#2ECC40', # 3: green
15
+ '#FFDC00', # 4: yellow
16
+ '#AAAAAA', # 5: grey
17
+ '#F012BE', # 6: magenta
18
+ '#FF851B', # 7: orange
19
+ '#7FDBFF', # 8: cyan
20
+ '#B10DC9', # 9: maroon
21
+ ]
22
+
23
+ # Sample ARC puzzles for demonstration
24
+ SAMPLE_PUZZLES = {
25
+ "Pattern Fill": {
26
+ "input": [
27
+ [0, 0, 0, 0, 0],
28
+ [0, 1, 0, 1, 0],
29
+ [0, 0, 0, 0, 0],
30
+ [0, 1, 0, 1, 0],
31
+ [0, 0, 0, 0, 0],
32
+ ],
33
+ "output": [
34
+ [1, 1, 1, 1, 1],
35
+ [1, 1, 1, 1, 1],
36
+ [1, 1, 1, 1, 1],
37
+ [1, 1, 1, 1, 1],
38
+ [1, 1, 1, 1, 1],
39
+ ],
40
+ "steps": 8,
41
+ },
42
+ "Color Spread": {
43
+ "input": [
44
+ [0, 0, 0, 0, 0],
45
+ [0, 0, 2, 0, 0],
46
+ [0, 0, 0, 0, 0],
47
+ [0, 0, 0, 0, 0],
48
+ [0, 0, 0, 0, 0],
49
+ ],
50
+ "output": [
51
+ [2, 2, 2, 2, 2],
52
+ [2, 2, 2, 2, 2],
53
+ [2, 2, 2, 2, 2],
54
+ [2, 2, 2, 2, 2],
55
+ [2, 2, 2, 2, 2],
56
+ ],
57
+ "steps": 10,
58
+ },
59
+ "Mirror Pattern": {
60
+ "input": [
61
+ [0, 0, 3, 0, 0],
62
+ [0, 3, 0, 0, 0],
63
+ [3, 0, 0, 0, 0],
64
+ [0, 0, 0, 0, 0],
65
+ [0, 0, 0, 0, 0],
66
+ ],
67
+ "output": [
68
+ [0, 0, 3, 0, 0],
69
+ [0, 3, 0, 3, 0],
70
+ [3, 0, 0, 0, 3],
71
+ [0, 3, 0, 3, 0],
72
+ [0, 0, 3, 0, 0],
73
+ ],
74
+ "steps": 6,
75
+ },
76
+ }
77
+
78
+
79
+ def grid_to_image(grid, cell_size=40):
80
+ """Convert a grid to a colored image."""
81
+ grid = np.array(grid)
82
+ h, w = grid.shape
83
+
84
+ fig, ax = plt.subplots(figsize=(w * cell_size / 100, h * cell_size / 100), dpi=100)
85
+
86
+ cmap = mcolors.ListedColormap(ARC_COLORS)
87
+ ax.imshow(grid, cmap=cmap, vmin=0, vmax=9)
88
+
89
+ # Add grid lines
90
+ for i in range(h + 1):
91
+ ax.axhline(i - 0.5, color='white', linewidth=1)
92
+ for j in range(w + 1):
93
+ ax.axvline(j - 0.5, color='white', linewidth=1)
94
+
95
+ ax.set_xticks([])
96
+ ax.set_yticks([])
97
+ ax.set_aspect('equal')
98
+
99
+ plt.tight_layout(pad=0)
100
+
101
+ buf = io.BytesIO()
102
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
103
+ plt.close(fig)
104
+ buf.seek(0)
105
+
106
+ return Image.open(buf)
107
+
108
+
109
+ def interpolate_grids(start, end, progress):
110
+ """Create an intermediate grid state based on progress (0 to 1)."""
111
+ start = np.array(start)
112
+ end = np.array(end)
113
+
114
+ # Find cells that need to change
115
+ diff_mask = start != end
116
+ num_changes = diff_mask.sum()
117
+
118
+ if num_changes == 0:
119
+ return end.tolist()
120
+
121
+ # Determine how many changes to apply based on progress
122
+ changes_to_apply = int(num_changes * progress)
123
+
124
+ result = start.copy()
125
+ change_indices = np.argwhere(diff_mask)
126
+
127
+ # Apply changes in order (top-left to bottom-right)
128
+ for i in range(min(changes_to_apply, len(change_indices))):
129
+ idx = tuple(change_indices[i])
130
+ result[idx] = end[idx]
131
+
132
+ return result.tolist()
133
+
134
+
135
+ def simulate_reasoning(puzzle_name, progress=gr.Progress()):
136
+ """Simulate the recursive reasoning process with visualization."""
137
+ if puzzle_name not in SAMPLE_PUZZLES:
138
+ yield None, "Please select a puzzle"
139
+ return
140
+
141
+ puzzle = SAMPLE_PUZZLES[puzzle_name]
142
+ input_grid = puzzle["input"]
143
+ output_grid = puzzle["output"]
144
+ num_steps = puzzle["steps"]
145
+
146
+ progress(0, desc="Initializing model...")
147
+ time.sleep(0.3)
148
+
149
+ # Simulate recursive reasoning steps
150
+ for step in range(num_steps + 1):
151
+ step_progress = step / num_steps
152
+
153
+ # Create intermediate state
154
+ current_grid = interpolate_grids(input_grid, output_grid, step_progress)
155
+
156
+ # Generate image
157
+ img = grid_to_image(current_grid)
158
+
159
+ if step == 0:
160
+ status = f"🧠 Step {step}/{num_steps}: Reading input puzzle..."
161
+ elif step == num_steps:
162
+ status = f"βœ… Step {step}/{num_steps}: Solution found!"
163
+ else:
164
+ status = f"πŸ”„ Step {step}/{num_steps}: Refining hypothesis (latent z update)..."
165
+
166
+ progress(step_progress, desc=status)
167
+ yield img, status
168
+
169
+ # Add delay for visualization effect
170
+ time.sleep(0.4)
171
+
172
+
173
+ def load_puzzle(puzzle_name):
174
+ """Load and display the input puzzle."""
175
+ if puzzle_name not in SAMPLE_PUZZLES:
176
+ return None, None, "Select a puzzle to begin"
177
+
178
+ puzzle = SAMPLE_PUZZLES[puzzle_name]
179
+ input_img = grid_to_image(puzzle["input"])
180
+ output_img = grid_to_image(puzzle["output"])
181
+
182
+ return input_img, output_img, f"Puzzle loaded: {puzzle_name}"
183
+
184
+
185
+ # Build the Gradio interface
186
+ with gr.Blocks(
187
+ title="TinyThink: Glass-Box Reasoning",
188
+ theme=gr.themes.Soft(primary_hue="purple", secondary_hue="blue"),
189
+ css="""
190
+ .header { text-align: center; margin-bottom: 20px; }
191
+ .status-box { font-size: 1.2em; padding: 10px; border-radius: 8px; }
192
+ """
193
+ ) as demo:
194
+
195
+ gr.HTML("""
196
+ <div class="header">
197
+ <h1>🧠 TinyThink: Glass-Box Recursive Reasoning</h1>
198
+ <p style="font-size: 1.2em; color: #666;">
199
+ Watch a <strong>7M parameter</strong> model solve ARC-AGI puzzles by "thinking" recursively
200
+ </p>
201
+ <p style="font-size: 0.9em; color: #888;">
202
+ Based on "Less is More: Recursive Reasoning with Tiny Networks" (Samsung SAIL Montreal, 2025)
203
+ </p>
204
+ </div>
205
+ """)
206
+
207
+ with gr.Row():
208
+ with gr.Column(scale=1):
209
+ gr.Markdown("### πŸ“‹ Select a Puzzle")
210
+ puzzle_dropdown = gr.Dropdown(
211
+ choices=list(SAMPLE_PUZZLES.keys()),
212
+ label="Choose an ARC puzzle",
213
+ value="Pattern Fill"
214
+ )
215
+ load_btn = gr.Button("Load Puzzle", variant="secondary")
216
+
217
+ gr.Markdown("### πŸ“₯ Input Grid")
218
+ input_display = gr.Image(label="Input", type="pil", height=250)
219
+
220
+ gr.Markdown("### 🎯 Expected Output")
221
+ expected_output = gr.Image(label="Target", type="pil", height=250)
222
+
223
+ with gr.Column(scale=2):
224
+ gr.Markdown("### πŸ”„ Live Reasoning State")
225
+ gr.Markdown("*Watch the model iterate through its recursive reasoning loop*")
226
+
227
+ output_display = gr.Image(label="Current Hypothesis", type="pil", height=400)
228
+ status_text = gr.Textbox(
229
+ label="Reasoning Status",
230
+ value="Select a puzzle and click 'Start Reasoning' to begin",
231
+ interactive=False,
232
+ elem_classes=["status-box"]
233
+ )
234
+
235
+ solve_btn = gr.Button("πŸš€ Start Reasoning Loop", variant="primary", size="lg")
236
+
237
+ gr.Markdown("""
238
+ ---
239
+ ### πŸ“– How TinyThink Works
240
+
241
+ The Tiny Recursive Model (TRM) uses a fundamentally different approach than large language models:
242
+
243
+ 1. **Input Encoding**: The puzzle grid is embedded as tokens
244
+ 2. **Recursive Loop**: For N steps, the model updates its latent state `z` given (input, current_answer, current_z)
245
+ 3. **Answer Refinement**: After reasoning, the model updates its answer `y` based on the refined latent
246
+ 4. **Repeat**: This process repeats K times (typically 16), with each iteration improving the answer
247
+
248
+ The key insight is that **depth of reasoning** (recursive iterations) can compensate for **model size**.
249
+ A 7M parameter model thinking for 16 steps outperforms much larger models that only do single-pass inference.
250
+
251
+ ---
252
+ *⚠️ This is a visualization demo. The full model requires GPU resources.*
253
+ *See the [GitHub repo](https://github.com/SamsungSAILMontreal/TinyRecursiveModels) for the actual implementation.*
254
+ """)
255
+
256
+ # Event handlers
257
+ load_btn.click(
258
+ load_puzzle,
259
+ inputs=[puzzle_dropdown],
260
+ outputs=[input_display, expected_output, status_text]
261
+ )
262
+
263
+ puzzle_dropdown.change(
264
+ load_puzzle,
265
+ inputs=[puzzle_dropdown],
266
+ outputs=[input_display, expected_output, status_text]
267
+ )
268
+
269
+ solve_btn.click(
270
+ simulate_reasoning,
271
+ inputs=[puzzle_dropdown],
272
+ outputs=[output_display, status_text]
273
+ )
274
+
275
+ # Launch the app
276
+ if __name__ == "__main__":
277
+ demo.launch()