Siam2315 commited on
Commit
54519d5
·
verified ·
1 Parent(s): f484856

Update basicsr/utils/misc.py

Browse files
Files changed (1) hide show
  1. basicsr/utils/misc.py +68 -93
basicsr/utils/misc.py CHANGED
@@ -9,36 +9,58 @@ from os import path as osp
9
  from .dist_util import master_only
10
  from .logger import get_root_logger
11
 
12
- IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
13
- torch.__version__)[0][:3])] >= [1, 12, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def gpu_is_available():
16
- if IS_HIGH_VERSION:
17
- if torch.backends.mps.is_available():
18
- return True
19
- return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
 
20
 
21
  def get_device(gpu_id=None):
22
- if gpu_id is None:
23
- gpu_str = ''
24
- elif isinstance(gpu_id, int):
25
- gpu_str = f':{gpu_id}'
26
- else:
27
- raise TypeError('Input should be int value.')
28
 
29
- if IS_HIGH_VERSION:
30
- if torch.backends.mps.is_available():
31
- return torch.device('mps'+gpu_str)
32
- return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
33
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def set_random_seed(seed):
36
  """Set random seeds."""
37
  random.seed(seed)
38
  np.random.seed(seed)
39
  torch.manual_seed(seed)
40
- torch.cuda.manual_seed(seed)
41
- torch.cuda.manual_seed_all(seed)
 
42
 
43
 
44
  def get_time_str():
@@ -46,112 +68,65 @@ def get_time_str():
46
 
47
 
48
  def mkdir_and_rename(path):
49
- """mkdirs. If path exists, rename it with timestamp and create a new one.
50
-
51
- Args:
52
- path (str): Folder path.
53
- """
54
  if osp.exists(path):
55
  new_name = path + '_archived_' + get_time_str()
56
- print(f'Path already exists. Rename it to {new_name}', flush=True)
57
  os.rename(path, new_name)
58
  os.makedirs(path, exist_ok=True)
59
 
60
 
61
  @master_only
62
  def make_exp_dirs(opt):
63
- """Make dirs for experiments."""
64
  path_opt = opt['path'].copy()
65
  if opt['is_train']:
66
  mkdir_and_rename(path_opt.pop('experiments_root'))
67
  else:
68
  mkdir_and_rename(path_opt.pop('results_root'))
 
69
  for key, path in path_opt.items():
70
  if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
71
  os.makedirs(path, exist_ok=True)
72
 
73
 
74
  def scandir(dir_path, suffix=None, recursive=False, full_path=False):
75
- """Scan a directory to find the interested files.
76
-
77
- Args:
78
- dir_path (str): Path of the directory.
79
- suffix (str | tuple(str), optional): File suffix that we are
80
- interested in. Default: None.
81
- recursive (bool, optional): If set to True, recursively scan the
82
- directory. Default: False.
83
- full_path (bool, optional): If set to True, include the dir_path.
84
- Default: False.
85
-
86
- Returns:
87
- A generator for all the interested files with relative pathes.
88
- """
89
-
90
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
91
- raise TypeError('"suffix" must be a string or tuple of strings')
92
-
93
  root = dir_path
94
 
95
- def _scandir(dir_path, suffix, recursive):
96
- for entry in os.scandir(dir_path):
97
- if not entry.name.startswith('.') and entry.is_file():
98
- if full_path:
99
- return_path = entry.path
100
- else:
101
- return_path = osp.relpath(entry.path, root)
102
-
103
- if suffix is None:
104
- yield return_path
105
- elif return_path.endswith(suffix):
106
- yield return_path
107
- else:
108
- if recursive:
109
- yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
110
- else:
111
- continue
112
 
113
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
114
 
115
 
116
  def check_resume(opt, resume_iter):
117
- """Check resume states and pretrain_network paths.
118
-
119
- Args:
120
- opt (dict): Options.
121
- resume_iter (int): Resume iteration.
122
- """
123
  logger = get_root_logger()
 
124
  if opt['path']['resume_state']:
125
- # get all the networks
126
- networks = [key for key in opt.keys() if key.startswith('network_')]
127
- flag_pretrain = False
128
- for network in networks:
129
- if opt['path'].get(f'pretrain_{network}') is not None:
130
- flag_pretrain = True
131
  if flag_pretrain:
132
  logger.warning('pretrain_network path will be ignored during resuming.')
133
- # set pretrained model paths
134
  for network in networks:
135
- name = f'pretrain_{network}'
136
  basename = network.replace('network_', '')
137
- if opt['path'].get('ignore_resume_networks') is None or (basename
138
- not in opt['path']['ignore_resume_networks']):
139
- opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
140
- logger.info(f"Set {name} to {opt['path'][name]}")
 
 
 
141
 
142
 
143
  def sizeof_fmt(size, suffix='B'):
144
- """Get human readable file size.
145
-
146
- Args:
147
- size (int): File size.
148
- suffix (str): Suffix. Default: 'B'.
149
-
150
- Return:
151
- str: Formated file siz.
152
- """
153
- for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
154
- if abs(size) < 1024.0:
155
- return f'{size:3.1f} {unit}{suffix}'
156
- size /= 1024.0
157
- return f'{size:3.1f} Y{suffix}'
 
9
  from .dist_util import master_only
10
  from .logger import get_root_logger
11
 
12
+
13
+ # ---------------------------
14
+ # GPU / MPS Compatibility
15
+ # ---------------------------
16
+
17
+ # Check if PyTorch ≥ 1.12 for MPS (Apple Silicon)
18
+ try:
19
+ version_match = re.findall(
20
+ r"^([0-9]+)\.([0-9]+)\.([0-9]+)",
21
+ torch.__version__
22
+ )[0]
23
+ IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0]
24
+ except:
25
+ IS_HIGH_VERSION = False
26
+
27
 
