| |
| |
|
|
| import json, os, sys, time, itertools, hashlib |
| from pathlib import Path |
| from collections import defaultdict, Counter |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
|
|
| |
| |
| |
|
|
| TTT_TIME_BUDGET = 120 |
| DSL_TIME_BUDGET = 10 |
| TOTAL_TIME_BUDGET = 11 * 3600 |
| TTT_MAX_GRID_SIZE = 30 |
|
|
| |
| TTT_D_MODEL = 64 |
| TTT_NHEAD = 2 |
| TTT_NUM_LAYERS = 2 |
| TTT_DIM_FEEDFORWARD = 256 |
| TTT_LR = 1e-3 |
| TTT_BATCH_SIZE = 16 |
| TTT_MIN_EPOCHS = 20 |
| TTT_MAX_EPOCHS = 100 |
|
|
| |
| PAD_TOKEN, ROW_SEP, EOS_TOKEN = 10, 11, 12 |
| VOCAB_SIZE = 13 |
|
|
| |
| |
| |
|
|
| def encode_grid(grid, max_size=TTT_MAX_GRID_SIZE): |
| grid = np.array(grid, dtype=np.int32) |
| h, w = min(grid.shape[0], max_size), min(grid.shape[1], max_size) |
| tokens = [] |
| for r in range(h): |
| for c in range(w): |
| tokens.append(int(grid[r, c])) |
| tokens.extend([PAD_TOKEN] * (max_size - w) + [ROW_SEP]) |
| tokens.extend(([PAD_TOKEN] * max_size + [ROW_SEP]) * (max_size - h)) |
| tokens.append(EOS_TOKEN) |
| return tokens |
|
|
| def decode_grid(tokens, h, w): |
| eos = next((i for i, t in enumerate(tokens) if t == EOS_TOKEN), len(tokens)) |
| grid = np.zeros((h, w), dtype=np.int32) |
| r = c = 0 |
| for t in tokens[:eos]: |
| if t == ROW_SEP: |
| r += 1; c = 0 |
| if r >= h: break |
| elif t != PAD_TOKEN: |
| if r < h and c < w: grid[r, c] = min(t, 9) |
| c += 1 |
| if c >= w: c = 0 |
| else: |
| c += 1 |
| if c >= w: c = 0 |
| return grid |
|
|
| |
| |
| |
|
|
| def find_objects(grid, bg=0): |
| h, w = grid.shape |
| visited = np.zeros((h, w), dtype=bool) |
| objects = [] |
| for r in range(h): |
| for c in range(w): |
| if grid[r, c] != bg and not visited[r, c]: |
| color = grid[r, c] |
| pixels, queue = [], [(r, c)] |
| visited[r, c] = True |
| while queue: |
| cr, cc = queue.pop(0) |
| pixels.append((cr, cc)) |
| for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]: |
| nr, nc = cr+dr, cc+dc |
| if 0<=nr<h and 0<=nc<w and not visited[nr,nc] and grid[nr,nc]==color: |
| visited[nr,nc] = True |
| queue.append((nr,nc)) |
| rows, cols = [p[0] for p in pixels], [p[1] for p in pixels] |
| objects.append({ |
| 'color': int(color), 'pixels': pixels, |
| 'bbox': (min(rows), max(rows)+1, min(cols), max(cols)+1), |
| 'size': len(pixels) |
| }) |
| return objects |
|
|
| def extract_object_mask(grid, obj): |
| mask = np.zeros_like(grid) |
| for r, c in obj['pixels']: |
| mask[r, c] = grid[r, c] |
| r0, r1, c0, c1 = obj['bbox'] |
| return mask[r0:r1, c0:c1] |
|
|
| |
| |
| |
|
|
| def _bbox(grid): |
| nz = np.argwhere(grid != 0) |
| if len(nz) == 0: return 0, grid.shape[0], 0, grid.shape[1] |
| return nz[:,0].min(), nz[:,0].max()+1, nz[:,1].min(), nz[:,1].max()+1 |
|
|
| def _crop(grid): |
| r0,r1,c0,c1 = _bbox(grid) |
| return grid[r0:r1, c0:c1] |
|
|
| def _pad_sq(grid): |
| h, w = grid.shape; s = max(h,w) |
| r = np.zeros((s,s), dtype=grid.dtype); r[:h,:w] = grid; return r |
|
|
| def _border(grid): |
| r = grid.copy() |
| if r.shape[0]>2 and r.shape[1]>2: r[1:-1,1:-1] = 0 |
| return _crop(r) |
|
|
| def _fill_holes(grid): |
| h, w = grid.shape |
| if h < 3 or w < 3: return grid.copy() |
| ext = np.zeros((h,w), dtype=bool) |
| seeds = [(0,c) for c in range(w) if grid[0,c]==0] + \ |
| [(-1,c) for c in range(w) if grid[-1,c]==0] + \ |
| [(r,0) for r in range(h) if grid[r,0]==0] + \ |
| [(r,-1) for r in range(h) if grid[r,-1]==0] |
| for sr, sc in seeds: |
| if not ext[sr,sc]: |
| q = [(sr,sc)]; ext[sr,sc] = True |
| while q: |
| cr, cc = q.pop(0) |
| for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]: |
| nr, nc = cr+dr, cc+dc |
| if 0<=nr<h and 0<=nc<w and not ext[nr,nc] and grid[nr,nc]==0: |
| ext[nr,nc] = True; q.append((nr,nc)) |
| result = grid.copy() |
| for r in range(h): |
| for c in range(w): |
| if grid[r,c]==0 and not ext[r,c]: |
| for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]: |
| nr, nc = r+dr, c+dc |
| if 0<=nr<h and 0<=nc<w and grid[nr,nc]!=0: |
| result[r,c] = grid[nr,nc]; break |
| return result |
|
|
| def _scale(grid, th, tw): |
| h, w = grid.shape |
| r = np.zeros((th,tw), dtype=grid.dtype) |
| for i in range(th): |
| for j in range(tw): |
| r[i,j] = grid[min(int(i*h/th), h-1), min(int(j*w/tw), w-1)] |
| return r |
|
|
| def _color_map_fn(grid, mapping): |
| r = grid.copy() |
| for k, v in mapping.items(): |
| try: r[grid == k] = v |
| except: pass |
| return r |
|
|
| def _find_cmap(train_in, train_out): |
| m = {} |
| for inp, out in zip(train_in, train_out): |
| if inp.shape != out.shape: return None |
| for v in np.unique(inp): |
| if v == 0: continue |
| uv = np.unique(out[inp==v]) |
| if len(uv)==1 and uv[0]!=0: |
| if v in m and m[v]!=uv[0]: return None |
| m[v] = uv[0] |
| return m if m else None |
|
|
| def _consistent(fn, train_in, train_out): |
| for inp, out in zip(train_in, train_out): |
| try: |
| r = fn(inp) |
| if r is None or not np.array_equal(r, out): return False |
| except: return False |
| return True |
|
|
| def _safe(fn, grid): |
| try: |
| r = fn(grid) |
| return r if r is not None else None |
| except: return None |
|
|
| def _invert_colors(grid): |
| r = grid.copy() |
| vals = sorted(set(v for v in grid.flat if v != 0)) |
| if not vals: return r |
| m = {vals[i]: vals[-(i+1)] for i in range(len(vals))} |
| for k, v in m.items(): r[grid == k] = v |
| return r |
|
|
| def _diag_flip(grid): |
| h, w = grid.shape |
| r = np.zeros((w, h), dtype=grid.dtype) |
| for i in range(h): |
| for j in range(w): r[j, i] = grid[i, j] |
| return r |
|
|
| def _mirror_h(grid): |
| h, w = grid.shape |
| r = grid.copy() |
| mid = w // 2 |
| for i in range(h): |
| for j in range(mid): |
| if r[i, w-1-j] == 0: r[i, w-1-j] = r[i, j] |
| elif r[i, j] == 0: r[i, j] = r[i, w-1-j] |
| return r |
|
|
| def _mirror_v(grid): |
| h, w = grid.shape |
| r = grid.copy() |
| mid = h // 2 |
| for i in range(mid): |
| for j in range(w): |
| if r[h-1-i, j] == 0: r[h-1-i, j] = r[i, j] |
| elif r[i, j] == 0: r[i, j] = r[h-1-i, j] |
| return r |
|
|
| def _largest_object(grid): |
| objs = find_objects(grid) |
| if not objs: return grid.copy() |
| largest = max(objs, key=lambda o: o['size']) |
| r0, r1, c0, c1 = largest['bbox'] |
| mask = np.zeros_like(grid) |
| for rr, cc in largest['pixels']: mask[rr, cc] = grid[rr, cc] |
| return mask[r0:r1, c0:c1] |
|
|
| def _color_histogram(grid): |
| counts = Counter(v for v in grid.flat if v != 0) |
| if not counts: return np.array([[0]]) |
| colors = sorted(counts.keys()) |
| r = np.zeros((1, len(colors)), dtype=np.int32) |
| for i, col in enumerate(colors): r[0, i] = min(counts[col], 9) |
| return r |
|
|
| def _unique_colors(grid): |
| colors = sorted(set(v for v in grid.flat if v != 0)) |
| if not colors: return np.array([[0]]) |
| r = np.zeros((1, len(colors)), dtype=np.int32) |
| for i, c in enumerate(colors): r[0, i] = c |
| return r |
|
|
| PRIMITIVES = [ |
| ("id", lambda g: g.copy()), |
| ("rot90", lambda g: np.rot90(g)), |
| ("rot180", lambda g: np.rot90(g, 2)), |
| ("rot270", lambda g: np.rot90(g, 3)), |
| ("fliplr", lambda g: np.fliplr(g)), |
| ("flipud", lambda g: np.flipud(g)), |
| ("transpose", lambda g: g.T.copy()), |
| ("diag_flip", _diag_flip), |
| ("crop", _crop), |
| ("pad_sq", _pad_sq), |
| ("border", _border), |
| ("fill_holes", _fill_holes), |
| ("mirror_h", _mirror_h), |
| ("mirror_v", _mirror_v), |
| ("invert_colors", _invert_colors), |
| ("unique_colors", _unique_colors), |
| ("scale2x2", lambda g: _scale(g, 2, 2)), |
| ("scale3x3", lambda g: _scale(g, 3, 3)), |
| ("scale5x5", lambda g: _scale(g, 5, 5)), |
| ("rot90_crop", lambda g: _crop(np.rot90(g))), |
| ("crop_rot90", lambda g: np.rot90(_crop(g))), |
| ("fliplr_crop", lambda g: _crop(np.fliplr(g))), |
| ("flipud_crop", lambda g: _crop(np.flipud(g))), |
| ("transpose_crop", lambda g: _crop(g.T.copy())), |
| ("crop_pad", lambda g: _pad_sq(_crop(g))), |
| ("border_fliplr", lambda g: np.fliplr(_border(g))), |
| ("fill_crop", lambda g: _crop(_fill_holes(g))), |
| ("border_rot90", lambda g: np.rot90(_border(g))), |
| ("largest_obj", _largest_object), |
| ("color_hist", _color_histogram), |
| ("largest_fliplr", lambda g: np.fliplr(_largest_object(g))), |
| ("largest_rot90", lambda g: np.rot90(_largest_object(g))), |
| ] |
|
|
| class DSLSolver: |
| def solve(self, train_pairs, test_inputs, budget=DSL_TIME_BUDGET): |
| t0 = time.time() |
| ti = [np.array(p["input"]) for p in train_pairs] |
| to = [np.array(p["output"]) for p in train_pairs] |
| preds = [None] * len(test_inputs) |
| for name, fn in PRIMITIVES: |
| if time.time()-t0 > budget * 0.3: break |
| if _consistent(fn, ti, to): |
| for i, tin in enumerate(test_inputs): |
| preds[i] = _safe(fn, np.array(tin)) |
| if all(p is not None for p in preds): return preds |
| if len(train_pairs) >= 2: |
| cm = _find_cmap(ti, to) |
| if cm and len(cm) >= 2: |
| for i, tin in enumerate(test_inputs): |
| preds[i] = _color_map_fn(np.array(tin), cm) |
| if all(np.array_equal(_color_map_fn(inp, cm), out) for inp, out in zip(ti, to)): |
| if all(p is not None for p in preds): return preds |
| preds = [None] * len(test_inputs) |
| FAST_PRIMS = [ |
| ("rot90", lambda g: np.rot90(g)), |
| ("rot180", lambda g: np.rot90(g, 2)), |
| ("fliplr", lambda g: np.fliplr(g)), |
| ("flipud", lambda g: np.flipud(g)), |
| ("transpose", lambda g: g.T.copy()), |
| ("crop", _crop), |
| ("pad_sq", _pad_sq), |
| ("border", _border), |
| ("fill_holes", _fill_holes), |
| ("mirror_h", _mirror_h), |
| ("mirror_v", _mirror_v), |
| ] |
| for (n1,f1),(n2,f2) in itertools.product(FAST_PRIMS, repeat=2): |
| if n1 == n2: continue |
| if time.time()-t0 > budget * 0.85: break |
| comp = lambda g: _safe(f2, f1(g)) |
| if _consistent(comp, ti, to): |
| for i, tin in enumerate(test_inputs): |
| preds[i] = comp(np.array(tin)) |
| if all(p is not None for p in preds): return preds |
| cm = _find_cmap(ti, to) |
| if cm: |
| for i, tin in enumerate(test_inputs): |
| preds[i] = _color_map_fn(np.array(tin), cm) |
| if all(np.array_equal(_color_map_fn(inp, cm), out) for inp, out in zip(ti, to)): |
| return preds |
| return preds |
|
|
| class ObjectSolver: |
| def solve(self, train_pairs, test_inputs, budget=30): |
| t0 = time.time() |
| ti = [np.array(p["input"]) for p in train_pairs] |
| to = [np.array(p["output"]) for p in train_pairs] |
| all_ok = True |
| for inp, out in zip(ti, to): |
| objs = find_objects(inp) |
| if not objs: all_ok = False; break |
| largest = max(objs, key=lambda o: o['size']) |
| mask = extract_object_mask(inp, largest) |
| if not np.array_equal(mask, out): all_ok = False; break |
| if all_ok and len(train_pairs) >= 1: |
| preds = [] |
| for tin in test_inputs: |
| objs = find_objects(tin) |
| if objs: |
| largest = max(objs, key=lambda o: o['size']) |
| preds.append(extract_object_mask(tin, largest)) |
| else: |
| preds.append(None) |
| if all(p is not None for p in preds): return preds |
| shapes = set(out.shape for out in to) |
| if len(shapes) == 1 and all(s[0] == 1 for s in shapes): |
| w = list(shapes)[0][1] |
| preds = [] |
| for tin in test_inputs: |
| objs = find_objects(tin) |
| colors = sorted(set(o['color'] for o in objs)) |
| pred = np.zeros((1, max(w, len(colors))), dtype=np.int32) |
| for i, c in enumerate(colors): |
| if i < pred.shape[1]: pred[0, i] = c |
| preds.append(pred) |
| if all(p is not None for p in preds): return preds |
| return [None] * len(test_inputs) |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=2000): |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0)/d_model)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer('pe', pe.unsqueeze(1)) |
| def forward(self, x): return x + self.pe[:x.size(0)] |
|
|
| class ARCTTTModel(nn.Module): |
| def __init__(self, vocab_size=VOCAB_SIZE, d_model=TTT_D_MODEL, nhead=TTT_NHEAD, |
| num_enc=TTT_NUM_LAYERS, num_dec=TTT_NUM_LAYERS, |
| dim_ff=TTT_DIM_FEEDFORWARD, dropout=0.1, max_seq=2000): |
| super().__init__() |
| self.src_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_TOKEN) |
| self.tgt_emb = nn.Embedding(vocab_size, d_model, padding_idx=PAD_TOKEN) |
| self.pe = PositionalEncoding(d_model, max_seq) |
| self.transformer = nn.Transformer( |
| d_model=d_model, nhead=nhead, num_encoder_layers=num_enc, |
| num_decoder_layers=num_dec, dim_feedforward=dim_ff, |
| dropout=dropout, batch_first=False) |
| self.out = nn.Linear(d_model, vocab_size) |
|
|
| def forward(self, src, tgt, spm=None, tpm=None): |
| src = self.src_emb(src) * np.sqrt(self.src_emb.embedding_dim) |
| tgt = self.tgt_emb(tgt) * np.sqrt(self.tgt_emb.embedding_dim) |
| src, tgt = self.pe(src), self.pe(tgt) |
| tm = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device) |
| o = self.transformer(src, tgt, tgt_mask=tm, |
| src_key_padding_mask=spm, tgt_key_padding_mask=tpm) |
| return self.out(o) |
|
|
| @torch.no_grad() |
| def generate(self, src, max_len, device): |
| self.eval(); src = src.to(device) |
| gen = [ROW_SEP] |
| for _ in range(max_len): |
| tgt = torch.tensor([gen], device=device) |
| logits = self.forward(src.unsqueeze(1), tgt.unsqueeze(1))[-1, 0] |
| nxt = logits.argmax().item() |
| gen.append(nxt) |
| if nxt == EOS_TOKEN: break |
| return gen[1:] |
|
|
| def augment_pair(inp, out): |
| pairs = [(inp.copy(), out.copy())] |
| for k in [1,2,3]: |
| pairs.append((np.rot90(inp,k), np.rot90(out,k))) |
| pairs.append((np.fliplr(inp), np.fliplr(out))) |
| pairs.append((np.flipud(inp), np.flipud(out))) |
| for k in [1,3]: |
| pairs.append((np.fliplr(np.rot90(inp,k)), np.fliplr(np.rot90(out,k)))) |
| rng = np.random.RandomState(42) |
| for _ in range(8): |
| perm = rng.permutation(10) |
| pm = np.zeros(10, dtype=np.int32) |
| for i, p in enumerate(perm): pm[i] = p |
| pairs.append((pm[inp], pm[out])) |
| return pairs |
|
|
| class TTTSolver: |
| def __init__(self, device): |
| self.device = device |
|
|
| def solve(self, train_pairs, test_inputs, budget=TTT_TIME_BUDGET): |
| if not train_pairs: |
| return [None] * len(test_inputs) |
| t0 = time.time() |
| dataset = [] |
| for pair in train_pairs: |
| inp, out = np.array(pair["input"]), np.array(pair["output"]) |
| for ai, ao in augment_pair(inp, out): |
| dataset.append((encode_grid(ai), encode_grid(ao))) |
| if not dataset: |
| return [None] * len(test_inputs) |
| out_shapes = set() |
| for pair in train_pairs: out_shapes.add(np.array(pair["output"]).shape) |
| model = ARCTTTModel().to(self.device) |
| opt = optim.AdamW(model.parameters(), lr=TTT_LR, weight_decay=1e-4) |
| sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=TTT_MAX_EPOCHS) |
| crit = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) |
| for epoch in range(TTT_MAX_EPOCHS): |
| if time.time() - t0 > budget * 0.75: break |
| model.train() |
| idx = np.random.permutation(len(dataset)) |
| for bs in range(0, len(idx), TTT_BATCH_SIZE): |
| bi = idx[bs:bs+TTT_BATCH_SIZE] |
| batch = [dataset[i] for i in bi] |
| ms = max(len(b[0]) for b in batch) |
| mt = max(len(b[1]) for b in batch) |
| sb = torch.full((len(batch), ms), PAD_TOKEN, dtype=torch.long) |
| tb = torch.full((len(batch), mt), PAD_TOKEN, dtype=torch.long) |
| for i, (s, t) in enumerate(batch): |
| sb[i,:len(s)] = torch.tensor(s) |
| tb[i,:len(t)] = torch.tensor(t) |
| sb, tb = sb.to(self.device), tb.to(self.device) |
| st, tt = sb.T, tb.T |
| ti, to = tt[:-1], tt[1:] |
| spm = (sb == PAD_TOKEN) |
| tpm = (tb[:,:-1] == PAD_TOKEN) |
| logits = model(st, ti, spm, tpm) |
| loss = crit(logits.reshape(-1, VOCAB_SIZE), to.reshape(-1)) |
| opt.zero_grad(); loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| if time.time() - t0 > budget * 0.75: break |
| sched.step() |
| if epoch + 1 >= TTT_MIN_EPOCHS and (epoch + 1) >= min(len(dataset), TTT_MAX_EPOCHS): break |
| preds = [] |
| model.eval() |
| with torch.no_grad(): |
| for tin in test_inputs: |
| if time.time() - t0 > budget: preds.append(None); continue |
| st = encode_grid(np.array(tin)) |
| src = torch.tensor([st], device=self.device) |
| h, w = next(iter(out_shapes)) if out_shapes else tin.shape |
| max_tgt = (h+1) * (max(TTT_MAX_GRID_SIZE, w)+1) + 2 |
| try: |
| gen = model.generate(src.T, max_tgt, self.device) |
| preds.append(decode_grid(gen, h, w)) |
| except: preds.append(None) |
| return preds |
|
|
| class EnsembleSolver: |
| def __init__(self, device): |
| self.dsl = DSLSolver() |
| self.obj = ObjectSolver() |
| self.ttt = TTTSolver(device) |
|
|
| def solve(self, train_pairs, test_inputs, task_id=""): |
| d = self.dsl.solve(train_pairs, test_inputs) |
| if all(p is not None for p in d): return d |
| o = self.obj.solve(train_pairs, test_inputs) |
| if all(p is not None for p in o): return o |
| t = self.ttt.solve(train_pairs, test_inputs) |
| final = [] |
| for i, (dp, op, tp) in enumerate(zip(d, o, t)): |
| if dp is not None: final.append(dp) |
| elif op is not None: final.append(op) |
| elif tp is not None: final.append(tp) |
| else: final.append(np.array(test_inputs[i])) |
| return final |
|
|
| |
| |
| |
|
|
| import multiprocessing as mp |
| import traceback |
|
|
| def _worker_solve(args): |
| gpu_id, task_id, train_pairs, test_arrays = args |
| try: |
| device = torch.device(f"cuda:{gpu_id}") |
| solver = TTTSolver(device) |
| preds = solver.solve(train_pairs, test_arrays) |
| return task_id, preds, None |
| except Exception as e: |
| return task_id, None, f"{e}\n{traceback.format_exc()}" |
|
|
|
|
| def main(): |
| INPUT_DIR = Path("/kaggle/input/arc-prize-2026") |
| OUTPUT_DIR = Path("/kaggle/working") |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| test_path = INPUT_DIR / "arc-agi_test_challenges.json" |
| if not test_path.exists(): |
| print(f"ERROR: {test_path} not found"); sys.exit(1) |
|
|
| with open(test_path) as f: |
| test_challenges = json.load(f) |
|
|
| all_train = {} |
| for fn in ["arc-agi_training_challenges.json", "arc-agi_evaluation_challenges.json"]: |
| p = INPUT_DIR / fn |
| if p.exists(): |
| with open(p) as f: |
| all_train.update(json.load(f)) |
|
|
| num_gpus = torch.cuda.device_count() |
| print(f"Test tasks: {len(test_challenges)}, Training: {len(all_train)}, GPUs: {num_gpus}") |
|
|
| dsl_solver = DSLSolver() |
| obj_solver = ObjectSolver() |
| submission = {} |
| ttt_queue = [] |
| t_start = time.time() |
| sorted_ids = sorted(test_challenges) |
|
|
| for idx, task_id in enumerate(sorted_ids): |
| task_data = test_challenges[task_id] |
| train_pairs = list(task_data.get("train", [])) |
|
|
| if task_id in all_train: |
| extra = all_train[task_id].get("train", []) |
| seen = {hashlib.md5(str(tp).encode()).hexdigest() for tp in train_pairs} |
| for tp in extra: |
| h = hashlib.md5(str(tp).encode()).hexdigest() |
| if h not in seen: |
| train_pairs.append(tp); seen.add(h) |
|
|
| test_data = task_data.get("test", []) |
| test_inputs = [ |
| np.array(td["input"]) if isinstance(td, dict) else np.array(td) |
| for td in test_data |
| ] |
|
|
| dsl_preds = dsl_solver.solve(train_pairs, test_inputs) |
| if all(p is not None for p in dsl_preds): |
| submission[task_id] = [p.tolist() for p in dsl_preds] |
| print(f"[DSL] {task_id}: OK ({len(test_inputs)} tests)") |
| continue |
|
|
| obj_preds = obj_solver.solve(train_pairs, test_inputs) |
| if all(p is not None for p in obj_preds): |
| submission[task_id] = [p.tolist() for p in obj_preds] |
| print(f"[OBJECT] {task_id}: OK ({len(test_inputs)} tests)") |
| continue |
|
|
| ttt_queue.append((task_id, train_pairs, test_inputs)) |
|
|
| dsl_time = time.time() - t_start |
| print(f"\nPhase 1 (DSL+Object): {dsl_time:.1f}s") |
| print(f"Solved: {len(submission)}, Remaining for TTT: {len(ttt_queue)}") |
|
|
| if ttt_queue and num_gpus > 0: |
| print(f"Launching TTT workers across {num_gpus} GPUs...") |
| work_items = [] |
| for i, (task_id, train_pairs, test_inputs) in enumerate(ttt_queue): |
| gpu = i % num_gpus |
| work_items.append((gpu, task_id, train_pairs, test_inputs)) |
|
|
| batch_size = num_gpus |
| ttt_start = time.time() |
|
|
| for batch_start in range(0, len(work_items), batch_size): |
| elapsed = time.time() - t_start |
| if elapsed > TOTAL_TIME_BUDGET: |
| print(f"WARNING: Budget exceeded. Identity fallback for {len(work_items)-batch_start} remaining.") |
| for _, task_id, _, test_inputs in work_items[batch_start:]: |
| submission[task_id] = [tin.tolist() for tin in test_inputs] |
| break |
|
|
| batch = work_items[batch_start:batch_start + batch_size] |
| with mp.Pool(processes=min(len(batch), num_gpus)) as pool: |
| results = pool.map(_worker_solve, batch) |
|
|
| for task_id, preds, error in results: |
| if error: |
| print(f"[TTT] {task_id}: ERROR {error[:100]}") |
| task_data = test_challenges[task_id] |
| test_data = task_data.get("test", []) |
| test_arrs = [np.array(td["input"]) if isinstance(td, dict) else np.array(td) for td in test_data] |
| submission[task_id] = [tin.tolist() for tin in test_arrs] |
| else: |
| task_data = test_challenges[task_id] |
| test_data = task_data.get("test", []) |
| test_arrs = [np.array(td["input"]) if isinstance(td, dict) else np.array(td) for td in test_data] |
| final = [p.tolist() if p is not None else test_arrs[i].tolist() |
| for i, p in enumerate(preds)] |
| submission[task_id] = final |
| print(f"[TTT] {task_id}: OK ({len(final)} tests)") |
|
|
| batch_time = time.time() - ttt_start |
| done = batch_start + len(batch) |
| print(f" Batch {done}/{len(work_items)} done ({batch_time:.1f}s, ETA {batch_time*len(work_items)/done/3600:.1f}h)") |
|
|
| ttt_time = time.time() - ttt_start |
| print(f"TTT phase: {ttt_time:.1f}s ({ttt_time/3600:.2f}h)") |
|
|
| elif ttt_queue: |
| print(f"Running TTT sequentially on CPU...") |
| ttt_solver = TTTSolver(torch.device("cpu")) |
| for task_id, train_pairs, test_inputs in ttt_queue: |
| elapsed = time.time() - t_start |
| if elapsed > TOTAL_TIME_BUDGET: |
| submission[task_id] = [tin.tolist() for tin in test_inputs] |
| continue |
| t_task = time.time() |
| preds = ttt_solver.solve(train_pairs, test_inputs) |
| dt = time.time() - t_task |
| final = [p.tolist() if p is not None else test_inputs[i].tolist() |
| for i, p in enumerate(preds)] |
| submission[task_id] = final |
| print(f"[TTT-seq] {task_id}: {dt:.1f}s") |
|
|
| sub_path = OUTPUT_DIR / "submission.json" |
| with open(sub_path, "w") as f: |
| json.dump(submission, f) |
|
|
| total_t = time.time() - t_start |
| print(f"\nSaved: {sub_path}") |
| print(f"Total: {total_t/3600:.1f}h, Tasks: {len(submission)}/{len(test_challenges)}") |
|
|
|
|
| if __name__ == "__main__": |
| mp.set_start_method("spawn", force=True) |
| main() |
|
|