Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| import numpy as np | |
| from PIL import Image | |
| # --- Imports from your training script --- | |
| import os | |
| from tensorflow.keras.applications import ResNet50 | |
| from tensorflow.keras.applications.resnet50 import preprocess_input | |
| from tensorflow.keras.preprocessing import image | |
| # --- 1. Configuration (from training) --- | |
| IMG_WIDTH = 224 | |
| IMG_HEIGHT = 224 | |
| # --- 2. Load All Models (Run once on startup) --- | |
| print("Loading all models...") | |
| # Load the SVM and Scaler | |
| try: | |
| svm_model = joblib.load("svm_classifier.pkl") | |
| scaler = joblib.load("scaler.pkl") | |
| print("SVM and Scaler loaded.") | |
| except Exception as e: | |
| print(f"CRITICAL ERROR: Could not load .pkl files: {e}") | |
| # This will stop the app if models are missing | |
| raise FileNotFoundError("Could not find svm_classifier.pkl or scaler.pkl") | |
| # Load the ResNet50 feature extractor | |
| try: | |
| feature_extractor = ResNet50(weights='imagenet', | |
| include_top=False, | |
| pooling='avg', | |
| input_shape=(IMG_WIDTH, IMG_HEIGHT, 3)) | |
| print("ResNet50 feature extractor loaded.") | |
| except Exception as e: | |
| print(f"CRITICAL ERROR: Could not load ResNet50: {e}") | |
| # This often happens if tensorflow is not installed | |
| raise e | |
| print("--- All models loaded successfully! ---") | |
| # --- 3. The Corrected Feature Extraction Function --- | |
| def extract_features(pil_image): | |
| """ | |
| Processes a single PIL image and extracts ResNet50 features, | |
| replicating the logic from train_classifier.py. | |
| """ | |
| # 1. Resize the image to match model's expected input (224, 224) | |
| # We use PIL's resize, as the input is already a PIL object | |
| pil_image_resized = pil_image.resize((IMG_WIDTH, IMG_HEIGHT)) | |
| # 2. Convert PIL image to NumPy array (shape: 224, 224, 3) | |
| img_array = image.img_to_array(pil_image_resized) | |
| # 3. Add batch dimension (model expects 1, 224, 224, 3) | |
| img_array_expanded = np.expand_dims(img_array, axis=0) | |
| # 4. Preprocess the image for ResNet50 (handles color/pixel scaling) | |
| img_preprocessed = preprocess_input(img_array_expanded) | |
| # 5. Get the feature vector (shape: 1, 2048) | |
| features = feature_extractor.predict(img_preprocessed) | |
| # 6. Return the flattened 1D feature vector (shape: 2048,) | |
| return features.flatten() | |
| # --- 4. The Main Prediction Function (Now More Robust) --- | |
| def predict(input_image): | |
| """ | |
| The main prediction function called by Gradio. | |
| """ | |
| if not input_image: | |
| return None # Handle empty input | |
| # 1. Extract features using the ResNet50 function | |
| try: | |
| # features_1d will have shape (2048,) | |
| features_1d = extract_features(input_image) | |
| except Exception as e: | |
| print(f"Error extracting features: {e}") | |
| # gr.Error shows a clean error message in the UI | |
| raise gr.Error(f"Feature Extraction Failed: {e}") | |
| # 2. Reshape to 2D for the scaler (shape 1, 2048) | |
| features_2d = features_1d.reshape(1, -1) | |
| # Check shape just in case | |
| if features_2d.shape[1] != scaler.n_features_in_: | |
| raise gr.Error( | |
| f"Feature Mismatch! Model expects {scaler.n_features_in_} features, " | |
| f"but got {features_2d.shape[1]}." | |
| ) | |
| # 3. Scale the features | |
| try: | |
| scaled_features = scaler.transform(features_2d) | |
| except Exception as e: | |
| print(f"Error scaling features: {e}") | |
| raise gr.Error(f"Feature Scaling Failed: {e}") | |
| # 4. Predict probabilities | |
| try: | |
| # Ensure your SVM was trained with probability=True | |
| probabilities = svm_model.predict_proba(scaled_features)[0] | |
| class_labels = svm_model.classes_ | |
| # Create a {label: probability} dictionary | |
| confidences = {label: float(prob) for label, prob in zip(class_labels, probabilities)} | |
| return confidences | |
| except AttributeError: | |
| # Fallback if probability=False | |
| prediction = svm_model.predict(scaled_features)[0] | |
| return {str(prediction): 1.0} # Return definite prediction | |
| except Exception as e: | |
| print(f"Error during prediction: {e}") | |
| raise gr.Error(f"Prediction Failed: {e}") | |
| # --- 5. Create and Launch the Gradio Interface --- | |
| image_input = gr.Image(type="pil", label="Upload Otolith Image") | |
| label_output = gr.Label(num_top_classes=3, label="Classification Results") | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=image_input, | |
| outputs=label_output, | |
| title="Otolith Classification Engine", | |
| description="Upload an image of an otolith to classify it. This app uses a ResNet50 feature extractor and an SVM classifier." | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |