Bhavi23 commited on
Commit
16fb9de
·
verified ·
1 Parent(s): eec819a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ from PIL import Image
8
+ import requests
9
+ import io
10
+ import logging
11
+ import time
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Class mappings
18
+ CLASS_NAMES = {
19
+ 0: 'AcrimSat', 1: 'Aquarius', 2: 'Aura', 3: 'Calipso', 4: 'Cloudsat',
20
+ 5: 'CubeSat', 6: 'Debris', 7: 'Jason', 8: 'Sentinel-6', 9: 'TRMM', 10: 'Terra'
21
+ }
22
+
23
+ # Model configurations
24
+ MODEL_CONFIGS = {
25
+ "Custom CNN": {
26
+ "url": "https://huggingface.co/Bhavi23/Custom_CNN/resolve/main/best_multimodal_model.keras",
27
+ "input_shape": (224, 224, 3)
28
+ },
29
+ "MobileNetV2": {
30
+ "url": "https://huggingface.co/Bhavi23/MobilenetV2/resolve/main/multi_input_model_v1.keras",
31
+ "input_shape": (224, 224, 3)
32
+ },
33
+ "EfficientNetB0": {
34
+ "url": "https://huggingface.co/Bhavi23/EfficientNet_B0/resolve/main/efficientnet_model.keras",
35
+ "input_shape": (224, 224, 3)
36
+ },
37
+ "DenseNet121": {
38
+ "url": "https://huggingface.co/Bhavi23/DenseNet/resolve/main/densenet_model.keras",
39
+ "input_shape": (224, 224, 3)
40
+ }
41
+ }
42
+
43
+ # Performance metrics (for recommendation logic)
44
+ MODEL_METRICS = {
45
+ "Custom CNN": {"accuracy": 95.2, "inference_time": 45, "model_size": 25.3},
46
+ "MobileNetV2": {"accuracy": 92.8, "inference_time": 18, "model_size": 8.7},
47
+ "EfficientNetB0": {"accuracy": 96.4, "inference_time": 35, "model_size": 20.1},
48
+ "DenseNet121": {"accuracy": 94.7, "inference_time": 52, "model_size": 32.8}
49
+ }
50
+
51
+ def load_model(model_name):
52
+ """Load model from Hugging Face with error handling"""
53
+ try:
54
+ logger.info(f"Loading model: {model_name}")
55
+ url = MODEL_CONFIGS[model_name]["url"]
56
+ response = requests.get(url, timeout=60, stream=True)
57
+ response.raise_for_status()
58
+ if len(response.content) < 1000:
59
+ return None, f"Model {model_name} download failed - file too small"
60
+ model_bytes = io.BytesIO(response.content)
61
+ model = tf.keras.models.load_model(model_bytes)
62
+ logger.info(f"Successfully loaded model: {model_name}")
63
+ return model, None
64
+ except Exception as e:
65
+ logger.error(f"Error loading {model_name}: {str(e)}")
66
+ return None, f"Error loading {model_name}: {str(e)}"
67
+
68
+ def preprocess_image(image, target_size=(224, 224)):
69
+ """Preprocess image for model prediction"""
70
+ try:
71
+ if image.mode != 'RGB':
72
+ image = image.convert('RGB')
73
+ image = image.resize(target_size)
74
+ image_array = np.array(image) / 255.0
75
+ return np.expand_dims(image_array, axis=0), None
76
+ except Exception as e:
77
+ return None, f"Error preprocessing image: {str(e)}"
78
+
79
+ def predict_with_model(model, image, model_name):
80
+ """Make prediction with a specific model"""
81
+ if model is None:
82
+ return None
83
+ try:
84
+ start_time = time.time()
85
+ predictions = model.predict(image, verbose=0)
86
+ inference_time = (time.time() - start_time) * 1000
87
+ predicted_class = np.argmax(predictions[0])
88
+ confidence = np.max(predictions[0]) * 100
89
+ if predicted_class not in CLASS_NAMES:
90
+ return None
91
+ return {
92
+ 'class': predicted_class,
93
+ 'class_name': CLASS_NAMES[predicted_class],
94
+ 'confidence': confidence,
95
+ 'inference_time': inference_time,
96
+ 'probabilities': predictions[0]
97
+ }
98
+ except Exception as e:
99
+ logger.error(f"Prediction error with {model_name}: {str(e)}")
100
+ return None
101
+
102
+ def recommend_best_model(predictions):
103
+ """Recommend the best model based on confidence and performance"""
104
+ if not predictions:
105
+ return "EfficientNetB0"
106
+ recommendations = {}
107
+ for model_name, pred in predictions.items():
108
+ if pred:
109
+ base_score = MODEL_METRICS[model_name]["accuracy"]
110
+ confidence_bonus = pred['confidence'] * 0.1
111
+ speed_bonus = max(0, 100 - MODEL_METRICS[model_name]["inference_time"]) * 0.05
112
+ recommendations[model_name] = base_score + confidence_bonus + speed_bonus
113
+ return max(recommendations, key=recommendations.get) if recommendations else "EfficientNetB0"
114
+
115
+ def create_confidence_plot(predictions):
116
+ """Create a bar plot for model confidence comparison"""
117
+ if not predictions:
118
+ return None
119
+ confidences = [pred['confidence'] for pred in predictions.values() if pred]
120
+ model_names = [name for name, pred in predictions.items() if pred]
121
+ recommended_model = recommend_best_model(predictions)
122
+ fig = go.Figure()
123
+ fig.add_trace(go.Bar(
124
+ x=model_names,
125
+ y=confidences,
126
+ marker_color=['gold' if name == recommended_model else 'lightblue' for name in model_names],
127
+ text=[f'{c:.1f}%' for c in confidences],
128
+ textposition='auto'
129
+ ))
130
+ fig.update_layout(
131
+ title="Prediction Confidence by Model",
132
+ xaxis_title="Models",
133
+ yaxis_title="Confidence (%)",
134
+ height=400
135
+ )
136
+ return fig
137
+
138
+ def create_probability_plot(predictions, recommended_model):
139
+ """Create a bar plot for top 5 class probabilities of the recommended model"""
140
+ if recommended_model not in predictions or not predictions[recommended_model]:
141
+ return None
142
+ probs = predictions[recommended_model]['probabilities']
143
+ prob_df = pd.DataFrame({
144
+ 'Class': [CLASS_NAMES[i] for i in range(len(probs))],
145
+ 'Probability': probs * 100
146
+ }).sort_values('Probability', ascending=False).head(5)
147
+ fig = px.bar(
148
+ prob_df,
149
+ x='Probability',
150
+ y='Class',
151
+ orientation='h',
152
+ title=f"Top 5 Class Probabilities - {recommended_model}",
153
+ color='Probability',
154
+ color_continuous_scale='viridis'
155
+ )
156
+ fig.update_layout(height=400)
157
+ return fig
158
+
159
+ def classify_image(image, selected_models):
160
+ """Main function to classify an image and return results"""
161
+ if image is None:
162
+ return "Please upload an image.", None, None, None, None
163
+ if not selected_models:
164
+ return "Please select at least one model.", None, None, None, None
165
+
166
+ processed_image, error = preprocess_image(image)
167
+ if error:
168
+ return error, None, None, None, None
169
+
170
+ predictions = {}
171
+ results_data = []
172
+ for model_name in selected_models:
173
+ model, error = load_model(model_name)
174
+ if error:
175
+ results_data.append({'Model': model_name, 'Error': error})
176
+ continue
177
+ pred = predict_with_model(model, processed_image, model_name)
178
+ if pred:
179
+ predictions[model_name] = pred
180
+ results_data.append({
181
+ 'Model': model_name,
182
+ 'Predicted Class': pred['class_name'],
183
+ 'Confidence (%)': f"{pred['confidence']:.1f}%",
184
+ 'Inference Time (ms)': f"{pred['inference_time']:.1f}"
185
+ })
186
+ else:
187
+ results_data.append({'Model': model_name, 'Error': f"Prediction failed for {model_name}"})
188
+
189
+ recommended_model = recommend_best_model(predictions)
190
+ results_df = pd.DataFrame(results_data)
191
+ confidence_plot = create_confidence_plot(predictions)
192
+ probability_plot = create_probability_plot(predictions, recommended_model)
193
+
194
+ return (
195
+ f"**Recommended Model**: {recommended_model}",
196
+ results_df,
197
+ image,
198
+ confidence_plot,
199
+ probability_plot
200
+ )
201
+
202
+ # Gradio interface
203
+ with gr.Blocks(title="Satellite Classification Dashboard") as demo:
204
+ gr.Markdown("# 🛰️ Satellite Classification Dashboard")
205
+ gr.Markdown("Upload a satellite image and select models to classify it into one of 11 categories. View predictions, confidence scores, and visualizations.")
206
+
207
+ with gr.Row():
208
+ with gr.Column(scale=1):
209
+ image_input = gr.Image(type="pil", label="Upload Satellite Image (PNG, JPG, JPEG)")
210
+ model_select = gr.Dropdown(
211
+ choices=list(MODEL_CONFIGS.keys()),
212
+ value=["EfficientNetB0"],
213
+ multiselect=True,
214
+ label="Select Models"
215
+ )
216
+ classify_button = gr.Button("Classify Image", variant="primary")
217
+ with gr.Column(scale=2):
218
+ output_text = gr.Markdown(label="Prediction Results")
219
+ output_table = gr.Dataframe(label="Prediction Details")
220
+ output_image = gr.Image(label="Uploaded Image")
221
+
222
+ with gr.Row():
223
+ confidence_plot = gr.Plot(label="Confidence Comparison")
224
+ probability_plot = gr.Plot(label="Class Probabilities")
225
+
226
+ classify_button.click(
227
+ fn=classify_image,
228
+ inputs=[image_input, model_select],
229
+ outputs=[output_text, output_table, output_image, confidence_plot, probability_plot]
230
+ )
231
+
232
+ gr.Markdown("""
233
+ ### Supported Classes
234
+ - AcrimSat, Aquarius, Aura, Calipso, Cloudsat, CubeSat, Debris, Jason, Sentinel-6, TRMM, Terra
235
+ ### Available Models
236
+ - **Custom CNN**: Tailored for satellite imagery (95.2% accuracy)
237
+ - **MobileNetV2**: Lightweight and fast (92.8% accuracy, 18ms inference)
238
+ - **EfficientNetB0**: Best accuracy-efficiency balance (96.4% accuracy)
239
+ - **DenseNet121**: Complex pattern recognition (94.7% accuracy)
240
+ """)
241
+
242
+ demo.launch()