Falcon-OCR-bf16 / modeling_falcon_ocr.py
geoHeil's picture
Upload MLX-converted Falcon-OCR (bf16)
28ccccd verified
from pathlib import Path
import einops as E
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from PIL import Image
from torch import Tensor as T
from torch import nn
from torch.nn.attention.flex_attention import (
AuxRequest,
BlockMask,
)
from transformers import AutoTokenizer, PreTrainedModel
from .attention import (
compiled_flex_attn_decode,
compiled_flex_attn_prefill,
create_batch_attention_mask,
offset_mask_mod,
)
from .configuration_falcon_ocr import FalconOCRConfig
from .processing_falcon_ocr import load_image, process_batch
from .rope import (
apply_3d_rotary_emb,
apply_golden_freqs_cis_to_visual_pos,
precompute_freqs_cis,
)
CATEGORY_PROMPTS = {
"plain": "Extract the text content from this image.",
"formula": "Extract the formula content from this image.",
"table": "Extract the table content from this image.",
"text": "Extract the text content from this image.",
"caption": "Extract the caption content from this image.",
"footnote": "Extract the footnote content from this image.",
"list-item": "Extract the list-item content from this image.",
"page-footer": "Extract the page-footer content from this image.",
"page-header": "Extract the page-header content from this image.",
"section-header": "Extract the section-header content from this image.",
"title": "Extract the title content from this image.",
}
LAYOUT_TO_OCR_CATEGORY: dict[str, str | None] = {
"text": "text",
"table": "table",
"formula": "formula",
"caption": "caption",
"footnote": "footnote",
"list-item": "list-item",
"title": "title",
"header": "text",
"footer": "page-footer",
"number": "text",
"figure_title": "caption",
"paragraph_title": "section-header",
"doc_title": "title",
"reference_content": "text",
"reference": "text",
"abstract": "text",
"aside_text": "text",
"content": "text",
"formula_number": "text",
"vision_footnote": "footnote",
"algorithm": "text",
"page-footer": "page-footer",
"page-header": "page-header",
"section-header": "section-header",
# Skip — no text to extract
"image": None,
"picture": None,
"figure": None,
"chart": None,
"seal": None,
}
_LAYOUT_TARGET_H, _LAYOUT_TARGET_W = 800, 800
_MIN_CROP_DIM = 16
def _box_area(bbox):
return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1])
def _intersection_area(a, b):
return max(0, min(a[2], b[2]) - max(a[0], b[0])) * max(0, min(a[3], b[3]) - max(a[1], b[1]))
def _containment_ratio(small, large):
area = _box_area(small)
if area <= 0:
return 0.0
return _intersection_area(small, large) / area
def _filter_nested_detections(detections: list[dict], containment_threshold: float = 0.8) -> list[dict]:
"""Remove any box that is mostly contained within a strictly larger box."""
areas = [_box_area(d["bbox"]) for d in detections]
keep = []
for i, det in enumerate(detections):
is_nested = False
for j, other in enumerate(detections):
if i == j:
continue
if areas[j] <= areas[i]:
continue
if _containment_ratio(det["bbox"], other["bbox"]) > containment_threshold:
is_nested = True
break
if not is_nested:
keep.append(det)
return keep
# Attention
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
B, S, H, D = x.shape
if n_rep == 1:
return x
return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D)
class Attention(nn.Module):
def __init__(self, config: FalconOCRConfig, layer_id: int):
super().__init__()
self.layer_id = layer_id
self.n_kv_heads = config.n_kv_heads or config.n_heads
self.n_rep = config.n_heads // self.n_kv_heads
self.head_dim = config.head_dim or config.dim // config.n_heads
self.q_dim = config.n_heads * self.head_dim
self.kv_dim = self.n_kv_heads * self.head_dim
self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False)
self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
qkv = self.wqkv(F.rms_norm(x, (x.size(-1),)))
xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim)
xq = F.rms_norm(xq, (xq.size(-1),))
xk = F.rms_norm(xk, (xk.size(-1),))
xk = repeat_kv(xk, n_rep=self.n_rep)
xv = repeat_kv(xv, n_rep=self.n_rep)
return xq, xk, xv
def _post_attention(self, output: T, lse: T) -> T:
# Sink-based scaling: sigmoid(lse - sinks) * output
# equivalent to prepending a sink token to the input
sinks_BHS = self.sinks.view(1, -1, 1)
sink_scale = torch.sigmoid(lse - sinks_BHS)
output = (output * sink_scale.unsqueeze(-1)).to(output.dtype)
output = output.permute(0, 2, 1, 3).contiguous().flatten(2)
return self.wo(output)
def compile_attention(self, *, dynamic: bool = True, mode: str = "default"):
self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode)
self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode)
def forward(
self, x: T, attention_masks: BlockMask, freqs_cis: T,
freqs_cis_2d: T | None = None, pos_hw: T | None = None,
kv_cache=None, input_pos=None, batch_idx=None,
flex_attn_kernel_options=None,
):
xq, xk, xv = self._pre_attention_qkv(x)
xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw)
xq = E.rearrange(xq, "b s h d -> b h s d")
xk = E.rearrange(xk, "b s h d -> b h s d")
xv = E.rearrange(xv, "b s h d -> b h s d")
xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx)
flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill
output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True))
return self._post_attention(output, aux_output.lse)
# FeedForward
@triton.jit
def _squared_relu_gate_kernel(
packed_ptr, out_ptr, n_rows, n_cols,
in_row_stride, in_col_stride, out_row_stride, out_col_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_elements = n_rows * n_cols
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
rows = offsets // n_cols
cols = offsets % n_cols
gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride
up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride
out_idx = rows * out_row_stride + cols * out_col_stride
gate = tl.load(packed_ptr + gate_idx, mask=mask)
up = tl.load(packed_ptr + up_idx, mask=mask)
gate = tl.where(gate > 0, gate, 0.0)
out = gate * gate * up
tl.store(out_ptr + out_idx, out, mask=mask)
def squared_relu_gate(packed: T, hidden_dim: int) -> T:
"""Processes interleaved [gate, up, gate, up, ...] from w13; output = ReLU(gate)^2 * up."""
packed_2d = packed.flatten(0, -2)
n_rows = packed_2d.shape[0]
n_cols = hidden_dim
out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype)
n = n_rows * n_cols
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
_squared_relu_gate_kernel[grid](
packed_2d, out_2d, n_rows, n_cols,
packed_2d.stride(0), packed_2d.stride(1),
out_2d.stride(0), out_2d.stride(1),
BLOCK_SIZE=1024,
)
return out_2d.view(*packed.shape[:-1], hidden_dim)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.hidden_dim = hidden_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.rms_norm(x, (x.size(-1),))
w13_out = self.w13(x)
return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
# TransformerBlock
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, config: FalconOCRConfig):
super().__init__()
self.attention = Attention(config, layer_id)
self.feed_forward = FeedForward(config.dim, config.ffn_dim)
def compile(self, *, dynamic: bool = True, mode: str = "default"):
self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode)
self.attention.compile_attention(dynamic=dynamic, mode=mode)
return self
def forward(
self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None,
pos_hw: T | None = None, attention_masks=None, kv_cache=None,
input_pos=None, batch_idx=None, flex_attn_kernel_options=None,
):
B, S, D = x.shape
x = x + self.attention(
x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw,
attention_masks=attention_masks, kv_cache=kv_cache,
input_pos=input_pos, batch_idx=batch_idx,
flex_attn_kernel_options=flex_attn_kernel_options,
)
out = x + self.feed_forward(x)
return out.reshape(B, S, D)
# KV Cache
class KVCache:
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers):
self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim)
self.kv_cache = None
self.pos = 0
self.pos_t: T | None = None
def reset(self):
self.pos = 0
self.pos_t = None
def get_pos(self):
return self.pos
def set_pos_t(self, pos_t):
self.pos_t = pos_t
def increment_and_get_pos_t(self):
assert self.pos_t is not None
self.pos_t += 1
return self.pos_t
def insert_kv(self, layer_id: int, k: T, v: T, **kwargs):
del kwargs
assert self.pos_t is not None
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
self.kv_cache[layer_id, 0, :, :, t0:t1] = k
self.kv_cache[layer_id, 1, :, :, t0:t1] = v
key_view = self.kv_cache[layer_id, 0, :, :, :t1]
value_view = self.kv_cache[layer_id, 1, :, :, :t1]
if layer_id == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
# Sampling
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=0.0, top_k=None):
assert temperature >= 0.0
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
# Main Model
class FalconOCRForCausalLM(PreTrainedModel):
config_class = FalconOCRConfig
_no_split_modules = ["TransformerBlock"]
def __init__(self, config: FalconOCRConfig):
super().__init__(config)
img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size
self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleDict()
for layer_id in range(config.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, config)
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
rope_dim = config.head_dim // 2
freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta)
freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True)
self._weights_fused = False
self._is_compiled = False
self.post_init()
# Weight management
def _ensure_device_buffers(self):
"""Recompute non-persistent buffers that HF meta-device loading may discard."""
if self._weights_fused:
return
device = self.tok_embeddings.weight.device
c = self.config
rope_dim = c.head_dim // 2
freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
if self.freqs_cis_golden.device != device:
self.freqs_cis_golden = self.freqs_cis_golden.to(device)
self._weights_fused = True
def compile_model(self):
if self._is_compiled:
return
torch._inductor.config.triton.cudagraphs = False
for layer in self.layers.values():
layer.compile(dynamic=True, mode="default")
self._is_compiled = True
# Tokenizer
def _get_tokenizer(self):
if not hasattr(self, "_tokenizer"):
import os
path = self.config._name_or_path
is_local = os.path.exists(path)
self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True)
for token_name, token in self._tokenizer.special_tokens_map.items():
if isinstance(token, str):
setattr(self._tokenizer, token_name, token)
setattr(
self._tokenizer, token_name + "_id",
self._tokenizer.convert_tokens_to_ids(token),
)
return self._tokenizer
# Attention mask
def get_attention_mask(self, input_batch: T, max_len: int | None = None):
return create_batch_attention_mask(
input_batch,
pad_token_id=self._pad_token_id,
eos_token_id=self.config.eos_id,
soi_token_id=self.config.image_cls_token_id,
eoi_token_id=self.config.img_end_id,
max_len=max_len,
)
# Embedding helpers
def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS):
B, S, D = h_BSD.shape
pixel_patch_mask = E.reduce(
pixel_masks_NTHW,
"n (t pt) (h ph) (w pw) -> (n t h w)",
reduction="any",
pt=self.config.temporal_patch_size,
ph=self.config.spatial_patch_size,
pw=self.config.spatial_patch_size,
)
pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c")
valid_patches = pixel_patches_flat[pixel_patch_mask]
valid_feats = self.img_projector(valid_patches)
img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
assert valid_feats.numel() == img_mask_h_BSD.sum()
return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats)
# Core forward
def forward(
self,
tokens: T,
attention_mask: BlockMask,
kv_cache,
rope_pos_t: T | None = None,
rope_pos_hw: T | None = None,
pixel_values: T | None = None,
pixel_mask: T | None = None,
):
B, S = tokens.size()
c = self.config
block_mask = attention_mask
T_pos = kv_cache.get_pos()
is_prefill = S != 1
if is_prefill:
assert rope_pos_t is not None and rope_pos_hw is not None
pos_t = rope_pos_t[:, T_pos:T_pos + S].long()
kv_cache.pos_t = pos_t[:, -1:]
freqs_cis = self.freqs_cis[pos_t]
rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S]
freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw)
block_mask.seq_lengths = (S, S)
else:
pos_t = kv_cache.increment_and_get_pos_t()
freqs_cis = self.freqs_cis[pos_t]
freqs_cis_golden = None
block_idx = T_pos // block_mask.BLOCK_SIZE[0]
block_mask = block_mask[:, :, block_idx]
block_mask.seq_lengths = (S, T_pos + S)
block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos)
h_BSD = self.tok_embeddings(tokens)
if pixel_values is not None:
assert pixel_mask is not None
pixel_values = pixel_values.to(self.dtype)
pixel_mask = pixel_mask.to(self.dtype)
pixel_patches_NLC = E.rearrange(
pixel_values,
"n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)",
pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
)
h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens)
for layer in self.layers.values():
h_BSD = layer(
h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden,
pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache,
)
h_BSD = self.norm(h_BSD)
logits_BSV = self.output(h_BSD)
return logits_BSV
# Layout detection
def _load_layout_model(self, layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors"):
if hasattr(self, "_layout_model"):
return
import torchvision.transforms.functional as tvF
from transformers import AutoModelForObjectDetection, PPDocLayoutV3ImageProcessorFast
self._layout_processor = PPDocLayoutV3ImageProcessorFast.from_pretrained(layout_model)
self._layout_det_model = AutoModelForObjectDetection.from_pretrained(
layout_model, torch_dtype=torch.float16,
).to(self.device).eval()
self._layout_id2label = self._layout_det_model.config.id2label
self._tvF = tvF
@torch.inference_mode()
def _run_layout_detection(
self, images: list[Image.Image], threshold: float = 0.5,
) -> list[list[dict]]:
"""Run PP-DocLayoutV3 on a batch of PIL images, return per-image detections."""
device = self.device
tvF = self._tvF
target_sizes = torch.tensor([img.size[::-1] for img in images])
tensors = [tvF.pil_to_tensor(img) for img in images]
# GPU-accelerated resize + normalize
result = torch.empty(
len(tensors), 3, _LAYOUT_TARGET_H, _LAYOUT_TARGET_W,
dtype=torch.float16, device=device,
)
size_groups: dict[tuple[int, int], list[int]] = {}
for i, t in enumerate(tensors):
size_groups.setdefault((t.shape[1], t.shape[2]), []).append(i)
for shape, indices in size_groups.items():
batch = torch.stack([tensors[i] for i in indices])
batch = batch.to(device=device, dtype=torch.float32, non_blocking=True)
batch = F.interpolate(
batch, size=(_LAYOUT_TARGET_H, _LAYOUT_TARGET_W),
mode="bicubic", align_corners=False, antialias=False,
)
batch = (batch.clamp_(0, 255) / 255.0).to(torch.float16)
for j, idx in enumerate(indices):
result[idx] = batch[j]
del batch
outputs = self._layout_det_model(pixel_values=result)
del result
# Postprocess on GPU
logits = outputs.logits
boxes = outputs.pred_boxes
order_logits = outputs.order_logits
box_centers, box_dims = boxes.split(2, dim=-1)
boxes_xyxy = torch.cat([box_centers - 0.5 * box_dims, box_centers + 0.5 * box_dims], dim=-1)
img_h, img_w = target_sizes.unbind(1)
scale = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(device, dtype=boxes_xyxy.dtype)
boxes_xyxy = boxes_xyxy * scale[:, None, :]
num_queries = logits.shape[1]
num_classes = logits.shape[2]
scores = logits.sigmoid()
scores_flat, index = scores.flatten(1).topk(num_queries, dim=-1)
labels = index % num_classes
box_indices = index // num_classes
boxes_xyxy = boxes_xyxy.gather(dim=1, index=box_indices.unsqueeze(-1).expand(-1, -1, 4))
order_seqs = self._layout_processor._get_order_seqs(order_logits)
order_seqs = order_seqs.gather(dim=1, index=box_indices)
batch_results = []
for s, l, b, o in zip(scores_flat, labels, boxes_xyxy, order_seqs):
mask = s >= threshold
o_valid = o[mask]
_, indices_sorted = o_valid.sort()
detections = []
for si, li, bi in zip(s[mask][indices_sorted], l[mask][indices_sorted], b[mask][indices_sorted]):
detections.append({
"category": self._layout_id2label[li.item()],
"bbox": [round(x, 2) for x in bi.tolist()],
"score": round(si.item(), 4),
})
batch_results.append(detections)
return batch_results
# Core batch decode (shared by generate & generate_with_layout)
def _generate_batch(
self,
image_prompt_pairs: list[tuple],
*,
max_new_tokens: int,
temperature: float,
top_k: int | None,
min_dimension: int,
max_dimension: int,
seed: int | None,
) -> list[str]:
"""Core autoregressive decode for a list of (image, prompt) pairs."""
device = self.device
tokenizer = self._get_tokenizer()
self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")]
batch_inputs = process_batch(
tokenizer, self.config, image_prompt_pairs,
max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension,
)
batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()}
tokens = batch_inputs["tokens"]
B, L = tokens.size()
block_size = 128
S = (L + max_new_tokens + block_size - 1) // block_size * block_size
assert S <= self.config.max_seq_len
rng = torch.Generator(device).manual_seed(seed) if seed is not None else None
kv_cache = KVCache(
max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads,
head_dim=self.config.head_dim, num_layers=self.config.n_layers,
)
padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device)
padded_tokens[:, :L] = tokens
attention_mask = self.get_attention_mask(padded_tokens, max_len=S)
logits_BSV = self.forward(
tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"],
attention_mask=attention_mask, kv_cache=kv_cache,
pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"],
)
stop_ids = torch.tensor(stop_token_ids).to(device)
should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device)
generated_ids: list[list[int]] = [[] for _ in range(B)]
while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S:
tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k)
if torch.any(should_stop_B):
tokens_B1 = tokens_B1.clone()
tokens_B1[should_stop_B, :] = self._pad_token_id
padded_tokens[:, pos] = tokens_B1[:, -1]
for b in range(B):
if not should_stop_B[b]:
generated_ids[b].append(tokens_B1[b, 0].item())
logits_BSV = self.forward(
tokens=tokens_B1, attention_mask=attention_mask, kv_cache=kv_cache,
)
hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
should_stop_B = should_stop_B.logical_or(hit_stop_B)
results = []
for b in range(B):
text = tokenizer.decode(generated_ids[b], skip_special_tokens=False)
text = text.replace("<|end_of_query|>", "").replace("<|end_of_text|>", "").strip()
results.append(text)
return results
# Main API: generate
@torch.inference_mode()
def generate(
self,
images,
*,
category: str | list[str] = "plain",
max_new_tokens: int = 4096,
temperature: float = 0.0,
top_k: int | None = None,
min_dimension: int = 64,
max_dimension: int = 1024,
compile: bool = True,
seed: int | None = 42,
) -> list[str]:
"""
Extract text from document images.
Args:
images: Single PIL Image (or path/URL) or list of them.
category: OCR category — one of "plain", "text", "table", "formula",
"caption", "footnote", "list-item", "page-footer", "page-header",
"section-header", "title". Can be a single string (applied to all
images) or a list (one per image).
max_new_tokens: Maximum generation steps.
temperature: Sampling temperature (0.0 = greedy).
top_k: Top-k sampling (None = disabled).
min_dimension: Min image side after resize.
max_dimension: Max image side after resize.
compile: Whether to torch.compile on first call.
seed: Random seed for reproducibility (None = non-deterministic).
Returns:
List of extracted text strings, one per image.
"""
self._ensure_device_buffers()
if compile:
self.compile_model()
if isinstance(images, (str, Path, Image.Image)):
images = [images]
if isinstance(category, str):
category = [category] * len(images)
assert len(images) == len(category), "Must provide one category per image"
image_prompt_pairs = []
for img, cat in zip(images, category):
instruction = CATEGORY_PROMPTS.get(cat.strip().lower(), CATEGORY_PROMPTS["plain"])
prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>"
image_prompt_pairs.append((img, prompt))
return self._generate_batch(
image_prompt_pairs,
max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k,
min_dimension=min_dimension, max_dimension=max_dimension, seed=seed,
)
# Main API: generate_with_layout
@torch.inference_mode()
def generate_with_layout(
self,
images,
*,
max_new_tokens: int = 4096,
temperature: float = 0.0,
top_k: int | None = None,
min_dimension: int = 64,
max_dimension: int = 1024,
compile: bool = True,
seed: int | None = 42,
layout_threshold: float = 0.3,
layout_batch_size: int = 4,
ocr_batch_size: int = 32,
containment_threshold: float = 0.8,
layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors",
) -> list[list[dict]]:
"""
Run layout detection then OCR on each detected region.
Args:
images: Single PIL Image (or path/URL) or list of them.
max_new_tokens: Maximum generation steps per crop.
temperature: Sampling temperature (0.0 = greedy).
top_k: Top-k sampling (None = disabled).
min_dimension: Min crop side after resize for OCR.
max_dimension: Max crop side after resize for OCR.
compile: Whether to torch.compile on first call.
seed: Random seed for reproducibility.
layout_threshold: Confidence threshold for layout detections.
layout_batch_size: Batch size for layout detection.
ocr_batch_size: Batch size for OCR generation (chunks crops).
containment_threshold: Drop formula boxes >threshold contained in text boxes.
layout_model: HuggingFace model ID for layout detection.
Returns:
Per-image list of detections, each a dict with keys:
``category``, ``bbox`` [x1,y1,x2,y2], ``score``, ``text``.
"""
self._ensure_device_buffers()
if compile:
self.compile_model()
self._load_layout_model(layout_model)
if isinstance(images, (str, Path, Image.Image)):
images = [images]
pil_images = [load_image(img).convert("RGB") for img in images]
# --- Layout detection (batched) ---
all_layout_dets: list[list[dict]] = []
for i in range(0, len(pil_images), layout_batch_size):
batch_imgs = pil_images[i : i + layout_batch_size]
dets = self._run_layout_detection(batch_imgs, threshold=layout_threshold)
all_layout_dets.extend(dets)
# --- Filter nested boxes (e.g. inline formulas inside text) ---
all_layout_dets = [
_filter_nested_detections(dets, containment_threshold)
for dets in all_layout_dets
]
# --- Build crops + track origin ---
flat_crops: list[tuple[Image.Image, str]] = []
crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"):
prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>"
flat_crops.append((pil_img, prompt))
crop_origins.append((img_idx, -1))
continue
img_w, img_h = pil_img.size
for det_idx, det in enumerate(dets):
cat_key = det["category"].strip().lower()
ocr_cat = LAYOUT_TO_OCR_CATEGORY.get(cat_key)
if ocr_cat is None:
continue
x1, y1, x2, y2 = det["bbox"]
x1 = max(0, int(x1))
y1 = max(0, int(y1))
x2 = min(img_w, int(x2 + 0.5))
y2 = min(img_h, int(y2 + 0.5))
cw, ch = x2 - x1, y2 - y1
if cw < _MIN_CROP_DIM or ch < _MIN_CROP_DIM:
continue
short, long = sorted((cw, ch))
resized_short = short * (max_dimension / long) if long > max_dimension else short
if resized_short < _MIN_CROP_DIM:
continue
crop = pil_img.crop((x1, y1, x2, y2))
instruction = CATEGORY_PROMPTS.get(ocr_cat, CATEGORY_PROMPTS["plain"])
prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>"
flat_crops.append((crop, prompt))
crop_origins.append((img_idx, det_idx))
# --- OCR in chunks ---
flat_texts: list[str] = []
for i in range(0, max(len(flat_crops), 1), ocr_batch_size):
chunk = flat_crops[i : i + ocr_batch_size]
if not chunk:
break
texts = self._generate_batch(
chunk,
max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k,
min_dimension=min_dimension, max_dimension=max_dimension, seed=seed,
)
flat_texts.extend(texts)
# --- Reassemble per-image results ---
results: list[list[dict]] = [[] for _ in range(len(pil_images))]
for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
if det_idx == -1:
img_w, img_h = pil_images[img_idx].size
results[img_idx].append({
"category": "plain",
"bbox": [0, 0, img_w, img_h],
"score": 1.0,
"text": text,
})
else:
det = all_layout_dets[img_idx][det_idx]
results[img_idx].append({
"category": det["category"],
"bbox": det["bbox"],
"score": det["score"],
"text": text,
})
return results