""" CIFAR-100 Image Classification App Deployed on Hugging Face Spaces with Gradio Author: Krishnakanth Date: 2025-10-10 """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np from typing import Dict, Tuple, List import torchvision.transforms as transforms import plotly.graph_objects as go # Import model architecture from model import CIFAR100ResNet34, ModelConfig # CIFAR-100 class names CIFAR100_CLASSES = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] # CIFAR-100 normalization values CIFAR100_MEAN = (0.5071, 0.4867, 0.4408) CIFAR100_STD = (0.2675, 0.2565, 0.2761) # Global variables for model model = None device = None def load_model(model_path: str = "cifar100_model.pth"): """Load the trained CIFAR-100 model.""" global model, device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create model configuration config = ModelConfig( input_channels=3, input_size=(32, 32), num_classes=100, dropout_rate=0.05 ) # Initialize model model = CIFAR100ResNet34(config) # Load trained weights try: # PyTorch 2.6+ requires weights_only=False for checkpoints with custom classes checkpoint = torch.load(model_path, map_location=device, weights_only=False) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) print(f"✅ Model loaded with metrics: {checkpoint.get('metrics', {})}") else: model.load_state_dict(checkpoint) model.to(device) model.eval() total_params = sum(p.numel() for p in model.parameters()) print(f"✅ Model loaded successfully on {device}") print(f" Total parameters: {total_params:,}") return True except Exception as e: print(f"❌ Error loading model: {str(e)}") return False def get_transform(): """Get image transformation pipeline.""" return transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD) ]) def preprocess_image(image: Image.Image) -> torch.Tensor: """Preprocess image for model input.""" # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') # Apply transformations transform = get_transform() image_tensor = transform(image) # Add batch dimension image_tensor = image_tensor.unsqueeze(0) return image_tensor def predict(image: Image.Image) -> Tuple[Dict[str, float], str, str]: """ Make prediction on image. Returns: - Dictionary of top predictions {class: probability} - HTML formatted main prediction - Plotly chart (not used in Gradio, for reference) """ if model is None: return {}, "❌ Model not loaded", "" try: # Preprocess image image_tensor = preprocess_image(image) # Make prediction with torch.no_grad(): image_tensor = image_tensor.to(device) # Get model output (log probabilities) output = model(image_tensor) # Convert to probabilities probabilities = torch.exp(output) # Get top-10 predictions top_probs, top_indices = torch.topk(probabilities, 10, dim=1) top_probs = top_probs[0].cpu().numpy() top_indices = top_indices[0].cpu().numpy() # Get predicted class predicted_class = CIFAR100_CLASSES[top_indices[0]] confidence = top_probs[0] # Create results dictionary for Gradio Label output results_dict = {} for idx, prob in zip(top_indices, top_probs): class_name = CIFAR100_CLASSES[idx].replace('_', ' ').title() results_dict[class_name] = float(prob) # Create formatted output confidence_pct = confidence * 100 if confidence_pct > 70: conf_emoji = "✅" conf_text = "High Confidence" color = "#28a745" elif confidence_pct > 40: conf_emoji = "⚠️" conf_text = "Medium Confidence" color = "#ffc107" else: conf_emoji = "❌" conf_text = "Low Confidence" color = "#dc3545" main_prediction = f"""

Predicted Class

{predicted_class.replace('_', ' ').upper()}

{conf_emoji} {confidence_pct:.2f}%

{conf_text}

""" return results_dict, main_prediction, "" except Exception as e: error_msg = f"❌ Error during prediction: {str(e)}" return {}, error_msg, "" def create_interface(): """Create Gradio interface.""" # Custom CSS custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } .gr-button-primary { background: linear-gradient(90deg, #1f77b4, #9467bd) !important; border: none !important; } footer { visibility: hidden; } """ # Create interface with gr.Blocks(css=custom_css, title="CIFAR-100 Classifier", theme=gr.themes.Soft()) as demo: # Header gr.Markdown(""" # 🖼️ CIFAR-100 Image Classifier Upload an image and the AI will classify it into one of **100 different categories** with confidence scores. Built with **PyTorch ResNet-34** architecture. """) with gr.Row(): with gr.Column(scale=1): # Input section gr.Markdown("## 📤 Upload Image") image_input = gr.Image( type="pil", label="Upload an image", sources=["upload", "clipboard", "webcam"], height=400 ) predict_btn = gr.Button("🔍 Classify Image", variant="primary", size="lg") # Examples section with two categories with gr.Row(): with gr.Column(): gr.Markdown("### ✅ Good Predictions Examples") gr.Examples( examples=[ "apple_s_000028.png", "breakfast_table_s_000178.png", "cichlid_fish_s_000888.png" ], inputs=image_input, label="High Confidence Predictions" ) with gr.Column(): gr.Markdown("### ⚠️ Challenging Predictions Examples") gr.Examples( examples=[ "crocodile_s_000018.png", "boy_s_000005.png", "armchair_s_000853.png" ], inputs=image_input, label="Lower Confidence Predictions" ) gr.Markdown(""" ### 💡 Tips for Best Results - Use clear, well-lit images - Center the main object - Any size works (auto-resized to 32×32) - Supported: JPG, PNG, BMP, WEBP """) with gr.Column(scale=1): # Output section gr.Markdown("## 🎯 Classification Results") main_output = gr.HTML(label="Main Prediction") gr.Markdown("### 📊 Top 10 Predictions") label_output = gr.Label( num_top_classes=10, label="Confidence Scores", show_label=False ) # Additional info section with gr.Row(): with gr.Column(): gr.Markdown(""" ### 🤖 Model Information - **Architecture**: ResNet-34 with Bottleneck Layers - **Parameters**: ~21 Million - **Dataset**: CIFAR-100 (60,000 images, 100 classes) - **Input Size**: 32×32 RGB images - **Categories**: Animals, vehicles, household items, nature scenes, and more """) with gr.Column(): gr.Markdown(""" ### 📚 Sample Categories **Animals**: bear, dolphin, elephant, fox, leopard, tiger, whale **Vehicles**: bicycle, bus, motorcycle, train, tractor **Nature**: cloud, forest, mountain, sea, plain **Objects**: chair, clock, lamp, telephone, keyboard **Plants**: maple_tree, oak_tree, orchid, rose, sunflower *...and 75 more categories!* """) # Footer gr.Markdown(""" ---

Built with ❤️ using PyTorch, Gradio, and Hugging Face Spaces

Model: ResNet-34 trained on CIFAR-100 dataset

Created by Krishnakanth | © 2025

""") # Connect button to prediction function predict_btn.click( fn=predict, inputs=image_input, outputs=[label_output, main_output, gr.Textbox(visible=False)] ) # # Also trigger on image upload # image_input.change( # fn=predict, # inputs=image_input, # outputs=[label_output, main_output, gr.Textbox(visible=False)] # ) return demo def main(): """Main function to run the Gradio app.""" print("="*60) print("🖼️ CIFAR-100 Image Classifier") print("="*60) # Load model print("\n📦 Loading model...") success = load_model("cifar100_model.pth") if not success: print("❌ Failed to load model. Please check if cifar100_model.pth exists.") return print("\n🚀 Creating Gradio interface...") # Create and launch interface demo = create_interface() print("\n✅ Interface created successfully!") print("="*60) # Launch app demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True ) if __name__ == "__main__": main()