air / app.py
grfdjiwsd's picture
Create app.py
b5bb51f verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
# Import your model class from the model.py file
from model import SimpleCNN
# --- 1. SETUP ---
# Define the path to your model and the number of classes
MODEL_PATH = "air_analyzer_cnn_iden_7m.pth" # Make sure this file is in your repository
NUM_CLASSES = 3
# Define your class names
class_names = ["Cat", "Dog", "Bird"]
# --- 2. LOAD THE MODEL ---
# Instantiate the model (must be the same architecture as the one you saved)
model = SimpleCNN(num_classes=NUM_CLASSES)
# Load the trained weights
# Use map_location=torch.device('cpu') to ensure the model runs on the CPU
# This is crucial for Hugging Face Spaces' free tier
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
# Set the model to evaluation mode
model.eval()
# --- 3. DEFINE IMAGE TRANSFORMATIONS ---
# This should be the same transformation as used during training/validation
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize the image to 224x224 pixels
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize
])
# --- 4. DEFINE THE PREDICTION FUNCTION ---
def predict(input_image: Image.Image):
"""
Takes a PIL image, processes it, and returns a dictionary of class probabilities.
"""
# Apply the transformations to the input image
image_tensor = transform(input_image).unsqueeze(0) # Add a batch dimension
# Make a prediction
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(image_tensor)
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# Create a dictionary of class names and their probabilities
confidences = {class_names[i]: float(prob) for i, prob in enumerate(probabilities)}
return confidences
# --- 5. CREATE THE GRADIO INTERFACE ---
# Define the input and output components
image_input = gr.Image(type="pil", label="Upload an Image")
label_output = gr.Label(num_top_classes=3, label="Predictions")
# Example images (optional but highly recommended)
# Make sure you upload these images to your Space repository
example_images = [
"sample_cat.jpg",
"sample_dog.jpg",
"sample_bird.jpg"
]
# Create and launch the interface
iface = gr.Interface(
fn=predict,
inputs=image_input,
outputs=label_output,
title="Image Classifier",
description="Upload an image of a cat, dog, or bird to see the model's prediction.",
examples=example_images
)
if __name__ == "__main__":
iface.launch()