Bhavi23 commited on
Commit
9eb2913
·
verified ·
1 Parent(s): 271a740

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +567 -376
app.py CHANGED
@@ -10,14 +10,23 @@ import requests
10
  import io
11
  from datetime import datetime
12
  import time
 
13
 
14
- # Configure page
15
- st.set_page_config(
16
- page_title="Satellite Classification Dashboard",
17
- page_icon="🛰️",
18
- layout="wide",
19
- initial_sidebar_state="expanded"
20
- )
 
 
 
 
 
 
 
 
21
 
22
  # Custom CSS for better styling
23
  st.markdown("""
@@ -51,6 +60,10 @@ st.markdown("""
51
  color: white;
52
  font-size: 1.2rem;
53
  }
 
 
 
 
54
  </style>
55
  """, unsafe_allow_html=True)
56
 
@@ -134,38 +147,77 @@ MODEL_METRICS = {
134
 
135
  @st.cache_resource
136
  def load_model(model_name):
137
- """Load model from HuggingFace with caching"""
138
  try:
139
- with st.spinner(f'Loading {model_name}...'):
140
- url = MODEL_CONFIGS[model_name]["url"]
141
- response = requests.get(url)
142
- response.raise_for_status()
 
 
 
 
 
 
 
143
 
144
- model_bytes = io.BytesIO(response.content)
 
 
 
145
  model = tf.keras.models.load_model(model_bytes)
 
146
  return model
 
 
 
 
 
 
 
 
 
 
147
  except Exception as e:
148
- st.error(f"Error loading {model_name}: {str(e)}")
 
149
  return None
150
 
151
  def preprocess_image(image, target_size=(224, 224)):
152
- """Preprocess image for model prediction"""
153
- if image.mode != 'RGB':
154
- image = image.convert('RGB')
155
- image = image.resize(target_size)
156
- image_array = np.array(image) / 255.0
157
- return np.expand_dims(image_array, axis=0)
 
 
 
 
158
 
159
  def predict_with_model(model, image, model_name):
160
- """Make prediction with a specific model"""
 
 
 
161
  try:
162
  start_time = time.time()
163
  predictions = model.predict(image, verbose=0)
164
  inference_time = (time.time() - start_time) * 1000 # Convert to ms
165
 
 
 
 
 
 
166
  predicted_class = np.argmax(predictions[0])
167
  confidence = np.max(predictions[0]) * 100
168
 
 
 
 
 
 
169
  return {
170
  'class': predicted_class,
171
  'class_name': CLASS_NAMES[predicted_class],
@@ -175,6 +227,7 @@ def predict_with_model(model, image, model_name):
175
  }
176
  except Exception as e:
177
  st.error(f"Prediction error with {model_name}: {str(e)}")
 
178
  return None
179
 
180
  def recommend_best_model(image_predictions):
@@ -199,402 +252,540 @@ def recommend_best_model(image_predictions):
199
  return "EfficientNetB0"
200
 
201
  def create_metrics_comparison():
202
- """Create comprehensive metrics comparison dashboard"""
203
-
204
- # Create subplots
205
- fig = make_subplots(
206
- rows=2, cols=2,
207
- subplot_titles=('Accuracy Comparison', 'Model Size vs Inference Time',
208
- 'Performance Metrics Radar', 'Training Efficiency'),
209
- specs=[[{"type": "bar"}, {"type": "scatter"}],
210
- [{"type": "scatterpolar"}, {"type": "bar"}]]
211
- )
212
-
213
- models = list(MODEL_METRICS.keys())
214
-
215
- # 1. Accuracy Comparison Bar Chart
216
- accuracies = [MODEL_METRICS[model]["accuracy"] for model in models]
217
- fig.add_trace(
218
- go.Bar(x=models, y=accuracies, name="Accuracy",
219
- marker_color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']),
220
- row=1, col=1
221
- )
222
-
223
- # 2. Model Size vs Inference Time Scatter
224
- sizes = [MODEL_METRICS[model]["model_size"] for model in models]
225
- times = [MODEL_METRICS[model]["inference_time"] for model in models]
226
- fig.add_trace(
227
- go.Scatter(x=sizes, y=times, mode='markers+text',
228
- text=models, textposition="top center",
229
- marker=dict(size=15, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']),
230
- name="Size vs Speed"),
231
- row=1, col=2
232
- )
233
-
234
- # 3. Radar Chart for Performance Metrics
235
- metrics = ['accuracy', 'precision', 'recall', 'f1_score']
236
- for i, model in enumerate(models):
237
- values = [MODEL_METRICS[model][metric] for metric in metrics]
238
  fig.add_trace(
239
- go.Scatterpolar(r=values, theta=metrics, fill='toself',
240
- name=model, opacity=0.7),
241
- row=2, col=1
242
  )
243
-
244
- # 4. Training Time Comparison
245
- training_times = [MODEL_METRICS[model]["training_time"] for model in models]
246
- fig.add_trace(
247
- go.Bar(x=models, y=training_times, name="Training Time",
248
- marker_color=['#9467bd', '#8c564b', '#e377c2', '#7f7f7f']),
249
- row=2, col=2
250
- )
251
-
252
- # Update layout
253
- fig.update_layout(height=800, showlegend=True,
254
- title_text="Comprehensive Model Comparison Dashboard")
255
- fig.update_xaxes(title_text="Models", row=1, col=1)
256
- fig.update_yaxes(title_text="Accuracy (%)", row=1, col=1)
257
- fig.update_xaxes(title_text="Model Size (MB)", row=1, col=2)
258
- fig.update_yaxes(title_text="Inference Time (ms)", row=1, col=2)
259
- fig.update_xaxes(title_text="Models", row=2, col=2)
260
- fig.update_yaxes(title_text="Training Time (minutes)", row=2, col=2)
261
-
262
- return fig
263
-
264
- def create_class_distribution_chart():
265
- """Create class distribution visualization"""
266
- classes = list(CLASS_NAMES.values())
267
- samples = [7500 if cls != 'Debris' else 15000 for cls in classes]
268
- percentages = [8.33 if cls != 'Debris' else 16.67 for cls in classes]
269
-
270
- fig = go.Figure()
271
- fig.add_trace(go.Bar(
272
- x=classes,
273
- y=samples,
274
- text=[f'{s} ({p}%)' for s, p in zip(samples, percentages)],
275
- textposition='auto',
276
- marker_color=['#ff6b6b' if cls == 'Debris' else '#4ecdc4' for cls in classes]
277
- ))
278
-
279
- fig.update_layout(
280
- title="Class Distribution in Training Dataset",
281
- xaxis_title="Satellite Classes",
282
- yaxis_title="Number of Samples",
283
- height=400
284
- )
285
-
286
- return fig
287
-
288
- # Main App
289
- def main():
290
- # Header
291
- st.markdown('<h1 class="main-header">🛰️ Satellite Classification Dashboard</h1>',
292
- unsafe_allow_html=True)
293
-
294
- # Sidebar
295
- st.sidebar.title("Navigation")
296
- page = st.sidebar.selectbox("Choose a page",
297
- ["🏠 Home", "📊 Model Comparison", "🔍 Image Classification",
298
- "📈 Performance Analytics", "ℹ️ About Models"])
299
-
300
- if page == "🏠 Home":
301
- st.markdown("## Welcome to the Satellite Classification System")
302
 
303
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
304
 
305
- with col1:
306
- st.markdown("### 🎯 System Overview")
307
- st.write("""
308
- This dashboard provides comprehensive satellite classification using 4 different
309
- deep learning models. Upload satellite images to classify them into 11 different
310
- categories including various satellites and space debris.
311
- """)
312
-
313
- st.markdown("### 🛰️ Supported Classes")
314
- for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()):
315
- if i < 6: # First column
316
- st.write(f"• **{class_name}**")
317
 
318
- with col2:
319
- st.markdown("### 🤖 Available Models")
320
- st.write("""
321
- - **Custom CNN**: Tailored architecture for satellite imagery
322
- - **MobileNetV2**: Lightweight and fast inference
323
- - **EfficientNetB0**: Best accuracy-efficiency balance
324
- - **DenseNet121**: Complex pattern recognition
325
- """)
326
-
327
- st.markdown("### 📊 Class Distribution")
328
- for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()):
329
- if i >= 6: # Second column
330
- st.write(f"• **{class_name}**")
331
 
332
- # Class distribution chart
333
- st.plotly_chart(create_class_distribution_chart(), use_container_width=True)
334
-
335
- elif page == "📊 Model Comparison":
336
- st.markdown("## 📊 Model Performance Comparison")
 
 
 
 
337
 
338
- # Metrics table
339
- st.markdown("### Performance Metrics Summary")
340
- df_metrics = pd.DataFrame(MODEL_METRICS).T
341
- st.dataframe(df_metrics.style.highlight_max(axis=0), use_container_width=True)
 
 
 
 
 
 
 
342
 
343
- # Comprehensive comparison chart
344
- st.plotly_chart(create_metrics_comparison(), use_container_width=True)
 
 
 
 
 
 
345
 
346
- # Model recommendations
347
- st.markdown("### 🎯 Model Selection Guide")
 
 
 
 
 
348
 
349
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
 
 
 
350
 
351
- with col1:
352
- st.markdown("#### 🏆 Best for Accuracy")
353
- st.success("**EfficientNetB0** - 96.4% accuracy")
354
-
355
- st.markdown("#### ⚡ Best for Speed")
356
- st.info("**MobileNetV2** - 18ms inference time")
357
 
358
- with col2:
359
- st.markdown("#### 💾 Most Lightweight")
360
- st.info("**MobileNetV2** - 8.7MB model size")
361
-
362
- st.markdown("#### 🎯 Best Overall Balance")
363
- st.warning("**EfficientNetB0** - High accuracy + efficiency")
364
-
365
- elif page == "🔍 Image Classification":
366
- st.markdown("## 🔍 Image Classification")
367
 
368
- uploaded_file = st.file_uploader(
369
- "Upload a satellite image",
370
- type=['png', 'jpg', 'jpeg'],
371
- help="Upload an image of a satellite or space object for classification"
 
372
  )
373
 
374
- if uploaded_file is not None:
375
- # Display uploaded image
376
- image = Image.open(uploaded_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
- col1, col2 = st.columns([1, 2])
379
 
380
  with col1:
381
- st.image(image, caption="Uploaded Image", use_container_width=True)
 
 
 
 
 
 
 
 
 
 
382
 
383
  with col2:
384
- st.markdown("### Image Details")
385
- st.write(f"**Filename:** {uploaded_file.name}")
386
- st.write(f"**Size:** {image.size}")
387
- st.write(f"**Mode:** {image.mode}")
388
-
389
- # Model selection
390
- selected_models = st.multiselect(
391
- "Select models for prediction",
392
- list(MODEL_CONFIGS.keys()),
393
- default=["EfficientNetB0", "Custom CNN"]
394
- )
 
395
 
396
- if st.button("🚀 Classify Image", type="primary"):
397
- if not selected_models:
398
- st.warning("Please select at least one model.")
399
- return
400
 
401
- # Preprocess image
402
- processed_image = preprocess_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
- # Store predictions
405
- predictions = {}
 
 
 
 
406
 
407
- # Create progress bar
408
- progress_bar = st.progress(0)
409
- status_text = st.empty()
410
 
411
- # Make predictions with selected models
412
- for i, model_name in enumerate(selected_models):
413
- status_text.text(f'Loading {model_name}...')
414
- model = load_model(model_name)
415
-
416
- if model:
417
- status_text.text(f'Predicting with {model_name}...')
418
- pred = predict_with_model(model, processed_image, model_name)
419
- predictions[model_name] = pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- progress_bar.progress((i + 1) / len(selected_models))
422
-
423
- status_text.empty()
424
- progress_bar.empty()
425
-
426
- # Display results
427
- if predictions:
428
- # Get recommendation
429
- recommended_model = recommend_best_model(predictions)
430
 
431
- st.markdown("### 🎯 Prediction Results")
 
432
 
433
- # Show recommendation
434
- st.markdown(f"""
435
- <div class="prediction-box">
436
- <h3>🏆 Recommended Model: {recommended_model}</h3>
437
- <p>Based on confidence and model performance</p>
438
- </div>
439
- """, unsafe_allow_html=True)
440
 
441
- # Results table
442
- results_data = []
443
- for model_name, pred in predictions.items():
444
- if pred:
445
- results_data.append({
446
- 'Model': model_name,
447
- 'Predicted Class': pred['class_name'],
448
- 'Confidence (%)': f"{pred['confidence']:.1f}%",
449
- 'Inference Time (ms)': f"{pred['inference_time']:.1f}",
450
- 'Recommended': '🏆' if model_name == recommended_model else ''
451
- })
452
 
453
- if results_data:
454
- df_results = pd.DataFrame(results_data)
455
- st.dataframe(df_results, use_container_width=True)
 
456
 
457
- # Confidence comparison
458
- st.markdown("### 📊 Confidence Comparison")
459
- confidences = [pred['confidence'] for pred in predictions.values() if pred]
460
- model_names = [name for name, pred in predictions.items() if pred]
 
461
 
462
- fig_conf = go.Figure()
463
- fig_conf.add_trace(go.Bar(
464
- x=model_names,
465
- y=confidences,
466
- marker_color=['gold' if name == recommended_model else 'lightblue'
467
- for name in model_names]
468
- ))
469
- fig_conf.update_layout(
470
- title="Prediction Confidence by Model",
471
- xaxis_title="Models",
472
- yaxis_title="Confidence (%)",
473
- height=400
474
- )
475
- st.plotly_chart(fig_conf, use_container_width=True)
476
 
477
- # Probability distribution for recommended model
478
- if recommended_model in predictions and predictions[recommended_model]:
479
- st.markdown(f"### 🔍 Detailed Probabilities - {recommended_model}")
480
- probs = predictions[recommended_model]['probabilities']
481
- prob_df = pd.DataFrame({
482
- 'Class': [CLASS_NAMES[i] for i in range(len(probs))],
483
- 'Probability': probs * 100
484
- }).sort_values('Probability', ascending=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
- fig_prob = px.bar(
487
- prob_df.head(5),
488
- x='Probability',
489
- y='Class',
490
- orientation='h',
491
- title=f"Top 5 Class Probabilities - {recommended_model}"
492
- )
493
- st.plotly_chart(fig_prob, use_container_width=True)
494
-
495
- elif page == "📈 Performance Analytics":
496
- st.markdown("## 📈 Performance Analytics")
497
-
498
- # Performance overview
499
- col1, col2, col3, col4 = st.columns(4)
500
-
501
- with col1:
502
- st.metric("Best Accuracy", "96.4%", "EfficientNetB0")
503
- with col2:
504
- st.metric("Fastest Inference", "18ms", "MobileNetV2")
505
- with col3:
506
- st.metric("Smallest Model", "8.7MB", "MobileNetV2")
507
- with col4:
508
- st.metric("Total Classes", "11", "Satellites + Debris")
509
-
510
- # Detailed analytics
511
- tab1, tab2, tab3 = st.tabs(["Accuracy Analysis", "Efficiency Metrics", "Model Comparison"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- with tab1:
514
- # Accuracy breakdown
515
- models = list(MODEL_METRICS.keys())
516
- metrics_list = ['accuracy', 'precision', 'recall', 'f1_score']
517
 
518
- for metric in metrics_list:
519
- values = [MODEL_METRICS[model][metric] for model in models]
520
- fig = go.Figure()
521
- fig.add_trace(go.Bar(x=models, y=values, name=metric.title()))
522
- fig.update_layout(title=f"{metric.title()} Comparison", height=300)
523
- st.plotly_chart(fig, use_container_width=True)
524
-
525
- with tab2:
526
- # Efficiency metrics
527
- col1, col2 = st.columns(2)
528
 
529
  with col1:
530
- # Inference time
531
- times = [MODEL_METRICS[model]["inference_time"] for model in models]
532
- fig_time = go.Figure()
533
- fig_time.add_trace(go.Bar(x=models, y=times,
534
- marker_color=['red' if t > 40 else 'green' for t in times]))
535
- fig_time.update_layout(title="Inference Time (ms)", height=400)
536
- st.plotly_chart(fig_time, use_container_width=True)
537
-
538
  with col2:
539
- # Model size
540
- sizes = [MODEL_METRICS[model]["model_size"] for model in models]
541
- fig_size = go.Figure()
542
- fig_size.add_trace(go.Bar(x=models, y=sizes,
543
- marker_color=['red' if s > 25 else 'green' for s in sizes]))
544
- fig_size.update_layout(title="Model Size (MB)", height=400)
545
- st.plotly_chart(fig_size, use_container_width=True)
546
-
547
- with tab3:
548
- # Side-by-side comparison
549
- comparison_data = []
550
- for model in models:
551
- metrics = MODEL_METRICS[model]
552
- comparison_data.append({
553
- 'Model': model,
554
- 'Accuracy (%)': metrics['accuracy'],
555
- 'Inference Time (ms)': metrics['inference_time'],
556
- 'Model Size (MB)': metrics['model_size'],
557
- 'Training Time (min)': metrics['training_time'],
558
- 'Efficiency Score': round(metrics['accuracy'] / (metrics['inference_time'] * 0.1 + metrics['model_size'] * 0.1), 2)
559
- })
560
 
561
- df_comparison = pd.DataFrame(comparison_data)
562
- st.dataframe(df_comparison.style.highlight_max(axis=0, subset=['Accuracy (%)', 'Efficiency Score'])
563
- .highlight_min(axis=0, subset=['Inference Time (ms)', 'Model Size (MB)', 'Training Time (min)']),
564
- use_container_width=True)
565
-
566
- elif page == "ℹ️ About Models":
567
- st.markdown("## ℹ️ Model Information")
568
-
569
- for model_name, config in MODEL_CONFIGS.items():
570
- with st.expander(f"📋 {model_name}", expanded=False):
571
- col1, col2 = st.columns(2)
572
-
573
- with col1:
574
- st.markdown("### Description")
575
- st.write(config["description"])
576
-
577
- st.markdown("### Input Shape")
578
- st.code(f"{config['input_shape']}")
579
-
580
- st.markdown("### Model URL")
581
- st.code(config["url"])
582
-
583
- with col2:
584
- st.markdown("### Strengths")
585
- for strength in config["strengths"]:
586
- st.write(f"• {strength}")
587
 
588
- st.markdown("### Best Use Cases")
589
- for use_case in config["best_for"]:
590
- st.write(f"• {use_case}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
- # Performance summary
593
- metrics = MODEL_METRICS[model_name]
594
- st.markdown("### Key Metrics")
595
- st.write(f"**Accuracy:** {metrics['accuracy']}%")
596
- st.write(f"**Inference Time:** {metrics['inference_time']}ms")
597
- st.write(f"**Model Size:** {metrics['model_size']}MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
 
599
  if __name__ == "__main__":
600
- main()
 
10
  import io
11
  from datetime import datetime
12
  import time
13
+ import logging
14
 
15
+ # Set up logging to help debug issues
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Configure page - move this to the very top and add error handling
20
+ try:
21
+ st.set_page_config(
22
+ page_title="Satellite Classification Dashboard",
23
+ page_icon="🛰️",
24
+ layout="wide",
25
+ initial_sidebar_state="expanded"
26
+ )
27
+ except Exception as e:
28
+ logger.error(f"Error setting page config: {e}")
29
+ # Continue without custom config if it fails
30
 
31
  # Custom CSS for better styling
32
  st.markdown("""
 
