Aumkeshchy2003 commited on
Commit
db8f24a
·
verified ·
1 Parent(s): 00a96fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -43
app.py CHANGED
@@ -13,9 +13,9 @@ cfg = {
13
  "patch_size": 4,
14
  "in_channels": 3,
15
  "num_classes": 100,
16
- "emb_dim": 384,
17
  "num_heads": 6,
18
- "depth": 8,
19
  "mlp_ratio": 4.0,
20
  "drop": 0.1
21
  }
@@ -40,32 +40,32 @@ classes = [
40
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
- # ViT model implementation
44
  class ConvPatchEmbed(nn.Module):
45
- def __init__(self, in_chans=3, embed_dim=384):
46
  super().__init__()
47
- # Input 32x32 -> conv1: 32x32 -> conv2 stride2 -> 16x16 -> conv3 stride2 -> 8x8
48
- self.conv = nn.Sequential(
49
  nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
50
  nn.BatchNorm2d(64),
51
  nn.ReLU(inplace=True),
52
 
53
- nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
54
  nn.BatchNorm2d(128),
55
  nn.ReLU(inplace=True),
56
 
57
- nn.Conv2d(128, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
58
  nn.BatchNorm2d(embed_dim),
59
  nn.ReLU(inplace=True),
60
  )
61
- # n_patches = (32/4)^2 = 8*8 = 64
62
- self.n_patches = (32 // 4) ** 2
 
 
63
 
64
  def forward(self, x):
65
- # x: (B, C, H, W)
66
- x = self.conv(x)
67
- x = x.flatten(2)
68
- x = x.transpose(1, 2)
69
  return x
70
 
71
  class MLP(nn.Module):
@@ -76,6 +76,7 @@ class MLP(nn.Module):
76
  self.act = nn.GELU()
77
  self.fc2 = nn.Linear(hidden_features, in_features)
78
  self.drop = nn.Dropout(drop)
 
79
  def forward(self, x):
80
  x = self.fc1(x)
81
  x = self.act(x)
@@ -85,70 +86,87 @@ class MLP(nn.Module):
85
  return x
86
 
87
  class Attention(nn.Module):
88
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
89
  super().__init__()
90
  self.num_heads = num_heads
91
  head_dim = dim // num_heads
92
  self.scale = head_dim ** -0.5
93
 
94
- self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
95
  self.attn_drop = nn.Dropout(attn_drop)
96
  self.proj = nn.Linear(dim, dim)
97
  self.proj_drop = nn.Dropout(proj_drop)
98
 
99
  def forward(self, x):
100
  B, N, C = x.shape
101
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2,0,3,1,4)
102
- q, k, v = qkv[0], qkv[1], qkv[2] # each: (B, heads, N, head_dim)
 
 
103
  attn = (q @ k.transpose(-2, -1)) * self.scale
104
  attn = attn.softmax(dim=-1)
105
  attn = self.attn_drop(attn)
106
- x = (attn @ v).transpose(1,2).reshape(B, N, C)
 
107
  x = self.proj(x)
108
  x = self.proj_drop(x)
109
  return x
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class Block(nn.Module):
112
  def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
113
  super().__init__()
114
  self.norm1 = nn.LayerNorm(dim)
115
  self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
116
- self.drop_path = nn.Identity() if drop_path == 0. else _StochasticDepth(drop_path)
117
  self.norm2 = nn.LayerNorm(dim)
118
- self.mlp = MLP(dim, int(dim*mlp_ratio), drop=drop)
119
 
120
  def forward(self, x):
121
  x = x + self.drop_path(self.attn(self.norm1(x)))
122
  x = x + self.drop_path(self.mlp(self.norm2(x)))
123
  return x
124
 
125
- # Simple implementation of stochastic depth
126
- class _StochasticDepth(nn.Module):
127
- def __init__(self, p):
128
- super().__init__()
129
- self.p = p
130
- def forward(self, x):
131
- if not self.training or self.p == 0.:
132
- return x
133
- keep = torch.rand(x.shape[0], 1, 1, device=x.device) >= self.p
134
- return x * keep / (1 - self.p)
135
-
136
  class ViT(nn.Module):
137
  def __init__(self, cfg):
138
  super().__init__()
139
- img_size, patch_size = cfg["image_size"], cfg["patch_size"]
140
- # Use ConvPatchEmbed (hybrid) instead of linear patch conv with kernel=patch_size
141
- self.patch_embed = ConvPatchEmbed(cfg["in_channels"], cfg["emb_dim"])
142
- n_patches = self.patch_embed.n_patches
 
 
 
 
143
 
144
- self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"]))
145
  self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
146
  self.pos_drop = nn.Dropout(p=cfg["drop"])
147
 
148
- # transformer blocks
149
- dpr = [x.item() for x in torch.linspace(0, cfg.get("drop_path", 0.2), cfg["depth"])] # stochastic depth decay
150
  self.blocks = nn.ModuleList([
151
- Block(cfg["emb_dim"], num_heads=cfg["num_heads"], mlp_ratio=cfg["mlp_ratio"], drop=cfg["drop"], drop_path=dpr[i])
 
 
 
 
 
 
152
  for i in range(cfg["depth"])
153
  ])
154
  self.norm = nn.LayerNorm(cfg["emb_dim"])
@@ -174,9 +192,9 @@ class ViT(nn.Module):
174
 
175
  def forward(self, x):
176
  B = x.shape[0]
177
- x = self.patch_embed(x) # (B, N, E)
178
  cls_tokens = self.cls_token.expand(B, -1, -1)
179
- x = torch.cat((cls_tokens, x), dim=1) # (B, 1+N, E)
180
  x = x + self.pos_embed
181
  x = self.pos_drop(x)
182
 
@@ -190,7 +208,7 @@ class ViT(nn.Module):
190
 
191
 
192
  # Load model weights
193
- checkpoint = torch.load("best_CTD_ViT_CIFAR100_checkpoint.pth", map_location=device)
194
 
195
  model = ViT(cfg).to(device)
196
 
 
13
  "patch_size": 4,
14
  "in_channels": 3,
15
  "num_classes": 100,
16
+ "emb_dim": 192,
17
  "num_heads": 6,
18
+ "depth": 6,
19
  "mlp_ratio": 4.0,
20
  "drop": 0.1
21
  }
 
40
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
 
43
  class ConvPatchEmbed(nn.Module):
44
+ def __init__(self, img_size=32, in_chans=3, embed_dim=192):
45
  super().__init__()
46
+ # 32x32 -> 32x32 -> 16x16 -> 16x16
47
+ self.proj = nn.Sequential(
48
  nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
49
  nn.BatchNorm2d(64),
50
  nn.ReLU(inplace=True),
51
 
52
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), # 32 -> 16
53
  nn.BatchNorm2d(128),
54
  nn.ReLU(inplace=True),
55
 
56
+ nn.Conv2d(128, embed_dim, kernel_size=3, stride=1, padding=1, bias=False), # stays 16x16
57
  nn.BatchNorm2d(embed_dim),
58
  nn.ReLU(inplace=True),
59
  )
