flower-vision / app.py
faranbutt789's picture
Update app.py
875fc02 verified
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from model import create_model
image_height, image_width = 224, 224
# Load checkpoint
checkpoint = torch.load("flower_checkpoint.pth", map_location="cpu")
model = create_model(num_classes=len(checkpoint["classes"]), dropout=checkpoint["dropout"])
model.load_state_dict(checkpoint["model_state"])
model.eval()
classes = checkpoint["classes"] # Make sure classes are defined
# Transform
transform = transforms.Compose([
transforms.Resize((image_height, image_width)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Prediction function
def predict(image):
# Ensure PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
image = transform(image).unsqueeze(0) # add batch dimension
with torch.no_grad():
outputs = model(image)
probs = torch.nn.functional.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
# Return all class probabilities
return {classes[i]: float(probs[0][i]) for i in range(len(classes))}
# Gradio app
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"), # use PIL to avoid numpy issues
outputs=gr.Label(num_top_classes=5),
title="Flower Classification",
description="Upload a flower image to classify."
)
if __name__ == "__main__":
demo.launch()