Spaces:
Runtime error
Runtime error
| from fastapi import APIRouter, File, UploadFile, Form, HTTPException, status | |
| from fastapi.responses import JSONResponse | |
| from config import settings | |
| from PIL import Image | |
| import urllib.request | |
| from io import BytesIO | |
| import utils | |
| import os | |
| import time | |
| from functools import lru_cache | |
| from paddleocr import PaddleOCR | |
| from pdf2image import convert_from_bytes | |
| import io | |
| import json | |
| from routers.data_utils import merge_data | |
| from routers.data_utils import store_data | |
| import motor.motor_asyncio | |
| from typing import Optional | |
| from pymongo import ASCENDING | |
| from pymongo.errors import DuplicateKeyError | |
| router = APIRouter() | |
| client = None | |
| db = None | |
| async def create_unique_index(collection, *fields): | |
| index_fields = [(field, 1) for field in fields] | |
| return await collection.create_index(index_fields, unique=True) | |
| async def create_ttl_index(db, collection_name, field, expire_after_seconds): | |
| # Get a reference to your collection | |
| collection = db[collection_name] | |
| # Create an index on the specified field | |
| index_result = await collection.create_index([(field, ASCENDING)], expireAfterSeconds=expire_after_seconds) | |
| print(f"TTL index created or already exists: {index_result}") | |
| async def startup_event(): | |
| if "MONGODB_URL" in os.environ: | |
| global client | |
| global db | |
| client = motor.motor_asyncio.AsyncIOMotorClient(os.environ.get("MONGODB_URL")) | |
| db = client.chatgpt_plugin | |
| index_result = await create_unique_index(db['uploads'], 'receipt_key') | |
| print(f"Unique index created or already exists: {index_result}") | |
| index_result = await create_unique_index(db['receipts'], 'user', 'receipt_key') | |
| print(f"Unique index created or already exists: {index_result}") | |
| await create_ttl_index(db, 'uploads', 'created_at', 15*60) | |
| print("Connected to MongoDB from OCR!") | |
| async def shutdown_event(): | |
| if "MONGODB_URL" in os.environ: | |
| global client | |
| client.close() | |
| def load_ocr_model(): | |
| model = PaddleOCR(use_angle_cls=True, lang='en') | |
| return model | |
| def invoke_ocr(doc, content_type): | |
| worker_pid = os.getpid() | |
| print(f"Handling OCR request with worker PID: {worker_pid}") | |
| start_time = time.time() | |
| model = load_ocr_model() | |
| bytes_img = io.BytesIO() | |
| format_img = "JPEG" | |
| if content_type == "image/png": | |
| format_img = "PNG" | |
| doc.save(bytes_img, format=format_img) | |
| bytes_data = bytes_img.getvalue() | |
| bytes_img.close() | |
| result = model.ocr(bytes_data, cls=True) | |
| values = [] | |
| for idx in range(len(result)): | |
| res = result[idx] | |
| for line in res: | |
| values.append(line) | |
| values = merge_data(values) | |
| end_time = time.time() | |
| processing_time = end_time - start_time | |
| print(f"OCR done, worker PID: {worker_pid}") | |
| return values, processing_time | |
| async def run_ocr(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), | |
| post_processing: Optional[bool] = Form(False), sparrow_key: str = Form(None)): | |
| if sparrow_key != settings.sparrow_key: | |
| return {"error": "Invalid Sparrow key."} | |
| result = None | |
| if file: | |
| if file.content_type in ["image/jpeg", "image/jpg", "image/png"]: | |
| doc = Image.open(BytesIO(await file.read())) | |
| elif file.content_type == "application/pdf": | |
| pdf_bytes = await file.read() | |
| pages = convert_from_bytes(pdf_bytes, 300) | |
| doc = pages[0] | |
| else: | |
| return {"error": "Invalid file type. Only JPG/PNG images and PDF are allowed."} | |
| result, processing_time = invoke_ocr(doc, file.content_type) | |
| utils.log_stats(settings.ocr_stats_file, [processing_time, file.filename]) | |
| print(f"Processing time OCR: {processing_time:.2f} seconds") | |
| if post_processing and "MONGODB_URL" in os.environ: | |
| print("Postprocessing...") | |
| try: | |
| result = await store_data(result, db) | |
| except DuplicateKeyError: | |
| return HTTPException(status_code=400, detail=f"Duplicate data.") | |
| print(f"Stored data with key: {result}") | |
| elif image_url: | |
| # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg | |
| # test PDF: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/receipts/2021/us/bestbuy-20211211_006.pdf | |
| with urllib.request.urlopen(image_url) as response: | |
| content_type = response.info().get_content_type() | |
| if content_type in ["image/jpeg", "image/jpg", "image/png"]: | |
| doc = Image.open(BytesIO(response.read())) | |
| elif content_type == "application/octet-stream": | |
| pdf_bytes = response.read() | |
| pages = convert_from_bytes(pdf_bytes, 300) | |
| doc = pages[0] | |
| else: | |
| return {"error": "Invalid file type. Only JPG/PNG images and PDF are allowed."} | |
| result, processing_time = invoke_ocr(doc, content_type) | |
| # parse file name from url | |
| file_name = image_url.split("/")[-1] | |
| utils.log_stats(settings.ocr_stats_file, [processing_time, file_name]) | |
| print(f"Processing time OCR: {processing_time:.2f} seconds") | |
| if post_processing and "MONGODB_URL" in os.environ: | |
| print("Postprocessing...") | |
| try: | |
| result = await store_data(result, db) | |
| except DuplicateKeyError: | |
| return HTTPException(status_code=400, detail=f"Duplicate data.") | |
| print(f"Stored data with key: {result}") | |
| else: | |
| result = {"info": "No input provided"} | |
| if result is None: | |
| raise HTTPException(status_code=400, detail=f"Failed to process the input.") | |
| return JSONResponse(status_code=status.HTTP_200_OK, content=result) | |
| async def get_statistics(): | |
| file_path = settings.ocr_stats_file | |
| # Check if the file exists, and read its content | |
| if os.path.exists(file_path): | |
| with open(file_path, 'r') as file: | |
| try: | |
| content = json.load(file) | |
| except json.JSONDecodeError: | |
| content = [] | |
| else: | |
| content = [] | |
| return content | |