""" CIFAR-100 Image Classifier - Hugging Face Space =============================================== Advanced ResNet-34 model trained on CIFAR-100. Architecture: ResNet-34 - 100 output classes across diverse object categories - Deep residual learning for robust feature extraction - Trained with Albumentations augmentations """ import torch import torch.nn.functional as F import gradio as gr from PIL import Image from pathlib import Path import numpy as np import cv2 # Import model architecture and preprocessing from model_cifar import ResNet34 from preprocess import CIFAR100_MEAN, CIFAR100_STD from torchvision import transforms # CIFAR-100 class names (official dataset labels) CIFAR100_CLASSES = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --------------------------- # Load Model # --------------------------- @torch.no_grad() def load_model(checkpoint_path: str = None): """Load the trained ResNet-34 model for CIFAR-100. Reaching Accuracy of 76.68%""" model = ResNet34(num_classes=100).to(device) if checkpoint_path and Path(checkpoint_path).exists(): try: checkpoint = torch.load(checkpoint_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) print(f"✅ Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}") else: model.load_state_dict(checkpoint) print(f"✅ Loaded model weights from {checkpoint_path}") except Exception as e: print(f"⚠️ Could not load checkpoint: {e}") print("Using randomly initialized model") else: print("ℹ️ No checkpoint provided, using randomly initialized model") model.eval() return model print(f"Device: {device}") # Try to load the best checkpoint, fallback to epoch 99, then random init checkpoint_paths = [ "./best_model.pth", # For HF Space deployment "./snapshots_complete/cifar_epoch_99.pth", None # Fallback to random initialization ] model = None for checkpoint_path in checkpoint_paths: if checkpoint_path is None or Path(checkpoint_path).exists(): model = load_model(checkpoint_path) break if model is None: model = load_model(None) # --------------------------- # Preprocessing pipeline # --------------------------- preprocess = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD) ]) preprocess_no_norm = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), ]) # --------------------------- # Grad-CAM Implementation # --------------------------- class GradCAM: """Grad-CAM: Visual Explanations from Deep Networks""" def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None # Register hooks target_layer.register_forward_hook(self.save_activation) target_layer.register_full_backward_hook(self.save_gradient) def save_activation(self, module, input, output): self.activations = output.detach() def save_gradient(self, module, grad_input, grad_output): self.gradients = grad_output[0].detach() def generate_cam(self, input_tensor, target_class=None): """Generate Grad-CAM heatmap""" # Forward pass model_output = self.model(input_tensor) if target_class is None: target_class = model_output.argmax(dim=1).item() # Backward pass self.model.zero_grad() one_hot = torch.zeros_like(model_output) one_hot[0, target_class] = 1 model_output.backward(gradient=one_hot, retain_graph=True) # Generate CAM gradients = self.gradients[0] activations = self.activations[0] # Global average pooling on gradients weights = gradients.mean(dim=(1, 2), keepdim=True) # Weighted combination of activation maps cam = (weights * activations).sum(dim=0) # Apply ReLU cam = F.relu(cam) # Normalize cam = cam - cam.min() if cam.max() > 0: cam = cam / cam.max() return cam.cpu().numpy(), target_class def apply_gradcam(image_pil, model, gradcam, top_class_idx): """Apply Grad-CAM and overlay on original image""" # Prepare input img_tensor = preprocess(image_pil.convert("RGB")).unsqueeze(0).to(device) # Generate Grad-CAM cam, _ = gradcam.generate_cam(img_tensor, target_class=top_class_idx) # Fixed output size for better visibility output_size = 200 # Resize CAM to match input size cam_resized = cv2.resize(cam, (32, 32)) # Convert original image to numpy img_np = np.array(image_pil.resize((32, 32))) # Create heatmap heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Overlay heatmap on original image overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0) # Resize to 200x200 for better visibility overlay_large = cv2.resize(overlay, (output_size, output_size), interpolation=cv2.INTER_LINEAR) heatmap_large = cv2.resize(heatmap, (output_size, output_size), interpolation=cv2.INTER_LINEAR) return overlay_large, heatmap_large # Initialize Grad-CAM (target the last convolutional layer) gradcam = GradCAM(model, model.layer4[-1].conv2) # --------------------------- # Prediction Function # --------------------------- def predict(image: Image.Image): """Predict the class of an input image with Grad-CAM visualization.""" if image is None: return {}, "

