trafaqat commited on
Commit
481e120
·
verified ·
1 Parent(s): eb61e20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -141
app.py CHANGED
@@ -4,149 +4,15 @@ from huggingface_hub import hf_hub_download
4
  import torch
5
  import torch.nn as nn
6
  from torchvision import transforms
 
7
 
8
-
9
- class SimpleResidualBlock(nn.Module):
10
- def __init__(self, in_channels, out_channels, set_stride=False):
11
- super().__init__()
12
- stride = 2 if in_channels != out_channels and set_stride else 1
13
-
14
- self.conv1 = nn.LazyConv2d(
15
- out_channels,
16
- kernel_size=3,
17
- padding="same" if stride == 1 else 1,
18
- stride=stride,
19
- )
20
- self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
21
-
22
- self.bn1 = nn.LazyBatchNorm2d()
23
- self.bn2 = nn.LazyBatchNorm2d()
24
-
25
- self.relu = nn.ReLU()
26
-
27
- if in_channels != out_channels:
28
- self.residual = nn.Sequential(
29
- nn.LazyConv2d(out_channels, kernel_size=1, stride=stride),
30
- nn.LazyBatchNorm2d(),
31
- )
32
- else:
33
- self.residual = nn.Identity()
34
-
35
- def forward(self, x):
36
- out = self.relu(self.bn1(self.conv1(x)))
37
- out = self.bn2(self.conv2(out))
38
- out += self.residual(x)
39
- out = self.relu(out)
40
- return out
41
-
42
-
43
- class BottleneckResidualBlock(nn.Module):
44
- def __init__(
45
- self, in_channels, out_channels, identity_mapping=False, set_stride=False
46
- ):
47
- super().__init__()
48
- stride = 2 if in_channels != out_channels and set_stride else 1
49
-
50
- self.conv1 = nn.LazyConv2d(
51
- out_channels,
52
- kernel_size=1,
53
- padding="same" if stride == 1 else 0,
54
- stride=stride,
55
- )
56
- self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding="same")
57
- self.conv3 = nn.LazyConv2d(out_channels * 4, kernel_size=1, padding="same")
58
-
59
- self.bn1 = nn.LazyBatchNorm2d()
60
- self.bn2 = nn.LazyBatchNorm2d()
61
- self.bn3 = nn.LazyBatchNorm2d()
62
-
63
- self.relu = nn.ReLU()
64
-
65
- if in_channels != out_channels or not identity_mapping:
66
- self.residual = nn.Sequential(
67
- nn.LazyConv2d(out_channels * 4, kernel_size=1, stride=stride),
68
- nn.LazyBatchNorm2d(),
69
- )
70
- else:
71
- self.residual = nn.Identity()
72
-
73
- def forward(self, x):
74
- out = self.relu(self.bn1(self.conv1(x)))
75
- out = self.relu(self.bn2(self.conv2(out)))
76
- out = self.bn3(self.conv3(out))
77
- out += self.residual(x)
78
- out = self.relu(out)
79
- return out
80
-
81
-
82
- RESNET_18 = [2, 2, 2, 2]
83
- RESNET_34 = [3, 4, 6, 3]
84
- RESNET_50 = [3, 4, 6, 3]
85
- RESNET_101 = [3, 4, 23, 3]
86
- RESNET_152 = [3, 8, 36, 3]
87
-
88
-
89
- class ResNet(nn.Module):
90
- def __init__(self, arch=RESNET_18, block="simple", num_classes=256):
91
- super().__init__()
92
- self.conv1 = nn.Sequential(
93
- nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
94
- nn.LazyBatchNorm2d(),
95
- nn.ReLU(),
96
- )
97
- self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
98
- self.conv2 = self._make_layer(64, 64, arch[0], set_stride=False, block=block)
99
- self.conv3 = self._make_layer(64, 128, arch[1], block=block)
100
- self.conv4 = self._make_layer(128, 256, arch[2], block=block)
101
- self.conv5 = self._make_layer(256, 512, arch[3], block=block)
102
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
103
- self.flatten = nn.Flatten()
104
- self.fc = nn.LazyLinear(num_classes)
105
-
106
- def _make_layer(
107
- self, in_channels, out_channels, num_blocks, set_stride=True, block="simple"
108
- ):
109
- """Block is either 'simple' or 'bottleneck'"""
110
- layers = []
111
- for i in range(num_blocks):
112
- layers.append(
113
- SimpleResidualBlock(in_channels, out_channels, set_stride=set_stride)
114
- if block == "simple"
115
- else BottleneckResidualBlock(
116
- in_channels if i == 0 else out_channels * 4,
117
- out_channels,
118
- set_stride=set_stride,
119
- )
120
- )
121
- set_stride = False
122
- return nn.Sequential(*layers)
123
-
124
- def forward(self, x):
125
- out = self.conv1(x)
126
- out = self.maxpool(self.conv2(out))
127
- out = self.conv3(out)
128
- out = self.conv4(out)
129
- out = self.conv5(out)
130
- out = self.avgpool(out)
131
- out = self.flatten(out)
132
- out = self.fc(out)
133
- return out
134
-
135
- def _init_weights(module):
136
- # Initlize weights with glorot uniform
137
- if isinstance(module, nn.Conv2d):
138
- nn.init.xavier_uniform_(module.weight)
139
- nn.init.zeros_(module.bias)
140
- elif isinstance(module, nn.Linear):
141
- nn.init.xavier_uniform_(module.weight)
142
- nn.init.zeros_(module.bias)
143
-
144
 
