import os import json import io import uuid import threading import time from typing import Optional, Dict from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks from fastapi.responses import JSONResponse from PIL import Image import numpy as np from dotenv import load_dotenv # Import functions from app.py import app # Load environment variables load_dotenv() app_api = FastAPI(title="OCR Route Data Extraction API") # In-memory storage for jobs and uploaded files # In a production app, use Redis or a database and a task queue like Celery jobs: Dict[str, dict] = {} uploaded_files: Dict[str, bytes] = {} class APIProgress: def __init__(self, job_id: str): self.job_id = job_id def __call__(self, progress_val, desc=None): if self.job_id in jobs: jobs[self.job_id]["progress"] = round(progress_val * 100, 2) jobs[self.job_id]["status_message"] = desc or "Processing..." def run_extraction_task(job_id: str, file_id: str, ocr_engine: str, api_key: str): try: jobs[job_id]["status"] = "processing" # Get the file contents contents = uploaded_files.get(file_id) if not contents: jobs[job_id]["status"] = "failed" jobs[job_id]["error"] = "File not found" return image_pil = Image.open(io.BytesIO(contents)).convert("RGB") # Mock progress progress_tracker = APIProgress(job_id) # Call the pipeline from app.py # run_pipeline(image, ocr_engine, api_key, progress) json_result_str, debug_info = app.run_pipeline( image=image_pil, ocr_engine=ocr_engine, api_key=api_key, progress=progress_tracker ) result = json.loads(json_result_str) jobs[job_id]["status"] = "completed" jobs[job_id]["progress"] = 100.0 jobs[job_id]["result"] = result jobs[job_id]["debug"] = debug_info # Clean up the file from memory after processing if file_id in uploaded_files: del uploaded_files[file_id] except Exception as e: import traceback traceback.print_exc() jobs[job_id]["status"] = "failed" jobs[job_id]["error"] = str(e) @app_api.get("/") async def root(): return {"message": "OCR Route Data Extraction API is running"} # 1. Endpoint to upload the document @app_api.post("/upload") async def upload_document(file: UploadFile = File(...)): try: file_id = str(uuid.uuid4()) contents = await file.read() uploaded_files[file_id] = contents return { "file_id": file_id, "filename": file.filename, "message": "File uploaded successfully. Use /extract to start processing." } except Exception as e: raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") # 2. Endpoint to extract and validate data (starts a job) @app_api.post("/extract") async def extract_data( background_tasks: BackgroundTasks, file_id: str = Form(...), ocr_engine: str = Form("PaddleOCR"), api_key: Optional[str] = Form(None) ): if file_id not in uploaded_files: raise HTTPException(status_code=404, detail="File ID not found. Please upload first.") job_id = str(uuid.uuid4()) gemini_api_key = api_key or os.getenv("GEMINI_API_KEY") or "" jobs[job_id] = { "status": "pending", "progress": 0.0, "status_message": "Starting job...", "created_at": time.time() } # Start the task in the background background_tasks.add_task( run_extraction_task, job_id, file_id, ocr_engine, gemini_api_key ) return { "job_id": job_id, "message": "Extraction job started." } # 3. Endpoint to show the running job progress/status @app_api.get("/job/{job_id}") async def get_job_status(job_id: str): if job_id not in jobs: raise HTTPException(status_code=404, detail="Job ID not found.") return jobs[job_id] if __name__ == "__main__": import uvicorn uvicorn.run(app_api, host="0.0.0.0", port=8000)