Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import random | |
| # Import model definitions | |
| from model import SimplifiedAlexNet | |
| # Global variables | |
| MODEL = None | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CLASSES = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") | |
| # Load the model | |
| def load_model(): | |
| global MODEL | |
| # Create the model | |
| MODEL = SimplifiedAlexNet(num_classes=10) | |
| # For demo purposes, we will use a random model | |
| print("Using a demonstration model for the Hugging Face Space") | |
| MODEL.to(DEVICE) | |
| MODEL.eval() | |
| # Preprocess image for model input | |
| def preprocess_image(image): | |
| # Define the same transforms used for testing | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |
| ]) | |
| # Convert to RGB and transform the image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| return image_tensor | |
| # Make prediction | |
| def predict(image): | |
| if image is None: | |
| return {class_name: 0.0 for class_name in CLASSES} | |
| # For demo purposes, return random predictions | |
| # In a real deployment, you would use your trained model | |
| results = {} | |
| values = np.random.dirichlet(np.ones(10), size=1)[0] | |
| for i, class_name in enumerate(CLASSES): | |
| results[class_name] = float(values[i]) | |
| return results | |
| # Load the model at startup | |
| load_model() | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="AlexNet CNN Image Classifier", | |
| description="Upload an image to classify it into one of the CIFAR-10 categories.", | |
| article=f""" | |
| <div> | |
| <h3>Model Information</h3> | |
| <p>This model is trained on the CIFAR-10 dataset and can classify images into 10 categories: | |
| plane, car, bird, cat, deer, dog, frog, horse, ship, and truck.</p> | |
| <h3>Model Architecture</h3> | |
| <pre>{str(MODEL)}</pre> | |
| <h3>Model Parameters</h3> | |
| <ul> | |
| <li>Total parameters: {sum(p.numel() for p in MODEL.parameters()):,}</li> | |
| <li>Trainable parameters: {sum(p.numel() for p in MODEL.parameters() if p.requires_grad):,}</li> | |
| </ul> | |
| </div> | |
| """, | |
| examples=[ | |
| ["examples/airplane.jpg"], | |
| ["examples/automobile.jpg"], | |
| ["examples/cat.jpg"] | |
| ], | |
| flagging_mode="never" | |
| ) | |
| # Launch the app | |
| demo.launch() | |