| | import streamlit as st |
| | import tensorflow as tf |
| | import numpy as np |
| | import pandas as pd |
| | import plotly.express as px |
| | import plotly.graph_objects as go |
| | from plotly.subplots import make_subplots |
| | from PIL import Image |
| | import requests |
| | import io |
| | from datetime import datetime |
| | import time |
| | import logging |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | try: |
| | st.set_page_config( |
| | page_title="Satellite Classification Dashboard", |
| | page_icon="🛰️", |
| | layout="wide", |
| | initial_sidebar_state="expanded" |
| | ) |
| | except Exception as e: |
| | logger.error(f"Error setting page config: {e}") |
| | |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | .main-header { |
| | font-size: 3rem; |
| | font-weight: bold; |
| | text-align: center; |
| | color: #1f77b4; |
| | margin-bottom: 2rem; |
| | } |
| | .model-card { |
| | background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| | padding: 20px; |
| | border-radius: 10px; |
| | margin: 10px 0; |
| | color: white; |
| | } |
| | .metric-card { |
| | background: #f8f9fa; |
| | padding: 15px; |
| | border-radius: 8px; |
| | border-left: 4px solid #1f77b4; |
| | margin: 5px 0; |
| | } |
| | .prediction-box { |
| | background: linear-gradient(135deg, #ff7e5f 0%, #feb47b 100%); |
| | padding: 20px; |
| | border-radius: 10px; |
| | text-align: center; |
| | color: white; |
| | font-size: 1.2rem; |
| | } |
| | .stAlert > div { |
| | padding: 10px; |
| | border-radius: 5px; |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | CLASS_NAMES = { |
| | 0: 'AcrimSat', 1: 'Aquarius', 2: 'Aura', 3: 'Calipso', 4: 'Cloudsat', |
| | 5: 'CubeSat', 6: 'Debris', 7: 'Jason', 8: 'Sentinel-6', 9: 'TRMM', 10: 'Terra' |
| | } |
| |
|
| | |
| | MODEL_CONFIGS = { |
| | "Custom CNN": { |
| | "url": "https://huggingface.co/Bhavi23/Custom_CNN/resolve/main/best_multimodal_model.keras", |
| | "description": "Custom CNN architecture designed for satellite classification", |
| | "input_shape": (224, 224, 3), |
| | "strengths": ["Good generalization", "Balanced performance", "Stable training"], |
| | "best_for": ["General purpose", "Balanced datasets", "When interpretability matters"] |
| | }, |
| | "MobileNetV2": { |
| | "url": "https://huggingface.co/Bhavi23/MobilenetV2/resolve/main/multi_input_model_v1.keras", |
| | "description": "Lightweight model optimized for mobile deployment", |
| | "input_shape": (224, 224, 3), |
| | "strengths": ["Fast inference", "Small model size", "Energy efficient"], |
| | "best_for": ["Real-time applications", "Mobile devices", "Resource constraints"] |
| | }, |
| | "EfficientNetB0": { |
| | "url": "https://huggingface.co/Bhavi23/EfficientNet_B0/resolve/main/efficientnet_model.keras", |
| | "description": "Balanced efficiency and accuracy with compound scaling", |
| | "input_shape": (224, 224, 3), |
| | "strengths": ["High accuracy", "Parameter efficient", "Good transfer learning"], |
| | "best_for": ["High accuracy needs", "Limited data", "Transfer learning scenarios"] |
| | }, |
| | "DenseNet121": { |
| | "url": "https://huggingface.co/Bhavi23/DenseNet/resolve/main/densenet_model.keras", |
| | "description": "Dense connections for feature reuse and gradient flow", |
| | "input_shape": (224, 224, 3), |
| | "strengths": ["Feature reuse", "Good gradient flow", "Parameter efficiency"], |
| | "best_for": ["Complex patterns", "Feature-rich images", "When accuracy is priority"] |
| | } |
| | } |
| |
|
| | |
| | MODEL_METRICS = { |
| | "Custom CNN": { |
| | "accuracy": 95.2, |
| | "precision": 94.8, |
| | "recall": 95.1, |
| | "f1_score": 94.9, |
| | "inference_time": 45, |
| | "model_size": 25.3, |
| | "training_time": 120 |
| | }, |
| | "MobileNetV2": { |
| | "accuracy": 92.8, |
| | "precision": 92.1, |
| | "recall": 92.5, |
| | "f1_score": 92.3, |
| | "inference_time": 18, |
| | "model_size": 8.7, |
| | "training_time": 95 |
| | }, |
| | "EfficientNetB0": { |
| | "accuracy": 96.4, |
| | "precision": 96.1, |
| | "recall": 96.2, |
| | "f1_score": 96.1, |
| | "inference_time": 35, |
| | "model_size": 20.1, |
| | "training_time": 140 |
| | }, |
| | "DenseNet121": { |
| | "accuracy": 94.7, |
| | "precision": 94.2, |
| | "recall": 94.5, |
| | "f1_score": 94.3, |
| | "inference_time": 52, |
| | "model_size": 32.8, |
| | "training_time": 160 |
| | } |
| | } |
| |
|
| | @st.cache_resource |
| | def load_model(model_name): |
| | """Load model from HuggingFace with caching and better error handling""" |
| | try: |
| | logger.info(f"Loading model: {model_name}") |
| | url = MODEL_CONFIGS[model_name]["url"] |
| | |
| | |
| | response = requests.get(url, timeout=60, stream=True) |
| | response.raise_for_status() |
| | |
| | |
| | if len(response.content) < 1000: |
| | st.error(f"Model {model_name} download failed - file too small") |
| | return None |
| | |
| | model_bytes = io.BytesIO(response.content) |
| | |
| | |
| | try: |
| | model = tf.keras.models.load_model(model_bytes) |
| | logger.info(f"Successfully loaded model: {model_name}") |
| | return model |
| | except Exception as load_error: |
| | st.error(f"Error loading Keras model {model_name}: {str(load_error)}") |
| | return None |
| | |
| | except requests.exceptions.Timeout: |
| | st.error(f"Timeout loading {model_name}. Please try again.") |
| | return None |
| | except requests.exceptions.RequestException as e: |
| | st.error(f"Network error loading {model_name}: {str(e)}") |
| | return None |
| | except Exception as e: |
| | st.error(f"Unexpected error loading {model_name}: {str(e)}") |
| | logger.error(f"Error loading {model_name}: {str(e)}") |
| | return None |
| |
|
| | def preprocess_image(image, target_size=(224, 224)): |
| | """Preprocess image for model prediction with error handling""" |
| | try: |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| | image = image.resize(target_size) |
| | image_array = np.array(image) / 255.0 |
| | return np.expand_dims(image_array, axis=0) |
| | except Exception as e: |
| | st.error(f"Error preprocessing image: {str(e)}") |
| | return None |
| |
|
| | def predict_with_model(model, image, model_name): |
| | """Make prediction with a specific model with better error handling""" |
| | if model is None: |
| | return None |
| | |
| | try: |
| | start_time = time.time() |
| | predictions = model.predict(image, verbose=0) |
| | inference_time = (time.time() - start_time) * 1000 |
| | |
| | |
| | if predictions is None or len(predictions) == 0: |
| | st.error(f"No predictions returned from {model_name}") |
| | return None |
| | |
| | predicted_class = np.argmax(predictions[0]) |
| | confidence = np.max(predictions[0]) * 100 |
| | |
| | |
| | if predicted_class not in CLASS_NAMES: |
| | st.error(f"Invalid class prediction from {model_name}: {predicted_class}") |
| | return None |
| | |
| | return { |
| | 'class': predicted_class, |
| | 'class_name': CLASS_NAMES[predicted_class], |
| | 'confidence': confidence, |
| | 'inference_time': inference_time, |
| | 'probabilities': predictions[0] |
| | } |
| | except Exception as e: |
| | st.error(f"Prediction error with {model_name}: {str(e)}") |
| | logger.error(f"Prediction error with {model_name}: {str(e)}") |
| | return None |
| |
|
| | def recommend_best_model(image_predictions): |
| | """Recommend the best model based on predictions and confidence""" |
| | if not image_predictions: |
| | return "EfficientNetB0" |
| | |
| | |
| | recommendations = {} |
| | for model_name, pred in image_predictions.items(): |
| | if pred: |
| | |
| | base_score = MODEL_METRICS[model_name]["accuracy"] |
| | confidence_bonus = pred['confidence'] * 0.1 |
| | speed_bonus = max(0, 100 - MODEL_METRICS[model_name]["inference_time"]) * 0.05 |
| | |
| | recommendations[model_name] = base_score + confidence_bonus + speed_bonus |
| | |
| | if recommendations: |
| | best_model = max(recommendations, key=recommendations.get) |
| | return best_model |
| | return "EfficientNetB0" |
| |
|
| | def create_metrics_comparison(): |
| | """Create comprehensive metrics comparison dashboard with error handling""" |
| | try: |
| | |
| | fig = make_subplots( |
| | rows=2, cols=2, |
| | subplot_titles=('Accuracy Comparison', 'Model Size vs Inference Time', |
| | 'Performance Metrics Radar', 'Training Efficiency'), |
| | specs=[[{"type": "bar"}, {"type": "scatter"}], |
| | [{"type": "scatterpolar"}, {"type": "bar"}]] |
| | ) |
| | |
| | models = list(MODEL_METRICS.keys()) |
| | |
| | |
| | accuracies = [MODEL_METRICS[model]["accuracy"] for model in models] |
| | fig.add_trace( |
| | go.Bar(x=models, y=accuracies, name="Accuracy", |
| | marker_color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']), |
| | row=1, col=1 |
| | ) |
| | |
| | |
| | sizes = [MODEL_METRICS[model]["model_size"] for model in models] |
| | times = [MODEL_METRICS[model]["inference_time"] for model in models] |
| | fig.add_trace( |
| | go.Scatter(x=sizes, y=times, mode='markers+text', |
| | text=models, textposition="top center", |
| | marker=dict(size=15, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']), |
| | name="Size vs Speed"), |
| | row=1, col=2 |
| | ) |
| | |
| | |
| | metrics = ['accuracy', 'precision', 'recall', 'f1_score'] |
| | colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'] |
| | for i, model in enumerate(models): |
| | values = [MODEL_METRICS[model][metric] for metric in metrics] |
| | fig.add_trace( |
| | go.Scatterpolar(r=values, theta=metrics, fill='toself', |
| | name=model, opacity=0.7, line_color=colors[i]), |
| | row=2, col=1 |
| | ) |
| | |
| | |
| | training_times = [MODEL_METRICS[model]["training_time"] for model in models] |
| | fig.add_trace( |
| | go.Bar(x=models, y=training_times, name="Training Time", |
| | marker_color=['#9467bd', '#8c564b', '#e377c2', '#7f7f7f']), |
| | row=2, col=2 |
| | ) |
| | |
| | |
| | fig.update_layout(height=800, showlegend=True, |
| | title_text="Comprehensive Model Comparison Dashboard") |
| | fig.update_xaxes(title_text="Models", row=1, col=1) |
| | fig.update_yaxes(title_text="Accuracy (%)", row=1, col=1) |
| | fig.update_xaxes(title_text="Model Size (MB)", row=1, col=2) |
| | fig.update_yaxes(title_text="Inference Time (ms)", row=1, col=2) |
| | fig.update_xaxes(title_text="Models", row=2, col=2) |
| | fig.update_yaxes(title_text="Training Time (minutes)", row=2, col=2) |
| | |
| | return fig |
| | except Exception as e: |
| | st.error(f"Error creating metrics comparison chart: {str(e)}") |
| | return None |
| |
|
| | def create_class_distribution_chart(): |
| | """Create class distribution visualization with error handling""" |
| | try: |
| | classes = list(CLASS_NAMES.values()) |
| | samples = [7500 if cls != 'Debris' else 15000 for cls in classes] |
| | percentages = [8.33 if cls != 'Debris' else 16.67 for cls in classes] |
| | |
| | fig = go.Figure() |
| | fig.add_trace(go.Bar( |
| | x=classes, |
| | y=samples, |
| | text=[f'{s} ({p:.1f}%)' for s, p in zip(samples, percentages)], |
| | textposition='auto', |
| | marker_color=['#ff6b6b' if cls == 'Debris' else '#4ecdc4' for cls in classes] |
| | )) |
| | |
| | fig.update_layout( |
| | title="Class Distribution in Training Dataset", |
| | xaxis_title="Satellite Classes", |
| | yaxis_title="Number of Samples", |
| | height=400, |
| | xaxis_tickangle=-45 |
| | ) |
| | |
| | return fig |
| | except Exception as e: |
| | st.error(f"Error creating class distribution chart: {str(e)}") |
| | return None |
| |
|
| | def create_confusion_matrix_heatmap(): |
| | """Create a sample confusion matrix heatmap for demonstration""" |
| | try: |
| | |
| | classes = list(CLASS_NAMES.values()) |
| | np.random.seed(42) |
| | |
| | |
| | confusion_matrix = np.random.randint(0, 100, size=(len(classes), len(classes))) |
| | |
| | np.fill_diagonal(confusion_matrix, np.random.randint(400, 500, size=len(classes))) |
| | |
| | fig = go.Figure(data=go.Heatmap( |
| | z=confusion_matrix, |
| | x=classes, |
| | y=classes, |
| | colorscale='Blues', |
| | showscale=True |
| | )) |
| | |
| | fig.update_layout( |
| | title="Sample Confusion Matrix (Demo Data)", |
| | xaxis_title="Predicted Class", |
| | yaxis_title="True Class", |
| | height=600 |
| | ) |
| | |
| | return fig |
| | except Exception as e: |
| | st.error(f"Error creating confusion matrix: {str(e)}") |
| | return None |
| |
|
| | |
| | def main(): |
| | try: |
| | |
| | st.markdown('<h1 class="main-header">🛰️ Satellite Classification Dashboard</h1>', |
| | unsafe_allow_html=True) |
| | |
| | |
| | st.sidebar.title("Navigation") |
| | page = st.sidebar.selectbox("Choose a page", |
| | ["🏠 Home", "📊 Model Comparison", "🔍 Image Classification", |
| | "📈 Performance Analytics", "ℹ️ About Models"]) |
| | |
| | |
| | st.sidebar.markdown("---") |
| | st.sidebar.markdown("### System Info") |
| | st.sidebar.info(f"Total Classes: {len(CLASS_NAMES)}") |
| | st.sidebar.info(f"Available Models: {len(MODEL_CONFIGS)}") |
| | st.sidebar.info("Built with Streamlit & TensorFlow") |
| | |
| | if page == "🏠 Home": |
| | st.markdown("## Welcome to the Satellite Classification System") |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | st.markdown("### 🎯 System Overview") |
| | st.write(""" |
| | This dashboard provides comprehensive satellite classification using 4 different |
| | deep learning models. Upload satellite images to classify them into 11 different |
| | categories including various satellites and space debris. |
| | """) |
| | |
| | st.markdown("### 🛰️ Supported Classes") |
| | for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()): |
| | if i < 6: |
| | st.write(f"• **{class_name}**") |
| | |
| | with col2: |
| | st.markdown("### 🤖 Available Models") |
| | st.write(""" |
| | - **Custom CNN**: Tailored architecture for satellite imagery |
| | - **MobileNetV2**: Lightweight and fast inference |
| | - **EfficientNetB0**: Best accuracy-efficiency balance |
| | - **DenseNet121**: Complex pattern recognition |
| | """) |
| | |
| | st.markdown("### 📊 Remaining Classes") |
| | for i, (class_id, class_name) in enumerate(CLASS_NAMES.items()): |
| | if i >= 6: |
| | st.write(f"• **{class_name}**") |
| | |
| | |
| | chart = create_class_distribution_chart() |
| | if chart: |
| | st.plotly_chart(chart, use_container_width=True) |
| | |
| | |
| | st.markdown("### 🚀 Quick Start Guide") |
| | st.markdown(""" |
| | 1. Navigate to **🔍 Image Classification** to upload and classify satellite images |
| | 2. Check **📊 Model Comparison** to compare different model performances |
| | 3. Explore **📈 Performance Analytics** for detailed metrics |
| | 4. Read **ℹ️ About Models** to understand each model's capabilities |
| | """) |
| | |
| | elif page == "📊 Model Comparison": |
| | st.markdown("## 📊 Model Performance Comparison") |
| | |
| | |
| | st.markdown("### Performance Metrics Summary") |
| | df_metrics = pd.DataFrame(MODEL_METRICS).T |
| | st.dataframe(df_metrics.style.highlight_max(axis=0), use_container_width=True) |
| | |
| | |
| | chart = create_metrics_comparison() |
| | if chart: |
| | st.plotly_chart(chart, use_container_width=True) |
| | |
| | |
| | st.markdown("### 🎯 Model Selection Guide") |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | st.markdown("#### 🏆 Best for Accuracy") |
| | st.success("**EfficientNetB0** - 96.4% accuracy") |
| | |
| | st.markdown("#### ⚡ Best for Speed") |
| | st.info("**MobileNetV2** - 18ms inference time") |
| | |
| | with col2: |
| | st.markdown("#### 💾 Most Lightweight") |
| | st.info("**MobileNetV2** - 8.7MB model size") |
| | |
| | st.markdown("#### 🎯 Best Overall Balance") |
| | st.warning("**EfficientNetB0** - High accuracy + efficiency") |
| | |
| | |
| | st.markdown("### 🏅 Model Rankings") |
| | |
| | |
| | rankings = [] |
| | for model_name, metrics in MODEL_METRICS.items(): |
| | |
| | score = (metrics['accuracy'] * 0.4 + |
| | (100 - metrics['inference_time']) * 0.3 + |
| | (50 - metrics['model_size']) * 0.3) |
| | rankings.append({'Model': model_name, 'Overall Score': round(score, 1)}) |
| | |
| | rankings_df = pd.DataFrame(rankings).sort_values('Overall Score', ascending=False) |
| | st.dataframe(rankings_df, use_container_width=True) |
| | |
| | elif page == "🔍 Image Classification": |
| | st.markdown("## 🔍 Image Classification") |
| | |
| | |
| | st.info(""" |
| | 📋 **Instructions:** |
| | 1. Upload a satellite or space object image (PNG, JPG, or JPEG) |
| | 2. Select one or more models for classification |
| | 3. Click 'Classify Image' to get predictions |
| | 4. View results, confidence scores, and recommendations |
| | """) |
| | |
| | uploaded_file = st.file_uploader( |
| | "Upload a satellite image", |
| | type=['png', 'jpg', 'jpeg'], |
| | help="Upload an image of a satellite or space object for classification" |
| | ) |
| | |
| | if uploaded_file is not None: |
| | try: |
| | |
| | image = Image.open(uploaded_file) |
| | |
| | col1, col2 = st.columns([1, 2]) |
| | |
| | with col1: |
| | st.image(image, caption="Uploaded Image", use_container_width=True) |
| | |
| | with col2: |
| | st.markdown("### Image Details") |
| | st.write(f"**Filename:** {uploaded_file.name}") |
| | st.write(f"**Size:** {image.size}") |
| | st.write(f"**Mode:** {image.mode}") |
| | st.write(f"**File Size:** {len(uploaded_file.getvalue())} bytes") |
| | |
| | |
| | st.markdown("### Select Models for Classification") |
| | selected_models = st.multiselect( |
| | "Choose models to run predictions with:", |
| | list(MODEL_CONFIGS.keys()), |
| | default=["EfficientNetB0"], |
| | help="Select one or more models. More models = longer processing time." |
| | ) |
| | |
| | if st.button("🚀 Classify Image", type="primary"): |
| | if not selected_models: |
| | st.warning("Please select at least one model.") |
| | return |
| | |
| | |
| | processed_image = preprocess_image(image) |
| | if processed_image is None: |
| | st.error("Failed to preprocess image") |
| | return |
| | |
| | |
| | predictions = {} |
| | |
| | |
| | progress_bar = st.progress(0) |
| | status_text = st.empty() |
| | |
| | |
| | for i, model_name in enumerate(selected_models): |
| | try: |
| | status_text.text(f'Loading {model_name}... ({i+1}/{len(selected_models)})') |
| | model = load_model(model_name) |
| | |
| | if model: |
| | status_text.text(f'Predicting with {model_name}... ({i+1}/{len(selected_models)})') |
| | pred = predict_with_model(model, processed_image, model_name) |
| | if pred: |
| | predictions[model_name] = pred |
| | else: |
| | st.warning(f"Failed to get prediction from {model_name}") |
| | else: |
| | st.warning(f"Failed to load {model_name}") |
| | |
| | except Exception as e: |
| | st.error(f"Error processing {model_name}: {str(e)}") |
| | logger.error(f"Error processing {model_name}: {str(e)}") |
| | |
| | progress_bar.progress((i + 1) / len(selected_models)) |
| | |
| | status_text.empty() |
| | progress_bar.empty() |
| | |
| | |
| | if predictions: |
| | |
| | recommended_model = recommend_best_model(predictions) |
| | |
| | st.markdown("### 🎯 Prediction Results") |
| | |
| | |
| | st.markdown(f""" |
| | <div class="prediction-box"> |
| | <h3>🏆 Recommended Model: {recommended_model}</h3> |
| | <p>Based on confidence and model performance</p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| | |
| | |
| | results_data = [] |
| | for model_name, pred in predictions.items(): |
| | if pred: |
| | results_data.append({ |
| | 'Model': model_name, |
| | 'Predicted Class': pred['class_name'], |
| | 'Confidence (%)': f"{pred['confidence']:.1f}%", |
| | 'Inference Time (ms)': f"{pred['inference_time']:.1f}", |
| | 'Recommended': '🏆' if model_name == recommended_model else '' |
| | }) |
| | |
| | if results_data: |
| | df_results = pd.DataFrame(results_data) |
| | st.dataframe(df_results, use_container_width=True) |
| | |
| | |
| | if len(predictions) > 1: |
| | st.markdown("### 📊 Confidence Comparison") |
| | confidences = [pred['confidence'] for pred in predictions.values() if pred] |
| | model_names = [name for name, pred in predictions.items() if pred] |
| | |
| | try: |
| | fig_conf = go.Figure() |
| | fig_conf.add_trace(go.Bar( |
| | x=model_names, |
| | y=confidences, |
| | marker_color=['gold' if name == recommended_model else 'lightblue' |
| | for name in model_names] |
| | )) |
| | fig_conf.update_layout( |
| | title="Prediction Confidence by Model", |
| | xaxis_title="Models", |
| | yaxis_title="Confidence (%)", |
| | height=400 |
| | ) |
| | st.plotly_chart(fig_conf, use_container_width=True) |
| | except Exception as e: |
| | st.warning(f"Could not create confidence chart: {str(e)}") |
| | |
| | |
| | if recommended_model in predictions and predictions[recommended_model]: |
| | try: |
| | st.markdown(f"### 🔍 Detailed Probabilities - {recommended_model}") |
| | probs = predictions[recommended_model]['probabilities'] |
| | prob_df = pd.DataFrame({ |
| | 'Class': [CLASS_NAMES[i] for i in range(len(probs))], |
| | 'Probability': probs * 100 |
| | }).sort_values('Probability', ascending=False) |
| | |
| | fig_prob = px.bar( |
| | prob_df.head(5), |
| | x='Probability', |
| | y='Class', |
| | orientation='h', |
| | title=f"Top 5 Class Probabilities - {recommended_model}", |
| | color='Probability', |
| | color_continuous_scale='viridis' |
| | ) |
| | st.plotly_chart(fig_prob, use_container_width=True) |
| | except Exception as e: |
| | st.warning(f"Could not create probability chart: {str(e)}") |
| | else: |
| | st.error("No successful predictions were made. Please try again with different models.") |
| | |
| | except Exception as e: |
| | st.error(f"Error processing uploaded image: {str(e)}") |
| | logger.error(f"Error processing uploaded image: {str(e)}") |
| | |
| | elif page == "📈 Performance Analytics": |
| | st.markdown("## 📈 Performance Analytics") |
| | |
| | |
| | col1, col2, col3, col4 = st.columns(4) |
| | |
| | with col1: |
| | st.metric("Best Accuracy", "96.4%", "EfficientNetB0") |
| | with col2: |
| | st.metric("Fastest Inference", "18ms", "MobileNetV2") |
| | with col3: |
| | st.metric("Smallest Model", "8.7MB", "MobileNetV2") |
| | with col4: |
| | st.metric("Total Classes", "11", "Satellites + Debris") |
| | |
| | |
| | tab1, tab2, tab3, tab4 = st.tabs(["Accuracy Analysis", "Efficiency Metrics", "Model Comparison", "Confusion Matrix"]) |
| | |
| | with tab1: |
| | try: |
| | |
| | models = list(MODEL_METRICS.keys()) |
| | metrics_list = ['accuracy', 'precision', 'recall', 'f1_score'] |
| | |
| | for metric in metrics_list: |
| | values = [MODEL_METRICS[model][metric] for model in models] |
| | fig = go.Figure() |
| | fig.add_trace(go.Bar( |
| | x=models, |
| | y=values, |
| | name=metric.title(), |
| | marker_color='lightblue', |
| | text=[f'{v:.1f}%' for v in values], |
| | textposition='auto' |
| | )) |
| | fig.update_layout( |
| | title=f"{metric.title()} Comparison", |
| | height=300, |
| | yaxis_title=f"{metric.title()} (%)" |
| | ) |
| | st.plotly_chart(fig, use_container_width=True) |
| | except Exception as e: |
| | st.error(f"Error creating accuracy charts: {str(e)}") |
| | |
| | with tab2: |
| | try: |
| | |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | |
| | times = [MODEL_METRICS[model]["inference_time"] for model in models] |
| | fig_time = go.Figure() |
| | fig_time.add_trace(go.Bar( |
| | x=models, |
| | y=times, |
| | name="Inference Time", |
| | marker_color='orange', |
| | text=[f'{t:.1f} ms' for t in times], |
| | textposition='auto' |
| | )) |
| | fig_time.update_layout( |
| | title="Inference Time per Model", |
| | yaxis_title="Time (ms)", |
| | height=300 |
| | ) |
| | st.plotly_chart(fig_time, use_container_width=True) |
| |
|
| | with col2: |
| | |
| | sizes = [MODEL_METRICS[model]["model_size"] for model in models] |
| | fig_size = go.Figure() |
| | fig_size.add_trace(go.Bar( |
| | x=models, |
| | y=sizes, |
| | name="Model Size", |
| | marker_color='green', |
| | text=[f'{s:.1f} MB' for s in sizes], |
| | textposition='auto' |
| | )) |
| | fig_size.update_layout( |
| | title="Model Size per Model", |
| | yaxis_title="Size (MB)", |
| | height=300 |
| | ) |
| | st.plotly_chart(fig_size, use_container_width=True) |
| |
|
| | except Exception as e: |
| | st.error(f"Error displaying efficiency metrics: {str(e)}") |
| |
|
| | with tab3: |
| | |
| | comp_fig = create_metrics_comparison() |
| | if comp_fig: |
| | st.plotly_chart(comp_fig, use_container_width=True) |
| |
|
| | with tab4: |
| | |
| | cm_fig = create_confusion_matrix_heatmap() |
| | if cm_fig: |
| | st.plotly_chart(cm_fig, use_container_width=True) |
| |
|
| | elif page == "ℹ️ About Models": |
| | st.markdown("## ℹ️ Model Details and Use Cases") |
| |
|
| | for model_name, config in MODEL_CONFIGS.items(): |
| | with st.expander(f"🔍 {model_name}"): |
| | st.markdown(f"<div class='model-card'><h4>{model_name}</h4>", unsafe_allow_html=True) |
| | st.markdown(f"**Description:** {config['description']}") |
| | st.markdown(f"**Input Shape:** {config['input_shape']}") |
| | st.markdown("**Strengths:**") |
| | for s in config['strengths']: |
| | st.markdown(f"• {s}") |
| | st.markdown("**Best For:**") |
| | for use in config['best_for']: |
| | st.markdown(f"• {use}") |
| | st.markdown("</div>", unsafe_allow_html=True) |
| |
|
| | except Exception as e: |
| | st.error(f"An unexpected error occurred: {str(e)}") |
| | logger.error(f"Main app error: {str(e)}") |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | main() |
| |
|