Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import sys | |
| import time | |
| import signal | |
| import asyncio | |
| import torch | |
| import cv2 | |
| import json | |
| import base64 | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| # Fix Python 3.13 asyncio garbage-collection bug on container restart. | |
| # See: https://github.com/python/cpython/issues/109496 | |
| _original_del = asyncio.BaseEventLoop.__del__ | |
| def _safe_loop_del(self): | |
| try: | |
| _original_del(self) | |
| except (ValueError, OSError): | |
| pass # ignore already-closed file descriptors | |
| asyncio.BaseEventLoop.__del__ = _safe_loop_del | |
| # Configuration | |
| MODEL_REPO = "Awarebeyond/receipt-donut" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model from {MODEL_REPO}...") | |
| print(f"Using device: {DEVICE}") | |
| processor = DonutProcessor.from_pretrained(MODEL_REPO) | |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_REPO) | |
| model.to(DEVICE) | |
| model.eval() | |
| def preprocess_image(pil_image): | |
| """Apply exact same preprocessing as training.""" | |
| image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = cv2.resize(image, (1536, 1152), interpolation=cv2.INTER_LINEAR) | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| image = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(image) | |
| def process_single(pil_image): | |
| """Process a single image and return parsed JSON dict or error string.""" | |
| try: | |
| processed = preprocess_image(pil_image) | |
| pixel_values = processor(processed, return_tensors="pt").pixel_values.to(DEVICE) | |
| decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=512, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| ) | |
| seq = processor.tokenizer.batch_decode(outputs.sequences)[0] | |
| seq = seq.replace(processor.tokenizer.eos_token, "").replace( | |
| processor.tokenizer.pad_token, "" | |
| ) | |
| seq = ( | |
| seq.replace( | |
| processor.tokenizer.decode([model.config.decoder_start_token_id]), "" | |
| ) | |
| .strip() | |
| ) | |
| try: | |
| parsed = json.loads(seq) | |
| return parsed | |
| except json.JSONDecodeError: | |
| return {"raw_output": seq} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def extract_receipt_single(image): | |
| """Single image inference for the Single tab.""" | |
| if image is None: | |
| return "Please upload an image." | |
| result = process_single(image) | |
| return json.dumps(result, indent=2, ensure_ascii=False) | |
| def extract_receipt_batch(files, progress=gr.Progress()): | |
| """Batch inference for multiple files.""" | |
| if not files: | |
| return [], None, "No images uploaded." | |
| results = [] | |
| gallery_images = [] | |
| progress(0, total=len(files), desc="Processing receipts...") | |
| for i, file in enumerate(files): | |
| try: | |
| if isinstance(file, str): | |
| img = Image.open(file).convert("RGB") | |
| else: | |
| img = Image.open(file.name).convert("RGB") | |
| gallery_images.append(img) | |
| parsed = process_single(img) | |
| results.append( | |
| { | |
| "filename": os.path.basename(file.name if hasattr(file, "name") else str(file)), | |
| "merchant": parsed.get("merchant", "N/A"), | |
| "date": parsed.get("date", "N/A"), | |
| "subtotal": parsed.get("subtotal", "N/A"), | |
| "tax": parsed.get("tax", "N/A"), | |
| "total": parsed.get("total", "N/A"), | |
| "full_json": json.dumps(parsed, ensure_ascii=False), | |
| } | |
| ) | |
| except Exception as e: | |
| results.append( | |
| { | |
| "filename": "error", | |
| "merchant": "ERROR", | |
| "date": "ERROR", | |
| "subtotal": "ERROR", | |
| "tax": "ERROR", | |
| "total": "ERROR", | |
| "full_json": str(e), | |
| } | |
| ) | |
| progress(i + 1, total=len(files)) | |
| # Build downloadable JSON | |
| output_json = json.dumps(results, indent=2, ensure_ascii=False) | |
| return results, gallery_images, output_json | |
| def create_download_file(json_str): | |
| """Create a temporary file for downloading.""" | |
| return json_str | |
| # ββ Gradio Blocks App ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="π§Ύπ© Receipt Donut") as demo: | |
| gr.Markdown( | |
| """ | |
| # π§Ύπ© Receipt Donut β Live Receipt Extraction | |
| **Fine-tuned Donut model for structured receipt extraction.** | |
| Upload one or multiple receipt images to instantly extract merchant, date, subtotal, tax, and total. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # ββ Single Upload Tab ββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π Single Receipt"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| single_input = gr.Image( | |
| type="pil", | |
| label="Upload Receipt Image", | |
| sources=["upload", "clipboard"], | |
| ) | |
| single_btn = gr.Button( | |
| "π Extract Data", variant="primary", size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| single_output = gr.Code( | |
| label="Extracted JSON", | |
| language="json", | |
| lines=18, | |
| ) | |
| single_btn.click( | |
| fn=extract_receipt_single, | |
| inputs=single_input, | |
| outputs=single_output, | |
| ) | |
| # ββ Batch Upload Tab βββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π Batch Processing"): | |
| gr.Markdown( | |
| "Upload multiple receipts at once. Results appear in a table with JSON download. " | |
| "**Note:** For bulk uploads, only the first 50 thumbnails are shown to keep the page fast." | |
| ) | |
| batch_files = gr.File( | |
| label="Upload Receipt Images", | |
| file_count="multiple", | |
| file_types=["image"], | |
| height=120, # compact file list | |
| ) | |
| batch_btn = gr.Button( | |
| "π Process All Receipts", variant="primary", size="lg" | |
| ) | |
| batch_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| value="Ready β upload images to begin", | |
| ) | |
| batch_gallery = gr.Gallery( | |
| label="Receipt Thumbnails (first 50 shown)", | |
| columns=10, | |
| rows=5, | |
| height=400, | |
| object_fit="cover", | |
| preview=True, | |
| ) | |
| batch_results = gr.Dataframe( | |
| headers=["Filename", "Merchant", "Date", "Subtotal", "Tax", "Total", "Full JSON"], | |
| label="Extraction Results", | |
| wrap=True, | |
| column_widths=["15%", "15%", "10%", "12%", "10%", "12%", "26%"], | |
| ) | |
| batch_json = gr.State("") | |
| batch_download = gr.DownloadButton( | |
| label="β¬οΈ Download All Results (JSON)", | |
| variant="secondary", | |
| ) | |
| def on_batch_click(files): | |
| if not files: | |
| return [], None, "", "No images uploaded." | |
| if len(files) > 200: | |
| return [], None, "", "β οΈ Please upload 200 or fewer receipts at once." | |
| table_data, gallery_imgs, json_str = extract_receipt_batch(files) | |
| # Show only first 50 thumbnails for performance | |
| display_imgs = gallery_imgs[:50] if gallery_imgs else None | |
| rows = [ | |
| [ | |
| r["filename"], | |
| r["merchant"], | |
| r["date"], | |
| r["subtotal"], | |
| r["tax"], | |
| r["total"], | |
| r["full_json"], | |
| ] | |
| for r in table_data | |
| ] | |
| status_msg = f"Processed {len(rows)} receipt(s)" | |
| if len(rows) > 50: | |
| status_msg += f" β showing first 50 thumbnails" | |
| return rows, display_imgs, json_str, status_msg | |
| batch_btn.click( | |
| fn=on_batch_click, | |
| inputs=batch_files, | |
| outputs=[batch_results, batch_gallery, batch_json, batch_status], | |
| ) | |
| batch_download.click( | |
| fn=lambda s: s, | |
| inputs=batch_json, | |
| outputs=batch_download, | |
| ) | |
| # ββ Info Tab βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("βΉοΈ About"): | |
| gr.Markdown( | |
| """ | |
| ### Model Details | |
| - **Architecture:** Donut (Vision Encoder + Text Decoder) | |
| - **Fine-tuned on:** 8,615 real-world receipt images | |
| - **Training hardware:** Google Cloud L4 GPU (bf16 mixed precision) | |
| - **Base model:** `naver-clova-ix/donut-base` | |
| ### Extracted Fields | |
| | Field | Description | | |
| |-------|-------------| | |
| | `merchant` | Store or company name | | |
| | `date` | Transaction date | | |
| | `subtotal` | Amount before tax | | |
| | `tax` | Tax amount | | |
| | `total` | Final amount | | |
| ### GitHub | |
| [Awarebeyond/receipt-donut](https://huggingface.co/Awarebeyond/receipt-donut) | |
| """ | |
| ) | |
| # Clean shutdown handler: prevents asyncio garbage-collection errors | |
| # when HF Spaces sends SIGTERM during container restart | |
| def _handle_sigterm(signum, frame): | |
| print("\nReceived SIGTERM, shutting down gracefully...") | |
| sys.exit(0) | |
| signal.signal(signal.SIGTERM, _handle_sigterm) | |
| demo.launch(theme=gr.themes.Soft(), prevent_thread_lock=True) | |
| try: | |
| while True: | |
| time.sleep(1) | |
| except (KeyboardInterrupt, SystemExit): | |
| print("Shutdown complete.") | |