beanapologist commited on
Commit
4c236ce
Β·
verified Β·
1 Parent(s): a9e6886

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -299
app.py CHANGED
@@ -1,9 +1,6 @@
1
  """
2
- ARC-AGI-3 Feature Visualizer
3
  Hugging Face Space: beanapologist/arc-agi
4
-
5
- Interactive demo of the Re/Im symmetry-aware feature extractor.
6
- Draw a grid, see what the CNN sees before any training happens.
7
  """
8
 
9
  import gradio as gr
@@ -11,186 +8,134 @@ import numpy as np
11
  import matplotlib
12
  matplotlib.use('Agg')
13
  import matplotlib.pyplot as plt
14
- import matplotlib.patches as mpatches
15
  from matplotlib.colors import ListedColormap
16
  import io
17
  from PIL import Image
18
 
19
- # ── Inline feature extractor (no external deps) ───────────────────────────────
 
 
 
 
 
 
 
 
 
 
20
 
21
  def _connected_components(mask):
22
  labels = np.zeros_like(mask, dtype=np.int32)
23
- current_label = 0
24
  H, W = mask.shape
25
  for r in range(H):
26
  for c in range(W):
27
- if mask[r, c] and labels[r, c] == 0:
28
- current_label += 1
29
- queue = [(r, c)]
30
- labels[r, c] = current_label
31
- while queue:
32
- y, x = queue.pop()
33
- for dy, dx in [(-1,0),(1,0),(0,-1),(0,1)]:
34
- ny, nx = y+dy, x+dx
35
  if 0<=ny<H and 0<=nx<W and mask[ny,nx] and labels[ny,nx]==0:
36
- labels[ny,nx] = current_label
37
- queue.append((ny,nx))
38
  return labels
39
 
40
- def _distance_transform(mask, max_dist=8):
41
- H, W = mask.shape
42
- dist = np.full((H, W), max_dist, dtype=np.float32)
43
- queue = []
44
- for r in range(H):
45
- for c in range(W):
46
- if mask[r, c]:
47
- dist[r, c] = 0
48
- queue.append((r, c))
49
- head = 0
50
- while head < len(queue):
51
- y, x = queue[head]; head += 1
52
- for dy, dx in [(-1,0),(1,0),(0,-1),(0,1)]:
53
- ny, nx = y+dy, x+dx
54
- if 0<=ny<H and 0<=nx<W and dist[ny,nx] > dist[y,x]+1:
55
- dist[ny,nx] = dist[y,x]+1
56
- queue.append((ny,nx))
57
- return np.clip(dist / max_dist, 0, 1)
58
-
59
- def _sobel(grid_f):
60
- p = np.pad(grid_f, 1, mode='edge')
61
- gx = (-p[:-2,:-2] - 2*p[1:-1,:-2] - p[2:,:-2]
62
- + p[:-2,2:] + 2*p[1:-1,2:] + p[2:,2:]) / 8.0
63
- gy = (-p[:-2,:-2] - 2*p[:-2,1:-1] - p[:-2,2:]
64
- + p[2:,:-2] + 2*p[2:,1:-1] + p[2:,2:]) / 8.0
65
  return gx, gy
66
 
67
- def _reflection_symmetry_map(grid, axis):
68
- H, W = grid.shape
69
- score_map = np.zeros((H, W), dtype=np.float32)
70
  if axis == 'h':
71
  for x in range(W):
72
- reach = min(x, W-1-x)
73
- if reach == 0: score_map[:,x] = 1.0; continue
74
- left = grid[:, x-reach:x]
75
- right = grid[:, x+1:x+reach+1][:, ::-1]
76
- score_map[:,x] = (left == right).mean()
77
  else:
78
  for y in range(H):
