github-actions[bot] commited on
Commit
5778306
·
0 Parent(s):

Sync from https://github.com/ryanlinjui/menu-text-detection

Browse files
.checkpoints/.gitkeep ADDED
File without changes
.env.example ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ HUGGINGFACE_TOKEN="HUGGINGFACE_TOKEN"
2
+ GEMINI_API_TOKEN="GEMINI_API_TOKEN"
3
+ OPENAI_API_TOKEN="OPENAI_API_TOKEN"
.github/workflows/sync.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Spaces
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ jobs:
8
+ sync:
9
+ name: Sync
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Checkout Repository
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Remove bad files
16
+ run: rm -rf examples assets
17
+
18
+ - name: Sync to Hugging Face Spaces
19
+ uses: JacobLinCool/huggingface-sync@v1
20
+ with:
21
+ github: ${{ secrets.GITHUB_TOKEN }}
22
+ user: ryanlinjui # Hugging Face username or organization name
23
+ space: menu-text-detection # Hugging Face space name
24
+ token: ${{ secrets.HF_TOKEN }} # Hugging Face token
25
+ python_version: 3.11 # Python version
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mac
2
+ .DS_Store
3
+
4
+ # cache
5
+ __pycache__
6
+
7
+ # datasets
8
+ datasets
9
+
10
+ # papers
11
+ docs/papers
12
+
13
+ # uv
14
+ .venv
15
+
16
+ # gradio
17
+ .gradio
18
+
19
+ # env
20
+ .env
21
+
22
+ # checkpoint
23
+ .checkpoints/*
24
+ !.checkpoints/.gitkeep
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 RyanLin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: menu text detection
3
+ emoji: 🦄
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ python_version: 3.11
8
+ short_description: Extract structured menu information from images into JSON...
9
+ tags: [ "document-understanding","donut","fine-tuning","image-text-to-text","transformer" ]
10
+ ---
11
+
12
+ # Menu Text Detection System
13
+
14
+ Extract structured menu information from images into JSON using a fine-tuned E2E model or LLM.
15
+
16
+ [![Gradio Space Demo](https://img.shields.io/badge/GradioSpace-Demo-important?logo=huggingface)](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)
17
+ [![Hugging Face Models & Datasets](https://img.shields.io/badge/HuggingFace-Models_&_Datasets-important?logo=huggingface)](https://huggingface.co/collections/ryanlinjui/menu-text-detection-670ccf527626bb004bbfb39b)
18
+
19
+ https://github.com/user-attachments/assets/80e5d54c-f2c8-4593-ad9b-499e5b71d8f6
20
+
21
+ ## 🚀 Features
22
+ ### Overview
23
+ Currently supports the following information from menu images:
24
+
25
+ - **Restaurant Name**
26
+ - **Business Hours**
27
+ - **Address**
28
+ - **Phone Number**
29
+ - **Dish Information**
30
+ - Name
31
+ - Price
32
+
33
+ > For the JSON schema, see [tools directory](./tools).
34
+
35
+ ### Supported Methods to Extract Menu Information
36
+ #### Fine-tuned E2E model and Training metrics
37
+ - [**Donut (Document Parsing Task)**](https://huggingface.co/ryanlinjui/donut-base-finetuned-menu) - Base model by [Clova AI (ECCV ’22)](https://github.com/clovaai/donut)
38
+
39
+ #### LLM Function Calling
40
+ - Google Gemini API
41
+ - OpenAI GPT API
42
+
43
+ ## 💻 Training / Fine-Tuning
44
+ ### Setup
45
+ Use [uv](https://github.com/astral-sh/uv) to set up the development environment:
46
+
47
+ ```bash
48
+ uv sync
49
+ ```
50
+
51
+ > or use `pip install -r requirements.txt` if it has any problems
52
+
53
+ ### Training Script (Datasets collecting, Fine-Tuning)
54
+ Please refer [`train.ipynb`](./train.ipynb). Use Jupyter Notebook for training:
55
+
56
+ ```bash
57
+ uv run jupyter-notebook
58
+ ```
59
+
60
+ > For VSCode users, please install Jupyter extension, then select `.venv/bin/python` as your kernel.
61
+
62
+ ### Run Demo Locally
63
+ ```bash
64
+ uv run python app.py
65
+ ```
app.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import requests
5
+ from io import BytesIO
6
+ from typing import List
7
+
8
+ import gradio as gr
9
+ from PIL import Image
10
+ from dotenv import load_dotenv
11
+ from pillow_heif import register_heif_opener
12
+
13
+ from menu.llm import (
14
+ GeminiAPI,
15
+ OpenAIAPI
16
+ )
17
+ from menu.donut import DonutFinetuned
18
+
19
+ donut_finetuned = DonutFinetuned("ryanlinjui/donut-base-finetuned-menu")
20
+
21
+ register_heif_opener()
22
+ load_dotenv(override=True)
23
+ GEMINI_API_TOKEN = os.getenv("GEMINI_API_TOKEN", "")
24
+ OPENAI_API_TOKEN = os.getenv("OPENAI_API_TOKEN", "")
25
+
26
+ SOURCE_CODE_GH_URL = "https://github.com/ryanlinjui/menu-text-detection"
27
+ BADGE_URL = "https://img.shields.io/badge/GitHub_Code-Click_Here!!-default?logo=github"
28
+
29
+ GITHUB_RAW_URL = "https://raw.githubusercontent.com/ryanlinjui/menu-text-detection/main"
30
+ EXAMPLE_IMAGE_LIST = [
31
+ [f"{GITHUB_RAW_URL}/examples/menu-hd.jpg"],
32
+ [f"{GITHUB_RAW_URL}/examples/menu-vs.jpg"],
33
+ [f"{GITHUB_RAW_URL}/examples/menu-si.jpg"]
34
+ ]
35
+ FINETUNED_MODEL_LIST = [
36
+ "Donut (Document Parsing Task) Fine-tuned Model"
37
+ ]
38
+ LLM_MODEL_LIST = [
39
+ "gemini-2.5-pro",
40
+ "gemini-2.5-flash",
41
+ "gemini-2.0-flash",
42
+ "gpt-4.1",
43
+ "gpt-4o",
44
+ "o4-mini"
45
+ ]
46
+ CSS_STYLE = """
47
+ .image-panel img {
48
+ max-height: 500px;
49
+ margin-top: -100px;
50
+ }
51
+ .large-text textarea {
52
+ font-size: 20px !important;
53
+ height: 600px !important;
54
+ width: 100% !important;
55
+ }
56
+ .control-row {
57
+ margin-top: -10px !important;
58
+ margin-bottom: -10px !important;
59
+ align-items: center !important;
60
+ justify-content: center !important;
61
+ }
62
+ .page-info {
63
+ text-align: center !important;
64
+ font-size: 20px !important;
65
+ display: flex !important;
66
+ align-items: center !important;
67
+ justify-content: center !important;
68
+ height: 100% !important;
69
+ font-weight: 900 !important;
70
+ color: #374151; /* Darker gray for clarity */
71
+ }
72
+ .page-info p {
73
+ margin: 0 !important;
74
+ width: 100% !important;
75
+ text-align: center !important;
76
+ }
77
+ .upload-btn {
78
+ margin-top: 2px !important;
79
+ background-color: #e0f2fe !important; /* Light blue background */
80
+ color: #0369a1 !important; /* Dark blue text */
81
+ border: 1px solid #0ea5e9 !important;
82
+ }
83
+ .upload-btn:hover {
84
+ background-color: #bae6fd !important;
85
+ }
86
+ .clear-btn {
87
+ margin-top: 2px !important;
88
+ }
89
+ .image-container {
90
+ height: 650px !important;
91
+ display: flex;
92
+ flex-direction: column;
93
+ border: 1px solid #e5e7eb;
94
+ border-radius: 8px;
95
+ padding: 4px;
96
+ }
97
+ """
98
+
99
+ def handle(images: List[str], model: str, api_token: str) -> str:
100
+ if not images:
101
+ raise gr.Error("Please upload an image first.")
102
+
103
+ # Convert to PIL Images
104
+ pil_images = []
105
+ for img in images:
106
+ if img.startswith("http://") or img.startswith("https://"):
107
+ try:
108
+ response = requests.get(img)
109
+ response.raise_for_status()
110
+ pil_images.append(Image.open(BytesIO(response.content)))
111
+ except Exception as e:
112
+ raise gr.Error(f"Failed to load image from URL: {str(e)}")
113
+ elif img.startswith("data:image/") and ";base64," in img:
114
+ try:
115
+ _, encoded = img.split(";base64,", 1)
116
+ data = base64.b64decode(encoded)
117
+ pil_images.append(Image.open(BytesIO(data)))
118
+ except Exception as e:
119
+ raise gr.Error(f"Failed to decode Base64 image: {str(e)}")
120
+ else:
121
+ pil_images.append(Image.open(img))
122
+
123
+ if model == FINETUNED_MODEL_LIST[0]:
124
+ result = donut_finetuned.predict(pil_images[0])
125
+
126
+ elif model in LLM_MODEL_LIST:
127
+ if len(api_token) < 10:
128
+ raise gr.Error(f"Please provide a valid token for {model}.")
129
+ try:
130
+ if model in LLM_MODEL_LIST[:3]:
131
+ result = GeminiAPI.call(pil_images, model, api_token)
132
+ else:
133
+ result = OpenAIAPI.call(pil_images, model, api_token)
134
+ except Exception as e:
135
+ raise gr.Error(f"Failed to process with API model {model}: {str(e)}")
136
+ else:
137
+ raise gr.Error("Invalid model selection. Please choose a valid model.")
138
+
139
+ return json.dumps(result, indent=4, ensure_ascii=False, sort_keys=True)
140
+
141
+ def UserInterface() -> gr.Interface:
142
+ with gr.Blocks(delete_cache=(86400, 86400)) as gradio_interface:
143
+ gr.HTML(f'<a href="{SOURCE_CODE_GH_URL}"><img src="{BADGE_URL}" alt="GitHub Code"/></a>')
144
+ gr.Markdown("# Menu Text Detection")
145
+
146
+ images_state = gr.State([])
147
+ current_index_state = gr.State(0)
148
+
149
+ with gr.Row():
150
+ with gr.Column(scale=1, min_width=500):
151
+ gr.Markdown("## 📷 Menu Image")
152
+
153
+ with gr.Column(elem_classes="image-container"):
154
+ menu_image_display = gr.Image(
155
+ label="Input menu image",
156
+ type="filepath",
157
+ elem_classes="image-panel",
158
+ interactive=False,
159
+ show_label=True,
160
+ height="100%",
161
+ width="100%"
162
+ )
163
+ with gr.Row(elem_classes="control-row"):
164
+ prev_btn = gr.Button("◀️ Previous", variant="secondary", scale=1)
165
+ with gr.Column(scale=2, min_width=50):
166
+ page_info = gr.Markdown("Page 1 / 1", elem_classes="page-info")
167
+ next_btn = gr.Button("Next ▶️", variant="secondary", scale=1)
168
+
169
+ with gr.Row():
170
+ upload_btn = gr.UploadButton(
171
+ "📷 Upload Menu Images",
172
+ file_types=["image"],
173
+ file_count="multiple",
174
+ scale=3,
175
+ elem_classes="upload-btn",
176
+ variant="primary"
177
+ )
178
+ clear_btn = gr.Button("🗑️ Remove", variant="stop", scale=1, elem_classes="clear-btn")
179
+
180
+ gr.Markdown("## 🤖 Model Selection")
181
+ model_choice_dropdown = gr.Dropdown(
182
+ choices=FINETUNED_MODEL_LIST + LLM_MODEL_LIST,
183
+ value=FINETUNED_MODEL_LIST[0],
184
+ label="Select Text Detection Model"
185
+ )
186
+
187
+ api_token_textbox = gr.Textbox(
188
+ label="API Token",
189
+ placeholder="Enter your API token here...",
190
+ type="password",
191
+ visible=False
192
+ )
193
+
194
+ generate_button = gr.Button("Generate Menu Information", variant="primary")
195
+ example_receiver = gr.Image(visible=False, label="Example Preview", type="filepath")
196
+
197
+ examples_component = gr.Examples(
198
+ examples=[[img_list[0]] for img_list in EXAMPLE_IMAGE_LIST],
199
+ inputs=example_receiver,
200
+ label="Example Menu Images"
201
+ )
202
+
203
+ with gr.Column(scale=1):
204
+ gr.Markdown("## 🍽️ Menu Info")
205
+ menu_json_textbox = gr.Textbox(
206
+ label="Output JSON",
207
+ interactive=True,
208
+ text_align="left",
209
+ elem_classes="large-text"
210
+ )
211
+
212
+ def update_display(images, index):
213
+ if not images:
214
+ return None, "Page 1 / 1"
215
+ idx = max(0, min(index, len(images) - 1))
216
+ return images[idx], f"Page {idx + 1} / {len(images)}"
217
+
218
+ def on_upload(new_files, current_images):
219
+ if current_images is None:
220
+ current_images = []
221
+ if new_files:
222
+ new_paths = [f.name for f in new_files]
223
+ current_images.extend(new_paths)
224
+ new_index = len(current_images) - 1
225
+ img, info = update_display(current_images, new_index)
226
+ return current_images, new_index, img, info
227
+
228
+ upload_btn.upload(
229
+ fn=on_upload,
230
+ inputs=[upload_btn, images_state],
231
+ outputs=[images_state, current_index_state, menu_image_display, page_info]
232
+ )
233
+
234
+ def on_clear(images, index):
235
+ if not images:
236
+ return [], 0, None, "Page 1 / 1"
237
+
238
+ new_images = list(images)
239
+ if 0 <= index < len(new_images):
240
+ new_images.pop(index)
241
+
242
+ if not new_images:
243
+ return [], 0, None, "Page 1 / 1"
244
+
245
+ new_index = index
246
+ if new_index >= len(new_images):
247
+ new_index = len(new_images) - 1
248
+
249
+ img, info = update_display(new_images, new_index)
250
+ return new_images, new_index, img, info
251
+
252
+ clear_btn.click(
253
+ fn=on_clear,
254
+ inputs=[images_state, current_index_state],
255
+ outputs=[images_state, current_index_state, menu_image_display, page_info]
256
+ )
257
+
258
+ def on_prev(images, index):
259
+ if not images:
260
+ return 0, None, "Page 1 / 1"
261
+ new_index = max(0, index - 1)
262
+ img, info = update_display(images, new_index)
263
+ return new_index, img, info
264
+
265
+ def on_next(images, index):
266
+ if not images:
267
+ return 0, None, "Page 1 / 1"
268
+ new_index = min(len(images) - 1, index + 1)
269
+ img, info = update_display(images, new_index)
270
+ return new_index, img, info
271
+
272
+ prev_btn.click(on_prev, [images_state, current_index_state], [current_index_state, menu_image_display, page_info])
273
+ next_btn.click(on_next, [images_state, current_index_state], [current_index_state, menu_image_display, page_info])
274
+
275
+ def on_example_click(evt: gr.SelectData):
276
+ if evt.index is None:
277
+ return [], 0, None, "Page 1 / 1"
278
+
279
+ # Retrieve the full batch based on the clicked index
280
+ if 0 <= evt.index < len(EXAMPLE_IMAGE_LIST):
281
+ current_images = EXAMPLE_IMAGE_LIST[evt.index]
282
+ else:
283
+ current_images = []
284
+
285
+ new_index = 0
286
+ img, info = update_display(current_images, new_index)
287
+ return current_images, new_index, img, info
288
+
289
+ examples_component.dataset.select(
290
+ fn=on_example_click,
291
+ inputs=None,
292
+ outputs=[images_state, current_index_state, menu_image_display, page_info]
293
+ )
294
+
295
+ def update_token_visibility(choice):
296
+ if choice in LLM_MODEL_LIST:
297
+ current_token = ""
298
+ if choice in LLM_MODEL_LIST[:3]:
299
+ current_token = GEMINI_API_TOKEN
300
+ else:
301
+ current_token = OPENAI_API_TOKEN
302
+ return gr.update(visible=True, value=current_token)
303
+ else:
304
+ return gr.update(visible=False)
305
+
306
+ model_choice_dropdown.change(
307
+ fn=update_token_visibility,
308
+ inputs=model_choice_dropdown,
309
+ outputs=api_token_textbox
310
+ )
311
+
312
+ generate_button.click(
313
+ fn=handle,
314
+ inputs=[images_state, model_choice_dropdown, api_token_textbox],
315
+ outputs=menu_json_textbox
316
+ )
317
+
318
+ gr.api(
319
+ fn=handle,
320
+ api_name="run"
321
+ )
322
+
323
+ return gradio_interface
324
+
325
+ if __name__ == "__main__":
326
+ demo = UserInterface()
327
+ demo.launch(css=CSS_STYLE)
menu/donut.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is modified from the HuggingFace transformers tutorial script for fine-tuning Donut on a custom dataset.
3
+ It's defined from `.ipynb` to the module implementation for better reusability and maintainability.
4
+ Reference: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb
5
+ """
6
+
7
+ import re
8
+ import random
9
+ from typing import Any, List, Tuple, Dict
10
+
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ from tqdm.auto import tqdm
15
+ from nltk import edit_distance
16
+ import pytorch_lightning as pl
17
+ from datasets import DatasetDict
18
+ from donut import JSONParseEvaluator
19
+ from huggingface_hub import upload_folder
20
+ from pillow_heif import register_heif_opener
21
+ from pytorch_lightning.callbacks import Callback
22
+ from pytorch_lightning.loggers import TensorBoardLogger
23
+ from torch.utils.data import (
24
+ Dataset,
25
+ DataLoader
26
+ )
27
+ from transformers import (
28
+ DonutProcessor,
29
+ VisionEncoderDecoderModel,
30
+ VisionEncoderDecoderConfig
31
+ )
32
+
33
+ TASK_PROMPT_NAME = "<s_menu-text-detection>"
34
+ register_heif_opener()
35
+
36
+ class DonutFinetuned:
37
+ def __init__(self, pretrained_model_repo_id: str = "ryanlinjui/donut-test"):
38
+ self.device = (
39
+ "cuda"
40
+ if torch.cuda.is_available()
41
+ else "mps" if torch.backends.mps.is_available() else "cpu"
42
+ )
43
+ self.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
44
+ self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id)
45
+ self.model.eval()
46
+ self.model.to(self.device)
47
+ print(f"Using {self.device} device")
48
+
49
+ def predict(self, image: Image.Image) -> Dict[str, Any]:
50
+ # prepare encoder inputs
51
+ pixel_values = self.processor(image.convert("RGB"), return_tensors="pt").pixel_values
52
+ pixel_values = pixel_values.to(self.device)
53
+
54
+ # prepare decoder inputs
55
+ decoder_input_ids = self.processor.tokenizer(TASK_PROMPT_NAME, add_special_tokens=False, return_tensors="pt").input_ids
56
+ decoder_input_ids = decoder_input_ids.to(self.device)
57
+
58
+ # autoregressively generate sequence
59
+ outputs = self.model.generate(
60
+ pixel_values,
61
+ decoder_input_ids=decoder_input_ids,
62
+ max_length=self.model.decoder.config.max_position_embeddings,
63
+ early_stopping=True,
64
+ pad_token_id=self.processor.tokenizer.pad_token_id,
65
+ eos_token_id=self.processor.tokenizer.eos_token_id,
66
+ use_cache=True,
67
+ num_beams=1,
68
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
69
+ return_dict_in_generate=True
70
+ )
71
+
72
+ # turn into JSON
73
+ seq = self.processor.batch_decode(outputs.sequences)[0]
74
+ seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
75
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
76
+ seq = self.processor.token2json(seq)
77
+ return seq
78
+
79
+ def evaluate(self, dataset: Dataset, ground_truth_key: str = "ground_truth") -> Tuple[Dict[str, Any], List[Any]]:
80
+ output_list = []
81
+ accs = []
82
+ ted_accs = []
83
+ f1_accs = []
84
+
85
+ for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
86
+ seq = self.predict(sample["image"])
87
+ ground_truth = sample[ground_truth_key]
88
+
89
+ # Original JSON accuracy
90
+ evaluator = JSONParseEvaluator()
91
+ score = evaluator.cal_acc(seq, ground_truth)
92
+ accs.append(score)
93
+ output_list.append(seq)
94
+
95
+ # TED (Tree Edit Distance) Accuracy
96
+ # Convert predictions and ground truth to string format for comparison
97
+ pred_str = str(seq) if seq else ""
98
+ gt_str = str(ground_truth) if ground_truth else ""
99
+
100
+ # Calculate normalized edit distance (1 - normalized_edit_distance = accuracy)
101
+ if len(pred_str) == 0 and len(gt_str) == 0:
102
+ ted_acc = 1.0
103
+ elif len(pred_str) == 0 or len(gt_str) == 0:
104
+ ted_acc = 0.0
105
+ else:
106
+ edit_dist = edit_distance(pred_str, gt_str)
107
+ max_len = max(len(pred_str), len(gt_str))
108
+ ted_acc = 1 - (edit_dist / max_len)
109
+ ted_accs.append(ted_acc)
110
+
111
+ # F1 Score Accuracy (character-level)
112
+ if len(pred_str) == 0 and len(gt_str) == 0:
113
+ f1_acc = 1.0
114
+ elif len(pred_str) == 0 or len(gt_str) == 0:
115
+ f1_acc = 0.0
116
+ else:
117
+ # Character-level precision and recall
118
+ pred_chars = set(pred_str)
119
+ gt_chars = set(gt_str)
120
+
121
+ if len(pred_chars) == 0:
122
+ precision = 0.0
123
+ else:
124
+ precision = len(pred_chars.intersection(gt_chars)) / len(pred_chars)
125
+
126
+ if len(gt_chars) == 0:
127
+ recall = 0.0
128
+ else:
129
+ recall = len(pred_chars.intersection(gt_chars)) / len(gt_chars)
130
+
131
+ if precision + recall == 0:
132
+ f1_acc = 0.0
133
+ else:
134
+ f1_acc = 2 * (precision * recall) / (precision + recall)
135
+ f1_accs.append(f1_acc)
136
+
137
+ scores = {
138
+ "accuracies": accs,
139
+ "mean_accuracy": np.mean(accs),
140
+ "ted_accuracies": ted_accs,
141
+ "mean_ted_accuracy": np.mean(ted_accs),
142
+ "f1_accuracies": f1_accs,
143
+ "mean_f1_accuracy": np.mean(f1_accs),
144
+ "length": len(accs)
145
+ }
146
+ return scores, output_list
147
+
148
+ class DonutTrainer:
149
+ processor = None
150
+ max_length = 768
151
+ image_size = [1280, 960]
152
+ added_tokens = []
153
+ train_dataloader = None
154
+ val_dataloader = None
155
+ huggingface_model_id = None
156
+
157
+ class DonutDataset(Dataset):
158
+ """
159
+ PyTorch Dataset for Donut. This class takes a HuggingFace Dataset as input.
160
+
161
+ Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
162
+ and it will be converted into pixel_values (vectorized image) and labels (input_ids of the tokenized string).
163
+
164
+ Args:
165
+ dataset: HuggingFace DatasetDict containing the dataset to be used
166
+ max_length: the max number of tokens for the target sequences
167
+ split: whether to load "train", "validation" or "test" split
168
+ ignore_id: ignore_index for torch.nn.CrossEntropyLoss
169
+ task_start_token: the special token to be fed to the decoder to conduct the target task
170
+ prompt_end_token: the special token at the end of the sequences
171
+ sort_json_key: whether or not to sort the JSON keys
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dataset: DatasetDict,
177
+ ground_truth_key: str,
178
+ max_length: int,
179
+ split: str = "train",
180
+ ignore_id: int = -100,
181
+ task_start_token: str = "<s>",
182
+ prompt_end_token: str = None,
183
+ sort_json_key: bool = True,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.dataset = dataset[split]
188
+ self.ground_truth_key = ground_truth_key
189
+ self.max_length = max_length
190
+ self.split = split
191
+ self.ignore_id = ignore_id
192
+ self.task_start_token = task_start_token
193
+ self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
194
+ self.sort_json_key = sort_json_key
195
+
196
+ self.dataset_length = len(self.dataset)
197
+
198
+ self.gt_token_sequences = []
199
+ for sample in self.dataset:
200
+ ground_truth = sample[self.ground_truth_key]
201
+ self.gt_token_sequences.append(
202
+ [
203
+ self.json2token(
204
+ gt_json,
205
+ update_special_tokens_for_json_key=self.split == "train",
206
+ sort_json_key=self.sort_json_key,
207
+ )
208
+ + DonutTrainer.processor.tokenizer.eos_token
209
+ for gt_json in [ground_truth] # load json from list of json
210
+ ]
211
+ )
212
+
213
+ self.add_tokens([self.task_start_token, self.prompt_end_token])
214
+ self.prompt_end_token_id = DonutTrainer.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
215
+
216
+ def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
217
+ """
218
+ Convert an ordered JSON object into a token sequence
219
+ """
220
+ if type(obj) == dict:
221
+ if len(obj) == 1 and "text_sequence" in obj:
222
+ return obj["text_sequence"]
223
+ else:
224
+ output = ""
225
+ if sort_json_key:
226
+ keys = sorted(obj.keys(), reverse=True)
227
+ else:
228
+ keys = obj.keys()
229
+ for k in keys:
230
+ if update_special_tokens_for_json_key:
231
+ self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
232
+ output += (
233
+ fr"<s_{k}>"
234
+ + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
235
+ + fr"</s_{k}>"
236
+ )
237
+ return output
238
+ elif type(obj) == list:
239
+ return r"<sep/>".join(
240
+ [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
241
+ )
242
+ else:
243
+ obj = str(obj)
244
+ if f"<{obj}/>" in DonutTrainer.added_tokens:
245
+ obj = f"<{obj}/>" # for categorical special tokens
246
+ return obj
247
+
248
+ def add_tokens(self, list_of_tokens: List[str]):
249
+ """
250
+ Add special tokens to tokenizer and resize the token embeddings of the decoder
251
+ """
252
+ newly_added_num = DonutTrainer.processor.tokenizer.add_tokens(list_of_tokens)
253
+ if newly_added_num > 0:
254
+ DonutTrainer.model.decoder.resize_token_embeddings(len(DonutTrainer.processor.tokenizer))
255
+ DonutTrainer.added_tokens.extend(list_of_tokens)
256
+
257
+ def __len__(self) -> int:
258
+ return self.dataset_length
259
+
260
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
261
+ """
262
+ Load image from image_path of given dataset_path and convert into input_tensor and labels
263
+ Convert gt data into input_ids (tokenized string)
264
+ Returns:
265
+ input_tensor : preprocessed image
266
+ input_ids : tokenized gt_data
267
+ labels : masked labels (model doesn't need to predict prompt and pad token)
268
+ """
269
+ sample = self.dataset[idx]
270
+
271
+ # inputs
272
+ pixel_values = DonutTrainer.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
273
+ pixel_values = pixel_values.squeeze()
274
+
275
+ # targets
276
+ target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
277
+ input_ids = DonutTrainer.processor.tokenizer(
278
+ target_sequence,
279
+ add_special_tokens=False,
280
+ max_length=self.max_length,
281
+ padding="max_length",
282
+ truncation=True,
283
+ return_tensors="pt",
284
+ )["input_ids"].squeeze(0)
285
+
286
+ labels = input_ids.clone()
287
+ labels[labels == DonutTrainer.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
288
+ # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
289
+ return pixel_values, labels, target_sequence
290
+
291
+
292
+ class DonutModelPLModule(pl.LightningModule):
293
+ def __init__(self, config, processor, model):
294
+ super().__init__()
295
+ self.config = config
296
+ self.processor = processor
297
+ self.model = model
298
+
299
+ def training_step(self, batch, batch_idx):
300
+ pixel_values, labels, _ = batch
301
+
302
+ outputs = self.model(pixel_values, labels=labels)
303
+ loss = outputs.loss
304
+ self.log("train_loss", loss)
305
+ return loss
306
+
307
+ def validation_step(self, batch, batch_idx, dataset_idx=0):
308
+ pixel_values, labels, answers = batch
309
+ batch_size = pixel_values.shape[0]
310
+ # we feed the prompt to the model
311
+ decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
312
+
313
+ outputs = self.model.generate(pixel_values,
314
+ decoder_input_ids=decoder_input_ids,
315
+ max_length=DonutTrainer.max_length,
316
+ early_stopping=True,
317
+ pad_token_id=self.processor.tokenizer.pad_token_id,
318
+ eos_token_id=self.processor.tokenizer.eos_token_id,
319
+ use_cache=True,
320
+ num_beams=1,
321
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
322
+ return_dict_in_generate=True,)
323
+
324
+ predictions = []
325
+ for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
326
+ seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
327
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
328
+ predictions.append(seq)
329
+
330
+ scores = []
331
+ for pred, answer in zip(predictions, answers):
332
+ pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
333
+ # NOT NEEDED ANYMORE
334
+ # answer = re.sub(r"<.*?>", "", answer, count=1)
335
+ answer = answer.replace(self.processor.tokenizer.eos_token, "")
336
+ scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
337
+
338
+ if self.config.get("verbose", False) and len(scores) == 1:
339
+ print(f"Prediction: {pred}")
340
+ print(f" Answer: {answer}")
341
+ print(f" Normed ED: {scores[0]}")
342
+
343
+ val_edit_distance = np.mean(scores)
344
+ self.log("val_edit_distance", val_edit_distance)
345
+ print(f"Validation Edit Distance: {val_edit_distance}")
346
+
347
+ return scores
348
+
349
+ def configure_optimizers(self):
350
+ # you could also add a learning rate scheduler if you want
351
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
352
+
353
+ return optimizer
354
+
355
+ def train_dataloader(self):
356
+ return DonutTrainer.train_dataloader
357
+
358
+ def val_dataloader(self):
359
+ return DonutTrainer.val_dataloader
360
+
361
+ class PushToHubCallback(Callback):
362
+ def on_train_epoch_end(self, trainer, pl_module):
363
+ print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
364
+ pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
365
+ self._upload_logs(trainer.logger.log_dir, trainer.current_epoch)
366
+
367
+ def on_train_end(self, trainer, pl_module):
368
+ print(f"Pushing model to the hub after training")
369
+ pl_module.processor.push_to_hub(DonutTrainer.huggingface_model_id,commit_message=f"Training done")
370
+ pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training done")
371
+ self._upload_logs(trainer.logger.log_dir, "final")
372
+
373
+ def _upload_logs(self, log_dir: str, epoch_info):
374
+ try:
375
+ print(f"Attempting to upload logs from: {log_dir}")
376
+ upload_folder(folder_path=log_dir, repo_id=DonutTrainer.huggingface_model_id,
377
+ path_in_repo="tensorboard_logs",
378
+ commit_message=f"Upload logs - epoch {epoch_info}", ignore_patterns=["*.tmp", "*.lock"])
379
+ print(f"Successfully uploaded logs for epoch {epoch_info}")
380
+ except Exception as e:
381
+ print(f"Failed to upload logs: {e}")
382
+ pass
383
+
384
+ @classmethod
385
+ def train(
386
+ cls,
387
+ dataset: DatasetDict,
388
+ pretrained_model_repo_id: str,
389
+ huggingface_model_id: str,
390
+ epochs: int,
391
+ train_batch_size: int,
392
+ val_batch_size: int,
393
+ learning_rate: float,
394
+ val_check_interval: float,
395
+ check_val_every_n_epoch: int,
396
+ gradient_clip_val: float,
397
+ num_training_samples_per_epoch: int,
398
+ num_nodes: int,
399
+ warmup_steps: int,
400
+ ground_truth_key: str = "ground_truth"
401
+ ):
402
+ cls.huggingface_model_id = huggingface_model_id
403
+ config = VisionEncoderDecoderConfig.from_pretrained(pretrained_model_repo_id)
404
+ config.encoder.image_size = cls.image_size
405
+ config.decoder.max_length = cls.max_length
406
+
407
+ cls.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
408
+ cls.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id, config=config)
409
+ cls.processor.image_processor.size = cls.image_size[::-1]
410
+ cls.processor.image_processor.do_align_long_axis = False
411
+
412
+ train_dataset = cls.DonutDataset(
413
+ dataset=dataset,
414
+ ground_truth_key=ground_truth_key,
415
+ max_length=cls.max_length,
416
+ split="train",
417
+ task_start_token=TASK_PROMPT_NAME,
418
+ prompt_end_token=TASK_PROMPT_NAME,
419
+ sort_json_key=True
420
+ )
421
+ val_dataset = cls.DonutDataset(
422
+ dataset=dataset,
423
+ ground_truth_key=ground_truth_key,
424
+ max_length=cls.max_length,
425
+ split="validation",
426
+ task_start_token=TASK_PROMPT_NAME,
427
+ prompt_end_token=TASK_PROMPT_NAME,
428
+ sort_json_key=True
429
+ )
430
+
431
+ cls.model.config.pad_token_id = cls.processor.tokenizer.pad_token_id
432
+ cls.model.config.decoder_start_token_id = cls.processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]
433
+
434
+ cls.train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
435
+ cls.val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
436
+
437
+ config = {
438
+ "max_epochs": epochs,
439
+ "val_check_interval": val_check_interval, # how many times we want to validate during an epoch
440
+ "check_val_every_n_epoch": check_val_every_n_epoch,
441
+ "gradient_clip_val": gradient_clip_val,
442
+ "num_training_samples_per_epoch": num_training_samples_per_epoch,
443
+ "lr": learning_rate,
444
+ "train_batch_sizes": [train_batch_size],
445
+ "val_batch_sizes": [val_batch_size],
446
+ # "seed":2022,
447
+ "num_nodes": num_nodes,
448
+ "warmup_steps": warmup_steps, # 10%
449
+ "result_path": "./.checkpoints",
450
+ "verbose": True,
451
+ }
452
+ model_module = cls.DonutModelPLModule(config, cls.processor, cls.model)
453
+
454
+ device = (
455
+ "cuda"
456
+ if torch.cuda.is_available()
457
+ else "mps" if torch.backends.mps.is_available() else "cpu"
458
+ )
459
+ print(f"Using {device} device")
460
+ trainer = pl.Trainer(
461
+ accelerator="gpu" if device == "cuda" else "mps" if device == "mps" else "cpu",
462
+ devices=1 if device == "cuda" else 0,
463
+ max_epochs=config.get("max_epochs"),
464
+ val_check_interval=config.get("val_check_interval"),
465
+ check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
466
+ gradient_clip_val=config.get("gradient_clip_val"),
467
+ precision=16 if device == "cuda" else 32, # we'll use mixed precision if device == "cuda"
468
+ num_sanity_val_steps=0,
469
+ logger=TensorBoardLogger(save_dir="./.checkpoints", name="donut_training", version=None),
470
+ callbacks=[cls.PushToHubCallback()]
471
+ )
472
+ trainer.fit(model_module)
menu/llm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .gemini import GeminiAPI
2
+ from .openai import OpenAIAPI
menu/llm/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from abc import ABC, abstractmethod
3
+
4
+ from PIL import Image
5
+
6
+ PROMPT = "The provided images display a menu. IMPORTANT: There may be MULTIPLE images representing different pages. You MUST examine EVERY image provided and combine all extracted information into the final result. Do not miss any dishes from any page."
7
+
8
+ class LLMBase(ABC):
9
+ @classmethod
10
+ @abstractmethod
11
+ def call(cls, images: List[Image.Image], model: str, token: str) -> dict:
12
+ raise NotImplementedError
menu/llm/gemini.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ from PIL import Image
5
+ from google import genai
6
+ from google.genai import types
7
+
8
+ from .base import LLMBase, PROMPT
9
+
10
+ FUNCTION_CALL = json.load(open("tools/schema_gemini.json", "r"))
11
+
12
+ class GeminiAPI(LLMBase):
13
+ @classmethod
14
+ def call(cls, images: List[Image.Image], model: str, token: str) -> dict:
15
+ client = genai.Client(api_key=token) # Initialize the client with the API key
16
+
17
+ config = types.GenerateContentConfig(
18
+ tools=[types.Tool(function_declarations=[FUNCTION_CALL])],
19
+ tool_config={
20
+ "function_calling_config": {
21
+ "mode": "ANY",
22
+ "allowed_function_names": [FUNCTION_CALL["name"]]
23
+ }
24
+ }
25
+ )
26
+
27
+ response = client.models.generate_content(
28
+ model=model,
29
+ contents=images + [PROMPT],
30
+ config=config
31
+ )
32
+ if response.candidates[0].content.parts[0].function_call:
33
+ function_call = response.candidates[0].content.parts[0].function_call
34
+ return function_call.args
35
+
36
+ return {}
menu/llm/openai.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import base64
3
+ from io import BytesIO
4
+ from typing import List
5
+
6
+ from PIL import Image
7
+ from openai import OpenAI
8
+
9
+ from .base import LLMBase, PROMPT
10
+
11
+ FUNCTION_CALL = json.load(open("tools/schema_openai.json", "r"))
12
+
13
+ class OpenAIAPI(LLMBase):
14
+ @classmethod
15
+ def call(cls, images: List[Image.Image], model: str, token: str) -> dict:
16
+ client = OpenAI(api_key=token) # Initialize the client with the API key
17
+
18
+ content = []
19
+ for image in images:
20
+ buffer = BytesIO()
21
+ image.save(buffer, format="JPEG")
22
+ encode_img = base64.b64encode(buffer.getvalue()).decode("utf-8")
23
+ content.append({
24
+ "type": "input_image",
25
+ "image_url": {"url": f"data:image/jpeg;base64,{encode_img}"},
26
+ })
27
+
28
+ content.append({"type": "text", "text": PROMPT})
29
+
30
+ response = client.responses.create(
31
+ model=model,
32
+ input=[
33
+ {
34
+ "role": "user",
35
+ "content": content,
36
+ }
37
+ ],
38
+ tools=[FUNCTION_CALL],
39
+ )
40
+ if response and response.output:
41
+ if hasattr(response.output[0], "arguments"):
42
+ return json.loads(response.output[0].arguments)
43
+ return {}
menu/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from datasets import Dataset, DatasetDict
4
+
5
+ def split_dataset(
6
+ dataset: Dataset,
7
+ train: float,
8
+ validation: float,
9
+ test: float,
10
+ seed: Optional[int] = None
11
+ ) -> DatasetDict:
12
+ """
13
+ Split a single-split Hugging Face Dataset into train/validation/test subsets.
14
+
15
+ Args:
16
+ dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']).
17
+ train (float): Proportion of data for the train split (0 < train < 1).
18
+ val (float): Proportion of data for the validation split (0 < val < 1).
19
+ test (float): Proportion of data for the test split (0 < test < 1).
20
+ Must satisfy train + val + test == 1.0.
21
+ seed (int): Random seed for reproducibility (default: None).
22
+
23
+ Returns:
24
+ DatasetDict: A dictionary with keys "train", "validation", and "test".
25
+ """
26
+ # Verify ratios sum to 1.0
27
+ total = train + validation + test
28
+ if abs(total - 1.0) > 1e-8:
29
+ raise ValueError(f"train + validation + test must equal 1.0 (got {total})")
30
+
31
+ # First split: extract train vs. temp (validation + test)
32
+ temp_size = validation + test
33
+ split_1 = dataset.train_test_split(test_size=temp_size, seed=seed)
34
+ train_ds = split_1["train"]
35
+ temp_ds = split_1["test"]
36
+
37
+ # Second split: divide temp into validation vs. test
38
+ relative_test_size = test / temp_size
39
+ split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed)
40
+ validation_ds = split_2["train"]
41
+ test_ds = split_2["test"]
42
+
43
+ # Return a DatasetDict with all three splits
44
+ return DatasetDict({
45
+ "train": train_ds,
46
+ "validation": validation_ds,
47
+ "test": test_ds,
48
+ })
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ authors = [{name = "ryanlinjui", email = "ryanlinjui@gmail.com"}]
3
+ name = "menu-text-detection"
4
+ version = "0.1.0"
5
+ description = "Extract structured menu information from images into JSON using a fine-tuned Donut E2E model."
6
+ readme = "README.md"
7
+ requires-python = "==3.11.*"
8
+ dependencies = [
9
+ "datasets>=3.6.0",
10
+ "dotenv>=0.9.9",
11
+ "google-genai>=1.14.0",
12
+ "gradio>=5.29.0",
13
+ "huggingface-hub>=0.31.1",
14
+ "matplotlib>=3.10.1",
15
+ "nltk>=3.9.1",
16
+ "notebook>=7.4.2",
17
+ "openai>=1.77.0",
18
+ "pillow>=11.2.1",
19
+ "pillow-heif>=0.22.0",
20
+ "protobuf>=6.30.2",
21
+ "pytorch-lightning>=2.5.2",
22
+ "sentencepiece>=0.2.0",
23
+ "tensorboardx>=2.6.2.2",
24
+ "transformers==4.49",
25
+ "torch==2.4.1",
26
+ "donut-python>=1.0.9",
27
+ ]
requirements.txt ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.3
4
+ aiosignal==1.4.0
5
+ annotated-doc==0.0.4
6
+ annotated-types==0.7.0
7
+ anyio==4.12.1
8
+ appnope==0.1.4
9
+ argon2-cffi==25.1.0
10
+ argon2-cffi-bindings==25.1.0
11
+ arrow==1.4.0
12
+ asttokens==3.0.1
13
+ async-lru==2.0.5
14
+ attrs==25.4.0
15
+ babel==2.17.0
16
+ beautifulsoup4==4.14.3
17
+ bleach==6.3.0
18
+ brotli==1.2.0
19
+ certifi==2026.1.4
20
+ cffi==2.0.0
21
+ charset-normalizer==3.4.4
22
+ click==8.3.1
23
+ comm==0.2.3
24
+ contourpy==1.3.3
25
+ cycler==0.12.1
26
+ datasets==4.5.0
27
+ debugpy==1.8.19
28
+ decorator==5.2.1
29
+ defusedxml==0.7.1
30
+ dill==0.4.0
31
+ distro==1.9.0
32
+ donut-python==1.0.9
33
+ dotenv==0.9.9
34
+ executing==2.2.1
35
+ fastapi==0.128.0
36
+ fastjsonschema==2.21.2
37
+ ffmpy==1.0.0
38
+ filelock==3.20.3
39
+ fonttools==4.61.1
40
+ fqdn==1.5.1
41
+ frozenlist==1.8.0
42
+ fsspec==2025.10.0
43
+ google-auth==2.47.0
44
+ google-genai==1.58.0
45
+ gradio==6.3.0
46
+ gradio-client==2.0.3
47
+ groovy==0.1.2
48
+ h11==0.16.0
49
+ hf-xet==1.2.0
50
+ httpcore==1.0.9
51
+ httpx==0.28.1
52
+ huggingface-hub==0.36.0
53
+ idna==3.11
54
+ ipykernel==7.1.0
55
+ ipython==9.9.0
56
+ ipython-pygments-lexers==1.1.1
57
+ isoduration==20.11.0
58
+ jedi==0.19.2
59
+ jinja2==3.1.6
60
+ jiter==0.12.0
61
+ joblib==1.5.3
62
+ json5==0.13.0
63
+ jsonpointer==3.0.0
64
+ jsonschema==4.26.0
65
+ jsonschema-specifications==2025.9.1
66
+ jupyter-client==8.8.0
67
+ jupyter-core==5.9.1
68
+ jupyter-events==0.12.0
69
+ jupyter-lsp==2.3.0
70
+ jupyter-server==2.17.0
71
+ jupyter-server-terminals==0.5.4
72
+ jupyterlab==4.5.2
73
+ jupyterlab-pygments==0.3.0
74
+ jupyterlab-server==2.28.0
75
+ kiwisolver==1.4.9
76
+ lark==1.3.1
77
+ lightning-utilities==0.15.2
78
+ markdown-it-py==4.0.0
79
+ markupsafe==3.0.3
80
+ matplotlib==3.10.8
81
+ matplotlib-inline==0.2.1
82
+ mdurl==0.1.2
83
+ mistune==3.2.0
84
+ mpmath==1.3.0
85
+ multidict==6.7.0
86
+ multiprocess==0.70.18
87
+ munch==4.0.0
88
+ nbclient==0.10.4
89
+ nbconvert==7.16.6
90
+ nbformat==5.10.4
91
+ nest-asyncio==1.6.0
92
+ networkx==3.6.1
93
+ nltk==3.9.2
94
+ notebook==7.5.2
95
+ notebook-shim==0.2.4
96
+ numpy==2.4.1
97
+ openai==2.15.0
98
+ orjson==3.11.5
99
+ overrides==7.7.0
100
+ packaging==25.0
101
+ pandas==2.3.3
102
+ pandocfilters==1.5.1
103
+ parso==0.8.5
104
+ pexpect==4.9.0
105
+ pillow==12.1.0
106
+ pillow-heif==1.1.1
107
+ platformdirs==4.5.1
108
+ prometheus-client==0.24.1
109
+ prompt-toolkit==3.0.52
110
+ propcache==0.4.1
111
+ protobuf==6.33.4
112
+ psutil==7.2.1
113
+ ptyprocess==0.7.0
114
+ pure-eval==0.2.3
115
+ pyarrow==22.0.0
116
+ pyasn1==0.6.1
117
+ pyasn1-modules==0.4.2
118
+ pycparser==2.23
119
+ pydantic==2.12.5
120
+ pydantic-core==2.41.5
121
+ pydub==0.25.1
122
+ pygments==2.19.2
123
+ pyparsing==3.3.1
124
+ python-dateutil==2.9.0.post0
125
+ python-dotenv==1.2.1
126
+ python-json-logger==4.0.0
127
+ python-multipart==0.0.21
128
+ pytorch-lightning==2.6.0
129
+ pytz==2025.2
130
+ pyyaml==6.0.3
131
+ pyzmq==27.1.0
132
+ referencing==0.37.0
133
+ regex==2026.1.15
134
+ requests==2.32.5
135
+ rfc3339-validator==0.1.4
136
+ rfc3986-validator==0.1.1
137
+ rfc3987-syntax==1.1.0
138
+ rich==14.2.0
139
+ rpds-py==0.30.0
140
+ rsa==4.9.1
141
+ ruamel-yaml==0.19.1
142
+ safehttpx==0.1.7
143
+ safetensors==0.7.0
144
+ sconf==0.2.5
145
+ semantic-version==2.10.0
146
+ send2trash==2.1.0
147
+ sentencepiece==0.2.1
148
+ setuptools==80.9.0
149
+ shellingham==1.5.4
150
+ six==1.17.0
151
+ sniffio==1.3.1
152
+ soupsieve==2.8.1
153
+ stack-data==0.6.3
154
+ starlette==0.50.0
155
+ sympy==1.14.0
156
+ tenacity==9.1.2
157
+ tensorboardx==2.6.4
158
+ terminado==0.18.1
159
+ timm==1.0.24
160
+ tinycss2==1.4.0
161
+ tokenizers==0.21.4
162
+ tomlkit==0.13.3
163
+ torch==2.4.1
164
+ torchmetrics==1.8.2
165
+ torchvision==0.19.1
166
+ tornado==6.5.4
167
+ tqdm==4.67.1
168
+ traitlets==5.14.3
169
+ transformers==4.49.0
170
+ typer==0.21.1
171
+ typing-extensions==4.15.0
172
+ typing-inspection==0.4.2
173
+ tzdata==2025.3
174
+ uri-template==1.3.0
175
+ urllib3==2.6.3
176
+ uvicorn==0.40.0
177
+ wcwidth==0.2.14
178
+ webcolors==25.10.0
179
+ webencodings==0.5.1
180
+ websocket-client==1.9.0
181
+ websockets==15.0.1
182
+ xxhash==3.6.0
183
+ yarl==1.22.0
184
+ zss==1.2.0
tools/schema_gemini.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "extract_menu_data",
3
+ "description": "Extract structured menu information from images.",
4
+ "parameters": {
5
+ "type": "object",
6
+ "properties": {
7
+ "restaurant": {
8
+ "type": "string",
9
+ "description": "Name of the restaurant. If the name is not available, it should be ''."
10
+ },
11
+ "address": {
12
+ "type": "string",
13
+ "description": "Address of the restaurant. If the address is not available, it should be ''."
14
+ },
15
+ "phone": {
16
+ "type": "string",
17
+ "description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
18
+ },
19
+ "business_hours": {
20
+ "type": "string",
21
+ "description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
22
+ },
23
+ "dishes": {
24
+ "type": "array",
25
+ "items": {
26
+ "type": "object",
27
+ "properties": {
28
+ "name": {
29
+ "type": "string",
30
+ "description": "Name of the menu item."
31
+ },
32
+ "price": {
33
+ "type": "string",
34
+ "description": "Price of the menu item. If the price is not available, it should be -1."
35
+ }
36
+ },
37
+ "required": ["name", "price"]
38
+ },
39
+ "description": "List of menu dishes item."
40
+ }
41
+ },
42
+ "required": ["restaurant", "address", "phone", "business_hours", "dishes"]
43
+ }
44
+ }
tools/schema_openai.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "function",
3
+ "name": "extract_menu_data",
4
+ "description": "Extract structured menu information from images.",
5
+ "parameters": {
6
+ "type": "object",
7
+ "properties": {
8
+ "restaurant": {
9
+ "type": "string",
10
+ "description": "Name of the restaurant. If the name is not available, it should be ''."
11
+ },
12
+ "address": {
13
+ "type": "string",
14
+ "description": "Address of the restaurant. If the address is not available, it should be ''."
15
+ },
16
+ "phone": {
17
+ "type": "string",
18
+ "description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
19
+ },
20
+ "business_hours": {
21
+ "type": "string",
22
+ "description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
23
+ },
24
+ "dishes": {
25
+ "type": "array",
26
+ "items": {
27
+ "type": "object",
28
+ "properties": {
29
+ "name": {
30
+ "type": "string",
31
+ "description": "Name of the menu item."
32
+ },
33
+ "price": {
34
+ "type": "string",
35
+ "description": "Price of the menu item. If the price is not available, it should be -1."
36
+ }
37
+ },
38
+ "required": ["name", "price"],
39
+ "additionalProperties": false
40
+ },
41
+ "description": "List of menu dishes item."
42
+ }
43
+ },
44
+ "required": ["restaurant", "address", "phone", "business_hours", "dishes"],
45
+ "additionalProperties": false
46
+ }
47
+ }
train.ipynb ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Login to HuggingFace (just login once)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "from huggingface_hub import interpreter_login\n",
17
+ "interpreter_login()"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# Collect Menu Image Datasets\n",
25
+ "- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n",
26
+ "- After finishing, push to HuggingFace Datasets.\n",
27
+ "- For labeling:\n",
28
+ " - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n",
29
+ " - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n",
30
+ "\n",
31
+ "### Menu Type\n",
32
+ "- **h**: horizontal menu\n",
33
+ "- **v**: vertical menu\n",
34
+ "- **d**: document-style menu\n",
35
+ "- **s**: in-scene menu (non-document style)\n",
36
+ "- **i**: irregular menu (menu with irregular text layout)\n",
37
+ "\n",
38
+ "> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import os\n",
48
+ "import json\n",
49
+ "\n",
50
+ "import numpy as np\n",
51
+ "from PIL import Image\n",
52
+ "from pillow_heif import register_heif_opener\n",
53
+ "\n",
54
+ "from menu.llm import (\n",
55
+ " GeminiAPI,\n",
56
+ " OpenAIAPI\n",
57
+ ")\n",
58
+ "\n",
59
+ "IMAGE_DIR = \"datasets/images\" # set your image directory here\n",
60
+ "SELECTED_MODEL = \"gemini-2.5-flash\" # set model name here, refer MODEL_LIST from app.py for more\n",
61
+ "API_TOKEN = \"\" # set your API token here\n",
62
+ "SELECTED_FUNCTION = GeminiAPI # set \"GeminiAPI\" or \"OpenAIAPI\"\n",
63
+ "\n",
64
+ "register_heif_opener()\n",
65
+ "\n",
66
+ "for file in os.listdir(IMAGE_DIR):\n",
67
+ " print(f\"Processing image: {file}\")\n",
68
+ " try:\n",
69
+ " image = np.array(Image.open(os.path.join(IMAGE_DIR, file)))\n",
70
+ " data = {\n",
71
+ " \"file_name\": file,\n",
72
+ " \"menu\": SELECTED_FUNCTION.call(image, SELECTED_MODEL, API_TOKEN)\n",
73
+ " }\n",
74
+ " with open(os.path.join(IMAGE_DIR, \"metadata.jsonl\"), \"a\", encoding=\"utf-8\") as metaf:\n",
75
+ " metaf.write(json.dumps(data, ensure_ascii=False, sort_keys=True) + \"\\n\")\n",
76
+ " except Exception as e:\n",
77
+ " print(f\"Skipping invalid image '{file}': {e}\")\n",
78
+ " continue"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "# Push Datasets to HuggingFace"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "from datasets import load_dataset\n",
95
+ "\n",
96
+ "dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n",
97
+ "dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "# Prepare the dataset for training"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "from menu.utils import split_dataset\n",
114
+ "from datasets import load_dataset\n",
115
+ "\n",
116
+ "dataset = load_dataset(path=\"ryanlinjui/menu-zh-TW\") # set your dataset repo id for training\n",
117
+ "dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
118
+ "print(f\"Dataset split: {len(dataset['train'])} train, {len(dataset['validation'])} validation, {len(dataset['test'])} test\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "# Fine-tune Donut Model"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "import logging\n",
135
+ "from menu.donut import DonutTrainer\n",
136
+ "\n",
137
+ "logging.getLogger(\"transformers\").setLevel(logging.ERROR) # filter output message from transformers\n",
138
+ "\n",
139
+ "DonutTrainer.train(\n",
140
+ " dataset=dataset,\n",
141
+ " pretrained_model_repo_id=\"naver-clova-ix/donut-base\", # set your pretrained model repo id for fine-tuning\n",
142
+ " ground_truth_key=\"menu\", # set your ground truth key for training\n",
143
+ " huggingface_model_id=\"ryanlinjui/donut-base-finetuned-menu\", # set your huggingface model repo id for saving / pushing to the hub\n",
144
+ " epochs=15, # set your training epochs\n",
145
+ " train_batch_size=8, # set your training batch size\n",
146
+ " val_batch_size=1, # set your validation batch size\n",
147
+ " learning_rate=3e-5, # set your learning rate\n",
148
+ " val_check_interval=0.5, # how many times we want to validate during an epoch\n",
149
+ " check_val_every_n_epoch=1, # how many epochs we want to validate\n",
150
+ " gradient_clip_val=1.0, # gradient clipping value for training stability\n",
151
+ " num_training_samples_per_epoch=198, # set num_training_samples_per_epoch = training set size\n",
152
+ " num_nodes=1, # number of nodes for distributed training\n",
153
+ " warmup_steps=75 # number of warmup steps for learning rate scheduler, 198/8*30/10, 10%\n",
154
+ ")"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "metadata": {},
160
+ "source": [
161
+ "# Evaluate Donut Model"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "import json\n",
171
+ "from datasets import load_dataset\n",
172
+ "\n",
173
+ "from menu.utils import split_dataset\n",
174
+ "from menu.donut import DonutFinetuned\n",
175
+ "\n",
176
+ "dataset = load_dataset(\"ryanlinjui/menu-zh-TW\")\n",
177
+ "dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
178
+ "donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
179
+ "scores, output_list = donut_finetuned.evaluate(dataset=dataset[\"test\"], ground_truth_key=\"menu\")\n",
180
+ "\n",
181
+ "print(\"Evaluation scores:\")\n",
182
+ "for key, value in scores.items():\n",
183
+ " print(f\"{key}: {value}\")\n",
184
+ "\n",
185
+ "print(\"\\nSample outputs:\")\n",
186
+ "for output in output_list[:5]:\n",
187
+ " print(json.dumps(output, ensure_ascii=False, indent=4))"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {},
193
+ "source": [
194
+ "# Test Donut Model"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "from PIL import Image\n",
204
+ "from menu.donut import DonutFinetuned\n",
205
+ "\n",
206
+ "image = Image.open(\"./examples/menu-hd.jpg\")\n",
207
+ "\n",
208
+ "donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
209
+ "outputs = donut_finetuned.predict(image=image)\n",
210
+ "print(outputs)"
211
+ ]
212
+ }
213
+ ],
214
+ "metadata": {
215
+ "kernelspec": {
216
+ "display_name": "menu-text-detection",
217
+ "language": "python",
218
+ "name": "python3"
219
+ },
220
+ "language_info": {
221
+ "codemirror_mode": {
222
+ "name": "ipython",
223
+ "version": 3
224
+ },
225
+ "file_extension": ".py",
226
+ "mimetype": "text/x-python",
227
+ "name": "python",
228
+ "nbconvert_exporter": "python",
229
+ "pygments_lexer": "ipython3",
230
+ "version": "3.11.12"
231
+ }
232
+ },
233
+ "nbformat": 4,
234
+ "nbformat_minor": 2
235
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff