Commit ·
ed18782
0
Parent(s):
Duplicate from tiiuae/Falcon-OCR
Browse filesCo-authored-by: Yasser Dahou <yasserDahou@users.noreply.huggingface.co>
- .eval_results/olmocrbench.yaml +89 -0
- .gitattributes +35 -0
- README.md +456 -0
- attention.py +129 -0
- config.json +32 -0
- configuration_falcon_ocr.py +65 -0
- model.safetensors +3 -0
- model_args.json +37 -0
- modeling_falcon_ocr.py +845 -0
- processing_falcon_ocr.py +423 -0
- rope.py +127 -0
- special_tokens_map.json +390 -0
- tokenizer.json +0 -0
- tokenizer_config.json +110 -0
.eval_results/olmocrbench.yaml
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- dataset:
|
| 2 |
+
id: allenai/olmOCR-bench
|
| 3 |
+
task_id: overall
|
| 4 |
+
value: 80.3
|
| 5 |
+
source:
|
| 6 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 7 |
+
name: Falcon-OCR Model Card
|
| 8 |
+
user: nielsr
|
| 9 |
+
notes: English-subset only
|
| 10 |
+
|
| 11 |
+
- dataset:
|
| 12 |
+
id: allenai/olmOCR-bench
|
| 13 |
+
task_id: arxiv_math
|
| 14 |
+
value: 80.5
|
| 15 |
+
source:
|
| 16 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 17 |
+
name: Falcon-OCR Model Card
|
| 18 |
+
user: nielsr
|
| 19 |
+
notes: English-subset only
|
| 20 |
+
|
| 21 |
+
- dataset:
|
| 22 |
+
id: allenai/olmOCR-bench
|
| 23 |
+
task_id: old_scans_math
|
| 24 |
+
value: 69.2
|
| 25 |
+
source:
|
| 26 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 27 |
+
name: Falcon-OCR Model Card
|
| 28 |
+
user: nielsr
|
| 29 |
+
notes: English-subset only
|
| 30 |
+
|
| 31 |
+
- dataset:
|
| 32 |
+
id: allenai/olmOCR-bench
|
| 33 |
+
task_id: table_tests
|
| 34 |
+
value: 90.3
|
| 35 |
+
source:
|
| 36 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 37 |
+
name: Falcon-OCR Model Card
|
| 38 |
+
user: nielsr
|
| 39 |
+
notes: English-subset only
|
| 40 |
+
|
| 41 |
+
- dataset:
|
| 42 |
+
id: allenai/olmOCR-bench
|
| 43 |
+
task_id: old_scans
|
| 44 |
+
value: 43.5
|
| 45 |
+
source:
|
| 46 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 47 |
+
name: Falcon-OCR Model Card
|
| 48 |
+
user: nielsr
|
| 49 |
+
notes: English-subset only
|
| 50 |
+
|
| 51 |
+
- dataset:
|
| 52 |
+
id: allenai/olmOCR-bench
|
| 53 |
+
task_id: headers_footers
|
| 54 |
+
value: 94.0
|
| 55 |
+
source:
|
| 56 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 57 |
+
name: Falcon-OCR Model Card
|
| 58 |
+
user: nielsr
|
| 59 |
+
notes: English-subset only
|
| 60 |
+
|
| 61 |
+
- dataset:
|
| 62 |
+
id: allenai/olmOCR-bench
|
| 63 |
+
task_id: multi_column
|
| 64 |
+
value: 87.1
|
| 65 |
+
source:
|
| 66 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 67 |
+
name: Falcon-OCR Model Card
|
| 68 |
+
user: nielsr
|
| 69 |
+
notes: English-subset only
|
| 70 |
+
|
| 71 |
+
- dataset:
|
| 72 |
+
id: allenai/olmOCR-bench
|
| 73 |
+
task_id: long_tiny_text
|
| 74 |
+
value: 78.5
|
| 75 |
+
source:
|
| 76 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 77 |
+
name: Falcon-OCR Model Card
|
| 78 |
+
user: nielsr
|
| 79 |
+
notes: English-subset only
|
| 80 |
+
|
| 81 |
+
- dataset:
|
| 82 |
+
id: allenai/olmOCR-bench
|
| 83 |
+
task_id: baseline
|
| 84 |
+
value: 99.5
|
| 85 |
+
source:
|
| 86 |
+
url: https://huggingface.co/tiiuae/Falcon-OCR
|
| 87 |
+
name: Falcon-OCR Model Card
|
| 88 |
+
user: nielsr
|
| 89 |
+
notes: English-subset only
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
pipeline_tag: image-to-text
|
| 3 |
+
library_name: transformers
|
| 4 |
+
tags:
|
| 5 |
+
- falcon
|
| 6 |
+
- ocr
|
| 7 |
+
- vision-language
|
| 8 |
+
- document-understanding
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
<div style="width: 480px; text-align: left;">
|
| 13 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/663c9939c1b4f7297c4ae6f6/YIuxzgDiV5T2ZuSB4bam9.png" alt="Falcon OCR Logo" style="max-width: 100%; height: auto;">
|
| 14 |
+
</div>
|
| 15 |
+
|
| 16 |
+
# Falcon OCR
|
| 17 |
+
|
| 18 |
+
Falcon OCR is a 300M parameter early-fusion vision-language model for document OCR. Given an image, it can produce plain text, LaTeX for formulas, or HTML for tables, depending on the requested output format.
|
| 19 |
+
|
| 20 |
+
Most OCR VLM systems are built as a pipeline with a vision encoder feeding a separate text decoder, plus additional task-specific glue. Falcon OCR takes a different approach: a single Transformer processes image patches and text tokens in a shared parameter space from the first layer, using a hybrid attention mask where image tokens attend bidirectionally and text tokens decode causally conditioned on the image.
|
| 21 |
+
|
| 22 |
+
We built it this way for two practical reasons. First, it keeps the interface simple: one backbone, one decoding path, and task switching through prompts rather than a growing set of modules. Second, a 0.3B model has a lower latency and cost footprint than 0.9B-class OCR VLMs, and in our vLLM-based serving setup this translates into higher throughput, often 2–3× faster depending on sequence lengths and batch configuration. To our knowledge, this is one of the first attempts to apply this early-fusion single-stack recipe directly to competitive document OCR at this scale.
|
| 23 |
+
|
| 24 |
+
### Links
|
| 25 |
+
|
| 26 |
+
- Code and inference engine: [https://github.com/tiiuae/Falcon-Perception](https://ghcr.io/tiiuae/falcon-ocr:latest)
|
| 27 |
+
- Tech report: arXiv link coming soon
|
| 28 |
+
- Perception model: `tiiuae/falcon-perception`
|
| 29 |
+
- vLLM/Docker: [https://ghcr.io/tiiuae/falcon-ocr:latest](https://ghcr.io/tiiuae/falcon-ocr:latest)
|
| 30 |
+
|
| 31 |
+
## Quickstart
|
| 32 |
+
|
| 33 |
+
### Installation
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
pip install "torch>=2.5" transformers pillow einops
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Falcon OCR requires PyTorch 2.5 or newer for FlexAttention. The first call may be slower as `torch.compile` builds optimized kernels.
|
| 40 |
+
|
| 41 |
+
### Single-Image OCR
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
import torch
|
| 45 |
+
from PIL import Image
|
| 46 |
+
from transformers import AutoModelForCausalLM
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 48 |
+
"tiiuae/Falcon-OCR",
|
| 49 |
+
trust_remote_code=True,
|
| 50 |
+
torch_dtype=torch.bfloat16,
|
| 51 |
+
device_map="auto",
|
| 52 |
+
)
|
| 53 |
+
image = Image.open("document.png")
|
| 54 |
+
texts = model.generate(image) # default category is "plain"
|
| 55 |
+
print(texts[0])
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Choose an output format with `category`
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
texts = model.generate(image, category="text") # plain text
|
| 62 |
+
texts = model.generate(image, category="formula") # LaTeX
|
| 63 |
+
texts = model.generate(image, category="table") # HTML table
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## API
|
| 67 |
+
|
| 68 |
+
### `model.generate(images, category="plain", **kwargs)`
|
| 69 |
+
- **Inputs**:
|
| 70 |
+
- `images`: a `PIL.Image.Image` or a list of images
|
| 71 |
+
- `category`: one of `plain`, `text`, `table`, `formula`, `caption`, `footnote`, `list-item`, `page-footer`, `page-header`, `section-header`, `title`
|
| 72 |
+
- **Returns**: `list[str]`, one extracted string per image
|
| 73 |
+
|
| 74 |
+
## Layout OCR (Two-Stage Pipeline)
|
| 75 |
+
For sparse documents, running OCR on the whole image can work well. For dense documents with heterogeneous regions (multi-column layouts, interleaved tables and formulas, small captions), we provide an optional two-stage pipeline:
|
| 76 |
+
1. A layout detector finds regions on the page.
|
| 77 |
+
2. Falcon OCR runs independently on each crop with a category-specific prompt.
|
| 78 |
+
We use [PP-DocLayoutV3](https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors) as the layout detector.
|
| 79 |
+
```python
|
| 80 |
+
results = model.generate_with_layout(image)
|
| 81 |
+
for det in results[0]:
|
| 82 |
+
print(f"[{det['category']}] {det['text'][:100]}...")
|
| 83 |
+
```
|
| 84 |
+
Batch mode:
|
| 85 |
+
```python
|
| 86 |
+
results = model.generate_with_layout(
|
| 87 |
+
[Image.open("page1.png"), Image.open("page2.png")],
|
| 88 |
+
ocr_batch_size=32,
|
| 89 |
+
)
|
| 90 |
+
```
|
| 91 |
+
The layout model is loaded lazily on the first `generate_with_layout()` call and runs on the same GPU as the OCR model.
|
| 92 |
+
**Returns**: `list[list[dict]]`, one list per image, in reading order:
|
| 93 |
+
```python
|
| 94 |
+
{
|
| 95 |
+
"category": "text", # layout category
|
| 96 |
+
"bbox": [x1, y1, x2, y2], # in original image pixels
|
| 97 |
+
"score": 0.93, # detection confidence
|
| 98 |
+
"text": "..." # extracted text
|
| 99 |
+
}
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
## When to Use What
|
| 103 |
+
|
| 104 |
+
| Mode | Best for | How |
|
| 105 |
+
|------|----------|-----|
|
| 106 |
+
| **Plain OCR** | Simple documents, real-world photos, slides, receipts, invoices | `model.generate(image)` |
|
| 107 |
+
| **Layout + OCR** | Complex multi-column documents, academic papers, reports, dense pages like newspapers | `model.generate_with_layout(image)` |
|
| 108 |
+
|
| 109 |
+
## Benchmark Results
|
| 110 |
+
<details name="benchmarks" open>
|
| 111 |
+
<summary><b>olmOCR Benchmark</b></summary>
|
| 112 |
+
|
| 113 |
+
Category-wise performance comparison of FalconOCR against state-of-the-art OCR models. We report accuracy (%) across all category splits.
|
| 114 |
+
|
| 115 |
+
<table>
|
| 116 |
+
<tr><th>Model</th><th>Average</th><th>ArXiv Math</th><th>Base</th><th>Hdr/Ftr</th><th>TinyTxt</th><th>MultCol</th><th>OldScan</th><th>OldMath</th><th>Tables</th></tr>
|
| 117 |
+
<tr><td>Mistral OCR 3</td><td>81.7</td><td><b>85.4</b></td><td><b>99.9</b></td><td>93.8</td><td>88.9</td><td>82.1</td><td>48.8</td><td>68.3</td><td>86.1</td></tr>
|
| 118 |
+
<tr><td>Chandra</td><td><b>82.0</b></td><td>81.4</td><td>99.8</td><td>88.8</td><td><b>91.9</b></td><td>82.9</td><td><b>49.2</b></td><td>73.6</td><td>88.2</td></tr>
|
| 119 |
+
<tr><td>Gemini 3 Pro</td><td>80.2</td><td>70.6</td><td>99.8</td><td>84.0</td><td>90.3</td><td>79.2</td><td>47.5</td><td>84.9</td><td>84.9</td></tr>
|
| 120 |
+
<tr><td>PaddleOCR VL 1.5</td><td>79.3</td><td><b>85.4</b></td><td>98.8</td><td><b>96.9</b></td><td>80.8</td><td>82.6</td><td>39.2</td><td>66.4</td><td>84.1</td></tr>
|
| 121 |
+
<tr><td>PaddleOCR VL</td><td>79.2</td><td><b>85.4</b></td><td>98.6</td><td><b>96.9</b></td><td>80.8</td><td>82.5</td><td>38.8</td><td>66.4</td><td>83.9</td></tr>
|
| 122 |
+
<tr><td>DeepSeek OCR v2</td><td>78.8</td><td>81.9</td><td>99.8</td><td>95.6</td><td>88.7</td><td>83.6</td><td>33.7</td><td>68.8</td><td>78.1</td></tr>
|
| 123 |
+
<tr><td>Gemini 3 Flash</td><td>77.5</td><td>66.5</td><td>99.8</td><td>83.8</td><td>88.2</td><td>73.7</td><td>46.0</td><td><b>85.8</b></td><td>75.9</td></tr>
|
| 124 |
+
<tr><td>GPT 5.2</td><td>69.8</td><td>61.0</td><td>99.8</td><td>75.6</td><td>62.2</td><td>70.2</td><td>34.6</td><td>75.8</td><td>79.0</td></tr>
|
| 125 |
+
<tr style="background:#a358e5; color:white"><td><b>FalconOCR</b></td><td>80.3</td><td>80.5</td><td>99.5</td><td>94.0</td><td>78.5</td><td><b>87.1</b></td><td>43.5</td><td>69.2</td><td><b>90.3</b></td></tr>
|
| 126 |
+
</table>
|
| 127 |
+
|
| 128 |
+
</details>
|
| 129 |
+
|
| 130 |
+
<details name="benchmarks">
|
| 131 |
+
<summary><b>OmniDocBench</b></summary>
|
| 132 |
+
|
| 133 |
+
Performance comparison on full-page document parsing. Overall↑ aggregates the three sub-metrics. Edit↓ measures text edit distance (lower is better). CDM↑ evaluates formula recognition accuracy. TEDS↑ measures table structure similarity.
|
| 134 |
+
|
| 135 |
+
<table>
|
| 136 |
+
<tr><th>Model</th><th>Overall↑</th><th>Edit↓</th><th>CDM↑</th><th>TEDS↑</th></tr>
|
| 137 |
+
<tr><td>PaddleOCR VL 1.5</td><td><b>94.37</b></td><td>0.025</td><td><b>94.4</b></td><td><b>91.1</b></td></tr>
|
| 138 |
+
<tr><td>PaddleOCR VL</td><td>91.76</td><td><b>0.024</b></td><td>91.7</td><td>85.9</td></tr>
|
| 139 |
+
<tr><td>Chandra</td><td>88.97</td><td>0.046</td><td>88.1</td><td>89.5</td></tr>
|
| 140 |
+
<tr><td>DeepSeek OCR v2</td><td>87.66</td><td>0.037</td><td>89.2</td><td>77.5</td></tr>
|
| 141 |
+
<tr><td>GPT 5.2</td><td>86.56</td><td>0.061</td><td>88.0</td><td>77.7</td></tr>
|
| 142 |
+
<tr><td>Mistral OCR 3</td><td>85.20</td><td>0.053</td><td>84.3</td><td>76.1</td></tr>
|
| 143 |
+
<tr style="background:#a358e5; color:white"><td><b>FalconOCR</b></td><td>88.64</td><td>0.055</td><td>86.8</td><td>84.6</td></tr>
|
| 144 |
+
</table>
|
| 145 |
+
|
| 146 |
+
</details>
|
| 147 |
+
|
| 148 |
+
### Results Analysis
|
| 149 |
+
|
| 150 |
+
First, a compact model can be competitive when the interface is simple and the training signal is targeted. On olmOCR, Falcon OCR performs strongly on multi-column documents and tables, and is competitive overall against substantially larger systems. Second, evaluation on full-page parsing is sensitive to matching and representation details. On OmniDocBench, the table and formula metrics depend not only on recognition quality but also on how predicted elements are matched to ground truth and how output structure is normalized.
|
| 151 |
+
|
| 152 |
+
More broadly, these results suggest that an early-fusion single-stack Transformer can be a viable alternative to the common "vision encoder plus text decoder" recipe for OCR. We do not view this as a finished answer, but as a promising direction: one early-fusion backbone, a shared parameter space between text and images, a single decoding interface, and better data and training signals, rather than increasingly complex pipelines. To our knowledge, this is among the first demonstrations that this early-fusion recipe can reach competitive document OCR accuracy at this scale, and we hope it encourages further work in this direction.
|
| 153 |
+
|
| 154 |
+
## Serving Throughput
|
| 155 |
+
|
| 156 |
+
Measured on a single A100-80GB GPU with vLLM, processing document images from olmOCR-Bench under high concurrency for optimal vLLM utilization.
|
| 157 |
+
|
| 158 |
+
<!-- We benchmark two modes to isolate different parts of the pipeline: -->
|
| 159 |
+
|
| 160 |
+
<!-- - **Cropped regions** — A layout detector is run offline first to extract all regions from every page. Only the resulting crops are sent to the VLLM . This measures pure VLLM throughput with no layout overhead. -->
|
| 161 |
+
- **Layout + OCR** — The full end-to-end pipeline: layout detection finds regions on each page, crops them, and vLLM runs OCR on every crop. This represents the real-world serving throughput, inclusive of both layout detection and OCR time.
|
| 162 |
+
|
| 163 |
+
| Mode | tok/s | img/s | Description |
|
| 164 |
+
|------|------:|------:|-------------|
|
| 165 |
+
| **Layout + OCR** | 5,825 | 2.9 | Full pipeline: layout detection → crop → per-region OCR |
|
| 166 |
+
<!-- | **Plain OCR** | 6,076 | 43.7 | plain OCR, no layout step | -->
|
| 167 |
+
|
| 168 |
+
At 0.3B parameters, Falcon OCR is roughly 3× smaller than 0.9B-class OCR VLMs (e.g., PaddleOCR VL), which translates directly into higher serving throughput at competitive accuracy.
|
| 169 |
+
|
| 170 |
+
## Limitations
|
| 171 |
+
|
| 172 |
+
- **Old scans and tiny text**: Heavily degraded scans and very small glyphs remain challenging. These cases often require higher effective resolution and better coverage in the training mixture.
|
| 173 |
+
- **Non-unique table representations**: Visually identical tables can be encoded in structurally different HTML forms, which can affect tree-based metrics.
|
| 174 |
+
- **Formula matching sensitivity**: LaTeX and Unicode conventions can be penalized differently depending on the benchmark normalization and matching pipeline.
|
| 175 |
+
|
| 176 |
+
## Examples
|
| 177 |
+
|
| 178 |
+
*Click each section below to expand.*
|
| 179 |
+
|
| 180 |
+
<details name="ocr-examples" open>
|
| 181 |
+
<summary><b>Handwriting and Real World Images</b></summary>
|
| 182 |
+
<p align="center">
|
| 183 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/51Fj1wxxtAV_jwubml6sa.png" width="600" alt="Handwriting and real world OCR examples" />
|
| 184 |
+
</p>
|
| 185 |
+
</details>
|
| 186 |
+
|
| 187 |
+
<details name="ocr-examples">
|
| 188 |
+
<summary><b>Tables</b></summary>
|
| 189 |
+
<p align="center">
|
| 190 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/2yZjZJAEHVVpd_jfyyDcQ.png" width="600" alt="Table OCR examples" />
|
| 191 |
+
</p>
|
| 192 |
+
</details>
|
| 193 |
+
|
| 194 |
+
<details name="ocr-examples">
|
| 195 |
+
<summary><b>Formulas</b></summary>
|
| 196 |
+
<p align="center">
|
| 197 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/__XMb0GyGO02IPKlQsPQx.png" width="600" alt="Formula OCR examples" />
|
| 198 |
+
</p>
|
| 199 |
+
</details>
|
| 200 |
+
|
| 201 |
+
<details name="ocr-examples">
|
| 202 |
+
<summary><b>Complex Layout</b></summary>
|
| 203 |
+
<p align="center">
|
| 204 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/kTR7nI7ogEqdI1SQtXpTu.png" width="600" alt="Complex layout OCR examples" />
|
| 205 |
+
</p>
|
| 206 |
+
</details>
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## vLLM Server
|
| 212 |
+
We also provide a Docker-based vLLM-backed inference server capable of serving approximately 6,000 tokens per second.
|
| 213 |
+
|
| 214 |
+
Single Docker image with two services:
|
| 215 |
+
|
| 216 |
+
| Service | Default Port | Description |
|
| 217 |
+
|---------|-------------|-------------|
|
| 218 |
+
| **vLLM** | 8000 | Falcon-OCR vision-language model (OpenAI-compatible API) |
|
| 219 |
+
| **Pipeline** | 5002 | Full document parsing: layout detection → crop → OCR → markdown |
|
| 220 |
+
|
| 221 |
+
The layout model runs inside the pipeline process — it is not a standalone service.
|
| 222 |
+
|
| 223 |
+
### Quick Start
|
| 224 |
+
|
| 225 |
+
```bash
|
| 226 |
+
docker run -d --name falcon-ocr \
|
| 227 |
+
--gpus '"device=0,1"' \
|
| 228 |
+
-e EXPOSED_GPU_IDS=0,1 \
|
| 229 |
+
-e VLLM_GPU=0 \
|
| 230 |
+
-e PIPELINE_GPU=1 \
|
| 231 |
+
-e VLLM_GPU_MEM_UTIL=0.90 \
|
| 232 |
+
-p 8000:8000 \
|
| 233 |
+
-p 5002:5002 \
|
| 234 |
+
ghcr.io/tiiuae/falcon-ocr:latest
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
### API
|
| 238 |
+
|
| 239 |
+
<details name="api" open>
|
| 240 |
+
<summary><b>Health Checks</b></summary>
|
| 241 |
+
|
| 242 |
+
```bash
|
| 243 |
+
curl http://localhost:8000/health # vLLM
|
| 244 |
+
curl http://localhost:5002/health # Pipeline
|
| 245 |
+
```
|
| 246 |
+
</details>
|
| 247 |
+
|
| 248 |
+
<details name="api">
|
| 249 |
+
<summary><b>Upload</b> (multipart file upload — images and PDFs)</summary>
|
| 250 |
+
|
| 251 |
+
The easiest way to send files. Supports images and multi-page PDFs:
|
| 252 |
+
|
| 253 |
+
```bash
|
| 254 |
+
# Single image
|
| 255 |
+
curl -X POST http://localhost:5002/falconocr/upload \
|
| 256 |
+
-F "files=@photo.jpg;type=image/jpeg"
|
| 257 |
+
# PDF document
|
| 258 |
+
curl -X POST http://localhost:5002/falconocr/upload \
|
| 259 |
+
-F "files=@document.pdf;type=application/pdf"
|
| 260 |
+
```
|
| 261 |
+
</details>
|
| 262 |
+
|
| 263 |
+
<details name="api">
|
| 264 |
+
<summary><b>Parse</b> (full pipeline: layout + OCR)</summary>
|
| 265 |
+
|
| 266 |
+
Send base64-encoded images for layout detection, cropping, and OCR:
|
| 267 |
+
|
| 268 |
+
```bash
|
| 269 |
+
curl -X POST http://localhost:5002/falconocr/parse \
|
| 270 |
+
-H "Content-Type: application/json" \
|
| 271 |
+
-d '{
|
| 272 |
+
"images": ["data:image/jpeg;base64,<...>"],
|
| 273 |
+
"skip_layout": false
|
| 274 |
+
}'
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
Response:
|
| 278 |
+
|
| 279 |
+
```json
|
| 280 |
+
{
|
| 281 |
+
"json_result": [[{
|
| 282 |
+
"index": 0,
|
| 283 |
+
"mapped_label": "text",
|
| 284 |
+
"content": "The Manuscript",
|
| 285 |
+
"bbox": [273, 273, 937, 380],
|
| 286 |
+
"score": 0.3145
|
| 287 |
+
}]],
|
| 288 |
+
"markdown_result": "The Manuscript",
|
| 289 |
+
"total_output_tokens": 93,
|
| 290 |
+
"processing_time_ms": 414
|
| 291 |
+
}
|
| 292 |
+
```
|
| 293 |
+
</details>
|
| 294 |
+
|
| 295 |
+
<details name="api">
|
| 296 |
+
<summary><b>Parse</b> (direct VLM, no layout)</summary>
|
| 297 |
+
|
| 298 |
+
Skip layout detection and send the full image directly to the VLM:
|
| 299 |
+
|
| 300 |
+
```bash
|
| 301 |
+
curl -X POST http://localhost:5002/falconocr/parse \
|
| 302 |
+
-H "Content-Type: application/json" \
|
| 303 |
+
-d '{
|
| 304 |
+
"images": ["data:image/jpeg;base64,<...>"],
|
| 305 |
+
"skip_layout": true
|
| 306 |
+
}'
|
| 307 |
+
```
|
| 308 |
+
</details>
|
| 309 |
+
|
| 310 |
+
<details name="api">
|
| 311 |
+
<summary><b>Direct vLLM</b> (OpenAI-compatible)</summary>
|
| 312 |
+
|
| 313 |
+
```bash
|
| 314 |
+
curl -X POST http://localhost:8000/v1/chat/completions \
|
| 315 |
+
-H "Content-Type: application/json" \
|
| 316 |
+
-d '{
|
| 317 |
+
"model": "falcon-ocr",
|
| 318 |
+
"messages": [{"role": "user", "content": [
|
| 319 |
+
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<...>"}},
|
| 320 |
+
{"type": "text", "text": "Extract the text content from this image.\n<|OCR_PLAIN|>"}
|
| 321 |
+
]}],
|
| 322 |
+
"max_tokens": 2048
|
| 323 |
+
}'
|
| 324 |
+
```
|
| 325 |
+
</details>
|
| 326 |
+
|
| 327 |
+
### Configuration
|
| 328 |
+
|
| 329 |
+
All settings are controlled via environment variables at `docker run` time.
|
| 330 |
+
|
| 331 |
+
<details name="config" open>
|
| 332 |
+
<summary><b>GPU Assignment</b></summary>
|
| 333 |
+
|
| 334 |
+
| Variable | Default | Description |
|
| 335 |
+
|----------|---------|-------------|
|
| 336 |
+
| `VLLM_GPU` | `0` | Host GPU ID for the vLLM process |
|
| 337 |
+
| `PIPELINE_GPU` | `0` | Host GPU ID for the pipeline (layout model) |
|
| 338 |
+
| `EXPOSED_GPU_IDS` | *(all visible)* | Comma-separated host GPU IDs passed via `--gpus` (for index remapping) |
|
| 339 |
+
</details>
|
| 340 |
+
|
| 341 |
+
<details name="config">
|
| 342 |
+
<summary><b>Port Assignment</b></summary>
|
| 343 |
+
|
| 344 |
+
| Variable | Default | Description |
|
| 345 |
+
|----------|---------|-------------|
|
| 346 |
+
| `VLLM_PORT` | `8000` | Port for the vLLM OpenAI-compatible API |
|
| 347 |
+
| `PIPELINE_PORT` | `5002` | Port for the pipeline API |
|
| 348 |
+
</details>
|
| 349 |
+
|
| 350 |
+
<details name="config">
|
| 351 |
+
<summary><b>vLLM Tuning</b></summary>
|
| 352 |
+
|
| 353 |
+
| Variable | Default | Description |
|
| 354 |
+
|----------|---------|-------------|
|
| 355 |
+
| `VLLM_GPU_MEM_UTIL` | `0.90` | Fraction of GPU memory vLLM can use |
|
| 356 |
+
| `MAX_NUM_SEQS` | `2048` | Max concurrent sequences in vLLM |
|
| 357 |
+
| `MAX_MODEL_LEN` | `8192` | Max model context length |
|
| 358 |
+
| `DTYPE` | `bfloat16` | Model dtype |
|
| 359 |
+
| `MAX_NUM_BATCHED_TOKENS` | *(auto)* | Max batched tokens per iteration |
|
| 360 |
+
| `CHUNKED_PREFILL` | `false` | Enable chunked prefill |
|
| 361 |
+
</details>
|
| 362 |
+
|
| 363 |
+
<details name="config">
|
| 364 |
+
<summary><b>Layout Model Tuning</b></summary>
|
| 365 |
+
|
| 366 |
+
| Variable | Default | Description |
|
| 367 |
+
|----------|---------|-------------|
|
| 368 |
+
| `LAYOUT_BATCH_SIZE` | `64` | Batch size for layout detection inference |
|
| 369 |
+
</details>
|
| 370 |
+
|
| 371 |
+
<details name="config">
|
| 372 |
+
<summary><b>Model Paths</b></summary>
|
| 373 |
+
|
| 374 |
+
| Variable | Default | Description |
|
| 375 |
+
|----------|---------|-------------|
|
| 376 |
+
| `FALCON_OCR_MODEL` | `/models/Falcon-OCR` | Path to Falcon-OCR VLM weights (inside container) |
|
| 377 |
+
| `SERVED_MODEL_NAME` | `falcon-ocr` | Model name exposed by vLLM API |
|
| 378 |
+
</details>
|
| 379 |
+
|
| 380 |
+
### Deployment Modes
|
| 381 |
+
|
| 382 |
+
<details name="deploy" open>
|
| 383 |
+
<summary><b>Two GPUs</b> (best throughput)</summary>
|
| 384 |
+
|
| 385 |
+
vLLM on one GPU, layout model on another — zero GPU contention:
|
| 386 |
+
|
| 387 |
+
```bash
|
| 388 |
+
docker run -d --name falcon-ocr \
|
| 389 |
+
--gpus '"device=3,4"' \
|
| 390 |
+
-e EXPOSED_GPU_IDS=3,4 \
|
| 391 |
+
-e VLLM_GPU=3 \
|
| 392 |
+
-e PIPELINE_GPU=4 \
|
| 393 |
+
-e VLLM_GPU_MEM_UTIL=0.90 \
|
| 394 |
+
-p 8000:8000 \
|
| 395 |
+
-p 5002:5002 \
|
| 396 |
+
ghcr.io/tiiuae/falcon-ocr:latest
|
| 397 |
+
```
|
| 398 |
+
</details>
|
| 399 |
+
|
| 400 |
+
<details name="deploy">
|
| 401 |
+
<summary><b>Single GPU</b> (memory sharing)</summary>
|
| 402 |
+
|
| 403 |
+
Both services share one GPU — tune `VLLM_GPU_MEM_UTIL` to leave room for the layout model:
|
| 404 |
+
|
| 405 |
+
```bash
|
| 406 |
+
docker run -d --name falcon-ocr \
|
| 407 |
+
--gpus '"device=0"' \
|
| 408 |
+
-e EXPOSED_GPU_IDS=0 \
|
| 409 |
+
-e VLLM_GPU=0 \
|
| 410 |
+
-e PIPELINE_GPU=0 \
|
| 411 |
+
-e VLLM_GPU_MEM_UTIL=0.55 \
|
| 412 |
+
-e LAYOUT_BATCH_SIZE=32 \
|
| 413 |
+
-e MAX_NUM_SEQS=512 \
|
| 414 |
+
-p 8000:8000 \
|
| 415 |
+
-p 5002:5002 \
|
| 416 |
+
ghcr.io/tiiuae/falcon-ocr:latest
|
| 417 |
+
```
|
| 418 |
+
</details>
|
| 419 |
+
|
| 420 |
+
<details name="deploy">
|
| 421 |
+
<summary><b>Custom Ports</b></summary>
|
| 422 |
+
|
| 423 |
+
```bash
|
| 424 |
+
docker run -d --name falcon-ocr \
|
| 425 |
+
--gpus '"device=0,1"' \
|
| 426 |
+
-e EXPOSED_GPU_IDS=0,1 \
|
| 427 |
+
-e VLLM_GPU=0 \
|
| 428 |
+
-e PIPELINE_GPU=1 \
|
| 429 |
+
-e VLLM_PORT=18000 \
|
| 430 |
+
-e PIPELINE_PORT=15002 \
|
| 431 |
+
-p 18000:18000 \
|
| 432 |
+
-p 15002:15002 \
|
| 433 |
+
ghcr.io/tiiuae/falcon-ocr:latest
|
| 434 |
+
```
|
| 435 |
+
|
| 436 |
+
Docker `--gpus "device=3,4"` makes the container see GPUs as local indices `0,1`.
|
| 437 |
+
`EXPOSED_GPU_IDS=3,4` allows you to reference host GPU IDs (`VLLM_GPU=3`, `PIPELINE_GPU=4`);
|
| 438 |
+
the entrypoint remaps them to the correct container-local indices.
|
| 439 |
+
</details>
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
## Citation
|
| 443 |
+
|
| 444 |
+
If you use Falcon OCR, please cite:
|
| 445 |
+
|
| 446 |
+
```bibtex
|
| 447 |
+
@misc{falconocr2026,
|
| 448 |
+
title = {Falcon OCR},
|
| 449 |
+
author = {TII Falcon Vision Team},
|
| 450 |
+
year = {2026},
|
| 451 |
+
howpublished = {arXiv preprint, link forthcoming},
|
| 452 |
+
note = {Code: https://github.com/tiiuae/Falcon-Perception},
|
| 453 |
+
}
|
| 454 |
+
```
|
| 455 |
+
|
| 456 |
+
|
attention.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor as T
|
| 3 |
+
from torch.nn.attention.flex_attention import (
|
| 4 |
+
BlockMask,
|
| 5 |
+
_mask_mod_signature,
|
| 6 |
+
and_masks,
|
| 7 |
+
create_block_mask,
|
| 8 |
+
flex_attention,
|
| 9 |
+
or_masks,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Two compiled variants of flex_attention
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# _decode: fullgraph=True, static shapes.
|
| 16 |
+
# Used for decode steps (S_q == 1) where shapes are fixed and
|
| 17 |
+
# the call will be captured inside a CUDA graph. fullgraph=True
|
| 18 |
+
# avoids graph breaks that would corrupt the capture.
|
| 19 |
+
#
|
| 20 |
+
# _prefill: dynamic=True, symbolic shapes.
|
| 21 |
+
# Used for prefill steps (S_q > 1) where the sequence length
|
| 22 |
+
# varies per image. dynamic=True lets one compiled graph handle
|
| 23 |
+
# all lengths without recompilation. Prefill is never inside a
|
| 24 |
+
# CUDA graph, so symbolic shape guards are fine.
|
| 25 |
+
compiled_flex_attn_decode = torch.compile(flex_attention, fullgraph=True)
|
| 26 |
+
compiled_flex_attn_prefill = torch.compile(flex_attention, dynamic=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def offset_mask_mod(mask_mod: _mask_mod_signature, offset: int):
|
| 30 |
+
"""Get a mask mod function with an offset applied to the query positions."""
|
| 31 |
+
|
| 32 |
+
def _mask_mod(b, h, q, kv):
|
| 33 |
+
return mask_mod(b, h, q + offset, kv)
|
| 34 |
+
|
| 35 |
+
return _mask_mod
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_causal_mask_mod() -> _mask_mod_signature:
|
| 39 |
+
"""Causal mask that prevents attention to future tokens."""
|
| 40 |
+
|
| 41 |
+
def _causal_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
|
| 42 |
+
return q_idx >= kv_idx
|
| 43 |
+
|
| 44 |
+
return _causal_mask
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_document_mask_mod(batch: T, eos_id: int) -> _mask_mod_signature:
|
| 48 |
+
"""Document mask: prevents attention across document boundaries (token IDs [B, S])."""
|
| 49 |
+
eos_mask = batch == eos_id
|
| 50 |
+
eos_mask[:, -1] = True
|
| 51 |
+
cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1)
|
| 52 |
+
sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32)
|
| 53 |
+
sequence_indices[:, 1:] = cumulative_mask[:, :-1]
|
| 54 |
+
|
| 55 |
+
def document_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
|
| 56 |
+
return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx]
|
| 57 |
+
|
| 58 |
+
return document_mask
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_non_left_pad_mask_mod(batch: T, pad_id: int) -> _mask_mod_signature:
|
| 62 |
+
"""Prevent model from attending to the left-padded token required for correct batch inference."""
|
| 63 |
+
|
| 64 |
+
non_pad_mask_id = torch.cumsum(batch != pad_id, dim=1)
|
| 65 |
+
|
| 66 |
+
# Left-most pad tokens have cumulative id == 0.
|
| 67 |
+
def mask_mod(b, h, q_idx, kv_idx):
|
| 68 |
+
return non_pad_mask_id[b, kv_idx] > 0
|
| 69 |
+
|
| 70 |
+
return mask_mod
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_image_prefix_mask_mod(
|
| 74 |
+
batch: T, soi_id: int, eoi_id: int
|
| 75 |
+
) -> _mask_mod_signature:
|
| 76 |
+
"""Image-prefix mask: tokens between SOI and EOI attend only within same image."""
|
| 77 |
+
soi_mask = batch == soi_id
|
| 78 |
+
eoi_mask = batch == eoi_id
|
| 79 |
+
acc_soi_mask = torch.cumsum(soi_mask, dim=1)
|
| 80 |
+
acc_eoi_mask = torch.cumsum(eoi_mask, dim=1)
|
| 81 |
+
img_mask = (acc_soi_mask - acc_eoi_mask) > 0
|
| 82 |
+
img_indices = acc_soi_mask * img_mask
|
| 83 |
+
|
| 84 |
+
def image_prefix_mask_mod(b, h, q_idx, kv_idx):
|
| 85 |
+
is_img_tokens = img_mask[b, q_idx] & img_mask[b, kv_idx]
|
| 86 |
+
is_same_image = img_indices[b, q_idx] == img_indices[b, kv_idx]
|
| 87 |
+
return is_img_tokens & is_same_image
|
| 88 |
+
|
| 89 |
+
return image_prefix_mask_mod
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
_compiled_create_block_mask = torch.compile(
|
| 93 |
+
create_block_mask, dynamic=True
|
| 94 |
+
) # reduce-overhead mode breaks manual CUDA graph capture (private streams)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@torch.inference_mode()
|
| 98 |
+
def create_attention_mask(*args, **kwargs) -> BlockMask:
|
| 99 |
+
"""Compiled for large masks; inference_mode avoids grad_mode recompiles."""
|
| 100 |
+
return _compiled_create_block_mask(*args, **kwargs)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_batch_attention_mask(
|
| 104 |
+
input_batch: T,
|
| 105 |
+
*,
|
| 106 |
+
pad_token_id: int,
|
| 107 |
+
eos_token_id: int,
|
| 108 |
+
soi_token_id: int,
|
| 109 |
+
eoi_token_id: int,
|
| 110 |
+
max_len: int | None = None,
|
| 111 |
+
) -> BlockMask:
|
| 112 |
+
"""Build the combined FlexAttention mask for the batch engine.
|
| 113 |
+
|
| 114 |
+
Composes causal + document + non-left-pad + image-prefix masks.
|
| 115 |
+
"""
|
| 116 |
+
B, S = input_batch.size()
|
| 117 |
+
block_causal_mask_mod = and_masks(
|
| 118 |
+
get_causal_mask_mod(),
|
| 119 |
+
get_document_mask_mod(input_batch, eos_token_id),
|
| 120 |
+
get_non_left_pad_mask_mod(input_batch, pad_token_id),
|
| 121 |
+
)
|
| 122 |
+
image_prefix_mask_mod = get_image_prefix_mask_mod(
|
| 123 |
+
batch=input_batch,
|
| 124 |
+
soi_id=soi_token_id,
|
| 125 |
+
eoi_id=eoi_token_id,
|
| 126 |
+
)
|
| 127 |
+
mask_mod = or_masks(image_prefix_mask_mod, block_causal_mask_mod)
|
| 128 |
+
max_len = max_len or S
|
| 129 |
+
return create_attention_mask(mask_mod, B, None, max_len, max_len)
|
config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"FalconOCRForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_falcon_ocr.FalconOCRConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_falcon_ocr.FalconOCRForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "falcon_ocr",
|
| 10 |
+
"torch_dtype": "float32",
|
| 11 |
+
"dim": 768,
|
| 12 |
+
"n_layers": 22,
|
| 13 |
+
"n_heads": 16,
|
| 14 |
+
"head_dim": 64,
|
| 15 |
+
"n_kv_heads": 8,
|
| 16 |
+
"vocab_size": 65536,
|
| 17 |
+
"ffn_dim": 2304,
|
| 18 |
+
"norm_eps": 1e-05,
|
| 19 |
+
"max_seq_len": 8192,
|
| 20 |
+
"rope_theta": 10000,
|
| 21 |
+
"channel_size": 3,
|
| 22 |
+
"spatial_patch_size": 16,
|
| 23 |
+
"temporal_patch_size": 1,
|
| 24 |
+
"eos_id": 11,
|
| 25 |
+
"img_id": 227,
|
| 26 |
+
"image_cls_token_id": 244,
|
| 27 |
+
"image_reg_1_token_id": 245,
|
| 28 |
+
"image_reg_2_token_id": 246,
|
| 29 |
+
"image_reg_3_token_id": 247,
|
| 30 |
+
"image_reg_4_token_id": 248,
|
| 31 |
+
"img_end_id": 230
|
| 32 |
+
}
|
configuration_falcon_ocr.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class FalconOCRConfig(PretrainedConfig):
|
| 5 |
+
model_type = "falcon_ocr"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
dim: int = 768,
|
| 10 |
+
n_layers: int = 22,
|
| 11 |
+
n_heads: int = 16,
|
| 12 |
+
head_dim: int = 64,
|
| 13 |
+
n_kv_heads: int = 8,
|
| 14 |
+
vocab_size: int = 65536,
|
| 15 |
+
ffn_dim: int = 2304,
|
| 16 |
+
norm_eps: float = 1e-5,
|
| 17 |
+
max_seq_len: int = 8192,
|
| 18 |
+
rope_theta: int = 10000,
|
| 19 |
+
channel_size: int = 3,
|
| 20 |
+
spatial_patch_size: int = 16,
|
| 21 |
+
temporal_patch_size: int = 1,
|
| 22 |
+
img_id: int = 227,
|
| 23 |
+
eos_id: int = 11,
|
| 24 |
+
image_cls_token_id: int = 244,
|
| 25 |
+
image_mask_token_id: int = 243,
|
| 26 |
+
image_reg_1_token_id: int = 245,
|
| 27 |
+
image_reg_2_token_id: int = 246,
|
| 28 |
+
image_reg_3_token_id: int = 247,
|
| 29 |
+
image_reg_4_token_id: int = 248,
|
| 30 |
+
img_start_id: int = 229,
|
| 31 |
+
img_end_id: int = 230,
|
| 32 |
+
img_row_sep_id: int = 228,
|
| 33 |
+
vid_start_id: int = 231,
|
| 34 |
+
vid_end_id: int = 232,
|
| 35 |
+
frame_sep_id: int = 233,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
self.dim = dim
|
| 39 |
+
self.n_layers = n_layers
|
| 40 |
+
self.n_heads = n_heads
|
| 41 |
+
self.head_dim = head_dim
|
| 42 |
+
self.n_kv_heads = n_kv_heads
|
| 43 |
+
self.vocab_size = vocab_size
|
| 44 |
+
self.ffn_dim = ffn_dim
|
| 45 |
+
self.norm_eps = norm_eps
|
| 46 |
+
self.max_seq_len = max_seq_len
|
| 47 |
+
self.rope_theta = rope_theta
|
| 48 |
+
self.channel_size = channel_size
|
| 49 |
+
self.spatial_patch_size = spatial_patch_size
|
| 50 |
+
self.temporal_patch_size = temporal_patch_size
|
| 51 |
+
self.img_id = img_id
|
| 52 |
+
self.eos_id = eos_id
|
| 53 |
+
self.image_cls_token_id = image_cls_token_id
|
| 54 |
+
self.image_mask_token_id = image_mask_token_id
|
| 55 |
+
self.image_reg_1_token_id = image_reg_1_token_id
|
| 56 |
+
self.image_reg_2_token_id = image_reg_2_token_id
|
| 57 |
+
self.image_reg_3_token_id = image_reg_3_token_id
|
| 58 |
+
self.image_reg_4_token_id = image_reg_4_token_id
|
| 59 |
+
self.img_start_id = img_start_id
|
| 60 |
+
self.img_end_id = img_end_id
|
| 61 |
+
self.img_row_sep_id = img_row_sep_id
|
| 62 |
+
self.vid_start_id = vid_start_id
|
| 63 |
+
self.vid_end_id = vid_end_id
|
| 64 |
+
self.frame_sep_id = frame_sep_id
|
| 65 |
+
super().__init__(**kwargs)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e7f73a508c050f3a4f0b5ce196dba36db896626e44c4f8976d3a3e3c18ceb2a
|
| 3 |
+
size 1079789440
|
model_args.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"channel_size": 3,
|
| 3 |
+
"coord_dec_dim": 8192,
|
| 4 |
+
"coord_enc_dim": 512,
|
| 5 |
+
"coord_out_dim": 2048,
|
| 6 |
+
"coord_token_id": 240,
|
| 7 |
+
"dim": 768,
|
| 8 |
+
"eos_id": 11,
|
| 9 |
+
"ffn_dim": 2304,
|
| 10 |
+
"head_dim": 64,
|
| 11 |
+
"image_cls_token_id": 244,
|
| 12 |
+
"image_reg_1_token_id": 245,
|
| 13 |
+
"image_reg_2_token_id": 246,
|
| 14 |
+
"image_reg_3_token_id": 247,
|
| 15 |
+
"image_reg_4_token_id": 248,
|
| 16 |
+
"img_end_id": 230,
|
| 17 |
+
"img_id": 227,
|
| 18 |
+
"img_row_sep_id": 228,
|
| 19 |
+
"img_start_id": 229,
|
| 20 |
+
"max_seq_len": 8192,
|
| 21 |
+
"n_heads": 16,
|
| 22 |
+
"n_kv_heads": 8,
|
| 23 |
+
"n_layers": 22,
|
| 24 |
+
"norm_eps": 1e-05,
|
| 25 |
+
"num_segm_layers": 3,
|
| 26 |
+
"perception_heads": false,
|
| 27 |
+
"rope_theta": 10000,
|
| 28 |
+
"seg_token_id": 262,
|
| 29 |
+
"segm_out_dim": 256,
|
| 30 |
+
"size_dec_dim": 8192,
|
| 31 |
+
"size_enc_dim": 512,
|
| 32 |
+
"size_out_dim": 2048,
|
| 33 |
+
"size_token_id": 241,
|
| 34 |
+
"spatial_patch_size": 16,
|
| 35 |
+
"temporal_patch_size": 1,
|
| 36 |
+
"vocab_size": 65536
|
| 37 |
+
}
|
modeling_falcon_ocr.py
ADDED
|
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import einops as E
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch import Tensor as T
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn.attention.flex_attention import (
|
| 12 |
+
AuxRequest,
|
| 13 |
+
BlockMask,
|
| 14 |
+
)
|
| 15 |
+
from transformers import AutoTokenizer, PreTrainedModel
|
| 16 |
+
|
| 17 |
+
from .attention import (
|
| 18 |
+
compiled_flex_attn_decode,
|
| 19 |
+
compiled_flex_attn_prefill,
|
| 20 |
+
create_batch_attention_mask,
|
| 21 |
+
offset_mask_mod,
|
| 22 |
+
)
|
| 23 |
+
from .configuration_falcon_ocr import FalconOCRConfig
|
| 24 |
+
from .processing_falcon_ocr import load_image, process_batch
|
| 25 |
+
from .rope import (
|
| 26 |
+
apply_3d_rotary_emb,
|
| 27 |
+
apply_golden_freqs_cis_to_visual_pos,
|
| 28 |
+
precompute_freqs_cis,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
CATEGORY_PROMPTS = {
|
| 33 |
+
"plain": "Extract the text content from this image.",
|
| 34 |
+
"formula": "Extract the formula content from this image.",
|
| 35 |
+
"table": "Extract the table content from this image.",
|
| 36 |
+
"text": "Extract the text content from this image.",
|
| 37 |
+
"caption": "Extract the caption content from this image.",
|
| 38 |
+
"footnote": "Extract the footnote content from this image.",
|
| 39 |
+
"list-item": "Extract the list-item content from this image.",
|
| 40 |
+
"page-footer": "Extract the page-footer content from this image.",
|
| 41 |
+
"page-header": "Extract the page-header content from this image.",
|
| 42 |
+
"section-header": "Extract the section-header content from this image.",
|
| 43 |
+
"title": "Extract the title content from this image.",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
LAYOUT_TO_OCR_CATEGORY: dict[str, str | None] = {
|
| 47 |
+
"text": "text",
|
| 48 |
+
"table": "table",
|
| 49 |
+
"formula": "formula",
|
| 50 |
+
"caption": "caption",
|
| 51 |
+
"footnote": "footnote",
|
| 52 |
+
"list-item": "list-item",
|
| 53 |
+
"title": "title",
|
| 54 |
+
"header": "text",
|
| 55 |
+
"footer": "page-footer",
|
| 56 |
+
"number": "text",
|
| 57 |
+
"figure_title": "caption",
|
| 58 |
+
"paragraph_title": "section-header",
|
| 59 |
+
"doc_title": "title",
|
| 60 |
+
"reference_content": "text",
|
| 61 |
+
"reference": "text",
|
| 62 |
+
"abstract": "text",
|
| 63 |
+
"aside_text": "text",
|
| 64 |
+
"content": "text",
|
| 65 |
+
"formula_number": "text",
|
| 66 |
+
"vision_footnote": "footnote",
|
| 67 |
+
"algorithm": "text",
|
| 68 |
+
"page-footer": "page-footer",
|
| 69 |
+
"page-header": "page-header",
|
| 70 |
+
"section-header": "section-header",
|
| 71 |
+
# Skip — no text to extract
|
| 72 |
+
"image": None,
|
| 73 |
+
"picture": None,
|
| 74 |
+
"figure": None,
|
| 75 |
+
"chart": None,
|
| 76 |
+
"seal": None,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
_LAYOUT_TARGET_H, _LAYOUT_TARGET_W = 800, 800
|
| 80 |
+
_MIN_CROP_DIM = 16
|
| 81 |
+
|
| 82 |
+
def _box_area(bbox):
|
| 83 |
+
return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1])
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _intersection_area(a, b):
|
| 87 |
+
return max(0, min(a[2], b[2]) - max(a[0], b[0])) * max(0, min(a[3], b[3]) - max(a[1], b[1]))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _containment_ratio(small, large):
|
| 91 |
+
area = _box_area(small)
|
| 92 |
+
if area <= 0:
|
| 93 |
+
return 0.0
|
| 94 |
+
return _intersection_area(small, large) / area
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _filter_nested_detections(detections: list[dict], containment_threshold: float = 0.8) -> list[dict]:
|
| 98 |
+
"""Remove any box that is mostly contained within a strictly larger box."""
|
| 99 |
+
areas = [_box_area(d["bbox"]) for d in detections]
|
| 100 |
+
keep = []
|
| 101 |
+
for i, det in enumerate(detections):
|
| 102 |
+
is_nested = False
|
| 103 |
+
for j, other in enumerate(detections):
|
| 104 |
+
if i == j:
|
| 105 |
+
continue
|
| 106 |
+
if areas[j] <= areas[i]:
|
| 107 |
+
continue
|
| 108 |
+
if _containment_ratio(det["bbox"], other["bbox"]) > containment_threshold:
|
| 109 |
+
is_nested = True
|
| 110 |
+
break
|
| 111 |
+
if not is_nested:
|
| 112 |
+
keep.append(det)
|
| 113 |
+
return keep
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Attention
|
| 117 |
+
|
| 118 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 119 |
+
B, S, H, D = x.shape
|
| 120 |
+
if n_rep == 1:
|
| 121 |
+
return x
|
| 122 |
+
return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Attention(nn.Module):
|
| 126 |
+
def __init__(self, config: FalconOCRConfig, layer_id: int):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.layer_id = layer_id
|
| 129 |
+
self.n_kv_heads = config.n_kv_heads or config.n_heads
|
| 130 |
+
self.n_rep = config.n_heads // self.n_kv_heads
|
| 131 |
+
self.head_dim = config.head_dim or config.dim // config.n_heads
|
| 132 |
+
self.q_dim = config.n_heads * self.head_dim
|
| 133 |
+
self.kv_dim = self.n_kv_heads * self.head_dim
|
| 134 |
+
|
| 135 |
+
self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False)
|
| 136 |
+
self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
|
| 137 |
+
self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
|
| 138 |
+
|
| 139 |
+
def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
|
| 140 |
+
qkv = self.wqkv(F.rms_norm(x, (x.size(-1),)))
|
| 141 |
+
xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
|
| 142 |
+
xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
|
| 143 |
+
xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
|
| 144 |
+
xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim)
|
| 145 |
+
xq = F.rms_norm(xq, (xq.size(-1),))
|
| 146 |
+
xk = F.rms_norm(xk, (xk.size(-1),))
|
| 147 |
+
xk = repeat_kv(xk, n_rep=self.n_rep)
|
| 148 |
+
xv = repeat_kv(xv, n_rep=self.n_rep)
|
| 149 |
+
return xq, xk, xv
|
| 150 |
+
|
| 151 |
+
def _post_attention(self, output: T, lse: T) -> T:
|
| 152 |
+
# Sink-based scaling: sigmoid(lse - sinks) * output
|
| 153 |
+
# equivalent to prepending a sink token to the input
|
| 154 |
+
sinks_BHS = self.sinks.view(1, -1, 1)
|
| 155 |
+
sink_scale = torch.sigmoid(lse - sinks_BHS)
|
| 156 |
+
output = (output * sink_scale.unsqueeze(-1)).to(output.dtype)
|
| 157 |
+
output = output.permute(0, 2, 1, 3).contiguous().flatten(2)
|
| 158 |
+
return self.wo(output)
|
| 159 |
+
|
| 160 |
+
def compile_attention(self, *, dynamic: bool = True, mode: str = "default"):
|
| 161 |
+
self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode)
|
| 162 |
+
self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode)
|
| 163 |
+
|
| 164 |
+
def forward(
|
| 165 |
+
self, x: T, attention_masks: BlockMask, freqs_cis: T,
|
| 166 |
+
freqs_cis_2d: T | None = None, pos_hw: T | None = None,
|
| 167 |
+
kv_cache=None, input_pos=None, batch_idx=None,
|
| 168 |
+
flex_attn_kernel_options=None,
|
| 169 |
+
):
|
| 170 |
+
xq, xk, xv = self._pre_attention_qkv(x)
|
| 171 |
+
xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw)
|
| 172 |
+
xq = E.rearrange(xq, "b s h d -> b h s d")
|
| 173 |
+
xk = E.rearrange(xk, "b s h d -> b h s d")
|
| 174 |
+
xv = E.rearrange(xv, "b s h d -> b h s d")
|
| 175 |
+
xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx)
|
| 176 |
+
flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill
|
| 177 |
+
output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True))
|
| 178 |
+
return self._post_attention(output, aux_output.lse)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# FeedForward
|
| 182 |
+
|
| 183 |
+
@triton.jit
|
| 184 |
+
def _squared_relu_gate_kernel(
|
| 185 |
+
packed_ptr, out_ptr, n_rows, n_cols,
|
| 186 |
+
in_row_stride, in_col_stride, out_row_stride, out_col_stride,
|
| 187 |
+
BLOCK_SIZE: tl.constexpr,
|
| 188 |
+
):
|
| 189 |
+
pid = tl.program_id(0)
|
| 190 |
+
n_elements = n_rows * n_cols
|
| 191 |
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
| 192 |
+
mask = offsets < n_elements
|
| 193 |
+
rows = offsets // n_cols
|
| 194 |
+
cols = offsets % n_cols
|
| 195 |
+
gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride
|
| 196 |
+
up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride
|
| 197 |
+
out_idx = rows * out_row_stride + cols * out_col_stride
|
| 198 |
+
gate = tl.load(packed_ptr + gate_idx, mask=mask)
|
| 199 |
+
up = tl.load(packed_ptr + up_idx, mask=mask)
|
| 200 |
+
gate = tl.where(gate > 0, gate, 0.0)
|
| 201 |
+
out = gate * gate * up
|
| 202 |
+
tl.store(out_ptr + out_idx, out, mask=mask)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def squared_relu_gate(packed: T, hidden_dim: int) -> T:
|
| 206 |
+
"""Processes interleaved [gate, up, gate, up, ...] from w13; output = ReLU(gate)^2 * up."""
|
| 207 |
+
packed_2d = packed.flatten(0, -2)
|
| 208 |
+
n_rows = packed_2d.shape[0]
|
| 209 |
+
n_cols = hidden_dim
|
| 210 |
+
out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype)
|
| 211 |
+
n = n_rows * n_cols
|
| 212 |
+
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
|
| 213 |
+
_squared_relu_gate_kernel[grid](
|
| 214 |
+
packed_2d, out_2d, n_rows, n_cols,
|
| 215 |
+
packed_2d.stride(0), packed_2d.stride(1),
|
| 216 |
+
out_2d.stride(0), out_2d.stride(1),
|
| 217 |
+
BLOCK_SIZE=1024,
|
| 218 |
+
)
|
| 219 |
+
return out_2d.view(*packed.shape[:-1], hidden_dim)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class FeedForward(nn.Module):
|
| 223 |
+
def __init__(self, dim: int, hidden_dim: int):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
|
| 226 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 227 |
+
self.hidden_dim = hidden_dim
|
| 228 |
+
|
| 229 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 230 |
+
x = F.rms_norm(x, (x.size(-1),))
|
| 231 |
+
w13_out = self.w13(x)
|
| 232 |
+
return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# TransformerBlock
|
| 236 |
+
|
| 237 |
+
class TransformerBlock(nn.Module):
|
| 238 |
+
def __init__(self, layer_id: int, config: FalconOCRConfig):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.attention = Attention(config, layer_id)
|
| 241 |
+
self.feed_forward = FeedForward(config.dim, config.ffn_dim)
|
| 242 |
+
|
| 243 |
+
def compile(self, *, dynamic: bool = True, mode: str = "default"):
|
| 244 |
+
self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode)
|
| 245 |
+
self.attention.compile_attention(dynamic=dynamic, mode=mode)
|
| 246 |
+
return self
|
| 247 |
+
|
| 248 |
+
def forward(
|
| 249 |
+
self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None,
|
| 250 |
+
pos_hw: T | None = None, attention_masks=None, kv_cache=None,
|
| 251 |
+
input_pos=None, batch_idx=None, flex_attn_kernel_options=None,
|
| 252 |
+
):
|
| 253 |
+
B, S, D = x.shape
|
| 254 |
+
x = x + self.attention(
|
| 255 |
+
x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw,
|
| 256 |
+
attention_masks=attention_masks, kv_cache=kv_cache,
|
| 257 |
+
input_pos=input_pos, batch_idx=batch_idx,
|
| 258 |
+
flex_attn_kernel_options=flex_attn_kernel_options,
|
| 259 |
+
)
|
| 260 |
+
out = x + self.feed_forward(x)
|
| 261 |
+
return out.reshape(B, S, D)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# KV Cache
|
| 265 |
+
|
| 266 |
+
class KVCache:
|
| 267 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers):
|
| 268 |
+
self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim)
|
| 269 |
+
self.kv_cache = None
|
| 270 |
+
self.pos = 0
|
| 271 |
+
self.pos_t: T | None = None
|
| 272 |
+
|
| 273 |
+
def reset(self):
|
| 274 |
+
self.pos = 0
|
| 275 |
+
self.pos_t = None
|
| 276 |
+
|
| 277 |
+
def get_pos(self):
|
| 278 |
+
return self.pos
|
| 279 |
+
|
| 280 |
+
def set_pos_t(self, pos_t):
|
| 281 |
+
self.pos_t = pos_t
|
| 282 |
+
|
| 283 |
+
def increment_and_get_pos_t(self):
|
| 284 |
+
assert self.pos_t is not None
|
| 285 |
+
self.pos_t += 1
|
| 286 |
+
return self.pos_t
|
| 287 |
+
|
| 288 |
+
def insert_kv(self, layer_id: int, k: T, v: T, **kwargs):
|
| 289 |
+
del kwargs
|
| 290 |
+
assert self.pos_t is not None
|
| 291 |
+
if self.kv_cache is None:
|
| 292 |
+
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
| 293 |
+
B, H, T_add, D = k.size()
|
| 294 |
+
t0, t1 = self.pos, self.pos + T_add
|
| 295 |
+
self.kv_cache[layer_id, 0, :, :, t0:t1] = k
|
| 296 |
+
self.kv_cache[layer_id, 1, :, :, t0:t1] = v
|
| 297 |
+
key_view = self.kv_cache[layer_id, 0, :, :, :t1]
|
| 298 |
+
value_view = self.kv_cache[layer_id, 1, :, :, :t1]
|
| 299 |
+
if layer_id == self.kv_cache.size(0) - 1:
|
| 300 |
+
self.pos = t1
|
| 301 |
+
return key_view, value_view
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# Sampling
|
| 305 |
+
|
| 306 |
+
@torch.inference_mode()
|
| 307 |
+
def sample_next_token(logits, rng, temperature=0.0, top_k=None):
|
| 308 |
+
assert temperature >= 0.0
|
| 309 |
+
if temperature == 0.0:
|
| 310 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 311 |
+
if top_k is not None:
|
| 312 |
+
k = min(top_k, logits.size(-1))
|
| 313 |
+
vals, idx = torch.topk(logits, k, dim=-1)
|
| 314 |
+
vals = vals / temperature
|
| 315 |
+
probs = F.softmax(vals, dim=-1)
|
| 316 |
+
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 317 |
+
return idx.gather(1, choice)
|
| 318 |
+
logits = logits / temperature
|
| 319 |
+
probs = F.softmax(logits, dim=-1)
|
| 320 |
+
return torch.multinomial(probs, num_samples=1, generator=rng)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# Main Model
|
| 324 |
+
|
| 325 |
+
class FalconOCRForCausalLM(PreTrainedModel):
|
| 326 |
+
config_class = FalconOCRConfig
|
| 327 |
+
_no_split_modules = ["TransformerBlock"]
|
| 328 |
+
|
| 329 |
+
def __init__(self, config: FalconOCRConfig):
|
| 330 |
+
super().__init__(config)
|
| 331 |
+
img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size
|
| 332 |
+
self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False)
|
| 333 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
| 334 |
+
|
| 335 |
+
self.layers = nn.ModuleDict()
|
| 336 |
+
for layer_id in range(config.n_layers):
|
| 337 |
+
self.layers[str(layer_id)] = TransformerBlock(layer_id, config)
|
| 338 |
+
|
| 339 |
+
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
|
| 340 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
| 341 |
+
|
| 342 |
+
rope_dim = config.head_dim // 2
|
| 343 |
+
freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta)
|
| 344 |
+
freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float)
|
| 345 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 346 |
+
self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True)
|
| 347 |
+
|
| 348 |
+
self._weights_fused = False
|
| 349 |
+
self._is_compiled = False
|
| 350 |
+
|
| 351 |
+
self.post_init()
|
| 352 |
+
|
| 353 |
+
# Weight management
|
| 354 |
+
|
| 355 |
+
def _ensure_device_buffers(self):
|
| 356 |
+
"""Recompute non-persistent buffers that HF meta-device loading may discard."""
|
| 357 |
+
if self._weights_fused:
|
| 358 |
+
return
|
| 359 |
+
device = self.tok_embeddings.weight.device
|
| 360 |
+
c = self.config
|
| 361 |
+
rope_dim = c.head_dim // 2
|
| 362 |
+
freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
|
| 363 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 364 |
+
if self.freqs_cis_golden.device != device:
|
| 365 |
+
self.freqs_cis_golden = self.freqs_cis_golden.to(device)
|
| 366 |
+
self._weights_fused = True
|
| 367 |
+
|
| 368 |
+
def compile_model(self):
|
| 369 |
+
if self._is_compiled:
|
| 370 |
+
return
|
| 371 |
+
torch._inductor.config.triton.cudagraphs = False
|
| 372 |
+
for layer in self.layers.values():
|
| 373 |
+
layer.compile(dynamic=True, mode="default")
|
| 374 |
+
self._is_compiled = True
|
| 375 |
+
|
| 376 |
+
# Tokenizer
|
| 377 |
+
|
| 378 |
+
def _get_tokenizer(self):
|
| 379 |
+
if not hasattr(self, "_tokenizer"):
|
| 380 |
+
import os
|
| 381 |
+
path = self.config._name_or_path
|
| 382 |
+
is_local = os.path.exists(path)
|
| 383 |
+
self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True)
|
| 384 |
+
for token_name, token in self._tokenizer.special_tokens_map.items():
|
| 385 |
+
if isinstance(token, str):
|
| 386 |
+
setattr(self._tokenizer, token_name, token)
|
| 387 |
+
setattr(
|
| 388 |
+
self._tokenizer, token_name + "_id",
|
| 389 |
+
self._tokenizer.convert_tokens_to_ids(token),
|
| 390 |
+
)
|
| 391 |
+
return self._tokenizer
|
| 392 |
+
|
| 393 |
+
# Attention mask
|
| 394 |
+
|
| 395 |
+
def get_attention_mask(self, input_batch: T, max_len: int | None = None):
|
| 396 |
+
return create_batch_attention_mask(
|
| 397 |
+
input_batch,
|
| 398 |
+
pad_token_id=self._pad_token_id,
|
| 399 |
+
eos_token_id=self.config.eos_id,
|
| 400 |
+
soi_token_id=self.config.image_cls_token_id,
|
| 401 |
+
eoi_token_id=self.config.img_end_id,
|
| 402 |
+
max_len=max_len,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Embedding helpers
|
| 406 |
+
|
| 407 |
+
def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS):
|
| 408 |
+
B, S, D = h_BSD.shape
|
| 409 |
+
pixel_patch_mask = E.reduce(
|
| 410 |
+
pixel_masks_NTHW,
|
| 411 |
+
"n (t pt) (h ph) (w pw) -> (n t h w)",
|
| 412 |
+
reduction="any",
|
| 413 |
+
pt=self.config.temporal_patch_size,
|
| 414 |
+
ph=self.config.spatial_patch_size,
|
| 415 |
+
pw=self.config.spatial_patch_size,
|
| 416 |
+
)
|
| 417 |
+
pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c")
|
| 418 |
+
valid_patches = pixel_patches_flat[pixel_patch_mask]
|
| 419 |
+
valid_feats = self.img_projector(valid_patches)
|
| 420 |
+
img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
|
| 421 |
+
assert valid_feats.numel() == img_mask_h_BSD.sum()
|
| 422 |
+
return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats)
|
| 423 |
+
|
| 424 |
+
# Core forward
|
| 425 |
+
|
| 426 |
+
def forward(
|
| 427 |
+
self,
|
| 428 |
+
tokens: T,
|
| 429 |
+
attention_mask: BlockMask,
|
| 430 |
+
kv_cache,
|
| 431 |
+
rope_pos_t: T | None = None,
|
| 432 |
+
rope_pos_hw: T | None = None,
|
| 433 |
+
pixel_values: T | None = None,
|
| 434 |
+
pixel_mask: T | None = None,
|
| 435 |
+
):
|
| 436 |
+
B, S = tokens.size()
|
| 437 |
+
c = self.config
|
| 438 |
+
block_mask = attention_mask
|
| 439 |
+
|
| 440 |
+
T_pos = kv_cache.get_pos()
|
| 441 |
+
is_prefill = S != 1
|
| 442 |
+
|
| 443 |
+
if is_prefill:
|
| 444 |
+
assert rope_pos_t is not None and rope_pos_hw is not None
|
| 445 |
+
pos_t = rope_pos_t[:, T_pos:T_pos + S].long()
|
| 446 |
+
kv_cache.pos_t = pos_t[:, -1:]
|
| 447 |
+
freqs_cis = self.freqs_cis[pos_t]
|
| 448 |
+
rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S]
|
| 449 |
+
freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw)
|
| 450 |
+
block_mask.seq_lengths = (S, S)
|
| 451 |
+
else:
|
| 452 |
+
pos_t = kv_cache.increment_and_get_pos_t()
|
| 453 |
+
freqs_cis = self.freqs_cis[pos_t]
|
| 454 |
+
freqs_cis_golden = None
|
| 455 |
+
block_idx = T_pos // block_mask.BLOCK_SIZE[0]
|
| 456 |
+
block_mask = block_mask[:, :, block_idx]
|
| 457 |
+
block_mask.seq_lengths = (S, T_pos + S)
|
| 458 |
+
block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos)
|
| 459 |
+
|
| 460 |
+
h_BSD = self.tok_embeddings(tokens)
|
| 461 |
+
|
| 462 |
+
if pixel_values is not None:
|
| 463 |
+
assert pixel_mask is not None
|
| 464 |
+
pixel_values = pixel_values.to(self.dtype)
|
| 465 |
+
pixel_mask = pixel_mask.to(self.dtype)
|
| 466 |
+
pixel_patches_NLC = E.rearrange(
|
| 467 |
+
pixel_values,
|
| 468 |
+
"n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)",
|
| 469 |
+
pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
|
| 470 |
+
)
|
| 471 |
+
h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens)
|
| 472 |
+
|
| 473 |
+
for layer in self.layers.values():
|
| 474 |
+
h_BSD = layer(
|
| 475 |
+
h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden,
|
| 476 |
+
pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
h_BSD = self.norm(h_BSD)
|
| 480 |
+
logits_BSV = self.output(h_BSD)
|
| 481 |
+
return logits_BSV
|
| 482 |
+
|
| 483 |
+
# Layout detection
|
| 484 |
+
|
| 485 |
+
def _load_layout_model(self, layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors"):
|
| 486 |
+
if hasattr(self, "_layout_model"):
|
| 487 |
+
return
|
| 488 |
+
import torchvision.transforms.functional as tvF
|
| 489 |
+
from transformers import AutoModelForObjectDetection, PPDocLayoutV3ImageProcessorFast
|
| 490 |
+
|
| 491 |
+
self._layout_processor = PPDocLayoutV3ImageProcessorFast.from_pretrained(layout_model)
|
| 492 |
+
self._layout_det_model = AutoModelForObjectDetection.from_pretrained(
|
| 493 |
+
layout_model, torch_dtype=torch.float16,
|
| 494 |
+
).to(self.device).eval()
|
| 495 |
+
self._layout_id2label = self._layout_det_model.config.id2label
|
| 496 |
+
self._tvF = tvF
|
| 497 |
+
|
| 498 |
+
@torch.inference_mode()
|
| 499 |
+
def _run_layout_detection(
|
| 500 |
+
self, images: list[Image.Image], threshold: float = 0.5,
|
| 501 |
+
) -> list[list[dict]]:
|
| 502 |
+
"""Run PP-DocLayoutV3 on a batch of PIL images, return per-image detections."""
|
| 503 |
+
device = self.device
|
| 504 |
+
tvF = self._tvF
|
| 505 |
+
|
| 506 |
+
target_sizes = torch.tensor([img.size[::-1] for img in images])
|
| 507 |
+
tensors = [tvF.pil_to_tensor(img) for img in images]
|
| 508 |
+
|
| 509 |
+
# GPU-accelerated resize + normalize
|
| 510 |
+
result = torch.empty(
|
| 511 |
+
len(tensors), 3, _LAYOUT_TARGET_H, _LAYOUT_TARGET_W,
|
| 512 |
+
dtype=torch.float16, device=device,
|
| 513 |
+
)
|
| 514 |
+
size_groups: dict[tuple[int, int], list[int]] = {}
|
| 515 |
+
for i, t in enumerate(tensors):
|
| 516 |
+
size_groups.setdefault((t.shape[1], t.shape[2]), []).append(i)
|
| 517 |
+
|
| 518 |
+
for shape, indices in size_groups.items():
|
| 519 |
+
batch = torch.stack([tensors[i] for i in indices])
|
| 520 |
+
batch = batch.to(device=device, dtype=torch.float32, non_blocking=True)
|
| 521 |
+
batch = F.interpolate(
|
| 522 |
+
batch, size=(_LAYOUT_TARGET_H, _LAYOUT_TARGET_W),
|
| 523 |
+
mode="bicubic", align_corners=False, antialias=False,
|
| 524 |
+
)
|
| 525 |
+
batch = (batch.clamp_(0, 255) / 255.0).to(torch.float16)
|
| 526 |
+
for j, idx in enumerate(indices):
|
| 527 |
+
result[idx] = batch[j]
|
| 528 |
+
del batch
|
| 529 |
+
|
| 530 |
+
outputs = self._layout_det_model(pixel_values=result)
|
| 531 |
+
del result
|
| 532 |
+
|
| 533 |
+
# Postprocess on GPU
|
| 534 |
+
logits = outputs.logits
|
| 535 |
+
boxes = outputs.pred_boxes
|
| 536 |
+
order_logits = outputs.order_logits
|
| 537 |
+
|
| 538 |
+
box_centers, box_dims = boxes.split(2, dim=-1)
|
| 539 |
+
boxes_xyxy = torch.cat([box_centers - 0.5 * box_dims, box_centers + 0.5 * box_dims], dim=-1)
|
| 540 |
+
|
| 541 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 542 |
+
scale = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(device, dtype=boxes_xyxy.dtype)
|
| 543 |
+
boxes_xyxy = boxes_xyxy * scale[:, None, :]
|
| 544 |
+
|
| 545 |
+
num_queries = logits.shape[1]
|
| 546 |
+
num_classes = logits.shape[2]
|
| 547 |
+
scores = logits.sigmoid()
|
| 548 |
+
scores_flat, index = scores.flatten(1).topk(num_queries, dim=-1)
|
| 549 |
+
labels = index % num_classes
|
| 550 |
+
box_indices = index // num_classes
|
| 551 |
+
boxes_xyxy = boxes_xyxy.gather(dim=1, index=box_indices.unsqueeze(-1).expand(-1, -1, 4))
|
| 552 |
+
|
| 553 |
+
order_seqs = self._layout_processor._get_order_seqs(order_logits)
|
| 554 |
+
order_seqs = order_seqs.gather(dim=1, index=box_indices)
|
| 555 |
+
|
| 556 |
+
batch_results = []
|
| 557 |
+
for s, l, b, o in zip(scores_flat, labels, boxes_xyxy, order_seqs):
|
| 558 |
+
mask = s >= threshold
|
| 559 |
+
o_valid = o[mask]
|
| 560 |
+
_, indices_sorted = o_valid.sort()
|
| 561 |
+
|
| 562 |
+
detections = []
|
| 563 |
+
for si, li, bi in zip(s[mask][indices_sorted], l[mask][indices_sorted], b[mask][indices_sorted]):
|
| 564 |
+
detections.append({
|
| 565 |
+
"category": self._layout_id2label[li.item()],
|
| 566 |
+
"bbox": [round(x, 2) for x in bi.tolist()],
|
| 567 |
+
"score": round(si.item(), 4),
|
| 568 |
+
})
|
| 569 |
+
batch_results.append(detections)
|
| 570 |
+
|
| 571 |
+
return batch_results
|
| 572 |
+
|
| 573 |
+
# Core batch decode (shared by generate & generate_with_layout)
|
| 574 |
+
|
| 575 |
+
def _generate_batch(
|
| 576 |
+
self,
|
| 577 |
+
image_prompt_pairs: list[tuple],
|
| 578 |
+
*,
|
| 579 |
+
max_new_tokens: int,
|
| 580 |
+
temperature: float,
|
| 581 |
+
top_k: int | None,
|
| 582 |
+
min_dimension: int,
|
| 583 |
+
max_dimension: int,
|
| 584 |
+
seed: int | None,
|
| 585 |
+
) -> list[str]:
|
| 586 |
+
"""Core autoregressive decode for a list of (image, prompt) pairs."""
|
| 587 |
+
device = self.device
|
| 588 |
+
tokenizer = self._get_tokenizer()
|
| 589 |
+
self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
|
| 590 |
+
stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")]
|
| 591 |
+
|
| 592 |
+
batch_inputs = process_batch(
|
| 593 |
+
tokenizer, self.config, image_prompt_pairs,
|
| 594 |
+
max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension,
|
| 595 |
+
)
|
| 596 |
+
batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()}
|
| 597 |
+
|
| 598 |
+
tokens = batch_inputs["tokens"]
|
| 599 |
+
B, L = tokens.size()
|
| 600 |
+
block_size = 128
|
| 601 |
+
S = (L + max_new_tokens + block_size - 1) // block_size * block_size
|
| 602 |
+
assert S <= self.config.max_seq_len
|
| 603 |
+
|
| 604 |
+
rng = torch.Generator(device).manual_seed(seed) if seed is not None else None
|
| 605 |
+
|
| 606 |
+
kv_cache = KVCache(
|
| 607 |
+
max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads,
|
| 608 |
+
head_dim=self.config.head_dim, num_layers=self.config.n_layers,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device)
|
| 612 |
+
padded_tokens[:, :L] = tokens
|
| 613 |
+
|
| 614 |
+
attention_mask = self.get_attention_mask(padded_tokens, max_len=S)
|
| 615 |
+
|
| 616 |
+
logits_BSV = self.forward(
|
| 617 |
+
tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"],
|
| 618 |
+
attention_mask=attention_mask, kv_cache=kv_cache,
|
| 619 |
+
pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"],
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
stop_ids = torch.tensor(stop_token_ids).to(device)
|
| 623 |
+
should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device)
|
| 624 |
+
generated_ids: list[list[int]] = [[] for _ in range(B)]
|
| 625 |
+
|
| 626 |
+
while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S:
|
| 627 |
+
tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k)
|
| 628 |
+
|
| 629 |
+
if torch.any(should_stop_B):
|
| 630 |
+
tokens_B1 = tokens_B1.clone()
|
| 631 |
+
tokens_B1[should_stop_B, :] = self._pad_token_id
|
| 632 |
+
padded_tokens[:, pos] = tokens_B1[:, -1]
|
| 633 |
+
|
| 634 |
+
for b in range(B):
|
| 635 |
+
if not should_stop_B[b]:
|
| 636 |
+
generated_ids[b].append(tokens_B1[b, 0].item())
|
| 637 |
+
|
| 638 |
+
logits_BSV = self.forward(
|
| 639 |
+
tokens=tokens_B1, attention_mask=attention_mask, kv_cache=kv_cache,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
|
| 643 |
+
should_stop_B = should_stop_B.logical_or(hit_stop_B)
|
| 644 |
+
|
| 645 |
+
results = []
|
| 646 |
+
for b in range(B):
|
| 647 |
+
text = tokenizer.decode(generated_ids[b], skip_special_tokens=False)
|
| 648 |
+
text = text.replace("<|end_of_query|>", "").replace("<|end_of_text|>", "").strip()
|
| 649 |
+
results.append(text)
|
| 650 |
+
|
| 651 |
+
return results
|
| 652 |
+
|
| 653 |
+
# Main API: generate
|
| 654 |
+
|
| 655 |
+
@torch.inference_mode()
|
| 656 |
+
def generate(
|
| 657 |
+
self,
|
| 658 |
+
images,
|
| 659 |
+
*,
|
| 660 |
+
category: str | list[str] = "plain",
|
| 661 |
+
max_new_tokens: int = 4096,
|
| 662 |
+
temperature: float = 0.0,
|
| 663 |
+
top_k: int | None = None,
|
| 664 |
+
min_dimension: int = 64,
|
| 665 |
+
max_dimension: int = 1024,
|
| 666 |
+
compile: bool = True,
|
| 667 |
+
seed: int | None = 42,
|
| 668 |
+
) -> list[str]:
|
| 669 |
+
"""
|
| 670 |
+
Extract text from document images.
|
| 671 |
+
|
| 672 |
+
Args:
|
| 673 |
+
images: Single PIL Image (or path/URL) or list of them.
|
| 674 |
+
category: OCR category — one of "plain", "text", "table", "formula",
|
| 675 |
+
"caption", "footnote", "list-item", "page-footer", "page-header",
|
| 676 |
+
"section-header", "title". Can be a single string (applied to all
|
| 677 |
+
images) or a list (one per image).
|
| 678 |
+
max_new_tokens: Maximum generation steps.
|
| 679 |
+
temperature: Sampling temperature (0.0 = greedy).
|
| 680 |
+
top_k: Top-k sampling (None = disabled).
|
| 681 |
+
min_dimension: Min image side after resize.
|
| 682 |
+
max_dimension: Max image side after resize.
|
| 683 |
+
compile: Whether to torch.compile on first call.
|
| 684 |
+
seed: Random seed for reproducibility (None = non-deterministic).
|
| 685 |
+
|
| 686 |
+
Returns:
|
| 687 |
+
List of extracted text strings, one per image.
|
| 688 |
+
"""
|
| 689 |
+
self._ensure_device_buffers()
|
| 690 |
+
if compile:
|
| 691 |
+
self.compile_model()
|
| 692 |
+
|
| 693 |
+
if isinstance(images, (str, Path, Image.Image)):
|
| 694 |
+
images = [images]
|
| 695 |
+
if isinstance(category, str):
|
| 696 |
+
category = [category] * len(images)
|
| 697 |
+
assert len(images) == len(category), "Must provide one category per image"
|
| 698 |
+
|
| 699 |
+
image_prompt_pairs = []
|
| 700 |
+
for img, cat in zip(images, category):
|
| 701 |
+
instruction = CATEGORY_PROMPTS.get(cat.strip().lower(), CATEGORY_PROMPTS["plain"])
|
| 702 |
+
prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>"
|
| 703 |
+
image_prompt_pairs.append((img, prompt))
|
| 704 |
+
|
| 705 |
+
return self._generate_batch(
|
| 706 |
+
image_prompt_pairs,
|
| 707 |
+
max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k,
|
| 708 |
+
min_dimension=min_dimension, max_dimension=max_dimension, seed=seed,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# Main API: generate_with_layout
|
| 712 |
+
|
| 713 |
+
@torch.inference_mode()
|
| 714 |
+
def generate_with_layout(
|
| 715 |
+
self,
|
| 716 |
+
images,
|
| 717 |
+
*,
|
| 718 |
+
max_new_tokens: int = 4096,
|
| 719 |
+
temperature: float = 0.0,
|
| 720 |
+
top_k: int | None = None,
|
| 721 |
+
min_dimension: int = 64,
|
| 722 |
+
max_dimension: int = 1024,
|
| 723 |
+
compile: bool = True,
|
| 724 |
+
seed: int | None = 42,
|
| 725 |
+
layout_threshold: float = 0.3,
|
| 726 |
+
layout_batch_size: int = 4,
|
| 727 |
+
ocr_batch_size: int = 32,
|
| 728 |
+
containment_threshold: float = 0.8,
|
| 729 |
+
layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors",
|
| 730 |
+
) -> list[list[dict]]:
|
| 731 |
+
"""
|
| 732 |
+
Run layout detection then OCR on each detected region.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
images: Single PIL Image (or path/URL) or list of them.
|
| 736 |
+
max_new_tokens: Maximum generation steps per crop.
|
| 737 |
+
temperature: Sampling temperature (0.0 = greedy).
|
| 738 |
+
top_k: Top-k sampling (None = disabled).
|
| 739 |
+
min_dimension: Min crop side after resize for OCR.
|
| 740 |
+
max_dimension: Max crop side after resize for OCR.
|
| 741 |
+
compile: Whether to torch.compile on first call.
|
| 742 |
+
seed: Random seed for reproducibility.
|
| 743 |
+
layout_threshold: Confidence threshold for layout detections.
|
| 744 |
+
layout_batch_size: Batch size for layout detection.
|
| 745 |
+
ocr_batch_size: Batch size for OCR generation (chunks crops).
|
| 746 |
+
containment_threshold: Drop formula boxes >threshold contained in text boxes.
|
| 747 |
+
layout_model: HuggingFace model ID for layout detection.
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
Per-image list of detections, each a dict with keys:
|
| 751 |
+
``category``, ``bbox`` [x1,y1,x2,y2], ``score``, ``text``.
|
| 752 |
+
"""
|
| 753 |
+
self._ensure_device_buffers()
|
| 754 |
+
if compile:
|
| 755 |
+
self.compile_model()
|
| 756 |
+
self._load_layout_model(layout_model)
|
| 757 |
+
|
| 758 |
+
if isinstance(images, (str, Path, Image.Image)):
|
| 759 |
+
images = [images]
|
| 760 |
+
pil_images = [load_image(img).convert("RGB") for img in images]
|
| 761 |
+
|
| 762 |
+
# --- Layout detection (batched) ---
|
| 763 |
+
all_layout_dets: list[list[dict]] = []
|
| 764 |
+
for i in range(0, len(pil_images), layout_batch_size):
|
| 765 |
+
batch_imgs = pil_images[i : i + layout_batch_size]
|
| 766 |
+
dets = self._run_layout_detection(batch_imgs, threshold=layout_threshold)
|
| 767 |
+
all_layout_dets.extend(dets)
|
| 768 |
+
|
| 769 |
+
# --- Filter nested boxes (e.g. inline formulas inside text) ---
|
| 770 |
+
all_layout_dets = [
|
| 771 |
+
_filter_nested_detections(dets, containment_threshold)
|
| 772 |
+
for dets in all_layout_dets
|
| 773 |
+
]
|
| 774 |
+
|
| 775 |
+
# --- Build crops + track origin ---
|
| 776 |
+
flat_crops: list[tuple[Image.Image, str]] = []
|
| 777 |
+
crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
|
| 778 |
+
|
| 779 |
+
for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
|
| 780 |
+
if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"):
|
| 781 |
+
prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>"
|
| 782 |
+
flat_crops.append((pil_img, prompt))
|
| 783 |
+
crop_origins.append((img_idx, -1))
|
| 784 |
+
continue
|
| 785 |
+
|
| 786 |
+
img_w, img_h = pil_img.size
|
| 787 |
+
for det_idx, det in enumerate(dets):
|
| 788 |
+
cat_key = det["category"].strip().lower()
|
| 789 |
+
ocr_cat = LAYOUT_TO_OCR_CATEGORY.get(cat_key)
|
| 790 |
+
if ocr_cat is None:
|
| 791 |
+
continue
|
| 792 |
+
|
| 793 |
+
x1, y1, x2, y2 = det["bbox"]
|
| 794 |
+
x1 = max(0, int(x1))
|
| 795 |
+
y1 = max(0, int(y1))
|
| 796 |
+
x2 = min(img_w, int(x2 + 0.5))
|
| 797 |
+
y2 = min(img_h, int(y2 + 0.5))
|
| 798 |
+
cw, ch = x2 - x1, y2 - y1
|
| 799 |
+
if cw < _MIN_CROP_DIM or ch < _MIN_CROP_DIM:
|
| 800 |
+
continue
|
| 801 |
+
short, long = sorted((cw, ch))
|
| 802 |
+
resized_short = short * (max_dimension / long) if long > max_dimension else short
|
| 803 |
+
if resized_short < _MIN_CROP_DIM:
|
| 804 |
+
continue
|
| 805 |
+
|
| 806 |
+
crop = pil_img.crop((x1, y1, x2, y2))
|
| 807 |
+
instruction = CATEGORY_PROMPTS.get(ocr_cat, CATEGORY_PROMPTS["plain"])
|
| 808 |
+
prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>"
|
| 809 |
+
flat_crops.append((crop, prompt))
|
| 810 |
+
crop_origins.append((img_idx, det_idx))
|
| 811 |
+
|
| 812 |
+
# --- OCR in chunks ---
|
| 813 |
+
flat_texts: list[str] = []
|
| 814 |
+
for i in range(0, max(len(flat_crops), 1), ocr_batch_size):
|
| 815 |
+
chunk = flat_crops[i : i + ocr_batch_size]
|
| 816 |
+
if not chunk:
|
| 817 |
+
break
|
| 818 |
+
texts = self._generate_batch(
|
| 819 |
+
chunk,
|
| 820 |
+
max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k,
|
| 821 |
+
min_dimension=min_dimension, max_dimension=max_dimension, seed=seed,
|
| 822 |
+
)
|
| 823 |
+
flat_texts.extend(texts)
|
| 824 |
+
|
| 825 |
+
# --- Reassemble per-image results ---
|
| 826 |
+
results: list[list[dict]] = [[] for _ in range(len(pil_images))]
|
| 827 |
+
for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
|
| 828 |
+
if det_idx == -1:
|
| 829 |
+
img_w, img_h = pil_images[img_idx].size
|
| 830 |
+
results[img_idx].append({
|
| 831 |
+
"category": "plain",
|
| 832 |
+
"bbox": [0, 0, img_w, img_h],
|
| 833 |
+
"score": 1.0,
|
| 834 |
+
"text": text,
|
| 835 |
+
})
|
| 836 |
+
else:
|
| 837 |
+
det = all_layout_dets[img_idx][det_idx]
|
| 838 |
+
results[img_idx].append({
|
| 839 |
+
"category": det["category"],
|
| 840 |
+
"bbox": det["bbox"],
|
| 841 |
+
"score": det["score"],
|
| 842 |
+
"text": text,
|
| 843 |
+
})
|
| 844 |
+
|
| 845 |
+
return results
|
processing_falcon_ocr.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import einops as E
|
| 5 |
+
import numpy as np
|
| 6 |
+
import requests
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 10 |
+
from transformers.image_transforms import convert_to_rgb, resize
|
| 11 |
+
from transformers.image_utils import (
|
| 12 |
+
ImageInput,
|
| 13 |
+
get_image_size,
|
| 14 |
+
infer_channel_dimension_format,
|
| 15 |
+
to_numpy_array,
|
| 16 |
+
valid_images,
|
| 17 |
+
validate_preprocess_arguments,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
IMAGE_MEAN = [0.5, 0.5, 0.5]
|
| 21 |
+
IMAGE_STD = [0.5, 0.5, 0.5]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_image(image):
|
| 25 |
+
if image is None:
|
| 26 |
+
return None
|
| 27 |
+
if isinstance(image, Image.Image):
|
| 28 |
+
return image
|
| 29 |
+
if isinstance(image, str):
|
| 30 |
+
if image.startswith(("http://", "https://")):
|
| 31 |
+
response = requests.get(image, timeout=10)
|
| 32 |
+
response.raise_for_status()
|
| 33 |
+
return Image.open(io.BytesIO(response.content))
|
| 34 |
+
if image.endswith(".npy"):
|
| 35 |
+
img_array = io.BytesIO(np.load(image))
|
| 36 |
+
return Image.open(img_array)
|
| 37 |
+
return Image.open(image)
|
| 38 |
+
if isinstance(image, np.bytes_):
|
| 39 |
+
return Image.open(io.BytesIO(image))
|
| 40 |
+
if isinstance(image, np.ndarray):
|
| 41 |
+
return Image.fromarray(image)
|
| 42 |
+
raise TypeError(f"Unknown image format {image}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_images(images_input, min_dimension: int, max_dimension: int):
|
| 46 |
+
images = []
|
| 47 |
+
if images_input is not None:
|
| 48 |
+
for inp in images_input:
|
| 49 |
+
img = load_image(inp)
|
| 50 |
+
img = resize_image_if_necessary(img, min_dimension, max_dimension)
|
| 51 |
+
images.append(img)
|
| 52 |
+
return images
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def resize_image_if_necessary(
|
| 56 |
+
image,
|
| 57 |
+
shortest_dimension=224,
|
| 58 |
+
longest_dimension=896,
|
| 59 |
+
):
|
| 60 |
+
original_width, original_height = image.size
|
| 61 |
+
aspect_ratio = original_width / original_height
|
| 62 |
+
|
| 63 |
+
if (
|
| 64 |
+
shortest_dimension <= original_width <= longest_dimension
|
| 65 |
+
and shortest_dimension <= original_height <= longest_dimension
|
| 66 |
+
):
|
| 67 |
+
return image
|
| 68 |
+
|
| 69 |
+
is_vertical_image = original_width < original_height
|
| 70 |
+
if original_width < shortest_dimension or original_height < shortest_dimension:
|
| 71 |
+
if is_vertical_image:
|
| 72 |
+
new_width = shortest_dimension
|
| 73 |
+
new_height = int(new_width / aspect_ratio)
|
| 74 |
+
else:
|
| 75 |
+
new_height = shortest_dimension
|
| 76 |
+
new_width = int(new_height * aspect_ratio)
|
| 77 |
+
else:
|
| 78 |
+
if is_vertical_image:
|
| 79 |
+
new_width = longest_dimension
|
| 80 |
+
new_height = int(new_width / aspect_ratio)
|
| 81 |
+
else:
|
| 82 |
+
new_height = longest_dimension
|
| 83 |
+
new_width = int(new_height * aspect_ratio)
|
| 84 |
+
|
| 85 |
+
if new_width > longest_dimension:
|
| 86 |
+
new_width = longest_dimension
|
| 87 |
+
new_height = int(new_width / aspect_ratio)
|
| 88 |
+
if new_height > longest_dimension:
|
| 89 |
+
new_height = longest_dimension
|
| 90 |
+
new_width = int(new_height * aspect_ratio)
|
| 91 |
+
|
| 92 |
+
resized_image = image.resize((new_width, new_height))
|
| 93 |
+
return resized_image
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def smart_resize(
|
| 97 |
+
image,
|
| 98 |
+
factor: int,
|
| 99 |
+
resample,
|
| 100 |
+
input_data_format,
|
| 101 |
+
min_pixels: int = 56 * 56,
|
| 102 |
+
max_pixels: int = 14 * 14 * 4 * 1280,
|
| 103 |
+
):
|
| 104 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 105 |
+
if height < factor or width < factor:
|
| 106 |
+
raise ValueError(f"{height=} or {width=} must be larger than {factor=}")
|
| 107 |
+
if max(height, width) / min(height, width) > 200:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 110 |
+
)
|
| 111 |
+
h_bar = round(height / factor) * factor
|
| 112 |
+
w_bar = round(width / factor) * factor
|
| 113 |
+
if h_bar * w_bar > max_pixels:
|
| 114 |
+
beta = np.sqrt((height * width) / max_pixels)
|
| 115 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 116 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 117 |
+
elif h_bar * w_bar < min_pixels:
|
| 118 |
+
beta = np.sqrt(min_pixels / (height * width))
|
| 119 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 120 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 121 |
+
image = resize(
|
| 122 |
+
image,
|
| 123 |
+
size=(h_bar, w_bar),
|
| 124 |
+
resample=resample,
|
| 125 |
+
input_data_format=input_data_format,
|
| 126 |
+
)
|
| 127 |
+
return image
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ImageProcessor(BaseImageProcessor):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
patch_size,
|
| 134 |
+
merge_size,
|
| 135 |
+
do_resize: bool = True,
|
| 136 |
+
resample: Image.Resampling = Image.Resampling.BICUBIC,
|
| 137 |
+
do_rescale: bool = True,
|
| 138 |
+
rescale_factor: float = 1 / 255,
|
| 139 |
+
do_normalize: bool = True,
|
| 140 |
+
image_mean: float | list[float] | None = None,
|
| 141 |
+
image_std: float | list[float] | None = None,
|
| 142 |
+
do_convert_rgb: bool = True,
|
| 143 |
+
min_pixels: int = 56 * 56,
|
| 144 |
+
max_pixels: int = 28 * 28 * 1280,
|
| 145 |
+
**kwargs,
|
| 146 |
+
) -> None:
|
| 147 |
+
super().__init__(**kwargs)
|
| 148 |
+
self.do_resize = do_resize
|
| 149 |
+
self.resample = resample
|
| 150 |
+
self.do_rescale = do_rescale
|
| 151 |
+
self.rescale_factor = rescale_factor
|
| 152 |
+
self.do_normalize = do_normalize
|
| 153 |
+
self.image_mean = image_mean or IMAGE_MEAN
|
| 154 |
+
self.image_std = image_std or IMAGE_STD
|
| 155 |
+
self.min_pixels = min_pixels
|
| 156 |
+
self.max_pixels = max_pixels
|
| 157 |
+
self.patch_size = patch_size
|
| 158 |
+
self.merge_size = merge_size
|
| 159 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 160 |
+
self.do_convert_rgb = do_convert_rgb
|
| 161 |
+
validate_preprocess_arguments(
|
| 162 |
+
rescale_factor=self.rescale_factor,
|
| 163 |
+
do_normalize=self.do_normalize,
|
| 164 |
+
image_mean=self.image_mean,
|
| 165 |
+
image_std=self.image_std,
|
| 166 |
+
do_resize=self.do_resize,
|
| 167 |
+
size=self.size,
|
| 168 |
+
resample=self.resample,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def _preprocess(self, image: ImageInput, do_rescale=None, do_normalize=None):
|
| 172 |
+
if self.do_convert_rgb:
|
| 173 |
+
image = convert_to_rgb(image)
|
| 174 |
+
image = to_numpy_array(image)
|
| 175 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 176 |
+
if self.do_resize:
|
| 177 |
+
image = smart_resize(
|
| 178 |
+
image,
|
| 179 |
+
factor=self.patch_size * self.merge_size,
|
| 180 |
+
resample=self.resample,
|
| 181 |
+
input_data_format=input_data_format,
|
| 182 |
+
min_pixels=self.min_pixels,
|
| 183 |
+
max_pixels=self.max_pixels,
|
| 184 |
+
)
|
| 185 |
+
if do_rescale or self.do_rescale:
|
| 186 |
+
image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)
|
| 187 |
+
if do_normalize or self.do_normalize:
|
| 188 |
+
image = self.normalize(
|
| 189 |
+
image=image, mean=self.image_mean, std=self.image_std,
|
| 190 |
+
input_data_format=input_data_format,
|
| 191 |
+
)
|
| 192 |
+
return image
|
| 193 |
+
|
| 194 |
+
def preprocess(self, images: list[ImageInput] | None, do_rescale=None, do_normalize=None, **kwargs):
|
| 195 |
+
del kwargs
|
| 196 |
+
if images is None:
|
| 197 |
+
return []
|
| 198 |
+
images = [item for item in images if item is not None]
|
| 199 |
+
if not valid_images(images):
|
| 200 |
+
raise ValueError(
|
| 201 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 202 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 203 |
+
)
|
| 204 |
+
pixel_values = []
|
| 205 |
+
for image in images:
|
| 206 |
+
processed_image = self._preprocess(image, do_rescale, do_normalize)
|
| 207 |
+
processed_image = processed_image[None, ...]
|
| 208 |
+
pixel_values.append(processed_image)
|
| 209 |
+
return pixel_values
|
| 210 |
+
|
| 211 |
+
def batch_images_with_mask(self, pixel_values, max_image_height, max_image_width):
|
| 212 |
+
if pixel_values is None:
|
| 213 |
+
return None
|
| 214 |
+
pixel_values = [item for item in pixel_values if item is not None and len(item) != 0]
|
| 215 |
+
if len(pixel_values) == 0:
|
| 216 |
+
return None
|
| 217 |
+
pixel_values = [torch.from_numpy(img) for img in pixel_values]
|
| 218 |
+
max_temporal = max(img.shape[0] for img in pixel_values)
|
| 219 |
+
|
| 220 |
+
def pad_image_and_mask(img):
|
| 221 |
+
time_steps, height, width, channels = img.shape
|
| 222 |
+
if channels != 3:
|
| 223 |
+
raise ValueError(f"Expected 3-channel RGB images, got {channels} channels.")
|
| 224 |
+
padding = (0, 0, 0, max_image_width - width, 0, max_image_height - height, 0, max_temporal - time_steps)
|
| 225 |
+
padded_image = torch.nn.functional.pad(img, padding)
|
| 226 |
+
mask = torch.zeros((max_temporal, max_image_height, max_image_width), dtype=torch.long)
|
| 227 |
+
mask[:time_steps, :height, :width] = 1
|
| 228 |
+
return padded_image, mask
|
| 229 |
+
|
| 230 |
+
padded_pixel_values, padding_masks = zip(*[pad_image_and_mask(img) for img in pixel_values])
|
| 231 |
+
padded_pixel_values = torch.stack(list(padded_pixel_values))
|
| 232 |
+
padding_masks = torch.stack(list(padding_masks))
|
| 233 |
+
return {"pixel_values": padded_pixel_values, "padding_mask": padding_masks}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
# Positional encoding helpers
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
|
| 240 |
+
def _compute_image_spatial_positions(
|
| 241 |
+
pixel_mask_THW: torch.Tensor,
|
| 242 |
+
spatial_patch_size: int,
|
| 243 |
+
temporal_patch_size: int = 1,
|
| 244 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 245 |
+
mask_thw = E.reduce(
|
| 246 |
+
pixel_mask_THW,
|
| 247 |
+
"(t tp) (h hp) (w wp) -> t h w",
|
| 248 |
+
reduction="any",
|
| 249 |
+
tp=temporal_patch_size,
|
| 250 |
+
hp=spatial_patch_size,
|
| 251 |
+
wp=spatial_patch_size,
|
| 252 |
+
)
|
| 253 |
+
width = E.reduce(mask_thw.sum(dim=-1).int(), "t h -> ", reduction="max")
|
| 254 |
+
height = E.reduce(mask_thw.sum(dim=-2).int(), "t w -> ", reduction="max")
|
| 255 |
+
xlim = torch.sqrt(width / height)
|
| 256 |
+
ylim = torch.sqrt(height / width)
|
| 257 |
+
xpos = torch.linspace(-xlim, xlim, int(width))
|
| 258 |
+
ypos = torch.linspace(-ylim, ylim, int(height))
|
| 259 |
+
wpos, hpos = torch.meshgrid(xpos, ypos, indexing="xy")
|
| 260 |
+
return hpos.flatten(), wpos.flatten()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _get_image_token_masks(tokens, config):
|
| 264 |
+
spatial_mask = tokens == config.img_id
|
| 265 |
+
no_increase_mask = (
|
| 266 |
+
spatial_mask
|
| 267 |
+
| (tokens == config.image_reg_1_token_id)
|
| 268 |
+
| (tokens == config.image_reg_2_token_id)
|
| 269 |
+
| (tokens == config.image_reg_3_token_id)
|
| 270 |
+
| (tokens == config.image_reg_4_token_id)
|
| 271 |
+
| (tokens == config.img_end_id)
|
| 272 |
+
)
|
| 273 |
+
return spatial_mask, no_increase_mask
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_pos_thw(
|
| 277 |
+
tokens: torch.Tensor,
|
| 278 |
+
pixel_masks_NTHW: torch.Tensor,
|
| 279 |
+
config,
|
| 280 |
+
spatial_patch_size: int,
|
| 281 |
+
temporal_patch_size: int = 1,
|
| 282 |
+
pad_token_id: int = None,
|
| 283 |
+
):
|
| 284 |
+
assert pad_token_id is not None
|
| 285 |
+
assert tokens.ndim == 2
|
| 286 |
+
assert pixel_masks_NTHW.ndim == 4
|
| 287 |
+
|
| 288 |
+
spatial_img_token_mask_BS, no_increase_idx_img_token_mask_BS = _get_image_token_masks(tokens, config)
|
| 289 |
+
|
| 290 |
+
hpos_parts, wpos_parts = [], []
|
| 291 |
+
for i in range(pixel_masks_NTHW.shape[0]):
|
| 292 |
+
h, w = _compute_image_spatial_positions(pixel_masks_NTHW[i], spatial_patch_size, temporal_patch_size)
|
| 293 |
+
hpos_parts.append(h)
|
| 294 |
+
wpos_parts.append(w)
|
| 295 |
+
|
| 296 |
+
hpos_N = torch.cat(hpos_parts) if hpos_parts else torch.empty(0)
|
| 297 |
+
wpos_N = torch.cat(wpos_parts) if wpos_parts else torch.empty(0)
|
| 298 |
+
|
| 299 |
+
expected_tokens = spatial_img_token_mask_BS.sum().item()
|
| 300 |
+
actual_tokens = hpos_N.numel()
|
| 301 |
+
assert actual_tokens == expected_tokens, (
|
| 302 |
+
f"Mismatch between spatial image tokens ({expected_tokens}) and generated positions ({actual_tokens})."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
hpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device)
|
| 306 |
+
wpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device)
|
| 307 |
+
hpos_BS = hpos_BS.masked_scatter_(spatial_img_token_mask_BS, hpos_N)
|
| 308 |
+
wpos_BS = wpos_BS.masked_scatter_(spatial_img_token_mask_BS, wpos_N)
|
| 309 |
+
|
| 310 |
+
tpos_BS = torch.ones_like(tokens, dtype=torch.float, device=tokens.device)
|
| 311 |
+
tpos_BS[no_increase_idx_img_token_mask_BS] = 0
|
| 312 |
+
tpos_BS = torch.cumsum(tpos_BS, dim=1) - 1
|
| 313 |
+
tpos_BS[tokens == pad_token_id] = 0
|
| 314 |
+
|
| 315 |
+
hw_pos_BS2 = torch.stack([hpos_BS, wpos_BS], dim=-1)
|
| 316 |
+
return tpos_BS.long(), hw_pos_BS2
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def calculate_image_tokens(image, patch_size, merge_size):
|
| 320 |
+
height, width = get_image_size(image)
|
| 321 |
+
return int((height * width) / (patch_size * patch_size * merge_size * merge_size))
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def tokenize_inputs(prompt, images, tokenizer, config, patch_size, merge_size, max_length):
|
| 325 |
+
img_reg_ids = [
|
| 326 |
+
config.image_reg_1_token_id,
|
| 327 |
+
config.image_reg_2_token_id,
|
| 328 |
+
config.image_reg_3_token_id,
|
| 329 |
+
config.image_reg_4_token_id,
|
| 330 |
+
]
|
| 331 |
+
|
| 332 |
+
if images is not None and len(images) > 0:
|
| 333 |
+
image_token_counts = [calculate_image_tokens(image, patch_size, merge_size) for image in images]
|
| 334 |
+
else:
|
| 335 |
+
image_token_counts = []
|
| 336 |
+
|
| 337 |
+
image_token = tokenizer.convert_ids_to_tokens(config.img_id)
|
| 338 |
+
prompt_chunks = [tokenizer.encode(chunk) for chunk in prompt.split(image_token)]
|
| 339 |
+
|
| 340 |
+
def insert_separator(X, sep):
|
| 341 |
+
return [ele for sublist in zip(X, sep) for ele in sublist][:-1]
|
| 342 |
+
|
| 343 |
+
input_ids = []
|
| 344 |
+
offset = 0
|
| 345 |
+
bos_id = getattr(tokenizer, "bos_token_id", None)
|
| 346 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and bos_id is not None and prompt_chunks[0][0] == bos_id:
|
| 347 |
+
offset = 1
|
| 348 |
+
input_ids.append(prompt_chunks[0][0])
|
| 349 |
+
|
| 350 |
+
separators = []
|
| 351 |
+
for count in image_token_counts:
|
| 352 |
+
tokens = [config.img_id] * count
|
| 353 |
+
image_block = [config.image_cls_token_id, *img_reg_ids, *tokens, config.img_end_id]
|
| 354 |
+
separators.append(image_block)
|
| 355 |
+
|
| 356 |
+
if len(separators) != 0 and len(separators) != len(prompt_chunks):
|
| 357 |
+
separators.append(separators[-1])
|
| 358 |
+
|
| 359 |
+
selected_images = []
|
| 360 |
+
if len(separators) == 0:
|
| 361 |
+
input_ids = prompt_chunks[0]
|
| 362 |
+
else:
|
| 363 |
+
for index, x in enumerate(insert_separator(prompt_chunks, separators)):
|
| 364 |
+
if index % 2 != 0:
|
| 365 |
+
if (len(input_ids) + len(x)) < max_length:
|
| 366 |
+
input_ids.extend(x)
|
| 367 |
+
selected_images.append(images[index // 2])
|
| 368 |
+
elif index % 2 == 0:
|
| 369 |
+
input_ids.extend(x[offset:])
|
| 370 |
+
|
| 371 |
+
input_ids = torch.LongTensor(input_ids)
|
| 372 |
+
return input_ids, selected_images
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def process_batch(
|
| 376 |
+
tokenizer,
|
| 377 |
+
config,
|
| 378 |
+
image_prompt_pairs,
|
| 379 |
+
max_length,
|
| 380 |
+
min_dimension,
|
| 381 |
+
max_dimension,
|
| 382 |
+
patch_size=16,
|
| 383 |
+
merge_size=1,
|
| 384 |
+
):
|
| 385 |
+
"""
|
| 386 |
+
Process a batch of images with text prompts.
|
| 387 |
+
Uses LEFT PADDING for proper batch generation with causal models.
|
| 388 |
+
"""
|
| 389 |
+
all_input_ids = []
|
| 390 |
+
all_selected_images = []
|
| 391 |
+
processor_local = ImageProcessor(patch_size, merge_size)
|
| 392 |
+
|
| 393 |
+
for img_input, prompt in image_prompt_pairs:
|
| 394 |
+
img = load_image(img_input)
|
| 395 |
+
if img is not None:
|
| 396 |
+
img = resize_image_if_necessary(img, min_dimension, max_dimension)
|
| 397 |
+
images = processor_local.preprocess(images=[img] if img else [])
|
| 398 |
+
input_ids, selected_images = tokenize_inputs(
|
| 399 |
+
prompt, images, tokenizer, config, patch_size, merge_size, max_length,
|
| 400 |
+
)
|
| 401 |
+
all_input_ids.append(input_ids)
|
| 402 |
+
all_selected_images.extend(selected_images)
|
| 403 |
+
|
| 404 |
+
pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
|
| 405 |
+
padded_input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 406 |
+
all_input_ids, batch_first=True, padding_value=pad_token_id, padding_side="left",
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
processed = processor_local.batch_images_with_mask(all_selected_images, max_dimension, max_dimension)
|
| 410 |
+
assert processed is not None
|
| 411 |
+
|
| 412 |
+
pos_t, pos_hw = get_pos_thw(
|
| 413 |
+
padded_input_ids, processed["padding_mask"], config, patch_size, pad_token_id=pad_token_id,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return {
|
| 417 |
+
"tokens": padded_input_ids,
|
| 418 |
+
"pixel_values": processed["pixel_values"],
|
| 419 |
+
"pixel_mask": processed["padding_mask"],
|
| 420 |
+
"pos_t": pos_t,
|
| 421 |
+
"pos_hw": pos_hw,
|
| 422 |
+
"pad_token_id": pad_token_id,
|
| 423 |
+
}
|
rope.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops as E
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 6 |
+
"""
|
| 7 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 8 |
+
|
| 9 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 10 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 11 |
+
The returned tensor contains complex values in complex64 data type.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
dim (int): Dimension of the frequency tensor.
|
| 15 |
+
end (int): End index for precomputing frequencies.
|
| 16 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
| 20 |
+
"""
|
| 21 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 22 |
+
t = torch.arange(end, device=freqs.device)
|
| 23 |
+
freqs = torch.outer(t, freqs).float()
|
| 24 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 25 |
+
return freqs_cis # [S, D//2]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_rotary_emb(
|
| 29 |
+
xq: torch.Tensor,
|
| 30 |
+
xk: torch.Tensor,
|
| 31 |
+
freqs_cis: torch.Tensor,
|
| 32 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 33 |
+
"""1D rotary embedding"""
|
| 34 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 35 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 36 |
+
assert freqs_cis.ndim == 3, (
|
| 37 |
+
"Freqs_cis must be indexed by position ids already and has shape (B,S,D)"
|
| 38 |
+
)
|
| 39 |
+
freqs_cis = E.rearrange(freqs_cis, "b s d -> b s 1 d")
|
| 40 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 41 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 42 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
###### 2D golden rope
|
| 46 |
+
"""
|
| 47 |
+
Dimension key:
|
| 48 |
+
B: batch size
|
| 49 |
+
S: number of tokens per sample, Seqlen
|
| 50 |
+
T: Number of selected Tokens
|
| 51 |
+
P: pos_dim
|
| 52 |
+
h: n_heads
|
| 53 |
+
d: head_dim
|
| 54 |
+
F: num_freqs == head_dim // 2
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def apply_golden_freqs_cis_to_visual_pos(freqs_hFP, pos_BSP) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
This function is applied once per input batch, and the cached
|
| 61 |
+
freqs_cis is passed through to all layers.
|
| 62 |
+
Safe for Torch‑Inductor because it never uses boolean indexing on a symbolic tensor.
|
| 63 |
+
"""
|
| 64 |
+
# 1. Boolean mask → integer indices (no unbacked shapes)
|
| 65 |
+
img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
|
| 66 |
+
idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # each shape: (N,)
|
| 67 |
+
|
| 68 |
+
# 2. Gather the positional tensor for those tokens
|
| 69 |
+
pos_tP = pos_BSP[idx_b, idx_s].float() # (N, p)
|
| 70 |
+
|
| 71 |
+
# 3. Project positions onto the frequency table → angles θ
|
| 72 |
+
theta_thF = torch.einsum("tp,hfp->thf", pos_tP, freqs_hFP.float()) # (t, h, f)
|
| 73 |
+
|
| 74 |
+
# 4. Convert to complex numbers on the unit circle
|
| 75 |
+
freqs_cis_thF = torch.polar(torch.ones_like(theta_thF), theta_thF)
|
| 76 |
+
return freqs_cis_thF
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def apply_golden_rotary_emb(input_BShd, freqs_cis_thF, pos_BSP) -> torch.Tensor:
|
| 80 |
+
"""
|
| 81 |
+
Rotates *only* the image tokens in `input_BShd`. No boolean indexing,
|
| 82 |
+
so it is safe for Torch‑Inductor.
|
| 83 |
+
"""
|
| 84 |
+
img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
|
| 85 |
+
idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # (N,)
|
| 86 |
+
|
| 87 |
+
input_thd = input_BShd[idx_b, idx_s].float() # (N, h, d)
|
| 88 |
+
x_even = input_thd[..., 0::2] # (N, h, F)
|
| 89 |
+
x_odd = input_thd[..., 1::2] # (N, h, F)
|
| 90 |
+
|
| 91 |
+
cos_thF = freqs_cis_thF.real
|
| 92 |
+
sin_thF = freqs_cis_thF.imag
|
| 93 |
+
|
| 94 |
+
# (a + ib) * (c + id) = (ac - bd) + i(ad + bc)
|
| 95 |
+
rot_even = x_even * cos_thF - x_odd * sin_thF
|
| 96 |
+
rot_odd = x_even * sin_thF + x_odd * cos_thF
|
| 97 |
+
|
| 98 |
+
output_real = torch.empty_like(input_thd)
|
| 99 |
+
output_real[..., 0::2] = rot_even
|
| 100 |
+
output_real[..., 1::2] = rot_odd
|
| 101 |
+
output_real = output_real.type_as(input_BShd)
|
| 102 |
+
|
| 103 |
+
output_BShd = input_BShd.clone()
|
| 104 |
+
output_BShd[idx_b, idx_s] = output_real
|
| 105 |
+
|
| 106 |
+
return output_BShd
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def apply_3d_rotary_emb(
|
| 110 |
+
xq: torch.Tensor, # (B, S, H, D)
|
| 111 |
+
xk: torch.Tensor, # (B, S, H, D)
|
| 112 |
+
freqs_cis: torch.Tensor,
|
| 113 |
+
freqs_cis_2d: torch.Tensor | None,
|
| 114 |
+
pos_hw: torch.Tensor | None, # (B,S,3)
|
| 115 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 116 |
+
xq_t, xq_hw = xq.chunk(chunks=2, dim=-1)
|
| 117 |
+
xk_t, xk_hw = xk.chunk(chunks=2, dim=-1)
|
| 118 |
+
B, S, H, D = xq.shape
|
| 119 |
+
|
| 120 |
+
xq_t, xk_t = apply_rotary_emb(xq_t, xk_t, freqs_cis)
|
| 121 |
+
if freqs_cis_2d is not None and pos_hw is not None:
|
| 122 |
+
xq_hw = apply_golden_rotary_emb(xq_hw, freqs_cis_2d, pos_hw)
|
| 123 |
+
xk_hw = apply_golden_rotary_emb(xk_hw, freqs_cis_2d, pos_hw)
|
| 124 |
+
|
| 125 |
+
xq_out = torch.concat([xq_t, xq_hw], dim=-1).type_as(xq)
|
| 126 |
+
xk_out = torch.concat([xk_t, xk_hw], dim=-1).type_as(xk)
|
| 127 |
+
return xq_out, xk_out
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|pad|>",
|
| 4 |
+
">>ABSTRACT<<",
|
| 5 |
+
">>INTRODUCTION<<",
|
| 6 |
+
">>SUMMARY<<",
|
| 7 |
+
">>COMMENT<<",
|
| 8 |
+
">>ANSWER<<",
|
| 9 |
+
">>QUESTION<<",
|
| 10 |
+
">>DOMAIN<<",
|
| 11 |
+
">>PREFIX<<",
|
| 12 |
+
">>SUFFIX<<",
|
| 13 |
+
">>MIDDLE<<",
|
| 14 |
+
"<|finetune_right_pad_id|>",
|
| 15 |
+
"<|start_header_id|>",
|
| 16 |
+
"<|end_header_id|>",
|
| 17 |
+
"<|eom_id|>",
|
| 18 |
+
"<|eot_id|>",
|
| 19 |
+
"<|begin_of_text|>",
|
| 20 |
+
">>TITLE<<",
|
| 21 |
+
"<tool_response>",
|
| 22 |
+
"</tool_response>",
|
| 23 |
+
"<tool_call>",
|
| 24 |
+
"</tool_call>",
|
| 25 |
+
"<schema>",
|
| 26 |
+
"</schema>",
|
| 27 |
+
"<scratch_pad>",
|
| 28 |
+
"</scratch_pad>",
|
| 29 |
+
"<thinking>",
|
| 30 |
+
"</thinking>",
|
| 31 |
+
"<explanation>",
|
| 32 |
+
"</explanation>",
|
| 33 |
+
"<file_sep>",
|
| 34 |
+
"<repo_name>",
|
| 35 |
+
"<tr>",
|
| 36 |
+
"</tr>",
|
| 37 |
+
"<|image|>",
|
| 38 |
+
"<|image_row_sep|>",
|
| 39 |
+
"<|start_of_image|>",
|
| 40 |
+
"<|end_of_image|>",
|
| 41 |
+
"<|start_of_video|>",
|
| 42 |
+
"<|end_of_video|>",
|
| 43 |
+
"<|frame_sep|>",
|
| 44 |
+
"<|start_of_turn|>",
|
| 45 |
+
"<|end_of_turn|>",
|
| 46 |
+
"<|start_of_diffusion_query|>",
|
| 47 |
+
"<|end_of_diffusion_query|>",
|
| 48 |
+
"<|diffusion_query|>",
|
| 49 |
+
"<|object|>",
|
| 50 |
+
"<|coord|>",
|
| 51 |
+
"<|size|>",
|
| 52 |
+
"<|perceive|>",
|
| 53 |
+
"<|image_mask_token|>",
|
| 54 |
+
"<|image_cls|>",
|
| 55 |
+
"<|image_reg_1|>",
|
| 56 |
+
"<|image_reg_2|>",
|
| 57 |
+
"<|image_reg_3|>",
|
| 58 |
+
"<|image_reg_4|>",
|
| 59 |
+
"<|image_reg_5|>",
|
| 60 |
+
"<|image_reg_6|>",
|
| 61 |
+
"<|image_reg_7|>",
|
| 62 |
+
"<|image_reg_8|>",
|
| 63 |
+
"<|DET|>",
|
| 64 |
+
"<|POINTING|>",
|
| 65 |
+
"<|OCR_GROUNDING|>",
|
| 66 |
+
"<|OCR_DOC_PARSER|>",
|
| 67 |
+
"<|OCR_PLAIN|>",
|
| 68 |
+
"<|REF_SEG|>",
|
| 69 |
+
"<|POINT_REF_SEG|>",
|
| 70 |
+
"<|CAPTION|>",
|
| 71 |
+
"<|DETAILED_CAPTION|>",
|
| 72 |
+
"<|seg|>",
|
| 73 |
+
"<|end_of_query|>",
|
| 74 |
+
"<|start_of_query|>",
|
| 75 |
+
"<|task_sep|>",
|
| 76 |
+
"<|QA|>",
|
| 77 |
+
"<|LAYOUT_DETECTION|>",
|
| 78 |
+
"<|category_sep|>",
|
| 79 |
+
"<td>",
|
| 80 |
+
"</td>",
|
| 81 |
+
"<th>",
|
| 82 |
+
"</th>",
|
| 83 |
+
">>UNUSED_261<<",
|
| 84 |
+
">>UNUSED_262<<",
|
| 85 |
+
">>UNUSED_263<<",
|
| 86 |
+
">>UNUSED_264<<",
|
| 87 |
+
">>UNUSED_265<<",
|
| 88 |
+
">>UNUSED_266<<",
|
| 89 |
+
">>UNUSED_267<<",
|
| 90 |
+
">>UNUSED_268<<",
|
| 91 |
+
">>UNUSED_269<<",
|
| 92 |
+
">>UNUSED_270<<",
|
| 93 |
+
">>UNUSED_271<<",
|
| 94 |
+
">>UNUSED_272<<",
|
| 95 |
+
">>UNUSED_273<<",
|
| 96 |
+
">>UNUSED_274<<",
|
| 97 |
+
">>UNUSED_275<<",
|
| 98 |
+
">>UNUSED_276<<",
|
| 99 |
+
">>UNUSED_277<<",
|
| 100 |
+
">>UNUSED_278<<",
|
| 101 |
+
">>UNUSED_279<<",
|
| 102 |
+
">>UNUSED_280<<",
|
| 103 |
+
">>UNUSED_281<<",
|
| 104 |
+
">>UNUSED_282<<",
|
| 105 |
+
">>UNUSED_283<<",
|
| 106 |
+
">>UNUSED_284<<",
|
| 107 |
+
">>UNUSED_285<<",
|
| 108 |
+
">>UNUSED_286<<",
|
| 109 |
+
">>UNUSED_287<<",
|
| 110 |
+
">>UNUSED_288<<",
|
| 111 |
+
">>UNUSED_289<<",
|
| 112 |
+
">>UNUSED_290<<",
|
| 113 |
+
">>UNUSED_291<<",
|
| 114 |
+
">>UNUSED_292<<",
|
| 115 |
+
">>UNUSED_293<<",
|
| 116 |
+
">>UNUSED_294<<",
|
| 117 |
+
">>UNUSED_295<<",
|
| 118 |
+
">>UNUSED_296<<",
|
| 119 |
+
">>UNUSED_297<<",
|
| 120 |
+
">>UNUSED_298<<",
|
| 121 |
+
">>UNUSED_299<<",
|
| 122 |
+
">>UNUSED_300<<",
|
| 123 |
+
">>UNUSED_301<<",
|
| 124 |
+
">>UNUSED_302<<",
|
| 125 |
+
">>UNUSED_303<<",
|
| 126 |
+
">>UNUSED_304<<",
|
| 127 |
+
">>UNUSED_305<<",
|
| 128 |
+
">>UNUSED_306<<",
|
| 129 |
+
">>UNUSED_307<<",
|
| 130 |
+
">>UNUSED_308<<",
|
| 131 |
+
">>UNUSED_309<<",
|
| 132 |
+
">>UNUSED_310<<",
|
| 133 |
+
">>UNUSED_311<<",
|
| 134 |
+
">>UNUSED_312<<",
|
| 135 |
+
">>UNUSED_313<<",
|
| 136 |
+
">>UNUSED_314<<",
|
| 137 |
+
">>UNUSED_315<<",
|
| 138 |
+
">>UNUSED_316<<",
|
| 139 |
+
">>UNUSED_317<<",
|
| 140 |
+
">>UNUSED_318<<",
|
| 141 |
+
">>UNUSED_319<<",
|
| 142 |
+
">>UNUSED_320<<",
|
| 143 |
+
">>UNUSED_321<<",
|
| 144 |
+
">>UNUSED_322<<",
|
| 145 |
+
">>UNUSED_323<<",
|
| 146 |
+
">>UNUSED_324<<",
|
| 147 |
+
">>UNUSED_325<<",
|
| 148 |
+
">>UNUSED_326<<",
|
| 149 |
+
">>UNUSED_327<<",
|
| 150 |
+
">>UNUSED_328<<",
|
| 151 |
+
">>UNUSED_329<<",
|
| 152 |
+
">>UNUSED_330<<",
|
| 153 |
+
">>UNUSED_331<<",
|
| 154 |
+
">>UNUSED_332<<",
|
| 155 |
+
">>UNUSED_333<<",
|
| 156 |
+
">>UNUSED_334<<",
|
| 157 |
+
">>UNUSED_335<<",
|
| 158 |
+
">>UNUSED_336<<",
|
| 159 |
+
">>UNUSED_337<<",
|
| 160 |
+
">>UNUSED_338<<",
|
| 161 |
+
">>UNUSED_339<<",
|
| 162 |
+
">>UNUSED_340<<",
|
| 163 |
+
">>UNUSED_341<<",
|
| 164 |
+
">>UNUSED_342<<",
|
| 165 |
+
">>UNUSED_343<<",
|
| 166 |
+
">>UNUSED_344<<",
|
| 167 |
+
">>UNUSED_345<<",
|
| 168 |
+
">>UNUSED_346<<",
|
| 169 |
+
">>UNUSED_347<<",
|
| 170 |
+
">>UNUSED_348<<",
|
| 171 |
+
">>UNUSED_349<<",
|
| 172 |
+
">>UNUSED_350<<",
|
| 173 |
+
">>UNUSED_351<<",
|
| 174 |
+
">>UNUSED_352<<",
|
| 175 |
+
">>UNUSED_353<<",
|
| 176 |
+
">>UNUSED_354<<",
|
| 177 |
+
">>UNUSED_355<<",
|
| 178 |
+
">>UNUSED_356<<",
|
| 179 |
+
">>UNUSED_357<<",
|
| 180 |
+
">>UNUSED_358<<",
|
| 181 |
+
">>UNUSED_359<<",
|
| 182 |
+
">>UNUSED_360<<",
|
| 183 |
+
">>UNUSED_361<<",
|
| 184 |
+
">>UNUSED_362<<",
|
| 185 |
+
">>UNUSED_363<<",
|
| 186 |
+
">>UNUSED_364<<",
|
| 187 |
+
">>UNUSED_365<<",
|
| 188 |
+
">>UNUSED_366<<",
|
| 189 |
+
">>UNUSED_367<<",
|
| 190 |
+
">>UNUSED_368<<",
|
| 191 |
+
">>UNUSED_369<<",
|
| 192 |
+
">>UNUSED_370<<",
|
| 193 |
+
">>UNUSED_371<<",
|
| 194 |
+
">>UNUSED_372<<",
|
| 195 |
+
">>UNUSED_373<<",
|
| 196 |
+
">>UNUSED_374<<",
|
| 197 |
+
">>UNUSED_375<<",
|
| 198 |
+
">>UNUSED_376<<",
|
| 199 |
+
">>UNUSED_377<<",
|
| 200 |
+
">>UNUSED_378<<",
|
| 201 |
+
">>UNUSED_379<<",
|
| 202 |
+
">>UNUSED_380<<",
|
| 203 |
+
">>UNUSED_381<<",
|
| 204 |
+
">>UNUSED_382<<",
|
| 205 |
+
">>UNUSED_383<<",
|
| 206 |
+
">>UNUSED_384<<",
|
| 207 |
+
">>UNUSED_385<<",
|
| 208 |
+
">>UNUSED_386<<",
|
| 209 |
+
">>UNUSED_387<<",
|
| 210 |
+
">>UNUSED_388<<",
|
| 211 |
+
">>UNUSED_389<<",
|
| 212 |
+
">>UNUSED_390<<",
|
| 213 |
+
">>UNUSED_391<<",
|
| 214 |
+
">>UNUSED_392<<",
|
| 215 |
+
">>UNUSED_393<<",
|
| 216 |
+
">>UNUSED_394<<",
|
| 217 |
+
">>UNUSED_395<<",
|
| 218 |
+
">>UNUSED_396<<",
|
| 219 |
+
">>UNUSED_397<<",
|
| 220 |
+
">>UNUSED_398<<",
|
| 221 |
+
">>UNUSED_399<<",
|
| 222 |
+
">>UNUSED_400<<",
|
| 223 |
+
">>UNUSED_401<<",
|
| 224 |
+
">>UNUSED_402<<",
|
| 225 |
+
">>UNUSED_403<<",
|
| 226 |
+
">>UNUSED_404<<",
|
| 227 |
+
">>UNUSED_405<<",
|
| 228 |
+
">>UNUSED_406<<",
|
| 229 |
+
">>UNUSED_407<<",
|
| 230 |
+
">>UNUSED_408<<",
|
| 231 |
+
">>UNUSED_409<<",
|
| 232 |
+
">>UNUSED_410<<",
|
| 233 |
+
">>UNUSED_411<<",
|
| 234 |
+
">>UNUSED_412<<",
|
| 235 |
+
">>UNUSED_413<<",
|
| 236 |
+
">>UNUSED_414<<",
|
| 237 |
+
">>UNUSED_415<<",
|
| 238 |
+
">>UNUSED_416<<",
|
| 239 |
+
">>UNUSED_417<<",
|
| 240 |
+
">>UNUSED_418<<",
|
| 241 |
+
">>UNUSED_419<<",
|
| 242 |
+
">>UNUSED_420<<",
|
| 243 |
+
">>UNUSED_421<<",
|
| 244 |
+
">>UNUSED_422<<",
|
| 245 |
+
">>UNUSED_423<<",
|
| 246 |
+
">>UNUSED_424<<",
|
| 247 |
+
">>UNUSED_425<<",
|
| 248 |
+
">>UNUSED_426<<",
|
| 249 |
+
">>UNUSED_427<<",
|
| 250 |
+
">>UNUSED_428<<",
|
| 251 |
+
">>UNUSED_429<<",
|
| 252 |
+
">>UNUSED_430<<",
|
| 253 |
+
">>UNUSED_431<<",
|
| 254 |
+
">>UNUSED_432<<",
|
| 255 |
+
">>UNUSED_433<<",
|
| 256 |
+
">>UNUSED_434<<",
|
| 257 |
+
">>UNUSED_435<<",
|
| 258 |
+
">>UNUSED_436<<",
|
| 259 |
+
">>UNUSED_437<<",
|
| 260 |
+
">>UNUSED_438<<",
|
| 261 |
+
">>UNUSED_439<<",
|
| 262 |
+
">>UNUSED_440<<",
|
| 263 |
+
">>UNUSED_441<<",
|
| 264 |
+
">>UNUSED_442<<",
|
| 265 |
+
">>UNUSED_443<<",
|
| 266 |
+
">>UNUSED_444<<",
|
| 267 |
+
">>UNUSED_445<<",
|
| 268 |
+
">>UNUSED_446<<",
|
| 269 |
+
">>UNUSED_447<<",
|
| 270 |
+
">>UNUSED_448<<",
|
| 271 |
+
">>UNUSED_449<<",
|
| 272 |
+
">>UNUSED_450<<",
|
| 273 |
+
">>UNUSED_451<<",
|
| 274 |
+
">>UNUSED_452<<",
|
| 275 |
+
">>UNUSED_453<<",
|
| 276 |
+
">>UNUSED_454<<",
|
| 277 |
+
">>UNUSED_455<<",
|
| 278 |
+
">>UNUSED_456<<",
|
| 279 |
+
">>UNUSED_457<<",
|
| 280 |
+
">>UNUSED_458<<",
|
| 281 |
+
">>UNUSED_459<<",
|
| 282 |
+
">>UNUSED_460<<",
|
| 283 |
+
">>UNUSED_461<<",
|
| 284 |
+
">>UNUSED_462<<",
|
| 285 |
+
">>UNUSED_463<<",
|
| 286 |
+
">>UNUSED_464<<",
|
| 287 |
+
">>UNUSED_465<<",
|
| 288 |
+
">>UNUSED_466<<",
|
| 289 |
+
">>UNUSED_467<<",
|
| 290 |
+
">>UNUSED_468<<",
|
| 291 |
+
">>UNUSED_469<<",
|
| 292 |
+
">>UNUSED_470<<",
|
| 293 |
+
">>UNUSED_471<<",
|
| 294 |
+
">>UNUSED_472<<",
|
| 295 |
+
">>UNUSED_473<<",
|
| 296 |
+
">>UNUSED_474<<",
|
| 297 |
+
">>UNUSED_475<<",
|
| 298 |
+
">>UNUSED_476<<",
|
| 299 |
+
">>UNUSED_477<<",
|
| 300 |
+
">>UNUSED_478<<",
|
| 301 |
+
">>UNUSED_479<<",
|
| 302 |
+
">>UNUSED_480<<",
|
| 303 |
+
">>UNUSED_481<<",
|
| 304 |
+
">>UNUSED_482<<",
|
| 305 |
+
">>UNUSED_483<<",
|
| 306 |
+
">>UNUSED_484<<",
|
| 307 |
+
">>UNUSED_485<<",
|
| 308 |
+
">>UNUSED_486<<",
|
| 309 |
+
">>UNUSED_487<<",
|
| 310 |
+
">>UNUSED_488<<",
|
| 311 |
+
">>UNUSED_489<<",
|
| 312 |
+
">>UNUSED_490<<",
|
| 313 |
+
">>UNUSED_491<<",
|
| 314 |
+
">>UNUSED_492<<",
|
| 315 |
+
">>UNUSED_493<<",
|
| 316 |
+
">>UNUSED_494<<",
|
| 317 |
+
">>UNUSED_495<<",
|
| 318 |
+
">>UNUSED_496<<",
|
| 319 |
+
">>UNUSED_497<<",
|
| 320 |
+
">>UNUSED_498<<",
|
| 321 |
+
">>UNUSED_499<<",
|
| 322 |
+
">>UNUSED_500<<",
|
| 323 |
+
">>UNUSED_501<<",
|
| 324 |
+
">>UNUSED_502<<",
|
| 325 |
+
">>UNUSED_503<<",
|
| 326 |
+
">>UNUSED_504<<",
|
| 327 |
+
">>UNUSED_505<<",
|
| 328 |
+
">>UNUSED_506<<",
|
| 329 |
+
">>UNUSED_507<<",
|
| 330 |
+
">>UNUSED_508<<",
|
| 331 |
+
">>UNUSED_509<<",
|
| 332 |
+
">>UNUSED_510<<",
|
| 333 |
+
">>UNUSED_511<<"
|
| 334 |
+
],
|
| 335 |
+
"eos_token": {
|
| 336 |
+
"content": "<|end_of_text|>",
|
| 337 |
+
"lstrip": false,
|
| 338 |
+
"normalized": false,
|
| 339 |
+
"rstrip": false,
|
| 340 |
+
"single_word": false
|
| 341 |
+
},
|
| 342 |
+
"image_token": "<|image|>",
|
| 343 |
+
"image_cls_token": "<|image_cls|>",
|
| 344 |
+
"image_reg_1_token": "<|image_reg_1|>",
|
| 345 |
+
"image_reg_2_token": "<|image_reg_2|>",
|
| 346 |
+
"image_reg_3_token": "<|image_reg_3|>",
|
| 347 |
+
"image_reg_4_token": "<|image_reg_4|>",
|
| 348 |
+
"image_reg_5_token": "<|image_reg_5|>",
|
| 349 |
+
"image_reg_6_token": "<|image_reg_6|>",
|
| 350 |
+
"image_reg_7_token": "<|image_reg_7|>",
|
| 351 |
+
"image_reg_8_token": "<|image_reg_8|>",
|
| 352 |
+
"image_row_sep_token": "<|image_row_sep|>",
|
| 353 |
+
"start_of_image_token": "<|start_of_image|>",
|
| 354 |
+
"end_of_image_token": "<|end_of_image|>",
|
| 355 |
+
"start_of_video_token": "<|start_of_video|>",
|
| 356 |
+
"end_of_video_token": "<|end_of_video|>",
|
| 357 |
+
"frame_sep_token": "<|frame_sep|>",
|
| 358 |
+
"start_of_turn_token": "<|start_of_turn|>",
|
| 359 |
+
"end_of_turn_token": "<|end_of_turn|>",
|
| 360 |
+
"start_of_diffusion_query_token": "<|start_of_diffusion_query|>",
|
| 361 |
+
"end_of_diffusion_query_token": "<|end_of_diffusion_query|>",
|
| 362 |
+
"diffusion_query_token": "<|diffusion_query|>",
|
| 363 |
+
"object_token": "<|object|>",
|
| 364 |
+
"coord_token": "<|coord|>",
|
| 365 |
+
"size_token": "<|size|>",
|
| 366 |
+
"perceive_token": "<|perceive|>",
|
| 367 |
+
"image_mask_token": "<|image_mask_token|>",
|
| 368 |
+
"det_token": "<|DET|>",
|
| 369 |
+
"pointing_token": "<|POINTING|>",
|
| 370 |
+
"ocr_grounding_token": "<|OCR_GROUNDING|>",
|
| 371 |
+
"ocr_doc_parser_token": "<|OCR_DOC_PARSER|>",
|
| 372 |
+
"ocr_plain_token": "<|OCR_PLAIN|>",
|
| 373 |
+
"ref_seg_token": "<|REF_SEG|>",
|
| 374 |
+
"point_ref_seg_token": "<|POINT_REF_SEG|>",
|
| 375 |
+
"caption_token": "<|CAPTION|>",
|
| 376 |
+
"detailed_caption_token": "<|DETAILED_CAPTION|>",
|
| 377 |
+
"seg_token": "<|seg|>",
|
| 378 |
+
"start_of_query_token": "<|start_of_query|>",
|
| 379 |
+
"end_of_query_token": "<|end_of_query|>",
|
| 380 |
+
"task_sep_token": "<|task_sep|>",
|
| 381 |
+
"qa_token": "<|QA|>",
|
| 382 |
+
"layout_detection_token": "<|LAYOUT_DETECTION|>",
|
| 383 |
+
"category_sep_token": "<|category_sep|>",
|
| 384 |
+
"table_row_start_token": "<tr>",
|
| 385 |
+
"table_row_end_token": "</tr>",
|
| 386 |
+
"table_data_start_token": "<td>",
|
| 387 |
+
"table_data_end_token": "</td>",
|
| 388 |
+
"table_header_start_token": "<th>",
|
| 389 |
+
"table_header_end_token": "</th>"
|
| 390 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"caption_token": "<|CAPTION|>",
|
| 4 |
+
"category_sep_token": "<|category_sep|>",
|
| 5 |
+
"clean_up_tokenization_spaces": true,
|
| 6 |
+
"coord_token": "<|coord|>",
|
| 7 |
+
"det_token": "<|DET|>",
|
| 8 |
+
"detailed_caption_token": "<|DETAILED_CAPTION|>",
|
| 9 |
+
"diffusion_query_token": "<|diffusion_query|>",
|
| 10 |
+
"end_of_diffusion_query_token": "<|end_of_diffusion_query|>",
|
| 11 |
+
"end_of_image_token": "<|end_of_image|>",
|
| 12 |
+
"end_of_query_token": "<|end_of_query|>",
|
| 13 |
+
"end_of_turn_token": "<|end_of_turn|>",
|
| 14 |
+
"end_of_video_token": "<|end_of_video|>",
|
| 15 |
+
"eos_token": "<|end_of_text|>",
|
| 16 |
+
"frame_sep_token": "<|frame_sep|>",
|
| 17 |
+
"image_cls_token": "<|image_cls|>",
|
| 18 |
+
"image_mask_token": "<|image_mask_token|>",
|
| 19 |
+
"image_reg_1_token": "<|image_reg_1|>",
|
| 20 |
+
"image_reg_2_token": "<|image_reg_2|>",
|
| 21 |
+
"image_reg_3_token": "<|image_reg_3|>",
|
| 22 |
+
"image_reg_4_token": "<|image_reg_4|>",
|
| 23 |
+
"image_reg_5_token": "<|image_reg_5|>",
|
| 24 |
+
"image_reg_6_token": "<|image_reg_6|>",
|
| 25 |
+
"image_reg_7_token": "<|image_reg_7|>",
|
| 26 |
+
"image_reg_8_token": "<|image_reg_8|>",
|
| 27 |
+
"image_row_sep_token": "<|image_row_sep|>",
|
| 28 |
+
"image_token": "<|image|>",
|
| 29 |
+
"is_local": true,
|
| 30 |
+
"layout_detection_token": "<|LAYOUT_DETECTION|>",
|
| 31 |
+
"model_input_names": [
|
| 32 |
+
"input_ids",
|
| 33 |
+
"attention_mask"
|
| 34 |
+
],
|
| 35 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 36 |
+
"model_specific_special_tokens": {
|
| 37 |
+
"caption_token": "<|CAPTION|>",
|
| 38 |
+
"category_sep_token": "<|category_sep|>",
|
| 39 |
+
"coord_token": "<|coord|>",
|
| 40 |
+
"det_token": "<|DET|>",
|
| 41 |
+
"detailed_caption_token": "<|DETAILED_CAPTION|>",
|
| 42 |
+
"diffusion_query_token": "<|diffusion_query|>",
|
| 43 |
+
"end_of_diffusion_query_token": "<|end_of_diffusion_query|>",
|
| 44 |
+
"end_of_image_token": "<|end_of_image|>",
|
| 45 |
+
"end_of_query_token": "<|end_of_query|>",
|
| 46 |
+
"end_of_turn_token": "<|end_of_turn|>",
|
| 47 |
+
"end_of_video_token": "<|end_of_video|>",
|
| 48 |
+
"frame_sep_token": "<|frame_sep|>",
|
| 49 |
+
"image_cls_token": "<|image_cls|>",
|
| 50 |
+
"image_mask_token": "<|image_mask_token|>",
|
| 51 |
+
"image_reg_1_token": "<|image_reg_1|>",
|
| 52 |
+
"image_reg_2_token": "<|image_reg_2|>",
|
| 53 |
+
"image_reg_3_token": "<|image_reg_3|>",
|
| 54 |
+
"image_reg_4_token": "<|image_reg_4|>",
|
| 55 |
+
"image_reg_5_token": "<|image_reg_5|>",
|
| 56 |
+
"image_reg_6_token": "<|image_reg_6|>",
|
| 57 |
+
"image_reg_7_token": "<|image_reg_7|>",
|
| 58 |
+
"image_reg_8_token": "<|image_reg_8|>",
|
| 59 |
+
"image_row_sep_token": "<|image_row_sep|>",
|
| 60 |
+
"image_token": "<|image|>",
|
| 61 |
+
"layout_detection_token": "<|LAYOUT_DETECTION|>",
|
| 62 |
+
"object_token": "<|object|>",
|
| 63 |
+
"ocr_doc_parser_token": "<|OCR_DOC_PARSER|>",
|
| 64 |
+
"ocr_grounding_token": "<|OCR_GROUNDING|>",
|
| 65 |
+
"ocr_plain_token": "<|OCR_PLAIN|>",
|
| 66 |
+
"perceive_token": "<|perceive|>",
|
| 67 |
+
"point_ref_seg_token": "<|POINT_REF_SEG|>",
|
| 68 |
+
"pointing_token": "<|POINTING|>",
|
| 69 |
+
"qa_token": "<|QA|>",
|
| 70 |
+
"ref_seg_token": "<|REF_SEG|>",
|
| 71 |
+
"seg_token": "<|seg|>",
|
| 72 |
+
"size_token": "<|size|>",
|
| 73 |
+
"start_of_diffusion_query_token": "<|start_of_diffusion_query|>",
|
| 74 |
+
"start_of_image_token": "<|start_of_image|>",
|
| 75 |
+
"start_of_query_token": "<|start_of_query|>",
|
| 76 |
+
"start_of_turn_token": "<|start_of_turn|>",
|
| 77 |
+
"start_of_video_token": "<|start_of_video|>",
|
| 78 |
+
"table_data_end_token": "</td>",
|
| 79 |
+
"table_data_start_token": "<td>",
|
| 80 |
+
"table_header_end_token": "</th>",
|
| 81 |
+
"table_header_start_token": "<th>",
|
| 82 |
+
"table_row_end_token": "</tr>",
|
| 83 |
+
"table_row_start_token": "<tr>",
|
| 84 |
+
"task_sep_token": "<|task_sep|>"
|
| 85 |
+
},
|
| 86 |
+
"object_token": "<|object|>",
|
| 87 |
+
"ocr_doc_parser_token": "<|OCR_DOC_PARSER|>",
|
| 88 |
+
"ocr_grounding_token": "<|OCR_GROUNDING|>",
|
| 89 |
+
"ocr_plain_token": "<|OCR_PLAIN|>",
|
| 90 |
+
"perceive_token": "<|perceive|>",
|
| 91 |
+
"point_ref_seg_token": "<|POINT_REF_SEG|>",
|
| 92 |
+
"pointing_token": "<|POINTING|>",
|
| 93 |
+
"qa_token": "<|QA|>",
|
| 94 |
+
"ref_seg_token": "<|REF_SEG|>",
|
| 95 |
+
"seg_token": "<|seg|>",
|
| 96 |
+
"size_token": "<|size|>",
|
| 97 |
+
"start_of_diffusion_query_token": "<|start_of_diffusion_query|>",
|
| 98 |
+
"start_of_image_token": "<|start_of_image|>",
|
| 99 |
+
"start_of_query_token": "<|start_of_query|>",
|
| 100 |
+
"start_of_turn_token": "<|start_of_turn|>",
|
| 101 |
+
"start_of_video_token": "<|start_of_video|>",
|
| 102 |
+
"table_data_end_token": "</td>",
|
| 103 |
+
"table_data_start_token": "<td>",
|
| 104 |
+
"table_header_end_token": "</th>",
|
| 105 |
+
"table_header_start_token": "<th>",
|
| 106 |
+
"table_row_end_token": "</tr>",
|
| 107 |
+
"table_row_start_token": "<tr>",
|
| 108 |
+
"task_sep_token": "<|task_sep|>",
|
| 109 |
+
"tokenizer_class": "TokenizersBackend"
|
| 110 |
+
}
|