vj1148 commited on
Commit
cc0be5b
·
verified ·
1 Parent(s): deadb5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ # Define the ResNet architecture (same as in training)
9
+ class BasicBlock(nn.Module):
10
+ expansion = 1
11
+
12
+ def __init__(self, in_channels, out_channels, stride=1):
13
+ super(BasicBlock, self).__init__()
14
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
15
+ stride=stride, padding=1, bias=False)
16
+ self.bn1 = nn.BatchNorm2d(out_channels)
17
+ self.relu = nn.ReLU(inplace=True)
18
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
19
+ stride=1, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(out_channels)
21
+
22
+ self.shortcut = nn.Sequential()
23
+ if stride != 1 or in_channels != out_channels:
24
+ self.shortcut = nn.Sequential(
25
+ nn.Conv2d(in_channels, out_channels, kernel_size=1,
26
+ stride=stride, bias=False),
27
+ nn.BatchNorm2d(out_channels)
28
+ )
29
+
30
+ def forward(self, x):
31
+ out = self.relu(self.bn1(self.conv1(x)))
32
+ out = self.bn2(self.conv2(out))
33
+ out += self.shortcut(x)
34
+ out = self.relu(out)
35
+ return out
36
+
37
+ class ResNet(nn.Module):
38
+ def __init__(self, block, num_blocks, num_classes=100):
39
+ super(ResNet, self).__init__()
40
+ self.in_channels = 64
41
+
42
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
43
+ self.bn1 = nn.BatchNorm2d(64)
44
+ self.relu = nn.ReLU(inplace=True)
45
+
46
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
47
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
48
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
49
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
50
+
51
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
52
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
53
+
54
+ def _make_layer(self, block, out_channels, num_blocks, stride):
55
+ strides = [stride] + [1]*(num_blocks-1)
56
+ layers = []
57
+ for stride in strides:
58
+ layers.append(block(self.in_channels, out_channels, stride))
59
+ self.in_channels = out_channels * block.expansion
60
+ return nn.Sequential(*layers)
61
+
62
+ def forward(self, x):
63
+ out = self.relu(self.bn1(self.conv1(x)))
64
+ out = self.layer1(out)
65
+ out = self.layer2(out)
66
+ out = self.layer3(out)
67
+ out = self.layer4(out)
68
+ out = self.avgpool(out)
69
+ out = torch.flatten(out, 1)
70
+ out = self.fc(out)
71
+ return out
72
+
73
+ def ResNet34():
74
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=100)
75
+
76
+ # CIFAR-100 class names
77
+ CIFAR100_CLASSES = [
78
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
79
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
80
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
81
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
82
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
83
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
84
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
85
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
86
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
87
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
88
+ 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
89
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
90
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
91
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
92
+ ]
93
+
94
+ # Load model
95
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96
+ model = ResNet34()
97
+
98
+ # Load checkpoint
99
+ checkpoint = torch.load('cifar100_resnet34_final.pth', map_location=device)
100
+ model.load_state_dict(checkpoint['model_state_dict'])
101
+ model.to(device)
102
+ model.eval()
103
+
104
+ # Get normalization parameters from checkpoint
105
+ mean = checkpoint.get('normalization_mean', [0.5071, 0.4867, 0.4408])
106
+ std = checkpoint.get('normalization_std', [0.2675, 0.2565, 0.2761])
107
+
108
+ # Define transforms
109
+ transform = transforms.Compose([
110
+ transforms.Resize((32, 32)),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize(mean, std),
113
+ ])
114
+
115
+ def predict(image):
116
+ """
117
+ Predict the class of an image using the trained ResNet model.
118
+ """
119
+ if image is None:
120
+ return None
121
+
122
+ # Convert to PIL Image if necessary
123
+ if isinstance(image, np.ndarray):
124
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
125
+
126
+ # Apply transforms
127
+ img_tensor = transform(image).unsqueeze(0).to(device)
128
+
129
+ # Make prediction
130
+ with torch.no_grad():
131
+ outputs = model(img_tensor)
132
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
133
+
134
+ # Get top 5 predictions
135
+ top5_prob, top5_idx = torch.topk(probabilities[0], 5)
136
+
137
+ # Create dictionary of top 5 predictions
138
+ predictions = {}
139
+ for i in range(5):
140
+ class_name = CIFAR100_CLASSES[top5_idx[i].item()]
141
+ probability = top5_prob[i].item()
142
+ predictions[class_name] = float(probability)
143
+
144
+ return predictions
145
+
146
+ # Create Gradio interface
147
+ title = "CIFAR-100 Image Classifier (ResNet34)"
148
+ description = """
149
+ This is a ResNet34 model trained on CIFAR-100 dataset with 80%+ accuracy.
150
+ Upload an image to classify it into one of 100 categories.
151
+
152
+ The model works best with:
153
+ - Natural images (animals, objects, vehicles, etc.)
154
+ - Images with clear subjects
155
+ - Square aspect ratio images
156
+
157
+ Note: The model was trained on 32x32 images, so very high resolution details might not be fully utilized.
158
+ """
159
+
160
+ examples = [
161
+ # You can add example image paths here if you have them
162
+ ]
163
+
164
+ # Create the interface
165
+ iface = gr.Interface(
166
+ fn=predict,
167
+ inputs=gr.Image(type="pil", label="Upload Image"),
168
+ outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
169
+ title=title,
170
+ description=description,
171
+ examples=examples if examples else None,
172
+ theme="default",
173
+ allow_flagging="never"
174
+ )
175
+
176
+ # Launch the app
177
+ if __name__ == "__main__":
178
+ iface.launch()