| | |
| | import argparse |
| | import binascii |
| | import os |
| | import os.path as osp |
| |
|
| | import imageio |
| | import torch |
| | import torchvision |
| |
|
| | __all__ = ['cache_video', 'cache_image', 'str2bool'] |
| |
|
| |
|
| | def rand_name(length=8, suffix=''): |
| | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') |
| | if suffix: |
| | if not suffix.startswith('.'): |
| | suffix = '.' + suffix |
| | name += suffix |
| | return name |
| |
|
| |
|
| | def cache_video(tensor, |
| | save_file=None, |
| | fps=30, |
| | suffix='.mp4', |
| | nrow=8, |
| | normalize=True, |
| | value_range=(-1, 1), |
| | retry=5): |
| | |
| | cache_file = osp.join('/tmp', rand_name( |
| | suffix=suffix)) if save_file is None else save_file |
| |
|
| | |
| | error = None |
| | for _ in range(retry): |
| | try: |
| | |
| | tensor = tensor.clamp(min(value_range), max(value_range)) |
| | tensor = torch.stack([ |
| | torchvision.utils.make_grid( |
| | u, nrow=nrow, normalize=normalize, value_range=value_range) |
| | for u in tensor.unbind(2) |
| | ], |
| | dim=1).permute(1, 2, 3, 0) |
| | tensor = (tensor * 255).type(torch.uint8).cpu() |
| |
|
| | |
| | writer = imageio.get_writer( |
| | cache_file, fps=fps, codec='libx264', quality=8) |
| | for frame in tensor.numpy(): |
| | writer.append_data(frame) |
| | writer.close() |
| | return cache_file |
| | except Exception as e: |
| | error = e |
| | continue |
| | else: |
| | print(f'cache_video failed, error: {error}', flush=True) |
| | return None |
| |
|
| |
|
| | def cache_image(tensor, |
| | save_file, |
| | nrow=8, |
| | normalize=True, |
| | value_range=(-1, 1), |
| | retry=5): |
| | |
| | suffix = osp.splitext(save_file)[1] |
| | if suffix.lower() not in [ |
| | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' |
| | ]: |
| | suffix = '.png' |
| |
|
| | |
| | error = None |
| | for _ in range(retry): |
| | try: |
| | tensor = tensor.clamp(min(value_range), max(value_range)) |
| | torchvision.utils.save_image( |
| | tensor, |
| | save_file, |
| | nrow=nrow, |
| | normalize=normalize, |
| | value_range=value_range) |
| | return save_file |
| | except Exception as e: |
| | error = e |
| | continue |
| |
|
| |
|
| | def str2bool(v): |
| | """ |
| | Convert a string to a boolean. |
| | |
| | Supported true values: 'yes', 'true', 't', 'y', '1' |
| | Supported false values: 'no', 'false', 'f', 'n', '0' |
| | |
| | Args: |
| | v (str): String to convert. |
| | |
| | Returns: |
| | bool: Converted boolean value. |
| | |
| | Raises: |
| | argparse.ArgumentTypeError: If the value cannot be converted to boolean. |
| | """ |
| | if isinstance(v, bool): |
| | return v |
| | v_lower = v.lower() |
| | if v_lower in ('yes', 'true', 't', 'y', '1'): |
| | return True |
| | elif v_lower in ('no', 'false', 'f', 'n', '0'): |
| | return False |
| | else: |
| | raise argparse.ArgumentTypeError('Boolean value expected (True/False)') |
| |
|