from __future__ import annotations import os from pathlib import Path from typing import Iterable import numpy as np import torch HDR_REFERENCE_WHITE_NITS = 203.0 HDR10_MASTER_DISPLAY = "G(13250,34500)B(7500,3000)R(34000,16000)WP(15635,16450)L(10000000,1)" HDR10_MAX_CLL = "10000,400" VIDEO_PROMPT_HDR_OUTPUT_FLAG = "&" def hdr10_zscale_filter(*, reference_white_nits: float = HDR_REFERENCE_WHITE_NITS) -> str: return ( "zscale=pin=709:tin=linear:min=gbr:rin=full:" f"p=2020:t=smpte2084:m=2020_ncl:r=limited:npl={float(reference_white_nits):.12g}," "format=yuv420p10le" ) def hdr10_x265_params() -> str: return f"hdr10=1:repeat-headers=1:master-display={HDR10_MASTER_DISPLAY}:max-cll={HDR10_MAX_CLL}:log-level=none" class LogC3: A = 5.555556 B = 0.052272 C = 0.247190 D = 0.385537 E = 5.367655 F = 0.092809 CUT = 0.010591 def compress(self, hdr: torch.Tensor) -> torch.Tensor: x = torch.clamp(hdr, min=0.0) log_part = self.C * torch.log10(self.A * x + self.B) + self.D lin_part = self.E * x + self.F return torch.where(x >= self.CUT, log_part, lin_part).clamp_(0.0, 1.0) def compress_ldr(self, ldr: torch.Tensor) -> torch.Tensor: return torch.clamp(ldr, 0.0, 1.0) def decompress(self, logc: torch.Tensor) -> torch.Tensor: logc = torch.clamp(logc, 0.0, 1.0) cut_log = self.E * self.CUT + self.F lin_from_log = (torch.pow(10.0, (logc - self.D) / self.C) - self.B) / self.A lin_from_lin = (logc - self.F) / self.E return torch.where(logc >= cut_log, lin_from_log, lin_from_lin).clamp_(min=0.0) def hdr_linear_to_vae_range(frames: torch.Tensor, *, transform: str = "logc3") -> torch.Tensor: frames = frames.to(dtype=torch.float32) if transform != "logc3": raise ValueError(f"Unsupported HDR transform: {transform}") return LogC3().compress(frames).mul_(2.0).sub_(1.0) def vae_range_to_hdr_linear(frames: torch.Tensor, *, transform: str = "logc3") -> torch.Tensor: frames = frames.to(dtype=torch.float32).add_(1.0).mul_(0.5).clamp_(0.0, 1.0) if transform != "logc3": raise ValueError(f"Unsupported HDR transform: {transform}") return LogC3().decompress(frames) def linear_to_srgb(linear: torch.Tensor) -> torch.Tensor: linear = torch.clamp(linear, 0.0, 1.0) low = linear * 12.92 high = 1.055 * torch.pow(linear, 1.0 / 2.4) - 0.055 return torch.where(linear <= 0.0031308, low, high).clamp_(0.0, 1.0) def tonemap_hdr_tensor_to_uint8(video: torch.Tensor, *, exposure: float = 0.0) -> torch.Tensor: if video.ndim == 5 and video.shape[0] == 1: video = video[0] if video.ndim != 4: raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(video.shape)}.") scale = float(2.0 ** float(exposure)) srgb = linear_to_srgb(video.to(dtype=torch.float32).mul(scale)) return srgb.mul(255.0).round_().clamp_(0.0, 255.0).to(torch.uint8) def iter_video_chunks(video: torch.Tensor | Iterable[torch.Tensor]): if torch.is_tensor(video): yield video return for chunk in video: if chunk is not None: yield chunk def iter_hdr_gbrpf32_frames(video: torch.Tensor | Iterable[torch.Tensor]): for chunk in iter_video_chunks(video): if chunk is None: continue if chunk.ndim == 5 and chunk.shape[0] == 1: chunk = chunk[0] if chunk.ndim != 4: raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(chunk.shape)}.") frames = chunk.detach().cpu().to(dtype=torch.float32) for frame in frames.permute(1, 0, 2, 3): yield frame[[1, 2, 0]].contiguous().numpy().astype(np.float32, copy=False).tobytes() def write_hdr_exr_frames( video: torch.Tensor, output_dir: str | os.PathLike[str], *, start_index: int = 0, exr_half: bool = True, ) -> int: os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1") import cv2 if video.ndim == 5 and video.shape[0] == 1: video = video[0] if video.ndim != 4: raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(video.shape)}.") Path(output_dir).mkdir(parents=True, exist_ok=True) frame_count = int(video.shape[1]) params: list[int] = [] if exr_half and hasattr(cv2, "IMWRITE_EXR_TYPE") and hasattr(cv2, "IMWRITE_EXR_TYPE_HALF"): params = [int(cv2.IMWRITE_EXR_TYPE), int(cv2.IMWRITE_EXR_TYPE_HALF)] frames = video.detach().cpu().to(dtype=torch.float32).permute(1, 2, 3, 0).contiguous() for idx, frame in enumerate(frames, start=int(start_index)): rgb = frame.numpy().astype(np.float32, copy=False) bgr = np.ascontiguousarray(rgb[..., ::-1]) path = os.path.join(os.fspath(output_dir), f"frame_{idx:06d}.exr") if not cv2.imwrite(path, bgr, params): raise RuntimeError(f"Failed to write HDR EXR frame: {path}") return frame_count def read_hdr_exr_frames( output_dir: str | os.PathLike[str], *, start_index: int, frame_count: int, ) -> torch.Tensor | None: os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1") import cv2 frames = [] for idx in range(int(start_index), int(start_index) + int(frame_count)): path = os.path.join(os.fspath(output_dir), f"frame_{idx:06d}.exr") if not os.path.isfile(path): return None bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED) if bgr is None: return None rgb = np.ascontiguousarray(bgr[..., ::-1]).astype(np.float32, copy=False) frames.append(torch.from_numpy(rgb)) if not frames: return None return torch.stack(frames, dim=0).permute(3, 0, 1, 2).contiguous()