Hiro-MOSS-OCR-0.3B / image_processing.py
wwjiang's picture
Upload 3 files
66c5b87 verified
Raw
History Blame Contribute Delete
17.9 kB
import os
import math
import json
from dataclasses import dataclass
from typing import Dict, Optional
from functools import partial
from concurrent.futures import as_completed, ThreadPoolExecutor
from PIL import Image
import cv2
import numpy as np
import torch
@dataclass
class ImageProcessOutput:
pixel_values: torch.Tensor
def __getitem__(self, key):
return getattr(self, key)
class NormalizeHub:
def __init__(self, name: str, params: Optional[Dict] = None):
if params is None:
params = dict()
self.normalize_map = {
"general_norm": self.general_normalize,
"general_norm_v1": self.general_normalize_v1,
"general_norm_v2": self.general_normalize_v2,
"standard_norm": self.standard_normalize
}
assert name in self.normalize_map, f"{name} not in {self.normalize_map.keys()}"
self.normalize_func = partial(self.normalize_map.get(name), **params)
def __call__(self, img: np.ndarray) -> np.ndarray:
return self.normalize_func(img)
@staticmethod
def general_normalize_v2(img: np.ndarray, **kwargs) -> np.ndarray:
img /= 127.5
return (img - 1).astype(np.float32)
@staticmethod
def general_normalize_v1(img: np.ndarray, **kwargs) -> np.ndarray:
img /= 255.0
return img.astype(np.float32)
@staticmethod
def general_normalize(img: np.ndarray, **kwargs) -> np.ndarray:
"""normalize image to [0, 1] by ((img / 255.) - 0.5) / 0.5"""
img = img / 255.0
img -= 0.5
img /= 0.5
return img.astype(np.float32)
@staticmethod
def standard_normalize(
img: np.ndarray,
mean: list | tuple | None = None,
std: list | tuple | None = None,
**kwargs
) -> np.ndarray:
"""
normalize image by ((img / 255.0) - mean) / std
Args:
img: image mode is RGB
mean: RGB's mean
std: RGB's std
Returns: np.ndarray
"""
assert img.ndim == 3
if mean is None:
mean = [0.485, 0.456, 0.406]
if std is None:
std = [0.229, 0.224, 0.225]
img = img / 255.0
img = (img - np.array(mean)) / np.array(std)
return img.astype(np.float32)
def to_numpy(img: np.ndarray | Image.Image | str) -> np.ndarray:
if isinstance(img, str):
assert os.path.exists(img)
return cv2.imread(img)[..., ::-1]
if isinstance(img, np.ndarray):
return img
if isinstance(img, Image.Image):
return np.array(img.convert('RGB'))[..., :3]
raise TypeError(f"Unsupported type {type(img)}")
def normalize_shape(
height: int,
width: int,
min_height_size: int = 32,
min_width_size: int = 32,
max_height_size: int = 1344,
max_width_size: int = 1344,
stride: int = 16
) -> tuple[int, int]:
"""Normalize image dimensions to meet min/max constraints and align to stride."""
if height <= 0 or width <= 0:
return min_height_size, min_width_size
scale_min = max(
min_height_size / height if height < min_height_size else 0.,
min_width_size / width if width < min_width_size else 0.
)
scale_max = min(
max_height_size / height if height > max_height_size else float('inf'),
max_width_size / width if width > max_width_size else float('inf')
)
scale = max(scale_min, 1.0) # 至少保持原大小
scale = min(scale, scale_max) if scale_max != float('inf') else scale
if scale != 1.0:
height = round(height * scale)
width = round(width * scale)
height = max(min(round(height / stride) * stride, max_height_size), min_height_size)
width = max(min(round(width / stride) * stride, max_width_size), min_width_size)
return height, width
def resize_with_padding(
img: np.ndarray,
image_shape: tuple[int | None, int | None] | int,
padding_value: int = 255,
keep_aspect_ratio: bool = True,
center_pad: bool = True,
no_scale_up: bool = True,
return_crop_info: bool = False,
interpolation: int = cv2.INTER_LINEAR, # only influence opencv backend
backend: str = "opencv"
) -> np.ndarray | tuple[np.ndarray, tuple]:
assert img.ndim == 3
if isinstance(image_shape, int):
image_shape: tuple[int, int] = (image_shape, image_shape)
img_dtype = img.dtype
tgt_height, tgt_width = image_shape[:2]
ori_h, ori_w = img.shape[:2]
r = min(tgt_height / ori_h, tgt_width / ori_w)
if not keep_aspect_ratio:
if backend == "opencv":
img = cv2.resize(img, image_shape[::-1], interpolation=interpolation)
elif backend == "PIL":
img = Image.fromarray(img.astype(np.uint8)).resize(image_shape[::-1], resample=2) # default `BILINEAR`
img = np.array(img).astype(img_dtype)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
new_h, new_w = img.shape[:2]
r_h = new_h / ori_h
r_w = new_w / ori_w
if not return_crop_info:
return img
else:
start_x, start_y = 0, 0
return img, (start_y, start_y + new_h, start_x, start_x + new_w, r_h, r_w)
if no_scale_up:
r = min(r, 1.0)
new_h, new_w = max(math.floor(r * ori_h), 1), max(math.floor(r * ori_w), 1)
# print(new_h, new_w)
if (new_h, new_w) == (ori_h, ori_w):
new_img = img
else:
if backend == "opencv":
new_img = cv2.resize(img, (new_w, new_h), interpolation=interpolation)
elif backend == "PIL":
new_img = Image.fromarray(img.astype(np.uint8)).resize((new_w, new_h), resample=2) # default `BILINEAR`
new_img = np.array(new_img).astype(img_dtype)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
delta_h = tgt_height - new_img.shape[0]
delta_w = tgt_width - new_img.shape[1]
if center_pad:
start_x, start_y = math.floor(delta_w / 2), math.floor(delta_h / 2)
else:
start_x, start_y = 0, 0
bg = np.ones((tgt_height, tgt_width, 3), dtype=img.dtype) * padding_value
bg[start_y: start_y + new_h, start_x: start_x + new_w, :] = new_img
if not return_crop_info:
return bg
else:
return bg, (start_y, start_y + new_h, start_x, start_x + new_w, r, r)
def normalize_shape_resize_with_padding(
img: np.ndarray,
min_height_size: int = 32,
min_width_size: int = 32,
max_height_size: int = 1344,
max_width_size: int = 1344,
fixed_factor: int = 16,
padding_value: int = 255,
keep_aspect_ratio: bool = True,
center_pad: bool = True,
no_scale_up: bool = True,
return_crop_info: bool = False,
interpolation: int = cv2.INTER_LINEAR,
backend: str = "opencv",
image_shape: tuple[int, int] | None = None
) -> np.ndarray | tuple[np.ndarray, tuple]:
ori_h, ori_w = img.shape[:2]
if image_shape is not None:
h, w = image_shape
assert min_height_size <= h <= max_height_size and min_width_size <= w <= max_width_size, \
f"Image shape {h}x{w} is not in the range of {min_height_size}x{min_width_size} to {max_height_size}x{max_width_size}"
assert h % fixed_factor == 0 and w % fixed_factor == 0, f"Image shape {h}x{w} is not divisible by {fixed_factor}"
else:
h, w = normalize_shape(
height=ori_h,
width=ori_w,
min_height_size=min_height_size,
min_width_size=min_width_size,
max_height_size=max_height_size,
max_width_size=max_width_size,
stride=fixed_factor
)
return resize_with_padding(
img,
image_shape=(h, w),
padding_value=padding_value,
keep_aspect_ratio=keep_aspect_ratio,
center_pad=center_pad,
no_scale_up=no_scale_up,
return_crop_info=return_crop_info,
interpolation=interpolation,
backend=backend
)
class ImageProcess:
"""
Base class for image process
processing order: transform -> resize -> normalize -> permute
"""
def __init__(
self,
do_resize: bool = False,
do_permute: bool = False,
do_normalize: bool = False,
resize_config: dict | None = None,
normalize_config: dict | None = None,
padding_value: int = 255,
num_workers: int = 8
):
self.do_resize = do_resize
self.do_permute = do_permute
self.do_normalize = do_normalize
self.resize_config = resize_config if resize_config is not None else {}
self.normalize_config = normalize_config if normalize_config is not None else {}
self.padding_value = padding_value
self.num_workers = min(num_workers, os.cpu_count())
self.normalize_obj = NormalizeHub(**normalize_config)
self.resizer = partial(normalize_shape_resize_with_padding, **resize_config)
def __repr__(self):
return (f"{self.__class__.__name__}\ndo_resize: {self.do_resize}\ndo_permute: {self.do_permute}\n"
f"do_normalize: {self.do_normalize}\nresize_config: {self.resize_config}\n"
f"normalize_config: {self.normalize_config}\n"
f"num_workers: {self.num_workers}\n")
def save_pretrained(self, save_rtpath):
config = dict(
do_resize=self.do_resize,
do_permute=self.do_permute,
do_normalize=self.do_normalize,
resize_config=self.resize_config,
normalize_config=self.normalize_config,
padding_value=self.padding_value,
num_workers=self.num_workers
)
with open(os.path.join(save_rtpath, "img_processor.json"), 'w', encoding="utf-8") as f:
f.write(json.dumps(config, indent=4))
def __call__(
self,
img: list[np.ndarray] | np.ndarray | list[Image.Image] | Image.Image | str | list[str],
image_shape: tuple[int, int] | None = None,
to_continuous: bool = True,
**kwargs
) -> ImageProcessOutput:
"""batch input -> batch output"""
if isinstance(img, list):
img = [to_numpy(i) for i in img]
else:
img = to_numpy(img)
res = self.preprocessing(
img=img, image_shape=image_shape, to_continuous=to_continuous, **kwargs
)
if isinstance(res, tuple):
img = res[0]
else:
img = res
return ImageProcessOutput(pixel_values=torch.from_numpy(img))
def _single_preprocessing(
self,
_img: np.ndarray | Image.Image,
image_shape: tuple[int, int] | None = None,
order_id: int | None = None,
**kwargs
) -> np.ndarray | tuple:
"""
Runner order:
step1: check whether transform
step2: check whether resize
step3: check whether normalize
step4: check whether permute
Args:
_img:
image_shape: tuple[int, int] | None = None,
order_id: int | None
**kwargs:
Returns:
"""
info = None
if kwargs.get("do_resize", getattr(self, "do_resize", False)):
assert hasattr(self, "resizer"), f"{self.__class__} does not have `resizer`"
if image_shape is not None:
_img = getattr(self, "resizer")(_img, image_shape=image_shape)
else:
_img = getattr(self, "resizer")(_img)
if isinstance(_img, tuple):
_img, *info = _img
info = info if len(info) > 1 else info[0]
if kwargs.get("do_normalize", getattr(self, "do_normalize", False)):
assert hasattr(self, "normalize_obj"), f"{self.__class__} does not have `normalize_obj`"
_img = getattr(self, "normalize_obj")(img=_img.astype(np.float32))
else:
_img = _img.astype(np.float32)
if kwargs.get("do_permute", getattr(self, "do_permute", False)):
_img = _img.transpose(2, 0, 1)
if order_id is None:
return _img if info is None else (_img, info)
else:
return (_img, order_id) if info is None else (_img, info, order_id)
def preprocessing(
self,
img: np.ndarray | list[np.ndarray],
image_shape: tuple[int, int] | None = None,
to_continuous: bool = True,
**kwargs
) -> np.ndarray | tuple[np.ndarray, list]:
"""
resize & normalize & permute
Args:
img (Union[np.ndarray, List[np.ndarray]]):\
image_shape: tuple[int, int] | None = None,
to_continuous
timer(Timer | None)
Return:
img_pre (np.ndarray), (B, 3, H, W)
"""
if isinstance(img, np.ndarray):
img = [img]
img_ls = []
img_shape_ls = []
info_ls = []
if len(img) == 1 or self.num_workers <= 1:
for idx, cur_img in enumerate(img):
info = None
cur_img = self._single_preprocessing(cur_img, image_shape=image_shape, **kwargs)
if isinstance(cur_img, tuple):
cur_img, *info = cur_img
info = info if len(info) > 1 else info[0]
info_ls.append(info)
img_ls.append(cur_img)
img_shape_ls.append(cur_img.shape)
else:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
futures = [
executor.submit(
self._single_preprocessing, image_shape=image_shape, _img=cur_img, order_id=order_id
) for order_id, cur_img in enumerate(img)
]
results = []
for future in as_completed(futures):
results.append(future.result())
results = [i[0] if len(i) == 2 else i[:-1] for i in sorted(results, key=lambda x: x[-1])]
info = None
for result in results:
if isinstance(result, tuple):
cur_img, *info = result
info = info if len(info) > 1 else info[0]
else:
cur_img = result
info_ls.append(info)
img_ls.append(cur_img)
img_shape_ls.append(cur_img.shape)
if len(set(img_shape_ls)) > 1:
padding_value = kwargs.get("padding_value", getattr(self, "padding_value", None))
assert padding_value is not None, f"You should setting `padding_value`"
if kwargs.get("do_normalize", self.do_normalize):
dtype = np.float32
else:
dtype = img_ls[0].dtype
img_batch = np.ones(
shape=np.max(np.array(img_shape_ls), axis=0).tolist(),
dtype=dtype
) * padding_value
if kwargs.get("do_normalize", getattr(self, "do_normalize", False)):
assert hasattr(self, "normalize_obj"), f"{self.__class__} does not have `normalize_obj`"
if kwargs.get("do_permute", getattr(self, "do_permute", False)):
img_batch = img_batch.transpose(1, 2, 0)
img_batch = getattr(self, "normalize_obj")(img_batch)
img_batch = img_batch.transpose(2, 0, 1)
else:
img_batch = getattr(self, "normalize_obj")(img_batch)
img_batch = np.stack([img_batch] * len(img_shape_ls), axis=0)
for idx, cur_img in enumerate(img_ls):
cur_shape = cur_img.shape
img_batch[idx, :cur_shape[0], :cur_shape[1], :cur_shape[2]] = cur_img
return img_batch if info_ls[0] is None else (img_batch, info_ls)
else:
if to_continuous:
return np.ascontiguousarray(np.stack(img_ls)) if info_ls[0] is None \
else (np.ascontiguousarray(np.stack(img_ls)), info_ls)
else:
return np.stack(img_ls) if info_ls[0] is None else (np.stack(img_ls), info_ls)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "ImageProcess":
subfolder = kwargs.get("subfolder")
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
config_dir = os.path.join(pretrained_model_name_or_path, subfolder) if subfolder else pretrained_model_name_or_path
config_file = os.path.join(config_dir, "img_processor.json")
else:
from transformers.utils import cached_file
cache_kwargs = {
key: kwargs[key]
for key in (
"cache_dir",
"force_download",
"local_files_only",
"revision",
"subfolder",
"token",
)
if key in kwargs and kwargs[key] is not None
}
if "use_auth_token" in kwargs and "token" not in cache_kwargs:
cache_kwargs["token"] = kwargs["use_auth_token"]
config_file = cached_file(pretrained_model_name_or_path, "img_processor.json", **cache_kwargs)
assert config_file is not None and os.path.exists(config_file), f"{config_file} does not exist!"
with open(config_file, 'r', encoding="utf-8", errors="ignore") as f:
config = json.load(f, strict=True)
return cls(**config)