OCRroute / api.py
hikmat123's picture
make teh pipelien generic and add api layer
937fe69
Raw
History Blame Contribute Delete
4.27 kB
import os
import json
import io
import uuid
import threading
import time
from typing import Optional, Dict
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from PIL import Image
import numpy as np
from dotenv import load_dotenv
# Import functions from app.py
import app
# Load environment variables
load_dotenv()
app_api = FastAPI(title="OCR Route Data Extraction API")
# In-memory storage for jobs and uploaded files
# In a production app, use Redis or a database and a task queue like Celery
jobs: Dict[str, dict] = {}
uploaded_files: Dict[str, bytes] = {}
class APIProgress:
def __init__(self, job_id: str):
self.job_id = job_id
def __call__(self, progress_val, desc=None):
if self.job_id in jobs:
jobs[self.job_id]["progress"] = round(progress_val * 100, 2)
jobs[self.job_id]["status_message"] = desc or "Processing..."
def run_extraction_task(job_id: str, file_id: str, ocr_engine: str, api_key: str):
try:
jobs[job_id]["status"] = "processing"
# Get the file contents
contents = uploaded_files.get(file_id)
if not contents:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = "File not found"
return
image_pil = Image.open(io.BytesIO(contents)).convert("RGB")
# Mock progress
progress_tracker = APIProgress(job_id)
# Call the pipeline from app.py
# run_pipeline(image, ocr_engine, api_key, progress)
json_result_str, debug_info = app.run_pipeline(
image=image_pil,
ocr_engine=ocr_engine,
api_key=api_key,
progress=progress_tracker
)
result = json.loads(json_result_str)
jobs[job_id]["status"] = "completed"
jobs[job_id]["progress"] = 100.0
jobs[job_id]["result"] = result
jobs[job_id]["debug"] = debug_info
# Clean up the file from memory after processing
if file_id in uploaded_files:
del uploaded_files[file_id]
except Exception as e:
import traceback
traceback.print_exc()
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
@app_api.get("/")
async def root():
return {"message": "OCR Route Data Extraction API is running"}
# 1. Endpoint to upload the document
@app_api.post("/upload")
async def upload_document(file: UploadFile = File(...)):
try:
file_id = str(uuid.uuid4())
contents = await file.read()
uploaded_files[file_id] = contents
return {
"file_id": file_id,
"filename": file.filename,
"message": "File uploaded successfully. Use /extract to start processing."
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
# 2. Endpoint to extract and validate data (starts a job)
@app_api.post("/extract")
async def extract_data(
background_tasks: BackgroundTasks,
file_id: str = Form(...),
ocr_engine: str = Form("PaddleOCR"),
api_key: Optional[str] = Form(None)
):
if file_id not in uploaded_files:
raise HTTPException(status_code=404, detail="File ID not found. Please upload first.")
job_id = str(uuid.uuid4())
gemini_api_key = api_key or os.getenv("GEMINI_API_KEY") or ""
jobs[job_id] = {
"status": "pending",
"progress": 0.0,
"status_message": "Starting job...",
"created_at": time.time()
}
# Start the task in the background
background_tasks.add_task(
run_extraction_task,
job_id,
file_id,
ocr_engine,
gemini_api_key
)
return {
"job_id": job_id,
"message": "Extraction job started."
}
# 3. Endpoint to show the running job progress/status
@app_api.get("/job/{job_id}")
async def get_job_status(job_id: str):
if job_id not in jobs:
raise HTTPException(status_code=404, detail="Job ID not found.")
return jobs[job_id]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app_api, host="0.0.0.0", port=8000)