| from pathlib import Path |
|
|
| import einops as E |
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
| from PIL import Image |
| from torch import Tensor as T |
| from torch import nn |
| from torch.nn.attention.flex_attention import ( |
| AuxRequest, |
| BlockMask, |
| ) |
| from transformers import AutoTokenizer, PreTrainedModel |
|
|
| from .attention import ( |
| compiled_flex_attn_decode, |
| compiled_flex_attn_prefill, |
| create_batch_attention_mask, |
| offset_mask_mod, |
| ) |
| from .configuration_falcon_ocr import FalconOCRConfig |
| from .processing_falcon_ocr import load_image, process_batch |
| from .rope import ( |
| apply_3d_rotary_emb, |
| apply_golden_freqs_cis_to_visual_pos, |
| precompute_freqs_cis, |
| ) |
|
|
|
|
| CATEGORY_PROMPTS = { |
| "plain": "Extract the text content from this image.", |
| "formula": "Extract the formula content from this image.", |
| "table": "Extract the table content from this image.", |
| "text": "Extract the text content from this image.", |
| "caption": "Extract the caption content from this image.", |
| "footnote": "Extract the footnote content from this image.", |
| "list-item": "Extract the list-item content from this image.", |
| "page-footer": "Extract the page-footer content from this image.", |
| "page-header": "Extract the page-header content from this image.", |
| "section-header": "Extract the section-header content from this image.", |
| "title": "Extract the title content from this image.", |
| } |
|
|
| LAYOUT_TO_OCR_CATEGORY: dict[str, str | None] = { |
| "text": "text", |
| "table": "table", |
| "formula": "formula", |
| "caption": "caption", |
| "footnote": "footnote", |
| "list-item": "list-item", |
| "title": "title", |
| "header": "text", |
| "footer": "page-footer", |
| "number": "text", |
| "figure_title": "caption", |
| "paragraph_title": "section-header", |
| "doc_title": "title", |
| "reference_content": "text", |
| "reference": "text", |
| "abstract": "text", |
| "aside_text": "text", |
| "content": "text", |
| "formula_number": "text", |
| "vision_footnote": "footnote", |
| "algorithm": "text", |
| "page-footer": "page-footer", |
| "page-header": "page-header", |
| "section-header": "section-header", |
| |
| "image": None, |
| "picture": None, |
| "figure": None, |
| "chart": None, |
| "seal": None, |
| } |
|
|
| _LAYOUT_TARGET_H, _LAYOUT_TARGET_W = 800, 800 |
| _MIN_CROP_DIM = 16 |
|
|
| def _box_area(bbox): |
| return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1]) |
|
|
|
|
| def _intersection_area(a, b): |
| return max(0, min(a[2], b[2]) - max(a[0], b[0])) * max(0, min(a[3], b[3]) - max(a[1], b[1])) |
|
|
|
|
| def _containment_ratio(small, large): |
| area = _box_area(small) |
| if area <= 0: |
| return 0.0 |
| return _intersection_area(small, large) / area |
|
|
|
|
| def _filter_nested_detections(detections: list[dict], containment_threshold: float = 0.8) -> list[dict]: |
| """Remove any box that is mostly contained within a strictly larger box.""" |
| areas = [_box_area(d["bbox"]) for d in detections] |
| keep = [] |
| for i, det in enumerate(detections): |
| is_nested = False |
| for j, other in enumerate(detections): |
| if i == j: |
| continue |
| if areas[j] <= areas[i]: |
| continue |
| if _containment_ratio(det["bbox"], other["bbox"]) > containment_threshold: |
| is_nested = True |
| break |
| if not is_nested: |
| keep.append(det) |
| return keep |
|
|
|
|
| |
|
|
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| B, S, H, D = x.shape |
| if n_rep == 1: |
| return x |
| return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config: FalconOCRConfig, layer_id: int): |
| super().__init__() |
| self.layer_id = layer_id |
| self.n_kv_heads = config.n_kv_heads or config.n_heads |
| self.n_rep = config.n_heads // self.n_kv_heads |
| self.head_dim = config.head_dim or config.dim // config.n_heads |
| self.q_dim = config.n_heads * self.head_dim |
| self.kv_dim = self.n_kv_heads * self.head_dim |
|
|
| self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False) |
| self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False) |
| self.sinks = nn.Parameter(torch.empty((config.n_heads,))) |
|
|
| def _pre_attention_qkv(self, x) -> tuple[T, T, T]: |
| qkv = self.wqkv(F.rms_norm(x, (x.size(-1),))) |
| xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1) |
| xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim) |
| xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim) |
| xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim) |
| xq = F.rms_norm(xq, (xq.size(-1),)) |
| xk = F.rms_norm(xk, (xk.size(-1),)) |
| xk = repeat_kv(xk, n_rep=self.n_rep) |
| xv = repeat_kv(xv, n_rep=self.n_rep) |
| return xq, xk, xv |
|
|
| def _post_attention(self, output: T, lse: T) -> T: |
| |
| |
| sinks_BHS = self.sinks.view(1, -1, 1) |
| sink_scale = torch.sigmoid(lse - sinks_BHS) |
| output = (output * sink_scale.unsqueeze(-1)).to(output.dtype) |
| output = output.permute(0, 2, 1, 3).contiguous().flatten(2) |
| return self.wo(output) |
|
|
| def compile_attention(self, *, dynamic: bool = True, mode: str = "default"): |
| self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode) |
| self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode) |
|
|
| def forward( |
| self, x: T, attention_masks: BlockMask, freqs_cis: T, |
| freqs_cis_2d: T | None = None, pos_hw: T | None = None, |
| kv_cache=None, input_pos=None, batch_idx=None, |
| flex_attn_kernel_options=None, |
| ): |
| xq, xk, xv = self._pre_attention_qkv(x) |
| xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw) |
| xq = E.rearrange(xq, "b s h d -> b h s d") |
| xk = E.rearrange(xk, "b s h d -> b h s d") |
| xv = E.rearrange(xv, "b s h d -> b h s d") |
| xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx) |
| flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill |
| output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True)) |
| return self._post_attention(output, aux_output.lse) |
|
|
|
|
| |
|
|
| @triton.jit |
| def _squared_relu_gate_kernel( |
| packed_ptr, out_ptr, n_rows, n_cols, |
| in_row_stride, in_col_stride, out_row_stride, out_col_stride, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| pid = tl.program_id(0) |
| n_elements = n_rows * n_cols |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| mask = offsets < n_elements |
| rows = offsets // n_cols |
| cols = offsets % n_cols |
| gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride |
| up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride |
| out_idx = rows * out_row_stride + cols * out_col_stride |
| gate = tl.load(packed_ptr + gate_idx, mask=mask) |
| up = tl.load(packed_ptr + up_idx, mask=mask) |
| gate = tl.where(gate > 0, gate, 0.0) |
| out = gate * gate * up |
| tl.store(out_ptr + out_idx, out, mask=mask) |
|
|
|
|
| def squared_relu_gate(packed: T, hidden_dim: int) -> T: |
| """Processes interleaved [gate, up, gate, up, ...] from w13; output = ReLU(gate)^2 * up.""" |
| packed_2d = packed.flatten(0, -2) |
| n_rows = packed_2d.shape[0] |
| n_cols = hidden_dim |
| out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype) |
| n = n_rows * n_cols |
| grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) |
| _squared_relu_gate_kernel[grid]( |
| packed_2d, out_2d, n_rows, n_cols, |
| packed_2d.stride(0), packed_2d.stride(1), |
| out_2d.stride(0), out_2d.stride(1), |
| BLOCK_SIZE=1024, |
| ) |
| return out_2d.view(*packed.shape[:-1], hidden_dim) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim: int, hidden_dim: int): |
| super().__init__() |
| self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.hidden_dim = hidden_dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = F.rms_norm(x, (x.size(-1),)) |
| w13_out = self.w13(x) |
| return self.w2(squared_relu_gate(w13_out, self.hidden_dim)) |
|
|
|
|
| |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, layer_id: int, config: FalconOCRConfig): |
| super().__init__() |
| self.attention = Attention(config, layer_id) |
| self.feed_forward = FeedForward(config.dim, config.ffn_dim) |
|
|
| def compile(self, *, dynamic: bool = True, mode: str = "default"): |
| self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode) |
| self.attention.compile_attention(dynamic=dynamic, mode=mode) |
| return self |
|
|
| def forward( |
| self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None, |
| pos_hw: T | None = None, attention_masks=None, kv_cache=None, |
| input_pos=None, batch_idx=None, flex_attn_kernel_options=None, |
| ): |
| B, S, D = x.shape |
| x = x + self.attention( |
| x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw, |
| attention_masks=attention_masks, kv_cache=kv_cache, |
| input_pos=input_pos, batch_idx=batch_idx, |
| flex_attn_kernel_options=flex_attn_kernel_options, |
| ) |
| out = x + self.feed_forward(x) |
| return out.reshape(B, S, D) |
|
|
|
|
| |
|
|
| class KVCache: |
| def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers): |
| self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim) |
| self.kv_cache = None |
| self.pos = 0 |
| self.pos_t: T | None = None |
|
|
| def reset(self): |
| self.pos = 0 |
| self.pos_t = None |
|
|
| def get_pos(self): |
| return self.pos |
|
|
| def set_pos_t(self, pos_t): |
| self.pos_t = pos_t |
|
|
| def increment_and_get_pos_t(self): |
| assert self.pos_t is not None |
| self.pos_t += 1 |
| return self.pos_t |
|
|
| def insert_kv(self, layer_id: int, k: T, v: T, **kwargs): |
| del kwargs |
| assert self.pos_t is not None |
| if self.kv_cache is None: |
| self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) |
| B, H, T_add, D = k.size() |
| t0, t1 = self.pos, self.pos + T_add |
| self.kv_cache[layer_id, 0, :, :, t0:t1] = k |
| self.kv_cache[layer_id, 1, :, :, t0:t1] = v |
| key_view = self.kv_cache[layer_id, 0, :, :, :t1] |
| value_view = self.kv_cache[layer_id, 1, :, :, :t1] |
| if layer_id == self.kv_cache.size(0) - 1: |
| self.pos = t1 |
| return key_view, value_view |
|
|
|
|
| |
|
|
| @torch.inference_mode() |
| def sample_next_token(logits, rng, temperature=0.0, top_k=None): |
| assert temperature >= 0.0 |
| if temperature == 0.0: |
| return torch.argmax(logits, dim=-1, keepdim=True) |
| if top_k is not None: |
| k = min(top_k, logits.size(-1)) |
| vals, idx = torch.topk(logits, k, dim=-1) |
| vals = vals / temperature |
| probs = F.softmax(vals, dim=-1) |
| choice = torch.multinomial(probs, num_samples=1, generator=rng) |
| return idx.gather(1, choice) |
| logits = logits / temperature |
| probs = F.softmax(logits, dim=-1) |
| return torch.multinomial(probs, num_samples=1, generator=rng) |
|
|
|
|
| |
|
|
| class FalconOCRForCausalLM(PreTrainedModel): |
| config_class = FalconOCRConfig |
| _no_split_modules = ["TransformerBlock"] |
|
|
| def __init__(self, config: FalconOCRConfig): |
| super().__init__(config) |
| img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size |
| self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False) |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
|
|
| self.layers = nn.ModuleDict() |
| for layer_id in range(config.n_layers): |
| self.layers[str(layer_id)] = TransformerBlock(layer_id, config) |
|
|
| self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps) |
| self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
| rope_dim = config.head_dim // 2 |
| freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta) |
| freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float) |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) |
| self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True) |
|
|
| self._weights_fused = False |
| self._is_compiled = False |
|
|
| self.post_init() |
|
|
| |
|
|
| def _ensure_device_buffers(self): |
| """Recompute non-persistent buffers that HF meta-device loading may discard.""" |
| if self._weights_fused: |
| return |
| device = self.tok_embeddings.weight.device |
| c = self.config |
| rope_dim = c.head_dim // 2 |
| freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device) |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) |
| if self.freqs_cis_golden.device != device: |
| self.freqs_cis_golden = self.freqs_cis_golden.to(device) |
| self._weights_fused = True |
|
|
| def compile_model(self): |
| if self._is_compiled: |
| return |
| torch._inductor.config.triton.cudagraphs = False |
| for layer in self.layers.values(): |
| layer.compile(dynamic=True, mode="default") |
| self._is_compiled = True |
|
|
| |
|
|
| def _get_tokenizer(self): |
| if not hasattr(self, "_tokenizer"): |
| import os |
| path = self.config._name_or_path |
| is_local = os.path.exists(path) |
| self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True) |
| for token_name, token in self._tokenizer.special_tokens_map.items(): |
| if isinstance(token, str): |
| setattr(self._tokenizer, token_name, token) |
| setattr( |
| self._tokenizer, token_name + "_id", |
| self._tokenizer.convert_tokens_to_ids(token), |
| ) |
| return self._tokenizer |
|
|
| |
|
|
| def get_attention_mask(self, input_batch: T, max_len: int | None = None): |
| return create_batch_attention_mask( |
| input_batch, |
| pad_token_id=self._pad_token_id, |
| eos_token_id=self.config.eos_id, |
| soi_token_id=self.config.image_cls_token_id, |
| eoi_token_id=self.config.img_end_id, |
| max_len=max_len, |
| ) |
|
|
| |
|
|
| def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS): |
| B, S, D = h_BSD.shape |
| pixel_patch_mask = E.reduce( |
| pixel_masks_NTHW, |
| "n (t pt) (h ph) (w pw) -> (n t h w)", |
| reduction="any", |
| pt=self.config.temporal_patch_size, |
| ph=self.config.spatial_patch_size, |
| pw=self.config.spatial_patch_size, |
| ) |
| pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c") |
| valid_patches = pixel_patches_flat[pixel_patch_mask] |
| valid_feats = self.img_projector(valid_patches) |
| img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D) |
| assert valid_feats.numel() == img_mask_h_BSD.sum() |
| return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats) |
|
|
| |
|
|
| def forward( |
| self, |
| tokens: T, |
| attention_mask: BlockMask, |
| kv_cache, |
| rope_pos_t: T | None = None, |
| rope_pos_hw: T | None = None, |
| pixel_values: T | None = None, |
| pixel_mask: T | None = None, |
| ): |
| B, S = tokens.size() |
| c = self.config |
| block_mask = attention_mask |
|
|
| T_pos = kv_cache.get_pos() |
| is_prefill = S != 1 |
|
|
| if is_prefill: |
| assert rope_pos_t is not None and rope_pos_hw is not None |
| pos_t = rope_pos_t[:, T_pos:T_pos + S].long() |
| kv_cache.pos_t = pos_t[:, -1:] |
| freqs_cis = self.freqs_cis[pos_t] |
| rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S] |
| freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw) |
| block_mask.seq_lengths = (S, S) |
| else: |
| pos_t = kv_cache.increment_and_get_pos_t() |
| freqs_cis = self.freqs_cis[pos_t] |
| freqs_cis_golden = None |
| block_idx = T_pos // block_mask.BLOCK_SIZE[0] |
| block_mask = block_mask[:, :, block_idx] |
| block_mask.seq_lengths = (S, T_pos + S) |
| block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos) |
|
|
| h_BSD = self.tok_embeddings(tokens) |
|
|
| if pixel_values is not None: |
| assert pixel_mask is not None |
| pixel_values = pixel_values.to(self.dtype) |
| pixel_mask = pixel_mask.to(self.dtype) |
| pixel_patches_NLC = E.rearrange( |
| pixel_values, |
| "n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)", |
| pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size, |
| ) |
| h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens) |
|
|
| for layer in self.layers.values(): |
| h_BSD = layer( |
| h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden, |
| pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache, |
| ) |
|
|
| h_BSD = self.norm(h_BSD) |
| logits_BSV = self.output(h_BSD) |
| return logits_BSV |
|
|
| |
|
|
| def _load_layout_model(self, layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors"): |
| if hasattr(self, "_layout_model"): |
| return |
| import torchvision.transforms.functional as tvF |
| from transformers import AutoModelForObjectDetection, PPDocLayoutV3ImageProcessorFast |
|
|
| self._layout_processor = PPDocLayoutV3ImageProcessorFast.from_pretrained(layout_model) |
| self._layout_det_model = AutoModelForObjectDetection.from_pretrained( |
| layout_model, torch_dtype=torch.float16, |
| ).to(self.device).eval() |
| self._layout_id2label = self._layout_det_model.config.id2label |
| self._tvF = tvF |
|
|
| @torch.inference_mode() |
| def _run_layout_detection( |
| self, images: list[Image.Image], threshold: float = 0.5, |
| ) -> list[list[dict]]: |
| """Run PP-DocLayoutV3 on a batch of PIL images, return per-image detections.""" |
| device = self.device |
| tvF = self._tvF |
|
|
| target_sizes = torch.tensor([img.size[::-1] for img in images]) |
| tensors = [tvF.pil_to_tensor(img) for img in images] |
|
|
| |
| result = torch.empty( |
| len(tensors), 3, _LAYOUT_TARGET_H, _LAYOUT_TARGET_W, |
| dtype=torch.float16, device=device, |
| ) |
| size_groups: dict[tuple[int, int], list[int]] = {} |
| for i, t in enumerate(tensors): |
| size_groups.setdefault((t.shape[1], t.shape[2]), []).append(i) |
|
|
| for shape, indices in size_groups.items(): |
| batch = torch.stack([tensors[i] for i in indices]) |
| batch = batch.to(device=device, dtype=torch.float32, non_blocking=True) |
| batch = F.interpolate( |
| batch, size=(_LAYOUT_TARGET_H, _LAYOUT_TARGET_W), |
| mode="bicubic", align_corners=False, antialias=False, |
| ) |
| batch = (batch.clamp_(0, 255) / 255.0).to(torch.float16) |
| for j, idx in enumerate(indices): |
| result[idx] = batch[j] |
| del batch |
|
|
| outputs = self._layout_det_model(pixel_values=result) |
| del result |
|
|
| |
| logits = outputs.logits |
| boxes = outputs.pred_boxes |
| order_logits = outputs.order_logits |
|
|
| box_centers, box_dims = boxes.split(2, dim=-1) |
| boxes_xyxy = torch.cat([box_centers - 0.5 * box_dims, box_centers + 0.5 * box_dims], dim=-1) |
|
|
| img_h, img_w = target_sizes.unbind(1) |
| scale = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(device, dtype=boxes_xyxy.dtype) |
| boxes_xyxy = boxes_xyxy * scale[:, None, :] |
|
|
| num_queries = logits.shape[1] |
| num_classes = logits.shape[2] |
| scores = logits.sigmoid() |
| scores_flat, index = scores.flatten(1).topk(num_queries, dim=-1) |
| labels = index % num_classes |
| box_indices = index // num_classes |
| boxes_xyxy = boxes_xyxy.gather(dim=1, index=box_indices.unsqueeze(-1).expand(-1, -1, 4)) |
|
|
| order_seqs = self._layout_processor._get_order_seqs(order_logits) |
| order_seqs = order_seqs.gather(dim=1, index=box_indices) |
|
|
| batch_results = [] |
| for s, l, b, o in zip(scores_flat, labels, boxes_xyxy, order_seqs): |
| mask = s >= threshold |
| o_valid = o[mask] |
| _, indices_sorted = o_valid.sort() |
|
|
| detections = [] |
| for si, li, bi in zip(s[mask][indices_sorted], l[mask][indices_sorted], b[mask][indices_sorted]): |
| detections.append({ |
| "category": self._layout_id2label[li.item()], |
| "bbox": [round(x, 2) for x in bi.tolist()], |
| "score": round(si.item(), 4), |
| }) |
| batch_results.append(detections) |
|
|
| return batch_results |
|
|
| |
|
|
| def _generate_batch( |
| self, |
| image_prompt_pairs: list[tuple], |
| *, |
| max_new_tokens: int, |
| temperature: float, |
| top_k: int | None, |
| min_dimension: int, |
| max_dimension: int, |
| seed: int | None, |
| ) -> list[str]: |
| """Core autoregressive decode for a list of (image, prompt) pairs.""" |
| device = self.device |
| tokenizer = self._get_tokenizer() |
| self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>") |
| stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")] |
|
|
| batch_inputs = process_batch( |
| tokenizer, self.config, image_prompt_pairs, |
| max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension, |
| ) |
| batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()} |
|
|
| tokens = batch_inputs["tokens"] |
| B, L = tokens.size() |
| block_size = 128 |
| S = (L + max_new_tokens + block_size - 1) // block_size * block_size |
| assert S <= self.config.max_seq_len |
|
|
| rng = torch.Generator(device).manual_seed(seed) if seed is not None else None |
|
|
| kv_cache = KVCache( |
| max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads, |
| head_dim=self.config.head_dim, num_layers=self.config.n_layers, |
| ) |
|
|
| padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device) |
| padded_tokens[:, :L] = tokens |
|
|
| attention_mask = self.get_attention_mask(padded_tokens, max_len=S) |
|
|
| logits_BSV = self.forward( |
| tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"], |
| attention_mask=attention_mask, kv_cache=kv_cache, |
| pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"], |
| ) |
|
|
| stop_ids = torch.tensor(stop_token_ids).to(device) |
| should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device) |
| generated_ids: list[list[int]] = [[] for _ in range(B)] |
|
|
| while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S: |
| tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k) |
|
|
| if torch.any(should_stop_B): |
| tokens_B1 = tokens_B1.clone() |
| tokens_B1[should_stop_B, :] = self._pad_token_id |
| padded_tokens[:, pos] = tokens_B1[:, -1] |
|
|
| for b in range(B): |
| if not should_stop_B[b]: |
| generated_ids[b].append(tokens_B1[b, 0].item()) |
|
|
| logits_BSV = self.forward( |
| tokens=tokens_B1, attention_mask=attention_mask, kv_cache=kv_cache, |
| ) |
|
|
| hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1) |
| should_stop_B = should_stop_B.logical_or(hit_stop_B) |
|
|
| results = [] |
| for b in range(B): |
| text = tokenizer.decode(generated_ids[b], skip_special_tokens=False) |
| text = text.replace("<|end_of_query|>", "").replace("<|end_of_text|>", "").strip() |
| results.append(text) |
|
|
| return results |
|
|
| |
|
|
| @torch.inference_mode() |
| def generate( |
| self, |
| images, |
| *, |
| category: str | list[str] = "plain", |
| max_new_tokens: int = 4096, |
| temperature: float = 0.0, |
| top_k: int | None = None, |
| min_dimension: int = 64, |
| max_dimension: int = 1024, |
| compile: bool = True, |
| seed: int | None = 42, |
| ) -> list[str]: |
| """ |
| Extract text from document images. |
| |
| Args: |
| images: Single PIL Image (or path/URL) or list of them. |
| category: OCR category — one of "plain", "text", "table", "formula", |
| "caption", "footnote", "list-item", "page-footer", "page-header", |
| "section-header", "title". Can be a single string (applied to all |
| images) or a list (one per image). |
| max_new_tokens: Maximum generation steps. |
| temperature: Sampling temperature (0.0 = greedy). |
| top_k: Top-k sampling (None = disabled). |
| min_dimension: Min image side after resize. |
| max_dimension: Max image side after resize. |
| compile: Whether to torch.compile on first call. |
| seed: Random seed for reproducibility (None = non-deterministic). |
| |
| Returns: |
| List of extracted text strings, one per image. |
| """ |
| self._ensure_device_buffers() |
| if compile: |
| self.compile_model() |
|
|
| if isinstance(images, (str, Path, Image.Image)): |
| images = [images] |
| if isinstance(category, str): |
| category = [category] * len(images) |
| assert len(images) == len(category), "Must provide one category per image" |
|
|
| image_prompt_pairs = [] |
| for img, cat in zip(images, category): |
| instruction = CATEGORY_PROMPTS.get(cat.strip().lower(), CATEGORY_PROMPTS["plain"]) |
| prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>" |
| image_prompt_pairs.append((img, prompt)) |
|
|
| return self._generate_batch( |
| image_prompt_pairs, |
| max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, |
| min_dimension=min_dimension, max_dimension=max_dimension, seed=seed, |
| ) |
|
|
| |
|
|
| @torch.inference_mode() |
| def generate_with_layout( |
| self, |
| images, |
| *, |
| max_new_tokens: int = 4096, |
| temperature: float = 0.0, |
| top_k: int | None = None, |
| min_dimension: int = 64, |
| max_dimension: int = 1024, |
| compile: bool = True, |
| seed: int | None = 42, |
| layout_threshold: float = 0.3, |
| layout_batch_size: int = 4, |
| ocr_batch_size: int = 32, |
| containment_threshold: float = 0.8, |
| layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors", |
| ) -> list[list[dict]]: |
| """ |
| Run layout detection then OCR on each detected region. |
| |
| Args: |
| images: Single PIL Image (or path/URL) or list of them. |
| max_new_tokens: Maximum generation steps per crop. |
| temperature: Sampling temperature (0.0 = greedy). |
| top_k: Top-k sampling (None = disabled). |
| min_dimension: Min crop side after resize for OCR. |
| max_dimension: Max crop side after resize for OCR. |
| compile: Whether to torch.compile on first call. |
| seed: Random seed for reproducibility. |
| layout_threshold: Confidence threshold for layout detections. |
| layout_batch_size: Batch size for layout detection. |
| ocr_batch_size: Batch size for OCR generation (chunks crops). |
| containment_threshold: Drop formula boxes >threshold contained in text boxes. |
| layout_model: HuggingFace model ID for layout detection. |
| |
| Returns: |
| Per-image list of detections, each a dict with keys: |
| ``category``, ``bbox`` [x1,y1,x2,y2], ``score``, ``text``. |
| """ |
| self._ensure_device_buffers() |
| if compile: |
| self.compile_model() |
| self._load_layout_model(layout_model) |
|
|
| if isinstance(images, (str, Path, Image.Image)): |
| images = [images] |
| pil_images = [load_image(img).convert("RGB") for img in images] |
|
|
| |
| all_layout_dets: list[list[dict]] = [] |
| for i in range(0, len(pil_images), layout_batch_size): |
| batch_imgs = pil_images[i : i + layout_batch_size] |
| dets = self._run_layout_detection(batch_imgs, threshold=layout_threshold) |
| all_layout_dets.extend(dets) |
|
|
| |
| all_layout_dets = [ |
| _filter_nested_detections(dets, containment_threshold) |
| for dets in all_layout_dets |
| ] |
|
|
| |
| flat_crops: list[tuple[Image.Image, str]] = [] |
| crop_origins: list[tuple[int, int]] = [] |
|
|
| for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)): |
| if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"): |
| prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>" |
| flat_crops.append((pil_img, prompt)) |
| crop_origins.append((img_idx, -1)) |
| continue |
|
|
| img_w, img_h = pil_img.size |
| for det_idx, det in enumerate(dets): |
| cat_key = det["category"].strip().lower() |
| ocr_cat = LAYOUT_TO_OCR_CATEGORY.get(cat_key) |
| if ocr_cat is None: |
| continue |
|
|
| x1, y1, x2, y2 = det["bbox"] |
| x1 = max(0, int(x1)) |
| y1 = max(0, int(y1)) |
| x2 = min(img_w, int(x2 + 0.5)) |
| y2 = min(img_h, int(y2 + 0.5)) |
| cw, ch = x2 - x1, y2 - y1 |
| if cw < _MIN_CROP_DIM or ch < _MIN_CROP_DIM: |
| continue |
| short, long = sorted((cw, ch)) |
| resized_short = short * (max_dimension / long) if long > max_dimension else short |
| if resized_short < _MIN_CROP_DIM: |
| continue |
|
|
| crop = pil_img.crop((x1, y1, x2, y2)) |
| instruction = CATEGORY_PROMPTS.get(ocr_cat, CATEGORY_PROMPTS["plain"]) |
| prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>" |
| flat_crops.append((crop, prompt)) |
| crop_origins.append((img_idx, det_idx)) |
|
|
| |
| flat_texts: list[str] = [] |
| for i in range(0, max(len(flat_crops), 1), ocr_batch_size): |
| chunk = flat_crops[i : i + ocr_batch_size] |
| if not chunk: |
| break |
| texts = self._generate_batch( |
| chunk, |
| max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, |
| min_dimension=min_dimension, max_dimension=max_dimension, seed=seed, |
| ) |
| flat_texts.extend(texts) |
|
|
| |
| results: list[list[dict]] = [[] for _ in range(len(pil_images))] |
| for (img_idx, det_idx), text in zip(crop_origins, flat_texts): |
| if det_idx == -1: |
| img_w, img_h = pil_images[img_idx].size |
| results[img_idx].append({ |
| "category": "plain", |
| "bbox": [0, 0, img_w, img_h], |
| "score": 1.0, |
| "text": text, |
| }) |
| else: |
| det = all_layout_dets[img_idx][det_idx] |
| results[img_idx].append({ |
| "category": det["category"], |
| "bbox": det["bbox"], |
| "score": det["score"], |
| "text": text, |
| }) |
|
|
| return results |
|
|