Shiva-teja-chary commited on
Commit
ebf4e14
·
verified ·
1 Parent(s): 1ee5dea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -10,7 +10,7 @@ num_features = model.fc.in_features
10
  model.fc = torch.nn.Linear(num_features, 5) # Replace the final layer for 5 classes
11
 
12
  # Load the model weights
13
- checkpoint = torch.load('shiva_flower_classification.pth', map_location=torch.device('cpu'))
14
 
15
  # Get model state_dict without the 'fc' layer
16
  state_dict = checkpoint
@@ -37,7 +37,6 @@ transform = transforms.Compose([
37
  # Prediction function
38
  def predict(image):
39
  # Preprocess the image
40
- image = Image.open(image).convert('RGB')
41
  image = transform(image).unsqueeze(0)
42
 
43
  # Predict the class
@@ -51,7 +50,7 @@ def predict(image):
51
  # Gradio Interface
52
  interface = gr.Interface(
53
  fn=predict,
54
- inputs=gr.Image(type="file"),
55
  outputs="text",
56
  title="Flower Classification",
57
  description="Upload an image of a flower to classify it into one of the five categories: daisy, dandelion, rose, sunflower, or tulip."
 
10
  model.fc = torch.nn.Linear(num_features, 5) # Replace the final layer for 5 classes
11
 
12
  # Load the model weights
13
+ checkpoint = torch.load('shiva_flower_classification.pth', map_location=torch.device('cpu'), weights_only=True)
14
 
15
  # Get model state_dict without the 'fc' layer
16
  state_dict = checkpoint
 
37
  # Prediction function
38
  def predict(image):
39
  # Preprocess the image
 
40
  image = transform(image).unsqueeze(0)
41
 
42
  # Predict the class
 
50
  # Gradio Interface
51
  interface = gr.Interface(
52
  fn=predict,
53
+ inputs=gr.Image(type="pil"),
54
  outputs="text",
55
  title="Flower Classification",
56
  description="Upload an image of a flower to classify it into one of the five categories: daisy, dandelion, rose, sunflower, or tulip."