Spaces:
Sleeping
Sleeping
File size: 4,845 Bytes
c65e61c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from models.custom_cnn import create_custom_cnn
from models.resnet18 import load_resnet18
from utils.data_loader import get_cifar10_info
# CIFAR-10 preprocessing
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2023, 0.1994, 0.2010)
def preprocess_image(image):
"""Preprocess uploaded image for CIFAR-10 model."""
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
])
if image.mode != 'RGB':
image = image.convert('RGB')
return transform(image).unsqueeze(0)
def predict_image(image, model_choice="CustomCNN"):
"""
Predict class label for uploaded image.
Args:
image: PIL Image uploaded by user
model_choice: Which model to use for prediction
Returns:
str: Formatted prediction result
"""
try:
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if model_choice == "CustomCNN":
model = create_custom_cnn()
model_path = "best_model_custom.pth"
else:
model = load_resnet18()
model_path = "best_model_resnet18.pth"
# Load trained weights
try:
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
except FileNotFoundError:
return f"❌ Model weights not found: {model_path}\nTrain the model first!"
model.to(device)
model.eval()
# Preprocess image
input_tensor = preprocess_image(image).to(device)
# Make prediction
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
# Get class info
cifar10_info = get_cifar10_info()
class_names = cifar10_info['class_names']
predicted_class = class_names[predicted.item()]
confidence_score = confidence.item() * 100
# Format result with top-3 predictions
top3_prob, top3_indices = torch.topk(probabilities, 3)
result = f"🎯 **Prediction: {predicted_class}** ({confidence_score:.1f}% confidence)\n\n"
result += f"📊 **Top 3 Predictions:**\n"
for i in range(3):
class_name = class_names[top3_indices[0][i].item()]
prob = top3_prob[0][i].item() * 100
result += f"{i+1}. {class_name}: {prob:.1f}%\n"
result += f"\n🔧 **Model:** {model_choice}\n"
result += f"📱 **Device:** {device}"
return result
except Exception as e:
return f"❌ **Error:** {str(e)}\n\nPlease ensure the image is valid and model is trained."
def create_gradio_interface():
"""Create and launch Gradio interface."""
# Custom CSS for better styling
css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.output-text {
font-family: 'Courier New', monospace;
font-size: 14px;
}
"""
# Create interface
interface = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(type="pil", label="Upload Image", height=300),
gr.Dropdown(
choices=["CustomCNN", "ResNet18"],
value="CustomCNN",
label="Select Model"
)
],
outputs=gr.Textbox(
label="Prediction Result",
lines=10,
elem_classes=["output-text"]
),
title="🧠 CIFAR-10 CNN Benchmark",
description="""
Upload an image to test our trained models!
**Models:**
- **CustomCNN**: Lightweight 3M parameter model
- **ResNet18**: Standard 11M parameter baseline
**Note:** Images will be resized to 32x32 pixels (CIFAR-10 format)
**CIFAR-10 Classes:** airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
""",
css=css,
theme=gr.themes.Soft(),
allow_flagging="never"
)
return interface
if __name__ == "__main__":
# Create and launch interface
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
) |