beanapologist commited on
Commit
b8a8f34
Β·
verified Β·
1 Parent(s): 292cc35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -323
app.py CHANGED
@@ -1,6 +1,10 @@
1
  """
2
- ARC-AGI-3 Live Puzzle Interface
3
  Hugging Face Space: beanapologist/arc-agi
 
 
 
 
4
  """
5
 
6
  import gradio as gr
@@ -9,15 +13,75 @@ import matplotlib
9
  matplotlib.use('Agg')
10
  import matplotlib.pyplot as plt
11
  from matplotlib.colors import ListedColormap
12
- import io, json, os
13
  from PIL import Image
14
 
 
 
15
  ARC_HEX = ['#000000','#1a6faf','#e03a3a','#3aa63a','#f5c400',
16
  '#c060c0','#d07030','#aaaaaa','#60b8d0','#874010']
17
  ARC_CMAP = ListedColormap(ARC_HEX)
18
  COLOR_NAMES = ['black','blue','red','green','yellow',
19
  'purple','orange','gray','azure','maroon']
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def _cc(mask):
22
  labels=np.zeros_like(mask,dtype=np.int32); cur=0; H,W=mask.shape
23
  for r in range(H):
@@ -57,340 +121,307 @@ def _boundary(grid):
57
  return ((p[1:-1,1:-1]!=p[:-2,1:-1])|(p[1:-1,1:-1]!=p[2:,1:-1])|
58
  (p[1:-1,1:-1]!=p[1:-1,:-2])|(p[1:-1,1:-1]!=p[1:-1,2:])).astype(np.float32)
59
 
