Jayce-Ping commited on
Commit
2c25848
·
verified ·
1 Parent(s): 7cdb0ca

Add files using upload-large-folder tool

Browse files
frozenlake/data_process.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FrozenLake Video Dataset Generator — generate, eval, verify.
3
+
4
+ Uses plain BFS solver (not networkx) for fast generation at all grid sizes.
5
+
6
+ Usage:
7
+ python frozenlake_video_gen.py generate --output-dir frozenlake \
8
+ --sizes 8 16 32 --num-per-size 100 500 1000 --p 0.8
9
+ python frozenlake_video_gen.py eval result_videos/ --table-dir frozenlake/tables
10
+ python frozenlake_video_gen.py verify results.json --table-dir frozenlake/tables
11
+ """
12
+ import json
13
+ import csv
14
+ import hashlib
15
+ import random
16
+ import re
17
+ import argparse
18
+ from dataclasses import dataclass, asdict
19
+ from pathlib import Path
20
+ from typing import Dict, List, Optional
21
+
22
+ import cv2
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ from frozenlake_processor import FrozenLakeProcessor
27
+
28
+
29
+ # ==================== Checkpoint ====================
30
+
31
+ @dataclass
32
+ class GenerationState:
33
+ params_hash: str
34
+ size_progress: Dict[int, int]
35
+ seen_fingerprints: List[str]
36
+ all_samples: List[Dict]
37
+ completed: bool = False
38
+
39
+ def to_dict(self) -> Dict:
40
+ return asdict(self)
41
+
42
+ @classmethod
43
+ def from_dict(cls, d: Dict) -> "GenerationState":
44
+ return cls(**d)
45
+
46
+
47
+ def _params_hash(params: Dict) -> str:
48
+ key = {k: v for k, v in params.items() if k != "output_dir"}
49
+ return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12]
50
+
51
+
52
+ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
53
+ meta = output_dir / "metadata.json"
54
+ if not meta.exists():
55
+ return None
56
+ with open(meta) as f:
57
+ data = json.load(f)
58
+ state = GenerationState.from_dict(data["state"])
59
+ expected = _params_hash(params)
60
+ if state.params_hash != expected:
61
+ print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh")
62
+ return None
63
+ if state.completed:
64
+ print("✓ Generation already completed")
65
+ return state
66
+ print(f"✓ Resuming: {sum(state.size_progress.values())} puzzles done")
67
+ return state
68
+
69
+
70
+ def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
71
+ meta = output_dir / "metadata.json"
72
+ tmp = meta.with_suffix(".tmp")
73
+ with open(tmp, "w") as f:
74
+ json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
75
+ tmp.rename(meta)
76
+
77
+
78
+ # ==================== Video I/O ====================
79
+
80
+ def save_video_cv2(frames: list, path: str, fps: int = 10):
81
+ first = np.array(frames[0])
82
+ h, w = first.shape[:2]
83
+ writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
84
+ for frame in frames:
85
+ writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
86
+ writer.release()
87
+
88
+
89
+ def extract_last_frame(video_path: str) -> Optional[np.ndarray]:
90
+ cap = cv2.VideoCapture(str(video_path))
91
+ if not cap.isOpened():
92
+ return None
93
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
94
+ if total > 0:
95
+ cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1)
96
+ ret, frame = cap.read()
97
+ cap.release()
98
+ if not ret or frame is None:
99
+ return None
100
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
101
+
102
+
103
+ # ==================== Helpers ====================
104
+
105
+ def _normalise_list(val, sizes, name="parameter"):
106
+ if isinstance(val, int):
107
+ return [val] * len(sizes)
108
+ if len(val) != len(sizes):
109
+ raise ValueError(f"{name} length ({len(val)}) != sizes ({len(sizes)})")
110
+ return list(val)
111
+
112
+
113
+ # ==================== Generate ====================
114
+
115
+ def generate_dataset(
116
+ output_dir: str = "frozenlake",
117
+ sizes: List[int] = [8, 16, 32],
118
+ num_per_size: list = [100, 500, 1000],
119
+ p: float = 0.8,
120
+ min_path_ratio: float = 0.3,
121
+ img_size: int = 512,
122
+ prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.",
123
+ train_ratio: float = 0.9,
124
+ n_start: int = 2,
125
+ m_end: int = 3,
126
+ frames: Optional[int] = None,
127
+ fps: int = 10,
128
+ seed: int = 42,
129
+ use_gym: bool = True,
130
+ checkpoint_interval: int = 50,
131
+ ):
132
+ """
133
+ Generate FrozenLake video dataset with checkpoint/resume.
134
+
135
+ Layout::
136
+
137
+ output_dir/
138
+ images/ videos/ tables/
139
+ train.jsonl test.jsonl train.csv test.csv
140
+ path.json metadata.json
141
+ """
142
+ params = {
143
+ "sizes": sizes, "num_per_size": num_per_size,
144
+ "p": p, "min_path_ratio": min_path_ratio, "img_size": img_size,
145
+ "prompt": prompt, "train_ratio": train_ratio,
146
+ "n_start": n_start, "m_end": m_end, "frames": frames,
147
+ "fps": fps, "seed": seed, "use_gym": use_gym,
148
+ }
149
+
150
+ out = Path(output_dir)
151
+ img_dir, vid_dir, tbl_dir = out / "images", out / "videos", out / "tables"
152
+ for d in (img_dir, vid_dir, tbl_dir):
153
+ d.mkdir(parents=True, exist_ok=True)
154
+
155
+ state = load_checkpoint(out, params)
156
+ if state and state.completed:
157
+ return
158
+
159
+ num_list = _normalise_list(
160
+ num_per_size[0] if len(num_per_size) == 1 else num_per_size,
161
+ sizes, "num_per_size",
162
+ )
163
+ num_w = len(str(max(num_list)))
164
+ proc = FrozenLakeProcessor(img_size=img_size)
165
+
166
+ if state is None:
167
+ random.seed(seed)
168
+ state = GenerationState(
169
+ params_hash=_params_hash(params),
170
+ size_progress={sz: 0 for sz in sizes},
171
+ seen_fingerprints=[], all_samples=[],
172
+ )
173
+ print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}")
174
+ print(f" frames={'auto' if frames is None else frames}, "
175
+ f"n_start={n_start}, m_end={m_end}, fps={fps}")
176
+ else:
177
+ random.seed(seed)
178
+ for _ in range(sum(state.size_progress.values()) * 10):
179
+ random.random()
180
+
181
+ seen = set(state.seen_fingerprints)
182
+ all_samples = list(state.all_samples)
183
+ progress = {int(k): v for k, v in state.size_progress.items()}
184
+ since_ckpt = 0
185
+ total_target = sum(num_list)
186
+
187
+ with tqdm(total=total_target, initial=sum(progress.values()),
188
+ desc="Total", unit="puzzle") as pbar:
189
+ for grid_size, target in zip(sizes, num_list):
190
+ generated = progress.get(grid_size, 0)
191
+ if generated >= target:
192
+ continue
193
+
194
+ min_len = max(1, int(grid_size * grid_size * min_path_ratio))
195
+
196
+ with tqdm(total=target, initial=generated,
197
+ desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz:
198
+ for _ in range((target - generated) * 20):
199
+ if generated >= target:
200
+ break
201
+ try:
202
+ desc, path = proc.generate(grid_size, p=p, min_path_len=min_len)
203
+ except RuntimeError:
204
+ continue
205
+
206
+ fp = proc.fingerprint(desc)
207
+ if fp in seen:
208
+ continue
209
+ seen.add(fp)
210
+
211
+ base = f"size{grid_size}_{generated:0{num_w}d}"
212
+ img_name, vid_name, tbl_name = f"{base}.png", f"{base}.mp4", f"{base}.txt"
213
+
214
+ proc.render(desc, use_gym=use_gym).save(str(img_dir / img_name))
215
+
216
+ vid_frames = proc.generate_video_frames(
217
+ desc, path, n_start=n_start, m_end=m_end,
218
+ frames=frames, use_gym=use_gym,
219
+ )
220
+ save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps)
221
+ proc.save_table(str(tbl_dir / tbl_name), desc)
222
+
223
+ udrl = proc.path_to_udrl(path)
224
+ all_samples.append({
225
+ "prompt": prompt, "image": img_name, "video": vid_name,
226
+ "table": tbl_name, "grid_size": grid_size,
227
+ "grid_desc": desc, "start": list(proc.find_start(desc)),
228
+ "path_udrl": udrl, "path_length": len(path),
229
+ "frame_count": len(vid_frames),
230
+ })
231
+
232
+ generated += 1
233
+ progress[grid_size] = generated
234
+ since_ckpt += 1
235
+ pbar_sz.update(1)
236
+ pbar.update(1)
237
+
238
+ if since_ckpt >= checkpoint_interval:
239
+ state.size_progress = progress
240
+ state.seen_fingerprints = list(seen)
241
+ state.all_samples = all_samples
242
+ save_checkpoint(out, state, params)
243
+ since_ckpt = 0
244
+
245
+ tqdm.write(f"Size {grid_size}: {generated} puzzles")
246
+
247
+ # --- Final outputs ---
248
+ with open(out / "path.json", "w") as f:
249
+ json.dump(
250
+ dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)),
251
+ f, indent=4,
252
+ )
253
+
254
+ random.seed(seed + 1)
255
+ random.shuffle(all_samples)
256
+ split = int(len(all_samples) * train_ratio)
257
+
258
+ def _jsonl(samples, path):
259
+ with open(path, "w") as f:
260
+ for s in samples:
261
+ f.write(json.dumps(s) + "\n")
262
+
263
+ _jsonl(all_samples[:split], out / "train.jsonl")
264
+ _jsonl(all_samples[split:], out / "test.jsonl")
265
+
266
+ for name, samps in [("train", all_samples[:split]), ("test", all_samples[split:])]:
267
+ with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f:
268
+ w = csv.writer(f)
269
+ w.writerow(["input_image", "video", "prompt"])
270
+ for s in samps:
271
+ w.writerow([f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]])
272
+
273
+ state.size_progress = progress
274
+ state.seen_fingerprints = list(seen)
275
+ state.all_samples = all_samples
276
+ state.completed = True
277
+ save_checkpoint(out, state, params)
278
+
279
+ lengths = [s["path_length"] for s in all_samples]
280
+ fcounts = [s["frame_count"] for s in all_samples]
281
+ print(f"\n✓ Dataset complete: {out}/")
282
+ print(f" Sizes: {sizes}, p={p}, Puzzles: {len(all_samples)}")
283
+ print(f" Train: {split}, Test: {len(all_samples) - split}")
284
+ print(f" Path lengths: avg={np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")
285
+ print(f" Frame counts: avg={np.mean(fcounts):.1f}, min={min(fcounts)}, max={max(fcounts)}")
286
+
287
+
288
+ # ==================== Eval ====================
289
+
290
+ def eval_videos(
291
+ video_dir: str,
292
+ table_dir: str,
293
+ output_json: Optional[str] = None,
294
+ gt_json: Optional[str] = None,
295
+ use_gym: bool = True,
296
+ ):
297
+ """Evaluate result videos: last frame → red path → verify."""
298
+ proc = FrozenLakeProcessor()
299
+ vid_root, tbl_root = Path(video_dir), Path(table_dir)
300
+ if output_json is None:
301
+ output_json = str(vid_root / "0_result.json")
302
+
303
+ videos = sorted(
304
+ vid_root.glob("*.mp4"),
305
+ key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
306
+ )
307
+ if not videos:
308
+ print(f"No .mp4 in {vid_root}")
309
+ return
310
+
311
+ print(f"Found {len(videos)} videos, table_dir={tbl_root}")
312
+
313
+ extracted: Dict[str, str] = {}
314
+ missing_tbl = missing_frame = 0
315
+
316
+ for vp in tqdm(videos, desc="Extracting"):
317
+ stem = vp.stem
318
+ desc = proc.load_table(str(tbl_root / f"{stem}.txt"))
319
+ if desc is None:
320
+ missing_tbl += 1
321
+ continue
322
+ start = proc.find_start(desc)
323
+ if start is None:
324
+ missing_tbl += 1
325
+ continue
326
+ lf = extract_last_frame(str(vp))
327
+ if lf is None:
328
+ missing_frame += 1
329
+ continue
330
+ extracted[f"{stem}.png"] = proc.extract_path_from_pixels(
331
+ lf, len(desc), len(desc[0]), start, desc
332
+ )
333
+
334
+ with open(output_json, "w") as f:
335
+ json.dump(extracted, f, indent=4)
336
+ print(f"Saved: {output_json}")
337
+
338
+ # Verify
339
+ correct = total_valid = 0
340
+ correctly_solved: List[Dict] = []
341
+ size_stats: Dict[int, Dict[str, int]] = {}
342
+
343
+ verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
344
+
345
+ for name, udrl in extracted.items():
346
+ desc = proc.load_table(str(tbl_root / f"{name.replace('.png', '')}.txt"))
347
+ if desc is None:
348
+ continue
349
+ total_valid += 1
350
+ sz = len(desc)
351
+ size_stats.setdefault(sz, {"total": 0, "correct": 0})
352
+ size_stats[sz]["total"] += 1
353
+ if verify_fn(desc, udrl):
354
+ correct += 1
355
+ size_stats[sz]["correct"] += 1
356
+ correctly_solved.append({"name": name, "length": len(udrl)})
357
+
358
+ acc = correct / total_valid * 100 if total_valid else 0
359
+ print(f"\n{'='*50}\nEvaluation Summary\n{'='*50}")
360
+ print(f"Videos: {len(videos)}, Missing tables: {missing_tbl}, "
361
+ f"Failed frames: {missing_frame}")
362
+ print(f"Evaluated: {total_valid}, Correct: {correct}, Accuracy: {acc:.2f}%")
363
+
364
+ if size_stats:
365
+ print("\nBy size:")
366
+ for sz in sorted(size_stats):
367
+ s = size_stats[sz]
368
+ print(f" {sz:3d}: {s['correct']}/{s['total']} "
369
+ f"({s['correct']/s['total']*100:.1f}%)")
370
+
371
+ correctly_solved.sort(key=lambda x: x["length"], reverse=True)
372
+ for i, item in enumerate(correctly_solved[:3]):
373
+ print(f" Top {i+1}: {item['name']} (len={item['length']})")
374
+
375
+ if gt_json:
376
+ _gt_bins(extracted, gt_json, tbl_root, proc, verify_fn)
377
+ print(f"{'='*50}")
378
+
379
+
380
+ def _gt_bins(extracted, gt_path, tbl_root, proc, verify_fn):
381
+ try:
382
+ with open(gt_path) as f:
383
+ gt = json.load(f)
384
+ except Exception:
385
+ return
386
+ bins: Dict[str, Dict[str, int]] = {}
387
+ for name, pred in extracted.items():
388
+ if name not in gt:
389
+ continue
390
+ lo = (len(gt[name]) // 10) * 10
391
+ label = f"{lo:3d}-{lo+9:3d}"
392
+ bins.setdefault(label, {"total": 0, "correct": 0})
393
+ bins[label]["total"] += 1
394
+ desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
395
+ if desc and verify_fn(desc, pred):
396
+ bins[label]["correct"] += 1
397
+ if bins:
398
+ print("\nBy GT path length:")
399
+ for label in sorted(bins):
400
+ b = bins[label]
401
+ print(f" {label}: {b['correct']}/{b['total']} "
402
+ f"({b['correct']/b['total']*100:.1f}%)")
403
+
404
+
405
+ # ==================== Verify ====================
406
+
407
+ def verify_results(json_file: str, table_dir: str, use_gym: bool = True):
408
+ proc = FrozenLakeProcessor()
409
+ with open(json_file) as f:
410
+ solutions = json.load(f)
411
+ verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
412
+ correct = skipped = valid = 0
413
+ for name, udrl in solutions.items():
414
+ desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt"))
415
+ if desc is None:
416
+ skipped += 1
417
+ continue
418
+ valid += 1
419
+ if verify_fn(desc, udrl):
420
+ correct += 1
421
+ acc = correct / valid * 100 if valid else 0
422
+ print(f"\n{'='*40}\nVerification: {correct}/{valid} ({acc:.2f}%)")
423
+ if skipped:
424
+ print(f"Skipped: {skipped}")
425
+ print(f"{'='*40}")
426
+
427
+
428
+ # ==================== CLI ====================
429
+
430
+ def parse_args():
431
+ p = argparse.ArgumentParser(description="FrozenLake video dataset")
432
+ sub = p.add_subparsers(dest="command")
433
+
434
+ gen = sub.add_parser("generate")
435
+ gen.add_argument("--output-dir", default="frozenlake")
436
+ gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 32])
437
+ gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000])
438
+ gen.add_argument("--p", type=float, default=0.8)
439
+ gen.add_argument("--min-path-ratio", type=float, default=0.1,
440
+ help="Min path length as fraction of size² (default 0.1; "
441
+ "FrozenLake paths are much shorter than maze paths)")
442
+ gen.add_argument("--img-size", type=int, default=1024)
443
+ gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.")
444
+ gen.add_argument("--train-ratio", type=float, default=0.9)
445
+ gen.add_argument("--n-start", type=int, default=2)
446
+ gen.add_argument("--m-end", type=int, default=3)
447
+ gen.add_argument("--frames", type=int, default=None)
448
+ gen.add_argument("--fps", type=int, default=10)
449
+ gen.add_argument("--seed", type=int, default=42)
450
+ gen.add_argument("--no-gym", action="store_true")
451
+ gen.add_argument("--checkpoint-interval", type=int, default=50)
452
+
453
+ ev = sub.add_parser("eval")
454
+ ev.add_argument("video_dir")
455
+ ev.add_argument("--table-dir", required=True)
456
+ ev.add_argument("--output-json", default=None)
457
+ ev.add_argument("--gt-json", default=None)
458
+ ev.add_argument("--no-gym", action="store_true")
459
+
460
+ ver = sub.add_parser("verify")
461
+ ver.add_argument("json_file")
462
+ ver.add_argument("--table-dir", required=True)
463
+ ver.add_argument("--no-gym", action="store_true")
464
+
465
+ return p.parse_args()
466
+
467
+
468
+ if __name__ == "__main__":
469
+ args = parse_args()
470
+ if args.command == "generate":
471
+ kw = {k: v for k, v in vars(args).items() if k not in ("command", "no_gym")}
472
+ kw["use_gym"] = not args.no_gym
473
+ generate_dataset(**kw)
474
+ elif args.command == "eval":
475
+ eval_videos(args.video_dir, args.table_dir, args.output_json,
476
+ args.gt_json, not args.no_gym)
477
+ elif args.command == "verify":
478
+ verify_results(args.json_file, args.table_dir, not args.no_gym)
479
+ else:
480
+ print("Usage: python frozenlake_video_gen.py {generate|eval|verify} ...")
frozenlake/frozenlake_processor.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FrozenLakeProcessor - FrozenLake puzzle generation, solving, rendering, and evaluation.
3
+
4
+ Grid cells: S=Start, F=Frozen(safe), H=Hole(death), G=Goal
5
+ Table chars: @=Start, _=Frozen, #=Hole, *=Goal
6
+
7
+ Performance notes vs original DiffThinker code:
8
+ - Solving uses plain BFS (O(n²)) instead of networkx graph construction
9
+ which had massive overhead from add_node/add_edge Python calls.
10
+ - Gym renderer is cached per puzzle to avoid repeated pygame init.
11
+ """
12
+ import os
13
+ import random
14
+ import warnings
15
+ from collections import deque
16
+ from typing import List, Tuple, Optional
17
+
18
+ import numpy as np
19
+ from PIL import Image, ImageDraw
20
+
21
+ try:
22
+ os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
23
+ warnings.filterwarnings("ignore", category=UserWarning, module="pygame")
24
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
25
+ import gymnasium as gym
26
+
27
+ HAS_GYM = True
28
+ except ImportError:
29
+ HAS_GYM = False
30
+
31
+ # Table ↔ Grid mapping
32
+ TABLE_TO_GRID = {"@": "S", "_": "F", "#": "H", "*": "G"}
33
+ GRID_TO_TABLE = {v: k for k, v in TABLE_TO_GRID.items()}
34
+
35
+ MOVES = {"U": (-1, 0), "D": (1, 0), "L": (0, -1), "R": (0, 1)}
36
+ GYM_ACTION_MAP = {"L": 0, "D": 1, "R": 2, "U": 3}
37
+
38
+ GridDesc = List[str]
39
+
40
+
41
+ class FrozenLakeProcessor:
42
+ """FrozenLake generation, BFS solving, rendering, and evaluation."""
43
+
44
+ def __init__(self, img_size: int = 512):
45
+ self.img_size = img_size
46
+ self.path_color = "red"
47
+
48
+ # ==================== Generation ====================
49
+
50
+ def generate(
51
+ self,
52
+ size: int,
53
+ p: float = 0.8,
54
+ min_path_len: int = 1,
55
+ max_attempts: int = 500,
56
+ ) -> Tuple[GridDesc, List[Tuple[int, int]]]:
57
+ """
58
+ Generate a solvable FrozenLake grid with shortest path >= *min_path_len* moves.
59
+
60
+ Returns:
61
+ (desc, path) — desc is list[str], path is list[(r,c)].
62
+ """
63
+ for _ in range(max_attempts):
64
+ desc = self._random_layout(size, p)
65
+ path = self.solve(desc)
66
+ if path is not None and (len(path) - 1) >= min_path_len:
67
+ return desc, path
68
+ raise RuntimeError(
69
+ f"Failed after {max_attempts} attempts "
70
+ f"(size={size}, p={p}, min_path_len={min_path_len})."
71
+ )
72
+
73
+ @staticmethod
74
+ def _random_layout(size: int, p: float = 0.8) -> GridDesc:
75
+ """Random grid with one S and one G at random positions."""
76
+ all_coords = [(r, c) for r in range(size) for c in range(size)]
77
+ start, goal = random.sample(all_coords, 2)
78
+ grid = []
79
+ for r in range(size):
80
+ row = []
81
+ for c in range(size):
82
+ if (r, c) == start:
83
+ row.append("S")
84
+ elif (r, c) == goal:
85
+ row.append("G")
86
+ else:
87
+ row.append("F" if random.random() < p else "H")
88
+ grid.append("".join(row))
89
+ return grid
90
+
91
+ # ==================== Solving (plain BFS — fast) ====================
92
+
93
+ @staticmethod
94
+ def solve(desc: GridDesc) -> Optional[List[Tuple[int, int]]]:
95
+ """
96
+ BFS shortest path from S to G, avoiding H.
97
+
98
+ ~100× faster than networkx for typical grid sizes because it avoids
99
+ Python-level graph object construction entirely.
100
+
101
+ Returns:
102
+ List of (r, c) or None.
103
+ """
104
+ rows, cols = len(desc), len(desc[0])
105
+ start = goal = None
106
+ for r in range(rows):
107
+ for c in range(cols):
108
+ if desc[r][c] == "S":
109
+ start = (r, c)
110
+ elif desc[r][c] == "G":
111
+ goal = (r, c)
112
+ if start is None or goal is None:
113
+ return None
114
+
115
+ visited = [[False] * cols for _ in range(rows)]
116
+ visited[start[0]][start[1]] = True
117
+ queue: deque = deque([(start, [start])])
118
+
119
+ while queue:
120
+ (r, c), path = queue.popleft()
121
+ if (r, c) == goal:
122
+ return path
123
+ for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)):
124
+ nr, nc = r + dr, c + dc
125
+ if 0 <= nr < rows and 0 <= nc < cols and not visited[nr][nc]:
126
+ ch = desc[nr][nc]
127
+ if ch != "H":
128
+ visited[nr][nc] = True
129
+ queue.append(((nr, nc), path + [(nr, nc)]))
130
+ return None
131
+
132
+ # ==================== Path ↔ UDRL ====================
133
+
134
+ @staticmethod
135
+ def path_to_udrl(path: List[Tuple[int, int]]) -> str:
136
+ """Convert coordinate path to UDRL string."""
137
+ moves = []
138
+ for i in range(len(path) - 1):
139
+ r1, c1 = path[i]
140
+ r2, c2 = path[i + 1]
141
+ if r2 < r1:
142
+ moves.append("U")
143
+ elif r2 > r1:
144
+ moves.append("D")
145
+ elif c2 < c1:
146
+ moves.append("L")
147
+ else:
148
+ moves.append("R")
149
+ return "".join(moves)
150
+
151
+ # ==================== Verification ====================
152
+
153
+ def verify_path_sim(self, desc: GridDesc, udrl: str) -> bool:
154
+ """Verify UDRL via grid simulation (no dependencies)."""
155
+ rows, cols = len(desc), len(desc[0])
156
+ start = self.find_start(desc)
157
+ if start is None:
158
+ return False
159
+
160
+ r, c = start
161
+ clean = udrl.replace(",", "").replace(" ", "").strip()
162
+ if "Action plan" in clean:
163
+ clean = clean.rsplit("Action plan", 1)[-1]
164
+
165
+ for ch in clean:
166
+ if ch not in MOVES:
167
+ continue
168
+ dr, dc = MOVES[ch]
169
+ nr, nc = r + dr, c + dc
170
+ if not (0 <= nr < rows and 0 <= nc < cols):
171
+ return False
172
+ cell = desc[nr][nc]
173
+ if cell == "H":
174
+ return False
175
+ r, c = nr, nc
176
+ if cell == "G":
177
+ return True
178
+ return desc[r][c] == "G"
179
+
180
+ def verify_path_gym(self, desc: GridDesc, udrl: str) -> bool:
181
+ """Verify via gymnasium (falls back to sim if unavailable)."""
182
+ if not HAS_GYM:
183
+ return self.verify_path_sim(desc, udrl)
184
+ rows, cols = len(desc), len(desc[0])
185
+ try:
186
+ env = gym.make(
187
+ "FrozenLake-v1", desc=desc,
188
+ map_name=f"{rows}x{cols}", is_slippery=False, render_mode=None,
189
+ )
190
+ env.reset(seed=42)
191
+ success = False
192
+ clean = udrl.replace(",", "").replace(" ", "").strip()
193
+ if "Action plan" in clean:
194
+ clean = clean.rsplit("Action plan", 1)[-1]
195
+ for ch in clean:
196
+ if ch not in GYM_ACTION_MAP:
197
+ continue
198
+ _, reward, terminated, truncated, _ = env.step(GYM_ACTION_MAP[ch])
199
+ if terminated or truncated:
200
+ success = reward > 0
201
+ break
202
+ env.close()
203
+ return success
204
+ except Exception:
205
+ return self.verify_path_sim(desc, udrl)
206
+
207
+ # ==================== Table Text I/O ====================
208
+
209
+ def encode_table(self, desc: GridDesc) -> str:
210
+ """Encode to pipe-delimited table format."""
211
+ size = len(desc)
212
+ lines = ["| | " + " | ".join(f"Col {i+1}" for i in range(size)) + " |"]
213
+ for r in range(size):
214
+ mapped = [GRID_TO_TABLE[ch] for ch in desc[r]]
215
+ lines.append(f"| Row {r+1} | " + " | ".join(mapped) + " |")
216
+ return "\n".join(lines)
217
+
218
+ def decode_table(self, text: str) -> Optional[GridDesc]:
219
+ """Parse table text back to GridDesc."""
220
+ try:
221
+ rows = []
222
+ for line in text.strip().splitlines():
223
+ line = line.strip()
224
+ if not line or "Col" in line or "---" in line:
225
+ continue
226
+ parts = [p.strip() for p in line.split("|")]
227
+ clean = [p for p in parts if p]
228
+ if len(clean) < 2:
229
+ continue
230
+ row_str = "".join(
231
+ TABLE_TO_GRID[ch] for ch in clean[1:] if ch in TABLE_TO_GRID
232
+ )
233
+ if row_str:
234
+ rows.append(row_str)
235
+ return rows if rows else None
236
+ except Exception:
237
+ return None
238
+
239
+ def save_table(self, filepath: str, desc: GridDesc) -> None:
240
+ with open(filepath, "w") as f:
241
+ f.write(self.encode_table(desc))
242
+
243
+ def load_table(self, filepath: str) -> Optional[GridDesc]:
244
+ try:
245
+ with open(filepath) as f:
246
+ return self.decode_table(f.read())
247
+ except Exception:
248
+ return None
249
+
250
+ def find_start(self, desc: GridDesc) -> Optional[Tuple[int, int]]:
251
+ for r, row in enumerate(desc):
252
+ for c, ch in enumerate(row):
253
+ if ch == "S":
254
+ return (r, c)
255
+ return None
256
+
257
+ def fingerprint(self, desc: GridDesc) -> str:
258
+ return "".join(desc)
259
+
260
+ # ==================== Rendering ====================
261
+
262
+ def render_gym(self, desc: GridDesc) -> Optional[Image.Image]:
263
+ """Render via gymnasium (creates a pygame window — slow)."""
264
+ if not HAS_GYM:
265
+ return None
266
+ try:
267
+ env = gym.make(
268
+ "FrozenLake-v1", desc=desc,
269
+ is_slippery=False, render_mode="rgb_array",
270
+ )
271
+ env.reset()
272
+ rgb = env.render()
273
+ env.close()
274
+ return Image.fromarray(rgb).resize(
275
+ (self.img_size, self.img_size), Image.NEAREST
276
+ )
277
+ except Exception:
278
+ return None
279
+
280
+ def render_simple(self, desc: GridDesc) -> Image.Image:
281
+ """Fast PIL-only renderer (no pygame dependency)."""
282
+ size = len(desc)
283
+ cell = self.img_size // size
284
+ img = Image.new("RGB", (self.img_size, self.img_size), (255, 255, 255))
285
+ draw = ImageDraw.Draw(img)
286
+ colors = {
287
+ "S": (0, 0, 255), "F": (200, 220, 255),
288
+ "H": (80, 80, 80), "G": (0, 200, 0),
289
+ }
290
+ for r in range(size):
291
+ for c in range(size):
292
+ x0, y0 = c * cell, r * cell
293
+ draw.rectangle(
294
+ [x0, y0, x0 + cell - 1, y0 + cell - 1],
295
+ fill=colors.get(desc[r][c], (200, 220, 255)),
296
+ )
297
+ for i in range(size + 1):
298
+ draw.line([(i * cell, 0), (i * cell, self.img_size)], fill="black", width=1)
299
+ draw.line([(0, i * cell), (self.img_size, i * cell)], fill="black", width=1)
300
+ return img
301
+
302
+ def render(self, desc: GridDesc, use_gym: bool = True) -> Image.Image:
303
+ if use_gym:
304
+ img = self.render_gym(desc)
305
+ if img is not None:
306
+ return img
307
+ return self.render_simple(desc)
308
+
309
+ def draw_solution_line(
310
+ self, image: Image.Image, path: List[Tuple[int, int]], grid_size: int,
311
+ ) -> Image.Image:
312
+ """Draw red line on *image* (modifies in-place)."""
313
+ draw = ImageDraw.Draw(image)
314
+ w, h = image.size
315
+ cw, ch_ = w / grid_size, h / grid_size
316
+ pts = [(c * cw + cw / 2, r * ch_ + ch_ / 2) for r, c in path]
317
+ draw.line(pts, fill=self.path_color, width=max(1, int(cw / 4)), joint="curve")
318
+ return image
319
+
320
+ # ==================== Video Frames ====================
321
+
322
+ def generate_video_frames(
323
+ self,
324
+ desc: GridDesc,
325
+ path: List[Tuple[int, int]],
326
+ n_start: int = 5,
327
+ m_end: int = 5,
328
+ frames: Optional[int] = None,
329
+ use_gym: bool = True,
330
+ ) -> List[Image.Image]:
331
+ """
332
+ Progressive red-line video frames.
333
+
334
+ *frames* controls content frames between holds:
335
+ None → 1 per step, >steps → slow-mo, <steps → fast-fwd.
336
+ """
337
+ size = len(desc)
338
+ n_steps = len(path) - 1
339
+ base_img = self.render(desc, use_gym=use_gym)
340
+
341
+ if n_steps <= 0:
342
+ return [base_img] * (n_start + m_end + 1)
343
+
344
+ content = frames if frames is not None else n_steps
345
+ content = max(1, content)
346
+ result: List[Image.Image] = []
347
+
348
+ # Opening hold
349
+ result.extend([base_img.copy() for _ in range(n_start)])
350
+
351
+ def _partial(steps: int) -> Image.Image:
352
+ return self.draw_solution_line(base_img.copy(), path[: steps + 1], size)
353
+
354
+ if content == n_steps:
355
+ for s in range(1, n_steps + 1):
356
+ result.append(_partial(s))
357
+ elif content > n_steps:
358
+ for s in range(1, n_steps + 1):
359
+ lo = (s - 1) * content // n_steps
360
+ hi = s * content // n_steps
361
+ frame = _partial(s)
362
+ result.append(frame)
363
+ for _ in range(hi - lo - 1):
364
+ result.append(frame.copy())
365
+ else:
366
+ for f in range(content):
367
+ result.append(_partial((f + 1) * n_steps // content))
368
+
369
+ # Closing hold
370
+ final = _partial(n_steps)
371
+ result.extend([final.copy() for _ in range(m_end)])
372
+ return result
373
+
374
+ # ==================== Red-Path Extraction ====================
375
+
376
+ def extract_path_from_pixels(
377
+ self,
378
+ pixels: np.ndarray,
379
+ rows: int,
380
+ cols: int,
381
+ start: Tuple[int, int],
382
+ desc: Optional[GridDesc] = None,
383
+ pixel_threshold: float = 0.01,
384
+ ) -> str:
385
+ """Detect red path in RGB array, return UDRL."""
386
+ img = Image.fromarray(pixels)
387
+ w, h = img.size
388
+ px = np.array(img, dtype=float)
389
+ r_ch, g_ch, b_ch = px[:, :, 0], px[:, :, 1], px[:, :, 2]
390
+ red_mask = (r_ch > 100) & (r_ch > g_ch * 1.2) & (r_ch > b_ch * 1.2)
391
+
392
+ cell_h, cell_w = h // rows, w // cols
393
+ path_grid = np.zeros((rows, cols), dtype=bool)
394
+ for r in range(rows):
395
+ for c in range(cols):
396
+ sub = red_mask[r * cell_h : (r + 1) * cell_h,
397
+ c * cell_w : (c + 1) * cell_w]
398
+ if sub.size > 0 and np.mean(sub) > pixel_threshold:
399
+ path_grid[r, c] = True
400
+
401
+ # Greedy walk
402
+ visited = {start}
403
+ cr, cc = start
404
+ actions: List[str] = []
405
+ for _ in range(rows * cols * 2):
406
+ found = False
407
+ for act, (dr, dc) in [("R", (0, 1)), ("D", (1, 0)), ("L", (0, -1)), ("U", (-1, 0))]:
408
+ nr, nc = cr + dr, cc + dc
409
+ if 0 <= nr < rows and 0 <= nc < cols:
410
+ if path_grid[nr, nc] and (nr, nc) not in visited:
411
+ visited.add((nr, nc))
412
+ actions.append(act)
413
+ cr, cc = nr, nc
414
+ found = True
415
+ break
416
+ if not found:
417
+ break
418
+ return "".join(actions)
419
+
420
+ def extract_path_from_image(
421
+ self, img_path: str, rows: int, cols: int, start: Tuple, desc=None,
422
+ ) -> str:
423
+ """Extract UDRL from an image file."""
424
+ try:
425
+ pixels = np.array(Image.open(img_path).convert("RGB"))
426
+ return self.extract_path_from_pixels(pixels, rows, cols, start, desc)
427
+ except Exception:
428
+ return ""
429
+
430
+
431
+ if __name__ == "__main__":
432
+ import time
433
+
434
+ proc = FrozenLakeProcessor(img_size=512)
435
+
436
+ # Benchmark BFS vs problem sizes
437
+ for sz in [8, 16, 32, 64]:
438
+ t0 = time.perf_counter()
439
+ count = 0
440
+ for _ in range(100):
441
+ desc = proc._random_layout(sz, p=0.8)
442
+ path = proc.solve(desc)
443
+ if path:
444
+ count += 1
445
+ elapsed = time.perf_counter() - t0
446
+ print(f"Size {sz:3d}: 100 BFS solves in {elapsed:.3f}s "
447
+ f"({count} solvable, {elapsed/100*1000:.1f}ms/solve)")
448
+
449
+ # Functional test
450
+ desc, path = proc.generate(size=16, p=0.8, min_path_len=20)
451
+ udrl = proc.path_to_udrl(path)
452
+ print(f"\nGenerate 16×16: path={len(path)}, UDRL={udrl[:40]}...")
453
+ print(f"Verify (sim): {proc.verify_path_sim(desc, udrl)}")
454
+
455
+ # Table round-trip
456
+ decoded = proc.decode_table(proc.encode_table(desc))
457
+ assert decoded == desc
458
+ print("Table round-trip: ✓")
459
+
460
+ # Render + extract round-trip
461
+ img = proc.render(desc, use_gym=False)
462
+ sol = proc.draw_solution_line(img.copy(), path, len(desc))
463
+ start = proc.find_start(desc)
464
+ extracted = proc.extract_path_from_pixels(np.array(sol), len(desc), len(desc[0]), start)
465
+ print(f"Extract round-trip verify: {proc.verify_path_sim(desc, extracted)}")
466
+ print("All tests passed ✓")
maze/data_process.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maze Video Dataset Generator — generates maze puzzle images and solution videos
3
+ with checkpoint/resume support, train/test splitting, and JSONL metadata.
4
+
5
+ Includes an ``eval`` subcommand that takes a directory of result videos,
6
+ extracts the last frame from each, parses the red path, and verifies it
7
+ against the ground-truth maze text files.
8
+
9
+ Usage:
10
+ # Generate
11
+ python maze_video_gen.py generate --output-dir maze --sizes 8 16 32 \
12
+ --num-per-size 100 500 1000 --min-path-ratio 0.3 \
13
+ --n-start 5 --m-end 5 --frames 50 --fps 10 --seed 42
14
+
15
+ # Evaluate result videos
16
+ python maze_video_gen.py eval result_videos/ --text-dir maze/texts
17
+
18
+ # Verify a pre-extracted JSON
19
+ python maze_video_gen.py verify results.json --text-dir maze/texts
20
+ """
21
+ import json
22
+ import csv
23
+ import hashlib
24
+ import random
25
+ import re
26
+ import argparse
27
+ from dataclasses import dataclass, asdict
28
+ from pathlib import Path
29
+ from typing import Dict, List, Optional
30
+
31
+ import cv2
32
+ import numpy as np
33
+ from tqdm import tqdm
34
+
35
+ from maze_processor import MazeProcessor
36
+
37
+
38
+ # ==================== Checkpoint Management ====================
39
+
40
+ @dataclass
41
+ class GenerationState:
42
+ """Tracks generation progress for checkpoint/resume."""
43
+ params_hash: str
44
+ size_progress: Dict[int, int]
45
+ seen_fingerprints: List[str]
46
+ all_samples: List[Dict]
47
+ completed: bool = False
48
+
49
+ def to_dict(self) -> Dict:
50
+ return asdict(self)
51
+
52
+ @classmethod
53
+ def from_dict(cls, d: Dict) -> "GenerationState":
54
+ return cls(**d)
55
+
56
+
57
+ def _params_hash(params: Dict) -> str:
58
+ """Deterministic hash of generation parameters (excluding output_dir)."""
59
+ key = {k: v for k, v in params.items() if k != "output_dir"}
60
+ return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12]
61
+
62
+
63
+ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
64
+ """Load checkpoint if it exists and parameters match."""
65
+ meta = output_dir / "metadata.json"
66
+ if not meta.exists():
67
+ return None
68
+ with open(meta) as f:
69
+ data = json.load(f)
70
+ state = GenerationState.from_dict(data["state"])
71
+ expected = _params_hash(params)
72
+ if state.params_hash != expected:
73
+ print(f"⚠️ Parameters changed ({state.params_hash} → {expected}), starting fresh")
74
+ return None
75
+ if state.completed:
76
+ print("✓ Generation already completed")
77
+ return state
78
+ done = sum(state.size_progress.values())
79
+ print(f"✓ Resuming from checkpoint: {done} mazes generated")
80
+ return state
81
+
82
+
83
+ def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
84
+ """Atomically write checkpoint to metadata.json."""
85
+ meta = output_dir / "metadata.json"
86
+ tmp = meta.with_suffix(".tmp")
87
+ with open(tmp, "w") as f:
88
+ json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
89
+ tmp.rename(meta)
90
+
91
+
92
+ # ==================== Video I/O ====================
93
+
94
+ def save_video_cv2(frames: list, path: str, fps: int = 10):
95
+ """Save list of PIL Images as an mp4 video."""
96
+ first = np.array(frames[0])
97
+ h, w = first.shape[:2]
98
+ writer = cv2.VideoWriter(
99
+ str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)
100
+ )
101
+ for frame in frames:
102
+ writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
103
+ writer.release()
104
+
105
+
106
+ def extract_last_frame(video_path: str) -> Optional[np.ndarray]:
107
+ """
108
+ Extract the last frame from a video file as an RGB numpy array.
109
+
110
+ Returns:
111
+ (H, W, 3) uint8 RGB array, or None on failure.
112
+ """
113
+ cap = cv2.VideoCapture(str(video_path))
114
+ if not cap.isOpened():
115
+ return None
116
+
117
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
118
+ if total > 0:
119
+ cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1)
120
+
121
+ ret, frame = cap.read()
122
+ cap.release()
123
+
124
+ if not ret or frame is None:
125
+ return None
126
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
127
+
128
+
129
+ # ==================== Normalisation Helpers ====================
130
+
131
+ def _normalise_list(val, sizes, name="parameter"):
132
+ """Broadcast a single int to a list, or validate list length."""
133
+ if isinstance(val, int):
134
+ return [val] * len(sizes)
135
+ if len(val) != len(sizes):
136
+ raise ValueError(f"{name} length ({len(val)}) != sizes length ({len(sizes)})")
137
+ return list(val)
138
+
139
+
140
+ # ==================== Core Dataset Generation ====================
141
+
142
+ def generate_dataset(
143
+ output_dir: str = "maze",
144
+ sizes: List[int] = [8, 16, 32],
145
+ num_per_size: list = [100, 500, 1000],
146
+ min_path_ratio: float = 0.3,
147
+ img_size: int = 1024,
148
+ prompt: str = "Draw a continuous red line from the yellow dot to the blue dot, avoiding all walls.",
149
+ train_ratio: float = 0.9,
150
+ n_start: int = 5,
151
+ m_end: int = 5,
152
+ frames: Optional[int] = None,
153
+ fps: int = 10,
154
+ seed: int = 42,
155
+ checkpoint_interval: int = 50,
156
+ ):
157
+ """
158
+ Generate maze video dataset with checkpoint/resume support.
159
+
160
+ The *frames* parameter controls content frames per video:
161
+ - None → one content frame per path step (variable length)
162
+ - N > 0 → exactly N content frames (slow-mo / fast-fwd as needed)
163
+
164
+ Directory layout::
165
+
166
+ output_dir/
167
+ images/ — puzzle PNG (no solution line)
168
+ videos/ — solution MP4 (progressive red line)
169
+ texts/ — maze text files (bitmask format)
170
+ train.jsonl / test.jsonl
171
+ train.csv / test.csv
172
+ path.json — UDRL answer key
173
+ metadata.json — checkpoint state
174
+ """
175
+ params = {
176
+ "sizes": sizes, "num_per_size": num_per_size,
177
+ "min_path_ratio": min_path_ratio, "img_size": img_size,
178
+ "prompt": prompt, "train_ratio": train_ratio,
179
+ "n_start": n_start, "m_end": m_end, "frames": frames,
180
+ "fps": fps, "seed": seed,
181
+ }
182
+
183
+ out = Path(output_dir)
184
+ img_dir = out / "images"
185
+ vid_dir = out / "videos"
186
+ txt_dir = out / "texts"
187
+ for d in (img_dir, vid_dir, txt_dir):
188
+ d.mkdir(parents=True, exist_ok=True)
189
+
190
+ state = load_checkpoint(out, params)
191
+ if state and state.completed:
192
+ return
193
+
194
+ num_list = _normalise_list(
195
+ num_per_size[0] if len(num_per_size) == 1 else num_per_size,
196
+ sizes, "num_per_size",
197
+ )
198
+ max_puzzles = max(num_list)
199
+ num_w = len(str(max_puzzles))
200
+ proc = MazeProcessor(img_size=img_size)
201
+
202
+ if state is None:
203
+ random.seed(seed)
204
+ state = GenerationState(
205
+ params_hash=_params_hash(params),
206
+ size_progress={sz: 0 for sz in sizes},
207
+ seen_fingerprints=[],
208
+ all_samples=[],
209
+ )
210
+ print(f"Starting fresh generation: sizes={sizes}, counts={num_list}")
211
+ print(f" frames={'auto (1 per step)' if frames is None else frames}, "
212
+ f"n_start={n_start}, m_end={m_end}, fps={fps}")
213
+ else:
214
+ random.seed(seed)
215
+ for _ in range(sum(state.size_progress.values()) * 10):
216
+ random.random()
217
+
218
+ seen = set(state.seen_fingerprints)
219
+ all_samples = list(state.all_samples)
220
+ progress = {int(k): v for k, v in state.size_progress.items()}
221
+ since_ckpt = 0
222
+
223
+ total_target = sum(num_list)
224
+ total_done = sum(progress.values())
225
+
226
+ with tqdm(total=total_target, initial=total_done, desc="Total", unit="maze") as pbar:
227
+ for maze_size, target in zip(sizes, num_list):
228
+ generated = progress.get(maze_size, 0)
229
+ if generated >= target:
230
+ continue
231
+
232
+ min_len = max(1, int(maze_size * maze_size * min_path_ratio))
233
+ max_attempts = (target - generated) * 20
234
+
235
+ with tqdm(
236
+ total=target, initial=generated, desc=f"Size {maze_size:3d}",
237
+ unit="maze", leave=False,
238
+ ) as pbar_sz:
239
+ for _ in range(max_attempts):
240
+ if generated >= target:
241
+ break
242
+
243
+ try:
244
+ grid, start, end, path = proc.generate(
245
+ maze_size, min_path_len=min_len
246
+ )
247
+ except RuntimeError:
248
+ continue
249
+
250
+ fp = proc.fingerprint(grid, start, end)
251
+ if fp in seen:
252
+ continue
253
+ seen.add(fp)
254
+
255
+ idx = generated
256
+ base = f"size{maze_size}_{idx:0{num_w}d}"
257
+ img_name = f"{base}.png"
258
+ vid_name = f"{base}.mp4"
259
+ txt_name = f"{base}.txt"
260
+
261
+ puzzle_img = proc.render(grid, start, end)
262
+ puzzle_img.save(str(img_dir / img_name))
263
+
264
+ vid_frames = proc.generate_video_frames(
265
+ grid, start, end, path,
266
+ n_start=n_start, m_end=m_end, frames=frames,
267
+ )
268
+ save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps)
269
+
270
+ proc.save_text(str(txt_dir / txt_name), grid, start, end)
271
+
272
+ udrl = proc.path_to_udrl(path)
273
+
274
+ all_samples.append({
275
+ "prompt": prompt,
276
+ "image": img_name,
277
+ "video": vid_name,
278
+ "text": txt_name,
279
+ "maze_size": maze_size,
280
+ "start": list(start),
281
+ "end": list(end),
282
+ "path_udrl": udrl,
283
+ "path_length": len(path),
284
+ "frame_count": len(vid_frames),
285
+ })
286
+
287
+ generated += 1
288
+ progress[maze_size] = generated
289
+ since_ckpt += 1
290
+ pbar_sz.update(1)
291
+ pbar.update(1)
292
+
293
+ if since_ckpt >= checkpoint_interval:
294
+ state.size_progress = progress
295
+ state.seen_fingerprints = list(seen)
296
+ state.all_samples = all_samples
297
+ save_checkpoint(out, state, params)
298
+ since_ckpt = 0
299
+
300
+ tqdm.write(
301
+ f"Size {maze_size}: {generated} mazes, "
302
+ f"{sum(1 for s in all_samples if s['maze_size'] == maze_size)} samples"
303
+ )
304
+
305
+ # ==================== Final outputs ====================
306
+
307
+ path_answers = {s["image"]: s["path_udrl"] for s in all_samples}
308
+ with open(out / "path.json", "w") as f:
309
+ json.dump(dict(sorted(path_answers.items())), f, indent=4)
310
+
311
+ random.seed(seed + 1)
312
+ random.shuffle(all_samples)
313
+ split = int(len(all_samples) * train_ratio)
314
+
315
+ def _write_jsonl(samples, path):
316
+ with open(path, "w") as f:
317
+ for s in samples:
318
+ f.write(json.dumps(s) + "\n")
319
+
320
+ _write_jsonl(all_samples[:split], out / "train.jsonl")
321
+ _write_jsonl(all_samples[split:], out / "test.jsonl")
322
+
323
+ for name, samples in [("train", all_samples[:split]), ("test", all_samples[split:])]:
324
+ with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f:
325
+ writer = csv.writer(f)
326
+ writer.writerow(["input_image", "video", "prompt"])
327
+ for s in samples:
328
+ writer.writerow([
329
+ f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]
330
+ ])
331
+
332
+ state.size_progress = progress
333
+ state.seen_fingerprints = list(seen)
334
+ state.all_samples = all_samples
335
+ state.completed = True
336
+ save_checkpoint(out, state, params)
337
+
338
+ print(f"\n✓ Dataset complete: {out}/")
339
+ print(f" Sizes: {sizes}")
340
+ print(f" Mazes: {len(all_samples)}")
341
+ print(f" Train: {split}, Test: {len(all_samples) - split}")
342
+ lengths = [s["path_length"] for s in all_samples]
343
+ fcounts = [s["frame_count"] for s in all_samples]
344
+ print(f" Path lengths: avg={np.mean(lengths):.1f}, "
345
+ f"min={min(lengths)}, max={max(lengths)}")
346
+ print(f" Frame counts: avg={np.mean(fcounts):.1f}, "
347
+ f"min={min(fcounts)}, max={max(fcounts)}")
348
+
349
+
350
+ # ==================== Eval: Video → Last Frame → Verify ====================
351
+
352
+ def eval_videos(
353
+ video_dir: str,
354
+ text_dir: str,
355
+ output_json: Optional[str] = None,
356
+ gt_json: Optional[str] = None,
357
+ ):
358
+ """
359
+ Evaluate a directory of result videos against ground-truth mazes.
360
+
361
+ Pipeline per video:
362
+ 1. Extract last frame from .mp4
363
+ 2. Detect red path via pixel analysis
364
+ 3. Convert to UDRL action string
365
+ 4. Verify against maze .txt (wall-respecting walk from start to end)
366
+
367
+ Matching convention:
368
+ Video ``<stem>.mp4`` → Text ``<stem>.txt`` in *text_dir*.
369
+ Common stems: ``size8_000``, ``size16_042``, etc.
370
+
371
+ Args:
372
+ video_dir: Directory containing result .mp4 files.
373
+ text_dir: Directory containing ground-truth maze .txt files.
374
+ output_json: Path to save extracted paths as JSON (default: video_dir/0_result.json).
375
+ gt_json: Optional ground-truth answer JSON for accuracy by path length.
376
+ """
377
+ proc = MazeProcessor()
378
+ vid_root = Path(video_dir)
379
+ txt_root = Path(text_dir)
380
+
381
+ if output_json is None:
382
+ output_json = str(vid_root / "0_result.json")
383
+
384
+ # Collect videos
385
+ videos = sorted(
386
+ vid_root.glob("*.mp4"),
387
+ key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
388
+ )
389
+
390
+ if not videos:
391
+ print(f"No .mp4 files found in {vid_root}")
392
+ return
393
+
394
+ print(f"Found {len(videos)} result videos in {vid_root}")
395
+ print(f"Text dir: {txt_root}")
396
+
397
+ # --- Phase 1: Extract paths from last frames ---
398
+ extracted: Dict[str, str] = {}
399
+ missing_txt = 0
400
+ missing_frame = 0
401
+
402
+ for vpath in tqdm(videos, desc="Extracting paths"):
403
+ stem = vpath.stem # e.g. "size8_000"
404
+ txt_path = txt_root / f"{stem}.txt"
405
+
406
+ if not txt_path.exists():
407
+ missing_txt += 1
408
+ continue
409
+
410
+ maze = proc.load_text(str(txt_path))
411
+ if maze is None:
412
+ missing_txt += 1
413
+ continue
414
+
415
+ last_frame = extract_last_frame(str(vpath))
416
+ if last_frame is None:
417
+ missing_frame += 1
418
+ continue
419
+
420
+ udrl = proc.extract_path_from_pixels(
421
+ last_frame,
422
+ grid_raw=maze["grid_raw"],
423
+ size=maze["size"],
424
+ start=maze["start"],
425
+ )
426
+ extracted[f"{stem}.png"] = udrl # keyed by image name for consistency
427
+
428
+ # Save extracted paths
429
+ with open(output_json, "w", encoding="utf-8") as f:
430
+ json.dump(extracted, f, indent=4)
431
+ print(f"\nExtracted paths saved to: {output_json}")
432
+
433
+ # --- Phase 2: Verify ---
434
+ correct = 0
435
+ total_valid = 0
436
+ correctly_solved: List[Dict] = []
437
+
438
+ for name, udrl in extracted.items():
439
+ stem = name.replace(".png", "")
440
+ txt_path = txt_root / f"{stem}.txt"
441
+ maze = proc.load_text(str(txt_path))
442
+ if maze is None:
443
+ continue
444
+ total_valid += 1
445
+ if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl):
446
+ correct += 1
447
+ correctly_solved.append({"name": name, "length": len(udrl)})
448
+
449
+ acc = (correct / total_valid * 100) if total_valid else 0
450
+
451
+ print(f"\n{'=' * 50}")
452
+ print("Evaluation Summary")
453
+ print(f"{'=' * 50}")
454
+ print(f"Total Videos : {len(videos)}")
455
+ print(f"Missing .txt : {missing_txt}")
456
+ print(f"Failed Frame Read : {missing_frame}")
457
+ print(f"Evaluated : {total_valid}")
458
+ print(f"Correctly Solved : {correct}")
459
+ print(f"Accuracy : {acc:.2f}%")
460
+ print(f"{'-' * 50}")
461
+
462
+ # Breakdown by maze size
463
+ size_stats: Dict[int, Dict[str, int]] = {}
464
+ for name, udrl in extracted.items():
465
+ stem = name.replace(".png", "")
466
+ txt_path = txt_root / f"{stem}.txt"
467
+ maze = proc.load_text(str(txt_path))
468
+ if maze is None:
469
+ continue
470
+ sz = maze["size"]
471
+ if sz not in size_stats:
472
+ size_stats[sz] = {"total": 0, "correct": 0}
473
+ size_stats[sz]["total"] += 1
474
+ if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl):
475
+ size_stats[sz]["correct"] += 1
476
+
477
+ if size_stats:
478
+ print("\nAccuracy by maze size:")
479
+ for sz in sorted(size_stats):
480
+ s = size_stats[sz]
481
+ sz_acc = s["correct"] / s["total"] * 100 if s["total"] else 0
482
+ print(f" Size {sz:3d}: {s['correct']:4d}/{s['total']:4d} ({sz_acc:.2f}%)")
483
+
484
+ # Top longest correct
485
+ correctly_solved.sort(key=lambda x: x["length"], reverse=True)
486
+ if correctly_solved:
487
+ print(f"\nTop 3 Longest Correct Paths:")
488
+ for i, item in enumerate(correctly_solved[:3]):
489
+ print(f" {i+1}. {item['name']} (length: {item['length']})")
490
+
491
+ # Optional: compare with ground-truth JSON for path-length-binned accuracy
492
+ if gt_json:
493
+ _compare_with_gt(extracted, gt_json, txt_root, proc)
494
+
495
+ print(f"{'=' * 50}")
496
+
497
+
498
+ def _compare_with_gt(
499
+ extracted: Dict[str, str],
500
+ gt_json_path: str,
501
+ txt_root: Path,
502
+ proc: MazeProcessor,
503
+ ):
504
+ """Print accuracy binned by ground-truth path length."""
505
+ try:
506
+ with open(gt_json_path) as f:
507
+ gt = json.load(f)
508
+ except Exception:
509
+ print(f" Warning: could not load ground-truth JSON: {gt_json_path}")
510
+ return
511
+
512
+ bins: Dict[str, Dict[str, int]] = {} # "10-19" -> {total, correct}
513
+ for name, pred_udrl in extracted.items():
514
+ if name not in gt:
515
+ continue
516
+ gt_udrl = gt[name]
517
+ gt_len = len(gt_udrl)
518
+
519
+ # Bin by path length (decades)
520
+ lo = (gt_len // 10) * 10
521
+ hi = lo + 9
522
+ label = f"{lo:3d}-{hi:3d}"
523
+ if label not in bins:
524
+ bins[label] = {"total": 0, "correct": 0}
525
+ bins[label]["total"] += 1
526
+
527
+ stem = name.replace(".png", "")
528
+ maze = proc.load_text(str(txt_root / f"{stem}.txt"))
529
+ if maze and proc.verify_path(maze["grid"], maze["start"], maze["end"], pred_udrl):
530
+ bins[label]["correct"] += 1
531
+
532
+ if bins:
533
+ print("\nAccuracy by GT path length:")
534
+ for label in sorted(bins):
535
+ b = bins[label]
536
+ b_acc = b["correct"] / b["total"] * 100 if b["total"] else 0
537
+ print(f" Length {label}: {b['correct']:4d}/{b['total']:4d} ({b_acc:.2f}%)")
538
+
539
+
540
+ # ==================== Verify: Pre-extracted JSON ====================
541
+
542
+ def verify_results(json_file: str, text_dir: str):
543
+ """
544
+ Verify pre-extracted UDRL paths (from a JSON file) against maze .txt files.
545
+
546
+ Args:
547
+ json_file: Path to JSON with {name: udrl_string} predictions.
548
+ text_dir: Directory containing maze .txt files.
549
+ """
550
+ proc = MazeProcessor()
551
+ json_path = Path(json_file)
552
+ txt_root = Path(text_dir)
553
+
554
+ with open(json_path) as f:
555
+ solutions = json.load(f)
556
+
557
+ correct = skipped = valid = 0
558
+
559
+ for name, udrl in solutions.items():
560
+ clean = name.replace(".png", "")
561
+ txt_path = txt_root / f"{clean}.txt"
562
+ maze = proc.load_text(str(txt_path))
563
+ if maze is None:
564
+ skipped += 1
565
+ continue
566
+ valid += 1
567
+ if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl):
568
+ correct += 1
569
+
570
+ acc = (correct / valid * 100) if valid else 0
571
+ print(f"\n{'='*40}")
572
+ print(f"Verification: {correct}/{valid} correct ({acc:.2f}%)")
573
+ if skipped:
574
+ print(f"Skipped: {skipped}")
575
+ print(f"{'='*40}")
576
+
577
+
578
+ # ==================== CLI ====================
579
+
580
+ def parse_args():
581
+ p = argparse.ArgumentParser(
582
+ description="Maze video dataset: generate, eval, verify"
583
+ )
584
+ sub = p.add_subparsers(dest="command", help="Sub-command")
585
+
586
+ # --- generate ---
587
+ gen = sub.add_parser("generate", help="Generate dataset")
588
+ gen.add_argument("--output-dir", type=str, default="maze")
589
+ gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 24, 32])
590
+ gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000, 2000])
591
+ gen.add_argument("--min-path-ratio", type=float, default=0.3,
592
+ help="Min path length as fraction of size²")
593
+ gen.add_argument("--img-size", type=int, default=1024)
594
+ gen.add_argument("--prompt", type=str,
595
+ default="Draw a continuous red line from the yellow dot "
596
+ "to the blue dot, avoiding all walls.")
597
+ gen.add_argument("--train-ratio", type=float, default=0.9)
598
+ gen.add_argument("--n-start", type=int, default=2,
599
+ help="Hold frames at video start (blank puzzle)")
600
+ gen.add_argument("--m-end", type=int, default=3,
601
+ help="Hold frames at video end (completed solution)")
602
+ gen.add_argument("--frames", type=int, default=None,
603
+ help="Content frames per video (None=auto 1 per step)")
604
+ gen.add_argument("--fps", type=int, default=10)
605
+ gen.add_argument("--seed", type=int, default=42)
606
+ gen.add_argument("--checkpoint-interval", type=int, default=50)
607
+
608
+ # --- eval ---
609
+ ev = sub.add_parser("eval",
610
+ help="Evaluate result videos (last frame → extract → verify)")
611
+ ev.add_argument("video_dir", type=str,
612
+ help="Directory containing result .mp4 files")
613
+ ev.add_argument("--text-dir", type=str, required=True,
614
+ help="Directory with ground-truth maze .txt files")
615
+ ev.add_argument("--output-json", type=str, default=None,
616
+ help="Output JSON for extracted paths (default: video_dir/0_result.json)")
617
+ ev.add_argument("--gt-json", type=str, default=None,
618
+ help="Optional ground-truth path.json for length-binned accuracy")
619
+
620
+ # --- verify ---
621
+ ver = sub.add_parser("verify", help="Verify a pre-extracted JSON of UDRL paths")
622
+ ver.add_argument("json_file", type=str)
623
+ ver.add_argument("--text-dir", type=str, required=True,
624
+ help="Directory with maze .txt files")
625
+
626
+ return p.parse_args()
627
+
628
+
629
+ if __name__ == "__main__":
630
+ args = parse_args()
631
+
632
+ if args.command == "generate":
633
+ kwargs = {k: v for k, v in vars(args).items() if k != "command"}
634
+ generate_dataset(**kwargs)
635
+
636
+ elif args.command == "eval":
637
+ eval_videos(
638
+ video_dir=args.video_dir,
639
+ text_dir=args.text_dir,
640
+ output_json=args.output_json,
641
+ gt_json=args.gt_json,
642
+ )
643
+
644
+ elif args.command == "verify":
645
+ verify_results(args.json_file, args.text_dir)
646
+
647
+ else:
648
+ print("Usage: python maze_video_gen.py {generate|eval|verify} [options]")
649
+ print(" python maze_video_gen.py generate --help")
650
+ print(" python maze_video_gen.py eval --help")
651
+ print(" python maze_video_gen.py verify --help")
maze/maze_processor.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MazeProcessor - Maze generation, solving, rendering, and video frame generation.
3
+
4
+ Mirrors the SudokuProcessor pattern: a single class encapsulating all maze logic
5
+ including DFS generation, BFS solving, image/video rendering, path verification,
6
+ and text serialization.
7
+ """
8
+ import random
9
+ from collections import deque
10
+ from typing import List, Tuple, Optional, Dict
11
+
12
+ import numpy as np
13
+ from PIL import Image, ImageDraw
14
+
15
+ # Wall bitmask encoding (matches text file format)
16
+ WALL_MASKS = {"N": 1, "S": 2, "W": 4, "E": 8}
17
+ OPPOSITE = {"N": "S", "S": "N", "E": "W", "W": "E"}
18
+ MOVES = {
19
+ "U": (-1, 0, "N"),
20
+ "D": (1, 0, "S"),
21
+ "L": (0, -1, "W"),
22
+ "R": (0, 1, "E"),
23
+ }
24
+ NEIGHBORS = {
25
+ "N": (-1, 0),
26
+ "S": (1, 0),
27
+ "E": (0, 1),
28
+ "W": (0, -1),
29
+ }
30
+
31
+
32
+ # ======================== Grid Type ========================
33
+ # grid[r][c] = {"N": bool, "S": bool, "W": bool, "E": bool}
34
+ # True => wall present, False => passage open
35
+
36
+ Grid = List[List[Dict[str, bool]]]
37
+
38
+
39
+ class MazeProcessor:
40
+ """Handles maze generation, solving, image rendering, and video frame creation."""
41
+
42
+ def __init__(self, img_size: int = 512):
43
+ self.img_size = img_size
44
+
45
+ # Rendering colours (RGB)
46
+ self.bg_color = "black"
47
+ self.cell_color = "white"
48
+ self.wall_color = "black"
49
+ self.grid_color = (224, 224, 224)
50
+ self.start_color = "yellow"
51
+ self.end_color = "blue"
52
+ self.path_color = "red"
53
+
54
+ # ==================== Generation (DFS) ====================
55
+
56
+ @staticmethod
57
+ def _empty_grid(n: int) -> Grid:
58
+ """Create an n×n grid with all walls present."""
59
+ return [
60
+ [{"N": True, "E": True, "S": True, "W": True} for _ in range(n)]
61
+ for _ in range(n)
62
+ ]
63
+
64
+ @staticmethod
65
+ def _remove_wall(grid: Grid, r: int, c: int, direction: str) -> None:
66
+ """Remove wall between (r,c) and its neighbour in *direction*."""
67
+ dr, dc = NEIGHBORS[direction]
68
+ grid[r][c][direction] = False
69
+ grid[r + dr][c + dc][OPPOSITE[direction]] = False
70
+
71
+ def generate(
72
+ self, size: int, min_path_len: int = 1, max_attempts: int = 200
73
+ ) -> Tuple[Grid, Tuple[int, int], Tuple[int, int], np.ndarray]:
74
+ """
75
+ Generate a perfect maze and a start/end pair whose shortest path
76
+ length >= *min_path_len*.
77
+
78
+ Returns:
79
+ (grid, start, end, path) where path is an (L, 2) int array.
80
+ """
81
+ for _ in range(max_attempts):
82
+ grid = self._gen_dfs(size)
83
+ nodes = [(r, c) for r in range(size) for c in range(size)]
84
+ start, end = random.sample(nodes, 2)
85
+ path = self.solve_bfs(grid, start, end)
86
+ if path is not None and len(path) >= min_path_len:
87
+ return grid, tuple(start), tuple(end), path
88
+ raise RuntimeError(
89
+ f"Failed to generate maze (size={size}, min_path_len={min_path_len}) "
90
+ f"after {max_attempts} attempts."
91
+ )
92
+
93
+ def _gen_dfs(self, n: int) -> Grid:
94
+ """Randomised DFS (iterative) to carve a perfect maze."""
95
+ grid = self._empty_grid(n)
96
+ visited = [[False] * n for _ in range(n)]
97
+ sr, sc = random.randrange(n), random.randrange(n)
98
+ visited[sr][sc] = True
99
+ stack = [(sr, sc)]
100
+
101
+ while stack:
102
+ r, c = stack[-1]
103
+ nbrs = []
104
+ for d, (dr, dc) in NEIGHBORS.items():
105
+ nr, nc = r + dr, c + dc
106
+ if 0 <= nr < n and 0 <= nc < n and not visited[nr][nc]:
107
+ nbrs.append((d, nr, nc))
108
+ if nbrs:
109
+ d, nr, nc = random.choice(nbrs)
110
+ self._remove_wall(grid, r, c, d)
111
+ visited[nr][nc] = True
112
+ stack.append((nr, nc))
113
+ else:
114
+ stack.pop()
115
+ return grid
116
+
117
+ # ==================== Solving (BFS) ====================
118
+
119
+ def solve_bfs(
120
+ self, grid: Grid, start: Tuple[int, int], end: Tuple[int, int]
121
+ ) -> Optional[np.ndarray]:
122
+ """BFS shortest path. Returns (L,2) int ndarray or None."""
123
+ n = len(grid)
124
+ q: deque = deque([(start, [start])])
125
+ visited = {start}
126
+
127
+ while q:
128
+ (r, c), path = q.popleft()
129
+ if (r, c) == end:
130
+ return np.array(path, dtype=int)
131
+ cell = grid[r][c]
132
+ for d, (dr, dc) in NEIGHBORS.items():
133
+ nr, nc = r + dr, c + dc
134
+ if (
135
+ 0 <= nr < n
136
+ and 0 <= nc < n
137
+ and not cell[d]
138
+ and (nr, nc) not in visited
139
+ ):
140
+ visited.add((nr, nc))
141
+ q.append(((nr, nc), path + [(nr, nc)]))
142
+ return None
143
+
144
+ # ==================== Path ↔ UDRL ====================
145
+
146
+ @staticmethod
147
+ def path_to_udrl(path) -> str:
148
+ """Convert coordinate path to UDRL string."""
149
+ moves = []
150
+ for i in range(len(path) - 1):
151
+ r1, c1 = path[i]
152
+ r2, c2 = path[i + 1]
153
+ if r2 < r1:
154
+ moves.append("U")
155
+ elif r2 > r1:
156
+ moves.append("D")
157
+ elif c2 < c1:
158
+ moves.append("L")
159
+ else:
160
+ moves.append("R")
161
+ return "".join(moves)
162
+
163
+ # ==================== Verification ====================
164
+
165
+ def verify_path(self, grid: Grid, start: Tuple, end: Tuple, udrl: str) -> bool:
166
+ """Verify that *udrl* is a wall-respecting walk from *start* to *end*."""
167
+ n = len(grid)
168
+ r, c = start
169
+ for ch in udrl.replace(",", "").replace(" ", "").strip():
170
+ if ch not in MOVES:
171
+ continue
172
+ dr, dc, wall = MOVES[ch]
173
+ if grid[r][c][wall]:
174
+ return False
175
+ nr, nc = r + dr, c + dc
176
+ if not (0 <= nr < n and 0 <= nc < n):
177
+ return False
178
+ r, c = nr, nc
179
+ return (r, c) == end
180
+
181
+ # ==================== Text Encoding ====================
182
+
183
+ def encode_grid(self, grid: Grid) -> str:
184
+ """Encode grid to compact bitmask string (one int per cell, row-major)."""
185
+ rows = []
186
+ for row in grid:
187
+ vals = []
188
+ for cell in row:
189
+ v = 0
190
+ for d, mask in WALL_MASKS.items():
191
+ if cell[d]:
192
+ v |= mask
193
+ vals.append(str(v))
194
+ rows.append(" ".join(vals))
195
+ return "\n".join(rows)
196
+
197
+ def decode_grid(self, text_lines: List[str]) -> Grid:
198
+ """Decode bitmask text lines back to grid dicts."""
199
+ grid = []
200
+ for line in text_lines:
201
+ row = []
202
+ for val_s in line.split():
203
+ val = int(val_s)
204
+ row.append({d: bool(val & mask) for d, mask in WALL_MASKS.items()})
205
+ grid.append(row)
206
+ return grid
207
+
208
+ def save_text(self, filepath, grid: Grid, start: Tuple, end: Tuple) -> None:
209
+ """Save maze to compact text file."""
210
+ n = len(grid)
211
+ with open(filepath, "w") as f:
212
+ f.write(f"{n}\n{start[0]} {start[1]}\n{end[0]} {end[1]}\n")
213
+ f.write(self.encode_grid(grid) + "\n")
214
+
215
+ def load_text(self, filepath) -> Optional[Dict]:
216
+ """
217
+ Load maze from text file.
218
+
219
+ Returns dict with keys: size, start, end, grid (dict-based),
220
+ grid_raw (list[list[int]] bitmask). None on failure.
221
+ """
222
+ try:
223
+ with open(filepath) as f:
224
+ lines = [l.strip() for l in f if l.strip()]
225
+ n = int(lines[0])
226
+ sr, sc = map(int, lines[1].split())
227
+ er, ec = map(int, lines[2].split())
228
+ grid = self.decode_grid(lines[3 : 3 + n])
229
+ grid_raw: List[List[int]] = []
230
+ for r in range(n):
231
+ grid_raw.append(list(map(int, lines[3 + r].split())))
232
+ return {
233
+ "size": n,
234
+ "start": (sr, sc),
235
+ "end": (er, ec),
236
+ "grid": grid,
237
+ "grid_raw": grid_raw,
238
+ }
239
+ except Exception:
240
+ return None
241
+
242
+ def fingerprint(self, grid: Grid, start: Tuple, end: Tuple) -> str:
243
+ """Content fingerprint for deduplication."""
244
+ n = len(grid)
245
+ parts = [f"{n},{start[0]},{start[1]},{end[0]},{end[1]}"]
246
+ for row in grid:
247
+ for cell in row:
248
+ v = sum(WALL_MASKS[d] for d in WALL_MASKS if cell[d])
249
+ parts.append(str(v))
250
+ return "|".join(parts)
251
+
252
+ # ==================== Image Rendering ====================
253
+
254
+ def _layout(self, n: int):
255
+ """Compute rendering layout parameters."""
256
+ cell_f = float(self.img_size) / n
257
+ wall_f = cell_f / 4.0
258
+ half_f = wall_f / 2.0
259
+ grid_w = max(1, int(cell_f / 16.0))
260
+ return cell_f, wall_f, half_f, grid_w
261
+
262
+ def render(
263
+ self,
264
+ grid: Grid,
265
+ start: Tuple[int, int],
266
+ end: Tuple[int, int],
267
+ path: Optional[np.ndarray] = None,
268
+ path_steps: Optional[int] = None,
269
+ ) -> Image.Image:
270
+ """
271
+ Render maze as a PIL image.
272
+
273
+ Args:
274
+ grid: The maze grid.
275
+ start, end: Coordinates of start/end cells.
276
+ path: Full solution path (L, 2).
277
+ path_steps: Draw only the first *path_steps* segments (for video).
278
+
279
+ Returns:
280
+ PIL.Image (RGB, img_size × img_size).
281
+ """
282
+ n = len(grid)
283
+ cell_f, wall_f, half_f, grid_w = self._layout(n)
284
+
285
+ img = Image.new("RGB", (self.img_size, self.img_size), self.bg_color)
286
+ draw = ImageDraw.Draw(img)
287
+
288
+ # --- fill cells & open passages ---
289
+ for r in range(n):
290
+ for c in range(n):
291
+ x1 = c * cell_f + half_f
292
+ y1 = r * cell_f + half_f
293
+ x2 = (c + 1) * cell_f - half_f
294
+ y2 = (r + 1) * cell_f - half_f
295
+ draw.rectangle([(x1, y1), (x2, y2)], fill=self.cell_color)
296
+ cell = grid[r][c]
297
+ if not cell["S"] and r < n - 1:
298
+ draw.rectangle(
299
+ [(x1, y2), (x2, y2 + wall_f)], fill=self.cell_color
300
+ )
301
+ if not cell["E"] and c < n - 1:
302
+ draw.rectangle(
303
+ [(x2, y1), (x2 + wall_f, y2)], fill=self.cell_color
304
+ )
305
+
306
+ # --- subtle grid lines on open passages ---
307
+ for r in range(n):
308
+ for c in range(n):
309
+ if r < n - 1 and not grid[r][c]["S"]:
310
+ y = (r + 1) * cell_f
311
+ draw.line(
312
+ [(c * cell_f + half_f, y), ((c + 1) * cell_f - half_f, y)],
313
+ fill=self.grid_color, width=grid_w,
314
+ )
315
+ if c < n - 1 and not grid[r][c]["E"]:
316
+ x = (c + 1) * cell_f
317
+ draw.line(
318
+ [(x, r * cell_f + half_f), (x, (r + 1) * cell_f - half_f)],
319
+ fill=self.grid_color, width=grid_w,
320
+ )
321
+
322
+ # --- start / end dots ---
323
+ def _dot(rc, color):
324
+ rr, cc = rc
325
+ cx = cc * cell_f + cell_f / 2
326
+ cy = rr * cell_f + cell_f / 2
327
+ rad = max(2, int((cell_f - wall_f) * 0.25))
328
+ draw.ellipse([cx - rad, cy - rad, cx + rad, cy + rad], fill=color)
329
+
330
+ _dot(start, self.start_color)
331
+ _dot(end, self.end_color)
332
+
333
+ # --- solution path (optionally partial) ---
334
+ if path is not None and len(path) >= 2:
335
+ end_idx = (
336
+ len(path) if path_steps is None
337
+ else min(path_steps + 1, len(path))
338
+ )
339
+ if end_idx >= 2:
340
+ pts = [
341
+ (c * cell_f + cell_f / 2, r * cell_f + cell_f / 2)
342
+ for r, c in path[:end_idx]
343
+ ]
344
+ draw.line(
345
+ pts, fill=self.path_color,
346
+ width=max(1, int(wall_f)), joint="curve",
347
+ )
348
+
349
+ return img
350
+
351
+ # ==================== Video Frame Generation ====================
352
+
353
+ def generate_video_frames(
354
+ self,
355
+ grid: Grid,
356
+ start: Tuple[int, int],
357
+ end: Tuple[int, int],
358
+ path: np.ndarray,
359
+ n_start: int = 5,
360
+ m_end: int = 5,
361
+ frames: Optional[int] = None,
362
+ ) -> List[Image.Image]:
363
+ """
364
+ Generate progressive video frames showing the red line growing.
365
+
366
+ *frames* controls the number of **content frames** between holds:
367
+ - None → 1 per path step
368
+ - frames > steps → slow-motion
369
+ - frames < steps → fast-forward
370
+
371
+ Total length = n_start + content_frames + m_end.
372
+ """
373
+ n_steps = len(path) - 1
374
+ if n_steps <= 0:
375
+ blank = self.render(grid, start, end)
376
+ return [blank] * (n_start + m_end + 1)
377
+
378
+ content_frames = frames if frames is not None else n_steps
379
+ content_frames = max(1, content_frames)
380
+
381
+ result: List[Image.Image] = []
382
+
383
+ # Opening hold
384
+ blank = self.render(grid, start, end)
385
+ result.extend([blank.copy() for _ in range(n_start)])
386
+
387
+ # Content frames
388
+ if content_frames == n_steps:
389
+ for step in range(1, n_steps + 1):
390
+ result.append(
391
+ self.render(grid, start, end, path=path, path_steps=step)
392
+ )
393
+ elif content_frames > n_steps:
394
+ for step in range(1, n_steps + 1):
395
+ f_lo = (step - 1) * content_frames // n_steps
396
+ f_hi = step * content_frames // n_steps
397
+ count = f_hi - f_lo
398
+ frame_img = self.render(
399
+ grid, start, end, path=path, path_steps=step
400
+ )
401
+ result.append(frame_img)
402
+ if count > 1:
403
+ result.extend([frame_img.copy() for _ in range(count - 1)])
404
+ else:
405
+ for f in range(content_frames):
406
+ step = (f + 1) * n_steps // content_frames
407
+ result.append(
408
+ self.render(grid, start, end, path=path, path_steps=step)
409
+ )
410
+
411
+ # Closing hold
412
+ final = self.render(grid, start, end, path=path)
413
+ result.extend([final.copy() for _ in range(m_end)])
414
+
415
+ return result
416
+
417
+ # ==================== Red-Path Extraction ====================
418
+
419
+ def extract_path_from_pixels(
420
+ self,
421
+ pixels: np.ndarray,
422
+ grid_raw: List[List[int]],
423
+ size: int,
424
+ start: Tuple[int, int],
425
+ pixel_threshold: float = 0.01,
426
+ ) -> str:
427
+ """
428
+ Detect red path in an RGB pixel array and return UDRL.
429
+
430
+ Uses **floating-point** cell boundaries matching the renderer to avoid
431
+ misalignment on sizes that don't evenly divide the image (e.g. 24, 48).
432
+
433
+ Args:
434
+ pixels: (H, W, 3) uint8 RGB array.
435
+ grid_raw: Bitmask grid as list[list[int]].
436
+ size: Maze dimension n.
437
+ start: Start coordinate (r, c).
438
+ pixel_threshold: Min red-pixel fraction to mark a cell.
439
+
440
+ Returns:
441
+ UDRL action string.
442
+ """
443
+ img = Image.fromarray(pixels)
444
+ w, h = img.size
445
+ px = np.array(img, dtype=float)
446
+
447
+ r_ch, g_ch, b_ch = px[:, :, 0], px[:, :, 1], px[:, :, 2]
448
+ red_mask = (r_ch > 100) & (r_ch > g_ch * 1.2) & (r_ch > b_ch * 1.2)
449
+
450
+ # Use FLOAT cell size to match render() coordinate system exactly.
451
+ # Integer division (h // size) drifts by up to (size-1) * fractional
452
+ # pixels, causing the last cells to be completely misaligned.
453
+ cell_h_f = h / size
454
+ cell_w_f = w / size
455
+
456
+ path_grid = np.zeros((size, size), dtype=bool)
457
+ for r in range(size):
458
+ y0 = int(round(r * cell_h_f))
459
+ y1 = int(round((r + 1) * cell_h_f))
460
+ for c in range(size):
461
+ x0 = int(round(c * cell_w_f))
462
+ x1 = int(round((c + 1) * cell_w_f))
463
+ # Small inward margin to avoid wall / neighbour bleed-over
464
+ ch = y1 - y0
465
+ cw = x1 - x0
466
+ margin_y = max(1, int(ch * 0.15))
467
+ margin_x = max(1, int(cw * 0.15))
468
+ sub = red_mask[y0 + margin_y : y1 - margin_y,
469
+ x0 + margin_x : x1 - margin_x]
470
+ if sub.size > 0 and np.mean(sub) > pixel_threshold:
471
+ path_grid[r, c] = True
472
+
473
+ # Greedy walk from start, respecting maze walls
474
+ directions = [
475
+ ("R", MOVES["R"]),
476
+ ("D", MOVES["D"]),
477
+ ("L", MOVES["L"]),
478
+ ("U", MOVES["U"]),
479
+ ]
480
+ visited = {start}
481
+ cr, cc = start
482
+ actions: List[str] = []
483
+ for _ in range(size * size * 2):
484
+ found = False
485
+ wval = grid_raw[cr][cc]
486
+ for act, (dr, dc, wall_ch) in directions:
487
+ nr, nc = cr + dr, cc + dc
488
+ if 0 <= nr < size and 0 <= nc < size:
489
+ if (wval & WALL_MASKS[wall_ch]) != 0:
490
+ continue
491
+ if path_grid[nr, nc] and (nr, nc) not in visited:
492
+ visited.add((nr, nc))
493
+ actions.append(act)
494
+ cr, cc = nr, nc
495
+ found = True
496
+ break
497
+ if not found:
498
+ break
499
+ return "".join(actions)
500
+
501
+ def extract_path_from_image(
502
+ self, img_path: str, grid_raw: List[List[int]], size: int, start: Tuple
503
+ ) -> str:
504
+ """Extract UDRL from an image file (convenience wrapper)."""
505
+ try:
506
+ pixels = np.array(Image.open(img_path).convert("RGB"))
507
+ return self.extract_path_from_pixels(pixels, grid_raw, size, start)
508
+ except Exception:
509
+ return ""
510
+
511
+
512
+ if __name__ == "__main__":
513
+ proc = MazeProcessor(img_size=512)
514
+
515
+ # Quick smoke test
516
+ grid, start, end, path = proc.generate(size=8, min_path_len=10)
517
+ n_steps = len(path) - 1
518
+ print(f"Maze 8×8 | path length {len(path)} | steps {n_steps}")
519
+ print(f"UDRL: {proc.path_to_udrl(path)}")
520
+ print(f"Verify: {proc.verify_path(grid, start, end, proc.path_to_udrl(path))}")
521
+
522
+ proc.render(grid, start, end).save("test_maze.png")
523
+ proc.render(grid, start, end, path=path).save("test_maze_solution.png")
524
+
525
+ # Test video frame modes
526
+ f1 = proc.generate_video_frames(grid, start, end, path, n_start=3, m_end=3)
527
+ assert len(f1) == 3 + n_steps + 3
528
+
529
+ f2 = proc.generate_video_frames(
530
+ grid, start, end, path, n_start=3, m_end=3, frames=n_steps * 3
531
+ )
532
+ assert len(f2) == 3 + n_steps * 3 + 3
533
+
534
+ half = max(1, n_steps // 2)
535
+ f3 = proc.generate_video_frames(
536
+ grid, start, end, path, n_start=3, m_end=3, frames=half
537
+ )
538
+ assert len(f3) == 3 + half + 3
539
+
540
+ print(f"frames=None → {len(f1)} total ({n_steps} content)")
541
+ print(f"frames={n_steps*3:<4d} → {len(f2)} total (slow-mo)")
542
+ print(f"frames={half:<4d} → {len(f3)} total (fast-fwd)")
543
+ print("All assertions passed ✓")
sudoku/generate_dataset.py CHANGED
@@ -1,6 +1,11 @@
1
  """
2
  Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle.
3
  With checkpoint/resume support via metadata.json.
 
 
 
 
 
4
  """
5
  import json
6
  import re
@@ -8,7 +13,7 @@ import random
8
  import argparse
9
  from dataclasses import dataclass, asdict
10
  from pathlib import Path
11
- from typing import List, Tuple, Optional, Union, Dict, Any
12
  import numpy as np
13
  import cv2
14
  from tqdm import tqdm
@@ -22,7 +27,7 @@ class SolRange:
22
  """Flexible solution count constraint for puzzle generation."""
23
  min_sol: int
24
  max_sol: Optional[int]
25
-
26
  @classmethod
27
  def parse(cls, expr: str) -> "SolRange":
28
  expr = expr.strip()
@@ -46,7 +51,7 @@ class SolRange:
46
  if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}")
47
  return cls(min_sol=n, max_sol=n)
48
  raise ValueError(f"Invalid sol_num expression: '{expr}'")
49
-
50
  @property
51
  def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol
52
  @property
@@ -77,10 +82,10 @@ class GenerationState:
77
  seen_grids: List[str]
78
  all_samples: List[Dict]
79
  completed: bool = False
80
-
81
  def to_dict(self) -> Dict:
82
  return asdict(self)
83
-
84
  @classmethod
85
  def from_dict(cls, d: Dict) -> "GenerationState":
86
  return cls(**d)
@@ -89,9 +94,7 @@ class GenerationState:
89
  def compute_params_hash(params: Dict) -> str:
90
  """Compute hash of generation parameters for consistency check."""
91
  import hashlib
92
- # Only hash parameters that affect generation logic
93
- key_params = {k: v for k, v in params.items()
94
- if k not in ['output_dir']} # output_dir can differ
95
  return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12]
96
 
97
 
@@ -100,21 +103,16 @@ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]
100
  meta_path = output_dir / "metadata.json"
