Oill_split / backend /inference.py
Utkarshres32's picture
Fix backend imports and package initialization
9ddf356
import os
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from model.unet import build_unet
from model.loss_metrics import get_metrics, bce_dice_loss
from backend.utils.model_fetcher import download_if_missing
# Force download if missing before initialization
download_if_missing()
# Path to the saved model weights
MODEL_PATH = os.path.join(os.path.dirname(__file__), "../model/saved_models/oil_spill_unet_best.keras")
IMG_SIZE = (256, 256)
class OilSpillModel:
def __init__(self):
self.model = None
self.load_model()
def load_model(self):
print(f"Attempting to load model from: {os.path.abspath(MODEL_PATH)}")
if os.path.exists(MODEL_PATH):
try:
# Provide custom objects if model was saved with custom metrics/loss
custom_objects = {
'bce_dice_loss': bce_dice_loss,
'dice_coef': get_metrics()[0],
'iou_metric': get_metrics()[1]
}
self.model = tf.keras.models.load_model(MODEL_PATH, custom_objects=custom_objects)
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model from file: {e}")
self._build_stub_model()
else:
print("Trained model weights not found. Building an untrained stub model for development.")
self._build_stub_model()
def _build_stub_model(self):
"""Used for development when trained weights aren't available yet."""
# Builds architecture with random weights
self.model = build_unet(input_shape=(256, 256, 3))
def predict(self, image_bytes):
"""
Takes raw image bytes, preprocesses, predicts, and returns the binary mask array and confidence score.
"""
# Load image
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Resize to network input shape
img = img.resize(IMG_SIZE)
# Convert to numpy array and normalize to [0, 1]
img_array = np.array(img, dtype=np.float32) / 255.0
# Expand dimension to create a batch size of 1
img_array = np.expand_dims(img_array, axis=0)
# Inference
pred_mask = self.model.predict(img_array)[0] # Shape is (256, 256, 1)
# Calculate confidence metric:
# We average the probability of pixels that the network thinks are part of the spill (>0.5 probability)
oil_pixels = pred_mask[pred_mask > 0.5]
confidence = float(np.mean(oil_pixels)) if len(oil_pixels) > 0 else 0.0
# Threshold to create binary mask (255 for oil, 0 for background)
binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
return binary_mask, confidence
# Singleton prediction engine
prediction_engine = OilSpillModel()