| """ |
| Misc |
| |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) |
| Please cite our work if the code is helpful to you. |
| """ |
|
|
| import os |
| import warnings |
| from collections import abc |
| import numpy as np |
| import torch |
| from importlib import import_module |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def intersection_and_union(output, target, K, ignore_index=-1): |
| |
| assert output.ndim in [1, 2, 3] |
| assert output.shape == target.shape |
| output = output.reshape(output.size).copy() |
| target = target.reshape(target.size) |
| output[np.where(target == ignore_index)[0]] = ignore_index |
| intersection = output[np.where(output == target)[0]] |
| area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) |
| area_output, _ = np.histogram(output, bins=np.arange(K + 1)) |
| area_target, _ = np.histogram(target, bins=np.arange(K + 1)) |
| area_union = area_output + area_target - area_intersection |
| return area_intersection, area_union, area_target |
|
|
|
|
| def intersection_and_union_gpu(output, target, k, ignore_index=-1): |
| |
| assert output.dim() in [1, 2, 3] |
| assert output.shape == target.shape |
| output = output.view(-1) |
| target = target.view(-1) |
| output[target == ignore_index] = ignore_index |
| intersection = output[output == target] |
| area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) |
| area_output = torch.histc(output, bins=k, min=0, max=k - 1) |
| area_target = torch.histc(target, bins=k, min=0, max=k - 1) |
| area_union = area_output + area_target - area_intersection |
| return area_intersection, area_union, area_target |
|
|
|
|
| def make_dirs(dir_name): |
| if not os.path.exists(dir_name): |
| os.makedirs(dir_name, exist_ok=True) |
|
|
|
|
| def find_free_port(): |
| import socket |
|
|
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| |
| sock.bind(("", 0)) |
| port = sock.getsockname()[1] |
| sock.close() |
| |
| return port |
|
|
|
|
| def is_seq_of(seq, expected_type, seq_type=None): |
| """Check whether it is a sequence of some type. |
| |
| Args: |
| seq (Sequence): The sequence to be checked. |
| expected_type (type): Expected type of sequence items. |
| seq_type (type, optional): Expected sequence type. |
| |
| Returns: |
| bool: Whether the sequence is valid. |
| """ |
| if seq_type is None: |
| exp_seq_type = abc.Sequence |
| else: |
| assert isinstance(seq_type, type) |
| exp_seq_type = seq_type |
| if not isinstance(seq, exp_seq_type): |
| return False |
| for item in seq: |
| if not isinstance(item, expected_type): |
| return False |
| return True |
|
|
|
|
| def is_str(x): |
| """Whether the input is an string instance. |
| |
| Note: This method is deprecated since python 2 is no longer supported. |
| """ |
| return isinstance(x, str) |
|
|
|
|
| def import_modules_from_strings(imports, allow_failed_imports=False): |
| """Import modules from the given list of strings. |
| |
| Args: |
| imports (list | str | None): The given module names to be imported. |
| allow_failed_imports (bool): If True, the failed imports will return |
| None. Otherwise, an ImportError is raise. Default: False. |
| |
| Returns: |
| list[module] | module | None: The imported modules. |
| |
| Examples: |
| >>> osp, sys = import_modules_from_strings( |
| ... ['os.path', 'sys']) |
| >>> import os.path as osp_ |
| >>> import sys as sys_ |
| >>> assert osp == osp_ |
| >>> assert sys == sys_ |
| """ |
| if not imports: |
| return |
| single_import = False |
| if isinstance(imports, str): |
| single_import = True |
| imports = [imports] |
| if not isinstance(imports, list): |
| raise TypeError(f"custom_imports must be a list but got type {type(imports)}") |
| imported = [] |
| for imp in imports: |
| if not isinstance(imp, str): |
| raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") |
| try: |
| imported_tmp = import_module(imp) |
| except ImportError: |
| if allow_failed_imports: |
| warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) |
| imported_tmp = None |
| else: |
| raise ImportError |
| imported.append(imported_tmp) |
| if single_import: |
| imported = imported[0] |
| return imported |
|
|
|
|
| class DummyClass: |
| def __init__(self): |
| pass |
|
|