|
|
""" |
|
|
CIFAR-100 Image Classification App |
|
|
Deployed on Hugging Face Spaces with Gradio |
|
|
|
|
|
Author: Krishnakanth |
|
|
Date: 2025-10-10 |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from typing import Dict, Tuple, List |
|
|
import torchvision.transforms as transforms |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
|
|
|
from model import CIFAR100ResNet34, ModelConfig |
|
|
|
|
|
|
|
|
CIFAR100_CLASSES = [ |
|
|
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', |
|
|
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', |
|
|
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', |
|
|
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', |
|
|
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', |
|
|
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', |
|
|
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', |
|
|
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', |
|
|
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', |
|
|
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', |
|
|
'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', |
|
|
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', |
|
|
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', |
|
|
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' |
|
|
] |
|
|
|
|
|
|
|
|
CIFAR100_MEAN = (0.5071, 0.4867, 0.4408) |
|
|
CIFAR100_STD = (0.2675, 0.2565, 0.2761) |
|
|
|
|
|
|
|
|
model = None |
|
|
device = None |
|
|
|
|
|
|
|
|
def load_model(model_path: str = "cifar100_model.pth"): |
|
|
"""Load the trained CIFAR-100 model.""" |
|
|
global model, device |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
config = ModelConfig( |
|
|
input_channels=3, |
|
|
input_size=(32, 32), |
|
|
num_classes=100, |
|
|
dropout_rate=0.05 |
|
|
) |
|
|
|
|
|
|
|
|
model = CIFAR100ResNet34(config) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f"✅ Model loaded with metrics: {checkpoint.get('metrics', {})}") |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"✅ Model loaded successfully on {device}") |
|
|
print(f" Total parameters: {total_params:,}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {str(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_transform(): |
|
|
"""Get image transformation pipeline.""" |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((32, 32)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD) |
|
|
]) |
|
|
|
|
|
|
|
|
def preprocess_image(image: Image.Image) -> torch.Tensor: |
|
|
"""Preprocess image for model input.""" |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
transform = get_transform() |
|
|
image_tensor = transform(image) |
|
|
|
|
|
|
|
|
image_tensor = image_tensor.unsqueeze(0) |
|
|
|
|
|
return image_tensor |
|
|
|
|
|
|
|
|
def predict(image: Image.Image) -> Tuple[Dict[str, float], str, str]: |
|
|
""" |
|
|
Make prediction on image. |
|
|
|
|
|
Returns: |
|
|
- Dictionary of top predictions {class: probability} |
|
|
- HTML formatted main prediction |
|
|
- Plotly chart (not used in Gradio, for reference) |
|
|
""" |
|
|
if model is None: |
|
|
return {}, "❌ Model not loaded", "" |
|
|
|
|
|
try: |
|
|
|
|
|
image_tensor = preprocess_image(image) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
|
|
|
output = model(image_tensor) |
|
|
|
|
|
|
|
|
probabilities = torch.exp(output) |
|
|
|
|
|
|
|
|
top_probs, top_indices = torch.topk(probabilities, 10, dim=1) |
|
|
|
|
|
top_probs = top_probs[0].cpu().numpy() |
|
|
top_indices = top_indices[0].cpu().numpy() |
|
|
|
|
|
|
|
|
predicted_class = CIFAR100_CLASSES[top_indices[0]] |
|
|
confidence = top_probs[0] |
|
|
|
|
|
|
|
|
results_dict = {} |
|
|
for idx, prob in zip(top_indices, top_probs): |
|
|
class_name = CIFAR100_CLASSES[idx].replace('_', ' ').title() |
|
|
results_dict[class_name] = float(prob) |
|
|
|
|
|
|
|
|
confidence_pct = confidence * 100 |
|
|
|
|
|
if confidence_pct > 70: |
|
|
conf_emoji = "✅" |
|
|
conf_text = "High Confidence" |
|
|
color = "#28a745" |
|
|
elif confidence_pct > 40: |
|
|
conf_emoji = "⚠️" |
|
|
conf_text = "Medium Confidence" |
|
|
color = "#ffc107" |
|
|
else: |
|
|
conf_emoji = "❌" |
|
|
conf_text = "Low Confidence" |
|
|
color = "#dc3545" |
|
|
|
|
|
main_prediction = f""" |
|
|
<div style='padding: 20px; background-color: #f0f2f6; border-radius: 10px; text-align: center;'> |
|
|
<h2 style='color: #1f77b4; margin: 10px 0;'>Predicted Class</h2> |
|
|
<h1 style='color: #1f77b4; font-size: 2.5em; margin: 15px 0;'> |
|
|
{predicted_class.replace('_', ' ').upper()} |
|
|
</h1> |
|
|
<h2 style='color: {color}; margin: 10px 0;'> |
|
|
{conf_emoji} {confidence_pct:.2f}% |
|
|
</h2> |
|
|
<p style='color: #666; font-size: 1.2em;'>{conf_text}</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
return results_dict, main_prediction, "" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Error during prediction: {str(e)}" |
|
|
return {}, error_msg, "" |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create Gradio interface.""" |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
font-family: 'Arial', sans-serif; |
|
|
} |
|
|
.gr-button-primary { |
|
|
background: linear-gradient(90deg, #1f77b4, #9467bd) !important; |
|
|
border: none !important; |
|
|
} |
|
|
footer { |
|
|
visibility: hidden; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="CIFAR-100 Classifier", theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🖼️ CIFAR-100 Image Classifier |
|
|
|
|
|
Upload an image and the AI will classify it into one of **100 different categories** with confidence scores. |
|
|
Built with **PyTorch ResNet-34** architecture. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("## 📤 Upload Image") |
|
|
|
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="Upload an image", |
|
|
sources=["upload", "clipboard", "webcam"], |
|
|
height=400 |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("🔍 Classify Image", variant="primary", size="lg") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### ✅ Good Predictions Examples") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"apple_s_000028.png", |
|
|
"breakfast_table_s_000178.png", |
|
|
"cichlid_fish_s_000888.png" |
|
|
], |
|
|
inputs=image_input, |
|
|
label="High Confidence Predictions" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### ⚠️ Challenging Predictions Examples") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"crocodile_s_000018.png", |
|
|
"boy_s_000005.png", |
|
|
"armchair_s_000853.png" |
|
|
], |
|
|
inputs=image_input, |
|
|
label="Lower Confidence Predictions" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### 💡 Tips for Best Results |
|
|
- Use clear, well-lit images |
|
|
- Center the main object |
|
|
- Any size works (auto-resized to 32×32) |
|
|
- Supported: JPG, PNG, BMP, WEBP |
|
|
""") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("## 🎯 Classification Results") |
|
|
|
|
|
main_output = gr.HTML(label="Main Prediction") |
|
|
|
|
|
gr.Markdown("### 📊 Top 10 Predictions") |
|
|
|
|
|
label_output = gr.Label( |
|
|
num_top_classes=10, |
|
|
label="Confidence Scores", |
|
|
show_label=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown(""" |
|
|
### 🤖 Model Information |
|
|
|
|
|
- **Architecture**: ResNet-34 with Bottleneck Layers |
|
|
- **Parameters**: ~21 Million |
|
|
- **Dataset**: CIFAR-100 (60,000 images, 100 classes) |
|
|
- **Input Size**: 32×32 RGB images |
|
|
- **Categories**: Animals, vehicles, household items, nature scenes, and more |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown(""" |
|
|
### 📚 Sample Categories |
|
|
|
|
|
**Animals**: bear, dolphin, elephant, fox, leopard, tiger, whale |
|
|
**Vehicles**: bicycle, bus, motorcycle, train, tractor |
|
|
**Nature**: cloud, forest, mountain, sea, plain |
|
|
**Objects**: chair, clock, lamp, telephone, keyboard |
|
|
**Plants**: maple_tree, oak_tree, orchid, rose, sunflower |
|
|
|
|
|
*...and 75 more categories!* |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
<div style='text-align: center; color: #666; padding: 20px;'> |
|
|
<p style='font-size: 1.1em;'>Built with ❤️ using <strong>PyTorch</strong>, <strong>Gradio</strong>, and <strong>Hugging Face Spaces</strong></p> |
|
|
<p>Model: ResNet-34 trained on CIFAR-100 dataset</p> |
|
|
<p style='font-size: 0.9em; margin-top: 10px;'>Created by Krishnakanth | © 2025</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict, |
|
|
inputs=image_input, |
|
|
outputs=[label_output, main_output, gr.Textbox(visible=False)] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run the Gradio app.""" |
|
|
print("="*60) |
|
|
print("🖼️ CIFAR-100 Image Classifier") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
print("\n📦 Loading model...") |
|
|
success = load_model("cifar100_model.pth") |
|
|
|
|
|
if not success: |
|
|
print("❌ Failed to load model. Please check if cifar100_model.pth exists.") |
|
|
return |
|
|
|
|
|
print("\n🚀 Creating Gradio interface...") |
|
|
|
|
|
|
|
|
demo = create_interface() |
|
|
|
|
|
print("\n✅ Interface created successfully!") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
demo.launch( |
|
|
share=False, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|