ivantv commited on
Commit
b4fc516
·
verified ·
1 Parent(s): 61d00be

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import gradio as gr
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ CIFAR100_CLASSES = [
9
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
10
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
11
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
12
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
13
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
14
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
15
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
16
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
17
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
18
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
19
+ 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
20
+ 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
21
+ 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
22
+ 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
23
+ ]
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+ def __init__(self, in_channels, out_channels, stride=1):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
30
+ self.bn1 = nn.BatchNorm2d(out_channels)
31
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
32
+ self.bn2 = nn.BatchNorm2d(out_channels)
33
+ self.shortcut = nn.Sequential()
34
+ if stride != 1 or in_channels != self.expansion * out_channels:
35
+ self.shortcut = nn.Sequential(
36
+ nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
37
+ nn.BatchNorm2d(self.expansion * out_channels)
38
+ )
39
+ def forward(self, x):
40
+ out = torch.relu(self.bn1(self.conv1(x)))
41
+ out = self.bn2(self.conv2(out))
42
+ out += self.shortcut(x)
43
+ out = torch.relu(out)
44
+ return out
45
+
46
+ class ResNet(nn.Module):
47
+ def __init__(self, block, num_blocks, num_classes=100):
48
+ super(ResNet, self).__init__()
49
+ self.in_channels = 64
50
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
51
+ self.bn1 = nn.BatchNorm2d(64)
52
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
53
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
54
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
55
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
56
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
57
+ def _make_layer(self, block, out_channels, num_blocks, stride):
58
+ strides = [stride] + [1] * (num_blocks - 1)
59
+ layers = []
60
+ for stride in strides:
61
+ layers.append(block(self.in_channels, out_channels, stride))
62
+ self.in_channels = out_channels * block.expansion
63
+ return nn.Sequential(*layers)
64
+ def forward(self, x):
65
+ out = torch.relu(self.bn1(self.conv1(x)))
66
+ out = self.layer1(out)
67
+ out = self.layer2(out)
68
+ out = self.layer3(out)
69
+ out = self.layer4(out)
70
+ out = torch.nn.functional.avg_pool2d(out, 4)
71
+ out = out.view(out.size(0), -1)
72
+ out = self.linear(out)
73
+ return out
74
+
75
+ def ResNet18():
76
+ return ResNet(BasicBlock, [2, 2, 2, 2])
77
+
78
+ print("Loading model...")
79
+ model_path = hf_hub_download(repo_id="ivantv/cifar100-resnet18", filename="best_model.pth")
80
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ model = ResNet18().to(device)
82
+ model.load_state_dict(torch.load(model_path, map_location=device))
83
+ model.eval()
84
+ print("Model loaded!")
85
+
86
+ transform = transforms.Compose([
87
+ transforms.Resize((32, 32)),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
90
+ ])
91
+
92
+ def classify_image(image):
93
+ if image is None:
94
+ return {}
95
+ if not isinstance(image, Image.Image):
96
+ image = Image.fromarray(image)
97
+ if image.mode != 'RGB':
98
+ image = image.convert('RGB')
99
+
100
+ img_tensor = transform(image).unsqueeze(0).to(device)
101
+ with torch.no_grad():
102
+ output = model(img_tensor)
103
+ probs = torch.nn.functional.softmax(output, dim=1)[0]
104
+ top5_prob, top5_idx = torch.topk(probs, 5)
105
+
106
+ return {CIFAR100_CLASSES[idx]: prob.item() for prob, idx in zip(top5_prob, top5_idx)}
107
+
108
+ iface = gr.Interface(
109
+ fn=classify_image,
110
+ inputs=gr.Image(type="pil"),
111
+ outputs=gr.Label(num_top_classes=5),
112
+ title="CIFAR-100 Classifier",
113
+ description="ResNet-18 model trained on CIFAR-100 (75.84% accuracy)"
114
+ )
115
+
116
+ iface.launch()
117
+