HF-Demo / app.py
felix2703's picture
Add model definitions and fix imports for HuggingFace Space
95382f9
"""
Gradio Demo for Shifted MNIST CNN Models
Supports 6 models:
- Shifted MNIST: CNNModel, TinyCNN, MiniCNN
- Attack CNN: Standard, Lighter, Depthwise
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import time
import sys
import os
# Import model architectures from local files
from models_shifted import CNNModel, TinyCNN, MiniCNN
from models_attack import StandardCNN, LighterCNN, DepthwiseCNN
# Label mapping for shifted MNIST
LABEL_MAPPING = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0}
REVERSE_MAPPING = {v: k for k, v in LABEL_MAPPING.items()}
def get_device():
"""Get the best available device"""
if torch.cuda.is_available():
return torch.device('cuda')
elif torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
def load_model(model_path, model_type, device):
"""Load a trained model from checkpoint"""
# Create model instance
if model_type == 'CNN':
model = CNNModel(num_classes=10, dropout_rate=0.5)
elif model_type == 'TinyCNN':
model = TinyCNN(num_classes=10)
elif model_type == 'MiniCNN':
model = MiniCNN(num_classes=10)
elif model_type == 'StandardAttack':
model = StandardCNN(num_classes=10, dropout_rate=0.5)
elif model_type == 'LighterAttack':
model = LighterCNN(num_classes=10, dropout_rate=0.5)
elif model_type == 'DepthwiseAttack':
model = DepthwiseCNN(num_classes=10, dropout_rate=0.5)
else:
raise ValueError(f"Unknown model type: {model_type}")
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
# Shifted MNIST format: {'model_state_dict': ..., 'model_info': ...}
model.load_state_dict(checkpoint['model_state_dict'])
model_info = checkpoint.get('model_info', {})
else:
# Direct state dict format
model.load_state_dict(checkpoint)
model_info = {}
else:
# Fallback: assume it's a state dict
model.load_state_dict(checkpoint)
model_info = {}
# If model_info is empty, calculate parameters
if not model_info.get('total_parameters'):
total_params = sum(p.numel() for p in model.parameters())
model_info['total_parameters'] = total_params
model_info['architecture'] = model_type
model.to(device)
model.eval()
return model, model_info
def preprocess_image(image):
"""Preprocess image for model input"""
# Convert to grayscale if needed
if image.mode != 'L':
image = image.convert('L')
# Resize to 28x28
image = image.resize((28, 28), Image.Resampling.LANCZOS)
# Convert to numpy array and normalize
img_array = np.array(image).astype(np.float32) / 255.0
# Apply MNIST normalization
mean = 0.1307
std = 0.3081
img_array = (img_array - mean) / std
# Convert to tensor and add batch and channel dimensions
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
return img_tensor
def logit_attack_lowest(logits, margin=5.0):
"""
Attack by boosting lowest logit
Args:
logits: Model logits (batch_size, num_classes)
margin: How much to boost the lowest logit above highest
Returns:
attacked_logits
"""
attacked_logits = logits.clone()
batch_size = logits.size(0)
for i in range(batch_size):
highest_val = torch.max(logits[i]).item()
lowest_idx = torch.argmin(logits[i]).item()
lowest_val = logits[i, lowest_idx].item()
delta_needed = (highest_val - lowest_val) + margin
attacked_logits[i, lowest_idx] += delta_needed
return attacked_logits
def predict_with_timing(model, image, device, apply_attack=False, margin=5.0):
"""Make prediction with timing"""
# Preprocess image
img_tensor = preprocess_image(image).to(device)
# Check if model supports return_logits parameter (Attack CNN models)
# by checking if it has the parameter in forward signature
supports_return_logits = apply_attack # Only attack models need logits
# Warm-up run (for accurate timing on GPU)
with torch.no_grad():
if supports_return_logits:
_ = model(img_tensor, return_logits=True)
else:
_ = model(img_tensor)
# Actual prediction with timing
start_time = time.time()
with torch.no_grad():
if supports_return_logits:
# Attack CNN models - get logits
logits = model(img_tensor, return_logits=True)
# Apply attack if requested
if apply_attack:
logits = logit_attack_lowest(logits, margin=margin)
probabilities = F.softmax(logits, dim=1)
else:
# Shifted MNIST models - already return softmax probabilities
outputs = model(img_tensor)
# If outputs are logits, apply softmax; if already probabilities, use as-is
if outputs.max() > 1.0 or outputs.min() < 0.0:
# Likely logits
probabilities = F.softmax(outputs, dim=1)
else:
# Already probabilities
probabilities = outputs
end_time = time.time()
inference_time = (end_time - start_time) * 1000 # Convert to milliseconds
# Get predictions
probs = probabilities.cpu().numpy()[0]
predicted_label = np.argmax(probs)
confidence = probs[predicted_label] * 100
return predicted_label, confidence, probs, inference_time
def create_prediction_output(predicted_label, confidence, probs, inference_time, model_name, model_info):
"""Create formatted prediction output"""
# Main prediction
result_text = f"### 🎯 Prediction Results ({model_name})\n\n"
result_text += f"**Predicted Label:** {predicted_label}\n\n"
result_text += f"**Confidence:** {confidence:.2f}%\n\n"
result_text += f"**⏱️ Inference Time:** {inference_time:.3f} ms\n\n"
# Model info
if model_info:
result_text += f"**📊 Model Info:**\n"
result_text += f"- Parameters: {model_info.get('total_parameters', 'N/A'):,}\n"
result_text += f"- Architecture: {model_info.get('architecture', 'N/A')}\n\n"
# Create probability distribution dictionary for plot - showing predicted labels
prob_dict = {}
for i in range(10):
prob_dict[f"Label {i}"] = float(probs[i])
return result_text, prob_dict
def predict_cnn(image):
"""Predict using CNNModel"""
if image is None:
return "Please upload an image", {}
if cnn_model is None:
return "❌ CNNModel not loaded. Please check the model path.", {}
try:
predicted_label, conf, probs, inf_time = predict_with_timing(
cnn_model, image, device
)
text_output, prob_dict = create_prediction_output(
predicted_label, conf, probs, inf_time, "CNNModel", cnn_info
)
return text_output, prob_dict
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}
def predict_tinycnn(image):
"""Predict using TinyCNN"""
if image is None:
return "Please upload an image", {}
if tinycnn_model is None:
return "❌ TinyCNN not loaded. Please check the model path.", {}
try:
predicted_label, conf, probs, inf_time = predict_with_timing(
tinycnn_model, image, device
)
text_output, prob_dict = create_prediction_output(
predicted_label, conf, probs, inf_time, "TinyCNN", tinycnn_info
)
return text_output, prob_dict
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}
def predict_minicnn(image):
"""Predict using MiniCNN"""
if image is None:
return "Please upload an image", {}
if minicnn_model is None:
return "❌ MiniCNN not loaded. Please check the model path.", {}
try:
predicted_label, conf, probs, inf_time = predict_with_timing(
minicnn_model, image, device
)
text_output, prob_dict = create_prediction_output(
predicted_label, conf, probs, inf_time, "MiniCNN", minicnn_info
)
return text_output, prob_dict
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}
def predict_standard_attack(image):
"""Predict using Standard Attack CNN with attack enabled (margin=5)"""
if image is None:
return "Please upload an image", {}
if standard_attack_model is None:
return "❌ Standard Attack CNN not loaded. Please check the model path.", {}
try:
predicted_label, conf, probs, inf_time = predict_with_timing(
standard_attack_model, image, device, apply_attack=True, margin=5.0
)
text_output, prob_dict = create_prediction_output(
predicted_label, conf, probs, inf_time, "Standard Attack CNN (margin=5)", standard_attack_info
)
return text_output, prob_dict
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}
def predict_lighter_attack(image):
"""Predict using Lighter Attack CNN with attack enabled (margin=5)"""
if image is None:
return "Please upload an image", {}
if lighter_attack_model is None:
return "❌ Lighter Attack CNN not loaded. Please check the model path.", {}
try:
predicted_label, conf, probs, inf_time = predict_with_timing(
lighter_attack_model, image, device, apply_attack=True, margin=5.0
)
text_output, prob_dict = create_prediction_output(
predicted_label, conf, probs, inf_time, "Lighter Attack CNN (margin=5)", lighter_attack_info
)
return text_output, prob_dict
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}
def predict_depthwise_attack(image):
"""Predict using Depthwise Attack CNN with attack enabled (margin=5)"""
if image is None:
return "Please upload an image", {}
if depthwise_attack_model is None:
return "❌ Depthwise Attack CNN not loaded. Please check the model path.", {}
try:
predicted_label, conf, probs, inf_time = predict_with_timing(
depthwise_attack_model, image, device, apply_attack=True, margin=5.0
)
text_output, prob_dict = create_prediction_output(
predicted_label, conf, probs, inf_time, "Depthwise Attack CNN (margin=5)", depthwise_attack_info
)
return text_output, prob_dict
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}
def predict_all_models(image):
"""Predict using all models and compare"""
if image is None:
empty_msg = "Please upload an image"
return empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}
try:
# Shifted MNIST models
cnn_text, cnn_probs = predict_cnn(image)
tiny_text, tiny_probs = predict_tinycnn(image)
mini_text, mini_probs = predict_minicnn(image)
# Attack CNN models
standard_text, standard_probs = predict_standard_attack(image)
lighter_text, lighter_probs = predict_lighter_attack(image)
depthwise_text, depthwise_probs = predict_depthwise_attack(image)
return (cnn_text, cnn_probs,
tiny_text, tiny_probs,
mini_text, mini_probs,
standard_text, standard_probs,
lighter_text, lighter_probs,
depthwise_text, depthwise_probs)
except Exception as e:
import traceback
error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
return error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}
# Initialize device
device = get_device()
print(f"🖥️ Using device: {device}")
# Load models
print("📥 Loading models...")
# Define model paths - use checkpoints in HF_demo directory
MODEL_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
# Direct paths to model files in checkpoints directory
cnn_model_path = os.path.join(MODEL_DIR, 'best_CNN_model_acc_99.33.pth')
tinycnn_model_path = os.path.join(MODEL_DIR, 'best_TinyCNN_model_acc_99.17.pth')
minicnn_model_path = os.path.join(MODEL_DIR, 'best_MiniCNN_model_acc_97.57.pth')
standard_attack_path = os.path.join(MODEL_DIR, 'best_standard_attack_CNN_model.pth')
lighter_attack_path = os.path.join(MODEL_DIR, 'best_lighter_attack_CNN_model.pth.pth')
depthwise_attack_path = os.path.join(MODEL_DIR, 'best_depthwise_attack_CNN_model.pth')
print(f"📂 Model directory: {MODEL_DIR}")
print(f" CNN model path: {cnn_model_path}")
print(f" TinyCNN model path: {tinycnn_model_path}")
print(f" MiniCNN model path: {minicnn_model_path}")
print(f" Standard Attack CNN path: {standard_attack_path}")
print(f" Lighter Attack CNN path: {lighter_attack_path}")
print(f" Depthwise Attack CNN path: {depthwise_attack_path}")
# Try to load Shifted MNIST models
try:
cnn_model, cnn_info = load_model(cnn_model_path, 'CNN', device)
print(f"✅ CNNModel loaded: {cnn_info.get('total_parameters', 'N/A'):,} parameters")
except Exception as e:
print(f"⚠️ Failed to load CNNModel: {e}")
cnn_model, cnn_info = None, {}
try:
tinycnn_model, tinycnn_info = load_model(tinycnn_model_path, 'TinyCNN', device)
print(f"✅ TinyCNN loaded: {tinycnn_info.get('total_parameters', 'N/A'):,} parameters")
except Exception as e:
print(f"⚠️ Failed to load TinyCNN: {e}")
tinycnn_model, tinycnn_info = None, {}
try:
minicnn_model, minicnn_info = load_model(minicnn_model_path, 'MiniCNN', device)
print(f"✅ MiniCNN loaded: {minicnn_info.get('total_parameters', 'N/A'):,} parameters")
except Exception as e:
print(f"⚠️ Failed to load MiniCNN: {e}")
minicnn_model, minicnn_info = None, {}
# Try to load Attack CNN models
try:
standard_attack_model, standard_attack_info = load_model(standard_attack_path, 'StandardAttack', device)
print(f"✅ Standard Attack CNN loaded: {standard_attack_info.get('total_parameters', 'N/A'):,} parameters")
except Exception as e:
print(f"⚠️ Failed to load Standard Attack CNN: {e}")
standard_attack_model, standard_attack_info = None, {}
try:
lighter_attack_model, lighter_attack_info = load_model(lighter_attack_path, 'LighterAttack', device)
print(f"✅ Lighter Attack CNN loaded: {lighter_attack_info.get('total_parameters', 'N/A'):,} parameters")
except Exception as e:
print(f"⚠️ Failed to load Lighter Attack CNN: {e}")
lighter_attack_model, lighter_attack_info = None, {}
try:
depthwise_attack_model, depthwise_attack_info = load_model(depthwise_attack_path, 'DepthwiseAttack', device)
print(f"✅ Depthwise Attack CNN loaded: {depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters")
except Exception as e:
print(f"⚠️ Failed to load Depthwise Attack CNN: {e}")
depthwise_attack_model, depthwise_attack_info = None, {}
# Create Gradio interface
with gr.Blocks(title="MNIST CNN Classifier - 6 Models Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🔢 MNIST Digit Classifier - 6 Model Comparison
This app demonstrates **six CNN architectures** trained on MNIST with **shifted labels**:
### 🎯 Shifted MNIST Models:
- **CNNModel**: 817K params - High accuracy baseline
- **TinyCNN**: 94K params - Balanced performance
- **MiniCNN**: 1.4K params - Ultra-lightweight
### ⚔️ Attack CNN Models:
- **Standard Attack CNN**: ~817K params - Standard architecture with attack defense
- **Lighter Attack CNN**: ~94K params - Lighter with attack defense
- **Depthwise Attack CNN**: ~1.4K params - Most efficient with depthwise separable convolutions
**Note:** All models show the **predicted label directly** (0-9) as they were trained.
- Shifted MNIST models: Trained with shifted labels (0→9, 1→8, etc.)
- **Attack CNN models: Apply logit attack with margin=5 (boosts lowest logit above highest)**
Upload a handwritten digit image and compare predictions across all architectures!
""")
# Display model loading status
status_text = "### 📊 Model Status\n\n"
status_text += "**Shifted MNIST Models:**\n\n"
if cnn_model:
status_text += f"✅ **CNNModel** loaded ({cnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
else:
status_text += "❌ **CNNModel** not loaded\n\n"
if tinycnn_model:
status_text += f"✅ **TinyCNN** loaded ({tinycnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
else:
status_text += "❌ **TinyCNN** not loaded\n\n"
if minicnn_model:
status_text += f"✅ **MiniCNN** loaded ({minicnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
else:
status_text += "❌ **MiniCNN** not loaded\n\n"
status_text += "**Attack CNN Models:**\n\n"
if standard_attack_model:
status_text += f"✅ **Standard Attack CNN** loaded ({standard_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
else:
status_text += "❌ **Standard Attack CNN** not loaded\n\n"
if lighter_attack_model:
status_text += f"✅ **Lighter Attack CNN** loaded ({lighter_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
else:
status_text += "❌ **Lighter Attack CNN** not loaded\n\n"
if depthwise_attack_model:
status_text += f"✅ **Depthwise Attack CNN** loaded ({depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
else:
status_text += "❌ **Depthwise Attack CNN** not loaded\n\n"
gr.Markdown(status_text)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Upload Digit Image",
image_mode="L",
sources=["upload", "webcam", "clipboard"]
)
gr.Markdown("---")
with gr.Tabs():
with gr.Tab("🔍 Individual Models"):
gr.Markdown("### Shifted MNIST Models")
with gr.Row():
with gr.Column():
gr.Markdown("#### CNNModel (817K params)")
cnn_btn = gr.Button(
"Predict with CNNModel",
variant="primary",
interactive=cnn_model is not None
)
cnn_output = gr.Markdown()
cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### TinyCNN (94K params)")
tiny_btn = gr.Button(
"Predict with TinyCNN",
variant="primary",
interactive=tinycnn_model is not None
)
tiny_output = gr.Markdown()
tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### MiniCNN (1.4K params)")
mini_btn = gr.Button(
"Predict with MiniCNN",
variant="primary",
interactive=minicnn_model is not None
)
mini_output = gr.Markdown()
mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
gr.Markdown("---")
gr.Markdown("### Attack CNN Models")
with gr.Row():
with gr.Column():
gr.Markdown("#### Standard Attack CNN (817K params)")
standard_btn = gr.Button(
"Predict with Standard Attack",
variant="secondary",
interactive=standard_attack_model is not None
)
standard_output = gr.Markdown()
standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### Lighter Attack CNN (94K params)")
lighter_btn = gr.Button(
"Predict with Lighter Attack",
variant="secondary",
interactive=lighter_attack_model is not None
)
lighter_output = gr.Markdown()
lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### Depthwise Attack CNN (1.4K params)")
depthwise_btn = gr.Button(
"Predict with Depthwise Attack",
variant="secondary",
interactive=depthwise_attack_model is not None
)
depthwise_output = gr.Markdown()
depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Tab("⚖️ Compare All Models"):
compare_btn = gr.Button(
"Compare All 6 Models",
variant="primary",
size="lg",
interactive=True
)
gr.Markdown("### Shifted MNIST Models")
with gr.Row():
with gr.Column():
gr.Markdown("#### CNNModel")
compare_cnn_output = gr.Markdown()
compare_cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### TinyCNN")
compare_tiny_output = gr.Markdown()
compare_tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### MiniCNN")
compare_mini_output = gr.Markdown()
compare_mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
gr.Markdown("---")
gr.Markdown("### Attack CNN Models")
with gr.Row():
with gr.Column():
gr.Markdown("#### Standard Attack CNN")
compare_standard_output = gr.Markdown()
compare_standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### Lighter Attack CNN")
compare_lighter_output = gr.Markdown()
compare_lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
with gr.Column():
gr.Markdown("#### Depthwise Attack CNN")
compare_depthwise_output = gr.Markdown()
compare_depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
# Connect buttons to functions
cnn_btn.click(predict_cnn, inputs=input_image, outputs=[cnn_output, cnn_plot])
tiny_btn.click(predict_tinycnn, inputs=input_image, outputs=[tiny_output, tiny_plot])
mini_btn.click(predict_minicnn, inputs=input_image, outputs=[mini_output, mini_plot])
standard_btn.click(predict_standard_attack, inputs=input_image, outputs=[standard_output, standard_plot])
lighter_btn.click(predict_lighter_attack, inputs=input_image, outputs=[lighter_output, lighter_plot])
depthwise_btn.click(predict_depthwise_attack, inputs=input_image, outputs=[depthwise_output, depthwise_plot])
compare_btn.click(
predict_all_models,
inputs=input_image,
outputs=[
compare_cnn_output, compare_cnn_plot,
compare_tiny_output, compare_tiny_plot,
compare_mini_output, compare_mini_plot,
compare_standard_output, compare_standard_plot,
compare_lighter_output, compare_lighter_plot,
compare_depthwise_output, compare_depthwise_plot
]
)
# Launch the app
if __name__ == "__main__":
print("\n🚀 Launching Gradio app...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)