Invoice_OCR / main.py
HemanthR007's picture
Update main.py
fd47650 verified
from fastapi import FastAPI
import base64
from PIL import Image, ImageEnhance
import pytesseract
from langdetect import detect, DetectorFactory
from deep_translator import GoogleTranslator
import re
import numpy as np
import cv2
import unicodedata
import io
from pydantic import BaseModel
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
# Fix language detection randomness
DetectorFactory.seed = 0
app = FastAPI()
LANG_CODE_MAP = {
"en": "eng", "ta": "tam", "hi": "hin",
"kn": "kan", "ml": "mal", "te": "tel",
"bn": "ben", "gu": "guj", "pa": "pan", "mr": "mar"
}
def perform_ocr(image):
try:
text = pytesseract.image_to_string(
image,
lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar',
config='--psm 6'
)
return text.strip()
except Exception as e:
print("OCR Error:", e)
return None
def perform_ocr(image):
try:
# First OCR pass (default settings)
text = pytesseract.image_to_string(image, config='--psm 6').strip()
# Detect language
detected_lang = detect(text)
# If not English, re-run OCR for better accuracy
if detected_lang != 'en':
text = pytesseract.image_to_string(
image,
lang=detected_lang,
config='--psm 6'
).strip()
# Translate if needed
translated_text = text
if detected_lang != 'en':
translator = Translator()
translated_text = translator.translate(text, src=detected_lang, dest='en').text
return {
"detected_language": detected_lang,
"original_text": text,
"translated_text": translated_text if detected_lang != 'en' else None
}
except Exception as e:
print("OCR Error:", e)
return None
def clean_ocr_text(text):
# Normalize unicode (fix weird diacritics, spacing issues)
text = unicodedata.normalize("NFKC", text)
# Remove excessive spaces & fix newlines
text = re.sub(r'\s+', ' ', text).strip()
# Common OCR letter/number confusion corrections (global)
replacements = {
r'\bI(?=\d)': '1', # I before a digit → 1
r'(?<=\d)O\b': '0', # O after a digit → 0
r'\bO(?=\d)': '0', # O before a digit → 0
r'(?<=\d)l\b': '1', # l after digit → 1
r'\bS(?=\d)': '5', # S before digit → 5
r'\bBi\s*11\b': 'Bill', # Specific common OCR error
}
for pattern, replacement in replacements.items():
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
# Fix common punctuation errors
text = text.replace(" .", ".").replace(" ,", ",")
text = re.sub(r'\s+:\s*', ': ', text)
text = re.sub(r'\s+#\s*', ' #', text)
# Remove weird OCR garbage characters
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
return text
def preprocess_image(image):
if image is None: # Check if image is None
print("Error: Input image is None.")
return None
if not isinstance(image, np.ndarray):
image = np.array(image)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Denoise + sharpen
gray = cv2.medianBlur(gray, 3)
kernel = np.array([[0, -1, 0], [-1, 5,-1], [0, -1, 0]])
gray = cv2.filter2D(gray, -1, kernel)
# Increase contrast
pil_img = Image.fromarray(gray)
enhancer = ImageEnhance.Contrast(pil_img)
pil_img = enhancer.enhance(2)
return np.array(pil_img)
def detect_language(text_data):
"""Detect the language of extracted text"""
try:
lang_code = detect(text_data['original_text'])
language_map = {
'en': 'English',
'hi': 'Hindi',
'ta': 'Tamil',
'te': 'Telugu',
'kn': 'Kannada'
}
detected_lang = language_map.get(lang_code, lang_code)
print(f"\nDetected Language: {detected_lang} ({lang_code})")
return lang_code
except Exception as e:
print(f"Language detection error: {e}")
return None
def perform_ocr(image):
text = pytesseract.image_to_string(
image,
lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar',
config='--psm 6'
).strip()
detected_lang = detect(text) if text else "en"
translated_text = None
if detected_lang != 'en' and text:
try:
translated_text = GoogleTranslator(source=detected_lang, target="en").translate(text)
except Exception as e:
translated_text = f"[Translation failed: {e}]"
return {
"detected_language": detected_lang,
"original_text": text,
"translated_text": translated_text
}
def extract_field_from_lines(lines, patterns):
for line in lines:
for pattern in patterns:
match = re.search(pattern, line, flags=re.IGNORECASE)
if match:
# Check if the pattern has capturing groups
if match.groups():
#return match.group(1).strip()
return match.group(1).strip() if match.lastindex else match.group(0).strip()
# else:
# # If no capturing group, return the entire match
# return match.group(0).strip()
return None
def extract_invoice_fields(text):
lines = [line.strip() for line in text.split('\n') if line.strip()]
invoice_number_patterns = [
r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/]{4,})',
r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)(?!date)([A-Z0-9][A-Z0-9\-_/]{4,})',
r'(?:invoice\s*(?:number|no|nos|na|#)?\s*[:\-\=\.]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{3,})',
r'(?:receipt\s*(?:number|no|#)?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})',
r'(?:^|\s)#\s*([A-Z0-9][A-Z0-9\-_/\.]{2,})',
r'(?:order\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})'
]
date_patterns = [
r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})',
r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*([A-Za-z]{3,9}[ ]?\d{1,2},?[ ]?\d{4})',
r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{4}[/-]\d{1,2}[/-]\d{1,2})',
r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
r'(?:receipt\s*date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
]
fallback_date_patterns = [
r'\b(\d{1,2}\s[A-Za-z]{3,9}\s?\d{2,4})\b',
r'\b(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})\b',
r'\b([A-Za-z]{3,9}\s*\d{1,2},?\s*\d{4})\b',
r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b',
r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b',
]
amount_patterns = [
r'(?:total\s*amount|grand\s*total|amount\s*payable|net\s*amount|total|rounding)\s*[:\-]?\s*\₹?\s*([\d,]+\.\d{2})',
#r'(?i)(?:total\s*(?:value|due)?|invoice\s*value)\s*[:\-]?\s*(?:₹|Rs\.?|INR)?\s*([\d,.]+)', # Added this pattern
r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b',
r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b'
]
invoice_number = extract_field_from_lines(lines, invoice_number_patterns)
invoice_date = extract_field_from_lines(lines, date_patterns)
total_amount = extract_field_from_lines(lines, amount_patterns)
if not invoice_date:
invoice_date = extract_field_from_lines(lines, fallback_date_patterns)
# Fallback: largest number in OCR
if not total_amount:
numbers = []
for line in lines:
matches = re.findall(r'\d{1,3}(?:,\d{3})*(?:\.\d{2})', line)
numbers += [float(m.replace(',', '')) for m in matches if m]
if numbers:
total_amount = f"{max(numbers):.2f}"
return {
"invoice_number": invoice_number,
"invoice_date": invoice_date,
"total_amount": total_amount
}
# ------------------ API ENDPOINTS ------------------
class ImagePayload(BaseModel):
image: str
@app.get("/")
def read_root():
return {"status": "ok", "message": "Invoice OCR API is running!"}
@app.post("/predict")
async def predict(payload: ImagePayload):
try:
img_base64 = payload.image
if not img_base64:
return {"error": "No image provided"}
# Remove base64 prefix if present
if img_base64.startswith("data:image"):
img_base64 = img_base64.split(",")[1]
# Decode base64 to image
image_bytes = base64.b64decode(img_base64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Preprocess
processed_img = preprocess_image(image)
# OCR + Translation
text_data = perform_ocr(processed_img)
# Cleaning
cleaned_text = clean_ocr_text(text_data["translated_text"] or text_data["original_text"])
# Extraction
fields = extract_invoice_fields(cleaned_text)
return {
#"language": text_data["detected_language"],
"text": cleaned_text,
"fields": fields
}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 8080))