Falcon-OCR / processing_falcon_ocr.py
yasserDahou's picture
Upload folder using huggingface_hub
f22f735 verified
import io
import math
import einops as E
import numpy as np
import requests
import torch
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_transforms import convert_to_rgb, resize
from transformers.image_utils import (
ImageInput,
get_image_size,
infer_channel_dimension_format,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
IMAGE_MEAN = [0.5, 0.5, 0.5]
IMAGE_STD = [0.5, 0.5, 0.5]
def load_image(image):
if image is None:
return None
if isinstance(image, Image.Image):
return image
if isinstance(image, str):
if image.startswith(("http://", "https://")):
response = requests.get(image, timeout=10)
response.raise_for_status()
return Image.open(io.BytesIO(response.content))
if image.endswith(".npy"):
img_array = io.BytesIO(np.load(image))
return Image.open(img_array)
return Image.open(image)
if isinstance(image, np.bytes_):
return Image.open(io.BytesIO(image))
if isinstance(image, np.ndarray):
return Image.fromarray(image)
raise TypeError(f"Unknown image format {image}")
def load_images(images_input, min_dimension: int, max_dimension: int):
images = []
if images_input is not None:
for inp in images_input:
img = load_image(inp)
img = resize_image_if_necessary(img, min_dimension, max_dimension)
images.append(img)
return images
def resize_image_if_necessary(
image,
shortest_dimension=224,
longest_dimension=896,
):
original_width, original_height = image.size
aspect_ratio = original_width / original_height
if (
shortest_dimension <= original_width <= longest_dimension
and shortest_dimension <= original_height <= longest_dimension
):
return image
is_vertical_image = original_width < original_height
if original_width < shortest_dimension or original_height < shortest_dimension:
if is_vertical_image:
new_width = shortest_dimension
new_height = int(new_width / aspect_ratio)
else:
new_height = shortest_dimension
new_width = int(new_height * aspect_ratio)
else:
if is_vertical_image:
new_width = longest_dimension
new_height = int(new_width / aspect_ratio)
else:
new_height = longest_dimension
new_width = int(new_height * aspect_ratio)
if new_width > longest_dimension:
new_width = longest_dimension
new_height = int(new_width / aspect_ratio)
if new_height > longest_dimension:
new_height = longest_dimension
new_width = int(new_height * aspect_ratio)
resized_image = image.resize((new_width, new_height))
return resized_image
def smart_resize(
image,
factor: int,
resample,
input_data_format,
min_pixels: int = 56 * 56,
max_pixels: int = 14 * 14 * 4 * 1280,
):
height, width = get_image_size(image, channel_dim=input_data_format)
if height < factor or width < factor:
raise ValueError(f"{height=} or {width=} must be larger than {factor=}")
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = np.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = np.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
image = resize(
image,
size=(h_bar, w_bar),
resample=resample,
input_data_format=input_data_format,
)
return image
class ImageProcessor(BaseImageProcessor):
def __init__(
self,
patch_size,
merge_size,
do_resize: bool = True,
resample: Image.Resampling = Image.Resampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
min_pixels: int = 56 * 56,
max_pixels: int = 28 * 28 * 1280,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean or IMAGE_MEAN
self.image_std = image_std or IMAGE_STD
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.patch_size = patch_size
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
self.do_convert_rgb = do_convert_rgb
validate_preprocess_arguments(
rescale_factor=self.rescale_factor,
do_normalize=self.do_normalize,
image_mean=self.image_mean,
image_std=self.image_std,
do_resize=self.do_resize,
size=self.size,
resample=self.resample,
)
def _preprocess(self, image: ImageInput, do_rescale=None, do_normalize=None):
if self.do_convert_rgb:
image = convert_to_rgb(image)
image = to_numpy_array(image)
input_data_format = infer_channel_dimension_format(image)
if self.do_resize:
image = smart_resize(
image,
factor=self.patch_size * self.merge_size,
resample=self.resample,
input_data_format=input_data_format,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
if do_rescale or self.do_rescale:
image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)
if do_normalize or self.do_normalize:
image = self.normalize(
image=image, mean=self.image_mean, std=self.image_std,
input_data_format=input_data_format,
)
return image
def preprocess(self, images: list[ImageInput] | None, do_rescale=None, do_normalize=None, **kwargs):
del kwargs
if images is None:
return []
images = [item for item in images if item is not None]
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
pixel_values = []
for image in images:
processed_image = self._preprocess(image, do_rescale, do_normalize)
processed_image = processed_image[None, ...]
pixel_values.append(processed_image)
return pixel_values
def batch_images_with_mask(self, pixel_values, max_image_height, max_image_width):
if pixel_values is None:
return None
pixel_values = [item for item in pixel_values if item is not None and len(item) != 0]
if len(pixel_values) == 0:
return None
pixel_values = [torch.from_numpy(img) for img in pixel_values]
max_temporal = max(img.shape[0] for img in pixel_values)
def pad_image_and_mask(img):
time_steps, height, width, channels = img.shape
if channels != 3:
raise ValueError(f"Expected 3-channel RGB images, got {channels} channels.")
padding = (0, 0, 0, max_image_width - width, 0, max_image_height - height, 0, max_temporal - time_steps)
padded_image = torch.nn.functional.pad(img, padding)
mask = torch.zeros((max_temporal, max_image_height, max_image_width), dtype=torch.long)
mask[:time_steps, :height, :width] = 1
return padded_image, mask
padded_pixel_values, padding_masks = zip(*[pad_image_and_mask(img) for img in pixel_values])
padded_pixel_values = torch.stack(list(padded_pixel_values))
padding_masks = torch.stack(list(padding_masks))
return {"pixel_values": padded_pixel_values, "padding_mask": padding_masks}
# ---------------------------------------------------------------------------
# Positional encoding helpers
# ---------------------------------------------------------------------------
def _compute_image_spatial_positions(
pixel_mask_THW: torch.Tensor,
spatial_patch_size: int,
temporal_patch_size: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
mask_thw = E.reduce(
pixel_mask_THW,
"(t tp) (h hp) (w wp) -> t h w",
reduction="any",
tp=temporal_patch_size,
hp=spatial_patch_size,
wp=spatial_patch_size,
)
width = E.reduce(mask_thw.sum(dim=-1).int(), "t h -> ", reduction="max")
height = E.reduce(mask_thw.sum(dim=-2).int(), "t w -> ", reduction="max")
xlim = torch.sqrt(width / height)
ylim = torch.sqrt(height / width)
xpos = torch.linspace(-xlim, xlim, int(width))
ypos = torch.linspace(-ylim, ylim, int(height))
wpos, hpos = torch.meshgrid(xpos, ypos, indexing="xy")
return hpos.flatten(), wpos.flatten()
def _get_image_token_masks(tokens, config):
spatial_mask = tokens == config.img_id
no_increase_mask = (
spatial_mask
| (tokens == config.image_reg_1_token_id)
| (tokens == config.image_reg_2_token_id)
| (tokens == config.image_reg_3_token_id)
| (tokens == config.image_reg_4_token_id)
| (tokens == config.img_end_id)
)
return spatial_mask, no_increase_mask
def get_pos_thw(
tokens: torch.Tensor,
pixel_masks_NTHW: torch.Tensor,
config,
spatial_patch_size: int,
temporal_patch_size: int = 1,
pad_token_id: int = None,
):
assert pad_token_id is not None
assert tokens.ndim == 2
assert pixel_masks_NTHW.ndim == 4
spatial_img_token_mask_BS, no_increase_idx_img_token_mask_BS = _get_image_token_masks(tokens, config)
hpos_parts, wpos_parts = [], []
for i in range(pixel_masks_NTHW.shape[0]):
h, w = _compute_image_spatial_positions(pixel_masks_NTHW[i], spatial_patch_size, temporal_patch_size)
hpos_parts.append(h)
wpos_parts.append(w)
hpos_N = torch.cat(hpos_parts) if hpos_parts else torch.empty(0)
wpos_N = torch.cat(wpos_parts) if wpos_parts else torch.empty(0)
expected_tokens = spatial_img_token_mask_BS.sum().item()
actual_tokens = hpos_N.numel()
assert actual_tokens == expected_tokens, (
f"Mismatch between spatial image tokens ({expected_tokens}) and generated positions ({actual_tokens})."
)
hpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device)
wpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device)
hpos_BS = hpos_BS.masked_scatter_(spatial_img_token_mask_BS, hpos_N)
wpos_BS = wpos_BS.masked_scatter_(spatial_img_token_mask_BS, wpos_N)
tpos_BS = torch.ones_like(tokens, dtype=torch.float, device=tokens.device)
tpos_BS[no_increase_idx_img_token_mask_BS] = 0
tpos_BS = torch.cumsum(tpos_BS, dim=1) - 1
tpos_BS[tokens == pad_token_id] = 0
hw_pos_BS2 = torch.stack([hpos_BS, wpos_BS], dim=-1)
return tpos_BS.long(), hw_pos_BS2
def calculate_image_tokens(image, patch_size, merge_size):
height, width = get_image_size(image)
return int((height * width) / (patch_size * patch_size * merge_size * merge_size))
def tokenize_inputs(prompt, images, tokenizer, config, patch_size, merge_size, max_length):
img_reg_ids = [
config.image_reg_1_token_id,
config.image_reg_2_token_id,
config.image_reg_3_token_id,
config.image_reg_4_token_id,
]
if images is not None and len(images) > 0:
image_token_counts = [calculate_image_tokens(image, patch_size, merge_size) for image in images]
else:
image_token_counts = []
image_token = tokenizer.convert_ids_to_tokens(config.img_id)
prompt_chunks = [tokenizer.encode(chunk) for chunk in prompt.split(image_token)]
def insert_separator(X, sep):
return [ele for sublist in zip(X, sep) for ele in sublist][:-1]
input_ids = []
offset = 0
bos_id = getattr(tokenizer, "bos_token_id", None)
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and bos_id is not None and prompt_chunks[0][0] == bos_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
separators = []
for count in image_token_counts:
tokens = [config.img_id] * count
image_block = [config.image_cls_token_id, *img_reg_ids, *tokens, config.img_end_id]
separators.append(image_block)
if len(separators) != 0 and len(separators) != len(prompt_chunks):
separators.append(separators[-1])
selected_images = []
if len(separators) == 0:
input_ids = prompt_chunks[0]
else:
for index, x in enumerate(insert_separator(prompt_chunks, separators)):
if index % 2 != 0:
if (len(input_ids) + len(x)) < max_length:
input_ids.extend(x)
selected_images.append(images[index // 2])
elif index % 2 == 0:
input_ids.extend(x[offset:])
input_ids = torch.LongTensor(input_ids)
return input_ids, selected_images
def process_batch(
tokenizer,
config,
image_prompt_pairs,
max_length,
min_dimension,
max_dimension,
patch_size=16,
merge_size=1,
):
"""
Process a batch of images with text prompts.
Uses LEFT PADDING for proper batch generation with causal models.
"""
all_input_ids = []
all_selected_images = []
processor_local = ImageProcessor(patch_size, merge_size)
for img_input, prompt in image_prompt_pairs:
img = load_image(img_input)
if img is not None:
img = resize_image_if_necessary(img, min_dimension, max_dimension)
images = processor_local.preprocess(images=[img] if img else [])
input_ids, selected_images = tokenize_inputs(
prompt, images, tokenizer, config, patch_size, merge_size, max_length,
)
all_input_ids.append(input_ids)
all_selected_images.extend(selected_images)
pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
padded_input_ids = torch.nn.utils.rnn.pad_sequence(
all_input_ids, batch_first=True, padding_value=pad_token_id, padding_side="left",
)
processed = processor_local.batch_images_with_mask(all_selected_images, max_dimension, max_dimension)
assert processed is not None
pos_t, pos_hw = get_pos_thw(
padded_input_ids, processed["padding_mask"], config, patch_size, pad_token_id=pad_token_id,
)
return {
"tokens": padded_input_ids,
"pixel_values": processed["pixel_values"],
"pixel_mask": processed["padding_mask"],
"pos_t": pos_t,
"pos_hw": pos_hw,
"pad_token_id": pad_token_id,
}