""" Gradio Demo for Shifted MNIST CNN Models Supports 6 models: - Shifted MNIST: CNNModel, TinyCNN, MiniCNN - Attack CNN: Standard, Lighter, Depthwise """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np import time import sys import os # Import model architectures from local files from models_shifted import CNNModel, TinyCNN, MiniCNN from models_attack import StandardCNN, LighterCNN, DepthwiseCNN # Label mapping for shifted MNIST LABEL_MAPPING = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0} REVERSE_MAPPING = {v: k for k, v in LABEL_MAPPING.items()} def get_device(): """Get the best available device""" if torch.cuda.is_available(): return torch.device('cuda') elif torch.backends.mps.is_available(): return torch.device('mps') else: return torch.device('cpu') def load_model(model_path, model_type, device): """Load a trained model from checkpoint""" # Create model instance if model_type == 'CNN': model = CNNModel(num_classes=10, dropout_rate=0.5) elif model_type == 'TinyCNN': model = TinyCNN(num_classes=10) elif model_type == 'MiniCNN': model = MiniCNN(num_classes=10) elif model_type == 'StandardAttack': model = StandardCNN(num_classes=10, dropout_rate=0.5) elif model_type == 'LighterAttack': model = LighterCNN(num_classes=10, dropout_rate=0.5) elif model_type == 'DepthwiseAttack': model = DepthwiseCNN(num_classes=10, dropout_rate=0.5) else: raise ValueError(f"Unknown model type: {model_type}") # Load checkpoint checkpoint = torch.load(model_path, map_location=device) # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: # Shifted MNIST format: {'model_state_dict': ..., 'model_info': ...} model.load_state_dict(checkpoint['model_state_dict']) model_info = checkpoint.get('model_info', {}) else: # Direct state dict format model.load_state_dict(checkpoint) model_info = {} else: # Fallback: assume it's a state dict model.load_state_dict(checkpoint) model_info = {} # If model_info is empty, calculate parameters if not model_info.get('total_parameters'): total_params = sum(p.numel() for p in model.parameters()) model_info['total_parameters'] = total_params model_info['architecture'] = model_type model.to(device) model.eval() return model, model_info def preprocess_image(image): """Preprocess image for model input""" # Convert to grayscale if needed if image.mode != 'L': image = image.convert('L') # Resize to 28x28 image = image.resize((28, 28), Image.Resampling.LANCZOS) # Convert to numpy array and normalize img_array = np.array(image).astype(np.float32) / 255.0 # Apply MNIST normalization mean = 0.1307 std = 0.3081 img_array = (img_array - mean) / std # Convert to tensor and add batch and channel dimensions img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0) return img_tensor def logit_attack_lowest(logits, margin=5.0): """ Attack by boosting lowest logit Args: logits: Model logits (batch_size, num_classes) margin: How much to boost the lowest logit above highest Returns: attacked_logits """ attacked_logits = logits.clone() batch_size = logits.size(0) for i in range(batch_size): highest_val = torch.max(logits[i]).item() lowest_idx = torch.argmin(logits[i]).item() lowest_val = logits[i, lowest_idx].item() delta_needed = (highest_val - lowest_val) + margin attacked_logits[i, lowest_idx] += delta_needed return attacked_logits def predict_with_timing(model, image, device, apply_attack=False, margin=5.0): """Make prediction with timing""" # Preprocess image img_tensor = preprocess_image(image).to(device) # Check if model supports return_logits parameter (Attack CNN models) # by checking if it has the parameter in forward signature supports_return_logits = apply_attack # Only attack models need logits # Warm-up run (for accurate timing on GPU) with torch.no_grad(): if supports_return_logits: _ = model(img_tensor, return_logits=True) else: _ = model(img_tensor) # Actual prediction with timing start_time = time.time() with torch.no_grad(): if supports_return_logits: # Attack CNN models - get logits logits = model(img_tensor, return_logits=True) # Apply attack if requested if apply_attack: logits = logit_attack_lowest(logits, margin=margin) probabilities = F.softmax(logits, dim=1) else: # Shifted MNIST models - already return softmax probabilities outputs = model(img_tensor) # If outputs are logits, apply softmax; if already probabilities, use as-is if outputs.max() > 1.0 or outputs.min() < 0.0: # Likely logits probabilities = F.softmax(outputs, dim=1) else: # Already probabilities probabilities = outputs end_time = time.time() inference_time = (end_time - start_time) * 1000 # Convert to milliseconds # Get predictions probs = probabilities.cpu().numpy()[0] predicted_label = np.argmax(probs) confidence = probs[predicted_label] * 100 return predicted_label, confidence, probs, inference_time def create_prediction_output(predicted_label, confidence, probs, inference_time, model_name, model_info): """Create formatted prediction output""" # Main prediction result_text = f"### šŸŽÆ Prediction Results ({model_name})\n\n" result_text += f"**Predicted Label:** {predicted_label}\n\n" result_text += f"**Confidence:** {confidence:.2f}%\n\n" result_text += f"**ā±ļø Inference Time:** {inference_time:.3f} ms\n\n" # Model info if model_info: result_text += f"**šŸ“Š Model Info:**\n" result_text += f"- Parameters: {model_info.get('total_parameters', 'N/A'):,}\n" result_text += f"- Architecture: {model_info.get('architecture', 'N/A')}\n\n" # Create probability distribution dictionary for plot - showing predicted labels prob_dict = {} for i in range(10): prob_dict[f"Label {i}"] = float(probs[i]) return result_text, prob_dict def predict_cnn(image): """Predict using CNNModel""" if image is None: return "Please upload an image", {} if cnn_model is None: return "āŒ CNNModel not loaded. Please check the model path.", {} try: predicted_label, conf, probs, inf_time = predict_with_timing( cnn_model, image, device ) text_output, prob_dict = create_prediction_output( predicted_label, conf, probs, inf_time, "CNNModel", cnn_info ) return text_output, prob_dict except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {} def predict_tinycnn(image): """Predict using TinyCNN""" if image is None: return "Please upload an image", {} if tinycnn_model is None: return "āŒ TinyCNN not loaded. Please check the model path.", {} try: predicted_label, conf, probs, inf_time = predict_with_timing( tinycnn_model, image, device ) text_output, prob_dict = create_prediction_output( predicted_label, conf, probs, inf_time, "TinyCNN", tinycnn_info ) return text_output, prob_dict except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {} def predict_minicnn(image): """Predict using MiniCNN""" if image is None: return "Please upload an image", {} if minicnn_model is None: return "āŒ MiniCNN not loaded. Please check the model path.", {} try: predicted_label, conf, probs, inf_time = predict_with_timing( minicnn_model, image, device ) text_output, prob_dict = create_prediction_output( predicted_label, conf, probs, inf_time, "MiniCNN", minicnn_info ) return text_output, prob_dict except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {} def predict_standard_attack(image): """Predict using Standard Attack CNN with attack enabled (margin=5)""" if image is None: return "Please upload an image", {} if standard_attack_model is None: return "āŒ Standard Attack CNN not loaded. Please check the model path.", {} try: predicted_label, conf, probs, inf_time = predict_with_timing( standard_attack_model, image, device, apply_attack=True, margin=5.0 ) text_output, prob_dict = create_prediction_output( predicted_label, conf, probs, inf_time, "Standard Attack CNN (margin=5)", standard_attack_info ) return text_output, prob_dict except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {} def predict_lighter_attack(image): """Predict using Lighter Attack CNN with attack enabled (margin=5)""" if image is None: return "Please upload an image", {} if lighter_attack_model is None: return "āŒ Lighter Attack CNN not loaded. Please check the model path.", {} try: predicted_label, conf, probs, inf_time = predict_with_timing( lighter_attack_model, image, device, apply_attack=True, margin=5.0 ) text_output, prob_dict = create_prediction_output( predicted_label, conf, probs, inf_time, "Lighter Attack CNN (margin=5)", lighter_attack_info ) return text_output, prob_dict except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {} def predict_depthwise_attack(image): """Predict using Depthwise Attack CNN with attack enabled (margin=5)""" if image is None: return "Please upload an image", {} if depthwise_attack_model is None: return "āŒ Depthwise Attack CNN not loaded. Please check the model path.", {} try: predicted_label, conf, probs, inf_time = predict_with_timing( depthwise_attack_model, image, device, apply_attack=True, margin=5.0 ) text_output, prob_dict = create_prediction_output( predicted_label, conf, probs, inf_time, "Depthwise Attack CNN (margin=5)", depthwise_attack_info ) return text_output, prob_dict except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {} def predict_all_models(image): """Predict using all models and compare""" if image is None: empty_msg = "Please upload an image" return empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {} try: # Shifted MNIST models cnn_text, cnn_probs = predict_cnn(image) tiny_text, tiny_probs = predict_tinycnn(image) mini_text, mini_probs = predict_minicnn(image) # Attack CNN models standard_text, standard_probs = predict_standard_attack(image) lighter_text, lighter_probs = predict_lighter_attack(image) depthwise_text, depthwise_probs = predict_depthwise_attack(image) return (cnn_text, cnn_probs, tiny_text, tiny_probs, mini_text, mini_probs, standard_text, standard_probs, lighter_text, lighter_probs, depthwise_text, depthwise_probs) except Exception as e: import traceback error_msg = f"āŒ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```" return error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {} # Initialize device device = get_device() print(f"šŸ–„ļø Using device: {device}") # Load models print("šŸ“„ Loading models...") # Define model paths - use checkpoints in HF_demo directory MODEL_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') # Direct paths to model files in checkpoints directory cnn_model_path = os.path.join(MODEL_DIR, 'best_CNN_model_acc_99.33.pth') tinycnn_model_path = os.path.join(MODEL_DIR, 'best_TinyCNN_model_acc_99.17.pth') minicnn_model_path = os.path.join(MODEL_DIR, 'best_MiniCNN_model_acc_97.57.pth') standard_attack_path = os.path.join(MODEL_DIR, 'best_standard_attack_CNN_model.pth') lighter_attack_path = os.path.join(MODEL_DIR, 'best_lighter_attack_CNN_model.pth.pth') depthwise_attack_path = os.path.join(MODEL_DIR, 'best_depthwise_attack_CNN_model.pth') print(f"šŸ“‚ Model directory: {MODEL_DIR}") print(f" CNN model path: {cnn_model_path}") print(f" TinyCNN model path: {tinycnn_model_path}") print(f" MiniCNN model path: {minicnn_model_path}") print(f" Standard Attack CNN path: {standard_attack_path}") print(f" Lighter Attack CNN path: {lighter_attack_path}") print(f" Depthwise Attack CNN path: {depthwise_attack_path}") # Try to load Shifted MNIST models try: cnn_model, cnn_info = load_model(cnn_model_path, 'CNN', device) print(f"āœ… CNNModel loaded: {cnn_info.get('total_parameters', 'N/A'):,} parameters") except Exception as e: print(f"āš ļø Failed to load CNNModel: {e}") cnn_model, cnn_info = None, {} try: tinycnn_model, tinycnn_info = load_model(tinycnn_model_path, 'TinyCNN', device) print(f"āœ… TinyCNN loaded: {tinycnn_info.get('total_parameters', 'N/A'):,} parameters") except Exception as e: print(f"āš ļø Failed to load TinyCNN: {e}") tinycnn_model, tinycnn_info = None, {} try: minicnn_model, minicnn_info = load_model(minicnn_model_path, 'MiniCNN', device) print(f"āœ… MiniCNN loaded: {minicnn_info.get('total_parameters', 'N/A'):,} parameters") except Exception as e: print(f"āš ļø Failed to load MiniCNN: {e}") minicnn_model, minicnn_info = None, {} # Try to load Attack CNN models try: standard_attack_model, standard_attack_info = load_model(standard_attack_path, 'StandardAttack', device) print(f"āœ… Standard Attack CNN loaded: {standard_attack_info.get('total_parameters', 'N/A'):,} parameters") except Exception as e: print(f"āš ļø Failed to load Standard Attack CNN: {e}") standard_attack_model, standard_attack_info = None, {} try: lighter_attack_model, lighter_attack_info = load_model(lighter_attack_path, 'LighterAttack', device) print(f"āœ… Lighter Attack CNN loaded: {lighter_attack_info.get('total_parameters', 'N/A'):,} parameters") except Exception as e: print(f"āš ļø Failed to load Lighter Attack CNN: {e}") lighter_attack_model, lighter_attack_info = None, {} try: depthwise_attack_model, depthwise_attack_info = load_model(depthwise_attack_path, 'DepthwiseAttack', device) print(f"āœ… Depthwise Attack CNN loaded: {depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters") except Exception as e: print(f"āš ļø Failed to load Depthwise Attack CNN: {e}") depthwise_attack_model, depthwise_attack_info = None, {} # Create Gradio interface with gr.Blocks(title="MNIST CNN Classifier - 6 Models Comparison", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # šŸ”¢ MNIST Digit Classifier - 6 Model Comparison This app demonstrates **six CNN architectures** trained on MNIST with **shifted labels**: ### šŸŽÆ Shifted MNIST Models: - **CNNModel**: 817K params - High accuracy baseline - **TinyCNN**: 94K params - Balanced performance - **MiniCNN**: 1.4K params - Ultra-lightweight ### āš”ļø Attack CNN Models: - **Standard Attack CNN**: ~817K params - Standard architecture with attack defense - **Lighter Attack CNN**: ~94K params - Lighter with attack defense - **Depthwise Attack CNN**: ~1.4K params - Most efficient with depthwise separable convolutions **Note:** All models show the **predicted label directly** (0-9) as they were trained. - Shifted MNIST models: Trained with shifted labels (0→9, 1→8, etc.) - **Attack CNN models: Apply logit attack with margin=5 (boosts lowest logit above highest)** Upload a handwritten digit image and compare predictions across all architectures! """) # Display model loading status status_text = "### šŸ“Š Model Status\n\n" status_text += "**Shifted MNIST Models:**\n\n" if cnn_model: status_text += f"āœ… **CNNModel** loaded ({cnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n" else: status_text += "āŒ **CNNModel** not loaded\n\n" if tinycnn_model: status_text += f"āœ… **TinyCNN** loaded ({tinycnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n" else: status_text += "āŒ **TinyCNN** not loaded\n\n" if minicnn_model: status_text += f"āœ… **MiniCNN** loaded ({minicnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n" else: status_text += "āŒ **MiniCNN** not loaded\n\n" status_text += "**Attack CNN Models:**\n\n" if standard_attack_model: status_text += f"āœ… **Standard Attack CNN** loaded ({standard_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n" else: status_text += "āŒ **Standard Attack CNN** not loaded\n\n" if lighter_attack_model: status_text += f"āœ… **Lighter Attack CNN** loaded ({lighter_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n" else: status_text += "āŒ **Lighter Attack CNN** not loaded\n\n" if depthwise_attack_model: status_text += f"āœ… **Depthwise Attack CNN** loaded ({depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n" else: status_text += "āŒ **Depthwise Attack CNN** not loaded\n\n" gr.Markdown(status_text) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( type="pil", label="Upload Digit Image", image_mode="L", sources=["upload", "webcam", "clipboard"] ) gr.Markdown("---") with gr.Tabs(): with gr.Tab("šŸ” Individual Models"): gr.Markdown("### Shifted MNIST Models") with gr.Row(): with gr.Column(): gr.Markdown("#### CNNModel (817K params)") cnn_btn = gr.Button( "Predict with CNNModel", variant="primary", interactive=cnn_model is not None ) cnn_output = gr.Markdown() cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### TinyCNN (94K params)") tiny_btn = gr.Button( "Predict with TinyCNN", variant="primary", interactive=tinycnn_model is not None ) tiny_output = gr.Markdown() tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### MiniCNN (1.4K params)") mini_btn = gr.Button( "Predict with MiniCNN", variant="primary", interactive=minicnn_model is not None ) mini_output = gr.Markdown() mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10) gr.Markdown("---") gr.Markdown("### Attack CNN Models") with gr.Row(): with gr.Column(): gr.Markdown("#### Standard Attack CNN (817K params)") standard_btn = gr.Button( "Predict with Standard Attack", variant="secondary", interactive=standard_attack_model is not None ) standard_output = gr.Markdown() standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### Lighter Attack CNN (94K params)") lighter_btn = gr.Button( "Predict with Lighter Attack", variant="secondary", interactive=lighter_attack_model is not None ) lighter_output = gr.Markdown() lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### Depthwise Attack CNN (1.4K params)") depthwise_btn = gr.Button( "Predict with Depthwise Attack", variant="secondary", interactive=depthwise_attack_model is not None ) depthwise_output = gr.Markdown() depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Tab("āš–ļø Compare All Models"): compare_btn = gr.Button( "Compare All 6 Models", variant="primary", size="lg", interactive=True ) gr.Markdown("### Shifted MNIST Models") with gr.Row(): with gr.Column(): gr.Markdown("#### CNNModel") compare_cnn_output = gr.Markdown() compare_cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### TinyCNN") compare_tiny_output = gr.Markdown() compare_tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### MiniCNN") compare_mini_output = gr.Markdown() compare_mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10) gr.Markdown("---") gr.Markdown("### Attack CNN Models") with gr.Row(): with gr.Column(): gr.Markdown("#### Standard Attack CNN") compare_standard_output = gr.Markdown() compare_standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### Lighter Attack CNN") compare_lighter_output = gr.Markdown() compare_lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10) with gr.Column(): gr.Markdown("#### Depthwise Attack CNN") compare_depthwise_output = gr.Markdown() compare_depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10) # Connect buttons to functions cnn_btn.click(predict_cnn, inputs=input_image, outputs=[cnn_output, cnn_plot]) tiny_btn.click(predict_tinycnn, inputs=input_image, outputs=[tiny_output, tiny_plot]) mini_btn.click(predict_minicnn, inputs=input_image, outputs=[mini_output, mini_plot]) standard_btn.click(predict_standard_attack, inputs=input_image, outputs=[standard_output, standard_plot]) lighter_btn.click(predict_lighter_attack, inputs=input_image, outputs=[lighter_output, lighter_plot]) depthwise_btn.click(predict_depthwise_attack, inputs=input_image, outputs=[depthwise_output, depthwise_plot]) compare_btn.click( predict_all_models, inputs=input_image, outputs=[ compare_cnn_output, compare_cnn_plot, compare_tiny_output, compare_tiny_plot, compare_mini_output, compare_mini_plot, compare_standard_output, compare_standard_plot, compare_lighter_output, compare_lighter_plot, compare_depthwise_output, compare_depthwise_plot ] ) # Launch the app if __name__ == "__main__": print("\nšŸš€ Launching Gradio app...") demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )