Tophness2022 commited on
Commit
b28cdfa
·
1 Parent(s): 4be1cfd

fix pytorch version extraction

Browse files
preprocessing/matanyone/tools/misc.py CHANGED
@@ -52,9 +52,14 @@ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None
52
  initialized_logger[logger_name] = True
53
  return logger
54
 
55
-
56
- IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57
- torch.__version__)[0][:3])] >= [1, 12, 0]
 
 
 
 
 
58
 
59
  def gpu_is_available():
60
  if IS_HIGH_VERSION:
 
52
  initialized_logger[logger_name] = True
53
  return logger
54
 
55
+ match = re.match(r"^([0-9]+)\.([0-9]+)\.([0-9]+)", torch.__version__)
56
+ if match:
57
+ version_tuple = match.groups()
58
+ IS_HIGH_VERSION = [int(v) for v in version_tuple] >= [1, 12, 0]
59
+ else:
60
+ logger = get_root_logger()
61
+ logger.warning(f"Could not parse torch version '{torch.__version__}'. Assuming it's not a high version >= 1.12.0.")
62
+ IS_HIGH_VERSION = False
63
 
64
  def gpu_is_available():
65
  if IS_HIGH_VERSION: