Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from FlowerClassificationModel import FlowerClassificationModel # Replace with your model's class name | |
| from torchvision import transforms | |
| from PIL import Image | |
| # Load the model | |
| model = FlowerClassificationModel() # Instantiate your model | |
| model.load_state_dict(torch.load("flower_classification_model.pth", map_location=torch.device('cpu'))) | |
| model.eval() # Set the model to evaluation mode | |
| # Define image preprocessing | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), # Adjust to your model's input size | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Define the prediction function | |
| def classify_flower(image): | |
| # Preprocess the input image | |
| image = Image.fromarray(image) # Convert NumPy array to PIL Image | |
| input_tensor = preprocess(image).unsqueeze(0) # Add batch dimension | |
| # Perform prediction | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| _, predicted = torch.max(output, 1) | |
| # Map prediction to class label | |
| labels = ["Class1", "Class2", "Class3", "Class4", "Class5"] # Replace with your actual class names | |
| return labels[predicted.item()] | |
| # Create the Gradio interface | |
| demo = gr.Interface( | |
| fn=classify_flower, | |
| inputs="image", | |
| outputs="text", | |
| title="Flower Classification", | |
| description="Upload an image to classify the flower type." | |
| ) | |
| # Launch the app | |
| demo.launch() | |