Jagjeet2003 commited on
Commit
55b8a37
·
verified ·
1 Parent(s): 0fd920d

Upload 3 files

Browse files
Files changed (3) hide show
  1. Modified_ALexnet_for_CIFAR.pth +3 -0
  2. app.py +30 -0
  3. model.py +71 -0
Modified_ALexnet_for_CIFAR.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2363681901b216dabb565c232ddbceb945c2092e6ad7d41c5b5f540342667ea
3
+ size 30058715
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from model import ALexNet # Make sure this matches your actual class name
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+
7
+ # Load model
8
+ model = ALexNet(3, 64, 10)
9
+ model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu")))
10
+ model.eval()
11
+
12
+ # Preprocessing
13
+ transform = transforms.Compose([
14
+ transforms.Resize((32, 32)), # Adjust to your model's input size
15
+ transforms.ToTensor()
16
+ ])
17
+
18
+ # Inference function
19
+ def predict(img):
20
+ img = transform(img).unsqueeze(0)
21
+ with torch.no_grad():
22
+ outputs = model(img)
23
+ predicted_class = torch.argmax(outputs, dim=1).item()
24
+ class_names = ["airplane", "automobile", "bird", "cat", "deer",
25
+ "dog", "frog", "horse", "ship", "truck"]
26
+
27
+ return f"Predicted class: {class_names[predicted_class]}"
28
+
29
+ # Gradio UI
30
+ gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()
model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class ALexNet(nn.Module):
5
+ def __init__(self, input_shape: int, hidden_units: int, output_shape):
6
+ super().__init__()
7
+ self.block1 = nn.Sequential(
8
+ nn.Conv2d(input_shape, 64, kernel_size=3, padding=1),
9
+ nn.BatchNorm2d(64),
10
+ nn.ReLU(),
11
+ nn.MaxPool2d(2, 2)
12
+ )
13
+ self.block2 = nn.Sequential(
14
+ nn.Conv2d(64, 192, kernel_size=3, padding=1),
15
+ nn.BatchNorm2d(192),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2, 2)
18
+ )
19
+ self.block3 = nn.Sequential(
20
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
21
+ nn.BatchNorm2d(384),
22
+ nn.ReLU()
23
+ )
24
+ self.block4 = nn.Sequential(
25
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
26
+ nn.BatchNorm2d(256),
27
+ nn.ReLU()
28
+ )
29
+ self.block5 = nn.Sequential(
30
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
31
+ nn.BatchNorm2d(256),
32
+ nn.ReLU(),
33
+ nn.MaxPool2d(2, 2)
34
+ )
35
+
36
+ with torch.no_grad():
37
+ dummy = torch.zeros(1, input_shape, 32, 32) # change 224 if needed
38
+ x = self.block1(dummy)
39
+ x = self.block2(x)
40
+ x = self.block3(x)
41
+ x = self.block4(x)
42
+ x = self.block5(x)
43
+ self.flattened_size = x.view(1, -1).shape[1]
44
+ self.flatten = nn.Flatten()
45
+ self.fc1 = nn.Sequential(
46
+ nn.Linear(in_features=self.flattened_size,
47
+ out_features=1024),
48
+ nn.ReLU(),
49
+ nn.Dropout(0.5)
50
+ )
51
+ self.fc2 = nn.Sequential(
52
+ nn.Linear(1024, 1024),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.5)
55
+ )
56
+ self.classifier = nn.Sequential(
57
+ nn.Linear(1024, output_shape)
58
+ )
59
+
60
+ def forward(self, x: torch.Tensor):
61
+ x = self.block1(x)
62
+ x = self.block2(x)
63
+ x = self.block3(x)
64
+ x = self.block4(x)
65
+ x = self.block5(x)
66
+ x = self.flatten(x)
67
+ x = self.fc1(x)
68
+ x = self.fc2(x)
69
+ x = self.classifier(x)
70
+
71
+ return x