import os import io import sys import time import signal import asyncio import torch import cv2 import json import base64 import numpy as np import gradio as gr from PIL import Image from transformers import DonutProcessor, VisionEncoderDecoderModel # Fix Python 3.13 asyncio garbage-collection bug on container restart. # See: https://github.com/python/cpython/issues/109496 _original_del = asyncio.BaseEventLoop.__del__ def _safe_loop_del(self): try: _original_del(self) except (ValueError, OSError): pass # ignore already-closed file descriptors asyncio.BaseEventLoop.__del__ = _safe_loop_del # Configuration MODEL_REPO = "Awarebeyond/receipt-donut" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model from {MODEL_REPO}...") print(f"Using device: {DEVICE}") processor = DonutProcessor.from_pretrained(MODEL_REPO) model = VisionEncoderDecoderModel.from_pretrained(MODEL_REPO) model.to(DEVICE) model.eval() def preprocess_image(pil_image): """Apply exact same preprocessing as training.""" image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (1536, 1152), interpolation=cv2.INTER_LINEAR) gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) image = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) return Image.fromarray(image) def process_single(pil_image): """Process a single image and return parsed JSON dict or error string.""" try: processed = preprocess_image(pil_image) pixel_values = processor(processed, return_tensors="pt").pixel_values.to(DEVICE) decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(DEVICE) with torch.no_grad(): outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=512, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) seq = processor.tokenizer.batch_decode(outputs.sequences)[0] seq = seq.replace(processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "" ) seq = ( seq.replace( processor.tokenizer.decode([model.config.decoder_start_token_id]), "" ) .strip() ) try: parsed = json.loads(seq) return parsed except json.JSONDecodeError: return {"raw_output": seq} except Exception as e: return {"error": str(e)} def extract_receipt_single(image): """Single image inference for the Single tab.""" if image is None: return "Please upload an image." result = process_single(image) return json.dumps(result, indent=2, ensure_ascii=False) def extract_receipt_batch(files, progress=gr.Progress()): """Batch inference for multiple files.""" if not files: return [], None, "No images uploaded." results = [] gallery_images = [] progress(0, total=len(files), desc="Processing receipts...") for i, file in enumerate(files): try: if isinstance(file, str): img = Image.open(file).convert("RGB") else: img = Image.open(file.name).convert("RGB") gallery_images.append(img) parsed = process_single(img) results.append( { "filename": os.path.basename(file.name if hasattr(file, "name") else str(file)), "merchant": parsed.get("merchant", "N/A"), "date": parsed.get("date", "N/A"), "subtotal": parsed.get("subtotal", "N/A"), "tax": parsed.get("tax", "N/A"), "total": parsed.get("total", "N/A"), "full_json": json.dumps(parsed, ensure_ascii=False), } ) except Exception as e: results.append( { "filename": "error", "merchant": "ERROR", "date": "ERROR", "subtotal": "ERROR", "tax": "ERROR", "total": "ERROR", "full_json": str(e), } ) progress(i + 1, total=len(files)) # Build downloadable JSON output_json = json.dumps(results, indent=2, ensure_ascii=False) return results, gallery_images, output_json def create_download_file(json_str): """Create a temporary file for downloading.""" return json_str # ── Gradio Blocks App ────────────────────────────────────────────── with gr.Blocks(title="🧾🍩 Receipt Donut") as demo: gr.Markdown( """ # 🧾🍩 Receipt Donut — Live Receipt Extraction **Fine-tuned Donut model for structured receipt extraction.** Upload one or multiple receipt images to instantly extract merchant, date, subtotal, tax, and total. """ ) with gr.Tabs(): # ── Single Upload Tab ──────────────────────────────────── with gr.TabItem("📄 Single Receipt"): with gr.Row(): with gr.Column(scale=1): single_input = gr.Image( type="pil", label="Upload Receipt Image", sources=["upload", "clipboard"], ) single_btn = gr.Button( "🔍 Extract Data", variant="primary", size="lg" ) with gr.Column(scale=1): single_output = gr.Code( label="Extracted JSON", language="json", lines=18, ) single_btn.click( fn=extract_receipt_single, inputs=single_input, outputs=single_output, ) # ── Batch Upload Tab ───────────────────────────────────── with gr.TabItem("📁 Batch Processing"): gr.Markdown( "Upload multiple receipts at once. Results appear in a table with JSON download. " "**Note:** For bulk uploads, only the first 50 thumbnails are shown to keep the page fast." ) batch_files = gr.File( label="Upload Receipt Images", file_count="multiple", file_types=["image"], height=120, # compact file list ) batch_btn = gr.Button( "🚀 Process All Receipts", variant="primary", size="lg" ) batch_status = gr.Textbox( label="Status", interactive=False, value="Ready — upload images to begin", ) batch_gallery = gr.Gallery( label="Receipt Thumbnails (first 50 shown)", columns=10, rows=5, height=400, object_fit="cover", preview=True, ) batch_results = gr.Dataframe( headers=["Filename", "Merchant", "Date", "Subtotal", "Tax", "Total", "Full JSON"], label="Extraction Results", wrap=True, column_widths=["15%", "15%", "10%", "12%", "10%", "12%", "26%"], ) batch_json = gr.State("") batch_download = gr.DownloadButton( label="⬇️ Download All Results (JSON)", variant="secondary", ) def on_batch_click(files): if not files: return [], None, "", "No images uploaded." if len(files) > 200: return [], None, "", "⚠️ Please upload 200 or fewer receipts at once." table_data, gallery_imgs, json_str = extract_receipt_batch(files) # Show only first 50 thumbnails for performance display_imgs = gallery_imgs[:50] if gallery_imgs else None rows = [ [ r["filename"], r["merchant"], r["date"], r["subtotal"], r["tax"], r["total"], r["full_json"], ] for r in table_data ] status_msg = f"Processed {len(rows)} receipt(s)" if len(rows) > 50: status_msg += f" — showing first 50 thumbnails" return rows, display_imgs, json_str, status_msg batch_btn.click( fn=on_batch_click, inputs=batch_files, outputs=[batch_results, batch_gallery, batch_json, batch_status], ) batch_download.click( fn=lambda s: s, inputs=batch_json, outputs=batch_download, ) # ── Info Tab ─────────────────────────────────────────────── with gr.TabItem("ℹ️ About"): gr.Markdown( """ ### Model Details - **Architecture:** Donut (Vision Encoder + Text Decoder) - **Fine-tuned on:** 8,615 real-world receipt images - **Training hardware:** Google Cloud L4 GPU (bf16 mixed precision) - **Base model:** `naver-clova-ix/donut-base` ### Extracted Fields | Field | Description | |-------|-------------| | `merchant` | Store or company name | | `date` | Transaction date | | `subtotal` | Amount before tax | | `tax` | Tax amount | | `total` | Final amount | ### GitHub [Awarebeyond/receipt-donut](https://huggingface.co/Awarebeyond/receipt-donut) """ ) # Clean shutdown handler: prevents asyncio garbage-collection errors # when HF Spaces sends SIGTERM during container restart def _handle_sigterm(signum, frame): print("\nReceived SIGTERM, shutting down gracefully...") sys.exit(0) signal.signal(signal.SIGTERM, _handle_sigterm) demo.launch(theme=gr.themes.Soft(), prevent_thread_lock=True) try: while True: time.sleep(1) except (KeyboardInterrupt, SystemExit): print("Shutdown complete.")