File size: 4,876 Bytes
fd6cceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import joblib
import numpy as np
from PIL import Image

# --- Imports from your training script ---
import os
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image

# --- 1. Configuration (from training) ---
IMG_WIDTH = 224
IMG_HEIGHT = 224

# --- 2. Load All Models (Run once on startup) ---
print("Loading all models...")

# Load the SVM and Scaler
try:
    svm_model = joblib.load("svm_classifier.pkl")
    scaler = joblib.load("scaler.pkl")
    print("SVM and Scaler loaded.")
except Exception as e:
    print(f"CRITICAL ERROR: Could not load .pkl files: {e}")
    # This will stop the app if models are missing
    raise FileNotFoundError("Could not find svm_classifier.pkl or scaler.pkl")

# Load the ResNet50 feature extractor
try:
    feature_extractor = ResNet50(weights='imagenet', 
                                 include_top=False, 
                                 pooling='avg', 
                                 input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
    print("ResNet50 feature extractor loaded.")
except Exception as e:
    print(f"CRITICAL ERROR: Could not load ResNet50: {e}")
    # This often happens if tensorflow is not installed
    raise e

print("--- All models loaded successfully! ---")


# --- 3. The Corrected Feature Extraction Function ---
def extract_features(pil_image):
    """

    Processes a single PIL image and extracts ResNet50 features,

    replicating the logic from train_classifier.py.

    """
    
    # 1. Resize the image to match model's expected input (224, 224)
    # We use PIL's resize, as the input is already a PIL object
    pil_image_resized = pil_image.resize((IMG_WIDTH, IMG_HEIGHT))
    
    # 2. Convert PIL image to NumPy array (shape: 224, 224, 3)
    img_array = image.img_to_array(pil_image_resized)
    
    # 3. Add batch dimension (model expects 1, 224, 224, 3)
    img_array_expanded = np.expand_dims(img_array, axis=0)
    
    # 4. Preprocess the image for ResNet50 (handles color/pixel scaling)
    img_preprocessed = preprocess_input(img_array_expanded)
    
    # 5. Get the feature vector (shape: 1, 2048)
    features = feature_extractor.predict(img_preprocessed)
    
    # 6. Return the flattened 1D feature vector (shape: 2048,)
    return features.flatten()


# --- 4. The Main Prediction Function (Now More Robust) ---
def predict(input_image):
    """

    The main prediction function called by Gradio.

    """
    if not input_image:
        return None # Handle empty input

    # 1. Extract features using the ResNet50 function
    try:
        # features_1d will have shape (2048,)
        features_1d = extract_features(input_image)
    except Exception as e:
        print(f"Error extracting features: {e}")
        # gr.Error shows a clean error message in the UI
        raise gr.Error(f"Feature Extraction Failed: {e}")

    # 2. Reshape to 2D for the scaler (shape 1, 2048)
    features_2d = features_1d.reshape(1, -1)
    
    # Check shape just in case
    if features_2d.shape[1] != scaler.n_features_in_:
         raise gr.Error(
            f"Feature Mismatch! Model expects {scaler.n_features_in_} features, "
            f"but got {features_2d.shape[1]}."
        )

    # 3. Scale the features
    try:
        scaled_features = scaler.transform(features_2d)
    except Exception as e:
        print(f"Error scaling features: {e}")
        raise gr.Error(f"Feature Scaling Failed: {e}")

    # 4. Predict probabilities
    try:
        # Ensure your SVM was trained with probability=True
        probabilities = svm_model.predict_proba(scaled_features)[0]
        class_labels = svm_model.classes_
        
        # Create a {label: probability} dictionary
        confidences = {label: float(prob) for label, prob in zip(class_labels, probabilities)}
        return confidences
        
    except AttributeError:
        # Fallback if probability=False
        prediction = svm_model.predict(scaled_features)[0]
        return {str(prediction): 1.0} # Return definite prediction
    except Exception as e:
        print(f"Error during prediction: {e}")
        raise gr.Error(f"Prediction Failed: {e}")


# --- 5. Create and Launch the Gradio Interface ---
image_input = gr.Image(type="pil", label="Upload Otolith Image")
label_output = gr.Label(num_top_classes=3, label="Classification Results")

app = gr.Interface(
    fn=predict,
    inputs=image_input,
    outputs=label_output,
    title="Otolith Classification Engine",
    description="Upload an image of an otolith to classify it. This app uses a ResNet50 feature extractor and an SVM classifier."
)

if __name__ == "__main__":
    app.launch()