File size: 3,017 Bytes
7a5bb5d
 
998bc6e
7a5bb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998bc6e
 
 
 
 
7a5bb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998bc6e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import Response, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import numpy as np
import time

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from backend.inference import prediction_engine

app = FastAPI(title="Oil Spill Detection API", version="1.0.0", description="API to detect oil spills from satellite images using U-Net")

# Setup CORS for the React frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # Allow all for local dev
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serves the built frontend files
# Ensure 'frontend/dist' exists after the build stage in Docker
if os.path.exists("frontend/dist"):
    app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static")

@app.get("/health")
def health_check():
    """Health check endpoint to ensure API and Model are ready."""
    return {"status": "healthy", "model_loaded": prediction_engine.model is not None}

@app.post("/predict")
async def predict_spill(file: UploadFile = File(...)):
    """Receives an image, performs inference, and returns a PNG mask segmenting the oil spill."""
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="Invalid file format. Upload an image.")
    
    try:
        start_time = time.time()
        
        contents = await file.read()
        
        # Perform inference
        mask_array, confidence = prediction_engine.predict(contents)
        
        # Convert mask array (256, 256, 1) -> (256, 256) for Pillow
        if len(mask_array.shape) == 3 and mask_array.shape[-1] == 1:
            mask_array = np.squeeze(mask_array, axis=-1)
            
        mask_image = Image.fromarray(mask_array.astype(np.uint8), mode="L")
        
        # Save Image to bytes
        img_byte_arr = io.BytesIO()
        mask_image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        
        latency_ms = int((time.time() - start_time) * 1000)
        
        # Include metadata in headers so frontend can read confidence and latency
        headers = {
            "Access-Control-Expose-Headers": "X-Confidence-Score, X-Inference-Latency-Ms",
            "X-Confidence-Score": str(round(confidence * 100, 2)),
            "X-Inference-Latency-Ms": str(latency_ms)
        }
        
        return Response(content=img_byte_arr, media_type="image/png", headers=headers)
        
    except Exception as e:
        print(f"Error during prediction: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    # Hugging Face Spaces defaults to port 7860
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)