Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
6bd37dd
0
Parent(s):
Sync from https://github.com/ryanlinjui/menu-text-detection
Browse files- .checkpoints/.gitkeep +0 -0
- .env.example +3 -0
- .github/workflows/sync.yml +25 -0
- .gitignore +24 -0
- .python-version +1 -0
- LICENSE +21 -0
- README.md +65 -0
- app.py +157 -0
- menu/donut.py +472 -0
- menu/llm/__init__.py +2 -0
- menu/llm/base.py +9 -0
- menu/llm/gemini.py +36 -0
- menu/llm/openai.py +39 -0
- menu/utils.py +48 -0
- pyproject.toml +27 -0
- requirements.txt +183 -0
- tools/schema_gemini.json +44 -0
- tools/schema_openai.json +47 -0
- train.ipynb +235 -0
- uv.lock +0 -0
.checkpoints/.gitkeep
ADDED
|
File without changes
|
.env.example
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HUGGINGFACE_TOKEN="HUGGINGFACE_TOKEN"
|
| 2 |
+
GEMINI_API_TOKEN="GEMINI_API_TOKEN"
|
| 3 |
+
OPENAI_API_TOKEN="OPENAI_API_TOKEN"
|
.github/workflows/sync.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
jobs:
|
| 8 |
+
sync:
|
| 9 |
+
name: Sync
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout Repository
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
|
| 15 |
+
- name: Remove bad files
|
| 16 |
+
run: rm -rf examples assets
|
| 17 |
+
|
| 18 |
+
- name: Sync to Hugging Face Spaces
|
| 19 |
+
uses: JacobLinCool/huggingface-sync@v1
|
| 20 |
+
with:
|
| 21 |
+
github: ${{ secrets.GITHUB_TOKEN }}
|
| 22 |
+
user: ryanlinjui # Hugging Face username or organization name
|
| 23 |
+
space: menu-text-detection # Hugging Face space name
|
| 24 |
+
token: ${{ secrets.HF_TOKEN }} # Hugging Face token
|
| 25 |
+
python_version: 3.11 # Python version
|
.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mac
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# cache
|
| 5 |
+
__pycache__
|
| 6 |
+
|
| 7 |
+
# datasets
|
| 8 |
+
datasets
|
| 9 |
+
|
| 10 |
+
# papers
|
| 11 |
+
docs/papers
|
| 12 |
+
|
| 13 |
+
# uv
|
| 14 |
+
.venv
|
| 15 |
+
|
| 16 |
+
# gradio
|
| 17 |
+
.gradio
|
| 18 |
+
|
| 19 |
+
# env
|
| 20 |
+
.env
|
| 21 |
+
|
| 22 |
+
# checkpoint
|
| 23 |
+
.checkpoints/*
|
| 24 |
+
!.checkpoints/.gitkeep
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 RyanLin
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: menu text detection
|
| 3 |
+
emoji: 🦄
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
python_version: 3.11
|
| 8 |
+
short_description: Extract structured menu information from images into JSON...
|
| 9 |
+
tags: [ "donut","fine-tuning","image-to-text","transformer" ]
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Menu Text Detection System
|
| 13 |
+
|
| 14 |
+
Extract structured menu information from images into JSON using a fine-tuned E2E model or LLM.
|
| 15 |
+
|
| 16 |
+
[](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)
|
| 17 |
+
[](https://huggingface.co/collections/ryanlinjui/menu-text-detection-670ccf527626bb004bbfb39b)
|
| 18 |
+
|
| 19 |
+
https://github.com/user-attachments/assets/80e5d54c-f2c8-4593-ad9b-499e5b71d8f6
|
| 20 |
+
|
| 21 |
+
## 🚀 Features
|
| 22 |
+
### Overview
|
| 23 |
+
Currently supports the following information from menu images:
|
| 24 |
+
|
| 25 |
+
- **Restaurant Name**
|
| 26 |
+
- **Business Hours**
|
| 27 |
+
- **Address**
|
| 28 |
+
- **Phone Number**
|
| 29 |
+
- **Dish Information**
|
| 30 |
+
- Name
|
| 31 |
+
- Price
|
| 32 |
+
|
| 33 |
+
> For the JSON schema, see [tools directory](./tools).
|
| 34 |
+
|
| 35 |
+
### Supported Methods to Extract Menu Information
|
| 36 |
+
#### Fine-tuned E2E model and Training metrics
|
| 37 |
+
- [**Donut (Document Parsing Task)**](https://huggingface.co/ryanlinjui/donut-base-finetuned-menu) - Base model by [Clova AI (ECCV ’22)](https://github.com/clovaai/donut)
|
| 38 |
+
|
| 39 |
+
#### LLM Function Calling
|
| 40 |
+
- Google Gemini API
|
| 41 |
+
- OpenAI GPT API
|
| 42 |
+
|
| 43 |
+
## 💻 Training / Fine-Tuning
|
| 44 |
+
### Setup
|
| 45 |
+
Use [uv](https://github.com/astral-sh/uv) to set up the development environment:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
uv sync
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
> or use `pip install -r requirements.txt` if it has any problems
|
| 52 |
+
|
| 53 |
+
### Training Script (Datasets collecting, Fine-Tuning)
|
| 54 |
+
Please refer [`train.ipynb`](./train.ipynb). Use Jupyter Notebook for training:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
uv run jupyter-notebook
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
> For VSCode users, please install Jupyter extension, then select `.venv/bin/python` as your kernel.
|
| 61 |
+
|
| 62 |
+
### Run Demo Locally
|
| 63 |
+
```bash
|
| 64 |
+
uv run python app.py
|
| 65 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from pillow_heif import register_heif_opener
|
| 8 |
+
|
| 9 |
+
from menu.llm import (
|
| 10 |
+
GeminiAPI,
|
| 11 |
+
OpenAIAPI
|
| 12 |
+
)
|
| 13 |
+
from menu.donut import DonutFinetuned
|
| 14 |
+
|
| 15 |
+
register_heif_opener()
|
| 16 |
+
load_dotenv(override=True)
|
| 17 |
+
GEMINI_API_TOKEN = os.getenv("GEMINI_API_TOKEN", "")
|
| 18 |
+
OPENAI_API_TOKEN = os.getenv("OPENAI_API_TOKEN", "")
|
| 19 |
+
|
| 20 |
+
SOURCE_CODE_GH_URL = "https://github.com/ryanlinjui/menu-text-detection"
|
| 21 |
+
BADGE_URL = "https://img.shields.io/badge/GitHub_Code-Click_Here!!-default?logo=github"
|
| 22 |
+
|
| 23 |
+
GITHUB_RAW_URL = "https://raw.githubusercontent.com/ryanlinjui/menu-text-detection/main"
|
| 24 |
+
EXAMPLE_IMAGE_LIST = [
|
| 25 |
+
f"{GITHUB_RAW_URL}/examples/menu-hd.jpg",
|
| 26 |
+
f"{GITHUB_RAW_URL}/examples/menu-vs.jpg",
|
| 27 |
+
f"{GITHUB_RAW_URL}/examples/menu-si.jpg"
|
| 28 |
+
]
|
| 29 |
+
FINETUNED_MODEL_LIST = [
|
| 30 |
+
"Donut (Document Parsing Task) Fine-tuned Model"
|
| 31 |
+
]
|
| 32 |
+
LLM_MODEL_LIST = [
|
| 33 |
+
"gemini-2.5-pro",
|
| 34 |
+
"gemini-2.5-flash",
|
| 35 |
+
"gemini-2.0-flash",
|
| 36 |
+
"gpt-4.1",
|
| 37 |
+
"gpt-4o",
|
| 38 |
+
"o4-mini"
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
donut_finetuned = DonutFinetuned("ryanlinjui/donut-base-finetuned-menu")
|
| 42 |
+
|
| 43 |
+
def handle(image: Image.Image, model: str, api_token: str) -> str:
|
| 44 |
+
if image is None:
|
| 45 |
+
raise gr.Error("Please upload an image first.")
|
| 46 |
+
|
| 47 |
+
if model == FINETUNED_MODEL_LIST[0]:
|
| 48 |
+
result = donut_finetuned.predict(image)
|
| 49 |
+
|
| 50 |
+
elif model in LLM_MODEL_LIST:
|
| 51 |
+
if len(api_token) < 10:
|
| 52 |
+
raise gr.Error(f"Please provide a valid token for {model}.")
|
| 53 |
+
try:
|
| 54 |
+
if model in LLM_MODEL_LIST[:3]:
|
| 55 |
+
result = GeminiAPI.call(image, model, api_token)
|
| 56 |
+
else:
|
| 57 |
+
result = OpenAIAPI.call(image, model, api_token)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
raise gr.Error(f"Failed to process with API model {model}: {str(e)}")
|
| 60 |
+
else:
|
| 61 |
+
raise gr.Error("Invalid model selection. Please choose a valid model.")
|
| 62 |
+
|
| 63 |
+
return json.dumps(result, indent=4, ensure_ascii=False, sort_keys=True)
|
| 64 |
+
|
| 65 |
+
def UserInterface() -> gr.Interface:
|
| 66 |
+
with gr.Blocks(
|
| 67 |
+
delete_cache=(86400, 86400),
|
| 68 |
+
css="""
|
| 69 |
+
.image-panel {
|
| 70 |
+
display: flex;
|
| 71 |
+
flex-direction: column;
|
| 72 |
+
height: 600px;
|
| 73 |
+
}
|
| 74 |
+
.image-panel img {
|
| 75 |
+
object-fit: contain;
|
| 76 |
+
max-height: 600px;
|
| 77 |
+
max-width: 600px;
|
| 78 |
+
width: 100%;
|
| 79 |
+
}
|
| 80 |
+
.large-text textarea {
|
| 81 |
+
font-size: 20px !important;
|
| 82 |
+
height: 600px !important;
|
| 83 |
+
width: 100% !important;
|
| 84 |
+
}
|
| 85 |
+
"""
|
| 86 |
+
) as gradio_interface:
|
| 87 |
+
gr.HTML(f'<a href="{SOURCE_CODE_GH_URL}"><img src="{BADGE_URL}" alt="GitHub Code"/></a>')
|
| 88 |
+
gr.Markdown("# Menu Text Detection")
|
| 89 |
+
|
| 90 |
+
with gr.Row():
|
| 91 |
+
with gr.Column(scale=1, min_width=500):
|
| 92 |
+
gr.Markdown("## 📷 Menu Image")
|
| 93 |
+
menu_image = gr.Image(
|
| 94 |
+
type="pil",
|
| 95 |
+
label="Input menu image",
|
| 96 |
+
elem_classes="image-panel"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
gr.Markdown("## 🤖 Model Selection")
|
| 100 |
+
model_choice_dropdown = gr.Dropdown(
|
| 101 |
+
choices=FINETUNED_MODEL_LIST + LLM_MODEL_LIST,
|
| 102 |
+
value=FINETUNED_MODEL_LIST[0],
|
| 103 |
+
label="Select Text Detection Model"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
api_token_textbox = gr.Textbox(
|
| 107 |
+
label="API Token",
|
| 108 |
+
placeholder="Enter your API token here...",
|
| 109 |
+
type="password",
|
| 110 |
+
visible=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
generate_button = gr.Button("Generate Menu Information", variant="primary")
|
| 114 |
+
|
| 115 |
+
gr.Examples(
|
| 116 |
+
examples=EXAMPLE_IMAGE_LIST,
|
| 117 |
+
inputs=menu_image,
|
| 118 |
+
label="Example Menu Images"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
with gr.Column(scale=1):
|
| 122 |
+
gr.Markdown("## 🍽️ Menu Info")
|
| 123 |
+
menu_json_textbox = gr.Textbox(
|
| 124 |
+
label="Ouput JSON",
|
| 125 |
+
interactive=True,
|
| 126 |
+
text_align="left",
|
| 127 |
+
elem_classes="large-text"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def update_token_visibility(choice):
|
| 131 |
+
if choice in LLM_MODEL_LIST:
|
| 132 |
+
current_token = ""
|
| 133 |
+
if choice in LLM_MODEL_LIST[:3]:
|
| 134 |
+
current_token = GEMINI_API_TOKEN
|
| 135 |
+
else:
|
| 136 |
+
current_token = OPENAI_API_TOKEN
|
| 137 |
+
return gr.update(visible=True, value=current_token)
|
| 138 |
+
else:
|
| 139 |
+
return gr.update(visible=False)
|
| 140 |
+
|
| 141 |
+
model_choice_dropdown.change(
|
| 142 |
+
fn=update_token_visibility,
|
| 143 |
+
inputs=model_choice_dropdown,
|
| 144 |
+
outputs=api_token_textbox
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
generate_button.click(
|
| 148 |
+
fn=handle,
|
| 149 |
+
inputs=[menu_image, model_choice_dropdown, api_token_textbox],
|
| 150 |
+
outputs=menu_json_textbox
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return gradio_interface
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
demo = UserInterface()
|
| 157 |
+
demo.launch()
|
menu/donut.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is modified from the HuggingFace transformers tutorial script for fine-tuning Donut on a custom dataset.
|
| 3 |
+
It's defined from `.ipynb` to the module implementation for better reusability and maintainability.
|
| 4 |
+
Reference: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import random
|
| 9 |
+
from typing import Any, List, Tuple, Dict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from tqdm.auto import tqdm
|
| 15 |
+
from nltk import edit_distance
|
| 16 |
+
import pytorch_lightning as pl
|
| 17 |
+
from datasets import DatasetDict
|
| 18 |
+
from donut import JSONParseEvaluator
|
| 19 |
+
from huggingface_hub import upload_folder
|
| 20 |
+
from pillow_heif import register_heif_opener
|
| 21 |
+
from pytorch_lightning.callbacks import Callback
|
| 22 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 23 |
+
from torch.utils.data import (
|
| 24 |
+
Dataset,
|
| 25 |
+
DataLoader
|
| 26 |
+
)
|
| 27 |
+
from transformers import (
|
| 28 |
+
DonutProcessor,
|
| 29 |
+
VisionEncoderDecoderModel,
|
| 30 |
+
VisionEncoderDecoderConfig
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
TASK_PROMPT_NAME = "<s_menu-text-detection>"
|
| 34 |
+
register_heif_opener()
|
| 35 |
+
|
| 36 |
+
class DonutFinetuned:
|
| 37 |
+
def __init__(self, pretrained_model_repo_id: str = "ryanlinjui/donut-test"):
|
| 38 |
+
self.device = (
|
| 39 |
+
"cuda"
|
| 40 |
+
if torch.cuda.is_available()
|
| 41 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 42 |
+
)
|
| 43 |
+
self.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
|
| 44 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id)
|
| 45 |
+
self.model.eval()
|
| 46 |
+
self.model.to(self.device)
|
| 47 |
+
print(f"Using {self.device} device")
|
| 48 |
+
|
| 49 |
+
def predict(self, image: Image.Image) -> Dict[str, Any]:
|
| 50 |
+
# prepare encoder inputs
|
| 51 |
+
pixel_values = self.processor(image.convert("RGB"), return_tensors="pt").pixel_values
|
| 52 |
+
pixel_values = pixel_values.to(self.device)
|
| 53 |
+
|
| 54 |
+
# prepare decoder inputs
|
| 55 |
+
decoder_input_ids = self.processor.tokenizer(TASK_PROMPT_NAME, add_special_tokens=False, return_tensors="pt").input_ids
|
| 56 |
+
decoder_input_ids = decoder_input_ids.to(self.device)
|
| 57 |
+
|
| 58 |
+
# autoregressively generate sequence
|
| 59 |
+
outputs = self.model.generate(
|
| 60 |
+
pixel_values,
|
| 61 |
+
decoder_input_ids=decoder_input_ids,
|
| 62 |
+
max_length=self.model.decoder.config.max_position_embeddings,
|
| 63 |
+
early_stopping=True,
|
| 64 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
| 65 |
+
eos_token_id=self.processor.tokenizer.eos_token_id,
|
| 66 |
+
use_cache=True,
|
| 67 |
+
num_beams=1,
|
| 68 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
| 69 |
+
return_dict_in_generate=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# turn into JSON
|
| 73 |
+
seq = self.processor.batch_decode(outputs.sequences)[0]
|
| 74 |
+
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
| 75 |
+
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
| 76 |
+
seq = self.processor.token2json(seq)
|
| 77 |
+
return seq
|
| 78 |
+
|
| 79 |
+
def evaluate(self, dataset: Dataset, ground_truth_key: str = "ground_truth") -> Tuple[Dict[str, Any], List[Any]]:
|
| 80 |
+
output_list = []
|
| 81 |
+
accs = []
|
| 82 |
+
ted_accs = []
|
| 83 |
+
f1_accs = []
|
| 84 |
+
|
| 85 |
+
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
|
| 86 |
+
seq = self.predict(sample["image"])
|
| 87 |
+
ground_truth = sample[ground_truth_key]
|
| 88 |
+
|
| 89 |
+
# Original JSON accuracy
|
| 90 |
+
evaluator = JSONParseEvaluator()
|
| 91 |
+
score = evaluator.cal_acc(seq, ground_truth)
|
| 92 |
+
accs.append(score)
|
| 93 |
+
output_list.append(seq)
|
| 94 |
+
|
| 95 |
+
# TED (Tree Edit Distance) Accuracy
|
| 96 |
+
# Convert predictions and ground truth to string format for comparison
|
| 97 |
+
pred_str = str(seq) if seq else ""
|
| 98 |
+
gt_str = str(ground_truth) if ground_truth else ""
|
| 99 |
+
|
| 100 |
+
# Calculate normalized edit distance (1 - normalized_edit_distance = accuracy)
|
| 101 |
+
if len(pred_str) == 0 and len(gt_str) == 0:
|
| 102 |
+
ted_acc = 1.0
|
| 103 |
+
elif len(pred_str) == 0 or len(gt_str) == 0:
|
| 104 |
+
ted_acc = 0.0
|
| 105 |
+
else:
|
| 106 |
+
edit_dist = edit_distance(pred_str, gt_str)
|
| 107 |
+
max_len = max(len(pred_str), len(gt_str))
|
| 108 |
+
ted_acc = 1 - (edit_dist / max_len)
|
| 109 |
+
ted_accs.append(ted_acc)
|
| 110 |
+
|
| 111 |
+
# F1 Score Accuracy (character-level)
|
| 112 |
+
if len(pred_str) == 0 and len(gt_str) == 0:
|
| 113 |
+
f1_acc = 1.0
|
| 114 |
+
elif len(pred_str) == 0 or len(gt_str) == 0:
|
| 115 |
+
f1_acc = 0.0
|
| 116 |
+
else:
|
| 117 |
+
# Character-level precision and recall
|
| 118 |
+
pred_chars = set(pred_str)
|
| 119 |
+
gt_chars = set(gt_str)
|
| 120 |
+
|
| 121 |
+
if len(pred_chars) == 0:
|
| 122 |
+
precision = 0.0
|
| 123 |
+
else:
|
| 124 |
+
precision = len(pred_chars.intersection(gt_chars)) / len(pred_chars)
|
| 125 |
+
|
| 126 |
+
if len(gt_chars) == 0:
|
| 127 |
+
recall = 0.0
|
| 128 |
+
else:
|
| 129 |
+
recall = len(pred_chars.intersection(gt_chars)) / len(gt_chars)
|
| 130 |
+
|
| 131 |
+
if precision + recall == 0:
|
| 132 |
+
f1_acc = 0.0
|
| 133 |
+
else:
|
| 134 |
+
f1_acc = 2 * (precision * recall) / (precision + recall)
|
| 135 |
+
f1_accs.append(f1_acc)
|
| 136 |
+
|
| 137 |
+
scores = {
|
| 138 |
+
"accuracies": accs,
|
| 139 |
+
"mean_accuracy": np.mean(accs),
|
| 140 |
+
"ted_accuracies": ted_accs,
|
| 141 |
+
"mean_ted_accuracy": np.mean(ted_accs),
|
| 142 |
+
"f1_accuracies": f1_accs,
|
| 143 |
+
"mean_f1_accuracy": np.mean(f1_accs),
|
| 144 |
+
"length": len(accs)
|
| 145 |
+
}
|
| 146 |
+
return scores, output_list
|
| 147 |
+
|
| 148 |
+
class DonutTrainer:
|
| 149 |
+
processor = None
|
| 150 |
+
max_length = 768
|
| 151 |
+
image_size = [1280, 960]
|
| 152 |
+
added_tokens = []
|
| 153 |
+
train_dataloader = None
|
| 154 |
+
val_dataloader = None
|
| 155 |
+
huggingface_model_id = None
|
| 156 |
+
|
| 157 |
+
class DonutDataset(Dataset):
|
| 158 |
+
"""
|
| 159 |
+
PyTorch Dataset for Donut. This class takes a HuggingFace Dataset as input.
|
| 160 |
+
|
| 161 |
+
Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
|
| 162 |
+
and it will be converted into pixel_values (vectorized image) and labels (input_ids of the tokenized string).
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
dataset: HuggingFace DatasetDict containing the dataset to be used
|
| 166 |
+
max_length: the max number of tokens for the target sequences
|
| 167 |
+
split: whether to load "train", "validation" or "test" split
|
| 168 |
+
ignore_id: ignore_index for torch.nn.CrossEntropyLoss
|
| 169 |
+
task_start_token: the special token to be fed to the decoder to conduct the target task
|
| 170 |
+
prompt_end_token: the special token at the end of the sequences
|
| 171 |
+
sort_json_key: whether or not to sort the JSON keys
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
dataset: DatasetDict,
|
| 177 |
+
ground_truth_key: str,
|
| 178 |
+
max_length: int,
|
| 179 |
+
split: str = "train",
|
| 180 |
+
ignore_id: int = -100,
|
| 181 |
+
task_start_token: str = "<s>",
|
| 182 |
+
prompt_end_token: str = None,
|
| 183 |
+
sort_json_key: bool = True,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
self.dataset = dataset[split]
|
| 188 |
+
self.ground_truth_key = ground_truth_key
|
| 189 |
+
self.max_length = max_length
|
| 190 |
+
self.split = split
|
| 191 |
+
self.ignore_id = ignore_id
|
| 192 |
+
self.task_start_token = task_start_token
|
| 193 |
+
self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
|
| 194 |
+
self.sort_json_key = sort_json_key
|
| 195 |
+
|
| 196 |
+
self.dataset_length = len(self.dataset)
|
| 197 |
+
|
| 198 |
+
self.gt_token_sequences = []
|
| 199 |
+
for sample in self.dataset:
|
| 200 |
+
ground_truth = sample[self.ground_truth_key]
|
| 201 |
+
self.gt_token_sequences.append(
|
| 202 |
+
[
|
| 203 |
+
self.json2token(
|
| 204 |
+
gt_json,
|
| 205 |
+
update_special_tokens_for_json_key=self.split == "train",
|
| 206 |
+
sort_json_key=self.sort_json_key,
|
| 207 |
+
)
|
| 208 |
+
+ DonutTrainer.processor.tokenizer.eos_token
|
| 209 |
+
for gt_json in [ground_truth] # load json from list of json
|
| 210 |
+
]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.add_tokens([self.task_start_token, self.prompt_end_token])
|
| 214 |
+
self.prompt_end_token_id = DonutTrainer.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
|
| 215 |
+
|
| 216 |
+
def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
|
| 217 |
+
"""
|
| 218 |
+
Convert an ordered JSON object into a token sequence
|
| 219 |
+
"""
|
| 220 |
+
if type(obj) == dict:
|
| 221 |
+
if len(obj) == 1 and "text_sequence" in obj:
|
| 222 |
+
return obj["text_sequence"]
|
| 223 |
+
else:
|
| 224 |
+
output = ""
|
| 225 |
+
if sort_json_key:
|
| 226 |
+
keys = sorted(obj.keys(), reverse=True)
|
| 227 |
+
else:
|
| 228 |
+
keys = obj.keys()
|
| 229 |
+
for k in keys:
|
| 230 |
+
if update_special_tokens_for_json_key:
|
| 231 |
+
self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
|
| 232 |
+
output += (
|
| 233 |
+
fr"<s_{k}>"
|
| 234 |
+
+ self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
|
| 235 |
+
+ fr"</s_{k}>"
|
| 236 |
+
)
|
| 237 |
+
return output
|
| 238 |
+
elif type(obj) == list:
|
| 239 |
+
return r"<sep/>".join(
|
| 240 |
+
[self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
obj = str(obj)
|
| 244 |
+
if f"<{obj}/>" in DonutTrainer.added_tokens:
|
| 245 |
+
obj = f"<{obj}/>" # for categorical special tokens
|
| 246 |
+
return obj
|
| 247 |
+
|
| 248 |
+
def add_tokens(self, list_of_tokens: List[str]):
|
| 249 |
+
"""
|
| 250 |
+
Add special tokens to tokenizer and resize the token embeddings of the decoder
|
| 251 |
+
"""
|
| 252 |
+
newly_added_num = DonutTrainer.processor.tokenizer.add_tokens(list_of_tokens)
|
| 253 |
+
if newly_added_num > 0:
|
| 254 |
+
DonutTrainer.model.decoder.resize_token_embeddings(len(DonutTrainer.processor.tokenizer))
|
| 255 |
+
DonutTrainer.added_tokens.extend(list_of_tokens)
|
| 256 |
+
|
| 257 |
+
def __len__(self) -> int:
|
| 258 |
+
return self.dataset_length
|
| 259 |
+
|
| 260 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 261 |
+
"""
|
| 262 |
+
Load image from image_path of given dataset_path and convert into input_tensor and labels
|
| 263 |
+
Convert gt data into input_ids (tokenized string)
|
| 264 |
+
Returns:
|
| 265 |
+
input_tensor : preprocessed image
|
| 266 |
+
input_ids : tokenized gt_data
|
| 267 |
+
labels : masked labels (model doesn't need to predict prompt and pad token)
|
| 268 |
+
"""
|
| 269 |
+
sample = self.dataset[idx]
|
| 270 |
+
|
| 271 |
+
# inputs
|
| 272 |
+
pixel_values = DonutTrainer.processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
|
| 273 |
+
pixel_values = pixel_values.squeeze()
|
| 274 |
+
|
| 275 |
+
# targets
|
| 276 |
+
target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
|
| 277 |
+
input_ids = DonutTrainer.processor.tokenizer(
|
| 278 |
+
target_sequence,
|
| 279 |
+
add_special_tokens=False,
|
| 280 |
+
max_length=self.max_length,
|
| 281 |
+
padding="max_length",
|
| 282 |
+
truncation=True,
|
| 283 |
+
return_tensors="pt",
|
| 284 |
+
)["input_ids"].squeeze(0)
|
| 285 |
+
|
| 286 |
+
labels = input_ids.clone()
|
| 287 |
+
labels[labels == DonutTrainer.processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
|
| 288 |
+
# labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
|
| 289 |
+
return pixel_values, labels, target_sequence
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class DonutModelPLModule(pl.LightningModule):
|
| 293 |
+
def __init__(self, config, processor, model):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.config = config
|
| 296 |
+
self.processor = processor
|
| 297 |
+
self.model = model
|
| 298 |
+
|
| 299 |
+
def training_step(self, batch, batch_idx):
|
| 300 |
+
pixel_values, labels, _ = batch
|
| 301 |
+
|
| 302 |
+
outputs = self.model(pixel_values, labels=labels)
|
| 303 |
+
loss = outputs.loss
|
| 304 |
+
self.log("train_loss", loss)
|
| 305 |
+
return loss
|
| 306 |
+
|
| 307 |
+
def validation_step(self, batch, batch_idx, dataset_idx=0):
|
| 308 |
+
pixel_values, labels, answers = batch
|
| 309 |
+
batch_size = pixel_values.shape[0]
|
| 310 |
+
# we feed the prompt to the model
|
| 311 |
+
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
| 312 |
+
|
| 313 |
+
outputs = self.model.generate(pixel_values,
|
| 314 |
+
decoder_input_ids=decoder_input_ids,
|
| 315 |
+
max_length=DonutTrainer.max_length,
|
| 316 |
+
early_stopping=True,
|
| 317 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
| 318 |
+
eos_token_id=self.processor.tokenizer.eos_token_id,
|
| 319 |
+
use_cache=True,
|
| 320 |
+
num_beams=1,
|
| 321 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
| 322 |
+
return_dict_in_generate=True,)
|
| 323 |
+
|
| 324 |
+
predictions = []
|
| 325 |
+
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
|
| 326 |
+
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
| 327 |
+
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
|
| 328 |
+
predictions.append(seq)
|
| 329 |
+
|
| 330 |
+
scores = []
|
| 331 |
+
for pred, answer in zip(predictions, answers):
|
| 332 |
+
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
|
| 333 |
+
# NOT NEEDED ANYMORE
|
| 334 |
+
# answer = re.sub(r"<.*?>", "", answer, count=1)
|
| 335 |
+
answer = answer.replace(self.processor.tokenizer.eos_token, "")
|
| 336 |
+
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
|
| 337 |
+
|
| 338 |
+
if self.config.get("verbose", False) and len(scores) == 1:
|
| 339 |
+
print(f"Prediction: {pred}")
|
| 340 |
+
print(f" Answer: {answer}")
|
| 341 |
+
print(f" Normed ED: {scores[0]}")
|
| 342 |
+
|
| 343 |
+
val_edit_distance = np.mean(scores)
|
| 344 |
+
self.log("val_edit_distance", val_edit_distance)
|
| 345 |
+
print(f"Validation Edit Distance: {val_edit_distance}")
|
| 346 |
+
|
| 347 |
+
return scores
|
| 348 |
+
|
| 349 |
+
def configure_optimizers(self):
|
| 350 |
+
# you could also add a learning rate scheduler if you want
|
| 351 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
|
| 352 |
+
|
| 353 |
+
return optimizer
|
| 354 |
+
|
| 355 |
+
def train_dataloader(self):
|
| 356 |
+
return DonutTrainer.train_dataloader
|
| 357 |
+
|
| 358 |
+
def val_dataloader(self):
|
| 359 |
+
return DonutTrainer.val_dataloader
|
| 360 |
+
|
| 361 |
+
class PushToHubCallback(Callback):
|
| 362 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
| 363 |
+
print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
|
| 364 |
+
pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training in progress, epoch {trainer.current_epoch}")
|
| 365 |
+
self._upload_logs(trainer.logger.log_dir, trainer.current_epoch)
|
| 366 |
+
|
| 367 |
+
def on_train_end(self, trainer, pl_module):
|
| 368 |
+
print(f"Pushing model to the hub after training")
|
| 369 |
+
pl_module.processor.push_to_hub(DonutTrainer.huggingface_model_id,commit_message=f"Training done")
|
| 370 |
+
pl_module.model.push_to_hub(DonutTrainer.huggingface_model_id, commit_message=f"Training done")
|
| 371 |
+
self._upload_logs(trainer.logger.log_dir, "final")
|
| 372 |
+
|
| 373 |
+
def _upload_logs(self, log_dir: str, epoch_info):
|
| 374 |
+
try:
|
| 375 |
+
print(f"Attempting to upload logs from: {log_dir}")
|
| 376 |
+
upload_folder(folder_path=log_dir, repo_id=DonutTrainer.huggingface_model_id,
|
| 377 |
+
path_in_repo="tensorboard_logs",
|
| 378 |
+
commit_message=f"Upload logs - epoch {epoch_info}", ignore_patterns=["*.tmp", "*.lock"])
|
| 379 |
+
print(f"Successfully uploaded logs for epoch {epoch_info}")
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(f"Failed to upload logs: {e}")
|
| 382 |
+
pass
|
| 383 |
+
|
| 384 |
+
@classmethod
|
| 385 |
+
def train(
|
| 386 |
+
cls,
|
| 387 |
+
dataset: DatasetDict,
|
| 388 |
+
pretrained_model_repo_id: str,
|
| 389 |
+
huggingface_model_id: str,
|
| 390 |
+
epochs: int,
|
| 391 |
+
train_batch_size: int,
|
| 392 |
+
val_batch_size: int,
|
| 393 |
+
learning_rate: float,
|
| 394 |
+
val_check_interval: float,
|
| 395 |
+
check_val_every_n_epoch: int,
|
| 396 |
+
gradient_clip_val: float,
|
| 397 |
+
num_training_samples_per_epoch: int,
|
| 398 |
+
num_nodes: int,
|
| 399 |
+
warmup_steps: int,
|
| 400 |
+
ground_truth_key: str = "ground_truth"
|
| 401 |
+
):
|
| 402 |
+
cls.huggingface_model_id = huggingface_model_id
|
| 403 |
+
config = VisionEncoderDecoderConfig.from_pretrained(pretrained_model_repo_id)
|
| 404 |
+
config.encoder.image_size = cls.image_size
|
| 405 |
+
config.decoder.max_length = cls.max_length
|
| 406 |
+
|
| 407 |
+
cls.processor = DonutProcessor.from_pretrained(pretrained_model_repo_id)
|
| 408 |
+
cls.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_repo_id, config=config)
|
| 409 |
+
cls.processor.image_processor.size = cls.image_size[::-1]
|
| 410 |
+
cls.processor.image_processor.do_align_long_axis = False
|
| 411 |
+
|
| 412 |
+
train_dataset = cls.DonutDataset(
|
| 413 |
+
dataset=dataset,
|
| 414 |
+
ground_truth_key=ground_truth_key,
|
| 415 |
+
max_length=cls.max_length,
|
| 416 |
+
split="train",
|
| 417 |
+
task_start_token=TASK_PROMPT_NAME,
|
| 418 |
+
prompt_end_token=TASK_PROMPT_NAME,
|
| 419 |
+
sort_json_key=True
|
| 420 |
+
)
|
| 421 |
+
val_dataset = cls.DonutDataset(
|
| 422 |
+
dataset=dataset,
|
| 423 |
+
ground_truth_key=ground_truth_key,
|
| 424 |
+
max_length=cls.max_length,
|
| 425 |
+
split="validation",
|
| 426 |
+
task_start_token=TASK_PROMPT_NAME,
|
| 427 |
+
prompt_end_token=TASK_PROMPT_NAME,
|
| 428 |
+
sort_json_key=True
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
cls.model.config.pad_token_id = cls.processor.tokenizer.pad_token_id
|
| 432 |
+
cls.model.config.decoder_start_token_id = cls.processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]
|
| 433 |
+
|
| 434 |
+
cls.train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
|
| 435 |
+
cls.val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
|
| 436 |
+
|
| 437 |
+
config = {
|
| 438 |
+
"max_epochs": epochs,
|
| 439 |
+
"val_check_interval": val_check_interval, # how many times we want to validate during an epoch
|
| 440 |
+
"check_val_every_n_epoch": check_val_every_n_epoch,
|
| 441 |
+
"gradient_clip_val": gradient_clip_val,
|
| 442 |
+
"num_training_samples_per_epoch": num_training_samples_per_epoch,
|
| 443 |
+
"lr": learning_rate,
|
| 444 |
+
"train_batch_sizes": [train_batch_size],
|
| 445 |
+
"val_batch_sizes": [val_batch_size],
|
| 446 |
+
# "seed":2022,
|
| 447 |
+
"num_nodes": num_nodes,
|
| 448 |
+
"warmup_steps": warmup_steps, # 10%
|
| 449 |
+
"result_path": "./.checkpoints",
|
| 450 |
+
"verbose": True,
|
| 451 |
+
}
|
| 452 |
+
model_module = cls.DonutModelPLModule(config, cls.processor, cls.model)
|
| 453 |
+
|
| 454 |
+
device = (
|
| 455 |
+
"cuda"
|
| 456 |
+
if torch.cuda.is_available()
|
| 457 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 458 |
+
)
|
| 459 |
+
print(f"Using {device} device")
|
| 460 |
+
trainer = pl.Trainer(
|
| 461 |
+
accelerator="gpu" if device == "cuda" else "mps" if device == "mps" else "cpu",
|
| 462 |
+
devices=1 if device == "cuda" else 0,
|
| 463 |
+
max_epochs=config.get("max_epochs"),
|
| 464 |
+
val_check_interval=config.get("val_check_interval"),
|
| 465 |
+
check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
|
| 466 |
+
gradient_clip_val=config.get("gradient_clip_val"),
|
| 467 |
+
precision=16 if device == "cuda" else 32, # we'll use mixed precision if device == "cuda"
|
| 468 |
+
num_sanity_val_steps=0,
|
| 469 |
+
logger=TensorBoardLogger(save_dir="./.checkpoints", name="donut_training", version=None),
|
| 470 |
+
callbacks=[cls.PushToHubCallback()]
|
| 471 |
+
)
|
| 472 |
+
trainer.fit(model_module)
|
menu/llm/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .gemini import GeminiAPI
|
| 2 |
+
from .openai import OpenAIAPI
|
menu/llm/base.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
class LLMBase(ABC):
|
| 6 |
+
@classmethod
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def call(image: np.ndarray, model: str, token: str) -> dict:
|
| 9 |
+
raise NotImplementedError
|
menu/llm/gemini.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from google import genai
|
| 6 |
+
from google.genai import types
|
| 7 |
+
|
| 8 |
+
from .base import LLMBase
|
| 9 |
+
|
| 10 |
+
FUNCTION_CALL = json.load(open("tools/schema_gemini.json", "r"))
|
| 11 |
+
|
| 12 |
+
class GeminiAPI(LLMBase):
|
| 13 |
+
@classmethod
|
| 14 |
+
def call(cls, image: np.ndarray, model: str, token: str) -> dict:
|
| 15 |
+
client = genai.Client(api_key=token) # Initialize the client with the API key
|
| 16 |
+
encode_img = Image.fromarray(image) # Convert the image for the API
|
| 17 |
+
|
| 18 |
+
config = types.GenerateContentConfig(
|
| 19 |
+
tools=[types.Tool(function_declarations=[FUNCTION_CALL])],
|
| 20 |
+
tool_config={
|
| 21 |
+
"function_calling_config": {
|
| 22 |
+
"mode": "ANY",
|
| 23 |
+
"allowed_function_names": [FUNCTION_CALL["name"]]
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
)
|
| 27 |
+
response = client.models.generate_content(
|
| 28 |
+
model=model,
|
| 29 |
+
contents=[encode_img],
|
| 30 |
+
config=config
|
| 31 |
+
)
|
| 32 |
+
if response.candidates[0].content.parts[0].function_call:
|
| 33 |
+
function_call = response.candidates[0].content.parts[0].function_call
|
| 34 |
+
return function_call.args
|
| 35 |
+
|
| 36 |
+
return {}
|
menu/llm/openai.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import base64
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
|
| 9 |
+
from .base import LLMBase
|
| 10 |
+
|
| 11 |
+
FUNCTION_CALL = json.load(open("tools/schema_openai.json", "r"))
|
| 12 |
+
|
| 13 |
+
class OpenAIAPI(LLMBase):
|
| 14 |
+
@classmethod
|
| 15 |
+
def call(cls, image: np.ndarray, model: str, token: str) -> dict:
|
| 16 |
+
client = OpenAI(api_key=token) # Initialize the client with the API key
|
| 17 |
+
buffer = BytesIO()
|
| 18 |
+
Image.fromarray(image).save(buffer, format="JPEG")
|
| 19 |
+
encode_img = base64.b64encode(buffer.getvalue()).decode("utf-8") # Convert the image for the API
|
| 20 |
+
|
| 21 |
+
response = client.responses.create(
|
| 22 |
+
model=model,
|
| 23 |
+
input=[
|
| 24 |
+
{
|
| 25 |
+
"role": "user",
|
| 26 |
+
"content": [
|
| 27 |
+
{
|
| 28 |
+
"type": "input_image",
|
| 29 |
+
"image_url": f"data:image/jpeg;base64,{encode_img}",
|
| 30 |
+
},
|
| 31 |
+
],
|
| 32 |
+
}
|
| 33 |
+
],
|
| 34 |
+
tools=[FUNCTION_CALL],
|
| 35 |
+
)
|
| 36 |
+
if response and response.output:
|
| 37 |
+
if hasattr(response.output[0], "arguments"):
|
| 38 |
+
return json.loads(response.output[0].arguments)
|
| 39 |
+
return {}
|
menu/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from datasets import Dataset, DatasetDict
|
| 4 |
+
|
| 5 |
+
def split_dataset(
|
| 6 |
+
dataset: Dataset,
|
| 7 |
+
train: float,
|
| 8 |
+
validation: float,
|
| 9 |
+
test: float,
|
| 10 |
+
seed: Optional[int] = None
|
| 11 |
+
) -> DatasetDict:
|
| 12 |
+
"""
|
| 13 |
+
Split a single-split Hugging Face Dataset into train/validation/test subsets.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
dataset (Dataset): The input dataset (e.g. load_dataset(...)['train']).
|
| 17 |
+
train (float): Proportion of data for the train split (0 < train < 1).
|
| 18 |
+
val (float): Proportion of data for the validation split (0 < val < 1).
|
| 19 |
+
test (float): Proportion of data for the test split (0 < test < 1).
|
| 20 |
+
Must satisfy train + val + test == 1.0.
|
| 21 |
+
seed (int): Random seed for reproducibility (default: None).
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
DatasetDict: A dictionary with keys "train", "validation", and "test".
|
| 25 |
+
"""
|
| 26 |
+
# Verify ratios sum to 1.0
|
| 27 |
+
total = train + validation + test
|
| 28 |
+
if abs(total - 1.0) > 1e-8:
|
| 29 |
+
raise ValueError(f"train + validation + test must equal 1.0 (got {total})")
|
| 30 |
+
|
| 31 |
+
# First split: extract train vs. temp (validation + test)
|
| 32 |
+
temp_size = validation + test
|
| 33 |
+
split_1 = dataset.train_test_split(test_size=temp_size, seed=seed)
|
| 34 |
+
train_ds = split_1["train"]
|
| 35 |
+
temp_ds = split_1["test"]
|
| 36 |
+
|
| 37 |
+
# Second split: divide temp into validation vs. test
|
| 38 |
+
relative_test_size = test / temp_size
|
| 39 |
+
split_2 = temp_ds.train_test_split(test_size=relative_test_size, seed=seed)
|
| 40 |
+
validation_ds = split_2["train"]
|
| 41 |
+
test_ds = split_2["test"]
|
| 42 |
+
|
| 43 |
+
# Return a DatasetDict with all three splits
|
| 44 |
+
return DatasetDict({
|
| 45 |
+
"train": train_ds,
|
| 46 |
+
"validation": validation_ds,
|
| 47 |
+
"test": test_ds,
|
| 48 |
+
})
|
pyproject.toml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
authors = [{name = "ryanlinjui", email = "ryanlinjui@gmail.com"}]
|
| 3 |
+
name = "menu-text-detection"
|
| 4 |
+
version = "0.1.0"
|
| 5 |
+
description = "Extract structured menu information from images into JSON using a fine-tuned Donut E2E model."
|
| 6 |
+
readme = "README.md"
|
| 7 |
+
requires-python = "==3.11.*"
|
| 8 |
+
dependencies = [
|
| 9 |
+
"datasets>=3.6.0",
|
| 10 |
+
"dotenv>=0.9.9",
|
| 11 |
+
"google-genai>=1.14.0",
|
| 12 |
+
"gradio>=5.29.0",
|
| 13 |
+
"huggingface-hub>=0.31.1",
|
| 14 |
+
"matplotlib>=3.10.1",
|
| 15 |
+
"nltk>=3.9.1",
|
| 16 |
+
"notebook>=7.4.2",
|
| 17 |
+
"openai>=1.77.0",
|
| 18 |
+
"pillow>=11.2.1",
|
| 19 |
+
"pillow-heif>=0.22.0",
|
| 20 |
+
"protobuf>=6.30.2",
|
| 21 |
+
"pytorch-lightning>=2.5.2",
|
| 22 |
+
"sentencepiece>=0.2.0",
|
| 23 |
+
"tensorboardx>=2.6.2.2",
|
| 24 |
+
"transformers==4.49",
|
| 25 |
+
"torch==2.4.1",
|
| 26 |
+
"donut-python>=1.0.9",
|
| 27 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.11.18
|
| 4 |
+
aiosignal==1.3.2
|
| 5 |
+
annotated-types==0.7.0
|
| 6 |
+
anyio==4.9.0
|
| 7 |
+
appnope==0.1.4
|
| 8 |
+
argon2-cffi==23.1.0
|
| 9 |
+
argon2-cffi-bindings==21.2.0
|
| 10 |
+
arrow==1.3.0
|
| 11 |
+
asttokens==3.0.0
|
| 12 |
+
async-lru==2.0.5
|
| 13 |
+
attrs==25.3.0
|
| 14 |
+
babel==2.17.0
|
| 15 |
+
beautifulsoup4==4.13.4
|
| 16 |
+
bleach==6.2.0
|
| 17 |
+
cachetools==5.5.2
|
| 18 |
+
certifi==2025.4.26
|
| 19 |
+
cffi==1.17.1
|
| 20 |
+
charset-normalizer==3.4.2
|
| 21 |
+
click==8.1.8
|
| 22 |
+
comm==0.2.2
|
| 23 |
+
contourpy==1.3.2
|
| 24 |
+
cycler==0.12.1
|
| 25 |
+
datasets==3.6.0
|
| 26 |
+
debugpy==1.8.14
|
| 27 |
+
decorator==5.2.1
|
| 28 |
+
defusedxml==0.7.1
|
| 29 |
+
dill==0.3.8
|
| 30 |
+
distro==1.9.0
|
| 31 |
+
donut-python==1.0.9
|
| 32 |
+
dotenv==0.9.9
|
| 33 |
+
executing==2.2.0
|
| 34 |
+
fastapi==0.115.12
|
| 35 |
+
fastjsonschema==2.21.1
|
| 36 |
+
ffmpy==0.5.0
|
| 37 |
+
filelock==3.18.0
|
| 38 |
+
fonttools==4.57.0
|
| 39 |
+
fqdn==1.5.1
|
| 40 |
+
frozenlist==1.6.0
|
| 41 |
+
fsspec==2025.3.0
|
| 42 |
+
google-auth==2.40.1
|
| 43 |
+
google-genai==1.14.0
|
| 44 |
+
gradio==5.29.0
|
| 45 |
+
gradio-client==1.10.0
|
| 46 |
+
groovy==0.1.2
|
| 47 |
+
h11==0.16.0
|
| 48 |
+
hf-xet==1.1.0
|
| 49 |
+
httpcore==1.0.9
|
| 50 |
+
httpx==0.28.1
|
| 51 |
+
huggingface-hub==0.31.1
|
| 52 |
+
idna==3.10
|
| 53 |
+
ipykernel==6.29.5
|
| 54 |
+
ipython==9.2.0
|
| 55 |
+
ipython-pygments-lexers==1.1.1
|
| 56 |
+
isoduration==20.11.0
|
| 57 |
+
jedi==0.19.2
|
| 58 |
+
jinja2==3.1.6
|
| 59 |
+
jiter==0.9.0
|
| 60 |
+
joblib==1.5.0
|
| 61 |
+
json5==0.12.0
|
| 62 |
+
jsonpointer==3.0.0
|
| 63 |
+
jsonschema==4.23.0
|
| 64 |
+
jsonschema-specifications==2025.4.1
|
| 65 |
+
jupyter-client==8.6.3
|
| 66 |
+
jupyter-core==5.7.2
|
| 67 |
+
jupyter-events==0.12.0
|
| 68 |
+
jupyter-lsp==2.2.5
|
| 69 |
+
jupyter-server==2.15.0
|
| 70 |
+
jupyter-server-terminals==0.5.3
|
| 71 |
+
jupyterlab==4.4.2
|
| 72 |
+
jupyterlab-pygments==0.3.0
|
| 73 |
+
jupyterlab-server==2.27.3
|
| 74 |
+
kiwisolver==1.4.8
|
| 75 |
+
lightning-utilities==0.14.3
|
| 76 |
+
markdown-it-py==3.0.0
|
| 77 |
+
markupsafe==3.0.2
|
| 78 |
+
matplotlib==3.10.1
|
| 79 |
+
matplotlib-inline==0.1.7
|
| 80 |
+
mdurl==0.1.2
|
| 81 |
+
mistune==3.1.3
|
| 82 |
+
mpmath==1.3.0
|
| 83 |
+
multidict==6.4.3
|
| 84 |
+
multiprocess==0.70.16
|
| 85 |
+
munch==4.0.0
|
| 86 |
+
nbclient==0.10.2
|
| 87 |
+
nbconvert==7.16.6
|
| 88 |
+
nbformat==5.10.4
|
| 89 |
+
nest-asyncio==1.6.0
|
| 90 |
+
networkx==3.4.2
|
| 91 |
+
nltk==3.9.1
|
| 92 |
+
notebook==7.4.2
|
| 93 |
+
notebook-shim==0.2.4
|
| 94 |
+
numpy==2.2.5
|
| 95 |
+
openai==1.77.0
|
| 96 |
+
orjson==3.10.18
|
| 97 |
+
overrides==7.7.0
|
| 98 |
+
packaging==25.0
|
| 99 |
+
pandas==2.2.3
|
| 100 |
+
pandocfilters==1.5.1
|
| 101 |
+
parso==0.8.4
|
| 102 |
+
pexpect==4.9.0
|
| 103 |
+
pillow==11.2.1
|
| 104 |
+
pillow-heif==0.22.0
|
| 105 |
+
platformdirs==4.3.8
|
| 106 |
+
prometheus-client==0.21.1
|
| 107 |
+
prompt-toolkit==3.0.51
|
| 108 |
+
propcache==0.3.1
|
| 109 |
+
protobuf==6.30.2
|
| 110 |
+
psutil==7.0.0
|
| 111 |
+
ptyprocess==0.7.0
|
| 112 |
+
pure-eval==0.2.3
|
| 113 |
+
pyarrow==20.0.0
|
| 114 |
+
pyasn1==0.6.1
|
| 115 |
+
pyasn1-modules==0.4.2
|
| 116 |
+
pycparser==2.22
|
| 117 |
+
pydantic==2.11.4
|
| 118 |
+
pydantic-core==2.33.2
|
| 119 |
+
pydub==0.25.1
|
| 120 |
+
pygments==2.19.1
|
| 121 |
+
pyparsing==3.2.3
|
| 122 |
+
python-dateutil==2.9.0.post0
|
| 123 |
+
python-dotenv==1.1.0
|
| 124 |
+
python-json-logger==3.3.0
|
| 125 |
+
python-multipart==0.0.20
|
| 126 |
+
pytorch-lightning==2.5.2
|
| 127 |
+
pytz==2025.2
|
| 128 |
+
pyyaml==6.0.2
|
| 129 |
+
pyzmq==26.4.0
|
| 130 |
+
referencing==0.36.2
|
| 131 |
+
regex==2024.11.6
|
| 132 |
+
requests==2.32.3
|
| 133 |
+
rfc3339-validator==0.1.4
|
| 134 |
+
rfc3986-validator==0.1.1
|
| 135 |
+
rich==14.0.0
|
| 136 |
+
rpds-py==0.24.0
|
| 137 |
+
rsa==4.9.1
|
| 138 |
+
ruamel-yaml==0.18.14
|
| 139 |
+
ruamel-yaml-clib==0.2.12
|
| 140 |
+
ruff==0.11.8
|
| 141 |
+
safehttpx==0.1.6
|
| 142 |
+
safetensors==0.5.3
|
| 143 |
+
sconf==0.2.5
|
| 144 |
+
semantic-version==2.10.0
|
| 145 |
+
send2trash==1.8.3
|
| 146 |
+
sentencepiece==0.2.0
|
| 147 |
+
setuptools==80.3.1
|
| 148 |
+
shellingham==1.5.4
|
| 149 |
+
six==1.17.0
|
| 150 |
+
sniffio==1.3.1
|
| 151 |
+
soupsieve==2.7
|
| 152 |
+
stack-data==0.6.3
|
| 153 |
+
starlette==0.46.2
|
| 154 |
+
sympy==1.14.0
|
| 155 |
+
tensorboardx==2.6.2.2
|
| 156 |
+
terminado==0.18.1
|
| 157 |
+
timm==1.0.16
|
| 158 |
+
tinycss2==1.4.0
|
| 159 |
+
tokenizers==0.21.1
|
| 160 |
+
tomlkit==0.13.2
|
| 161 |
+
torch==2.4.1
|
| 162 |
+
torchmetrics==1.7.3
|
| 163 |
+
torchvision==0.19.1
|
| 164 |
+
tornado==6.4.2
|
| 165 |
+
tqdm==4.67.1
|
| 166 |
+
traitlets==5.14.3
|
| 167 |
+
transformers==4.49.0
|
| 168 |
+
typer==0.15.3
|
| 169 |
+
types-python-dateutil==2.9.0.20241206
|
| 170 |
+
typing-extensions==4.13.2
|
| 171 |
+
typing-inspection==0.4.0
|
| 172 |
+
tzdata==2025.2
|
| 173 |
+
uri-template==1.3.0
|
| 174 |
+
urllib3==2.4.0
|
| 175 |
+
uvicorn==0.34.2
|
| 176 |
+
wcwidth==0.2.13
|
| 177 |
+
webcolors==24.11.1
|
| 178 |
+
webencodings==0.5.1
|
| 179 |
+
websocket-client==1.8.0
|
| 180 |
+
websockets==15.0.1
|
| 181 |
+
xxhash==3.5.0
|
| 182 |
+
yarl==1.20.0
|
| 183 |
+
zss==1.2.0
|
tools/schema_gemini.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "extract_menu_data",
|
| 3 |
+
"description": "Extract structured menu information from images.",
|
| 4 |
+
"parameters": {
|
| 5 |
+
"type": "object",
|
| 6 |
+
"properties": {
|
| 7 |
+
"restaurant": {
|
| 8 |
+
"type": "string",
|
| 9 |
+
"description": "Name of the restaurant. If the name is not available, it should be ''."
|
| 10 |
+
},
|
| 11 |
+
"address": {
|
| 12 |
+
"type": "string",
|
| 13 |
+
"description": "Address of the restaurant. If the address is not available, it should be ''."
|
| 14 |
+
},
|
| 15 |
+
"phone": {
|
| 16 |
+
"type": "string",
|
| 17 |
+
"description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
|
| 18 |
+
},
|
| 19 |
+
"business_hours": {
|
| 20 |
+
"type": "string",
|
| 21 |
+
"description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
|
| 22 |
+
},
|
| 23 |
+
"dishes": {
|
| 24 |
+
"type": "array",
|
| 25 |
+
"items": {
|
| 26 |
+
"type": "object",
|
| 27 |
+
"properties": {
|
| 28 |
+
"name": {
|
| 29 |
+
"type": "string",
|
| 30 |
+
"description": "Name of the menu item."
|
| 31 |
+
},
|
| 32 |
+
"price": {
|
| 33 |
+
"type": "string",
|
| 34 |
+
"description": "Price of the menu item. If the price is not available, it should be -1."
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"required": ["name", "price"]
|
| 38 |
+
},
|
| 39 |
+
"description": "List of menu dishes item."
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
"required": ["restaurant", "address", "phone", "business_hours", "dishes"]
|
| 43 |
+
}
|
| 44 |
+
}
|
tools/schema_openai.json
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"type": "function",
|
| 3 |
+
"name": "extract_menu_data",
|
| 4 |
+
"description": "Extract structured menu information from images.",
|
| 5 |
+
"parameters": {
|
| 6 |
+
"type": "object",
|
| 7 |
+
"properties": {
|
| 8 |
+
"restaurant": {
|
| 9 |
+
"type": "string",
|
| 10 |
+
"description": "Name of the restaurant. If the name is not available, it should be ''."
|
| 11 |
+
},
|
| 12 |
+
"address": {
|
| 13 |
+
"type": "string",
|
| 14 |
+
"description": "Address of the restaurant. If the address is not available, it should be ''."
|
| 15 |
+
},
|
| 16 |
+
"phone": {
|
| 17 |
+
"type": "string",
|
| 18 |
+
"description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
|
| 19 |
+
},
|
| 20 |
+
"business_hours": {
|
| 21 |
+
"type": "string",
|
| 22 |
+
"description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
|
| 23 |
+
},
|
| 24 |
+
"dishes": {
|
| 25 |
+
"type": "array",
|
| 26 |
+
"items": {
|
| 27 |
+
"type": "object",
|
| 28 |
+
"properties": {
|
| 29 |
+
"name": {
|
| 30 |
+
"type": "string",
|
| 31 |
+
"description": "Name of the menu item."
|
| 32 |
+
},
|
| 33 |
+
"price": {
|
| 34 |
+
"type": "string",
|
| 35 |
+
"description": "Price of the menu item. If the price is not available, it should be -1."
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"required": ["name", "price"],
|
| 39 |
+
"additionalProperties": false
|
| 40 |
+
},
|
| 41 |
+
"description": "List of menu dishes item."
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"required": ["restaurant", "address", "phone", "business_hours", "dishes"],
|
| 45 |
+
"additionalProperties": false
|
| 46 |
+
}
|
| 47 |
+
}
|
train.ipynb
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Login to HuggingFace (just login once)"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": null,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"from huggingface_hub import interpreter_login\n",
|
| 17 |
+
"interpreter_login()"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "markdown",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": [
|
| 24 |
+
"# Collect Menu Image Datasets\n",
|
| 25 |
+
"- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n",
|
| 26 |
+
"- After finishing, push to HuggingFace Datasets.\n",
|
| 27 |
+
"- For labeling:\n",
|
| 28 |
+
" - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n",
|
| 29 |
+
" - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"### Menu Type\n",
|
| 32 |
+
"- **h**: horizontal menu\n",
|
| 33 |
+
"- **v**: vertical menu\n",
|
| 34 |
+
"- **d**: document-style menu\n",
|
| 35 |
+
"- **s**: in-scene menu (non-document style)\n",
|
| 36 |
+
"- **i**: irregular menu (menu with irregular text layout)\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details."
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"import os\n",
|
| 48 |
+
"import json\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"import numpy as np\n",
|
| 51 |
+
"from PIL import Image\n",
|
| 52 |
+
"from pillow_heif import register_heif_opener\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"from menu.llm import (\n",
|
| 55 |
+
" GeminiAPI,\n",
|
| 56 |
+
" OpenAIAPI\n",
|
| 57 |
+
")\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"IMAGE_DIR = \"datasets/images\" # set your image directory here\n",
|
| 60 |
+
"SELECTED_MODEL = \"gemini-2.5-flash\" # set model name here, refer MODEL_LIST from app.py for more\n",
|
| 61 |
+
"API_TOKEN = \"\" # set your API token here\n",
|
| 62 |
+
"SELECTED_FUNCTION = GeminiAPI # set \"GeminiAPI\" or \"OpenAIAPI\"\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"register_heif_opener()\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"for file in os.listdir(IMAGE_DIR):\n",
|
| 67 |
+
" print(f\"Processing image: {file}\")\n",
|
| 68 |
+
" try:\n",
|
| 69 |
+
" image = np.array(Image.open(os.path.join(IMAGE_DIR, file)))\n",
|
| 70 |
+
" data = {\n",
|
| 71 |
+
" \"file_name\": file,\n",
|
| 72 |
+
" \"menu\": SELECTED_FUNCTION.call(image, SELECTED_MODEL, API_TOKEN)\n",
|
| 73 |
+
" }\n",
|
| 74 |
+
" with open(os.path.join(IMAGE_DIR, \"metadata.jsonl\"), \"a\", encoding=\"utf-8\") as metaf:\n",
|
| 75 |
+
" metaf.write(json.dumps(data, ensure_ascii=False, sort_keys=True) + \"\\n\")\n",
|
| 76 |
+
" except Exception as e:\n",
|
| 77 |
+
" print(f\"Skipping invalid image '{file}': {e}\")\n",
|
| 78 |
+
" continue"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"source": [
|
| 85 |
+
"# Push Datasets to HuggingFace"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": null,
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"source": [
|
| 94 |
+
"from datasets import load_dataset\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n",
|
| 97 |
+
"dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "markdown",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"source": [
|
| 104 |
+
"# Prepare the dataset for training"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"from menu.utils import split_dataset\n",
|
| 114 |
+
"from datasets import load_dataset\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"dataset = load_dataset(path=\"ryanlinjui/menu-zh-TW\") # set your dataset repo id for training\n",
|
| 117 |
+
"dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
|
| 118 |
+
"print(f\"Dataset split: {len(dataset['train'])} train, {len(dataset['validation'])} validation, {len(dataset['test'])} test\")"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "markdown",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"source": [
|
| 125 |
+
"# Fine-tune Donut Model"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"import logging\n",
|
| 135 |
+
"from menu.donut import DonutTrainer\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"logging.getLogger(\"transformers\").setLevel(logging.ERROR) # filter output message from transformers\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"DonutTrainer.train(\n",
|
| 140 |
+
" dataset=dataset,\n",
|
| 141 |
+
" pretrained_model_repo_id=\"naver-clova-ix/donut-base\", # set your pretrained model repo id for fine-tuning\n",
|
| 142 |
+
" ground_truth_key=\"menu\", # set your ground truth key for training\n",
|
| 143 |
+
" huggingface_model_id=\"ryanlinjui/donut-base-finetuned-menu\", # set your huggingface model repo id for saving / pushing to the hub\n",
|
| 144 |
+
" epochs=15, # set your training epochs\n",
|
| 145 |
+
" train_batch_size=8, # set your training batch size\n",
|
| 146 |
+
" val_batch_size=1, # set your validation batch size\n",
|
| 147 |
+
" learning_rate=3e-5, # set your learning rate\n",
|
| 148 |
+
" val_check_interval=0.5, # how many times we want to validate during an epoch\n",
|
| 149 |
+
" check_val_every_n_epoch=1, # how many epochs we want to validate\n",
|
| 150 |
+
" gradient_clip_val=1.0, # gradient clipping value for training stability\n",
|
| 151 |
+
" num_training_samples_per_epoch=198, # set num_training_samples_per_epoch = training set size\n",
|
| 152 |
+
" num_nodes=1, # number of nodes for distributed training\n",
|
| 153 |
+
" warmup_steps=75 # number of warmup steps for learning rate scheduler, 198/8*30/10, 10%\n",
|
| 154 |
+
")"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"cell_type": "markdown",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"source": [
|
| 161 |
+
"# Evaluate Donut Model"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": null,
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"import json\n",
|
| 171 |
+
"from datasets import load_dataset\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"from menu.utils import split_dataset\n",
|
| 174 |
+
"from menu.donut import DonutFinetuned\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"dataset = load_dataset(\"ryanlinjui/menu-zh-TW\")\n",
|
| 177 |
+
"dataset = split_dataset(dataset[\"train\"], train=0.8, validation=0.1, test=0.1, seed=42) # (optional) use it if your dataset is not split into train/validation/test\n",
|
| 178 |
+
"donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
|
| 179 |
+
"scores, output_list = donut_finetuned.evaluate(dataset=dataset[\"test\"], ground_truth_key=\"menu\")\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"print(\"Evaluation scores:\")\n",
|
| 182 |
+
"for key, value in scores.items():\n",
|
| 183 |
+
" print(f\"{key}: {value}\")\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"print(\"\\nSample outputs:\")\n",
|
| 186 |
+
"for output in output_list[:5]:\n",
|
| 187 |
+
" print(json.dumps(output, ensure_ascii=False, indent=4))"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "markdown",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"source": [
|
| 194 |
+
"# Test Donut Model"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "code",
|
| 199 |
+
"execution_count": null,
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [],
|
| 202 |
+
"source": [
|
| 203 |
+
"from PIL import Image\n",
|
| 204 |
+
"from menu.donut import DonutFinetuned\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"image = Image.open(\"./examples/menu-hd.jpg\")\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"donut_finetuned = DonutFinetuned(pretrained_model_repo_id=\"ryanlinjui/donut-base-finetuned-menu\")\n",
|
| 209 |
+
"outputs = donut_finetuned.predict(image=image)\n",
|
| 210 |
+
"print(outputs)"
|
| 211 |
+
]
|
| 212 |
+
}
|
| 213 |
+
],
|
| 214 |
+
"metadata": {
|
| 215 |
+
"kernelspec": {
|
| 216 |
+
"display_name": "menu-text-detection",
|
| 217 |
+
"language": "python",
|
| 218 |
+
"name": "python3"
|
| 219 |
+
},
|
| 220 |
+
"language_info": {
|
| 221 |
+
"codemirror_mode": {
|
| 222 |
+
"name": "ipython",
|
| 223 |
+
"version": 3
|
| 224 |
+
},
|
| 225 |
+
"file_extension": ".py",
|
| 226 |
+
"mimetype": "text/x-python",
|
| 227 |
+
"name": "python",
|
| 228 |
+
"nbconvert_exporter": "python",
|
| 229 |
+
"pygments_lexer": "ipython3",
|
| 230 |
+
"version": "3.11.12"
|
| 231 |
+
}
|
| 232 |
+
},
|
| 233 |
+
"nbformat": 4,
|
| 234 |
+
"nbformat_minor": 2
|
| 235 |
+
}
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|