| """ |
| CIFAR-100 Image Classifier - Hugging Face Space |
| =============================================== |
| Advanced ResNet-34 model trained on CIFAR-100. |
| |
| Architecture: ResNet-34 |
| - 100 output classes across diverse object categories |
| - Deep residual learning for robust feature extraction |
| - Trained with Albumentations augmentations |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import gradio as gr |
| from PIL import Image |
| from pathlib import Path |
| import numpy as np |
| import cv2 |
|
|
| |
| from model_cifar import ResNet34 |
| from preprocess import CIFAR100_MEAN, CIFAR100_STD |
| from torchvision import transforms |
|
|
| |
| CIFAR100_CLASSES = [ |
| 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', |
| 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', |
| 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', |
| 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', |
| 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', |
| 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', |
| 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', |
| 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', |
| 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', |
| 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' |
| ] |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
| @torch.no_grad() |
| def load_model(checkpoint_path: str = None): |
| """Load the trained ResNet-34 model for CIFAR-100. Reaching Accuracy of 76.68%""" |
| model = ResNet34(num_classes=100).to(device) |
| |
| if checkpoint_path and Path(checkpoint_path).exists(): |
| try: |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print(f"β
Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}") |
| else: |
| model.load_state_dict(checkpoint) |
| print(f"β
Loaded model weights from {checkpoint_path}") |
| except Exception as e: |
| print(f"β οΈ Could not load checkpoint: {e}") |
| print("Using randomly initialized model") |
| else: |
| print("βΉοΈ No checkpoint provided, using randomly initialized model") |
| |
| model.eval() |
| return model |
|
|
|
|
| print(f"Device: {device}") |
| |
| checkpoint_paths = [ |
| "./best_model.pth", |
| "./snapshots_complete/cifar_epoch_99.pth", |
| None |
| ] |
|
|
| model = None |
| for checkpoint_path in checkpoint_paths: |
| if checkpoint_path is None or Path(checkpoint_path).exists(): |
| model = load_model(checkpoint_path) |
| break |
|
|
| if model is None: |
| model = load_model(None) |
|
|
| |
| |
| |
| preprocess = transforms.Compose([ |
| transforms.Resize((32, 32)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD) |
| ]) |
|
|
| preprocess_no_norm = transforms.Compose([ |
| transforms.Resize((32, 32)), |
| transforms.ToTensor(), |
| ]) |
|
|
|
|
| |
| |
| |
| class GradCAM: |
| """Grad-CAM: Visual Explanations from Deep Networks""" |
| |
| def __init__(self, model, target_layer): |
| self.model = model |
| self.target_layer = target_layer |
| self.gradients = None |
| self.activations = None |
| |
| |
| target_layer.register_forward_hook(self.save_activation) |
| target_layer.register_full_backward_hook(self.save_gradient) |
| |
| def save_activation(self, module, input, output): |
| self.activations = output.detach() |
| |
| def save_gradient(self, module, grad_input, grad_output): |
| self.gradients = grad_output[0].detach() |
| |
| def generate_cam(self, input_tensor, target_class=None): |
| """Generate Grad-CAM heatmap""" |
| |
| model_output = self.model(input_tensor) |
| |
| if target_class is None: |
| target_class = model_output.argmax(dim=1).item() |
| |
| |
| self.model.zero_grad() |
| one_hot = torch.zeros_like(model_output) |
| one_hot[0, target_class] = 1 |
| model_output.backward(gradient=one_hot, retain_graph=True) |
| |
| |
| gradients = self.gradients[0] |
| activations = self.activations[0] |
| |
| |
| weights = gradients.mean(dim=(1, 2), keepdim=True) |
| |
| |
| cam = (weights * activations).sum(dim=0) |
| |
| |
| cam = F.relu(cam) |
| |
| |
| cam = cam - cam.min() |
| if cam.max() > 0: |
| cam = cam / cam.max() |
| |
| return cam.cpu().numpy(), target_class |
|
|
|
|
| def apply_gradcam(image_pil, model, gradcam, top_class_idx): |
| """Apply Grad-CAM and overlay on original image""" |
| |
| img_tensor = preprocess(image_pil.convert("RGB")).unsqueeze(0).to(device) |
| |
| |
| cam, _ = gradcam.generate_cam(img_tensor, target_class=top_class_idx) |
| |
| |
| output_size = 200 |
| |
| |
| cam_resized = cv2.resize(cam, (32, 32)) |
| |
| |
| img_np = np.array(image_pil.resize((32, 32))) |
| |
| |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| |
| |
| overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0) |
| |
| |
| overlay_large = cv2.resize(overlay, (output_size, output_size), interpolation=cv2.INTER_LINEAR) |
| heatmap_large = cv2.resize(heatmap, (output_size, output_size), interpolation=cv2.INTER_LINEAR) |
| |
| return overlay_large, heatmap_large |
|
|
|
|
| |
| gradcam = GradCAM(model, model.layer4[-1].conv2) |
|
|
|
|
| |
| |
| |
| def predict(image: Image.Image): |
| """Predict the class of an input image with Grad-CAM visualization.""" |
| if image is None: |
| return {}, "<p style='color: red;'>Please upload an image first!</p>", None, None |
| |
| try: |
| |
| img_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(img_tensor) |
| probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy() |
|
|
| sorted_indices = np.argsort(probabilities)[::-1] |
| |
| top3_results = { |
| CIFAR100_CLASSES[i]: float(probabilities[i]) |
| for i in sorted_indices[:3] |
| } |
| |
| predicted_class = CIFAR100_CLASSES[sorted_indices[0]] |
| confidence = probabilities[sorted_indices[0]] |
| |
| |
| try: |
| overlay, heatmap = apply_gradcam(image, model, gradcam, sorted_indices[0]) |
| gradcam_overlay = Image.fromarray(overlay.astype(np.uint8)) |
| gradcam_heatmap = Image.fromarray(heatmap.astype(np.uint8)) |
| except Exception as e: |
| print(f"Grad-CAM error: {e}") |
| gradcam_overlay = None |
| gradcam_heatmap = None |
|
|
| |
| html_output = f""" |
| <div style='padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| border-radius: 10px; color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1);'> |
| <h2>π― Prediction Result</h2> |
| <div style='font-size: 24px; font-weight: bold;'>{predicted_class.upper()}</div> |
| <div style='font-size: 18px;'>Confidence: <strong>{confidence*100:.2f}%</strong></div> |
| </div> |
| <div style='margin-top: 20px; background: #f8f9fa; border-radius: 8px; padding: 15px;'> |
| <h3>π Top 5 Predictions</h3> |
| """ |
| for i, idx in enumerate(sorted_indices[:5], 1): |
| name = CIFAR100_CLASSES[idx] |
| prob = probabilities[idx] |
| bar_width = int(prob * 100) |
| color = "#667eea" if i == 1 else ("#764ba2" if i == 2 else "#95a5a6") |
| html_output += f""" |
| <div style='margin: 8px 0;'> |
| <div style='display: flex; justify-content: space-between;'> |
| <span>{i}. {name}</span> |
| <span style='font-weight:bold; color:{color}'>{prob*100:.2f}%</span> |
| </div> |
| <div style='background:#e9ecef; border-radius:4px; height:20px;'> |
| <div style='width:{bar_width}%; background:{color}; height:100%; border-radius:4px;'></div> |
| </div> |
| </div> |
| """ |
| html_output += """ |
| </div> |
| <div style='margin-top: 15px; padding: 10px; background: #e8f4f8; border-left: 4px solid #667eea; border-radius: 4px;'> |
| <p style='margin: 0; color: #333;'><strong>π‘ Grad-CAM Visualization:</strong> The heatmap shows which parts of the image the model focused on to make its prediction. Red/yellow areas indicate high importance.</p> |
| </div> |
| """ |
| |
| return top3_results, html_output, gradcam_overlay, gradcam_heatmap |
|
|
| except Exception as e: |
| return {}, f"<p style='color: red;'>Error during prediction: {str(e)}</p>", None, None |
|
|
|
|
| |
| |
| |
| model_description = """ |
| ## π About This Model |
| **Custom Lightweight ResNet trained on CIFAR-100 from scratch** |
| |
| ### π Performance Metrics (100 Epochs) |
| - **Top-1 Accuracy:** 76.68% β
(Target: 73%) |
| - **Top-3 Accuracy:** 90.95% |
| - **Top-5 Accuracy:** 94.07% |
| - **Best Test Accuracy:** 76.79% (Epoch 99) |
| - **Macro F1-Score:** 0.7670 |
| - **Dataset:** CIFAR-100 (50,000 train / 10,000 test) |
| |
| ### ποΈ Architecture |
| - **Model:** Custom ResNet-34 variant (CIFAR-optimized) |
| - **Parameters:** 4,949,412 (~5M) |
| - **Depth:** 10 weight layers (1 stem + 8 residual + 1 FC) |
| - **Design:** 4 BasicBlocks with [1,1,1,1] configuration |
| - **Key Features:** |
| - 3Γ3 initial conv (no 7Γ7, no MaxPool) |
| - Progressive downsampling: 32Γ32 β 16Γ16 β 8Γ8 β 4Γ4 |
| - Channel expansion: 64 β 128 β 256 β 512 |
| - Receptive field: 63Γ63 (covers full 32Γ32 image) |
| - Global Average Pooling + Linear(512 β 100) |
| |
| ### π― Training Configuration |
| - **Optimizer:** SGD with Nesterov momentum (0.9) |
| - **LR Schedule:** OneCycle (0.01 β 0.1 β 0.01 β 0.001) |
| - Phase 1: 41 epochs warmup |
| - Phase 2: 41 epochs cooldown |
| - Phase 3: 18 epochs annihilation |
| - **Augmentations:** Albumentations (Flip, ShiftScaleRotate, CoarseDropout, ColorJitter) |
| - **Regularization:** Weight decay (1e-4), Label smoothing (0.1) |
| - **Mixed Precision:** Enabled (AMP) |
| - **Batch Size:** 512 |
| |
| ### π‘ CIFAR-100 Classes |
| 100 fine-grained categories across 20 superclasses: |
| - π¦ **Animals:** lion, tiger, elephant, whale, bear, leopard, wolf |
| - π **Vehicles:** pickup_truck, bus, train, streetcar, motorcycle, tractor |
| - π³ **Trees:** maple_tree, oak_tree, palm_tree, pine_tree, willow_tree |
| - πΊ **Flowers:** rose, poppy, orchid, sunflower, tulip |
| - π **Aquatic:** aquarium_fish, flatfish, ray, shark, trout |
| - π **Structures:** house, castle, skyscraper, bridge, road |
| - π **Food:** apple, orange, pear, sweet_pepper, mushroom |
| - π¨ **People:** man, woman, baby, boy, girl |
| - πͺ **Furniture:** bed, chair, couch, table, wardrobe |
| - And many more! |
| |
| ### π Best Performing Classes (F1-Score) |
| 1. **Wardrobe** - 94.58% |
| 2. **Sunflower** - 93.81% |
| 3. **Poppy** - 93.15% |
| 4. **Can** - 93.10% |
| 5. **Skyscraper** - 91.00% |
| |
| ### π Deployment |
| - Trained without pre-trained weights |
| - Built with PyTorch and Albumentations |
| - Deployed on Hugging Face Spaces |
| - Inference optimized for CPU/GPU |
| """ |
|
|
| |
| examples = [] |
| example_dir = Path("examples") |
| if example_dir.exists(): |
| |
| priority_examples = [ |
| "lion_1.jpg", "tiger_1.jpg", "elephant_1.jpg", "bear_1.jpg", |
| "pickup_truck_1.jpg", "bus_1.jpg", "train_1.jpg", |
| "rose_1.jpg", "sunflower_1.jpg", "tulip_1.jpg", |
| "apple_1.jpg", "orange_1.jpg", "pear_1.jpg", |
| "castle_1.jpg", "skyscraper_1.jpg", "house_1.jpg" |
| ] |
| |
| for filename in priority_examples: |
| file_path = example_dir / filename |
| if file_path.exists(): |
| examples.append([str(file_path)]) |
| if len(examples) >= 12: |
| break |
| |
| |
| if len(examples) < 12: |
| all_examples = list(example_dir.glob("*_1.jpg")) |
| for ex in all_examples: |
| if str(ex) not in [e[0] for e in examples]: |
| examples.append([str(ex)]) |
| if len(examples) >= 12: |
| break |
|
|
|
|
| |
| |
| |
| custom_css = """ |
| .gradio-container { font-family: 'Inter', sans-serif; } |
| .output-html { font-family: 'Inter', sans-serif; } |
| """ |
|
|
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π― CIFAR-100 Image Classifier with Grad-CAM") |
| gr.Markdown("### Deep ResNet-34 trained on 100 object categories β’ Explainable AI with Grad-CAM heatmaps") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| image_input = gr.Image(type="pil", label="Upload Image", height=500) |
| predict_btn = gr.Button("π Classify Image", variant="primary", size="lg") |
| gr.Markdown("Upload an image belonging to one of the CIFAR-100 categories.") |
|
|
| |
| with gr.Column(scale=1): |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gradcam_overlay_output = gr.Image(label="π₯ Grad-CAM Overlay", type="pil", height=200) |
| gr.Markdown("**Overlay** - Model attention on image") |
| |
| with gr.Column(scale=1): |
| gradcam_heatmap_output = gr.Image(label="π‘οΈ Grad-CAM Heatmap", type="pil", height=200) |
| gr.Markdown("**Heatmap** - Red = high importance") |
| |
| |
| label_output = gr.Label(num_top_classes=3, label="Top 3 Predictions") |
| html_output = gr.HTML(label="Detailed Results") |
|
|
| if examples: |
| gr.Examples( |
| examples=examples, |
| inputs=image_input, |
| outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output], |
| fn=predict, |
| cache_examples=False, |
| ) |
|
|
| with gr.Accordion("π Model Information & Performance Metrics", open=False): |
| gr.Markdown(model_description) |
|
|
| predict_btn.click( |
| fn=predict, |
| inputs=image_input, |
| outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output] |
| ) |
| image_input.change( |
| fn=predict, |
| inputs=image_input, |
| outputs=[label_output, html_output, gradcam_overlay_output, gradcam_heatmap_output] |
| ) |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|
|
|