Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Demo for Shifted MNIST CNN Models | |
| Supports 6 models: | |
| - Shifted MNIST: CNNModel, TinyCNN, MiniCNN | |
| - Attack CNN: Standard, Lighter, Depthwise | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| import sys | |
| import os | |
| # Import model architectures from local files | |
| from models_shifted import CNNModel, TinyCNN, MiniCNN | |
| from models_attack import StandardCNN, LighterCNN, DepthwiseCNN | |
| # Label mapping for shifted MNIST | |
| LABEL_MAPPING = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} | |
| REVERSE_MAPPING = {v: k for k, v in LABEL_MAPPING.items()} | |
| def get_device(): | |
| """Get the best available device""" | |
| if torch.cuda.is_available(): | |
| return torch.device('cuda') | |
| elif torch.backends.mps.is_available(): | |
| return torch.device('mps') | |
| else: | |
| return torch.device('cpu') | |
| def load_model(model_path, model_type, device): | |
| """Load a trained model from checkpoint""" | |
| # Create model instance | |
| if model_type == 'CNN': | |
| model = CNNModel(num_classes=10, dropout_rate=0.5) | |
| elif model_type == 'TinyCNN': | |
| model = TinyCNN(num_classes=10) | |
| elif model_type == 'MiniCNN': | |
| model = MiniCNN(num_classes=10) | |
| elif model_type == 'StandardAttack': | |
| model = StandardCNN(num_classes=10, dropout_rate=0.5) | |
| elif model_type == 'LighterAttack': | |
| model = LighterCNN(num_classes=10, dropout_rate=0.5) | |
| elif model_type == 'DepthwiseAttack': | |
| model = DepthwiseCNN(num_classes=10, dropout_rate=0.5) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| # Load checkpoint | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| if 'model_state_dict' in checkpoint: | |
| # Shifted MNIST format: {'model_state_dict': ..., 'model_info': ...} | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model_info = checkpoint.get('model_info', {}) | |
| else: | |
| # Direct state dict format | |
| model.load_state_dict(checkpoint) | |
| model_info = {} | |
| else: | |
| # Fallback: assume it's a state dict | |
| model.load_state_dict(checkpoint) | |
| model_info = {} | |
| # If model_info is empty, calculate parameters | |
| if not model_info.get('total_parameters'): | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| model_info['total_parameters'] = total_params | |
| model_info['architecture'] = model_type | |
| model.to(device) | |
| model.eval() | |
| return model, model_info | |
| def preprocess_image(image): | |
| """Preprocess image for model input""" | |
| # Convert to grayscale if needed | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| # Resize to 28x28 | |
| image = image.resize((28, 28), Image.Resampling.LANCZOS) | |
| # Convert to numpy array and normalize | |
| img_array = np.array(image).astype(np.float32) / 255.0 | |
| # Apply MNIST normalization | |
| mean = 0.1307 | |
| std = 0.3081 | |
| img_array = (img_array - mean) / std | |
| # Convert to tensor and add batch and channel dimensions | |
| img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0) | |
| return img_tensor | |
| def logit_attack_lowest(logits, margin=5.0): | |
| """ | |
| Attack by boosting lowest logit | |
| Args: | |
| logits: Model logits (batch_size, num_classes) | |
| margin: How much to boost the lowest logit above highest | |
| Returns: | |
| attacked_logits | |
| """ | |
| attacked_logits = logits.clone() | |
| batch_size = logits.size(0) | |
| for i in range(batch_size): | |
| highest_val = torch.max(logits[i]).item() | |
| lowest_idx = torch.argmin(logits[i]).item() | |
| lowest_val = logits[i, lowest_idx].item() | |
| delta_needed = (highest_val - lowest_val) + margin | |
| attacked_logits[i, lowest_idx] += delta_needed | |
| return attacked_logits | |
| def predict_with_timing(model, image, device, apply_attack=False, margin=5.0): | |
| """Make prediction with timing""" | |
| # Preprocess image | |
| img_tensor = preprocess_image(image).to(device) | |
| # Check if model supports return_logits parameter (Attack CNN models) | |
| # by checking if it has the parameter in forward signature | |
| supports_return_logits = apply_attack # Only attack models need logits | |
| # Warm-up run (for accurate timing on GPU) | |
| with torch.no_grad(): | |
| if supports_return_logits: | |
| _ = model(img_tensor, return_logits=True) | |
| else: | |
| _ = model(img_tensor) | |
| # Actual prediction with timing | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| if supports_return_logits: | |
| # Attack CNN models - get logits | |
| logits = model(img_tensor, return_logits=True) | |
| # Apply attack if requested | |
| if apply_attack: | |
| logits = logit_attack_lowest(logits, margin=margin) | |
| probabilities = F.softmax(logits, dim=1) | |
| else: | |
| # Shifted MNIST models - already return softmax probabilities | |
| outputs = model(img_tensor) | |
| # If outputs are logits, apply softmax; if already probabilities, use as-is | |
| if outputs.max() > 1.0 or outputs.min() < 0.0: | |
| # Likely logits | |
| probabilities = F.softmax(outputs, dim=1) | |
| else: | |
| # Already probabilities | |
| probabilities = outputs | |
| end_time = time.time() | |
| inference_time = (end_time - start_time) * 1000 # Convert to milliseconds | |
| # Get predictions | |
| probs = probabilities.cpu().numpy()[0] | |
| predicted_label = np.argmax(probs) | |
| confidence = probs[predicted_label] * 100 | |
| return predicted_label, confidence, probs, inference_time | |
| def create_prediction_output(predicted_label, confidence, probs, inference_time, model_name, model_info): | |
| """Create formatted prediction output""" | |
| # Main prediction | |
| result_text = f"### 🎯 Prediction Results ({model_name})\n\n" | |
| result_text += f"**Predicted Label:** {predicted_label}\n\n" | |
| result_text += f"**Confidence:** {confidence:.2f}%\n\n" | |
| result_text += f"**⏱️ Inference Time:** {inference_time:.3f} ms\n\n" | |
| # Model info | |
| if model_info: | |
| result_text += f"**📊 Model Info:**\n" | |
| result_text += f"- Parameters: {model_info.get('total_parameters', 'N/A'):,}\n" | |
| result_text += f"- Architecture: {model_info.get('architecture', 'N/A')}\n\n" | |
| # Create probability distribution dictionary for plot - showing predicted labels | |
| prob_dict = {} | |
| for i in range(10): | |
| prob_dict[f"Label {i}"] = float(probs[i]) | |
| return result_text, prob_dict | |
| def predict_cnn(image): | |
| """Predict using CNNModel""" | |
| if image is None: | |
| return "Please upload an image", {} | |
| if cnn_model is None: | |
| return "❌ CNNModel not loaded. Please check the model path.", {} | |
| try: | |
| predicted_label, conf, probs, inf_time = predict_with_timing( | |
| cnn_model, image, device | |
| ) | |
| text_output, prob_dict = create_prediction_output( | |
| predicted_label, conf, probs, inf_time, "CNNModel", cnn_info | |
| ) | |
| return text_output, prob_dict | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {} | |
| def predict_tinycnn(image): | |
| """Predict using TinyCNN""" | |
| if image is None: | |
| return "Please upload an image", {} | |
| if tinycnn_model is None: | |
| return "❌ TinyCNN not loaded. Please check the model path.", {} | |
| try: | |
| predicted_label, conf, probs, inf_time = predict_with_timing( | |
| tinycnn_model, image, device | |
| ) | |
| text_output, prob_dict = create_prediction_output( | |
| predicted_label, conf, probs, inf_time, "TinyCNN", tinycnn_info | |
| ) | |
| return text_output, prob_dict | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {} | |
| def predict_minicnn(image): | |
| """Predict using MiniCNN""" | |
| if image is None: | |
| return "Please upload an image", {} | |
| if minicnn_model is None: | |
| return "❌ MiniCNN not loaded. Please check the model path.", {} | |
| try: | |
| predicted_label, conf, probs, inf_time = predict_with_timing( | |
| minicnn_model, image, device | |
| ) | |
| text_output, prob_dict = create_prediction_output( | |
| predicted_label, conf, probs, inf_time, "MiniCNN", minicnn_info | |
| ) | |
| return text_output, prob_dict | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {} | |
| def predict_standard_attack(image): | |
| """Predict using Standard Attack CNN with attack enabled (margin=5)""" | |
| if image is None: | |
| return "Please upload an image", {} | |
| if standard_attack_model is None: | |
| return "❌ Standard Attack CNN not loaded. Please check the model path.", {} | |
| try: | |
| predicted_label, conf, probs, inf_time = predict_with_timing( | |
| standard_attack_model, image, device, apply_attack=True, margin=5.0 | |
| ) | |
| text_output, prob_dict = create_prediction_output( | |
| predicted_label, conf, probs, inf_time, "Standard Attack CNN (margin=5)", standard_attack_info | |
| ) | |
| return text_output, prob_dict | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {} | |
| def predict_lighter_attack(image): | |
| """Predict using Lighter Attack CNN with attack enabled (margin=5)""" | |
| if image is None: | |
| return "Please upload an image", {} | |
| if lighter_attack_model is None: | |
| return "❌ Lighter Attack CNN not loaded. Please check the model path.", {} | |
| try: | |
| predicted_label, conf, probs, inf_time = predict_with_timing( | |
| lighter_attack_model, image, device, apply_attack=True, margin=5.0 | |
| ) | |
| text_output, prob_dict = create_prediction_output( | |
| predicted_label, conf, probs, inf_time, "Lighter Attack CNN (margin=5)", lighter_attack_info | |
| ) | |
| return text_output, prob_dict | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {} | |
| def predict_depthwise_attack(image): | |
| """Predict using Depthwise Attack CNN with attack enabled (margin=5)""" | |
| if image is None: | |
| return "Please upload an image", {} | |
| if depthwise_attack_model is None: | |
| return "❌ Depthwise Attack CNN not loaded. Please check the model path.", {} | |
| try: | |
| predicted_label, conf, probs, inf_time = predict_with_timing( | |
| depthwise_attack_model, image, device, apply_attack=True, margin=5.0 | |
| ) | |
| text_output, prob_dict = create_prediction_output( | |
| predicted_label, conf, probs, inf_time, "Depthwise Attack CNN (margin=5)", depthwise_attack_info | |
| ) | |
| return text_output, prob_dict | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {} | |
| def predict_all_models(image): | |
| """Predict using all models and compare""" | |
| if image is None: | |
| empty_msg = "Please upload an image" | |
| return empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {} | |
| try: | |
| # Shifted MNIST models | |
| cnn_text, cnn_probs = predict_cnn(image) | |
| tiny_text, tiny_probs = predict_tinycnn(image) | |
| mini_text, mini_probs = predict_minicnn(image) | |
| # Attack CNN models | |
| standard_text, standard_probs = predict_standard_attack(image) | |
| lighter_text, lighter_probs = predict_lighter_attack(image) | |
| depthwise_text, depthwise_probs = predict_depthwise_attack(image) | |
| return (cnn_text, cnn_probs, | |
| tiny_text, tiny_probs, | |
| mini_text, mini_probs, | |
| standard_text, standard_probs, | |
| lighter_text, lighter_probs, | |
| depthwise_text, depthwise_probs) | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" | |
| return error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {} | |
| # Initialize device | |
| device = get_device() | |
| print(f"🖥️ Using device: {device}") | |
| # Load models | |
| print("📥 Loading models...") | |
| # Define model paths - use checkpoints in HF_demo directory | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') | |
| # Direct paths to model files in checkpoints directory | |
| cnn_model_path = os.path.join(MODEL_DIR, 'best_CNN_model_acc_99.33.pth') | |
| tinycnn_model_path = os.path.join(MODEL_DIR, 'best_TinyCNN_model_acc_99.17.pth') | |
| minicnn_model_path = os.path.join(MODEL_DIR, 'best_MiniCNN_model_acc_97.57.pth') | |
| standard_attack_path = os.path.join(MODEL_DIR, 'best_standard_attack_CNN_model.pth') | |
| lighter_attack_path = os.path.join(MODEL_DIR, 'best_lighter_attack_CNN_model.pth.pth') | |
| depthwise_attack_path = os.path.join(MODEL_DIR, 'best_depthwise_attack_CNN_model.pth') | |
| print(f"📂 Model directory: {MODEL_DIR}") | |
| print(f" CNN model path: {cnn_model_path}") | |
| print(f" TinyCNN model path: {tinycnn_model_path}") | |
| print(f" MiniCNN model path: {minicnn_model_path}") | |
| print(f" Standard Attack CNN path: {standard_attack_path}") | |
| print(f" Lighter Attack CNN path: {lighter_attack_path}") | |
| print(f" Depthwise Attack CNN path: {depthwise_attack_path}") | |
| # Try to load Shifted MNIST models | |
| try: | |
| cnn_model, cnn_info = load_model(cnn_model_path, 'CNN', device) | |
| print(f"✅ CNNModel loaded: {cnn_info.get('total_parameters', 'N/A'):,} parameters") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load CNNModel: {e}") | |
| cnn_model, cnn_info = None, {} | |
| try: | |
| tinycnn_model, tinycnn_info = load_model(tinycnn_model_path, 'TinyCNN', device) | |
| print(f"✅ TinyCNN loaded: {tinycnn_info.get('total_parameters', 'N/A'):,} parameters") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load TinyCNN: {e}") | |
| tinycnn_model, tinycnn_info = None, {} | |
| try: | |
| minicnn_model, minicnn_info = load_model(minicnn_model_path, 'MiniCNN', device) | |
| print(f"✅ MiniCNN loaded: {minicnn_info.get('total_parameters', 'N/A'):,} parameters") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load MiniCNN: {e}") | |
| minicnn_model, minicnn_info = None, {} | |
| # Try to load Attack CNN models | |
| try: | |
| standard_attack_model, standard_attack_info = load_model(standard_attack_path, 'StandardAttack', device) | |
| print(f"✅ Standard Attack CNN loaded: {standard_attack_info.get('total_parameters', 'N/A'):,} parameters") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load Standard Attack CNN: {e}") | |
| standard_attack_model, standard_attack_info = None, {} | |
| try: | |
| lighter_attack_model, lighter_attack_info = load_model(lighter_attack_path, 'LighterAttack', device) | |
| print(f"✅ Lighter Attack CNN loaded: {lighter_attack_info.get('total_parameters', 'N/A'):,} parameters") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load Lighter Attack CNN: {e}") | |
| lighter_attack_model, lighter_attack_info = None, {} | |
| try: | |
| depthwise_attack_model, depthwise_attack_info = load_model(depthwise_attack_path, 'DepthwiseAttack', device) | |
| print(f"✅ Depthwise Attack CNN loaded: {depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load Depthwise Attack CNN: {e}") | |
| depthwise_attack_model, depthwise_attack_info = None, {} | |
| # Create Gradio interface | |
| with gr.Blocks(title="MNIST CNN Classifier - 6 Models Comparison", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🔢 MNIST Digit Classifier - 6 Model Comparison | |
| This app demonstrates **six CNN architectures** trained on MNIST with **shifted labels**: | |
| ### 🎯 Shifted MNIST Models: | |
| - **CNNModel**: 817K params - High accuracy baseline | |
| - **TinyCNN**: 94K params - Balanced performance | |
| - **MiniCNN**: 1.4K params - Ultra-lightweight | |
| ### ⚔️ Attack CNN Models: | |
| - **Standard Attack CNN**: ~817K params - Standard architecture with attack defense | |
| - **Lighter Attack CNN**: ~94K params - Lighter with attack defense | |
| - **Depthwise Attack CNN**: ~1.4K params - Most efficient with depthwise separable convolutions | |
| **Note:** All models show the **predicted label directly** (0-9) as they were trained. | |
| - Shifted MNIST models: Trained with shifted labels (0→9, 1→8, etc.) | |
| - **Attack CNN models: Apply logit attack with margin=5 (boosts lowest logit above highest)** | |
| Upload a handwritten digit image and compare predictions across all architectures! | |
| """) | |
| # Display model loading status | |
| status_text = "### 📊 Model Status\n\n" | |
| status_text += "**Shifted MNIST Models:**\n\n" | |
| if cnn_model: | |
| status_text += f"✅ **CNNModel** loaded ({cnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n" | |
| else: | |
| status_text += "❌ **CNNModel** not loaded\n\n" | |
| if tinycnn_model: | |
| status_text += f"✅ **TinyCNN** loaded ({tinycnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n" | |
| else: | |
| status_text += "❌ **TinyCNN** not loaded\n\n" | |
| if minicnn_model: | |
| status_text += f"✅ **MiniCNN** loaded ({minicnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n" | |
| else: | |
| status_text += "❌ **MiniCNN** not loaded\n\n" | |
| status_text += "**Attack CNN Models:**\n\n" | |
| if standard_attack_model: | |
| status_text += f"✅ **Standard Attack CNN** loaded ({standard_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n" | |
| else: | |
| status_text += "❌ **Standard Attack CNN** not loaded\n\n" | |
| if lighter_attack_model: | |
| status_text += f"✅ **Lighter Attack CNN** loaded ({lighter_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n" | |
| else: | |
| status_text += "❌ **Lighter Attack CNN** not loaded\n\n" | |
| if depthwise_attack_model: | |
| status_text += f"✅ **Depthwise Attack CNN** loaded ({depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n" | |
| else: | |
| status_text += "❌ **Depthwise Attack CNN** not loaded\n\n" | |
| gr.Markdown(status_text) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Digit Image", | |
| image_mode="L", | |
| sources=["upload", "webcam", "clipboard"] | |
| ) | |
| gr.Markdown("---") | |
| with gr.Tabs(): | |
| with gr.Tab("🔍 Individual Models"): | |
| gr.Markdown("### Shifted MNIST Models") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### CNNModel (817K params)") | |
| cnn_btn = gr.Button( | |
| "Predict with CNNModel", | |
| variant="primary", | |
| interactive=cnn_model is not None | |
| ) | |
| cnn_output = gr.Markdown() | |
| cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### TinyCNN (94K params)") | |
| tiny_btn = gr.Button( | |
| "Predict with TinyCNN", | |
| variant="primary", | |
| interactive=tinycnn_model is not None | |
| ) | |
| tiny_output = gr.Markdown() | |
| tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### MiniCNN (1.4K params)") | |
| mini_btn = gr.Button( | |
| "Predict with MiniCNN", | |
| variant="primary", | |
| interactive=minicnn_model is not None | |
| ) | |
| mini_output = gr.Markdown() | |
| mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| gr.Markdown("---") | |
| gr.Markdown("### Attack CNN Models") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Standard Attack CNN (817K params)") | |
| standard_btn = gr.Button( | |
| "Predict with Standard Attack", | |
| variant="secondary", | |
| interactive=standard_attack_model is not None | |
| ) | |
| standard_output = gr.Markdown() | |
| standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### Lighter Attack CNN (94K params)") | |
| lighter_btn = gr.Button( | |
| "Predict with Lighter Attack", | |
| variant="secondary", | |
| interactive=lighter_attack_model is not None | |
| ) | |
| lighter_output = gr.Markdown() | |
| lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### Depthwise Attack CNN (1.4K params)") | |
| depthwise_btn = gr.Button( | |
| "Predict with Depthwise Attack", | |
| variant="secondary", | |
| interactive=depthwise_attack_model is not None | |
| ) | |
| depthwise_output = gr.Markdown() | |
| depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Tab("⚖️ Compare All Models"): | |
| compare_btn = gr.Button( | |
| "Compare All 6 Models", | |
| variant="primary", | |
| size="lg", | |
| interactive=True | |
| ) | |
| gr.Markdown("### Shifted MNIST Models") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### CNNModel") | |
| compare_cnn_output = gr.Markdown() | |
| compare_cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### TinyCNN") | |
| compare_tiny_output = gr.Markdown() | |
| compare_tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### MiniCNN") | |
| compare_mini_output = gr.Markdown() | |
| compare_mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| gr.Markdown("---") | |
| gr.Markdown("### Attack CNN Models") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Standard Attack CNN") | |
| compare_standard_output = gr.Markdown() | |
| compare_standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### Lighter Attack CNN") | |
| compare_lighter_output = gr.Markdown() | |
| compare_lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| with gr.Column(): | |
| gr.Markdown("#### Depthwise Attack CNN") | |
| compare_depthwise_output = gr.Markdown() | |
| compare_depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10) | |
| # Connect buttons to functions | |
| cnn_btn.click(predict_cnn, inputs=input_image, outputs=[cnn_output, cnn_plot]) | |
| tiny_btn.click(predict_tinycnn, inputs=input_image, outputs=[tiny_output, tiny_plot]) | |
| mini_btn.click(predict_minicnn, inputs=input_image, outputs=[mini_output, mini_plot]) | |
| standard_btn.click(predict_standard_attack, inputs=input_image, outputs=[standard_output, standard_plot]) | |
| lighter_btn.click(predict_lighter_attack, inputs=input_image, outputs=[lighter_output, lighter_plot]) | |
| depthwise_btn.click(predict_depthwise_attack, inputs=input_image, outputs=[depthwise_output, depthwise_plot]) | |
| compare_btn.click( | |
| predict_all_models, | |
| inputs=input_image, | |
| outputs=[ | |
| compare_cnn_output, compare_cnn_plot, | |
| compare_tiny_output, compare_tiny_plot, | |
| compare_mini_output, compare_mini_plot, | |
| compare_standard_output, compare_standard_plot, | |
| compare_lighter_output, compare_lighter_plot, | |
| compare_depthwise_output, compare_depthwise_plot | |
| ] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| print("\n🚀 Launching Gradio app...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |