import argparse import os import sys import time import warnings from importlib import import_module import numpy as np import torch import torch.nn as nn from PIL import Image warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional") def str2bool(v: str, strict=True) -> bool: if isinstance(v, bool): return v elif isinstance(v, str): if v.lower() in ("true", "yes", "on" "t", "y", "1"): return True elif v.lower() in ("false", "no", "off", "f", "n", "0"): return False if strict: raise argparse.ArgumentTypeError("Unsupported value encountered.") else: return True def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None): if isinstance(data, torch.Tensor): data = data.to(device) elif isinstance(data, (tuple, list, set)): data = [to_cuda(b, device) for b in data] elif isinstance(data, dict): if exclude_keys is None: exclude_keys = [] for k in data.keys(): if k not in exclude_keys: data[k] = to_cuda(data[k], device) else: # raise TypeError(f"Unsupported type: {type(data)}") data = data return data class HiddenPrints: def __enter__(self): self._original_stdout = sys.stdout sys.stdout = open(os.devnull, "w") def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() sys.stdout = self._original_stdout class Logger(object): def __init__(self): self.terminal = sys.stdout self.file = None def open(self, file, mode=None): if mode is None: mode = "w" self.file = open(file, mode) def write(self, message, is_terminal=1, is_file=1): if "\r" in message: is_file = 0 if is_terminal == 1: self.terminal.write(message) self.terminal.flush() if is_file == 1: self.file.write(message) self.file.flush() def flush(self): # this flush method is needed for python 3 compatibility. # this handles the flush command by doing nothing. # you might want to specify some extra behavior here. pass def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True): if "resnet" in arch: from networks.resnet import ResNet resnet = getattr(import_module("networks.resnet"), arch) if isTrain: if continue_train: model: ResNet = resnet(num_classes=1) else: model: ResNet = resnet(pretrained=pretrained) model.fc = nn.Linear(2048, 1) nn.init.normal_(model.fc.weight.data, 0.0, init_gain) else: model: ResNet = resnet(num_classes=1) return model else: raise ValueError(f"Unsupported arch: {arch}") def pad_img_to_square(img: np.ndarray): H, W = img.shape[:2] if H != W: new_size = max(H, W) img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant") assert img.shape[0] == img.shape[1] == new_size return img