Spaces:
Runtime error
Runtime error
| """ | |
| Misc | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| import os | |
| import warnings | |
| from collections import abc | |
| import numpy as np | |
| import torch | |
| from importlib import import_module | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def intersection_and_union(output, target, K, ignore_index=-1): | |
| # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. | |
| assert output.ndim in [1, 2, 3] | |
| assert output.shape == target.shape | |
| output = output.reshape(output.size).copy() | |
| target = target.reshape(target.size) | |
| output[np.where(target == ignore_index)[0]] = ignore_index | |
| intersection = output[np.where(output == target)[0]] | |
| area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) | |
| area_output, _ = np.histogram(output, bins=np.arange(K + 1)) | |
| area_target, _ = np.histogram(target, bins=np.arange(K + 1)) | |
| area_union = area_output + area_target - area_intersection | |
| return area_intersection, area_union, area_target | |
| def intersection_and_union_gpu(output, target, k, ignore_index=-1): | |
| # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. | |
| assert output.dim() in [1, 2, 3] | |
| assert output.shape == target.shape | |
| output = output.view(-1) | |
| target = target.view(-1) | |
| output[target == ignore_index] = ignore_index | |
| intersection = output[output == target] | |
| area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) | |
| area_output = torch.histc(output, bins=k, min=0, max=k - 1) | |
| area_target = torch.histc(target, bins=k, min=0, max=k - 1) | |
| area_union = area_output + area_target - area_intersection | |
| return area_intersection, area_union, area_target | |
| def make_dirs(dir_name): | |
| if not os.path.exists(dir_name): | |
| os.makedirs(dir_name, exist_ok=True) | |
| def find_free_port(): | |
| import socket | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| # Binding to port 0 will cause the OS to find an available port for us | |
| sock.bind(("", 0)) | |
| port = sock.getsockname()[1] | |
| sock.close() | |
| # NOTE: there is still a chance the port could be taken by other processes. | |
| return port | |
| def is_seq_of(seq, expected_type, seq_type=None): | |
| """Check whether it is a sequence of some type. | |
| Args: | |
| seq (Sequence): The sequence to be checked. | |
| expected_type (type): Expected type of sequence items. | |
| seq_type (type, optional): Expected sequence type. | |
| Returns: | |
| bool: Whether the sequence is valid. | |
| """ | |
| if seq_type is None: | |
| exp_seq_type = abc.Sequence | |
| else: | |
| assert isinstance(seq_type, type) | |
| exp_seq_type = seq_type | |
| if not isinstance(seq, exp_seq_type): | |
| return False | |
| for item in seq: | |
| if not isinstance(item, expected_type): | |
| return False | |
| return True | |
| def is_str(x): | |
| """Whether the input is an string instance. | |
| Note: This method is deprecated since python 2 is no longer supported. | |
| """ | |
| return isinstance(x, str) | |
| def import_modules_from_strings(imports, allow_failed_imports=False): | |
| """Import modules from the given list of strings. | |
| Args: | |
| imports (list | str | None): The given module names to be imported. | |
| allow_failed_imports (bool): If True, the failed imports will return | |
| None. Otherwise, an ImportError is raise. Default: False. | |
| Returns: | |
| list[module] | module | None: The imported modules. | |
| Examples: | |
| >>> osp, sys = import_modules_from_strings( | |
| ... ['os.path', 'sys']) | |
| >>> import os.path as osp_ | |
| >>> import sys as sys_ | |
| >>> assert osp == osp_ | |
| >>> assert sys == sys_ | |
| """ | |
| if not imports: | |
| return | |
| single_import = False | |
| if isinstance(imports, str): | |
| single_import = True | |
| imports = [imports] | |
| if not isinstance(imports, list): | |
| raise TypeError(f"custom_imports must be a list but got type {type(imports)}") | |
| imported = [] | |
| for imp in imports: | |
| if not isinstance(imp, str): | |
| raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") | |
| try: | |
| imported_tmp = import_module(imp) | |
| except ImportError: | |
| if allow_failed_imports: | |
| warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) | |
| imported_tmp = None | |
| else: | |
| raise ImportError | |
| imported.append(imported_tmp) | |
| if single_import: | |
| imported = imported[0] | |
| return imported | |
| class DummyClass: | |
| def __init__(self): | |
| pass | |