receipt_scanner / app.py
Roy7384's picture
Update app.py
e8e83e5 verified
Raw
History Blame Contribute Delete
15.9 kB
"""
Receipt Scanner β€” AI-powered receipt parser using MiniCPM-V 4.6
Deploy to Hugging Face Spaces (GPU T4 small or better recommended).
"""
# `spaces` MUST be imported before torch/transformers on HF Spaces β€”
# the package hooks into CUDA initialisation and raises a RuntimeError
# if anything has already touched CUDA before it loads.
# The try/except makes the same file work fine when running locally.
try:
import spaces # noqa: F401
except ImportError:
pass
import json
import re
import io
import base64
import numpy as np
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
# ─────────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────────
MODEL_ID = "openbmb/MiniCPM-V-4.6"
DOWNSAMPLE_MODE = "4x" # "4x" = finer detail, ideal for dense receipt text
MAX_SLICE_NUMS = 36 # allow high-res slicing for sharp photos
MAX_NEW_TOKENS = 1200
# ─────────────────────────────────────────────────────────────────────────────
# Structured extraction prompt
# ─────────────────────────────────────────────────────────────────────────────
RECEIPT_PROMPT = """\
You are a precise receipt data extractor. Carefully read every part of the receipt image.
Return ONLY a valid JSON object β€” no markdown fences, no explanation, nothing else.
Use this exact schema (set any unknown field to null):
{
"store": {
"name": "string | null",
"address": "string | null",
"phone": "string | null"
},
"transaction": {
"date": "YYYY-MM-DD string | null",
"time": "HH:MM string | null",
"receipt_number": "string | null",
"cashier": "string | null"
},
"items": [
{
"name": "string",
"quantity": number,
"unit_price": number | null,
"total_price": number
}
],
"subtotal": number | null,
"discounts": number | null,
"tax": number | null,
"tax_rate": "string | null",
"total": number | null,
"payment": {
"method": "string | null",
"amount_tendered": number | null,
"change": number | null
},
"currency": "string"
}
Rules:
- Numbers must be numeric (e.g. 4.99), never strings.
- If quantity is not printed, assume 1.
- Extract EVERY line item you can see.
- For discounts/coupons, use a positive number (it will be shown as a deduction).
- Currency: use the symbol or 3-letter ISO code visible on the receipt (default "$").
"""
# ─────────────────────────────────────────────────────────────────────────────
# Utility β€” normalise escaped newlines emitted by some model responses
# (taken from the official MiniCPM-V 4.6 model card)
# ─────────────────────────────────────────────────────────────────────────────
_NL_PATTERN = re.compile(
r"(```[\s\S]*?```|`[^`]+`|\$\$[\s\S]*?\$\$|\$[^$]+\$"
r"|\\\([\s\S]*?\\\)|\\\[[\s\S]*?\\\])"
r"|(?<!\\)(?:\\r\\n|\\[nr])"
)
def _normalize(text: str) -> str:
if not isinstance(text, str) or "\\" not in text:
return text
return _NL_PATTERN.sub(lambda m: m.group(1) or "\n", text)
# ─────────────────────────────────────────────────────────────────────────────
# Model β€” lazy-loaded on first inference (required for ZeroGPU)
# ─────────────────────────────────────────────────────────────────────────────
_processor = None
_model = None
def _get_model():
global _processor, _model
if _model is None:
print(f"Loading {MODEL_ID} …")
_processor = AutoProcessor.from_pretrained(MODEL_ID)
_model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="cuda",
)
_model.eval()
print("βœ“ Model ready")
return _processor, _model
# ─────────────────────────────────────────────────────────────────────────────
# Inference
# ─────────────────────────────────────────────────────────────────────────────
def _to_pil(image) -> Image.Image:
"""Accept numpy array (Gradio) or PIL Image."""
if isinstance(image, np.ndarray):
return Image.fromarray(image).convert("RGB")
return image.convert("RGB")
@spaces.GPU
def run_model(pil_image: Image.Image) -> str:
"""Run the model and return raw text output."""
processor, model = _get_model()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": pil_image},
{"type": "text", "text": RECEIPT_PROMPT},
],
}
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
downsample_mode=DOWNSAMPLE_MODE,
max_slice_nums=MAX_SLICE_NUMS,
).to(model.device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
downsample_mode=DOWNSAMPLE_MODE,
do_sample=False,
)
trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
text = processor.batch_decode(
trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return _normalize(text)
# ─────────────────────────────────────────────────────────────────────────────
# JSON extraction & formatting
# ─────────────────────────────────────────────────────────────────────────────
def _extract_json(raw: str) -> dict | None:
"""Strip markdown fences and parse the first JSON object found."""
# Remove ```json … ``` wrappers
cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", raw.strip(), flags=re.MULTILINE)
match = re.search(r"\{[\s\S]*\}", cleaned)
if not match:
return None
try:
return json.loads(match.group())
except json.JSONDecodeError:
return None
def _fmt(value, sym: str = "") -> str:
if value is None:
return "β€”"
try:
return f"{sym}{float(value):.2f}"
except (TypeError, ValueError):
return str(value)
def build_markdown(d: dict) -> str:
lines: list[str] = []
# Currency symbol
raw_cur = d.get("currency") or "$"
sym = raw_cur if len(raw_cur) == 1 else "$"
# ── Store ────────────────────────────────────────────────────────────────
store = d.get("store") or {}
if store.get("name"):
lines.append(f"## πŸͺ {store['name']}")
if store.get("address"):
lines.append(f"πŸ“ {store['address']}")
if store.get("phone"):
lines.append(f"πŸ“ž {store['phone']}")
# ── Transaction metadata ─────────────────────────────────────────────────
tx = d.get("transaction") or {}
tx_lines = []
if tx.get("date"): tx_lines.append(f"πŸ“… **Date:** {tx['date']}")
if tx.get("time"): tx_lines.append(f"πŸ• **Time:** {tx['time']}")
if tx.get("receipt_number"): tx_lines.append(f"🧾 **Receipt #:** {tx['receipt_number']}")
if tx.get("cashier"): tx_lines.append(f"πŸ‘€ **Cashier:** {tx['cashier']}")
if tx_lines:
lines.append("")
lines.extend(tx_lines)
# ── Line items ───────────────────────────────────────────────────────────
items = d.get("items") or []
if items:
lines += ["", "---", "### πŸ›’ Items Purchased", ""]
for item in items:
name = item.get("name", "Unknown")
qty = item.get("quantity") or 1
total = item.get("total_price")
unit = item.get("unit_price")
unit_str = ""
if unit is not None and qty != 1:
unit_str = f" ({_fmt(unit, sym)} ea.)"
lines.append(f"- **{name}** Γ—{qty}{unit_str} &nbsp;β†’&nbsp; **{_fmt(total, sym)}**")
# ── Totals ───────────────────────────────────────────────────────────────
lines += ["", "---", ""]
if d.get("subtotal") is not None:
lines.append(f"Subtotal: &nbsp; {_fmt(d['subtotal'], sym)}")
if d.get("discounts") and float(d.get("discounts") or 0) != 0:
lines.append(f"Discounts: &nbsp; βˆ’{_fmt(abs(float(d['discounts'])), sym)}")
if d.get("tax") is not None:
rate_str = f" ({d['tax_rate']})" if d.get("tax_rate") else ""
lines.append(f"Tax{rate_str}: &nbsp; {_fmt(d['tax'], sym)}")
if d.get("total") is not None:
lines.append(f"\n### πŸ’° Total: {_fmt(d['total'], sym)}")
# ── Payment ──────────────────────────────────────────────────────────────
pay = d.get("payment") or {}
pay_lines = []
if pay.get("method"): pay_lines.append(f"πŸ’³ **Method:** {pay['method']}")
if pay.get("amount_tendered") is not None:
pay_lines.append(f"πŸ’΅ **Tendered:** {_fmt(pay['amount_tendered'], sym)}")
if pay.get("change") is not None:
pay_lines.append(f"πŸ”„ **Change:** {_fmt(pay['change'], sym)}")
if pay_lines:
lines.append("")
lines.extend(pay_lines)
# Currency code (only show when it's a 3-letter code, not a symbol)
if raw_cur and len(raw_cur) > 1:
lines.append(f"\n*Currency: {raw_cur}*")
return "\n".join(lines)
# ─────────────────────────────────────────────────────────────────────────────
# Top-level handler wired to Gradio
# ─────────────────────────────────────────────────────────────────────────────
def parse_receipt(image) -> tuple[str, str]:
"""
Returns (markdown_summary, json_string).
Gradio calls this with a numpy array or None.
"""
if image is None:
return "⚠️ Please upload or capture a receipt image to begin.", "{}"
pil_image = _to_pil(image)
try:
raw_text = run_model(pil_image)
except Exception as exc:
return f"❌ Model error: {exc}", "{}"
data = _extract_json(raw_text)
if data is None:
# Model returned non-JSON β€” show raw text as fallback
return f"**Raw model output (JSON parse failed):**\n\n```\n{raw_text}\n```", "{}"
markdown = build_markdown(data)
json_str = json.dumps(data, indent=2, ensure_ascii=False)
return markdown, json_str
# ─────────────────────────────────────────────────────────────────────────────
# Gradio UI
# ─────────────────────────────────────────────────────────────────────────────
TIPS = """\
**Tips for best results:**
- Hold the camera directly above the receipt (avoid angles)
- Make sure the receipt is fully visible and well-lit
- Flatten crumpled receipts before scanning
"""
with gr.Blocks(title="🧾 AI Receipt Scanner") as demo:
gr.Markdown("""
# 🧾 AI Receipt Scanner
Upload a receipt photo or snap one with your camera.
The model extracts every line item, price, tax, and store metadata automatically.
""")
with gr.Row(equal_height=False):
# ── Input column ─────────────────────────────────────────────────────
with gr.Column(scale=1):
image_input = gr.Image(
label="Receipt Image",
sources=["upload", "webcam", "clipboard"],
type="numpy",
height=500,
image_mode="RGB",
)
scan_btn = gr.Button("πŸ” Scan Receipt", variant="primary", size="lg")
gr.Markdown(TIPS)
# ── Output column ────────────────────────────────────────────────────
with gr.Column(scale=1):
with gr.Tabs():
with gr.TabItem("πŸ“‹ Summary"):
summary_out = gr.Markdown(
value="*Scan a receipt to see results here.*"
)
with gr.TabItem("{ } Raw JSON"):
json_out = gr.Code(
value="{}",
language="json",
label="Structured JSON output",
interactive=False,
)
# Wire up the button
scan_btn.click(
fn=parse_receipt,
inputs=[image_input],
outputs=[summary_out, json_out],
api_name="scan",
)
# Also scan automatically when an image is uploaded/captured
image_input.change(
fn=parse_receipt,
inputs=[image_input],
outputs=[summary_out, json_out],
)
gr.Markdown("""
---
*Powered by [MiniCPM-V 4.6](https://huggingface.co/openbmb/MiniCPM-V-4.6) β€” a lightweight 1.3 B multimodal model.*
*Source: [OpenBMB / MiniCPM-V](https://github.com/OpenBMB/MiniCPM-V)*
""")
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"), share=True)