Shiva-teja-chary commited on
Commit
912da9a
·
verified ·
1 Parent(s): 6ec04de
Files changed (1) hide show
  1. app.py +62 -0
app.py CHANGED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ from torchvision import models
6
+
7
+ # Define the model architecture
8
+ model = models.resnet18(weights='IMAGENET1K_V1') # Load pretrained ResNet18 from ImageNet
9
+ num_features = model.fc.in_features
10
+ model.fc = torch.nn.Linear(num_features, 5) # Replace the final layer for 5 classes
11
+
12
+ # Load the model weights
13
+ checkpoint = torch.load('shiva_flower_classification.pth', map_location=torch.device('cpu'))
14
+
15
+ # Get model state_dict without the 'fc' layer
16
+ state_dict = checkpoint
17
+
18
+ # Remove the 'fc' layer's weights from the state_dict
19
+ state_dict.pop('fc.weight', None)
20
+ state_dict.pop('fc.bias', None)
21
+
22
+ # Load the state_dict into the model
23
+ model.load_state_dict(state_dict, strict=False)
24
+
25
+ model.eval() # Set the model to evaluation mode
26
+
27
+ # Define the class labels
28
+ classes = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
29
+
30
+ # Define image transformations
31
+ transform = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
35
+ ])
36
+
37
+ # Prediction function
38
+ def predict(image):
39
+ # Preprocess the image
40
+ image = Image.open(image).convert('RGB')
41
+ image = transform(image).unsqueeze(0)
42
+
43
+ # Predict the class
44
+ with torch.no_grad():
45
+ outputs = model(image)
46
+ _, predicted = torch.max(outputs, 1)
47
+ class_name = classes[predicted.item()]
48
+
49
+ return class_name
50
+
51
+ # Gradio Interface
52
+ interface = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Image(type="file"),
55
+ outputs="text",
56
+ title="Flower Classification",
57
+ description="Upload an image of a flower to classify it into one of the five categories: daisy, dandelion, rose, sunflower, or tulip."
58
+ )
59
+
60
+ # Launch the Gradio app
61
+ if __name__ == "__main__":
62
+ interface.launch()