Spaces:
Runtime error
Runtime error
Update basicsr/utils/misc.py
Browse files- basicsr/utils/misc.py +68 -93
basicsr/utils/misc.py
CHANGED
|
@@ -9,36 +9,58 @@ from os import path as osp
|
|
| 9 |
from .dist_util import master_only
|
| 10 |
from .logger import get_root_logger
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def gpu_is_available():
|
| 16 |
-
if
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
return
|
|
|
|
| 20 |
|
| 21 |
def get_device(gpu_id=None):
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
gpu_str = f':{gpu_id}'
|
| 26 |
-
else:
|
| 27 |
-
raise TypeError('Input should be int value.')
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def set_random_seed(seed):
|
| 36 |
"""Set random seeds."""
|
| 37 |
random.seed(seed)
|
| 38 |
np.random.seed(seed)
|
| 39 |
torch.manual_seed(seed)
|
| 40 |
-
torch.cuda.
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def get_time_str():
|
|
@@ -46,112 +68,65 @@ def get_time_str():
|
|
| 46 |
|
| 47 |
|
| 48 |
def mkdir_and_rename(path):
|
| 49 |
-
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
path (str): Folder path.
|
| 53 |
-
"""
|
| 54 |
if osp.exists(path):
|
| 55 |
new_name = path + '_archived_' + get_time_str()
|
| 56 |
-
print(f'Path already exists.
|
| 57 |
os.rename(path, new_name)
|
| 58 |
os.makedirs(path, exist_ok=True)
|
| 59 |
|
| 60 |
|
| 61 |
@master_only
|
| 62 |
def make_exp_dirs(opt):
|
| 63 |
-
"""Make dirs for experiments."""
|
| 64 |
path_opt = opt['path'].copy()
|
| 65 |
if opt['is_train']:
|
| 66 |
mkdir_and_rename(path_opt.pop('experiments_root'))
|
| 67 |
else:
|
| 68 |
mkdir_and_rename(path_opt.pop('results_root'))
|
|
|
|
| 69 |
for key, path in path_opt.items():
|
| 70 |
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
|
| 71 |
os.makedirs(path, exist_ok=True)
|
| 72 |
|
| 73 |
|
| 74 |
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
| 75 |
-
"""Scan a directory to find the interested files.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
dir_path (str): Path of the directory.
|
| 79 |
-
suffix (str | tuple(str), optional): File suffix that we are
|
| 80 |
-
interested in. Default: None.
|
| 81 |
-
recursive (bool, optional): If set to True, recursively scan the
|
| 82 |
-
directory. Default: False.
|
| 83 |
-
full_path (bool, optional): If set to True, include the dir_path.
|
| 84 |
-
Default: False.
|
| 85 |
-
|
| 86 |
-
Returns:
|
| 87 |
-
A generator for all the interested files with relative pathes.
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
-
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 91 |
-
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 92 |
-
|
| 93 |
root = dir_path
|
| 94 |
|
| 95 |
-
def
|
| 96 |
-
for entry in os.scandir(
|
| 97 |
-
if not entry.name.startswith('.')
|
| 98 |
-
if full_path
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
if suffix is None:
|
| 104 |
-
yield return_path
|
| 105 |
-
elif return_path.endswith(suffix):
|
| 106 |
-
yield return_path
|
| 107 |
-
else:
|
| 108 |
-
if recursive:
|
| 109 |
-
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
| 110 |
-
else:
|
| 111 |
-
continue
|
| 112 |
|
| 113 |
-
return
|
| 114 |
|
| 115 |
|
| 116 |
def check_resume(opt, resume_iter):
|
| 117 |
-
"""Check resume states and pretrain_network paths.
|
| 118 |
-
|
| 119 |
-
Args:
|
| 120 |
-
opt (dict): Options.
|
| 121 |
-
resume_iter (int): Resume iteration.
|
| 122 |
-
"""
|
| 123 |
logger = get_root_logger()
|
|
|
|
| 124 |
if opt['path']['resume_state']:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
for network in networks:
|
| 129 |
-
if opt['path'].get(f'pretrain_{network}') is not None:
|
| 130 |
-
flag_pretrain = True
|
| 131 |
if flag_pretrain:
|
| 132 |
logger.warning('pretrain_network path will be ignored during resuming.')
|
| 133 |
-
|
| 134 |
for network in networks:
|
| 135 |
-
name = f'pretrain_{network}'
|
| 136 |
basename = network.replace('network_', '')
|
| 137 |
-
if opt['path'].get('ignore_resume_networks') is None or (
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
def sizeof_fmt(size, suffix='B'):
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
size
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
Return:
|
| 151 |
-
str: Formated file siz.
|
| 152 |
-
"""
|
| 153 |
-
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
| 154 |
-
if abs(size) < 1024.0:
|
| 155 |
-
return f'{size:3.1f} {unit}{suffix}'
|
| 156 |
-
size /= 1024.0
|
| 157 |
-
return f'{size:3.1f} Y{suffix}'
|
|
|
|
| 9 |
from .dist_util import master_only
|
| 10 |
from .logger import get_root_logger
|
| 11 |
|
| 12 |
+
|
| 13 |
+
# ---------------------------
|
| 14 |
+
# GPU / MPS Compatibility
|
| 15 |
+
# ---------------------------
|
| 16 |
+
|
| 17 |
+
# Check if PyTorch ≥ 1.12 for MPS (Apple Silicon)
|
| 18 |
+
try:
|
| 19 |
+
version_match = re.findall(
|
| 20 |
+
r"^([0-9]+)\.([0-9]+)\.([0-9]+)",
|
| 21 |
+
torch.__version__
|
| 22 |
+
)[0]
|
| 23 |
+
IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0]
|
| 24 |
+
except:
|
| 25 |
+
IS_HIGH_VERSION = False
|
| 26 |
+
|
| 27 |
|
| 28 |
def gpu_is_available():
|
| 29 |
+
"""Return True if CUDA or MPS is available."""
|
| 30 |
+
if IS_HIGH_VERSION and torch.backends.mps.is_available():
|
| 31 |
+
return True
|
| 32 |
+
return torch.cuda.is_available() and torch.backends.cudnn.is_available()
|
| 33 |
+
|
| 34 |
|
| 35 |
def get_device(gpu_id=None):
|
| 36 |
+
"""Return the best available device (MPS → CUDA → CPU)."""
|
| 37 |
+
|
| 38 |
+
gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else ""
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
# Apple MPS
|
| 41 |
+
if IS_HIGH_VERSION and torch.backends.mps.is_available():
|
| 42 |
+
return torch.device("mps")
|
|
|
|
| 43 |
|
| 44 |
+
# NVIDIA CUDA
|
| 45 |
+
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
|
| 46 |
+
return torch.device("cuda" + gpu_str)
|
| 47 |
+
|
| 48 |
+
# CPU fallback
|
| 49 |
+
return torch.device("cpu")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------
|
| 53 |
+
# Utilities
|
| 54 |
+
# ---------------------------
|
| 55 |
|
| 56 |
def set_random_seed(seed):
|
| 57 |
"""Set random seeds."""
|
| 58 |
random.seed(seed)
|
| 59 |
np.random.seed(seed)
|
| 60 |
torch.manual_seed(seed)
|
| 61 |
+
if torch.cuda.is_available():
|
| 62 |
+
torch.cuda.manual_seed(seed)
|
| 63 |
+
torch.cuda.manual_seed_all(seed)
|
| 64 |
|
| 65 |
|
| 66 |
def get_time_str():
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def mkdir_and_rename(path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
if osp.exists(path):
|
| 72 |
new_name = path + '_archived_' + get_time_str()
|
| 73 |
+
print(f'Path already exists. Renamed to {new_name}', flush=True)
|
| 74 |
os.rename(path, new_name)
|
| 75 |
os.makedirs(path, exist_ok=True)
|
| 76 |
|
| 77 |
|
| 78 |
@master_only
|
| 79 |
def make_exp_dirs(opt):
|
|
|
|
| 80 |
path_opt = opt['path'].copy()
|
| 81 |
if opt['is_train']:
|
| 82 |
mkdir_and_rename(path_opt.pop('experiments_root'))
|
| 83 |
else:
|
| 84 |
mkdir_and_rename(path_opt.pop('results_root'))
|
| 85 |
+
|
| 86 |
for key, path in path_opt.items():
|
| 87 |
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
|
| 88 |
os.makedirs(path, exist_ok=True)
|
| 89 |
|
| 90 |
|
| 91 |
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
root = dir_path
|
| 93 |
|
| 94 |
+
def _scan(path):
|
| 95 |
+
for entry in os.scandir(path):
|
| 96 |
+
if entry.is_file() and not entry.name.startswith('.'):
|
| 97 |
+
file_path = entry.path if full_path else osp.relpath(entry.path, root)
|
| 98 |
+
if suffix is None or file_path.endswith(suffix):
|
| 99 |
+
yield file_path
|
| 100 |
+
elif entry.is_dir() and recursive:
|
| 101 |
+
yield from _scan(entry.path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
return _scan(dir_path)
|
| 104 |
|
| 105 |
|
| 106 |
def check_resume(opt, resume_iter):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
logger = get_root_logger()
|
| 108 |
+
|
| 109 |
if opt['path']['resume_state']:
|
| 110 |
+
networks = [k for k in opt.keys() if k.startswith('network_')]
|
| 111 |
+
flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks)
|
| 112 |
+
|
|
|
|
|
|
|
|
|
|
| 113 |
if flag_pretrain:
|
| 114 |
logger.warning('pretrain_network path will be ignored during resuming.')
|
| 115 |
+
|
| 116 |
for network in networks:
|
|
|
|
| 117 |
basename = network.replace('network_', '')
|
| 118 |
+
if opt['path'].get('ignore_resume_networks') is None or (
|
| 119 |
+
basename not in opt['path']['ignore_resume_networks']
|
| 120 |
+
):
|
| 121 |
+
opt['path'][f'pretrain_{network}'] = osp.join(
|
| 122 |
+
opt['path']['models'], f'net_{basename}_{resume_iter}.pth'
|
| 123 |
+
)
|
| 124 |
+
logger.info(f"Set pretrain for {network}")
|
| 125 |
|
| 126 |
|
| 127 |
def sizeof_fmt(size, suffix='B'):
|
| 128 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P']:
|
| 129 |
+
if size < 1024:
|
| 130 |
+
return f"{size:3.1f} {unit}{suffix}"
|
| 131 |
+
size /= 1024
|
| 132 |
+
return f"{size:3.1f} Y{suffix}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|