| import os |
| import math |
| import json |
| from dataclasses import dataclass |
| from typing import Dict, Optional |
| from functools import partial |
| from concurrent.futures import as_completed, ThreadPoolExecutor |
| from PIL import Image |
| import cv2 |
| import numpy as np |
| import torch |
|
|
|
|
| @dataclass |
| class ImageProcessOutput: |
| pixel_values: torch.Tensor |
|
|
| def __getitem__(self, key): |
| return getattr(self, key) |
|
|
|
|
| class NormalizeHub: |
| def __init__(self, name: str, params: Optional[Dict] = None): |
| if params is None: |
| params = dict() |
| self.normalize_map = { |
| "general_norm": self.general_normalize, |
| "general_norm_v1": self.general_normalize_v1, |
| "general_norm_v2": self.general_normalize_v2, |
| "standard_norm": self.standard_normalize |
| } |
| assert name in self.normalize_map, f"{name} not in {self.normalize_map.keys()}" |
| self.normalize_func = partial(self.normalize_map.get(name), **params) |
|
|
| def __call__(self, img: np.ndarray) -> np.ndarray: |
| return self.normalize_func(img) |
|
|
| @staticmethod |
| def general_normalize_v2(img: np.ndarray, **kwargs) -> np.ndarray: |
| img /= 127.5 |
| return (img - 1).astype(np.float32) |
|
|
| @staticmethod |
| def general_normalize_v1(img: np.ndarray, **kwargs) -> np.ndarray: |
| img /= 255.0 |
| return img.astype(np.float32) |
|
|
| @staticmethod |
| def general_normalize(img: np.ndarray, **kwargs) -> np.ndarray: |
| """normalize image to [0, 1] by ((img / 255.) - 0.5) / 0.5""" |
| img = img / 255.0 |
| img -= 0.5 |
| img /= 0.5 |
| return img.astype(np.float32) |
|
|
| @staticmethod |
| def standard_normalize( |
| img: np.ndarray, |
| mean: list | tuple | None = None, |
| std: list | tuple | None = None, |
| **kwargs |
| ) -> np.ndarray: |
| """ |
| normalize image by ((img / 255.0) - mean) / std |
| Args: |
| img: image mode is RGB |
| mean: RGB's mean |
| std: RGB's std |
| |
| Returns: np.ndarray |
| """ |
| assert img.ndim == 3 |
| if mean is None: |
| mean = [0.485, 0.456, 0.406] |
| if std is None: |
| std = [0.229, 0.224, 0.225] |
| img = img / 255.0 |
| img = (img - np.array(mean)) / np.array(std) |
| return img.astype(np.float32) |
|
|
|
|
| def to_numpy(img: np.ndarray | Image.Image | str) -> np.ndarray: |
| if isinstance(img, str): |
| assert os.path.exists(img) |
| return cv2.imread(img)[..., ::-1] |
| if isinstance(img, np.ndarray): |
| return img |
| if isinstance(img, Image.Image): |
| return np.array(img.convert('RGB'))[..., :3] |
| raise TypeError(f"Unsupported type {type(img)}") |
|
|
|
|
| def normalize_shape( |
| height: int, |
| width: int, |
| min_height_size: int = 32, |
| min_width_size: int = 32, |
| max_height_size: int = 1344, |
| max_width_size: int = 1344, |
| stride: int = 16 |
| ) -> tuple[int, int]: |
| """Normalize image dimensions to meet min/max constraints and align to stride.""" |
|
|
| if height <= 0 or width <= 0: |
| return min_height_size, min_width_size |
|
|
| scale_min = max( |
| min_height_size / height if height < min_height_size else 0., |
| min_width_size / width if width < min_width_size else 0. |
| ) |
| scale_max = min( |
| max_height_size / height if height > max_height_size else float('inf'), |
| max_width_size / width if width > max_width_size else float('inf') |
| ) |
|
|
| scale = max(scale_min, 1.0) |
| scale = min(scale, scale_max) if scale_max != float('inf') else scale |
|
|
| if scale != 1.0: |
| height = round(height * scale) |
| width = round(width * scale) |
|
|
| height = max(min(round(height / stride) * stride, max_height_size), min_height_size) |
| width = max(min(round(width / stride) * stride, max_width_size), min_width_size) |
|
|
| return height, width |
|
|
|
|
| def resize_with_padding( |
| img: np.ndarray, |
| image_shape: tuple[int | None, int | None] | int, |
| padding_value: int = 255, |
| keep_aspect_ratio: bool = True, |
| center_pad: bool = True, |
| no_scale_up: bool = True, |
| return_crop_info: bool = False, |
| interpolation: int = cv2.INTER_LINEAR, |
| backend: str = "opencv" |
| ) -> np.ndarray | tuple[np.ndarray, tuple]: |
| assert img.ndim == 3 |
| if isinstance(image_shape, int): |
| image_shape: tuple[int, int] = (image_shape, image_shape) |
| img_dtype = img.dtype |
| tgt_height, tgt_width = image_shape[:2] |
| ori_h, ori_w = img.shape[:2] |
| r = min(tgt_height / ori_h, tgt_width / ori_w) |
| if not keep_aspect_ratio: |
| if backend == "opencv": |
| img = cv2.resize(img, image_shape[::-1], interpolation=interpolation) |
| elif backend == "PIL": |
| img = Image.fromarray(img.astype(np.uint8)).resize(image_shape[::-1], resample=2) |
| img = np.array(img).astype(img_dtype) |
| else: |
| raise NotImplementedError(f"Unsupported backend: {backend}") |
| new_h, new_w = img.shape[:2] |
| r_h = new_h / ori_h |
| r_w = new_w / ori_w |
| if not return_crop_info: |
| return img |
| else: |
| start_x, start_y = 0, 0 |
| return img, (start_y, start_y + new_h, start_x, start_x + new_w, r_h, r_w) |
| if no_scale_up: |
| r = min(r, 1.0) |
| new_h, new_w = max(math.floor(r * ori_h), 1), max(math.floor(r * ori_w), 1) |
| |
| if (new_h, new_w) == (ori_h, ori_w): |
| new_img = img |
| else: |
| if backend == "opencv": |
| new_img = cv2.resize(img, (new_w, new_h), interpolation=interpolation) |
| elif backend == "PIL": |
| new_img = Image.fromarray(img.astype(np.uint8)).resize((new_w, new_h), resample=2) |
| new_img = np.array(new_img).astype(img_dtype) |
| else: |
| raise NotImplementedError(f"Unsupported backend: {backend}") |
| delta_h = tgt_height - new_img.shape[0] |
| delta_w = tgt_width - new_img.shape[1] |
| if center_pad: |
| start_x, start_y = math.floor(delta_w / 2), math.floor(delta_h / 2) |
| else: |
| start_x, start_y = 0, 0 |
| bg = np.ones((tgt_height, tgt_width, 3), dtype=img.dtype) * padding_value |
| bg[start_y: start_y + new_h, start_x: start_x + new_w, :] = new_img |
| if not return_crop_info: |
| return bg |
| else: |
| return bg, (start_y, start_y + new_h, start_x, start_x + new_w, r, r) |
|
|
|
|
| def normalize_shape_resize_with_padding( |
| img: np.ndarray, |
| min_height_size: int = 32, |
| min_width_size: int = 32, |
| max_height_size: int = 1344, |
| max_width_size: int = 1344, |
| fixed_factor: int = 16, |
| padding_value: int = 255, |
| keep_aspect_ratio: bool = True, |
| center_pad: bool = True, |
| no_scale_up: bool = True, |
| return_crop_info: bool = False, |
| interpolation: int = cv2.INTER_LINEAR, |
| backend: str = "opencv", |
| image_shape: tuple[int, int] | None = None |
| ) -> np.ndarray | tuple[np.ndarray, tuple]: |
| ori_h, ori_w = img.shape[:2] |
| if image_shape is not None: |
| h, w = image_shape |
| assert min_height_size <= h <= max_height_size and min_width_size <= w <= max_width_size, \ |
| f"Image shape {h}x{w} is not in the range of {min_height_size}x{min_width_size} to {max_height_size}x{max_width_size}" |
| assert h % fixed_factor == 0 and w % fixed_factor == 0, f"Image shape {h}x{w} is not divisible by {fixed_factor}" |
| else: |
| h, w = normalize_shape( |
| height=ori_h, |
| width=ori_w, |
| min_height_size=min_height_size, |
| min_width_size=min_width_size, |
| max_height_size=max_height_size, |
| max_width_size=max_width_size, |
| stride=fixed_factor |
| ) |
| return resize_with_padding( |
| img, |
| image_shape=(h, w), |
| padding_value=padding_value, |
| keep_aspect_ratio=keep_aspect_ratio, |
| center_pad=center_pad, |
| no_scale_up=no_scale_up, |
| return_crop_info=return_crop_info, |
| interpolation=interpolation, |
| backend=backend |
| ) |
|
|
|
|
| class ImageProcess: |
| """ |
| Base class for image process |
| processing order: transform -> resize -> normalize -> permute |
| """ |
| def __init__( |
| self, |
| do_resize: bool = False, |
| do_permute: bool = False, |
| do_normalize: bool = False, |
| resize_config: dict | None = None, |
| normalize_config: dict | None = None, |
| padding_value: int = 255, |
| num_workers: int = 8 |
| ): |
| self.do_resize = do_resize |
| self.do_permute = do_permute |
| self.do_normalize = do_normalize |
|
|
| self.resize_config = resize_config if resize_config is not None else {} |
| self.normalize_config = normalize_config if normalize_config is not None else {} |
|
|
| self.padding_value = padding_value |
|
|
| self.num_workers = min(num_workers, os.cpu_count()) |
| self.normalize_obj = NormalizeHub(**normalize_config) |
| self.resizer = partial(normalize_shape_resize_with_padding, **resize_config) |
|
|
| def __repr__(self): |
| return (f"{self.__class__.__name__}\ndo_resize: {self.do_resize}\ndo_permute: {self.do_permute}\n" |
| f"do_normalize: {self.do_normalize}\nresize_config: {self.resize_config}\n" |
| f"normalize_config: {self.normalize_config}\n" |
| f"num_workers: {self.num_workers}\n") |
|
|
| def save_pretrained(self, save_rtpath): |
| config = dict( |
| do_resize=self.do_resize, |
| do_permute=self.do_permute, |
| do_normalize=self.do_normalize, |
| resize_config=self.resize_config, |
| normalize_config=self.normalize_config, |
| padding_value=self.padding_value, |
| num_workers=self.num_workers |
| ) |
| with open(os.path.join(save_rtpath, "img_processor.json"), 'w', encoding="utf-8") as f: |
| f.write(json.dumps(config, indent=4)) |
|
|
| def __call__( |
| self, |
| img: list[np.ndarray] | np.ndarray | list[Image.Image] | Image.Image | str | list[str], |
| image_shape: tuple[int, int] | None = None, |
| to_continuous: bool = True, |
| **kwargs |
| ) -> ImageProcessOutput: |
| """batch input -> batch output""" |
| if isinstance(img, list): |
| img = [to_numpy(i) for i in img] |
| else: |
| img = to_numpy(img) |
| res = self.preprocessing( |
| img=img, image_shape=image_shape, to_continuous=to_continuous, **kwargs |
| ) |
| if isinstance(res, tuple): |
| img = res[0] |
| else: |
| img = res |
| return ImageProcessOutput(pixel_values=torch.from_numpy(img)) |
|
|
| def _single_preprocessing( |
| self, |
| _img: np.ndarray | Image.Image, |
| image_shape: tuple[int, int] | None = None, |
| order_id: int | None = None, |
| **kwargs |
| ) -> np.ndarray | tuple: |
| """ |
| Runner order: |
| step1: check whether transform |
| step2: check whether resize |
| step3: check whether normalize |
| step4: check whether permute |
| Args: |
| _img: |
| image_shape: tuple[int, int] | None = None, |
| order_id: int | None |
| **kwargs: |
| |
| Returns: |
| |
| """ |
| info = None |
|
|
| if kwargs.get("do_resize", getattr(self, "do_resize", False)): |
| assert hasattr(self, "resizer"), f"{self.__class__} does not have `resizer`" |
| if image_shape is not None: |
| _img = getattr(self, "resizer")(_img, image_shape=image_shape) |
| else: |
| _img = getattr(self, "resizer")(_img) |
| if isinstance(_img, tuple): |
| _img, *info = _img |
| info = info if len(info) > 1 else info[0] |
|
|
| if kwargs.get("do_normalize", getattr(self, "do_normalize", False)): |
| assert hasattr(self, "normalize_obj"), f"{self.__class__} does not have `normalize_obj`" |
| _img = getattr(self, "normalize_obj")(img=_img.astype(np.float32)) |
| else: |
| _img = _img.astype(np.float32) |
| if kwargs.get("do_permute", getattr(self, "do_permute", False)): |
| _img = _img.transpose(2, 0, 1) |
|
|
| if order_id is None: |
| return _img if info is None else (_img, info) |
| else: |
| return (_img, order_id) if info is None else (_img, info, order_id) |
|
|
| def preprocessing( |
| self, |
| img: np.ndarray | list[np.ndarray], |
| image_shape: tuple[int, int] | None = None, |
| to_continuous: bool = True, |
| **kwargs |
| ) -> np.ndarray | tuple[np.ndarray, list]: |
| """ |
| resize & normalize & permute |
| Args: |
| img (Union[np.ndarray, List[np.ndarray]]):\ |
| image_shape: tuple[int, int] | None = None, |
| to_continuous |
| timer(Timer | None) |
| Return: |
| img_pre (np.ndarray), (B, 3, H, W) |
| """ |
| if isinstance(img, np.ndarray): |
| img = [img] |
| img_ls = [] |
| img_shape_ls = [] |
| info_ls = [] |
|
|
| if len(img) == 1 or self.num_workers <= 1: |
| for idx, cur_img in enumerate(img): |
| info = None |
| cur_img = self._single_preprocessing(cur_img, image_shape=image_shape, **kwargs) |
| if isinstance(cur_img, tuple): |
| cur_img, *info = cur_img |
| info = info if len(info) > 1 else info[0] |
|
|
| info_ls.append(info) |
| img_ls.append(cur_img) |
| img_shape_ls.append(cur_img.shape) |
|
|
| else: |
| with ThreadPoolExecutor(max_workers=self.num_workers) as executor: |
| futures = [ |
| executor.submit( |
| self._single_preprocessing, image_shape=image_shape, _img=cur_img, order_id=order_id |
| ) for order_id, cur_img in enumerate(img) |
| ] |
| results = [] |
| for future in as_completed(futures): |
| results.append(future.result()) |
| results = [i[0] if len(i) == 2 else i[:-1] for i in sorted(results, key=lambda x: x[-1])] |
| info = None |
| for result in results: |
| if isinstance(result, tuple): |
| cur_img, *info = result |
| info = info if len(info) > 1 else info[0] |
| else: |
| cur_img = result |
| info_ls.append(info) |
| img_ls.append(cur_img) |
| img_shape_ls.append(cur_img.shape) |
|
|
| if len(set(img_shape_ls)) > 1: |
| padding_value = kwargs.get("padding_value", getattr(self, "padding_value", None)) |
| assert padding_value is not None, f"You should setting `padding_value`" |
|
|
| if kwargs.get("do_normalize", self.do_normalize): |
| dtype = np.float32 |
| else: |
| dtype = img_ls[0].dtype |
|
|
| img_batch = np.ones( |
| shape=np.max(np.array(img_shape_ls), axis=0).tolist(), |
| dtype=dtype |
| ) * padding_value |
|
|
| if kwargs.get("do_normalize", getattr(self, "do_normalize", False)): |
| assert hasattr(self, "normalize_obj"), f"{self.__class__} does not have `normalize_obj`" |
|
|
| if kwargs.get("do_permute", getattr(self, "do_permute", False)): |
| img_batch = img_batch.transpose(1, 2, 0) |
| img_batch = getattr(self, "normalize_obj")(img_batch) |
| img_batch = img_batch.transpose(2, 0, 1) |
| else: |
| img_batch = getattr(self, "normalize_obj")(img_batch) |
|
|
| img_batch = np.stack([img_batch] * len(img_shape_ls), axis=0) |
| for idx, cur_img in enumerate(img_ls): |
| cur_shape = cur_img.shape |
| img_batch[idx, :cur_shape[0], :cur_shape[1], :cur_shape[2]] = cur_img |
| return img_batch if info_ls[0] is None else (img_batch, info_ls) |
| else: |
| if to_continuous: |
| return np.ascontiguousarray(np.stack(img_ls)) if info_ls[0] is None \ |
| else (np.ascontiguousarray(np.stack(img_ls)), info_ls) |
| else: |
| return np.stack(img_ls) if info_ls[0] is None else (np.stack(img_ls), info_ls) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "ImageProcess": |
| subfolder = kwargs.get("subfolder") |
| if os.path.isfile(pretrained_model_name_or_path): |
| config_file = pretrained_model_name_or_path |
| elif os.path.isdir(pretrained_model_name_or_path): |
| config_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path |
| config_file = os.path.join(config_dir, "img_processor.json") |
| else: |
| from transformers.utils import cached_file |
|
|
| cache_kwargs = { |
| key: kwargs[key] |
| for key in ( |
| "cache_dir", |
| "force_download", |
| "local_files_only", |
| "revision", |
| "subfolder", |
| "token", |
| ) |
| if key in kwargs and kwargs[key] is not None |
| } |
| if "use_auth_token" in kwargs and "token" not in cache_kwargs: |
| cache_kwargs["token"] = kwargs["use_auth_token"] |
| config_file = cached_file(pretrained_model_name_or_path, "img_processor.json", **cache_kwargs) |
| assert config_file is not None and os.path.exists(config_file), f"{config_file} does not exist!" |
| with open(config_file, 'r', encoding="utf-8", errors="ignore") as f: |
| config = json.load(f, strict=True) |
| return cls(**config) |
|
|