File size: 7,279 Bytes
d2469d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from typing import List, Optional, Union
import re
import torch
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import (
    ImageInput,
    make_list_of_images,
    valid_images,
    to_numpy_array,
)
from transformers.utils import TensorType


rules = [
    (r'-<\|sn\|>', ''),
    (r' <\|sn\|>', ' '),
    (r'<\|sn\|>', ' '),
    (r'<\|unk\|>', ''),
    (r'<s>', ''),
    (r'</s>', ''),
    (r'\uffff', ''),
    (r'_{4,}', '___'),
    (r'\.{4,}', '...'),
]


def clean_special_tokens(text):
    text = text.replace(' ', '').replace('Ġ', ' ').replace('Ċ', '\n').replace(
        '<|bos|>', '').replace('<|eos|>', '').replace('<|pad|>', '')
    for rule in rules:
        text = re.sub(rule[0], rule[1], text)
    text = text.replace('<tdcolspan=', '<td colspan=')
    text = text.replace('<tdrowspan=', '<td rowspan=')
    text = text.replace('"colspan=', '" colspan=')
    return text

class UniRecImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values"]

    def __init__(
            self,
            max_side: List[int] = [64 * 15, 64 * 22],  # [960, 1408] w, h
            divided_factor: List[int] = [64, 64],
            do_resize: bool = True,
            do_rescale: bool = True,
            rescale_factor: float = 1 / 255.0,
            do_normalize: bool = True,
            image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
            image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
            resample: int = Image.BICUBIC,  # 对应 T.InterpolationMode.BICUBIC
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.max_side = max_side
        self.divided_factor = divided_factor
        self.do_resize = do_resize
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = image_mean
        self.image_std = image_std
        self.resample = resample

    def _calculate_target_size(self, original_width, original_height):
        """
        复刻自定义的 resize_image 和整除逻辑
        """
        max_width, max_height = self.max_side[0], self.max_side[1]

        # 1. 计算宽高比逻辑
        aspect_ratio = original_width / original_height

        if original_width > max_width or original_height > max_height:
            if (max_width / max_height) >= aspect_ratio:
                # 按高度限制
                new_height = max_height
                new_width = int(new_height * aspect_ratio)
            else:
                # 按宽度限制
                new_width = max_width
                new_height = int(new_width / aspect_ratio)
        else:
            new_width, new_height = original_width, original_height

        # 2. 整除因子逻辑 (Divided Factor)
        # 注意:原代码中 max_side[0] 是宽还是高取决于 imgW, imgH 的定义。
        # 通常 PIL size 是 (W, H)。
        # 原代码:h_r = max(int(h_r // factor * factor), 64)

        div_w, div_h = self.divided_factor[0], self.divided_factor[1]

        final_width = max(int(new_width // div_w * div_w), 64)
        final_height = max(int(new_height // div_h * div_h), 64)

        return (final_width, final_height)

    def preprocess(
            self,
            images: ImageInput,
            do_resize: Optional[bool] = None,
            do_rescale: Optional[bool] = None,
            do_normalize: Optional[bool] = None,
            return_tensors: Optional[Union[str, TensorType]] = None,
            data_format: Optional[str] = "channels_first",
            input_data_format: Optional[str] = None,
            **kwargs,
    ) -> BatchFeature:
        """
        主处理函数
        """
        do_resize = do_resize if do_resize is not None else self.do_resize
        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
        do_normalize = do_normalize if do_normalize is not None else self.do_normalize

        images = make_list_of_images(images)

        if not valid_images(images):
            raise ValueError("Invalid image type. Must be PIL Image, numpy array, or tensor.")

        # 准备结果容器
        pixel_values = []
        valid_ratios = []

        for image in images:
            # 确保是 PIL Image 以便使用 resize
            if not isinstance(image, Image.Image):
                image = to_numpy_array(image)
                image = Image.fromarray(image)

            original_width, original_height = image.size

            # --- 1. Resize (自定义逻辑) ---
            if do_resize:
                target_size = self._calculate_target_size(original_width, original_height)
                # 计算 valid_ratio (原代码逻辑: min(1.0, float(w_r / w)))
                # 注意:这里用的是 resize 后的宽 / 原始宽
                valid_ratio = min(1.0, float(target_size[0] / original_width))

                # 执行 Resize
                image = image.resize(target_size, resample=self.resample)
            else:
                valid_ratio = 1.0

            # --- 2. Convert to Numpy & Rescale (ToTensor 的一部分) ---
            # T.ToTensor() 会将 PIL [0, 255] 转换为 Float [0.0, 1.0]
            image = to_numpy_array(image)[:, :, :3]  # (H, W, C)

            if do_rescale:
                image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)

            # --- 3. Normalize ---
            # T.Normalize(0.5, 0.5) -> (x - 0.5) / 0.5
            if do_normalize:
                image = self.normalize(image, mean=self.image_mean, std=self.image_std,
                                       input_data_format=input_data_format)

            # --- 4. Transpose (HWC -> CHW) ---
            # HuggingFace 默认输出通常需要转换通道
            if data_format == "channels_first":
                image = image.transpose((2, 0, 1))

            pixel_values.append(image)
            valid_ratios.append(valid_ratio)

        # 打包返回
        data = {"pixel_values": pixel_values, "valid_ratio": valid_ratios}

        return BatchFeature(data=data, tensor_type=return_tensors)


if __name__ == "__main__":
    # 1. 实例化 Processor
    processor = UniRecImageProcessor(
        max_side=[960, 1408],
        divided_factor=[64, 64]
    )

    # 2. 准备测试图片
    img_path = "/mnt/bn/dykdataa800/workspace/openocrdoc/OpenOCR/crop_img_hand/Snipaste_2025-04-13_20-46-06.png"
    image = Image.open(img_path).convert("RGB")

    # 3. 处理图片 (返回 PyTorch Tensor)
    inputs = processor(image, return_tensors="pt")

    print("Keys:", inputs.keys())
    print("Shape:", inputs["pixel_values"].shape)
    print("Valid Ratio:", inputs["valid_ratio"])

    # 保存到本地目录
    processor.save_pretrained("./unirec_0_1b_mbart")

    # 从本地加载
    loaded_processor = UniRecImageProcessor.from_pretrained("./unirec_0_1b_mbart")

    # 验证参数是否恢复
    print(loaded_processor.max_side)  # [960, 1408]
    print(loaded_processor.divided_factor)  # [64, 64]

    result = loaded_processor(image, return_tensors="pt")
    print(torch.equal(inputs["pixel_values"], result["pixel_values"]))