nirmalpratheep's picture
Update app.py
75b682e verified
"""
CIFAR-10 Image Classifier - Hugging Face Space
===============================================
Advanced CNN with Residual Connections achieving 87.88% test accuracy
Architecture: CIFARNet
- Depthwise Separable Convolutions
- Residual Connections in C2 and C3 layers
- Spatial Dropout for regularization
- 174,762 parameters (0.67 MB model size)
"""
import torch
import gradio as gr
from PIL import Image
from pathlib import Path
import numpy as np
# Import model architecture and preprocessing
from model_cifar import CIFARNet
from preprocess import CIFAR10_MEAN, CIFAR10_STD
from torchvision import transforms
# CIFAR-10 class names
CIFAR10_CLASSES = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
# Class descriptions for better UX
CLASS_DESCRIPTIONS = {
"airplane": "✈️ Commercial or military aircraft",
"automobile": "πŸš— Cars, sedans, and vehicles",
"bird": "🐦 Various bird species",
"cat": "🐱 Domestic and wild cats",
"deer": "🦌 Deer and similar animals",
"dog": "πŸ• Domestic dogs of various breeds",
"frog": "🐸 Frogs and similar amphibians",
"horse": "🐴 Horses and equines",
"ship": "🚒 Ships, boats, and vessels",
"truck": "🚚 Trucks and large vehicles"
}
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
@torch.no_grad()
def load_model(checkpoint_path: str = None):
"""Load the trained CIFARNet model."""
model = CIFARNet(num_classes=10).to(device)
# Try to load checkpoint
if checkpoint_path and Path(checkpoint_path).exists():
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
print(f"βœ… Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}")
else:
model.load_state_dict(checkpoint)
print(f"βœ… Loaded model weights from {checkpoint_path}")
except Exception as e:
print(f"⚠️ Could not load checkpoint: {e}")
print("Using randomly initialized model")
else:
print("ℹ️ No checkpoint provided, using randomly initialized model")
model.eval()
return model
# Initialize model
print(f"Device: {device}")
model = load_model("./snapshots_complete/cifar_epoch_249.pth")
# Preprocessing pipeline (matches training)
preprocess = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD)
])
def predict(image: Image.Image) -> tuple:
"""
Predict the class of an input image.
Args:
image: PIL Image
Returns:
Tuple of (predictions_dict, confidence_html)
"""
if image is None:
return {}, "<p style='color: red;'>Please upload an image first!</p>"
try:
# Preprocess image
img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy()
# Get all predictions sorted by probability
sorted_indices = np.argsort(probabilities)[::-1]
# Create results dictionary for top 3
top3_results = {
CIFAR10_CLASSES[i]: float(probabilities[i])
for i in sorted_indices[:3]
}
# Create detailed HTML output
predicted_class = CIFAR10_CLASSES[sorted_indices[0]]
confidence = probabilities[sorted_indices[0]]
html_output = f"""
<div style='padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 10px; color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'>
<h2 style='margin: 0 0 10px 0;'>🎯 Prediction Result</h2>
<div style='font-size: 24px; font-weight: bold; margin: 10px 0;'>
{predicted_class.upper()}
</div>
<div style='font-size: 16px; opacity: 0.9;'>
{CLASS_DESCRIPTIONS[predicted_class]}
</div>
<div style='font-size: 18px; margin-top: 10px;'>
Confidence: <strong>{confidence*100:.2f}%</strong>
</div>
</div>
<div style='margin-top: 20px; padding: 15px; background: #f8f9fa;
border-radius: 8px; border-left: 4px solid #667eea;'>
<h3 style='margin-top: 0; color: #333;'>πŸ“Š Top 5 Predictions:</h3>
<div style='margin-top: 10px;'>
"""
for i, idx in enumerate(sorted_indices[:5], 1):
class_name = CIFAR10_CLASSES[idx]
prob = probabilities[idx]
bar_width = int(prob * 100)
# Color coding based on rank
if i == 1:
color = "#28a745" # Green for top prediction
elif i == 2:
color = "#17a2b8" # Blue for second
else:
color = "#6c757d" # Gray for others
html_output += f"""
<div style='margin: 8px 0;'>
<div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;'>
<span style='font-weight: 500; color: #333;'>{i}. {class_name}</span>
<span style='font-weight: bold; color: {color};'>{prob*100:.2f}%</span>
</div>
<div style='width: 100%; background: #e9ecef; border-radius: 4px; height: 20px; overflow: hidden;'>
<div style='width: {bar_width}%; background: {color}; height: 100%;
transition: width 0.3s ease;'></div>
</div>
</div>
"""
html_output += """
</div>
</div>
"""
return top3_results, html_output
except Exception as e:
error_html = f"<p style='color: red;'>Error during prediction: {str(e)}</p>"
return {}, error_html
# Model information for display
model_description = """
## πŸš€ About This Model
**CIFARNet** is an advanced CNN architecture designed for CIFAR-10 image classification. It achieves state-of-the-art performance with exceptional efficiency.
### πŸ“Š Performance Metrics
- **Test Accuracy:** 87.88%
- **Top-3 Accuracy:** 97.74%
- **Top-5 Accuracy:** 99.31%
- **Model Size:** 174,762 parameters (0.67 MB)
### πŸ—οΈ Architecture Highlights
- **Depthwise Separable Convolutions** for parameter efficiency
- **Residual Connections** in C2 and C3 layers for improved gradient flow
- **Spatial Dropout** for better regularization
- **Dilated Convolutions** (C3 layer with dilation=4) for larger receptive field
### 🎯 Best Performing Classes
- Ship: 92.95% F1-score
- Truck: 92.19% F1-score
- Automobile: 93.77% F1-score
- Frog: 90.28% F1-score
### πŸ”¬ Training Details
- **Training Set:** CIFAR-10 (50,000 images)
- **Test Set:** CIFAR-10 (10,000 images)
- **Epochs:** 250 with cosine annealing
- **Optimizer:** SGD with Nesterov momentum
- **Augmentation:** HorizontalFlip, ShiftScaleRotate, CoarseDropout, ColorJitter
### πŸ“š Classes
The model classifies images into 10 categories:
- ✈️ Airplane
- πŸš— Automobile
- 🐦 Bird
- 🐱 Cat
- 🦌 Deer
- πŸ• Dog
- 🐸 Frog
- 🐴 Horse
- 🚒 Ship
- 🚚 Truck
### πŸ’‘ Tips
- Upload clear images for best results
- The model works best with images containing the main object centered
- Images are automatically resized to 32Γ—32 pixels (CIFAR-10 standard)
- Try different angles or lighting conditions to see how the model performs
### πŸ”— Links
- [GitHub Repository](https://github.com/yourusername/CIFAR10-MLTraining)
- [Model Architecture Details](https://github.com/yourusername/CIFAR10-MLTraining#model-architecture)
- [Training Logs & Metrics](https://github.com/yourusername/CIFAR10-MLTraining#performance-results)
"""
# Example images (3 per class)
examples = [
["examples/airplane_1.jpg"],
["examples/airplane_2.jpg"],
["examples/automobile_1.jpg"],
["examples/ship_1.jpg"],
["examples/ship_2.jpg"],
["examples/cat_1.jpg"],
["examples/dog_1.jpg"],
["examples/horse_1.jpg"],
["examples/frog_1.jpg"],
["examples/truck_1.jpg"],
]
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Inter', sans-serif;
}
.output-html {
font-family: 'Inter', sans-serif;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 CIFAR-10 Image Classifier")
gr.Markdown("### Advanced CNN achieving 87.88% test accuracy")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil")
predict_btn = gr.Button("πŸš€ Classify Image", variant="primary", size="lg")
gr.Markdown("### πŸ“€ Try it out!")
gr.Markdown("Upload an image containing one of the 10 CIFAR-10 classes.")
with gr.Column(scale=1):
label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions")
html_output = gr.HTML(label="Detailed Results")
# Add examples section
gr.Markdown("---")
gr.Markdown("## πŸ’‘ How to use")
gr.Markdown("""
1. **Upload an image** using the upload box on the left
2. **Click 'Classify Image'** to get predictions
3. **View results** showing the top predictions with confidence scores
The model works best with images from these categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.
""")
gr.Markdown("### πŸ“Έ Try These Examples")
gr.Examples(
examples=examples,
inputs=image_input,
outputs=[label_output, html_output],
fn=predict,
cache_examples=False,
)
# Model information
gr.Markdown("---")
with gr.Accordion("πŸ“– Model Information & Performance Metrics", open=False):
gr.Markdown(model_description)
# Connect the prediction function to button click only
predict_btn.click(
fn=predict,
inputs=image_input,
outputs=[label_output, html_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()