Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ num_features = model.fc.in_features
|
|
| 10 |
model.fc = torch.nn.Linear(num_features, 5) # Replace the final layer for 5 classes
|
| 11 |
|
| 12 |
# Load the model weights
|
| 13 |
-
checkpoint = torch.load('shiva_flower_classification.pth', map_location=torch.device('cpu'))
|
| 14 |
|
| 15 |
# Get model state_dict without the 'fc' layer
|
| 16 |
state_dict = checkpoint
|
|
@@ -37,7 +37,6 @@ transform = transforms.Compose([
|
|
| 37 |
# Prediction function
|
| 38 |
def predict(image):
|
| 39 |
# Preprocess the image
|
| 40 |
-
image = Image.open(image).convert('RGB')
|
| 41 |
image = transform(image).unsqueeze(0)
|
| 42 |
|
| 43 |
# Predict the class
|
|
@@ -51,7 +50,7 @@ def predict(image):
|
|
| 51 |
# Gradio Interface
|
| 52 |
interface = gr.Interface(
|
| 53 |
fn=predict,
|
| 54 |
-
inputs=gr.Image(type="
|
| 55 |
outputs="text",
|
| 56 |
title="Flower Classification",
|
| 57 |
description="Upload an image of a flower to classify it into one of the five categories: daisy, dandelion, rose, sunflower, or tulip."
|
|
|
|
| 10 |
model.fc = torch.nn.Linear(num_features, 5) # Replace the final layer for 5 classes
|
| 11 |
|
| 12 |
# Load the model weights
|
| 13 |
+
checkpoint = torch.load('shiva_flower_classification.pth', map_location=torch.device('cpu'), weights_only=True)
|
| 14 |
|
| 15 |
# Get model state_dict without the 'fc' layer
|
| 16 |
state_dict = checkpoint
|
|
|
|
| 37 |
# Prediction function
|
| 38 |
def predict(image):
|
| 39 |
# Preprocess the image
|
|
|
|
| 40 |
image = transform(image).unsqueeze(0)
|
| 41 |
|
| 42 |
# Predict the class
|
|
|
|
| 50 |
# Gradio Interface
|
| 51 |
interface = gr.Interface(
|
| 52 |
fn=predict,
|
| 53 |
+
inputs=gr.Image(type="pil"),
|
| 54 |
outputs="text",
|
| 55 |
title="Flower Classification",
|
| 56 |
description="Upload an image of a flower to classify it into one of the five categories: daisy, dandelion, rose, sunflower, or tulip."
|