File size: 11,220 Bytes
460dc79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""ARC-AGI-2 Video Answer Evaluator.

Extracts the test output grid from the last frame of a generated video,
then compares it against the ground-truth answer.

Color recovery pipeline:
  1. Match pixel RGB against the canonical ARC_COLORS palette β†’ permuted color index
  2. Apply inverse permutation β†’ original color index
  3. Compare with ground truth

Usage:
    python video_evaluate.py --video_dir videos --data_dir data --output results.json
"""

import json
import random
import argparse
from pathlib import Path

from collections import defaultdict
import cv2
import numpy as np
from tqdm import tqdm

# ── ARC Color Palette (RGB) ───────────────────────────────────────────────────

ARC_COLORS = np.array([
    [0x00, 0x00, 0x00],  # 0: black
    [0x00, 0x74, 0xD9],  # 1: blue
    [0xFF, 0x41, 0x36],  # 2: red
    [0x2E, 0xCC, 0x40],  # 3: green
    [0xFF, 0xDC, 0x00],  # 4: yellow
    [0xAA, 0xAA, 0xAA],  # 5: grey
    [0xF0, 0x12, 0xBE],  # 6: magenta
    [0xFF, 0x85, 0x1B],  # 7: orange
    [0x7F, 0xDB, 0xFF],  # 8: light blue
    [0x87, 0x0C, 0x25],  # 9: maroon
], dtype=np.uint8)


# ── Color Permutation Utilities ────────────────────────────────────────────────

def generate_color_permutation(seed: int) -> list[int]:
    """Reproduce the same permutation used during video generation."""
    rng = random.Random(seed)
    perm = list(range(10))
    rng.shuffle(perm)
    return perm


def invert_permutation(perm: list[int]) -> list[int]:
    """Compute inverse permutation: inv[perm[i]] = i."""
    inv = [0] * len(perm)
    for i, p in enumerate(perm):
        inv[p] = i
    return inv


# ── Layout Computation (mirrors video_generate.py exactly) ─────────────────────

def compute_test_output_bbox(task: dict, canvas_h: int, canvas_w: int) -> dict:
    """Compute pixel bounding box of the test output grid region.

    Replicates _compute_layout + render_frame positioning from video_generate.py.
    """
    n_cols = len(task["train"]) + 1
    n_rows = 2
    padding = 12
    outer_margin = 16
    label_h = 20

    usable_w = canvas_w - 2 * outer_margin - (n_cols - 1) * padding
    usable_h = canvas_h - 2 * outer_margin - (n_rows - 1) * padding
    cell_w = usable_w // n_cols
    cell_h = usable_h // n_rows

    total_block_w = cell_w * n_cols + (n_cols - 1) * padding
    total_block_h = cell_h * n_rows + (n_rows - 1) * padding
    margin_x = (canvas_w - total_block_w) // 2
    margin_y = (canvas_h - total_block_h) // 2

    # Test output: last column, second row
    col = n_cols - 1
    x0 = margin_x + col * (cell_w + padding)
    y0 = margin_y + cell_h + padding

    test_out = np.array(task["test"][0]["output"])
    gr, gc = test_out.shape

    return {
        "grid_rows": gr,
        "grid_cols": gc,
        "grid_x0": x0,
        "grid_y0": y0 + label_h,
        "grid_w": cell_w,
        "grid_h": cell_h - label_h,
    }


# ── Frame Extraction ───────────────────────────────────────────────────────────

def extract_last_frame(video_path: str) -> np.ndarray:
    """Extract the last frame from a video as an RGB numpy array."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise FileNotFoundError(f"Cannot open video: {video_path}")

    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, total - 1))
    ret, frame = cap.read()
    cap.release()

    if not ret:
        raise RuntimeError(f"Failed to read last frame from {video_path}")
    return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)


# ── Grid Extraction ────────────────────────────────────────────────────────────

def extract_grid_from_frame(
    frame: np.ndarray,
    grid_x0: int,
    grid_y0: int,
    grid_w: int,
    grid_h: int,
    grid_rows: int,
    grid_cols: int,
) -> list[list[int]]:
    """Extract ARC grid by sampling cell centers and matching to ARC_COLORS.

    Always matches against the canonical ARC_COLORS palette. The returned
    indices are the permuted color values as rendered in the video.

    Args:
        frame: RGB image (H, W, 3).
        grid_x0, grid_y0: Top-left of grid area (below label).
        grid_w, grid_h: Grid area dimensions.
        grid_rows, grid_cols: Expected grid shape.

    Returns:
        Grid of permuted color indices (apply inverse perm to get originals).
    """
    cell_h = grid_h / grid_rows
    cell_w = grid_w / grid_cols

    grid = []
    for r in range(grid_rows):
        row = []
        cy = int(grid_y0 + (r + 0.5) * cell_h)
        for c in range(grid_cols):
            cx = int(grid_x0 + (c + 0.5) * cell_w)
            # 3x3 patch average for codec artifact robustness
            patch = frame[max(0, cy - 1): cy + 2, max(0, cx - 1): cx + 2]
            avg = patch.mean(axis=(0, 1)).astype(np.uint8)
            dists = np.sum((ARC_COLORS.astype(int) - avg.astype(int)) ** 2, axis=1)
            row.append(int(np.argmin(dists)))
        grid.append(row)
    return grid


