Munazz commited on
Commit
e94e577
·
1 Parent(s): ccc3c0d

Move files to Clothes-Category-Classifier

Browse files
Files changed (6) hide show
  1. .gitignore +31 -0
  2. main.py +100 -0
  3. requirements.txt +6 -0
  4. src/datasets.py +91 -0
  5. src/model_jigsaw.py +235 -0
  6. src/utils.py +267 -0
.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore Python bytecode files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Ignore virtual environments
7
+ env/
8
+ venv/
9
+ myenv/
10
+
11
+ # Ignore Jupyter notebook checkpoints
12
+ .ipynb_checkpoints
13
+
14
+ # Ignore model files that are tracked by Git LFS
15
+ netBest.pth
16
+
17
+ # Ignore test images (if you don't want to include them in your repo)
18
+ test_images/
19
+
20
+ # Ignore log files and temporary files
21
+ *.log
22
+ *.tmp
23
+ *.bak
24
+
25
+ # Ignore system files
26
+ .DS_Store
27
+ Thumbs.db
28
+
29
+ # Ignore any other temporary or backup files
30
+ *.swp
31
+ *.swo
main.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision
6
+ from torchvision import transforms, datasets # Correct import for torchvision
7
+ import sys
8
+
9
+
10
+ # Add the 'src' folder to the system path for module imports
11
+ sys.path.append('./src')
12
+
13
+ # Import from the 'src' folder
14
+ from model_jigsaw import mae_vit_small_patch16 # Import directly from src folder
15
+
16
+
17
+ # Static Variables
18
+ MODEL_PATH = "model/netBest.pth"
19
+ TEST_IMAGE_FOLDER = "test_images/"
20
+
21
+ # Class Mapping from Clothing1M dataset
22
+ class_names = {
23
+ 0: "T-shirt",
24
+ 1: "Shirt",
25
+ 2: "Knitwear",
26
+ 3: "Chiffon",
27
+ 4: "Sweater",
28
+ 5: "Hoodie",
29
+ 6: "Windbreaker",
30
+ 7: "Jacket",
31
+ 8: "Down Coat",
32
+ 9: "Suit",
33
+ 10: "Shawl",
34
+ 11: "Dress",
35
+ 12: "Vest",
36
+ 13: "Nightwear"
37
+ }
38
+
39
+ # Image Preprocessing
40
+ def preprocess_image(image):
41
+ transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ image = image.convert('RGB')
48
+ image = transform(image).unsqueeze(0)
49
+ return image
50
+
51
+ # Load Model
52
+ model = mae_vit_small_patch16(nb_cls=14)
53
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"))['net'])
54
+ model.eval()
55
+
56
+ # Load dataset mapping
57
+ val_dataset = torchvision.datasets.ImageFolder(root=TEST_IMAGE_FOLDER)
58
+ idx_to_class = {v: k for k, v in val_dataset.class_to_idx.items()}
59
+
60
+ def predict_single_image(image):
61
+ """Predicts the class of an image using the exact logic from eval.py"""
62
+ image = preprocess_image(image)
63
+
64
+ with torch.no_grad():
65
+ outputs = model.forward_cls(image)
66
+
67
+ _, predicted_class = torch.max(outputs, 1)
68
+
69
+ # Get final class name
70
+ mapped_class_index = idx_to_class[predicted_class.item()]
71
+ final_class_name = class_names[int(mapped_class_index)]
72
+
73
+ return final_class_name
74
+
75
+ # Get all images from subfolders (recursively)
76
+ def get_image_paths(test_image_folder):
77
+ image_paths = []
78
+ for root, dirs, files in os.walk(test_image_folder):
79
+ for file in files:
80
+ if file.endswith((".jpg", ".png")):
81
+ image_paths.append(os.path.join(root, file))
82
+ return image_paths
83
+
84
+ test_image_files = get_image_paths(TEST_IMAGE_FOLDER)
85
+
86
+ def load_test_image(selected_image):
87
+ """Loads the selected test image from test_images folder"""
88
+ return Image.open(selected_image)
89
+
90
+ # Create Gradio Interface with dynamic examples from test_images folder
91
+ demo = gr.Interface(
92
+ fn=predict_single_image,
93
+ inputs=gr.Image(type="pil", label="Upload an image"),
94
+ outputs=gr.Textbox(label="Predicted Category"),
95
+ title="Clothes Category Classifier",
96
+ description="Upload an image to classify its clothing category.",
97
+ examples=test_image_files # Use the correct paths for example images
98
+ )
99
+
100
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow
5
+ einops
6
+ timm
src/datasets.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # --------------------------------------------------------
10
+
11
+ import os
12
+ import PIL
13
+
14
+ from torchvision import datasets, transforms
15
+
16
+ from timm.data import create_transform
17
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+
19
+
20
+ def build_dataset(is_train, args) :
21
+ transform = build_transform(is_train, args)
22
+
23
+ if args.data_set == 'CIFAR10' :
24
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
25
+ nb_cls = 10
26
+ elif args.data_set == 'CIFAR100' :
27
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
28
+ nb_cls = 100
29
+ elif args.data_set == 'Animal10N' :
30
+ root = os.path.join(args.data_path, 'train' if is_train else 'test')
31
+ nb_cls = 10
32
+ elif args.data_set == 'Clothing1M' :
33
+ # we use a randomly selected balanced training subset
34
+ root = os.path.join(args.data_path, 'noisy_rand_subtrain' if is_train else 'clean_val')
35
+ nb_cls = 14
36
+ elif args.data_set == 'Food101N' :
37
+ root = os.path.join(args.data_path, 'train' if is_train else 'test')
38
+ nb_cls = 101
39
+
40
+ dataset = datasets.ImageFolder(root, transform=transform)
41
+
42
+ print(dataset)
43
+
44
+ return dataset, nb_cls
45
+
46
+
47
+ def build_transform(is_train, args) :
48
+ if args.data_set == 'CIFAR10' or args.data_set == 'CIFAR100' :
49
+ mean = (0.4914, 0.4822, 0.4465)
50
+ std = (0.2023, 0.1994, 0.2010)
51
+ else :
52
+ mean = IMAGENET_DEFAULT_MEAN
53
+ std = IMAGENET_DEFAULT_STD
54
+
55
+ resize_im = args.input_size > 32
56
+ if is_train :
57
+ # this should always dispatch to transforms_imagenet_train
58
+ transform = create_transform(
59
+ input_size=args.input_size,
60
+ is_training=True,
61
+ color_jitter=args.color_jitter,
62
+ auto_augment=args.aa,
63
+ interpolation='bicubic',
64
+ re_prob=args.reprob,
65
+ re_mode=args.remode,
66
+ re_count=args.recount,
67
+ mean=mean,
68
+ std=std,
69
+ )
70
+ if not resize_im :
71
+ # replace RandomResizedCropAndInterpolation with
72
+ # RandomCrop
73
+ transform.transforms[0] = transforms.RandomCrop(
74
+ args.input_size, padding=4)
75
+ return transform
76
+
77
+ # eval transform
78
+ t = []
79
+ if args.input_size <= 224 :
80
+ crop_pct = 224 / 256
81
+ else :
82
+ crop_pct = 1.0
83
+ size = int(args.input_size / crop_pct)
84
+ t.append(
85
+ transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
86
+ )
87
+ t.append(transforms.CenterCrop(args.input_size))
88
+
89
+ t.append(transforms.ToTensor())
90
+ t.append(transforms.Normalize(mean, std))
91
+ return transforms.Compose(t)
src/model_jigsaw.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import einops
17
+
18
+ from timm.models.vision_transformer import PatchEmbed, Block
19
+
20
+ import utils
21
+
22
+ class MaskedAutoencoderViT(nn.Module):
23
+ """ Masked Autoencoder with VisionTransformer backbone
24
+ """
25
+ def __init__(self,
26
+ nb_cls=10,
27
+ img_size=224,
28
+ patch_size=16,
29
+ in_chans=3,
30
+ embed_dim=1024,
31
+ depth=24,
32
+ num_heads=16,
33
+ mlp_ratio=4.,
34
+ norm_layer=nn.LayerNorm):
35
+ super().__init__()
36
+
37
+ # --------------------------------------------------------------------------
38
+ # MAE encoder specifics
39
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
40
+ self.num_patches = self.patch_embed.num_patches
41
+
42
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
43
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
44
+
45
+ self.blocks = nn.ModuleList([
46
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
47
+ for i in range(depth)])
48
+ self.norm = norm_layer(embed_dim)
49
+ self.head = torch.nn.Linear(embed_dim, nb_cls)
50
+ self.jigsaw = torch.nn.Sequential(*[torch.nn.Linear(embed_dim, embed_dim),
51
+ torch.nn.ReLU(),
52
+ torch.nn.Linear(embed_dim, embed_dim),
53
+ torch.nn.ReLU(),
54
+ torch.nn.Linear(embed_dim, self.num_patches)])
55
+ self.target = torch.arange(self.num_patches)
56
+
57
+
58
+ self.initialize_weights()
59
+
60
+ def initialize_weights(self):
61
+ # initialization
62
+ # initialize (and freeze) pos_embed by sin-cos embedding
63
+ pos_embed = utils.get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
64
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
65
+
66
+
67
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
68
+ w = self.patch_embed.proj.weight.data
69
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
70
+
71
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
72
+ torch.nn.init.normal_(self.cls_token, std=.02)
73
+
74
+ # initialize nn.Linear and nn.LayerNorm
75
+ self.apply(self._init_weights)
76
+
77
+ def _init_weights(self, m):
78
+ if isinstance(m, nn.Linear):
79
+ # we use xavier_uniform following official JAX ViT:
80
+ torch.nn.init.xavier_uniform_(m.weight)
81
+ if isinstance(m, nn.Linear) and m.bias is not None:
82
+ nn.init.constant_(m.bias, 0)
83
+ elif isinstance(m, nn.LayerNorm):
84
+ nn.init.constant_(m.bias, 0)
85
+ nn.init.constant_(m.weight, 1.0)
86
+
87
+ def patchify(self, imgs):
88
+ """
89
+ imgs: (N, 3, H, W)
90
+ x: (N, L, patch_size**2 *3)
91
+ """
92
+ p = self.patch_embed.patch_size[0]
93
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
94
+
95
+ h = w = imgs.shape[2] // p
96
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
97
+ x = torch.einsum('nchpwq->nhwpqc', x)
98
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
99
+ return x
100
+
101
+ def unpatchify(self, x):
102
+ """
103
+ x: (N, L, patch_size**2 *3)
104
+ imgs: (N, 3, H, W)
105
+ """
106
+ p = self.patch_embed.patch_size[0]
107
+ h = w = int(x.shape[1]**.5)
108
+ assert h * w == x.shape[1]
109
+
110
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
111
+ x = torch.einsum('nhwpqc->nchpwq', x)
112
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
113
+ return imgs
114
+
115
+ def random_masking(self, x, mask_ratio):
116
+ """
117
+ Perform per-sample random masking by per-sample shuffling.
118
+ Per-sample shuffling is done by argsort random noise.
119
+ x: [N, L, D], sequence
120
+ """
121
+ N, L, D = x.shape # batch, length, dim
122
+ len_keep = int(L * (1 - mask_ratio))
123
+
124
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
125
+
126
+ # sort noise for each sample
127
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
128
+ # target = einops.repeat(self.target, 'L -> N L', N=N)
129
+ # target = target.to(x.device)
130
+
131
+ # keep the first subset
132
+ ids_keep = ids_shuffle[:, :len_keep] # N, len_keep
133
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
134
+ target_masked = ids_keep
135
+
136
+ return x_masked, target_masked
137
+
138
+ def forward_jigsaw(self, x, mask_ratio):
139
+ # embed patches
140
+ x = self.patch_embed(x)
141
+
142
+ # masking: length -> length * mask_ratio
143
+ x, target = self.random_masking(x, mask_ratio)
144
+
145
+ # append cls token
146
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
147
+ x = torch.cat((cls_tokens, x), dim=1)
148
+
149
+ # apply Transformer blocks
150
+ for blk in self.blocks:
151
+ x = blk(x)
152
+ x = self.norm(x)
153
+ x = self.jigsaw(x[:, 1:])
154
+ return x.reshape(-1, self.num_patches), target.reshape(-1)
155
+
156
+ def forward_cls(self, x) :
157
+ # embed patches
158
+ x = self.patch_embed(x)
159
+
160
+ # add pos embed w/o cls token
161
+ x = x + self.pos_embed[:, 1:, :]
162
+
163
+ # append cls token
164
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
165
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
166
+ x = torch.cat((cls_tokens, x), dim=1)
167
+
168
+ # apply Transformer blocks
169
+ for blk in self.blocks:
170
+ x = blk(x)
171
+ x = self.norm(x)
172
+ x = self.head(x[:, 0])
173
+ return x
174
+
175
+ def forward(self, x_jigsaw, x_cls, mask_ratio) :
176
+ pred_jigsaw, targets_jigsaw = self.forward_jigsaw(x_jigsaw, mask_ratio)
177
+ pred_cls = self.forward_cls(x_cls)
178
+ return pred_jigsaw, targets_jigsaw, pred_cls
179
+
180
+
181
+ def mae_vit_small_patch16(nb_cls, **kwargs):
182
+ model = MaskedAutoencoderViT(nb_cls,
183
+ img_size=224,
184
+ patch_size=16,
185
+ embed_dim=384,
186
+ depth=12,
187
+ num_heads=6,
188
+ mlp_ratio=4,
189
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
190
+ **kwargs)
191
+ return model
192
+
193
+
194
+ def mae_vit_base_patch16(nb_cls, **kwargs):
195
+ model = MaskedAutoencoderViT(nb_cls,
196
+ img_size=224,
197
+ patch_size=16,
198
+ embed_dim=768,
199
+ depth=12,
200
+ num_heads=12,
201
+ mlp_ratio=4,
202
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
203
+ **kwargs)
204
+ return model
205
+
206
+
207
+ def mae_vit_large_patch16(nb_cls, **kwargs):
208
+ model = MaskedAutoencoderViT(nb_cls,
209
+ img_size=224,
210
+ patch_size=16,
211
+ embed_dim=1024,
212
+ depth=24,
213
+ num_heads=16,
214
+ mlp_ratio=4,
215
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
216
+ **kwargs)
217
+ return model
218
+
219
+
220
+ def create_model(arch, nb_cls) :
221
+ if arch == 'vit_small_patch16' :
222
+ return mae_vit_small_patch16(nb_cls)
223
+ elif arch == 'vit_base_patch16' :
224
+ return mae_vit_base_patch16(nb_cls)
225
+ elif arch == 'vit_large_patch16' :
226
+ return mae_vit_large_patch16(nb_cls)
227
+
228
+ if __name__ == '__main__':
229
+
230
+ net = create_model(arch = 'vit_small_patch16', nb_cls = 10)
231
+ net = net.cpu() # Move the model to CPU instead of CUDA
232
+ img = torch.cuda.FloatTensor(6, 3, 224, 224)
233
+ mask_ratio = 0.75
234
+ with torch.no_grad():
235
+ x, target = net.forward_jigsaw(img, mask_ratio)
src/utils.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Some helper functions for PyTorch, including:
2
+ - get_mean_and_std: calculate the mean and std value of dataset.
3
+ - msr_init: net parameter initialization.
4
+ - progress_bar: progress bar mimic xlua.progress.
5
+ '''
6
+ import os
7
+ import sys
8
+ import time
9
+ import math
10
+
11
+ import logging
12
+ from datetime import datetime
13
+ import torch
14
+ import numpy as np
15
+ from torch.nn import Parameter
16
+
17
+
18
+ def get_logger(out_dir):
19
+ logger = logging.getLogger('Exp')
20
+ logger.setLevel(logging.INFO)
21
+ formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
22
+
23
+ ts = str(datetime.now()).split(".")[0].replace(" ", "_")
24
+ ts = ts.replace(":", "_").replace("-", "_")
25
+ file_path = os.path.join(out_dir, "run_{}.log".format(ts)) if os.path.isdir(out_dir) else out_dir.replace('.pth.tar', '')
26
+ file_hdlr = logging.FileHandler(file_path)
27
+ file_hdlr.setFormatter(formatter)
28
+
29
+ strm_hdlr = logging.StreamHandler(sys.stdout)
30
+ strm_hdlr.setFormatter(formatter)
31
+
32
+ logger.addHandler(file_hdlr)
33
+ logger.addHandler(strm_hdlr)
34
+ return logger
35
+
36
+
37
+ class AverageMeter(object):
38
+ """Computes and stores the average and current value"""
39
+ def __init__(self):
40
+ self.reset()
41
+
42
+ def reset(self):
43
+ self.val = 0
44
+ self.avg = 0
45
+ self.sum = 0
46
+ self.count = 0
47
+
48
+ def update(self, val, n=1):
49
+ self.val = val
50
+ self.sum += val * n
51
+ self.count += n
52
+ self.avg = self.sum / self.count
53
+
54
+
55
+ def accuracy(output, target, topk=(1,)):
56
+ """Computes the accuracy over the k top predictions for the specified values of k"""
57
+ with torch.no_grad():
58
+ maxk = max(topk)
59
+ batch_size = target.size()[0]
60
+
61
+ _, pred = output.topk(maxk, 1, True, True)
62
+ pred = pred.t()
63
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
64
+
65
+ res = []
66
+ for k in topk:
67
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
68
+ res.append(correct_k.mul_(100.0 / batch_size))
69
+ return res
70
+
71
+
72
+ def update_lr(iteration, warmup_iter, total_iter, max_lr, min_lr) :
73
+ if iteration < warmup_iter:
74
+ current_lr = max_lr * iteration / warmup_iter
75
+ else:
76
+ current_lr = min_lr + (max_lr - min_lr) * 0.5 * \
77
+ (1. + math.cos(math.pi * (iteration - warmup_iter) / (total_iter - warmup_iter)))
78
+ return current_lr
79
+
80
+
81
+ def adjust_learning_rate(optimizer, iteration, warmup_iter, total_iter, max_lr, min_lr):
82
+ """Decay the learning rate with half-cycle cosine after warmup"""
83
+ current_lr = update_lr(iteration, warmup_iter, total_iter, max_lr, min_lr)
84
+ for param_group in optimizer.param_groups:
85
+ if "lr_scale" in param_group:
86
+ param_group["lr"] = current_lr * param_group["lr_scale"]
87
+ else:
88
+ param_group["lr"] = current_lr
89
+ return current_lr
90
+
91
+
92
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
93
+ """
94
+ grid_size: int of the grid height and width
95
+ return:
96
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
97
+ """
98
+ grid_h = np.arange(grid_size, dtype=np.float32)
99
+ grid_w = np.arange(grid_size, dtype=np.float32)
100
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
101
+ grid = np.stack(grid, axis=0)
102
+
103
+ grid = grid.reshape([2, 1, grid_size, grid_size])
104
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
105
+ if cls_token:
106
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
107
+ return pos_embed
108
+
109
+
110
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
111
+ assert embed_dim % 2 == 0
112
+
113
+ # use half of dimensions to encode grid_h
114
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
115
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
116
+
117
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
118
+ return emb
119
+
120
+
121
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
122
+ """
123
+ embed_dim: output dimension for each position
124
+ pos: a list of positions to be encoded: size (M,)
125
+ out: (M, D)
126
+ """
127
+ assert embed_dim % 2 == 0
128
+ omega = np.arange(embed_dim // 2, dtype=float)
129
+ omega /= embed_dim / 2.
130
+ omega = 1. / 10000**omega # (D/2,)
131
+
132
+ pos = pos.reshape(-1) # (M,)
133
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
134
+
135
+ emb_sin = np.sin(out) # (M, D/2)
136
+ emb_cos = np.cos(out) # (M, D/2)
137
+
138
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
139
+ return emb
140
+
141
+ # --------------------------------------------------------
142
+ # Interpolate position embeddings for high-resolution
143
+ # References:
144
+ # DeiT: https://github.com/facebookresearch/deit
145
+ # --------------------------------------------------------
146
+ def interpolate_pos_embed(model, checkpoint_model):
147
+ if 'pos_embed' in checkpoint_model:
148
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
149
+ embedding_size = pos_embed_checkpoint.shape[-1]
150
+ num_patches = model.patch_embed.num_patches
151
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
152
+ # height (== width) for the checkpoint position embedding
153
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
154
+ # height (== width) for the new position embedding
155
+ new_size = int(num_patches ** 0.5)
156
+ # class_token and dist_token are kept unchanged
157
+ if orig_size != new_size:
158
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
159
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
160
+ # only the position tokens are interpolated
161
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
162
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
163
+ pos_tokens = torch.nn.functional.interpolate(
164
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
165
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
166
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
167
+ checkpoint_model['pos_embed'] = new_pos_embed
168
+
169
+
170
+ def load_my_state_dict(net, state_dict):
171
+ own_state = net.state_dict()
172
+ for name, param in state_dict.items():
173
+ name = name.replace('module.','')
174
+ if name not in own_state:
175
+ continue
176
+ if isinstance(param, Parameter):
177
+ # backwards compatibility for serialized parameters
178
+ param = param.data
179
+ own_state[name].copy_(param)
180
+
181
+ # Fix for non-interactive environments
182
+ try:
183
+ _, term_width = os.popen('stty size', 'r').read().split()
184
+ term_width = int(term_width)
185
+ except ValueError:
186
+ term_width = 80 # Set a default value if the stty command fails
187
+
188
+ TOTAL_BAR_LENGTH = 65.
189
+ last_time = time.time()
190
+ begin_time = last_time
191
+
192
+
193
+ def progress_bar(current, total, msg=None):
194
+ global last_time, begin_time
195
+ if current == 0:
196
+ begin_time = time.time() # reset for new bar.
197
+
198
+ cur_len = int(TOTAL_BAR_LENGTH*current/total)
199
+ rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
200
+
201
+ sys.stdout.write(' [')
202
+ for i in range(cur_len):
203
+ sys.stdout.write('=')
204
+ sys.stdout.write('>')
205
+ for i in range(rest_len):
206
+ sys.stdout.write('.')
207
+ sys.stdout.write(']')
208
+
209
+ cur_time = time.time()
210
+ step_time = cur_time - last_time
211
+ last_time = cur_time
212
+ tot_time = cur_time - begin_time
213
+
214
+ L = []
215
+ L.append(' Step: %s' % format_time(step_time))
216
+ L.append(' | Tot: %s' % format_time(tot_time))
217
+ if msg:
218
+ L.append(' | ' + msg)
219
+
220
+ msg = ''.join(L)
221
+ sys.stdout.write(msg)
222
+ for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
223
+ sys.stdout.write(' ')
224
+
225
+ # Go back to the center of the bar.
226
+ for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
227
+ sys.stdout.write('\b')
228
+ sys.stdout.write(' %d/%d ' % (current+1, total))
229
+
230
+ if current < total-1:
231
+ sys.stdout.write('\r')
232
+ else:
233
+ sys.stdout.write('\n')
234
+ sys.stdout.flush()
235
+
236
+
237
+ def format_time(seconds):
238
+ days = int(seconds / 3600/24)
239
+ seconds = seconds - days*3600*24
240
+ hours = int(seconds / 3600)
241
+ seconds = seconds - hours*3600
242
+ minutes = int(seconds / 60)
243
+ seconds = seconds - minutes*60
244
+ secondsf = int(seconds)
245
+ seconds = seconds - secondsf
246
+ millis = int(seconds*1000)
247
+
248
+ f = ''
249
+ i = 1
250
+ if days > 0:
251
+ f += str(days) + 'D'
252
+ i += 1
253
+ if hours > 0 and i <= 2:
254
+ f += str(hours) + 'h'
255
+ i += 1
256
+ if minutes > 0 and i <= 2:
257
+ f += str(minutes) + 'm'
258
+ i += 1
259
+ if secondsf > 0 and i <= 2:
260
+ f += str(secondsf) + 's'
261
+ i += 1
262
+ if millis > 0 and i <= 2:
263
+ f += str(millis) + 'ms'
264
+ i += 1
265
+ if f == '':
266
+ f = '0ms'
267
+ return f