101
  if not meta_path.exists():
102
  return None
103
-
104
  with open(meta_path) as f:
105
  data = json.load(f)
106
-
107
  state = GenerationState.from_dict(data["state"])
108
  expected_hash = compute_params_hash(params)
109
-
110
  if state.params_hash != expected_hash:
111
  print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh")
112
  return None
113
-
114
  if state.completed:
115
  print("✓ Generation already completed")
116
  return state
117
-
118
  print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated")
119
  return state
120
 
@@ -122,65 +120,136 @@ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]
122
  def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
123
  """Save current generation state to metadata.json."""
124
  meta_path = output_dir / "metadata.json"
125
- data = {
126
- "params": params,
127
- "state": state.to_dict()
128
- }
129
- # Atomic write
130
  tmp_path = meta_path.with_suffix('.tmp')
131
  with open(tmp_path, 'w') as f:
132
- json.dump(data, f, indent=2)
133
  tmp_path.rename(meta_path)
134
 
135
 
136
  # ==================== Core Functions ====================
137
 
138
  def get_fill_order(puzzle, solution):
 
139
  return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0]
140
 
 
141
  def create_processor(resolution=None):
142
- if resolution is None: return SudokuProcessor()
 
 
143
  target_size = min(resolution)
144
  cell_size = target_size // 9
145
  sf = cell_size / 60
146
- return SudokuProcessor(cell_size=cell_size, font_scale=1.2*sf, thickness=max(1, int(2*sf)))
 
 
 
147
 
148
- def generate_video_frames(proc, puzzle, solution, n_start, m_end, k=1, max_frames=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  fills = get_fill_order(puzzle, solution)
150
  n_fills = len(fills)
151
- effective_k = k
152
- if max_frames is not None and n_start + n_fills * k + m_end > max_frames:
153
- avail = max_frames - n_start - m_end
154
- effective_k = max(1, avail // n_fills) if avail > 0 and n_fills > 0 else 1
155
-
156
- frames = []
 
 
 
157
  current = [row[:] for row in puzzle]
 
 
158
  img = proc.render(current)
159
- frames.extend([img.copy() for _ in range(n_start)])
160
-
161
- for r, c, v in fills:
162
- current[r][c] = v
163
- frames.append(proc.render(current, highlight_new=(r, c), original=puzzle))
164
- if effective_k > 1:
165
- img = proc.render(current, original=puzzle)
166
- frames.extend([img.copy() for _ in range(effective_k - 1)])
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  img = proc.render(solution, original=puzzle)
169
- frames.extend([img.copy() for _ in range(m_end)])
170
- if max_frames is not None and len(frames) > max_frames:
171
- frames = frames[:max_frames]
172
- return frames
173
 
174
  def save_video(frames, path, fps=10):
 
175
  h, w = frames[0].shape[:2]
176
  writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
177
- for f in frames: writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
 
178
  writer.release()
179
 
 
180
  def normalize_num_per_clue(num_per_clue, clue_levels):
181
- if isinstance(num_per_clue, int): return [num_per_clue] * len(clue_levels)
 
 
182
  if len(num_per_clue) != len(clue_levels):
183
- raise ValueError(f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})")
 
 
184
  return num_per_clue
185
 
186
 
@@ -191,7 +260,7 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
191
  if sol_range.is_unique_only:
192
  puzzle, solution = proc.generate(clue, unique=True)
193
  return puzzle, [solution]
194
-
195
  if sol_range.requires_multi:
196
  try:
197
  puzzle, solutions = proc.generate_multi_solution(
@@ -204,7 +273,7 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
204
  except RuntimeError:
205
  pass
206
  return None
207
-
208
  try:
209
  puzzle, solutions = proc.generate_multi_solution(
210
  clue, min_solutions=max(2, sol_range.min_sol),
@@ -215,7 +284,7 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
215
  return puzzle, solutions
216
  except RuntimeError:
217
  pass
218
-
219
  if sol_range.allows_unique:
220
  puzzle, solution = proc.generate(clue, unique=True)
221
  return puzzle, [solution]
@@ -225,45 +294,48 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
225
  # ==================== Dataset Generation ====================
226
 
227
  def generate_dataset(
228
- output_dir="sudoku_video", clue_levels=[30,40,50,60], num_per_clue=50,
229
- sol_num="1", min_hamming=10, train_ratio=0.8,
 
230
  prompt="Solve this Sudoku puzzle using red font.",
231
- n_start=10, m_end=10, k=1, max_frames=None, fps=10,
232
  resolution=None, seed=42, checkpoint_interval=50
233
  ):
234
  """
235
  Generate Sudoku video dataset with checkpoint/resume support.
236
-
 
 
 
 
237
  Args:
238
  checkpoint_interval: Save checkpoint every N puzzles (default: 50)
239
  """
240
- # Prepare params dict for hashing
241
  params = {
242
  "clue_levels": clue_levels, "num_per_clue": num_per_clue,
243
  "sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio,
244
- "prompt": prompt, "n_start": n_start, "m_end": m_end, "k": k,
245
- "max_frames": max_frames, "fps": fps, "resolution": resolution, "seed": seed
246
  }
247
-
248
  output_dir = Path(output_dir)
249
  video_dir = output_dir / "videos"
250
  image_dir = output_dir / "images"
251
  video_dir.mkdir(parents=True, exist_ok=True)
252
  image_dir.mkdir(parents=True, exist_ok=True)
253
-
254
  # Try to resume from checkpoint
255
  state = load_checkpoint(output_dir, params)
256
-
257
  if state and state.completed:
258
- return # Already done
259
-
260
  sol_range = SolRange.parse(str(sol_num))
261
  proc = create_processor(resolution)
262
  actual_size = proc.img_size
263
  num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels)
264
  max_puzzles = max(num_per_clue_list)
265
  num_width = len(str(max_puzzles))
266
-
267
  # Initialize or restore state
268
  if state is None:
269
  random.seed(seed)
@@ -274,138 +346,162 @@ def generate_dataset(
274
  all_samples=[]
275
  )
276
  print(f"Starting fresh generation with solution range: {sol_range}")
 
 
277
  else:
278
- # Restore RNG state approximately by fast-forwarding
279
  random.seed(seed)
280
  for _ in range(sum(state.clue_progress.values()) * 10):
281
  random.random()
282
-
283
  seen_grids = set(state.seen_grids)
284
  all_samples = state.all_samples.copy()
285
  clue_progress = {int(k): v for k, v in state.clue_progress.items()}
286
-
287
  total_target = sum(num_per_clue_list)
288
  total_done = sum(clue_progress.values())
289
  stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0)
290
  stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0)