Please upload an image first!

", None, None try: # Prepare input img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device) # Get predictions with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy() sorted_indices = np.argsort(probabilities)[::-1] top3_results = { CIFAR100_CLASSES[i]: float(probabilities[i]) for i in sorted_indices[:3] } predicted_class = CIFAR100_CLASSES[sorted_indices[0]] confidence = probabilities[sorted_indices[0]] # Generate Grad-CAM visualization try: overlay, heatmap = apply_gradcam(image, model, gradcam, sorted_indices[0]) gradcam_overlay = Image.fromarray(overlay.astype(np.uint8)) gradcam_heatmap = Image.fromarray(heatmap.astype(np.uint8)) except Exception as e: print(f"Grad-CAM error: {e}") gradcam_overlay = None gradcam_heatmap = None # Create HTML output html_output = f"""

🎯 Prediction Result

{predicted_class.upper()}
Confidence: {confidence*100:.2f}%

📊 Top 5 Predictions

""" for i, idx in enumerate(sorted_indices[:5], 1): name = CIFAR100_CLASSES[idx] prob = probabilities[idx] bar_width = int(prob * 100) color = "#667eea" if i == 1 else ("#764ba2" if i == 2 else "#95a5a6") html_output += f"""
{i}. {name} {prob*100:.2f}%
""" html_output += """

💡 Grad-CAM Visualization: The heatmap shows which parts of the image the model focused on to make its prediction. Red/yellow areas indicate high importance.

