unirec_0_1b / processing_unirec.py
topdu's picture
Upload folder using huggingface_hub
d2469d0 verified
from typing import List, Optional, Union
import re
import torch
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import (
ImageInput,
make_list_of_images,
valid_images,
to_numpy_array,
)
from transformers.utils import TensorType
rules = [
(r'-<\|sn\|>', ''),
(r' <\|sn\|>', ' '),
(r'<\|sn\|>', ' '),
(r'<\|unk\|>', ''),
(r'<s>', ''),
(r'</s>', ''),
(r'\uffff', ''),
(r'_{4,}', '___'),
(r'\.{4,}', '...'),
]
def clean_special_tokens(text):
text = text.replace(' ', '').replace('Ġ', ' ').replace('Ċ', '\n').replace(
'<|bos|>', '').replace('<|eos|>', '').replace('<|pad|>', '')
for rule in rules:
text = re.sub(rule[0], rule[1], text)
text = text.replace('<tdcolspan=', '<td colspan=')
text = text.replace('<tdrowspan=', '<td rowspan=')
text = text.replace('"colspan=', '" colspan=')
return text
class UniRecImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
max_side: List[int] = [64 * 15, 64 * 22], # [960, 1408] w, h
divided_factor: List[int] = [64, 64],
do_resize: bool = True,
do_rescale: bool = True,
rescale_factor: float = 1 / 255.0,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = [0.5, 0.5, 0.5],
image_std: Union[float, List[float]] = [0.5, 0.5, 0.5],
resample: int = Image.BICUBIC, # 对应 T.InterpolationMode.BICUBIC
**kwargs,
):
super().__init__(**kwargs)
self.max_side = max_side
self.divided_factor = divided_factor
self.do_resize = do_resize
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.resample = resample
def _calculate_target_size(self, original_width, original_height):
"""
复刻自定义的 resize_image 和整除逻辑
"""
max_width, max_height = self.max_side[0], self.max_side[1]
# 1. 计算宽高比逻辑
aspect_ratio = original_width / original_height
if original_width > max_width or original_height > max_height:
if (max_width / max_height) >= aspect_ratio:
# 按高度限制
new_height = max_height
new_width = int(new_height * aspect_ratio)
else:
# 按宽度限制
new_width = max_width
new_height = int(new_width / aspect_ratio)
else:
new_width, new_height = original_width, original_height
# 2. 整除因子逻辑 (Divided Factor)
# 注意:原代码中 max_side[0] 是宽还是高取决于 imgW, imgH 的定义。
# 通常 PIL size 是 (W, H)。
# 原代码:h_r = max(int(h_r // factor * factor), 64)
div_w, div_h = self.divided_factor[0], self.divided_factor[1]
final_width = max(int(new_width // div_w * div_w), 64)
final_height = max(int(new_height // div_h * div_h), 64)
return (final_width, final_height)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
do_rescale: Optional[bool] = None,
do_normalize: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[str] = "channels_first",
input_data_format: Optional[str] = None,
**kwargs,
) -> BatchFeature:
"""
主处理函数
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError("Invalid image type. Must be PIL Image, numpy array, or tensor.")
# 准备结果容器
pixel_values = []
valid_ratios = []
for image in images:
# 确保是 PIL Image 以便使用 resize
if not isinstance(image, Image.Image):
image = to_numpy_array(image)
image = Image.fromarray(image)
original_width, original_height = image.size
# --- 1. Resize (自定义逻辑) ---
if do_resize:
target_size = self._calculate_target_size(original_width, original_height)
# 计算 valid_ratio (原代码逻辑: min(1.0, float(w_r / w)))
# 注意:这里用的是 resize 后的宽 / 原始宽
valid_ratio = min(1.0, float(target_size[0] / original_width))
# 执行 Resize
image = image.resize(target_size, resample=self.resample)
else:
valid_ratio = 1.0
# --- 2. Convert to Numpy & Rescale (ToTensor 的一部分) ---
# T.ToTensor() 会将 PIL [0, 255] 转换为 Float [0.0, 1.0]
image = to_numpy_array(image)[:, :, :3] # (H, W, C)
if do_rescale:
image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)
# --- 3. Normalize ---
# T.Normalize(0.5, 0.5) -> (x - 0.5) / 0.5
if do_normalize:
image = self.normalize(image, mean=self.image_mean, std=self.image_std,
input_data_format=input_data_format)
# --- 4. Transpose (HWC -> CHW) ---
# HuggingFace 默认输出通常需要转换通道
if data_format == "channels_first":
image = image.transpose((2, 0, 1))
pixel_values.append(image)
valid_ratios.append(valid_ratio)
# 打包返回
data = {"pixel_values": pixel_values, "valid_ratio": valid_ratios}
return BatchFeature(data=data, tensor_type=return_tensors)
if __name__ == "__main__":
# 1. 实例化 Processor
processor = UniRecImageProcessor(
max_side=[960, 1408],
divided_factor=[64, 64]
)
# 2. 准备测试图片
img_path = "/mnt/bn/dykdataa800/workspace/openocrdoc/OpenOCR/crop_img_hand/Snipaste_2025-04-13_20-46-06.png"
image = Image.open(img_path).convert("RGB")
# 3. 处理图片 (返回 PyTorch Tensor)
inputs = processor(image, return_tensors="pt")
print("Keys:", inputs.keys())
print("Shape:", inputs["pixel_values"].shape)
print("Valid Ratio:", inputs["valid_ratio"])
# 保存到本地目录
processor.save_pretrained("./unirec_0_1b_mbart")
# 从本地加载
loaded_processor = UniRecImageProcessor.from_pretrained("./unirec_0_1b_mbart")
# 验证参数是否恢复
print(loaded_processor.max_side) # [960, 1408]
print(loaded_processor.divided_factor) # [64, 64]
result = loaded_processor(image, return_tensors="pt")
print(torch.equal(inputs["pixel_values"], result["pixel_values"]))