291
  puzzles_since_checkpoint = 0
292
-
293
  with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total:
294
  for clue, target_count in zip(clue_levels, num_per_clue_list):
295
  generated = clue_progress.get(clue, 0)
296
  if generated >= target_count:
297
- continue # This clue level is done
298
-
299
  max_attempts = (target_count - generated) * 20
300
-
301
- with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}",
302
  unit="puzzle", leave=False) as pbar_clue:
303
  for _ in range(max_attempts):
304
  if generated >= target_count:
305
  break
306
-
307
  result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming)
308
  if result is None:
309
  continue
310
  puzzle, solutions = result
311
-
312
  fp = proc.encode(puzzle)
313
  if fp in seen_grids:
314
  continue
315
  seen_grids.add(fp)
316
-
317
  n_sols = len(solutions)
318
  if n_sols == 1:
319
  stats_unique += 1
320
  else:
321
  stats_multi += 1
322
-
323
  img_name = f"clue{clue}_{generated:0{num_width}d}.png"
324
  puzzle_img = proc.render(puzzle)
325
- cv2.imwrite(str(image_dir / img_name), cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR))
326
-
 
 
 
327
  for si, sol in enumerate(solutions):
328
  vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4"
329
- frames = generate_video_frames(proc, puzzle, sol, n_start, m_end, k, max_frames)
330
- save_video(frames, video_dir / vid_name, fps)
331
-
332
- hdists = [proc._hamming(sol, solutions[j]) for j in range(n_sols) if j != si]
 
 
 
 
 
