| import logging | |
| import os | |
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| import random | |
| import cv2 | |
| from PIL import Image | |
| def path_to_image(path, size=(1024, 1024), color_type=["rgb", "gray"][0]): | |
| if color_type.lower() == "rgb": | |
| image = cv2.imread(path) | |
| elif color_type.lower() == "gray": | |
| image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| else: | |
| print("Select the color_type to return, either to RGB or gray image.") | |
| return | |
| if size: | |
| image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) | |
| if color_type.lower() == "rgb": | |
| image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGB") | |
| else: | |
| image = Image.fromarray(image).convert("L") | |
| return image | |
| def check_state_dict(state_dict, unwanted_prefixes=["_orig_mod.", "module."]): | |
| for k, v in list(state_dict.items()): | |
| for unwanted_prefix in unwanted_prefixes: | |
| if k.startswith(unwanted_prefix): | |
| state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) | |
| break | |
| return state_dict | |
| def generate_smoothed_gt(gts): | |
| epsilon = 0.001 | |
| new_gts = (1 - epsilon) * gts + epsilon / 2 | |
| return new_gts | |
| class Logger: | |
| def __init__(self, path="log.txt"): | |
| self.logger = logging.getLogger("BiRefNet") | |
| self.file_handler = logging.FileHandler(path, "w") | |
| self.stdout_handler = logging.StreamHandler() | |
| self.stdout_handler.setFormatter( | |
| logging.Formatter("%(asctime)s %(levelname)s %(message)s") | |
| ) | |
| self.file_handler.setFormatter( | |
| logging.Formatter("%(asctime)s %(levelname)s %(message)s") | |
| ) | |
| self.logger.addHandler(self.file_handler) | |
| self.logger.addHandler(self.stdout_handler) | |
| self.logger.setLevel(logging.INFO) | |
| self.logger.propagate = False | |
| def info(self, txt): | |
| self.logger.info(txt) | |
| def close(self): | |
| self.file_handler.close() | |
| self.stdout_handler.close() | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0.0 | |
| self.avg = 0.0 | |
| self.sum = 0.0 | |
| self.count = 0.0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def save_checkpoint(state, path, filename="latest.pth"): | |
| torch.save(state, os.path.join(path, filename)) | |
| def save_tensor_img(tenor_im, path): | |
| im = tenor_im.cpu().clone() | |
| im = im.squeeze(0) | |
| tensor2pil = transforms.ToPILImage() | |
| im = tensor2pil(im) | |
| im.save(path) | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.deterministic = True | |