Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import io | |
| import uuid | |
| import threading | |
| import time | |
| from typing import Optional, Dict | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| # Import functions from app.py | |
| import app | |
| # Load environment variables | |
| load_dotenv() | |
| app_api = FastAPI(title="OCR Route Data Extraction API") | |
| # In-memory storage for jobs and uploaded files | |
| # In a production app, use Redis or a database and a task queue like Celery | |
| jobs: Dict[str, dict] = {} | |
| uploaded_files: Dict[str, bytes] = {} | |
| class APIProgress: | |
| def __init__(self, job_id: str): | |
| self.job_id = job_id | |
| def __call__(self, progress_val, desc=None): | |
| if self.job_id in jobs: | |
| jobs[self.job_id]["progress"] = round(progress_val * 100, 2) | |
| jobs[self.job_id]["status_message"] = desc or "Processing..." | |
| def run_extraction_task(job_id: str, file_id: str, ocr_engine: str, api_key: str): | |
| try: | |
| jobs[job_id]["status"] = "processing" | |
| # Get the file contents | |
| contents = uploaded_files.get(file_id) | |
| if not contents: | |
| jobs[job_id]["status"] = "failed" | |
| jobs[job_id]["error"] = "File not found" | |
| return | |
| image_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Mock progress | |
| progress_tracker = APIProgress(job_id) | |
| # Call the pipeline from app.py | |
| # run_pipeline(image, ocr_engine, api_key, progress) | |
| json_result_str, debug_info = app.run_pipeline( | |
| image=image_pil, | |
| ocr_engine=ocr_engine, | |
| api_key=api_key, | |
| progress=progress_tracker | |
| ) | |
| result = json.loads(json_result_str) | |
| jobs[job_id]["status"] = "completed" | |
| jobs[job_id]["progress"] = 100.0 | |
| jobs[job_id]["result"] = result | |
| jobs[job_id]["debug"] = debug_info | |
| # Clean up the file from memory after processing | |
| if file_id in uploaded_files: | |
| del uploaded_files[file_id] | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| jobs[job_id]["status"] = "failed" | |
| jobs[job_id]["error"] = str(e) | |
| async def root(): | |
| return {"message": "OCR Route Data Extraction API is running"} | |
| # 1. Endpoint to upload the document | |
| async def upload_document(file: UploadFile = File(...)): | |
| try: | |
| file_id = str(uuid.uuid4()) | |
| contents = await file.read() | |
| uploaded_files[file_id] = contents | |
| return { | |
| "file_id": file_id, | |
| "filename": file.filename, | |
| "message": "File uploaded successfully. Use /extract to start processing." | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
| # 2. Endpoint to extract and validate data (starts a job) | |
| async def extract_data( | |
| background_tasks: BackgroundTasks, | |
| file_id: str = Form(...), | |
| ocr_engine: str = Form("PaddleOCR"), | |
| api_key: Optional[str] = Form(None) | |
| ): | |
| if file_id not in uploaded_files: | |
| raise HTTPException(status_code=404, detail="File ID not found. Please upload first.") | |
| job_id = str(uuid.uuid4()) | |
| gemini_api_key = api_key or os.getenv("GEMINI_API_KEY") or "" | |
| jobs[job_id] = { | |
| "status": "pending", | |
| "progress": 0.0, | |
| "status_message": "Starting job...", | |
| "created_at": time.time() | |
| } | |
| # Start the task in the background | |
| background_tasks.add_task( | |
| run_extraction_task, | |
| job_id, | |
| file_id, | |
| ocr_engine, | |
| gemini_api_key | |
| ) | |
| return { | |
| "job_id": job_id, | |
| "message": "Extraction job started." | |
| } | |
| # 3. Endpoint to show the running job progress/status | |
| async def get_job_status(job_id: str): | |
| if job_id not in jobs: | |
| raise HTTPException(status_code=404, detail="Job ID not found.") | |
| return jobs[job_id] | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app_api, host="0.0.0.0", port=8000) | |