from __future__ import annotations import hashlib import json from dataclasses import dataclass from pathlib import Path from typing import Any from PIL import Image from .storyboard import VideoStoryboard @dataclass class KeyframeResult: frames: list[Image.Image] paths: list[Path] cache_hits: int metadata: list[dict[str, Any]] def cache_key(payload: dict[str, Any]) -> str: raw = json.dumps(payload, sort_keys=True, ensure_ascii=True) return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24] def _json_safe(value: Any) -> Any: if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, Path): return str(value) if isinstance(value, (list, tuple)): return [_json_safe(item) for item in value] if isinstance(value, dict): return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda pair: str(pair[0]))} return repr(value) def generate_keyframes( router: Any, storyboard: VideoStoryboard, *, cache_dir: str | Path, width: int, height: int, image_steps: int, guidance_scale: float, seed: int, motion: str, version: str, use_cache: bool = True, **kwargs: Any, ) -> KeyframeResult: root = Path(cache_dir) root.mkdir(parents=True, exist_ok=True) generation_strategy = kwargs.pop("generation_strategy", "diffusion") frames: list[Image.Image] = [] paths: list[Path] = [] metadata: list[dict[str, Any]] = [] cache_hits = 0 for idx, prompt in enumerate(storyboard.keyframe_prompts): image_kwargs = { "use_memory": False, "reference_pass_steps": 0, "unload_after_call": False, **kwargs, } payload = { "prompt": prompt, "negative": storyboard.negative_prompt, "width": width, "height": height, "steps": image_steps, "guidance_scale": guidance_scale, "seed": seed + idx, "motion": motion, "videogen_version": version, "image_options": { "generation_strategy": generation_strategy, **{str(key): _json_safe(value) for key, value in sorted(image_kwargs.items(), key=lambda pair: str(pair[0]))}, }, } path = root / f"{cache_key(payload)}.png" if use_cache and path.exists(): frame = Image.open(path).convert("RGB") cache_hits += 1 route_meta = {"cache_hit": True, "saved_path": str(path)} else: result = router.generate_image( prompt=prompt, height=height, width=width, steps=image_steps, guidance_scale=guidance_scale, seed=seed + idx, generation_strategy=generation_strategy, **image_kwargs, ) frame = result.payload.images[0].convert("RGB") frame.save(path) route_meta = dict(getattr(result, "metadata", {})) route_meta["cache_hit"] = False route_meta["saved_path"] = str(path) frames.append(frame) paths.append(path) metadata.append(route_meta) return KeyframeResult(frames=frames, paths=paths, cache_hits=cache_hits, metadata=metadata)