inv-ocr / app.py
abinash73's picture
Update app.py
6ac6ef8 verified
# ================================
# FIX PIL ANTIALIAS (new Pillow versions)
# ================================
from PIL import Image
if not hasattr(Image, "ANTIALIAS"):
Image.ANTIALIAS = Image.Resampling.LANCZOS
# ================================
# IMPORTS
# ================================
import gradio as gr
import easyocr
import fitz # PyMuPDF
import numpy as np
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image as PILImage
import re
import os
import tempfile
# =====================================================
# LOAD OCR MODEL (only once)
# =====================================================
print("Loading EasyOCR model...")
reader = easyocr.Reader(["en"], gpu=False)
print("Model loaded successfully!")
# =====================================================
# PDF → IMAGE CONVERSION
# =====================================================
def pdf_to_images(pdf_path):
images = []
try:
doc = fitz.open(pdf_path)
for page in doc:
pix = page.get_pixmap(dpi=200)
img = PILImage.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
except Exception as e:
print("PDF ERROR:", e)
return []
return images
# =====================================================
# EXTRACT ITEMS (Enhanced parsing)
# =====================================================
def extract_items(text):
lines = [l.strip() for l in text.split("\n") if l.strip()]
items = []
# Common patterns - updated for better matching
serial_regex = r"^\d+\s+" # Serial at start: "1 "
hsn_regex = r"\b\d{6,8}\b" # 6-8 digit HSN
amount_regex = r"\d{1,3}(?:,\d{3})*(?:\.\d{2})?" # Amounts with commas
qty_regex = r"(\d+(?:\.\d{1,2})?)\s*(NOS|PCS|PC|KG|KILO|LTR|LITRE|MTR|METRE|BOX|SET|UNIT|EA|EACH)"
i = 0
n = len(lines)
while i < n:
line = lines[i]
# Check if this line starts with serial number followed by description
serial_match = re.match(serial_regex, line)
if serial_match:
try:
# Extract serial number
serial = serial_match.group().strip()
# Extract product description (rest of the line after serial)
remaining = line[len(serial_match.group()):].strip()
# Skip if this looks like table headers
if re.search(r"Description|Goods|HSN|Quantity|Rate|Amount", remaining, re.I):
i += 1
continue
product_name = remaining
# Initialize variables
hsn = ""
part_no = ""
quantity = ""
uom = ""
rate_excl_tax = ""
gst_percentage = ""
gst_amount = ""
discount_amount = ""
discount_percentage = ""
taxable_value = ""
# Search in next 20 lines for related data
search_end = min(i + 20, n)
for k in range(i + 1, search_end):
current_line = lines[k]
# Stop if we hit next serial number
if re.match(serial_regex, current_line):
break
# Stop if we hit total/subtotal
if re.search(r"^(Total|Sub|CGST|SGST|IGST|Round)", current_line, re.I):
break
# Extract HSN
if not hsn:
hsn_match = re.search(hsn_regex, current_line)
if hsn_match:
potential_hsn = hsn_match.group()
# Make sure it's not a phone number or other number
if not re.search(rf"\b{potential_hsn}\b.*\b{potential_hsn}\b", current_line):
hsn = potential_hsn
# Extract Part Number - look for patterns like "84408596-P"
if not part_no:
# Pattern: alphanumeric with optional -P suffix
part_match = re.search(r"\b([A-Z0-9]{6,}(?:-[A-Z0-9]+)?)\b", current_line)
if part_match:
potential_part = part_match.group(1)
# Not HSN, not phone, not date
if (potential_part != hsn and
len(potential_part) >= 6 and
not re.match(r'^\d{10}$', potential_part)):
part_no = potential_part
# Extract Quantity and UOM
if not quantity:
qty_match = re.search(qty_regex, current_line, re.I)
if qty_match:
quantity = qty_match.group(1)
uom = qty_match.group(2)
# Extract GST Percentage - look for patterns like "13 %", "CGST@6%"
if not gst_percentage:
gst_match = re.search(r"(\d{1,2})\s*%", current_line)
if gst_match and not re.search(r"Disc|Discount", current_line, re.I):
gst_pct = gst_match.group(1)
# For CGST/SGST, double the percentage for total GST
if re.search(r"CGST|SGST", current_line, re.I):
gst_percentage = str(int(gst_pct) * 2) + "%"
else:
gst_percentage = gst_pct + "%"
# Extract Discount
if not discount_amount and not discount_percentage:
if re.search(r"Disc|Discount", current_line, re.I):
disc_match = re.search(r"(\d+(?:\.\d{2})?)\s*%", current_line)
if disc_match:
discount_percentage = disc_match.group(1) + "%"
else:
amount_match = re.search(amount_regex, current_line)
if amount_match:
discount_amount = amount_match.group()
# Extract Rate (look for rate context)
if not rate_excl_tax:
if re.search(r"Rate(?!\s*%)", current_line, re.I):
rate_matches = re.findall(amount_regex, current_line)
if rate_matches:
# Get the last amount (usually the rate)
rate_excl_tax = rate_matches[-1].replace(',', '')
# Extract Taxable Value
if not taxable_value:
if re.search(r"Taxable|Value", current_line, re.I) and not re.search(r"Rate", current_line, re.I):
tax_matches = re.findall(amount_regex, current_line)
if tax_matches:
taxable_value = tax_matches[-1].replace(',', '')
# Extract GST Amount
if not gst_amount:
if re.search(r"(CGST|SGST|IGST).*Amount", current_line, re.I):
gst_matches = re.findall(amount_regex, current_line)
if gst_matches:
# For CGST+SGST, we need to sum them
gst_amount = gst_matches[-1].replace(',', '')
# Calculate missing values
if not rate_excl_tax and taxable_value and quantity:
try:
rate_excl_tax = str(round(float(taxable_value.replace(',', '')) / float(quantity), 2))
except:
pass
if not taxable_value and rate_excl_tax and quantity:
try:
taxable_value = str(round(float(rate_excl_tax.replace(',', '')) * float(quantity), 2))
except:
pass
# Clean up product name
product_name = re.sub(r'\s+', ' ', product_name).strip()
# Only add if we have meaningful data
if product_name or hsn or quantity:
items.append({
"Serial": serial,
"Product Name": product_name,
"Part Number": part_no,
"HSN/SAC": hsn,
"Quantity": quantity,
"UOM": uom,
"Rate (Excl. Tax)": rate_excl_tax,
"Taxable Value": taxable_value,
"GST Percentage": gst_percentage,
"GST Amount": gst_amount,
"Discount Percentage": discount_percentage,
"Discount Amount": discount_amount
})
i += 1
except Exception as e:
print(f"Item extraction error at line {i}: {e}")
i += 1
else:
i += 1
return items
# =====================================================
# FIELD EXTRACTION (Vendor, Buyer, GSTIN, Invoice)
# =====================================================
def extract_fields(text):
data = {
"Vendor Name": "",
"Vendor GSTIN": "",
"Vendor Contact": "",
"Buyer Name": "",
"Buyer GSTIN": "",
"Buyer Contact": "",
"Invoice Number": "",
"Invoice Date": "",
"Items": []
}
lines = [l.strip() for l in text.split("\n") if l.strip()]
# Phone number pattern (Indian format)
phone_regex = r"(?:(?:\+91|0)?[\s-]?)?[6-9]\d{9}"
# === VENDOR INFORMATION ===
for i, l in enumerate(lines):
if "TAX INVOICE" in l.upper() or "INVOICE" in l.upper():
for j in range(i + 1, min(i + 5, len(lines))):
candidate = lines[j]
if (len(candidate) > 3 and
not re.search(r"GSTIN|GST|PAN|ADDRESS|PHONE|EMAIL|^\d+$", candidate, re.I)):
data["Vendor Name"] = candidate
break
break
if not data["Vendor Name"]:
for i in range(min(5, len(lines))):
if len(lines[i]) > 3 and not re.search(r"INVOICE|ORIGINAL|DUPLICATE", lines[i], re.I):
data["Vendor Name"] = lines[i]
break
# GSTIN extraction with context
gst_regex = r"\b\d{2}[A-Z]{5}\d{4}[A-Z]\d[A-Z\d]{3}\b"
# Find vendor GSTIN (appears first or near vendor section)
for i, l in enumerate(lines[:len(lines)//2]):
if "GSTIN" in l.upper() or "UIN" in l.upper():
# Check same line and next few lines
for j in range(i, min(i + 3, len(lines))):
match = re.search(gst_regex, lines[j])
if match and not data["Vendor GSTIN"]:
data["Vendor GSTIN"] = match.group()
break
# Find buyer GSTIN (appears in buyer section)
for i, l in enumerate(lines):
if any(kw in l.upper() for kw in ["BUYER", "BILL TO"]):
for j in range(i, min(i + 10, len(lines))):
if "GSTIN" in lines[j].upper() or "UIN" in lines[j].upper():
for k in range(j, min(j + 3, len(lines))):
match = re.search(gst_regex, lines[k])
if match and match.group() != data["Vendor GSTIN"]:
data["Buyer GSTIN"] = match.group()
break
break
break
# Vendor Contact
first_half = "\n".join(lines[:len(lines)//2])
vendor_phones = re.findall(phone_regex, first_half)
if vendor_phones:
data["Vendor Contact"] = vendor_phones[0].strip()
# === BUYER INFORMATION ===
buyer_keywords = ["BUYER", "BILL TO", "BILLED TO", "CUSTOMER", "CONSIGNEE", "SHIP TO"]
for i, l in enumerate(lines):
if any(keyword in l.upper() for keyword in buyer_keywords):
for j in range(i + 1, min(i + 10, len(lines))):
candidate = lines[j]
# Skip lines that contain common non-name patterns
if (len(candidate) > 3 and
not re.search(r"GSTIN|GST|STATE|CODE|ADDRESS|PHONE|GROUND|FLOOR|KHATA|PLOT|OPPOSITE|UNIT|DATED|PLACE", candidate, re.I)):
# Check if it looks like a company name (has letters and reasonable length)
if re.search(r"[A-Z]", candidate) and len(candidate) < 50:
data["Buyer Name"] = candidate
break
break
# Buyer Contact
middle_section = "\n".join(lines[len(lines)//4:3*len(lines)//4])
buyer_phones = re.findall(phone_regex, middle_section)
for phone in buyer_phones:
if phone != data["Vendor Contact"]:
data["Buyer Contact"] = phone.strip()
break
# === INVOICE NUMBER ===
skip_words = ["KHATA", "PLOT", "POST", "LANE", "STATE", "DATED", "ORIGINAL", "DUPLICATE"]
inv_keywords = ["INVOICE NO", "INVOICE NUMBER", "INV NO", "BILL NO", "BILL NUMBER"]
for i, l in enumerate(lines):
if any(keyword in l.upper() for keyword in inv_keywords):
parts = l.split(":")
if len(parts) > 1:
inv_candidate = parts[1].strip()
if inv_candidate and not any(sw in inv_candidate.upper() for sw in skip_words):
data["Invoice Number"] = inv_candidate
break
for j in range(i + 1, min(i + 5, len(lines))):
cand = lines[j]
if not any(sw in cand.upper() for sw in skip_words):
if re.search(r"[A-Z0-9]", cand) and len(cand) > 2:
data["Invoice Number"] = cand
break
break
# === INVOICE DATE ===
date_patterns = [
r"\b\d{1,2}-[A-Za-z]{3}-\d{2,4}\b", # 16-Aug-25
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b", # DD-MM-YYYY or DD/MM/YYYY
r"\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", # YYYY-MM-DD
]
date_keywords = ["DATE", "DATED", "INVOICE DATE", "BILL DATE"]
for i, l in enumerate(lines):
# Check if line contains date keyword
if any(keyword in l.upper() for keyword in date_keywords):
# First check the same line for date
for pattern in date_patterns:
m = re.search(pattern, l)
if m:
# Make sure it's not the invoice number
potential_date = m.group(0)
if not re.search(r"[A-Z]{2,}", potential_date) and '/' not in potential_date[:5]:
data["Invoice Date"] = potential_date
break
if data["Invoice Date"]:
break
# Then check next few lines
for j in range(i + 1, min(i + 4, len(lines))):
for pattern in date_patterns:
m = re.search(pattern, lines[j])
if m:
potential_date = m.group(0)
# Validate it's not invoice number
if not re.search(r"[A-Z]{2,}", potential_date) and len(potential_date) < 15:
data["Invoice Date"] = potential_date
break
if data["Invoice Date"]:
break
break
# Items
data["Items"] = extract_items(text)
return data
# =====================================================
# OCR MAIN FUNCTION
# =====================================================
def run_ocr(file_path):
try:
full_text = ""
# PDF
if file_path.lower().endswith(".pdf"):
pages = pdf_to_images(file_path)
for img in pages:
arr = np.array(img)
txt = reader.readtext(arr, detail=0)
full_text += "\n".join(txt) + "\n"
# Image
else:
img = PILImage.open(file_path).convert("RGB")
arr = np.array(img)
txt = reader.readtext(arr, detail=0)
full_text = "\n".join(txt)
fields = extract_fields(full_text)
return full_text, fields
except Exception as e:
return f"Error processing file: {str(e)}", {}
# =====================================================
# FASTAPI APP WITH CORS
# =====================================================
app = FastAPI(title="Invoice OCR API", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"message": "Invoice OCR API",
"endpoints": {
"POST /api/extract": "Extract data from invoice (PDF/Image)",
"GET /docs": "API Documentation"
}
}
@app.post("/api/extract")
async def extract_api(file: UploadFile = File(...)):
try:
allowed_types = ["application/pdf", "image/jpeg", "image/png", "image/jpg"]
if file.content_type not in allowed_types:
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: PDF, JPEG, PNG"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
full_text, fields = run_ocr(tmp_path)
os.unlink(tmp_path)
return JSONResponse({
"success": True,
"filename": file.filename,
"text": full_text,
"fields": fields
})
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =====================================================
# GRADIO FRONTEND
# =====================================================
def process_invoice(file):
if file is None:
return "No file uploaded", {}
full_text, fields = run_ocr(file.name)
return full_text, fields
demo = gr.Interface(
fn=process_invoice,
inputs=gr.File(type="filepath", label="Upload Invoice (PDF/Image)"),
outputs=[
gr.Textbox(label="Extracted Text", lines=10),
gr.JSON(label="Extracted Fields")
],
title="📄 Invoice OCR Extractor",
description="Upload PDF or Image invoices to extract text and structured data using EasyOCR",
examples=None,
cache_examples=False
)
# =====================================================
# MOUNT GRADIO ON FASTAPI
# =====================================================
app = gr.mount_gradio_app(app, demo, path="/")
# =====================================================
# LAUNCH (Hugging Face compatible)
# =====================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)