Image-to-Text
Transformers
Safetensors
falcon_ocr
text-generation
falcon
ocr
vision-language
document-understanding
custom_code
8-bit precision
Instructions to use beaupi/Falcon-OCR-oQ8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use beaupi/Falcon-OCR-oQ8 with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "image-to-text" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("image-to-text", model="beaupi/Falcon-OCR-oQ8", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("beaupi/Falcon-OCR-oQ8", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import io | |
| import math | |
| import einops as E | |
| import numpy as np | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from transformers.image_processing_utils import BaseImageProcessor | |
| from transformers.image_transforms import convert_to_rgb, resize | |
| from transformers.image_utils import ( | |
| ImageInput, | |
| get_image_size, | |
| infer_channel_dimension_format, | |
| to_numpy_array, | |
| valid_images, | |
| validate_preprocess_arguments, | |
| ) | |
| IMAGE_MEAN = [0.5, 0.5, 0.5] | |
| IMAGE_STD = [0.5, 0.5, 0.5] | |
| def load_image(image): | |
| if image is None: | |
| return None | |
| if isinstance(image, Image.Image): | |
| return image | |
| if isinstance(image, str): | |
| if image.startswith(("http://", "https://")): | |
| response = requests.get(image, timeout=10) | |
| response.raise_for_status() | |
| return Image.open(io.BytesIO(response.content)) | |
| if image.endswith(".npy"): | |
| img_array = io.BytesIO(np.load(image)) | |
| return Image.open(img_array) | |
| return Image.open(image) | |
| if isinstance(image, np.bytes_): | |
| return Image.open(io.BytesIO(image)) | |
| if isinstance(image, np.ndarray): | |
| return Image.fromarray(image) | |
| raise TypeError(f"Unknown image format {image}") | |
| def load_images(images_input, min_dimension: int, max_dimension: int): | |
| images = [] | |
| if images_input is not None: | |
| for inp in images_input: | |
| img = load_image(inp) | |
| img = resize_image_if_necessary(img, min_dimension, max_dimension) | |
| images.append(img) | |
| return images | |
| def resize_image_if_necessary( | |
| image, | |
| shortest_dimension=224, | |
| longest_dimension=896, | |
| ): | |
| original_width, original_height = image.size | |
| aspect_ratio = original_width / original_height | |
| if ( | |
| shortest_dimension <= original_width <= longest_dimension | |
| and shortest_dimension <= original_height <= longest_dimension | |
| ): | |
| return image | |
| is_vertical_image = original_width < original_height | |
| if original_width < shortest_dimension or original_height < shortest_dimension: | |
| if is_vertical_image: | |
| new_width = shortest_dimension | |
| new_height = int(new_width / aspect_ratio) | |
| else: | |
| new_height = shortest_dimension | |
| new_width = int(new_height * aspect_ratio) | |
| else: | |
| if is_vertical_image: | |
| new_width = longest_dimension | |
| new_height = int(new_width / aspect_ratio) | |
| else: | |
| new_height = longest_dimension | |
| new_width = int(new_height * aspect_ratio) | |
| if new_width > longest_dimension: | |
| new_width = longest_dimension | |
| new_height = int(new_width / aspect_ratio) | |
| if new_height > longest_dimension: | |
| new_height = longest_dimension | |
| new_width = int(new_height * aspect_ratio) | |
| resized_image = image.resize((new_width, new_height)) | |
| return resized_image | |
| def smart_resize( | |
| image, | |
| factor: int, | |
| resample, | |
| input_data_format, | |
| min_pixels: int = 56 * 56, | |
| max_pixels: int = 14 * 14 * 4 * 1280, | |
| ): | |
| height, width = get_image_size(image, channel_dim=input_data_format) | |
| if height < factor or width < factor: | |
| raise ValueError(f"{height=} or {width=} must be larger than {factor=}") | |
| if max(height, width) / min(height, width) > 200: | |
| raise ValueError( | |
| f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" | |
| ) | |
| h_bar = round(height / factor) * factor | |
| w_bar = round(width / factor) * factor | |
| if h_bar * w_bar > max_pixels: | |
| beta = np.sqrt((height * width) / max_pixels) | |
| h_bar = math.floor(height / beta / factor) * factor | |
| w_bar = math.floor(width / beta / factor) * factor | |
| elif h_bar * w_bar < min_pixels: | |
| beta = np.sqrt(min_pixels / (height * width)) | |
| h_bar = math.ceil(height * beta / factor) * factor | |
| w_bar = math.ceil(width * beta / factor) * factor | |
| image = resize( | |
| image, | |
| size=(h_bar, w_bar), | |
| resample=resample, | |
| input_data_format=input_data_format, | |
| ) | |
| return image | |
| class ImageProcessor(BaseImageProcessor): | |
| def __init__( | |
| self, | |
| patch_size, | |
| merge_size, | |
| do_resize: bool = True, | |
| resample: Image.Resampling = Image.Resampling.BICUBIC, | |
| do_rescale: bool = True, | |
| rescale_factor: float = 1 / 255, | |
| do_normalize: bool = True, | |
| image_mean: float | list[float] | None = None, | |
| image_std: float | list[float] | None = None, | |
| do_convert_rgb: bool = True, | |
| min_pixels: int = 56 * 56, | |
| max_pixels: int = 28 * 28 * 1280, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.do_resize = do_resize | |
| self.resample = resample | |
| self.do_rescale = do_rescale | |
| self.rescale_factor = rescale_factor | |
| self.do_normalize = do_normalize | |
| self.image_mean = image_mean or IMAGE_MEAN | |
| self.image_std = image_std or IMAGE_STD | |
| self.min_pixels = min_pixels | |
| self.max_pixels = max_pixels | |
| self.patch_size = patch_size | |
| self.merge_size = merge_size | |
| self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} | |
| self.do_convert_rgb = do_convert_rgb | |
| validate_preprocess_arguments( | |
| rescale_factor=self.rescale_factor, | |
| do_normalize=self.do_normalize, | |
| image_mean=self.image_mean, | |
| image_std=self.image_std, | |
| do_resize=self.do_resize, | |
| size=self.size, | |
| resample=self.resample, | |
| ) | |
| def _preprocess(self, image: ImageInput, do_rescale=None, do_normalize=None): | |
| if self.do_convert_rgb: | |
| image = convert_to_rgb(image) | |
| image = to_numpy_array(image) | |
| input_data_format = infer_channel_dimension_format(image) | |
| if self.do_resize: | |
| image = smart_resize( | |
| image, | |
| factor=self.patch_size * self.merge_size, | |
| resample=self.resample, | |
| input_data_format=input_data_format, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| ) | |
| if do_rescale or self.do_rescale: | |
| image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format) | |
| if do_normalize or self.do_normalize: | |
| image = self.normalize( | |
| image=image, mean=self.image_mean, std=self.image_std, | |
| input_data_format=input_data_format, | |
| ) | |
| return image | |
| def preprocess(self, images: list[ImageInput] | None, do_rescale=None, do_normalize=None, **kwargs): | |
| del kwargs | |
| if images is None: | |
| return [] | |
| images = [item for item in images if item is not None] | |
| if not valid_images(images): | |
| raise ValueError( | |
| "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " | |
| "torch.Tensor, tf.Tensor or jax.ndarray." | |
| ) | |
| pixel_values = [] | |
| for image in images: | |
| processed_image = self._preprocess(image, do_rescale, do_normalize) | |
| processed_image = processed_image[None, ...] | |
| pixel_values.append(processed_image) | |
| return pixel_values | |
| def batch_images_with_mask(self, pixel_values, max_image_height, max_image_width): | |
| if pixel_values is None: | |
| return None | |
| pixel_values = [item for item in pixel_values if item is not None and len(item) != 0] | |
| if len(pixel_values) == 0: | |
| return None | |
| pixel_values = [torch.from_numpy(img) for img in pixel_values] | |
| max_temporal = max(img.shape[0] for img in pixel_values) | |
| def pad_image_and_mask(img): | |
| time_steps, height, width, channels = img.shape | |
| if channels != 3: | |
| raise ValueError(f"Expected 3-channel RGB images, got {channels} channels.") | |
| padding = (0, 0, 0, max_image_width - width, 0, max_image_height - height, 0, max_temporal - time_steps) | |
| padded_image = torch.nn.functional.pad(img, padding) | |
| mask = torch.zeros((max_temporal, max_image_height, max_image_width), dtype=torch.long) | |
| mask[:time_steps, :height, :width] = 1 | |
| return padded_image, mask | |
| padded_pixel_values, padding_masks = zip(*[pad_image_and_mask(img) for img in pixel_values]) | |
| padded_pixel_values = torch.stack(list(padded_pixel_values)) | |
| padding_masks = torch.stack(list(padding_masks)) | |
| return {"pixel_values": padded_pixel_values, "padding_mask": padding_masks} | |
| # --------------------------------------------------------------------------- | |
| # Positional encoding helpers | |
| # --------------------------------------------------------------------------- | |
| def _compute_image_spatial_positions( | |
| pixel_mask_THW: torch.Tensor, | |
| spatial_patch_size: int, | |
| temporal_patch_size: int = 1, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| mask_thw = E.reduce( | |
| pixel_mask_THW, | |
| "(t tp) (h hp) (w wp) -> t h w", | |
| reduction="any", | |
| tp=temporal_patch_size, | |
| hp=spatial_patch_size, | |
| wp=spatial_patch_size, | |
| ) | |
| width = E.reduce(mask_thw.sum(dim=-1).int(), "t h -> ", reduction="max") | |
| height = E.reduce(mask_thw.sum(dim=-2).int(), "t w -> ", reduction="max") | |
| xlim = torch.sqrt(width / height) | |
| ylim = torch.sqrt(height / width) | |
| xpos = torch.linspace(-xlim, xlim, int(width)) | |
| ypos = torch.linspace(-ylim, ylim, int(height)) | |
| wpos, hpos = torch.meshgrid(xpos, ypos, indexing="xy") | |
| return hpos.flatten(), wpos.flatten() | |
| def _get_image_token_masks(tokens, config): | |
| spatial_mask = tokens == config.img_id | |
| no_increase_mask = ( | |
| spatial_mask | |
| | (tokens == config.image_reg_1_token_id) | |
| | (tokens == config.image_reg_2_token_id) | |
| | (tokens == config.image_reg_3_token_id) | |
| | (tokens == config.image_reg_4_token_id) | |
| | (tokens == config.img_end_id) | |
| ) | |
| return spatial_mask, no_increase_mask | |
| def get_pos_thw( | |
| tokens: torch.Tensor, | |
| pixel_masks_NTHW: torch.Tensor, | |
| config, | |
| spatial_patch_size: int, | |
| temporal_patch_size: int = 1, | |
| pad_token_id: int = None, | |
| ): | |
| assert pad_token_id is not None | |
| assert tokens.ndim == 2 | |
| assert pixel_masks_NTHW.ndim == 4 | |
| spatial_img_token_mask_BS, no_increase_idx_img_token_mask_BS = _get_image_token_masks(tokens, config) | |
| hpos_parts, wpos_parts = [], [] | |
| for i in range(pixel_masks_NTHW.shape[0]): | |
| h, w = _compute_image_spatial_positions(pixel_masks_NTHW[i], spatial_patch_size, temporal_patch_size) | |
| hpos_parts.append(h) | |
| wpos_parts.append(w) | |
| hpos_N = torch.cat(hpos_parts) if hpos_parts else torch.empty(0) | |
| wpos_N = torch.cat(wpos_parts) if wpos_parts else torch.empty(0) | |
| expected_tokens = spatial_img_token_mask_BS.sum().item() | |
| actual_tokens = hpos_N.numel() | |
| assert actual_tokens == expected_tokens, ( | |
| f"Mismatch between spatial image tokens ({expected_tokens}) and generated positions ({actual_tokens})." | |
| ) | |
| hpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device) | |
| wpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device) | |
| hpos_BS = hpos_BS.masked_scatter_(spatial_img_token_mask_BS, hpos_N) | |
| wpos_BS = wpos_BS.masked_scatter_(spatial_img_token_mask_BS, wpos_N) | |
| tpos_BS = torch.ones_like(tokens, dtype=torch.float, device=tokens.device) | |
| tpos_BS[no_increase_idx_img_token_mask_BS] = 0 | |
| tpos_BS = torch.cumsum(tpos_BS, dim=1) - 1 | |
| tpos_BS[tokens == pad_token_id] = 0 | |
| hw_pos_BS2 = torch.stack([hpos_BS, wpos_BS], dim=-1) | |
| return tpos_BS.long(), hw_pos_BS2 | |
| def calculate_image_tokens(image, patch_size, merge_size): | |
| height, width = get_image_size(image) | |
| return int((height * width) / (patch_size * patch_size * merge_size * merge_size)) | |
| def tokenize_inputs(prompt, images, tokenizer, config, patch_size, merge_size, max_length): | |
| img_reg_ids = [ | |
| config.image_reg_1_token_id, | |
| config.image_reg_2_token_id, | |
| config.image_reg_3_token_id, | |
| config.image_reg_4_token_id, | |
| ] | |
| if images is not None and len(images) > 0: | |
| image_token_counts = [calculate_image_tokens(image, patch_size, merge_size) for image in images] | |
| else: | |
| image_token_counts = [] | |
| image_token = tokenizer.convert_ids_to_tokens(config.img_id) | |
| prompt_chunks = [tokenizer.encode(chunk) for chunk in prompt.split(image_token)] | |
| def insert_separator(X, sep): | |
| return [ele for sublist in zip(X, sep) for ele in sublist][:-1] | |
| input_ids = [] | |
| offset = 0 | |
| bos_id = getattr(tokenizer, "bos_token_id", None) | |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and bos_id is not None and prompt_chunks[0][0] == bos_id: | |
| offset = 1 | |
| input_ids.append(prompt_chunks[0][0]) | |
| separators = [] | |
| for count in image_token_counts: | |
| tokens = [config.img_id] * count | |
| image_block = [config.image_cls_token_id, *img_reg_ids, *tokens, config.img_end_id] | |
| separators.append(image_block) | |
| if len(separators) != 0 and len(separators) != len(prompt_chunks): | |
| separators.append(separators[-1]) | |
| selected_images = [] | |
| if len(separators) == 0: | |
| input_ids = prompt_chunks[0] | |
| else: | |
| for index, x in enumerate(insert_separator(prompt_chunks, separators)): | |
| if index % 2 != 0: | |
| if (len(input_ids) + len(x)) < max_length: | |
| input_ids.extend(x) | |
| selected_images.append(images[index // 2]) | |
| elif index % 2 == 0: | |
| input_ids.extend(x[offset:]) | |
| input_ids = torch.LongTensor(input_ids) | |
| return input_ids, selected_images | |
| def process_batch( | |
| tokenizer, | |
| config, | |
| image_prompt_pairs, | |
| max_length, | |
| min_dimension, | |
| max_dimension, | |
| patch_size=16, | |
| merge_size=1, | |
| ): | |
| """ | |
| Process a batch of images with text prompts. | |
| Uses LEFT PADDING for proper batch generation with causal models. | |
| """ | |
| all_input_ids = [] | |
| all_selected_images = [] | |
| processor_local = ImageProcessor(patch_size, merge_size) | |
| for img_input, prompt in image_prompt_pairs: | |
| img = load_image(img_input) | |
| if img is not None: | |
| img = resize_image_if_necessary(img, min_dimension, max_dimension) | |
| images = processor_local.preprocess(images=[img] if img else []) | |
| input_ids, selected_images = tokenize_inputs( | |
| prompt, images, tokenizer, config, patch_size, merge_size, max_length, | |
| ) | |
| all_input_ids.append(input_ids) | |
| all_selected_images.extend(selected_images) | |
| pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>") | |
| padded_input_ids = torch.nn.utils.rnn.pad_sequence( | |
| all_input_ids, batch_first=True, padding_value=pad_token_id, padding_side="left", | |
| ) | |
| processed = processor_local.batch_images_with_mask(all_selected_images, max_dimension, max_dimension) | |
| assert processed is not None | |
| pos_t, pos_hw = get_pos_thw( | |
| padded_input_ids, processed["padding_mask"], config, patch_size, pad_token_id=pad_token_id, | |
| ) | |
| return { | |
| "tokens": padded_input_ids, | |
| "pixel_values": processed["pixel_values"], | |
| "pixel_mask": processed["padding_mask"], | |
| "pos_t": pos_t, | |
| "pos_hw": pos_hw, | |
| "pad_token_id": pad_token_id, | |
| } | |