28
  def gpu_is_available():
29
+ """Return True if CUDA or MPS is available."""
30
+ if IS_HIGH_VERSION and torch.backends.mps.is_available():
31
+ return True
32
+ return torch.cuda.is_available() and torch.backends.cudnn.is_available()
33
+
34
 
35
  def get_device(gpu_id=None):
36
+ """Return the best available device (MPS → CUDA → CPU)."""
37
+
38
+ gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else ""
 
 
 
39
 
40
+ # Apple MPS
41
+ if IS_HIGH_VERSION and torch.backends.mps.is_available():
42
+ return torch.device("mps")
 
43
 
44
+ # NVIDIA CUDA
45
+ if torch.cuda.is_available() and torch.backends.cudnn.is_available():
46
+ return torch.device("cuda" + gpu_str)
47
+
48
+ # CPU fallback
49
+ return torch.device("cpu")
50
+
51
+
52
+ # ---------------------------
53
+ # Utilities
54
+ # ---------------------------
55
 
56
  def set_random_seed(seed):
57
  """Set random seeds."""
58
  random.seed(seed)
59
  np.random.seed(seed)
60
  torch.manual_seed(seed)
61
+ if torch.cuda.is_available():
62
+ torch.cuda.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
 
65
 
66
  def get_time_str():
 
68
 
69
 
70
  def mkdir_and_rename(path):
 
 
 
 
 
71
  if osp.exists(path):
72
  new_name = path + '_archived_' + get_time_str()
73
+ print(f'Path already exists. Renamed to {new_name}', flush=True)
74
  os.rename(path, new_name)
75
  os.makedirs(path, exist_ok=True)
76
 
77
 
78
  @master_only
79
  def make_exp_dirs(opt):
 
80
  path_opt = opt['path'].copy()
81
  if opt['is_train']:
82
  mkdir_and_rename(path_opt.pop('experiments_root'))
83
  else:
84
  mkdir_and_rename(path_opt.pop('results_root'))
85
+
86
  for key, path in path_opt.items():
87
  if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
88
  os.makedirs(path, exist_ok=True)
89
 
90
 
91
  def scandir(dir_path, suffix=None, recursive=False, full_path=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  root = dir_path
93
 
94
+ def _scan(path):
95
+ for entry in os.scandir(path):
96
+ if entry.is_file() and not entry.name.startswith('.'):
97
+ file_path = entry.path if full_path else osp.relpath(entry.path, root)
98
+ if suffix is None or file_path.endswith(suffix):
99
+ yield file_path
100
+ elif entry.is_dir() and recursive:
101
+ yield from _scan(entry.path)
 
 
 
 
 
 
 
 
 
102
 
103
+ return _scan(dir_path)
104
 
105
 
106
  def check_resume(opt, resume_iter):
 
 
 
 
 
 
107
  logger = get_root_logger()
108
+
109
  if opt['path']['resume_state']:
110
+ networks = [k for k in opt.keys() if k.startswith('network_')]
111
+ flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks)
112
+
 
 
 
113
  if flag_pretrain:
114
  logger.warning('pretrain_network path will be ignored during resuming.')
115
+
116
  for network in networks:
 
117
  basename = network.replace('network_', '')
118
+ if opt['path'].get('ignore_resume_networks') is None or (
119
+ basename not in opt['path']['ignore_resume_networks']
120
+ ):
121
+ opt['path'][f'pretrain_{network}'] = osp.join(
122
+ opt['path']['models'], f'net_{basename}_{resume_iter}.pth'
123
+ )
124
+ logger.info(f"Set pretrain for {network}")
125
 
126
 
127
  def sizeof_fmt(size, suffix='B'):
128
+ for unit in ['', 'K', 'M', 'G', 'T', 'P']:
129
+ if size < 1024:
130
+ return f"{size:3.1f} {unit}{suffix}"
131
+ size /= 1024
132
+ return f"{size:3.1f} Y{suffix}"