60
+
61
+ grid_size = (img_size // 2, img_size // 2) # (16,16)
62
+ self.grid_size = grid_size
63
+ self.num_patches = grid_size[0] * grid_size[1]
64
 
65
  def forward(self, x):
66
+ x = self.proj(x) # (B, E, H=16, W=16)
67
+ B, C, H, W = x.shape
68
+ x = x.flatten(2).transpose(1, 2) # (B, N=H*W, E)
 
69
  return x
70
 
71
  class MLP(nn.Module):
 
76
  self.act = nn.GELU()
77
  self.fc2 = nn.Linear(hidden_features, in_features)
78
  self.drop = nn.Dropout(drop)
79
+
80
  def forward(self, x):
81
  x = self.fc1(x)
82
  x = self.act(x)
 
86
  return x
87
 
88
  class Attention(nn.Module):
89
+ def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
90
  super().__init__()
91
  self.num_heads = num_heads
92
  head_dim = dim // num_heads
93
  self.scale = head_dim ** -0.5
94
 
95
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
96
  self.attn_drop = nn.Dropout(attn_drop)
97
  self.proj = nn.Linear(dim, dim)
98
  self.proj_drop = nn.Dropout(proj_drop)
99
 
100
  def forward(self, x):
101
  B, N, C = x.shape
102
+ # (B, N, 3C) -> (3, B, heads, N, head_dim)
103
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
104
+ q, k, v = qkv[0], qkv[1], qkv[2]
105
+
106
  attn = (q @ k.transpose(-2, -1)) * self.scale
107
  attn = attn.softmax(dim=-1)
108
  attn = self.attn_drop(attn)
109
+
110
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
111
  x = self.proj(x)
112
  x = self.proj_drop(x)
113
  return x
114
 
115
+ # Simple Stochastic Depth
116
+ class StochasticDepth(nn.Module):
117
+ def __init__(self, p):
118
+ super().__init__()
119
+ self.p = float(p)
120
+
121
+ def forward(self, x):
122
+ if not self.training or self.p == 0.0:
123
+ return x
124
+ keep_prob = 1.0 - self.p
125
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
126
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
127
+ random_tensor.floor_()
128
+ return x / keep_prob * random_tensor
129
+
130
  class Block(nn.Module):
131
  def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
132
  super().__init__()
133
  self.norm1 = nn.LayerNorm(dim)
134
  self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
135
+ self.drop_path = StochasticDepth(drop_path) if drop_path > 0. else nn.Identity()
136
  self.norm2 = nn.LayerNorm(dim)
137
+ self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)
138
 
139
  def forward(self, x):
140
  x = x + self.drop_path(self.attn(self.norm1(x)))
141
  x = x + self.drop_path(self.mlp(self.norm2(x)))
142
  return x
143
 
 
 
 
 
 
 
 
 
 
 
 
144
  class ViT(nn.Module):
145
  def __init__(self, cfg):
146
  super().__init__()
147
+ img_size = cfg["image_size"]
148
+
149
+ self.patch_embed = ConvPatchEmbed(
150
+ img_size=img_size,
151
+ in_chans=cfg["in_channels"],
152
+ embed_dim=cfg["emb_dim"]
153
+ )
154
+ n_patches = self.patch_embed.num_patches
155
 
156
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg["emb_dim"]))
157
  self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
158
  self.pos_drop = nn.Dropout(p=cfg["drop"])
159
 
160
+ # stochastic depth decay rule
161
+ dpr = torch.linspace(0, cfg["drop_path"], cfg["depth"]).tolist()
162
  self.blocks = nn.ModuleList([
163
+ Block(
164
+ dim=cfg["emb_dim"],
165
+ num_heads=cfg["num_heads"],
166
+ mlp_ratio=cfg["mlp_ratio"],
167
+ drop=cfg["drop"],
168
+ drop_path=dpr[i]
169
+ )
170
  for i in range(cfg["depth"])
171
  ])
172
  self.norm = nn.LayerNorm(cfg["emb_dim"])
 
192
 
193
  def forward(self, x):
194
  B = x.shape[0]
195
+ x = self.patch_embed(x) # (B, N, E)
196
  cls_tokens = self.cls_token.expand(B, -1, -1)
197
+ x = torch.cat((cls_tokens, x), dim=1) # (B, 1+N, E)
198
  x = x + self.pos_embed
199
  x = self.pos_drop(x)
200
 
 
208
 
209
 
210
  # Load model weights
211
+ checkpoint = torch.load("Revised_best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device)
212
 
213
  model = ViT(cfg).to(device)
214