| |
| import argparse |
| import binascii |
| import os |
| import os.path as osp |
|
|
| import imageio |
| import torch |
| import torchvision |
|
|
| __all__ = ['cache_video', 'cache_image', 'str2bool'] |
|
|
|
|
| def rand_name(length=8, suffix=''): |
| name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') |
| if suffix: |
| if not suffix.startswith('.'): |
| suffix = '.' + suffix |
| name += suffix |
| return name |
|
|
|
|
| def cache_video(tensor, |
| save_file=None, |
| fps=30, |
| suffix='.mp4', |
| nrow=8, |
| normalize=True, |
| value_range=(-1, 1), |
| retry=5): |
| |
| cache_file = osp.join('/tmp', rand_name( |
| suffix=suffix)) if save_file is None else save_file |
|
|
| |
| error = None |
| for _ in range(retry): |
| try: |
| |
| tensor = tensor.clamp(min(value_range), max(value_range)) |
| tensor = torch.stack([ |
| torchvision.utils.make_grid( |
| u, nrow=nrow, normalize=normalize, value_range=value_range) |
| for u in tensor.unbind(2) |
| ], |
| dim=1).permute(1, 2, 3, 0) |
| tensor = (tensor * 255).type(torch.uint8).cpu() |
|
|
| |
| writer = imageio.get_writer( |
| cache_file, fps=fps, codec='libx264', quality=8) |
| for frame in tensor.numpy(): |
| writer.append_data(frame) |
| writer.close() |
| return cache_file |
| except Exception as e: |
| error = e |
| continue |
| else: |
| print(f'cache_video failed, error: {error}', flush=True) |
| return None |
|
|
|
|
| def cache_image(tensor, |
| save_file, |
| nrow=8, |
| normalize=True, |
| value_range=(-1, 1), |
| retry=5): |
| |
| suffix = osp.splitext(save_file)[1] |
| if suffix.lower() not in [ |
| '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' |
| ]: |
| suffix = '.png' |
|
|
| |
| error = None |
| for _ in range(retry): |
| try: |
| tensor = tensor.clamp(min(value_range), max(value_range)) |
| torchvision.utils.save_image( |
| tensor, |
| save_file, |
| nrow=nrow, |
| normalize=normalize, |
| value_range=value_range) |
| return save_file |
| except Exception as e: |
| error = e |
| continue |
|
|
|
|
| def str2bool(v): |
| """ |
| Convert a string to a boolean. |
| |
| Supported true values: 'yes', 'true', 't', 'y', '1' |
| Supported false values: 'no', 'false', 'f', 'n', '0' |
| |
| Args: |
| v (str): String to convert. |
| |
| Returns: |
| bool: Converted boolean value. |
| |
| Raises: |
| argparse.ArgumentTypeError: If the value cannot be converted to boolean. |
| """ |
| if isinstance(v, bool): |
| return v |
| v_lower = v.lower() |
| if v_lower in ('yes', 'true', 't', 'y', '1'): |
| return True |
| elif v_lower in ('no', 'false', 'f', 'n', '0'): |
| return False |
| else: |
| raise argparse.ArgumentTypeError('Boolean value expected (True/False)') |
|
|