333
  all_samples.append({
334
  "prompt": prompt, "video": vid_name, "image": img_name,
335
  "clue": clue, "puzzle": fp, "solution": proc.encode(sol),
336
  "sol_idx": si, "total_solutions": n_sols,
337
- "frame_count": len(frames),
338
- "min_hamming_to_others": min(hdists) if hdists else 0
339
  })
340
-
341
  generated += 1
342
  clue_progress[clue] = generated
343
  puzzles_since_checkpoint += 1
344
  pbar_clue.update(1)
345
  pbar_total.update(1)
346
-
347
- # Periodic checkpoint
348
  if puzzles_since_checkpoint >= checkpoint_interval:
349
  state.clue_progress = clue_progress
350
  state.seen_grids = list(seen_grids)
351
  state.all_samples = all_samples
352
  save_checkpoint(output_dir, state, params)
353
  puzzles_since_checkpoint = 0
354
-
355
- tqdm.write(f"Clue {clue}: {generated} puzzles, "
356
- f"{sum(1 for s in all_samples if s['clue'] == clue)} videos")
357
-
 
 
358
  # Final output
359
- random.seed(seed + 1) # Deterministic shuffle
360
  random.shuffle(all_samples)
361
  split_idx = int(len(all_samples) * train_ratio)
362
-
363
  def write_jsonl(samples, path):
