Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| # @Organization : Alibaba XR-Lab | |
| # @Author : Lingteng Qiu | |
| # @Email : 220019047@link.cuhk.edu.cn | |
| # @Time : 2024-08-30 20:50:27 | |
| # @Function : The class defines bbox, base-seg module | |
| import copy | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| class BaseModel(object): | |
| """ | |
| Simple BaseModel | |
| """ | |
| def cuda(self): | |
| self.model.cuda() | |
| return self | |
| def cpu(self): | |
| self.model.cpu() | |
| return self | |
| def float(self): | |
| self.model.float() | |
| return self | |
| def to(self, device): | |
| self.model.to(device) | |
| return self | |
| def eval(self): | |
| self.model.eval() | |
| return self | |
| def train(self): | |
| self.model.train() | |
| return self | |
| def __call__(self, x): | |
| raise NotImplementedError | |
| def __repr__(self): | |
| return f"model: \n{self.model}" | |
| def get_dtype_string(arr): | |
| if arr.dtype == np.uint8: | |
| return "uint8" | |
| elif arr.dtype == np.float32: | |
| return "float32" | |
| elif arr.dtype == np.float64: | |
| return "float" | |
| else: | |
| return "unknow" | |
| class BaseSeg(BaseModel): | |
| def __init__(self): | |
| pass | |
| class Bbox: | |
| def __init__(self, box, mode="whwh"): | |
| assert len(box) == 4 | |
| assert mode in ["whwh", "xywh"] | |
| self.box = box | |
| self.mode = mode | |
| def to_xywh(self): | |
| if self.mode == "whwh": | |
| l, t, r, b = self.box | |
| center_x = (l + r) / 2 | |
| center_y = (t + b) / 2 | |
| width = r - l | |
| height = b - t | |
| return Bbox([center_x, center_y, width, height], mode="xywh") | |
| else: | |
| return self | |
| def to_whwh(self): | |
| if self.mode == "whwh": | |
| return self | |
| else: | |
| cx, cy, w, h = self.box | |
| l = cx - w // 2 | |
| t = cy - h // 2 | |
| r = cx + w - (w // 2) | |
| b = cy + h - (h // 2) | |
| return Bbox([l, t, r, b], mode="whwh") | |
| def area(self): | |
| box = self.to_xywh() | |
| _, __, w, h = box.box | |
| return w * h | |
| def get_box(self): | |
| return list(map(int, self.box)) | |
| def scale(self, scale, width, height): | |
| new_box = self.to_xywh() | |
| cx, cy, w, h = new_box.get_box() | |
| w = w * scale | |
| h = h * scale | |
| l = cx - w // 2 | |
| t = cy - h // 2 | |
| r = cx + w - (w // 2) | |
| b = cy + h - (h // 2) | |
| l = int(max(l, 0)) | |
| t = int(max(t, 0)) | |
| r = int(min(r, width)) | |
| b = int(min(b, height)) | |
| return Bbox([l, t, r, b], mode="whwh") | |
| def __repr__(self): | |
| box = self.to_whwh() | |
| l, t, r, b = box.box | |
| return f"BBox(left={l}, top={t}, right={r}, bottom={b})" | |
| class Image: | |
| """TODO need to debug""" | |
| TYPE_ORDER = ["uint8", "float32", "float"] | |
| ORDER = ["RGB", "BGR"] | |
| MODE = ["numpy"] | |
| def __init__(self, input, order="RGB", type_mode="uint8"): | |
| """Only support 3 Channel Image""" | |
| if isinstance(input, str): | |
| self.data = self.read_image(input, type_mode, order) | |
| else: | |
| self.data = self.get_image(input, type_mode, order) | |
| self.order = order | |
| self.type_mode = type_mode | |
| def get_image(self, input, type_mode, order): | |
| if isinstance(input, Image): | |
| return input.to_numpy(type_mode, order) | |
| elif isinstance(input, np.ndarray): | |
| self.data = input | |
| self.order = "RGB" # default | |
| self.type_mode = get_dtype_string(input) | |
| return self.to_numpy(type_mode, order) | |
| else: | |
| raise NotImplementedError | |
| def to_numpy(self, type_mode="uint8", order="RGB"): | |
| data = copy.deepcopy(self.data) | |
| if not order == self.order: | |
| return data[..., ::-1] # only support RGB -> BGR or BGR -> RGB | |
| if self.type_mode == type_mode: | |
| return data | |
| else: | |
| if self.type_mode == "float32": | |
| return (self.data / 255.0).astype(np.float32) | |
| elif self.type_mode == "float": | |
| return (self.data / 255.0).astype(np.float64) | |
| def to_tensor(self, order): | |
| data = self.to_numpy(type_mode="float32", order=order) | |
| return torch.from_numpy(data) | |
| def read_image( | |
| self, | |
| path, | |
| mode, | |
| order, | |
| ): | |
| """read an image file into various formats and color mode. | |
| Args: | |
| path (str): path to the image file. | |
| mode (Literal["float", "uint8", "pil", "torch", "tensor"], optional): returned image format. Defaults to "float". | |
| float: float32 numpy array, range [0, 1]; | |
| uint8: uint8 numpy array, range [0, 255]; | |
| pil: PIL image; | |
| torch/tensor: float32 torch tensor, range [0, 1]; | |
| order (Literal["RGB", "RGBA", "BGR", "BGRA"], optional): channel order. Defaults to "RGB". | |
| Note: | |
| By default this function will convert RGBA image to white-background RGB image. Use ``order="RGBA"`` to keep the alpha channel. | |
| Returns: | |
| Union[np.ndarray, PIL.Image, torch.Tensor]: the image array. | |
| """ | |
| if mode == "pil": | |
| return Image.open(path).convert(order) | |
| img = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
| # cvtColor | |
| if len(img.shape) == 3: # ignore if gray scale | |
| if order in ["RGB", "RGBA"]: | |
| if img.shape[-1] == 4: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) | |
| elif img.shape[-1] == 3: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # mix background | |
| if img.shape[-1] == 4 and "A" not in order: | |
| img = img.astype(np.float32) / 255 | |
| img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) | |
| # mode | |
| if mode == "uint8": | |
| if img.dtype != np.uint8: | |
| img = (img * 255).astype(np.uint8) | |
| elif mode == "float": | |
| if img.dtype == np.uint8: | |
| img = img.astype(np.float32) / 255 | |
| else: | |
| raise ValueError(f"Unknown read_image mode {mode}") | |
| return img | |