79
- reach = min(y, H-1-y)
80
- if reach == 0: score_map[y,:] = 1.0; continue
81
- top = grid[y-reach:y, :]
82
- bottom = grid[y+1:y+reach+1,:][::-1,:]
83
- score_map[y,:] = (top == bottom).mean()
84
- return score_map
85
-
86
- def _winding_proxy(gx, gy):
87
- p_gx = np.pad(gx, 1, mode='edge')
88
- p_gy = np.pad(gy, 1, mode='edge')
89
- dgydx = (p_gy[1:-1,2:] - p_gy[1:-1,:-2]) / 2.0
90
- dgxdy = (p_gx[2:,1:-1] - p_gx[:-2,1:-1]) / 2.0
91
- curl = dgydx - dgxdy
92
- mx = np.abs(curl).max()
93
- if mx > 0: curl = curl / mx
94
- return curl.astype(np.float32)
95
-
96
- def extract_features_visual(grid_2d):
97
- """Extract and return named Im-side feature maps for visualization."""
98
- H, W = grid_2d.shape
99
- color_float = grid_2d.astype(np.float32) / 9.0
100
- gx, gy = _sobel(color_float)
101
-
102
- h_sym = _reflection_symmetry_map(grid_2d, 'h')
103
- v_sym = _reflection_symmetry_map(grid_2d, 'v')
104
- curl = _winding_proxy(gx, gy)
105
-
106
- padded = np.pad(grid_2d, 1, mode='edge')
107
- boundary = (
108
- (padded[1:-1,1:-1] != padded[:-2,1:-1]) |
109
- (padded[1:-1,1:-1] != padded[2:,1:-1]) |
110
- (padded[1:-1,1:-1] != padded[1:-1,:-2]) |
111
- (padded[1:-1,1:-1] != padded[1:-1,2:])
112
- ).astype(np.float32)
113
-
114
- edge_mag = np.sqrt(gx**2 + gy**2)
115
- edge_mag = edge_mag / (edge_mag.max() + 1e-8)
116
-
117
- # Connected component map
118
- global_labels = np.zeros((H, W), dtype=np.int32)
119
- current = 0
120
  for c in range(10):
121
- mask = (grid_2d == c)
122
- if not mask.any(): continue
123
- labels = _connected_components(mask)
124
- for comp_id in range(1, labels.max()+1):
125
- current += 1
126
- global_labels[labels == comp_id] = current
127
- norm_labels = global_labels.astype(np.float32)
128
- if norm_labels.max() > 0: norm_labels /= norm_labels.max()
129
-
130
- return dict(
131
- h_symmetry=h_sym,
132
- v_symmetry=v_sym,
133
- boundary=boundary,
134
- edge_magnitude=edge_mag,
135
- winding_curl=curl,
136
- component_map=norm_labels,
137
- )
138
-
139
- # ── ARC color palette ─────────────────────────────────────────────────────────
140
-
141
- ARC_COLORS = [
142
- '#000000', '#1a6faf', '#e03a3a', '#3aa63a',
143
- '#f5c400', '#c060c0', '#d07030', '#aaaaaa',
144
- '#60b8d0', '#874010'
145
- ]
146
- ARC_CMAP = ListedColormap(ARC_COLORS)
147
-
148
- # ── Preset puzzles ────────────────────────────────────────────────────────────
149
-
150
- PRESETS = {
151
- "Horizontal mirror": """0 1 2 0 0 0 2 1 0
152
- 0 1 2 0 0 0 2 1 0
153
- 0 1 2 0 0 0 2 1 0
154
- 0 0 0 0 0 0 0 0 0
155
- 0 0 0 3 3 3 0 0 0
156
- 0 0 0 3 0 3 0 0 0
157
- 0 0 0 3 3 3 0 0 0""",
158
-
159
- "Colored objects": """0 0 0 0 0 0 0 0 0
160
- 0 1 1 0 0 2 2 0 0
161
- 0 1 1 0 0 2 2 0 0
162
- 0 0 0 0 0 0 0 0 0
163
- 0 3 3 0 0 4 4 0 0
164
- 0 3 3 0 0 4 4 0 0
165
- 0 0 0 0 0 0 0 0 0""",
166
-
167
- "Border pattern": """5 5 5 5 5 5 5
168
- 5 0 0 0 0 0 5
169
- 5 0 3 3 3 0 5
170
- 5 0 3 0 3 0 5
171
- 5 0 3 3 3 0 5
172
- 5 0 0 0 0 0 5
173
- 5 5 5 5 5 5 5""",
174
-
175
- "Diagonal": """1 0 0 0 0 0 0
176
- 0 1 0 0 0 0 0
177
- 0 0 1 0 0 0 0
178
- 0 0 0 2 0 0 0
179
- 0 0 0 0 2 0 0
180
- 0 0 0 0 0 2 0
181
- 0 0 0 0 0 0 3""",
182
-
183
- "Scattered dots": """0 0 0 0 0 0 0 0
184
- 0 1 0 0 0 2 0 0
185
- 0 0 0 0 0 0 0 0
186
- 0 0 0 3 0 0 0 0
187
- 0 0 0 0 0 4 0 0
188
- 0 5 0 0 0 0 0 0
189
- 0 0 0 0 6 0 0 0
190
- 0 0 0 0 0 0 0 0""",
191
- }
192
 