364
  with open(path, 'w') as f:
365
  for s in samples:
366
  json.dump(s, f)
367
  f.write('\n')
368
-
369
  write_jsonl(all_samples[:split_idx], output_dir / "train.jsonl")
370
  write_jsonl(all_samples[split_idx:], output_dir / "test.jsonl")
371
-
372
  # Mark as completed
373
  state.clue_progress = clue_progress
374
  state.seen_grids = list(seen_grids)
375
  state.all_samples = all_samples
376
  state.completed = True
377
  save_checkpoint(output_dir, state, params)
378
-
379
  print(f"\n✓ Dataset complete: {output_dir}/")
380
  print(f" Resolution: {actual_size}x{actual_size}")
381
  print(f" Solution range: {sol_range}")
382
  print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)")
383
  print(f" Videos: {len(all_samples)}")
384
  print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}")
385
-
 
 
 
 
386
  hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0]
387
  if hammings:
388
- print(f" Solution diversity: avg={np.mean(hammings):.1f}, min={min(hammings)}, max={max(hammings)}")
 
389
 
390
 
391
  def parse_resolution(s):
392
  w, h = map(int, s.lower().split('x'))
393
  return (w, h)
394
 
 
395
  def parse_args():
396
- p = argparse.ArgumentParser(description="Generate Sudoku video dataset with resume support")
 
 
397
  p.add_argument("--output-dir", type=str, default="sudoku")
