Shiva-teja-chary's picture
Update app.py
ebf4e14 verified
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
# Define the model architecture
model = models.resnet18(weights='IMAGENET1K_V1') # Load pretrained ResNet18 from ImageNet
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 5) # Replace the final layer for 5 classes
# Load the model weights
checkpoint = torch.load('shiva_flower_classification.pth', map_location=torch.device('cpu'), weights_only=True)
# Get model state_dict without the 'fc' layer
state_dict = checkpoint
# Remove the 'fc' layer's weights from the state_dict
state_dict.pop('fc.weight', None)
state_dict.pop('fc.bias', None)
# Load the state_dict into the model
model.load_state_dict(state_dict, strict=False)
model.eval() # Set the model to evaluation mode
# Define the class labels
classes = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Prediction function
def predict(image):
# Preprocess the image
image = transform(image).unsqueeze(0)
# Predict the class
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
class_name = classes[predicted.item()]
return class_name
# Gradio Interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Flower Classification",
description="Upload an image of a flower to classify it into one of the five categories: daisy, dandelion, rose, sunflower, or tulip."
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()