|
|
|
|
|
import argparse |
|
|
import binascii |
|
|
import os |
|
|
import os.path as osp |
|
|
|
|
|
import imageio |
|
|
import torch |
|
|
import torchvision |
|
|
|
|
|
__all__ = ["cache_video", "cache_image", "str2bool"] |
|
|
|
|
|
|
|
|
def rand_name(length=8, suffix=""): |
|
|
name = binascii.b2a_hex(os.urandom(length)).decode("utf-8") |
|
|
if suffix: |
|
|
if not suffix.startswith("."): |
|
|
suffix = "." + suffix |
|
|
name += suffix |
|
|
return name |
|
|
|
|
|
|
|
|
def cache_video( |
|
|
tensor, |
|
|
save_file=None, |
|
|
fps=30, |
|
|
suffix=".mp4", |
|
|
nrow=8, |
|
|
normalize=True, |
|
|
value_range=(-1, 1), |
|
|
retry=5, |
|
|
): |
|
|
|
|
|
cache_file = ( |
|
|
osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file |
|
|
) |
|
|
|
|
|
|
|
|
error = None |
|
|
for _ in range(retry): |
|
|
try: |
|
|
|
|
|
tensor = tensor.clamp(min(value_range), max(value_range)) |
|
|
tensor = torch.stack( |
|
|
[ |
|
|
torchvision.utils.make_grid( |
|
|
u, nrow=nrow, normalize=normalize, value_range=value_range |
|
|
) |
|
|
for u in tensor.unbind(2) |
|
|
], |
|
|
dim=1, |
|
|
).permute(1, 2, 3, 0) |
|
|
tensor = (tensor * 255).type(torch.uint8).cpu() |
|
|
|
|
|
|
|
|
writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) |
|
|
for frame in tensor.numpy(): |
|
|
writer.append_data(frame) |
|
|
writer.close() |
|
|
return cache_file |
|
|
except Exception as e: |
|
|
error = e |
|
|
continue |
|
|
else: |
|
|
print(f"cache_video failed, error: {error}", flush=True) |
|
|
return None |
|
|
|
|
|
|
|
|
def cache_image( |
|
|
tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5 |
|
|
): |
|
|
|
|
|
suffix = osp.splitext(save_file)[1] |
|
|
if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]: |
|
|
suffix = ".png" |
|
|
|
|
|
|
|
|
error = None |
|
|
for _ in range(retry): |
|
|
try: |
|
|
tensor = tensor.clamp(min(value_range), max(value_range)) |
|
|
torchvision.utils.save_image( |
|
|
tensor, |
|
|
save_file, |
|
|
nrow=nrow, |
|
|
normalize=normalize, |
|
|
value_range=value_range, |
|
|
) |
|
|
return save_file |
|
|
except Exception as e: |
|
|
error = e |
|
|
continue |
|
|
|
|
|
|
|
|
def str2bool(v): |
|
|
""" |
|
|
Convert a string to a boolean. |
|
|
|
|
|
Supported true values: 'yes', 'true', 't', 'y', '1' |
|
|
Supported false values: 'no', 'false', 'f', 'n', '0' |
|
|
|
|
|
Args: |
|
|
v (str): String to convert. |
|
|
|
|
|
Returns: |
|
|
bool: Converted boolean value. |
|
|
|
|
|
Raises: |
|
|
argparse.ArgumentTypeError: If the value cannot be converted to boolean. |
|
|
""" |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
v_lower = v.lower() |
|
|
if v_lower in ("yes", "true", "t", "y", "1"): |
|
|
return True |
|
|
elif v_lower in ("no", "false", "f", "n", "0"): |
|
|
return False |
|
|
else: |
|
|
raise argparse.ArgumentTypeError("Boolean value expected (True/False)") |
|
|
|