145
  class ImageClassifier:
146
  def __init__(self, checkpoint_path):
147
  self.checkpoint_path = checkpoint_path
148
  self.model = self.load_model(checkpoint_path)
149
- self.transform = self.get_transform((244, 244))
150
  self.labels = [
151
  "airplane",
152
  "automobile",
@@ -166,7 +32,7 @@ class ImageClassifier:
166
  block="simple",
167
  num_classes=10,
168
  )
169
- classifier.load_state_dict(torch.load(checkpoint_path))
170
  classifier = classifier.cpu()
171
  classifier.eval()
172
  return classifier
@@ -190,18 +56,18 @@ class ImageClassifier:
190
  def classify(self, input_image):
191
  return self.predict(input_image)
192
 
193
-
194
  def classify(input_image):
195
  return classifier.classify(input_image)
196
 
197
-
198
  checkpoint_path = hf_hub_download(
199
  repo_id="SatwikKambham/resnet18-cifar10",
200
  filename="model.pt",
201
  )
 
202
  classifier = ImageClassifier(checkpoint_path)
 
203
  iface = gr.Interface(
204
- classify,
205
  inputs=[
206
  gr.Image(label="Input Image", type="pil"),
207
  ],
 
4
  import torch
5
  import torch.nn as nn
6
  from torchvision import transforms
7
+ from PIL import Image
8
 
9
+ # Define the model classes (SimpleResidualBlock, BottleneckResidualBlock, ResNet) here...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class ImageClassifier:
12
  def __init__(self, checkpoint_path):
13
  self.checkpoint_path = checkpoint_path
14
  self.model = self.load_model(checkpoint_path)
15
+ self.transform = self.get_transform((224, 224)) # Typical size for ResNet
16
  self.labels = [
17
  "airplane",
18
  "automobile",
 
32
  block="simple",
33
  num_classes=10,
34
  )
35
+ classifier.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
36
  classifier = classifier.cpu()
37
  classifier.eval()
38
  return classifier
 
56
  def classify(self, input_image):
57
  return self.predict(input_image)
58
 
 
59
  def classify(input_image):
60
  return classifier.classify(input_image)
61
 
 
62
  checkpoint_path = hf_hub_download(
63
  repo_id="SatwikKambham/resnet18-cifar10",
64
  filename="model.pt",
65
  )
66
+
67
  classifier = ImageClassifier(checkpoint_path)
68
+
69
  iface = gr.Interface(
70
+ fn=classify,
71
  inputs=[
72
  gr.Image(label="Input Image", type="pil"),
73
  ],