Sivakkanth
queue enqbled
f121f9b
import gradio as gr
import cv2
from ultralytics import YOLO
import easyocr
import numpy as np
import re
from datetime import datetime
from PIL import Image
# Load YOLO model
model = YOLO("model/best.pt")
# Initialize OCR
reader = easyocr.Reader(['en'])
# Class names
class_names = ["Merchant","date","total","no","item"]
# Regex for numbers
NUMBER_RE = re.compile(r"\d+(?:\.\d+)?")
NUMBER_RE_PARSE = re.compile(r'[-+]?\d{1,3}(?:[,\d]*\d)?(?:[.,]\d{1,2})?')
# ---------- Helper functions ----------
def normalize_ocr_text(s: str) -> str:
s = s.replace('\n',' ').strip()
s = re.sub(r'\s{2,}', ' ', s)
return s
def extract_numbers_parse(s: str):
tokens = NUMBER_RE_PARSE.findall(s)
nums = []
for t in tokens:
t_norm = t.replace(',', '')
if ',' in t and '.' not in t and re.search(r',\d{1,2}$', t):
t_norm = t.replace(',', '.')
try:
nums.append(float(t_norm))
except:
continue
return nums
def pick_price_from_numbers(numbers, original_str):
if not numbers:
return None
if len(numbers) > 1:
largest = max(numbers)
if numbers.count(largest) >= 1:
return largest
matches = NUMBER_RE_PARSE.finditer(original_str)
found = [m.group(0) for m in matches]
if found:
last = found[-1]
last_val = None
try:
t = found[-1].replace(',', '')
if ',' in found[-1] and '.' not in found[-1] and re.search(r',\d{1,2}$', found[-1]):
t = found[-1].replace(',', '.')
last_val = float(t)
except:
last_val = None
if last_val is not None:
return last_val
return largest
else:
return numbers[0]
def clean_product_name(s: str):
s = re.sub(r'\b(x|qty|pcs|pc|nos|no|each)\b', '', s, flags=re.IGNORECASE)
s = re.sub(NUMBER_RE_PARSE, '', s)
s = re.sub(r'[\$₹£€:,()*`"“”]', ' ', s)
s = re.sub(r'\s{2,}', ' ', s).strip()
return s
def parse_line_item(raw_line: str):
raw = normalize_ocr_text(raw_line)
numbers = extract_numbers_parse(raw)
price = pick_price_from_numbers(numbers, raw)
product = clean_product_name(raw)
return {"product": product if product else raw_line, "price": f"{price:.2f}" if price is not None else ""}
def extract_total_amount(total_str: str):
if not total_str:
return None
matches = NUMBER_RE.findall(total_str)
for m in matches[::-1]:
try:
return float(m.replace(",",""))
except:
continue
return None
def parse_date(text):
text = text.replace('Date','').replace('date','').replace(':','').strip()
patterns = [r"(\d{2}[/-]\d{2}[/-]\d{2,4})", r"(\d{4}[/-]\d{2}[/-]\d{2})"]
for pat in patterns:
match = re.search(pat, text)
if match:
dt_str = match.group(1)
for fmt in ("%d/%m/%y", "%d/%m/%Y", "%Y-%m-%d"):
try:
dt = datetime.strptime(dt_str, fmt)
return dt.strftime("%Y-%m-%d")
except:
continue
return None
def parse_time(text):
text = text.replace('Time','').replace('time','').replace(':','').strip()
patterns = [r"(\d{1,2}:\d{2}(:\d{2})?)"]
for pat in patterns:
match = re.search(pat, text)
if match:
tm_str = match.group(1)
for fmt in ("%H:%M:%S","%H:%M"):
try:
tm = datetime.strptime(tm_str, fmt)
return tm.strftime("%H:%M:%S")
except:
continue
return None
# ---------- Main extraction function ----------
def extract_receipt(image):
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
results = model(img)[0]
output = {"items": [], "name": "", "total": "", "date": "", "time": "", "discount": 0.0, "tax": 0.0}
for box, cls_id, conf in zip(results.boxes.xyxy, results.boxes.cls, results.boxes.conf):
x1, y1, x2, y2 = [int(i) for i in box]
cls_name = class_names[int(cls_id)]
crop = img[y1:y2, x1:x2]
text_result = reader.readtext(crop)
text = " ".join([t[1] for t in text_result])
if cls_name == "Merchant":
output["name"] = text
elif cls_name == "date":
parsed_date = parse_date(text)
if parsed_date:
output["date"] = parsed_date
elif cls_name == "total":
output["total"] = text
elif cls_name == "no":
parsed_time = parse_time(text)
if parsed_time:
output["time"] = parsed_time
elif cls_name == "item":
# Initial extraction
parts = text.rsplit(" ", 1)
old_price = None
product_text = text
if len(parts) == 2:
try:
old_price = float(parts[1].replace(",", "").replace("O", "0"))
product_text = parts[0]
except:
old_price = None
# Refine using parse_line_item
parsed = parse_line_item(text)
new_product = parsed["product"]
new_price = float(parsed["price"]) if parsed["price"] else None
candidates = []
if old_price is not None and old_price <= 1000000:
candidates.append(old_price)
if new_price is not None and new_price <= 1000000:
candidates.append(new_price)
final_price = round(min(candidates), 2) if candidates else None
output["items"].append({
"product": new_product,
"price": final_price if final_price is not None else ""
})
# ---------- Post-processing totals ----------
model_total = extract_total_amount(output.get("total",""))
item_sum = sum([it["price"] for it in output["items"] if it.get("price") not in ("",None)])
if model_total is None or model_total > item_sum*10:
model_total = round(item_sum,2)
tax, discount = 0.0, 0.0
else:
if abs(model_total - item_sum) < 0.01:
tax, discount = 0.0, 0.0
elif model_total > item_sum:
tax, discount = round(model_total - item_sum,2), 0.0
else:
tax, discount = 0.0, round(item_sum - model_total,2)
output["total"] = model_total
output["tax"] = tax
output["discount"] = discount
# ---------- Fill missing date/time ----------
now = datetime.now()
if not output["date"]:
output["date"] = now.strftime("%Y-%m-%d")
if not output["time"]:
output["time"] = now.strftime("%H:%M:%S")
# ---------- Generate YOLO prediction image ----------
yolo_img = results.plot()
yolo_img = cv2.cvtColor(yolo_img, cv2.COLOR_BGR2RGB)
yolo_img_pil = Image.fromarray(yolo_img)
return output, yolo_img_pil
# ---------- Gradio Interface ----------
iface = gr.Interface(
fn=extract_receipt,
inputs=gr.Image(type="pil"),
outputs=[gr.JSON(), gr.Image(type="pil")],
title="Receipt Extractor",
description="Upload a receipt image to extract merchant, date, total, time, items, and see YOLO predictions.",
allow_flagging="never"
)
iface.launch(share=True)