193
- # ── Parse grid text ───────────────────────────────────────────────────────────
194
 
195
  def parse_grid(text):
196
  try:
@@ -198,169 +143,191 @@ def parse_grid(text):
198
  for line in text.strip().split('\n'):
199
  line = line.strip()
200
  if not line: continue
201
- vals = [int(x) for x in line.split()]
202
- rows.append(vals)
203
- if not rows: return None, "No rows found"
204
  W = len(rows[0])
205
- if any(len(r) != W for r in rows):
206
- return None, "Rows have different lengths"
207
- grid = np.array(rows, dtype=np.int64)
208
- if grid.min() < 0 or grid.max() > 9:
209
- return None, "Values must be 0–9"
210
- return grid, None
211
  except Exception as e:
212
  return None, str(e)
213
 
214
- # ── Main analysis function ────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- def analyze_grid(grid_text):
217
- grid, err = parse_grid(grid_text)
218
- if err:
219
- return None, f"Parse error: {err}", ""
220
-
221
- features = extract_features_visual(grid)
222
- H, W = grid.shape
223
-
224
- fig, axes = plt.subplots(2, 4, figsize=(14, 7))
225
- fig.patch.set_facecolor('#1a1a2e')
226
-
227
- titles = [
228
- ("Input grid", grid, ARC_CMAP, 0, 9, False),
229
- ("H-symmetry\n(mirror axes β†’)", features['h_symmetry'], 'YlOrRd', 0, 1, False),
230
- ("V-symmetry\n(mirror axes ↕)", features['v_symmetry'], 'YlOrRd', 0, 1, False),
231
- ("Boundary contour\n(Cauchy edge)", features['boundary'], 'plasma', 0, 1, False),
232
- ("Edge magnitude\n(Sobel)", features['edge_magnitude'], 'hot', 0, 1, False),
233
- ("Winding / curl\n(Im rotation)", features['winding_curl'], 'RdBu', -1, 1, True),
234
- ("Component map\n(object IDs)", features['component_map'], 'tab20', 0, 1, False),
235
- ]
236
 
