""" CIFAR-100 Image Classification App Deployed on Hugging Face Spaces with Gradio Author: Krishnakanth Date: 2025-10-10 """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np from typing import Dict, Tuple, List import torchvision.transforms as transforms import plotly.graph_objects as go # Import model architecture from model import CIFAR100ResNet34, ModelConfig # CIFAR-100 class names CIFAR100_CLASSES = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] # CIFAR-100 normalization values CIFAR100_MEAN = (0.5071, 0.4867, 0.4408) CIFAR100_STD = (0.2675, 0.2565, 0.2761) # Global variables for model model = None device = None def load_model(model_path: str = "cifar100_model.pth"): """Load the trained CIFAR-100 model.""" global model, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create model configuration config = ModelConfig( input_channels=3, input_size=(32, 32), num_classes=100, dropout_rate=0.05 ) # Initialize model model = CIFAR100ResNet34(config) # Load trained weights try: # PyTorch 2.6+ requires weights_only=False for checkpoints with custom classes checkpoint = torch.load(model_path, map_location=device, weights_only=False) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) print(f"✅ Model loaded with metrics: {checkpoint.get('metrics', {})}") else: model.load_state_dict(checkpoint) model.to(device) model.eval() total_params = sum(p.numel() for p in model.parameters()) print(f"✅ Model loaded successfully on {device}") print(f" Total parameters: {total_params:,}") return True except Exception as e: print(f"❌ Error loading model: {str(e)}") return False def get_transform(): """Get image transformation pipeline.""" return transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD) ]) def preprocess_image(image: Image.Image) -> torch.Tensor: """Preprocess image for model input.""" # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') # Apply transformations transform = get_transform() image_tensor = transform(image) # Add batch dimension image_tensor = image_tensor.unsqueeze(0) return image_tensor def predict(image: Image.Image) -> Tuple[Dict[str, float], str, str]: """ Make prediction on image. Returns: - Dictionary of top predictions {class: probability} - HTML formatted main prediction - Plotly chart (not used in Gradio, for reference) """ if model is None: return {}, "❌ Model not loaded", "" try: # Preprocess image image_tensor = preprocess_image(image) # Make prediction with torch.no_grad(): image_tensor = image_tensor.to(device) # Get model output (log probabilities) output = model(image_tensor) # Convert to probabilities probabilities = torch.exp(output) # Get top-10 predictions top_probs, top_indices = torch.topk(probabilities, 10, dim=1) top_probs = top_probs[0].cpu().numpy() top_indices = top_indices[0].cpu().numpy() # Get predicted class predicted_class = CIFAR100_CLASSES[top_indices[0]] confidence = top_probs[0] # Create results dictionary for Gradio Label output results_dict = {} for idx, prob in zip(top_indices, top_probs): class_name = CIFAR100_CLASSES[idx].replace('_', ' ').title() results_dict[class_name] = float(prob) # Create formatted output confidence_pct = confidence * 100 if confidence_pct > 70: conf_emoji = "✅" conf_text = "High Confidence" color = "#28a745" elif confidence_pct > 40: conf_emoji = "⚠️" conf_text = "Medium Confidence" color = "#ffc107" else: conf_emoji = "❌" conf_text = "Low Confidence" color = "#dc3545" main_prediction = f"""
{conf_text}
Built with ❤️ using PyTorch, Gradio, and Hugging Face Spaces
Model: ResNet-34 trained on CIFAR-100 dataset
Created by Krishnakanth | © 2025