Prashant26am's picture
Upload app.py with huggingface_hub
f3a3039 verified
import os
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import gradio as gr
import matplotlib.pyplot as plt
import random
# Import model definitions
from model import SimplifiedAlexNet
# Global variables
MODEL = None
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
# Load the model
def load_model():
global MODEL
# Create the model
MODEL = SimplifiedAlexNet(num_classes=10)
# For demo purposes, we will use a random model
print("Using a demonstration model for the Hugging Face Space")
MODEL.to(DEVICE)
MODEL.eval()
# Preprocess image for model input
def preprocess_image(image):
# Define the same transforms used for testing
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# Convert to RGB and transform the image
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
return image_tensor
# Make prediction
def predict(image):
if image is None:
return {class_name: 0.0 for class_name in CLASSES}
# For demo purposes, return random predictions
# In a real deployment, you would use your trained model
results = {}
values = np.random.dirichlet(np.ones(10), size=1)[0]
for i, class_name in enumerate(CLASSES):
results[class_name] = float(values[i])
return results
# Load the model at startup
load_model()
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="AlexNet CNN Image Classifier",
description="Upload an image to classify it into one of the CIFAR-10 categories.",
article=f"""
<div>
<h3>Model Information</h3>
<p>This model is trained on the CIFAR-10 dataset and can classify images into 10 categories:
plane, car, bird, cat, deer, dog, frog, horse, ship, and truck.</p>
<h3>Model Architecture</h3>
<pre>{str(MODEL)}</pre>
<h3>Model Parameters</h3>
<ul>
<li>Total parameters: {sum(p.numel() for p in MODEL.parameters()):,}</li>
<li>Trainable parameters: {sum(p.numel() for p in MODEL.parameters() if p.requires_grad):,}</li>
</ul>
</div>
""",
examples=[
["examples/airplane.jpg"],
["examples/automobile.jpg"],
["examples/cat.jpg"]
],
flagging_mode="never"
)
# Launch the app
demo.launch()