Spaces:
Build error
Build error
| # -*- encoding: utf-8 -*- | |
| # @Author: SWHL | |
| # @Contact: liekkaskono@163.com | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Any, Union | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, UnidentifiedImageError | |
| root_dir = Path(__file__).resolve().parent | |
| InputType = Union[str, np.ndarray, bytes, Path, Image.Image] | |
| class LoadImage: | |
| def __init__(self): | |
| pass | |
| def __call__(self, img: InputType) -> np.ndarray: | |
| if not isinstance(img, InputType.__args__): | |
| raise LoadImageError( | |
| f"The img type {type(img)} does not in {InputType.__args__}" | |
| ) | |
| origin_img_type = type(img) | |
| img = self.load_img(img) | |
| img = self.convert_img(img, origin_img_type) | |
| return img | |
| def load_img(self, img: InputType) -> np.ndarray: | |
| if isinstance(img, (str, Path)): | |
| self.verify_exist(img) | |
| try: | |
| img = self.img_to_ndarray(Image.open(img)) | |
| except UnidentifiedImageError as e: | |
| raise LoadImageError(f"cannot identify image file {img}") from e | |
| return img | |
| if isinstance(img, bytes): | |
| img = self.img_to_ndarray(Image.open(BytesIO(img))) | |
| return img | |
| if isinstance(img, np.ndarray): | |
| return img | |
| if isinstance(img, Image.Image): | |
| return self.img_to_ndarray(img) | |
| raise LoadImageError(f"{type(img)} is not supported!") | |
| def img_to_ndarray(self, img: Image.Image) -> np.ndarray: | |
| if img.mode == "1": | |
| img = img.convert("L") | |
| return np.array(img) | |
| return np.array(img) | |
| def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: | |
| if img.ndim == 2: | |
| return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| if img.ndim == 3: | |
| channel = img.shape[2] | |
| if channel == 1: | |
| return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| if channel == 2: | |
| return self.cvt_two_to_three(img) | |
| if channel == 3: | |
| if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): | |
| return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| return img | |
| if channel == 4: | |
| return self.cvt_four_to_three(img) | |
| raise LoadImageError( | |
| f"The channel({channel}) of the img is not in [1, 2, 3, 4]" | |
| ) | |
| raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") | |
| def cvt_two_to_three(img: np.ndarray) -> np.ndarray: | |
| """gray + alpha → BGR""" | |
| img_gray = img[..., 0] | |
| img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) | |
| img_alpha = img[..., 1] | |
| not_a = cv2.bitwise_not(img_alpha) | |
| not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) | |
| new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) | |
| new_img = cv2.add(new_img, not_a) | |
| return new_img | |
| def cvt_four_to_three(img: np.ndarray) -> np.ndarray: | |
| """RGBA → BGR""" | |
| r, g, b, a = cv2.split(img) | |
| new_img = cv2.merge((b, g, r)) | |
| not_a = cv2.bitwise_not(a) | |
| not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) | |
| new_img = cv2.bitwise_and(new_img, new_img, mask=a) | |
| mean_color = np.mean(new_img) | |
| if mean_color <= 0.0: | |
| new_img = cv2.add(new_img, not_a) | |
| else: | |
| new_img = cv2.bitwise_not(new_img) | |
| return new_img | |
| def verify_exist(file_path: Union[str, Path]): | |
| if not Path(file_path).exists(): | |
| raise LoadImageError(f"{file_path} does not exist.") | |
| class LoadImageError(Exception): | |
| pass | |