Phillnet-2 / VideoGen /keyframes.py
ayjays132's picture
Upload 470 files
ad2ce18 verified
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from PIL import Image
from .storyboard import VideoStoryboard
@dataclass
class KeyframeResult:
frames: list[Image.Image]
paths: list[Path]
cache_hits: int
metadata: list[dict[str, Any]]
def cache_key(payload: dict[str, Any]) -> str:
raw = json.dumps(payload, sort_keys=True, ensure_ascii=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:24]
def _json_safe(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, Path):
return str(value)
if isinstance(value, (list, tuple)):
return [_json_safe(item) for item in value]
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda pair: str(pair[0]))}
return repr(value)
def generate_keyframes(
router: Any,
storyboard: VideoStoryboard,
*,
cache_dir: str | Path,
width: int,
height: int,
image_steps: int,
guidance_scale: float,
seed: int,
motion: str,
version: str,
use_cache: bool = True,
**kwargs: Any,
) -> KeyframeResult:
root = Path(cache_dir)
root.mkdir(parents=True, exist_ok=True)
generation_strategy = kwargs.pop("generation_strategy", "diffusion")
frames: list[Image.Image] = []
paths: list[Path] = []
metadata: list[dict[str, Any]] = []
cache_hits = 0
for idx, prompt in enumerate(storyboard.keyframe_prompts):
image_kwargs = {
"use_memory": False,
"reference_pass_steps": 0,
"unload_after_call": False,
**kwargs,
}
payload = {
"prompt": prompt,
"negative": storyboard.negative_prompt,
"width": width,
"height": height,
"steps": image_steps,
"guidance_scale": guidance_scale,
"seed": seed + idx,
"motion": motion,
"videogen_version": version,
"image_options": {
"generation_strategy": generation_strategy,
**{str(key): _json_safe(value) for key, value in sorted(image_kwargs.items(), key=lambda pair: str(pair[0]))},
},
}
path = root / f"{cache_key(payload)}.png"
if use_cache and path.exists():
frame = Image.open(path).convert("RGB")
cache_hits += 1
route_meta = {"cache_hit": True, "saved_path": str(path)}
else:
result = router.generate_image(
prompt=prompt,
height=height,
width=width,
steps=image_steps,
guidance_scale=guidance_scale,
seed=seed + idx,
generation_strategy=generation_strategy,
**image_kwargs,
)
frame = result.payload.images[0].convert("RGB")
frame.save(path)
route_meta = dict(getattr(result, "metadata", {}))
route_meta["cache_hit"] = False
route_meta["saved_path"] = str(path)
frames.append(frame)
paths.append(path)
metadata.append(route_meta)
return KeyframeResult(frames=frames, paths=paths, cache_hits=cache_hits, metadata=metadata)