Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, List | |
| import torch | |
| import torchvision.transforms as T | |
| from transformers.image_processing_base import ImageProcessingMixin | |
| class TinyDocImageProcessor(ImageProcessingMixin): | |
| """ | |
| Image processor for TinyDoc-VLM. | |
| Handles resizing, normalization, and optional tiling (splitting) of document images. | |
| """ | |
| def __init__( | |
| self, | |
| image_size: int = 384, | |
| mean: Optional[List[float]] = None, | |
| std: Optional[List[float]] = None, | |
| tiling_mode: str = "auto", # "none", "auto" (split if large) | |
| **kwargs, | |
| ): | |
| self.image_size = image_size | |
| self.mean = mean or [0.5, 0.5, 0.5] | |
| self.std = std or [0.5, 0.5, 0.5] | |
| self.tiling_mode = tiling_mode | |
| super().__init__(**kwargs) | |
| # Base torchvision transforms for single tile | |
| self.transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=self.mean, std=self.std) | |
| ]) | |
| def preprocess( | |
| self, | |
| image: Image.Image, | |
| return_tensors: str = "pt" | |
| ) -> torch.Tensor: | |
| """ | |
| Preprocesses a PIL Image into a multi-tile float tensor. | |
| Returns shape: (num_tiles, 3, image_size, image_size) | |
| """ | |
| # Ensure RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| w, h = image.size | |
| if self.tiling_mode == "none" or (w <= self.image_size and h <= self.image_size): | |
| # No tiling needed: resize to image_size x image_size and return single tile | |
| resized = image.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR) | |
| tile_tensor = self.transform(resized) # shape (3, image_size, image_size) | |
| # Add tile batch dimension: shape (1, 3, image_size, image_size) | |
| return tile_tensor.unsqueeze(0) | |
| # Tiling mode 'auto': split high-res image into a grid of image_size x image_size tiles, | |
| # plus a downscaled overview thumbnail. | |
| # Calculate how many tiles we need | |
| cols = int(np.ceil(w / self.image_size)) | |
| rows = int(np.ceil(h / self.image_size)) | |
| # Limit grid size to prevent excessive memory usage (max 2x2 grid = 4 tiles) | |
| cols = min(cols, 2) | |
| rows = min(rows, 2) | |
| # Target size for the tiling grid | |
| target_w = cols * self.image_size | |
| target_h = rows * self.image_size | |
| # Resize original image to fit the target grid shape (maintaining proportions via padding) | |
| resized_full = self._resize_and_pad(image, target_w, target_h) | |
| tiles = [] | |
| # 1. Generate thumbnail/overview of the full image | |
| thumbnail = image.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR) | |
| tiles.append(self.transform(thumbnail)) | |
| # 2. Extract tiles from the grid | |
| for r in range(rows): | |
| for c in range(cols): | |
| box = ( | |
| c * self.image_size, | |
| r * self.image_size, | |
| (c + 1) * self.image_size, | |
| (r + 1) * self.image_size | |
| ) | |
| tile = resized_full.crop(box) | |
| tiles.append(self.transform(tile)) | |
| # Stack tiles along a new dimension: shape (num_tiles, 3, image_size, image_size) | |
| # where num_tiles = 1 (overview) + rows * cols | |
| stacked_tiles = torch.stack(tiles, dim=0) | |
| return stacked_tiles | |
| def _resize_and_pad(self, img: Image.Image, target_w: int, target_h: int) -> Image.Image: | |
| """ | |
| Resizes and pads an image to target dimensions while maintaining aspect ratio. | |
| """ | |
| # Calculate aspect ratio | |
| w, h = img.size | |
| ratio = min(target_w / w, target_h / h) | |
| new_w = int(w * ratio) | |
| new_h = int(h * ratio) | |
| resized = img.resize((new_w, new_h), Image.Resampling.BILINEAR) | |
| # Create a new padded background image | |
| padded = Image.new("RGB", (target_w, target_h), (255, 255, 255)) | |
| # Center the resized image | |
| x_offset = (target_w - new_w) // 2 | |
| y_offset = (target_h - new_h) // 2 | |
| padded.paste(resized, (x_offset, y_offset)) | |
| return padded | |