Add files using upload-large-folder tool
Browse files- frozenlake/data_process.py +480 -0
- frozenlake/frozenlake_processor.py +466 -0
- maze/data_process.py +651 -0
- maze/maze_processor.py +543 -0
- sudoku/generate_dataset.py +198 -102
- sudoku/jsonl_to_csv.py +7 -4
frozenlake/data_process.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FrozenLake Video Dataset Generator — generate, eval, verify.
|
| 3 |
+
|
| 4 |
+
Uses plain BFS solver (not networkx) for fast generation at all grid sizes.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python frozenlake_video_gen.py generate --output-dir frozenlake \
|
| 8 |
+
--sizes 8 16 32 --num-per-size 100 500 1000 --p 0.8
|
| 9 |
+
python frozenlake_video_gen.py eval result_videos/ --table-dir frozenlake/tables
|
| 10 |
+
python frozenlake_video_gen.py verify results.json --table-dir frozenlake/tables
|
| 11 |
+
"""
|
| 12 |
+
import json
|
| 13 |
+
import csv
|
| 14 |
+
import hashlib
|
| 15 |
+
import random
|
| 16 |
+
import re
|
| 17 |
+
import argparse
|
| 18 |
+
from dataclasses import dataclass, asdict
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Dict, List, Optional
|
| 21 |
+
|
| 22 |
+
import cv2
|
| 23 |
+
import numpy as np
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from frozenlake_processor import FrozenLakeProcessor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ==================== Checkpoint ====================
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class GenerationState:
|
| 33 |
+
params_hash: str
|
| 34 |
+
size_progress: Dict[int, int]
|
| 35 |
+
seen_fingerprints: List[str]
|
| 36 |
+
all_samples: List[Dict]
|
| 37 |
+
completed: bool = False
|
| 38 |
+
|
| 39 |
+
def to_dict(self) -> Dict:
|
| 40 |
+
return asdict(self)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def from_dict(cls, d: Dict) -> "GenerationState":
|
| 44 |
+
return cls(**d)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _params_hash(params: Dict) -> str:
|
| 48 |
+
key = {k: v for k, v in params.items() if k != "output_dir"}
|
| 49 |
+
return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
|
| 53 |
+
meta = output_dir / "metadata.json"
|
| 54 |
+
if not meta.exists():
|
| 55 |
+
return None
|
| 56 |
+
with open(meta) as f:
|
| 57 |
+
data = json.load(f)
|
| 58 |
+
state = GenerationState.from_dict(data["state"])
|
| 59 |
+
expected = _params_hash(params)
|
| 60 |
+
if state.params_hash != expected:
|
| 61 |
+
print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh")
|
| 62 |
+
return None
|
| 63 |
+
if state.completed:
|
| 64 |
+
print("✓ Generation already completed")
|
| 65 |
+
return state
|
| 66 |
+
print(f"✓ Resuming: {sum(state.size_progress.values())} puzzles done")
|
| 67 |
+
return state
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
|
| 71 |
+
meta = output_dir / "metadata.json"
|
| 72 |
+
tmp = meta.with_suffix(".tmp")
|
| 73 |
+
with open(tmp, "w") as f:
|
| 74 |
+
json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
|
| 75 |
+
tmp.rename(meta)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ==================== Video I/O ====================
|
| 79 |
+
|
| 80 |
+
def save_video_cv2(frames: list, path: str, fps: int = 10):
|
| 81 |
+
first = np.array(frames[0])
|
| 82 |
+
h, w = first.shape[:2]
|
| 83 |
+
writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
|
| 84 |
+
for frame in frames:
|
| 85 |
+
writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
|
| 86 |
+
writer.release()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def extract_last_frame(video_path: str) -> Optional[np.ndarray]:
|
| 90 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 91 |
+
if not cap.isOpened():
|
| 92 |
+
return None
|
| 93 |
+
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 94 |
+
if total > 0:
|
| 95 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1)
|
| 96 |
+
ret, frame = cap.read()
|
| 97 |
+
cap.release()
|
| 98 |
+
if not ret or frame is None:
|
| 99 |
+
return None
|
| 100 |
+
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ==================== Helpers ====================
|
| 104 |
+
|
| 105 |
+
def _normalise_list(val, sizes, name="parameter"):
|
| 106 |
+
if isinstance(val, int):
|
| 107 |
+
return [val] * len(sizes)
|
| 108 |
+
if len(val) != len(sizes):
|
| 109 |
+
raise ValueError(f"{name} length ({len(val)}) != sizes ({len(sizes)})")
|
| 110 |
+
return list(val)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ==================== Generate ====================
|
| 114 |
+
|
| 115 |
+
def generate_dataset(
|
| 116 |
+
output_dir: str = "frozenlake",
|
| 117 |
+
sizes: List[int] = [8, 16, 32],
|
| 118 |
+
num_per_size: list = [100, 500, 1000],
|
| 119 |
+
p: float = 0.8,
|
| 120 |
+
min_path_ratio: float = 0.3,
|
| 121 |
+
img_size: int = 512,
|
| 122 |
+
prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.",
|
| 123 |
+
train_ratio: float = 0.9,
|
| 124 |
+
n_start: int = 2,
|
| 125 |
+
m_end: int = 3,
|
| 126 |
+
frames: Optional[int] = None,
|
| 127 |
+
fps: int = 10,
|
| 128 |
+
seed: int = 42,
|
| 129 |
+
use_gym: bool = True,
|
| 130 |
+
checkpoint_interval: int = 50,
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Generate FrozenLake video dataset with checkpoint/resume.
|
| 134 |
+
|
| 135 |
+
Layout::
|
| 136 |
+
|
| 137 |
+
output_dir/
|
| 138 |
+
images/ videos/ tables/
|
| 139 |
+
train.jsonl test.jsonl train.csv test.csv
|
| 140 |
+
path.json metadata.json
|
| 141 |
+
"""
|
| 142 |
+
params = {
|
| 143 |
+
"sizes": sizes, "num_per_size": num_per_size,
|
| 144 |
+
"p": p, "min_path_ratio": min_path_ratio, "img_size": img_size,
|
| 145 |
+
"prompt": prompt, "train_ratio": train_ratio,
|
| 146 |
+
"n_start": n_start, "m_end": m_end, "frames": frames,
|
| 147 |
+
"fps": fps, "seed": seed, "use_gym": use_gym,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
out = Path(output_dir)
|
| 151 |
+
img_dir, vid_dir, tbl_dir = out / "images", out / "videos", out / "tables"
|
| 152 |
+
for d in (img_dir, vid_dir, tbl_dir):
|
| 153 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 154 |
+
|
| 155 |
+
state = load_checkpoint(out, params)
|
| 156 |
+
if state and state.completed:
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
num_list = _normalise_list(
|
| 160 |
+
num_per_size[0] if len(num_per_size) == 1 else num_per_size,
|
| 161 |
+
sizes, "num_per_size",
|
| 162 |
+
)
|
| 163 |
+
num_w = len(str(max(num_list)))
|
| 164 |
+
proc = FrozenLakeProcessor(img_size=img_size)
|
| 165 |
+
|
| 166 |
+
if state is None:
|
| 167 |
+
random.seed(seed)
|
| 168 |
+
state = GenerationState(
|
| 169 |
+
params_hash=_params_hash(params),
|
| 170 |
+
size_progress={sz: 0 for sz in sizes},
|
| 171 |
+
seen_fingerprints=[], all_samples=[],
|
| 172 |
+
)
|
| 173 |
+
print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}")
|
| 174 |
+
print(f" frames={'auto' if frames is None else frames}, "
|
| 175 |
+
f"n_start={n_start}, m_end={m_end}, fps={fps}")
|
| 176 |
+
else:
|
| 177 |
+
random.seed(seed)
|
| 178 |
+
for _ in range(sum(state.size_progress.values()) * 10):
|
| 179 |
+
random.random()
|
| 180 |
+
|
| 181 |
+
seen = set(state.seen_fingerprints)
|
| 182 |
+
all_samples = list(state.all_samples)
|
| 183 |
+
progress = {int(k): v for k, v in state.size_progress.items()}
|
| 184 |
+
since_ckpt = 0
|
| 185 |
+
total_target = sum(num_list)
|
| 186 |
+
|
| 187 |
+
with tqdm(total=total_target, initial=sum(progress.values()),
|
| 188 |
+
desc="Total", unit="puzzle") as pbar:
|
| 189 |
+
for grid_size, target in zip(sizes, num_list):
|
| 190 |
+
generated = progress.get(grid_size, 0)
|
| 191 |
+
if generated >= target:
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
min_len = max(1, int(grid_size * grid_size * min_path_ratio))
|
| 195 |
+
|
| 196 |
+
with tqdm(total=target, initial=generated,
|
| 197 |
+
desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz:
|
| 198 |
+
for _ in range((target - generated) * 20):
|
| 199 |
+
if generated >= target:
|
| 200 |
+
break
|
| 201 |
+
try:
|
| 202 |
+
desc, path = proc.generate(grid_size, p=p, min_path_len=min_len)
|
| 203 |
+
except RuntimeError:
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
fp = proc.fingerprint(desc)
|
| 207 |
+
if fp in seen:
|
| 208 |
+
continue
|
| 209 |
+
seen.add(fp)
|
| 210 |
+
|
| 211 |
+
base = f"size{grid_size}_{generated:0{num_w}d}"
|
| 212 |
+
img_name, vid_name, tbl_name = f"{base}.png", f"{base}.mp4", f"{base}.txt"
|
| 213 |
+
|
| 214 |
+
proc.render(desc, use_gym=use_gym).save(str(img_dir / img_name))
|
| 215 |
+
|
| 216 |
+
vid_frames = proc.generate_video_frames(
|
| 217 |
+
desc, path, n_start=n_start, m_end=m_end,
|
| 218 |
+
frames=frames, use_gym=use_gym,
|
| 219 |
+
)
|
| 220 |
+
save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps)
|
| 221 |
+
proc.save_table(str(tbl_dir / tbl_name), desc)
|
| 222 |
+
|
| 223 |
+
udrl = proc.path_to_udrl(path)
|
| 224 |
+
all_samples.append({
|
| 225 |
+
"prompt": prompt, "image": img_name, "video": vid_name,
|
| 226 |
+
"table": tbl_name, "grid_size": grid_size,
|
| 227 |
+
"grid_desc": desc, "start": list(proc.find_start(desc)),
|
| 228 |
+
"path_udrl": udrl, "path_length": len(path),
|
| 229 |
+
"frame_count": len(vid_frames),
|
| 230 |
+
})
|
| 231 |
+
|
| 232 |
+
generated += 1
|
| 233 |
+
progress[grid_size] = generated
|
| 234 |
+
since_ckpt += 1
|
| 235 |
+
pbar_sz.update(1)
|
| 236 |
+
pbar.update(1)
|
| 237 |
+
|
| 238 |
+
if since_ckpt >= checkpoint_interval:
|
| 239 |
+
state.size_progress = progress
|
| 240 |
+
state.seen_fingerprints = list(seen)
|
| 241 |
+
state.all_samples = all_samples
|
| 242 |
+
save_checkpoint(out, state, params)
|
| 243 |
+
since_ckpt = 0
|
| 244 |
+
|
| 245 |
+
tqdm.write(f"Size {grid_size}: {generated} puzzles")
|
| 246 |
+
|
| 247 |
+
# --- Final outputs ---
|
| 248 |
+
with open(out / "path.json", "w") as f:
|
| 249 |
+
json.dump(
|
| 250 |
+
dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)),
|
| 251 |
+
f, indent=4,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
random.seed(seed + 1)
|
| 255 |
+
random.shuffle(all_samples)
|
| 256 |
+
split = int(len(all_samples) * train_ratio)
|
| 257 |
+
|
| 258 |
+
def _jsonl(samples, path):
|
| 259 |
+
with open(path, "w") as f:
|
| 260 |
+
for s in samples:
|
| 261 |
+
f.write(json.dumps(s) + "\n")
|
| 262 |
+
|
| 263 |
+
_jsonl(all_samples[:split], out / "train.jsonl")
|
| 264 |
+
_jsonl(all_samples[split:], out / "test.jsonl")
|
| 265 |
+
|
| 266 |
+
for name, samps in [("train", all_samples[:split]), ("test", all_samples[split:])]:
|
| 267 |
+
with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f:
|
| 268 |
+
w = csv.writer(f)
|
| 269 |
+
w.writerow(["input_image", "video", "prompt"])
|
| 270 |
+
for s in samps:
|
| 271 |
+
w.writerow([f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]])
|
| 272 |
+
|
| 273 |
+
state.size_progress = progress
|
| 274 |
+
state.seen_fingerprints = list(seen)
|
| 275 |
+
state.all_samples = all_samples
|
| 276 |
+
state.completed = True
|
| 277 |
+
save_checkpoint(out, state, params)
|
| 278 |
+
|
| 279 |
+
lengths = [s["path_length"] for s in all_samples]
|
| 280 |
+
fcounts = [s["frame_count"] for s in all_samples]
|
| 281 |
+
print(f"\n✓ Dataset complete: {out}/")
|
| 282 |
+
print(f" Sizes: {sizes}, p={p}, Puzzles: {len(all_samples)}")
|
| 283 |
+
print(f" Train: {split}, Test: {len(all_samples) - split}")
|
| 284 |
+
print(f" Path lengths: avg={np.mean(lengths):.1f}, min={min(lengths)}, max={max(lengths)}")
|
| 285 |
+
print(f" Frame counts: avg={np.mean(fcounts):.1f}, min={min(fcounts)}, max={max(fcounts)}")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ==================== Eval ====================
|
| 289 |
+
|
| 290 |
+
def eval_videos(
|
| 291 |
+
video_dir: str,
|
| 292 |
+
table_dir: str,
|
| 293 |
+
output_json: Optional[str] = None,
|
| 294 |
+
gt_json: Optional[str] = None,
|
| 295 |
+
use_gym: bool = True,
|
| 296 |
+
):
|
| 297 |
+
"""Evaluate result videos: last frame → red path → verify."""
|
| 298 |
+
proc = FrozenLakeProcessor()
|
| 299 |
+
vid_root, tbl_root = Path(video_dir), Path(table_dir)
|
| 300 |
+
if output_json is None:
|
| 301 |
+
output_json = str(vid_root / "0_result.json")
|
| 302 |
+
|
| 303 |
+
videos = sorted(
|
| 304 |
+
vid_root.glob("*.mp4"),
|
| 305 |
+
key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
|
| 306 |
+
)
|
| 307 |
+
if not videos:
|
| 308 |
+
print(f"No .mp4 in {vid_root}")
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
print(f"Found {len(videos)} videos, table_dir={tbl_root}")
|
| 312 |
+
|
| 313 |
+
extracted: Dict[str, str] = {}
|
| 314 |
+
missing_tbl = missing_frame = 0
|
| 315 |
+
|
| 316 |
+
for vp in tqdm(videos, desc="Extracting"):
|
| 317 |
+
stem = vp.stem
|
| 318 |
+
desc = proc.load_table(str(tbl_root / f"{stem}.txt"))
|
| 319 |
+
if desc is None:
|
| 320 |
+
missing_tbl += 1
|
| 321 |
+
continue
|
| 322 |
+
start = proc.find_start(desc)
|
| 323 |
+
if start is None:
|
| 324 |
+
missing_tbl += 1
|
| 325 |
+
continue
|
| 326 |
+
lf = extract_last_frame(str(vp))
|
| 327 |
+
if lf is None:
|
| 328 |
+
missing_frame += 1
|
| 329 |
+
continue
|
| 330 |
+
extracted[f"{stem}.png"] = proc.extract_path_from_pixels(
|
| 331 |
+
lf, len(desc), len(desc[0]), start, desc
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
with open(output_json, "w") as f:
|
| 335 |
+
json.dump(extracted, f, indent=4)
|
| 336 |
+
print(f"Saved: {output_json}")
|
| 337 |
+
|
| 338 |
+
# Verify
|
| 339 |
+
correct = total_valid = 0
|
| 340 |
+
correctly_solved: List[Dict] = []
|
| 341 |
+
size_stats: Dict[int, Dict[str, int]] = {}
|
| 342 |
+
|
| 343 |
+
verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
|
| 344 |
+
|
| 345 |
+
for name, udrl in extracted.items():
|
| 346 |
+
desc = proc.load_table(str(tbl_root / f"{name.replace('.png', '')}.txt"))
|
| 347 |
+
if desc is None:
|
| 348 |
+
continue
|
| 349 |
+
total_valid += 1
|
| 350 |
+
sz = len(desc)
|
| 351 |
+
size_stats.setdefault(sz, {"total": 0, "correct": 0})
|
| 352 |
+
size_stats[sz]["total"] += 1
|
| 353 |
+
if verify_fn(desc, udrl):
|
| 354 |
+
correct += 1
|
| 355 |
+
size_stats[sz]["correct"] += 1
|
| 356 |
+
correctly_solved.append({"name": name, "length": len(udrl)})
|
| 357 |
+
|
| 358 |
+
acc = correct / total_valid * 100 if total_valid else 0
|
| 359 |
+
print(f"\n{'='*50}\nEvaluation Summary\n{'='*50}")
|
| 360 |
+
print(f"Videos: {len(videos)}, Missing tables: {missing_tbl}, "
|
| 361 |
+
f"Failed frames: {missing_frame}")
|
| 362 |
+
print(f"Evaluated: {total_valid}, Correct: {correct}, Accuracy: {acc:.2f}%")
|
| 363 |
+
|
| 364 |
+
if size_stats:
|
| 365 |
+
print("\nBy size:")
|
| 366 |
+
for sz in sorted(size_stats):
|
| 367 |
+
s = size_stats[sz]
|
| 368 |
+
print(f" {sz:3d}: {s['correct']}/{s['total']} "
|
| 369 |
+
f"({s['correct']/s['total']*100:.1f}%)")
|
| 370 |
+
|
| 371 |
+
correctly_solved.sort(key=lambda x: x["length"], reverse=True)
|
| 372 |
+
for i, item in enumerate(correctly_solved[:3]):
|
| 373 |
+
print(f" Top {i+1}: {item['name']} (len={item['length']})")
|
| 374 |
+
|
| 375 |
+
if gt_json:
|
| 376 |
+
_gt_bins(extracted, gt_json, tbl_root, proc, verify_fn)
|
| 377 |
+
print(f"{'='*50}")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _gt_bins(extracted, gt_path, tbl_root, proc, verify_fn):
|
| 381 |
+
try:
|
| 382 |
+
with open(gt_path) as f:
|
| 383 |
+
gt = json.load(f)
|
| 384 |
+
except Exception:
|
| 385 |
+
return
|
| 386 |
+
bins: Dict[str, Dict[str, int]] = {}
|
| 387 |
+
for name, pred in extracted.items():
|
| 388 |
+
if name not in gt:
|
| 389 |
+
continue
|
| 390 |
+
lo = (len(gt[name]) // 10) * 10
|
| 391 |
+
label = f"{lo:3d}-{lo+9:3d}"
|
| 392 |
+
bins.setdefault(label, {"total": 0, "correct": 0})
|
| 393 |
+
bins[label]["total"] += 1
|
| 394 |
+
desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
|
| 395 |
+
if desc and verify_fn(desc, pred):
|
| 396 |
+
bins[label]["correct"] += 1
|
| 397 |
+
if bins:
|
| 398 |
+
print("\nBy GT path length:")
|
| 399 |
+
for label in sorted(bins):
|
| 400 |
+
b = bins[label]
|
| 401 |
+
print(f" {label}: {b['correct']}/{b['total']} "
|
| 402 |
+
f"({b['correct']/b['total']*100:.1f}%)")
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
# ==================== Verify ====================
|
| 406 |
+
|
| 407 |
+
def verify_results(json_file: str, table_dir: str, use_gym: bool = True):
|
| 408 |
+
proc = FrozenLakeProcessor()
|
| 409 |
+
with open(json_file) as f:
|
| 410 |
+
solutions = json.load(f)
|
| 411 |
+
verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
|
| 412 |
+
correct = skipped = valid = 0
|
| 413 |
+
for name, udrl in solutions.items():
|
| 414 |
+
desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt"))
|
| 415 |
+
if desc is None:
|
| 416 |
+
skipped += 1
|
| 417 |
+
continue
|
| 418 |
+
valid += 1
|
| 419 |
+
if verify_fn(desc, udrl):
|
| 420 |
+
correct += 1
|
| 421 |
+
acc = correct / valid * 100 if valid else 0
|
| 422 |
+
print(f"\n{'='*40}\nVerification: {correct}/{valid} ({acc:.2f}%)")
|
| 423 |
+
if skipped:
|
| 424 |
+
print(f"Skipped: {skipped}")
|
| 425 |
+
print(f"{'='*40}")
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ==================== CLI ====================
|
| 429 |
+
|
| 430 |
+
def parse_args():
|
| 431 |
+
p = argparse.ArgumentParser(description="FrozenLake video dataset")
|
| 432 |
+
sub = p.add_subparsers(dest="command")
|
| 433 |
+
|
| 434 |
+
gen = sub.add_parser("generate")
|
| 435 |
+
gen.add_argument("--output-dir", default="frozenlake")
|
| 436 |
+
gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 32])
|
| 437 |
+
gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000])
|
| 438 |
+
gen.add_argument("--p", type=float, default=0.8)
|
| 439 |
+
gen.add_argument("--min-path-ratio", type=float, default=0.1,
|
| 440 |
+
help="Min path length as fraction of size² (default 0.1; "
|
| 441 |
+
"FrozenLake paths are much shorter than maze paths)")
|
| 442 |
+
gen.add_argument("--img-size", type=int, default=1024)
|
| 443 |
+
gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.")
|
| 444 |
+
gen.add_argument("--train-ratio", type=float, default=0.9)
|
| 445 |
+
gen.add_argument("--n-start", type=int, default=2)
|
| 446 |
+
gen.add_argument("--m-end", type=int, default=3)
|
| 447 |
+
gen.add_argument("--frames", type=int, default=None)
|
| 448 |
+
gen.add_argument("--fps", type=int, default=10)
|
| 449 |
+
gen.add_argument("--seed", type=int, default=42)
|
| 450 |
+
gen.add_argument("--no-gym", action="store_true")
|
| 451 |
+
gen.add_argument("--checkpoint-interval", type=int, default=50)
|
| 452 |
+
|
| 453 |
+
ev = sub.add_parser("eval")
|
| 454 |
+
ev.add_argument("video_dir")
|
| 455 |
+
ev.add_argument("--table-dir", required=True)
|
| 456 |
+
ev.add_argument("--output-json", default=None)
|
| 457 |
+
ev.add_argument("--gt-json", default=None)
|
| 458 |
+
ev.add_argument("--no-gym", action="store_true")
|
| 459 |
+
|
| 460 |
+
ver = sub.add_parser("verify")
|
| 461 |
+
ver.add_argument("json_file")
|
| 462 |
+
ver.add_argument("--table-dir", required=True)
|
| 463 |
+
ver.add_argument("--no-gym", action="store_true")
|
| 464 |
+
|
| 465 |
+
return p.parse_args()
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
if __name__ == "__main__":
|
| 469 |
+
args = parse_args()
|
| 470 |
+
if args.command == "generate":
|
| 471 |
+
kw = {k: v for k, v in vars(args).items() if k not in ("command", "no_gym")}
|
| 472 |
+
kw["use_gym"] = not args.no_gym
|
| 473 |
+
generate_dataset(**kw)
|
| 474 |
+
elif args.command == "eval":
|
| 475 |
+
eval_videos(args.video_dir, args.table_dir, args.output_json,
|
| 476 |
+
args.gt_json, not args.no_gym)
|
| 477 |
+
elif args.command == "verify":
|
| 478 |
+
verify_results(args.json_file, args.table_dir, not args.no_gym)
|
| 479 |
+
else:
|
| 480 |
+
print("Usage: python frozenlake_video_gen.py {generate|eval|verify} ...")
|
frozenlake/frozenlake_processor.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FrozenLakeProcessor - FrozenLake puzzle generation, solving, rendering, and evaluation.
|
| 3 |
+
|
| 4 |
+
Grid cells: S=Start, F=Frozen(safe), H=Hole(death), G=Goal
|
| 5 |
+
Table chars: @=Start, _=Frozen, #=Hole, *=Goal
|
| 6 |
+
|
| 7 |
+
Performance notes vs original DiffThinker code:
|
| 8 |
+
- Solving uses plain BFS (O(n²)) instead of networkx graph construction
|
| 9 |
+
which had massive overhead from add_node/add_edge Python calls.
|
| 10 |
+
- Gym renderer is cached per puzzle to avoid repeated pygame init.
|
| 11 |
+
"""
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import warnings
|
| 15 |
+
from collections import deque
|
| 16 |
+
from typing import List, Tuple, Optional
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
from PIL import Image, ImageDraw
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
|
| 23 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="pygame")
|
| 24 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 25 |
+
import gymnasium as gym
|
| 26 |
+
|
| 27 |
+
HAS_GYM = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
HAS_GYM = False
|
| 30 |
+
|
| 31 |
+
# Table ↔ Grid mapping
|
| 32 |
+
TABLE_TO_GRID = {"@": "S", "_": "F", "#": "H", "*": "G"}
|
| 33 |
+
GRID_TO_TABLE = {v: k for k, v in TABLE_TO_GRID.items()}
|
| 34 |
+
|
| 35 |
+
MOVES = {"U": (-1, 0), "D": (1, 0), "L": (0, -1), "R": (0, 1)}
|
| 36 |
+
GYM_ACTION_MAP = {"L": 0, "D": 1, "R": 2, "U": 3}
|
| 37 |
+
|
| 38 |
+
GridDesc = List[str]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class FrozenLakeProcessor:
|
| 42 |
+
"""FrozenLake generation, BFS solving, rendering, and evaluation."""
|
| 43 |
+
|
| 44 |
+
def __init__(self, img_size: int = 512):
|
| 45 |
+
self.img_size = img_size
|
| 46 |
+
self.path_color = "red"
|
| 47 |
+
|
| 48 |
+
# ==================== Generation ====================
|
| 49 |
+
|
| 50 |
+
def generate(
|
| 51 |
+
self,
|
| 52 |
+
size: int,
|
| 53 |
+
p: float = 0.8,
|
| 54 |
+
min_path_len: int = 1,
|
| 55 |
+
max_attempts: int = 500,
|
| 56 |
+
) -> Tuple[GridDesc, List[Tuple[int, int]]]:
|
| 57 |
+
"""
|
| 58 |
+
Generate a solvable FrozenLake grid with shortest path >= *min_path_len* moves.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
(desc, path) — desc is list[str], path is list[(r,c)].
|
| 62 |
+
"""
|
| 63 |
+
for _ in range(max_attempts):
|
| 64 |
+
desc = self._random_layout(size, p)
|
| 65 |
+
path = self.solve(desc)
|
| 66 |
+
if path is not None and (len(path) - 1) >= min_path_len:
|
| 67 |
+
return desc, path
|
| 68 |
+
raise RuntimeError(
|
| 69 |
+
f"Failed after {max_attempts} attempts "
|
| 70 |
+
f"(size={size}, p={p}, min_path_len={min_path_len})."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def _random_layout(size: int, p: float = 0.8) -> GridDesc:
|
| 75 |
+
"""Random grid with one S and one G at random positions."""
|
| 76 |
+
all_coords = [(r, c) for r in range(size) for c in range(size)]
|
| 77 |
+
start, goal = random.sample(all_coords, 2)
|
| 78 |
+
grid = []
|
| 79 |
+
for r in range(size):
|
| 80 |
+
row = []
|
| 81 |
+
for c in range(size):
|
| 82 |
+
if (r, c) == start:
|
| 83 |
+
row.append("S")
|
| 84 |
+
elif (r, c) == goal:
|
| 85 |
+
row.append("G")
|
| 86 |
+
else:
|
| 87 |
+
row.append("F" if random.random() < p else "H")
|
| 88 |
+
grid.append("".join(row))
|
| 89 |
+
return grid
|
| 90 |
+
|
| 91 |
+
# ==================== Solving (plain BFS — fast) ====================
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def solve(desc: GridDesc) -> Optional[List[Tuple[int, int]]]:
|
| 95 |
+
"""
|
| 96 |
+
BFS shortest path from S to G, avoiding H.
|
| 97 |
+
|
| 98 |
+
~100× faster than networkx for typical grid sizes because it avoids
|
| 99 |
+
Python-level graph object construction entirely.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of (r, c) or None.
|
| 103 |
+
"""
|
| 104 |
+
rows, cols = len(desc), len(desc[0])
|
| 105 |
+
start = goal = None
|
| 106 |
+
for r in range(rows):
|
| 107 |
+
for c in range(cols):
|
| 108 |
+
if desc[r][c] == "S":
|
| 109 |
+
start = (r, c)
|
| 110 |
+
elif desc[r][c] == "G":
|
| 111 |
+
goal = (r, c)
|
| 112 |
+
if start is None or goal is None:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
visited = [[False] * cols for _ in range(rows)]
|
| 116 |
+
visited[start[0]][start[1]] = True
|
| 117 |
+
queue: deque = deque([(start, [start])])
|
| 118 |
+
|
| 119 |
+
while queue:
|
| 120 |
+
(r, c), path = queue.popleft()
|
| 121 |
+
if (r, c) == goal:
|
| 122 |
+
return path
|
| 123 |
+
for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)):
|
| 124 |
+
nr, nc = r + dr, c + dc
|
| 125 |
+
if 0 <= nr < rows and 0 <= nc < cols and not visited[nr][nc]:
|
| 126 |
+
ch = desc[nr][nc]
|
| 127 |
+
if ch != "H":
|
| 128 |
+
visited[nr][nc] = True
|
| 129 |
+
queue.append(((nr, nc), path + [(nr, nc)]))
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
# ==================== Path ↔ UDRL ====================
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def path_to_udrl(path: List[Tuple[int, int]]) -> str:
|
| 136 |
+
"""Convert coordinate path to UDRL string."""
|
| 137 |
+
moves = []
|
| 138 |
+
for i in range(len(path) - 1):
|
| 139 |
+
r1, c1 = path[i]
|
| 140 |
+
r2, c2 = path[i + 1]
|
| 141 |
+
if r2 < r1:
|
| 142 |
+
moves.append("U")
|
| 143 |
+
elif r2 > r1:
|
| 144 |
+
moves.append("D")
|
| 145 |
+
elif c2 < c1:
|
| 146 |
+
moves.append("L")
|
| 147 |
+
else:
|
| 148 |
+
moves.append("R")
|
| 149 |
+
return "".join(moves)
|
| 150 |
+
|
| 151 |
+
# ==================== Verification ====================
|
| 152 |
+
|
| 153 |
+
def verify_path_sim(self, desc: GridDesc, udrl: str) -> bool:
|
| 154 |
+
"""Verify UDRL via grid simulation (no dependencies)."""
|
| 155 |
+
rows, cols = len(desc), len(desc[0])
|
| 156 |
+
start = self.find_start(desc)
|
| 157 |
+
if start is None:
|
| 158 |
+
return False
|
| 159 |
+
|
| 160 |
+
r, c = start
|
| 161 |
+
clean = udrl.replace(",", "").replace(" ", "").strip()
|
| 162 |
+
if "Action plan" in clean:
|
| 163 |
+
clean = clean.rsplit("Action plan", 1)[-1]
|
| 164 |
+
|
| 165 |
+
for ch in clean:
|
| 166 |
+
if ch not in MOVES:
|
| 167 |
+
continue
|
| 168 |
+
dr, dc = MOVES[ch]
|
| 169 |
+
nr, nc = r + dr, c + dc
|
| 170 |
+
if not (0 <= nr < rows and 0 <= nc < cols):
|
| 171 |
+
return False
|
| 172 |
+
cell = desc[nr][nc]
|
| 173 |
+
if cell == "H":
|
| 174 |
+
return False
|
| 175 |
+
r, c = nr, nc
|
| 176 |
+
if cell == "G":
|
| 177 |
+
return True
|
| 178 |
+
return desc[r][c] == "G"
|
| 179 |
+
|
| 180 |
+
def verify_path_gym(self, desc: GridDesc, udrl: str) -> bool:
|
| 181 |
+
"""Verify via gymnasium (falls back to sim if unavailable)."""
|
| 182 |
+
if not HAS_GYM:
|
| 183 |
+
return self.verify_path_sim(desc, udrl)
|
| 184 |
+
rows, cols = len(desc), len(desc[0])
|
| 185 |
+
try:
|
| 186 |
+
env = gym.make(
|
| 187 |
+
"FrozenLake-v1", desc=desc,
|
| 188 |
+
map_name=f"{rows}x{cols}", is_slippery=False, render_mode=None,
|
| 189 |
+
)
|
| 190 |
+
env.reset(seed=42)
|
| 191 |
+
success = False
|
| 192 |
+
clean = udrl.replace(",", "").replace(" ", "").strip()
|
| 193 |
+
if "Action plan" in clean:
|
| 194 |
+
clean = clean.rsplit("Action plan", 1)[-1]
|
| 195 |
+
for ch in clean:
|
| 196 |
+
if ch not in GYM_ACTION_MAP:
|
| 197 |
+
continue
|
| 198 |
+
_, reward, terminated, truncated, _ = env.step(GYM_ACTION_MAP[ch])
|
| 199 |
+
if terminated or truncated:
|
| 200 |
+
success = reward > 0
|
| 201 |
+
break
|
| 202 |
+
env.close()
|
| 203 |
+
return success
|
| 204 |
+
except Exception:
|
| 205 |
+
return self.verify_path_sim(desc, udrl)
|
| 206 |
+
|
| 207 |
+
# ==================== Table Text I/O ====================
|
| 208 |
+
|
| 209 |
+
def encode_table(self, desc: GridDesc) -> str:
|
| 210 |
+
"""Encode to pipe-delimited table format."""
|
| 211 |
+
size = len(desc)
|
| 212 |
+
lines = ["| | " + " | ".join(f"Col {i+1}" for i in range(size)) + " |"]
|
| 213 |
+
for r in range(size):
|
| 214 |
+
mapped = [GRID_TO_TABLE[ch] for ch in desc[r]]
|
| 215 |
+
lines.append(f"| Row {r+1} | " + " | ".join(mapped) + " |")
|
| 216 |
+
return "\n".join(lines)
|
| 217 |
+
|
| 218 |
+
def decode_table(self, text: str) -> Optional[GridDesc]:
|
| 219 |
+
"""Parse table text back to GridDesc."""
|
| 220 |
+
try:
|
| 221 |
+
rows = []
|
| 222 |
+
for line in text.strip().splitlines():
|
| 223 |
+
line = line.strip()
|
| 224 |
+
if not line or "Col" in line or "---" in line:
|
| 225 |
+
continue
|
| 226 |
+
parts = [p.strip() for p in line.split("|")]
|
| 227 |
+
clean = [p for p in parts if p]
|
| 228 |
+
if len(clean) < 2:
|
| 229 |
+
continue
|
| 230 |
+
row_str = "".join(
|
| 231 |
+
TABLE_TO_GRID[ch] for ch in clean[1:] if ch in TABLE_TO_GRID
|
| 232 |
+
)
|
| 233 |
+
if row_str:
|
| 234 |
+
rows.append(row_str)
|
| 235 |
+
return rows if rows else None
|
| 236 |
+
except Exception:
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
def save_table(self, filepath: str, desc: GridDesc) -> None:
|
| 240 |
+
with open(filepath, "w") as f:
|
| 241 |
+
f.write(self.encode_table(desc))
|
| 242 |
+
|
| 243 |
+
def load_table(self, filepath: str) -> Optional[GridDesc]:
|
| 244 |
+
try:
|
| 245 |
+
with open(filepath) as f:
|
| 246 |
+
return self.decode_table(f.read())
|
| 247 |
+
except Exception:
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
def find_start(self, desc: GridDesc) -> Optional[Tuple[int, int]]:
|
| 251 |
+
for r, row in enumerate(desc):
|
| 252 |
+
for c, ch in enumerate(row):
|
| 253 |
+
if ch == "S":
|
| 254 |
+
return (r, c)
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
def fingerprint(self, desc: GridDesc) -> str:
|
| 258 |
+
return "".join(desc)
|
| 259 |
+
|
| 260 |
+
# ==================== Rendering ====================
|
| 261 |
+
|
| 262 |
+
def render_gym(self, desc: GridDesc) -> Optional[Image.Image]:
|
| 263 |
+
"""Render via gymnasium (creates a pygame window — slow)."""
|
| 264 |
+
if not HAS_GYM:
|
| 265 |
+
return None
|
| 266 |
+
try:
|
| 267 |
+
env = gym.make(
|
| 268 |
+
"FrozenLake-v1", desc=desc,
|
| 269 |
+
is_slippery=False, render_mode="rgb_array",
|
| 270 |
+
)
|
| 271 |
+
env.reset()
|
| 272 |
+
rgb = env.render()
|
| 273 |
+
env.close()
|
| 274 |
+
return Image.fromarray(rgb).resize(
|
| 275 |
+
(self.img_size, self.img_size), Image.NEAREST
|
| 276 |
+
)
|
| 277 |
+
except Exception:
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
def render_simple(self, desc: GridDesc) -> Image.Image:
|
| 281 |
+
"""Fast PIL-only renderer (no pygame dependency)."""
|
| 282 |
+
size = len(desc)
|
| 283 |
+
cell = self.img_size // size
|
| 284 |
+
img = Image.new("RGB", (self.img_size, self.img_size), (255, 255, 255))
|
| 285 |
+
draw = ImageDraw.Draw(img)
|
| 286 |
+
colors = {
|
| 287 |
+
"S": (0, 0, 255), "F": (200, 220, 255),
|
| 288 |
+
"H": (80, 80, 80), "G": (0, 200, 0),
|
| 289 |
+
}
|
| 290 |
+
for r in range(size):
|
| 291 |
+
for c in range(size):
|
| 292 |
+
x0, y0 = c * cell, r * cell
|
| 293 |
+
draw.rectangle(
|
| 294 |
+
[x0, y0, x0 + cell - 1, y0 + cell - 1],
|
| 295 |
+
fill=colors.get(desc[r][c], (200, 220, 255)),
|
| 296 |
+
)
|
| 297 |
+
for i in range(size + 1):
|
| 298 |
+
draw.line([(i * cell, 0), (i * cell, self.img_size)], fill="black", width=1)
|
| 299 |
+
draw.line([(0, i * cell), (self.img_size, i * cell)], fill="black", width=1)
|
| 300 |
+
return img
|
| 301 |
+
|
| 302 |
+
def render(self, desc: GridDesc, use_gym: bool = True) -> Image.Image:
|
| 303 |
+
if use_gym:
|
| 304 |
+
img = self.render_gym(desc)
|
| 305 |
+
if img is not None:
|
| 306 |
+
return img
|
| 307 |
+
return self.render_simple(desc)
|
| 308 |
+
|
| 309 |
+
def draw_solution_line(
|
| 310 |
+
self, image: Image.Image, path: List[Tuple[int, int]], grid_size: int,
|
| 311 |
+
) -> Image.Image:
|
| 312 |
+
"""Draw red line on *image* (modifies in-place)."""
|
| 313 |
+
draw = ImageDraw.Draw(image)
|
| 314 |
+
w, h = image.size
|
| 315 |
+
cw, ch_ = w / grid_size, h / grid_size
|
| 316 |
+
pts = [(c * cw + cw / 2, r * ch_ + ch_ / 2) for r, c in path]
|
| 317 |
+
draw.line(pts, fill=self.path_color, width=max(1, int(cw / 4)), joint="curve")
|
| 318 |
+
return image
|
| 319 |
+
|
| 320 |
+
# ==================== Video Frames ====================
|
| 321 |
+
|
| 322 |
+
def generate_video_frames(
|
| 323 |
+
self,
|
| 324 |
+
desc: GridDesc,
|
| 325 |
+
path: List[Tuple[int, int]],
|
| 326 |
+
n_start: int = 5,
|
| 327 |
+
m_end: int = 5,
|
| 328 |
+
frames: Optional[int] = None,
|
| 329 |
+
use_gym: bool = True,
|
| 330 |
+
) -> List[Image.Image]:
|
| 331 |
+
"""
|
| 332 |
+
Progressive red-line video frames.
|
| 333 |
+
|
| 334 |
+
*frames* controls content frames between holds:
|
| 335 |
+
None → 1 per step, >steps → slow-mo, <steps → fast-fwd.
|
| 336 |
+
"""
|
| 337 |
+
size = len(desc)
|
| 338 |
+
n_steps = len(path) - 1
|
| 339 |
+
base_img = self.render(desc, use_gym=use_gym)
|
| 340 |
+
|
| 341 |
+
if n_steps <= 0:
|
| 342 |
+
return [base_img] * (n_start + m_end + 1)
|
| 343 |
+
|
| 344 |
+
content = frames if frames is not None else n_steps
|
| 345 |
+
content = max(1, content)
|
| 346 |
+
result: List[Image.Image] = []
|
| 347 |
+
|
| 348 |
+
# Opening hold
|
| 349 |
+
result.extend([base_img.copy() for _ in range(n_start)])
|
| 350 |
+
|
| 351 |
+
def _partial(steps: int) -> Image.Image:
|
| 352 |
+
return self.draw_solution_line(base_img.copy(), path[: steps + 1], size)
|
| 353 |
+
|
| 354 |
+
if content == n_steps:
|
| 355 |
+
for s in range(1, n_steps + 1):
|
| 356 |
+
result.append(_partial(s))
|
| 357 |
+
elif content > n_steps:
|
| 358 |
+
for s in range(1, n_steps + 1):
|
| 359 |
+
lo = (s - 1) * content // n_steps
|
| 360 |
+
hi = s * content // n_steps
|
| 361 |
+
frame = _partial(s)
|
| 362 |
+
result.append(frame)
|
| 363 |
+
for _ in range(hi - lo - 1):
|
| 364 |
+
result.append(frame.copy())
|
| 365 |
+
else:
|
| 366 |
+
for f in range(content):
|
| 367 |
+
result.append(_partial((f + 1) * n_steps // content))
|
| 368 |
+
|
| 369 |
+
# Closing hold
|
| 370 |
+
final = _partial(n_steps)
|
| 371 |
+
result.extend([final.copy() for _ in range(m_end)])
|
| 372 |
+
return result
|
| 373 |
+
|
| 374 |
+
# ==================== Red-Path Extraction ====================
|
| 375 |
+
|
| 376 |
+
def extract_path_from_pixels(
|
| 377 |
+
self,
|
| 378 |
+
pixels: np.ndarray,
|
| 379 |
+
rows: int,
|
| 380 |
+
cols: int,
|
| 381 |
+
start: Tuple[int, int],
|
| 382 |
+
desc: Optional[GridDesc] = None,
|
| 383 |
+
pixel_threshold: float = 0.01,
|
| 384 |
+
) -> str:
|
| 385 |
+
"""Detect red path in RGB array, return UDRL."""
|
| 386 |
+
img = Image.fromarray(pixels)
|
| 387 |
+
w, h = img.size
|
| 388 |
+
px = np.array(img, dtype=float)
|
| 389 |
+
r_ch, g_ch, b_ch = px[:, :, 0], px[:, :, 1], px[:, :, 2]
|
| 390 |
+
red_mask = (r_ch > 100) & (r_ch > g_ch * 1.2) & (r_ch > b_ch * 1.2)
|
| 391 |
+
|
| 392 |
+
cell_h, cell_w = h // rows, w // cols
|
| 393 |
+
path_grid = np.zeros((rows, cols), dtype=bool)
|
| 394 |
+
for r in range(rows):
|
| 395 |
+
for c in range(cols):
|
| 396 |
+
sub = red_mask[r * cell_h : (r + 1) * cell_h,
|
| 397 |
+
c * cell_w : (c + 1) * cell_w]
|
| 398 |
+
if sub.size > 0 and np.mean(sub) > pixel_threshold:
|
| 399 |
+
path_grid[r, c] = True
|
| 400 |
+
|
| 401 |
+
# Greedy walk
|
| 402 |
+
visited = {start}
|
| 403 |
+
cr, cc = start
|
| 404 |
+
actions: List[str] = []
|
| 405 |
+
for _ in range(rows * cols * 2):
|
| 406 |
+
found = False
|
| 407 |
+
for act, (dr, dc) in [("R", (0, 1)), ("D", (1, 0)), ("L", (0, -1)), ("U", (-1, 0))]:
|
| 408 |
+
nr, nc = cr + dr, cc + dc
|
| 409 |
+
if 0 <= nr < rows and 0 <= nc < cols:
|
| 410 |
+
if path_grid[nr, nc] and (nr, nc) not in visited:
|
| 411 |
+
visited.add((nr, nc))
|
| 412 |
+
actions.append(act)
|
| 413 |
+
cr, cc = nr, nc
|
| 414 |
+
found = True
|
| 415 |
+
break
|
| 416 |
+
if not found:
|
| 417 |
+
break
|
| 418 |
+
return "".join(actions)
|
| 419 |
+
|
| 420 |
+
def extract_path_from_image(
|
| 421 |
+
self, img_path: str, rows: int, cols: int, start: Tuple, desc=None,
|
| 422 |
+
) -> str:
|
| 423 |
+
"""Extract UDRL from an image file."""
|
| 424 |
+
try:
|
| 425 |
+
pixels = np.array(Image.open(img_path).convert("RGB"))
|
| 426 |
+
return self.extract_path_from_pixels(pixels, rows, cols, start, desc)
|
| 427 |
+
except Exception:
|
| 428 |
+
return ""
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
import time
|
| 433 |
+
|
| 434 |
+
proc = FrozenLakeProcessor(img_size=512)
|
| 435 |
+
|
| 436 |
+
# Benchmark BFS vs problem sizes
|
| 437 |
+
for sz in [8, 16, 32, 64]:
|
| 438 |
+
t0 = time.perf_counter()
|
| 439 |
+
count = 0
|
| 440 |
+
for _ in range(100):
|
| 441 |
+
desc = proc._random_layout(sz, p=0.8)
|
| 442 |
+
path = proc.solve(desc)
|
| 443 |
+
if path:
|
| 444 |
+
count += 1
|
| 445 |
+
elapsed = time.perf_counter() - t0
|
| 446 |
+
print(f"Size {sz:3d}: 100 BFS solves in {elapsed:.3f}s "
|
| 447 |
+
f"({count} solvable, {elapsed/100*1000:.1f}ms/solve)")
|
| 448 |
+
|
| 449 |
+
# Functional test
|
| 450 |
+
desc, path = proc.generate(size=16, p=0.8, min_path_len=20)
|
| 451 |
+
udrl = proc.path_to_udrl(path)
|
| 452 |
+
print(f"\nGenerate 16×16: path={len(path)}, UDRL={udrl[:40]}...")
|
| 453 |
+
print(f"Verify (sim): {proc.verify_path_sim(desc, udrl)}")
|
| 454 |
+
|
| 455 |
+
# Table round-trip
|
| 456 |
+
decoded = proc.decode_table(proc.encode_table(desc))
|
| 457 |
+
assert decoded == desc
|
| 458 |
+
print("Table round-trip: ✓")
|
| 459 |
+
|
| 460 |
+
# Render + extract round-trip
|
| 461 |
+
img = proc.render(desc, use_gym=False)
|
| 462 |
+
sol = proc.draw_solution_line(img.copy(), path, len(desc))
|
| 463 |
+
start = proc.find_start(desc)
|
| 464 |
+
extracted = proc.extract_path_from_pixels(np.array(sol), len(desc), len(desc[0]), start)
|
| 465 |
+
print(f"Extract round-trip verify: {proc.verify_path_sim(desc, extracted)}")
|
| 466 |
+
print("All tests passed ✓")
|
maze/data_process.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Maze Video Dataset Generator — generates maze puzzle images and solution videos
|
| 3 |
+
with checkpoint/resume support, train/test splitting, and JSONL metadata.
|
| 4 |
+
|
| 5 |
+
Includes an ``eval`` subcommand that takes a directory of result videos,
|
| 6 |
+
extracts the last frame from each, parses the red path, and verifies it
|
| 7 |
+
against the ground-truth maze text files.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# Generate
|
| 11 |
+
python maze_video_gen.py generate --output-dir maze --sizes 8 16 32 \
|
| 12 |
+
--num-per-size 100 500 1000 --min-path-ratio 0.3 \
|
| 13 |
+
--n-start 5 --m-end 5 --frames 50 --fps 10 --seed 42
|
| 14 |
+
|
| 15 |
+
# Evaluate result videos
|
| 16 |
+
python maze_video_gen.py eval result_videos/ --text-dir maze/texts
|
| 17 |
+
|
| 18 |
+
# Verify a pre-extracted JSON
|
| 19 |
+
python maze_video_gen.py verify results.json --text-dir maze/texts
|
| 20 |
+
"""
|
| 21 |
+
import json
|
| 22 |
+
import csv
|
| 23 |
+
import hashlib
|
| 24 |
+
import random
|
| 25 |
+
import re
|
| 26 |
+
import argparse
|
| 27 |
+
from dataclasses import dataclass, asdict
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Dict, List, Optional
|
| 30 |
+
|
| 31 |
+
import cv2
|
| 32 |
+
import numpy as np
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
from maze_processor import MazeProcessor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ==================== Checkpoint Management ====================
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class GenerationState:
|
| 42 |
+
"""Tracks generation progress for checkpoint/resume."""
|
| 43 |
+
params_hash: str
|
| 44 |
+
size_progress: Dict[int, int]
|
| 45 |
+
seen_fingerprints: List[str]
|
| 46 |
+
all_samples: List[Dict]
|
| 47 |
+
completed: bool = False
|
| 48 |
+
|
| 49 |
+
def to_dict(self) -> Dict:
|
| 50 |
+
return asdict(self)
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_dict(cls, d: Dict) -> "GenerationState":
|
| 54 |
+
return cls(**d)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _params_hash(params: Dict) -> str:
|
| 58 |
+
"""Deterministic hash of generation parameters (excluding output_dir)."""
|
| 59 |
+
key = {k: v for k, v in params.items() if k != "output_dir"}
|
| 60 |
+
return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
|
| 64 |
+
"""Load checkpoint if it exists and parameters match."""
|
| 65 |
+
meta = output_dir / "metadata.json"
|
| 66 |
+
if not meta.exists():
|
| 67 |
+
return None
|
| 68 |
+
with open(meta) as f:
|
| 69 |
+
data = json.load(f)
|
| 70 |
+
state = GenerationState.from_dict(data["state"])
|
| 71 |
+
expected = _params_hash(params)
|
| 72 |
+
if state.params_hash != expected:
|
| 73 |
+
print(f"⚠️ Parameters changed ({state.params_hash} → {expected}), starting fresh")
|
| 74 |
+
return None
|
| 75 |
+
if state.completed:
|
| 76 |
+
print("✓ Generation already completed")
|
| 77 |
+
return state
|
| 78 |
+
done = sum(state.size_progress.values())
|
| 79 |
+
print(f"✓ Resuming from checkpoint: {done} mazes generated")
|
| 80 |
+
return state
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
|
| 84 |
+
"""Atomically write checkpoint to metadata.json."""
|
| 85 |
+
meta = output_dir / "metadata.json"
|
| 86 |
+
tmp = meta.with_suffix(".tmp")
|
| 87 |
+
with open(tmp, "w") as f:
|
| 88 |
+
json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
|
| 89 |
+
tmp.rename(meta)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ==================== Video I/O ====================
|
| 93 |
+
|
| 94 |
+
def save_video_cv2(frames: list, path: str, fps: int = 10):
|
| 95 |
+
"""Save list of PIL Images as an mp4 video."""
|
| 96 |
+
first = np.array(frames[0])
|
| 97 |
+
h, w = first.shape[:2]
|
| 98 |
+
writer = cv2.VideoWriter(
|
| 99 |
+
str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)
|
| 100 |
+
)
|
| 101 |
+
for frame in frames:
|
| 102 |
+
writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
|
| 103 |
+
writer.release()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def extract_last_frame(video_path: str) -> Optional[np.ndarray]:
|
| 107 |
+
"""
|
| 108 |
+
Extract the last frame from a video file as an RGB numpy array.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
(H, W, 3) uint8 RGB array, or None on failure.
|
| 112 |
+
"""
|
| 113 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 114 |
+
if not cap.isOpened():
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 118 |
+
if total > 0:
|
| 119 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1)
|
| 120 |
+
|
| 121 |
+
ret, frame = cap.read()
|
| 122 |
+
cap.release()
|
| 123 |
+
|
| 124 |
+
if not ret or frame is None:
|
| 125 |
+
return None
|
| 126 |
+
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ==================== Normalisation Helpers ====================
|
| 130 |
+
|
| 131 |
+
def _normalise_list(val, sizes, name="parameter"):
|
| 132 |
+
"""Broadcast a single int to a list, or validate list length."""
|
| 133 |
+
if isinstance(val, int):
|
| 134 |
+
return [val] * len(sizes)
|
| 135 |
+
if len(val) != len(sizes):
|
| 136 |
+
raise ValueError(f"{name} length ({len(val)}) != sizes length ({len(sizes)})")
|
| 137 |
+
return list(val)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ==================== Core Dataset Generation ====================
|
| 141 |
+
|
| 142 |
+
def generate_dataset(
|
| 143 |
+
output_dir: str = "maze",
|
| 144 |
+
sizes: List[int] = [8, 16, 32],
|
| 145 |
+
num_per_size: list = [100, 500, 1000],
|
| 146 |
+
min_path_ratio: float = 0.3,
|
| 147 |
+
img_size: int = 1024,
|
| 148 |
+
prompt: str = "Draw a continuous red line from the yellow dot to the blue dot, avoiding all walls.",
|
| 149 |
+
train_ratio: float = 0.9,
|
| 150 |
+
n_start: int = 5,
|
| 151 |
+
m_end: int = 5,
|
| 152 |
+
frames: Optional[int] = None,
|
| 153 |
+
fps: int = 10,
|
| 154 |
+
seed: int = 42,
|
| 155 |
+
checkpoint_interval: int = 50,
|
| 156 |
+
):
|
| 157 |
+
"""
|
| 158 |
+
Generate maze video dataset with checkpoint/resume support.
|
| 159 |
+
|
| 160 |
+
The *frames* parameter controls content frames per video:
|
| 161 |
+
- None → one content frame per path step (variable length)
|
| 162 |
+
- N > 0 → exactly N content frames (slow-mo / fast-fwd as needed)
|
| 163 |
+
|
| 164 |
+
Directory layout::
|
| 165 |
+
|
| 166 |
+
output_dir/
|
| 167 |
+
images/ — puzzle PNG (no solution line)
|
| 168 |
+
videos/ — solution MP4 (progressive red line)
|
| 169 |
+
texts/ — maze text files (bitmask format)
|
| 170 |
+
train.jsonl / test.jsonl
|
| 171 |
+
train.csv / test.csv
|
| 172 |
+
path.json — UDRL answer key
|
| 173 |
+
metadata.json — checkpoint state
|
| 174 |
+
"""
|
| 175 |
+
params = {
|
| 176 |
+
"sizes": sizes, "num_per_size": num_per_size,
|
| 177 |
+
"min_path_ratio": min_path_ratio, "img_size": img_size,
|
| 178 |
+
"prompt": prompt, "train_ratio": train_ratio,
|
| 179 |
+
"n_start": n_start, "m_end": m_end, "frames": frames,
|
| 180 |
+
"fps": fps, "seed": seed,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
out = Path(output_dir)
|
| 184 |
+
img_dir = out / "images"
|
| 185 |
+
vid_dir = out / "videos"
|
| 186 |
+
txt_dir = out / "texts"
|
| 187 |
+
for d in (img_dir, vid_dir, txt_dir):
|
| 188 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 189 |
+
|
| 190 |
+
state = load_checkpoint(out, params)
|
| 191 |
+
if state and state.completed:
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
num_list = _normalise_list(
|
| 195 |
+
num_per_size[0] if len(num_per_size) == 1 else num_per_size,
|
| 196 |
+
sizes, "num_per_size",
|
| 197 |
+
)
|
| 198 |
+
max_puzzles = max(num_list)
|
| 199 |
+
num_w = len(str(max_puzzles))
|
| 200 |
+
proc = MazeProcessor(img_size=img_size)
|
| 201 |
+
|
| 202 |
+
if state is None:
|
| 203 |
+
random.seed(seed)
|
| 204 |
+
state = GenerationState(
|
| 205 |
+
params_hash=_params_hash(params),
|
| 206 |
+
size_progress={sz: 0 for sz in sizes},
|
| 207 |
+
seen_fingerprints=[],
|
| 208 |
+
all_samples=[],
|
| 209 |
+
)
|
| 210 |
+
print(f"Starting fresh generation: sizes={sizes}, counts={num_list}")
|
| 211 |
+
print(f" frames={'auto (1 per step)' if frames is None else frames}, "
|
| 212 |
+
f"n_start={n_start}, m_end={m_end}, fps={fps}")
|
| 213 |
+
else:
|
| 214 |
+
random.seed(seed)
|
| 215 |
+
for _ in range(sum(state.size_progress.values()) * 10):
|
| 216 |
+
random.random()
|
| 217 |
+
|
| 218 |
+
seen = set(state.seen_fingerprints)
|
| 219 |
+
all_samples = list(state.all_samples)
|
| 220 |
+
progress = {int(k): v for k, v in state.size_progress.items()}
|
| 221 |
+
since_ckpt = 0
|
| 222 |
+
|
| 223 |
+
total_target = sum(num_list)
|
| 224 |
+
total_done = sum(progress.values())
|
| 225 |
+
|
| 226 |
+
with tqdm(total=total_target, initial=total_done, desc="Total", unit="maze") as pbar:
|
| 227 |
+
for maze_size, target in zip(sizes, num_list):
|
| 228 |
+
generated = progress.get(maze_size, 0)
|
| 229 |
+
if generated >= target:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
min_len = max(1, int(maze_size * maze_size * min_path_ratio))
|
| 233 |
+
max_attempts = (target - generated) * 20
|
| 234 |
+
|
| 235 |
+
with tqdm(
|
| 236 |
+
total=target, initial=generated, desc=f"Size {maze_size:3d}",
|
| 237 |
+
unit="maze", leave=False,
|
| 238 |
+
) as pbar_sz:
|
| 239 |
+
for _ in range(max_attempts):
|
| 240 |
+
if generated >= target:
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
grid, start, end, path = proc.generate(
|
| 245 |
+
maze_size, min_path_len=min_len
|
| 246 |
+
)
|
| 247 |
+
except RuntimeError:
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
fp = proc.fingerprint(grid, start, end)
|
| 251 |
+
if fp in seen:
|
| 252 |
+
continue
|
| 253 |
+
seen.add(fp)
|
| 254 |
+
|
| 255 |
+
idx = generated
|
| 256 |
+
base = f"size{maze_size}_{idx:0{num_w}d}"
|
| 257 |
+
img_name = f"{base}.png"
|
| 258 |
+
vid_name = f"{base}.mp4"
|
| 259 |
+
txt_name = f"{base}.txt"
|
| 260 |
+
|
| 261 |
+
puzzle_img = proc.render(grid, start, end)
|
| 262 |
+
puzzle_img.save(str(img_dir / img_name))
|
| 263 |
+
|
| 264 |
+
vid_frames = proc.generate_video_frames(
|
| 265 |
+
grid, start, end, path,
|
| 266 |
+
n_start=n_start, m_end=m_end, frames=frames,
|
| 267 |
+
)
|
| 268 |
+
save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps)
|
| 269 |
+
|
| 270 |
+
proc.save_text(str(txt_dir / txt_name), grid, start, end)
|
| 271 |
+
|
| 272 |
+
udrl = proc.path_to_udrl(path)
|
| 273 |
+
|
| 274 |
+
all_samples.append({
|
| 275 |
+
"prompt": prompt,
|
| 276 |
+
"image": img_name,
|
| 277 |
+
"video": vid_name,
|
| 278 |
+
"text": txt_name,
|
| 279 |
+
"maze_size": maze_size,
|
| 280 |
+
"start": list(start),
|
| 281 |
+
"end": list(end),
|
| 282 |
+
"path_udrl": udrl,
|
| 283 |
+
"path_length": len(path),
|
| 284 |
+
"frame_count": len(vid_frames),
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
generated += 1
|
| 288 |
+
progress[maze_size] = generated
|
| 289 |
+
since_ckpt += 1
|
| 290 |
+
pbar_sz.update(1)
|
| 291 |
+
pbar.update(1)
|
| 292 |
+
|
| 293 |
+
if since_ckpt >= checkpoint_interval:
|
| 294 |
+
state.size_progress = progress
|
| 295 |
+
state.seen_fingerprints = list(seen)
|
| 296 |
+
state.all_samples = all_samples
|
| 297 |
+
save_checkpoint(out, state, params)
|
| 298 |
+
since_ckpt = 0
|
| 299 |
+
|
| 300 |
+
tqdm.write(
|
| 301 |
+
f"Size {maze_size}: {generated} mazes, "
|
| 302 |
+
f"{sum(1 for s in all_samples if s['maze_size'] == maze_size)} samples"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# ==================== Final outputs ====================
|
| 306 |
+
|
| 307 |
+
path_answers = {s["image"]: s["path_udrl"] for s in all_samples}
|
| 308 |
+
with open(out / "path.json", "w") as f:
|
| 309 |
+
json.dump(dict(sorted(path_answers.items())), f, indent=4)
|
| 310 |
+
|
| 311 |
+
random.seed(seed + 1)
|
| 312 |
+
random.shuffle(all_samples)
|
| 313 |
+
split = int(len(all_samples) * train_ratio)
|
| 314 |
+
|
| 315 |
+
def _write_jsonl(samples, path):
|
| 316 |
+
with open(path, "w") as f:
|
| 317 |
+
for s in samples:
|
| 318 |
+
f.write(json.dumps(s) + "\n")
|
| 319 |
+
|
| 320 |
+
_write_jsonl(all_samples[:split], out / "train.jsonl")
|
| 321 |
+
_write_jsonl(all_samples[split:], out / "test.jsonl")
|
| 322 |
+
|
| 323 |
+
for name, samples in [("train", all_samples[:split]), ("test", all_samples[split:])]:
|
| 324 |
+
with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f:
|
| 325 |
+
writer = csv.writer(f)
|
| 326 |
+
writer.writerow(["input_image", "video", "prompt"])
|
| 327 |
+
for s in samples:
|
| 328 |
+
writer.writerow([
|
| 329 |
+
f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]
|
| 330 |
+
])
|
| 331 |
+
|
| 332 |
+
state.size_progress = progress
|
| 333 |
+
state.seen_fingerprints = list(seen)
|
| 334 |
+
state.all_samples = all_samples
|
| 335 |
+
state.completed = True
|
| 336 |
+
save_checkpoint(out, state, params)
|
| 337 |
+
|
| 338 |
+
print(f"\n✓ Dataset complete: {out}/")
|
| 339 |
+
print(f" Sizes: {sizes}")
|
| 340 |
+
print(f" Mazes: {len(all_samples)}")
|
| 341 |
+
print(f" Train: {split}, Test: {len(all_samples) - split}")
|
| 342 |
+
lengths = [s["path_length"] for s in all_samples]
|
| 343 |
+
fcounts = [s["frame_count"] for s in all_samples]
|
| 344 |
+
print(f" Path lengths: avg={np.mean(lengths):.1f}, "
|
| 345 |
+
f"min={min(lengths)}, max={max(lengths)}")
|
| 346 |
+
print(f" Frame counts: avg={np.mean(fcounts):.1f}, "
|
| 347 |
+
f"min={min(fcounts)}, max={max(fcounts)}")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ==================== Eval: Video → Last Frame → Verify ====================
|
| 351 |
+
|
| 352 |
+
def eval_videos(
|
| 353 |
+
video_dir: str,
|
| 354 |
+
text_dir: str,
|
| 355 |
+
output_json: Optional[str] = None,
|
| 356 |
+
gt_json: Optional[str] = None,
|
| 357 |
+
):
|
| 358 |
+
"""
|
| 359 |
+
Evaluate a directory of result videos against ground-truth mazes.
|
| 360 |
+
|
| 361 |
+
Pipeline per video:
|
| 362 |
+
1. Extract last frame from .mp4
|
| 363 |
+
2. Detect red path via pixel analysis
|
| 364 |
+
3. Convert to UDRL action string
|
| 365 |
+
4. Verify against maze .txt (wall-respecting walk from start to end)
|
| 366 |
+
|
| 367 |
+
Matching convention:
|
| 368 |
+
Video ``<stem>.mp4`` → Text ``<stem>.txt`` in *text_dir*.
|
| 369 |
+
Common stems: ``size8_000``, ``size16_042``, etc.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
video_dir: Directory containing result .mp4 files.
|
| 373 |
+
text_dir: Directory containing ground-truth maze .txt files.
|
| 374 |
+
output_json: Path to save extracted paths as JSON (default: video_dir/0_result.json).
|
| 375 |
+
gt_json: Optional ground-truth answer JSON for accuracy by path length.
|
| 376 |
+
"""
|
| 377 |
+
proc = MazeProcessor()
|
| 378 |
+
vid_root = Path(video_dir)
|
| 379 |
+
txt_root = Path(text_dir)
|
| 380 |
+
|
| 381 |
+
if output_json is None:
|
| 382 |
+
output_json = str(vid_root / "0_result.json")
|
| 383 |
+
|
| 384 |
+
# Collect videos
|
| 385 |
+
videos = sorted(
|
| 386 |
+
vid_root.glob("*.mp4"),
|
| 387 |
+
key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if not videos:
|
| 391 |
+
print(f"No .mp4 files found in {vid_root}")
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
print(f"Found {len(videos)} result videos in {vid_root}")
|
| 395 |
+
print(f"Text dir: {txt_root}")
|
| 396 |
+
|
| 397 |
+
# --- Phase 1: Extract paths from last frames ---
|
| 398 |
+
extracted: Dict[str, str] = {}
|
| 399 |
+
missing_txt = 0
|
| 400 |
+
missing_frame = 0
|
| 401 |
+
|
| 402 |
+
for vpath in tqdm(videos, desc="Extracting paths"):
|
| 403 |
+
stem = vpath.stem # e.g. "size8_000"
|
| 404 |
+
txt_path = txt_root / f"{stem}.txt"
|
| 405 |
+
|
| 406 |
+
if not txt_path.exists():
|
| 407 |
+
missing_txt += 1
|
| 408 |
+
continue
|
| 409 |
+
|
| 410 |
+
maze = proc.load_text(str(txt_path))
|
| 411 |
+
if maze is None:
|
| 412 |
+
missing_txt += 1
|
| 413 |
+
continue
|
| 414 |
+
|
| 415 |
+
last_frame = extract_last_frame(str(vpath))
|
| 416 |
+
if last_frame is None:
|
| 417 |
+
missing_frame += 1
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
udrl = proc.extract_path_from_pixels(
|
| 421 |
+
last_frame,
|
| 422 |
+
grid_raw=maze["grid_raw"],
|
| 423 |
+
size=maze["size"],
|
| 424 |
+
start=maze["start"],
|
| 425 |
+
)
|
| 426 |
+
extracted[f"{stem}.png"] = udrl # keyed by image name for consistency
|
| 427 |
+
|
| 428 |
+
# Save extracted paths
|
| 429 |
+
with open(output_json, "w", encoding="utf-8") as f:
|
| 430 |
+
json.dump(extracted, f, indent=4)
|
| 431 |
+
print(f"\nExtracted paths saved to: {output_json}")
|
| 432 |
+
|
| 433 |
+
# --- Phase 2: Verify ---
|
| 434 |
+
correct = 0
|
| 435 |
+
total_valid = 0
|
| 436 |
+
correctly_solved: List[Dict] = []
|
| 437 |
+
|
| 438 |
+
for name, udrl in extracted.items():
|
| 439 |
+
stem = name.replace(".png", "")
|
| 440 |
+
txt_path = txt_root / f"{stem}.txt"
|
| 441 |
+
maze = proc.load_text(str(txt_path))
|
| 442 |
+
if maze is None:
|
| 443 |
+
continue
|
| 444 |
+
total_valid += 1
|
| 445 |
+
if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl):
|
| 446 |
+
correct += 1
|
| 447 |
+
correctly_solved.append({"name": name, "length": len(udrl)})
|
| 448 |
+
|
| 449 |
+
acc = (correct / total_valid * 100) if total_valid else 0
|
| 450 |
+
|
| 451 |
+
print(f"\n{'=' * 50}")
|
| 452 |
+
print("Evaluation Summary")
|
| 453 |
+
print(f"{'=' * 50}")
|
| 454 |
+
print(f"Total Videos : {len(videos)}")
|
| 455 |
+
print(f"Missing .txt : {missing_txt}")
|
| 456 |
+
print(f"Failed Frame Read : {missing_frame}")
|
| 457 |
+
print(f"Evaluated : {total_valid}")
|
| 458 |
+
print(f"Correctly Solved : {correct}")
|
| 459 |
+
print(f"Accuracy : {acc:.2f}%")
|
| 460 |
+
print(f"{'-' * 50}")
|
| 461 |
+
|
| 462 |
+
# Breakdown by maze size
|
| 463 |
+
size_stats: Dict[int, Dict[str, int]] = {}
|
| 464 |
+
for name, udrl in extracted.items():
|
| 465 |
+
stem = name.replace(".png", "")
|
| 466 |
+
txt_path = txt_root / f"{stem}.txt"
|
| 467 |
+
maze = proc.load_text(str(txt_path))
|
| 468 |
+
if maze is None:
|
| 469 |
+
continue
|
| 470 |
+
sz = maze["size"]
|
| 471 |
+
if sz not in size_stats:
|
| 472 |
+
size_stats[sz] = {"total": 0, "correct": 0}
|
| 473 |
+
size_stats[sz]["total"] += 1
|
| 474 |
+
if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl):
|
| 475 |
+
size_stats[sz]["correct"] += 1
|
| 476 |
+
|
| 477 |
+
if size_stats:
|
| 478 |
+
print("\nAccuracy by maze size:")
|
| 479 |
+
for sz in sorted(size_stats):
|
| 480 |
+
s = size_stats[sz]
|
| 481 |
+
sz_acc = s["correct"] / s["total"] * 100 if s["total"] else 0
|
| 482 |
+
print(f" Size {sz:3d}: {s['correct']:4d}/{s['total']:4d} ({sz_acc:.2f}%)")
|
| 483 |
+
|
| 484 |
+
# Top longest correct
|
| 485 |
+
correctly_solved.sort(key=lambda x: x["length"], reverse=True)
|
| 486 |
+
if correctly_solved:
|
| 487 |
+
print(f"\nTop 3 Longest Correct Paths:")
|
| 488 |
+
for i, item in enumerate(correctly_solved[:3]):
|
| 489 |
+
print(f" {i+1}. {item['name']} (length: {item['length']})")
|
| 490 |
+
|
| 491 |
+
# Optional: compare with ground-truth JSON for path-length-binned accuracy
|
| 492 |
+
if gt_json:
|
| 493 |
+
_compare_with_gt(extracted, gt_json, txt_root, proc)
|
| 494 |
+
|
| 495 |
+
print(f"{'=' * 50}")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _compare_with_gt(
|
| 499 |
+
extracted: Dict[str, str],
|
| 500 |
+
gt_json_path: str,
|
| 501 |
+
txt_root: Path,
|
| 502 |
+
proc: MazeProcessor,
|
| 503 |
+
):
|
| 504 |
+
"""Print accuracy binned by ground-truth path length."""
|
| 505 |
+
try:
|
| 506 |
+
with open(gt_json_path) as f:
|
| 507 |
+
gt = json.load(f)
|
| 508 |
+
except Exception:
|
| 509 |
+
print(f" Warning: could not load ground-truth JSON: {gt_json_path}")
|
| 510 |
+
return
|
| 511 |
+
|
| 512 |
+
bins: Dict[str, Dict[str, int]] = {} # "10-19" -> {total, correct}
|
| 513 |
+
for name, pred_udrl in extracted.items():
|
| 514 |
+
if name not in gt:
|
| 515 |
+
continue
|
| 516 |
+
gt_udrl = gt[name]
|
| 517 |
+
gt_len = len(gt_udrl)
|
| 518 |
+
|
| 519 |
+
# Bin by path length (decades)
|
| 520 |
+
lo = (gt_len // 10) * 10
|
| 521 |
+
hi = lo + 9
|
| 522 |
+
label = f"{lo:3d}-{hi:3d}"
|
| 523 |
+
if label not in bins:
|
| 524 |
+
bins[label] = {"total": 0, "correct": 0}
|
| 525 |
+
bins[label]["total"] += 1
|
| 526 |
+
|
| 527 |
+
stem = name.replace(".png", "")
|
| 528 |
+
maze = proc.load_text(str(txt_root / f"{stem}.txt"))
|
| 529 |
+
if maze and proc.verify_path(maze["grid"], maze["start"], maze["end"], pred_udrl):
|
| 530 |
+
bins[label]["correct"] += 1
|
| 531 |
+
|
| 532 |
+
if bins:
|
| 533 |
+
print("\nAccuracy by GT path length:")
|
| 534 |
+
for label in sorted(bins):
|
| 535 |
+
b = bins[label]
|
| 536 |
+
b_acc = b["correct"] / b["total"] * 100 if b["total"] else 0
|
| 537 |
+
print(f" Length {label}: {b['correct']:4d}/{b['total']:4d} ({b_acc:.2f}%)")
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
# ==================== Verify: Pre-extracted JSON ====================
|
| 541 |
+
|
| 542 |
+
def verify_results(json_file: str, text_dir: str):
|
| 543 |
+
"""
|
| 544 |
+
Verify pre-extracted UDRL paths (from a JSON file) against maze .txt files.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
json_file: Path to JSON with {name: udrl_string} predictions.
|
| 548 |
+
text_dir: Directory containing maze .txt files.
|
| 549 |
+
"""
|
| 550 |
+
proc = MazeProcessor()
|
| 551 |
+
json_path = Path(json_file)
|
| 552 |
+
txt_root = Path(text_dir)
|
| 553 |
+
|
| 554 |
+
with open(json_path) as f:
|
| 555 |
+
solutions = json.load(f)
|
| 556 |
+
|
| 557 |
+
correct = skipped = valid = 0
|
| 558 |
+
|
| 559 |
+
for name, udrl in solutions.items():
|
| 560 |
+
clean = name.replace(".png", "")
|
| 561 |
+
txt_path = txt_root / f"{clean}.txt"
|
| 562 |
+
maze = proc.load_text(str(txt_path))
|
| 563 |
+
if maze is None:
|
| 564 |
+
skipped += 1
|
| 565 |
+
continue
|
| 566 |
+
valid += 1
|
| 567 |
+
if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl):
|
| 568 |
+
correct += 1
|
| 569 |
+
|
| 570 |
+
acc = (correct / valid * 100) if valid else 0
|
| 571 |
+
print(f"\n{'='*40}")
|
| 572 |
+
print(f"Verification: {correct}/{valid} correct ({acc:.2f}%)")
|
| 573 |
+
if skipped:
|
| 574 |
+
print(f"Skipped: {skipped}")
|
| 575 |
+
print(f"{'='*40}")
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# ==================== CLI ====================
|
| 579 |
+
|
| 580 |
+
def parse_args():
|
| 581 |
+
p = argparse.ArgumentParser(
|
| 582 |
+
description="Maze video dataset: generate, eval, verify"
|
| 583 |
+
)
|
| 584 |
+
sub = p.add_subparsers(dest="command", help="Sub-command")
|
| 585 |
+
|
| 586 |
+
# --- generate ---
|
| 587 |
+
gen = sub.add_parser("generate", help="Generate dataset")
|
| 588 |
+
gen.add_argument("--output-dir", type=str, default="maze")
|
| 589 |
+
gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 24, 32])
|
| 590 |
+
gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000, 2000])
|
| 591 |
+
gen.add_argument("--min-path-ratio", type=float, default=0.3,
|
| 592 |
+
help="Min path length as fraction of size²")
|
| 593 |
+
gen.add_argument("--img-size", type=int, default=1024)
|
| 594 |
+
gen.add_argument("--prompt", type=str,
|
| 595 |
+
default="Draw a continuous red line from the yellow dot "
|
| 596 |
+
"to the blue dot, avoiding all walls.")
|
| 597 |
+
gen.add_argument("--train-ratio", type=float, default=0.9)
|
| 598 |
+
gen.add_argument("--n-start", type=int, default=2,
|
| 599 |
+
help="Hold frames at video start (blank puzzle)")
|
| 600 |
+
gen.add_argument("--m-end", type=int, default=3,
|
| 601 |
+
help="Hold frames at video end (completed solution)")
|
| 602 |
+
gen.add_argument("--frames", type=int, default=None,
|
| 603 |
+
help="Content frames per video (None=auto 1 per step)")
|
| 604 |
+
gen.add_argument("--fps", type=int, default=10)
|
| 605 |
+
gen.add_argument("--seed", type=int, default=42)
|
| 606 |
+
gen.add_argument("--checkpoint-interval", type=int, default=50)
|
| 607 |
+
|
| 608 |
+
# --- eval ---
|
| 609 |
+
ev = sub.add_parser("eval",
|
| 610 |
+
help="Evaluate result videos (last frame → extract → verify)")
|
| 611 |
+
ev.add_argument("video_dir", type=str,
|
| 612 |
+
help="Directory containing result .mp4 files")
|
| 613 |
+
ev.add_argument("--text-dir", type=str, required=True,
|
| 614 |
+
help="Directory with ground-truth maze .txt files")
|
| 615 |
+
ev.add_argument("--output-json", type=str, default=None,
|
| 616 |
+
help="Output JSON for extracted paths (default: video_dir/0_result.json)")
|
| 617 |
+
ev.add_argument("--gt-json", type=str, default=None,
|
| 618 |
+
help="Optional ground-truth path.json for length-binned accuracy")
|
| 619 |
+
|
| 620 |
+
# --- verify ---
|
| 621 |
+
ver = sub.add_parser("verify", help="Verify a pre-extracted JSON of UDRL paths")
|
| 622 |
+
ver.add_argument("json_file", type=str)
|
| 623 |
+
ver.add_argument("--text-dir", type=str, required=True,
|
| 624 |
+
help="Directory with maze .txt files")
|
| 625 |
+
|
| 626 |
+
return p.parse_args()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
if __name__ == "__main__":
|
| 630 |
+
args = parse_args()
|
| 631 |
+
|
| 632 |
+
if args.command == "generate":
|
| 633 |
+
kwargs = {k: v for k, v in vars(args).items() if k != "command"}
|
| 634 |
+
generate_dataset(**kwargs)
|
| 635 |
+
|
| 636 |
+
elif args.command == "eval":
|
| 637 |
+
eval_videos(
|
| 638 |
+
video_dir=args.video_dir,
|
| 639 |
+
text_dir=args.text_dir,
|
| 640 |
+
output_json=args.output_json,
|
| 641 |
+
gt_json=args.gt_json,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
elif args.command == "verify":
|
| 645 |
+
verify_results(args.json_file, args.text_dir)
|
| 646 |
+
|
| 647 |
+
else:
|
| 648 |
+
print("Usage: python maze_video_gen.py {generate|eval|verify} [options]")
|
| 649 |
+
print(" python maze_video_gen.py generate --help")
|
| 650 |
+
print(" python maze_video_gen.py eval --help")
|
| 651 |
+
print(" python maze_video_gen.py verify --help")
|
maze/maze_processor.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MazeProcessor - Maze generation, solving, rendering, and video frame generation.
|
| 3 |
+
|
| 4 |
+
Mirrors the SudokuProcessor pattern: a single class encapsulating all maze logic
|
| 5 |
+
including DFS generation, BFS solving, image/video rendering, path verification,
|
| 6 |
+
and text serialization.
|
| 7 |
+
"""
|
| 8 |
+
import random
|
| 9 |
+
from collections import deque
|
| 10 |
+
from typing import List, Tuple, Optional, Dict
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image, ImageDraw
|
| 14 |
+
|
| 15 |
+
# Wall bitmask encoding (matches text file format)
|
| 16 |
+
WALL_MASKS = {"N": 1, "S": 2, "W": 4, "E": 8}
|
| 17 |
+
OPPOSITE = {"N": "S", "S": "N", "E": "W", "W": "E"}
|
| 18 |
+
MOVES = {
|
| 19 |
+
"U": (-1, 0, "N"),
|
| 20 |
+
"D": (1, 0, "S"),
|
| 21 |
+
"L": (0, -1, "W"),
|
| 22 |
+
"R": (0, 1, "E"),
|
| 23 |
+
}
|
| 24 |
+
NEIGHBORS = {
|
| 25 |
+
"N": (-1, 0),
|
| 26 |
+
"S": (1, 0),
|
| 27 |
+
"E": (0, 1),
|
| 28 |
+
"W": (0, -1),
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ======================== Grid Type ========================
|
| 33 |
+
# grid[r][c] = {"N": bool, "S": bool, "W": bool, "E": bool}
|
| 34 |
+
# True => wall present, False => passage open
|
| 35 |
+
|
| 36 |
+
Grid = List[List[Dict[str, bool]]]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MazeProcessor:
|
| 40 |
+
"""Handles maze generation, solving, image rendering, and video frame creation."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, img_size: int = 512):
|
| 43 |
+
self.img_size = img_size
|
| 44 |
+
|
| 45 |
+
# Rendering colours (RGB)
|
| 46 |
+
self.bg_color = "black"
|
| 47 |
+
self.cell_color = "white"
|
| 48 |
+
self.wall_color = "black"
|
| 49 |
+
self.grid_color = (224, 224, 224)
|
| 50 |
+
self.start_color = "yellow"
|
| 51 |
+
self.end_color = "blue"
|
| 52 |
+
self.path_color = "red"
|
| 53 |
+
|
| 54 |
+
# ==================== Generation (DFS) ====================
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def _empty_grid(n: int) -> Grid:
|
| 58 |
+
"""Create an n×n grid with all walls present."""
|
| 59 |
+
return [
|
| 60 |
+
[{"N": True, "E": True, "S": True, "W": True} for _ in range(n)]
|
| 61 |
+
for _ in range(n)
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _remove_wall(grid: Grid, r: int, c: int, direction: str) -> None:
|
| 66 |
+
"""Remove wall between (r,c) and its neighbour in *direction*."""
|
| 67 |
+
dr, dc = NEIGHBORS[direction]
|
| 68 |
+
grid[r][c][direction] = False
|
| 69 |
+
grid[r + dr][c + dc][OPPOSITE[direction]] = False
|
| 70 |
+
|
| 71 |
+
def generate(
|
| 72 |
+
self, size: int, min_path_len: int = 1, max_attempts: int = 200
|
| 73 |
+
) -> Tuple[Grid, Tuple[int, int], Tuple[int, int], np.ndarray]:
|
| 74 |
+
"""
|
| 75 |
+
Generate a perfect maze and a start/end pair whose shortest path
|
| 76 |
+
length >= *min_path_len*.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(grid, start, end, path) where path is an (L, 2) int array.
|
| 80 |
+
"""
|
| 81 |
+
for _ in range(max_attempts):
|
| 82 |
+
grid = self._gen_dfs(size)
|
| 83 |
+
nodes = [(r, c) for r in range(size) for c in range(size)]
|
| 84 |
+
start, end = random.sample(nodes, 2)
|
| 85 |
+
path = self.solve_bfs(grid, start, end)
|
| 86 |
+
if path is not None and len(path) >= min_path_len:
|
| 87 |
+
return grid, tuple(start), tuple(end), path
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"Failed to generate maze (size={size}, min_path_len={min_path_len}) "
|
| 90 |
+
f"after {max_attempts} attempts."
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _gen_dfs(self, n: int) -> Grid:
|
| 94 |
+
"""Randomised DFS (iterative) to carve a perfect maze."""
|
| 95 |
+
grid = self._empty_grid(n)
|
| 96 |
+
visited = [[False] * n for _ in range(n)]
|
| 97 |
+
sr, sc = random.randrange(n), random.randrange(n)
|
| 98 |
+
visited[sr][sc] = True
|
| 99 |
+
stack = [(sr, sc)]
|
| 100 |
+
|
| 101 |
+
while stack:
|
| 102 |
+
r, c = stack[-1]
|
| 103 |
+
nbrs = []
|
| 104 |
+
for d, (dr, dc) in NEIGHBORS.items():
|
| 105 |
+
nr, nc = r + dr, c + dc
|
| 106 |
+
if 0 <= nr < n and 0 <= nc < n and not visited[nr][nc]:
|
| 107 |
+
nbrs.append((d, nr, nc))
|
| 108 |
+
if nbrs:
|
| 109 |
+
d, nr, nc = random.choice(nbrs)
|
| 110 |
+
self._remove_wall(grid, r, c, d)
|
| 111 |
+
visited[nr][nc] = True
|
| 112 |
+
stack.append((nr, nc))
|
| 113 |
+
else:
|
| 114 |
+
stack.pop()
|
| 115 |
+
return grid
|
| 116 |
+
|
| 117 |
+
# ==================== Solving (BFS) ====================
|
| 118 |
+
|
| 119 |
+
def solve_bfs(
|
| 120 |
+
self, grid: Grid, start: Tuple[int, int], end: Tuple[int, int]
|
| 121 |
+
) -> Optional[np.ndarray]:
|
| 122 |
+
"""BFS shortest path. Returns (L,2) int ndarray or None."""
|
| 123 |
+
n = len(grid)
|
| 124 |
+
q: deque = deque([(start, [start])])
|
| 125 |
+
visited = {start}
|
| 126 |
+
|
| 127 |
+
while q:
|
| 128 |
+
(r, c), path = q.popleft()
|
| 129 |
+
if (r, c) == end:
|
| 130 |
+
return np.array(path, dtype=int)
|
| 131 |
+
cell = grid[r][c]
|
| 132 |
+
for d, (dr, dc) in NEIGHBORS.items():
|
| 133 |
+
nr, nc = r + dr, c + dc
|
| 134 |
+
if (
|
| 135 |
+
0 <= nr < n
|
| 136 |
+
and 0 <= nc < n
|
| 137 |
+
and not cell[d]
|
| 138 |
+
and (nr, nc) not in visited
|
| 139 |
+
):
|
| 140 |
+
visited.add((nr, nc))
|
| 141 |
+
q.append(((nr, nc), path + [(nr, nc)]))
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
# ==================== Path ↔ UDRL ====================
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def path_to_udrl(path) -> str:
|
| 148 |
+
"""Convert coordinate path to UDRL string."""
|
| 149 |
+
moves = []
|
| 150 |
+
for i in range(len(path) - 1):
|
| 151 |
+
r1, c1 = path[i]
|
| 152 |
+
r2, c2 = path[i + 1]
|
| 153 |
+
if r2 < r1:
|
| 154 |
+
moves.append("U")
|
| 155 |
+
elif r2 > r1:
|
| 156 |
+
moves.append("D")
|
| 157 |
+
elif c2 < c1:
|
| 158 |
+
moves.append("L")
|
| 159 |
+
else:
|
| 160 |
+
moves.append("R")
|
| 161 |
+
return "".join(moves)
|
| 162 |
+
|
| 163 |
+
# ==================== Verification ====================
|
| 164 |
+
|
| 165 |
+
def verify_path(self, grid: Grid, start: Tuple, end: Tuple, udrl: str) -> bool:
|
| 166 |
+
"""Verify that *udrl* is a wall-respecting walk from *start* to *end*."""
|
| 167 |
+
n = len(grid)
|
| 168 |
+
r, c = start
|
| 169 |
+
for ch in udrl.replace(",", "").replace(" ", "").strip():
|
| 170 |
+
if ch not in MOVES:
|
| 171 |
+
continue
|
| 172 |
+
dr, dc, wall = MOVES[ch]
|
| 173 |
+
if grid[r][c][wall]:
|
| 174 |
+
return False
|
| 175 |
+
nr, nc = r + dr, c + dc
|
| 176 |
+
if not (0 <= nr < n and 0 <= nc < n):
|
| 177 |
+
return False
|
| 178 |
+
r, c = nr, nc
|
| 179 |
+
return (r, c) == end
|
| 180 |
+
|
| 181 |
+
# ==================== Text Encoding ====================
|
| 182 |
+
|
| 183 |
+
def encode_grid(self, grid: Grid) -> str:
|
| 184 |
+
"""Encode grid to compact bitmask string (one int per cell, row-major)."""
|
| 185 |
+
rows = []
|
| 186 |
+
for row in grid:
|
| 187 |
+
vals = []
|
| 188 |
+
for cell in row:
|
| 189 |
+
v = 0
|
| 190 |
+
for d, mask in WALL_MASKS.items():
|
| 191 |
+
if cell[d]:
|
| 192 |
+
v |= mask
|
| 193 |
+
vals.append(str(v))
|
| 194 |
+
rows.append(" ".join(vals))
|
| 195 |
+
return "\n".join(rows)
|
| 196 |
+
|
| 197 |
+
def decode_grid(self, text_lines: List[str]) -> Grid:
|
| 198 |
+
"""Decode bitmask text lines back to grid dicts."""
|
| 199 |
+
grid = []
|
| 200 |
+
for line in text_lines:
|
| 201 |
+
row = []
|
| 202 |
+
for val_s in line.split():
|
| 203 |
+
val = int(val_s)
|
| 204 |
+
row.append({d: bool(val & mask) for d, mask in WALL_MASKS.items()})
|
| 205 |
+
grid.append(row)
|
| 206 |
+
return grid
|
| 207 |
+
|
| 208 |
+
def save_text(self, filepath, grid: Grid, start: Tuple, end: Tuple) -> None:
|
| 209 |
+
"""Save maze to compact text file."""
|
| 210 |
+
n = len(grid)
|
| 211 |
+
with open(filepath, "w") as f:
|
| 212 |
+
f.write(f"{n}\n{start[0]} {start[1]}\n{end[0]} {end[1]}\n")
|
| 213 |
+
f.write(self.encode_grid(grid) + "\n")
|
| 214 |
+
|
| 215 |
+
def load_text(self, filepath) -> Optional[Dict]:
|
| 216 |
+
"""
|
| 217 |
+
Load maze from text file.
|
| 218 |
+
|
| 219 |
+
Returns dict with keys: size, start, end, grid (dict-based),
|
| 220 |
+
grid_raw (list[list[int]] bitmask). None on failure.
|
| 221 |
+
"""
|
| 222 |
+
try:
|
| 223 |
+
with open(filepath) as f:
|
| 224 |
+
lines = [l.strip() for l in f if l.strip()]
|
| 225 |
+
n = int(lines[0])
|
| 226 |
+
sr, sc = map(int, lines[1].split())
|
| 227 |
+
er, ec = map(int, lines[2].split())
|
| 228 |
+
grid = self.decode_grid(lines[3 : 3 + n])
|
| 229 |
+
grid_raw: List[List[int]] = []
|
| 230 |
+
for r in range(n):
|
| 231 |
+
grid_raw.append(list(map(int, lines[3 + r].split())))
|
| 232 |
+
return {
|
| 233 |
+
"size": n,
|
| 234 |
+
"start": (sr, sc),
|
| 235 |
+
"end": (er, ec),
|
| 236 |
+
"grid": grid,
|
| 237 |
+
"grid_raw": grid_raw,
|
| 238 |
+
}
|
| 239 |
+
except Exception:
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
def fingerprint(self, grid: Grid, start: Tuple, end: Tuple) -> str:
|
| 243 |
+
"""Content fingerprint for deduplication."""
|
| 244 |
+
n = len(grid)
|
| 245 |
+
parts = [f"{n},{start[0]},{start[1]},{end[0]},{end[1]}"]
|
| 246 |
+
for row in grid:
|
| 247 |
+
for cell in row:
|
| 248 |
+
v = sum(WALL_MASKS[d] for d in WALL_MASKS if cell[d])
|
| 249 |
+
parts.append(str(v))
|
| 250 |
+
return "|".join(parts)
|
| 251 |
+
|
| 252 |
+
# ==================== Image Rendering ====================
|
| 253 |
+
|
| 254 |
+
def _layout(self, n: int):
|
| 255 |
+
"""Compute rendering layout parameters."""
|
| 256 |
+
cell_f = float(self.img_size) / n
|
| 257 |
+
wall_f = cell_f / 4.0
|
| 258 |
+
half_f = wall_f / 2.0
|
| 259 |
+
grid_w = max(1, int(cell_f / 16.0))
|
| 260 |
+
return cell_f, wall_f, half_f, grid_w
|
| 261 |
+
|
| 262 |
+
def render(
|
| 263 |
+
self,
|
| 264 |
+
grid: Grid,
|
| 265 |
+
start: Tuple[int, int],
|
| 266 |
+
end: Tuple[int, int],
|
| 267 |
+
path: Optional[np.ndarray] = None,
|
| 268 |
+
path_steps: Optional[int] = None,
|
| 269 |
+
) -> Image.Image:
|
| 270 |
+
"""
|
| 271 |
+
Render maze as a PIL image.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
grid: The maze grid.
|
| 275 |
+
start, end: Coordinates of start/end cells.
|
| 276 |
+
path: Full solution path (L, 2).
|
| 277 |
+
path_steps: Draw only the first *path_steps* segments (for video).
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
PIL.Image (RGB, img_size × img_size).
|
| 281 |
+
"""
|
| 282 |
+
n = len(grid)
|
| 283 |
+
cell_f, wall_f, half_f, grid_w = self._layout(n)
|
| 284 |
+
|
| 285 |
+
img = Image.new("RGB", (self.img_size, self.img_size), self.bg_color)
|
| 286 |
+
draw = ImageDraw.Draw(img)
|
| 287 |
+
|
| 288 |
+
# --- fill cells & open passages ---
|
| 289 |
+
for r in range(n):
|
| 290 |
+
for c in range(n):
|
| 291 |
+
x1 = c * cell_f + half_f
|
| 292 |
+
y1 = r * cell_f + half_f
|
| 293 |
+
x2 = (c + 1) * cell_f - half_f
|
| 294 |
+
y2 = (r + 1) * cell_f - half_f
|
| 295 |
+
draw.rectangle([(x1, y1), (x2, y2)], fill=self.cell_color)
|
| 296 |
+
cell = grid[r][c]
|
| 297 |
+
if not cell["S"] and r < n - 1:
|
| 298 |
+
draw.rectangle(
|
| 299 |
+
[(x1, y2), (x2, y2 + wall_f)], fill=self.cell_color
|
| 300 |
+
)
|
| 301 |
+
if not cell["E"] and c < n - 1:
|
| 302 |
+
draw.rectangle(
|
| 303 |
+
[(x2, y1), (x2 + wall_f, y2)], fill=self.cell_color
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# --- subtle grid lines on open passages ---
|
| 307 |
+
for r in range(n):
|
| 308 |
+
for c in range(n):
|
| 309 |
+
if r < n - 1 and not grid[r][c]["S"]:
|
| 310 |
+
y = (r + 1) * cell_f
|
| 311 |
+
draw.line(
|
| 312 |
+
[(c * cell_f + half_f, y), ((c + 1) * cell_f - half_f, y)],
|
| 313 |
+
fill=self.grid_color, width=grid_w,
|
| 314 |
+
)
|
| 315 |
+
if c < n - 1 and not grid[r][c]["E"]:
|
| 316 |
+
x = (c + 1) * cell_f
|
| 317 |
+
draw.line(
|
| 318 |
+
[(x, r * cell_f + half_f), (x, (r + 1) * cell_f - half_f)],
|
| 319 |
+
fill=self.grid_color, width=grid_w,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# --- start / end dots ---
|
| 323 |
+
def _dot(rc, color):
|
| 324 |
+
rr, cc = rc
|
| 325 |
+
cx = cc * cell_f + cell_f / 2
|
| 326 |
+
cy = rr * cell_f + cell_f / 2
|
| 327 |
+
rad = max(2, int((cell_f - wall_f) * 0.25))
|
| 328 |
+
draw.ellipse([cx - rad, cy - rad, cx + rad, cy + rad], fill=color)
|
| 329 |
+
|
| 330 |
+
_dot(start, self.start_color)
|
| 331 |
+
_dot(end, self.end_color)
|
| 332 |
+
|
| 333 |
+
# --- solution path (optionally partial) ---
|
| 334 |
+
if path is not None and len(path) >= 2:
|
| 335 |
+
end_idx = (
|
| 336 |
+
len(path) if path_steps is None
|
| 337 |
+
else min(path_steps + 1, len(path))
|
| 338 |
+
)
|
| 339 |
+
if end_idx >= 2:
|
| 340 |
+
pts = [
|
| 341 |
+
(c * cell_f + cell_f / 2, r * cell_f + cell_f / 2)
|
| 342 |
+
for r, c in path[:end_idx]
|
| 343 |
+
]
|
| 344 |
+
draw.line(
|
| 345 |
+
pts, fill=self.path_color,
|
| 346 |
+
width=max(1, int(wall_f)), joint="curve",
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return img
|
| 350 |
+
|
| 351 |
+
# ==================== Video Frame Generation ====================
|
| 352 |
+
|
| 353 |
+
def generate_video_frames(
|
| 354 |
+
self,
|
| 355 |
+
grid: Grid,
|
| 356 |
+
start: Tuple[int, int],
|
| 357 |
+
end: Tuple[int, int],
|
| 358 |
+
path: np.ndarray,
|
| 359 |
+
n_start: int = 5,
|
| 360 |
+
m_end: int = 5,
|
| 361 |
+
frames: Optional[int] = None,
|
| 362 |
+
) -> List[Image.Image]:
|
| 363 |
+
"""
|
| 364 |
+
Generate progressive video frames showing the red line growing.
|
| 365 |
+
|
| 366 |
+
*frames* controls the number of **content frames** between holds:
|
| 367 |
+
- None → 1 per path step
|
| 368 |
+
- frames > steps → slow-motion
|
| 369 |
+
- frames < steps → fast-forward
|
| 370 |
+
|
| 371 |
+
Total length = n_start + content_frames + m_end.
|
| 372 |
+
"""
|
| 373 |
+
n_steps = len(path) - 1
|
| 374 |
+
if n_steps <= 0:
|
| 375 |
+
blank = self.render(grid, start, end)
|
| 376 |
+
return [blank] * (n_start + m_end + 1)
|
| 377 |
+
|
| 378 |
+
content_frames = frames if frames is not None else n_steps
|
| 379 |
+
content_frames = max(1, content_frames)
|
| 380 |
+
|
| 381 |
+
result: List[Image.Image] = []
|
| 382 |
+
|
| 383 |
+
# Opening hold
|
| 384 |
+
blank = self.render(grid, start, end)
|
| 385 |
+
result.extend([blank.copy() for _ in range(n_start)])
|
| 386 |
+
|
| 387 |
+
# Content frames
|
| 388 |
+
if content_frames == n_steps:
|
| 389 |
+
for step in range(1, n_steps + 1):
|
| 390 |
+
result.append(
|
| 391 |
+
self.render(grid, start, end, path=path, path_steps=step)
|
| 392 |
+
)
|
| 393 |
+
elif content_frames > n_steps:
|
| 394 |
+
for step in range(1, n_steps + 1):
|
| 395 |
+
f_lo = (step - 1) * content_frames // n_steps
|
| 396 |
+
f_hi = step * content_frames // n_steps
|
| 397 |
+
count = f_hi - f_lo
|
| 398 |
+
frame_img = self.render(
|
| 399 |
+
grid, start, end, path=path, path_steps=step
|
| 400 |
+
)
|
| 401 |
+
result.append(frame_img)
|
| 402 |
+
if count > 1:
|
| 403 |
+
result.extend([frame_img.copy() for _ in range(count - 1)])
|
| 404 |
+
else:
|
| 405 |
+
for f in range(content_frames):
|
| 406 |
+
step = (f + 1) * n_steps // content_frames
|
| 407 |
+
result.append(
|
| 408 |
+
self.render(grid, start, end, path=path, path_steps=step)
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Closing hold
|
| 412 |
+
final = self.render(grid, start, end, path=path)
|
| 413 |
+
result.extend([final.copy() for _ in range(m_end)])
|
| 414 |
+
|
| 415 |
+
return result
|
| 416 |
+
|
| 417 |
+
# ==================== Red-Path Extraction ====================
|
| 418 |
+
|
| 419 |
+
def extract_path_from_pixels(
|
| 420 |
+
self,
|
| 421 |
+
pixels: np.ndarray,
|
| 422 |
+
grid_raw: List[List[int]],
|
| 423 |
+
size: int,
|
| 424 |
+
start: Tuple[int, int],
|
| 425 |
+
pixel_threshold: float = 0.01,
|
| 426 |
+
) -> str:
|
| 427 |
+
"""
|
| 428 |
+
Detect red path in an RGB pixel array and return UDRL.
|
| 429 |
+
|
| 430 |
+
Uses **floating-point** cell boundaries matching the renderer to avoid
|
| 431 |
+
misalignment on sizes that don't evenly divide the image (e.g. 24, 48).
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
pixels: (H, W, 3) uint8 RGB array.
|
| 435 |
+
grid_raw: Bitmask grid as list[list[int]].
|
| 436 |
+
size: Maze dimension n.
|
| 437 |
+
start: Start coordinate (r, c).
|
| 438 |
+
pixel_threshold: Min red-pixel fraction to mark a cell.
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
UDRL action string.
|
| 442 |
+
"""
|
| 443 |
+
img = Image.fromarray(pixels)
|
| 444 |
+
w, h = img.size
|
| 445 |
+
px = np.array(img, dtype=float)
|
| 446 |
+
|
| 447 |
+
r_ch, g_ch, b_ch = px[:, :, 0], px[:, :, 1], px[:, :, 2]
|
| 448 |
+
red_mask = (r_ch > 100) & (r_ch > g_ch * 1.2) & (r_ch > b_ch * 1.2)
|
| 449 |
+
|
| 450 |
+
# Use FLOAT cell size to match render() coordinate system exactly.
|
| 451 |
+
# Integer division (h // size) drifts by up to (size-1) * fractional
|
| 452 |
+
# pixels, causing the last cells to be completely misaligned.
|
| 453 |
+
cell_h_f = h / size
|
| 454 |
+
cell_w_f = w / size
|
| 455 |
+
|
| 456 |
+
path_grid = np.zeros((size, size), dtype=bool)
|
| 457 |
+
for r in range(size):
|
| 458 |
+
y0 = int(round(r * cell_h_f))
|
| 459 |
+
y1 = int(round((r + 1) * cell_h_f))
|
| 460 |
+
for c in range(size):
|
| 461 |
+
x0 = int(round(c * cell_w_f))
|
| 462 |
+
x1 = int(round((c + 1) * cell_w_f))
|
| 463 |
+
# Small inward margin to avoid wall / neighbour bleed-over
|
| 464 |
+
ch = y1 - y0
|
| 465 |
+
cw = x1 - x0
|
| 466 |
+
margin_y = max(1, int(ch * 0.15))
|
| 467 |
+
margin_x = max(1, int(cw * 0.15))
|
| 468 |
+
sub = red_mask[y0 + margin_y : y1 - margin_y,
|
| 469 |
+
x0 + margin_x : x1 - margin_x]
|
| 470 |
+
if sub.size > 0 and np.mean(sub) > pixel_threshold:
|
| 471 |
+
path_grid[r, c] = True
|
| 472 |
+
|
| 473 |
+
# Greedy walk from start, respecting maze walls
|
| 474 |
+
directions = [
|
| 475 |
+
("R", MOVES["R"]),
|
| 476 |
+
("D", MOVES["D"]),
|
| 477 |
+
("L", MOVES["L"]),
|
| 478 |
+
("U", MOVES["U"]),
|
| 479 |
+
]
|
| 480 |
+
visited = {start}
|
| 481 |
+
cr, cc = start
|
| 482 |
+
actions: List[str] = []
|
| 483 |
+
for _ in range(size * size * 2):
|
| 484 |
+
found = False
|
| 485 |
+
wval = grid_raw[cr][cc]
|
| 486 |
+
for act, (dr, dc, wall_ch) in directions:
|
| 487 |
+
nr, nc = cr + dr, cc + dc
|
| 488 |
+
if 0 <= nr < size and 0 <= nc < size:
|
| 489 |
+
if (wval & WALL_MASKS[wall_ch]) != 0:
|
| 490 |
+
continue
|
| 491 |
+
if path_grid[nr, nc] and (nr, nc) not in visited:
|
| 492 |
+
visited.add((nr, nc))
|
| 493 |
+
actions.append(act)
|
| 494 |
+
cr, cc = nr, nc
|
| 495 |
+
found = True
|
| 496 |
+
break
|
| 497 |
+
if not found:
|
| 498 |
+
break
|
| 499 |
+
return "".join(actions)
|
| 500 |
+
|
| 501 |
+
def extract_path_from_image(
|
| 502 |
+
self, img_path: str, grid_raw: List[List[int]], size: int, start: Tuple
|
| 503 |
+
) -> str:
|
| 504 |
+
"""Extract UDRL from an image file (convenience wrapper)."""
|
| 505 |
+
try:
|
| 506 |
+
pixels = np.array(Image.open(img_path).convert("RGB"))
|
| 507 |
+
return self.extract_path_from_pixels(pixels, grid_raw, size, start)
|
| 508 |
+
except Exception:
|
| 509 |
+
return ""
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
proc = MazeProcessor(img_size=512)
|
| 514 |
+
|
| 515 |
+
# Quick smoke test
|
| 516 |
+
grid, start, end, path = proc.generate(size=8, min_path_len=10)
|
| 517 |
+
n_steps = len(path) - 1
|
| 518 |
+
print(f"Maze 8×8 | path length {len(path)} | steps {n_steps}")
|
| 519 |
+
print(f"UDRL: {proc.path_to_udrl(path)}")
|
| 520 |
+
print(f"Verify: {proc.verify_path(grid, start, end, proc.path_to_udrl(path))}")
|
| 521 |
+
|
| 522 |
+
proc.render(grid, start, end).save("test_maze.png")
|
| 523 |
+
proc.render(grid, start, end, path=path).save("test_maze_solution.png")
|
| 524 |
+
|
| 525 |
+
# Test video frame modes
|
| 526 |
+
f1 = proc.generate_video_frames(grid, start, end, path, n_start=3, m_end=3)
|
| 527 |
+
assert len(f1) == 3 + n_steps + 3
|
| 528 |
+
|
| 529 |
+
f2 = proc.generate_video_frames(
|
| 530 |
+
grid, start, end, path, n_start=3, m_end=3, frames=n_steps * 3
|
| 531 |
+
)
|
| 532 |
+
assert len(f2) == 3 + n_steps * 3 + 3
|
| 533 |
+
|
| 534 |
+
half = max(1, n_steps // 2)
|
| 535 |
+
f3 = proc.generate_video_frames(
|
| 536 |
+
grid, start, end, path, n_start=3, m_end=3, frames=half
|
| 537 |
+
)
|
| 538 |
+
assert len(f3) == 3 + half + 3
|
| 539 |
+
|
| 540 |
+
print(f"frames=None → {len(f1)} total ({n_steps} content)")
|
| 541 |
+
print(f"frames={n_steps*3:<4d} → {len(f2)} total (slow-mo)")
|
| 542 |
+
print(f"frames={half:<4d} → {len(f3)} total (fast-fwd)")
|
| 543 |
+
print("All assertions passed ✓")
|
sudoku/generate_dataset.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle.
|
| 3 |
With checkpoint/resume support via metadata.json.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
import json
|
| 6 |
import re
|
|
@@ -8,7 +13,7 @@ import random
|
|
| 8 |
import argparse
|
| 9 |
from dataclasses import dataclass, asdict
|
| 10 |
from pathlib import Path
|
| 11 |
-
from typing import List, Tuple, Optional,
|
| 12 |
import numpy as np
|
| 13 |
import cv2
|
| 14 |
from tqdm import tqdm
|
|
@@ -22,7 +27,7 @@ class SolRange:
|
|
| 22 |
"""Flexible solution count constraint for puzzle generation."""
|
| 23 |
min_sol: int
|
| 24 |
max_sol: Optional[int]
|
| 25 |
-
|
| 26 |
@classmethod
|
| 27 |
def parse(cls, expr: str) -> "SolRange":
|
| 28 |
expr = expr.strip()
|
|
@@ -46,7 +51,7 @@ class SolRange:
|
|
| 46 |
if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}")
|
| 47 |
return cls(min_sol=n, max_sol=n)
|
| 48 |
raise ValueError(f"Invalid sol_num expression: '{expr}'")
|
| 49 |
-
|
| 50 |
@property
|
| 51 |
def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol
|
| 52 |
@property
|
|
@@ -77,10 +82,10 @@ class GenerationState:
|
|
| 77 |
seen_grids: List[str]
|
| 78 |
all_samples: List[Dict]
|
| 79 |
completed: bool = False
|
| 80 |
-
|
| 81 |
def to_dict(self) -> Dict:
|
| 82 |
return asdict(self)
|
| 83 |
-
|
| 84 |
@classmethod
|
| 85 |
def from_dict(cls, d: Dict) -> "GenerationState":
|
| 86 |
return cls(**d)
|
|
@@ -89,9 +94,7 @@ class GenerationState:
|
|
| 89 |
def compute_params_hash(params: Dict) -> str:
|
| 90 |
"""Compute hash of generation parameters for consistency check."""
|
| 91 |
import hashlib
|
| 92 |
-
|
| 93 |
-
key_params = {k: v for k, v in params.items()
|
| 94 |
-
if k not in ['output_dir']} # output_dir can differ
|
| 95 |
return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12]
|
| 96 |
|
| 97 |
|
|
@@ -100,21 +103,16 @@ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]
|
|
| 100 |
meta_path = output_dir / "metadata.json"
|
| 101 |
if not meta_path.exists():
|
| 102 |
return None
|
| 103 |
-
|
| 104 |
with open(meta_path) as f:
|
| 105 |
data = json.load(f)
|
| 106 |
-
|
| 107 |
state = GenerationState.from_dict(data["state"])
|
| 108 |
expected_hash = compute_params_hash(params)
|
| 109 |
-
|
| 110 |
if state.params_hash != expected_hash:
|
| 111 |
print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh")
|
| 112 |
return None
|
| 113 |
-
|
| 114 |
if state.completed:
|
| 115 |
print("✓ Generation already completed")
|
| 116 |
return state
|
| 117 |
-
|
| 118 |
print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated")
|
| 119 |
return state
|
| 120 |
|
|
@@ -122,65 +120,136 @@ def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]
|
|
| 122 |
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
|
| 123 |
"""Save current generation state to metadata.json."""
|
| 124 |
meta_path = output_dir / "metadata.json"
|
| 125 |
-
data = {
|
| 126 |
-
"params": params,
|
| 127 |
-
"state": state.to_dict()
|
| 128 |
-
}
|
| 129 |
-
# Atomic write
|
| 130 |
tmp_path = meta_path.with_suffix('.tmp')
|
| 131 |
with open(tmp_path, 'w') as f:
|
| 132 |
-
json.dump(
|
| 133 |
tmp_path.rename(meta_path)
|
| 134 |
|
| 135 |
|
| 136 |
# ==================== Core Functions ====================
|
| 137 |
|
| 138 |
def get_fill_order(puzzle, solution):
|
|
|
|
| 139 |
return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0]
|
| 140 |
|
|
|
|
| 141 |
def create_processor(resolution=None):
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
target_size = min(resolution)
|
| 144 |
cell_size = target_size // 9
|
| 145 |
sf = cell_size / 60
|
| 146 |
-
return SudokuProcessor(
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
def generate_video_frames(proc, puzzle, solution, n_start, m_end,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
fills = get_fill_order(puzzle, solution)
|
| 150 |
n_fills = len(fills)
|
| 151 |
-
|
| 152 |
-
if
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
| 157 |
current = [row[:] for row in puzzle]
|
|
|
|
|
|
|
| 158 |
img = proc.render(current)
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
img = proc.render(solution, original=puzzle)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
|
| 174 |
def save_video(frames, path, fps=10):
|
|
|
|
| 175 |
h, w = frames[0].shape[:2]
|
| 176 |
writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
| 177 |
-
for f in frames:
|
|
|
|
| 178 |
writer.release()
|
| 179 |
|
|
|
|
| 180 |
def normalize_num_per_clue(num_per_clue, clue_levels):
|
| 181 |
-
|
|
|
|
|
|
|
| 182 |
if len(num_per_clue) != len(clue_levels):
|
| 183 |
-
raise ValueError(
|
|
|
|
|
|
|
| 184 |
return num_per_clue
|
| 185 |
|
| 186 |
|
|
@@ -191,7 +260,7 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
|
|
| 191 |
if sol_range.is_unique_only:
|
| 192 |
puzzle, solution = proc.generate(clue, unique=True)
|
| 193 |
return puzzle, [solution]
|
| 194 |
-
|
| 195 |
if sol_range.requires_multi:
|
| 196 |
try:
|
| 197 |
puzzle, solutions = proc.generate_multi_solution(
|
|
@@ -204,7 +273,7 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
|
|
| 204 |
except RuntimeError:
|
| 205 |
pass
|
| 206 |
return None
|
| 207 |
-
|
| 208 |
try:
|
| 209 |
puzzle, solutions = proc.generate_multi_solution(
|
| 210 |
clue, min_solutions=max(2, sol_range.min_sol),
|
|
@@ -215,7 +284,7 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
|
|
| 215 |
return puzzle, solutions
|
| 216 |
except RuntimeError:
|
| 217 |
pass
|
| 218 |
-
|
| 219 |
if sol_range.allows_unique:
|
| 220 |
puzzle, solution = proc.generate(clue, unique=True)
|
| 221 |
return puzzle, [solution]
|
|
@@ -225,45 +294,48 @@ def generate_puzzle_with_range(proc, clue, sol_range, min_hamming):
|
|
| 225 |
# ==================== Dataset Generation ====================
|
| 226 |
|
| 227 |
def generate_dataset(
|
| 228 |
-
output_dir="
|
| 229 |
-
|
|
|
|
| 230 |
prompt="Solve this Sudoku puzzle using red font.",
|
| 231 |
-
n_start=
|
| 232 |
resolution=None, seed=42, checkpoint_interval=50
|
| 233 |
):
|
| 234 |
"""
|
| 235 |
Generate Sudoku video dataset with checkpoint/resume support.
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
Args:
|
| 238 |
checkpoint_interval: Save checkpoint every N puzzles (default: 50)
|
| 239 |
"""
|
| 240 |
-
# Prepare params dict for hashing
|
| 241 |
params = {
|
| 242 |
"clue_levels": clue_levels, "num_per_clue": num_per_clue,
|
| 243 |
"sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio,
|
| 244 |
-
"prompt": prompt, "n_start": n_start, "m_end": m_end, "
|
| 245 |
-
"
|
| 246 |
}
|
| 247 |
-
|
| 248 |
output_dir = Path(output_dir)
|
| 249 |
video_dir = output_dir / "videos"
|
| 250 |
image_dir = output_dir / "images"
|
| 251 |
video_dir.mkdir(parents=True, exist_ok=True)
|
| 252 |
image_dir.mkdir(parents=True, exist_ok=True)
|
| 253 |
-
|
| 254 |
# Try to resume from checkpoint
|
| 255 |
state = load_checkpoint(output_dir, params)
|
| 256 |
-
|
| 257 |
if state and state.completed:
|
| 258 |
-
return
|
| 259 |
-
|
| 260 |
sol_range = SolRange.parse(str(sol_num))
|
| 261 |
proc = create_processor(resolution)
|
| 262 |
actual_size = proc.img_size
|
| 263 |
num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels)
|
| 264 |
max_puzzles = max(num_per_clue_list)
|
| 265 |
num_width = len(str(max_puzzles))
|
| 266 |
-
|
| 267 |
# Initialize or restore state
|
| 268 |
if state is None:
|
| 269 |
random.seed(seed)
|
|
@@ -274,138 +346,162 @@ def generate_dataset(
|
|
| 274 |
all_samples=[]
|
| 275 |
)
|
| 276 |
print(f"Starting fresh generation with solution range: {sol_range}")
|
|
|
|
|
|
|
| 277 |
else:
|
| 278 |
-
# Restore RNG state approximately by fast-forwarding
|
| 279 |
random.seed(seed)
|
| 280 |
for _ in range(sum(state.clue_progress.values()) * 10):
|
| 281 |
random.random()
|
| 282 |
-
|
| 283 |
seen_grids = set(state.seen_grids)
|
| 284 |
all_samples = state.all_samples.copy()
|
| 285 |
clue_progress = {int(k): v for k, v in state.clue_progress.items()}
|
| 286 |
-
|
| 287 |
total_target = sum(num_per_clue_list)
|
| 288 |
total_done = sum(clue_progress.values())
|
| 289 |
stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0)
|
| 290 |
stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0)
|
| 291 |
puzzles_since_checkpoint = 0
|
| 292 |
-
|
| 293 |
with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total:
|
| 294 |
for clue, target_count in zip(clue_levels, num_per_clue_list):
|
| 295 |
generated = clue_progress.get(clue, 0)
|
| 296 |
if generated >= target_count:
|
| 297 |
-
continue
|
| 298 |
-
|
| 299 |
max_attempts = (target_count - generated) * 20
|
| 300 |
-
|
| 301 |
-
with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}",
|
| 302 |
unit="puzzle", leave=False) as pbar_clue:
|
| 303 |
for _ in range(max_attempts):
|
| 304 |
if generated >= target_count:
|
| 305 |
break
|
| 306 |
-
|
| 307 |
result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming)
|
| 308 |
if result is None:
|
| 309 |
continue
|
| 310 |
puzzle, solutions = result
|
| 311 |
-
|
| 312 |
fp = proc.encode(puzzle)
|
| 313 |
if fp in seen_grids:
|
| 314 |
continue
|
| 315 |
seen_grids.add(fp)
|
| 316 |
-
|
| 317 |
n_sols = len(solutions)
|
| 318 |
if n_sols == 1:
|
| 319 |
stats_unique += 1
|
| 320 |
else:
|
| 321 |
stats_multi += 1
|
| 322 |
-
|
| 323 |
img_name = f"clue{clue}_{generated:0{num_width}d}.png"
|
| 324 |
puzzle_img = proc.render(puzzle)
|
| 325 |
-
cv2.imwrite(
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
| 327 |
for si, sol in enumerate(solutions):
|
| 328 |
vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4"
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
all_samples.append({
|
| 334 |
"prompt": prompt, "video": vid_name, "image": img_name,
|
| 335 |
"clue": clue, "puzzle": fp, "solution": proc.encode(sol),
|
| 336 |
"sol_idx": si, "total_solutions": n_sols,
|
| 337 |
-
"frame_count": len(
|
| 338 |
-
"min_hamming_to_others": min(hdists) if hdists else 0
|
| 339 |
})
|
| 340 |
-
|
| 341 |
generated += 1
|
| 342 |
clue_progress[clue] = generated
|
| 343 |
puzzles_since_checkpoint += 1
|
| 344 |
pbar_clue.update(1)
|
| 345 |
pbar_total.update(1)
|
| 346 |
-
|
| 347 |
-
# Periodic checkpoint
|
| 348 |
if puzzles_since_checkpoint >= checkpoint_interval:
|
| 349 |
state.clue_progress = clue_progress
|
| 350 |
state.seen_grids = list(seen_grids)
|
| 351 |
state.all_samples = all_samples
|
| 352 |
save_checkpoint(output_dir, state, params)
|
| 353 |
puzzles_since_checkpoint = 0
|
| 354 |
-
|
| 355 |
-
tqdm.write(
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
| 358 |
# Final output
|
| 359 |
-
random.seed(seed + 1)
|
| 360 |
random.shuffle(all_samples)
|
| 361 |
split_idx = int(len(all_samples) * train_ratio)
|
| 362 |
-
|
| 363 |
def write_jsonl(samples, path):
|
| 364 |
with open(path, 'w') as f:
|
| 365 |
for s in samples:
|
| 366 |
json.dump(s, f)
|
| 367 |
f.write('\n')
|
| 368 |
-
|
| 369 |
write_jsonl(all_samples[:split_idx], output_dir / "train.jsonl")
|
| 370 |
write_jsonl(all_samples[split_idx:], output_dir / "test.jsonl")
|
| 371 |
-
|
| 372 |
# Mark as completed
|
| 373 |
state.clue_progress = clue_progress
|
| 374 |
state.seen_grids = list(seen_grids)
|
| 375 |
state.all_samples = all_samples
|
| 376 |
state.completed = True
|
| 377 |
save_checkpoint(output_dir, state, params)
|
| 378 |
-
|
| 379 |
print(f"\n✓ Dataset complete: {output_dir}/")
|
| 380 |
print(f" Resolution: {actual_size}x{actual_size}")
|
| 381 |
print(f" Solution range: {sol_range}")
|
| 382 |
print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)")
|
| 383 |
print(f" Videos: {len(all_samples)}")
|
| 384 |
print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}")
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0]
|
| 387 |
if hammings:
|
| 388 |
-
print(f" Solution diversity: avg={np.mean(hammings):.1f},
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
def parse_resolution(s):
|
| 392 |
w, h = map(int, s.lower().split('x'))
|
| 393 |
return (w, h)
|
| 394 |
|
|
|
|
| 395 |
def parse_args():
|
| 396 |
-
p = argparse.ArgumentParser(
|
|
|
|
|
|
|
| 397 |
p.add_argument("--output-dir", type=str, default="sudoku")
|
| 398 |
-
p.add_argument("--clue-levels", type=int, nargs="+",
|
| 399 |
-
|
|
|
|
|
|
|
| 400 |
p.add_argument("--sol-num", type=str, default="<=3",
|
| 401 |
help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'")
|
| 402 |
p.add_argument("--min-hamming", type=int, default=10)
|
| 403 |
p.add_argument("--train-ratio", type=float, default=0.9)
|
| 404 |
-
p.add_argument("--prompt", type=str,
|
| 405 |
-
|
| 406 |
-
p.add_argument("--
|
| 407 |
-
|
| 408 |
-
p.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
p.add_argument("--fps", type=int, default=10)
|
| 410 |
p.add_argument("--resolution", type=str, default="1024x1024")
|
| 411 |
p.add_argument("--seed", type=int, default=42)
|
|
|
|
| 1 |
"""
|
| 2 |
Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle.
|
| 3 |
With checkpoint/resume support via metadata.json.
|
| 4 |
+
|
| 5 |
+
The *frames* parameter replaces the old max_frames + k pair:
|
| 6 |
+
- frames=None → 1 content frame per fill step (variable length)
|
| 7 |
+
- frames=N → exactly N content frames; fills distributed evenly
|
| 8 |
+
(slow-motion if N > fills, fast-forward if N < fills)
|
| 9 |
"""
|
| 10 |
import json
|
| 11 |
import re
|
|
|
|
| 13 |
import argparse
|
| 14 |
from dataclasses import dataclass, asdict
|
| 15 |
from pathlib import Path
|
| 16 |
+
from typing import List, Tuple, Optional, Dict
|
| 17 |
import numpy as np
|
| 18 |
import cv2
|
| 19 |
from tqdm import tqdm
|
|
|
|
| 27 |
"""Flexible solution count constraint for puzzle generation."""
|
| 28 |
min_sol: int
|
| 29 |
max_sol: Optional[int]
|
| 30 |
+
|
| 31 |
@classmethod
|
| 32 |
def parse(cls, expr: str) -> "SolRange":
|
| 33 |
expr = expr.strip()
|
|
|
|
| 51 |
if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}")
|
| 52 |
return cls(min_sol=n, max_sol=n)
|
| 53 |
raise ValueError(f"Invalid sol_num expression: '{expr}'")
|
| 54 |
+
|
| 55 |
@property
|
| 56 |
def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol
|
| 57 |
@property
|
|
|
|
| 82 |
seen_grids: List[str]
|
| 83 |
all_samples: List[Dict]
|
| 84 |
completed: bool = False
|
| 85 |
+
|
| 86 |
def to_dict(self) -> Dict:
|
| 87 |
return asdict(self)
|
| 88 |
+
|
| 89 |
@classmethod
|
| 90 |
def from_dict(cls, d: Dict) -> "GenerationState":
|
| 91 |
return cls(**d)
|
|
|
|
| 94 |
def compute_params_hash(params: Dict) -> str:
|
| 95 |
"""Compute hash of generation parameters for consistency check."""
|
| 96 |
import hashlib
|
| 97 |
+
key_params = {k: v for k, v in params.items() if k not in ['output_dir']}
|
|
|
|
|
|
|
| 98 |
return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12]
|
| 99 |
|
| 100 |
|
|
|
|
| 103 |
meta_path = output_dir / "metadata.json"
|
| 104 |
if not meta_path.exists():
|
| 105 |
return None
|
|
|
|
| 106 |
with open(meta_path) as f:
|
| 107 |
data = json.load(f)
|
|
|
|
| 108 |
state = GenerationState.from_dict(data["state"])
|
| 109 |
expected_hash = compute_params_hash(params)
|
|
|
|
| 110 |
if state.params_hash != expected_hash:
|
| 111 |
print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh")
|
| 112 |
return None
|
|
|
|
| 113 |
if state.completed:
|
| 114 |
print("✓ Generation already completed")
|
| 115 |
return state
|
|
|
|
| 116 |
print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated")
|
| 117 |
return state
|
| 118 |
|
|
|
|
| 120 |
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
|
| 121 |
"""Save current generation state to metadata.json."""
|
| 122 |
meta_path = output_dir / "metadata.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
tmp_path = meta_path.with_suffix('.tmp')
|
| 124 |
with open(tmp_path, 'w') as f:
|
| 125 |
+
json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
|
| 126 |
tmp_path.rename(meta_path)
|
| 127 |
|
| 128 |
|
| 129 |
# ==================== Core Functions ====================
|
| 130 |
|
| 131 |
def get_fill_order(puzzle, solution):
|
| 132 |
+
"""Return list of (row, col, value) for empty cells in row-major order."""
|
| 133 |
return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0]
|
| 134 |
|
| 135 |
+
|
| 136 |
def create_processor(resolution=None):
|
| 137 |
+
"""Create a SudokuProcessor with optional custom resolution."""
|
| 138 |
+
if resolution is None:
|
| 139 |
+
return SudokuProcessor()
|
| 140 |
target_size = min(resolution)
|
| 141 |
cell_size = target_size // 9
|
| 142 |
sf = cell_size / 60
|
| 143 |
+
return SudokuProcessor(
|
| 144 |
+
cell_size=cell_size, font_scale=1.2 * sf, thickness=max(1, int(2 * sf))
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
|
| 148 |
+
def generate_video_frames(proc, puzzle, solution, n_start, m_end, frames=None):
|
| 149 |
+
"""
|
| 150 |
+
Generate progressive video frames for a Sudoku solve.
|
| 151 |
+
|
| 152 |
+
The *frames* parameter controls the number of **content frames**
|
| 153 |
+
(between the opening and closing holds):
|
| 154 |
+
|
| 155 |
+
- frames=None → 1 content frame per fill step (n_fills total)
|
| 156 |
+
- frames > fills → multiple frames per fill step (slow-motion)
|
| 157 |
+
- frames < fills → multiple fills per frame (fast-forward)
|
| 158 |
+
- frames == fills → identical to None
|
| 159 |
+
|
| 160 |
+
Total output length = n_start + content_frames + m_end.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
proc: SudokuProcessor instance.
|
| 164 |
+
puzzle: 9×9 puzzle grid (0 = empty).
|
| 165 |
+
solution: 9×9 solved grid.
|
| 166 |
+
n_start: Hold frames for puzzle at the beginning.
|
| 167 |
+
m_end: Hold frames for completed solution at the end.
|
| 168 |
+
frames: Desired number of content frames (None = one per fill).
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
List of numpy arrays (RGB images).
|
| 172 |
+
"""
|
| 173 |
fills = get_fill_order(puzzle, solution)
|
| 174 |
n_fills = len(fills)
|
| 175 |
+
|
| 176 |
+
if n_fills == 0:
|
| 177 |
+
img = proc.render(solution, original=puzzle)
|
| 178 |
+
return [img.copy() for _ in range(n_start + m_end + 1)]
|
| 179 |
+
|
| 180 |
+
content_frames = frames if frames is not None else n_fills
|
| 181 |
+
content_frames = max(1, content_frames)
|
| 182 |
+
|
| 183 |
+
result = []
|
| 184 |
current = [row[:] for row in puzzle]
|
| 185 |
+
|
| 186 |
+
# --- opening hold ---
|
| 187 |
img = proc.render(current)
|
| 188 |
+
result.extend([img.copy() for _ in range(n_start)])
|
| 189 |
+
|
| 190 |
+
# --- content frames ---
|
| 191 |
+
if content_frames == n_fills:
|
| 192 |
+
# Exact 1:1 mapping
|
| 193 |
+
for r, c, v in fills:
|
| 194 |
+
current[r][c] = v
|
| 195 |
+
result.append(proc.render(current, highlight_new=(r, c), original=puzzle))
|
| 196 |
+
|
| 197 |
+
elif content_frames > n_fills:
|
| 198 |
+
# Slow-motion: distribute content_frames evenly across n_fills steps.
|
| 199 |
+
for i, (r, c, v) in enumerate(fills):
|
| 200 |
+
current[r][c] = v
|
| 201 |
+
f_lo = i * content_frames // n_fills
|
| 202 |
+
f_hi = (i + 1) * content_frames // n_fills
|
| 203 |
+
count = f_hi - f_lo # >= 1
|
| 204 |
+
|
| 205 |
+
# First frame of this step shows highlight
|
| 206 |
+
result.append(proc.render(current, highlight_new=(r, c), original=puzzle))
|
| 207 |
+
# Remaining hold frames (no highlight)
|
| 208 |
+
if count > 1:
|
| 209 |
+
img = proc.render(current, original=puzzle)
|
| 210 |
+
result.extend([img.copy() for _ in range(count - 1)])
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
# Fast-forward: each content frame applies multiple fills at once.
|
| 214 |
+
for f in range(content_frames):
|
| 215 |
+
prev_step = f * n_fills // content_frames
|
| 216 |
+
target_step = (f + 1) * n_fills // content_frames
|
| 217 |
+
last_r, last_c = None, None
|
| 218 |
+
for idx in range(prev_step, target_step):
|
| 219 |
+
r, c, v = fills[idx]
|
| 220 |
+
current[r][c] = v
|
| 221 |
+
last_r, last_c = r, c
|
| 222 |
+
if last_r is not None:
|
| 223 |
+
result.append(
|
| 224 |
+
proc.render(current, highlight_new=(last_r, last_c), original=puzzle)
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
result.append(proc.render(current, original=puzzle))
|
| 228 |
+
|
| 229 |
+
# --- closing hold ---
|
| 230 |
img = proc.render(solution, original=puzzle)
|
| 231 |
+
result.extend([img.copy() for _ in range(m_end)])
|
| 232 |
+
|
| 233 |
+
return result
|
| 234 |
+
|
| 235 |
|
| 236 |
def save_video(frames, path, fps=10):
|
| 237 |
+
"""Save list of numpy RGB frames as mp4."""
|
| 238 |
h, w = frames[0].shape[:2]
|
| 239 |
writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
| 240 |
+
for f in frames:
|
| 241 |
+
writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
|
| 242 |
writer.release()
|
| 243 |
|
| 244 |
+
|
| 245 |
def normalize_num_per_clue(num_per_clue, clue_levels):
|
| 246 |
+
"""Broadcast single int to list, or validate list length."""
|
| 247 |
+
if isinstance(num_per_clue, int):
|
| 248 |
+
return [num_per_clue] * len(clue_levels)
|
| 249 |
if len(num_per_clue) != len(clue_levels):
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})"
|
| 252 |
+
)
|
| 253 |
return num_per_clue
|
| 254 |
|
| 255 |
|
|
|
|
| 260 |
if sol_range.is_unique_only:
|
| 261 |
puzzle, solution = proc.generate(clue, unique=True)
|
| 262 |
return puzzle, [solution]
|
| 263 |
+
|
| 264 |
if sol_range.requires_multi:
|
| 265 |
try:
|
| 266 |
puzzle, solutions = proc.generate_multi_solution(
|
|
|
|
| 273 |
except RuntimeError:
|
| 274 |
pass
|
| 275 |
return None
|
| 276 |
+
|
| 277 |
try:
|
| 278 |
puzzle, solutions = proc.generate_multi_solution(
|
| 279 |
clue, min_solutions=max(2, sol_range.min_sol),
|
|
|
|
| 284 |
return puzzle, solutions
|
| 285 |
except RuntimeError:
|
| 286 |
pass
|
| 287 |
+
|
| 288 |
if sol_range.allows_unique:
|
| 289 |
puzzle, solution = proc.generate(clue, unique=True)
|
| 290 |
return puzzle, [solution]
|
|
|
|
| 294 |
# ==================== Dataset Generation ====================
|
| 295 |
|
| 296 |
def generate_dataset(
|
| 297 |
+
output_dir="sudoku", clue_levels=[20, 30, 40, 50, 60, 70],
|
| 298 |
+
num_per_clue=[15000, 10000, 10000, 5000, 2000, 1000],
|
| 299 |
+
sol_num="<=3", min_hamming=10, train_ratio=0.9,
|
| 300 |
prompt="Solve this Sudoku puzzle using red font.",
|
| 301 |
+
n_start=2, m_end=3, frames=None, fps=10,
|
| 302 |
resolution=None, seed=42, checkpoint_interval=50
|
| 303 |
):
|
| 304 |
"""
|
| 305 |
Generate Sudoku video dataset with checkpoint/resume support.
|
| 306 |
+
|
| 307 |
+
The *frames* parameter controls the number of **content frames** per video:
|
| 308 |
+
- None → one content frame per fill step (variable length per puzzle)
|
| 309 |
+
- N > 0 → exactly N content frames; fills distributed evenly
|
| 310 |
+
|
| 311 |
Args:
|
| 312 |
checkpoint_interval: Save checkpoint every N puzzles (default: 50)
|
| 313 |
"""
|
|
|
|
| 314 |
params = {
|
| 315 |
"clue_levels": clue_levels, "num_per_clue": num_per_clue,
|
| 316 |
"sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio,
|
| 317 |
+
"prompt": prompt, "n_start": n_start, "m_end": m_end, "frames": frames,
|
| 318 |
+
"fps": fps, "resolution": resolution, "seed": seed
|
| 319 |
}
|
| 320 |
+
|
| 321 |
output_dir = Path(output_dir)
|
| 322 |
video_dir = output_dir / "videos"
|
| 323 |
image_dir = output_dir / "images"
|
| 324 |
video_dir.mkdir(parents=True, exist_ok=True)
|
| 325 |
image_dir.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
|
| 327 |
# Try to resume from checkpoint
|
| 328 |
state = load_checkpoint(output_dir, params)
|
|
|
|
| 329 |
if state and state.completed:
|
| 330 |
+
return
|
| 331 |
+
|
| 332 |
sol_range = SolRange.parse(str(sol_num))
|
| 333 |
proc = create_processor(resolution)
|
| 334 |
actual_size = proc.img_size
|
| 335 |
num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels)
|
| 336 |
max_puzzles = max(num_per_clue_list)
|
| 337 |
num_width = len(str(max_puzzles))
|
| 338 |
+
|
| 339 |
# Initialize or restore state
|
| 340 |
if state is None:
|
| 341 |
random.seed(seed)
|
|
|
|
| 346 |
all_samples=[]
|
| 347 |
)
|
| 348 |
print(f"Starting fresh generation with solution range: {sol_range}")
|
| 349 |
+
print(f" frames={'auto (1 per fill)' if frames is None else frames}, "
|
| 350 |
+
f"n_start={n_start}, m_end={m_end}, fps={fps}")
|
| 351 |
else:
|
|
|
|
| 352 |
random.seed(seed)
|
| 353 |
for _ in range(sum(state.clue_progress.values()) * 10):
|
| 354 |
random.random()
|
| 355 |
+
|
| 356 |
seen_grids = set(state.seen_grids)
|
| 357 |
all_samples = state.all_samples.copy()
|
| 358 |
clue_progress = {int(k): v for k, v in state.clue_progress.items()}
|
| 359 |
+
|
| 360 |
total_target = sum(num_per_clue_list)
|
| 361 |
total_done = sum(clue_progress.values())
|
| 362 |
stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0)
|
| 363 |
stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0)
|
| 364 |
puzzles_since_checkpoint = 0
|
| 365 |
+
|
| 366 |
with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total:
|
| 367 |
for clue, target_count in zip(clue_levels, num_per_clue_list):
|
| 368 |
generated = clue_progress.get(clue, 0)
|
| 369 |
if generated >= target_count:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
max_attempts = (target_count - generated) * 20
|
| 373 |
+
|
| 374 |
+
with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}",
|
| 375 |
unit="puzzle", leave=False) as pbar_clue:
|
| 376 |
for _ in range(max_attempts):
|
| 377 |
if generated >= target_count:
|
| 378 |
break
|
| 379 |
+
|
| 380 |
result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming)
|
| 381 |
if result is None:
|
| 382 |
continue
|
| 383 |
puzzle, solutions = result
|
| 384 |
+
|
| 385 |
fp = proc.encode(puzzle)
|
| 386 |
if fp in seen_grids:
|
| 387 |
continue
|
| 388 |
seen_grids.add(fp)
|
| 389 |
+
|
| 390 |
n_sols = len(solutions)
|
| 391 |
if n_sols == 1:
|
| 392 |
stats_unique += 1
|
| 393 |
else:
|
| 394 |
stats_multi += 1
|
| 395 |
+
|
| 396 |
img_name = f"clue{clue}_{generated:0{num_width}d}.png"
|
| 397 |
puzzle_img = proc.render(puzzle)
|
| 398 |
+
cv2.imwrite(
|
| 399 |
+
str(image_dir / img_name),
|
| 400 |
+
cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR),
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
for si, sol in enumerate(solutions):
|
| 404 |
vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4"
|
| 405 |
+
vid_frames = generate_video_frames(
|
| 406 |
+
proc, puzzle, sol, n_start, m_end, frames
|
| 407 |
+
)
|
| 408 |
+
save_video(vid_frames, video_dir / vid_name, fps)
|
| 409 |
+
|
| 410 |
+
hdists = [
|
| 411 |
+
proc._hamming(sol, solutions[j])
|
| 412 |
+
for j in range(n_sols) if j != si
|
| 413 |
+
]
|
| 414 |
all_samples.append({
|
| 415 |
"prompt": prompt, "video": vid_name, "image": img_name,
|
| 416 |
"clue": clue, "puzzle": fp, "solution": proc.encode(sol),
|
| 417 |
"sol_idx": si, "total_solutions": n_sols,
|
| 418 |
+
"frame_count": len(vid_frames),
|
| 419 |
+
"min_hamming_to_others": min(hdists) if hdists else 0,
|
| 420 |
})
|
| 421 |
+
|
| 422 |
generated += 1
|
| 423 |
clue_progress[clue] = generated
|
| 424 |
puzzles_since_checkpoint += 1
|
| 425 |
pbar_clue.update(1)
|
| 426 |
pbar_total.update(1)
|
| 427 |
+
|
|
|
|
| 428 |
if puzzles_since_checkpoint >= checkpoint_interval:
|
| 429 |
state.clue_progress = clue_progress
|
| 430 |
state.seen_grids = list(seen_grids)
|
| 431 |
state.all_samples = all_samples
|
| 432 |
save_checkpoint(output_dir, state, params)
|
| 433 |
puzzles_since_checkpoint = 0
|
| 434 |
+
|
| 435 |
+
tqdm.write(
|
| 436 |
+
f"Clue {clue}: {generated} puzzles, "
|
| 437 |
+
f"{sum(1 for s in all_samples if s['clue'] == clue)} videos"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
# Final output
|
| 441 |
+
random.seed(seed + 1)
|
| 442 |
random.shuffle(all_samples)
|
| 443 |
split_idx = int(len(all_samples) * train_ratio)
|
| 444 |
+
|
| 445 |
def write_jsonl(samples, path):
|
| 446 |
with open(path, 'w') as f:
|
| 447 |
for s in samples:
|
| 448 |
json.dump(s, f)
|
| 449 |
f.write('\n')
|
| 450 |
+
|
| 451 |
write_jsonl(all_samples[:split_idx], output_dir / "train.jsonl")
|
| 452 |
write_jsonl(all_samples[split_idx:], output_dir / "test.jsonl")
|
| 453 |
+
|
| 454 |
# Mark as completed
|
| 455 |
state.clue_progress = clue_progress
|
| 456 |
state.seen_grids = list(seen_grids)
|
| 457 |
state.all_samples = all_samples
|
| 458 |
state.completed = True
|
| 459 |
save_checkpoint(output_dir, state, params)
|
| 460 |
+
|
| 461 |
print(f"\n✓ Dataset complete: {output_dir}/")
|
| 462 |
print(f" Resolution: {actual_size}x{actual_size}")
|
| 463 |
print(f" Solution range: {sol_range}")
|
| 464 |
print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)")
|
| 465 |
print(f" Videos: {len(all_samples)}")
|
| 466 |
print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}")
|
| 467 |
+
|
| 468 |
+
fcounts = [s["frame_count"] for s in all_samples]
|
| 469 |
+
print(f" Frame counts: avg={np.mean(fcounts):.1f}, "
|
| 470 |
+
f"min={min(fcounts)}, max={max(fcounts)}")
|
| 471 |
+
|
| 472 |
hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0]
|
| 473 |
if hammings:
|
| 474 |
+
print(f" Solution diversity: avg={np.mean(hammings):.1f}, "
|
| 475 |
+
f"min={min(hammings)}, max={max(hammings)}")
|
| 476 |
|
| 477 |
|
| 478 |
def parse_resolution(s):
|
| 479 |
w, h = map(int, s.lower().split('x'))
|
| 480 |
return (w, h)
|
| 481 |
|
| 482 |
+
|
| 483 |
def parse_args():
|
| 484 |
+
p = argparse.ArgumentParser(
|
| 485 |
+
description="Generate Sudoku video dataset with resume support"
|
| 486 |
+
)
|
| 487 |
p.add_argument("--output-dir", type=str, default="sudoku")
|
| 488 |
+
p.add_argument("--clue-levels", type=int, nargs="+",
|
| 489 |
+
default=[20, 30, 40, 50, 60, 70])
|
| 490 |
+
p.add_argument("--num-per-clue", type=int, nargs="+",
|
| 491 |
+
default=[15000, 10000, 10000, 5000, 2000, 1000])
|
| 492 |
p.add_argument("--sol-num", type=str, default="<=3",
|
| 493 |
help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'")
|
| 494 |
p.add_argument("--min-hamming", type=int, default=10)
|
| 495 |
p.add_argument("--train-ratio", type=float, default=0.9)
|
| 496 |
+
p.add_argument("--prompt", type=str,
|
| 497 |
+
default="Solve this Sudoku puzzle using red font.")
|
| 498 |
+
p.add_argument("--n-start", type=int, default=2,
|
| 499 |
+
help="Hold frames for puzzle at video start")
|
| 500 |
+
p.add_argument("--m-end", type=int, default=3,
|
| 501 |
+
help="Hold frames for completed solution at video end")
|
| 502 |
+
p.add_argument("--frames", type=int, default=None,
|
| 503 |
+
help="Content frames per video. None=1 per fill (auto). "
|
| 504 |
+
"If > fills: slow-motion. If < fills: fast-forward.")
|
| 505 |
p.add_argument("--fps", type=int, default=10)
|
| 506 |
p.add_argument("--resolution", type=str, default="1024x1024")
|
| 507 |
p.add_argument("--seed", type=int, default=42)
|
sudoku/jsonl_to_csv.py
CHANGED
|
@@ -2,11 +2,11 @@ import json
|
|
| 2 |
import csv
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
-
dataset='
|
| 6 |
-
split='
|
| 7 |
|
| 8 |
# Load test data
|
| 9 |
-
with open(f'{dataset}/{split}
|
| 10 |
data = [json.loads(line) for line in f]
|
| 11 |
|
| 12 |
# Write to CSV
|
|
@@ -19,4 +19,7 @@ with open(f'{dataset}/{split}.csv', 'w', newline='', encoding='utf-8') as f:
|
|
| 19 |
'images/' + item['image'],
|
| 20 |
'videos/' + item['video'],
|
| 21 |
item['prompt'],
|
| 22 |
-
])
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import csv
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
+
dataset='sudoku_large'
|
| 6 |
+
split='test'
|
| 7 |
|
| 8 |
# Load test data
|
| 9 |
+
with open(f'{dataset}/{split}.jsonl', 'r') as f:
|
| 10 |
data = [json.loads(line) for line in f]
|
| 11 |
|
| 12 |
# Write to CSV
|
|
|
|
| 19 |
'images/' + item['image'],
|
| 20 |
'videos/' + item['video'],
|
| 21 |
item['prompt'],
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
# Rename `{split}.jsonl' to `{split}_info.jsonl`
|
| 25 |
+
Path(f'{dataset}/{split}.jsonl').rename(Path(f'{dataset}/{split}_info.jsonl'))
|