Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from decimal import Decimal, InvalidOperation | |
| from fastapi.encoders import jsonable_encoder | |
| from starlette.responses import JSONResponse | |
| import pytesseract | |
| import cv2 | |
| import os | |
| from PIL import Image | |
| import json | |
| import unicodedata | |
| from pdf2image import convert_from_bytes | |
| from pypdf import PdfReader | |
| import numpy as np | |
| from typing import List, Any | |
| import io | |
| import logging | |
| import time | |
| import asyncio | |
| import psutil | |
| import cachetools | |
| import hashlib | |
| import re | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| # --- START OF MODIFICATIONS --- | |
| # 1. Define a custom JSON encoder function | |
| def custom_encoder(obj: Any) -> Any: | |
| if isinstance(obj, Decimal): | |
| try: | |
| float_val = float(obj) | |
| if float_val == 0: | |
| return "0.0" | |
| elif 0 < abs(float_val) < 1e-10: | |
| result = f"{float_val:.20f}".rstrip('0').rstrip('.') | |
| elif 0 < abs(float_val) < 1e-6: | |
| result = f"{float_val:.15f}".rstrip('0').rstrip('.') | |
| elif abs(float_val) < 1: | |
| result = f"{float_val:.10f}".rstrip('0').rstrip('.') | |
| else: | |
| result = f"{float_val:.8f}".rstrip('0').rstrip('.') | |
| # Ensure the result is a string to prevent JSON serialization issues | |
| return str(result) | |
| except (ValueError, OverflowError, InvalidOperation): | |
| return str(obj) # Fallback to string representation | |
| return jsonable_encoder(obj) | |
| def custom_decimal_parser(s): | |
| """ | |
| Custom parser that ensures numbers are converted to decimal format. | |
| """ | |
| try: | |
| float_val = float(s) | |
| if float_val == 0: | |
| return Decimal('0.0') | |
| elif 0 < abs(float_val) < 1e-10: | |
| formatted = f"{float_val:.20f}".rstrip('0').rstrip('.') | |
| elif 0 < abs(float_val) < 1e-6: | |
| formatted = f"{float_val:.15f}".rstrip('0').rstrip('.') | |
| elif abs(float_val) < 1: | |
| formatted = f"{float_val:.10f}".rstrip('0').rstrip('.') | |
| else: | |
| formatted = f"{float_val:.8f}".rstrip('0').rstrip('.') | |
| return Decimal(formatted) | |
| except (ValueError, InvalidOperation): | |
| return Decimal(str(s)) | |
| def fix_scientific_notation_in_json(json_str): | |
| """ | |
| Fix scientific notation in JSON string before parsing. | |
| """ | |
| def replace_scientific(match): | |
| try: | |
| scientific_num = match.group(0) | |
| float_val = float(scientific_num) | |
| if float_val == 0: | |
| return "0.0" | |
| elif 0 < abs(float_val) < 1e-10: | |
| return f"{float_val:.20f}".rstrip('0').rstrip('.') or "0.0" | |
| elif 0 < abs(float_val) < 1e-6: | |
| return f"{float_val:.15f}".rstrip('0').rstrip('.') or "0.0" | |
| elif abs(float_val) < 1: | |
| return f"{float_val:.10f}".rstrip('0').rstrip('.') or "0.0" | |
| else: | |
| return f"{float_val:.8f}".rstrip('0').rstrip('.') or "0.0" | |
| except Exception as e: | |
| logger.error(f"Error converting {match.group(0)}: {e}") | |
| return match.group(0) | |
| patterns = [ | |
| r'-?\d+\.?\d*[eE][+-]?\d+', | |
| r'-?\d+[eE][+-]?\d+', | |
| r'-?\d+\.\d+[eE][+-]?\d+', | |
| ] | |
| original_json = json_str | |
| for pattern in patterns: | |
| json_str = re.sub(pattern, replace_scientific, json_str) | |
| def replace_quoted_scientific(match): | |
| full_match = match.group(0) | |
| number_part = match.group(1) | |
| try: | |
| float_val = float(number_part) | |
| if 0 < abs(float_val) < 1e-6: | |
| converted = f"{float_val:.15f}".rstrip('0').rstrip('.') or "0.0" | |
| else: | |
| converted = f"{float_val:.10f}".rstrip('0').rstrip('.') or "0.0" | |
| return f'"{converted}"' | |
| except: | |
| return full_match | |
| quoted_pattern = r'"(-?\d+\.?\d*[eE][+-]?\d+)"' | |
| json_str = re.sub(quoted_pattern, replace_quoted_scientific, json_str) | |
| if original_json != json_str: | |
| logger.info(f"JSON transformation occurred") | |
| logger.info(f"Original: {original_json[:200]}...") | |
| logger.info(f"Fixed: {json_str[:200]}...") | |
| return json_str | |
| def convert_scientific_decimals(obj): | |
| """ | |
| Recursively convert Decimal objects to proper decimal notation. | |
| """ | |
| if isinstance(obj, dict): | |
| return {k: convert_scientific_decimals(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [convert_scientific_decimals(item) for item in obj] | |
| elif isinstance(obj, Decimal): | |
| try: | |
| float_val = float(obj) | |
| if float_val == 0: | |
| return Decimal('0.0') | |
| elif 0 < abs(float_val) < 1e-10: | |
| formatted = f"{float_val:.20f}".rstrip('0').rstrip('.') | |
| elif 0 < abs(float_val) < 1e-6: | |
| formatted = f"{float_val:.15f}".rstrip('0').rstrip('.') | |
| elif abs(float_val) < 1: | |
| formatted = f"{float_val:.10f}".rstrip('0').rstrip('.') | |
| elif abs(float_val) < 1000000: | |
| formatted = f"{float_val:.8f}".rstrip('0').rstrip('.') | |
| else: | |
| formatted = str(int(float_val)) if float_val == int(float_val) else f"{float_val:.2f}".rstrip('0').rstrip('.') | |
| if formatted == '0' and float_val != 0: | |
| formatted = f"{float_val:.20f}".rstrip('0').rstrip('.') | |
| return Decimal(formatted) | |
| except (ValueError, OverflowError, InvalidOperation): | |
| return obj | |
| else: | |
| return obj | |
| def force_decimal_format(data): | |
| """ | |
| Ensure all numeric values are in proper decimal format before JSON encoding. | |
| """ | |
| if isinstance(data, dict): | |
| result = {} | |
| for key, value in data.items(): | |
| if key in ['unit_price', 'total_price', 'tax_amount', 'discount', 'net_amount', | |
| 'sub_total', 'tax_total', 'discount_total', 'total_amount', 'tax_rate']: | |
| if isinstance(value, dict) and 'value' in value: | |
| if isinstance(value['value'], (Decimal, float, int)): | |
| try: | |
| float_val = float(value['value']) | |
| if float_val == 0: | |
| decimal_str = "0.0" | |
| elif 0 < abs(float_val) < 1e-10: | |
| decimal_str = f"{float_val:.20f}".rstrip('0').rstrip('.') | |
| elif 0 < abs(float_val) < 1e-6: | |
| decimal_str = f"{float_val:.15f}".rstrip('0').rstrip('.') | |
| else: | |
| decimal_str = f"{float_val:.10f}".rstrip('0').rstrip('.') | |
| result[key] = {'value': Decimal(decimal_str), 'accuracy': value['accuracy']} | |
| except (ValueError, InvalidOperation): | |
| result[key] = value | |
| else: | |
| result[key] = value | |
| else: | |
| result[key] = force_decimal_format(value) | |
| else: | |
| result[key] = force_decimal_format(value) | |
| return result | |
| elif isinstance(data, list): | |
| return [force_decimal_format(item) for item in data] | |
| else: | |
| return data | |
| # --- END OF MODIFICATIONS --- | |
| app = FastAPI() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure Gemini API | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| logger.error("GOOGLE_API_KEY not set") | |
| raise HTTPException(status_code=500, detail="GOOGLE_API_KEY not set") | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel("gemini-2.5-flash") | |
| # Set Tesseract path | |
| pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" | |
| # In-memory caches | |
| raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600) | |
| structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600) | |
| def log_memory_usage(): | |
| """Log current memory usage.""" | |
| process = psutil.Process() | |
| mem_info = process.memory_info() | |
| return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB" | |
| def get_file_hash(file_bytes): | |
| """Generate MD5 hash of file content.""" | |
| return hashlib.md5(file_bytes).hexdigest() | |
| def get_text_hash(raw_text): | |
| """Generate MD5 hash of raw text.""" | |
| return hashlib.md5(raw_text.encode('utf-8')).hexdigest() | |
| async def process_image(img_bytes, filename, idx): | |
| """Process a single image with OCR.""" | |
| start_time = time.time() | |
| logger.info(f"Starting OCR for {filename} image {idx}, {log_memory_usage()}") | |
| try: | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) | |
| img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) | |
| custom_config = r'--oem 1 --psm 6 -l eng+ara' | |
| page_text = pytesseract.image_to_string(img_pil, config=custom_config) | |
| logger.info(f"Completed OCR for {filename} image {idx}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}") | |
| return page_text + "\n" | |
| except Exception as e: | |
| logger.error(f"OCR failed for {filename} image {idx}: {str(e)}, {log_memory_usage()}") | |
| return "" | |
| async def process_pdf_page(img, page_idx): | |
| """Process a single PDF page with OCR.""" | |
| start_time = time.time() | |
| logger.info(f"Starting OCR for PDF page {page_idx}, {log_memory_usage()}") | |
| try: | |
| img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) | |
| img_pil = Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) | |
| custom_config = r'--oem 1 --psm 6 -l eng+ara' | |
| page_text = pytesseract.image_to_string(img_pil, config=custom_config) | |
| logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}") | |
| return page_text + "\n" | |
| except Exception as e: | |
| logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}, {log_memory_usage()}") | |
| return "" | |
| async def process_with_gemini(filename: str, raw_text: str): | |
| """Process raw text with Gemini to extract structured data.""" | |
| start_time = time.time() | |
| logger.info(f"Starting Gemini processing for {filename}, {log_memory_usage()}") | |
| text_hash = get_text_hash(raw_text) | |
| if text_hash in structured_data_cache: | |
| logger.info(f"Structured data cache hit for {filename}, {log_memory_usage()}") | |
| return structured_data_cache[text_hash] | |
| if len(raw_text) > 20000: | |
| raw_text = raw_text[:20000] | |
| logger.info(f"Truncated raw text for {filename} to 20000 characters, {log_memory_usage()}") | |
| try: | |
| prompt = f"""You are an intelligent invoice data extractor. Given raw text from an invoice (in English or other languages), | |
| extract key business fields into the specified JSON format. Return each field along with an estimated accuracy score between 0 and 1. | |
| - Accuracy reflects your confidence in the correctness of each field. | |
| - Handle synonyms (e.g., 'total' = 'net', 'tax' = 'GST'/'TDS'). | |
| - Detect currency from symbols ($, ₹, €) or keywords (USD, INR, EUR); default to USD if unclear. | |
| - The 'items' list may have multiple entries, each with detailed attributes. | |
| - If a field is missing or not found, return an empty value (`""` or `0`) and set `accuracy` to `0.0`. | |
| - Convert any date found in format: YYYY-MM-DD | |
| CRITICAL: ALL numeric values must be in full decimal notation. NEVER EVER use scientific notation or exponential format: | |
| - CORRECT: 0.0000009, 0.00000015, 0.0000002, 1500000, 0.00123 | |
| - ABSOLUTELY FORBIDDEN: 9e-7, 9E-7, 1.5e-7, 1.5E-7, 2e-7, 2E-7, 1.5e+6, 1.23e-3, any number with 'e' or 'E' | |
| - For very small numbers like 0.0000009, you MUST write out all the zeros: 0.0000009 | |
| - For large numbers like 1500000, you MUST write out all the digits: 1500000 | |
| - This is MANDATORY for: unit_price, total_price, tax_amount, discount, net_amount, sub_total, tax_total, discount_total, total_amount | |
| - Example: if unit price is 9 * 10^-7, write it as 0.0000009, NOT 9e-7 or 9E-7 | |
| Raw text: | |
| {raw_text} | |
| Output JSON: | |
| {{ | |
| "invoice": {{ | |
| "invoice_number": {{"value": "", "accuracy": 0.0}}, | |
| "invoice_date": {{"value": "", "accuracy": 0.0}}, | |
| "due_date": {{"value": "", "accuracy": 0.0}}, | |
| "purchase_order_number": {{"value": "", "accuracy": 0.0}}, | |
| "vendor": {{ | |
| "vendor_id": {{"value": "", "accuracy": 0.0}}, | |
| "name": {{"value": "", "accuracy": 0.0}}, | |
| "address": {{ | |
| "line1": {{"value": "", "accuracy": 0.0}}, | |
| "line2": {{"value": "", "accuracy": 0.0}}, | |
| "city": {{"value": "", "accuracy": 0.0}}, | |
| "state": {{"value": "", "accuracy": 0.0}}, | |
| "postal_code": {{"value": "", "accuracy": 0.0}}, | |
| "country": {{"value": "", "accuracy": 0.0}} | |
| }}, | |
| "contact": {{ | |
| "email": {{"value": "", "accuracy": 0.0}}, | |
| "phone": {{"value": "", "accuracy": 0.0}} | |
| }}, | |
| "tax_id": {{"value": "", "accuracy": 0.0}} | |
| }}, | |
| "buyer": {{ | |
| "buyer_id": {{"value": "", "accuracy": 0.0}}, | |
| "name": {{"value": "", "accuracy": 0.0}}, | |
| "address": {{ | |
| "line1": {{"value": "", "accuracy": 0.0}}, | |
| "line2": {{"value": "", "accuracy": 0.0}}, | |
| "city": {{"value": "", "accuracy": 0.0}}, | |
| "state": {{"value": "", "accuracy": 0.0}}, | |
| "postal_code": {{"value": "", "accuracy": 0.0}}, | |
| "country": {{"value": "", "accuracy": 0.0}} | |
| }}, | |
| "contact": {{ | |
| "email": {{"value": "", "accuracy": 0.0}}, | |
| "phone": {{"value": "", "accuracy": 0.0}} | |
| }}, | |
| "tax_id": {{"value": "", "accuracy": 0.0}} | |
| }}, | |
| "items": [ | |
| {{ | |
| "item_id": {{"value": "", "accuracy": 0.0}}, | |
| "description": {{"value": "", "accuracy": 0.0}}, | |
| "quantity": {{"value": 0, "accuracy": 0.0}}, | |
| "unit_of_measure": {{"value": "", "accuracy": 0.0}}, | |
| "unit_price": {{"value": 0.0, "accuracy": 0.0}}, | |
| "total_price": {{"value": 0.0, "accuracy": 0.0}}, | |
| "tax_rate": {{"value": 0.0, "accuracy": 0.0}}, | |
| "tax_amount": {{"value": 0.0, "accuracy": 0.0}}, | |
| "discount": {{"value": 0.0, "accuracy": 0.0}}, | |
| "net_amount": {{"value": 0.0, "accuracy": 0.0}} | |
| }} | |
| ], | |
| "sub_total": {{"value": 0.0, "accuracy": 0.0}}, | |
| "tax_total": {{"value": 0.0, "accuracy": 0.0}}, | |
| "discount_total": {{"value": 0.0, "accuracy": 0.0}}, | |
| "total_amount": {{"value": 0.0, "accuracy": 0.0}}, | |
| "currency": {{"value": "", "accuracy": 0.0}} | |
| }} | |
| }} | |
| """ | |
| response = model.generate_content(prompt) | |
| llm_output = response.text | |
| json_start = llm_output.find("{") | |
| json_end = llm_output.rfind("}") + 1 | |
| json_str = llm_output[json_start:json_end] | |
| logger.info(f"Extracted JSON before fix: {json_str}") | |
| json_str = fix_scientific_notation_in_json(json_str) | |
| structured_data = json.loads(json_str, parse_float=custom_decimal_parser) | |
| structured_data = convert_scientific_decimals(structured_data) | |
| structured_data = force_decimal_format(structured_data) | |
| structured_data_cache[text_hash] = structured_data | |
| logger.info(f"Gemini processing for {filename}, took {time.time() - start_time:.2f} seconds, {log_memory_usage()}") | |
| # Log structured data with custom encoder to avoid scientific notation in logs | |
| log_friendly_data = json.dumps(structured_data, default=custom_encoder) | |
| return structured_data | |
| except Exception as e: | |
| logger.error(f"Gemini processing failed for {filename}: {str(e)}, {log_memory_usage()}") | |
| return {"error": f"Gemini processing failed: {str(e)}"} | |
| async def extract_and_structure(files: List[UploadFile] = File(...)): | |
| output_data = { | |
| "success": True, | |
| "message": "", | |
| "data": [] | |
| } | |
| success_count = 0 | |
| fail_count = 0 | |
| logger.info(f"Starting processing for {len(files)} files, {log_memory_usage()}") | |
| for file in files: | |
| total_start_time = time.time() | |
| logger.info(f"Processing file: {file.filename}, {log_memory_usage()}") | |
| valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'} | |
| file_ext = os.path.splitext(file.filename.lower())[1] | |
| if file_ext not in valid_extensions: | |
| fail_count += 1 | |
| output_data["data"].append({ | |
| "filename": file.filename, | |
| "structured_data": {"error": f"Unsupported file format: {file_ext}"}, | |
| "error": f"Unsupported file format: {file_ext}" | |
| }) | |
| logger.error(f"Unsupported file format for {file.filename}: {file_ext}") | |
| continue | |
| try: | |
| file_start_time = time.time() | |
| file_bytes = await file.read() | |
| file_stream = io.BytesIO(file_bytes) | |
| file_hash = get_file_hash(file_bytes) | |
| logger.info(f"Read file {file.filename}, took {time.time() - file_start_time:.2f} seconds, size: {len(file_bytes)/1024:.2f} KB, {log_memory_usage()}") | |
| except Exception as e: | |
| fail_count += 1 | |
| output_data["data"].append({ | |
| "filename": file.filename, | |
| "structured_data": {"error": f"Failed to read file: {str(e)}"}, | |
| "error": f"Failed to read file: {str(e)}" | |
| }) | |
| logger.error(f"Failed to read file {file.filename}: {str(e)}, {log_memory_usage()}") | |
| continue | |
| raw_text = "" | |
| if file_hash in raw_text_cache: | |
| raw_text = raw_text_cache[file_hash] | |
| logger.info(f"Raw text cache hit for {file.filename}, {log_memory_usage()}") | |
| else: | |
| if file_ext == '.pdf': | |
| try: | |
| extract_start_time = time.time() | |
| reader = PdfReader(file_stream) | |
| for page in reader.pages: | |
| text = page.extract_text() | |
| if text: | |
| raw_text += text + "\n" | |
| logger.info(f"Embedded text extraction for {file.filename}, took {time.time() - extract_start_time:.2f} seconds, text length: {len(raw_text)}, {log_memory_usage()}") | |
| except Exception as e: | |
| logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}, {log_memory_usage()}") | |
| if not raw_text.strip(): | |
| try: | |
| convert_start_time = time.time() | |
| images = convert_from_bytes(file_bytes, dpi=150) | |
| logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages, took {time.time() - convert_start_time:.2f} seconds, {log_memory_usage()}") | |
| ocr_tasks = [process_pdf_page(img, i) for i, img in enumerate(images)] | |
| page_texts = await asyncio.gather(*ocr_tasks) | |
| raw_text = "".join(page_texts) | |
| logger.info(f"Total OCR for {file.filename}, text length: {len(raw_text)}, {log_memory_usage()}") | |
| except Exception as e: | |
| fail_count += 1 | |
| output_data["data"].append({ | |
| "filename": file.filename, | |
| "structured_data": {"error": f"OCR failed: {str(e)}"}, | |
| "error": f"OCR failed: {str(e)}" | |
| }) | |
| logger.error(f"OCR failed for {file.filename}: {str(e)}, {log_memory_usage()}") | |
| continue | |
| else: | |
| try: | |
| raw_text = await process_image(file_bytes, file.filename, 0) | |
| logger.info(f"Image OCR for {file.filename}, text length: {len(raw_text)}, {log_memory_usage()}") | |
| except Exception as e: | |
| fail_count += 1 | |
| output_data["data"].append({ | |
| "filename": file.filename, | |
| "structured_data": {"error": f"Image OCR failed: {str(e)}"}, | |
| "error": f"Image OCR failed: {str(e)}" | |
| }) | |
| logger.error(f"Image OCR failed for {file.filename}: {str(e)}, {log_memory_usage()}") | |
| continue | |
| if raw_text: | |
| raw_text = unicodedata.normalize('NFKC', raw_text) | |
| raw_text_cache[file_hash] = raw_text | |
| structured_data = await process_with_gemini(file.filename, raw_text) | |
| if "error" not in structured_data: | |
| success_count += 1 | |
| else: | |
| fail_count += 1 | |
| output_data["data"].append({ | |
| "filename": file.filename, | |
| "structured_data": structured_data, | |
| "error": structured_data.get("error", "") | |
| }) | |
| logger.info(f"Total processing for {file.filename}, took {time.time() - total_start_time:.2f} seconds, {log_memory_usage()}") | |
| output_data["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed." | |
| if fail_count > 0 and success_count == 0: | |
| output_data["success"] = False | |
| logger.info(f"Completed processing for {len(files)} files, {success_count} succeeded, {fail_count} failed, {log_memory_usage()}") | |
| output_data = force_decimal_format(output_data) | |
| encoded_data = json.dumps(output_data, default=custom_encoder) | |
| return JSONResponse(content=json.loads(encoded_data)) |