from fastapi import FastAPI import base64 from PIL import Image, ImageEnhance import pytesseract from langdetect import detect, DetectorFactory from deep_translator import GoogleTranslator import re import numpy as np import cv2 import unicodedata import io from pydantic import BaseModel pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" # Fix language detection randomness DetectorFactory.seed = 0 app = FastAPI() LANG_CODE_MAP = { "en": "eng", "ta": "tam", "hi": "hin", "kn": "kan", "ml": "mal", "te": "tel", "bn": "ben", "gu": "guj", "pa": "pan", "mr": "mar" } def perform_ocr(image): try: text = pytesseract.image_to_string( image, lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar', config='--psm 6' ) return text.strip() except Exception as e: print("OCR Error:", e) return None def perform_ocr(image): try: # First OCR pass (default settings) text = pytesseract.image_to_string(image, config='--psm 6').strip() # Detect language detected_lang = detect(text) # If not English, re-run OCR for better accuracy if detected_lang != 'en': text = pytesseract.image_to_string( image, lang=detected_lang, config='--psm 6' ).strip() # Translate if needed translated_text = text if detected_lang != 'en': translator = Translator() translated_text = translator.translate(text, src=detected_lang, dest='en').text return { "detected_language": detected_lang, "original_text": text, "translated_text": translated_text if detected_lang != 'en' else None } except Exception as e: print("OCR Error:", e) return None def clean_ocr_text(text): # Normalize unicode (fix weird diacritics, spacing issues) text = unicodedata.normalize("NFKC", text) # Remove excessive spaces & fix newlines text = re.sub(r'\s+', ' ', text).strip() # Common OCR letter/number confusion corrections (global) replacements = { r'\bI(?=\d)': '1', # I before a digit → 1 r'(?<=\d)O\b': '0', # O after a digit → 0 r'\bO(?=\d)': '0', # O before a digit → 0 r'(?<=\d)l\b': '1', # l after digit → 1 r'\bS(?=\d)': '5', # S before digit → 5 r'\bBi\s*11\b': 'Bill', # Specific common OCR error } for pattern, replacement in replacements.items(): text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) # Fix common punctuation errors text = text.replace(" .", ".").replace(" ,", ",") text = re.sub(r'\s+:\s*', ': ', text) text = re.sub(r'\s+#\s*', ' #', text) # Remove weird OCR garbage characters text = re.sub(r'[^\x00-\x7F]+', ' ', text) return text def preprocess_image(image): if image is None: # Check if image is None print("Error: Input image is None.") return None if not isinstance(image, np.ndarray): image = np.array(image) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Denoise + sharpen gray = cv2.medianBlur(gray, 3) kernel = np.array([[0, -1, 0], [-1, 5,-1], [0, -1, 0]]) gray = cv2.filter2D(gray, -1, kernel) # Increase contrast pil_img = Image.fromarray(gray) enhancer = ImageEnhance.Contrast(pil_img) pil_img = enhancer.enhance(2) return np.array(pil_img) def detect_language(text_data): """Detect the language of extracted text""" try: lang_code = detect(text_data['original_text']) language_map = { 'en': 'English', 'hi': 'Hindi', 'ta': 'Tamil', 'te': 'Telugu', 'kn': 'Kannada' } detected_lang = language_map.get(lang_code, lang_code) print(f"\nDetected Language: {detected_lang} ({lang_code})") return lang_code except Exception as e: print(f"Language detection error: {e}") return None def perform_ocr(image): text = pytesseract.image_to_string( image, lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar', config='--psm 6' ).strip() detected_lang = detect(text) if text else "en" translated_text = None if detected_lang != 'en' and text: try: translated_text = GoogleTranslator(source=detected_lang, target="en").translate(text) except Exception as e: translated_text = f"[Translation failed: {e}]" return { "detected_language": detected_lang, "original_text": text, "translated_text": translated_text } def extract_field_from_lines(lines, patterns): for line in lines: for pattern in patterns: match = re.search(pattern, line, flags=re.IGNORECASE) if match: # Check if the pattern has capturing groups if match.groups(): #return match.group(1).strip() return match.group(1).strip() if match.lastindex else match.group(0).strip() # else: # # If no capturing group, return the entire match # return match.group(0).strip() return None def extract_invoice_fields(text): lines = [line.strip() for line in text.split('\n') if line.strip()] invoice_number_patterns = [ r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/]{4,})', r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)(?!date)([A-Z0-9][A-Z0-9\-_/]{4,})', r'(?:invoice\s*(?:number|no|nos|na|#)?\s*[:\-\=\.]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{3,})', r'(?:receipt\s*(?:number|no|#)?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})', r'(?:^|\s)#\s*([A-Z0-9][A-Z0-9\-_/\.]{2,})', r'(?:order\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})' ] date_patterns = [ r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})', r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*([A-Za-z]{3,9}[ ]?\d{1,2},?[ ]?\d{4})', r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{4}[/-]\d{1,2}[/-]\d{1,2})', r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', r'(?:receipt\s*date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', ] fallback_date_patterns = [ r'\b(\d{1,2}\s[A-Za-z]{3,9}\s?\d{2,4})\b', r'\b(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})\b', r'\b([A-Za-z]{3,9}\s*\d{1,2},?\s*\d{4})\b', r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', ] amount_patterns = [ r'(?:total\s*amount|grand\s*total|amount\s*payable|net\s*amount|total|rounding)\s*[:\-]?\s*\₹?\s*([\d,]+\.\d{2})', #r'(?i)(?:total\s*(?:value|due)?|invoice\s*value)\s*[:\-]?\s*(?:₹|Rs\.?|INR)?\s*([\d,.]+)', # Added this pattern r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b', r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b' ] invoice_number = extract_field_from_lines(lines, invoice_number_patterns) invoice_date = extract_field_from_lines(lines, date_patterns) total_amount = extract_field_from_lines(lines, amount_patterns) if not invoice_date: invoice_date = extract_field_from_lines(lines, fallback_date_patterns) # Fallback: largest number in OCR if not total_amount: numbers = [] for line in lines: matches = re.findall(r'\d{1,3}(?:,\d{3})*(?:\.\d{2})', line) numbers += [float(m.replace(',', '')) for m in matches if m] if numbers: total_amount = f"{max(numbers):.2f}" return { "invoice_number": invoice_number, "invoice_date": invoice_date, "total_amount": total_amount } # ------------------ API ENDPOINTS ------------------ class ImagePayload(BaseModel): image: str @app.get("/") def read_root(): return {"status": "ok", "message": "Invoice OCR API is running!"} @app.post("/predict") async def predict(payload: ImagePayload): try: img_base64 = payload.image if not img_base64: return {"error": "No image provided"} # Remove base64 prefix if present if img_base64.startswith("data:image"): img_base64 = img_base64.split(",")[1] # Decode base64 to image image_bytes = base64.b64decode(img_base64) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Preprocess processed_img = preprocess_image(image) # OCR + Translation text_data = perform_ocr(processed_img) # Cleaning cleaned_text = clean_ocr_text(text_data["translated_text"] or text_data["original_text"]) # Extraction fields = extract_invoice_fields(cleaned_text) return { #"language": text_data["detected_language"], "text": cleaned_text, "fields": fields } except Exception as e: return {"error": str(e)} if __name__ == "__main__": import os port = int(os.environ.get("PORT", 8080))