File size: 5,760 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | from __future__ import annotations
import os
from pathlib import Path
from typing import Iterable
import numpy as np
import torch
HDR_REFERENCE_WHITE_NITS = 203.0
HDR10_MASTER_DISPLAY = "G(13250,34500)B(7500,3000)R(34000,16000)WP(15635,16450)L(10000000,1)"
HDR10_MAX_CLL = "10000,400"
VIDEO_PROMPT_HDR_OUTPUT_FLAG = "&"
def hdr10_zscale_filter(*, reference_white_nits: float = HDR_REFERENCE_WHITE_NITS) -> str:
return (
"zscale=pin=709:tin=linear:min=gbr:rin=full:"
f"p=2020:t=smpte2084:m=2020_ncl:r=limited:npl={float(reference_white_nits):.12g},"
"format=yuv420p10le"
)
def hdr10_x265_params() -> str:
return f"hdr10=1:repeat-headers=1:master-display={HDR10_MASTER_DISPLAY}:max-cll={HDR10_MAX_CLL}:log-level=none"
class LogC3:
A = 5.555556
B = 0.052272
C = 0.247190
D = 0.385537
E = 5.367655
F = 0.092809
CUT = 0.010591
def compress(self, hdr: torch.Tensor) -> torch.Tensor:
x = torch.clamp(hdr, min=0.0)
log_part = self.C * torch.log10(self.A * x + self.B) + self.D
lin_part = self.E * x + self.F
return torch.where(x >= self.CUT, log_part, lin_part).clamp_(0.0, 1.0)
def compress_ldr(self, ldr: torch.Tensor) -> torch.Tensor:
return torch.clamp(ldr, 0.0, 1.0)
def decompress(self, logc: torch.Tensor) -> torch.Tensor:
logc = torch.clamp(logc, 0.0, 1.0)
cut_log = self.E * self.CUT + self.F
lin_from_log = (torch.pow(10.0, (logc - self.D) / self.C) - self.B) / self.A
lin_from_lin = (logc - self.F) / self.E
return torch.where(logc >= cut_log, lin_from_log, lin_from_lin).clamp_(min=0.0)
def hdr_linear_to_vae_range(frames: torch.Tensor, *, transform: str = "logc3") -> torch.Tensor:
frames = frames.to(dtype=torch.float32)
if transform != "logc3":
raise ValueError(f"Unsupported HDR transform: {transform}")
return LogC3().compress(frames).mul_(2.0).sub_(1.0)
def vae_range_to_hdr_linear(frames: torch.Tensor, *, transform: str = "logc3") -> torch.Tensor:
frames = frames.to(dtype=torch.float32).add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
if transform != "logc3":
raise ValueError(f"Unsupported HDR transform: {transform}")
return LogC3().decompress(frames)
def linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
linear = torch.clamp(linear, 0.0, 1.0)
low = linear * 12.92
high = 1.055 * torch.pow(linear, 1.0 / 2.4) - 0.055
return torch.where(linear <= 0.0031308, low, high).clamp_(0.0, 1.0)
def tonemap_hdr_tensor_to_uint8(video: torch.Tensor, *, exposure: float = 0.0) -> torch.Tensor:
if video.ndim == 5 and video.shape[0] == 1:
video = video[0]
if video.ndim != 4:
raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(video.shape)}.")
scale = float(2.0 ** float(exposure))
srgb = linear_to_srgb(video.to(dtype=torch.float32).mul(scale))
return srgb.mul(255.0).round_().clamp_(0.0, 255.0).to(torch.uint8)
def iter_video_chunks(video: torch.Tensor | Iterable[torch.Tensor]):
if torch.is_tensor(video):
yield video
return
for chunk in video:
if chunk is not None:
yield chunk
def iter_hdr_gbrpf32_frames(video: torch.Tensor | Iterable[torch.Tensor]):
for chunk in iter_video_chunks(video):
if chunk is None:
continue
if chunk.ndim == 5 and chunk.shape[0] == 1:
chunk = chunk[0]
if chunk.ndim != 4:
raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(chunk.shape)}.")
frames = chunk.detach().cpu().to(dtype=torch.float32)
for frame in frames.permute(1, 0, 2, 3):
yield frame[[1, 2, 0]].contiguous().numpy().astype(np.float32, copy=False).tobytes()
def write_hdr_exr_frames(
video: torch.Tensor,
output_dir: str | os.PathLike[str],
*,
start_index: int = 0,
exr_half: bool = True,
) -> int:
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1")
import cv2
if video.ndim == 5 and video.shape[0] == 1:
video = video[0]
if video.ndim != 4:
raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(video.shape)}.")
Path(output_dir).mkdir(parents=True, exist_ok=True)
frame_count = int(video.shape[1])
params: list[int] = []
if exr_half and hasattr(cv2, "IMWRITE_EXR_TYPE") and hasattr(cv2, "IMWRITE_EXR_TYPE_HALF"):
params = [int(cv2.IMWRITE_EXR_TYPE), int(cv2.IMWRITE_EXR_TYPE_HALF)]
frames = video.detach().cpu().to(dtype=torch.float32).permute(1, 2, 3, 0).contiguous()
for idx, frame in enumerate(frames, start=int(start_index)):
rgb = frame.numpy().astype(np.float32, copy=False)
bgr = np.ascontiguousarray(rgb[..., ::-1])
path = os.path.join(os.fspath(output_dir), f"frame_{idx:06d}.exr")
if not cv2.imwrite(path, bgr, params):
raise RuntimeError(f"Failed to write HDR EXR frame: {path}")
return frame_count
def read_hdr_exr_frames(
output_dir: str | os.PathLike[str],
*,
start_index: int,
frame_count: int,
) -> torch.Tensor | None:
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1")
import cv2
frames = []
for idx in range(int(start_index), int(start_index) + int(frame_count)):
path = os.path.join(os.fspath(output_dir), f"frame_{idx:06d}.exr")
if not os.path.isfile(path):
return None
bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if bgr is None:
return None
rgb = np.ascontiguousarray(bgr[..., ::-1]).astype(np.float32, copy=False)
frames.append(torch.from_numpy(rgb))
if not frames:
return None
return torch.stack(frames, dim=0).permute(3, 0, 1, 2).contiguous()
|