Awarebeyond's picture
Remove unsupported show_download_button from Gallery for Gradio compatibility
f04e377 verified
import os
import io
import sys
import time
import signal
import asyncio
import torch
import cv2
import json
import base64
import numpy as np
import gradio as gr
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
# Fix Python 3.13 asyncio garbage-collection bug on container restart.
# See: https://github.com/python/cpython/issues/109496
_original_del = asyncio.BaseEventLoop.__del__
def _safe_loop_del(self):
try:
_original_del(self)
except (ValueError, OSError):
pass # ignore already-closed file descriptors
asyncio.BaseEventLoop.__del__ = _safe_loop_del
# Configuration
MODEL_REPO = "Awarebeyond/receipt-donut"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_REPO}...")
print(f"Using device: {DEVICE}")
processor = DonutProcessor.from_pretrained(MODEL_REPO)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_REPO)
model.to(DEVICE)
model.eval()
def preprocess_image(pil_image):
"""Apply exact same preprocessing as training."""
image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (1536, 1152), interpolation=cv2.INTER_LINEAR)
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(image)
def process_single(pil_image):
"""Process a single image and return parsed JSON dict or error string."""
try:
processed = preprocess_image(pil_image)
pixel_values = processor(processed, return_tensors="pt").pixel_values.to(DEVICE)
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(DEVICE)
with torch.no_grad():
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=512,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
seq = processor.tokenizer.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, ""
)
seq = (
seq.replace(
processor.tokenizer.decode([model.config.decoder_start_token_id]), ""
)
.strip()
)
try:
parsed = json.loads(seq)
return parsed
except json.JSONDecodeError:
return {"raw_output": seq}
except Exception as e:
return {"error": str(e)}
def extract_receipt_single(image):
"""Single image inference for the Single tab."""
if image is None:
return "Please upload an image."
result = process_single(image)
return json.dumps(result, indent=2, ensure_ascii=False)
def extract_receipt_batch(files, progress=gr.Progress()):
"""Batch inference for multiple files."""
if not files:
return [], None, "No images uploaded."
results = []
gallery_images = []
progress(0, total=len(files), desc="Processing receipts...")
for i, file in enumerate(files):
try:
if isinstance(file, str):
img = Image.open(file).convert("RGB")
else:
img = Image.open(file.name).convert("RGB")
gallery_images.append(img)
parsed = process_single(img)
results.append(
{
"filename": os.path.basename(file.name if hasattr(file, "name") else str(file)),
"merchant": parsed.get("merchant", "N/A"),
"date": parsed.get("date", "N/A"),
"subtotal": parsed.get("subtotal", "N/A"),
"tax": parsed.get("tax", "N/A"),
"total": parsed.get("total", "N/A"),
"full_json": json.dumps(parsed, ensure_ascii=False),
}
)
except Exception as e:
results.append(
{
"filename": "error",
"merchant": "ERROR",
"date": "ERROR",
"subtotal": "ERROR",
"tax": "ERROR",
"total": "ERROR",
"full_json": str(e),
}
)
progress(i + 1, total=len(files))
# Build downloadable JSON
output_json = json.dumps(results, indent=2, ensure_ascii=False)
return results, gallery_images, output_json
def create_download_file(json_str):
"""Create a temporary file for downloading."""
return json_str
# ── Gradio Blocks App ──────────────────────────────────────────────
with gr.Blocks(title="🧾🍩 Receipt Donut") as demo:
gr.Markdown(
"""
# 🧾🍩 Receipt Donut β€” Live Receipt Extraction
**Fine-tuned Donut model for structured receipt extraction.**
Upload one or multiple receipt images to instantly extract merchant, date, subtotal, tax, and total.
"""
)
with gr.Tabs():
# ── Single Upload Tab ────────────────────────────────────
with gr.TabItem("πŸ“„ Single Receipt"):
with gr.Row():
with gr.Column(scale=1):
single_input = gr.Image(
type="pil",
label="Upload Receipt Image",
sources=["upload", "clipboard"],
)
single_btn = gr.Button(
"πŸ” Extract Data", variant="primary", size="lg"
)
with gr.Column(scale=1):
single_output = gr.Code(
label="Extracted JSON",
language="json",
lines=18,
)
single_btn.click(
fn=extract_receipt_single,
inputs=single_input,
outputs=single_output,
)
# ── Batch Upload Tab ─────────────────────────────────────
with gr.TabItem("πŸ“ Batch Processing"):
gr.Markdown(
"Upload multiple receipts at once. Results appear in a table with JSON download. "
"**Note:** For bulk uploads, only the first 50 thumbnails are shown to keep the page fast."
)
batch_files = gr.File(
label="Upload Receipt Images",
file_count="multiple",
file_types=["image"],
height=120, # compact file list
)
batch_btn = gr.Button(
"πŸš€ Process All Receipts", variant="primary", size="lg"
)
batch_status = gr.Textbox(
label="Status",
interactive=False,
value="Ready β€” upload images to begin",
)
batch_gallery = gr.Gallery(
label="Receipt Thumbnails (first 50 shown)",
columns=10,
rows=5,
height=400,
object_fit="cover",
preview=True,
)
batch_results = gr.Dataframe(
headers=["Filename", "Merchant", "Date", "Subtotal", "Tax", "Total", "Full JSON"],
label="Extraction Results",
wrap=True,
column_widths=["15%", "15%", "10%", "12%", "10%", "12%", "26%"],
)
batch_json = gr.State("")
batch_download = gr.DownloadButton(
label="⬇️ Download All Results (JSON)",
variant="secondary",
)
def on_batch_click(files):
if not files:
return [], None, "", "No images uploaded."
if len(files) > 200:
return [], None, "", "⚠️ Please upload 200 or fewer receipts at once."
table_data, gallery_imgs, json_str = extract_receipt_batch(files)
# Show only first 50 thumbnails for performance
display_imgs = gallery_imgs[:50] if gallery_imgs else None
rows = [
[
r["filename"],
r["merchant"],
r["date"],
r["subtotal"],
r["tax"],
r["total"],
r["full_json"],
]
for r in table_data
]
status_msg = f"Processed {len(rows)} receipt(s)"
if len(rows) > 50:
status_msg += f" β€” showing first 50 thumbnails"
return rows, display_imgs, json_str, status_msg
batch_btn.click(
fn=on_batch_click,
inputs=batch_files,
outputs=[batch_results, batch_gallery, batch_json, batch_status],
)
batch_download.click(
fn=lambda s: s,
inputs=batch_json,
outputs=batch_download,
)
# ── Info Tab ───────────────────────────────────────────────
with gr.TabItem("ℹ️ About"):
gr.Markdown(
"""
### Model Details
- **Architecture:** Donut (Vision Encoder + Text Decoder)
- **Fine-tuned on:** 8,615 real-world receipt images
- **Training hardware:** Google Cloud L4 GPU (bf16 mixed precision)
- **Base model:** `naver-clova-ix/donut-base`
### Extracted Fields
| Field | Description |
|-------|-------------|
| `merchant` | Store or company name |
| `date` | Transaction date |
| `subtotal` | Amount before tax |
| `tax` | Tax amount |
| `total` | Final amount |
### GitHub
[Awarebeyond/receipt-donut](https://huggingface.co/Awarebeyond/receipt-donut)
"""
)
# Clean shutdown handler: prevents asyncio garbage-collection errors
# when HF Spaces sends SIGTERM during container restart
def _handle_sigterm(signum, frame):
print("\nReceived SIGTERM, shutting down gracefully...")
sys.exit(0)
signal.signal(signal.SIGTERM, _handle_sigterm)
demo.launch(theme=gr.themes.Soft(), prevent_thread_lock=True)
try:
while True:
time.sleep(1)
except (KeyboardInterrupt, SystemExit):
print("Shutdown complete.")