Siam2315 commited on
Commit
12bf33b
·
verified ·
1 Parent(s): 64b9375

Update CodeFormer/basicsr/utils/misc.py

Browse files
Files changed (1) hide show
  1. CodeFormer/basicsr/utils/misc.py +68 -109
CodeFormer/basicsr/utils/misc.py CHANGED
@@ -3,58 +3,64 @@ import re
3
  import random
4
  import time
5
  import torch
6
- def gpu_is_available():
7
- return torch.cuda.is_available()
8
-
9
- # Note: get_device is also often missing, so it's best practice to add it too.
10
- def get_device(gpu_id=None):
11
- if gpu_id is None:
12
- gpu_str = ''
13
- elif isinstance(gpu_id, int):
14
- gpu_str = f':{gpu_id}'
15
- else:
16
- raise TypeError('Input should be int value. ')
17
-
18
- # Prioritize CUDA (GPU) if available, otherwise use CPU
19
- if torch.cuda.is_available() and torch.backends.cudnn.is_available():
20
- return torch.device(f'cuda{gpu_str}')
21
- return torch.device('cpu')
22
  import numpy as np
23
  from os import path as osp
24
 
25
  from .dist_util import master_only
26
  from .logger import get_root_logger
27
 
28
- IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
29
- torch.__version__)[0][:3])] >= [1, 12, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def gpu_is_available():
32
- if IS_HIGH_VERSION:
33
- if torch.backends.mps.is_available():
34
- return True
35
- return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
 
36
 
37
  def get_device(gpu_id=None):
38
- if gpu_id is None:
39
- gpu_str = ''
40
- elif isinstance(gpu_id, int):
41
- gpu_str = f':{gpu_id}'
42
- else:
43
- raise TypeError('Input should be int value.')
 
 
 
 
 
44
 
45
- if IS_HIGH_VERSION:
46
- if torch.backends.mps.is_available():
47
- return torch.device('mps'+gpu_str)
48
- return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
49
 
50
 
 
 
 
 
51
  def set_random_seed(seed):
52
  """Set random seeds."""
53
  random.seed(seed)
54
  np.random.seed(seed)
55
  torch.manual_seed(seed)
56
- torch.cuda.manual_seed(seed)
57
- torch.cuda.manual_seed_all(seed)
 
58
 
59
 
60
  def get_time_str():
@@ -62,112 +68,65 @@ def get_time_str():
62
 
63
 
64
  def mkdir_and_rename(path):
65
- """mkdirs. If path exists, rename it with timestamp and create a new one.
66
-
67
- Args:
68
- path (str): Folder path.
69
- """
70
  if osp.exists(path):
71
  new_name = path + '_archived_' + get_time_str()
72
- print(f'Path already exists. Rename it to {new_name}', flush=True)
73
  os.rename(path, new_name)
74
  os.makedirs(path, exist_ok=True)
75
 
76
 
77
  @master_only
78
  def make_exp_dirs(opt):
79
- """Make dirs for experiments."""
80
  path_opt = opt['path'].copy()
81
  if opt['is_train']:
82
  mkdir_and_rename(path_opt.pop('experiments_root'))
83
  else:
84
  mkdir_and_rename(path_opt.pop('results_root'))
 
85
  for key, path in path_opt.items():
86
  if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
87
  os.makedirs(path, exist_ok=True)
88
 
89
 
90
  def scandir(dir_path, suffix=None, recursive=False, full_path=False):
91
- """Scan a directory to find the interested files.
92
-
93
- Args:
94
- dir_path (str): Path of the directory.
95
- suffix (str | tuple(str), optional): File suffix that we are
96
- interested in. Default: None.
97
- recursive (bool, optional): If set to True, recursively scan the
98
- directory. Default: False.
99
- full_path (bool, optional): If set to True, include the dir_path.
100
- Default: False.
101
-
102
- Returns:
103
- A generator for all the interested files with relative pathes.
104
- """
105
-
106
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
107
- raise TypeError('"suffix" must be a string or tuple of strings')
108
-
109
  root = dir_path
110
 
111
- def _scandir(dir_path, suffix, recursive):
112
- for entry in os.scandir(dir_path):
113
- if not entry.name.startswith('.') and entry.is_file():
114
- if full_path:
115
- return_path = entry.path
116
- else:
117
- return_path = osp.relpath(entry.path, root)
 
118
 
119
- if suffix is None:
120
- yield return_path
121
- elif return_path.endswith(suffix):
122
- yield return_path
123
- else:
124
- if recursive:
125
- yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
126
- else:
127
- continue
128
-
129
- return _scandir(dir_path, suffix=suffix, recursive=recursive)
130
 
131
 
132
  def check_resume(opt, resume_iter):
133
- """Check resume states and pretrain_network paths.
134
-
135
- Args:
136
- opt (dict): Options.
137
- resume_iter (int): Resume iteration.
138
- """
139
  logger = get_root_logger()
 
140
  if opt['path']['resume_state']:
141
- # get all the networks
142
- networks = [key for key in opt.keys() if key.startswith('network_')]
143
- flag_pretrain = False
144
- for network in networks:
145
- if opt['path'].get(f'pretrain_{network}') is not None:
146
- flag_pretrain = True
147
  if flag_pretrain:
148
  logger.warning('pretrain_network path will be ignored during resuming.')
149
- # set pretrained model paths
150
  for network in networks:
151
- name = f'pretrain_{network}'
152
  basename = network.replace('network_', '')
153
- if opt['path'].get('ignore_resume_networks') is None or (basename
154
- not in opt['path']['ignore_resume_networks']):
155
- opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
156
- logger.info(f"Set {name} to {opt['path'][name]}")
 
 
 
157
 
158
 
159
  def sizeof_fmt(size, suffix='B'):
160
- """Get human readable file size.
161
-
162
- Args:
163
- size (int): File size.
164
- suffix (str): Suffix. Default: 'B'.
165
-
166
- Return:
167
- str: Formated file siz.
168
- """
169
- for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
170
- if abs(size) < 1024.0:
171
- return f'{size:3.1f} {unit}{suffix}'
172
- size /= 1024.0
173
- return f'{size:3.1f} Y{suffix}'
 
3
  import random
4
  import time
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
  from os import path as osp
8
 
9
  from .dist_util import master_only
10
  from .logger import get_root_logger
11
 
12
+
13
+ # ---------------------------
14
+ # GPU / MPS Compatibility
15
+ # ---------------------------
16
+
17
+ # Check if PyTorch ≥ 1.12 for MPS (Apple Silicon)
18
+ try:
19
+ version_match = re.findall(
20
+ r"^([0-9]+)\.([0-9]+)\.([0-9]+)",
21
+ torch.__version__
22
+ )[0]
23
+ IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0]
24
+ except:
25
+ IS_HIGH_VERSION = False
26
+
27
 