60
- def compute_scores(grid):
61
- h=float(_sym(grid,'h').max()); v=float(_sym(grid,'v').max())
62
- b=int(_boundary(grid).sum())
63
- gx,gy=_sobel(grid.astype(np.float32)/9)
64
- edge=round(float(np.sqrt(gx**2+gy**2).mean()),4)
65
- nc=sum(int(_cc(grid==c).max()) for c in range(10) if (grid==c).any())
66
- return dict(h_sym=h,v_sym=v,boundary=b,edge=edge,components=nc)
67
-
68
- def render_grid(grid,title='',highlight_wrong=None):
69
- H,W=grid.shape; cell=max(28,min(56,280//max(H,W)))
70
- fig,ax=plt.subplots(figsize=((W*cell+4)/72,(H*cell+20)/72),dpi=72)
71
- fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
72
- ax.imshow(grid,cmap=ARC_CMAP,vmin=0,vmax=9,interpolation='nearest',aspect='equal')
73
- for x in range(W+1): ax.axvline(x-.5,color='#555',lw=.5)
74
- for y in range(H+1): ax.axhline(y-.5,color='#555',lw=.5)
75
- for r in range(H):
76
- for c in range(W):
77
- v=int(grid[r,c])
78
- col='white' if v in [0,1,2,3,5,6,9] else 'black'
79
- ax.text(c,r,str(v),ha='center',va='center',
80
- fontsize=max(7,cell//5),color=col,fontweight='bold',fontfamily='monospace')
81
- if highlight_wrong is not None and highlight_wrong[r,c]:
82
- ax.add_patch(plt.Rectangle((c-.5,r-.5),1,1,fill=False,edgecolor='#ff4444',lw=2))
83
- ax.set_xlim(-.5,W-.5); ax.set_ylim(H-.5,-.5); ax.axis('off')
84
- if title: ax.set_title(title,color='#cdd6f4',fontsize=9,pad=3)
85
- plt.tight_layout(pad=.3)
86
- buf=io.BytesIO()
87
- plt.savefig(buf,format='png',dpi=72,bbox_inches='tight',facecolor='#1e1e2e')
88
- buf.seek(0); img=Image.open(buf).copy(); plt.close()
89
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- def render_feature_maps(grid):
92
- gx,gy=_sobel(grid.astype(np.float32)/9)
93
- maps=[('H-sym',_sym(grid,'h'),'YlOrRd'),('V-sym',_sym(grid,'v'),'YlOrRd'),
94
- ('Boundary',_boundary(grid),'plasma'),('Edge',np.sqrt(gx**2+gy**2),'hot')]
95
- fig,axes=plt.subplots(1,4,figsize=(12,2.8)); fig.patch.set_facecolor('#1e1e2e')
96
- for ax,(t,d,cm) in zip(axes,maps):
97
- ax.set_facecolor('#0d0d1a')
98
- im=ax.imshow(d,cmap=cm,vmin=0,vmax=max(1,float(d.max())),
99
- interpolation='nearest',aspect='equal')
100
- ax.set_title(t,color='white',fontsize=9,pad=3); ax.axis('off')
101
- plt.colorbar(im,ax=ax,fraction=.046,pad=.04)
102
- plt.tight_layout(pad=1)
103
- buf=io.BytesIO()
104
- plt.savefig(buf,format='png',dpi=100,bbox_inches='tight',facecolor='#1e1e2e')
105
- buf.seek(0); img=Image.open(buf).copy(); plt.close()
106
- return img
107
 
108
- def parse_grid(text):
109
- try:
110
- rows=[]
111
- for line in text.strip().split('\n'):
112
- line=line.strip()
113
- if not line: continue
114
- rows.append([int(x) for x in line.split()])
115
- if not rows: return None,"Empty"
116
- W=len(rows[0])
117
- if any(len(r)!=W for r in rows): return None,"Unequal rows"
118
- g=np.array(rows,dtype=np.int64)
119
- if g.min()<0 or g.max()>9: return None,"Values 0-9 only"
120
- return g,None
121
- except Exception as e: return None,str(e)
122
-
123
- BUILTIN_PUZZLES = {
124
- "Mirror complete": {
125
- "desc": "Complete the right half β€” mirror the left half horizontally.",
126
- "train_in": "0 1 2 3 0 0 0 0\n0 4 0 3 0 0 0 0\n0 1 2 3 0 0 0 0",
127
- "train_out": "0 1 2 3 3 2 1 0\n0 4 0 3 3 0 4 0\n0 1 2 3 3 2 1 0",
128
- "test_in": "0 2 1 0 0 0 0 0\n0 2 3 0 0 0 0 0\n0 0 1 0 0 0 0 0",
129
- "answer": "0 2 1 0 0 1 2 0\n0 2 3 0 0 3 2 0\n0 0 1 0 0 1 0 0",
130
- },
131
- "Color shift +1": {
132
- "desc": "Each nonzero color increases by 1.",
133
- "train_in": "0 1 0\n2 0 3\n0 4 0",
134
- "train_out": "0 2 0\n3 0 4\n0 5 0",
135
- "test_in": "0 3 1\n0 0 2\n4 0 0",
136
- "answer": "0 4 2\n0 0 3\n5 0 0",
137
- },
138
- "Border only": {
139
- "desc": "Keep the outer border, fill interior with 0.",
140
- "train_in": "2 2 2 2 2\n2 2 2 2 2\n2 2 2 2 2\n2 2 2 2 2\n2 2 2 2 2",
141
- "train_out": "2 2 2 2 2\n2 0 0 0 2\n2 0 0 0 2\n2 0 0 0 2\n2 2 2 2 2",
142
- "test_in": "3 3 3 3 3\n3 3 3 3 3\n3 3 3 3 3\n3 3 3 3 3\n3 3 3 3 3",
143
- "answer": "3 3 3 3 3\n3 0 0 0 3\n3 0 0 0 3\n3 0 0 0 3\n3 3 3 3 3",
144
- },
145
- "Gravity down": {
146
- "desc": "Each colored pixel falls to the bottom of its column.",
147
- "train_in": "1 0 2\n0 0 0\n0 3 0",
148
- "train_out": "0 0 0\n0 0 2\n1 3 0",
149
- "test_in": "4 0 0\n0 5 0\n0 0 6",
150
- "answer": "0 0 0\n0 0 0\n4 5 6",
151
- },
152
- "Count to block": {
153
- "desc": "Count the 1s; output an NxN block of 2s at bottom-left.",
154
- "train_in": "0 1 0 1 0\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0",
155
- "train_out": "0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0\n2 2 0 0 0\n2 2 0 0 0",
156
- "test_in": "0 1 0 1 1\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0\n0 0 0 0 0",
157
- "answer": "0 0 0 0 0\n0 0 0 0 0\n2 2 2 0 0\n2 2 2 0 0\n2 2 2 0 0",
158
- },
159
- }
160
-
161
- def make_grid_html(grid_list, selected_color=0):
162
- H=len(grid_list); W=len(grid_list[0])
163
- cell=min(52,max(24,300//max(H,W)))
164
- swatches="".join(
165
- f'<div onclick="selectColor({i})" title="{i}:{COLOR_NAMES[i]}" '
166
- f'style="width:26px;height:26px;background:{ARC_HEX[i]};'
167
- f'border:{"3px solid white" if i==selected_color else "2px solid #555"};'
168
- f'border-radius:4px;cursor:pointer;display:flex;align-items:center;'
169
- f'justify-content:center;color:{"white" if i in [0,1,2,3,5,6,9] else "black"};'
170
- f'font-size:10px;font-weight:bold;">{i}</div>'
171
- for i in range(10))
172
- rows="".join(
173
- "<tr>"+"".join(
174
- f'<td onclick="paintCell({r},{c})" id="cell-{r}-{c}" '
175
- f'style="width:{cell}px;height:{cell}px;background:{ARC_HEX[grid_list[r][c]]};'
176
- f'color:{"white" if grid_list[r][c] in [0,1,2,3,5,6,9] else "black"};'
177
- f'text-align:center;vertical-align:middle;font-size:{max(8,cell//5)}px;'
178
- f'font-weight:bold;font-family:monospace;cursor:pointer;border:1px solid #444;'
179
- f'user-select:none;">{grid_list[r][c]}</td>'
180
- for c in range(W))+"</tr>"
181
- for r in range(H))
182
- return f"""
183
- <div style="font-family:sans-serif;background:#1e1e2e;padding:12px;border-radius:8px;">
184
- <div style="font-size:11px;color:#888;margin-bottom:6px;">Pick color β†’ click cells β†’ Submit</div>
185
- <div style="display:flex;gap:5px;margin-bottom:8px;flex-wrap:wrap;">{swatches}</div>
186
- <table style="border-collapse:collapse;margin-bottom:8px;">{rows}</table>
187
- <div style="display:flex;gap:8px;">
188
- <button onclick="submitGrid()"
189
- style="padding:7px 18px;background:#4a9eff;color:white;border:none;
190
- border-radius:6px;cursor:pointer;font-size:13px;font-weight:bold;">Submit</button>
191
- <button onclick="resetGrid()"
192
- style="padding:7px 14px;background:#333;color:#ccc;border:1px solid #555;
193
- border-radius:6px;cursor:pointer;font-size:13px;">Reset</button>
194
- </div>
195
- </div>
196
- <script>
197
- var gridState={json.dumps(grid_list)};
198
- var selColor={selected_color};
199
- var HEX={json.dumps(ARC_HEX)};
200
- var DARK=[true,true,true,true,false,true,true,false,false,true];
201
- function selectColor(c){{
202
- selColor=c;
203
- document.querySelectorAll('[onclick^="selectColor"]').forEach(function(el,i){{
204
- el.style.border=i===c?'3px solid white':'2px solid #555';
205
- }});
206
- }}
207
- function paintCell(r,c){{
208
- gridState[r][c]=selColor;
209
- var td=document.getElementById('cell-'+r+'-'+c);
210
- td.style.background=HEX[selColor];
211
- td.style.color=DARK[selColor]?'white':'black';
212
- td.textContent=selColor;
213
- }}
214
- function resetGrid(){{
215
- for(var r=0;r<gridState.length;r++)
216
- for(var c=0;c<gridState[r].length;c++){{
217
- gridState[r][c]=0;
218
- var td=document.getElementById('cell-'+r+'-'+c);
219
- td.style.background=HEX[0];td.style.color='white';td.textContent='0';
220
- }}
221
- }}
222
- function submitGrid(){{
223
- var txt=gridState.map(function(row){{return row.join(' ');}}).join('\\n');
224
- var ta=document.querySelector('#grid-state-box textarea');
225
- if(ta){{ta.value=txt;ta.dispatchEvent(new Event('input',{{bubbles:true}}));}}
226
- }}
227
- </script>"""
228
-
229
- def get_arcade(api_key=""):
230
  try:
231
  import arc_agi
232
- key=api_key or os.environ.get("ARC_API_KEY","")
233
- return arc_agi.Arcade(arc_api_key=key),None
234
- except Exception as e: return None,str(e)
235
-
236
- def on_load_builtin(name):
237
- p=BUILTIN_PUZZLES[name]
238
- ti_img=render_grid(parse_grid(p["train_in"])[0], title="Training input")
239
- to_img=render_grid(parse_grid(p["train_out"])[0], title="Training output")
240
- test_img=render_grid(parse_grid(p["test_in"])[0], title="Test input")
241
- ans_g=parse_grid(p["answer"])[0]
242
- blank=np.zeros_like(ans_g).tolist()
243
- return (p["desc"], p["answer"], ti_img, to_img, test_img,
244
- make_grid_html(blank), None, None, None,
245
- "*Paint your answer and click Submit.*")
246
-
247
- def on_grid_submit(grid_txt, answer_hidden):
248
- if not grid_txt or not grid_txt.strip():
249
- return None,None,None,"*Paint something first.*"
250
- g,err=parse_grid(grid_txt)
251
- if err: return None,None,None,f"**Parse error:** {err}"
252
- feat_img=render_feature_maps(g)
253
- sc=compute_scores(g)
254
- if answer_hidden and answer_hidden.strip():
255
- target_g,terr=parse_grid(answer_hidden)
256
- if terr is None and target_g.shape==g.shape:
257
- correct=(g==target_g); pct=correct.mean()*100
258
- solved=bool(correct.all()); wrong=~correct
259
- answer_img=render_grid(g,title="Your answer",highlight_wrong=wrong)
260
- target_img=render_grid(target_g,title="Target")
261
- tsc=compute_scores(target_g)
262
- emoji="βœ…" if solved else "❌"
263
- return answer_img,target_img,feat_img,f"""### {emoji} {"Solved!" if solved else "Not yet!"}
264
- **Pixel accuracy: {pct:.1f}%** Β· {int(correct.sum())}/{g.size} cells correct
265
-
266
- | Feature | Yours | Target |
267
- |---|---|---|
268
- | H-symmetry | {sc['h_sym']:.2f} | {tsc['h_sym']:.2f} |
269
- | V-symmetry | {sc['v_sym']:.2f} | {tsc['v_sym']:.2f} |
270
- | Boundary px | {sc['boundary']} | {tsc['boundary']} |
271
- | Components | {sc['components']} | {tsc['components']} |
272
- """
273
- answer_img=render_grid(g,title="Your answer")
274
- return answer_img,None,feat_img,f"""### Feature analysis
275
- | Feature | Value |
276
- |---|---|
277
- | H-symmetry | {sc['h_sym']:.2f} |
278
- | V-symmetry | {sc['v_sym']:.2f} |
279
- | Boundary pixels | {sc['boundary']} |
280
- | Components | {sc['components']} |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  """
 
282
 
283
- def on_refresh_games(api_key):
284
- arc,err=get_arcade(api_key)
285
- if err: return gr.Dropdown(choices=[]),f"**Error:** {err}"
286
- try:
287
- envs=arc.get_environments()
288
- ids=[e.game_id for e in envs]
289
- return gr.Dropdown(choices=ids,value=ids[0] if ids else None),f"Found **{len(ids)}** games."
290
- except Exception as e:
291
- return gr.Dropdown(choices=[]),f"**Error:** {e}"
292
 
293
- def on_load_live(game_id,api_key):
294
- if not game_id: return None,None,"Enter a game ID."
295
- arc,err=get_arcade(api_key)
296
- if err: return None,None,f"**API error:** {err}"
297
- try:
298
- env=arc.make(game_id,include_frame_data=True)
299
- frame=env.reset()
300
- raw=np.array(frame.frame,dtype=np.int64)
301
- grid=raw[-1] if raw.ndim==3 else raw
302
- blank=np.zeros_like(grid).tolist()
303
- frame_img=render_grid(grid,title=f"{game_id} β€” current frame")
304
- return frame_img,make_grid_html(blank),(
305
- f"**Game:** `{game_id}` | Levels: {frame.levels_completed} | "
306
- f"Grid: {grid.shape[0]}Γ—{grid.shape[1]}")
307
- except Exception as e:
308
- return None,None,f"**Error:** {e}"
309
-
310
- with gr.Blocks(title="ARC-AGI-3 Puzzle Interface") as demo:
311
 
312
- s_answer=gr.State("")
313
 
314
  gr.Markdown("""
315
- # ARC-AGI-3 Puzzle Interface
316
- Paint cells to solve ARC puzzles. Use **Practice puzzles** to learn,
317
- or connect to the **Live ARC API** with your key from [docs.arcprize.org/api-keys](https://docs.arcprize.org/api-keys).
318
  """)
319
 
320
- with gr.Tabs():
321
-
322
- with gr.Tab("Practice puzzles"):
323
- with gr.Row():
324
- with gr.Column(scale=2):
325
- builtin_dd=gr.Dropdown(choices=list(BUILTIN_PUZZLES.keys()),
326
- value=list(BUILTIN_PUZZLES.keys())[0],
327
- label="Select puzzle")
328
- with gr.Column(scale=3):
329
- builtin_desc=gr.Markdown()
330
-
331
- gr.Markdown("### Training example")
332
- with gr.Row():
333
- b_train_in =gr.Image(label="Input", type="pil",interactive=False,height=160)
334
- b_train_out=gr.Image(label="Output", type="pil",interactive=False,height=160)
335
- b_test_in =gr.Image(label="Test input β€” solve this",
336
- type="pil",interactive=False,height=160)
337
-
338
- gr.Markdown("### Your answer β€” click cells to paint")
339
- b_html=gr.HTML()
340
- b_grid_box=gr.Textbox(label="Grid state",elem_id="grid-state-box",
341
- lines=2,placeholder="Paint above then click Submit")
342
- gr.Markdown("### Score")
343
  with gr.Row():
344
- b_ans_img=gr.Image(label="Your answer",type="pil",interactive=False,height=180)
345
- b_tgt_img=gr.Image(label="Target", type="pil",interactive=False,height=180)
346
- b_score=gr.Markdown("*Paint and click Submit to score.*")
347
- b_feat =gr.Image(label="Re/Im feature maps",type="pil",interactive=False)
348
-
349
- b_outs=[builtin_desc,s_answer,b_train_in,b_train_out,b_test_in,
350
- b_html,b_ans_img,b_tgt_img,b_feat,b_score]
351
- builtin_dd.change(on_load_builtin,inputs=builtin_dd,outputs=b_outs)
352
- b_grid_box.change(on_grid_submit,inputs=[b_grid_box,s_answer],
353
- outputs=[b_ans_img,b_tgt_img,b_feat,b_score])
354
-
355
- with gr.Tab("Live ARC API"):
356
- gr.Markdown("Enter your API key then click **Fetch games** to see available games.")
357
- with gr.Row():
358
- with gr.Column(scale=3):
359
- api_box=gr.Textbox(label="ARC API key",type="password",
360
- value=os.environ.get("ARC_API_KEY",""),
361
- placeholder="arc-key-... (or set ARC_API_KEY as HF Secret)")
362
- with gr.Column(scale=1):
363
- refresh_btn=gr.Button("Fetch games")
364
- with gr.Row():
365
- with gr.Column(scale=2):
366
- live_dd=gr.Dropdown(label="Game ID",choices=[])
367
- with gr.Column(scale=1):
368
- load_btn=gr.Button("Load game",variant="primary")
369
- api_status=gr.Markdown()
370
- live_desc=gr.Markdown()
371
-
372
- gr.Markdown("### Current frame")
373
- l_frame=gr.Image(label="Live frame",type="pil",interactive=False,height=220)
374
-
375
- gr.Markdown("### Paint your answer")
376
- l_html=gr.HTML()
377
- l_grid_box=gr.Textbox(label="Grid state",elem_id="grid-state-box",
378
- lines=2,placeholder="Paint above then click Submit")
379
- with gr.Row():
380
- l_ans_img=gr.Image(label="Your answer",type="pil",interactive=False,height=180)
381
- l_tgt_img=gr.Image(label="Target (if available)",
382
- type="pil",interactive=False,height=180)
383
- l_score=gr.Markdown()
384
- l_feat =gr.Image(label="Re/Im feature maps",type="pil",interactive=False)
385
-
386
- refresh_btn.click(on_refresh_games,inputs=api_box,
387
- outputs=[live_dd,api_status])
388
- load_btn.click(on_load_live,inputs=[live_dd,api_box],
389
- outputs=[l_frame,l_html,live_desc])
390
- l_grid_box.change(on_grid_submit,inputs=[l_grid_box,s_answer],
391
- outputs=[l_ans_img,l_tgt_img,l_feat,l_score])
392
-
393
- demo.load(on_load_builtin,inputs=builtin_dd,outputs=b_outs)
394
 
395
  if __name__ == "__main__":
396
  demo.launch()
 
1
  """
2
+ ARC-AGI-3 Agent Spectator
3
  Hugging Face Space: beanapologist/arc-agi
4
+
5
+ Watch the Re/Im agent explore live ARC-AGI-3 games in real time.
6
+ Enter your API key, pick a game, hit Watch β€” the agent plays and
7
+ the grid updates every step.
8
  """
9
 
10
  import gradio as gr
 
13
  matplotlib.use('Agg')
14
  import matplotlib.pyplot as plt
15
  from matplotlib.colors import ListedColormap
16
+ import io, json, os, time, threading, queue
17
  from PIL import Image
18
 
19
+ # ── Palette ───────────────────────────────────────────────────────────────────
20
+
21
  ARC_HEX = ['#000000','#1a6faf','#e03a3a','#3aa63a','#f5c400',
22
  '#c060c0','#d07030','#aaaaaa','#60b8d0','#874010']
23
  ARC_CMAP = ListedColormap(ARC_HEX)
24
  COLOR_NAMES = ['black','blue','red','green','yellow',
25
  'purple','orange','gray','azure','maroon']
26
 
27
+ # ── Rendering ─────────────────────────────────────────────────────────────────
28
+
29
+ def render_grid(grid, title='', highlight_diff=None):
30
+ if grid is None: return None
31
+ H, W = grid.shape
32
+ cell = max(28, min(60, 360 // max(H, W)))
33
+ fig, ax = plt.subplots(figsize=((W*cell+4)/72, (H*cell+22)/72), dpi=72)
34
+ fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
35
+ ax.imshow(grid, cmap=ARC_CMAP, vmin=0, vmax=9,
36
+ interpolation='nearest', aspect='equal')
37
+ for x in range(W+1): ax.axvline(x-.5, color='#444', lw=.5)
38
+ for y in range(H+1): ax.axhline(y-.5, color='#444', lw=.5)
39
+ for r in range(H):
40
+ for c in range(W):
41
+ v = int(grid[r, c])
42
+ col = 'white' if v in [0,1,2,3,5,6,9] else 'black'
43
+ ax.text(c, r, str(v), ha='center', va='center',
44
+ fontsize=max(7, cell//5), color=col,
45
+ fontweight='bold', fontfamily='monospace')
46
+ if highlight_diff is not None and highlight_diff[r, c]:
47
+ ax.add_patch(plt.Rectangle(
48
+ (c-.5,r-.5), 1, 1, fill=True,
49
+ facecolor='#ffffff', alpha=0.25, lw=0))
50
+ ax.set_xlim(-.5, W-.5); ax.set_ylim(H-.5, -.5); ax.axis('off')
51
+ if title:
52
+ ax.set_title(title, color='#cdd6f4', fontsize=10, pad=4)
53
+ plt.tight_layout(pad=.3)
54
+ buf = io.BytesIO()
55
+ plt.savefig(buf, format='png', dpi=72, bbox_inches='tight',
56
+ facecolor='#1e1e2e')
57
+ buf.seek(0); img = Image.open(buf).copy(); plt.close()
58
+ return img
59
+
60
+ def render_action_bar(action_counts, total):
61
+ """Horizontal bar chart of action frequency."""
62
+ if not action_counts or total == 0: return None
63
+ labels = [f"A{k}" for k in sorted(action_counts)]
64
+ vals = [action_counts[k] for k in sorted(action_counts)]
65
+ fig, ax = plt.subplots(figsize=(5, 1.4))
66
+ fig.patch.set_facecolor('#1e1e2e'); ax.set_facecolor('#1e1e2e')
67
+ colors = ['#4a9eff','#e05050','#50c050','#f5c400','#c060c0','#d07030','#60b8d0']
68
+ bars = ax.barh(labels, vals, color=colors[:len(labels)], height=0.6)
69
+ for bar, v in zip(bars, vals):
70
+ ax.text(bar.get_width()+.3, bar.get_y()+bar.get_height()/2,
71
+ str(v), va='center', color='white', fontsize=8)
72
+ ax.set_xlim(0, max(vals)*1.25)
73
+ ax.tick_params(colors='#888', labelsize=8)
74
+ ax.spines[:].set_visible(False)
75
+ ax.set_facecolor('#1e1e2e')
76
+ plt.tight_layout(pad=.4)
77
+ buf = io.BytesIO()
78
+ plt.savefig(buf, format='png', dpi=90, bbox_inches='tight',
79
+ facecolor='#1e1e2e')
80
+ buf.seek(0); img = Image.open(buf).copy(); plt.close()
81
+ return img
82
+
83
+ # ── Minimal inline agent (no file import needed) ──────────────────────────────
84
+
85
  def _cc(mask):
86
  labels=np.zeros_like(mask,dtype=np.int32); cur=0; H,W=mask.shape
87
  for r in range(H):
 
121
  return ((p[1:-1,1:-1]!=p[:-2,1:-1])|(p[1:-1,1:-1]!=p[2:,1:-1])|
122
  (p[1:-1,1:-1]!=p[1:-1,:-2])|(p[1:-1,1:-1]!=p[1:-1,2:])).astype(np.float32)
123
 
124
+ def extract_features_fast(grid, num_colours=10):
125
+ """Lightweight version of the 56-channel extractor for the Space demo."""
126
+ import torch, torch.nn.functional as F
127
+ H, W = grid.shape
128
+ one_hot = np.zeros((num_colours, H, W), dtype=np.float32)
129
+ for c in range(num_colours):
130
+ one_hot[c] = (grid==c).astype(np.float32)
131
+ gx, gy = _sobel(grid.astype(np.float32)/9)
132
+ h_sym = _sym(grid,'h')[np.newaxis]
133
+ v_sym = _sym(grid,'v')[np.newaxis]
134
+ bound = _boundary(grid)[np.newaxis]
135
+ edge = np.sqrt(gx**2+gy**2)[np.newaxis].astype(np.float32)
136
+ stacked = np.concatenate([one_hot, h_sym, v_sym, bound, edge], axis=0)
137
+ t = torch.from_numpy(stacked).float().unsqueeze(0)
138
+ if H != 64 or W != 64:
139
+ t = F.interpolate(t, size=(64,64), mode='bilinear', align_corners=False)
140
+ return t.squeeze(0) # (14, 64, 64)
141
+
142
+ class TinyAgent:
143
+ """Stripped-down CNN agent for the spectator demo."""
144
+ def __init__(self):
145
+ import torch, torch.nn as nn
146
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
147
+ self.model = self._make_model().to(self.device)
148
+ self.opt = torch.optim.Adam(self.model.parameters(), lr=1e-4)
149
+ self.buf = []
150
+ self.prev_feat = None
151
+ self.prev_action = None
152
+ self.step_count = 0
153
+ self.action_counts = {}
154
+
155
+ def _make_model(self):
156
+ import torch.nn as nn
157
+ return nn.Sequential(
158
+ nn.Conv2d(14, 32, 3, padding=1), nn.ReLU(),
159
+ nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
160
+ nn.Conv2d(64, 128,3, padding=1), nn.ReLU(),
161
+ nn.AdaptiveAvgPool2d(8),
162
+ nn.Flatten(),
163
+ nn.Linear(128*8*8, 256), nn.ReLU(),
164
+ nn.Linear(256, 6), # actions 1-6
165
+ )
166
+
167
+ def reset(self):
168
+ import torch, torch.nn as nn
169
+ self.model = self._make_model().to(self.device)
170
+ self.opt = torch.optim.Adam(self.model.parameters(), lr=1e-4)
171
+ self.buf = []; self.prev_feat = None; self.prev_action = None
172
+ self.action_counts = {}
173
+
174
+ def choose(self, grid, available_actions=None):
175
+ import torch, torch.nn.functional as F
176
+ feat = extract_features_fast(grid).to(self.device)
177
+
178
+ # Store experience
179
+ if self.prev_feat is not None:
180
+ changed = not np.array_equal(
181
+ self.prev_feat.cpu().numpy(),
182
+ feat.cpu().numpy())
183
+ self.buf.append((self.prev_feat, self.prev_action, 1.0 if changed else 0.0))
184
+ if len(self.buf) > 500: self.buf.pop(0)
185
+
186
+ # Train every 10 steps
187
+ if self.step_count % 10 == 0 and len(self.buf) >= 16:
188
+ self._train()
189
+
190
+ # Sample action
191
+ with torch.no_grad():
192
+ logits = self.model(feat.unsqueeze(0)).squeeze(0)
193
+ mask = list(range(1,7))
194
+ if available_actions:
195
+ mask = [int(a.value if hasattr(a,'value') else a) for a in available_actions
196
+ if int(a.value if hasattr(a,'value') else a) <= 6]
197
+ indices = [m-1 for m in mask if 1 <= m <= 6]
198
+ masked = torch.full((6,), float('-inf'))
199
+ for i in indices: masked[i] = logits[i]
200
+ probs = torch.softmax(masked, dim=0).cpu().numpy()
201
+ probs = np.nan_to_num(probs, nan=1/len(indices))
202
+ if probs.sum() == 0: probs[indices] = 1/len(indices)
203
+ probs = probs / probs.sum()
204
+ action_idx = np.random.choice(6, p=probs)
205
+
206
+ self.prev_feat = feat
207
+ self.prev_action = action_idx
208
+ self.step_count += 1
209
+ a_id = action_idx + 1
210
+ self.action_counts[a_id] = self.action_counts.get(a_id, 0) + 1
211
+
212
+ from arcengine import GameAction
213
+ return GameAction(a_id), dict(probs=probs.tolist())
214
+
215
+ def _train(self):
216
+ import torch, torch.nn.functional as F
217
+ import random
218
+ batch = random.sample(self.buf, min(16, len(self.buf)))
219
+ states = torch.stack([b[0] for b in batch]).to(self.device)
220
+ actions = torch.tensor([b[1] for b in batch], dtype=torch.long, device=self.device)
221
+ rewards = torch.tensor([b[2] for b in batch], dtype=torch.float32, device=self.device)
222
+ self.opt.zero_grad()
223
+ logits = self.model(states)
224
+ loss = F.binary_cross_entropy_with_logits(
225
+ logits.gather(1, actions.unsqueeze(1)).squeeze(1), rewards)
226
+ loss.backward(); self.opt.step()
227
+
228
+
229
+ # ── Session state ─────────────────────────────────────────────────────────────
230
+ # One agent per Space instance β€” shared across users viewing the demo.
231
+ _agent = TinyAgent()
232
+ _stop_flag = threading.Event()
233
+ _run_thread = None
234
+ _frame_queue = queue.Queue(maxsize=30)
235
+
236
+ def _run_agent(game_id, api_key, max_steps):
237
+ """Background thread: run agent, push frames to queue."""
238
+ import arc_agi
239
+ from arcengine import GameState
240
+ try:
241
+ arc = arc_agi.Arcade(arc_api_key=api_key)
242
+ env = arc.make(game_id, include_frame_data=True)
243
+ frame = env.reset()
244
+ _agent.reset()
245
+ prev_grid = None
246
+ step = 0
247
+
248
+ while not _stop_flag.is_set() and step < max_steps:
249
+ if frame is None: break
250
+ raw = np.array(frame.frame, dtype=np.int64)
251
+ grid = raw[-1] if raw.ndim == 3 else raw
252
+
253
+ avail = getattr(frame, 'available_actions', None)
254
+ action, info = _agent.choose(grid, avail)
255
+
256
+ diff = (grid != prev_grid) if prev_grid is not None else None
257
+ prev_grid = grid.copy()
258
+
259
+ state_str = str(getattr(frame, 'state', ''))
260
+ levels = getattr(frame, 'levels_completed', 0)
261
+
262
+ _frame_queue.put({
263
+ 'grid': grid,
264
+ 'diff': diff,
265
+ 'step': step,
266
+ 'action': int(action.value),
267
+ 'levels': levels,
268
+ 'state': state_str,
269
+ 'probs': info['probs'],
270
+ 'counts': dict(_agent.action_counts),
271
+ }, block=True, timeout=5)
272
+
273
+ if 'WIN' in state_str or 'GAME_OVER' in state_str:
274
+ break
275
+
276
+ frame = env.step(action)
277
+ step += 1
278
+ time.sleep(0.05) # ~20 fps max
279
+
280
+ _frame_queue.put({'done': True, 'step': step})
281
+ except Exception as e:
282
+ _frame_queue.put({'error': str(e)})
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ # ── Gradio handlers ───────────────────────────────────────────────────────────
286
+
287
+ def fetch_games(api_key):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  try:
289
  import arc_agi
290
+ arc = arc_agi.Arcade(arc_api_key=api_key)
291
+ envs = arc.get_environments()
292
+ ids = [e.game_id for e in envs]
293
+ return gr.Dropdown(choices=ids, value=ids[0] if ids else None), \
294
+ f"Found **{len(ids)}** games."
295
+ except Exception as e:
296
+ return gr.Dropdown(choices=[]), f"**Error:** {e}"
297
+
298
+ def start_agent(game_id, api_key, max_steps):
299
+ global _run_thread, _stop_flag
300
+ if not game_id: return "Select a game first."
301
+ if not api_key: return "Enter your API key first."
302
+ _stop_flag.set()
303
+ if _run_thread and _run_thread.is_alive():
304
+ _run_thread.join(timeout=3)
305
+ while not _frame_queue.empty():
306
+ try: _frame_queue.get_nowait()
307
+ except: break
308
+ _stop_flag.clear()
309
+ _run_thread = threading.Thread(
310
+ target=_run_agent,
311
+ args=(game_id, api_key, int(max_steps)),
312
+ daemon=True)
313
+ _run_thread.start()
314
+ return f"Agent started on **{game_id}** for {int(max_steps)} steps."
315
+
316
+ def stop_agent():
317
+ _stop_flag.set()
318
+ return "Agent stopped."
319
+
320
+ def stream_frames():
321
+ """Generator: yield (grid_img, bar_img, status_md) for each frame."""
322
+ while True:
323
+ try:
324
+ data = _frame_queue.get(timeout=1)
325
+ except queue.Empty:
326
+ yield None, None, "*Waiting for agent...*"
327
+ continue
328
+
329
+ if 'error' in data:
330
+ yield None, None, f"**Error:** {data['error']}"
331
+ return
332
+ if data.get('done'):
333
+ yield None, None, f"**Done** β€” {data['step']} steps completed."
334
+ return
335
+
336
+ grid = data['grid']
337
+ diff = data['diff']
338
+ step = data['step']
339
+ action_id = data['action']
340
+ levels = data['levels']
341
+ state_str = data['state']
342
+ counts = data['counts']
343
+ probs = data['probs']
344
+
345
+ grid_img = render_grid(grid,
346
+ title=f"Step {step} | Action {action_id} | Levels {levels}",
347
+ highlight_diff=diff)
348
+ bar_img = render_action_bar(counts, sum(counts.values()))
349
+
350
+ action_names = {1:'A1',2:'A2',3:'A3',4:'A4',5:'A5',6:'A6(click)'}
351
+ prob_str = " ".join(
352
+ f"**{action_names.get(i+1,str(i+1))}** {p:.2f}"
353
+ for i,p in enumerate(probs))
354
+
355
+ status = f"""**Step:** {step} &nbsp;|&nbsp; **Action:** {action_id} &nbsp;|&nbsp; **Levels:** {levels} &nbsp;|&nbsp; **State:** {state_str}
356
+
357
+ Action probabilities: {prob_str}
358
  """
359
+ yield grid_img, bar_img, status
360
 
 
 
 
 
 
 
 
 
 
361
 
362
+ # ── UI ────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ with gr.Blocks(title="ARC-AGI-3 Agent Spectator") as demo:
365
 
366
  gr.Markdown("""
367
+ # ARC-AGI-3 Agent Spectator
368
+ Watch the Re/Im CNN agent explore live ARC-AGI-3 games in real time.
369
+ Get your API key at [docs.arcprize.org/api-keys](https://docs.arcprize.org/api-keys).
370
  """)
371
 
372
+ with gr.Row():
373
+ with gr.Column(scale=3):
374
+ api_box = gr.Textbox(
375
+ label="ARC API key",
376
+ type="password",
377
+ value=os.environ.get("ARC_API_KEY",""),
378
+ placeholder="arc-key-... (or set ARC_API_KEY as HF Space secret)")
379
+ with gr.Column(scale=1):
380
+ fetch_btn = gr.Button("Fetch games")
381
+
382
+ with gr.Row():
383
+ with gr.Column(scale=2):
384
+ game_dd = gr.Dropdown(label="Game", choices=[])
385
+ with gr.Column(scale=1):
386
+ steps_sl = gr.Slider(label="Max steps", minimum=20,
387
+ maximum=500, value=100, step=10)
388
+ with gr.Column(scale=1):
 
 
 
 
 
 
389
  with gr.Row():
390
+ start_btn = gr.Button("β–Ά Watch", variant="primary")
391
+ stop_btn = gr.Button("β–  Stop")
392
+
393
+ api_status = gr.Markdown()
394
+ run_status = gr.Markdown("*Press Fetch games, select a game, then Watch.*")
395
+
396
+ gr.Markdown("---")
397
+
398
+ with gr.Row():
399
+ grid_img = gr.Image(label="Current frame", type="pil",
400
+ interactive=False, height=320)
401
+ bar_img = gr.Image(label="Action frequency", type="pil",
402
+ interactive=False, height=320)
403
+
404
+ stream_btn = gr.Button("⟳ Refresh frame", variant="secondary")
405
+
406
+ fetch_btn.click(fetch_games, inputs=api_box,
407
+ outputs=[game_dd, api_status])
408
+ start_btn.click(start_agent, inputs=[game_dd, api_box, steps_sl],
409
+ outputs=run_status)
410
+ stop_btn.click(stop_agent, outputs=run_status)
411
+ stream_btn.click(
412
+ lambda: next(stream_frames()),
413
+ outputs=[grid_img, bar_img, run_status])
414
+
415
+ gr.Markdown("""
416
+ ---
417
+ **How it works:** The agent encodes each frame as 14 feature channels
418
+ (10 one-hot colors + H-symmetry + V-symmetry + boundary contour + edge magnitude)
419
+ and feeds them through a tiny CNN. It learns online: reward = 1 if the action
420
+ changed the frame, 0 if not. The action frequency chart shows which actions
421
+ the CNN is favouring as it learns.
422
+
423
+ For the full 56-channel extractor used in the Kaggle submission, see the agent code.
424
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  if __name__ == "__main__":
427
  demo.launch()