import gradio as gr import joblib import numpy as np from PIL import Image # --- Imports from your training script --- import os from tensorflow.keras.applications import ResNet50 from tensorflow.keras.applications.resnet50 import preprocess_input from tensorflow.keras.preprocessing import image # --- 1. Configuration (from training) --- IMG_WIDTH = 224 IMG_HEIGHT = 224 # --- 2. Load All Models (Run once on startup) --- print("Loading all models...") # Load the SVM and Scaler try: svm_model = joblib.load("svm_classifier.pkl") scaler = joblib.load("scaler.pkl") print("SVM and Scaler loaded.") except Exception as e: print(f"CRITICAL ERROR: Could not load .pkl files: {e}") # This will stop the app if models are missing raise FileNotFoundError("Could not find svm_classifier.pkl or scaler.pkl") # Load the ResNet50 feature extractor try: feature_extractor = ResNet50(weights='imagenet', include_top=False, pooling='avg', input_shape=(IMG_WIDTH, IMG_HEIGHT, 3)) print("ResNet50 feature extractor loaded.") except Exception as e: print(f"CRITICAL ERROR: Could not load ResNet50: {e}") # This often happens if tensorflow is not installed raise e print("--- All models loaded successfully! ---") # --- 3. The Corrected Feature Extraction Function --- def extract_features(pil_image): """ Processes a single PIL image and extracts ResNet50 features, replicating the logic from train_classifier.py. """ # 1. Resize the image to match model's expected input (224, 224) # We use PIL's resize, as the input is already a PIL object pil_image_resized = pil_image.resize((IMG_WIDTH, IMG_HEIGHT)) # 2. Convert PIL image to NumPy array (shape: 224, 224, 3) img_array = image.img_to_array(pil_image_resized) # 3. Add batch dimension (model expects 1, 224, 224, 3) img_array_expanded = np.expand_dims(img_array, axis=0) # 4. Preprocess the image for ResNet50 (handles color/pixel scaling) img_preprocessed = preprocess_input(img_array_expanded) # 5. Get the feature vector (shape: 1, 2048) features = feature_extractor.predict(img_preprocessed) # 6. Return the flattened 1D feature vector (shape: 2048,) return features.flatten() # --- 4. The Main Prediction Function (Now More Robust) --- def predict(input_image): """ The main prediction function called by Gradio. """ if not input_image: return None # Handle empty input # 1. Extract features using the ResNet50 function try: # features_1d will have shape (2048,) features_1d = extract_features(input_image) except Exception as e: print(f"Error extracting features: {e}") # gr.Error shows a clean error message in the UI raise gr.Error(f"Feature Extraction Failed: {e}") # 2. Reshape to 2D for the scaler (shape 1, 2048) features_2d = features_1d.reshape(1, -1) # Check shape just in case if features_2d.shape[1] != scaler.n_features_in_: raise gr.Error( f"Feature Mismatch! Model expects {scaler.n_features_in_} features, " f"but got {features_2d.shape[1]}." ) # 3. Scale the features try: scaled_features = scaler.transform(features_2d) except Exception as e: print(f"Error scaling features: {e}") raise gr.Error(f"Feature Scaling Failed: {e}") # 4. Predict probabilities try: # Ensure your SVM was trained with probability=True probabilities = svm_model.predict_proba(scaled_features)[0] class_labels = svm_model.classes_ # Create a {label: probability} dictionary confidences = {label: float(prob) for label, prob in zip(class_labels, probabilities)} return confidences except AttributeError: # Fallback if probability=False prediction = svm_model.predict(scaled_features)[0] return {str(prediction): 1.0} # Return definite prediction except Exception as e: print(f"Error during prediction: {e}") raise gr.Error(f"Prediction Failed: {e}") # --- 5. Create and Launch the Gradio Interface --- image_input = gr.Image(type="pil", label="Upload Otolith Image") label_output = gr.Label(num_top_classes=3, label="Classification Results") app = gr.Interface( fn=predict, inputs=image_input, outputs=label_output, title="Otolith Classification Engine", description="Upload an image of an otolith to classify it. This app uses a ResNet50 feature extractor and an SVM classifier." ) if __name__ == "__main__": app.launch()