import os import sys import uuid import time import logging import shutil import tempfile from typing import Optional, List from enum import Enum from pathlib import Path # Third-party imports import uvicorn import pytesseract from fastapi import ( FastAPI, File, UploadFile, Depends, HTTPException, Request, status ) from fastapi.middleware.cors import CORSMiddleware from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.responses import JSONResponse from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel from dotenv import load_dotenv from PIL import Image from pdf2image import convert_from_path # ========================================== # 1. CONFIGURATION & ENV LOADING # ========================================== load_dotenv() class Config: APP_NAME = os.getenv("APP_NAME", "OCR API") API_TOKEN = os.getenv("API_BEARER_TOKEN") MAX_SIZE = int(os.getenv("MAX_FILE_SIZE", 52428800)) # Parse allowed origins from comma-separated string allowed_origins_raw = os.getenv("ALLOWED_ORIGINS") ALLOWED_ORIGINS = [origin.strip() for origin in allowed_origins_raw.split(",") if origin.strip()] if allowed_origins_raw else [] ALLOWED_TYPES = ["image/jpeg", "image/png", "image/bmp", "image/webp", "application/pdf"] # Validation check on startup if not Config.API_TOKEN: print("CRITICAL WARNING: API_BEARER_TOKEN is not set in .env") # ========================================== # 2. LOGGING SETUP # ========================================== class RequestIdFilter(logging.Filter): def filter(self, record): if not hasattr(record, 'request_id'): record.request_id = 'system' return True logging.basicConfig( level=logging.INFO, format='%(asctime)s | %(levelname)s | ReqID:%(request_id)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger("ocr_api") logger.addFilter(RequestIdFilter()) # ========================================== # 3. PYDANTIC MODELS # ========================================== class StatusEnum(str, Enum): SUCCESS = "success" ERROR = "error" class BaseResponse(BaseModel): request_id: str process_time_ms: float status: StatusEnum message: Optional[str] = None class OCRResult(BaseModel): filename: str content_type: str pages: int text: str class APIResponse(BaseResponse): data: Optional[OCRResult] = None error_message: Optional[str] = None # ========================================== # 4. BUSINESS LOGIC SERVICES # ========================================== class SecurityService: security_scheme = HTTPBearer() @staticmethod async def validate_token(credentials: HTTPAuthorizationCredentials = Depends(security_scheme)): """ Validates the Bearer token. """ if credentials.credentials != Config.API_TOKEN: logger.warning(f"Auth Failed. Token used: {credentials.credentials[:5]}...") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Bearer Token", headers={"WWW-Authenticate": "Bearer"}, ) return credentials.credentials class FileValidator: @staticmethod def validate(file: UploadFile): if file.content_type not in Config.ALLOWED_TYPES: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid file type. Allowed: {Config.ALLOWED_TYPES}" ) @staticmethod def check_size_and_save(file: UploadFile) -> str: """ Stream file to disk to check size without loading entire file into RAM. Returns path to temp file. """ try: suffix = Path(file.filename).suffix # Create a named temp file that persists so Tesseract can read it with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as buffer: shutil.copyfileobj(file.file, buffer) tmp_path = buffer.name # Check size file_size = os.path.getsize(tmp_path) if file_size > Config.MAX_SIZE: os.remove(tmp_path) raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail=f"File size exceeds limit of {Config.MAX_SIZE / (1024*1024)}MB" ) return tmp_path except HTTPException: raise except Exception as e: logger.error(f"File save error: {e}") raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "File upload failed") class OCRProcessor: @classmethod def process_file(cls, file_path: str, content_type: str) -> dict: """ Heavy CPU Logic. """ start = time.perf_counter() text = "" pages = 1 try: if content_type == "application/pdf": # Convert PDF to images images = convert_from_path(file_path) pages = len(images) # Extract text from each page page_texts = [] for idx, img in enumerate(images): page_texts.append(f"--- Page {idx+1} ---\n{pytesseract.image_to_string(img)}") text = "\n\n".join(page_texts) else: # Standard Image text = pytesseract.image_to_string(Image.open(file_path)) duration = (time.perf_counter() - start) * 1000 logger.info(f"OCR CPU Engine finished in {duration:.2f}ms") return {"pages": pages, "text": text} except Exception as e: logger.error(f"OCR Extraction Error: {str(e)}") raise ValueError("Failed to extract text from document") # ========================================== # 5. FASTAPI APP INIT # ========================================== app = FastAPI( title=Config.APP_NAME, version="1.0.0", docs_url="/docs", # You can disable this in prod by setting to None redoc_url=None ) # STRICT CORS CONFIGURATION app.add_middleware( CORSMiddleware, allow_origins=Config.ALLOWED_ORIGINS, # Loaded from .env allow_credentials=True, allow_methods=["GET", "POST"], # STRICT: Only GET and POST allowed allow_headers=["Authorization", "Content-Type", "X-Request-ID"], # Allowed headers ) # Middleware: Request ID & Logging @app.middleware("http") async def request_context_middleware(request: Request, call_next): req_id = str(uuid.uuid4()) request.state.request_id = req_id # Inject ID into logger old_factory = logging.getLogRecordFactory() def record_factory(*args, **kwargs): record = old_factory(*args, **kwargs) record.request_id = req_id return record logging.setLogRecordFactory(record_factory) start_time = time.perf_counter() logger.info(f"Incoming Request: {request.method} {request.url.path} | Origin: {request.headers.get('origin', 'unknown')}") try: response = await call_next(request) process_time = (time.perf_counter() - start_time) * 1000 response.headers["X-Request-ID"] = req_id response.headers["X-Process-Time"] = f"{process_time:.2f}ms" logger.info(f"Response: {response.status_code} | Time: {process_time:.2f}ms") return response except Exception as e: logger.exception("Unhandled Exception in Middleware") return JSONResponse( status_code=500, content={"status": "error", "message": "Internal Server Error", "request_id": req_id} ) # ========================================== # 6. ENDPOINTS # ========================================== @app.get("/", response_model=BaseResponse) async def root(request: Request): """Simple connectivity check.""" return { "request_id": request.state.request_id, "process_time_ms": 0, "status": StatusEnum.SUCCESS, "message": "OCR API is running." } @app.get("/api/v1/ping", response_model=BaseResponse) async def health_check(request: Request): """Docker Healthcheck Endpoint.""" return { "request_id": request.state.request_id, "process_time_ms": 0, "status": StatusEnum.SUCCESS, "message": "OCR API is healthy." } @app.post("/api/v1/get_data", response_model=APIResponse) async def extract_data( request: Request, file: UploadFile = File(...), token: str = Depends(SecurityService.validate_token) ): """ Main OCR Endpoint. Non-blocking: Offloads OCR to threadpool. """ start_ts = time.perf_counter() tmp_file_path = None try: # 1. Validate File Type FileValidator.validate(file) # 2. Save File (IO Bound) tmp_file_path = FileValidator.check_size_and_save(file) # 3. Process (CPU Bound) - Run in ThreadPool for Non-Blocking result = await run_in_threadpool( OCRProcessor.process_file, tmp_file_path, file.content_type ) return { "request_id": request.state.request_id, "process_time_ms": (time.perf_counter() - start_ts) * 1000, "status": StatusEnum.SUCCESS, "message": "OCR Extraction Successful", "data": { "filename": file.filename, "content_type": file.content_type, "pages": result["pages"], "text": result["text"] } } except HTTPException as he: raise he except ValueError as ve: # OCR logic errors return { "request_id": request.state.request_id, "process_time_ms": (time.perf_counter() - start_ts) * 1000, "status": StatusEnum.ERROR, "error_message": str(ve) } except Exception as e: logger.error(f"Unexpected API Error: {e}") return { "request_id": request.state.request_id, "process_time_ms": (time.perf_counter() - start_ts) * 1000, "status": StatusEnum.ERROR, "error_message": "An unexpected error occurred." } finally: # Cleanup temp file if tmp_file_path and os.path.exists(tmp_file_path): try: os.remove(tmp_file_path) logger.info("Temp file deleted.") except OSError: pass