398
- p.add_argument("--clue-levels", type=int, nargs="+", default=[20,30,40,50,60,70])
399
- p.add_argument("--num-per-clue", type=int, nargs="+", default=[15000,10000,10000,5000,2000,1000])
 
 
400
  p.add_argument("--sol-num", type=str, default="<=3",
401
  help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'")
402
  p.add_argument("--min-hamming", type=int, default=10)
403
  p.add_argument("--train-ratio", type=float, default=0.9)
404
- p.add_argument("--prompt", type=str, default="Solve this Sudoku puzzle using red font.")
405
- p.add_argument("--n-start", type=int, default=2)
406
- p.add_argument("--m-end", type=int, default=3)
407
- p.add_argument("--k", type=int, default=1)
408
- p.add_argument("--max-frames", type=int, default=None)
 
 
 
 
409
  p.add_argument("--fps", type=int, default=10)
410
  p.add_argument("--resolution", type=str, default="1024x1024")
411
  p.add_argument("--seed", type=int, default=42)
 
1
  """
2
  Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle.
3
  With checkpoint/resume support via metadata.json.
4
+
5
+ The *frames* parameter replaces the old max_frames + k pair:
6
+ - frames=None → 1 content frame per fill step (variable length)
7
+ - frames=N → exactly N content frames; fills distributed evenly
8
+ (slow-motion if N > fills, fast-forward if N < fills)
9
  """
10
  import json
11
  import re
 
13
  import argparse
14
  from dataclasses import dataclass, asdict
15
  from pathlib import Path
16
+ from typing import List, Tuple, Optional, Dict
17
  import numpy as np
18
  import cv2
19
  from tqdm import tqdm
 
27
  """Flexible solution count constraint for puzzle generation."""
28
  min_sol: int
29
  max_sol: Optional[int]
30
+
31
  @classmethod
32
  def parse(cls, expr: str) -> "SolRange":
33
  expr = expr.strip()
 
51
  if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}")
52
  return cls(min_sol=n, max_sol=n)
53
  raise ValueError(f"Invalid sol_num expression: '{expr}'")
54
+
55
  @property
56
  def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol
57
  @property
 
82
  seen_grids: List[str]
83
  all_samples: List[Dict]
84
  completed: bool = False
85
+
86
  def to_dict(self) -> Dict:
87
  return asdict(self)
88
+
89
  @classmethod
90
  def from_dict(cls, d: Dict) -> "GenerationState":
91
  return cls(**d)
 
94
  def compute_params_hash(params: Dict) -> str:
95
  """Compute hash of generation parameters for consistency check."""
96
  import hashlib
97
+ key_params = {k: v for k, v in params.items() if k not in ['output_dir']}
 
 
98
  return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12]
99
 
100
 
 
103
  meta_path = output_dir / "metadata.json"
104
  if not meta_path.exists():
105
  return None
 
106
  with open(meta_path) as f:
107
  data = json.load(f)
 
108
  state = GenerationState.from_dict(data["state"])
109
  expected_hash = compute_params_hash(params)
 
110
  if state.params_hash != expected_hash:
111
  print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh")
112
  return None
 
113
  if state.completed:
114
  print("✓ Generation already completed")
115
  return state
 
116
  print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated")
117
  return state
118
 
 
120
  def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
121
  """Save current generation state to metadata.json."""
122
  meta_path = output_dir / "metadata.json"
 
 
 
 
 
123
  tmp_path = meta_path.with_suffix('.tmp')
124
  with open(tmp_path, 'w') as f:
125
+ json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
126
  tmp_path.rename(meta_path)
127
 
128
 
129
  # ==================== Core Functions ====================
130
 
131
  def get_fill_order(puzzle, solution):
132
+ """Return list of (row, col, value) for empty cells in row-major order."""
133
  return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0]
134
 
135
+
136
  def create_processor(resolution=None):
137
+ """Create a SudokuProcessor with optional custom resolution."""
138
+ if resolution is None:
139
+ return SudokuProcessor()
140
  target_size = min(resolution)
141
  cell_size = target_size // 9
142
  sf = cell_size / 60
143
+ return SudokuProcessor(
144
+ cell_size=cell_size, font_scale=1.2 * sf, thickness=max(1, int(2 * sf))
145
+ )
146
+
147
 