""" return top3_results, html_output, gradcam_overlay, gradcam_heatmap except Exception as e: return {}, f"

Error during prediction: {str(e)}

", None, None # --------------------------- # Model Information Section # --------------------------- model_description = """ ## 🚀 About This Model **Custom Lightweight ResNet trained on CIFAR-100 from scratch** ### 📊 Performance Metrics (100 Epochs) - **Top-1 Accuracy:** 76.68% ✅ (Target: 73%) - **Top-3 Accuracy:** 90.95% - **Top-5 Accuracy:** 94.07% - **Best Test Accuracy:** 76.79% (Epoch 99) - **Macro F1-Score:** 0.7670 - **Dataset:** CIFAR-100 (50,000 train / 10,000 test) ### 🏗️ Architecture - **Model:** Custom ResNet-34 variant (CIFAR-optimized) - **Parameters:** 4,949,412 (~5M) - **Depth:** 10 weight layers (1 stem + 8 residual + 1 FC) - **Design:** 4 BasicBlocks with [1,1,1,1] configuration - **Key Features:** - 3×3 initial conv (no 7×7, no MaxPool) - Progressive downsampling: 32×32 → 16×16 → 8×8 → 4×4 - Channel expansion: 64 → 128 → 256 → 512 - Receptive field: 63×63 (covers full 32×32 image) - Global Average Pooling + Linear(512 → 100) ### 🎯 Training Configuration - **Optimizer:** SGD with Nesterov momentum (0.9) - **LR Schedule:** OneCycle (0.01 → 0.1 → 0.01 → 0.001) - Phase 1: 41 epochs warmup - Phase 2: 41 epochs cooldown - Phase 3: 18 epochs annihilation - **Augmentations:** Albumentations (Flip, ShiftScaleRotate, CoarseDropout, ColorJitter) - **Regularization:** Weight decay (1e-4), Label smoothing (0.1) - **Mixed Precision:** Enabled (AMP) - **Batch Size:** 512 ### 💡 CIFAR-100 Classes 100 fine-grained categories across 20 superclasses: - 🦁 **Animals:** lion, tiger, elephant, whale, bear, leopard, wolf - 🚗 **Vehicles:** pickup_truck, bus, train, streetcar, motorcycle, tractor - 🌳 **Trees:** maple_tree, oak_tree, palm_tree, pine_tree, willow_tree - 🌺 **Flowers:** rose, poppy, orchid, sunflower, tulip - 🐠 **Aquatic:** aquarium_fish, flatfish, ray, shark, trout - 🏠 **Structures:** house, castle, skyscraper, bridge, road - 🍎 **Food:** apple, orange, pear, sweet_pepper, mushroom - 👨 **People:** man, woman, baby, boy, girl - 🪑 **Furniture:** bed, chair, couch, table, wardrobe - And many more! ### 🏆 Best Performing Classes (F1-Score) 1. **Wardrobe** - 94.58% 2. **Sunflower** - 93.81% 3. **Poppy** - 93.15% 4. **Can** - 93.10% 5. **Skyscraper** - 91.00% ### 🚀 Deployment - Trained without pre-trained weights - Built with PyTorch and Albumentations - Deployed on Hugging Face Spaces - Inference optimized for CPU/GPU """ # Example images (curated selection from available examples) examples = [] example_dir = Path("examples") if example_dir.exists(): # Get one example from each category (prioritize _1.jpg files for consistency) priority_examples = [ "lion_1.jpg", "tiger_1.jpg", "elephant_1.jpg", "bear_1.jpg", "pickup_truck_1.jpg", "bus_1.jpg", "train_1.jpg", "rose_1.jpg", "sunflower_1.jpg", "tulip_1.jpg", "apple_1.jpg", "orange_1.jpg", "pear_1.jpg", "castle_1.jpg", "skyscraper_1.jpg", "house_1.jpg" ] for filename in priority_examples: file_path = example_dir / filename if file_path.exists(): examples.append([str(file_path)]) if len(examples) >= 12: # Limit to 12 examples for clean UI break # If we don't have enough from priority list, add random ones if len(examples) < 12: all_examples = list(example_dir.glob("*_1.jpg")) # Get first of each category for ex in all_examples: if str(ex) not in [e[0] for e in examples]: examples.append([str(ex)]) if len(examples) >= 12: break # --------------------------- # Gradio UI # --------------------------- custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } .output-html { font-family: 'Inter', sans-serif; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎯 CIFAR-100 Image Classifier with Grad-CAM") gr.Markdown("### Deep ResNet-34 trained on 100 object categories • Explainable AI with Grad-CAM heatmaps") with gr.Row(): # Left Column: Input Image with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image", height=500) predict_btn = gr.Button("🚀 Classify Image", variant="primary", size="lg") gr.Markdown("Upload an image belonging to one of the CIFAR-100 categories.") # Right Column: Grad-CAM first, then Predictions with gr.Column(scale=1): # Grad-CAM visualizations at the top (side by side) with gr.Row(): with gr.Column(scale=1): gradcam_overlay_output = gr.Image(label="🔥 Grad-CAM Overlay", type="pil", height=200) gr.Markdown("**Overlay** - Model attention on image") with gr.Column(scale=1): gradcam_heatmap_output = gr.Image(label="🌡️ Grad-CAM Heatmap", type="pil", height=200) gr.Markdown("**Heatmap** - Red = high importance") # Predictions below Grad-CAM label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions") html_output = gr.HTML(label="Detailed Results") if examples: gr.Examples( examples=examples, inputs=image_input, outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output], fn=predict, cache_examples=False, ) with gr.Accordion("📖 Model Information & Performance Metrics", open=False): gr.Markdown(model_description) predict_btn.click( fn=predict, inputs=image_input, outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output] ) image_input.change( fn=predict, inputs=image_input, outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output] ) # --------------------------- # Launch # --------------------------- if __name__ == "__main__": demo.launch()