| | """ |
| | Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle. |
| | With checkpoint/resume support via metadata.json. |
| | |
| | The *frames* parameter replaces the old max_frames + k pair: |
| | - frames=None → 1 content frame per fill step (variable length) |
| | - frames=N → exactly N content frames; fills distributed evenly |
| | (slow-motion if N > fills, fast-forward if N < fills) |
| | """ |
| | import json |
| | import re |
| | import random |
| | import argparse |
| | from dataclasses import dataclass, asdict |
| | from pathlib import Path |
| | from typing import List, Tuple, Optional, Dict |
| | import numpy as np |
| | import cv2 |
| | from tqdm import tqdm |
| | from sudoku_processor import SudokuProcessor |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class SolRange: |
| | """Flexible solution count constraint for puzzle generation.""" |
| | min_sol: int |
| | max_sol: Optional[int] |
| |
|
| | @classmethod |
| | def parse(cls, expr: str) -> "SolRange": |
| | expr = expr.strip() |
| | m = re.fullmatch(r'(\d+)\s*-\s*(\d+)', expr) |
| | if m: |
| | lo, hi = int(m.group(1)), int(m.group(2)) |
| | if lo < 1: raise ValueError(f"min_sol must be >= 1, got {lo}") |
| | if hi < lo: raise ValueError(f"Invalid range: {lo}-{hi}") |
| | return cls(min_sol=lo, max_sol=hi) |
| | m = re.fullmatch(r'(>=|>|<=|<|==)\s*(\d+)', expr) |
| | if m: |
| | op, n = m.group(1), int(m.group(2)) |
| | if op == '>=': return cls(min_sol=max(1, n), max_sol=None) |
| | elif op == '>': return cls(min_sol=max(1, n + 1), max_sol=None) |
| | elif op == '<=': return cls(min_sol=1, max_sol=n) |
| | elif op == '<': return cls(min_sol=1, max_sol=max(1, n - 1)) |
| | elif op == '==': return cls(min_sol=n, max_sol=n) |
| | m = re.fullmatch(r'(\d+)', expr) |
| | if m: |
| | n = int(m.group(1)) |
| | if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}") |
| | return cls(min_sol=n, max_sol=n) |
| | raise ValueError(f"Invalid sol_num expression: '{expr}'") |
| |
|
| | @property |
| | def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol |
| | @property |
| | def is_unique_only(self): return self.is_exact and self.min_sol == 1 |
| | @property |
| | def allows_unique(self): return self.min_sol <= 1 |
| | @property |
| | def requires_multi(self): return self.min_sol > 1 |
| | @property |
| | def effective_max(self): return self.max_sol if self.max_sol is not None else max(self.min_sol, 10) |
| | def accepts(self, count): |
| | if count < self.min_sol: return False |
| | if self.max_sol is not None and count > self.max_sol: return False |
| | return True |
| | def __repr__(self): |
| | if self.is_exact: return f"SolRange(=={self.min_sol})" |
| | if self.max_sol is None: return f"SolRange(>={self.min_sol})" |
| | return f"SolRange({self.min_sol}-{self.max_sol})" |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class GenerationState: |
| | """Tracks generation progress for checkpoint/resume.""" |
| | params_hash: str |
| | clue_progress: Dict[int, int] |
| | seen_grids: List[str] |
| | all_samples: List[Dict] |
| | completed: bool = False |
| |
|
| | def to_dict(self) -> Dict: |
| | return asdict(self) |
| |
|
| | @classmethod |
| | def from_dict(cls, d: Dict) -> "GenerationState": |
| | return cls(**d) |
| |
|
| |
|
| | def compute_params_hash(params: Dict) -> str: |
| | """Compute hash of generation parameters for consistency check.""" |
| | import hashlib |
| | key_params = {k: v for k, v in params.items() if k not in ['output_dir']} |
| | return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12] |
| |
|
| |
|
| | def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: |
| | """Load checkpoint if exists and params match.""" |
| | meta_path = output_dir / "metadata.json" |
| | if not meta_path.exists(): |
| | return None |
| | with open(meta_path) as f: |
| | data = json.load(f) |
| | state = GenerationState.from_dict(data["state"]) |
| | expected_hash = compute_params_hash(params) |
| | if state.params_hash != expected_hash: |
| | print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh") |
| | return None |
| | if state.completed: |
| | print("✓ Generation already completed") |
| | return state |
| | print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated") |
| | return state |
| |
|
| |
|
| | def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): |
| | """Save current generation state to metadata.json.""" |
| | meta_path = output_dir / "metadata.json" |
| | tmp_path = meta_path.with_suffix('.tmp') |
| | with open(tmp_path, 'w') as f: |
| | json.dump({"params": params, "state": state.to_dict()}, f, indent=2) |
| | tmp_path.rename(meta_path) |
| |
|
| |
|
| | |
| |
|
| | def get_fill_order(puzzle, solution): |
| | """Return list of (row, col, value) for empty cells in row-major order.""" |
| | return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0] |
| |
|
| |
|
| | def create_processor(resolution=None): |
| | """Create a SudokuProcessor with optional custom resolution.""" |
| | if resolution is None: |
| | return SudokuProcessor() |
| | target_size = min(resolution) |
| | cell_size = target_size // 9 |
| | sf = cell_size / 60 |
| | return SudokuProcessor( |
| | cell_size=cell_size, font_scale=1.2 * sf, thickness=max(1, int(2 * sf)) |
| | ) |
| |
|
| |
|
| | def generate_video_frames(proc, puzzle, solution, n_start, m_end, frames=None): |
| | """ |
| | Generate progressive video frames for a Sudoku solve. |
| | |
| | The *frames* parameter controls the number of **content frames** |
| | (between the opening and closing holds): |
| | |
| | - frames=None → 1 content frame per fill step (n_fills total) |
| | - frames > fills → multiple frames per fill step (slow-motion) |
| | - frames < fills → multiple fills per frame (fast-forward) |
| | - frames == fills → identical to None |
| | |
| | Total output length = n_start + content_frames + m_end. |
| | |
| | Args: |
| | proc: SudokuProcessor instance. |
| | puzzle: 9×9 puzzle grid (0 = empty). |
| | solution: 9×9 solved grid. |
| | n_start: Hold frames for puzzle at the beginning. |
| | m_end: Hold frames for completed solution at the end. |
| | frames: Desired number of content frames (None = one per fill). |
| | |
| | Returns: |
| | List of numpy arrays (RGB images). |
| | """ |
| | fills = get_fill_order(puzzle, solution) |
| | n_fills = len(fills) |
| |
|
| | if n_fills == 0: |
| | img = proc.render(solution, original=puzzle) |
| | return [img.copy() for _ in range(n_start + m_end + 1)] |
| |
|
| | content_frames = frames if frames is not None else n_fills |
| | content_frames = max(1, content_frames) |
| |
|
| | result = [] |
| | current = [row[:] for row in puzzle] |
| |
|
| | |
| | img = proc.render(current) |
| | result.extend([img.copy() for _ in range(n_start)]) |
| |
|
| | |
| | if content_frames == n_fills: |
| | |
| | for r, c, v in fills: |
| | current[r][c] = v |
| | result.append(proc.render(current, highlight_new=(r, c), original=puzzle)) |
| |
|
| | elif content_frames > n_fills: |
| | |
| | for i, (r, c, v) in enumerate(fills): |
| | current[r][c] = v |
| | f_lo = i * content_frames // n_fills |
| | f_hi = (i + 1) * content_frames // n_fills |
| | count = f_hi - f_lo |
| |
|
| | |
| | result.append(proc.render(current, highlight_new=(r, c), original=puzzle)) |
| | |
| | if count > 1: |
| | img = proc.render(current, original=puzzle) |
| | result.extend([img.copy() for _ in range(count - 1)]) |
| |
|
| | else: |
| | |
| | for f in range(content_frames): |
| | prev_step = f * n_fills // content_frames |
| | target_step = (f + 1) * n_fills // content_frames |
| | last_r, last_c = None, None |
| | for idx in range(prev_step, target_step): |
| | r, c, v = fills[idx] |
| | current[r][c] = v |
| | last_r, last_c = r, c |
| | if last_r is not None: |
| | result.append( |
| | proc.render(current, highlight_new=(last_r, last_c), original=puzzle) |
| | ) |
| | else: |
| | result.append(proc.render(current, original=puzzle)) |
| |
|
| | |
| | img = proc.render(solution, original=puzzle) |
| | result.extend([img.copy() for _ in range(m_end)]) |
| |
|
| | return result |
| |
|
| |
|
| | def save_video(frames, path, fps=10): |
| | """Save list of numpy RGB frames as mp4.""" |
| | h, w = frames[0].shape[:2] |
| | writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) |
| | for f in frames: |
| | writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR)) |
| | writer.release() |
| |
|
| |
|
| | def normalize_num_per_clue(num_per_clue, clue_levels): |
| | """Broadcast single int to list, or validate list length.""" |
| | if isinstance(num_per_clue, int): |
| | return [num_per_clue] * len(clue_levels) |
| | if len(num_per_clue) != len(clue_levels): |
| | raise ValueError( |
| | f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})" |
| | ) |
| | return num_per_clue |
| |
|
| |
|
| | |
| |
|
| | def generate_puzzle_with_range(proc, clue, sol_range, min_hamming): |
| | """Generate one puzzle respecting sol_range. Returns (puzzle, solutions) or None.""" |
| | if sol_range.is_unique_only: |
| | puzzle, solution = proc.generate(clue, unique=True) |
| | return puzzle, [solution] |
| |
|
| | if sol_range.requires_multi: |
| | try: |
| | puzzle, solutions = proc.generate_multi_solution( |
| | clue, min_solutions=sol_range.min_sol, |
| | max_solutions=sol_range.effective_max, |
| | max_attempts=1, min_hamming=min_hamming |
| | ) |
| | if sol_range.accepts(len(solutions)): |
| | return puzzle, solutions |
| | except RuntimeError: |
| | pass |
| | return None |
| |
|
| | try: |
| | puzzle, solutions = proc.generate_multi_solution( |
| | clue, min_solutions=max(2, sol_range.min_sol), |
| | max_solutions=sol_range.effective_max, |
| | max_attempts=1, min_hamming=min_hamming |
| | ) |
| | if sol_range.accepts(len(solutions)): |
| | return puzzle, solutions |
| | except RuntimeError: |
| | pass |
| |
|
| | if sol_range.allows_unique: |
| | puzzle, solution = proc.generate(clue, unique=True) |
| | return puzzle, [solution] |
| | return None |
| |
|
| |
|
| | |
| |
|
| | def generate_dataset( |
| | output_dir="sudoku", clue_levels=[20, 30, 40, 50, 60, 70], |
| | num_per_clue=[15000, 10000, 10000, 5000, 2000, 1000], |
| | sol_num="<=3", min_hamming=10, train_ratio=0.9, |
| | prompt="Solve this Sudoku puzzle using red font.", |
| | n_start=2, m_end=3, frames=None, fps=10, |
| | resolution=None, seed=42, checkpoint_interval=50 |
| | ): |
| | """ |
| | Generate Sudoku video dataset with checkpoint/resume support. |
| | |
| | The *frames* parameter controls the number of **content frames** per video: |
| | - None → one content frame per fill step (variable length per puzzle) |
| | - N > 0 → exactly N content frames; fills distributed evenly |
| | |
| | Args: |
| | checkpoint_interval: Save checkpoint every N puzzles (default: 50) |
| | """ |
| | params = { |
| | "clue_levels": clue_levels, "num_per_clue": num_per_clue, |
| | "sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio, |
| | "prompt": prompt, "n_start": n_start, "m_end": m_end, "frames": frames, |
| | "fps": fps, "resolution": resolution, "seed": seed |
| | } |
| |
|
| | output_dir = Path(output_dir) |
| | video_dir = output_dir / "videos" |
| | image_dir = output_dir / "images" |
| | video_dir.mkdir(parents=True, exist_ok=True) |
| | image_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | state = load_checkpoint(output_dir, params) |
| | if state and state.completed: |
| | return |
| |
|
| | sol_range = SolRange.parse(str(sol_num)) |
| | proc = create_processor(resolution) |
| | actual_size = proc.img_size |
| | num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels) |
| | max_puzzles = max(num_per_clue_list) |
| | num_width = len(str(max_puzzles)) |
| |
|
| | |
| | if state is None: |
| | random.seed(seed) |
| | state = GenerationState( |
| | params_hash=compute_params_hash(params), |
| | clue_progress={clue: 0 for clue in clue_levels}, |
| | seen_grids=[], |
| | all_samples=[] |
| | ) |
| | print(f"Starting fresh generation with solution range: {sol_range}") |
| | print(f" frames={'auto (1 per fill)' if frames is None else frames}, " |
| | f"n_start={n_start}, m_end={m_end}, fps={fps}") |
| | else: |
| | random.seed(seed) |
| | for _ in range(sum(state.clue_progress.values()) * 10): |
| | random.random() |
| |
|
| | seen_grids = set(state.seen_grids) |
| | all_samples = state.all_samples.copy() |
| | clue_progress = {int(k): v for k, v in state.clue_progress.items()} |
| |
|
| | total_target = sum(num_per_clue_list) |
| | total_done = sum(clue_progress.values()) |
| | stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0) |
| | stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0) |
| | puzzles_since_checkpoint = 0 |
| |
|
| | with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total: |
| | for clue, target_count in zip(clue_levels, num_per_clue_list): |
| | generated = clue_progress.get(clue, 0) |
| | if generated >= target_count: |
| | continue |
| |
|
| | max_attempts = (target_count - generated) * 20 |
| |
|
| | with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}", |
| | unit="puzzle", leave=False) as pbar_clue: |
| | for _ in range(max_attempts): |
| | if generated >= target_count: |
| | break |
| |
|
| | result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming) |
| | if result is None: |
| | continue |
| | puzzle, solutions = result |
| |
|
| | fp = proc.encode(puzzle) |
| | if fp in seen_grids: |
| | continue |
| | seen_grids.add(fp) |
| |
|
| | n_sols = len(solutions) |
| | if n_sols == 1: |
| | stats_unique += 1 |
| | else: |
| | stats_multi += 1 |
| |
|
| | img_name = f"clue{clue}_{generated:0{num_width}d}.png" |
| | puzzle_img = proc.render(puzzle) |
| | cv2.imwrite( |
| | str(image_dir / img_name), |
| | cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR), |
| | ) |
| |
|
| | for si, sol in enumerate(solutions): |
| | vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4" |
| | vid_frames = generate_video_frames( |
| | proc, puzzle, sol, n_start, m_end, frames |
| | ) |
| | save_video(vid_frames, video_dir / vid_name, fps) |
| |
|
| | hdists = [ |
| | proc._hamming(sol, solutions[j]) |
| | for j in range(n_sols) if j != si |
| | ] |
| | all_samples.append({ |
| | "prompt": prompt, "video": vid_name, "image": img_name, |
| | "clue": clue, "puzzle": fp, "solution": proc.encode(sol), |
| | "sol_idx": si, "total_solutions": n_sols, |
| | "frame_count": len(vid_frames), |
| | "min_hamming_to_others": min(hdists) if hdists else 0, |
| | }) |
| |
|
| | generated += 1 |
| | clue_progress[clue] = generated |
| | puzzles_since_checkpoint += 1 |
| | pbar_clue.update(1) |
| | pbar_total.update(1) |
| |
|
| | if puzzles_since_checkpoint >= checkpoint_interval: |
| | state.clue_progress = clue_progress |
| | state.seen_grids = list(seen_grids) |
| | state.all_samples = all_samples |
| | save_checkpoint(output_dir, state, params) |
| | puzzles_since_checkpoint = 0 |
| |
|
| | tqdm.write( |
| | f"Clue {clue}: {generated} puzzles, " |
| | f"{sum(1 for s in all_samples if s['clue'] == clue)} videos" |
| | ) |
| |
|
| | |
| | random.seed(seed + 1) |
| | by_clue: Dict[int, List[Dict]] = {} |
| | for s in all_samples: |
| | by_clue.setdefault(s["clue"], []).append(s) |
| |
|
| | train_samples, test_samples = [], [] |
| | for clue in sorted(by_clue): |
| | group = by_clue[clue] |
| | random.shuffle(group) |
| | cl_split = int(len(group) * train_ratio) |
| | train_samples.extend(group[:cl_split]) |
| | test_samples.extend(group[cl_split:]) |
| |
|
| | random.shuffle(train_samples) |
| | random.shuffle(test_samples) |
| | split_idx = len(train_samples) |
| |
|
| | def write_jsonl(samples, path): |
| | with open(path, 'w') as f: |
| | for s in samples: |
| | json.dump(s, f) |
| | f.write('\n') |
| |
|
| | write_jsonl(train_samples, output_dir / "train.jsonl") |
| | write_jsonl(test_samples, output_dir / "test.jsonl") |
| |
|
| | |
| | state.clue_progress = clue_progress |
| | state.seen_grids = list(seen_grids) |
| | state.all_samples = all_samples |
| | state.completed = True |
| | save_checkpoint(output_dir, state, params) |
| |
|
| | print(f"\n✓ Dataset complete: {output_dir}/") |
| | print(f" Resolution: {actual_size}x{actual_size}") |
| | print(f" Solution range: {sol_range}") |
| | print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)") |
| | print(f" Videos: {len(all_samples)}") |
| | print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}") |
| |
|
| | fcounts = [s["frame_count"] for s in all_samples] |
| | print(f" Frame counts: avg={np.mean(fcounts):.1f}, " |
| | f"min={min(fcounts)}, max={max(fcounts)}") |
| |
|
| | hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0] |
| | if hammings: |
| | print(f" Solution diversity: avg={np.mean(hammings):.1f}, " |
| | f"min={min(hammings)}, max={max(hammings)}") |
| |
|
| |
|
| | def parse_resolution(s): |
| | w, h = map(int, s.lower().split('x')) |
| | return (w, h) |
| |
|
| |
|
| | def parse_args(): |
| | p = argparse.ArgumentParser( |
| | description="Generate Sudoku video dataset with resume support" |
| | ) |
| | p.add_argument("--output-dir", type=str, default="sudoku") |
| | p.add_argument("--clue-levels", type=int, nargs="+", |
| | default=[20, 30, 40, 50, 60, 70]) |
| | p.add_argument("--num-per-clue", type=int, nargs="+", |
| | default=[15000, 10000, 10000, 5000, 2000, 1000]) |
| | p.add_argument("--sol-num", type=str, default="<=3", |
| | help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'") |
| | p.add_argument("--min-hamming", type=int, default=10) |
| | p.add_argument("--train-ratio", type=float, default=0.9) |
| | p.add_argument("--prompt", type=str, |
| | default="Solve this Sudoku puzzle using red font.") |
| | p.add_argument("--n-start", type=int, default=2, |
| | help="Hold frames for puzzle at video start") |
| | p.add_argument("--m-end", type=int, default=3, |
| | help="Hold frames for completed solution at video end") |
| | p.add_argument("--frames", type=int, default=None, |
| | help="Content frames per video. None=1 per fill (auto). " |
| | "If > fills: slow-motion. If < fills: fast-forward.") |
| | p.add_argument("--fps", type=int, default=10) |
| | p.add_argument("--resolution", type=str, default="1024x1024") |
| | p.add_argument("--seed", type=int, default=42) |
| | p.add_argument("--checkpoint-interval", type=int, default=50, |
| | help="Save checkpoint every N puzzles (default: 50)") |
| | return p.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | kwargs = vars(args) |
| | if isinstance(kwargs["num_per_clue"], list) and len(kwargs["num_per_clue"]) == 1: |
| | kwargs["num_per_clue"] = kwargs["num_per_clue"][0] |
| | if kwargs["resolution"]: |
| | kwargs["resolution"] = parse_resolution(kwargs["resolution"]) |
| | generate_dataset(**kwargs) |