148
+ def generate_video_frames(proc, puzzle, solution, n_start, m_end, frames=None):
149
+ """
150
+ Generate progressive video frames for a Sudoku solve.
151
+
152
+ The *frames* parameter controls the number of **content frames**
153
+ (between the opening and closing holds):
154
+
155
+ - frames=None → 1 content frame per fill step (n_fills total)
156
+ - frames > fills → multiple frames per fill step (slow-motion)
157
+ - frames < fills → multiple fills per frame (fast-forward)
158
+ - frames == fills → identical to None
159
+
160
+ Total output length = n_start + content_frames + m_end.
161
+
162
+ Args:
163
+ proc: SudokuProcessor instance.
164
+ puzzle: 9×9 puzzle grid (0 = empty).
165
+ solution: 9×9 solved grid.
166
+ n_start: Hold frames for puzzle at the beginning.
167
+ m_end: Hold frames for completed solution at the end.
168
+ frames: Desired number of content frames (None = one per fill).
169
+
170
+ Returns:
171
+ List of numpy arrays (RGB images).
172
+ """
173
  fills = get_fill_order(puzzle, solution)
174
  n_fills = len(fills)
175
+
176
+ if n_fills == 0:
177
+ img = proc.render(solution, original=puzzle)
178
+ return [img.copy() for _ in range(n_start + m_end + 1)]
179
+
180
+ content_frames = frames if frames is not None else n_fills
181
+ content_frames = max(1, content_frames)
182
+
183
+ result = []
184
  current = [row[:] for row in puzzle]
185
+
186
+ # --- opening hold ---
187
  img = proc.render(current)
188
+ result.extend([img.copy() for _ in range(n_start)])
189
+
190
+ # --- content frames ---
191
+ if content_frames == n_fills:
192
+ # Exact 1:1 mapping
193
+ for r, c, v in fills:
194
+ current[r][c] = v
195
+ result.append(proc.render(current, highlight_new=(r, c), original=puzzle))
196
+
197
+ elif content_frames > n_fills:
198
+ # Slow-motion: distribute content_frames evenly across n_fills steps.
199
+ for i, (r, c, v) in enumerate(fills):
200
+ current[r][c] = v
201
+ f_lo = i * content_frames // n_fills
202
+ f_hi = (i + 1) * content_frames // n_fills
203
+ count = f_hi - f_lo # >= 1
204
+
205
+ # First frame of this step shows highlight
206
+ result.append(proc.render(current, highlight_new=(r, c), original=puzzle))
207
+ # Remaining hold frames (no highlight)
208
+ if count > 1:
209
+ img = proc.render(current, original=puzzle)
210
+ result.extend([img.copy() for _ in range(count - 1)])
211
+
212
+ else:
213
+ # Fast-forward: each content frame applies multiple fills at once.
214
+ for f in range(content_frames):
215
+ prev_step = f * n_fills // content_frames
216
+ target_step = (f + 1) * n_fills // content_frames
217
+ last_r, last_c = None, None
218
+ for idx in range(prev_step, target_step):
219
+ r, c, v = fills[idx]
220
+ current[r][c] = v
221
+ last_r, last_c = r, c
222
+ if last_r is not None:
223
+ result.append(
224
+ proc.render(current, highlight_new=(last_r, last_c), original=puzzle)
225
+ )
226
+ else:
227
+ result.append(proc.render(current, original=puzzle))
228
+
229
+ # --- closing hold ---
230
  img = proc.render(solution, original=puzzle)
231
+ result.extend([img.copy() for _ in range(m_end)])
232
+
233
+ return result
234
+
235
 
236
  def save_video(frames, path, fps=10):
237
+ """Save list of numpy RGB frames as mp4."""
238
  h, w = frames[0].shape[:2]
239
  writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
240
+ for f in frames:
241
+ writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
242
  writer.release()
243
 
244
+
245
  def normalize_num_per_clue(num_per_clue, clue_levels):
246
+ """Broadcast single int to list, or validate list length."""
247
+ if isinstance(num_per_clue, int):
248
+ return [num_per_clue] * len(clue_levels)
249
  if len(num_per_clue) != len(clue_levels):
250
+ raise ValueError(
251
+ f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})"
252
+ )
253
  return num_per_clue
254
 
255
 
 
260
  if sol_range.is_unique_only:
261
  puzzle, solution = proc.generate(clue, unique=True)
262
  return puzzle, [solution]
263
+
264
  if sol_range.requires_multi:
265
  try:
266
  puzzle, solutions = proc.generate_multi_solution(
 
273
  except RuntimeError:
274
  pass
275
  return None
276
+
277
  try:
278
  puzzle, solutions = proc.generate_multi_solution(
279
  clue, min_solutions=max(2, sol_range.min_sol),
 
284
  return puzzle, solutions
285
  except RuntimeError:
286
  pass
287
+
288
  if sol_range.allows_unique:
289
  puzzle, solution = proc.generate(clue, unique=True)
290
  return puzzle, [solution]
 
294
  # ==================== Dataset Generation ====================
295
 
296
  def generate_dataset(
297
+ output_dir="sudoku", clue_levels=[20, 30, 40, 50, 60, 70],
298
+ num_per_clue=[15000, 10000, 10000, 5000, 2000, 1000],
299
+ sol_num="<=3", min_hamming=10, train_ratio=0.9,
300
  prompt="Solve this Sudoku puzzle using red font.",
301
+ n_start=2, m_end=3, frames=None, fps=10,
302
  resolution=None, seed=42, checkpoint_interval=50
303
  ):
304
  """
305
  Generate Sudoku video dataset with checkpoint/resume support.
306
+
307
+ The *frames* parameter controls the number of **content frames** per video:
308
+ - None → one content frame per fill step (variable length per puzzle)
309
+ - N > 0 → exactly N content frames; fills distributed evenly
310
+
311
  Args:
312
  checkpoint_interval: Save checkpoint every N puzzles (default: 50)
313
  """
 
314
  params = {
315
  "clue_levels": clue_levels, "num_per_clue": num_per_clue,
316
  "sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio,
317
+ "prompt": prompt, "n_start": n_start, "m_end": m_end, "frames": frames,
318
+ "fps": fps, "resolution": resolution, "seed": seed
319
  }
320
+
321
  output_dir = Path(output_dir)
322
  video_dir = output_dir / "videos"
323
  image_dir = output_dir / "images"
324
  video_dir.mkdir(parents=True, exist_ok=True)
325
  image_dir.mkdir(parents=True, exist_ok=True)
326
+
327
  # Try to resume from checkpoint
328
  state = load_checkpoint(output_dir, params)
 
329
  if state and state.completed:
330
+ return
331
+
332
  sol_range = SolRange.parse(str(sol_num))
333
  proc = create_processor(resolution)
334
  actual_size = proc.img_size
335
  num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels)
336
  max_puzzles = max(num_per_clue_list)
337
  num_width = len(str(max_puzzles))
338
+
339
  # Initialize or restore state
340
  if state is None:
341
  random.seed(seed)
 
346
  all_samples=[]
347
  )
348
  print(f"Starting fresh generation with solution range: {sol_range}")
349
+ print(f" frames={'auto (1 per fill)' if frames is None else frames}, "
350
+ f"n_start={n_start}, m_end={m_end}, fps={fps}")
351
  else:
 
352
  random.seed(seed)
353
  for _ in range(sum(state.clue_progress.values()) * 10):
354
  random.random()
355
+
356
  seen_grids = set(state.seen_grids)
357
  all_samples = state.all_samples.copy()
358
  clue_progress = {int(k): v for k, v in state.clue_progress.items()}
359
+
360
  total_target = sum(num_per_clue_list)
361
  total_done = sum(clue_progress.values())
362
  stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0)
363
  stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0)
364
  puzzles_since_checkpoint = 0
365
+
366
  with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total:
367
  for clue, target_count in zip(clue_levels, num_per_clue_list):
368
  generated = clue_progress.get(clue, 0)
369
  if generated >= target_count:
370
+ continue
371
+
372
  max_attempts = (target_count - generated) * 20
373
+
374
+ with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}",
375
  unit="puzzle", leave=False) as pbar_clue:
376
  for _ in range(max_attempts):
377
  if generated >= target_count:
378
  break
379
+
380
  result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming)
381
  if result is None:
382
  continue
383
  puzzle, solutions = result
384
+
385
  fp = proc.encode(puzzle)
386
  if fp in seen_grids:
387
  continue
388
  seen_grids.add(fp)
389
+
390
  n_sols = len(solutions)
391
  if n_sols == 1:
392
  stats_unique += 1
393
  else:
394
  stats_multi += 1
395
+
396
  img_name = f"clue{clue}_{generated:0{num_width}d}.png"
397
  puzzle_img = proc.render(puzzle)
398
+ cv2.imwrite(
399
+ str(image_dir / img_name),
400
+ cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR),
401
+ )
402
+
403
  for si, sol in enumerate(solutions):
404
  vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4"
405
+ vid_frames = generate_video_frames(
406
+ proc, puzzle, sol, n_start, m_end, frames
407
+ )
408
+ save_video(vid_frames, video_dir / vid_name, fps)
409
+
410
+ hdists = [
411
+ proc._hamming(sol, solutions[j])
412
+ for j in range(n_sols) if j != si
413
+ ]
414
  all_samples.append({
415
  "prompt": prompt, "video": vid_name, "image": img_name,
416
  "clue": clue, "puzzle": fp, "solution": proc.encode(sol),
417
  "sol_idx": si, "total_solutions": n_sols,
418
+ "frame_count": len(vid_frames),
419
+ "min_hamming_to_others": min(hdists) if hdists else 0,
420
  })
421
+
422
  generated += 1
423
  clue_progress[clue] = generated
424
  puzzles_since_checkpoint += 1
425
  pbar_clue.update(1)
426
  pbar_total.update(1)
427
+
 
428
  if puzzles_since_checkpoint >= checkpoint_interval:
429
  state.clue_progress = clue_progress
430
  state.seen_grids = list(seen_grids)
431
  state.all_samples = all_samples
432
  save_checkpoint(output_dir, state, params)
433
  puzzles_since_checkpoint = 0
434
+
435
+ tqdm.write(
436
+ f"Clue {clue}: {generated} puzzles, "
437
+ f"{sum(1 for s in all_samples if s['clue'] == clue)} videos"
438
+ )
439
+
440
  # Final output
441
+ random.seed(seed + 1)
442
  random.shuffle(all_samples)
443
  split_idx = int(len(all_samples) * train_ratio)
444
+
445
  def write_jsonl(samples, path):
446
  with open(path, 'w') as f:
447
  for s in samples:
448
  json.dump(s, f)
449
  f.write('\n')
450
+
451
  write_jsonl(all_samples[:split_idx], output_dir / "train.jsonl")
452
  write_jsonl(all_samples[split_idx:], output_dir / "test.jsonl")
453
+
454
  # Mark as completed
455
  state.clue_progress = clue_progress
456
  state.seen_grids = list(seen_grids)
457
  state.all_samples = all_samples
458
  state.completed = True
459
  save_checkpoint(output_dir, state, params)
460
+
461
  print(f"\n✓ Dataset complete: {output_dir}/")
462
  print(f" Resolution: {actual_size}x{actual_size}")
463
  print(f" Solution range: {sol_range}")
464
  print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)")
465
  print(f" Videos: {len(all_samples)}")
466
  print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}")
467
+
468
+ fcounts = [s["frame_count"] for s in all_samples]
469
+ print(f" Frame counts: avg={np.mean(fcounts):.1f}, "
470
+ f"min={min(fcounts)}, max={max(fcounts)}")
471
+
472
  hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0]
473
  if hammings:
474
+ print(f" Solution diversity: avg={np.mean(hammings):.1f}, "
475
+ f"min={min(hammings)}, max={max(hammings)}")
476
 
477
 
478
  def parse_resolution(s):
479
  w, h = map(int, s.lower().split('x'))
480
  return (w, h)
481
 
482
+
483
  def parse_args():
484
+ p = argparse.ArgumentParser(
485
+ description="Generate Sudoku video dataset with resume support"
486
+ )
487
  p.add_argument("--output-dir", type=str, default="sudoku")
488
+ p.add_argument("--clue-levels", type=int, nargs="+",
489
+ default=[20, 30, 40, 50, 60, 70])
490
+ p.add_argument("--num-per-clue", type=int, nargs="+",
491
+ default=[15000, 10000, 10000, 5000, 2000, 1000])
492
  p.add_argument("--sol-num", type=str, default="<=3",
493
  help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'")
494
  p.add_argument("--min-hamming", type=int, default=10)
495
  p.add_argument("--train-ratio", type=float, default=0.9)
496
+ p.add_argument("--prompt", type=str,
497
+ default="Solve this Sudoku puzzle using red font.")
498
+ p.add_argument("--n-start", type=int, default=2,
499
+ help="Hold frames for puzzle at video start")
500
+ p.add_argument("--m-end", type=int, default=3,
501
+ help="Hold frames for completed solution at video end")
502
+ p.add_argument("--frames", type=int, default=None,
503
+ help="Content frames per video. None=1 per fill (auto). "
504
+ "If > fills: slow-motion. If < fills: fast-forward.")
505
  p.add_argument("--fps", type=int, default=10)
506
  p.add_argument("--resolution", type=str, default="1024x1024")
507
  p.add_argument("--seed", type=int, default=42)
sudoku/jsonl_to_csv.py CHANGED
@@ -2,11 +2,11 @@ import json
2
  import csv
3
  from pathlib import Path
4
 
5
- dataset='sudoku'
6
- split='train'
7
 
8
  # Load test data
9
- with open(f'{dataset}/{split}_info.jsonl', 'r') as f:
10
  data = [json.loads(line) for line in f]
11
 
12
  # Write to CSV
@@ -19,4 +19,7 @@ with open(f'{dataset}/{split}.csv', 'w', newline='', encoding='utf-8') as f:
19
  'images/' + item['image'],
20
  'videos/' + item['video'],
21
  item['prompt'],
22
- ])
 
 
 
 
2
  import csv
3
  from pathlib import Path
4
 
5
+ dataset='sudoku_large'
6
+ split='test'
7
 
8
  # Load test data
9
+ with open(f'{dataset}/{split}.jsonl', 'r') as f:
10
  data = [json.loads(line) for line in f]
11
 
12
  # Write to CSV
 
19
  'images/' + item['image'],
20
  'videos/' + item['video'],
21
  item['prompt'],
22
+ ])
23
+
24
+ # Rename `{split}.jsonl' to `{split}_info.jsonl`
25
+ Path(f'{dataset}/{split}.jsonl').rename(Path(f'{dataset}/{split}_info.jsonl'))