Update app.py
Browse files
app.py
CHANGED
|
@@ -13,9 +13,9 @@ cfg = {
|
|
| 13 |
"patch_size": 4,
|
| 14 |
"in_channels": 3,
|
| 15 |
"num_classes": 100,
|
| 16 |
-
"emb_dim":
|
| 17 |
"num_heads": 6,
|
| 18 |
-
"depth":
|
| 19 |
"mlp_ratio": 4.0,
|
| 20 |
"drop": 0.1
|
| 21 |
}
|
|
@@ -40,32 +40,32 @@ classes = [
|
|
| 40 |
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
|
| 43 |
-
# ViT model implementation
|
| 44 |
class ConvPatchEmbed(nn.Module):
|
| 45 |
-
def __init__(self, in_chans=3, embed_dim=
|
| 46 |
super().__init__()
|
| 47 |
-
#
|
| 48 |
-
self.
|
| 49 |
nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
| 50 |
nn.BatchNorm2d(64),
|
| 51 |
nn.ReLU(inplace=True),
|
| 52 |
|
| 53 |
-
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
|
| 54 |
nn.BatchNorm2d(128),
|
| 55 |
nn.ReLU(inplace=True),
|
| 56 |
|
| 57 |
-
nn.Conv2d(128, embed_dim, kernel_size=3, stride=
|
| 58 |
nn.BatchNorm2d(embed_dim),
|
| 59 |
nn.ReLU(inplace=True),
|
| 60 |
)
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
|
| 64 |
def forward(self, x):
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
x = x.flatten(2)
|
| 68 |
-
x = x.transpose(1, 2)
|
| 69 |
return x
|
| 70 |
|
| 71 |
class MLP(nn.Module):
|
|
@@ -76,6 +76,7 @@ class MLP(nn.Module):
|
|
| 76 |
self.act = nn.GELU()
|
| 77 |
self.fc2 = nn.Linear(hidden_features, in_features)
|
| 78 |
self.drop = nn.Dropout(drop)
|
|
|
|
| 79 |
def forward(self, x):
|
| 80 |
x = self.fc1(x)
|
| 81 |
x = self.act(x)
|
|
@@ -85,70 +86,87 @@ class MLP(nn.Module):
|
|
| 85 |
return x
|
| 86 |
|
| 87 |
class Attention(nn.Module):
|
| 88 |
-
def __init__(self, dim, num_heads=8, qkv_bias=
|
| 89 |
super().__init__()
|
| 90 |
self.num_heads = num_heads
|
| 91 |
head_dim = dim // num_heads
|
| 92 |
self.scale = head_dim ** -0.5
|
| 93 |
|
| 94 |
-
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
|
| 95 |
self.attn_drop = nn.Dropout(attn_drop)
|
| 96 |
self.proj = nn.Linear(dim, dim)
|
| 97 |
self.proj_drop = nn.Dropout(proj_drop)
|
| 98 |
|
| 99 |
def forward(self, x):
|
| 100 |
B, N, C = x.shape
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 104 |
attn = attn.softmax(dim=-1)
|
| 105 |
attn = self.attn_drop(attn)
|
| 106 |
-
|
|
|
|
| 107 |
x = self.proj(x)
|
| 108 |
x = self.proj_drop(x)
|
| 109 |
return x
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
class Block(nn.Module):
|
| 112 |
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
|
| 113 |
super().__init__()
|
| 114 |
self.norm1 = nn.LayerNorm(dim)
|
| 115 |
self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
|
| 116 |
-
self.drop_path =
|
| 117 |
self.norm2 = nn.LayerNorm(dim)
|
| 118 |
-
self.mlp = MLP(dim, int(dim*mlp_ratio), drop=drop)
|
| 119 |
|
| 120 |
def forward(self, x):
|
| 121 |
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 122 |
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 123 |
return x
|
| 124 |
|
| 125 |
-
# Simple implementation of stochastic depth
|
| 126 |
-
class _StochasticDepth(nn.Module):
|
| 127 |
-
def __init__(self, p):
|
| 128 |
-
super().__init__()
|
| 129 |
-
self.p = p
|
| 130 |
-
def forward(self, x):
|
| 131 |
-
if not self.training or self.p == 0.:
|
| 132 |
-
return x
|
| 133 |
-
keep = torch.rand(x.shape[0], 1, 1, device=x.device) >= self.p
|
| 134 |
-
return x * keep / (1 - self.p)
|
| 135 |
-
|
| 136 |
class ViT(nn.Module):
|
| 137 |
def __init__(self, cfg):
|
| 138 |
super().__init__()
|
| 139 |
-
img_size
|
| 140 |
-
|
| 141 |
-
self.patch_embed = ConvPatchEmbed(
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
self.cls_token = nn.Parameter(torch.zeros(1,1,cfg["emb_dim"]))
|
| 145 |
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
|
| 146 |
self.pos_drop = nn.Dropout(p=cfg["drop"])
|
| 147 |
|
| 148 |
-
#
|
| 149 |
-
dpr =
|
| 150 |
self.blocks = nn.ModuleList([
|
| 151 |
-
Block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
for i in range(cfg["depth"])
|
| 153 |
])
|
| 154 |
self.norm = nn.LayerNorm(cfg["emb_dim"])
|
|
@@ -174,9 +192,9 @@ class ViT(nn.Module):
|
|
| 174 |
|
| 175 |
def forward(self, x):
|
| 176 |
B = x.shape[0]
|
| 177 |
-
x = self.patch_embed(x)
|
| 178 |
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 179 |
-
x = torch.cat((cls_tokens, x), dim=1)
|
| 180 |
x = x + self.pos_embed
|
| 181 |
x = self.pos_drop(x)
|
| 182 |
|
|
@@ -190,7 +208,7 @@ class ViT(nn.Module):
|
|
| 190 |
|
| 191 |
|
| 192 |
# Load model weights
|
| 193 |
-
checkpoint = torch.load("
|
| 194 |
|
| 195 |
model = ViT(cfg).to(device)
|
| 196 |
|
|
|
|
| 13 |
"patch_size": 4,
|
| 14 |
"in_channels": 3,
|
| 15 |
"num_classes": 100,
|
| 16 |
+
"emb_dim": 192,
|
| 17 |
"num_heads": 6,
|
| 18 |
+
"depth": 6,
|
| 19 |
"mlp_ratio": 4.0,
|
| 20 |
"drop": 0.1
|
| 21 |
}
|
|
|
|
| 40 |
|
| 41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
|
|
|
|
| 43 |
class ConvPatchEmbed(nn.Module):
|
| 44 |
+
def __init__(self, img_size=32, in_chans=3, embed_dim=192):
|
| 45 |
super().__init__()
|
| 46 |
+
# 32x32 -> 32x32 -> 16x16 -> 16x16
|
| 47 |
+
self.proj = nn.Sequential(
|
| 48 |
nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
| 49 |
nn.BatchNorm2d(64),
|
| 50 |
nn.ReLU(inplace=True),
|
| 51 |
|
| 52 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), # 32 -> 16
|
| 53 |
nn.BatchNorm2d(128),
|
| 54 |
nn.ReLU(inplace=True),
|
| 55 |
|
| 56 |
+
nn.Conv2d(128, embed_dim, kernel_size=3, stride=1, padding=1, bias=False), # stays 16x16
|
| 57 |
nn.BatchNorm2d(embed_dim),
|
| 58 |
nn.ReLU(inplace=True),
|
| 59 |
)
|
| 60 |
+
|
| 61 |
+
grid_size = (img_size // 2, img_size // 2) # (16,16)
|
| 62 |
+
self.grid_size = grid_size
|
| 63 |
+
self.num_patches = grid_size[0] * grid_size[1]
|
| 64 |
|
| 65 |
def forward(self, x):
|
| 66 |
+
x = self.proj(x) # (B, E, H=16, W=16)
|
| 67 |
+
B, C, H, W = x.shape
|
| 68 |
+
x = x.flatten(2).transpose(1, 2) # (B, N=H*W, E)
|
|
|
|
| 69 |
return x
|
| 70 |
|
| 71 |
class MLP(nn.Module):
|
|
|
|
| 76 |
self.act = nn.GELU()
|
| 77 |
self.fc2 = nn.Linear(hidden_features, in_features)
|
| 78 |
self.drop = nn.Dropout(drop)
|
| 79 |
+
|
| 80 |
def forward(self, x):
|
| 81 |
x = self.fc1(x)
|
| 82 |
x = self.act(x)
|
|
|
|
| 86 |
return x
|
| 87 |
|
| 88 |
class Attention(nn.Module):
|
| 89 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
|
| 90 |
super().__init__()
|
| 91 |
self.num_heads = num_heads
|
| 92 |
head_dim = dim // num_heads
|
| 93 |
self.scale = head_dim ** -0.5
|
| 94 |
|
| 95 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 96 |
self.attn_drop = nn.Dropout(attn_drop)
|
| 97 |
self.proj = nn.Linear(dim, dim)
|
| 98 |
self.proj_drop = nn.Dropout(proj_drop)
|
| 99 |
|
| 100 |
def forward(self, x):
|
| 101 |
B, N, C = x.shape
|
| 102 |
+
# (B, N, 3C) -> (3, B, heads, N, head_dim)
|
| 103 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 104 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 105 |
+
|
| 106 |
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 107 |
attn = attn.softmax(dim=-1)
|
| 108 |
attn = self.attn_drop(attn)
|
| 109 |
+
|
| 110 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 111 |
x = self.proj(x)
|
| 112 |
x = self.proj_drop(x)
|
| 113 |
return x
|
| 114 |
|
| 115 |
+
# Simple Stochastic Depth
|
| 116 |
+
class StochasticDepth(nn.Module):
|
| 117 |
+
def __init__(self, p):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.p = float(p)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
if not self.training or self.p == 0.0:
|
| 123 |
+
return x
|
| 124 |
+
keep_prob = 1.0 - self.p
|
| 125 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 126 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 127 |
+
random_tensor.floor_()
|
| 128 |
+
return x / keep_prob * random_tensor
|
| 129 |
+
|
| 130 |
class Block(nn.Module):
|
| 131 |
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
|
| 132 |
super().__init__()
|
| 133 |
self.norm1 = nn.LayerNorm(dim)
|
| 134 |
self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
|
| 135 |
+
self.drop_path = StochasticDepth(drop_path) if drop_path > 0. else nn.Identity()
|
| 136 |
self.norm2 = nn.LayerNorm(dim)
|
| 137 |
+
self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)
|
| 138 |
|
| 139 |
def forward(self, x):
|
| 140 |
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 141 |
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 142 |
return x
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
class ViT(nn.Module):
|
| 145 |
def __init__(self, cfg):
|
| 146 |
super().__init__()
|
| 147 |
+
img_size = cfg["image_size"]
|
| 148 |
+
|
| 149 |
+
self.patch_embed = ConvPatchEmbed(
|
| 150 |
+
img_size=img_size,
|
| 151 |
+
in_chans=cfg["in_channels"],
|
| 152 |
+
embed_dim=cfg["emb_dim"]
|
| 153 |
+
)
|
| 154 |
+
n_patches = self.patch_embed.num_patches
|
| 155 |
|
| 156 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg["emb_dim"]))
|
| 157 |
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + n_patches, cfg["emb_dim"]))
|
| 158 |
self.pos_drop = nn.Dropout(p=cfg["drop"])
|
| 159 |
|
| 160 |
+
# stochastic depth decay rule
|
| 161 |
+
dpr = torch.linspace(0, cfg["drop_path"], cfg["depth"]).tolist()
|
| 162 |
self.blocks = nn.ModuleList([
|
| 163 |
+
Block(
|
| 164 |
+
dim=cfg["emb_dim"],
|
| 165 |
+
num_heads=cfg["num_heads"],
|
| 166 |
+
mlp_ratio=cfg["mlp_ratio"],
|
| 167 |
+
drop=cfg["drop"],
|
| 168 |
+
drop_path=dpr[i]
|
| 169 |
+
)
|
| 170 |
for i in range(cfg["depth"])
|
| 171 |
])
|
| 172 |
self.norm = nn.LayerNorm(cfg["emb_dim"])
|
|
|
|
| 192 |
|
| 193 |
def forward(self, x):
|
| 194 |
B = x.shape[0]
|
| 195 |
+
x = self.patch_embed(x) # (B, N, E)
|
| 196 |
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 197 |
+
x = torch.cat((cls_tokens, x), dim=1) # (B, 1+N, E)
|
| 198 |
x = x + self.pos_embed
|
| 199 |
x = self.pos_drop(x)
|
| 200 |
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
# Load model weights
|
| 211 |
+
checkpoint = torch.load("Revised_best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device)
|
| 212 |
|
| 213 |
model = ViT(cfg).to(device)
|
| 214 |
|