Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| import base64 | |
| from PIL import Image, ImageEnhance | |
| import pytesseract | |
| from langdetect import detect, DetectorFactory | |
| from deep_translator import GoogleTranslator | |
| import re | |
| import numpy as np | |
| import cv2 | |
| import unicodedata | |
| import io | |
| from pydantic import BaseModel | |
| pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" | |
| # Fix language detection randomness | |
| DetectorFactory.seed = 0 | |
| app = FastAPI() | |
| LANG_CODE_MAP = { | |
| "en": "eng", "ta": "tam", "hi": "hin", | |
| "kn": "kan", "ml": "mal", "te": "tel", | |
| "bn": "ben", "gu": "guj", "pa": "pan", "mr": "mar" | |
| } | |
| def perform_ocr(image): | |
| try: | |
| text = pytesseract.image_to_string( | |
| image, | |
| lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar', | |
| config='--psm 6' | |
| ) | |
| return text.strip() | |
| except Exception as e: | |
| print("OCR Error:", e) | |
| return None | |
| def perform_ocr(image): | |
| try: | |
| # First OCR pass (default settings) | |
| text = pytesseract.image_to_string(image, config='--psm 6').strip() | |
| # Detect language | |
| detected_lang = detect(text) | |
| # If not English, re-run OCR for better accuracy | |
| if detected_lang != 'en': | |
| text = pytesseract.image_to_string( | |
| image, | |
| lang=detected_lang, | |
| config='--psm 6' | |
| ).strip() | |
| # Translate if needed | |
| translated_text = text | |
| if detected_lang != 'en': | |
| translator = Translator() | |
| translated_text = translator.translate(text, src=detected_lang, dest='en').text | |
| return { | |
| "detected_language": detected_lang, | |
| "original_text": text, | |
| "translated_text": translated_text if detected_lang != 'en' else None | |
| } | |
| except Exception as e: | |
| print("OCR Error:", e) | |
| return None | |
| def clean_ocr_text(text): | |
| # Normalize unicode (fix weird diacritics, spacing issues) | |
| text = unicodedata.normalize("NFKC", text) | |
| # Remove excessive spaces & fix newlines | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| # Common OCR letter/number confusion corrections (global) | |
| replacements = { | |
| r'\bI(?=\d)': '1', # I before a digit → 1 | |
| r'(?<=\d)O\b': '0', # O after a digit → 0 | |
| r'\bO(?=\d)': '0', # O before a digit → 0 | |
| r'(?<=\d)l\b': '1', # l after digit → 1 | |
| r'\bS(?=\d)': '5', # S before digit → 5 | |
| r'\bBi\s*11\b': 'Bill', # Specific common OCR error | |
| } | |
| for pattern, replacement in replacements.items(): | |
| text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
| # Fix common punctuation errors | |
| text = text.replace(" .", ".").replace(" ,", ",") | |
| text = re.sub(r'\s+:\s*', ': ', text) | |
| text = re.sub(r'\s+#\s*', ' #', text) | |
| # Remove weird OCR garbage characters | |
| text = re.sub(r'[^\x00-\x7F]+', ' ', text) | |
| return text | |
| def preprocess_image(image): | |
| if image is None: # Check if image is None | |
| print("Error: Input image is None.") | |
| return None | |
| if not isinstance(image, np.ndarray): | |
| image = np.array(image) | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| # Denoise + sharpen | |
| gray = cv2.medianBlur(gray, 3) | |
| kernel = np.array([[0, -1, 0], [-1, 5,-1], [0, -1, 0]]) | |
| gray = cv2.filter2D(gray, -1, kernel) | |
| # Increase contrast | |
| pil_img = Image.fromarray(gray) | |
| enhancer = ImageEnhance.Contrast(pil_img) | |
| pil_img = enhancer.enhance(2) | |
| return np.array(pil_img) | |
| def detect_language(text_data): | |
| """Detect the language of extracted text""" | |
| try: | |
| lang_code = detect(text_data['original_text']) | |
| language_map = { | |
| 'en': 'English', | |
| 'hi': 'Hindi', | |
| 'ta': 'Tamil', | |
| 'te': 'Telugu', | |
| 'kn': 'Kannada' | |
| } | |
| detected_lang = language_map.get(lang_code, lang_code) | |
| print(f"\nDetected Language: {detected_lang} ({lang_code})") | |
| return lang_code | |
| except Exception as e: | |
| print(f"Language detection error: {e}") | |
| return None | |
| def perform_ocr(image): | |
| text = pytesseract.image_to_string( | |
| image, | |
| lang='eng+tam+kan+hin+tel+mal+ben+guj+pan+mar', | |
| config='--psm 6' | |
| ).strip() | |
| detected_lang = detect(text) if text else "en" | |
| translated_text = None | |
| if detected_lang != 'en' and text: | |
| try: | |
| translated_text = GoogleTranslator(source=detected_lang, target="en").translate(text) | |
| except Exception as e: | |
| translated_text = f"[Translation failed: {e}]" | |
| return { | |
| "detected_language": detected_lang, | |
| "original_text": text, | |
| "translated_text": translated_text | |
| } | |
| def extract_field_from_lines(lines, patterns): | |
| for line in lines: | |
| for pattern in patterns: | |
| match = re.search(pattern, line, flags=re.IGNORECASE) | |
| if match: | |
| # Check if the pattern has capturing groups | |
| if match.groups(): | |
| #return match.group(1).strip() | |
| return match.group(1).strip() if match.lastindex else match.group(0).strip() | |
| # else: | |
| # # If no capturing group, return the entire match | |
| # return match.group(0).strip() | |
| return None | |
| def extract_invoice_fields(text): | |
| lines = [line.strip() for line in text.split('\n') if line.strip()] | |
| invoice_number_patterns = [ | |
| r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/]{4,})', | |
| r'(?i)(?:invoice\s*(?:number|no)?\.?\s*[:\-]?\s*)(?!date)([A-Z0-9][A-Z0-9\-_/]{4,})', | |
| r'(?:invoice\s*(?:number|no|nos|na|#)?\s*[:\-\=\.]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{3,})', | |
| r'(?:receipt\s*(?:number|no|#)?\s*[:\-]?\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})', | |
| r'(?:^|\s)#\s*([A-Z0-9][A-Z0-9\-_/\.]{2,})', | |
| r'(?:order\s*)([A-Z0-9][A-Z0-9\-_/\.]{2,})' | |
| ] | |
| date_patterns = [ | |
| r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})', | |
| r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*([A-Za-z]{3,9}[ ]?\d{1,2},?[ ]?\d{4})', | |
| r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{4}[/-]\d{1,2}[/-]\d{1,2})', | |
| r'(?:invoice\s*date|bill\s*date|receipt\s*date|date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', | |
| r'(?:receipt\s*date)\s*[:\-]?\s*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})', | |
| ] | |
| fallback_date_patterns = [ | |
| r'\b(\d{1,2}\s[A-Za-z]{3,9}\s?\d{2,4})\b', | |
| r'\b(\d{1,2}[/-][A-Za-z]{3,9}[/-]?\d{2,4})\b', | |
| r'\b([A-Za-z]{3,9}\s*\d{1,2},?\s*\d{4})\b', | |
| r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', | |
| r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', | |
| ] | |
| amount_patterns = [ | |
| r'(?:total\s*amount|grand\s*total|amount\s*payable|net\s*amount|total|rounding)\s*[:\-]?\s*\₹?\s*([\d,]+\.\d{2})', | |
| #r'(?i)(?:total\s*(?:value|due)?|invoice\s*value)\s*[:\-]?\s*(?:₹|Rs\.?|INR)?\s*([\d,.]+)', # Added this pattern | |
| r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b', | |
| r'\b(₹|Rs\.?|INR)\s*([\d,]+\.\d{2})\b' | |
| ] | |
| invoice_number = extract_field_from_lines(lines, invoice_number_patterns) | |
| invoice_date = extract_field_from_lines(lines, date_patterns) | |
| total_amount = extract_field_from_lines(lines, amount_patterns) | |
| if not invoice_date: | |
| invoice_date = extract_field_from_lines(lines, fallback_date_patterns) | |
| # Fallback: largest number in OCR | |
| if not total_amount: | |
| numbers = [] | |
| for line in lines: | |
| matches = re.findall(r'\d{1,3}(?:,\d{3})*(?:\.\d{2})', line) | |
| numbers += [float(m.replace(',', '')) for m in matches if m] | |
| if numbers: | |
| total_amount = f"{max(numbers):.2f}" | |
| return { | |
| "invoice_number": invoice_number, | |
| "invoice_date": invoice_date, | |
| "total_amount": total_amount | |
| } | |
| # ------------------ API ENDPOINTS ------------------ | |
| class ImagePayload(BaseModel): | |
| image: str | |
| def read_root(): | |
| return {"status": "ok", "message": "Invoice OCR API is running!"} | |
| async def predict(payload: ImagePayload): | |
| try: | |
| img_base64 = payload.image | |
| if not img_base64: | |
| return {"error": "No image provided"} | |
| # Remove base64 prefix if present | |
| if img_base64.startswith("data:image"): | |
| img_base64 = img_base64.split(",")[1] | |
| # Decode base64 to image | |
| image_bytes = base64.b64decode(img_base64) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Preprocess | |
| processed_img = preprocess_image(image) | |
| # OCR + Translation | |
| text_data = perform_ocr(processed_img) | |
| # Cleaning | |
| cleaned_text = clean_ocr_text(text_data["translated_text"] or text_data["original_text"]) | |
| # Extraction | |
| fields = extract_invoice_fields(cleaned_text) | |
| return { | |
| #"language": text_data["detected_language"], | |
| "text": cleaned_text, | |
| "fields": fields | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| if __name__ == "__main__": | |
| import os | |
| port = int(os.environ.get("PORT", 8080)) |