ocr-api / main.py
Soumik Bose
Initial commit
8607b18
raw
history blame
10.6 kB
import os
import sys
import uuid
import time
import logging
import shutil
import tempfile
from typing import Optional, List
from enum import Enum
from pathlib import Path
# Third-party imports
import uvicorn
import pytesseract
from fastapi import (
FastAPI, File, UploadFile, Depends,
HTTPException, Request, status
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from dotenv import load_dotenv
from PIL import Image
from pdf2image import convert_from_path
# ==========================================
# 1. CONFIGURATION & ENV LOADING
# ==========================================
load_dotenv()
class Config:
APP_NAME = os.getenv("APP_NAME", "OCR API")
API_TOKEN = os.getenv("API_BEARER_TOKEN")
MAX_SIZE = int(os.getenv("MAX_FILE_SIZE", 52428800))
# Parse allowed origins from comma-separated string
allowed_origins_raw = os.getenv("ALLOWED_ORIGINS")
ALLOWED_ORIGINS = [origin.strip() for origin in allowed_origins_raw.split(",") if origin.strip()] if allowed_origins_raw else []
ALLOWED_TYPES = ["image/jpeg", "image/png", "image/bmp", "image/webp", "application/pdf"]
# Validation check on startup
if not Config.API_TOKEN:
print("CRITICAL WARNING: API_BEARER_TOKEN is not set in .env")
# ==========================================
# 2. LOGGING SETUP
# ==========================================
class RequestIdFilter(logging.Filter):
def filter(self, record):
if not hasattr(record, 'request_id'):
record.request_id = 'system'
return True
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s | %(levelname)s | ReqID:%(request_id)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger("ocr_api")
logger.addFilter(RequestIdFilter())
# ==========================================
# 3. PYDANTIC MODELS
# ==========================================
class StatusEnum(str, Enum):
SUCCESS = "success"
ERROR = "error"
class BaseResponse(BaseModel):
request_id: str
process_time_ms: float
status: StatusEnum
message: Optional[str] = None
class OCRResult(BaseModel):
filename: str
content_type: str
pages: int
text: str
class APIResponse(BaseResponse):
data: Optional[OCRResult] = None
error_message: Optional[str] = None
# ==========================================
# 4. BUSINESS LOGIC SERVICES
# ==========================================
class SecurityService:
security_scheme = HTTPBearer()
@staticmethod
async def validate_token(credentials: HTTPAuthorizationCredentials = Depends(security_scheme)):
"""
Validates the Bearer token.
"""
if credentials.credentials != Config.API_TOKEN:
logger.warning(f"Auth Failed. Token used: {credentials.credentials[:5]}...")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Bearer Token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials
class FileValidator:
@staticmethod
def validate(file: UploadFile):
if file.content_type not in Config.ALLOWED_TYPES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed: {Config.ALLOWED_TYPES}"
)
@staticmethod
def check_size_and_save(file: UploadFile) -> str:
"""
Stream file to disk to check size without loading entire file into RAM.
Returns path to temp file.
"""
try:
suffix = Path(file.filename).suffix
# Create a named temp file that persists so Tesseract can read it
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as buffer:
shutil.copyfileobj(file.file, buffer)
tmp_path = buffer.name
# Check size
file_size = os.path.getsize(tmp_path)
if file_size > Config.MAX_SIZE:
os.remove(tmp_path)
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File size exceeds limit of {Config.MAX_SIZE / (1024*1024)}MB"
)
return tmp_path
except HTTPException:
raise
except Exception as e:
logger.error(f"File save error: {e}")
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "File upload failed")
class OCRProcessor:
@classmethod
def process_file(cls, file_path: str, content_type: str) -> dict:
"""
Heavy CPU Logic.
"""
start = time.perf_counter()
text = ""
pages = 1
try:
if content_type == "application/pdf":
# Convert PDF to images
images = convert_from_path(file_path)
pages = len(images)
# Extract text from each page
page_texts = []
for idx, img in enumerate(images):
page_texts.append(f"--- Page {idx+1} ---\n{pytesseract.image_to_string(img)}")
text = "\n\n".join(page_texts)
else:
# Standard Image
text = pytesseract.image_to_string(Image.open(file_path))
duration = (time.perf_counter() - start) * 1000
logger.info(f"OCR CPU Engine finished in {duration:.2f}ms")
return {"pages": pages, "text": text}
except Exception as e:
logger.error(f"OCR Extraction Error: {str(e)}")
raise ValueError("Failed to extract text from document")
# ==========================================
# 5. FASTAPI APP INIT
# ==========================================
app = FastAPI(
title=Config.APP_NAME,
version="1.0.0",
docs_url="/docs", # You can disable this in prod by setting to None
redoc_url=None
)
# STRICT CORS CONFIGURATION
app.add_middleware(
CORSMiddleware,
allow_origins=Config.ALLOWED_ORIGINS, # Loaded from .env
allow_credentials=True,
allow_methods=["GET", "POST"], # STRICT: Only GET and POST allowed
allow_headers=["Authorization", "Content-Type", "X-Request-ID"], # Allowed headers
)
# Middleware: Request ID & Logging
@app.middleware("http")
async def request_context_middleware(request: Request, call_next):
req_id = str(uuid.uuid4())
request.state.request_id = req_id
# Inject ID into logger
old_factory = logging.getLogRecordFactory()
def record_factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
record.request_id = req_id
return record
logging.setLogRecordFactory(record_factory)
start_time = time.perf_counter()
logger.info(f"Incoming Request: {request.method} {request.url.path} | Origin: {request.headers.get('origin', 'unknown')}")
try:
response = await call_next(request)
process_time = (time.perf_counter() - start_time) * 1000
response.headers["X-Request-ID"] = req_id
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
logger.info(f"Response: {response.status_code} | Time: {process_time:.2f}ms")
return response
except Exception as e:
logger.exception("Unhandled Exception in Middleware")
return JSONResponse(
status_code=500,
content={"status": "error", "message": "Internal Server Error", "request_id": req_id}
)
# ==========================================
# 6. ENDPOINTS
# ==========================================
@app.get("/", response_model=BaseResponse)
async def root(request: Request):
"""Simple connectivity check."""
return {
"request_id": request.state.request_id,
"process_time_ms": 0,
"status": StatusEnum.SUCCESS,
"message": "OCR API is running."
}
@app.get("/api/v1/ping", response_model=BaseResponse)
async def health_check(request: Request):
"""Docker Healthcheck Endpoint."""
return {
"request_id": request.state.request_id,
"process_time_ms": 0,
"status": StatusEnum.SUCCESS,
"message": "OCR API is healthy."
}
@app.post("/api/v1/get_data", response_model=APIResponse)
async def extract_data(
request: Request,
file: UploadFile = File(...),
token: str = Depends(SecurityService.validate_token)
):
"""
Main OCR Endpoint.
Non-blocking: Offloads OCR to threadpool.
"""
start_ts = time.perf_counter()
tmp_file_path = None
try:
# 1. Validate File Type
FileValidator.validate(file)
# 2. Save File (IO Bound)
tmp_file_path = FileValidator.check_size_and_save(file)
# 3. Process (CPU Bound) - Run in ThreadPool for Non-Blocking
result = await run_in_threadpool(
OCRProcessor.process_file,
tmp_file_path,
file.content_type
)
return {
"request_id": request.state.request_id,
"process_time_ms": (time.perf_counter() - start_ts) * 1000,
"status": StatusEnum.SUCCESS,
"message": "OCR Extraction Successful",
"data": {
"filename": file.filename,
"content_type": file.content_type,
"pages": result["pages"],
"text": result["text"]
}
}
except HTTPException as he:
raise he
except ValueError as ve:
# OCR logic errors
return {
"request_id": request.state.request_id,
"process_time_ms": (time.perf_counter() - start_ts) * 1000,
"status": StatusEnum.ERROR,
"error_message": str(ve)
}
except Exception as e:
logger.error(f"Unexpected API Error: {e}")
return {
"request_id": request.state.request_id,
"process_time_ms": (time.perf_counter() - start_ts) * 1000,
"status": StatusEnum.ERROR,
"error_message": "An unexpected error occurred."
}
finally:
# Cleanup temp file
if tmp_file_path and os.path.exists(tmp_file_path):
try:
os.remove(tmp_file_path)
logger.info("Temp file deleted.")
except OSError:
pass