237
- flat_axes = axes.flatten()
238
- for i, ax in enumerate(flat_axes):
239
- ax.set_facecolor('#0d0d1a')
240
- if i < len(titles):
241
- title, data, cmap, vmin, vmax, center = titles[i]
242
- if cmap == ARC_CMAP:
243
- im = ax.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax,
244
- interpolation='nearest', aspect='equal')
245
- for r in range(H):
246
- for c in range(W):
247
- v = int(data[r, c])
248
- color = 'white' if v in [0,6,7,8] else 'black'
249
- ax.text(c, r, str(v), ha='center', va='center',
250
- fontsize=max(6, 10 - max(H, W)//2),
251
- color=color, fontweight='bold')
252
- else:
253
- kwargs = dict(cmap=cmap, vmin=vmin, vmax=vmax,
254
- interpolation='nearest', aspect='equal')
255
- im = ax.imshow(data, **kwargs)
256
- plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
257
-
258
- ax.set_title(title, color='white', fontsize=9, pad=4)
259
- ax.tick_params(colors='#666', labelsize=7)
260
- for spine in ax.spines.values():
261
- spine.set_edgecolor('#333')
262
- else:
263
- ax.axis('off')
264
-
265
- flat_axes[-1].axis('off')
266
- flat_axes[-1].set_facecolor('#0d0d1a')
267
-
268
- plt.tight_layout(pad=1.5)
269
 
270
- buf = io.BytesIO()
271
- plt.savefig(buf, format='png', dpi=120, bbox_inches='tight',
272
- facecolor='#1a1a2e')
273
- buf.seek(0)
274
- img = Image.open(buf)
275
- plt.close()
276
 
277
- # Stats summary
278
- h_max = features['h_symmetry'].max()
279
- v_max = features['v_symmetry'].max()
280
- boundary_count = int((features['boundary'] > 0.5).sum())
281
- curl_max = float(np.abs(features['winding_curl']).max())
282
- n_components = int(features['component_map'].max() * 20) if features['component_map'].max() > 0 else 0
283
 
284
- dominant = "H-symmetry" if h_max >= v_max else "V-symmetry"
285
- sym_score = max(h_max, v_max)
 
 
286
 
287
- summary = f"""**Grid:** {H}Γ—{W} | **Colors used:** {len(np.unique(grid))-1} (excl. background)
 
 
288
 
289
- **Im-side (orientation / structure):**
290
- - Strongest symmetry axis: {dominant} = {sym_score:.2f} {'β˜… strong' if sym_score > 0.8 else ''}
291
- - Boundary pixels (Cauchy edge): {boundary_count}
292
- - Winding / curl max: {curl_max:.3f}
293
- - Connected components: ~{n_components}
294
 
295
- **What the CNN sees before training:**
296
- The 56-channel extractor separates *how big* (Re: color counts, distances) from *where pointed* (Im: symmetry axes, boundaries, rotation). High symmetry scores tell the CNN which actions are likely to preserve structure; high boundary counts indicate complex objects worth clicking on precisely."""
297
 
298
- return img, summary, ""
299
 
 
 
 
 
 
 
 
300
 
301
- def load_preset(preset_name):
302
- return PRESETS.get(preset_name, "")
 
303
 
 
 
304
 
305
- # ── Gradio UI ─────────────────────────────────────────────────────────────────
306
 
307
- CSS = """
308
- .gradio-container { max-width: 1100px !important; }
309
- #grid-input textarea { font-family: monospace; font-size: 14px; }
310
- """
311
 
312
- with gr.Blocks(css=CSS, title="ARC-AGI-3 Feature Visualizer") as demo:
 
 
 
 
313
 
314
  gr.Markdown("""
315
- # ARC-AGI-3 Feature Visualizer
316
- ### Re/Im duality β€” what the CNN sees before any training
317
-
318
- Enter a grid of integers (0–9, space-separated rows) or pick a preset.
319
- The extractor computes **56 feature channels** split into:
320
- - **Re side** (local / multiplicative): one-hot colors, object sizes, distance maps
321
- - **Im side** (global / angular): symmetry axes, boundary contours, winding structure
322
  """)
323
 
324
  with gr.Row():
325
- with gr.Column(scale=1):
326
- preset_dd = gr.Dropdown(
327
- choices=list(PRESETS.keys()),
328
- label="Load a preset puzzle",
329
- value=None,
330
- )
331
- grid_input = gr.Textbox(
332
- label="Grid (space-separated integers 0–9, one row per line)",
333
- placeholder="0 1 2\n3 0 1\n2 3 0",
334
- lines=10,
335
- elem_id="grid-input",
336
- value=PRESETS["Horizontal mirror"],
337
- )
338
- analyze_btn = gr.Button("Analyze", variant="primary")
339
-
340
- with gr.Column(scale=2):
341
- output_img = gr.Image(label="Feature maps", type="pil")
342
- output_md = gr.Markdown()
343
- error_box = gr.Textbox(label="", visible=False)
344
 
345
- preset_dd.change(load_preset, inputs=preset_dd, outputs=grid_input)
346
- analyze_btn.click(analyze_grid,
347
- inputs=grid_input,
348
- outputs=[output_img, output_md, error_box])
 
 
349
 
350
- gr.Markdown("""
351
- ---
352
- **How this works in ARC-AGI-3:** The agent receives a 64Γ—64 grid frame from the game engine.
353
- Instead of feeding raw one-hot colors to the CNN, we pre-compute these 56 channels β€”
354
- giving the model geometric priors about symmetry and structure *before* it takes a single action.
355
 
356
- Source: [Kaggle competition](https://www.kaggle.com/competitions/arc-prize-2026-arc-agi-3) |
357
- Agent based on [StochasticGoose](https://github.com/DriesSmit/ARC3-solution) by Dries Smit & Jack Cole (Tufa Labs)
358
- """)
 
 
 
 
359
 
360
- demo.load(
361
- fn=lambda: analyze_grid(PRESETS["Horizontal mirror"]),
362
- outputs=[output_img, output_md, error_box],
363
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  if __name__ == "__main__":
366
- demo.launch()
 
 
 
1
  """
2
+ ARC-AGI-3 Puzzle Interface + Feature Visualizer
3
  Hugging Face Space: beanapologist/arc-agi
 
 
 
4
  """
5
 
6
  import gradio as gr
 
8
  import matplotlib
9
  matplotlib.use('Agg')
10
  import matplotlib.pyplot as plt
 
11
  from matplotlib.colors import ListedColormap
12
  import io
13
  from PIL import Image
14
 
15
+ # ── ARC color palette ─────────────────────────────────────────────────────────
16
+
17
+ ARC_HEX = [
18
+ '#000000','#1a6faf','#e03a3a','#3aa63a','#f5c400',
19
+ '#c060c0','#d07030','#aaaaaa','#60b8d0','#874010'
20
+ ]
21
+ ARC_CMAP = ListedColormap(ARC_HEX)
22
+ COLOR_NAMES = ['black','blue','red','green','yellow',
23
+ 'purple','orange','gray','azure','maroon']
24
+
25
+ # ── Feature extractor ─────────────────────────────────────────────────────────
26
 
27
  def _connected_components(mask):
28
  labels = np.zeros_like(mask, dtype=np.int32)
29
+ cur = 0
30
  H, W = mask.shape
31
  for r in range(H):
32
  for c in range(W):
33
+ if mask[r,c] and labels[r,c]==0:
34
+ cur += 1
35
+ q = [(r,c)]; labels[r,c] = cur
36
+ while q:
37
+ y,x = q.pop()
38
+ for dy,dx in [(-1,0),(1,0),(0,-1),(0,1)]:
39
+ ny,nx = y+dy,x+dx
 
40
  if 0<=ny<H and 0<=nx<W and mask[ny,nx] and labels[ny,nx]==0:
41
+ labels[ny,nx]=cur; q.append((ny,nx))
 
42
  return labels
43
 
44
+ def _sobel(f):
45
+ p = np.pad(f, 1, mode='edge')
46
+ gx = (-p[:-2,:-2]-2*p[1:-1,:-2]-p[2:,:-2]+p[:-2,2:]+2*p[1:-1,2:]+p[2:,2:])/8
47
+ gy = (-p[:-2,:-2]-2*p[:-2,1:-1]-p[:-2,2:]+p[2:,:-2]+2*p[2:,1:-1]+p[2:,2:])/8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return gx, gy
49
 
50
+ def _sym_map(grid, axis):
51
+ H,W = grid.shape
52
+ s = np.zeros((H,W), np.float32)
53
  if axis == 'h':
54
  for x in range(W):
55
+ r = min(x, W-1-x)
56
+ if r==0: s[:,x]=1.; continue
57
+ s[:,x] = (grid[:,x-r:x]==grid[:,x+1:x+r+1][:,::-1]).mean()
 
 
58
  else:
59
  for y in range(H):
60
+ r = min(y, H-1-y)
61
+ if r==0: s[y,:]=1.; continue
62
+ s[y,:] = (grid[y-r:y,:]==grid[y+1:y+r+1,:][::-1,:]).mean()
63
+ return s
64
+
65
+ def _boundary(grid):
66
+ p = np.pad(grid, 1, mode='edge')
67
+ return ((p[1:-1,1:-1]!=p[:-2,1:-1])|(p[1:-1,1:-1]!=p[2:,1:-1])|
68
+ (p[1:-1,1:-1]!=p[1:-1,:-2])|(p[1:-1,1:-1]!=p[1:-1,2:])).astype(np.float32)
69
+
70
+ def compute_scores(grid):
71
+ h_sym = _sym_map(grid,'h').max()
72
+ v_sym = _sym_map(grid,'v').max()
73
+ b_cnt = int(_boundary(grid).sum())
74
+ gx,gy = _sobel(grid.astype(np.float32)/9)
75
+ edge = float(np.sqrt(gx**2+gy**2).mean())
76
+ n_comp = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  for c in range(10):
78
+ mask = (grid==c)
79
+ if mask.any(): n_comp += _connected_components(mask).max()
80
+ return dict(h_sym=float(h_sym), v_sym=float(v_sym),
81
+ boundary=b_cnt, edge=round(edge,4), components=int(n_comp))
82
+
83
+ # ── Grid rendering ────────────────────────────────────────────────────────────
84
+
85
+ def render_grid(grid, title='', highlight_wrong=None):
86
+ H,W = grid.shape
87
+ cell = max(32, min(60, 280//max(H,W)))
88
+ fig,ax = plt.subplots(figsize=((W*cell+4)/72, (H*cell+24)/72), dpi=72)
89
+ fig.patch.set_facecolor('#1e1e2e')
90
+ ax.set_facecolor('#1e1e2e')
91
+ ax.imshow(grid, cmap=ARC_CMAP, vmin=0, vmax=9,
92
+ interpolation='nearest', aspect='equal')
93
+ for x in range(W+1): ax.axvline(x-0.5, color='#555', lw=0.5)
94
+ for y in range(H+1): ax.axhline(y-0.5, color='#555', lw=0.5)
95
+ for r in range(H):
96
+ for c in range(W):
97
+ v = int(grid[r,c])
98
+ col = 'white' if v in [0,1,2,3,5,6,9] else 'black'
99
+ ax.text(c, r, str(v), ha='center', va='center',
100
+ fontsize=max(7, cell//5), color=col,
101
+ fontweight='bold', fontfamily='monospace')
102
+ if highlight_wrong is not None and highlight_wrong[r,c]:
103
+ ax.add_patch(plt.Rectangle(
104
+ (c-0.5,r-0.5),1,1, fill=False, edgecolor='#ff4444', lw=2))
105
+ ax.set_xlim(-0.5, W-0.5); ax.set_ylim(H-0.5, -0.5); ax.axis('off')
106
+ if title:
107
+ ax.set_title(title, color='#cdd6f4', fontsize=9, pad=3)
108
+ plt.tight_layout(pad=0.3)
109
+ buf = io.BytesIO()
110
+ plt.savefig(buf, format='png', dpi=72, bbox_inches='tight', facecolor='#1e1e2e')
111
+ buf.seek(0); img = Image.open(buf).copy(); plt.close()
112
+ return img
113
+
114
+ def render_feature_maps(grid):
115
+ gx,gy = _sobel(grid.astype(np.float32)/9)
116
+ maps = [
117
+ ('H-symmetry', _sym_map(grid,'h'), 'YlOrRd'),
118
+ ('V-symmetry', _sym_map(grid,'v'), 'YlOrRd'),
119
+ ('Boundary', _boundary(grid), 'plasma'),
120
+ ('Edge mag', np.sqrt(gx**2+gy**2), 'hot'),
121
+ ]
122
+ fig,axes = plt.subplots(1,4,figsize=(12,2.8))
123
+ fig.patch.set_facecolor('#1e1e2e')
124
+ for ax,(title,data,cmap) in zip(axes,maps):
125
+ ax.set_facecolor('#0d0d1a')
126
+ vmax = 1 if data.max()<=1 else data.max()
127
+ im = ax.imshow(data, cmap=cmap, vmin=0, vmax=vmax,
128
+ interpolation='nearest', aspect='equal')
129
+ ax.set_title(title, color='white', fontsize=9, pad=3)
130
+ ax.axis('off')
131
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
132
+ plt.tight_layout(pad=1)
133
+ buf = io.BytesIO()
134
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='#1e1e2e')
135
+ buf.seek(0); img = Image.open(buf).copy(); plt.close()
136
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # ── Parse ─────────────────────────────────────────────────────────────────────
139
 
140
  def parse_grid(text):
141
  try:
 
143
  for line in text.strip().split('\n'):
144
  line = line.strip()
145
  if not line: continue
146
+ rows.append([int(x) for x in line.split()])
147
+ if not rows: return None, "Empty"
 
148
  W = len(rows[0])
149
+ if any(len(r)!=W for r in rows): return None, "Rows have different lengths"
150
+ g = np.array(rows, dtype=np.int64)
151
+ if g.min()<0 or g.max()>9: return None, "Values must be 0–9"
152
+ return g, None
 
 
153
  except Exception as e:
154
  return None, str(e)
155
 
156
+ def blank_like(grid_txt):
157
+ g, err = parse_grid(grid_txt)
158
+ if err or g is None: return ""
159
+ return "\n".join(" ".join(["0"]*g.shape[1]) for _ in range(g.shape[0]))
160
+
161
+ # ── Puzzles ───────────────────────────────────────────────────────────────────
162
+
163
+ PUZZLES = {
164
+ "Mirror complete": {
165
+ "desc": "Complete the right half to mirror the left half.",
166
+ "train_in": "0 1 2 3 0 0 0 0\n0 4 0 3 0 0 0 0\n0 1 2 3 0 0 0 0",
167
+ "train_out": "0 1 2 3 3 2 1 0\n0 4 0 3 3 0 4 0\n0 1 2 3 3 2 1 0",
168
+ "test_in": "0 2 1 0 0 0 0 0\n0 2 3 0 0 0 0 0\n0 0 1 0 0 0 0 0",
169
+ "answer": "0 2 1 0 0 1 2 0\n0 2 3 0 0 3 2 0\n0 0 1 0 0 1 0 0",
170
+ },
171
+ "Color shift +1": {
172
+ "desc": "Each nonzero color increases by 1 (9 stays 9).",
173
+ "train_in": "0 1 0\n2 0 3\n0 4 0",
174
+ "train_out": "0 2 0\n3 0 4\n0 5 0",
175
+ "test_in": "0 3 1\n0 0 2\n4 0 0",
176
+ "answer": "0 4 2\n0 0 3\n5 0 0",
177
+ },
178
+ "Border only": {
179
+ "desc": "Keep only the outer border, fill interior with 0.",
180
+ "train_in": "2 2 2 2 2\n2 2 2 2 2\n2 2 2 2 2\n2 2 2 2 2\n2 2 2 2 2",
181
+ "train_out": "2 2 2 2 2\n2 0 0 0 2\n2 0 0 0 2\n2 0 0 0 2\n2 2 2 2 2",
182
+ "test_in": "3 3 3 3 3\n3 3 3 3 3\n3 3 3 3 3\n3 3 3 3 3\n3 3 3 3 3",
183
+ "answer": "3 3 3 3 3\n3 0 0 0 3\n3 0 0 0 3\n3 0 0 0 3\n3 3 3 3 3",
184
+ },
185
+ "Gravity down": {
186
+ "desc": "Each colored pixel falls to the bottom of its column.",
187
+ "train_in": "1 0 2\n0 0 0\n0 3 0",
188
+ "train_out": "0 0 0\n0 0 2\n1 3 0",
189
+ "test_in": "4 0 0\n0 5 0\n0 0 6",
190
+ "answer": "0 0 0\n0 0 0\n4 5 6",
191
+ },
192
+ "Count β†’ block": {
193
+ "desc": "Count the 1s; fill an NΓ—N block of 2s in the bottom-left.",
194
+ "train_in": "0 1 0 1 0\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0",
195
+ "train_out": "0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0\n2 2 0 0 0\n2 2 0 0 0",
196
+ "test_in": "0 1 0 1 1\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0",
197
+ "answer": "0 0 0 0 0\n0 0 0 0 0\n2 2 2 0 0\n2 2 2 0 0\n2 2 2 0 0",
198
+ },
199
+ }
200
 
201
+ # ── Handlers ──────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def on_load_puzzle(name):
204
+ p = PUZZLES[name]
205
+ ti_img, to_img, test_img = [
206
+ render_grid(parse_grid(t)[0], title=lbl)
207
+ for t, lbl in [(p["train_in"],"Training input"),
208
+ (p["train_out"],"Training output"),
209
+ (p["test_in"],"Test input β€” solve this")]
210
+ ]
211
+ blank = blank_like(p["answer"])
212
+ return (p["train_in"], p["train_out"], p["test_in"],
213
+ p["desc"], p["answer"],
214
+ blank,
215
+ ti_img, to_img, test_img,
216
+ None, None, None,
217
+ "Edit the answer grid above and click Submit.")
218
+
219
+ def on_submit(answer_txt, answer_hidden, test_in_txt):
220
+ answer_g, err = parse_grid(answer_txt)
221
+ if err:
222
+ return None, None, None, f"**Parse error:** {err}"
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ target_g, err2 = parse_grid(answer_hidden)
225
+ if err2:
226
+ return None, None, None, f"**Target parse error:** {err2}"
 
 
 
227
 
228
+ if answer_g.shape != target_g.shape:
229
+ return None, None, None, (
230
+ f"**Wrong shape:** your answer is {answer_g.shape}, "
231
+ f"target is {target_g.shape}. Check row count and width.")
 
 
232
 
233
+ correct = (answer_g == target_g)
234
+ pct = correct.mean() * 100
235
+ solved = bool(correct.all())
236
+ wrong = ~correct
237
 
238
+ answer_img = render_grid(answer_g, title="Your answer", highlight_wrong=wrong)
239
+ target_img = render_grid(target_g, title="Target output")
240
+ feat_img = render_feature_maps(answer_g)
241
 
242
+ sc = compute_scores(answer_g)
243
+ tsc = compute_scores(target_g)
 
 
 
244
 
245
+ emoji = "βœ…" if solved else "❌"
246
+ score_md = f"""### {emoji} {"Solved!" if solved else "Not yet β€” keep trying"}
247
 
248
+ **Pixel accuracy: {pct:.1f}%** &nbsp;Β·&nbsp; {int(correct.sum())} / {answer_g.size} cells correct
249
 
250
+ | Im-side feature | Your answer | Target |
251
+ |---|---|---|
252
+ | H-symmetry | {sc['h_sym']:.2f} | {tsc['h_sym']:.2f} |
253
+ | V-symmetry | {sc['v_sym']:.2f} | {tsc['v_sym']:.2f} |
254
+ | Boundary pixels | {sc['boundary']} | {tsc['boundary']} |
255
+ | Components | {sc['components']} | {tsc['components']} |
256
+ | Edge magnitude | {sc['edge']:.4f} | {tsc['edge']:.4f} |
257
 
258
+ *These Im-side signals are exactly what the CNN reads from your grid before taking any action.*
259
+ """
260
+ return answer_img, target_img, feat_img, score_md
261
 
262
+ def on_reset(answer_hidden):
263
+ return blank_like(answer_hidden)
264
 
265
+ # ── UI ────────────────────────────────────────────────────────────────────────
266
 
267
+ with gr.Blocks(title="ARC-AGI-3 Puzzle Interface") as demo:
 
 
 
268
 
269
+ # State
270
+ s_train_in = gr.State("")
271
+ s_train_out = gr.State("")
272
+ s_test_in = gr.State("")
273
+ s_answer = gr.State("")
274
 
275
  gr.Markdown("""
276
+ # ARC-AGI-3 Puzzle Interface
277
+ Study the training example, figure out the rule, then complete the test input.
278
+ Your pixel score updates on every submission β€” just like the real competition.
 
 
 
 
279
  """)
280
 
281
  with gr.Row():
282
+ puzzle_dd = gr.Dropdown(choices=list(PUZZLES.keys()),
283
+ value=list(PUZZLES.keys())[0],
284
+ label="Select puzzle", scale=2)
285
+ desc_md = gr.Markdown(scale=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ gr.Markdown("### Training example")
288
+ with gr.Row():
289
+ train_in_img = gr.Image(label="Input", type="pil", interactive=False, height=180)
290
+ train_out_img = gr.Image(label="Output", type="pil", interactive=False, height=180)
291
+ test_in_img = gr.Image(label="Test input β€” what should the output be?",
292
+ type="pil", interactive=False, height=180)
293
 
294
+ gr.Markdown("---\n### Your answer")
295
+ gr.Markdown("Enter space-separated integers 0–9, one row per line. "
296
+ "**Colors:** " + " Β· ".join(f"`{i}`={n}" for i,n in enumerate(COLOR_NAMES)))
 
 
297
 
298
+ with gr.Row():
299
+ with gr.Column(scale=1):
300
+ answer_box = gr.Textbox(label="Answer grid", lines=8,
301
+ placeholder="0 0 0\n0 0 0")
302
+ with gr.Row():
303
+ submit_btn = gr.Button("Submit", variant="primary")
304
+ reset_btn = gr.Button("Reset to zeros")
305
 
306
+ with gr.Column(scale=2):
307
+ with gr.Row():
308
+ answer_img = gr.Image(label="Your answer", type="pil",
309
+ interactive=False, height=200)
310
+ target_img = gr.Image(label="Target", type="pil",
311
+ interactive=False, height=200)
312
+ score_md = gr.Markdown("*Submit your answer to see your score.*")
313
+
314
+ gr.Markdown("---\n### Re/Im feature maps β€” what the CNN reads from your answer")
315
+ feat_img = gr.Image(label="Feature maps", type="pil", interactive=False)
316
+
317
+ # Wire up
318
+ all_outputs = [s_train_in, s_train_out, s_test_in, desc_md, s_answer,
319
+ answer_box,
320
+ train_in_img, train_out_img, test_in_img,
321
+ answer_img, target_img, feat_img, score_md]
322
+
323
+ puzzle_dd.change(on_load_puzzle, inputs=puzzle_dd, outputs=all_outputs)
324
+ submit_btn.click(on_submit,
325
+ inputs=[answer_box, s_answer, s_test_in],
326
+ outputs=[answer_img, target_img, feat_img, score_md])
327
+ reset_btn.click(on_reset, inputs=s_answer, outputs=answer_box)
328
+ demo.load(on_load_puzzle, inputs=puzzle_dd, outputs=all_outputs)
329
 
330
  if __name__ == "__main__":
331
+ demo.launch(css="""
332
+ #answer-box textarea { font-family: monospace; font-size: 13px; }
333
+ """)