chinmay0805's picture
Upload 4 files
fd6cceb verified
import gradio as gr
import joblib
import numpy as np
from PIL import Image
# --- Imports from your training script ---
import os
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
# --- 1. Configuration (from training) ---
IMG_WIDTH = 224
IMG_HEIGHT = 224
# --- 2. Load All Models (Run once on startup) ---
print("Loading all models...")
# Load the SVM and Scaler
try:
svm_model = joblib.load("svm_classifier.pkl")
scaler = joblib.load("scaler.pkl")
print("SVM and Scaler loaded.")
except Exception as e:
print(f"CRITICAL ERROR: Could not load .pkl files: {e}")
# This will stop the app if models are missing
raise FileNotFoundError("Could not find svm_classifier.pkl or scaler.pkl")
# Load the ResNet50 feature extractor
try:
feature_extractor = ResNet50(weights='imagenet',
include_top=False,
pooling='avg',
input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
print("ResNet50 feature extractor loaded.")
except Exception as e:
print(f"CRITICAL ERROR: Could not load ResNet50: {e}")
# This often happens if tensorflow is not installed
raise e
print("--- All models loaded successfully! ---")
# --- 3. The Corrected Feature Extraction Function ---
def extract_features(pil_image):
"""
Processes a single PIL image and extracts ResNet50 features,
replicating the logic from train_classifier.py.
"""
# 1. Resize the image to match model's expected input (224, 224)
# We use PIL's resize, as the input is already a PIL object
pil_image_resized = pil_image.resize((IMG_WIDTH, IMG_HEIGHT))
# 2. Convert PIL image to NumPy array (shape: 224, 224, 3)
img_array = image.img_to_array(pil_image_resized)
# 3. Add batch dimension (model expects 1, 224, 224, 3)
img_array_expanded = np.expand_dims(img_array, axis=0)
# 4. Preprocess the image for ResNet50 (handles color/pixel scaling)
img_preprocessed = preprocess_input(img_array_expanded)
# 5. Get the feature vector (shape: 1, 2048)
features = feature_extractor.predict(img_preprocessed)
# 6. Return the flattened 1D feature vector (shape: 2048,)
return features.flatten()
# --- 4. The Main Prediction Function (Now More Robust) ---
def predict(input_image):
"""
The main prediction function called by Gradio.
"""
if not input_image:
return None # Handle empty input
# 1. Extract features using the ResNet50 function
try:
# features_1d will have shape (2048,)
features_1d = extract_features(input_image)
except Exception as e:
print(f"Error extracting features: {e}")
# gr.Error shows a clean error message in the UI
raise gr.Error(f"Feature Extraction Failed: {e}")
# 2. Reshape to 2D for the scaler (shape 1, 2048)
features_2d = features_1d.reshape(1, -1)
# Check shape just in case
if features_2d.shape[1] != scaler.n_features_in_:
raise gr.Error(
f"Feature Mismatch! Model expects {scaler.n_features_in_} features, "
f"but got {features_2d.shape[1]}."
)
# 3. Scale the features
try:
scaled_features = scaler.transform(features_2d)
except Exception as e:
print(f"Error scaling features: {e}")
raise gr.Error(f"Feature Scaling Failed: {e}")
# 4. Predict probabilities
try:
# Ensure your SVM was trained with probability=True
probabilities = svm_model.predict_proba(scaled_features)[0]
class_labels = svm_model.classes_
# Create a {label: probability} dictionary
confidences = {label: float(prob) for label, prob in zip(class_labels, probabilities)}
return confidences
except AttributeError:
# Fallback if probability=False
prediction = svm_model.predict(scaled_features)[0]
return {str(prediction): 1.0} # Return definite prediction
except Exception as e:
print(f"Error during prediction: {e}")
raise gr.Error(f"Prediction Failed: {e}")
# --- 5. Create and Launch the Gradio Interface ---
image_input = gr.Image(type="pil", label="Upload Otolith Image")
label_output = gr.Label(num_top_classes=3, label="Classification Results")
app = gr.Interface(
fn=predict,
inputs=image_input,
outputs=label_output,
title="Otolith Classification Engine",
description="Upload an image of an otolith to classify it. This app uses a ResNet50 feature extractor and an SVM classifier."
)
if __name__ == "__main__":
app.launch()