Uzairabbasi commited on
Commit
433d89f
·
verified ·
1 Parent(s): d3d7d5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py CHANGED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ # Define the complex CNN model
8
+ class ComplexCNN(nn.Module):
9
+ def __init__(self, num_classes=4):
10
+ super(ComplexCNN, self).__init__()
11
+ self.layer1 = nn.Sequential(
12
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
13
+ nn.BatchNorm2d(64),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
16
+ )
17
+ self.layer2 = nn.Sequential(
18
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
19
+ nn.BatchNorm2d(128),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
22
+ )
23
+ self.layer3 = nn.Sequential(
24
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
25
+ nn.BatchNorm2d(256),
26
+ nn.ReLU(),
27
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
28
+ )
29
+ self.layer4 = nn.Sequential(
30
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
31
+ nn.BatchNorm2d(512),
32
+ nn.ReLU(),
33
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
34
+ )
35
+ self.fc1 = nn.Linear(512 * 2 * 2, 1024) # Adjust based on the final feature map size
36
+ self.fc2 = nn.Linear(1024, 512)
37
+ self.fc3 = nn.Linear(512, num_classes)
38
+ self.relu = nn.ReLU()
39
+ self.dropout = nn.Dropout(p=0.5)
40
+
41
+ def forward(self, x):
42
+ x = self.layer1(x)
43
+ x = self.layer2(x)
44
+ x = self.layer3(x)
45
+ x = self.layer4(x)
46
+ x = x.view(x.size(0), -1) # Flatten the output
47
+ x = self.dropout(self.relu(self.fc1(x)))
48
+ x = self.relu(self.fc2(x))
49
+ x = self.fc3(x)
50
+ return x
51
+
52
+ # Load the trained model
53
+ model = ComplexCNN(num_classes=4)
54
+ model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
55
+ model.eval()
56
+
57
+ # Define the transformation
58
+ transform = transforms.Compose([
59
+ transforms.Resize((32, 32)),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
62
+ ])
63
+
64
+ # Define the class labels
65
+ class_labels = ['dog', 'goat', 'lion', 'sheep']
66
+
67
+ # Function to predict the class of an uploaded image
68
+ def predict(image):
69
+ image = transform(image).unsqueeze(0) # Transform and add batch dimension
70
+ with torch.no_grad():
71
+ outputs = model(image)
72
+ _, predicted = torch.max(outputs, 1)
73
+ predicted_class = class_labels[predicted.item()]
74
+ return predicted_class
75
+
76
+ # Create the Gradio interface
77
+ interface = gr.Interface(
78
+ fn=predict,
79
+ inputs=gr.Image(type="pil"),
80
+ outputs=gr.Textbox(label="Predicted Class")
81
+ )
82
+
83
+ # Launch the Gradio app
84
+ if __name__ == "__main__":
85
+ interface.launch()