testappgpt / app.py
Hameed0342j's picture
Update app.py
7bd18af verified
import gradio as gr
import torch
from FlowerClassificationModel import FlowerClassificationModel # Replace with your model's class name
from torchvision import transforms
from PIL import Image
# Load the model
model = FlowerClassificationModel() # Instantiate your model
model.load_state_dict(torch.load("flower_classification_model.pth", map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
# Define image preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Adjust to your model's input size
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the prediction function
def classify_flower(image):
# Preprocess the input image
image = Image.fromarray(image) # Convert NumPy array to PIL Image
input_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Perform prediction
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output, 1)
# Map prediction to class label
labels = ["Class1", "Class2", "Class3", "Class4", "Class5"] # Replace with your actual class names
return labels[predicted.item()]
# Create the Gradio interface
demo = gr.Interface(
fn=classify_flower,
inputs="image",
outputs="text",
title="Flower Classification",
description="Upload an image to classify the flower type."
)
# Launch the app
demo.launch()