Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| from ultralytics import YOLO | |
| import easyocr | |
| import numpy as np | |
| import re | |
| from datetime import datetime | |
| from PIL import Image | |
| # Load YOLO model | |
| model = YOLO("model/best.pt") | |
| # Initialize OCR | |
| reader = easyocr.Reader(['en']) | |
| # Class names | |
| class_names = ["Merchant","date","total","no","item"] | |
| # Regex for numbers | |
| NUMBER_RE = re.compile(r"\d+(?:\.\d+)?") | |
| NUMBER_RE_PARSE = re.compile(r'[-+]?\d{1,3}(?:[,\d]*\d)?(?:[.,]\d{1,2})?') | |
| # ---------- Helper functions ---------- | |
| def normalize_ocr_text(s: str) -> str: | |
| s = s.replace('\n',' ').strip() | |
| s = re.sub(r'\s{2,}', ' ', s) | |
| return s | |
| def extract_numbers_parse(s: str): | |
| tokens = NUMBER_RE_PARSE.findall(s) | |
| nums = [] | |
| for t in tokens: | |
| t_norm = t.replace(',', '') | |
| if ',' in t and '.' not in t and re.search(r',\d{1,2}$', t): | |
| t_norm = t.replace(',', '.') | |
| try: | |
| nums.append(float(t_norm)) | |
| except: | |
| continue | |
| return nums | |
| def pick_price_from_numbers(numbers, original_str): | |
| if not numbers: | |
| return None | |
| if len(numbers) > 1: | |
| largest = max(numbers) | |
| if numbers.count(largest) >= 1: | |
| return largest | |
| matches = NUMBER_RE_PARSE.finditer(original_str) | |
| found = [m.group(0) for m in matches] | |
| if found: | |
| last = found[-1] | |
| last_val = None | |
| try: | |
| t = found[-1].replace(',', '') | |
| if ',' in found[-1] and '.' not in found[-1] and re.search(r',\d{1,2}$', found[-1]): | |
| t = found[-1].replace(',', '.') | |
| last_val = float(t) | |
| except: | |
| last_val = None | |
| if last_val is not None: | |
| return last_val | |
| return largest | |
| else: | |
| return numbers[0] | |
| def clean_product_name(s: str): | |
| s = re.sub(r'\b(x|qty|pcs|pc|nos|no|each)\b', '', s, flags=re.IGNORECASE) | |
| s = re.sub(NUMBER_RE_PARSE, '', s) | |
| s = re.sub(r'[\$₹£€:,()*`"“”]', ' ', s) | |
| s = re.sub(r'\s{2,}', ' ', s).strip() | |
| return s | |
| def parse_line_item(raw_line: str): | |
| raw = normalize_ocr_text(raw_line) | |
| numbers = extract_numbers_parse(raw) | |
| price = pick_price_from_numbers(numbers, raw) | |
| product = clean_product_name(raw) | |
| return {"product": product if product else raw_line, "price": f"{price:.2f}" if price is not None else ""} | |
| def extract_total_amount(total_str: str): | |
| if not total_str: | |
| return None | |
| matches = NUMBER_RE.findall(total_str) | |
| for m in matches[::-1]: | |
| try: | |
| return float(m.replace(",","")) | |
| except: | |
| continue | |
| return None | |
| def parse_date(text): | |
| text = text.replace('Date','').replace('date','').replace(':','').strip() | |
| patterns = [r"(\d{2}[/-]\d{2}[/-]\d{2,4})", r"(\d{4}[/-]\d{2}[/-]\d{2})"] | |
| for pat in patterns: | |
| match = re.search(pat, text) | |
| if match: | |
| dt_str = match.group(1) | |
| for fmt in ("%d/%m/%y", "%d/%m/%Y", "%Y-%m-%d"): | |
| try: | |
| dt = datetime.strptime(dt_str, fmt) | |
| return dt.strftime("%Y-%m-%d") | |
| except: | |
| continue | |
| return None | |
| def parse_time(text): | |
| text = text.replace('Time','').replace('time','').replace(':','').strip() | |
| patterns = [r"(\d{1,2}:\d{2}(:\d{2})?)"] | |
| for pat in patterns: | |
| match = re.search(pat, text) | |
| if match: | |
| tm_str = match.group(1) | |
| for fmt in ("%H:%M:%S","%H:%M"): | |
| try: | |
| tm = datetime.strptime(tm_str, fmt) | |
| return tm.strftime("%H:%M:%S") | |
| except: | |
| continue | |
| return None | |
| # ---------- Main extraction function ---------- | |
| def extract_receipt(image): | |
| img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| results = model(img)[0] | |
| output = {"items": [], "name": "", "total": "", "date": "", "time": "", "discount": 0.0, "tax": 0.0} | |
| for box, cls_id, conf in zip(results.boxes.xyxy, results.boxes.cls, results.boxes.conf): | |
| x1, y1, x2, y2 = [int(i) for i in box] | |
| cls_name = class_names[int(cls_id)] | |
| crop = img[y1:y2, x1:x2] | |
| text_result = reader.readtext(crop) | |
| text = " ".join([t[1] for t in text_result]) | |
| if cls_name == "Merchant": | |
| output["name"] = text | |
| elif cls_name == "date": | |
| parsed_date = parse_date(text) | |
| if parsed_date: | |
| output["date"] = parsed_date | |
| elif cls_name == "total": | |
| output["total"] = text | |
| elif cls_name == "no": | |
| parsed_time = parse_time(text) | |
| if parsed_time: | |
| output["time"] = parsed_time | |
| elif cls_name == "item": | |
| # Initial extraction | |
| parts = text.rsplit(" ", 1) | |
| old_price = None | |
| product_text = text | |
| if len(parts) == 2: | |
| try: | |
| old_price = float(parts[1].replace(",", "").replace("O", "0")) | |
| product_text = parts[0] | |
| except: | |
| old_price = None | |
| # Refine using parse_line_item | |
| parsed = parse_line_item(text) | |
| new_product = parsed["product"] | |
| new_price = float(parsed["price"]) if parsed["price"] else None | |
| candidates = [] | |
| if old_price is not None and old_price <= 1000000: | |
| candidates.append(old_price) | |
| if new_price is not None and new_price <= 1000000: | |
| candidates.append(new_price) | |
| final_price = round(min(candidates), 2) if candidates else None | |
| output["items"].append({ | |
| "product": new_product, | |
| "price": final_price if final_price is not None else "" | |
| }) | |
| # ---------- Post-processing totals ---------- | |
| model_total = extract_total_amount(output.get("total","")) | |
| item_sum = sum([it["price"] for it in output["items"] if it.get("price") not in ("",None)]) | |
| if model_total is None or model_total > item_sum*10: | |
| model_total = round(item_sum,2) | |
| tax, discount = 0.0, 0.0 | |
| else: | |
| if abs(model_total - item_sum) < 0.01: | |
| tax, discount = 0.0, 0.0 | |
| elif model_total > item_sum: | |
| tax, discount = round(model_total - item_sum,2), 0.0 | |
| else: | |
| tax, discount = 0.0, round(item_sum - model_total,2) | |
| output["total"] = model_total | |
| output["tax"] = tax | |
| output["discount"] = discount | |
| # ---------- Fill missing date/time ---------- | |
| now = datetime.now() | |
| if not output["date"]: | |
| output["date"] = now.strftime("%Y-%m-%d") | |
| if not output["time"]: | |
| output["time"] = now.strftime("%H:%M:%S") | |
| # ---------- Generate YOLO prediction image ---------- | |
| yolo_img = results.plot() | |
| yolo_img = cv2.cvtColor(yolo_img, cv2.COLOR_BGR2RGB) | |
| yolo_img_pil = Image.fromarray(yolo_img) | |
| return output, yolo_img_pil | |
| # ---------- Gradio Interface ---------- | |
| iface = gr.Interface( | |
| fn=extract_receipt, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.JSON(), gr.Image(type="pil")], | |
| title="Receipt Extractor", | |
| description="Upload a receipt image to extract merchant, date, total, time, items, and see YOLO predictions.", | |
| allow_flagging="never" | |
| ) | |
| iface.launch(share=True) |