HF Deploy
Deploy CIFAR-100 classifier
a92663e
"""
CIFAR-100 Image Classifier - Hugging Face Space
===============================================
Advanced ResNet-34 model trained on CIFAR-100.
Architecture: ResNet-34
- 100 output classes across diverse object categories
- Deep residual learning for robust feature extraction
- Trained with Albumentations augmentations
"""
import torch
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from pathlib import Path
import numpy as np
import cv2
# Import model architecture and preprocessing
from model_cifar import ResNet34
from preprocess import CIFAR100_MEAN, CIFAR100_STD
from torchvision import transforms
# CIFAR-100 class names (official dataset labels)
CIFAR100_CLASSES = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel',
'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------
# Load Model
# ---------------------------
@torch.no_grad()
def load_model(checkpoint_path: str = None):
"""Load the trained ResNet-34 model for CIFAR-100. Reaching Accuracy of 76.68%"""
model = ResNet34(num_classes=100).to(device)
if checkpoint_path and Path(checkpoint_path).exists():
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
print(f"βœ… Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}")
else:
model.load_state_dict(checkpoint)
print(f"βœ… Loaded model weights from {checkpoint_path}")
except Exception as e:
print(f"⚠️ Could not load checkpoint: {e}")
print("Using randomly initialized model")
else:
print("ℹ️ No checkpoint provided, using randomly initialized model")
model.eval()
return model
print(f"Device: {device}")
# Try to load the best checkpoint, fallback to epoch 99, then random init
checkpoint_paths = [
"./best_model.pth", # For HF Space deployment
"./snapshots_complete/cifar_epoch_99.pth",
None # Fallback to random initialization
]
model = None
for checkpoint_path in checkpoint_paths:
if checkpoint_path is None or Path(checkpoint_path).exists():
model = load_model(checkpoint_path)
break
if model is None:
model = load_model(None)
# ---------------------------
# Preprocessing pipeline
# ---------------------------
preprocess = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD)
])
preprocess_no_norm = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
# ---------------------------
# Grad-CAM Implementation
# ---------------------------
class GradCAM:
"""Grad-CAM: Visual Explanations from Deep Networks"""
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
target_layer.register_forward_hook(self.save_activation)
target_layer.register_full_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate_cam(self, input_tensor, target_class=None):
"""Generate Grad-CAM heatmap"""
# Forward pass
model_output = self.model(input_tensor)
if target_class is None:
target_class = model_output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
one_hot = torch.zeros_like(model_output)
one_hot[0, target_class] = 1
model_output.backward(gradient=one_hot, retain_graph=True)
# Generate CAM
gradients = self.gradients[0]
activations = self.activations[0]
# Global average pooling on gradients
weights = gradients.mean(dim=(1, 2), keepdim=True)
# Weighted combination of activation maps
cam = (weights * activations).sum(dim=0)
# Apply ReLU
cam = F.relu(cam)
# Normalize
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam.cpu().numpy(), target_class
def apply_gradcam(image_pil, model, gradcam, top_class_idx):
"""Apply Grad-CAM and overlay on original image"""
# Prepare input
img_tensor = preprocess(image_pil.convert("RGB")).unsqueeze(0).to(device)
# Generate Grad-CAM
cam, _ = gradcam.generate_cam(img_tensor, target_class=top_class_idx)
# Fixed output size for better visibility
output_size = 200
# Resize CAM to match input size
cam_resized = cv2.resize(cam, (32, 32))
# Convert original image to numpy
img_np = np.array(image_pil.resize((32, 32)))
# Create heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Overlay heatmap on original image
overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
# Resize to 200x200 for better visibility
overlay_large = cv2.resize(overlay, (output_size, output_size), interpolation=cv2.INTER_LINEAR)
heatmap_large = cv2.resize(heatmap, (output_size, output_size), interpolation=cv2.INTER_LINEAR)
return overlay_large, heatmap_large
# Initialize Grad-CAM (target the last convolutional layer)
gradcam = GradCAM(model, model.layer4[-1].conv2)
# ---------------------------
# Prediction Function
# ---------------------------
def predict(image: Image.Image):
"""Predict the class of an input image with Grad-CAM visualization."""
if image is None:
return {}, "<p style='color: red;'>Please upload an image first!</p>", None, None
try:
# Prepare input
img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device)
# Get predictions
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy()
sorted_indices = np.argsort(probabilities)[::-1]
top3_results = {
CIFAR100_CLASSES[i]: float(probabilities[i])
for i in sorted_indices[:3]
}
predicted_class = CIFAR100_CLASSES[sorted_indices[0]]
confidence = probabilities[sorted_indices[0]]
# Generate Grad-CAM visualization
try:
overlay, heatmap = apply_gradcam(image, model, gradcam, sorted_indices[0])
gradcam_overlay = Image.fromarray(overlay.astype(np.uint8))
gradcam_heatmap = Image.fromarray(heatmap.astype(np.uint8))
except Exception as e:
print(f"Grad-CAM error: {e}")
gradcam_overlay = None
gradcam_heatmap = None
# Create HTML output
html_output = f"""
<div style='padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 10px; color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'>
<h2>🎯 Prediction Result</h2>
<div style='font-size: 24px; font-weight: bold;'>{predicted_class.upper()}</div>
<div style='font-size: 18px;'>Confidence: <strong>{confidence*100:.2f}%</strong></div>
</div>
<div style='margin-top: 20px; background: #f8f9fa; border-radius: 8px; padding: 15px;'>
<h3>πŸ“Š Top 5 Predictions</h3>
"""
for i, idx in enumerate(sorted_indices[:5], 1):
name = CIFAR100_CLASSES[idx]
prob = probabilities[idx]
bar_width = int(prob * 100)
color = "#667eea" if i == 1 else ("#764ba2" if i == 2 else "#95a5a6")
html_output += f"""
<div style='margin: 8px 0;'>
<div style='display: flex; justify-content: space-between;'>
<span>{i}. {name}</span>
<span style='font-weight:bold; color:{color}'>{prob*100:.2f}%</span>
</div>
<div style='background:#e9ecef; border-radius:4px; height:20px;'>
<div style='width:{bar_width}%; background:{color}; height:100%; border-radius:4px;'></div>
</div>
</div>
"""
html_output += """
</div>
<div style='margin-top: 15px; padding: 10px; background: #e8f4f8; border-left: 4px solid #667eea; border-radius: 4px;'>
<p style='margin: 0; color: #333;'><strong>πŸ’‘ Grad-CAM Visualization:</strong> The heatmap shows which parts of the image the model focused on to make its prediction. Red/yellow areas indicate high importance.</p>
</div>
"""
return top3_results, html_output, gradcam_overlay, gradcam_heatmap
except Exception as e:
return {}, f"<p style='color: red;'>Error during prediction: {str(e)}</p>", None, None
# ---------------------------
# Model Information Section
# ---------------------------
model_description = """
## πŸš€ About This Model
**Custom Lightweight ResNet trained on CIFAR-100 from scratch**
### πŸ“Š Performance Metrics (100 Epochs)
- **Top-1 Accuracy:** 76.68% βœ… (Target: 73%)
- **Top-3 Accuracy:** 90.95%
- **Top-5 Accuracy:** 94.07%
- **Best Test Accuracy:** 76.79% (Epoch 99)
- **Macro F1-Score:** 0.7670
- **Dataset:** CIFAR-100 (50,000 train / 10,000 test)
### πŸ—οΈ Architecture
- **Model:** Custom ResNet-34 variant (CIFAR-optimized)
- **Parameters:** 4,949,412 (~5M)
- **Depth:** 10 weight layers (1 stem + 8 residual + 1 FC)
- **Design:** 4 BasicBlocks with [1,1,1,1] configuration
- **Key Features:**
- 3Γ—3 initial conv (no 7Γ—7, no MaxPool)
- Progressive downsampling: 32Γ—32 β†’ 16Γ—16 β†’ 8Γ—8 β†’ 4Γ—4
- Channel expansion: 64 β†’ 128 β†’ 256 β†’ 512
- Receptive field: 63Γ—63 (covers full 32Γ—32 image)
- Global Average Pooling + Linear(512 β†’ 100)
### 🎯 Training Configuration
- **Optimizer:** SGD with Nesterov momentum (0.9)
- **LR Schedule:** OneCycle (0.01 β†’ 0.1 β†’ 0.01 β†’ 0.001)
- Phase 1: 41 epochs warmup
- Phase 2: 41 epochs cooldown
- Phase 3: 18 epochs annihilation
- **Augmentations:** Albumentations (Flip, ShiftScaleRotate, CoarseDropout, ColorJitter)
- **Regularization:** Weight decay (1e-4), Label smoothing (0.1)
- **Mixed Precision:** Enabled (AMP)
- **Batch Size:** 512
### πŸ’‘ CIFAR-100 Classes
100 fine-grained categories across 20 superclasses:
- 🦁 **Animals:** lion, tiger, elephant, whale, bear, leopard, wolf
- πŸš— **Vehicles:** pickup_truck, bus, train, streetcar, motorcycle, tractor
- 🌳 **Trees:** maple_tree, oak_tree, palm_tree, pine_tree, willow_tree
- 🌺 **Flowers:** rose, poppy, orchid, sunflower, tulip
- 🐠 **Aquatic:** aquarium_fish, flatfish, ray, shark, trout
- 🏠 **Structures:** house, castle, skyscraper, bridge, road
- 🍎 **Food:** apple, orange, pear, sweet_pepper, mushroom
- πŸ‘¨ **People:** man, woman, baby, boy, girl
- πŸͺ‘ **Furniture:** bed, chair, couch, table, wardrobe
- And many more!
### πŸ† Best Performing Classes (F1-Score)
1. **Wardrobe** - 94.58%
2. **Sunflower** - 93.81%
3. **Poppy** - 93.15%
4. **Can** - 93.10%
5. **Skyscraper** - 91.00%
### πŸš€ Deployment
- Trained without pre-trained weights
- Built with PyTorch and Albumentations
- Deployed on Hugging Face Spaces
- Inference optimized for CPU/GPU
"""
# Example images (curated selection from available examples)
examples = []
example_dir = Path("examples")
if example_dir.exists():
# Get one example from each category (prioritize _1.jpg files for consistency)
priority_examples = [
"lion_1.jpg", "tiger_1.jpg", "elephant_1.jpg", "bear_1.jpg",
"pickup_truck_1.jpg", "bus_1.jpg", "train_1.jpg",
"rose_1.jpg", "sunflower_1.jpg", "tulip_1.jpg",
"apple_1.jpg", "orange_1.jpg", "pear_1.jpg",
"castle_1.jpg", "skyscraper_1.jpg", "house_1.jpg"
]
for filename in priority_examples:
file_path = example_dir / filename
if file_path.exists():
examples.append([str(file_path)])
if len(examples) >= 12: # Limit to 12 examples for clean UI
break
# If we don't have enough from priority list, add random ones
if len(examples) < 12:
all_examples = list(example_dir.glob("*_1.jpg")) # Get first of each category
for ex in all_examples:
if str(ex) not in [e[0] for e in examples]:
examples.append([str(ex)])
if len(examples) >= 12:
break
# ---------------------------
# Gradio UI
# ---------------------------
custom_css = """
.gradio-container { font-family: 'Inter', sans-serif; }
.output-html { font-family: 'Inter', sans-serif; }
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎯 CIFAR-100 Image Classifier with Grad-CAM")
gr.Markdown("### Deep ResNet-34 trained on 100 object categories β€’ Explainable AI with Grad-CAM heatmaps")
with gr.Row():
# Left Column: Input Image
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image", height=500)
predict_btn = gr.Button("πŸš€ Classify Image", variant="primary", size="lg")
gr.Markdown("Upload an image belonging to one of the CIFAR-100 categories.")
# Right Column: Grad-CAM first, then Predictions
with gr.Column(scale=1):
# Grad-CAM visualizations at the top (side by side)
with gr.Row():
with gr.Column(scale=1):
gradcam_overlay_output = gr.Image(label="πŸ”₯ Grad-CAM Overlay", type="pil", height=200)
gr.Markdown("**Overlay** - Model attention on image")
with gr.Column(scale=1):
gradcam_heatmap_output = gr.Image(label="🌑️ Grad-CAM Heatmap", type="pil", height=200)
gr.Markdown("**Heatmap** - Red = high importance")
# Predictions below Grad-CAM
label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions")
html_output = gr.HTML(label="Detailed Results")
if examples:
gr.Examples(
examples=examples,
inputs=image_input,
outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output],
fn=predict,
cache_examples=False,
)
with gr.Accordion("πŸ“– Model Information & Performance Metrics", open=False):
gr.Markdown(model_description)
predict_btn.click(
fn=predict,
inputs=image_input,
outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output]
)
image_input.change(
fn=predict,
inputs=image_input,
outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output]
)
# ---------------------------
# Launch
# ---------------------------
if __name__ == "__main__":
demo.launch()