# ── Evaluation ─────────────────────────────────────────────────────────────────

def evaluate_video(
    video_path: str,
    task: dict,
    perm: list[int],
    canvas_h: int = 720,
    canvas_w: int = 1280,
) -> dict:
    """Evaluate a single video against ground truth.

    Pipeline:
      1. Extract last frame (full answer revealed)
      2. Locate test output region via layout math
      3. Sample cell centers β†’ match to ARC_COLORS β†’ get permuted color indices
      4. Apply inverse permutation β†’ recover original color indices
      5. Compare with ground truth

    Returns:
        Dict with 'correct', 'predicted_grid', 'ground_truth', 'pixel_accuracy'.
    """
    frame = extract_last_frame(video_path)
    bbox = compute_test_output_bbox(task, canvas_h, canvas_w)

    # Step 1: extract permuted color indices from rendered pixels
    permuted_grid = extract_grid_from_frame(frame, **bbox)

    # Step 2: invert permutation to recover original values
    inv = invert_permutation(perm)
    predicted = [[inv[cell] for cell in row] for row in permuted_grid]

    # Step 3: compare with ground truth
    gt = task["test"][0]["output"]
    correct = (predicted == gt)

    gt_flat = [c for row in gt for c in row]
    pred_flat = [c for row in predicted for c in row]
    n_match = sum(a == b for a, b in zip(gt_flat, pred_flat))
    pixel_acc = n_match / max(len(gt_flat), 1)

    return {
        "correct": correct,
        "predicted_grid": predicted,
        "ground_truth": gt,
        "pixel_accuracy": pixel_acc,
    }


# ── Batch Evaluation ───────────────────────────────────────────────────────────

def evaluate_all(
    video_dir: str = "videos",
    data_dir: str = "data",
    output_file: str = "results.json",
) -> None:
    """Evaluate all videos against ground-truth tasks.

    Recovers the color permutation from the seed in the filename
    ({task_id}_{seed}.mp4) using the same RNG as video_generate.py.
    """
    video_path = Path(video_dir)
    data_path = Path(data_dir)

    # Build task file lookup
    task_files: dict[str, Path] = {}
    for subdir in ["training", "evaluation"]:
        d = data_path / subdir
        if d.exists():
            for fp in d.glob("*.json"):
                task_files[fp.stem] = fp

    videos = sorted(video_path.glob("*.mp4"))
    if not videos:
        print(f"No videos found in {video_dir}")
        return

    # Auto-detect resolution from first video
    cap = cv2.VideoCapture(str(videos[0]))
    canvas_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    canvas_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    cap.release()
    print(f"Detected resolution: {canvas_h}x{canvas_w}")

    results = {}
    total_correct = 0
    total_count = 0

    for vp in tqdm(videos, desc="Evaluating"):
        stem = vp.stem
        parts = stem.rsplit("_", 1)
        if len(parts) != 2:
            continue
        task_id, seed_str = parts

        if task_id not in task_files:
            tqdm.write(f"Skip {stem}: task not found")
            continue

        with open(task_files[task_id]) as f:
            task = json.load(f)

        if not task.get("test") or "output" not in task["test"][0]:
            continue

        # Recover the exact permutation from seed
        seed = int(seed_str)
        perm = generate_color_permutation(seed)

        try:
            result = evaluate_video(str(vp), task, perm, canvas_h, canvas_w)
            results[stem] = {
                "correct": result["correct"],
                "pixel_accuracy": result["pixel_accuracy"],
                "task_id": task_id,
                "seed": seed_str,
            }
            total_count += 1
            if result["correct"]:
                total_correct += 1
        except Exception as e:
            tqdm.write(f"Error {stem}: {e}")
            results[stem] = {"error": str(e), "task_id": task_id}

    acc = total_correct / max(total_count, 1)

    # Per-task pixel accuracy aggregation
    task_pixels: dict[str, list[float]] = defaultdict(list)
    for v in results.values():
        if "pixel_accuracy" in v:
            task_pixels[v["task_id"]].append(v["pixel_accuracy"])

    per_task_pixel_acc = {
        tid: round(sum(accs) / len(accs), 4)
        for tid, accs in sorted(task_pixels.items())
    }

    summary = {
        "total_videos": total_count,
        "correct": total_correct,
        "accuracy": round(acc, 4),
        "mean_pixel_accuracy": round(
            sum(per_task_pixel_acc.values()) / max(len(per_task_pixel_acc), 1), 4
        ),
        "per_task_pixel_accuracy": per_task_pixel_acc,
        "results": results,
    }

    with open(output_file, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"\nResults: {total_correct}/{total_count} correct ({acc:.2%})")
    print(f"Mean pixel accuracy (per-task avg): {summary['mean_pixel_accuracy']:.2%}")
    print(f"Saved to {output_file}")


# ── CLI ────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    p = argparse.ArgumentParser(description="ARC Video Evaluator")
    p.add_argument("--video_dir", type=str, default="videos")
    p.add_argument("--data_dir", type=str, default="data")
    p.add_argument("--output", type=str, default="results.json")
    args = p.parse_args()
    evaluate_all(args.video_dir, args.data_dir, args.output)