|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import warnings |
|
|
from importlib import import_module |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional") |
|
|
|
|
|
|
|
|
def str2bool(v: str, strict=True) -> bool: |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
elif isinstance(v, str): |
|
|
if v.lower() in ("true", "yes", "on" "t", "y", "1"): |
|
|
return True |
|
|
elif v.lower() in ("false", "no", "off", "f", "n", "0"): |
|
|
return False |
|
|
if strict: |
|
|
raise argparse.ArgumentTypeError("Unsupported value encountered.") |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
def to_cuda(data, device="cuda", exclude_keys: "list[str]" = None): |
|
|
if isinstance(data, torch.Tensor): |
|
|
data = data.to(device) |
|
|
elif isinstance(data, (tuple, list, set)): |
|
|
data = [to_cuda(b, device) for b in data] |
|
|
elif isinstance(data, dict): |
|
|
if exclude_keys is None: |
|
|
exclude_keys = [] |
|
|
for k in data.keys(): |
|
|
if k not in exclude_keys: |
|
|
data[k] = to_cuda(data[k], device) |
|
|
else: |
|
|
|
|
|
data = data |
|
|
return data |
|
|
|
|
|
|
|
|
class HiddenPrints: |
|
|
def __enter__(self): |
|
|
self._original_stdout = sys.stdout |
|
|
sys.stdout = open(os.devnull, "w") |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
sys.stdout.close() |
|
|
sys.stdout = self._original_stdout |
|
|
|
|
|
|
|
|
class Logger(object): |
|
|
def __init__(self): |
|
|
self.terminal = sys.stdout |
|
|
self.file = None |
|
|
|
|
|
def open(self, file, mode=None): |
|
|
if mode is None: |
|
|
mode = "w" |
|
|
self.file = open(file, mode) |
|
|
|
|
|
def write(self, message, is_terminal=1, is_file=1): |
|
|
if "\r" in message: |
|
|
is_file = 0 |
|
|
if is_terminal == 1: |
|
|
self.terminal.write(message) |
|
|
self.terminal.flush() |
|
|
if is_file == 1: |
|
|
self.file.write(message) |
|
|
self.file.flush() |
|
|
|
|
|
def flush(self): |
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def get_network(arch: str, isTrain=False, continue_train=False, init_gain=0.02, pretrained=True): |
|
|
if "resnet" in arch: |
|
|
from networks.resnet import ResNet |
|
|
|
|
|
resnet = getattr(import_module("networks.resnet"), arch) |
|
|
if isTrain: |
|
|
if continue_train: |
|
|
model: ResNet = resnet(num_classes=1) |
|
|
else: |
|
|
model: ResNet = resnet(pretrained=pretrained) |
|
|
model.fc = nn.Linear(2048, 1) |
|
|
nn.init.normal_(model.fc.weight.data, 0.0, init_gain) |
|
|
else: |
|
|
model: ResNet = resnet(num_classes=1) |
|
|
return model |
|
|
else: |
|
|
raise ValueError(f"Unsupported arch: {arch}") |
|
|
|
|
|
|
|
|
def pad_img_to_square(img: np.ndarray): |
|
|
H, W = img.shape[:2] |
|
|
if H != W: |
|
|
new_size = max(H, W) |
|
|
img = np.pad(img, ((0, new_size - H), (0, new_size - W), (0, 0)), mode="constant") |
|
|
assert img.shape[0] == img.shape[1] == new_size |
|
|
return img |
|
|
|