28
  def gpu_is_available():
29
+ """Return True if CUDA or MPS is available."""
30
+ if IS_HIGH_VERSION and torch.backends.mps.is_available():
31
+ return True
32
+ return torch.cuda.is_available() and torch.backends.cudnn.is_available()
33
+
34
 
35
  def get_device(gpu_id=None):
36
+ """Return the best available device (MPS → CUDA → CPU)."""
37
+
38
+ gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else ""
39
+
40
+ # Apple MPS
41
+ if IS_HIGH_VERSION and torch.backends.mps.is_available():
42
+ return torch.device("mps")
43
+
44
+ # NVIDIA CUDA
45
+ if torch.cuda.is_available() and torch.backends.cudnn.is_available():
46
+ return torch.device("cuda" + gpu_str)
47
 
48
+ # CPU fallback
49
+ return torch.device("cpu")
 
 
50
 
51
 
52
+ # ---------------------------
53
+ # Utilities
54
+ # ---------------------------
55
+
56
  def set_random_seed(seed):
57
  """Set random seeds."""
58
  random.seed(seed)
59
  np.random.seed(seed)
60
  torch.manual_seed(seed)
61
+ if torch.cuda.is_available():
62
+ torch.cuda.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
 
65
 
66
  def get_time_str():
 
68
 
69
 
70
  def mkdir_and_rename(path):
 
 
 
 
 
71
  if osp.exists(path):
72
  new_name = path + '_archived_' + get_time_str()
73
+ print(f'Path already exists. Renamed to {new_name}', flush=True)
74
  os.rename(path, new_name)
75
  os.makedirs(path, exist_ok=True)
76
 
77
 
78
  @master_only
79
  def make_exp_dirs(opt):
 
80
  path_opt = opt['path'].copy()
81
  if opt['is_train']:
82
  mkdir_and_rename(path_opt.pop('experiments_root'))
83
  else:
84
  mkdir_and_rename(path_opt.pop('results_root'))
85
+
86
  for key, path in path_opt.items():
87
  if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
88
  os.makedirs(path, exist_ok=True)
89
 
90
 
91
  def scandir(dir_path, suffix=None, recursive=False, full_path=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  root = dir_path
93
 
94
+ def _scan(path):
95
+ for entry in os.scandir(path):
96
+ if entry.is_file() and not entry.name.startswith('.'):
97
+ file_path = entry.path if full_path else osp.relpath(entry.path, root)
98
+ if suffix is None or file_path.endswith(suffix):
99
+ yield file_path
100
+ elif entry.is_dir() and recursive:
101
+ yield from _scan(entry.path)
102
 
103
+ return _scan(dir_path)
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  def check_resume(opt, resume_iter):
 
 
 
 
 
 
107
  logger = get_root_logger()
108
+
109
  if opt['path']['resume_state']:
110
+ networks = [k for k in opt.keys() if k.startswith('network_')]
111
+ flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks)
112
+
 
 
 
113
  if flag_pretrain:
114
  logger.warning('pretrain_network path will be ignored during resuming.')
115
+
116
  for network in networks:
 
117
  basename = network.replace('network_', '')
118
+ if opt['path'].get('ignore_resume_networks') is None or (
119
+ basename not in opt['path']['ignore_resume_networks']
120
+ ):
121
+ opt['path'][f'pretrain_{network}'] = osp.join(
122
+ opt['path']['models'], f'net_{basename}_{resume_iter}.pth'
123
+ )
124
+ logger.info(f"Set pretrain for {network}")
125
 
126
 
127
  def sizeof_fmt(size, suffix='B'):
128
+ for unit in ['', 'K', 'M', 'G', 'T', 'P']:
129
+ if size < 1024:
130
+ return f"{size:3.1f} {unit}{suffix}"
131
+ size /= 1024
132
+ return f"{size:3.1f} Y{suffix}"