Bhavi23 commited on
Commit
59713eb
·
verified ·
1 Parent(s): 0940511

Rename app.py to Satellite Classification Streamlit app.py

Browse files
Files changed (2) hide show
  1. Satellite Classification Streamlit app.py +600 -0
  2. app.py +0 -241
Satellite Classification Streamlit app.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+ from PIL import Image
9
+ import requests
10
+ import io
11
+ from datetime import datetime
12
+ import time
13
+
14
+ # Configure page
15
+ st.set_page_config(
16
+ page_title="Satellite Classification Dashboard",
17
+ page_icon="🛰️",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded"
20
+ )
21
+
22
+ # Custom CSS for better styling
23
+ st.markdown("""
24
+ <style>
25
+ .main-header {
26
+ font-size: 3rem;
27
+ font-weight: bold;
28
+ text-align: center;
29
+ color: #1f77b4;
30
+ margin-bottom: 2rem;
31
+ }
32
+ .model-card {
33
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
34
+ padding: 20px;
35
+ border-radius: 10px;
36
+ margin: 10px 0;
37
+ color: white;
38
+ }
39
+ .metric-card {
40
+ background: #f8f9fa;
41
+ padding: 15px;
42
+ border-radius: 8px;
43
+ border-left: 4px solid #1f77b4;
44
+ margin: 5px 0;
45
+ }
46
+ .prediction-box {
47
+ background: linear-gradient(135deg, #ff7e5f 0%, #feb47b 100%);
48
+ padding: 20px;
49
+ border-radius: 10px;
50
+ text-align: center;
51
+ color: white;
52
+ font-size: 1.2rem;
53
+ }
54
+ </style>
55
+ """, unsafe_allow_html=True)
56
+
57
+ # Class mappings
58
+ CLASS_NAMES = {
59
+ 0: 'AcrimSat', 1: 'Aquarius', 2: 'Aura', 3: 'Calipso', 4: 'Cloudsat',
60
+ 5: 'CubeSat', 6: 'Debris', 7: 'Jason', 8: 'Sentinel-6', 9: 'TRMM', 10: 'Terra'
61
+ }
62
+
63
+ # Model configurations
64
+ MODEL_CONFIGS = {
65
+ "Custom CNN": {
66
+ "url": "https://huggingface.co/Bhavi23/Custom_CNN/resolve/main/best_multimodal_model.keras",
67
+ "description": "Custom CNN architecture designed for satellite classification",
68
+ "input_shape": (224, 224, 3),
69
+ "strengths": ["Good generalization", "Balanced performance", "Stable training"],
70
+ "best_for": ["General purpose", "Balanced datasets", "When interpretability matters"]
71
+ },
72
+ "MobileNetV2": {
73
+ "url": "https://huggingface.co/Bhavi23/MobilenetV2/resolve/main/multi_input_model_v1.keras",
74
+ "description": "Lightweight model optimized for mobile deployment",
75
+ "input_shape": (224, 224, 3),
76
+ "strengths": ["Fast inference", "Small model size", "Energy efficient"],
77
+ "best_for": ["Real-time applications", "Mobile devices", "Resource constraints"]
78
+ },
79
+ "EfficientNetB0": {
80
+ "url": "https://huggingface.co/Bhavi23/EfficientNet_B0/resolve/main/efficientnet_model.keras",
81
+ "description": "Balanced efficiency and accuracy with compound scaling",
82
+ "input_shape": (224, 224, 3),
83
+ "strengths": ["High accuracy", "Parameter efficient", "Good transfer learning"],
84
+ "best_for": ["High accuracy needs", "Limited data", "Transfer learning scenarios"]
85
+ },
86
+ "DenseNet121": {
87
+ "url": "https://huggingface.co/Bhavi23/DenseNet/resolve/main/densenet_model.keras",
88
+ "description": "Dense connections for feature reuse and gradient flow",
89
+ "input_shape": (224, 224, 3),
90
+ "strengths": ["Feature reuse", "Good gradient flow", "Parameter efficiency"],
91
+ "best_for": ["Complex patterns", "Feature-rich images", "When accuracy is priority"]
92
+ }
93
+ }
94
+
95
+ # Performance metrics (based on the results shown in your document)
96
+ MODEL_METRICS = {
97
+ "Custom CNN": {
98
+ "accuracy": 95.2,
99
+ "precision": 94.8,
100
+ "recall": 95.1,
101
+ "f1_score": 94.9,
102
+ "inference_time": 45, # ms
103
+ "model_size": 25.3, # MB
104
+ "training_time": 120 # minutes
105
+ },
106
+ "MobileNetV2": {
107
+ "accuracy": 92.8,
108
+ "precision": 92.1,
109
+ "recall": 92.5,
110
+ "f1_score": 92.3,
111
+ "inference_time": 18, # ms
112
+ "model_size": 8.7, # MB
113
+ "training_time": 95 # minutes
114
+ },
115
+ "EfficientNetB0": {
116
+ "accuracy": 96.4,
117
+ "precision": 96.1,
118
+ "recall": 96.2,
119
+ "f1_score": 96.1,
120
+ "inference_time": 35, # ms
121
+ "model_size": 20.1, # MB
122
+ "training_time": 140 # minutes
123
+ },
124
+ "DenseNet121": {
125
+ "accuracy": 94.7,
126
+ "precision": 94.2,
127
+ "recall": 94.5,
128
+ "f1_score": 94.3,
129
+ "inference_time": 52, # ms
130
+ "model_size": 32.8, # MB
131
+ "training_time": 160 # minutes
132
+ }
133
+ }
134
+
135
+ @st.cache_resource
136
+ def load_model(model_name):
137
+ """Load model from HuggingFace with caching"""
138
+ try:
139
+ with st.spinner(f'Loading {model_name}...'):
140
+ url = MODEL_CONFIGS[model_name]["url"]
141
+ response = requests.get(url)
142
+ response.raise_for_status()
143
+
144
+ model_bytes = io.BytesIO(response.content)
145
+ model = tf.keras.models.load_model(model_bytes)
146
+ return model
147
+ except Exception as e:
148
+ st.error(f"Error loading {model_name}: {str(e)}")
149
+ return None
150
+
151
+ def preprocess_image(image, target_size=(224, 224)):
152
+ """Preprocess image for model prediction"""
153
+ if image.mode != 'RGB':
154
+ image = image.convert('RGB')
155
+ image = image.resize(target_size)
156
+ image_array = np.array(image) / 255.0
157
+ return np.expand_dims(image_array, axis=0)
158
+
159
+ def predict_with_model(model, image, model_name):
160
+ """Make prediction with a specific model"""
161
+ try:
162
+ start_time = time.time()
163
+ predictions = model.predict(image, verbose=0)
164
+ inference_time = (time.time() - start_time) * 1000 # Convert to ms
165
+
166
+ predicted_class = np.argmax(predictions[0])
167
+ confidence = np.max(predictions[0]) * 100
168
+
169
+ return {
170
+ 'class': predicted_class,
171
+ 'class_name': CLASS_NAMES[predicted_class],
172
+ 'confidence': confidence,
173
+ 'inference_time': inference_time,
174
+ 'probabilities': predictions[0]
175
+ }
176
+ except Exception as e:
177
+ st.error(f"Prediction error with {model_name}: {str(e)}")
178
+ return None
179
+
180
+ def recommend_best_model(image_predictions):
181
+ """Recommend the best model based on predictions and confidence"""
182
+ if not image_predictions:
183
+ return "EfficientNetB0" # Default recommendation
184
+
185
+ # Calculate recommendation score based on confidence and model performance
186
+ recommendations = {}
187
+ for model_name, pred in image_predictions.items():
188
+ if pred:
189
+ # Combine confidence with model's overall accuracy
190
+ base_score = MODEL_METRICS[model_name]["accuracy"]
191
+ confidence_bonus = pred['confidence'] * 0.1
192
+ speed_bonus = max(0, 100 - MODEL_METRICS[model_name]["inference_time"]) * 0.05
193
+
194
+ recommendations[model_name] = base_score + confidence_bonus + speed_bonus
195
+
196
+ if recommendations:
197
+ best_model = max(recommendations, key=recommendations.get)
198
+ return best_model
199
+ return "EfficientNetB0"
200
+
201
+ def create_metrics_comparison():
202
+ """Create comprehensive metrics comparison dashboard"""
203
+
204
+ # Create subplots
205
+ fig = make_subplots(
206
+ rows=2, cols=2,
207
+ subplot_titles=('Accuracy Comparison', 'Model Size vs Inference Time',
208
+ 'Performance Metrics Radar', 'Training Efficiency'),
209
+ specs=[[{"type": "bar"}, {"type": "scatter"}],
210
+ [{"type": "scatterpolar"}, {"type": "bar"}]]
211
+ )
212
+
213
+ models = list(MODEL_METRICS.keys())
214
+
215
+ # 1. Accuracy Comparison Bar Chart
216
+ accuracies = [MODEL_METRICS[model]["accuracy"] for model in models]
217
+ fig.add_trace(
218
+ go.Bar(x=models, y=accuracies, name="Accuracy",
219
+ marker_color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']),
220
+ row=1, col=1
221
+ )
222
+
223
+ # 2. Model Size vs Inference Time Scatter
224
+ sizes = [MODEL_METRICS[model]["model_size"] for model in models]
225
+ times = [MODEL_METRICS[model]["inference_time"] for model in models]
226
+ fig.add_trace(
227
+ go.Scatter(x=sizes, y=times, mode='markers+text',
228
+ text=models, textposition="top center",
229
+ marker=dict(size=15, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']),
230
+ name="Size vs Speed"),
231
+ row=1, col=2
232
+ )
233
+
234
+ # 3. Radar Chart for Performance Metrics
235
+ metrics = ['accuracy', 'precision', 'recall', 'f1_score']
236
+ for i, model in enumerate(models):
237
+ values = [MODEL_METRICS[model][metric] for metric in metrics]
238
+ fig.add_trace(
239
+ go.Scatterpolar(r=values, theta=metrics, fill='toself',
240
+ name=model, opacity=0.7),
241
+ row=2, col=1
242
+ )
243
+
244
+ # 4. Training Time Comparison
245
+ training_times = [MODEL_METRICS[model]["training_time"] for model in models]
246
+ fig.add_trace(
247
+ go.Bar(x=models, y=training_times, name="Training Time",
248
+ marker_color=['#9467bd', '#8c564b', '#e377c2', '#7f7f7f']),
249
+ row=2, col=2
250
+ )
251
+
252
+ # Update layout
253
+ fig.update_layout(height=800, showlegend=True,
254
+ title_text="Comprehensive Model Comparison Dashboard")
255
+ fig.update_xaxes(title_text="Models", row=1, col=1)
256
+ fig.update_yaxes(title_text="Accuracy (%)", row=1, col=1)
257
+ fig.update_xaxes(title_text="Model Size (MB)", row=1, col=2)
258
+ fig.update_yaxes(title_text="Inference Time (ms)", row=1, col=2)
259
+ fig.update_xaxes(title_text="Models", row=2, col=2)
260
+ fig.update_yaxes(title_text="Training Time (minutes)", row=2, col=2)
261
+
262
+ return fig
263
+
264
+ def create_class_distribution_chart():
265
+ """Create class distribution visualization"""
266
+ classes = list(CLASS_NAMES.values())
267
+ samples = [7500 if cls != 'Debris' else 15000 for cls in classes]
268
+ percentages = [8.33 if cls != 'Debris' else 16.67 for cls in classes]
269
+
270
+ fig = go.Figure()
271
+ fig.add_trace(go.Bar(
272
+ x=classes,
273
+ y=samples,
274
+ text=[f'{s} ({p}%)' for s, p in zip(samples, percentages)],
275
+ textposition='auto',
276
+ marker_color=['#ff6b6b' if cls == 'Debris' else '#4ecdc4' for cls in classes]
277
+ ))
278
+
279
+ fig.update_layout(
280
+ title="Class Distribution in Training Dataset",
281
+ xaxis_title="Satellite Classes",
282
+ yaxis_title="Number of Samples",
283
+ height=400
284
+ )
285
+
286
+ return fig
287
+
288
+ # Main App
289
+ def main():
290
+ # Header
291
+ st.markdown('<h1 class="main-header">🛰️ Satellite Classification Dashboard</h1>',
292
+ unsafe_allow_html=True)
293
+
294
+ # Sidebar
295
+ st.sidebar.title("Navigation")
296
+ page = st.sidebar.selectbox("Choose a page",
297
+ ["🏠 Home", "📊 Model Comparison", "🔍 Image Classification",
298
+ "📈 Performance Analytics", "ℹ️ About Models"])
299
+
300
+ if page == "🏠 Home":
301
+ st.markdown("## Welcome to the Satellite Classification System")
302
+
303
+ col1, col2 = st.columns(2)
304
+
305
+ with col1:
306
+ st.markdown("### 🎯 System Overview")
307
+ st.write("""
308
+ This dashboard provides comprehensive satellite classification using 4 different
309
+ deep learning models. Upload satellite images to classify them into 11 different
310
+ categories including various satellites and space debris.
311
+ """)
312
+
313
+ st.markdown("### 🛰️ Supported Classes")
314
+ for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()):
315
+ if i < 6: # First column
316
+ st.write(f"• **{class_name}**")
317
+
318
+ with col2:
319
+ st.markdown("### 🤖 Available Models")
320
+ st.write("""
321
+ - **Custom CNN**: Tailored architecture for satellite imagery
322
+ - **MobileNetV2**: Lightweight and fast inference
323
+ - **EfficientNetB0**: Best accuracy-efficiency balance
324
+ - **DenseNet121**: Complex pattern recognition
325
+ """)
326
+
327
+ st.markdown("### 📊 Class Distribution")
328
+ for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()):
329
+ if i >= 6: # Second column
330
+ st.write(f"• **{class_name}**")
331
+
332
+ # Class distribution chart
333
+ st.plotly_chart(create_class_distribution_chart(), use_container_width=True)
334
+
335
+ elif page == "📊 Model Comparison":
336
+ st.markdown("## 📊 Model Performance Comparison")
337
+
338
+ # Metrics table
339
+ st.markdown("### Performance Metrics Summary")
340
+ df_metrics = pd.DataFrame(MODEL_METRICS).T
341
+ st.dataframe(df_metrics.style.highlight_max(axis=0), use_container_width=True)
342
+
343
+ # Comprehensive comparison chart
344
+ st.plotly_chart(create_metrics_comparison(), use_container_width=True)
345
+
346
+ # Model recommendations
347
+ st.markdown("### 🎯 Model Selection Guide")
348
+
349
+ col1, col2 = st.columns(2)
350
+
351
+ with col1:
352
+ st.markdown("#### 🏆 Best for Accuracy")
353
+ st.success("**EfficientNetB0** - 96.4% accuracy")
354
+
355
+ st.markdown("#### ⚡ Best for Speed")
356
+ st.info("**MobileNetV2** - 18ms inference time")
357
+
358
+ with col2:
359
+ st.markdown("#### 💾 Most Lightweight")
360
+ st.info("**MobileNetV2** - 8.7MB model size")
361
+
362
+ st.markdown("#### 🎯 Best Overall Balance")
363
+ st.warning("**EfficientNetB0** - High accuracy + efficiency")
364
+
365
+ elif page == "🔍 Image Classification":
366
+ st.markdown("## 🔍 Image Classification")
367
+
368
+ uploaded_file = st.file_uploader(
369
+ "Upload a satellite image",
370
+ type=['png', 'jpg', 'jpeg'],
371
+ help="Upload an image of a satellite or space object for classification"
372
+ )
373
+
374
+ if uploaded_file is not None:
375
+ # Display uploaded image
376
+ image = Image.open(uploaded_file)
377
+
378
+ col1, col2 = st.columns([1, 2])
379
+
380
+ with col1:
381
+ st.image(image, caption="Uploaded Image", use_container_width=True)
382
+
383
+ with col2:
384
+ st.markdown("### Image Details")
385
+ st.write(f"**Filename:** {uploaded_file.name}")
386
+ st.write(f"**Size:** {image.size}")
387
+ st.write(f"**Mode:** {image.mode}")
388
+
389
+ # Model selection
390
+ selected_models = st.multiselect(
391
+ "Select models for prediction",
392
+ list(MODEL_CONFIGS.keys()),
393
+ default=["EfficientNetB0", "Custom CNN"]
394
+ )
395
+
396
+ if st.button("🚀 Classify Image", type="primary"):
397
+ if not selected_models:
398
+ st.warning("Please select at least one model.")
399
+ return
400
+
401
+ # Preprocess image
402
+ processed_image = preprocess_image(image)
403
+
404
+ # Store predictions
405
+ predictions = {}
406
+
407
+ # Create progress bar
408
+ progress_bar = st.progress(0)
409
+ status_text = st.empty()
410
+
411
+ # Make predictions with selected models
412
+ for i, model_name in enumerate(selected_models):
413
+ status_text.text(f'Loading {model_name}...')
414
+ model = load_model(model_name)
415
+
416
+ if model:
417
+ status_text.text(f'Predicting with {model_name}...')
418
+ pred = predict_with_model(model, processed_image, model_name)
419
+ predictions[model_name] = pred
420
+
421
+ progress_bar.progress((i + 1) / len(selected_models))
422
+
423
+ status_text.empty()
424
+ progress_bar.empty()
425
+
426
+ # Display results
427
+ if predictions:
428
+ # Get recommendation
429
+ recommended_model = recommend_best_model(predictions)
430
+
431
+ st.markdown("### 🎯 Prediction Results")
432
+
433
+ # Show recommendation
434
+ st.markdown(f"""
435
+ <div class="prediction-box">
436
+ <h3>🏆 Recommended Model: {recommended_model}</h3>
437
+ <p>Based on confidence and model performance</p>
438
+ </div>
439
+ """, unsafe_allow_html=True)
440
+
441
+ # Results table
442
+ results_data = []
443
+ for model_name, pred in predictions.items():
444
+ if pred:
445
+ results_data.append({
446
+ 'Model': model_name,
447
+ 'Predicted Class': pred['class_name'],
448
+ 'Confidence (%)': f"{pred['confidence']:.1f}%",
449
+ 'Inference Time (ms)': f"{pred['inference_time']:.1f}",
450
+ 'Recommended': '🏆' if model_name == recommended_model else ''
451
+ })
452
+
453
+ if results_data:
454
+ df_results = pd.DataFrame(results_data)
455
+ st.dataframe(df_results, use_container_width=True)
456
+
457
+ # Confidence comparison
458
+ st.markdown("### 📊 Confidence Comparison")
459
+ confidences = [pred['confidence'] for pred in predictions.values() if pred]
460
+ model_names = [name for name, pred in predictions.items() if pred]
461
+
462
+ fig_conf = go.Figure()
463
+ fig_conf.add_trace(go.Bar(
464
+ x=model_names,
465
+ y=confidences,
466
+ marker_color=['gold' if name == recommended_model else 'lightblue'
467
+ for name in model_names]
468
+ ))
469
+ fig_conf.update_layout(
470
+ title="Prediction Confidence by Model",
471
+ xaxis_title="Models",
472
+ yaxis_title="Confidence (%)",
473
+ height=400
474
+ )
475
+ st.plotly_chart(fig_conf, use_container_width=True)
476
+
477
+ # Probability distribution for recommended model
478
+ if recommended_model in predictions and predictions[recommended_model]:
479
+ st.markdown(f"### 🔍 Detailed Probabilities - {recommended_model}")
480
+ probs = predictions[recommended_model]['probabilities']
481
+ prob_df = pd.DataFrame({
482
+ 'Class': [CLASS_NAMES[i] for i in range(len(probs))],
483
+ 'Probability': probs * 100
484
+ }).sort_values('Probability', ascending=False)
485
+
486
+ fig_prob = px.bar(
487
+ prob_df.head(5),
488
+ x='Probability',
489
+ y='Class',
490
+ orientation='h',
491
+ title=f"Top 5 Class Probabilities - {recommended_model}"
492
+ )
493
+ st.plotly_chart(fig_prob, use_container_width=True)
494
+
495
+ elif page == "📈 Performance Analytics":
496
+ st.markdown("## 📈 Performance Analytics")
497
+
498
+ # Performance overview
499
+ col1, col2, col3, col4 = st.columns(4)
500
+
501
+ with col1:
502
+ st.metric("Best Accuracy", "96.4%", "EfficientNetB0")
503
+ with col2:
504
+ st.metric("Fastest Inference", "18ms", "MobileNetV2")
505
+ with col3:
506
+ st.metric("Smallest Model", "8.7MB", "MobileNetV2")
507
+ with col4:
508
+ st.metric("Total Classes", "11", "Satellites + Debris")
509
+
510
+ # Detailed analytics
511
+ tab1, tab2, tab3 = st.tabs(["Accuracy Analysis", "Efficiency Metrics", "Model Comparison"])
512
+
513
+ with tab1:
514
+ # Accuracy breakdown
515
+ models = list(MODEL_METRICS.keys())
516
+ metrics_list = ['accuracy', 'precision', 'recall', 'f1_score']
517
+
518
+ for metric in metrics_list:
519
+ values = [MODEL_METRICS[model][metric] for model in models]
520
+ fig = go.Figure()
521
+ fig.add_trace(go.Bar(x=models, y=values, name=metric.title()))
522
+ fig.update_layout(title=f"{metric.title()} Comparison", height=300)
523
+ st.plotly_chart(fig, use_container_width=True)
524
+
525
+ with tab2:
526
+ # Efficiency metrics
527
+ col1, col2 = st.columns(2)
528
+
529
+ with col1:
530
+ # Inference time
531
+ times = [MODEL_METRICS[model]["inference_time"] for model in models]
532
+ fig_time = go.Figure()
533
+ fig_time.add_trace(go.Bar(x=models, y=times,
534
+ marker_color=['red' if t > 40 else 'green' for t in times]))
535
+ fig_time.update_layout(title="Inference Time (ms)", height=400)
536
+ st.plotly_chart(fig_time, use_container_width=True)
537
+
538
+ with col2:
539
+ # Model size
540
+ sizes = [MODEL_METRICS[model]["model_size"] for model in models]
541
+ fig_size = go.Figure()
542
+ fig_size.add_trace(go.Bar(x=models, y=sizes,
543
+ marker_color=['red' if s > 25 else 'green' for s in sizes]))
544
+ fig_size.update_layout(title="Model Size (MB)", height=400)
545
+ st.plotly_chart(fig_size, use_container_width=True)
546
+
547
+ with tab3:
548
+ # Side-by-side comparison
549
+ comparison_data = []
550
+ for model in models:
551
+ metrics = MODEL_METRICS[model]
552
+ comparison_data.append({
553
+ 'Model': model,
554
+ 'Accuracy (%)': metrics['accuracy'],
555
+ 'Inference Time (ms)': metrics['inference_time'],
556
+ 'Model Size (MB)': metrics['model_size'],
557
+ 'Training Time (min)': metrics['training_time'],
558
+ 'Efficiency Score': round(metrics['accuracy'] / (metrics['inference_time'] * 0.1 + metrics['model_size'] * 0.1), 2)
559
+ })
560
+
561
+ df_comparison = pd.DataFrame(comparison_data)
562
+ st.dataframe(df_comparison.style.highlight_max(axis=0, subset=['Accuracy (%)', 'Efficiency Score'])
563
+ .highlight_min(axis=0, subset=['Inference Time (ms)', 'Model Size (MB)', 'Training Time (min)']),
564
+ use_container_width=True)
565
+
566
+ elif page == "ℹ️ About Models":
567
+ st.markdown("## ℹ️ Model Information")
568
+
569
+ for model_name, config in MODEL_CONFIGS.items():
570
+ with st.expander(f"📋 {model_name}", expanded=False):
571
+ col1, col2 = st.columns(2)
572
+
573
+ with col1:
574
+ st.markdown("### Description")
575
+ st.write(config["description"])
576
+
577
+ st.markdown("### Input Shape")
578
+ st.code(f"{config['input_shape']}")
579
+
580
+ st.markdown("### Model URL")
581
+ st.code(config["url"])
582
+
583
+ with col2:
584
+ st.markdown("### Strengths")
585
+ for strength in config["strengths"]:
586
+ st.write(f"• {strength}")
587
+
588
+ st.markdown("### Best Use Cases")
589
+ for use_case in config["best_for"]:
590
+ st.write(f"• {use_case}")
591
+
592
+ # Performance summary
593
+ metrics = MODEL_METRICS[model_name]
594
+ st.markdown("### Key Metrics")
595
+ st.write(f"**Accuracy:** {metrics['accuracy']}%")
596
+ st.write(f"**Inference Time:** {metrics['inference_time']}ms")
597
+ st.write(f"**Model Size:** {metrics['model_size']}MB")
598
+
599
+ if __name__ == "__main__":
600
+ main()
app.py DELETED
@@ -1,241 +0,0 @@
1
- import streamlit as st
2
- import tensorflow as tf
3
- import numpy as np
4
- from PIL import Image
5
- import pandas as pd
6
- import matplotlib.pyplot as plt
7
- import plotly.express as px
8
-
9
- # Configure page
10
- st.set_page_config(
11
- page_title="ML Model Demo",
12
- page_icon="🤖",
13
- layout="wide"
14
- )
15
-
16
- @st.cache_resource
17
- def load_model():
18
- """Load your model (cached to avoid reloading)"""
19
- try:
20
- # Replace this with your actual model loading
21
- # Example: model = tf.keras.models.load_model('path/to/your/model.h5')
22
-
23
- # For demonstration, we'll create a simple model
24
- model = tf.keras.Sequential([
25
- tf.keras.layers.Dense(64, activation='relu', input_shape=(4,)),
26
- tf.keras.layers.Dense(32, activation='relu'),
27
- tf.keras.layers.Dense(3, activation='softmax')
28
- ])
29
-
30
- st.success("✅ Model loaded successfully!")
31
- return model
32
-
33
- except Exception as e:
34
- st.error(f"❌ Error loading model: {str(e)}")
35
- return None
36
-
37
- def preprocess_image(image):
38
- """Preprocess image for model input"""
39
- # Resize image to expected dimensions
40
- image = image.resize((224, 224))
41
- # Convert to array and normalize
42
- image_array = np.array(image) / 255.0
43
- # Add batch dimension
44
- return np.expand_dims(image_array, axis=0)
45
-
46
- def make_prediction(model, input_data, input_type):
47
- """Make prediction with the model"""
48
- if model is None:
49
- return "❌ Model not available"
50
-
51
- try:
52
- if input_type == "image":
53
- # Process image prediction
54
- processed_input = preprocess_image(input_data)
55
- # Mock prediction for demo
56
- prediction = np.random.rand(1, 3)
57
- classes = ['Class A', 'Class B', 'Class C']
58
- predicted_class = classes[np.argmax(prediction)]
59
- confidence = np.max(prediction) * 100
60
-
61
- return {
62
- 'predicted_class': predicted_class,
63
- 'confidence': confidence,
64
- 'all_predictions': dict(zip(classes, prediction[0]))
65
- }
66
-
67
- elif input_type == "numeric":
68
- # Process numeric prediction
69
- prediction = model.predict(input_data.reshape(1, -1))
70
- predicted_class = f"Class {np.argmax(prediction[0])}"
71
- confidence = np.max(prediction[0]) * 100
72
-
73
- return {
74
- 'predicted_class': predicted_class,
75
- 'confidence': confidence,
76
- 'raw_output': prediction[0].tolist()
77
- }
78
-
79
- elif input_type == "text":
80
- # Mock text processing
81
- return {
82
- 'sentiment': 'Positive',
83
- 'confidence': 85.6,
84
- 'keywords': ['example', 'text', 'analysis']
85
- }
86
-
87
- except Exception as e:
88
- return f"❌ Prediction error: {str(e)}"
89
-
90
- def main():
91
- # Header
92
- st.title("🤖 Machine Learning Model Demo")
93
- st.markdown("---")
94
-
95
- # Sidebar
96
- st.sidebar.header("🎛️ Model Controls")
97
-
98
- # Load model
99
- with st.spinner("Loading model..."):
100
- model = load_model()
101
-
102
- # Model selection
103
- model_type = st.sidebar.selectbox(
104
- "Select Model Type:",
105
- ["Image Classification", "Numeric Prediction", "Text Analysis"]
106
- )
107
-
108
- # Main content area
109
- col1, col2 = st.columns([2, 1])
110
-
111
- with col1:
112
- if model_type == "Image Classification":
113
- st.subheader("📸 Image Classification")
114
-
115
- uploaded_file = st.file_uploader(
116
- "Upload an image:",
117
- type=['jpg', 'jpeg', 'png', 'bmp'],
118
- help="Supported formats: JPG, JPEG, PNG, BMP"
119
- )
120
-
121
- if uploaded_file is not None:
122
- # Display uploaded image
123
- image = Image.open(uploaded_file)
124
- st.image(image, caption="Uploaded Image", use_column_width=True)
125
-
126
- # Prediction button
127
- if st.button("🔍 Classify Image", type="primary"):
128
- with st.spinner("Analyzing image..."):
129
- result = make_prediction(model, image, "image")
130
-
131
- if isinstance(result, dict):
132
- st.success(f"**Prediction:** {result['predicted_class']}")
133
- st.info(f"**Confidence:** {result['confidence']:.1f}%")
134
-
135
- # Show all predictions
136
- st.subheader("All Predictions:")
137
- for class_name, prob in result['all_predictions'].items():
138
- st.write(f"• {class_name}: {prob*100:.1f}%")
139
- else:
140
- st.error(result)
141
-
142
- elif model_type == "Numeric Prediction":
143
- st.subheader("🔢 Numeric Prediction")
144
-
145
- # Input parameters
146
- col_a, col_b = st.columns(2)
147
-
148
- with col_a:
149
- param1 = st.number_input("Parameter 1:", value=5.0, step=0.1)
150
- param2 = st.number_input("Parameter 2:", value=3.2, step=0.1)
151
-
152
- with col_b:
153
- param3 = st.number_input("Parameter 3:", value=1.4, step=0.1)
154
- param4 = st.number_input("Parameter 4:", value=0.2, step=0.1)
155
-
156
- # Create input array
157
- input_array = np.array([param1, param2, param3, param4])
158
-
159
- if st.button("🚀 Make Prediction", type="primary"):
160
- with st.spinner("Computing prediction..."):
161
- result = make_prediction(model, input_array, "numeric")
162
-
163
- if isinstance(result, dict):
164
- st.success(f"**Prediction:** {result['predicted_class']}")
165
- st.info(f"**Confidence:** {result['confidence']:.1f}%")
166
-
167
- # Visualization
168
- fig, ax = plt.subplots()
169
- ax.bar(range(len(result['raw_output'])), result['raw_output'])
170
- ax.set_xlabel('Class')
171
- ax.set_ylabel('Probability')
172
- ax.set_title('Prediction Probabilities')
173
- st.pyplot(fig)
174
- else:
175
- st.error(result)
176
-
177
- elif model_type == "Text Analysis":
178
- st.subheader("📝 Text Analysis")
179
-
180
- text_input = st.text_area(
181
- "Enter your text:",
182
- placeholder="Type your text here for analysis...",
183
- height=150
184
- )
185
-
186
- if st.button("📊 Analyze Text", type="primary") and text_input.strip():
187
- with st.spinner("Analyzing text..."):
188
- result = make_prediction(model, text_input, "text")
189
-
190
- if isinstance(result, dict):
191
- st.success(f"**Sentiment:** {result['sentiment']}")
192
- st.info(f"**Confidence:** {result['confidence']:.1f}%")
193
-
194
- st.subheader("Keywords:")
195
- for keyword in result['keywords']:
196
- st.write(f"• {keyword}")
197
- else:
198
- st.error(result)
199
-
200
- with col2:
201
- st.subheader("📊 Model Info")
202
-
203
- # Model statistics (mock data)
204
- metrics = {
205
- 'Accuracy': 94.2,
206
- 'Precision': 91.8,
207
- 'Recall': 93.5,
208
- 'F1-Score': 92.6
209
- }
210
-
211
- for metric, value in metrics.items():
212
- st.metric(metric, f"{value}%")
213
-
214
- # Additional info
215
- st.markdown("---")
216
- st.subheader("ℹ️ About")
217
- st.info("""
218
- **Model Details:**
219
- - Framework: TensorFlow 2.13
220
- - Architecture: Deep Neural Network
221
- - Training Data: Custom dataset
222
- - Last Updated: July 2025
223
- """)
224
-
225
- # Usage stats (mock)
226
- st.markdown("---")
227
- st.subheader("📈 Usage Stats")
228
- usage_data = pd.DataFrame({
229
- 'Day': ['Mon', 'Tue', 'Wed', 'Thu', 'Fri'],
230
- 'Predictions': [45, 52, 38, 61, 49]
231
- })
232
-
233
- fig = px.bar(usage_data, x='Day', y='Predictions', title='Daily Predictions')
234
- st.plotly_chart(fig, use_container_width=True)
235
-
236
- # Footer
237
- st.markdown("---")
238
- st.markdown("Built with ❤️ using Streamlit and TensorFlow")
239
-
240
- if __name__ == "__main__":
241
- main()