60
  color: white;
61
  font-size: 1.2rem;
62
  }
63
+ .stAlert > div {
64
+ padding: 10px;
65
+ border-radius: 5px;
66
+ }
67
  </style>
68
  """, unsafe_allow_html=True)
69
 
 
147
 
148
  @st.cache_resource
149
  def load_model(model_name):
150
+ """Load model from HuggingFace with caching and better error handling"""
151
  try:
152
+ logger.info(f"Loading model: {model_name}")
153
+ url = MODEL_CONFIGS[model_name]["url"]
154
+
155
+ # Add timeout and better error handling
156
+ response = requests.get(url, timeout=60, stream=True)
157
+ response.raise_for_status()
158
+
159
+ # Check if response is actually a Keras model
160
+ if len(response.content) < 1000: # Too small to be a model
161
+ st.error(f"Model {model_name} download failed - file too small")
162
+ return None
163
 
164
+ model_bytes = io.BytesIO(response.content)
165
+
166
+ # Try to load the model with error handling
167
+ try:
168
  model = tf.keras.models.load_model(model_bytes)
169
+ logger.info(f"Successfully loaded model: {model_name}")
170
  return model
171
+ except Exception as load_error:
172
+ st.error(f"Error loading Keras model {model_name}: {str(load_error)}")
173
+ return None
174
+
175
+ except requests.exceptions.Timeout:
176
+ st.error(f"Timeout loading {model_name}. Please try again.")
177
+ return None
178
+ except requests.exceptions.RequestException as e:
179
+ st.error(f"Network error loading {model_name}: {str(e)}")
180
+ return None
181
  except Exception as e:
182
+ st.error(f"Unexpected error loading {model_name}: {str(e)}")
183
+ logger.error(f"Error loading {model_name}: {str(e)}")
184
  return None
185
 
186
  def preprocess_image(image, target_size=(224, 224)):
187
+ """Preprocess image for model prediction with error handling"""
188
+ try:
189
+ if image.mode != 'RGB':
190
+ image = image.convert('RGB')
191
+ image = image.resize(target_size)
192
+ image_array = np.array(image) / 255.0
193
+ return np.expand_dims(image_array, axis=0)
194
+ except Exception as e:
195
+ st.error(f"Error preprocessing image: {str(e)}")
196
+ return None
197
 
198
  def predict_with_model(model, image, model_name):
199
+ """Make prediction with a specific model with better error handling"""
200
+ if model is None:
201
+ return None
202
+
203
  try:
204
  start_time = time.time()
205
  predictions = model.predict(image, verbose=0)
206
  inference_time = (time.time() - start_time) * 1000 # Convert to ms
207
 
208
+ # Validate predictions
209
+ if predictions is None or len(predictions) == 0:
210
+ st.error(f"No predictions returned from {model_name}")
211
+ return None
212
+
213
  predicted_class = np.argmax(predictions[0])
214
  confidence = np.max(predictions[0]) * 100
215
 
216
+ # Validate class prediction
217
+ if predicted_class not in CLASS_NAMES:
218
+ st.error(f"Invalid class prediction from {model_name}: {predicted_class}")
219
+ return None
220
+
221
  return {
222
  'class': predicted_class,
223
  'class_name': CLASS_NAMES[predicted_class],
 
227
  }
228
  except Exception as e:
229
  st.error(f"Prediction error with {model_name}: {str(e)}")
230
+ logger.error(f"Prediction error with {model_name}: {str(e)}")
231
  return None
232
 
233
  def recommend_best_model(image_predictions):
 
252
  return "EfficientNetB0"
253
 
254
  def create_metrics_comparison():
255
+ """Create comprehensive metrics comparison dashboard with error handling"""
256
+ try:
257
+ # Create subplots
258
+ fig = make_subplots(
259
+ rows=2, cols=2,
260
+ subplot_titles=('Accuracy Comparison', 'Model Size vs Inference Time',
261
+ 'Performance Metrics Radar', 'Training Efficiency'),
262
+ specs=[[{"type": "bar"}, {"type": "scatter"}],
263
+ [{"type": "scatterpolar"}, {"type": "bar"}]]
264
+ )
265
+
266
+ models = list(MODEL_METRICS.keys())
267
+
268
+ # 1. Accuracy Comparison Bar Chart
269
+ accuracies = [MODEL_METRICS[model]["accuracy"] for model in models]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  fig.add_trace(
271
+ go.Bar(x=models, y=accuracies, name="Accuracy",
272
+ marker_color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']),
273
+ row=1, col=1
274
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ # 2. Model Size vs Inference Time Scatter
277
+ sizes = [MODEL_METRICS[model]["model_size"] for model in models]
278
+ times = [MODEL_METRICS[model]["inference_time"] for model in models]
279
+ fig.add_trace(
280
+ go.Scatter(x=sizes, y=times, mode='markers+text',
281
+ text=models, textposition="top center",
282
+ marker=dict(size=15, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']),
283
+ name="Size vs Speed"),
284
+ row=1, col=2
285
+ )
286
 
287
+ # 3. Radar Chart for Performance Metrics
288
+ metrics = ['accuracy', 'precision', 'recall', 'f1_score']
289
+ colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
290
+ for i, model in enumerate(models):
291
+ values = [MODEL_METRICS[model][metric] for metric in metrics]
292
+ fig.add_trace(
293
+ go.Scatterpolar(r=values, theta=metrics, fill='toself',
294
+ name=model, opacity=0.7, line_color=colors[i]),
295
+ row=2, col=1
296
+ )
 
 
297
 
298
+ # 4. Training Time Comparison
299
+ training_times = [MODEL_METRICS[model]["training_time"] for model in models]
300
+ fig.add_trace(
301
+ go.Bar(x=models, y=training_times, name="Training Time",
302
+ marker_color=['#9467bd', '#8c564b', '#e377c2', '#7f7f7f']),
303
+ row=2, col=2
304
+ )
 
 
 
 
 
 
305
 
306
+ # Update layout
307
+ fig.update_layout(height=800, showlegend=True,
308
+ title_text="Comprehensive Model Comparison Dashboard")
309
+ fig.update_xaxes(title_text="Models", row=1, col=1)
310
+ fig.update_yaxes(title_text="Accuracy (%)", row=1, col=1)
311
+ fig.update_xaxes(title_text="Model Size (MB)", row=1, col=2)
312
+ fig.update_yaxes(title_text="Inference Time (ms)", row=1, col=2)
313
+ fig.update_xaxes(title_text="Models", row=2, col=2)
314
+ fig.update_yaxes(title_text="Training Time (minutes)", row=2, col=2)
315
 
316
+ return fig
317
+ except Exception as e:
318
+ st.error(f"Error creating metrics comparison chart: {str(e)}")
319
+ return None
320
+
321
+ def create_class_distribution_chart():
322
+ """Create class distribution visualization with error handling"""
323
+ try:
324
+ classes = list(CLASS_NAMES.values())
325
+ samples = [7500 if cls != 'Debris' else 15000 for cls in classes]
326
+ percentages = [8.33 if cls != 'Debris' else 16.67 for cls in classes]
327
 
328
+ fig = go.Figure()
329
+ fig.add_trace(go.Bar(
330
+ x=classes,
331
+ y=samples,
332
+ text=[f'{s} ({p:.1f}%)' for s, p in zip(samples, percentages)],
333
+ textposition='auto',
334
+ marker_color=['#ff6b6b' if cls == 'Debris' else '#4ecdc4' for cls in classes]
335
+ ))
336
 
337
+ fig.update_layout(
338
+ title="Class Distribution in Training Dataset",
339
+ xaxis_title="Satellite Classes",
340
+ yaxis_title="Number of Samples",
341
+ height=400,
342
+ xaxis_tickangle=-45
343
+ )
344
 
345
+ return fig
346
+ except Exception as e:
347
+ st.error(f"Error creating class distribution chart: {str(e)}")
348
+ return None
349
+
350
+ def create_confusion_matrix_heatmap():
351
+ """Create a sample confusion matrix heatmap for demonstration"""
352
+ try:
353
+ # Sample confusion matrix data (you would replace this with actual data)
354
+ classes = list(CLASS_NAMES.values())
355
+ np.random.seed(42) # For reproducible demo data
356
 
357
+ # Create a realistic-looking confusion matrix
358
+ confusion_matrix = np.random.randint(0, 100, size=(len(classes), len(classes)))
359
+ # Make diagonal elements higher (correct predictions)
360
+ np.fill_diagonal(confusion_matrix, np.random.randint(400, 500, size=len(classes)))
 
 
361
 
362
+ fig = go.Figure(data=go.Heatmap(
363
+ z=confusion_matrix,
364
+ x=classes,
365
+ y=classes,
366
+ colorscale='Blues',
367
+ showscale=True
368
+ ))
 
 
369
 
370
+ fig.update_layout(
371
+ title="Sample Confusion Matrix (Demo Data)",
372
+ xaxis_title="Predicted Class",
373
+ yaxis_title="True Class",
374
+ height=600
375
  )
376
 
377
+ return fig
378
+ except Exception as e:
379
+ st.error(f"Error creating confusion matrix: {str(e)}")
380
+ return None
381
+
382
+ # Main App
383
+ def main():
384
+ try:
385
+ # Header
386
+ st.markdown('<h1 class="main-header">🛰️ Satellite Classification Dashboard</h1>',
387
+ unsafe_allow_html=True)
388
+
389
+ # Sidebar
390
+ st.sidebar.title("Navigation")
391
+ page = st.sidebar.selectbox("Choose a page",
392
+ ["🏠 Home", "📊 Model Comparison", "🔍 Image Classification",
393
+ "📈 Performance Analytics", "ℹ️ About Models"])
394
+
395
+ # Add sidebar information
396
+ st.sidebar.markdown("---")
397
+ st.sidebar.markdown("### System Info")
398
+ st.sidebar.info(f"Total Classes: {len(CLASS_NAMES)}")
399
+ st.sidebar.info(f"Available Models: {len(MODEL_CONFIGS)}")
400
+ st.sidebar.info("Built with Streamlit & TensorFlow")
401
+
402
+ if page == "🏠 Home":
403
+ st.markdown("## Welcome to the Satellite Classification System")
404
 
405
+ col1, col2 = st.columns(2)
406
 
407
  with col1:
408
+ st.markdown("### 🎯 System Overview")
409
+ st.write("""
410
+ This dashboard provides comprehensive satellite classification using 4 different
411
+ deep learning models. Upload satellite images to classify them into 11 different
412
+ categories including various satellites and space debris.
413
+ """)
414
+
415
+ st.markdown("### 🛰️ Supported Classes")
416
+ for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()):
417
+ if i < 6: # First column
418
+ st.write(f"• **{class_name}**")
419
 
420
  with col2:
421
+ st.markdown("### 🤖 Available Models")
422
+ st.write("""
423
+ - **Custom CNN**: Tailored architecture for satellite imagery
424
+ - **MobileNetV2**: Lightweight and fast inference
425
+ - **EfficientNetB0**: Best accuracy-efficiency balance
426
+ - **DenseNet121**: Complex pattern recognition
427
+ """)
428
+
429
+ st.markdown("### 📊 Remaining Classes")
430
+ for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()):
431
+ if i >= 6: # Second column
432
+ st.write(f"• **{class_name}**")
433
 
434
+ # Class distribution chart
435
+ chart = create_class_distribution_chart()
436
+ if chart:
437
+ st.plotly_chart(chart, use_container_width=True)
438
 
439
+ # Quick start guide
440
+ st.markdown("### 🚀 Quick Start Guide")
441
+ st.markdown("""
442
+ 1. Navigate to **🔍 Image Classification** to upload and classify satellite images
443
+ 2. Check **📊 Model Comparison** to compare different model performances
444
+ 3. Explore **📈 Performance Analytics** for detailed metrics
445
+ 4. Read **ℹ️ About Models** to understand each model's capabilities
446
+ """)
447
+
448
+ elif page == "📊 Model Comparison":
449
+ st.markdown("## 📊 Model Performance Comparison")
450
+
451
+ # Metrics table
452
+ st.markdown("### Performance Metrics Summary")
453
+ df_metrics = pd.DataFrame(MODEL_METRICS).T
454
+ st.dataframe(df_metrics.style.highlight_max(axis=0), use_container_width=True)
455
+
456
+ # Comprehensive comparison chart
457
+ chart = create_metrics_comparison()
458
+ if chart:
459
+ st.plotly_chart(chart, use_container_width=True)
460
+
461
+ # Model recommendations
462
+ st.markdown("### 🎯 Model Selection Guide")
463
+
464
+ col1, col2 = st.columns(2)
465
+
466
+ with col1:
467
+ st.markdown("#### 🏆 Best for Accuracy")
468
+ st.success("**EfficientNetB0** - 96.4% accuracy")
469
 
470
+ st.markdown("#### Best for Speed")
471
+ st.info("**MobileNetV2** - 18ms inference time")
472
+
473
+ with col2:
474
+ st.markdown("#### 💾 Most Lightweight")
475
+ st.info("**MobileNetV2** - 8.7MB model size")
476
 
477
+ st.markdown("#### 🎯 Best Overall Balance")
478
+ st.warning("**EfficientNetB0** - High accuracy + efficiency")
 
479
 
480
+ # Model rankings
481
+ st.markdown("### 🏅 Model Rankings")
482
+
483
+ # Calculate overall scores
484
+ rankings = []
485
+ for model_name, metrics in MODEL_METRICS.items():
486
+ # Weighted score: accuracy (40%), speed (30%), size (30%)
487
+ score = (metrics['accuracy'] * 0.4 +
488
+ (100 - metrics['inference_time']) * 0.3 +
489
+ (50 - metrics['model_size']) * 0.3)
490
+ rankings.append({'Model': model_name, 'Overall Score': round(score, 1)})
491
+
492
+ rankings_df = pd.DataFrame(rankings).sort_values('Overall Score', ascending=False)
493
+ st.dataframe(rankings_df, use_container_width=True)
494
+
495
+ elif page == "🔍 Image Classification":
496
+ st.markdown("## 🔍 Image Classification")
497
+
498
+ # Instructions
499
+ st.info("""
500
+ 📋 **Instructions:**
501
+ 1. Upload a satellite or space object image (PNG, JPG, or JPEG)
502
+ 2. Select one or more models for classification
503
+ 3. Click 'Classify Image' to get predictions
504
+ 4. View results, confidence scores, and recommendations
505
+ """)
506
+
507
+ uploaded_file = st.file_uploader(
508
+ "Upload a satellite image",
509
+ type=['png', 'jpg', 'jpeg'],
510
+ help="Upload an image of a satellite or space object for classification"
511
+ )
512
+
513
+ if uploaded_file is not None:
514
+ try:
515
+ # Display uploaded image
516
+ image = Image.open(uploaded_file)
517
 
518
+ col1, col2 = st.columns([1, 2])
 
 
 
 
 
 
 
 
519
 
520
+ with col1:
521
+ st.image(image, caption="Uploaded Image", use_container_width=True)
522
 
523
+ with col2:
524
+ st.markdown("### Image Details")
525
+ st.write(f"**Filename:** {uploaded_file.name}")
526
+ st.write(f"**Size:** {image.size}")
527
+ st.write(f"**Mode:** {image.mode}")
528
+ st.write(f"**File Size:** {len(uploaded_file.getvalue())} bytes")
 
529
 
530
+ # Model selection
531
+ st.markdown("### Select Models for Classification")
532
+ selected_models = st.multiselect(
533
+ "Choose models to run predictions with:",
534
+ list(MODEL_CONFIGS.keys()),
535
+ default=["EfficientNetB0"], # Start with just one model to avoid timeouts
536
+ help="Select one or more models. More models = longer processing time."
537
+ )
 
 
 
538
 
539
+ if st.button("🚀 Classify Image", type="primary"):
540
+ if not selected_models:
541
+ st.warning("Please select at least one model.")
542
+ return
543
 
544
+ # Preprocess image
545
+ processed_image = preprocess_image(image)
546
+ if processed_image is None:
547
+ st.error("Failed to preprocess image")
548
+ return
549
 
550
+ # Store predictions
551
+ predictions = {}
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
+ # Create progress bar
554
+ progress_bar = st.progress(0)
555
+ status_text = st.empty()
556
+
557
+ # Make predictions with selected models
558
+ for i, model_name in enumerate(selected_models):
559
+ try:
560
+ status_text.text(f'Loading {model_name}... ({i+1}/{len(selected_models)})')
561
+ model = load_model(model_name)
562
+
563
+ if model:
564
+ status_text.text(f'Predicting with {model_name}... ({i+1}/{len(selected_models)})')
565
+ pred = predict_with_model(model, processed_image, model_name)
566
+ if pred:
567
+ predictions[model_name] = pred
568
+ else:
569
+ st.warning(f"Failed to get prediction from {model_name}")
570
+ else:
571
+ st.warning(f"Failed to load {model_name}")
572
+
573
+ except Exception as e:
574
+ st.error(f"Error processing {model_name}: {str(e)}")
575
+ logger.error(f"Error processing {model_name}: {str(e)}")
576
 
577
+ progress_bar.progress((i + 1) / len(selected_models))
578
+
579
+ status_text.empty()
580
+ progress_bar.empty()
581
+
582
+ # Display results
583
+ if predictions:
584
+ # Get recommendation
585
+ recommended_model = recommend_best_model(predictions)
586
+
587
+ st.markdown("### 🎯 Prediction Results")
588
+
589
+ # Show recommendation
590
+ st.markdown(f"""
591
+ <div class="prediction-box">
592
+ <h3>🏆 Recommended Model: {recommended_model}</h3>
593
+ <p>Based on confidence and model performance</p>
594
+ </div>
595
+ """, unsafe_allow_html=True)
596
+
597
+ # Results table
598
+ results_data = []
599
+ for model_name, pred in predictions.items():
600
+ if pred:
601
+ results_data.append({
602
+ 'Model': model_name,
603
+ 'Predicted Class': pred['class_name'],
604
+ 'Confidence (%)': f"{pred['confidence']:.1f}%",
605
+ 'Inference Time (ms)': f"{pred['inference_time']:.1f}",
606
+ 'Recommended': '🏆' if model_name == recommended_model else ''
607
+ })
608
+
609
+ if results_data:
610
+ df_results = pd.DataFrame(results_data)
611
+ st.dataframe(df_results, use_container_width=True)
612
+
613
+ # Confidence comparison
614
+ if len(predictions) > 1:
615
+ st.markdown("### 📊 Confidence Comparison")
616
+ confidences = [pred['confidence'] for pred in predictions.values() if pred]
617
+ model_names = [name for name, pred in predictions.items() if pred]
618
+
619
+ try:
620
+ fig_conf = go.Figure()
621
+ fig_conf.add_trace(go.Bar(
622
+ x=model_names,
623
+ y=confidences,
624
+ marker_color=['gold' if name == recommended_model else 'lightblue'
625
+ for name in model_names]
626
+ ))
627
+ fig_conf.update_layout(
628
+ title="Prediction Confidence by Model",
629
+ xaxis_title="Models",
630
+ yaxis_title="Confidence (%)",
631
+ height=400
632
+ )
633
+ st.plotly_chart(fig_conf, use_container_width=True)
634
+ except Exception as e:
635
+ st.warning(f"Could not create confidence chart: {str(e)}")
636
+
637
+ # Probability distribution for recommended model
638
+ if recommended_model in predictions and predictions[recommended_model]:
639
+ try:
640
+ st.markdown(f"### 🔍 Detailed Probabilities - {recommended_model}")
641
+ probs = predictions[recommended_model]['probabilities']
642
+ prob_df = pd.DataFrame({
643
+ 'Class': [CLASS_NAMES[i] for i in range(len(probs))],
644
+ 'Probability': probs * 100
645
+ }).sort_values('Probability', ascending=False)
646
+
647
+ fig_prob = px.bar(
648
+ prob_df.head(5),
649
+ x='Probability',
650
+ y='Class',
651
+ orientation='h',
652
+ title=f"Top 5 Class Probabilities - {recommended_model}",
653
+ color='Probability',
654
+ color_continuous_scale='viridis'
655
+ )
656
+ st.plotly_chart(fig_prob, use_container_width=True)
657
+ except Exception as e:
658
+ st.warning(f"Could not create probability chart: {str(e)}")
659
+ else:
660
+ st.error("No successful predictions were made. Please try again with different models.")
661
+
662
+ except Exception as e:
663
+ st.error(f"Error processing uploaded image: {str(e)}")
664
+ logger.error(f"Error processing uploaded image: {str(e)}")
665
 
666
+ elif page == "📈 Performance Analytics":
667
+ st.markdown("## 📈 Performance Analytics")
 
 
668
 
669
+ # Performance overview
670
+ col1, col2, col3, col4 = st.columns(4)
 
 
 
 
 
 
 
 
671
 
672
  with col1:
673
+ st.metric("Best Accuracy", "96.4%", "EfficientNetB0")
 
 
 
 
 
 
 
674
  with col2:
675
+ st.metric("Fastest Inference", "18ms", "MobileNetV2")
676
+ with col3:
677
+ st.metric("Smallest Model", "8.7MB", "MobileNetV2")
678
+ with col4:
679
+ st.metric("Total Classes", "11", "Satellites + Debris")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
 
681
+ # Detailed analytics
682
+ tab1, tab2, tab3, tab4 = st.tabs(["Accuracy Analysis", "Efficiency Metrics", "Model Comparison", "Confusion Matrix"])
683
+
684
+ with tab1:
685
+ try:
686
+ # Accuracy breakdown
687
+ models = list(MODEL_METRICS.keys())
688
+ metrics_list = ['accuracy', 'precision', 'recall', 'f1_score']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
690
+ for metric in metrics_list:
691
+ values = [MODEL_METRICS[model][metric] for model in models]
692
+ fig = go.Figure()
693
+ fig.add_trace(go.Bar(
694
+ x=models,
695
+ y=values,
696
+ name=metric.title(),
697
+ marker_color='lightblue',
698
+ text=[f'{v:.1f}%' for v in values],
699
+ textposition='auto'
700
+ ))
701
+ fig.update_layout(
702
+ title=f"{metric.title()} Comparison",
703
+ height=300,
704
+ yaxis_title=f"{metric.title()} (%)"
705
+ )
706
+ st.plotly_chart(fig, use_container_width=True)
707
+ except Exception as e:
708
+ st.error(f"Error creating accuracy charts: {str(e)}")
709
+
710
+ with tab2:
711
+ try:
712
+ # Efficiency metrics
713
+ col1, col2 = st.columns(2)
714
 
715
+ with col1:
716
+ # Inference time
717
+ times = [MODEL_METRICS[model]["inference_time"] for model in models]
718
+ fig_time = go.Figure()
719
+ fig_time.add_trace(go.Bar(
720
+ x=models,
721
+ y=times,
722
+ name="Inference Time",
723
+ marker_color='orange',
724
+ text=[f'{t:.1f} ms' for t in times],
725
+ textposition='auto'
726
+ ))
727
+ fig_time.update_layout(
728
+ title="Inference Time per Model",
729
+ yaxis_title="Time (ms)",
730
+ height=300
731
+ )
732
+ st.plotly_chart(fig_time, use_container_width=True)
733
+
734
+ with col2:
735
+ # Model sizes
736
+ sizes = [MODEL_METRICS[model]["model_size"] for model in models]
737
+ fig_size = go.Figure()
738
+ fig_size.add_trace(go.Bar(
739
+ x=models,
740
+ y=sizes,
741
+ name="Model Size",
742
+ marker_color='green',
743
+ text=[f'{s:.1f} MB' for s in sizes],
744
+ textposition='auto'
745
+ ))
746
+ fig_size.update_layout(
747
+ title="Model Size per Model",
748
+ yaxis_title="Size (MB)",
749
+ height=300
750
+ )
751
+ st.plotly_chart(fig_size, use_container_width=True)
752
+
753
+ except Exception as e:
754
+ st.error(f"Error displaying efficiency metrics: {str(e)}")
755
+
756
+ with tab3:
757
+ # Reuse full comparison dashboard
758
+ comp_fig = create_metrics_comparison()
759
+ if comp_fig:
760
+ st.plotly_chart(comp_fig, use_container_width=True)
761
+
762
+ with tab4:
763
+ # Display the confusion matrix
764
+ cm_fig = create_confusion_matrix_heatmap()
765
+ if cm_fig:
766
+ st.plotly_chart(cm_fig, use_container_width=True)
767
+
768
+ elif page == "ℹ️ About Models":
769
+ st.markdown("## ℹ️ Model Details and Use Cases")
770
+
771
+ for model_name, config in MODEL_CONFIGS.items():
772
+ with st.expander(f"🔍 {model_name}"):
773
+ st.markdown(f"<div class='model-card'><h4>{model_name}</h4>", unsafe_allow_html=True)
774
+ st.markdown(f"**Description:** {config['description']}")
775
+ st.markdown(f"**Input Shape:** {config['input_shape']}")
776
+ st.markdown("**Strengths:**")
777
+ for s in config['strengths']:
778
+ st.markdown(f"• {s}")
779
+ st.markdown("**Best For:**")
780
+ for use in config['best_for']:
781
+ st.markdown(f"• {use}")
782
+ st.markdown("</div>", unsafe_allow_html=True)
783
+
784
+ except Exception as e:
785
+ st.error(f"An unexpected error occurred: {str(e)}")
786
+ logger.error(f"Main app error: {str(e)}")
787
+
788
 
789
+ # Run the app
790
  if __name__ == "__main__":
791
+ main()