import gradio as gr import torch from torchvision import transforms from PIL import Image # Import your model class from the model.py file from model import SimpleCNN # --- 1. SETUP --- # Define the path to your model and the number of classes MODEL_PATH = "air_analyzer_cnn_iden_7m.pth" # Make sure this file is in your repository NUM_CLASSES = 3 # Define your class names class_names = ["Cat", "Dog", "Bird"] # --- 2. LOAD THE MODEL --- # Instantiate the model (must be the same architecture as the one you saved) model = SimpleCNN(num_classes=NUM_CLASSES) # Load the trained weights # Use map_location=torch.device('cpu') to ensure the model runs on the CPU # This is crucial for Hugging Face Spaces' free tier model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) # Set the model to evaluation mode model.eval() # --- 3. DEFINE IMAGE TRANSFORMATIONS --- # This should be the same transformation as used during training/validation transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize the image to 224x224 pixels transforms.ToTensor(), # Convert the image to a PyTorch tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize ]) # --- 4. DEFINE THE PREDICTION FUNCTION --- def predict(input_image: Image.Image): """ Takes a PIL image, processes it, and returns a dictionary of class probabilities. """ # Apply the transformations to the input image image_tensor = transform(input_image).unsqueeze(0) # Add a batch dimension # Make a prediction with torch.no_grad(): # Disable gradient calculation for inference outputs = model(image_tensor) # Apply softmax to get probabilities probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # Create a dictionary of class names and their probabilities confidences = {class_names[i]: float(prob) for i, prob in enumerate(probabilities)} return confidences # --- 5. CREATE THE GRADIO INTERFACE --- # Define the input and output components image_input = gr.Image(type="pil", label="Upload an Image") label_output = gr.Label(num_top_classes=3, label="Predictions") # Example images (optional but highly recommended) # Make sure you upload these images to your Space repository example_images = [ "sample_cat.jpg", "sample_dog.jpg", "sample_bird.jpg" ] # Create and launch the interface iface = gr.Interface( fn=predict, inputs=image_input, outputs=label_output, title="Image Classifier", description="Upload an image of a cat, dog, or bird to see the model's prediction.", examples=example_images ) if __name__ == "__main__": iface.launch()