from fastapi import FastAPI import base64 from PIL import Image, ImageEnhance import pytesseract from langdetect import detect, DetectorFactory from deep_translator import GoogleTranslator import re import numpy as np import cv2 import unicodedata import io from pydantic import BaseModel # Fix language detection randomness DetectorFactory.seed = 0 app = FastAPI() LANG_CODE_MAP = { "en": "eng", "ta": "tam", "hi": "hin", "kn": "kan", "ml": "mal", "te": "tel", "bn": "ben", "gu": "guj", "pa": "pan", "mr": "mar" } # ------------------ CLEANING ------------------ def clean_ocr_text(text): text = unicodedata.normalize("NFKC", text) text = re.sub(r'\s+', ' ', text).strip() replacements = { r'\bI(?=\d)': '1', r'(?<=\d)O\b': '0', r'\bO(?=\d)': '0', r'(?<=\d)l\b': '1', r'\bS(?=\d)': '5', r'\bBi\s*11\b': 'Bill', } for pattern, replacement in replacements.items(): text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) text = text.replace(" .", ".").replace(" ,", ",") text = re.sub(r'\s+:\s*', ': ', text) text = re.sub(r'\s+#\s*', ' #', text) text = re.sub(r'[^\x00-\x7F]+', ' ', text) return text # ------------------ PREPROCESSING ------------------ def preprocess_image(image): if not isinstance(image, np.ndarray): image = np.array(image) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray = cv2.medianBlur(gray, 3) kernel = np.array([[0, -1, 0], [-1, 5,-1], [0, -1, 0]]) gray = cv2.filter2D(gray, -1, kernel) pil_img = Image.fromarray(gray) enhancer = ImageEnhance.Contrast(pil_img) pil_img = enhancer.enhance(2) return np.array(pil_img) # ------------------ OCR ------------------ def perform_ocr(image): text = pytesseract.image_to_string( image, lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar', config='--psm 6' ).strip() detected_lang = detect(text) if text else "en" translated_text = None if detected_lang != 'en' and text: try: translated_text = GoogleTranslator(source=detected_lang, target="en").translate(text) except Exception as e: translated_text = f"[Translation failed: {e}]" return { "detected_language": detected_lang, "original_text": text, "translated_text": translated_text } # ------------------ FIELD EXTRACTION ------------------ def extract_field_from_lines(lines, patterns): for line in lines: for pattern in patterns: match = re.search(pattern, line, flags=re.IGNORECASE) if match: return match.group(1).strip() if match.lastindex else match.group(0).strip() return None def extract_invoice_fields(text): lines = [line.strip() for line in text.split('\n') if line.strip()] invoice_number_patterns = [ r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/]{4,})', r'(?:invoice\s*(?:number|no|nos|na|#)?\s*[:\-\=\.]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{3,})', r'(?:receipt\s*(?:number|no|#)?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})', r'(?:^|\s)#\s*([A-Z0-9][A-Z0-9\-_/\.]{2,})', r'(?:order\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})' ] date_patterns = [ r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})', r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*([A-Za-z]{3,9}[ ]?\d{1,2},?[ ]?\d{4})', r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{4}[/-]\d{1,2}[/-]\d{1,2})', r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', ] fallback_date_patterns = [ r'\b(\d{1,2}\s[A-Za-z]{3,9}\s?\d{2,4})\b', r'\b(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})\b', r'\b([A-Za-z]{3,9}\s*\d{1,2},?\s*\d{4})\b', r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', ] amount_patterns = [ r'(?:total\s*amount|grand\s*total|amount\s*payable|net\s*amount|total|rounding)\s*[:\-]?\s*\₹?\s*([\d,]+\.\d{2})', r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b', ] invoice_number = extract_field_from_lines(lines, invoice_number_patterns) invoice_date = extract_field_from_lines(lines, date_patterns) or extract_field_from_lines(lines, fallback_date_patterns) total_amount = extract_field_from_lines(lines, amount_patterns) if not total_amount: numbers = [] for line in lines: matches = re.findall(r'\d{1,3}(?:,\d{3})*(?:\.\d{2})', line) numbers += [float(m.replace(',', '')) for m in matches if m] if numbers: total_amount = f"{max(numbers):.2f}" return { "invoice_number": invoice_number, "invoice_date": invoice_date, "total_amount": total_amount } # ------------------ API ENDPOINTS ------------------ class ImagePayload(BaseModel): image: str @app.get("/") def read_root(): return {"status": "ok", "message": "Invoice OCR API is running!"} @app.post("/predict") async def predict(payload: ImagePayload): try: img_base64 = payload.image if not img_base64: return {"error": "No image provided"} # Remove base64 prefix if present if img_base64.startswith("data:image"): img_base64 = img_base64.split(",")[1] # Decode base64 to image image_bytes = base64.b64decode(img_base64) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Preprocess processed_img = preprocess_image(image) # OCR + Translation text_data = perform_ocr(processed_img) # Cleaning cleaned_text = clean_ocr_text(text_data["translated_text"] or text_data["original_text"]) # Extraction fields = extract_invoice_fields(cleaned_text) return { "language": text_data["detected_language"], "text": cleaned_text, "fields": fields } except Exception as e: return {"error": str(e)}