File size: 3,030 Bytes
7a5bb5d
 
 
 
 
 
 
 
 
 
9ddf356
998bc6e
 
 
7a5bb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from model.unet import build_unet
from model.loss_metrics import get_metrics, bce_dice_loss
from backend.utils.model_fetcher import download_if_missing

# Force download if missing before initialization
download_if_missing()

# Path to the saved model weights
MODEL_PATH = os.path.join(os.path.dirname(__file__), "../model/saved_models/oil_spill_unet_best.keras")
IMG_SIZE = (256, 256)

class OilSpillModel:
    def __init__(self):
        self.model = None
        self.load_model()

    def load_model(self):
        print(f"Attempting to load model from: {os.path.abspath(MODEL_PATH)}")
        if os.path.exists(MODEL_PATH):
            try:
                # Provide custom objects if model was saved with custom metrics/loss
                custom_objects = {
                    'bce_dice_loss': bce_dice_loss,
                    'dice_coef': get_metrics()[0],
                    'iou_metric': get_metrics()[1]
                }
                self.model = tf.keras.models.load_model(MODEL_PATH, custom_objects=custom_objects)
                print("Model loaded successfully.")
            except Exception as e:
                print(f"Failed to load model from file: {e}")
                self._build_stub_model()
        else:
            print("Trained model weights not found. Building an untrained stub model for development.")
            self._build_stub_model()

    def _build_stub_model(self):
        """Used for development when trained weights aren't available yet."""
        # Builds architecture with random weights
        self.model = build_unet(input_shape=(256, 256, 3))
        
    def predict(self, image_bytes):
        """
        Takes raw image bytes, preprocesses, predicts, and returns the binary mask array and confidence score.
        """
        # Load image
        img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        
        # Resize to network input shape
        img = img.resize(IMG_SIZE)
        
        # Convert to numpy array and normalize to [0, 1]
        img_array = np.array(img, dtype=np.float32) / 255.0
        
        # Expand dimension to create a batch size of 1
        img_array = np.expand_dims(img_array, axis=0)
        
        # Inference
        pred_mask = self.model.predict(img_array)[0] # Shape is (256, 256, 1)
        
        # Calculate confidence metric:
        # We average the probability of pixels that the network thinks are part of the spill (>0.5 probability)
        oil_pixels = pred_mask[pred_mask > 0.5]
        confidence = float(np.mean(oil_pixels)) if len(oil_pixels) > 0 else 0.0
        
        # Threshold to create binary mask (255 for oil, 0 for background)
        binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
        
        return binary_mask, confidence

# Singleton prediction engine
prediction_engine = OilSpillModel()