File size: 7,279 Bytes
d2469d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | from typing import List, Optional, Union
import re
import torch
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import (
ImageInput,
make_list_of_images,
valid_images,
to_numpy_array,
)
from transformers.utils import TensorType
rules = [
(r'-<\|sn\|>', ''),
(r' <\|sn\|>', ' '),
(r'<\|sn\|>', ' '),
(r'<\|unk\|>', ''),
(r'<s>', ''),
(r'</s>', ''),
(r'\uffff', ''),
(r'_{4,}', '___'),
(r'\.{4,}', '...'),
]
def clean_special_tokens(text):
text = text.replace(' ', '').replace('Ġ', ' ').replace('Ċ', '\n').replace(
'<|bos|>', '').replace('<|eos|>', '').replace('<|pad|>', '')
for rule in rules:
text = re.sub(rule[0], rule[1], text)
text = text.replace('<tdcolspan=', '<td colspan=')
text = text.replace('<tdrowspan=', '<td rowspan=')
text = text.replace('"colspan=', '" colspan=')
return text
class UniRecImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
max_side: List[int] = [64 * 15, 64 * 22], # [960, 1408] w, h
divided_factor: List[int] = [64, 64],
do_resize: bool = True,
do_rescale: bool = True,
rescale_factor: float = 1 / 255.0,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
resample: int = Image.BICUBIC, # 对应 T.InterpolationMode.BICUBIC
**kwargs,
):
super().__init__(**kwargs)
self.max_side = max_side
self.divided_factor = divided_factor
self.do_resize = do_resize
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.resample = resample
def _calculate_target_size(self, original_width, original_height):
"""
复刻自定义的 resize_image 和整除逻辑
"""
max_width, max_height = self.max_side[0], self.max_side[1]
# 1. 计算宽高比逻辑
aspect_ratio = original_width / original_height
if original_width > max_width or original_height > max_height:
if (max_width / max_height) >= aspect_ratio:
# 按高度限制
new_height = max_height
new_width = int(new_height * aspect_ratio)
else:
# 按宽度限制
new_width = max_width
new_height = int(new_width / aspect_ratio)
else:
new_width, new_height = original_width, original_height
# 2. 整除因子逻辑 (Divided Factor)
# 注意:原代码中 max_side[0] 是宽还是高取决于 imgW, imgH 的定义。
# 通常 PIL size 是 (W, H)。
# 原代码:h_r = max(int(h_r // factor * factor), 64)
div_w, div_h = self.divided_factor[0], self.divided_factor[1]
final_width = max(int(new_width // div_w * div_w), 64)
final_height = max(int(new_height // div_h * div_h), 64)
return (final_width, final_height)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
do_rescale: Optional[bool] = None,
do_normalize: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[str] = "channels_first",
input_data_format: Optional[str] = None,
**kwargs,
) -> BatchFeature:
"""
主处理函数
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError("Invalid image type. Must be PIL Image, numpy array, or tensor.")
# 准备结果容器
pixel_values = []
valid_ratios = []
for image in images:
# 确保是 PIL Image 以便使用 resize
if not isinstance(image, Image.Image):
image = to_numpy_array(image)
image = Image.fromarray(image)
original_width, original_height = image.size
# --- 1. Resize (自定义逻辑) ---
if do_resize:
target_size = self._calculate_target_size(original_width, original_height)
# 计算 valid_ratio (原代码逻辑: min(1.0, float(w_r / w)))
# 注意:这里用的是 resize 后的宽 / 原始宽
valid_ratio = min(1.0, float(target_size[0] / original_width))
# 执行 Resize
image = image.resize(target_size, resample=self.resample)
else:
valid_ratio = 1.0
# --- 2. Convert to Numpy & Rescale (ToTensor 的一部分) ---
# T.ToTensor() 会将 PIL [0, 255] 转换为 Float [0.0, 1.0]
image = to_numpy_array(image)[:, :, :3] # (H, W, C)
if do_rescale:
image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)
# --- 3. Normalize ---
# T.Normalize(0.5, 0.5) -> (x - 0.5) / 0.5
if do_normalize:
image = self.normalize(image, mean=self.image_mean, std=self.image_std,
input_data_format=input_data_format)
# --- 4. Transpose (HWC -> CHW) ---
# HuggingFace 默认输出通常需要转换通道
if data_format == "channels_first":
image = image.transpose((2, 0, 1))
pixel_values.append(image)
valid_ratios.append(valid_ratio)
# 打包返回
data = {"pixel_values": pixel_values, "valid_ratio": valid_ratios}
return BatchFeature(data=data, tensor_type=return_tensors)
if __name__ == "__main__":
# 1. 实例化 Processor
processor = UniRecImageProcessor(
max_side=[960, 1408],
divided_factor=[64, 64]
)
# 2. 准备测试图片
img_path = "/mnt/bn/dykdataa800/workspace/openocrdoc/OpenOCR/crop_img_hand/Snipaste_2025-04-13_20-46-06.png"
image = Image.open(img_path).convert("RGB")
# 3. 处理图片 (返回 PyTorch Tensor)
inputs = processor(image, return_tensors="pt")
print("Keys:", inputs.keys())
print("Shape:", inputs["pixel_values"].shape)
print("Valid Ratio:", inputs["valid_ratio"])
# 保存到本地目录
processor.save_pretrained("./unirec_0_1b_mbart")
# 从本地加载
loaded_processor = UniRecImageProcessor.from_pretrained("./unirec_0_1b_mbart")
# 验证参数是否恢复
print(loaded_processor.max_side) # [960, 1408]
print(loaded_processor.divided_factor) # [64, 64]
result = loaded_processor(image, return_tensors="pt")
print(torch.equal(inputs["pixel_values"], result["pixel_values"])) |