alvinfn yasserDahou commited on
Commit
ed18782
·
0 Parent(s):

Duplicate from tiiuae/Falcon-OCR

Browse files

Co-authored-by: Yasser Dahou <yasserDahou@users.noreply.huggingface.co>

.eval_results/olmocrbench.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - dataset:
2
+ id: allenai/olmOCR-bench
3
+ task_id: overall
4
+ value: 80.3
5
+ source:
6
+ url: https://huggingface.co/tiiuae/Falcon-OCR
7
+ name: Falcon-OCR Model Card
8
+ user: nielsr
9
+ notes: English-subset only
10
+
11
+ - dataset:
12
+ id: allenai/olmOCR-bench
13
+ task_id: arxiv_math
14
+ value: 80.5
15
+ source:
16
+ url: https://huggingface.co/tiiuae/Falcon-OCR
17
+ name: Falcon-OCR Model Card
18
+ user: nielsr
19
+ notes: English-subset only
20
+
21
+ - dataset:
22
+ id: allenai/olmOCR-bench
23
+ task_id: old_scans_math
24
+ value: 69.2
25
+ source:
26
+ url: https://huggingface.co/tiiuae/Falcon-OCR
27
+ name: Falcon-OCR Model Card
28
+ user: nielsr
29
+ notes: English-subset only
30
+
31
+ - dataset:
32
+ id: allenai/olmOCR-bench
33
+ task_id: table_tests
34
+ value: 90.3
35
+ source:
36
+ url: https://huggingface.co/tiiuae/Falcon-OCR
37
+ name: Falcon-OCR Model Card
38
+ user: nielsr
39
+ notes: English-subset only
40
+
41
+ - dataset:
42
+ id: allenai/olmOCR-bench
43
+ task_id: old_scans
44
+ value: 43.5
45
+ source:
46
+ url: https://huggingface.co/tiiuae/Falcon-OCR
47
+ name: Falcon-OCR Model Card
48
+ user: nielsr
49
+ notes: English-subset only
50
+
51
+ - dataset:
52
+ id: allenai/olmOCR-bench
53
+ task_id: headers_footers
54
+ value: 94.0
55
+ source:
56
+ url: https://huggingface.co/tiiuae/Falcon-OCR
57
+ name: Falcon-OCR Model Card
58
+ user: nielsr
59
+ notes: English-subset only
60
+
61
+ - dataset:
62
+ id: allenai/olmOCR-bench
63
+ task_id: multi_column
64
+ value: 87.1
65
+ source:
66
+ url: https://huggingface.co/tiiuae/Falcon-OCR
67
+ name: Falcon-OCR Model Card
68
+ user: nielsr
69
+ notes: English-subset only
70
+
71
+ - dataset:
72
+ id: allenai/olmOCR-bench
73
+ task_id: long_tiny_text
74
+ value: 78.5
75
+ source:
76
+ url: https://huggingface.co/tiiuae/Falcon-OCR
77
+ name: Falcon-OCR Model Card
78
+ user: nielsr
79
+ notes: English-subset only
80
+
81
+ - dataset:
82
+ id: allenai/olmOCR-bench
83
+ task_id: baseline
84
+ value: 99.5
85
+ source:
86
+ url: https://huggingface.co/tiiuae/Falcon-OCR
87
+ name: Falcon-OCR Model Card
88
+ user: nielsr
89
+ notes: English-subset only
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: image-to-text
3
+ library_name: transformers
4
+ tags:
5
+ - falcon
6
+ - ocr
7
+ - vision-language
8
+ - document-understanding
9
+ ---
10
+
11
+
12
+ <div style="width: 480px; text-align: left;">
13
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/663c9939c1b4f7297c4ae6f6/YIuxzgDiV5T2ZuSB4bam9.png" alt="Falcon OCR Logo" style="max-width: 100%; height: auto;">
14
+ </div>
15
+
16
+ # Falcon OCR
17
+
18
+ Falcon OCR is a 300M parameter early-fusion vision-language model for document OCR. Given an image, it can produce plain text, LaTeX for formulas, or HTML for tables, depending on the requested output format.
19
+
20
+ Most OCR VLM systems are built as a pipeline with a vision encoder feeding a separate text decoder, plus additional task-specific glue. Falcon OCR takes a different approach: a single Transformer processes image patches and text tokens in a shared parameter space from the first layer, using a hybrid attention mask where image tokens attend bidirectionally and text tokens decode causally conditioned on the image.
21
+
22
+ We built it this way for two practical reasons. First, it keeps the interface simple: one backbone, one decoding path, and task switching through prompts rather than a growing set of modules. Second, a 0.3B model has a lower latency and cost footprint than 0.9B-class OCR VLMs, and in our vLLM-based serving setup this translates into higher throughput, often 2–3× faster depending on sequence lengths and batch configuration. To our knowledge, this is one of the first attempts to apply this early-fusion single-stack recipe directly to competitive document OCR at this scale.
23
+
24
+ ### Links
25
+
26
+ - Code and inference engine: [https://github.com/tiiuae/Falcon-Perception](https://ghcr.io/tiiuae/falcon-ocr:latest)
27
+ - Tech report: arXiv link coming soon
28
+ - Perception model: `tiiuae/falcon-perception`
29
+ - vLLM/Docker: [https://ghcr.io/tiiuae/falcon-ocr:latest](https://ghcr.io/tiiuae/falcon-ocr:latest)
30
+
31
+ ## Quickstart
32
+
33
+ ### Installation
34
+
35
+ ```bash
36
+ pip install "torch>=2.5" transformers pillow einops
37
+ ```
38
+
39
+ Falcon OCR requires PyTorch 2.5 or newer for FlexAttention. The first call may be slower as `torch.compile` builds optimized kernels.
40
+
41
+ ### Single-Image OCR
42
+
43
+ ```python
44
+ import torch
45
+ from PIL import Image
46
+ from transformers import AutoModelForCausalLM
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ "tiiuae/Falcon-OCR",
49
+ trust_remote_code=True,
50
+ torch_dtype=torch.bfloat16,
51
+ device_map="auto",
52
+ )
53
+ image = Image.open("document.png")
54
+ texts = model.generate(image) # default category is "plain"
55
+ print(texts[0])
56
+ ```
57
+
58
+ ### Choose an output format with `category`
59
+
60
+ ```python
61
+ texts = model.generate(image, category="text") # plain text
62
+ texts = model.generate(image, category="formula") # LaTeX
63
+ texts = model.generate(image, category="table") # HTML table
64
+ ```
65
+
66
+ ## API
67
+
68
+ ### `model.generate(images, category="plain", **kwargs)`
69
+ - **Inputs**:
70
+ - `images`: a `PIL.Image.Image` or a list of images
71
+ - `category`: one of `plain`, `text`, `table`, `formula`, `caption`, `footnote`, `list-item`, `page-footer`, `page-header`, `section-header`, `title`
72
+ - **Returns**: `list[str]`, one extracted string per image
73
+
74
+ ## Layout OCR (Two-Stage Pipeline)
75
+ For sparse documents, running OCR on the whole image can work well. For dense documents with heterogeneous regions (multi-column layouts, interleaved tables and formulas, small captions), we provide an optional two-stage pipeline:
76
+ 1. A layout detector finds regions on the page.
77
+ 2. Falcon OCR runs independently on each crop with a category-specific prompt.
78
+ We use [PP-DocLayoutV3](https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors) as the layout detector.
79
+ ```python
80
+ results = model.generate_with_layout(image)
81
+ for det in results[0]:
82
+ print(f"[{det['category']}] {det['text'][:100]}...")
83
+ ```
84
+ Batch mode:
85
+ ```python
86
+ results = model.generate_with_layout(
87
+ [Image.open("page1.png"), Image.open("page2.png")],
88
+ ocr_batch_size=32,
89
+ )
90
+ ```
91
+ The layout model is loaded lazily on the first `generate_with_layout()` call and runs on the same GPU as the OCR model.
92
+ **Returns**: `list[list[dict]]`, one list per image, in reading order:
93
+ ```python
94
+ {
95
+ "category": "text", # layout category
96
+ "bbox": [x1, y1, x2, y2], # in original image pixels
97
+ "score": 0.93, # detection confidence
98
+ "text": "..." # extracted text
99
+ }
100
+ ```
101
+
102
+ ## When to Use What
103
+
104
+ | Mode | Best for | How |
105
+ |------|----------|-----|
106
+ | **Plain OCR** | Simple documents, real-world photos, slides, receipts, invoices | `model.generate(image)` |
107
+ | **Layout + OCR** | Complex multi-column documents, academic papers, reports, dense pages like newspapers | `model.generate_with_layout(image)` |
108
+
109
+ ## Benchmark Results
110
+ <details name="benchmarks" open>
111
+ <summary><b>olmOCR Benchmark</b></summary>
112
+
113
+ Category-wise performance comparison of FalconOCR against state-of-the-art OCR models. We report accuracy (%) across all category splits.
114
+
115
+ <table>
116
+ <tr><th>Model</th><th>Average</th><th>ArXiv Math</th><th>Base</th><th>Hdr/Ftr</th><th>TinyTxt</th><th>MultCol</th><th>OldScan</th><th>OldMath</th><th>Tables</th></tr>
117
+ <tr><td>Mistral OCR 3</td><td>81.7</td><td><b>85.4</b></td><td><b>99.9</b></td><td>93.8</td><td>88.9</td><td>82.1</td><td>48.8</td><td>68.3</td><td>86.1</td></tr>
118
+ <tr><td>Chandra</td><td><b>82.0</b></td><td>81.4</td><td>99.8</td><td>88.8</td><td><b>91.9</b></td><td>82.9</td><td><b>49.2</b></td><td>73.6</td><td>88.2</td></tr>
119
+ <tr><td>Gemini 3 Pro</td><td>80.2</td><td>70.6</td><td>99.8</td><td>84.0</td><td>90.3</td><td>79.2</td><td>47.5</td><td>84.9</td><td>84.9</td></tr>
120
+ <tr><td>PaddleOCR VL 1.5</td><td>79.3</td><td><b>85.4</b></td><td>98.8</td><td><b>96.9</b></td><td>80.8</td><td>82.6</td><td>39.2</td><td>66.4</td><td>84.1</td></tr>
121
+ <tr><td>PaddleOCR VL</td><td>79.2</td><td><b>85.4</b></td><td>98.6</td><td><b>96.9</b></td><td>80.8</td><td>82.5</td><td>38.8</td><td>66.4</td><td>83.9</td></tr>
122
+ <tr><td>DeepSeek OCR v2</td><td>78.8</td><td>81.9</td><td>99.8</td><td>95.6</td><td>88.7</td><td>83.6</td><td>33.7</td><td>68.8</td><td>78.1</td></tr>
123
+ <tr><td>Gemini 3 Flash</td><td>77.5</td><td>66.5</td><td>99.8</td><td>83.8</td><td>88.2</td><td>73.7</td><td>46.0</td><td><b>85.8</b></td><td>75.9</td></tr>
124
+ <tr><td>GPT 5.2</td><td>69.8</td><td>61.0</td><td>99.8</td><td>75.6</td><td>62.2</td><td>70.2</td><td>34.6</td><td>75.8</td><td>79.0</td></tr>
125
+ <tr style="background:#a358e5; color:white"><td><b>FalconOCR</b></td><td>80.3</td><td>80.5</td><td>99.5</td><td>94.0</td><td>78.5</td><td><b>87.1</b></td><td>43.5</td><td>69.2</td><td><b>90.3</b></td></tr>
126
+ </table>
127
+
128
+ </details>
129
+
130
+ <details name="benchmarks">
131
+ <summary><b>OmniDocBench</b></summary>
132
+
133
+ Performance comparison on full-page document parsing. Overall↑ aggregates the three sub-metrics. Edit↓ measures text edit distance (lower is better). CDM↑ evaluates formula recognition accuracy. TEDS↑ measures table structure similarity.
134
+
135
+ <table>
136
+ <tr><th>Model</th><th>Overall↑</th><th>Edit↓</th><th>CDM↑</th><th>TEDS↑</th></tr>
137
+ <tr><td>PaddleOCR VL 1.5</td><td><b>94.37</b></td><td>0.025</td><td><b>94.4</b></td><td><b>91.1</b></td></tr>
138
+ <tr><td>PaddleOCR VL</td><td>91.76</td><td><b>0.024</b></td><td>91.7</td><td>85.9</td></tr>
139
+ <tr><td>Chandra</td><td>88.97</td><td>0.046</td><td>88.1</td><td>89.5</td></tr>
140
+ <tr><td>DeepSeek OCR v2</td><td>87.66</td><td>0.037</td><td>89.2</td><td>77.5</td></tr>
141
+ <tr><td>GPT 5.2</td><td>86.56</td><td>0.061</td><td>88.0</td><td>77.7</td></tr>
142
+ <tr><td>Mistral OCR 3</td><td>85.20</td><td>0.053</td><td>84.3</td><td>76.1</td></tr>
143
+ <tr style="background:#a358e5; color:white"><td><b>FalconOCR</b></td><td>88.64</td><td>0.055</td><td>86.8</td><td>84.6</td></tr>
144
+ </table>
145
+
146
+ </details>
147
+
148
+ ### Results Analysis
149
+
150
+ First, a compact model can be competitive when the interface is simple and the training signal is targeted. On olmOCR, Falcon OCR performs strongly on multi-column documents and tables, and is competitive overall against substantially larger systems. Second, evaluation on full-page parsing is sensitive to matching and representation details. On OmniDocBench, the table and formula metrics depend not only on recognition quality but also on how predicted elements are matched to ground truth and how output structure is normalized.
151
+
152
+ More broadly, these results suggest that an early-fusion single-stack Transformer can be a viable alternative to the common "vision encoder plus text decoder" recipe for OCR. We do not view this as a finished answer, but as a promising direction: one early-fusion backbone, a shared parameter space between text and images, a single decoding interface, and better data and training signals, rather than increasingly complex pipelines. To our knowledge, this is among the first demonstrations that this early-fusion recipe can reach competitive document OCR accuracy at this scale, and we hope it encourages further work in this direction.
153
+
154
+ ## Serving Throughput
155
+
156
+ Measured on a single A100-80GB GPU with vLLM, processing document images from olmOCR-Bench under high concurrency for optimal vLLM utilization.
157
+
158
+ <!-- We benchmark two modes to isolate different parts of the pipeline: -->
159
+
160
+ <!-- - **Cropped regions** — A layout detector is run offline first to extract all regions from every page. Only the resulting crops are sent to the VLLM . This measures pure VLLM throughput with no layout overhead. -->
161
+ - **Layout + OCR** — The full end-to-end pipeline: layout detection finds regions on each page, crops them, and vLLM runs OCR on every crop. This represents the real-world serving throughput, inclusive of both layout detection and OCR time.
162
+
163
+ | Mode | tok/s | img/s | Description |
164
+ |------|------:|------:|-------------|
165
+ | **Layout + OCR** | 5,825 | 2.9 | Full pipeline: layout detection → crop → per-region OCR |
166
+ <!-- | **Plain OCR** | 6,076 | 43.7 | plain OCR, no layout step | -->
167
+
168
+ At 0.3B parameters, Falcon OCR is roughly 3× smaller than 0.9B-class OCR VLMs (e.g., PaddleOCR VL), which translates directly into higher serving throughput at competitive accuracy.
169
+
170
+ ## Limitations
171
+
172
+ - **Old scans and tiny text**: Heavily degraded scans and very small glyphs remain challenging. These cases often require higher effective resolution and better coverage in the training mixture.
173
+ - **Non-unique table representations**: Visually identical tables can be encoded in structurally different HTML forms, which can affect tree-based metrics.
174
+ - **Formula matching sensitivity**: LaTeX and Unicode conventions can be penalized differently depending on the benchmark normalization and matching pipeline.
175
+
176
+ ## Examples
177
+
178
+ *Click each section below to expand.*
179
+
180
+ <details name="ocr-examples" open>
181
+ <summary><b>Handwriting and Real World Images</b></summary>
182
+ <p align="center">
183
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/51Fj1wxxtAV_jwubml6sa.png" width="600" alt="Handwriting and real world OCR examples" />
184
+ </p>
185
+ </details>
186
+
187
+ <details name="ocr-examples">
188
+ <summary><b>Tables</b></summary>
189
+ <p align="center">
190
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/2yZjZJAEHVVpd_jfyyDcQ.png" width="600" alt="Table OCR examples" />
191
+ </p>
192
+ </details>
193
+
194
+ <details name="ocr-examples">
195
+ <summary><b>Formulas</b></summary>
196
+ <p align="center">
197
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/__XMb0GyGO02IPKlQsPQx.png" width="600" alt="Formula OCR examples" />
198
+ </p>
199
+ </details>
200
+
201
+ <details name="ocr-examples">
202
+ <summary><b>Complex Layout</b></summary>
203
+ <p align="center">
204
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/62fe441427c98b09b503a4e3/kTR7nI7ogEqdI1SQtXpTu.png" width="600" alt="Complex layout OCR examples" />
205
+ </p>
206
+ </details>
207
+
208
+
209
+ ---
210
+
211
+ ## vLLM Server
212
+ We also provide a Docker-based vLLM-backed inference server capable of serving approximately 6,000 tokens per second.
213
+
214
+ Single Docker image with two services:
215
+
216
+ | Service | Default Port | Description |
217
+ |---------|-------------|-------------|
218
+ | **vLLM** | 8000 | Falcon-OCR vision-language model (OpenAI-compatible API) |
219
+ | **Pipeline** | 5002 | Full document parsing: layout detection → crop → OCR → markdown |
220
+
221
+ The layout model runs inside the pipeline process — it is not a standalone service.
222
+
223
+ ### Quick Start
224
+
225
+ ```bash
226
+ docker run -d --name falcon-ocr \
227
+ --gpus '"device=0,1"' \
228
+ -e EXPOSED_GPU_IDS=0,1 \
229
+ -e VLLM_GPU=0 \
230
+ -e PIPELINE_GPU=1 \
231
+ -e VLLM_GPU_MEM_UTIL=0.90 \
232
+ -p 8000:8000 \
233
+ -p 5002:5002 \
234
+ ghcr.io/tiiuae/falcon-ocr:latest
235
+ ```
236
+
237
+ ### API
238
+
239
+ <details name="api" open>
240
+ <summary><b>Health Checks</b></summary>
241
+
242
+ ```bash
243
+ curl http://localhost:8000/health # vLLM
244
+ curl http://localhost:5002/health # Pipeline
245
+ ```
246
+ </details>
247
+
248
+ <details name="api">
249
+ <summary><b>Upload</b> (multipart file upload — images and PDFs)</summary>
250
+
251
+ The easiest way to send files. Supports images and multi-page PDFs:
252
+
253
+ ```bash
254
+ # Single image
255
+ curl -X POST http://localhost:5002/falconocr/upload \
256
+ -F "files=@photo.jpg;type=image/jpeg"
257
+ # PDF document
258
+ curl -X POST http://localhost:5002/falconocr/upload \
259
+ -F "files=@document.pdf;type=application/pdf"
260
+ ```
261
+ </details>
262
+
263
+ <details name="api">
264
+ <summary><b>Parse</b> (full pipeline: layout + OCR)</summary>
265
+
266
+ Send base64-encoded images for layout detection, cropping, and OCR:
267
+
268
+ ```bash
269
+ curl -X POST http://localhost:5002/falconocr/parse \
270
+ -H "Content-Type: application/json" \
271
+ -d '{
272
+ "images": ["data:image/jpeg;base64,<...>"],
273
+ "skip_layout": false
274
+ }'
275
+ ```
276
+
277
+ Response:
278
+
279
+ ```json
280
+ {
281
+ "json_result": [[{
282
+ "index": 0,
283
+ "mapped_label": "text",
284
+ "content": "The Manuscript",
285
+ "bbox": [273, 273, 937, 380],
286
+ "score": 0.3145
287
+ }]],
288
+ "markdown_result": "The Manuscript",
289
+ "total_output_tokens": 93,
290
+ "processing_time_ms": 414
291
+ }
292
+ ```
293
+ </details>
294
+
295
+ <details name="api">
296
+ <summary><b>Parse</b> (direct VLM, no layout)</summary>
297
+
298
+ Skip layout detection and send the full image directly to the VLM:
299
+
300
+ ```bash
301
+ curl -X POST http://localhost:5002/falconocr/parse \
302
+ -H "Content-Type: application/json" \
303
+ -d '{
304
+ "images": ["data:image/jpeg;base64,<...>"],
305
+ "skip_layout": true
306
+ }'
307
+ ```
308
+ </details>
309
+
310
+ <details name="api">
311
+ <summary><b>Direct vLLM</b> (OpenAI-compatible)</summary>
312
+
313
+ ```bash
314
+ curl -X POST http://localhost:8000/v1/chat/completions \
315
+ -H "Content-Type: application/json" \
316
+ -d '{
317
+ "model": "falcon-ocr",
318
+ "messages": [{"role": "user", "content": [
319
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,<...>"}},
320
+ {"type": "text", "text": "Extract the text content from this image.\n<|OCR_PLAIN|>"}
321
+ ]}],
322
+ "max_tokens": 2048
323
+ }'
324
+ ```
325
+ </details>
326
+
327
+ ### Configuration
328
+
329
+ All settings are controlled via environment variables at `docker run` time.
330
+
331
+ <details name="config" open>
332
+ <summary><b>GPU Assignment</b></summary>
333
+
334
+ | Variable | Default | Description |
335
+ |----------|---------|-------------|
336
+ | `VLLM_GPU` | `0` | Host GPU ID for the vLLM process |
337
+ | `PIPELINE_GPU` | `0` | Host GPU ID for the pipeline (layout model) |
338
+ | `EXPOSED_GPU_IDS` | *(all visible)* | Comma-separated host GPU IDs passed via `--gpus` (for index remapping) |
339
+ </details>
340
+
341
+ <details name="config">
342
+ <summary><b>Port Assignment</b></summary>
343
+
344
+ | Variable | Default | Description |
345
+ |----------|---------|-------------|
346
+ | `VLLM_PORT` | `8000` | Port for the vLLM OpenAI-compatible API |
347
+ | `PIPELINE_PORT` | `5002` | Port for the pipeline API |
348
+ </details>
349
+
350
+ <details name="config">
351
+ <summary><b>vLLM Tuning</b></summary>
352
+
353
+ | Variable | Default | Description |
354
+ |----------|---------|-------------|
355
+ | `VLLM_GPU_MEM_UTIL` | `0.90` | Fraction of GPU memory vLLM can use |
356
+ | `MAX_NUM_SEQS` | `2048` | Max concurrent sequences in vLLM |
357
+ | `MAX_MODEL_LEN` | `8192` | Max model context length |
358
+ | `DTYPE` | `bfloat16` | Model dtype |
359
+ | `MAX_NUM_BATCHED_TOKENS` | *(auto)* | Max batched tokens per iteration |
360
+ | `CHUNKED_PREFILL` | `false` | Enable chunked prefill |
361
+ </details>
362
+
363
+ <details name="config">
364
+ <summary><b>Layout Model Tuning</b></summary>
365
+
366
+ | Variable | Default | Description |
367
+ |----------|---------|-------------|
368
+ | `LAYOUT_BATCH_SIZE` | `64` | Batch size for layout detection inference |
369
+ </details>
370
+
371
+ <details name="config">
372
+ <summary><b>Model Paths</b></summary>
373
+
374
+ | Variable | Default | Description |
375
+ |----------|---------|-------------|
376
+ | `FALCON_OCR_MODEL` | `/models/Falcon-OCR` | Path to Falcon-OCR VLM weights (inside container) |
377
+ | `SERVED_MODEL_NAME` | `falcon-ocr` | Model name exposed by vLLM API |
378
+ </details>
379
+
380
+ ### Deployment Modes
381
+
382
+ <details name="deploy" open>
383
+ <summary><b>Two GPUs</b> (best throughput)</summary>
384
+
385
+ vLLM on one GPU, layout model on another — zero GPU contention:
386
+
387
+ ```bash
388
+ docker run -d --name falcon-ocr \
389
+ --gpus '"device=3,4"' \
390
+ -e EXPOSED_GPU_IDS=3,4 \
391
+ -e VLLM_GPU=3 \
392
+ -e PIPELINE_GPU=4 \
393
+ -e VLLM_GPU_MEM_UTIL=0.90 \
394
+ -p 8000:8000 \
395
+ -p 5002:5002 \
396
+ ghcr.io/tiiuae/falcon-ocr:latest
397
+ ```
398
+ </details>
399
+
400
+ <details name="deploy">
401
+ <summary><b>Single GPU</b> (memory sharing)</summary>
402
+
403
+ Both services share one GPU — tune `VLLM_GPU_MEM_UTIL` to leave room for the layout model:
404
+
405
+ ```bash
406
+ docker run -d --name falcon-ocr \
407
+ --gpus '"device=0"' \
408
+ -e EXPOSED_GPU_IDS=0 \
409
+ -e VLLM_GPU=0 \
410
+ -e PIPELINE_GPU=0 \
411
+ -e VLLM_GPU_MEM_UTIL=0.55 \
412
+ -e LAYOUT_BATCH_SIZE=32 \
413
+ -e MAX_NUM_SEQS=512 \
414
+ -p 8000:8000 \
415
+ -p 5002:5002 \
416
+ ghcr.io/tiiuae/falcon-ocr:latest
417
+ ```
418
+ </details>
419
+
420
+ <details name="deploy">
421
+ <summary><b>Custom Ports</b></summary>
422
+
423
+ ```bash
424
+ docker run -d --name falcon-ocr \
425
+ --gpus '"device=0,1"' \
426
+ -e EXPOSED_GPU_IDS=0,1 \
427
+ -e VLLM_GPU=0 \
428
+ -e PIPELINE_GPU=1 \
429
+ -e VLLM_PORT=18000 \
430
+ -e PIPELINE_PORT=15002 \
431
+ -p 18000:18000 \
432
+ -p 15002:15002 \
433
+ ghcr.io/tiiuae/falcon-ocr:latest
434
+ ```
435
+
436
+ Docker `--gpus "device=3,4"` makes the container see GPUs as local indices `0,1`.
437
+ `EXPOSED_GPU_IDS=3,4` allows you to reference host GPU IDs (`VLLM_GPU=3`, `PIPELINE_GPU=4`);
438
+ the entrypoint remaps them to the correct container-local indices.
439
+ </details>
440
+
441
+
442
+ ## Citation
443
+
444
+ If you use Falcon OCR, please cite:
445
+
446
+ ```bibtex
447
+ @misc{falconocr2026,
448
+ title = {Falcon OCR},
449
+ author = {TII Falcon Vision Team},
450
+ year = {2026},
451
+ howpublished = {arXiv preprint, link forthcoming},
452
+ note = {Code: https://github.com/tiiuae/Falcon-Perception},
453
+ }
454
+ ```
455
+
456
+
attention.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor as T
3
+ from torch.nn.attention.flex_attention import (
4
+ BlockMask,
5
+ _mask_mod_signature,
6
+ and_masks,
7
+ create_block_mask,
8
+ flex_attention,
9
+ or_masks,
10
+ )
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Two compiled variants of flex_attention
14
+ # ---------------------------------------------------------------------------
15
+ # _decode: fullgraph=True, static shapes.
16
+ # Used for decode steps (S_q == 1) where shapes are fixed and
17
+ # the call will be captured inside a CUDA graph. fullgraph=True
18
+ # avoids graph breaks that would corrupt the capture.
19
+ #
20
+ # _prefill: dynamic=True, symbolic shapes.
21
+ # Used for prefill steps (S_q > 1) where the sequence length
22
+ # varies per image. dynamic=True lets one compiled graph handle
23
+ # all lengths without recompilation. Prefill is never inside a
24
+ # CUDA graph, so symbolic shape guards are fine.
25
+ compiled_flex_attn_decode = torch.compile(flex_attention, fullgraph=True)
26
+ compiled_flex_attn_prefill = torch.compile(flex_attention, dynamic=True)
27
+
28
+
29
+ def offset_mask_mod(mask_mod: _mask_mod_signature, offset: int):
30
+ """Get a mask mod function with an offset applied to the query positions."""
31
+
32
+ def _mask_mod(b, h, q, kv):
33
+ return mask_mod(b, h, q + offset, kv)
34
+
35
+ return _mask_mod
36
+
37
+
38
+ def get_causal_mask_mod() -> _mask_mod_signature:
39
+ """Causal mask that prevents attention to future tokens."""
40
+
41
+ def _causal_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
42
+ return q_idx >= kv_idx
43
+
44
+ return _causal_mask
45
+
46
+
47
+ def get_document_mask_mod(batch: T, eos_id: int) -> _mask_mod_signature:
48
+ """Document mask: prevents attention across document boundaries (token IDs [B, S])."""
49
+ eos_mask = batch == eos_id
50
+ eos_mask[:, -1] = True
51
+ cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1)
52
+ sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32)
53
+ sequence_indices[:, 1:] = cumulative_mask[:, :-1]
54
+
55
+ def document_mask(b: T, h: T, q_idx: T, kv_idx: T) -> T:
56
+ return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx]
57
+
58
+ return document_mask
59
+
60
+
61
+ def get_non_left_pad_mask_mod(batch: T, pad_id: int) -> _mask_mod_signature:
62
+ """Prevent model from attending to the left-padded token required for correct batch inference."""
63
+
64
+ non_pad_mask_id = torch.cumsum(batch != pad_id, dim=1)
65
+
66
+ # Left-most pad tokens have cumulative id == 0.
67
+ def mask_mod(b, h, q_idx, kv_idx):
68
+ return non_pad_mask_id[b, kv_idx] > 0
69
+
70
+ return mask_mod
71
+
72
+
73
+ def get_image_prefix_mask_mod(
74
+ batch: T, soi_id: int, eoi_id: int
75
+ ) -> _mask_mod_signature:
76
+ """Image-prefix mask: tokens between SOI and EOI attend only within same image."""
77
+ soi_mask = batch == soi_id
78
+ eoi_mask = batch == eoi_id
79
+ acc_soi_mask = torch.cumsum(soi_mask, dim=1)
80
+ acc_eoi_mask = torch.cumsum(eoi_mask, dim=1)
81
+ img_mask = (acc_soi_mask - acc_eoi_mask) > 0
82
+ img_indices = acc_soi_mask * img_mask
83
+
84
+ def image_prefix_mask_mod(b, h, q_idx, kv_idx):
85
+ is_img_tokens = img_mask[b, q_idx] & img_mask[b, kv_idx]
86
+ is_same_image = img_indices[b, q_idx] == img_indices[b, kv_idx]
87
+ return is_img_tokens & is_same_image
88
+
89
+ return image_prefix_mask_mod
90
+
91
+
92
+ _compiled_create_block_mask = torch.compile(
93
+ create_block_mask, dynamic=True
94
+ ) # reduce-overhead mode breaks manual CUDA graph capture (private streams)
95
+
96
+
97
+ @torch.inference_mode()
98
+ def create_attention_mask(*args, **kwargs) -> BlockMask:
99
+ """Compiled for large masks; inference_mode avoids grad_mode recompiles."""
100
+ return _compiled_create_block_mask(*args, **kwargs)
101
+
102
+
103
+ def create_batch_attention_mask(
104
+ input_batch: T,
105
+ *,
106
+ pad_token_id: int,
107
+ eos_token_id: int,
108
+ soi_token_id: int,
109
+ eoi_token_id: int,
110
+ max_len: int | None = None,
111
+ ) -> BlockMask:
112
+ """Build the combined FlexAttention mask for the batch engine.
113
+
114
+ Composes causal + document + non-left-pad + image-prefix masks.
115
+ """
116
+ B, S = input_batch.size()
117
+ block_causal_mask_mod = and_masks(
118
+ get_causal_mask_mod(),
119
+ get_document_mask_mod(input_batch, eos_token_id),
120
+ get_non_left_pad_mask_mod(input_batch, pad_token_id),
121
+ )
122
+ image_prefix_mask_mod = get_image_prefix_mask_mod(
123
+ batch=input_batch,
124
+ soi_id=soi_token_id,
125
+ eoi_id=eoi_token_id,
126
+ )
127
+ mask_mod = or_masks(image_prefix_mask_mod, block_causal_mask_mod)
128
+ max_len = max_len or S
129
+ return create_attention_mask(mask_mod, B, None, max_len, max_len)
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FalconOCRForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_falcon_ocr.FalconOCRConfig",
7
+ "AutoModelForCausalLM": "modeling_falcon_ocr.FalconOCRForCausalLM"
8
+ },
9
+ "model_type": "falcon_ocr",
10
+ "torch_dtype": "float32",
11
+ "dim": 768,
12
+ "n_layers": 22,
13
+ "n_heads": 16,
14
+ "head_dim": 64,
15
+ "n_kv_heads": 8,
16
+ "vocab_size": 65536,
17
+ "ffn_dim": 2304,
18
+ "norm_eps": 1e-05,
19
+ "max_seq_len": 8192,
20
+ "rope_theta": 10000,
21
+ "channel_size": 3,
22
+ "spatial_patch_size": 16,
23
+ "temporal_patch_size": 1,
24
+ "eos_id": 11,
25
+ "img_id": 227,
26
+ "image_cls_token_id": 244,
27
+ "image_reg_1_token_id": 245,
28
+ "image_reg_2_token_id": 246,
29
+ "image_reg_3_token_id": 247,
30
+ "image_reg_4_token_id": 248,
31
+ "img_end_id": 230
32
+ }
configuration_falcon_ocr.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class FalconOCRConfig(PretrainedConfig):
5
+ model_type = "falcon_ocr"
6
+
7
+ def __init__(
8
+ self,
9
+ dim: int = 768,
10
+ n_layers: int = 22,
11
+ n_heads: int = 16,
12
+ head_dim: int = 64,
13
+ n_kv_heads: int = 8,
14
+ vocab_size: int = 65536,
15
+ ffn_dim: int = 2304,
16
+ norm_eps: float = 1e-5,
17
+ max_seq_len: int = 8192,
18
+ rope_theta: int = 10000,
19
+ channel_size: int = 3,
20
+ spatial_patch_size: int = 16,
21
+ temporal_patch_size: int = 1,
22
+ img_id: int = 227,
23
+ eos_id: int = 11,
24
+ image_cls_token_id: int = 244,
25
+ image_mask_token_id: int = 243,
26
+ image_reg_1_token_id: int = 245,
27
+ image_reg_2_token_id: int = 246,
28
+ image_reg_3_token_id: int = 247,
29
+ image_reg_4_token_id: int = 248,
30
+ img_start_id: int = 229,
31
+ img_end_id: int = 230,
32
+ img_row_sep_id: int = 228,
33
+ vid_start_id: int = 231,
34
+ vid_end_id: int = 232,
35
+ frame_sep_id: int = 233,
36
+ **kwargs,
37
+ ):
38
+ self.dim = dim
39
+ self.n_layers = n_layers
40
+ self.n_heads = n_heads
41
+ self.head_dim = head_dim
42
+ self.n_kv_heads = n_kv_heads
43
+ self.vocab_size = vocab_size
44
+ self.ffn_dim = ffn_dim
45
+ self.norm_eps = norm_eps
46
+ self.max_seq_len = max_seq_len
47
+ self.rope_theta = rope_theta
48
+ self.channel_size = channel_size
49
+ self.spatial_patch_size = spatial_patch_size
50
+ self.temporal_patch_size = temporal_patch_size
51
+ self.img_id = img_id
52
+ self.eos_id = eos_id
53
+ self.image_cls_token_id = image_cls_token_id
54
+ self.image_mask_token_id = image_mask_token_id
55
+ self.image_reg_1_token_id = image_reg_1_token_id
56
+ self.image_reg_2_token_id = image_reg_2_token_id
57
+ self.image_reg_3_token_id = image_reg_3_token_id
58
+ self.image_reg_4_token_id = image_reg_4_token_id
59
+ self.img_start_id = img_start_id
60
+ self.img_end_id = img_end_id
61
+ self.img_row_sep_id = img_row_sep_id
62
+ self.vid_start_id = vid_start_id
63
+ self.vid_end_id = vid_end_id
64
+ self.frame_sep_id = frame_sep_id
65
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e7f73a508c050f3a4f0b5ce196dba36db896626e44c4f8976d3a3e3c18ceb2a
3
+ size 1079789440
model_args.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "channel_size": 3,
3
+ "coord_dec_dim": 8192,
4
+ "coord_enc_dim": 512,
5
+ "coord_out_dim": 2048,
6
+ "coord_token_id": 240,
7
+ "dim": 768,
8
+ "eos_id": 11,
9
+ "ffn_dim": 2304,
10
+ "head_dim": 64,
11
+ "image_cls_token_id": 244,
12
+ "image_reg_1_token_id": 245,
13
+ "image_reg_2_token_id": 246,
14
+ "image_reg_3_token_id": 247,
15
+ "image_reg_4_token_id": 248,
16
+ "img_end_id": 230,
17
+ "img_id": 227,
18
+ "img_row_sep_id": 228,
19
+ "img_start_id": 229,
20
+ "max_seq_len": 8192,
21
+ "n_heads": 16,
22
+ "n_kv_heads": 8,
23
+ "n_layers": 22,
24
+ "norm_eps": 1e-05,
25
+ "num_segm_layers": 3,
26
+ "perception_heads": false,
27
+ "rope_theta": 10000,
28
+ "seg_token_id": 262,
29
+ "segm_out_dim": 256,
30
+ "size_dec_dim": 8192,
31
+ "size_enc_dim": 512,
32
+ "size_out_dim": 2048,
33
+ "size_token_id": 241,
34
+ "spatial_patch_size": 16,
35
+ "temporal_patch_size": 1,
36
+ "vocab_size": 65536
37
+ }
modeling_falcon_ocr.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import einops as E
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import triton
7
+ import triton.language as tl
8
+ from PIL import Image
9
+ from torch import Tensor as T
10
+ from torch import nn
11
+ from torch.nn.attention.flex_attention import (
12
+ AuxRequest,
13
+ BlockMask,
14
+ )
15
+ from transformers import AutoTokenizer, PreTrainedModel
16
+
17
+ from .attention import (
18
+ compiled_flex_attn_decode,
19
+ compiled_flex_attn_prefill,
20
+ create_batch_attention_mask,
21
+ offset_mask_mod,
22
+ )
23
+ from .configuration_falcon_ocr import FalconOCRConfig
24
+ from .processing_falcon_ocr import load_image, process_batch
25
+ from .rope import (
26
+ apply_3d_rotary_emb,
27
+ apply_golden_freqs_cis_to_visual_pos,
28
+ precompute_freqs_cis,
29
+ )
30
+
31
+
32
+ CATEGORY_PROMPTS = {
33
+ "plain": "Extract the text content from this image.",
34
+ "formula": "Extract the formula content from this image.",
35
+ "table": "Extract the table content from this image.",
36
+ "text": "Extract the text content from this image.",
37
+ "caption": "Extract the caption content from this image.",
38
+ "footnote": "Extract the footnote content from this image.",
39
+ "list-item": "Extract the list-item content from this image.",
40
+ "page-footer": "Extract the page-footer content from this image.",
41
+ "page-header": "Extract the page-header content from this image.",
42
+ "section-header": "Extract the section-header content from this image.",
43
+ "title": "Extract the title content from this image.",
44
+ }
45
+
46
+ LAYOUT_TO_OCR_CATEGORY: dict[str, str | None] = {
47
+ "text": "text",
48
+ "table": "table",
49
+ "formula": "formula",
50
+ "caption": "caption",
51
+ "footnote": "footnote",
52
+ "list-item": "list-item",
53
+ "title": "title",
54
+ "header": "text",
55
+ "footer": "page-footer",
56
+ "number": "text",
57
+ "figure_title": "caption",
58
+ "paragraph_title": "section-header",
59
+ "doc_title": "title",
60
+ "reference_content": "text",
61
+ "reference": "text",
62
+ "abstract": "text",
63
+ "aside_text": "text",
64
+ "content": "text",
65
+ "formula_number": "text",
66
+ "vision_footnote": "footnote",
67
+ "algorithm": "text",
68
+ "page-footer": "page-footer",
69
+ "page-header": "page-header",
70
+ "section-header": "section-header",
71
+ # Skip — no text to extract
72
+ "image": None,
73
+ "picture": None,
74
+ "figure": None,
75
+ "chart": None,
76
+ "seal": None,
77
+ }
78
+
79
+ _LAYOUT_TARGET_H, _LAYOUT_TARGET_W = 800, 800
80
+ _MIN_CROP_DIM = 16
81
+
82
+ def _box_area(bbox):
83
+ return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1])
84
+
85
+
86
+ def _intersection_area(a, b):
87
+ return max(0, min(a[2], b[2]) - max(a[0], b[0])) * max(0, min(a[3], b[3]) - max(a[1], b[1]))
88
+
89
+
90
+ def _containment_ratio(small, large):
91
+ area = _box_area(small)
92
+ if area <= 0:
93
+ return 0.0
94
+ return _intersection_area(small, large) / area
95
+
96
+
97
+ def _filter_nested_detections(detections: list[dict], containment_threshold: float = 0.8) -> list[dict]:
98
+ """Remove any box that is mostly contained within a strictly larger box."""
99
+ areas = [_box_area(d["bbox"]) for d in detections]
100
+ keep = []
101
+ for i, det in enumerate(detections):
102
+ is_nested = False
103
+ for j, other in enumerate(detections):
104
+ if i == j:
105
+ continue
106
+ if areas[j] <= areas[i]:
107
+ continue
108
+ if _containment_ratio(det["bbox"], other["bbox"]) > containment_threshold:
109
+ is_nested = True
110
+ break
111
+ if not is_nested:
112
+ keep.append(det)
113
+ return keep
114
+
115
+
116
+ # Attention
117
+
118
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
119
+ B, S, H, D = x.shape
120
+ if n_rep == 1:
121
+ return x
122
+ return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D)
123
+
124
+
125
+ class Attention(nn.Module):
126
+ def __init__(self, config: FalconOCRConfig, layer_id: int):
127
+ super().__init__()
128
+ self.layer_id = layer_id
129
+ self.n_kv_heads = config.n_kv_heads or config.n_heads
130
+ self.n_rep = config.n_heads // self.n_kv_heads
131
+ self.head_dim = config.head_dim or config.dim // config.n_heads
132
+ self.q_dim = config.n_heads * self.head_dim
133
+ self.kv_dim = self.n_kv_heads * self.head_dim
134
+
135
+ self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False)
136
+ self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
137
+ self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
138
+
139
+ def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
140
+ qkv = self.wqkv(F.rms_norm(x, (x.size(-1),)))
141
+ xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
142
+ xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
143
+ xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
144
+ xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim)
145
+ xq = F.rms_norm(xq, (xq.size(-1),))
146
+ xk = F.rms_norm(xk, (xk.size(-1),))
147
+ xk = repeat_kv(xk, n_rep=self.n_rep)
148
+ xv = repeat_kv(xv, n_rep=self.n_rep)
149
+ return xq, xk, xv
150
+
151
+ def _post_attention(self, output: T, lse: T) -> T:
152
+ # Sink-based scaling: sigmoid(lse - sinks) * output
153
+ # equivalent to prepending a sink token to the input
154
+ sinks_BHS = self.sinks.view(1, -1, 1)
155
+ sink_scale = torch.sigmoid(lse - sinks_BHS)
156
+ output = (output * sink_scale.unsqueeze(-1)).to(output.dtype)
157
+ output = output.permute(0, 2, 1, 3).contiguous().flatten(2)
158
+ return self.wo(output)
159
+
160
+ def compile_attention(self, *, dynamic: bool = True, mode: str = "default"):
161
+ self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode)
162
+ self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode)
163
+
164
+ def forward(
165
+ self, x: T, attention_masks: BlockMask, freqs_cis: T,
166
+ freqs_cis_2d: T | None = None, pos_hw: T | None = None,
167
+ kv_cache=None, input_pos=None, batch_idx=None,
168
+ flex_attn_kernel_options=None,
169
+ ):
170
+ xq, xk, xv = self._pre_attention_qkv(x)
171
+ xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw)
172
+ xq = E.rearrange(xq, "b s h d -> b h s d")
173
+ xk = E.rearrange(xk, "b s h d -> b h s d")
174
+ xv = E.rearrange(xv, "b s h d -> b h s d")
175
+ xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx)
176
+ flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill
177
+ output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True))
178
+ return self._post_attention(output, aux_output.lse)
179
+
180
+
181
+ # FeedForward
182
+
183
+ @triton.jit
184
+ def _squared_relu_gate_kernel(
185
+ packed_ptr, out_ptr, n_rows, n_cols,
186
+ in_row_stride, in_col_stride, out_row_stride, out_col_stride,
187
+ BLOCK_SIZE: tl.constexpr,
188
+ ):
189
+ pid = tl.program_id(0)
190
+ n_elements = n_rows * n_cols
191
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
192
+ mask = offsets < n_elements
193
+ rows = offsets // n_cols
194
+ cols = offsets % n_cols
195
+ gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride
196
+ up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride
197
+ out_idx = rows * out_row_stride + cols * out_col_stride
198
+ gate = tl.load(packed_ptr + gate_idx, mask=mask)
199
+ up = tl.load(packed_ptr + up_idx, mask=mask)
200
+ gate = tl.where(gate > 0, gate, 0.0)
201
+ out = gate * gate * up
202
+ tl.store(out_ptr + out_idx, out, mask=mask)
203
+
204
+
205
+ def squared_relu_gate(packed: T, hidden_dim: int) -> T:
206
+ """Processes interleaved [gate, up, gate, up, ...] from w13; output = ReLU(gate)^2 * up."""
207
+ packed_2d = packed.flatten(0, -2)
208
+ n_rows = packed_2d.shape[0]
209
+ n_cols = hidden_dim
210
+ out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype)
211
+ n = n_rows * n_cols
212
+ grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
213
+ _squared_relu_gate_kernel[grid](
214
+ packed_2d, out_2d, n_rows, n_cols,
215
+ packed_2d.stride(0), packed_2d.stride(1),
216
+ out_2d.stride(0), out_2d.stride(1),
217
+ BLOCK_SIZE=1024,
218
+ )
219
+ return out_2d.view(*packed.shape[:-1], hidden_dim)
220
+
221
+
222
+ class FeedForward(nn.Module):
223
+ def __init__(self, dim: int, hidden_dim: int):
224
+ super().__init__()
225
+ self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
226
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
227
+ self.hidden_dim = hidden_dim
228
+
229
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
230
+ x = F.rms_norm(x, (x.size(-1),))
231
+ w13_out = self.w13(x)
232
+ return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
233
+
234
+
235
+ # TransformerBlock
236
+
237
+ class TransformerBlock(nn.Module):
238
+ def __init__(self, layer_id: int, config: FalconOCRConfig):
239
+ super().__init__()
240
+ self.attention = Attention(config, layer_id)
241
+ self.feed_forward = FeedForward(config.dim, config.ffn_dim)
242
+
243
+ def compile(self, *, dynamic: bool = True, mode: str = "default"):
244
+ self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode)
245
+ self.attention.compile_attention(dynamic=dynamic, mode=mode)
246
+ return self
247
+
248
+ def forward(
249
+ self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None,
250
+ pos_hw: T | None = None, attention_masks=None, kv_cache=None,
251
+ input_pos=None, batch_idx=None, flex_attn_kernel_options=None,
252
+ ):
253
+ B, S, D = x.shape
254
+ x = x + self.attention(
255
+ x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw,
256
+ attention_masks=attention_masks, kv_cache=kv_cache,
257
+ input_pos=input_pos, batch_idx=batch_idx,
258
+ flex_attn_kernel_options=flex_attn_kernel_options,
259
+ )
260
+ out = x + self.feed_forward(x)
261
+ return out.reshape(B, S, D)
262
+
263
+
264
+ # KV Cache
265
+
266
+ class KVCache:
267
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers):
268
+ self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim)
269
+ self.kv_cache = None
270
+ self.pos = 0
271
+ self.pos_t: T | None = None
272
+
273
+ def reset(self):
274
+ self.pos = 0
275
+ self.pos_t = None
276
+
277
+ def get_pos(self):
278
+ return self.pos
279
+
280
+ def set_pos_t(self, pos_t):
281
+ self.pos_t = pos_t
282
+
283
+ def increment_and_get_pos_t(self):
284
+ assert self.pos_t is not None
285
+ self.pos_t += 1
286
+ return self.pos_t
287
+
288
+ def insert_kv(self, layer_id: int, k: T, v: T, **kwargs):
289
+ del kwargs
290
+ assert self.pos_t is not None
291
+ if self.kv_cache is None:
292
+ self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
293
+ B, H, T_add, D = k.size()
294
+ t0, t1 = self.pos, self.pos + T_add
295
+ self.kv_cache[layer_id, 0, :, :, t0:t1] = k
296
+ self.kv_cache[layer_id, 1, :, :, t0:t1] = v
297
+ key_view = self.kv_cache[layer_id, 0, :, :, :t1]
298
+ value_view = self.kv_cache[layer_id, 1, :, :, :t1]
299
+ if layer_id == self.kv_cache.size(0) - 1:
300
+ self.pos = t1
301
+ return key_view, value_view
302
+
303
+
304
+ # Sampling
305
+
306
+ @torch.inference_mode()
307
+ def sample_next_token(logits, rng, temperature=0.0, top_k=None):
308
+ assert temperature >= 0.0
309
+ if temperature == 0.0:
310
+ return torch.argmax(logits, dim=-1, keepdim=True)
311
+ if top_k is not None:
312
+ k = min(top_k, logits.size(-1))
313
+ vals, idx = torch.topk(logits, k, dim=-1)
314
+ vals = vals / temperature
315
+ probs = F.softmax(vals, dim=-1)
316
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
317
+ return idx.gather(1, choice)
318
+ logits = logits / temperature
319
+ probs = F.softmax(logits, dim=-1)
320
+ return torch.multinomial(probs, num_samples=1, generator=rng)
321
+
322
+
323
+ # Main Model
324
+
325
+ class FalconOCRForCausalLM(PreTrainedModel):
326
+ config_class = FalconOCRConfig
327
+ _no_split_modules = ["TransformerBlock"]
328
+
329
+ def __init__(self, config: FalconOCRConfig):
330
+ super().__init__(config)
331
+ img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size
332
+ self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False)
333
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
334
+
335
+ self.layers = nn.ModuleDict()
336
+ for layer_id in range(config.n_layers):
337
+ self.layers[str(layer_id)] = TransformerBlock(layer_id, config)
338
+
339
+ self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
340
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
341
+
342
+ rope_dim = config.head_dim // 2
343
+ freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta)
344
+ freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float)
345
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
346
+ self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True)
347
+
348
+ self._weights_fused = False
349
+ self._is_compiled = False
350
+
351
+ self.post_init()
352
+
353
+ # Weight management
354
+
355
+ def _ensure_device_buffers(self):
356
+ """Recompute non-persistent buffers that HF meta-device loading may discard."""
357
+ if self._weights_fused:
358
+ return
359
+ device = self.tok_embeddings.weight.device
360
+ c = self.config
361
+ rope_dim = c.head_dim // 2
362
+ freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
363
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
364
+ if self.freqs_cis_golden.device != device:
365
+ self.freqs_cis_golden = self.freqs_cis_golden.to(device)
366
+ self._weights_fused = True
367
+
368
+ def compile_model(self):
369
+ if self._is_compiled:
370
+ return
371
+ torch._inductor.config.triton.cudagraphs = False
372
+ for layer in self.layers.values():
373
+ layer.compile(dynamic=True, mode="default")
374
+ self._is_compiled = True
375
+
376
+ # Tokenizer
377
+
378
+ def _get_tokenizer(self):
379
+ if not hasattr(self, "_tokenizer"):
380
+ import os
381
+ path = self.config._name_or_path
382
+ is_local = os.path.exists(path)
383
+ self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True)
384
+ for token_name, token in self._tokenizer.special_tokens_map.items():
385
+ if isinstance(token, str):
386
+ setattr(self._tokenizer, token_name, token)
387
+ setattr(
388
+ self._tokenizer, token_name + "_id",
389
+ self._tokenizer.convert_tokens_to_ids(token),
390
+ )
391
+ return self._tokenizer
392
+
393
+ # Attention mask
394
+
395
+ def get_attention_mask(self, input_batch: T, max_len: int | None = None):
396
+ return create_batch_attention_mask(
397
+ input_batch,
398
+ pad_token_id=self._pad_token_id,
399
+ eos_token_id=self.config.eos_id,
400
+ soi_token_id=self.config.image_cls_token_id,
401
+ eoi_token_id=self.config.img_end_id,
402
+ max_len=max_len,
403
+ )
404
+
405
+ # Embedding helpers
406
+
407
+ def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS):
408
+ B, S, D = h_BSD.shape
409
+ pixel_patch_mask = E.reduce(
410
+ pixel_masks_NTHW,
411
+ "n (t pt) (h ph) (w pw) -> (n t h w)",
412
+ reduction="any",
413
+ pt=self.config.temporal_patch_size,
414
+ ph=self.config.spatial_patch_size,
415
+ pw=self.config.spatial_patch_size,
416
+ )
417
+ pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c")
418
+ valid_patches = pixel_patches_flat[pixel_patch_mask]
419
+ valid_feats = self.img_projector(valid_patches)
420
+ img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
421
+ assert valid_feats.numel() == img_mask_h_BSD.sum()
422
+ return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats)
423
+
424
+ # Core forward
425
+
426
+ def forward(
427
+ self,
428
+ tokens: T,
429
+ attention_mask: BlockMask,
430
+ kv_cache,
431
+ rope_pos_t: T | None = None,
432
+ rope_pos_hw: T | None = None,
433
+ pixel_values: T | None = None,
434
+ pixel_mask: T | None = None,
435
+ ):
436
+ B, S = tokens.size()
437
+ c = self.config
438
+ block_mask = attention_mask
439
+
440
+ T_pos = kv_cache.get_pos()
441
+ is_prefill = S != 1
442
+
443
+ if is_prefill:
444
+ assert rope_pos_t is not None and rope_pos_hw is not None
445
+ pos_t = rope_pos_t[:, T_pos:T_pos + S].long()
446
+ kv_cache.pos_t = pos_t[:, -1:]
447
+ freqs_cis = self.freqs_cis[pos_t]
448
+ rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S]
449
+ freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw)
450
+ block_mask.seq_lengths = (S, S)
451
+ else:
452
+ pos_t = kv_cache.increment_and_get_pos_t()
453
+ freqs_cis = self.freqs_cis[pos_t]
454
+ freqs_cis_golden = None
455
+ block_idx = T_pos // block_mask.BLOCK_SIZE[0]
456
+ block_mask = block_mask[:, :, block_idx]
457
+ block_mask.seq_lengths = (S, T_pos + S)
458
+ block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos)
459
+
460
+ h_BSD = self.tok_embeddings(tokens)
461
+
462
+ if pixel_values is not None:
463
+ assert pixel_mask is not None
464
+ pixel_values = pixel_values.to(self.dtype)
465
+ pixel_mask = pixel_mask.to(self.dtype)
466
+ pixel_patches_NLC = E.rearrange(
467
+ pixel_values,
468
+ "n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)",
469
+ pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
470
+ )
471
+ h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens)
472
+
473
+ for layer in self.layers.values():
474
+ h_BSD = layer(
475
+ h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden,
476
+ pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache,
477
+ )
478
+
479
+ h_BSD = self.norm(h_BSD)
480
+ logits_BSV = self.output(h_BSD)
481
+ return logits_BSV
482
+
483
+ # Layout detection
484
+
485
+ def _load_layout_model(self, layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors"):
486
+ if hasattr(self, "_layout_model"):
487
+ return
488
+ import torchvision.transforms.functional as tvF
489
+ from transformers import AutoModelForObjectDetection, PPDocLayoutV3ImageProcessorFast
490
+
491
+ self._layout_processor = PPDocLayoutV3ImageProcessorFast.from_pretrained(layout_model)
492
+ self._layout_det_model = AutoModelForObjectDetection.from_pretrained(
493
+ layout_model, torch_dtype=torch.float16,
494
+ ).to(self.device).eval()
495
+ self._layout_id2label = self._layout_det_model.config.id2label
496
+ self._tvF = tvF
497
+
498
+ @torch.inference_mode()
499
+ def _run_layout_detection(
500
+ self, images: list[Image.Image], threshold: float = 0.5,
501
+ ) -> list[list[dict]]:
502
+ """Run PP-DocLayoutV3 on a batch of PIL images, return per-image detections."""
503
+ device = self.device
504
+ tvF = self._tvF
505
+
506
+ target_sizes = torch.tensor([img.size[::-1] for img in images])
507
+ tensors = [tvF.pil_to_tensor(img) for img in images]
508
+
509
+ # GPU-accelerated resize + normalize
510
+ result = torch.empty(
511
+ len(tensors), 3, _LAYOUT_TARGET_H, _LAYOUT_TARGET_W,
512
+ dtype=torch.float16, device=device,
513
+ )
514
+ size_groups: dict[tuple[int, int], list[int]] = {}
515
+ for i, t in enumerate(tensors):
516
+ size_groups.setdefault((t.shape[1], t.shape[2]), []).append(i)
517
+
518
+ for shape, indices in size_groups.items():
519
+ batch = torch.stack([tensors[i] for i in indices])
520
+ batch = batch.to(device=device, dtype=torch.float32, non_blocking=True)
521
+ batch = F.interpolate(
522
+ batch, size=(_LAYOUT_TARGET_H, _LAYOUT_TARGET_W),
523
+ mode="bicubic", align_corners=False, antialias=False,
524
+ )
525
+ batch = (batch.clamp_(0, 255) / 255.0).to(torch.float16)
526
+ for j, idx in enumerate(indices):
527
+ result[idx] = batch[j]
528
+ del batch
529
+
530
+ outputs = self._layout_det_model(pixel_values=result)
531
+ del result
532
+
533
+ # Postprocess on GPU
534
+ logits = outputs.logits
535
+ boxes = outputs.pred_boxes
536
+ order_logits = outputs.order_logits
537
+
538
+ box_centers, box_dims = boxes.split(2, dim=-1)
539
+ boxes_xyxy = torch.cat([box_centers - 0.5 * box_dims, box_centers + 0.5 * box_dims], dim=-1)
540
+
541
+ img_h, img_w = target_sizes.unbind(1)
542
+ scale = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(device, dtype=boxes_xyxy.dtype)
543
+ boxes_xyxy = boxes_xyxy * scale[:, None, :]
544
+
545
+ num_queries = logits.shape[1]
546
+ num_classes = logits.shape[2]
547
+ scores = logits.sigmoid()
548
+ scores_flat, index = scores.flatten(1).topk(num_queries, dim=-1)
549
+ labels = index % num_classes
550
+ box_indices = index // num_classes
551
+ boxes_xyxy = boxes_xyxy.gather(dim=1, index=box_indices.unsqueeze(-1).expand(-1, -1, 4))
552
+
553
+ order_seqs = self._layout_processor._get_order_seqs(order_logits)
554
+ order_seqs = order_seqs.gather(dim=1, index=box_indices)
555
+
556
+ batch_results = []
557
+ for s, l, b, o in zip(scores_flat, labels, boxes_xyxy, order_seqs):
558
+ mask = s >= threshold
559
+ o_valid = o[mask]
560
+ _, indices_sorted = o_valid.sort()
561
+
562
+ detections = []
563
+ for si, li, bi in zip(s[mask][indices_sorted], l[mask][indices_sorted], b[mask][indices_sorted]):
564
+ detections.append({
565
+ "category": self._layout_id2label[li.item()],
566
+ "bbox": [round(x, 2) for x in bi.tolist()],
567
+ "score": round(si.item(), 4),
568
+ })
569
+ batch_results.append(detections)
570
+
571
+ return batch_results
572
+
573
+ # Core batch decode (shared by generate & generate_with_layout)
574
+
575
+ def _generate_batch(
576
+ self,
577
+ image_prompt_pairs: list[tuple],
578
+ *,
579
+ max_new_tokens: int,
580
+ temperature: float,
581
+ top_k: int | None,
582
+ min_dimension: int,
583
+ max_dimension: int,
584
+ seed: int | None,
585
+ ) -> list[str]:
586
+ """Core autoregressive decode for a list of (image, prompt) pairs."""
587
+ device = self.device
588
+ tokenizer = self._get_tokenizer()
589
+ self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
590
+ stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")]
591
+
592
+ batch_inputs = process_batch(
593
+ tokenizer, self.config, image_prompt_pairs,
594
+ max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension,
595
+ )
596
+ batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()}
597
+
598
+ tokens = batch_inputs["tokens"]
599
+ B, L = tokens.size()
600
+ block_size = 128
601
+ S = (L + max_new_tokens + block_size - 1) // block_size * block_size
602
+ assert S <= self.config.max_seq_len
603
+
604
+ rng = torch.Generator(device).manual_seed(seed) if seed is not None else None
605
+
606
+ kv_cache = KVCache(
607
+ max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads,
608
+ head_dim=self.config.head_dim, num_layers=self.config.n_layers,
609
+ )
610
+
611
+ padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device)
612
+ padded_tokens[:, :L] = tokens
613
+
614
+ attention_mask = self.get_attention_mask(padded_tokens, max_len=S)
615
+
616
+ logits_BSV = self.forward(
617
+ tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"],
618
+ attention_mask=attention_mask, kv_cache=kv_cache,
619
+ pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"],
620
+ )
621
+
622
+ stop_ids = torch.tensor(stop_token_ids).to(device)
623
+ should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device)
624
+ generated_ids: list[list[int]] = [[] for _ in range(B)]
625
+
626
+ while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S:
627
+ tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k)
628
+
629
+ if torch.any(should_stop_B):
630
+ tokens_B1 = tokens_B1.clone()
631
+ tokens_B1[should_stop_B, :] = self._pad_token_id
632
+ padded_tokens[:, pos] = tokens_B1[:, -1]
633
+
634
+ for b in range(B):
635
+ if not should_stop_B[b]:
636
+ generated_ids[b].append(tokens_B1[b, 0].item())
637
+
638
+ logits_BSV = self.forward(
639
+ tokens=tokens_B1, attention_mask=attention_mask, kv_cache=kv_cache,
640
+ )
641
+
642
+ hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
643
+ should_stop_B = should_stop_B.logical_or(hit_stop_B)
644
+
645
+ results = []
646
+ for b in range(B):
647
+ text = tokenizer.decode(generated_ids[b], skip_special_tokens=False)
648
+ text = text.replace("<|end_of_query|>", "").replace("<|end_of_text|>", "").strip()
649
+ results.append(text)
650
+
651
+ return results
652
+
653
+ # Main API: generate
654
+
655
+ @torch.inference_mode()
656
+ def generate(
657
+ self,
658
+ images,
659
+ *,
660
+ category: str | list[str] = "plain",
661
+ max_new_tokens: int = 4096,
662
+ temperature: float = 0.0,
663
+ top_k: int | None = None,
664
+ min_dimension: int = 64,
665
+ max_dimension: int = 1024,
666
+ compile: bool = True,
667
+ seed: int | None = 42,
668
+ ) -> list[str]:
669
+ """
670
+ Extract text from document images.
671
+
672
+ Args:
673
+ images: Single PIL Image (or path/URL) or list of them.
674
+ category: OCR category — one of "plain", "text", "table", "formula",
675
+ "caption", "footnote", "list-item", "page-footer", "page-header",
676
+ "section-header", "title". Can be a single string (applied to all
677
+ images) or a list (one per image).
678
+ max_new_tokens: Maximum generation steps.
679
+ temperature: Sampling temperature (0.0 = greedy).
680
+ top_k: Top-k sampling (None = disabled).
681
+ min_dimension: Min image side after resize.
682
+ max_dimension: Max image side after resize.
683
+ compile: Whether to torch.compile on first call.
684
+ seed: Random seed for reproducibility (None = non-deterministic).
685
+
686
+ Returns:
687
+ List of extracted text strings, one per image.
688
+ """
689
+ self._ensure_device_buffers()
690
+ if compile:
691
+ self.compile_model()
692
+
693
+ if isinstance(images, (str, Path, Image.Image)):
694
+ images = [images]
695
+ if isinstance(category, str):
696
+ category = [category] * len(images)
697
+ assert len(images) == len(category), "Must provide one category per image"
698
+
699
+ image_prompt_pairs = []
700
+ for img, cat in zip(images, category):
701
+ instruction = CATEGORY_PROMPTS.get(cat.strip().lower(), CATEGORY_PROMPTS["plain"])
702
+ prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>"
703
+ image_prompt_pairs.append((img, prompt))
704
+
705
+ return self._generate_batch(
706
+ image_prompt_pairs,
707
+ max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k,
708
+ min_dimension=min_dimension, max_dimension=max_dimension, seed=seed,
709
+ )
710
+
711
+ # Main API: generate_with_layout
712
+
713
+ @torch.inference_mode()
714
+ def generate_with_layout(
715
+ self,
716
+ images,
717
+ *,
718
+ max_new_tokens: int = 4096,
719
+ temperature: float = 0.0,
720
+ top_k: int | None = None,
721
+ min_dimension: int = 64,
722
+ max_dimension: int = 1024,
723
+ compile: bool = True,
724
+ seed: int | None = 42,
725
+ layout_threshold: float = 0.3,
726
+ layout_batch_size: int = 4,
727
+ ocr_batch_size: int = 32,
728
+ containment_threshold: float = 0.8,
729
+ layout_model: str = "PaddlePaddle/PP-DocLayoutV3_safetensors",
730
+ ) -> list[list[dict]]:
731
+ """
732
+ Run layout detection then OCR on each detected region.
733
+
734
+ Args:
735
+ images: Single PIL Image (or path/URL) or list of them.
736
+ max_new_tokens: Maximum generation steps per crop.
737
+ temperature: Sampling temperature (0.0 = greedy).
738
+ top_k: Top-k sampling (None = disabled).
739
+ min_dimension: Min crop side after resize for OCR.
740
+ max_dimension: Max crop side after resize for OCR.
741
+ compile: Whether to torch.compile on first call.
742
+ seed: Random seed for reproducibility.
743
+ layout_threshold: Confidence threshold for layout detections.
744
+ layout_batch_size: Batch size for layout detection.
745
+ ocr_batch_size: Batch size for OCR generation (chunks crops).
746
+ containment_threshold: Drop formula boxes >threshold contained in text boxes.
747
+ layout_model: HuggingFace model ID for layout detection.
748
+
749
+ Returns:
750
+ Per-image list of detections, each a dict with keys:
751
+ ``category``, ``bbox`` [x1,y1,x2,y2], ``score``, ``text``.
752
+ """
753
+ self._ensure_device_buffers()
754
+ if compile:
755
+ self.compile_model()
756
+ self._load_layout_model(layout_model)
757
+
758
+ if isinstance(images, (str, Path, Image.Image)):
759
+ images = [images]
760
+ pil_images = [load_image(img).convert("RGB") for img in images]
761
+
762
+ # --- Layout detection (batched) ---
763
+ all_layout_dets: list[list[dict]] = []
764
+ for i in range(0, len(pil_images), layout_batch_size):
765
+ batch_imgs = pil_images[i : i + layout_batch_size]
766
+ dets = self._run_layout_detection(batch_imgs, threshold=layout_threshold)
767
+ all_layout_dets.extend(dets)
768
+
769
+ # --- Filter nested boxes (e.g. inline formulas inside text) ---
770
+ all_layout_dets = [
771
+ _filter_nested_detections(dets, containment_threshold)
772
+ for dets in all_layout_dets
773
+ ]
774
+
775
+ # --- Build crops + track origin ---
776
+ flat_crops: list[tuple[Image.Image, str]] = []
777
+ crop_origins: list[tuple[int, int]] = [] # (image_idx, det_idx)
778
+
779
+ for img_idx, (pil_img, dets) in enumerate(zip(pil_images, all_layout_dets)):
780
+ if not dets or (len(dets) == 1 and dets[0]["category"].strip().lower() == "image"):
781
+ prompt = f"<|image|>{CATEGORY_PROMPTS['plain']}\n<|OCR_PLAIN|>"
782
+ flat_crops.append((pil_img, prompt))
783
+ crop_origins.append((img_idx, -1))
784
+ continue
785
+
786
+ img_w, img_h = pil_img.size
787
+ for det_idx, det in enumerate(dets):
788
+ cat_key = det["category"].strip().lower()
789
+ ocr_cat = LAYOUT_TO_OCR_CATEGORY.get(cat_key)
790
+ if ocr_cat is None:
791
+ continue
792
+
793
+ x1, y1, x2, y2 = det["bbox"]
794
+ x1 = max(0, int(x1))
795
+ y1 = max(0, int(y1))
796
+ x2 = min(img_w, int(x2 + 0.5))
797
+ y2 = min(img_h, int(y2 + 0.5))
798
+ cw, ch = x2 - x1, y2 - y1
799
+ if cw < _MIN_CROP_DIM or ch < _MIN_CROP_DIM:
800
+ continue
801
+ short, long = sorted((cw, ch))
802
+ resized_short = short * (max_dimension / long) if long > max_dimension else short
803
+ if resized_short < _MIN_CROP_DIM:
804
+ continue
805
+
806
+ crop = pil_img.crop((x1, y1, x2, y2))
807
+ instruction = CATEGORY_PROMPTS.get(ocr_cat, CATEGORY_PROMPTS["plain"])
808
+ prompt = f"<|image|>{instruction}\n<|OCR_PLAIN|>"
809
+ flat_crops.append((crop, prompt))
810
+ crop_origins.append((img_idx, det_idx))
811
+
812
+ # --- OCR in chunks ---
813
+ flat_texts: list[str] = []
814
+ for i in range(0, max(len(flat_crops), 1), ocr_batch_size):
815
+ chunk = flat_crops[i : i + ocr_batch_size]
816
+ if not chunk:
817
+ break
818
+ texts = self._generate_batch(
819
+ chunk,
820
+ max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k,
821
+ min_dimension=min_dimension, max_dimension=max_dimension, seed=seed,
822
+ )
823
+ flat_texts.extend(texts)
824
+
825
+ # --- Reassemble per-image results ---
826
+ results: list[list[dict]] = [[] for _ in range(len(pil_images))]
827
+ for (img_idx, det_idx), text in zip(crop_origins, flat_texts):
828
+ if det_idx == -1:
829
+ img_w, img_h = pil_images[img_idx].size
830
+ results[img_idx].append({
831
+ "category": "plain",
832
+ "bbox": [0, 0, img_w, img_h],
833
+ "score": 1.0,
834
+ "text": text,
835
+ })
836
+ else:
837
+ det = all_layout_dets[img_idx][det_idx]
838
+ results[img_idx].append({
839
+ "category": det["category"],
840
+ "bbox": det["bbox"],
841
+ "score": det["score"],
842
+ "text": text,
843
+ })
844
+
845
+ return results
processing_falcon_ocr.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+
4
+ import einops as E
5
+ import numpy as np
6
+ import requests
7
+ import torch
8
+ from PIL import Image
9
+ from transformers.image_processing_utils import BaseImageProcessor
10
+ from transformers.image_transforms import convert_to_rgb, resize
11
+ from transformers.image_utils import (
12
+ ImageInput,
13
+ get_image_size,
14
+ infer_channel_dimension_format,
15
+ to_numpy_array,
16
+ valid_images,
17
+ validate_preprocess_arguments,
18
+ )
19
+
20
+ IMAGE_MEAN = [0.5, 0.5, 0.5]
21
+ IMAGE_STD = [0.5, 0.5, 0.5]
22
+
23
+
24
+ def load_image(image):
25
+ if image is None:
26
+ return None
27
+ if isinstance(image, Image.Image):
28
+ return image
29
+ if isinstance(image, str):
30
+ if image.startswith(("http://", "https://")):
31
+ response = requests.get(image, timeout=10)
32
+ response.raise_for_status()
33
+ return Image.open(io.BytesIO(response.content))
34
+ if image.endswith(".npy"):
35
+ img_array = io.BytesIO(np.load(image))
36
+ return Image.open(img_array)
37
+ return Image.open(image)
38
+ if isinstance(image, np.bytes_):
39
+ return Image.open(io.BytesIO(image))
40
+ if isinstance(image, np.ndarray):
41
+ return Image.fromarray(image)
42
+ raise TypeError(f"Unknown image format {image}")
43
+
44
+
45
+ def load_images(images_input, min_dimension: int, max_dimension: int):
46
+ images = []
47
+ if images_input is not None:
48
+ for inp in images_input:
49
+ img = load_image(inp)
50
+ img = resize_image_if_necessary(img, min_dimension, max_dimension)
51
+ images.append(img)
52
+ return images
53
+
54
+
55
+ def resize_image_if_necessary(
56
+ image,
57
+ shortest_dimension=224,
58
+ longest_dimension=896,
59
+ ):
60
+ original_width, original_height = image.size
61
+ aspect_ratio = original_width / original_height
62
+
63
+ if (
64
+ shortest_dimension <= original_width <= longest_dimension
65
+ and shortest_dimension <= original_height <= longest_dimension
66
+ ):
67
+ return image
68
+
69
+ is_vertical_image = original_width < original_height
70
+ if original_width < shortest_dimension or original_height < shortest_dimension:
71
+ if is_vertical_image:
72
+ new_width = shortest_dimension
73
+ new_height = int(new_width / aspect_ratio)
74
+ else:
75
+ new_height = shortest_dimension
76
+ new_width = int(new_height * aspect_ratio)
77
+ else:
78
+ if is_vertical_image:
79
+ new_width = longest_dimension
80
+ new_height = int(new_width / aspect_ratio)
81
+ else:
82
+ new_height = longest_dimension
83
+ new_width = int(new_height * aspect_ratio)
84
+
85
+ if new_width > longest_dimension:
86
+ new_width = longest_dimension
87
+ new_height = int(new_width / aspect_ratio)
88
+ if new_height > longest_dimension:
89
+ new_height = longest_dimension
90
+ new_width = int(new_height * aspect_ratio)
91
+
92
+ resized_image = image.resize((new_width, new_height))
93
+ return resized_image
94
+
95
+
96
+ def smart_resize(
97
+ image,
98
+ factor: int,
99
+ resample,
100
+ input_data_format,
101
+ min_pixels: int = 56 * 56,
102
+ max_pixels: int = 14 * 14 * 4 * 1280,
103
+ ):
104
+ height, width = get_image_size(image, channel_dim=input_data_format)
105
+ if height < factor or width < factor:
106
+ raise ValueError(f"{height=} or {width=} must be larger than {factor=}")
107
+ if max(height, width) / min(height, width) > 200:
108
+ raise ValueError(
109
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
110
+ )
111
+ h_bar = round(height / factor) * factor
112
+ w_bar = round(width / factor) * factor
113
+ if h_bar * w_bar > max_pixels:
114
+ beta = np.sqrt((height * width) / max_pixels)
115
+ h_bar = math.floor(height / beta / factor) * factor
116
+ w_bar = math.floor(width / beta / factor) * factor
117
+ elif h_bar * w_bar < min_pixels:
118
+ beta = np.sqrt(min_pixels / (height * width))
119
+ h_bar = math.ceil(height * beta / factor) * factor
120
+ w_bar = math.ceil(width * beta / factor) * factor
121
+ image = resize(
122
+ image,
123
+ size=(h_bar, w_bar),
124
+ resample=resample,
125
+ input_data_format=input_data_format,
126
+ )
127
+ return image
128
+
129
+
130
+ class ImageProcessor(BaseImageProcessor):
131
+ def __init__(
132
+ self,
133
+ patch_size,
134
+ merge_size,
135
+ do_resize: bool = True,
136
+ resample: Image.Resampling = Image.Resampling.BICUBIC,
137
+ do_rescale: bool = True,
138
+ rescale_factor: float = 1 / 255,
139
+ do_normalize: bool = True,
140
+ image_mean: float | list[float] | None = None,
141
+ image_std: float | list[float] | None = None,
142
+ do_convert_rgb: bool = True,
143
+ min_pixels: int = 56 * 56,
144
+ max_pixels: int = 28 * 28 * 1280,
145
+ **kwargs,
146
+ ) -> None:
147
+ super().__init__(**kwargs)
148
+ self.do_resize = do_resize
149
+ self.resample = resample
150
+ self.do_rescale = do_rescale
151
+ self.rescale_factor = rescale_factor
152
+ self.do_normalize = do_normalize
153
+ self.image_mean = image_mean or IMAGE_MEAN
154
+ self.image_std = image_std or IMAGE_STD
155
+ self.min_pixels = min_pixels
156
+ self.max_pixels = max_pixels
157
+ self.patch_size = patch_size
158
+ self.merge_size = merge_size
159
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
160
+ self.do_convert_rgb = do_convert_rgb
161
+ validate_preprocess_arguments(
162
+ rescale_factor=self.rescale_factor,
163
+ do_normalize=self.do_normalize,
164
+ image_mean=self.image_mean,
165
+ image_std=self.image_std,
166
+ do_resize=self.do_resize,
167
+ size=self.size,
168
+ resample=self.resample,
169
+ )
170
+
171
+ def _preprocess(self, image: ImageInput, do_rescale=None, do_normalize=None):
172
+ if self.do_convert_rgb:
173
+ image = convert_to_rgb(image)
174
+ image = to_numpy_array(image)
175
+ input_data_format = infer_channel_dimension_format(image)
176
+ if self.do_resize:
177
+ image = smart_resize(
178
+ image,
179
+ factor=self.patch_size * self.merge_size,
180
+ resample=self.resample,
181
+ input_data_format=input_data_format,
182
+ min_pixels=self.min_pixels,
183
+ max_pixels=self.max_pixels,
184
+ )
185
+ if do_rescale or self.do_rescale:
186
+ image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)
187
+ if do_normalize or self.do_normalize:
188
+ image = self.normalize(
189
+ image=image, mean=self.image_mean, std=self.image_std,
190
+ input_data_format=input_data_format,
191
+ )
192
+ return image
193
+
194
+ def preprocess(self, images: list[ImageInput] | None, do_rescale=None, do_normalize=None, **kwargs):
195
+ del kwargs
196
+ if images is None:
197
+ return []
198
+ images = [item for item in images if item is not None]
199
+ if not valid_images(images):
200
+ raise ValueError(
201
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
202
+ "torch.Tensor, tf.Tensor or jax.ndarray."
203
+ )
204
+ pixel_values = []
205
+ for image in images:
206
+ processed_image = self._preprocess(image, do_rescale, do_normalize)
207
+ processed_image = processed_image[None, ...]
208
+ pixel_values.append(processed_image)
209
+ return pixel_values
210
+
211
+ def batch_images_with_mask(self, pixel_values, max_image_height, max_image_width):
212
+ if pixel_values is None:
213
+ return None
214
+ pixel_values = [item for item in pixel_values if item is not None and len(item) != 0]
215
+ if len(pixel_values) == 0:
216
+ return None
217
+ pixel_values = [torch.from_numpy(img) for img in pixel_values]
218
+ max_temporal = max(img.shape[0] for img in pixel_values)
219
+
220
+ def pad_image_and_mask(img):
221
+ time_steps, height, width, channels = img.shape
222
+ if channels != 3:
223
+ raise ValueError(f"Expected 3-channel RGB images, got {channels} channels.")
224
+ padding = (0, 0, 0, max_image_width - width, 0, max_image_height - height, 0, max_temporal - time_steps)
225
+ padded_image = torch.nn.functional.pad(img, padding)
226
+ mask = torch.zeros((max_temporal, max_image_height, max_image_width), dtype=torch.long)
227
+ mask[:time_steps, :height, :width] = 1
228
+ return padded_image, mask
229
+
230
+ padded_pixel_values, padding_masks = zip(*[pad_image_and_mask(img) for img in pixel_values])
231
+ padded_pixel_values = torch.stack(list(padded_pixel_values))
232
+ padding_masks = torch.stack(list(padding_masks))
233
+ return {"pixel_values": padded_pixel_values, "padding_mask": padding_masks}
234
+
235
+
236
+ # ---------------------------------------------------------------------------
237
+ # Positional encoding helpers
238
+ # ---------------------------------------------------------------------------
239
+
240
+ def _compute_image_spatial_positions(
241
+ pixel_mask_THW: torch.Tensor,
242
+ spatial_patch_size: int,
243
+ temporal_patch_size: int = 1,
244
+ ) -> tuple[torch.Tensor, torch.Tensor]:
245
+ mask_thw = E.reduce(
246
+ pixel_mask_THW,
247
+ "(t tp) (h hp) (w wp) -> t h w",
248
+ reduction="any",
249
+ tp=temporal_patch_size,
250
+ hp=spatial_patch_size,
251
+ wp=spatial_patch_size,
252
+ )
253
+ width = E.reduce(mask_thw.sum(dim=-1).int(), "t h -> ", reduction="max")
254
+ height = E.reduce(mask_thw.sum(dim=-2).int(), "t w -> ", reduction="max")
255
+ xlim = torch.sqrt(width / height)
256
+ ylim = torch.sqrt(height / width)
257
+ xpos = torch.linspace(-xlim, xlim, int(width))
258
+ ypos = torch.linspace(-ylim, ylim, int(height))
259
+ wpos, hpos = torch.meshgrid(xpos, ypos, indexing="xy")
260
+ return hpos.flatten(), wpos.flatten()
261
+
262
+
263
+ def _get_image_token_masks(tokens, config):
264
+ spatial_mask = tokens == config.img_id
265
+ no_increase_mask = (
266
+ spatial_mask
267
+ | (tokens == config.image_reg_1_token_id)
268
+ | (tokens == config.image_reg_2_token_id)
269
+ | (tokens == config.image_reg_3_token_id)
270
+ | (tokens == config.image_reg_4_token_id)
271
+ | (tokens == config.img_end_id)
272
+ )
273
+ return spatial_mask, no_increase_mask
274
+
275
+
276
+ def get_pos_thw(
277
+ tokens: torch.Tensor,
278
+ pixel_masks_NTHW: torch.Tensor,
279
+ config,
280
+ spatial_patch_size: int,
281
+ temporal_patch_size: int = 1,
282
+ pad_token_id: int = None,
283
+ ):
284
+ assert pad_token_id is not None
285
+ assert tokens.ndim == 2
286
+ assert pixel_masks_NTHW.ndim == 4
287
+
288
+ spatial_img_token_mask_BS, no_increase_idx_img_token_mask_BS = _get_image_token_masks(tokens, config)
289
+
290
+ hpos_parts, wpos_parts = [], []
291
+ for i in range(pixel_masks_NTHW.shape[0]):
292
+ h, w = _compute_image_spatial_positions(pixel_masks_NTHW[i], spatial_patch_size, temporal_patch_size)
293
+ hpos_parts.append(h)
294
+ wpos_parts.append(w)
295
+
296
+ hpos_N = torch.cat(hpos_parts) if hpos_parts else torch.empty(0)
297
+ wpos_N = torch.cat(wpos_parts) if wpos_parts else torch.empty(0)
298
+
299
+ expected_tokens = spatial_img_token_mask_BS.sum().item()
300
+ actual_tokens = hpos_N.numel()
301
+ assert actual_tokens == expected_tokens, (
302
+ f"Mismatch between spatial image tokens ({expected_tokens}) and generated positions ({actual_tokens})."
303
+ )
304
+
305
+ hpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device)
306
+ wpos_BS = torch.full_like(tokens, fill_value=torch.nan, dtype=torch.float, device=tokens.device)
307
+ hpos_BS = hpos_BS.masked_scatter_(spatial_img_token_mask_BS, hpos_N)
308
+ wpos_BS = wpos_BS.masked_scatter_(spatial_img_token_mask_BS, wpos_N)
309
+
310
+ tpos_BS = torch.ones_like(tokens, dtype=torch.float, device=tokens.device)
311
+ tpos_BS[no_increase_idx_img_token_mask_BS] = 0
312
+ tpos_BS = torch.cumsum(tpos_BS, dim=1) - 1
313
+ tpos_BS[tokens == pad_token_id] = 0
314
+
315
+ hw_pos_BS2 = torch.stack([hpos_BS, wpos_BS], dim=-1)
316
+ return tpos_BS.long(), hw_pos_BS2
317
+
318
+
319
+ def calculate_image_tokens(image, patch_size, merge_size):
320
+ height, width = get_image_size(image)
321
+ return int((height * width) / (patch_size * patch_size * merge_size * merge_size))
322
+
323
+
324
+ def tokenize_inputs(prompt, images, tokenizer, config, patch_size, merge_size, max_length):
325
+ img_reg_ids = [
326
+ config.image_reg_1_token_id,
327
+ config.image_reg_2_token_id,
328
+ config.image_reg_3_token_id,
329
+ config.image_reg_4_token_id,
330
+ ]
331
+
332
+ if images is not None and len(images) > 0:
333
+ image_token_counts = [calculate_image_tokens(image, patch_size, merge_size) for image in images]
334
+ else:
335
+ image_token_counts = []
336
+
337
+ image_token = tokenizer.convert_ids_to_tokens(config.img_id)
338
+ prompt_chunks = [tokenizer.encode(chunk) for chunk in prompt.split(image_token)]
339
+
340
+ def insert_separator(X, sep):
341
+ return [ele for sublist in zip(X, sep) for ele in sublist][:-1]
342
+
343
+ input_ids = []
344
+ offset = 0
345
+ bos_id = getattr(tokenizer, "bos_token_id", None)
346
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and bos_id is not None and prompt_chunks[0][0] == bos_id:
347
+ offset = 1
348
+ input_ids.append(prompt_chunks[0][0])
349
+
350
+ separators = []
351
+ for count in image_token_counts:
352
+ tokens = [config.img_id] * count
353
+ image_block = [config.image_cls_token_id, *img_reg_ids, *tokens, config.img_end_id]
354
+ separators.append(image_block)
355
+
356
+ if len(separators) != 0 and len(separators) != len(prompt_chunks):
357
+ separators.append(separators[-1])
358
+
359
+ selected_images = []
360
+ if len(separators) == 0:
361
+ input_ids = prompt_chunks[0]
362
+ else:
363
+ for index, x in enumerate(insert_separator(prompt_chunks, separators)):
364
+ if index % 2 != 0:
365
+ if (len(input_ids) + len(x)) < max_length:
366
+ input_ids.extend(x)
367
+ selected_images.append(images[index // 2])
368
+ elif index % 2 == 0:
369
+ input_ids.extend(x[offset:])
370
+
371
+ input_ids = torch.LongTensor(input_ids)
372
+ return input_ids, selected_images
373
+
374
+
375
+ def process_batch(
376
+ tokenizer,
377
+ config,
378
+ image_prompt_pairs,
379
+ max_length,
380
+ min_dimension,
381
+ max_dimension,
382
+ patch_size=16,
383
+ merge_size=1,
384
+ ):
385
+ """
386
+ Process a batch of images with text prompts.
387
+ Uses LEFT PADDING for proper batch generation with causal models.
388
+ """
389
+ all_input_ids = []
390
+ all_selected_images = []
391
+ processor_local = ImageProcessor(patch_size, merge_size)
392
+
393
+ for img_input, prompt in image_prompt_pairs:
394
+ img = load_image(img_input)
395
+ if img is not None:
396
+ img = resize_image_if_necessary(img, min_dimension, max_dimension)
397
+ images = processor_local.preprocess(images=[img] if img else [])
398
+ input_ids, selected_images = tokenize_inputs(
399
+ prompt, images, tokenizer, config, patch_size, merge_size, max_length,
400
+ )
401
+ all_input_ids.append(input_ids)
402
+ all_selected_images.extend(selected_images)
403
+
404
+ pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
405
+ padded_input_ids = torch.nn.utils.rnn.pad_sequence(
406
+ all_input_ids, batch_first=True, padding_value=pad_token_id, padding_side="left",
407
+ )
408
+
409
+ processed = processor_local.batch_images_with_mask(all_selected_images, max_dimension, max_dimension)
410
+ assert processed is not None
411
+
412
+ pos_t, pos_hw = get_pos_thw(
413
+ padded_input_ids, processed["padding_mask"], config, patch_size, pad_token_id=pad_token_id,
414
+ )
415
+
416
+ return {
417
+ "tokens": padded_input_ids,
418
+ "pixel_values": processed["pixel_values"],
419
+ "pixel_mask": processed["padding_mask"],
420
+ "pos_t": pos_t,
421
+ "pos_hw": pos_hw,
422
+ "pad_token_id": pad_token_id,
423
+ }
rope.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops as E
2
+ import torch
3
+
4
+
5
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
6
+ """
7
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
8
+
9
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
10
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
11
+ The returned tensor contains complex values in complex64 data type.
12
+
13
+ Args:
14
+ dim (int): Dimension of the frequency tensor.
15
+ end (int): End index for precomputing frequencies.
16
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
17
+
18
+ Returns:
19
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
20
+ """
21
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
22
+ t = torch.arange(end, device=freqs.device)
23
+ freqs = torch.outer(t, freqs).float()
24
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
25
+ return freqs_cis # [S, D//2]
26
+
27
+
28
+ def apply_rotary_emb(
29
+ xq: torch.Tensor,
30
+ xk: torch.Tensor,
31
+ freqs_cis: torch.Tensor,
32
+ ) -> tuple[torch.Tensor, torch.Tensor]:
33
+ """1D rotary embedding"""
34
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
35
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
36
+ assert freqs_cis.ndim == 3, (
37
+ "Freqs_cis must be indexed by position ids already and has shape (B,S,D)"
38
+ )
39
+ freqs_cis = E.rearrange(freqs_cis, "b s d -> b s 1 d")
40
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
41
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
42
+ return xq_out.type_as(xq), xk_out.type_as(xk)
43
+
44
+
45
+ ###### 2D golden rope
46
+ """
47
+ Dimension key:
48
+ B: batch size
49
+ S: number of tokens per sample, Seqlen
50
+ T: Number of selected Tokens
51
+ P: pos_dim
52
+ h: n_heads
53
+ d: head_dim
54
+ F: num_freqs == head_dim // 2
55
+ """
56
+
57
+
58
+ def apply_golden_freqs_cis_to_visual_pos(freqs_hFP, pos_BSP) -> torch.Tensor:
59
+ """
60
+ This function is applied once per input batch, and the cached
61
+ freqs_cis is passed through to all layers.
62
+ Safe for Torch‑Inductor because it never uses boolean indexing on a symbolic tensor.
63
+ """
64
+ # 1. Boolean mask → integer indices (no unbacked shapes)
65
+ img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
66
+ idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # each shape: (N,)
67
+
68
+ # 2. Gather the positional tensor for those tokens
69
+ pos_tP = pos_BSP[idx_b, idx_s].float() # (N, p)
70
+
71
+ # 3. Project positions onto the frequency table → angles θ
72
+ theta_thF = torch.einsum("tp,hfp->thf", pos_tP, freqs_hFP.float()) # (t, h, f)
73
+
74
+ # 4. Convert to complex numbers on the unit circle
75
+ freqs_cis_thF = torch.polar(torch.ones_like(theta_thF), theta_thF)
76
+ return freqs_cis_thF
77
+
78
+
79
+ def apply_golden_rotary_emb(input_BShd, freqs_cis_thF, pos_BSP) -> torch.Tensor:
80
+ """
81
+ Rotates *only* the image tokens in `input_BShd`. No boolean indexing,
82
+ so it is safe for Torch‑Inductor.
83
+ """
84
+ img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
85
+ idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # (N,)
86
+
87
+ input_thd = input_BShd[idx_b, idx_s].float() # (N, h, d)
88
+ x_even = input_thd[..., 0::2] # (N, h, F)
89
+ x_odd = input_thd[..., 1::2] # (N, h, F)
90
+
91
+ cos_thF = freqs_cis_thF.real
92
+ sin_thF = freqs_cis_thF.imag
93
+
94
+ # (a + ib) * (c + id) = (ac - bd) + i(ad + bc)
95
+ rot_even = x_even * cos_thF - x_odd * sin_thF
96
+ rot_odd = x_even * sin_thF + x_odd * cos_thF
97
+
98
+ output_real = torch.empty_like(input_thd)
99
+ output_real[..., 0::2] = rot_even
100
+ output_real[..., 1::2] = rot_odd
101
+ output_real = output_real.type_as(input_BShd)
102
+
103
+ output_BShd = input_BShd.clone()
104
+ output_BShd[idx_b, idx_s] = output_real
105
+
106
+ return output_BShd
107
+
108
+
109
+ def apply_3d_rotary_emb(
110
+ xq: torch.Tensor, # (B, S, H, D)
111
+ xk: torch.Tensor, # (B, S, H, D)
112
+ freqs_cis: torch.Tensor,
113
+ freqs_cis_2d: torch.Tensor | None,
114
+ pos_hw: torch.Tensor | None, # (B,S,3)
115
+ ) -> tuple[torch.Tensor, torch.Tensor]:
116
+ xq_t, xq_hw = xq.chunk(chunks=2, dim=-1)
117
+ xk_t, xk_hw = xk.chunk(chunks=2, dim=-1)
118
+ B, S, H, D = xq.shape
119
+
120
+ xq_t, xk_t = apply_rotary_emb(xq_t, xk_t, freqs_cis)
121
+ if freqs_cis_2d is not None and pos_hw is not None:
122
+ xq_hw = apply_golden_rotary_emb(xq_hw, freqs_cis_2d, pos_hw)
123
+ xk_hw = apply_golden_rotary_emb(xk_hw, freqs_cis_2d, pos_hw)
124
+
125
+ xq_out = torch.concat([xq_t, xq_hw], dim=-1).type_as(xq)
126
+ xk_out = torch.concat([xk_t, xk_hw], dim=-1).type_as(xk)
127
+ return xq_out, xk_out
special_tokens_map.json ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|pad|>",
4
+ ">>ABSTRACT<<",
5
+ ">>INTRODUCTION<<",
6
+ ">>SUMMARY<<",
7
+ ">>COMMENT<<",
8
+ ">>ANSWER<<",
9
+ ">>QUESTION<<",
10
+ ">>DOMAIN<<",
11
+ ">>PREFIX<<",
12
+ ">>SUFFIX<<",
13
+ ">>MIDDLE<<",
14
+ "<|finetune_right_pad_id|>",
15
+ "<|start_header_id|>",
16
+ "<|end_header_id|>",
17
+ "<|eom_id|>",
18
+ "<|eot_id|>",
19
+ "<|begin_of_text|>",
20
+ ">>TITLE<<",
21
+ "<tool_response>",
22
+ "</tool_response>",
23
+ "<tool_call>",
24
+ "</tool_call>",
25
+ "<schema>",
26
+ "</schema>",
27
+ "<scratch_pad>",
28
+ "</scratch_pad>",
29
+ "<thinking>",
30
+ "</thinking>",
31
+ "<explanation>",
32
+ "</explanation>",
33
+ "<file_sep>",
34
+ "<repo_name>",
35
+ "<tr>",
36
+ "</tr>",
37
+ "<|image|>",
38
+ "<|image_row_sep|>",
39
+ "<|start_of_image|>",
40
+ "<|end_of_image|>",
41
+ "<|start_of_video|>",
42
+ "<|end_of_video|>",
43
+ "<|frame_sep|>",
44
+ "<|start_of_turn|>",
45
+ "<|end_of_turn|>",
46
+ "<|start_of_diffusion_query|>",
47
+ "<|end_of_diffusion_query|>",
48
+ "<|diffusion_query|>",
49
+ "<|object|>",
50
+ "<|coord|>",
51
+ "<|size|>",
52
+ "<|perceive|>",
53
+ "<|image_mask_token|>",
54
+ "<|image_cls|>",
55
+ "<|image_reg_1|>",
56
+ "<|image_reg_2|>",
57
+ "<|image_reg_3|>",
58
+ "<|image_reg_4|>",
59
+ "<|image_reg_5|>",
60
+ "<|image_reg_6|>",
61
+ "<|image_reg_7|>",
62
+ "<|image_reg_8|>",
63
+ "<|DET|>",
64
+ "<|POINTING|>",
65
+ "<|OCR_GROUNDING|>",
66
+ "<|OCR_DOC_PARSER|>",
67
+ "<|OCR_PLAIN|>",
68
+ "<|REF_SEG|>",
69
+ "<|POINT_REF_SEG|>",
70
+ "<|CAPTION|>",
71
+ "<|DETAILED_CAPTION|>",
72
+ "<|seg|>",
73
+ "<|end_of_query|>",
74
+ "<|start_of_query|>",
75
+ "<|task_sep|>",
76
+ "<|QA|>",
77
+ "<|LAYOUT_DETECTION|>",
78
+ "<|category_sep|>",
79
+ "<td>",
80
+ "</td>",
81
+ "<th>",
82
+ "</th>",
83
+ ">>UNUSED_261<<",
84
+ ">>UNUSED_262<<",
85
+ ">>UNUSED_263<<",
86
+ ">>UNUSED_264<<",
87
+ ">>UNUSED_265<<",
88
+ ">>UNUSED_266<<",
89
+ ">>UNUSED_267<<",
90
+ ">>UNUSED_268<<",
91
+ ">>UNUSED_269<<",
92
+ ">>UNUSED_270<<",
93
+ ">>UNUSED_271<<",
94
+ ">>UNUSED_272<<",
95
+ ">>UNUSED_273<<",
96
+ ">>UNUSED_274<<",
97
+ ">>UNUSED_275<<",
98
+ ">>UNUSED_276<<",
99
+ ">>UNUSED_277<<",
100
+ ">>UNUSED_278<<",
101
+ ">>UNUSED_279<<",
102
+ ">>UNUSED_280<<",
103
+ ">>UNUSED_281<<",
104
+ ">>UNUSED_282<<",
105
+ ">>UNUSED_283<<",
106
+ ">>UNUSED_284<<",
107
+ ">>UNUSED_285<<",
108
+ ">>UNUSED_286<<",
109
+ ">>UNUSED_287<<",
110
+ ">>UNUSED_288<<",
111
+ ">>UNUSED_289<<",
112
+ ">>UNUSED_290<<",
113
+ ">>UNUSED_291<<",
114
+ ">>UNUSED_292<<",
115
+ ">>UNUSED_293<<",
116
+ ">>UNUSED_294<<",
117
+ ">>UNUSED_295<<",
118
+ ">>UNUSED_296<<",
119
+ ">>UNUSED_297<<",
120
+ ">>UNUSED_298<<",
121
+ ">>UNUSED_299<<",
122
+ ">>UNUSED_300<<",
123
+ ">>UNUSED_301<<",
124
+ ">>UNUSED_302<<",
125
+ ">>UNUSED_303<<",
126
+ ">>UNUSED_304<<",
127
+ ">>UNUSED_305<<",
128
+ ">>UNUSED_306<<",
129
+ ">>UNUSED_307<<",
130
+ ">>UNUSED_308<<",
131
+ ">>UNUSED_309<<",
132
+ ">>UNUSED_310<<",
133
+ ">>UNUSED_311<<",
134
+ ">>UNUSED_312<<",
135
+ ">>UNUSED_313<<",
136
+ ">>UNUSED_314<<",
137
+ ">>UNUSED_315<<",
138
+ ">>UNUSED_316<<",
139
+ ">>UNUSED_317<<",
140
+ ">>UNUSED_318<<",
141
+ ">>UNUSED_319<<",
142
+ ">>UNUSED_320<<",
143
+ ">>UNUSED_321<<",
144
+ ">>UNUSED_322<<",
145
+ ">>UNUSED_323<<",
146
+ ">>UNUSED_324<<",
147
+ ">>UNUSED_325<<",
148
+ ">>UNUSED_326<<",
149
+ ">>UNUSED_327<<",
150
+ ">>UNUSED_328<<",
151
+ ">>UNUSED_329<<",
152
+ ">>UNUSED_330<<",
153
+ ">>UNUSED_331<<",
154
+ ">>UNUSED_332<<",
155
+ ">>UNUSED_333<<",
156
+ ">>UNUSED_334<<",
157
+ ">>UNUSED_335<<",
158
+ ">>UNUSED_336<<",
159
+ ">>UNUSED_337<<",
160
+ ">>UNUSED_338<<",
161
+ ">>UNUSED_339<<",
162
+ ">>UNUSED_340<<",
163
+ ">>UNUSED_341<<",
164
+ ">>UNUSED_342<<",
165
+ ">>UNUSED_343<<",
166
+ ">>UNUSED_344<<",
167
+ ">>UNUSED_345<<",
168
+ ">>UNUSED_346<<",
169
+ ">>UNUSED_347<<",
170
+ ">>UNUSED_348<<",
171
+ ">>UNUSED_349<<",
172
+ ">>UNUSED_350<<",
173
+ ">>UNUSED_351<<",
174
+ ">>UNUSED_352<<",
175
+ ">>UNUSED_353<<",
176
+ ">>UNUSED_354<<",
177
+ ">>UNUSED_355<<",
178
+ ">>UNUSED_356<<",
179
+ ">>UNUSED_357<<",
180
+ ">>UNUSED_358<<",
181
+ ">>UNUSED_359<<",
182
+ ">>UNUSED_360<<",
183
+ ">>UNUSED_361<<",
184
+ ">>UNUSED_362<<",
185
+ ">>UNUSED_363<<",
186
+ ">>UNUSED_364<<",
187
+ ">>UNUSED_365<<",
188
+ ">>UNUSED_366<<",
189
+ ">>UNUSED_367<<",
190
+ ">>UNUSED_368<<",
191
+ ">>UNUSED_369<<",
192
+ ">>UNUSED_370<<",
193
+ ">>UNUSED_371<<",
194
+ ">>UNUSED_372<<",
195
+ ">>UNUSED_373<<",
196
+ ">>UNUSED_374<<",
197
+ ">>UNUSED_375<<",
198
+ ">>UNUSED_376<<",
199
+ ">>UNUSED_377<<",
200
+ ">>UNUSED_378<<",
201
+ ">>UNUSED_379<<",
202
+ ">>UNUSED_380<<",
203
+ ">>UNUSED_381<<",
204
+ ">>UNUSED_382<<",
205
+ ">>UNUSED_383<<",
206
+ ">>UNUSED_384<<",
207
+ ">>UNUSED_385<<",
208
+ ">>UNUSED_386<<",
209
+ ">>UNUSED_387<<",
210
+ ">>UNUSED_388<<",
211
+ ">>UNUSED_389<<",
212
+ ">>UNUSED_390<<",
213
+ ">>UNUSED_391<<",
214
+ ">>UNUSED_392<<",
215
+ ">>UNUSED_393<<",
216
+ ">>UNUSED_394<<",
217
+ ">>UNUSED_395<<",
218
+ ">>UNUSED_396<<",
219
+ ">>UNUSED_397<<",
220
+ ">>UNUSED_398<<",
221
+ ">>UNUSED_399<<",
222
+ ">>UNUSED_400<<",
223
+ ">>UNUSED_401<<",
224
+ ">>UNUSED_402<<",
225
+ ">>UNUSED_403<<",
226
+ ">>UNUSED_404<<",
227
+ ">>UNUSED_405<<",
228
+ ">>UNUSED_406<<",
229
+ ">>UNUSED_407<<",
230
+ ">>UNUSED_408<<",
231
+ ">>UNUSED_409<<",
232
+ ">>UNUSED_410<<",
233
+ ">>UNUSED_411<<",
234
+ ">>UNUSED_412<<",
235
+ ">>UNUSED_413<<",
236
+ ">>UNUSED_414<<",
237
+ ">>UNUSED_415<<",
238
+ ">>UNUSED_416<<",
239
+ ">>UNUSED_417<<",
240
+ ">>UNUSED_418<<",
241
+ ">>UNUSED_419<<",
242
+ ">>UNUSED_420<<",
243
+ ">>UNUSED_421<<",
244
+ ">>UNUSED_422<<",
245
+ ">>UNUSED_423<<",
246
+ ">>UNUSED_424<<",
247
+ ">>UNUSED_425<<",
248
+ ">>UNUSED_426<<",
249
+ ">>UNUSED_427<<",
250
+ ">>UNUSED_428<<",
251
+ ">>UNUSED_429<<",
252
+ ">>UNUSED_430<<",
253
+ ">>UNUSED_431<<",
254
+ ">>UNUSED_432<<",
255
+ ">>UNUSED_433<<",
256
+ ">>UNUSED_434<<",
257
+ ">>UNUSED_435<<",
258
+ ">>UNUSED_436<<",
259
+ ">>UNUSED_437<<",
260
+ ">>UNUSED_438<<",
261
+ ">>UNUSED_439<<",
262
+ ">>UNUSED_440<<",
263
+ ">>UNUSED_441<<",
264
+ ">>UNUSED_442<<",
265
+ ">>UNUSED_443<<",
266
+ ">>UNUSED_444<<",
267
+ ">>UNUSED_445<<",
268
+ ">>UNUSED_446<<",
269
+ ">>UNUSED_447<<",
270
+ ">>UNUSED_448<<",
271
+ ">>UNUSED_449<<",
272
+ ">>UNUSED_450<<",
273
+ ">>UNUSED_451<<",
274
+ ">>UNUSED_452<<",
275
+ ">>UNUSED_453<<",
276
+ ">>UNUSED_454<<",
277
+ ">>UNUSED_455<<",
278
+ ">>UNUSED_456<<",
279
+ ">>UNUSED_457<<",
280
+ ">>UNUSED_458<<",
281
+ ">>UNUSED_459<<",
282
+ ">>UNUSED_460<<",
283
+ ">>UNUSED_461<<",
284
+ ">>UNUSED_462<<",
285
+ ">>UNUSED_463<<",
286
+ ">>UNUSED_464<<",
287
+ ">>UNUSED_465<<",
288
+ ">>UNUSED_466<<",
289
+ ">>UNUSED_467<<",
290
+ ">>UNUSED_468<<",
291
+ ">>UNUSED_469<<",
292
+ ">>UNUSED_470<<",
293
+ ">>UNUSED_471<<",
294
+ ">>UNUSED_472<<",
295
+ ">>UNUSED_473<<",
296
+ ">>UNUSED_474<<",
297
+ ">>UNUSED_475<<",
298
+ ">>UNUSED_476<<",
299
+ ">>UNUSED_477<<",
300
+ ">>UNUSED_478<<",
301
+ ">>UNUSED_479<<",
302
+ ">>UNUSED_480<<",
303
+ ">>UNUSED_481<<",
304
+ ">>UNUSED_482<<",
305
+ ">>UNUSED_483<<",
306
+ ">>UNUSED_484<<",
307
+ ">>UNUSED_485<<",
308
+ ">>UNUSED_486<<",
309
+ ">>UNUSED_487<<",
310
+ ">>UNUSED_488<<",
311
+ ">>UNUSED_489<<",
312
+ ">>UNUSED_490<<",
313
+ ">>UNUSED_491<<",
314
+ ">>UNUSED_492<<",
315
+ ">>UNUSED_493<<",
316
+ ">>UNUSED_494<<",
317
+ ">>UNUSED_495<<",
318
+ ">>UNUSED_496<<",
319
+ ">>UNUSED_497<<",
320
+ ">>UNUSED_498<<",
321
+ ">>UNUSED_499<<",
322
+ ">>UNUSED_500<<",
323
+ ">>UNUSED_501<<",
324
+ ">>UNUSED_502<<",
325
+ ">>UNUSED_503<<",
326
+ ">>UNUSED_504<<",
327
+ ">>UNUSED_505<<",
328
+ ">>UNUSED_506<<",
329
+ ">>UNUSED_507<<",
330
+ ">>UNUSED_508<<",
331
+ ">>UNUSED_509<<",
332
+ ">>UNUSED_510<<",
333
+ ">>UNUSED_511<<"
334
+ ],
335
+ "eos_token": {
336
+ "content": "<|end_of_text|>",
337
+ "lstrip": false,
338
+ "normalized": false,
339
+ "rstrip": false,
340
+ "single_word": false
341
+ },
342
+ "image_token": "<|image|>",
343
+ "image_cls_token": "<|image_cls|>",
344
+ "image_reg_1_token": "<|image_reg_1|>",
345
+ "image_reg_2_token": "<|image_reg_2|>",
346
+ "image_reg_3_token": "<|image_reg_3|>",
347
+ "image_reg_4_token": "<|image_reg_4|>",
348
+ "image_reg_5_token": "<|image_reg_5|>",
349
+ "image_reg_6_token": "<|image_reg_6|>",
350
+ "image_reg_7_token": "<|image_reg_7|>",
351
+ "image_reg_8_token": "<|image_reg_8|>",
352
+ "image_row_sep_token": "<|image_row_sep|>",
353
+ "start_of_image_token": "<|start_of_image|>",
354
+ "end_of_image_token": "<|end_of_image|>",
355
+ "start_of_video_token": "<|start_of_video|>",
356
+ "end_of_video_token": "<|end_of_video|>",
357
+ "frame_sep_token": "<|frame_sep|>",
358
+ "start_of_turn_token": "<|start_of_turn|>",
359
+ "end_of_turn_token": "<|end_of_turn|>",
360
+ "start_of_diffusion_query_token": "<|start_of_diffusion_query|>",
361
+ "end_of_diffusion_query_token": "<|end_of_diffusion_query|>",
362
+ "diffusion_query_token": "<|diffusion_query|>",
363
+ "object_token": "<|object|>",
364
+ "coord_token": "<|coord|>",
365
+ "size_token": "<|size|>",
366
+ "perceive_token": "<|perceive|>",
367
+ "image_mask_token": "<|image_mask_token|>",
368
+ "det_token": "<|DET|>",
369
+ "pointing_token": "<|POINTING|>",
370
+ "ocr_grounding_token": "<|OCR_GROUNDING|>",
371
+ "ocr_doc_parser_token": "<|OCR_DOC_PARSER|>",
372
+ "ocr_plain_token": "<|OCR_PLAIN|>",
373
+ "ref_seg_token": "<|REF_SEG|>",
374
+ "point_ref_seg_token": "<|POINT_REF_SEG|>",
375
+ "caption_token": "<|CAPTION|>",
376
+ "detailed_caption_token": "<|DETAILED_CAPTION|>",
377
+ "seg_token": "<|seg|>",
378
+ "start_of_query_token": "<|start_of_query|>",
379
+ "end_of_query_token": "<|end_of_query|>",
380
+ "task_sep_token": "<|task_sep|>",
381
+ "qa_token": "<|QA|>",
382
+ "layout_detection_token": "<|LAYOUT_DETECTION|>",
383
+ "category_sep_token": "<|category_sep|>",
384
+ "table_row_start_token": "<tr>",
385
+ "table_row_end_token": "</tr>",
386
+ "table_data_start_token": "<td>",
387
+ "table_data_end_token": "</td>",
388
+ "table_header_start_token": "<th>",
389
+ "table_header_end_token": "</th>"
390
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "caption_token": "<|CAPTION|>",
4
+ "category_sep_token": "<|category_sep|>",
5
+ "clean_up_tokenization_spaces": true,
6
+ "coord_token": "<|coord|>",
7
+ "det_token": "<|DET|>",
8
+ "detailed_caption_token": "<|DETAILED_CAPTION|>",
9
+ "diffusion_query_token": "<|diffusion_query|>",
10
+ "end_of_diffusion_query_token": "<|end_of_diffusion_query|>",
11
+ "end_of_image_token": "<|end_of_image|>",
12
+ "end_of_query_token": "<|end_of_query|>",
13
+ "end_of_turn_token": "<|end_of_turn|>",
14
+ "end_of_video_token": "<|end_of_video|>",
15
+ "eos_token": "<|end_of_text|>",
16
+ "frame_sep_token": "<|frame_sep|>",
17
+ "image_cls_token": "<|image_cls|>",
18
+ "image_mask_token": "<|image_mask_token|>",
19
+ "image_reg_1_token": "<|image_reg_1|>",
20
+ "image_reg_2_token": "<|image_reg_2|>",
21
+ "image_reg_3_token": "<|image_reg_3|>",
22
+ "image_reg_4_token": "<|image_reg_4|>",
23
+ "image_reg_5_token": "<|image_reg_5|>",
24
+ "image_reg_6_token": "<|image_reg_6|>",
25
+ "image_reg_7_token": "<|image_reg_7|>",
26
+ "image_reg_8_token": "<|image_reg_8|>",
27
+ "image_row_sep_token": "<|image_row_sep|>",
28
+ "image_token": "<|image|>",
29
+ "is_local": true,
30
+ "layout_detection_token": "<|LAYOUT_DETECTION|>",
31
+ "model_input_names": [
32
+ "input_ids",
33
+ "attention_mask"
34
+ ],
35
+ "model_max_length": 1000000000000000019884624838656,
36
+ "model_specific_special_tokens": {
37
+ "caption_token": "<|CAPTION|>",
38
+ "category_sep_token": "<|category_sep|>",
39
+ "coord_token": "<|coord|>",
40
+ "det_token": "<|DET|>",
41
+ "detailed_caption_token": "<|DETAILED_CAPTION|>",
42
+ "diffusion_query_token": "<|diffusion_query|>",
43
+ "end_of_diffusion_query_token": "<|end_of_diffusion_query|>",
44
+ "end_of_image_token": "<|end_of_image|>",
45
+ "end_of_query_token": "<|end_of_query|>",
46
+ "end_of_turn_token": "<|end_of_turn|>",
47
+ "end_of_video_token": "<|end_of_video|>",
48
+ "frame_sep_token": "<|frame_sep|>",
49
+ "image_cls_token": "<|image_cls|>",
50
+ "image_mask_token": "<|image_mask_token|>",
51
+ "image_reg_1_token": "<|image_reg_1|>",
52
+ "image_reg_2_token": "<|image_reg_2|>",
53
+ "image_reg_3_token": "<|image_reg_3|>",
54
+ "image_reg_4_token": "<|image_reg_4|>",
55
+ "image_reg_5_token": "<|image_reg_5|>",
56
+ "image_reg_6_token": "<|image_reg_6|>",
57
+ "image_reg_7_token": "<|image_reg_7|>",
58
+ "image_reg_8_token": "<|image_reg_8|>",
59
+ "image_row_sep_token": "<|image_row_sep|>",
60
+ "image_token": "<|image|>",
61
+ "layout_detection_token": "<|LAYOUT_DETECTION|>",
62
+ "object_token": "<|object|>",
63
+ "ocr_doc_parser_token": "<|OCR_DOC_PARSER|>",
64
+ "ocr_grounding_token": "<|OCR_GROUNDING|>",
65
+ "ocr_plain_token": "<|OCR_PLAIN|>",
66
+ "perceive_token": "<|perceive|>",
67
+ "point_ref_seg_token": "<|POINT_REF_SEG|>",
68
+ "pointing_token": "<|POINTING|>",
69
+ "qa_token": "<|QA|>",
70
+ "ref_seg_token": "<|REF_SEG|>",
71
+ "seg_token": "<|seg|>",
72
+ "size_token": "<|size|>",
73
+ "start_of_diffusion_query_token": "<|start_of_diffusion_query|>",
74
+ "start_of_image_token": "<|start_of_image|>",
75
+ "start_of_query_token": "<|start_of_query|>",
76
+ "start_of_turn_token": "<|start_of_turn|>",
77
+ "start_of_video_token": "<|start_of_video|>",
78
+ "table_data_end_token": "</td>",
79
+ "table_data_start_token": "<td>",
80
+ "table_header_end_token": "</th>",
81
+ "table_header_start_token": "<th>",
82
+ "table_row_end_token": "</tr>",
83
+ "table_row_start_token": "<tr>",
84
+ "task_sep_token": "<|task_sep|>"
85
+ },
86
+ "object_token": "<|object|>",
87
+ "ocr_doc_parser_token": "<|OCR_DOC_PARSER|>",
88
+ "ocr_grounding_token": "<|OCR_GROUNDING|>",
89
+ "ocr_plain_token": "<|OCR_PLAIN|>",
90
+ "perceive_token": "<|perceive|>",
91
+ "point_ref_seg_token": "<|POINT_REF_SEG|>",
92
+ "pointing_token": "<|POINTING|>",
93
+ "qa_token": "<|QA|>",
94
+ "ref_seg_token": "<|REF_SEG|>",
95
+ "seg_token": "<|seg|>",
96
+ "size_token": "<|size|>",
97
+ "start_of_diffusion_query_token": "<|start_of_diffusion_query|>",
98
+ "start_of_image_token": "<|start_of_image|>",
99
+ "start_of_query_token": "<|start_of_query|>",
100
+ "start_of_turn_token": "<|start_of_turn|>",
101
+ "start_of_video_token": "<|start_of_video|>",
102
+ "table_data_end_token": "</td>",
103
+ "table_data_start_token": "<td>",
104
+ "table_header_end_token": "</th>",
105
+ "table_header_start_token": "<th>",
106
+ "table_row_end_token": "</tr>",
107
+ "table_row_start_token": "<tr>",
108
+ "task_sep_token": "<|task_sep|>",
109
+ "tokenizer_class": "TokenizersBackend"
110
+ }