Spaces:
Sleeping
Sleeping
File size: 4,876 Bytes
fd6cceb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import gradio as gr
import joblib
import numpy as np
from PIL import Image
# --- Imports from your training script ---
import os
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
# --- 1. Configuration (from training) ---
IMG_WIDTH = 224
IMG_HEIGHT = 224
# --- 2. Load All Models (Run once on startup) ---
print("Loading all models...")
# Load the SVM and Scaler
try:
svm_model = joblib.load("svm_classifier.pkl")
scaler = joblib.load("scaler.pkl")
print("SVM and Scaler loaded.")
except Exception as e:
print(f"CRITICAL ERROR: Could not load .pkl files: {e}")
# This will stop the app if models are missing
raise FileNotFoundError("Could not find svm_classifier.pkl or scaler.pkl")
# Load the ResNet50 feature extractor
try:
feature_extractor = ResNet50(weights='imagenet',
include_top=False,
pooling='avg',
input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
print("ResNet50 feature extractor loaded.")
except Exception as e:
print(f"CRITICAL ERROR: Could not load ResNet50: {e}")
# This often happens if tensorflow is not installed
raise e
print("--- All models loaded successfully! ---")
# --- 3. The Corrected Feature Extraction Function ---
def extract_features(pil_image):
"""
Processes a single PIL image and extracts ResNet50 features,
replicating the logic from train_classifier.py.
"""
# 1. Resize the image to match model's expected input (224, 224)
# We use PIL's resize, as the input is already a PIL object
pil_image_resized = pil_image.resize((IMG_WIDTH, IMG_HEIGHT))
# 2. Convert PIL image to NumPy array (shape: 224, 224, 3)
img_array = image.img_to_array(pil_image_resized)
# 3. Add batch dimension (model expects 1, 224, 224, 3)
img_array_expanded = np.expand_dims(img_array, axis=0)
# 4. Preprocess the image for ResNet50 (handles color/pixel scaling)
img_preprocessed = preprocess_input(img_array_expanded)
# 5. Get the feature vector (shape: 1, 2048)
features = feature_extractor.predict(img_preprocessed)
# 6. Return the flattened 1D feature vector (shape: 2048,)
return features.flatten()
# --- 4. The Main Prediction Function (Now More Robust) ---
def predict(input_image):
"""
The main prediction function called by Gradio.
"""
if not input_image:
return None # Handle empty input
# 1. Extract features using the ResNet50 function
try:
# features_1d will have shape (2048,)
features_1d = extract_features(input_image)
except Exception as e:
print(f"Error extracting features: {e}")
# gr.Error shows a clean error message in the UI
raise gr.Error(f"Feature Extraction Failed: {e}")
# 2. Reshape to 2D for the scaler (shape 1, 2048)
features_2d = features_1d.reshape(1, -1)
# Check shape just in case
if features_2d.shape[1] != scaler.n_features_in_:
raise gr.Error(
f"Feature Mismatch! Model expects {scaler.n_features_in_} features, "
f"but got {features_2d.shape[1]}."
)
# 3. Scale the features
try:
scaled_features = scaler.transform(features_2d)
except Exception as e:
print(f"Error scaling features: {e}")
raise gr.Error(f"Feature Scaling Failed: {e}")
# 4. Predict probabilities
try:
# Ensure your SVM was trained with probability=True
probabilities = svm_model.predict_proba(scaled_features)[0]
class_labels = svm_model.classes_
# Create a {label: probability} dictionary
confidences = {label: float(prob) for label, prob in zip(class_labels, probabilities)}
return confidences
except AttributeError:
# Fallback if probability=False
prediction = svm_model.predict(scaled_features)[0]
return {str(prediction): 1.0} # Return definite prediction
except Exception as e:
print(f"Error during prediction: {e}")
raise gr.Error(f"Prediction Failed: {e}")
# --- 5. Create and Launch the Gradio Interface ---
image_input = gr.Image(type="pil", label="Upload Otolith Image")
label_output = gr.Label(num_top_classes=3, label="Classification Results")
app = gr.Interface(
fn=predict,
inputs=image_input,
outputs=label_output,
title="Otolith Classification Engine",
description="Upload an image of an otolith to classify it. This app uses a ResNet50 feature extractor and an SVM classifier."
)
if __name__ == "__main__":
app.launch() |