Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import Response, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| import time | |
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from backend.inference import prediction_engine | |
| app = FastAPI(title="Oil Spill Detection API", version="1.0.0", description="API to detect oil spills from satellite images using U-Net") | |
| # Setup CORS for the React frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all for local dev | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Serves the built frontend files | |
| # Ensure 'frontend/dist' exists after the build stage in Docker | |
| if os.path.exists("frontend/dist"): | |
| app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static") | |
| def health_check(): | |
| """Health check endpoint to ensure API and Model are ready.""" | |
| return {"status": "healthy", "model_loaded": prediction_engine.model is not None} | |
| async def predict_spill(file: UploadFile = File(...)): | |
| """Receives an image, performs inference, and returns a PNG mask segmenting the oil spill.""" | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Invalid file format. Upload an image.") | |
| try: | |
| start_time = time.time() | |
| contents = await file.read() | |
| # Perform inference | |
| mask_array, confidence = prediction_engine.predict(contents) | |
| # Convert mask array (256, 256, 1) -> (256, 256) for Pillow | |
| if len(mask_array.shape) == 3 and mask_array.shape[-1] == 1: | |
| mask_array = np.squeeze(mask_array, axis=-1) | |
| mask_image = Image.fromarray(mask_array.astype(np.uint8), mode="L") | |
| # Save Image to bytes | |
| img_byte_arr = io.BytesIO() | |
| mask_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| latency_ms = int((time.time() - start_time) * 1000) | |
| # Include metadata in headers so frontend can read confidence and latency | |
| headers = { | |
| "Access-Control-Expose-Headers": "X-Confidence-Score, X-Inference-Latency-Ms", | |
| "X-Confidence-Score": str(round(confidence * 100, 2)), | |
| "X-Inference-Latency-Ms": str(latency_ms) | |
| } | |
| return Response(content=img_byte_arr, media_type="image/png", headers=headers) | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Hugging Face Spaces defaults to port 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True) | |