""" Sudoku Video Dataset Generator - Supports flexible solution count expressions per puzzle. With checkpoint/resume support via metadata.json. The *frames* parameter replaces the old max_frames + k pair: - frames=None → 1 content frame per fill step (variable length) - frames=N → exactly N content frames; fills distributed evenly (slow-motion if N > fills, fast-forward if N < fills) """ import json import re import random import argparse from dataclasses import dataclass, asdict from pathlib import Path from typing import List, Tuple, Optional, Dict import numpy as np import cv2 from tqdm import tqdm from sudoku_processor import SudokuProcessor # ==================== Solution Range ==================== @dataclass class SolRange: """Flexible solution count constraint for puzzle generation.""" min_sol: int max_sol: Optional[int] @classmethod def parse(cls, expr: str) -> "SolRange": expr = expr.strip() m = re.fullmatch(r'(\d+)\s*-\s*(\d+)', expr) if m: lo, hi = int(m.group(1)), int(m.group(2)) if lo < 1: raise ValueError(f"min_sol must be >= 1, got {lo}") if hi < lo: raise ValueError(f"Invalid range: {lo}-{hi}") return cls(min_sol=lo, max_sol=hi) m = re.fullmatch(r'(>=|>|<=|<|==)\s*(\d+)', expr) if m: op, n = m.group(1), int(m.group(2)) if op == '>=': return cls(min_sol=max(1, n), max_sol=None) elif op == '>': return cls(min_sol=max(1, n + 1), max_sol=None) elif op == '<=': return cls(min_sol=1, max_sol=n) elif op == '<': return cls(min_sol=1, max_sol=max(1, n - 1)) elif op == '==': return cls(min_sol=n, max_sol=n) m = re.fullmatch(r'(\d+)', expr) if m: n = int(m.group(1)) if n < 1: raise ValueError(f"sol_num must be >= 1, got {n}") return cls(min_sol=n, max_sol=n) raise ValueError(f"Invalid sol_num expression: '{expr}'") @property def is_exact(self): return self.max_sol is not None and self.min_sol == self.max_sol @property def is_unique_only(self): return self.is_exact and self.min_sol == 1 @property def allows_unique(self): return self.min_sol <= 1 @property def requires_multi(self): return self.min_sol > 1 @property def effective_max(self): return self.max_sol if self.max_sol is not None else max(self.min_sol, 10) def accepts(self, count): if count < self.min_sol: return False if self.max_sol is not None and count > self.max_sol: return False return True def __repr__(self): if self.is_exact: return f"SolRange(=={self.min_sol})" if self.max_sol is None: return f"SolRange(>={self.min_sol})" return f"SolRange({self.min_sol}-{self.max_sol})" # ==================== Checkpoint Management ==================== @dataclass class GenerationState: """Tracks generation progress for checkpoint/resume.""" params_hash: str clue_progress: Dict[int, int] # clue_level -> generated_count seen_grids: List[str] all_samples: List[Dict] completed: bool = False def to_dict(self) -> Dict: return asdict(self) @classmethod def from_dict(cls, d: Dict) -> "GenerationState": return cls(**d) def compute_params_hash(params: Dict) -> str: """Compute hash of generation parameters for consistency check.""" import hashlib key_params = {k: v for k, v in params.items() if k not in ['output_dir']} return hashlib.md5(json.dumps(key_params, sort_keys=True).encode()).hexdigest()[:12] def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: """Load checkpoint if exists and params match.""" meta_path = output_dir / "metadata.json" if not meta_path.exists(): return None with open(meta_path) as f: data = json.load(f) state = GenerationState.from_dict(data["state"]) expected_hash = compute_params_hash(params) if state.params_hash != expected_hash: print(f"⚠️ Parameters changed (hash {state.params_hash} → {expected_hash}), starting fresh") return None if state.completed: print("✓ Generation already completed") return state print(f"✓ Resuming from checkpoint: {sum(state.clue_progress.values())} puzzles generated") return state def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): """Save current generation state to metadata.json.""" meta_path = output_dir / "metadata.json" tmp_path = meta_path.with_suffix('.tmp') with open(tmp_path, 'w') as f: json.dump({"params": params, "state": state.to_dict()}, f, indent=2) tmp_path.rename(meta_path) # ==================== Core Functions ==================== def get_fill_order(puzzle, solution): """Return list of (row, col, value) for empty cells in row-major order.""" return [(i, j, solution[i][j]) for i in range(9) for j in range(9) if puzzle[i][j] == 0] def create_processor(resolution=None): """Create a SudokuProcessor with optional custom resolution.""" if resolution is None: return SudokuProcessor() target_size = min(resolution) cell_size = target_size // 9 sf = cell_size / 60 return SudokuProcessor( cell_size=cell_size, font_scale=1.2 * sf, thickness=max(1, int(2 * sf)) ) def generate_video_frames(proc, puzzle, solution, n_start, m_end, frames=None): """ Generate progressive video frames for a Sudoku solve. The *frames* parameter controls the number of **content frames** (between the opening and closing holds): - frames=None → 1 content frame per fill step (n_fills total) - frames > fills → multiple frames per fill step (slow-motion) - frames < fills → multiple fills per frame (fast-forward) - frames == fills → identical to None Total output length = n_start + content_frames + m_end. Args: proc: SudokuProcessor instance. puzzle: 9×9 puzzle grid (0 = empty). solution: 9×9 solved grid. n_start: Hold frames for puzzle at the beginning. m_end: Hold frames for completed solution at the end. frames: Desired number of content frames (None = one per fill). Returns: List of numpy arrays (RGB images). """ fills = get_fill_order(puzzle, solution) n_fills = len(fills) if n_fills == 0: img = proc.render(solution, original=puzzle) return [img.copy() for _ in range(n_start + m_end + 1)] content_frames = frames if frames is not None else n_fills content_frames = max(1, content_frames) result = [] current = [row[:] for row in puzzle] # --- opening hold --- img = proc.render(current) result.extend([img.copy() for _ in range(n_start)]) # --- content frames --- if content_frames == n_fills: # Exact 1:1 mapping for r, c, v in fills: current[r][c] = v result.append(proc.render(current, highlight_new=(r, c), original=puzzle)) elif content_frames > n_fills: # Slow-motion: distribute content_frames evenly across n_fills steps. for i, (r, c, v) in enumerate(fills): current[r][c] = v f_lo = i * content_frames // n_fills f_hi = (i + 1) * content_frames // n_fills count = f_hi - f_lo # >= 1 # First frame of this step shows highlight result.append(proc.render(current, highlight_new=(r, c), original=puzzle)) # Remaining hold frames (no highlight) if count > 1: img = proc.render(current, original=puzzle) result.extend([img.copy() for _ in range(count - 1)]) else: # Fast-forward: each content frame applies multiple fills at once. for f in range(content_frames): prev_step = f * n_fills // content_frames target_step = (f + 1) * n_fills // content_frames last_r, last_c = None, None for idx in range(prev_step, target_step): r, c, v = fills[idx] current[r][c] = v last_r, last_c = r, c if last_r is not None: result.append( proc.render(current, highlight_new=(last_r, last_c), original=puzzle) ) else: result.append(proc.render(current, original=puzzle)) # --- closing hold --- img = proc.render(solution, original=puzzle) result.extend([img.copy() for _ in range(m_end)]) return result def save_video(frames, path, fps=10): """Save list of numpy RGB frames as mp4.""" h, w = frames[0].shape[:2] writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) for f in frames: writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR)) writer.release() def normalize_num_per_clue(num_per_clue, clue_levels): """Broadcast single int to list, or validate list length.""" if isinstance(num_per_clue, int): return [num_per_clue] * len(clue_levels) if len(num_per_clue) != len(clue_levels): raise ValueError( f"num_per_clue length ({len(num_per_clue)}) != clue_levels ({len(clue_levels)})" ) return num_per_clue # ==================== Puzzle Generation with SolRange ==================== def generate_puzzle_with_range(proc, clue, sol_range, min_hamming): """Generate one puzzle respecting sol_range. Returns (puzzle, solutions) or None.""" if sol_range.is_unique_only: puzzle, solution = proc.generate(clue, unique=True) return puzzle, [solution] if sol_range.requires_multi: try: puzzle, solutions = proc.generate_multi_solution( clue, min_solutions=sol_range.min_sol, max_solutions=sol_range.effective_max, max_attempts=1, min_hamming=min_hamming ) if sol_range.accepts(len(solutions)): return puzzle, solutions except RuntimeError: pass return None try: puzzle, solutions = proc.generate_multi_solution( clue, min_solutions=max(2, sol_range.min_sol), max_solutions=sol_range.effective_max, max_attempts=1, min_hamming=min_hamming ) if sol_range.accepts(len(solutions)): return puzzle, solutions except RuntimeError: pass if sol_range.allows_unique: puzzle, solution = proc.generate(clue, unique=True) return puzzle, [solution] return None # ==================== Dataset Generation ==================== def generate_dataset( output_dir="sudoku", clue_levels=[20, 30, 40, 50, 60, 70], num_per_clue=[15000, 10000, 10000, 5000, 2000, 1000], sol_num="<=3", min_hamming=10, train_ratio=0.9, prompt="Solve this Sudoku puzzle using red font.", n_start=2, m_end=3, frames=None, fps=10, resolution=None, seed=42, checkpoint_interval=50 ): """ Generate Sudoku video dataset with checkpoint/resume support. The *frames* parameter controls the number of **content frames** per video: - None → one content frame per fill step (variable length per puzzle) - N > 0 → exactly N content frames; fills distributed evenly Args: checkpoint_interval: Save checkpoint every N puzzles (default: 50) """ params = { "clue_levels": clue_levels, "num_per_clue": num_per_clue, "sol_num": sol_num, "min_hamming": min_hamming, "train_ratio": train_ratio, "prompt": prompt, "n_start": n_start, "m_end": m_end, "frames": frames, "fps": fps, "resolution": resolution, "seed": seed } output_dir = Path(output_dir) video_dir = output_dir / "videos" image_dir = output_dir / "images" video_dir.mkdir(parents=True, exist_ok=True) image_dir.mkdir(parents=True, exist_ok=True) # Try to resume from checkpoint state = load_checkpoint(output_dir, params) if state and state.completed: return sol_range = SolRange.parse(str(sol_num)) proc = create_processor(resolution) actual_size = proc.img_size num_per_clue_list = normalize_num_per_clue(num_per_clue, clue_levels) max_puzzles = max(num_per_clue_list) num_width = len(str(max_puzzles)) # Initialize or restore state if state is None: random.seed(seed) state = GenerationState( params_hash=compute_params_hash(params), clue_progress={clue: 0 for clue in clue_levels}, seen_grids=[], all_samples=[] ) print(f"Starting fresh generation with solution range: {sol_range}") print(f" frames={'auto (1 per fill)' if frames is None else frames}, " f"n_start={n_start}, m_end={m_end}, fps={fps}") else: random.seed(seed) for _ in range(sum(state.clue_progress.values()) * 10): random.random() seen_grids = set(state.seen_grids) all_samples = state.all_samples.copy() clue_progress = {int(k): v for k, v in state.clue_progress.items()} total_target = sum(num_per_clue_list) total_done = sum(clue_progress.values()) stats_unique = sum(1 for s in all_samples if s["total_solutions"] == 1 and s["sol_idx"] == 0) stats_multi = sum(1 for s in all_samples if s["total_solutions"] > 1 and s["sol_idx"] == 0) puzzles_since_checkpoint = 0 with tqdm(total=total_target, initial=total_done, desc="Total", unit="puzzle") as pbar_total: for clue, target_count in zip(clue_levels, num_per_clue_list): generated = clue_progress.get(clue, 0) if generated >= target_count: continue max_attempts = (target_count - generated) * 20 with tqdm(total=target_count, initial=generated, desc=f"Clue {clue:2d}", unit="puzzle", leave=False) as pbar_clue: for _ in range(max_attempts): if generated >= target_count: break result = generate_puzzle_with_range(proc, clue, sol_range, min_hamming) if result is None: continue puzzle, solutions = result fp = proc.encode(puzzle) if fp in seen_grids: continue seen_grids.add(fp) n_sols = len(solutions) if n_sols == 1: stats_unique += 1 else: stats_multi += 1 img_name = f"clue{clue}_{generated:0{num_width}d}.png" puzzle_img = proc.render(puzzle) cv2.imwrite( str(image_dir / img_name), cv2.cvtColor(puzzle_img, cv2.COLOR_RGB2BGR), ) for si, sol in enumerate(solutions): vid_name = f"clue{clue}_{generated:0{num_width}d}_sol{si}.mp4" vid_frames = generate_video_frames( proc, puzzle, sol, n_start, m_end, frames ) save_video(vid_frames, video_dir / vid_name, fps) hdists = [ proc._hamming(sol, solutions[j]) for j in range(n_sols) if j != si ] all_samples.append({ "prompt": prompt, "video": vid_name, "image": img_name, "clue": clue, "puzzle": fp, "solution": proc.encode(sol), "sol_idx": si, "total_solutions": n_sols, "frame_count": len(vid_frames), "min_hamming_to_others": min(hdists) if hdists else 0, }) generated += 1 clue_progress[clue] = generated puzzles_since_checkpoint += 1 pbar_clue.update(1) pbar_total.update(1) if puzzles_since_checkpoint >= checkpoint_interval: state.clue_progress = clue_progress state.seen_grids = list(seen_grids) state.all_samples = all_samples save_checkpoint(output_dir, state, params) puzzles_since_checkpoint = 0 tqdm.write( f"Clue {clue}: {generated} puzzles, " f"{sum(1 for s in all_samples if s['clue'] == clue)} videos" ) # Stratified split: ensure each clue level is proportionally represented random.seed(seed + 1) by_clue: Dict[int, List[Dict]] = {} for s in all_samples: by_clue.setdefault(s["clue"], []).append(s) train_samples, test_samples = [], [] for clue in sorted(by_clue): group = by_clue[clue] random.shuffle(group) cl_split = int(len(group) * train_ratio) train_samples.extend(group[:cl_split]) test_samples.extend(group[cl_split:]) random.shuffle(train_samples) random.shuffle(test_samples) split_idx = len(train_samples) def write_jsonl(samples, path): with open(path, 'w') as f: for s in samples: json.dump(s, f) f.write('\n') write_jsonl(train_samples, output_dir / "train.jsonl") write_jsonl(test_samples, output_dir / "test.jsonl") # Mark as completed state.clue_progress = clue_progress state.seen_grids = list(seen_grids) state.all_samples = all_samples state.completed = True save_checkpoint(output_dir, state, params) print(f"\n✓ Dataset complete: {output_dir}/") print(f" Resolution: {actual_size}x{actual_size}") print(f" Solution range: {sol_range}") print(f" Puzzles: {len(seen_grids)} ({stats_unique} unique, {stats_multi} multi-sol)") print(f" Videos: {len(all_samples)}") print(f" Train: {split_idx}, Test: {len(all_samples) - split_idx}") fcounts = [s["frame_count"] for s in all_samples] print(f" Frame counts: avg={np.mean(fcounts):.1f}, " f"min={min(fcounts)}, max={max(fcounts)}") hammings = [s["min_hamming_to_others"] for s in all_samples if s["min_hamming_to_others"] > 0] if hammings: print(f" Solution diversity: avg={np.mean(hammings):.1f}, " f"min={min(hammings)}, max={max(hammings)}") def parse_resolution(s): w, h = map(int, s.lower().split('x')) return (w, h) def parse_args(): p = argparse.ArgumentParser( description="Generate Sudoku video dataset with resume support" ) p.add_argument("--output-dir", type=str, default="sudoku") p.add_argument("--clue-levels", type=int, nargs="+", default=[20, 30, 40, 50, 60, 70]) p.add_argument("--num-per-clue", type=int, nargs="+", default=[15000, 10000, 10000, 5000, 2000, 1000]) p.add_argument("--sol-num", type=str, default="<=3", help="'1', '3', '>=1', '>1', '<=3', '<3', '2-5'") p.add_argument("--min-hamming", type=int, default=10) p.add_argument("--train-ratio", type=float, default=0.9) p.add_argument("--prompt", type=str, default="Solve this Sudoku puzzle using red font.") p.add_argument("--n-start", type=int, default=2, help="Hold frames for puzzle at video start") p.add_argument("--m-end", type=int, default=3, help="Hold frames for completed solution at video end") p.add_argument("--frames", type=int, default=None, help="Content frames per video. None=1 per fill (auto). " "If > fills: slow-motion. If < fills: fast-forward.") p.add_argument("--fps", type=int, default=10) p.add_argument("--resolution", type=str, default="1024x1024") p.add_argument("--seed", type=int, default=42) p.add_argument("--checkpoint-interval", type=int, default=50, help="Save checkpoint every N puzzles (default: 50)") return p.parse_args() if __name__ == "__main__": args = parse_args() kwargs = vars(args) if isinstance(kwargs["num_per_clue"], list) and len(kwargs["num_per_clue"]) == 1: kwargs["num_per_clue"] = kwargs["num_per_clue"][0] if kwargs["resolution"]: kwargs["resolution"] = parse_resolution(kwargs["resolution"]) generate_dataset(**kwargs)