| | import gradio as gr |
| | import tensorflow as tf |
| | import numpy as np |
| | import pandas as pd |
| | import plotly.express as px |
| | import plotly.graph_objects as go |
| | from PIL import Image |
| | import requests |
| | import io |
| | import logging |
| | import time |
| | import os |
| | import tempfile |
| | from urllib.parse import urlparse |
| | from tensorflow import keras |
| | from tensorflow.keras import layers, models |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | tf.get_logger().setLevel('ERROR') |
| | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
| |
|
| | |
| | |
| | @tf.keras.utils.register_keras_serializable() |
| | class RepeatChannels(keras.layers.Layer): |
| | """Converts single channel (depth) to 3 channels for RGB models""" |
| | def __init__(self, **kwargs): |
| | super(RepeatChannels, self).__init__(**kwargs) |
| | |
| | def call(self, inputs): |
| | return tf.repeat(inputs, 3, axis=-1) |
| | |
| | def get_config(self): |
| | config = super(RepeatChannels, self).get_config() |
| | return config |
| |
|
| | |
| | @tf.keras.utils.register_keras_serializable() |
| | class CustomLayer(keras.layers.Layer): |
| | """Template for additional custom layers if needed""" |
| | def __init__(self, **kwargs): |
| | super(CustomLayer, self).__init__(**kwargs) |
| | |
| | def call(self, inputs): |
| | return inputs |
| | |
| | def get_config(self): |
| | config = super(CustomLayer, self).get_config() |
| | return config |
| |
|
| | |
| | CUSTOM_OBJECTS = { |
| | 'RepeatChannels': RepeatChannels, |
| | 'CustomLayer': CustomLayer, |
| | |
| | } |
| |
|
| | |
| | CLASS_NAMES = { |
| | 0: 'AcrimSat', 1: 'Aquarius', 2: 'Aura', 3: 'Calipso', 4: 'Cloudsat', |
| | 5: 'CubeSat', 6: 'Debris', 7: 'Jason', 8: 'Sentinel-6', 9: 'TRMM', 10: 'Terra' |
| | } |
| |
|
| | |
| | MODEL_CONFIGS = { |
| | "Custom CNN": { |
| | "url": "https://huggingface.co/Bhavi23/Custom_CNN/resolve/main/best_multimodal_model.keras", |
| | "input_shape": (224, 224, 3), |
| | "fallback": "https://huggingface.co/Bhavi23/Custom_CNN/resolve/main/model.keras" |
| | }, |
| | "MobileNetV2": { |
| | "url": "https://huggingface.co/Bhavi23/MobilenetV2/resolve/main/multi_input_model_v1.keras", |
| | "input_shape": (224, 224, 3), |
| | "fallback": "https://huggingface.co/Bhavi23/MobilenetV2/resolve/main/model.keras" |
| | }, |
| | "EfficientNetB0": { |
| | "url": "https://huggingface.co/Bhavi23/EfficientNet_B0/resolve/main/efficientnet_model.keras", |
| | "input_shape": (224, 224, 3), |
| | "fallback": "https://huggingface.co/Bhavi23/EfficientNet_B0/resolve/main/model.keras" |
| | }, |
| | "DenseNet121": { |
| | "url": "https://huggingface.co/Bhavi23/DenseNet/resolve/main/densenet_model.keras", |
| | "input_shape": (224, 224, 3), |
| | "fallback": "https://huggingface.co/Bhavi23/DenseNet/resolve/main/model.keras" |
| | } |
| | } |
| |
|
| | |
| | MODEL_METRICS = { |
| | "Custom CNN": {"accuracy": 95.2, "inference_time": 45, "model_size": 25.3}, |
| | "MobileNetV2": {"accuracy": 92.8, "inference_time": 18, "model_size": 8.7}, |
| | "EfficientNetB0": {"accuracy": 96.4, "inference_time": 35, "model_size": 20.1}, |
| | "DenseNet121": {"accuracy": 94.7, "inference_time": 52, "model_size": 32.8} |
| | } |
| |
|
| | |
| | model_cache = {} |
| |
|
| | def check_url_accessibility(url, timeout=10): |
| | """Check if URL is accessible""" |
| | try: |
| | response = requests.head(url, timeout=timeout, allow_redirects=True) |
| | return response.status_code == 200 |
| | except: |
| | return False |
| |
|
| | def download_model_with_progress(url, timeout=120): |
| | """Download model with better error handling and progress tracking""" |
| | try: |
| | logger.info(f"Attempting to download from: {url}") |
| | |
| | |
| | if not check_url_accessibility(url): |
| | logger.error(f"URL not accessible: {url}") |
| | return None, f"Model URL not accessible: {url}" |
| | |
| | |
| | response = requests.get(url, timeout=timeout, stream=True) |
| | response.raise_for_status() |
| | |
| | |
| | content_type = response.headers.get('content-type', '') |
| | if 'application/octet-stream' not in content_type and 'application/x-hdf' not in content_type: |
| | logger.warning(f"Unexpected content type: {content_type}") |
| | |
| | |
| | total_size = int(response.headers.get('content-length', 0)) |
| | logger.info(f"Downloading model, size: {total_size} bytes") |
| | |
| | if total_size < 1000: |
| | return None, f"Model file too small: {total_size} bytes" |
| | |
| | |
| | content = b"" |
| | downloaded = 0 |
| | |
| | for chunk in response.iter_content(chunk_size=8192): |
| | if chunk: |
| | content += chunk |
| | downloaded += len(chunk) |
| | |
| | logger.info(f"Downloaded {len(content)} bytes") |
| | return io.BytesIO(content), None |
| | |
| | except requests.exceptions.Timeout: |
| | return None, "Download timeout - model file too large or connection slow" |
| | except requests.exceptions.ConnectionError: |
| | return None, "Network connection error" |
| | except requests.exceptions.HTTPError as e: |
| | return None, f"HTTP error: {e}" |
| | except Exception as e: |
| | return None, f"Download error: {str(e)}" |
| |
|
| | def load_model(model_name): |
| | """Load model from Hugging Face with enhanced error handling and custom objects""" |
| | |
| | |
| | if model_name in model_cache: |
| | logger.info(f"Using cached model: {model_name}") |
| | return model_cache[model_name], None |
| | |
| | try: |
| | logger.info(f"Loading model: {model_name}") |
| | config = MODEL_CONFIGS[model_name] |
| | |
| | |
| | model_bytes, error = download_model_with_progress(config["url"]) |
| | |
| | |
| | if error and "fallback" in config: |
| | logger.info(f"Trying fallback URL for {model_name}") |
| | model_bytes, error = download_model_with_progress(config["fallback"]) |
| | |
| | if error: |
| | return None, error |
| | |
| | |
| | try: |
| | import tempfile |
| | |
| | |
| | with tempfile.NamedTemporaryFile(delete=False, suffix='.keras') as tmp_file: |
| | model_bytes.seek(0) |
| | tmp_file.write(model_bytes.read()) |
| | tmp_file_path = tmp_file.name |
| | |
| | |
| | model = tf.keras.models.load_model( |
| | tmp_file_path, |
| | custom_objects=CUSTOM_OBJECTS, |
| | compile=False |
| | ) |
| | |
| | |
| | try: |
| | os.unlink(tmp_file_path) |
| | except: |
| | pass |
| | |
| | |
| | model_cache[model_name] = model |
| | logger.info(f"Successfully loaded and cached model: {model_name}") |
| | return model, None |
| | |
| | except Exception as load_error: |
| | logger.error(f"Model loading error for {model_name}: {str(load_error)}") |
| | |
| | try: |
| | if 'tmp_file_path' in locals(): |
| | os.unlink(tmp_file_path) |
| | except: |
| | pass |
| | return None, f"Model loading failed: {str(load_error)}" |
| | |
| | except Exception as e: |
| | logger.error(f"General error loading {model_name}: {str(e)}") |
| | return None, f"Error loading {model_name}: {str(e)}" |
| |
|
| | def preprocess_image(image, target_size=(224, 224)): |
| | """Preprocess image for model prediction""" |
| | try: |
| | if image is None: |
| | return None, "No image provided" |
| | |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| | image = image.resize(target_size) |
| | image_array = np.array(image) / 255.0 |
| | return np.expand_dims(image_array, axis=0), None |
| | except Exception as e: |
| | return None, f"Error preprocessing image: {str(e)}" |
| |
|
| | def handle_multi_input_prediction(model, image, model_name): |
| | """Handle models that expect multiple inputs (RGB + Depth + Tabular)""" |
| | try: |
| | |
| | rgb_input = image |
| | |
| | |
| | depth_input = np.mean(image, axis=-1, keepdims=True) |
| | depth_input = np.repeat(depth_input, 3, axis=-1) |
| | |
| | |
| | if model_name == "Custom CNN": |
| | tabular_input = np.random.random((image.shape[0], 10)) |
| | else: |
| | tabular_input = np.random.random((image.shape[0], 1)) |
| | |
| | |
| | predictions = model.predict([rgb_input, depth_input, tabular_input], verbose=0) |
| | return predictions |
| | |
| | except Exception as e: |
| | logger.warning(f"Multi-input prediction failed for {model_name}: {e}") |
| | |
| | return model.predict(image, verbose=0) |
| |
|
| | def predict_with_model(model, image, model_name): |
| | """Make prediction with a specific model""" |
| | if model is None: |
| | return None |
| | try: |
| | start_time = time.time() |
| | |
| | |
| | if len(model.input_shape) > 1 or (hasattr(model, 'input') and isinstance(model.input, list)): |
| | |
| | predictions = handle_multi_input_prediction(model, image, model_name) |
| | else: |
| | |
| | predictions = model.predict(image, verbose=0) |
| | |
| | inference_time = (time.time() - start_time) * 1000 |
| | |
| | |
| | if len(predictions.shape) > 1 and predictions.shape[0] > 0: |
| | pred_array = predictions[0] |
| | else: |
| | pred_array = predictions |
| | |
| | predicted_class = np.argmax(pred_array) |
| | confidence = np.max(pred_array) * 100 |
| | |
| | if predicted_class not in CLASS_NAMES: |
| | logger.warning(f"Predicted class {predicted_class} not in CLASS_NAMES") |
| | return None |
| | |
| | return { |
| | 'class': predicted_class, |
| | 'class_name': CLASS_NAMES[predicted_class], |
| | 'confidence': confidence, |
| | 'inference_time': inference_time, |
| | 'probabilities': pred_array |
| | } |
| | except Exception as e: |
| | logger.error(f"Prediction error with {model_name}: {str(e)}") |
| | return None |
| |
|
| | def recommend_best_model(predictions): |
| | """Recommend the best model based on confidence and performance""" |
| | if not predictions: |
| | return "EfficientNetB0" |
| | recommendations = {} |
| | for model_name, pred in predictions.items(): |
| | if pred: |
| | base_score = MODEL_METRICS[model_name]["accuracy"] |
| | confidence_bonus = pred['confidence'] * 0.1 |
| | speed_bonus = max(0, 100 - MODEL_METRICS[model_name]["inference_time"]) * 0.05 |
| | recommendations[model_name] = base_score + confidence_bonus + speed_bonus |
| | return max(recommendations, key=recommendations.get) if recommendations else "EfficientNetB0" |
| |
|
| | def create_confidence_plot(predictions): |
| | """Create a bar plot for model confidence comparison""" |
| | if not predictions: |
| | return None |
| | confidences = [pred['confidence'] for pred in predictions.values() if pred] |
| | model_names = [name for name, pred in predictions.items() if pred] |
| | |
| | if not confidences: |
| | return None |
| | |
| | recommended_model = recommend_best_model(predictions) |
| | fig = go.Figure() |
| | fig.add_trace(go.Bar( |
| | x=model_names, |
| | y=confidences, |
| | marker_color=['gold' if name == recommended_model else 'lightblue' for name in model_names], |
| | text=[f'{c:.1f}%' for c in confidences], |
| | textposition='auto' |
| | )) |
| | fig.update_layout( |
| | title="Prediction Confidence by Model", |
| | xaxis_title="Models", |
| | yaxis_title="Confidence (%)", |
| | height=400 |
| | ) |
| | return fig |
| |
|
| | def create_probability_plot(predictions, recommended_model): |
| | """Create a bar plot for top 5 class probabilities of the recommended model""" |
| | if recommended_model not in predictions or not predictions[recommended_model]: |
| | return None |
| | probs = predictions[recommended_model]['probabilities'] |
| | prob_df = pd.DataFrame({ |
| | 'Class': [CLASS_NAMES[i] for i in range(len(probs))], |
| | 'Probability': probs * 100 |
| | }).sort_values('Probability', ascending=False).head(5) |
| | fig = px.bar( |
| | prob_df, |
| | x='Probability', |
| | y='Class', |
| | orientation='h', |
| | title=f"Top 5 Class Probabilities - {recommended_model}", |
| | color='Probability', |
| | color_continuous_scale='viridis' |
| | ) |
| | fig.update_layout(height=400) |
| | return fig |
| |
|
| | def classify_image(image, selected_models, progress=gr.Progress()): |
| | """Main function to classify an image and return results""" |
| | if image is None: |
| | return "β Please upload an image.", None, None, None, None |
| | if not selected_models: |
| | return "β Please select at least one model.", None, None, None, None |
| |
|
| | progress(0.1, desc="Preprocessing image...") |
| | processed_image, error = preprocess_image(image) |
| | if error: |
| | return f"β {error}", None, None, None, None |
| |
|
| | predictions = {} |
| | results_data = [] |
| | |
| | total_models = len(selected_models) |
| | |
| | for i, model_name in enumerate(selected_models): |
| | progress((i + 1) / total_models * 0.8, desc=f"Loading {model_name}...") |
| | |
| | model, error = load_model(model_name) |
| | if error: |
| | results_data.append({ |
| | 'Model': model_name, |
| | 'Status': 'β Failed', |
| | 'Error': error[:50] + "..." if len(error) > 50 else error |
| | }) |
| | continue |
| | |
| | progress((i + 1) / total_models * 0.9, desc=f"Predicting with {model_name}...") |
| | pred = predict_with_model(model, processed_image, model_name) |
| | |
| | if pred: |
| | predictions[model_name] = pred |
| | results_data.append({ |
| | 'Model': model_name, |
| | 'Status': 'β
Success', |
| | 'Predicted Class': pred['class_name'], |
| | 'Confidence': f"{pred['confidence']:.1f}%", |
| | 'Inference Time': f"{pred['inference_time']:.1f}ms" |
| | }) |
| | else: |
| | results_data.append({ |
| | 'Model': model_name, |
| | 'Status': 'β Failed', |
| | 'Error': 'Prediction failed' |
| | }) |
| |
|
| | progress(1.0, desc="Generating results...") |
| | |
| | if not predictions: |
| | return "β All models failed to make predictions. Check the logs for details.", pd.DataFrame(results_data), image, None, None |
| |
|
| | recommended_model = recommend_best_model(predictions) |
| | results_df = pd.DataFrame(results_data) |
| | confidence_plot = create_confidence_plot(predictions) |
| | probability_plot = create_probability_plot(predictions, recommended_model) |
| |
|
| | success_count = len(predictions) |
| | result_text = f"β
**{success_count}/{total_models} models succeeded**\n\n**π Recommended Model**: {recommended_model}" |
| | |
| | if recommended_model in predictions: |
| | best_pred = predictions[recommended_model] |
| | result_text += f"\n\n**π― Prediction**: {best_pred['class_name']}\n**π Confidence**: {best_pred['confidence']:.1f}%" |
| |
|
| | return result_text, results_df, image, confidence_plot, probability_plot |
| |
|
| | |
| | with gr.Blocks(title="π°οΈ Satellite Classification Dashboard", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# π°οΈ Satellite Classification Dashboard") |
| | gr.Markdown("Upload a satellite image and select models to classify it into one of 11 categories. View predictions, confidence scores, and visualizations.") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | image_input = gr.Image(type="pil", label="πΈ Upload Satellite Image") |
| | model_select = gr.CheckboxGroup( |
| | choices=list(MODEL_CONFIGS.keys()), |
| | value=["EfficientNetB0"], |
| | label="π€ Select Models to Test" |
| | ) |
| | classify_button = gr.Button("π Classify Image", variant="primary", size="lg") |
| | |
| | gr.Markdown(""" |
| | ### π‘ Tips: |
| | - Start with **EfficientNetB0** (best balance) |
| | - **MobileNetV2** is fastest |
| | - Upload clear satellite images for best results |
| | """) |
| | |
| | with gr.Column(scale=2): |
| | output_text = gr.Markdown(label="π Results Summary") |
| | output_table = gr.Dataframe(label="π Detailed Results") |
| | output_image = gr.Image(label="πΌοΈ Uploaded Image") |
| | |
| | with gr.Row(): |
| | confidence_plot = gr.Plot(label="π Model Confidence Comparison") |
| | probability_plot = gr.Plot(label="π― Class Probability Distribution") |
| |
|
| | classify_button.click( |
| | fn=classify_image, |
| | inputs=[image_input, model_select], |
| | outputs=[output_text, output_table, output_image, confidence_plot, probability_plot] |
| | ) |
| |
|
| | gr.Markdown(""" |
| | ### π°οΈ Supported Satellite Classes |
| | **AcrimSat** β’ **Aquarius** β’ **Aura** β’ **Calipso** β’ **Cloudsat** β’ **CubeSat** β’ **Debris** β’ **Jason** β’ **Sentinel-6** β’ **TRMM** β’ **Terra** |
| | |
| | ### π€ Available Models |
| | | Model | Accuracy | Speed | Best For | |
| | |-------|----------|-------|----------| |
| | | **Custom CNN** | 95.2% | Medium | Specialized satellite detection | |
| | | **MobileNetV2** | 92.8% | Fast β‘ | Quick predictions | |
| | | **EfficientNetB0** | 96.4% | Balanced | Best overall performance | |
| | | **DenseNet121** | 94.7% | Slow | Complex pattern recognition | |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |