from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import Response, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io import numpy as np import time import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from backend.inference import prediction_engine app = FastAPI(title="Oil Spill Detection API", version="1.0.0", description="API to detect oil spills from satellite images using U-Net") # Setup CORS for the React frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all for local dev allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Serves the built frontend files # Ensure 'frontend/dist' exists after the build stage in Docker if os.path.exists("frontend/dist"): app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static") @app.get("/health") def health_check(): """Health check endpoint to ensure API and Model are ready.""" return {"status": "healthy", "model_loaded": prediction_engine.model is not None} @app.post("/predict") async def predict_spill(file: UploadFile = File(...)): """Receives an image, performs inference, and returns a PNG mask segmenting the oil spill.""" if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Invalid file format. Upload an image.") try: start_time = time.time() contents = await file.read() # Perform inference mask_array, confidence = prediction_engine.predict(contents) # Convert mask array (256, 256, 1) -> (256, 256) for Pillow if len(mask_array.shape) == 3 and mask_array.shape[-1] == 1: mask_array = np.squeeze(mask_array, axis=-1) mask_image = Image.fromarray(mask_array.astype(np.uint8), mode="L") # Save Image to bytes img_byte_arr = io.BytesIO() mask_image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() latency_ms = int((time.time() - start_time) * 1000) # Include metadata in headers so frontend can read confidence and latency headers = { "Access-Control-Expose-Headers": "X-Confidence-Score, X-Inference-Latency-Ms", "X-Confidence-Score": str(round(confidence * 100, 2)), "X-Inference-Latency-Ms": str(latency_ms) } return Response(content=img_byte_arr, media_type="image/png", headers=headers) except Exception as e: print(f"Error during prediction: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn # Hugging Face Spaces defaults to port 7860 port = int(os.environ.get("PORT", 7860)) uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)