chinmay0805 commited on
Commit
fd6cceb
·
verified ·
1 Parent(s): 693b1bd

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +137 -0
  2. requirements.txt +0 -0
  3. scaler.pkl +3 -0
  4. svm_classifier.pkl +3 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # --- Imports from your training script ---
7
+ import os
8
+ from tensorflow.keras.applications import ResNet50
9
+ from tensorflow.keras.applications.resnet50 import preprocess_input
10
+ from tensorflow.keras.preprocessing import image
11
+
12
+ # --- 1. Configuration (from training) ---
13
+ IMG_WIDTH = 224
14
+ IMG_HEIGHT = 224
15
+
16
+ # --- 2. Load All Models (Run once on startup) ---
17
+ print("Loading all models...")
18
+
19
+ # Load the SVM and Scaler
20
+ try:
21
+ svm_model = joblib.load("svm_classifier.pkl")
22
+ scaler = joblib.load("scaler.pkl")
23
+ print("SVM and Scaler loaded.")
24
+ except Exception as e:
25
+ print(f"CRITICAL ERROR: Could not load .pkl files: {e}")
26
+ # This will stop the app if models are missing
27
+ raise FileNotFoundError("Could not find svm_classifier.pkl or scaler.pkl")
28
+
29
+ # Load the ResNet50 feature extractor
30
+ try:
31
+ feature_extractor = ResNet50(weights='imagenet',
32
+ include_top=False,
33
+ pooling='avg',
34
+ input_shape=(IMG_WIDTH, IMG_HEIGHT, 3))
35
+ print("ResNet50 feature extractor loaded.")
36
+ except Exception as e:
37
+ print(f"CRITICAL ERROR: Could not load ResNet50: {e}")
38
+ # This often happens if tensorflow is not installed
39
+ raise e
40
+
41
+ print("--- All models loaded successfully! ---")
42
+
43
+
44
+ # --- 3. The Corrected Feature Extraction Function ---
45
+ def extract_features(pil_image):
46
+ """
47
+ Processes a single PIL image and extracts ResNet50 features,
48
+ replicating the logic from train_classifier.py.
49
+ """
50
+
51
+ # 1. Resize the image to match model's expected input (224, 224)
52
+ # We use PIL's resize, as the input is already a PIL object
53
+ pil_image_resized = pil_image.resize((IMG_WIDTH, IMG_HEIGHT))
54
+
55
+ # 2. Convert PIL image to NumPy array (shape: 224, 224, 3)
56
+ img_array = image.img_to_array(pil_image_resized)
57
+
58
+ # 3. Add batch dimension (model expects 1, 224, 224, 3)
59
+ img_array_expanded = np.expand_dims(img_array, axis=0)
60
+
61
+ # 4. Preprocess the image for ResNet50 (handles color/pixel scaling)
62
+ img_preprocessed = preprocess_input(img_array_expanded)
63
+
64
+ # 5. Get the feature vector (shape: 1, 2048)
65
+ features = feature_extractor.predict(img_preprocessed)
66
+
67
+ # 6. Return the flattened 1D feature vector (shape: 2048,)
68
+ return features.flatten()
69
+
70
+
71
+ # --- 4. The Main Prediction Function (Now More Robust) ---
72
+ def predict(input_image):
73
+ """
74
+ The main prediction function called by Gradio.
75
+ """
76
+ if not input_image:
77
+ return None # Handle empty input
78
+
79
+ # 1. Extract features using the ResNet50 function
80
+ try:
81
+ # features_1d will have shape (2048,)
82
+ features_1d = extract_features(input_image)
83
+ except Exception as e:
84
+ print(f"Error extracting features: {e}")
85
+ # gr.Error shows a clean error message in the UI
86
+ raise gr.Error(f"Feature Extraction Failed: {e}")
87
+
88
+ # 2. Reshape to 2D for the scaler (shape 1, 2048)
89
+ features_2d = features_1d.reshape(1, -1)
90
+
91
+ # Check shape just in case
92
+ if features_2d.shape[1] != scaler.n_features_in_:
93
+ raise gr.Error(
94
+ f"Feature Mismatch! Model expects {scaler.n_features_in_} features, "
95
+ f"but got {features_2d.shape[1]}."
96
+ )
97
+
98
+ # 3. Scale the features
99
+ try:
100
+ scaled_features = scaler.transform(features_2d)
101
+ except Exception as e:
102
+ print(f"Error scaling features: {e}")
103
+ raise gr.Error(f"Feature Scaling Failed: {e}")
104
+
105
+ # 4. Predict probabilities
106
+ try:
107
+ # Ensure your SVM was trained with probability=True
108
+ probabilities = svm_model.predict_proba(scaled_features)[0]
109
+ class_labels = svm_model.classes_
110
+
111
+ # Create a {label: probability} dictionary
112
+ confidences = {label: float(prob) for label, prob in zip(class_labels, probabilities)}
113
+ return confidences
114
+
115
+ except AttributeError:
116
+ # Fallback if probability=False
117
+ prediction = svm_model.predict(scaled_features)[0]
118
+ return {str(prediction): 1.0} # Return definite prediction
119
+ except Exception as e:
120
+ print(f"Error during prediction: {e}")
121
+ raise gr.Error(f"Prediction Failed: {e}")
122
+
123
+
124
+ # --- 5. Create and Launch the Gradio Interface ---
125
+ image_input = gr.Image(type="pil", label="Upload Otolith Image")
126
+ label_output = gr.Label(num_top_classes=3, label="Classification Results")
127
+
128
+ app = gr.Interface(
129
+ fn=predict,
130
+ inputs=image_input,
131
+ outputs=label_output,
132
+ title="Otolith Classification Engine",
133
+ description="Upload an image of an otolith to classify it. This app uses a ResNet50 feature extractor and an SVM classifier."
134
+ )
135
+
136
+ if __name__ == "__main__":
137
+ app.launch()
requirements.txt ADDED
Binary file (214 Bytes). View file
 
scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f41b7222c4c4ff58ef935b47784c849f8c8b39197caac51ec2c774fc56a27e4d
3
+ size 49767
svm_classifier.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb5a45a2335d36f0ee3b453a60dbbcf80b2f1d50ab2769d9b8035f7dfe7e5430
3
+ size 4674815