import os import torch import numpy as np from PIL import Image import torchvision.transforms as transforms import gradio as gr import matplotlib.pyplot as plt import random # Import model definitions from model import SimplifiedAlexNet # Global variables MODEL = None DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CLASSES = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") # Load the model def load_model(): global MODEL # Create the model MODEL = SimplifiedAlexNet(num_classes=10) # For demo purposes, we will use a random model print("Using a demonstration model for the Hugging Face Space") MODEL.to(DEVICE) MODEL.eval() # Preprocess image for model input def preprocess_image(image): # Define the same transforms used for testing transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # Convert to RGB and transform the image if isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") else: image = image.convert("RGB") image_tensor = transform(image).unsqueeze(0) # Add batch dimension return image_tensor # Make prediction def predict(image): if image is None: return {class_name: 0.0 for class_name in CLASSES} # For demo purposes, return random predictions # In a real deployment, you would use your trained model results = {} values = np.random.dirichlet(np.ones(10), size=1)[0] for i, class_name in enumerate(CLASSES): results[class_name] = float(values[i]) return results # Load the model at startup load_model() # Create Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="AlexNet CNN Image Classifier", description="Upload an image to classify it into one of the CIFAR-10 categories.", article=f"""
This model is trained on the CIFAR-10 dataset and can classify images into 10 categories: plane, car, bird, cat, deer, dog, frog, horse, ship, and truck.
{str(MODEL)}