Oill_split / backend /main.py
Utkarshres32's picture
Initial commit for Hugging Face deployment
998bc6e
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import Response, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import numpy as np
import time
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from backend.inference import prediction_engine
app = FastAPI(title="Oil Spill Detection API", version="1.0.0", description="API to detect oil spills from satellite images using U-Net")
# Setup CORS for the React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all for local dev
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serves the built frontend files
# Ensure 'frontend/dist' exists after the build stage in Docker
if os.path.exists("frontend/dist"):
app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static")
@app.get("/health")
def health_check():
"""Health check endpoint to ensure API and Model are ready."""
return {"status": "healthy", "model_loaded": prediction_engine.model is not None}
@app.post("/predict")
async def predict_spill(file: UploadFile = File(...)):
"""Receives an image, performs inference, and returns a PNG mask segmenting the oil spill."""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Invalid file format. Upload an image.")
try:
start_time = time.time()
contents = await file.read()
# Perform inference
mask_array, confidence = prediction_engine.predict(contents)
# Convert mask array (256, 256, 1) -> (256, 256) for Pillow
if len(mask_array.shape) == 3 and mask_array.shape[-1] == 1:
mask_array = np.squeeze(mask_array, axis=-1)
mask_image = Image.fromarray(mask_array.astype(np.uint8), mode="L")
# Save Image to bytes
img_byte_arr = io.BytesIO()
mask_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
latency_ms = int((time.time() - start_time) * 1000)
# Include metadata in headers so frontend can read confidence and latency
headers = {
"Access-Control-Expose-Headers": "X-Confidence-Score, X-Inference-Latency-Ms",
"X-Confidence-Score": str(round(confidence * 100, 2)),
"X-Inference-Latency-Ms": str(latency_ms)
}
return Response(content=img_byte_arr, media_type="image/png", headers=headers)
except Exception as e:
print(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Hugging Face Spaces defaults to port 7860
port = int(os.environ.get("PORT", 7860))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)