| | from typing import List, Optional, Union |
| | import re |
| | import torch |
| | from PIL import Image |
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| | from transformers.image_utils import ( |
| | ImageInput, |
| | make_list_of_images, |
| | valid_images, |
| | to_numpy_array, |
| | ) |
| | from transformers.utils import TensorType |
| |
|
| |
|
| | rules = [ |
| | (r'-<\|sn\|>', ''), |
| | (r' <\|sn\|>', ' '), |
| | (r'<\|sn\|>', ' '), |
| | (r'<\|unk\|>', ''), |
| | (r'<s>', ''), |
| | (r'</s>', ''), |
| | (r'\uffff', ''), |
| | (r'_{4,}', '___'), |
| | (r'\.{4,}', '...'), |
| | ] |
| |
|
| |
|
| | def clean_special_tokens(text): |
| | text = text.replace(' ', '').replace('Ġ', ' ').replace('Ċ', '\n').replace( |
| | '<|bos|>', '').replace('<|eos|>', '').replace('<|pad|>', '') |
| | for rule in rules: |
| | text = re.sub(rule[0], rule[1], text) |
| | text = text.replace('<tdcolspan=', '<td colspan=') |
| | text = text.replace('<tdrowspan=', '<td rowspan=') |
| | text = text.replace('"colspan=', '" colspan=') |
| | return text |
| |
|
| | class UniRecImageProcessor(BaseImageProcessor): |
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | max_side: List[int] = [64 * 15, 64 * 22], |
| | divided_factor: List[int] = [64, 64], |
| | do_resize: bool = True, |
| | do_rescale: bool = True, |
| | rescale_factor: float = 1 / 255.0, |
| | do_normalize: bool = True, |
| | image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5], |
| | image_std: Union[float, List[float]] = [0.5, 0.5, 0.5], |
| | resample: int = Image.BICUBIC, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.max_side = max_side |
| | self.divided_factor = divided_factor |
| | self.do_resize = do_resize |
| | self.do_rescale = do_rescale |
| | self.rescale_factor = rescale_factor |
| | self.do_normalize = do_normalize |
| | self.image_mean = image_mean |
| | self.image_std = image_std |
| | self.resample = resample |
| |
|
| | def _calculate_target_size(self, original_width, original_height): |
| | """ |
| | 复刻自定义的 resize_image 和整除逻辑 |
| | """ |
| | max_width, max_height = self.max_side[0], self.max_side[1] |
| |
|
| | |
| | aspect_ratio = original_width / original_height |
| |
|
| | if original_width > max_width or original_height > max_height: |
| | if (max_width / max_height) >= aspect_ratio: |
| | |
| | new_height = max_height |
| | new_width = int(new_height * aspect_ratio) |
| | else: |
| | |
| | new_width = max_width |
| | new_height = int(new_width / aspect_ratio) |
| | else: |
| | new_width, new_height = original_width, original_height |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | div_w, div_h = self.divided_factor[0], self.divided_factor[1] |
| |
|
| | final_width = max(int(new_width // div_w * div_w), 64) |
| | final_height = max(int(new_height // div_h * div_h), 64) |
| |
|
| | return (final_width, final_height) |
| |
|
| | def preprocess( |
| | self, |
| | images: ImageInput, |
| | do_resize: Optional[bool] = None, |
| | do_rescale: Optional[bool] = None, |
| | do_normalize: Optional[bool] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | data_format: Optional[str] = "channels_first", |
| | input_data_format: Optional[str] = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | """ |
| | 主处理函数 |
| | """ |
| | do_resize = do_resize if do_resize is not None else self.do_resize |
| | do_rescale = do_rescale if do_rescale is not None else self.do_rescale |
| | do_normalize = do_normalize if do_normalize is not None else self.do_normalize |
| |
|
| | images = make_list_of_images(images) |
| |
|
| | if not valid_images(images): |
| | raise ValueError("Invalid image type. Must be PIL Image, numpy array, or tensor.") |
| |
|
| | |
| | pixel_values = [] |
| | valid_ratios = [] |
| |
|
| | for image in images: |
| | |
| | if not isinstance(image, Image.Image): |
| | image = to_numpy_array(image) |
| | image = Image.fromarray(image) |
| |
|
| | original_width, original_height = image.size |
| |
|
| | |
| | if do_resize: |
| | target_size = self._calculate_target_size(original_width, original_height) |
| | |
| | |
| | valid_ratio = min(1.0, float(target_size[0] / original_width)) |
| |
|
| | |
| | image = image.resize(target_size, resample=self.resample) |
| | else: |
| | valid_ratio = 1.0 |
| |
|
| | |
| | |
| | image = to_numpy_array(image)[:, :, :3] |
| |
|
| | if do_rescale: |
| | image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format) |
| |
|
| | |
| | |
| | if do_normalize: |
| | image = self.normalize(image, mean=self.image_mean, std=self.image_std, |
| | input_data_format=input_data_format) |
| |
|
| | |
| | |
| | if data_format == "channels_first": |
| | image = image.transpose((2, 0, 1)) |
| |
|
| | pixel_values.append(image) |
| | valid_ratios.append(valid_ratio) |
| |
|
| | |
| | data = {"pixel_values": pixel_values, "valid_ratio": valid_ratios} |
| |
|
| | return BatchFeature(data=data, tensor_type=return_tensors) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | processor = UniRecImageProcessor( |
| | max_side=[960, 1408], |
| | divided_factor=[64, 64] |
| | ) |
| |
|
| | |
| | img_path = "/mnt/bn/dykdataa800/workspace/openocrdoc/OpenOCR/crop_img_hand/Snipaste_2025-04-13_20-46-06.png" |
| | image = Image.open(img_path).convert("RGB") |
| |
|
| | |
| | inputs = processor(image, return_tensors="pt") |
| |
|
| | print("Keys:", inputs.keys()) |
| | print("Shape:", inputs["pixel_values"].shape) |
| | print("Valid Ratio:", inputs["valid_ratio"]) |
| |
|
| | |
| | processor.save_pretrained("./unirec_0_1b_mbart") |
| |
|
| | |
| | loaded_processor = UniRecImageProcessor.from_pretrained("./unirec_0_1b_mbart") |
| |
|
| | |
| | print(loaded_processor.max_side) |
| | print(loaded_processor.divided_factor) |
| |
|
| | result = loaded_processor(image, return_tensors="pt") |
| | print(torch.equal(inputs["pixel_values"], result["pixel_values"])) |