Image-Text-to-Text
MLX
Safetensors
English
falcon_ocr
ocr
vision-language
falcon
apple-silicon
custom_code
Eval Results
Instructions to use mlx-community/Falcon-OCR-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/Falcon-OCR-bf16 with MLX:
# Make sure mlx-vlm is installed # pip install --upgrade mlx-vlm from mlx_vlm import load, generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config # Load the model model, processor = load("mlx-community/Falcon-OCR-bf16") config = load_config("mlx-community/Falcon-OCR-bf16") # Prepare input image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] prompt = "Describe this image." # Apply chat template formatted_prompt = apply_chat_template( processor, config, prompt, num_images=1 ) # Generate output output = generate(model, processor, formatted_prompt, image) print(output) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| from pathlib import Path | |
| import einops as E | |
| import torch | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| from PIL import Image | |
| from torch import Tensor as T | |
| from torch import nn | |
| from torch.nn.attention.flex_attention import ( | |
| AuxRequest, | |
| BlockMask, | |
| ) | |
| from transformers import AutoTokenizer, PreTrainedModel | |
| from .attention import ( | |
| compiled_flex_attn_decode, | |
| compiled_flex_attn_prefill, | |
| create_batch_attention_mask, | |
| offset_mask_mod, | |
| ) | |
| from .configuration_falcon_ocr import FalconOCRConfig | |
| from .processing_falcon_ocr import load_image, process_batch | |
| from .rope import ( | |
| apply_3d_rotary_emb, | |
| apply_golden_freqs_cis_to_visual_pos, | |
| precompute_freqs_cis, | |
| ) | |
| CATEGORY_PROMPTS = { | |
| "plain": "Extract the text content from this image.", | |
| "formula": "Extract the formula content from this image.", | |
| "table": "Extract the table content from this image.", | |
| "text": "Extract the text content from this image.", | |
| "caption": "Extract the caption content from this image.", | |
| "footnote": "Extract the footnote content from this image.", | |
| "list-item": "Extract the list-item content from this image.", | |
| "page-footer": "Extract the page-footer content from this image.", | |
| "page-header": "Extract the page-header content from this image.", | |
| "section-header": "Extract the section-header content from this image.", | |
| "title": "Extract the title content from this image.", | |
| } | |
| LAYOUT_TO_OCR_CATEGORY: dict[str, str | None] = { | |
| "text": "text", | |
| "table": "table", | |
| "formula": "formula", | |
| "caption": "caption", | |
| "footnote": "footnote", | |
| "list-item": "list-item", | |
| "title": "title", | |
| "header": "text", | |
| "footer": "page-footer", | |
| "number": "text", | |
| "figure_title": "caption", | |
| "paragraph_title": "section-header", | |
| "doc_title": "title", | |
| "reference_content": "text", | |
| "reference": "text", | |
| "abstract": "text", | |
| "aside_text": "text", | |
| "content": "text", | |
| "formula_number": "text", | |
| "vision_footnote": "footnote", | |
| "algorithm": "text", | |
| "page-footer": "page-footer", | |
| "page-header": "page-header", | |
| "section-header": "section-header", | |
| # Skip — no text to extract | |
| "image": None, | |
| "picture": None, | |
| "figure": None, | |
| "chart": None, | |
| "seal": None, | |
| } | |
| _LAYOUT_TARGET_H, _LAYOUT_TARGET_W = 800, 800 | |
| _MIN_CROP_DIM = 16 | |
| def _box_area(bbox): | |
| return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1]) | |
| def _intersection_area(a, b): | |
| return max(0, min(a[2], b[2]) - max(a[0], b[0])) * max(0, min(a[3], b[3]) - max(a[1], b[1])) | |
| def _containment_ratio(small, large): | |
| area = _box_area(small) | |
| if area <= 0: | |
| return 0.0 | |
| return _intersection_area(small, large) / area | |
| def _filter_nested_detections(detections: list[dict], containment_threshold: float = 0.8) -> list[dict]: | |
| """Remove any box that is mostly contained within a strictly larger box.""" | |
| areas = [_box_area(d["bbox"]) for d in detections] | |
| keep = [] | |
| for i, det in enumerate(detections): | |
| is_nested = False | |
| for j, other in enumerate(detections): | |
| if i == j: | |
| continue | |
| if areas[j] <= areas[i]: | |
| continue | |
| if _containment_ratio(det["bbox"], other["bbox"]) > containment_threshold: | |
| is_nested = True | |
| break | |
| if not is_nested: | |
| keep.append(det) | |
| return keep | |
| # Attention | |
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| B, S, H, D = x.shape | |
| if n_rep == 1: | |
| return x | |
| return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D) | |
| class Attention(nn.Module): | |
| def __init__(self, config: FalconOCRConfig, layer_id: int): | |
| super().__init__() | |
| self.layer_id = layer_id | |
| self.n_kv_heads = config.n_kv_heads or config.n_heads | |
| self.n_rep = config.n_heads // self.n_kv_heads | |
| self.head_dim = config.head_dim or config.dim // config.n_heads | |
| self.q_dim = config.n_heads * self.head_dim | |
| self.kv_dim = self.n_kv_heads * self.head_dim | |
| self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False) | |
| self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False) | |
| self.sinks = nn.Parameter(torch.empty((config.n_heads,))) | |
| def _pre_attention_qkv(self, x) -> tuple[T, T, T]: | |
| qkv = self.wqkv(F.rms_norm(x, (x.size(-1),))) | |
| xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1) | |
| xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim) | |
| xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim) | |
| xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim) | |
| xq = F.rms_norm(xq, (xq.size(-1),)) | |
| xk = F.rms_norm(xk, (xk.size(-1),)) | |
| xk = repeat_kv(xk, n_rep=self.n_rep) | |
| xv = repeat_kv(xv, n_rep=self.n_rep) | |
| return xq, xk, xv | |
| def _post_attention(self, output: T, lse: T) -> T: | |
| # Sink-based scaling: sigmoid(lse - sinks) * output | |
| # equivalent to prepending a sink token to the input | |
| sinks_BHS = self.sinks.view(1, -1, 1) | |
| sink_scale = torch.sigmoid(lse - sinks_BHS) | |
| output = (output * sink_scale.unsqueeze(-1)).to(output.dtype) | |
| output = output.permute(0, 2, 1, 3).contiguous().flatten(2) | |
| return self.wo(output) | |
| def compile_attention(self, *, dynamic: bool = True, mode: str = "default"): | |
| self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode) | |
| self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode) | |
| def forward( | |
| self, x: T, attention_masks: BlockMask, freqs_cis: T, | |
| freqs_cis_2d: T | None = None, pos_hw: T | None = None, | |
| kv_cache=None, input_pos=None, batch_idx=None, | |
| flex_attn_kernel_options=None, | |
| ): | |
| xq, xk, xv = self._pre_attention_qkv(x) | |
| xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw) | |
| xq = E.rearrange(xq, "b s h d -> b h s d") | |
| xk = E.rearrange(xk, "b s h d -> b h s d") | |
| xv = E.rearrange(xv, "b s h d -> b h s d") | |
| xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx) | |
| flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill | |
| output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True)) | |
| return self._post_attention(output, aux_output.lse) | |
| # FeedForward | |
| def _squared_relu_gate_kernel( | |
| packed_ptr, out_ptr, n_rows, n_cols, | |
| in_row_stride, in_col_stride, out_row_stride, out_col_stride, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| n_elements = n_rows * n_cols | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| rows = offsets // n_cols | |
| cols = offsets % n_cols | |
| gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride | |
| up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride | |
| out_idx = rows * out_row_stride + cols * out_col_stride | |
| gate = tl.load(packed_ptr + gate_idx, mask=mask) | |
| up = tl.load(packed_ptr + up_idx, mask=mask) | |
| gate = tl.where(gate > 0, gate, 0.0) | |
| out = gate * gate * up | |
| tl.store(out_ptr + out_idx, out, mask=mask) | |
| def squared_relu_gate(packed: T, hidden_dim: int) -> T: | |
| """Processes interleaved [gate, up, gate, up, ...] from w13; output = ReLU(gate)^2 * up.""" | |
| packed_2d = packed.flatten(0, -2) | |
| n_rows = packed_2d.shape[0] | |
| n_cols = hidden_dim | |
| out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype) | |
| n = n_rows * n_cols | |
| grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) | |
| _squared_relu_gate_kernel[grid]( | |
| packed_2d, out_2d, n_rows, n_cols, | |
| packed_2d.stride(0), packed_2d.stride(1), | |
| out_2d.stride(0), out_2d.stride(1), | |
| BLOCK_SIZE=1024, | |
| ) | |
| return out_2d.view(*packed.shape[:-1], hidden_dim) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim: int, hidden_dim: int): | |
| super().__init__() | |
| self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.hidden_dim = hidden_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.rms_norm(x, (x.size(-1),)) | |
| w13_out = self.w13(x) | |
| return self.w2(squared_relu_gate(w13_out, self.hidden_dim)) | |
| # TransformerBlock | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, layer_id: int, config: FalconOCRConfig): | |
| super().__init__() | |
| self.attention = Attention(config, layer_id) | |
| self.feed_forward = FeedForward(config.dim, config.ffn_dim) | |
| def compile(self, *, dynamic: bool = True, mode: str = "default"): | |
| self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode) | |
| self.attention.compile_attention(dynamic=dynamic, mode=mode) | |
| return self | |
| def forward( | |
| self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None, | |
| pos_hw: T | None = None, attention_masks=None, kv_cache=None, | |
| input_pos=None, batch_idx=None, flex_attn_kernel_options=None, | |
| ): | |
| B, S, D = x.shape | |
| x = x + self.attention( | |
| x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw, | |
| attention_masks=attention_masks, kv_cache=kv_cache, | |
| input_pos=input_pos, batch_idx=batch_idx, | |
| flex_attn_kernel_options=flex_attn_kernel_options, | |
| ) | |
| out = x + self.feed_forward(x) | |
| return out.reshape(B, S, D) | |
| # KV Cache | |
| class KVCache: | |
| def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers): | |
| self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim) | |
| self.kv_cache = None | |
| self.pos = 0 | |
| self.pos_t: T | None = None | |
| def reset(self): | |
| self.pos = 0 | |
| self.pos_t = None | |
| def get_pos(self): | |
| return self.pos | |
| def set_pos_t(self, pos_t): | |
| self.pos_t = pos_t | |
| def increment_and_get_pos_t(self): | |
| assert self.pos_t is not None | |
| self.pos_t += 1 | |
| return self.pos_t | |
| def insert_kv(self, layer_id: int, k: T, v: T, **kwargs): | |
| del kwargs | |
| assert self.pos_t is not None | |
| if self.kv_cache is None: | |
| self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) | |
| B, H, T_add, D = k.size() | |
| t0, t1 = self.pos, self.pos + T_add | |
| self.kv_cache[layer_id, 0, :, :, t0:t1] = k | |
| self.kv_cache[layer_id, 1, :, :, t0:t1] = v | |
| key_view = self.kv_cache[layer_id, 0, :, :, :t1] | |
| value_view = self.kv_cache[layer_id, 1, :, :, :t1] | |
| if layer_id == self.kv_cache.size(0) - 1: | |
| self.pos = t1 | |
| return key_view, value_view | |
| # Sampling | |
| def sample_next_token(logits, rng, temperature=0.0, top_k=None): | |
| assert temperature >= 0.0 | |
| if temperature == 0.0: | |
| return torch.argmax(logits, dim=-1, keepdim=True) | |
| if top_k is not None: | |
| k = min(top_k, logits.size(-1)) | |
| vals, idx = torch.topk(logits, k, dim=-1) | |
| vals = vals / temperature | |
| probs = F.softmax(vals, dim=-1) | |
| choice = torch.multinomial(probs, num_samples=1, generator=rng) | |
| return idx.gather(1, choice) | |
| logits = logits / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| return torch.multinomial(probs, num_samples=1, generator=rng) | |
| # Main Model | |
| class FalconOCRForCausalLM(PreTrainedModel): | |
| config_class = FalconOCRConfig | |
| _no_split_modules = ["TransformerBlock"] | |
| def __init__(self, config: FalconOCRConfig): | |
| super().__init__(config) | |
| img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size | |
| self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False) | |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) | |
| self.layers = nn.ModuleDict() | |
| for layer_id in range(config.n_layers): | |
| self.layers[str(layer_id)] = TransformerBlock(layer_id, config) | |
| self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps) | |
| self.output = nn.Linear(config.dim, config.vocab_size, bias=False) | |
| rope_dim = config.head_dim // 2 | |
| freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta) | |
| freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float) | |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) | |
| self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True) | |
| self._weights_fused = False | |
| self._is_compiled = False | |
| self.post_init() | |
| # Weight management | |
| def _ensure_device_buffers(self): | |
| """Recompute non-persistent buffers that HF meta-device loading may discard.""" | |
| if self._weights_fused: | |
| return | |
| device = self.tok_embeddings.weight.device | |
| c = self.config | |
| rope_dim = c.head_dim // 2 | |
| freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device) | |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) | |
| if self.freqs_cis_golden.device != device: | |
| self.freqs_cis_golden = self.freqs_cis_golden.to(device) | |
| self._weights_fused = True | |
| def compile_model(self): | |
| if self._is_compiled: | |
| return | |
| torch._inductor.config.triton.cudagraphs = False | |
| for layer in self.layers.values(): | |
| layer.compile(dynamic=True, mode="default") | |
| self._is_compiled = True | |
| # Tokenizer | |
| def _get_tokenizer(self): | |
| if not hasattr(self, "_tokenizer"): | |
| import os | |
| path = self.config._name_or_path | |
| is_local = os.path.exists(path) | |
| self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True) | |
| for token_name, token in self._tokenizer.special_tokens_map.items(): | |
| if isinstance(token, str): | |
| setattr(self._tokenizer, token_name, token) | |
| setattr( | |
| self._tokenizer, token_name + "_id", | |
| self._tokenizer.convert_tokens_to_ids(token), | |
| ) | |
| return self._tokenizer | |
| # Attention mask | |
| def get_attention_mask(self, input_batch: T, max_len: int | None = None): | |
| return create_batch_attention_mask( | |
| input_batch, | |
| pad_token_id=self._pad_token_id, | |
| eos_token_id=self.config.eos_id, | |
| soi_token_id=self.config.image_cls_token_id, | |
| eoi_token_id=self.config.img_end_id, | |
| max_len=max_len, | |
| ) | |
| # Embedding helpers | |
| def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS): | |
| B, S, D = h_BSD.shape | |
| pixel_patch_mask = E.reduce( | |
| pixel_masks_NTHW, | |
| "n (t pt) (h ph) (w pw) -> (n t h w)", | |
| reduction="any", | |
| pt=self.config.temporal_patch_size, | |
| ph=self.config.spatial_patch_size, | |
| pw=self.config.spatial_patch_size, | |
| ) | |
| pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c") | |
| valid_patches = pixel_patches_flat[pixel_patch_mask] | |
| valid_feats = self.img_projector(valid_patches) | |
| img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D) | |
| assert valid_feats.numel() == img_mask_h_BSD.sum() | |
| return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats) | |
| # Core forward | |
| def forward( | |
| self, | |
| tokens: T, | |
| attention_mask: BlockMask, | |
| kv_cache, | |
| rope_pos_t: T | None = None, | |
| rope_pos_hw: T | None = None, | |
| pixel_values: T | None = None, | |
| pixel_mask: T | None = None, | |
| ): | |
| B, S = tokens.size() | |
| c = self.config | |
| block_mask = attention_mask | |
| T_pos = kv_cache.get_pos() | |
| is_prefill = S != 1 | |
| if is_prefill: | |
| assert rope_pos_t is not None and rope_pos_hw is not None | |
| pos_t = rope_pos_t[:, T_pos:T_pos + S].long() | |
| kv_cache.pos_t = pos_t[:, -1:] | |
| freqs_cis = self.freqs_cis[pos_t] | |
| rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S] | |
| freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw) | |
| block_mask.seq_lengths = (S, S) | |
| else: | |
| pos_t = kv_cache.increment_and_get_pos_t() | |
| freqs_cis = self.freqs_cis[pos_t] | |
| freqs_cis_golden = None | |
| block_idx = T_pos // block_mask.BLOCK_SIZE[0] | |
| block_mask = block_mask[:, :, block_idx] | |
| block_mask.seq_lengths = (S, T_pos + S) | |
| block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos) | |
| h_BSD = self.tok_embeddings(tokens) | |
| if pixel_values is not None: | |
| assert pixel_mask is not None | |
| pixel_values = pixel_values.to(self.dtype) | |
| pixel_mask = pixel_mask.to(self.dtype) | |
| pixel_patches_NLC = E.rearrange( | |
| pixel_values, | |
| "n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)", | |
| pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size, | |
| ) | |
| h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens) | |
| for layer in self.layers.values(): | |
| h_BSD = layer( | |
| h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden, | |
| pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache, | |
| ) | |
| h_BSD = self.norm(h_BSD) | |
| logits_BSV = self.output(h_BSD) | |
| return logits_BSV | |
| # Layout detection | |
| def _load_layout_model(self, layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors"): | |
| if hasattr(self, "_layout_model"): | |
| return | |
| import torchvision.transforms.functional as tvF | |
| from transformers import AutoModelForObjectDetection, PPDocLayoutV3ImageProcessorFast | |
| self._layout_processor = PPDocLayoutV3ImageProcessorFast.from_pretrained(layout_model) | |
| self._layout_det_model = AutoModelForObjectDetection.from_pretrained( | |
| layout_model, torch_dtype=torch.float16, | |
| ).to(self.device).eval() | |
| self._layout_id2label = self._layout_det_model.config.id2label | |
| self._tvF = tvF | |
| def _run_layout_detection( | |
| self, images: list[Image.Image], threshold: float = 0.5, | |
| ) -> list[list[dict]]: | |
| """Run PP-DocLayoutV3 on a batch of PIL images, return per-image detections.""" | |
| device = self.device | |
| tvF = self._tvF | |
| target_sizes = torch.tensor([img.size[::-1] for img in images]) | |
| tensors = [tvF.pil_to_tensor(img) for img in images] | |
| # GPU-accelerated resize + normalize | |
| result = torch.empty( | |
| len(tensors), 3, _LAYOUT_TARGET_H, _LAYOUT_TARGET_W, | |
| dtype=torch.float16, device=device, | |
| ) | |
| size_groups: dict[tuple[int, int], list[int]] = {} | |
| for i, t in enumerate(tensors): | |
| size_groups.setdefault((t.shape[1], t.shape[2]), []).append(i) | |
| for shape, indices in size_groups.items(): | |
| batch = torch.stack([tensors[i] for i in indices]) | |
| batch = batch.to(device=device, dtype=torch.float32, non_blocking=True) | |
| batch = F.interpolate( | |
| batch, size=(_LAYOUT_TARGET_H, _LAYOUT_TARGET_W), | |
| mode="bicubic", align_corners=False, antialias=False, | |
| ) | |
| batch = (batch.clamp_(0, 255) / 255.0).to(torch.float16) | |
| for j, idx in enumerate(indices): | |
| result[idx] = batch[j] | |
| del batch | |
| outputs = self._layout_det_model(pixel_values=result) | |
| del result | |
| # Postprocess on GPU | |
| logits = outputs.logits | |
| boxes = outputs.pred_boxes | |
| order_logits = outputs.order_logits | |
| box_centers, box_dims = boxes.split(2, dim=-1) | |
| boxes_xyxy = torch.cat([box_centers - 0.5 * box_dims, box_centers + 0.5 * box_dims], dim=-1) | |
| img_h, img_w = target_sizes.unbind(1) | |
| scale = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(device, dtype=boxes_xyxy.dtype) | |
| boxes_xyxy = boxes_xyxy * scale[:, None, :] | |
| num_queries = logits.shape[1] | |
| num_classes = logits.shape[2] | |
| scores = logits.sigmoid() | |
| scores_flat, index = scores.flatten(1).topk(num_queries, dim=-1) | |
| labels = index % num_classes | |
| box_indices = index // num_classes | |
| boxes_xyxy = boxes_xyxy.gather(dim=1, index=box_indices.unsqueeze(-1).expand(-1, -1, 4)) | |
| order_seqs = self._layout_processor._get_order_seqs(order_logits) | |
| order_seqs = order_seqs.gather(dim=1, index=box_indices) | |
| batch_results = [] | |
| for s, l, b, o in zip(scores_flat, labels, boxes_xyxy, order_seqs): | |
| mask = s >= threshold | |
| o_valid = o[mask] | |
| _, indices_sorted = o_valid.sort() | |
| detections = [] | |
| for si, li, bi in zip(s[mask][indices_sorted], l[mask][indices_sorted], b[mask][indices_sorted]): | |
| detections.append({ | |
| "category": self._layout_id2label[li.item()], | |
| "bbox": [round(x, 2) for x in bi.tolist()], | |
| "score": round(si.item(), 4), | |
| }) | |
| batch_results.append(detections) | |
| return batch_results | |
| # Core batch decode (shared by generate & generate_with_layout) | |
| def _generate_batch( | |
| self, | |
| image_prompt_pairs: list[tuple], | |
| *, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int | None, | |
| min_dimension: int, | |
| max_dimension: int, | |
| seed: int | None, | |
| ) -> list[str]: | |
| """Core autoregressive decode for a list of (image, prompt) pairs.""" | |
| device = self.device | |
| tokenizer = self._get_tokenizer() | |
| self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>") | |
| stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")] | |
| batch_inputs = process_batch( | |
| tokenizer, self.config, image_prompt_pairs, | |
| max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension, | |
| ) | |
| batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()} | |
| tokens = batch_inputs["tokens"] | |
| B, L = tokens.size() | |
| block_size = 128 | |
| S = (L + max_new_tokens + block_size - 1) // block_size * block_size | |
| assert S <= self.config.max_seq_len | |
| rng = torch.Generator(device).manual_seed(seed) if seed is not None else None | |
| kv_cache = KVCache( | |
| max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads, | |
| head_dim=self.config.head_dim, num_layers=self.config.n_layers, | |
| ) | |
| padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device) | |
| padded_tokens[:, :L] = tokens | |
| attention_mask = self.get_attention_mask(padded_tokens, max_len=S) | |
| logits_BSV = self.forward( | |
| tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"], | |
| attention_mask=attention_mask, kv_cache=kv_cache, | |
| pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"], | |
| ) | |
| stop_ids = torch.tensor(stop_token_ids).to(device) | |
| should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device) | |
| generated_ids: list[list[int]] = [[] for _ in range(B)] | |
| while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S: | |
| tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k) | |
| if torch.any(should_stop_B): | |
| tokens_B1 = tokens_B1.clone() | |
| tokens_B1[should_stop_B, :] = self._pad_token_id | |
| padded_tokens[:, pos] = tokens_B1[:, -1] | |
| for b in range(B): | |
| if not should_stop_B[b]: | |
| generated_ids[b].append(tokens_B1[b, 0].item()) | |
| logits_BSV = self.forward( | |
| tokens=tokens_B1, attention_mask=attention_mask, kv_cache=kv_cache, | |
| ) | |
| hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1) | |
| should_stop_B = should_stop_B.logical_or(hit_stop_B) | |
| results = [] | |
| for b in range(B): | |
| text = tokenizer.decode(generated_ids[b], skip_special_tokens=False) | |
| text = text.replace("<|end_of_query|>", "").replace("<|end_of_text|>", "").strip() | |
| results.append(text) | |
| return results | |
| # Main API: generate | |
| def generate( | |
| self, | |
| images, | |
| *, | |
| category: str | list[str] = "plain", | |
| max_new_tokens: int = 4096, | |
| temperature: float = 0.0, | |
| top_k: int | None = None, | |
| min_dimension: int = 64, | |
| max_dimension: int = 1024, | |
| compile: bool = True, | |
| seed: int | None = 42, | |
| ) -> list[str]: | |
| """ | |
| Extract text from document images. | |
| Args: | |
| images: Single PIL Image (or path/URL) or list of them. | |
| category: OCR category — one of "plain", "text", "table", "formula", | |
| "caption", "footnote", "list-item", "page-footer", "page-header", | |
| "section-header", "title". Can be a single string (applied to all | |
| images) or a list (one per image). | |
| max_new_tokens: Maximum generation steps. | |
| temperature: Sampling temperature (0.0 = greedy). | |
| top_k: Top-k sampling (None = disabled). | |
| min_dimension: Min image side after resize. | |
| max_dimension: Max image side after resize. | |
| compile: Whether to torch.compile on first call. | |
| seed: Random seed for reproducibility (None = non-deterministic). | |
| Returns: | |
| List of extracted text strings, one per image. | |
| """ | |
| self._ensure_device_buffers() | |
| if compile: | |
| self.compile_model() | |
| if isinstance(images, (str, Path, Image.Image)): | |
| images = [images] | |
| if isinstance(category, str): | |
| category = [category] * len(images) | |
| assert len(images) == len(category), "Must provide one category per image" | |
| image_prompt_pairs = [] | |
| for img, cat in zip(images, category): | |
| instruction = CATEGORY_PROMPTS.get(cat.strip().lower(), CATEGORY_PROMPTS["plain"]) | |
| prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>" | |
| image_prompt_pairs.append((img, prompt)) | |
| return self._generate_batch( | |
| image_prompt_pairs, | |
| max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, | |
| min_dimension=min_dimension, max_dimension=max_dimension, seed=seed, | |
| ) | |
| # Main API: generate_with_layout | |
| def generate_with_layout( | |
| self, | |
| images, | |
| *, | |
| max_new_tokens: int = 4096, | |
| temperature: float = 0.0, | |
| top_k: int | None = None, | |
| min_dimension: int = 64, | |
| max_dimension: int = 1024, | |
| compile: bool = True, | |
| seed: int | None = 42, | |
| layout_threshold: float = 0.3, | |
| layout_batch_size: int = 4, | |
| ocr_batch_size: int = 32, | |
| containment_threshold: float = 0.8, | |
| layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors", | |
| ) -> list[list[dict]]: | |
| """ | |
| Run layout detection then OCR on each detected region. | |
| Args: | |
| images: Single PIL Image (or path/URL) or list of them. | |
| max_new_tokens: Maximum generation steps per crop. | |
| temperature: Sampling temperature (0.0 = greedy). | |
| top_k: Top-k sampling (None = disabled). | |
| min_dimension: Min crop side after resize for OCR. | |
| max_dimension: Max crop side after resize for OCR. | |
| compile: Whether to torch.compile on first call. | |
| seed: Random seed for reproducibility. | |
| layout_threshold: Confidence threshold for layout detections. | |
| layout_batch_size: Batch size for layout detection. | |
| ocr_batch_size: Batch size for OCR generation (chunks crops). | |
| containment_threshold: Drop formula boxes >threshold contained in text boxes. | |
| layout_model: HuggingFace model ID for layout detection. | |
| Returns: | |
| Per-image list of detections, each a dict with keys: | |
| ``category``, ``bbox`` [x1,y1,x2,y2], ``score``, ``text``. | |
| """ | |
| self._ensure_device_buffers() | |
| if compile: | |
| self.compile_model() | |
| self._load_layout_model(layout_model) | |
| if isinstance(images, (str, Path, Image.Image)): | |
| images = [images] | |
| pil_images = [load_image(img).convert("RGB") for img in images] | |
| # --- Layout detection (batched) --- | |
| all_layout_dets: list[list[dict]] = [] | |
| for i in range(0, len(pil_images), layout_batch_size): | |
| batch_imgs = pil_images[i : i + layout_batch_size] | |
| dets = self._run_layout_detection(batch_imgs, threshold=layout_threshold) | |
| all_layout_dets.extend(dets) | |
| # --- Filter nested boxes (e.g. inline formulas inside text) --- | |
| all_layout_dets = [ | |
| _filter_nested_detections(dets, containment_threshold) | |
| for dets in all_layout_dets | |
| ] | |
| # --- Build crops + track origin --- | |
| flat_crops: list[tuple[Image.Image, str]] = [] | |
| crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx) | |
| for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)): | |
| if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"): | |
| prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>" | |
| flat_crops.append((pil_img, prompt)) | |
| crop_origins.append((img_idx, -1)) | |
| continue | |
| img_w, img_h = pil_img.size | |
| for det_idx, det in enumerate(dets): | |
| cat_key = det["category"].strip().lower() | |
| ocr_cat = LAYOUT_TO_OCR_CATEGORY.get(cat_key) | |
| if ocr_cat is None: | |
| continue | |
| x1, y1, x2, y2 = det["bbox"] | |
| x1 = max(0, int(x1)) | |
| y1 = max(0, int(y1)) | |
| x2 = min(img_w, int(x2 + 0.5)) | |
| y2 = min(img_h, int(y2 + 0.5)) | |
| cw, ch = x2 - x1, y2 - y1 | |
| if cw < _MIN_CROP_DIM or ch < _MIN_CROP_DIM: | |
| continue | |
| short, long = sorted((cw, ch)) | |
| resized_short = short * (max_dimension / long) if long > max_dimension else short | |
| if resized_short < _MIN_CROP_DIM: | |
| continue | |
| crop = pil_img.crop((x1, y1, x2, y2)) | |
| instruction = CATEGORY_PROMPTS.get(ocr_cat, CATEGORY_PROMPTS["plain"]) | |
| prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>" | |
| flat_crops.append((crop, prompt)) | |
| crop_origins.append((img_idx, det_idx)) | |
| # --- OCR in chunks --- | |
| flat_texts: list[str] = [] | |
| for i in range(0, max(len(flat_crops), 1), ocr_batch_size): | |
| chunk = flat_crops[i : i + ocr_batch_size] | |
| if not chunk: | |
| break | |
| texts = self._generate_batch( | |
| chunk, | |
| max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, | |
| min_dimension=min_dimension, max_dimension=max_dimension, seed=seed, | |
| ) | |
| flat_texts.extend(texts) | |
| # --- Reassemble per-image results --- | |
| results: list[list[dict]] = [[] for _ in range(len(pil_images))] | |
| for (img_idx, det_idx), text in zip(crop_origins, flat_texts): | |
| if det_idx == -1: | |
| img_w, img_h = pil_images[img_idx].size | |
| results[img_idx].append({ | |
| "category": "plain", | |
| "bbox": [0, 0, img_w, img_h], | |
| "score": 1.0, | |
| "text": text, | |
| }) | |
| else: | |
| det = all_layout_dets[img_idx][det_idx] | |
| results[img_idx].append({ | |
| "category": det["category"], | |
| "bbox": det["bbox"], | |
| "score": det["score"], | |
| "text": text, | |
| }) | |
| return results | |