Aumkeshchy2003 commited on
Commit
3b058d0
·
verified ·
1 Parent(s): 03d5776

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from einops import rearrange
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import math
9
+
10
+ # ------------------------
11
+ # Configuration (must match your trained model)
12
+ cfg = {
13
+ "image_size": 32,
14
+ "patch_size": 4,
15
+ "in_channels": 3,
16
+ "num_classes": 100,
17
+ "emb_dim": 192,
18
+ "num_heads": 6,
19
+ "depth": 8,
20
+ "mlp_ratio": 4.0,
21
+ "drop": 0.1
22
+ }
23
+
24
+ # CIFAR-100 class names
25
+ classes = [
26
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
27
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
28
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
29
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
30
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
31
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
32
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
33
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
34
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
35
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea',
36
+ 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
37
+ 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank',
38
+ 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip',
39
+ 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
40
+ ]
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ # ------------------------
45
+ # Model definition
46
+ class PatchEmbed(nn.Module):
47
+ def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
48
+ super().__init__()
49
+ self.patch_size = patch_size
50
+ self.n_patches = (img_size // patch_size) ** 2
51
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
52
+
53
+ def forward(self, x):
54
+ x = self.proj(x)
55
+ x = x.flatten(2).transpose(1,2)
56
+ return x
57
+
58
+ class MLP(nn.Module):
59
+ def __init__(self, in_features, hidden_features=None, drop=0.):
60
+ super().__init__()
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features)
63
+ self.act = nn.GELU()
64
+ self.fc2 = nn.Linear(hidden_features, in_features)
65
+ self.drop = nn.Dropout(drop)
66
+ def forward(self, x):
67
+ x = self.fc1(x)
68
+ x = self.act(x)
69
+ x = self.drop(x)
70
+ x = self.fc2(x)
71
+ x = self.drop(x)
72
+ return x
73
+
74
+ class Attention(nn.Module):
75
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
76
+ super().__init__()
77
+ self.num_heads = num_heads
78
+ head_dim = dim // num_heads
79
+ self.scale = head_dim ** -0.5
80
+ self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
81
+ self.attn_drop = nn.Dropout(attn_drop)
82
+ self.proj = nn.Linear(dim, dim)
83
+ self.proj_drop = nn.Dropout(proj_drop)
84
+ def forward(self, x):
85
+ B, N, C = x.shape
86
+ qkv = self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)
87
+ q,k,v = qkv[0], qkv[1], qkv[2]
88
+ attn = (q @ k.transpose(-2,-1)) * self.scale
89
+ attn = attn.softmax(dim=-1)
90
+ attn = self.attn_drop(attn)
91
+ x = (attn @ v).transpose(1,2).reshape(B,N,C)
92
+ x = self.proj(x)
93
+ x = self.proj_drop(x)
94
+ return x
95
+
96
+ class _StochasticDepth(nn.Module):
97
+ def __init__(self, p):
98
+ super().__init__()
99
+ self.p = p
100
+ def forward(self, x):
101
+ if not self.training or self.p==0.:
102
+ return x
103
+ keep = torch.rand(x.shape[0],1,1, device=x.device) >= self.p
104
+ return x * keep / (1 - self.p)
105
+
106
+ class Block(nn.Module):
107
+ def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., drop_path=0.):
108
+ super().__init__()
109
+ self.norm1 = nn.LayerNorm(dim)
110
+ self.attn = Attention(dim, num_heads=num_heads, attn_drop=drop, proj_drop=drop)
111
+ self.drop_path = nn.Identity() if drop_path==0. else _StochasticDepth(drop_path)
112
+ self.norm2 = nn.LayerNorm(dim)
113
+ self.mlp = MLP(dim, int(dim*mlp_ratio), drop=drop)
114
+ def forward(self, x):
115
+ x = x + self.drop_path(self.attn(self.norm1(x)))
116
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
117
+ return x
118
+
119
+ class ViT(nn.Module):
120
+ def __init__(self, cfg):
121
+ super().__init__()
122
+ self.patch_embed = PatchEmbed(cfg["image_size"], cfg["patch_size"], cfg["in_channels"], cfg["emb_dim"])
123
+ n_patches = self.patch_embed.n_patches
124
+ self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"]))
125
+ self.pos_embed = nn.Parameter(torch.zeros(1, 1+n_patches, cfg["emb_dim"]))
126
+ self.pos_drop = nn.Dropout(p=cfg["drop"])
127
+ dpr = [x.item() for x in torch.linspace(0, 0.1, cfg["depth"])]
128
+ self.blocks = nn.ModuleList([Block(cfg["emb_dim"], cfg["num_heads"], cfg["mlp_ratio"], cfg["drop"], dpr[i]) for i in range(cfg["depth"])])
129
+ self.norm = nn.LayerNorm(cfg["emb_dim"])
130
+ self.head = nn.Linear(cfg["emb_dim"], cfg["num_classes"])
131
+ nn.init.trunc_normal_(self.pos_embed,std=.02)
132
+ nn.init.trunc_normal_(self.cls_token,std=.02)
133
+ self.apply(self._init_weights)
134
+ def _init_weights(self, m):
135
+ if isinstance(m, nn.Linear):
136
+ nn.init.xavier_uniform_(m.weight)
137
+ if m.bias is not None:
138
+ nn.init.zeros_(m.bias)
139
+ elif isinstance(m, nn.LayerNorm):
140
+ nn.init.zeros_(m.bias)
141
+ nn.init.ones_(m.weight)
142
+ def forward(self,x):
143
+ B = x.shape[0]
144
+ x = self.patch_embed(x)
145
+ cls_tokens = self.cls_token.expand(B,-1,-1)
146
+ x = torch.cat((cls_tokens,x),dim=1)
147
+ x = x + self.pos_embed
148
+ x = self.pos_drop(x)
149
+ for blk in self.blocks:
150
+ x = blk(x)
151
+ x = self.norm(x)
152
+ cls = x[:,0]
153
+ out = self.head(cls)
154
+ return out
155
+
156
+ # ------------------------
157
+ # Load model weights
158
+ model = ViT(cfg).to(device)
159
+ model.load_state_dict(torch.load("best_vit_cifar100.pt", map_location=device))
160
+ model.eval()
161
+
162
+ # ------------------------
163
+ # Image preprocessing
164
+ transform = transforms.Compose([
165
+ transforms.Resize((32,32)),
166
+ transforms.ToTensor(),
167
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100 stats
168
+ ])
169
+
170
+ def predict(img: Image.Image):
171
+ img = transform(img).unsqueeze(0).to(device)
172
+ with torch.no_grad():
173
+ out = model(img)
174
+ pred = out.argmax(1).item()
175
+ return classes[pred]
176
+
177
+ # ------------------------
178
+ # Gradio interface
179
+ iface = gr.Interface(
180
+ fn=predict,
181
+ inputs=gr.Image(type="pil"),
182
+ outputs=gr.Label(num_top_classes=1),
183
+ title="ViT CIFAR-100 Classifier",
184
+ description="Upload a 32x32 image, and the model predicts the CIFAR-100 class."
185
+ )
186
+
187
+ iface.launch()