BBracke commited on
Commit
ce0919b
·
1 Parent(s): 7bf71f5

Upload 13 files

Browse files
classDist_HMP_missedRemoved.p ADDED
Binary file (28.8 kB). View file
 
code_class_mapping_obid.csv ADDED
The diff for this file is too large to render. See raw diff
 
exp1/convnext2b_exp1_baselineFE.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+
22
+
23
+ from torchmetrics import MeanMetric
24
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
25
+ from torchmetrics import MetricCollection
26
+
27
+ import wandb
28
+
29
+ import matplotlib.pyplot as plt
30
+
31
+
32
+ # ### parameters
33
+ ################## Settings #############################
34
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
35
+ torch.backends.cudnn.benchmark = True
36
+
37
+ ################## Data Paths ##########################
38
+ MODEL_DIR = "./convnext2b_baselineFE_iNet21k/"
39
+
40
+ if not os.path.exists(MODEL_DIR):
41
+ os.makedirs(MODEL_DIR)
42
+ shutil.copyfile('./convnext2b_exp1_baselineFE.py', f'{MODEL_DIR}convnext2b_exp1_baselineFE.py')
43
+
44
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
45
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
46
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
47
+
48
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
49
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
50
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
51
+
52
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
53
+
54
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
55
+
56
+
57
+ NUM_CLASSES = 1784
58
+
59
+ ################## Hyperparameters ########################
60
+ NUM_EPOCHS = 30
61
+ WARMUP_EPOCHS = 5 # num. epochs only training classification head of model
62
+ RESUME_EPOCH = 0 # epoch to resume from model, optimizer checkpoints
63
+
64
+
65
+ LEARNING_RATE = {
66
+ 'cnn': 1e-05,
67
+ 'classifier': 1e-04,
68
+ }
69
+
70
+ BATCH_SIZE = {
71
+ 'train': 128,
72
+ 'valid': 96,
73
+ 'grad_acc': 1, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
74
+ }
75
+
76
+ BATCH_SIZE_AFTER_WARMUP = {
77
+ 'train': 64,
78
+ 'valid': 96,
79
+ 'grad_acc': 2, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
80
+ }
81
+
82
+ TRANSFORMS = {
83
+ 'IMAGE_SIZE_TRAIN': 384,
84
+ 'IMAGE_SIZE_VAL': 384,
85
+ 'RandAug' : {
86
+ 'm': 7,
87
+ 'n': 2
88
+ }
89
+ }
90
+
91
+
92
+ ############# Checkpoints ####################
93
+ CHECKPOINTS = {
94
+ 'fe_cnn': None, # main differents of runs of experiment 1, iNaturalist pre-trained model checkpoints available at "https://huggingface.co/BBracke/convnextv2_base.inat21_384"
95
+ 'model': None,
96
+ 'optimizer': None,
97
+ 'scaler': None,
98
+ }
99
+
100
+
101
+ ################### WandB ##################
102
+ WANDB = False
103
+
104
+ if WANDB:
105
+ wandb.init(
106
+ entity="snakeclef2023", # our team at wandb
107
+
108
+ # set the wandb project where this run will be logged
109
+ project="exp1", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
110
+
111
+ # define a name for this run
112
+ name="iNet21k",
113
+
114
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
115
+ config={
116
+ "learning_rate": LEARNING_RATE,
117
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
118
+ "pretrained": "iNet21",
119
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
120
+ "epochs": NUM_EPOCHS,
121
+ "transforms": TRANSFORMS,
122
+ "checkpoints": CHECKPOINTS,
123
+ "model_dir": MODEL_DIR
124
+ # ... any other hyperparameter that is necessary to reproduce the result
125
+ },
126
+ save_code=True, # save the script file as backup
127
+ dir=MODEL_DIR # locally folder where wandb log files are saved
128
+ )
129
+
130
+
131
+
132
+
133
+ ##################### Dataset & AugTransforms #####################################
134
+ # ### dataset & loaders
135
+ class SnakeTrainDataset(Dataset):
136
+ def __init__(self, data, ccm, transform=None):
137
+ self.data = data
138
+ self.transform = transform # Image augmentation pipeline
139
+ self.code_class_mapping = ccm
140
+
141
+ def __len__(self):
142
+ return self.data.shape[0]
143
+
144
+ def __getitem__(self, index):
145
+ obj = self.data.iloc[index] # get instance
146
+ label = obj.class_id # get label
147
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
148
+
149
+ img = Image.open(obj.image_path).convert("RGB") # load image
150
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
151
+
152
+ # img. augmentation
153
+ img = self.transform(img)
154
+
155
+ return (img, label, ccm)
156
+
157
+
158
+ # valid data preprocessing pipeline
159
+ def get_val_preprocessing(img_size):
160
+ print(f'IMG_SIZE_VAL: {img_size}')
161
+ return transforms.Compose([
162
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
163
+ transforms.Compose([
164
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
165
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
166
+ ]),
167
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
168
+ ])
169
+
170
+ class IdentityTransform:
171
+ def __call__(self, x):
172
+ return x
173
+
174
+
175
+ # train data augmentation/ preprocessing pipeline
176
+ def get_train_augmentation_preprocessing(img_size, rand_aug=False):
177
+ print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}')
178
+ return transforms.Compose([
179
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
180
+ transforms.RandomHorizontalFlip(p=0.5),
181
+ transforms.RandomVerticalFlip(p=0.5),
182
+ transforms.RandomCrop((img_size, img_size)), # Random Crop to IMAGE_SIZE
183
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(),
184
+ transforms.ToTensor(),
185
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
186
+ ])
187
+
188
+
189
+ def get_datasets(train_transfroms, val_transforms):
190
+ # load CSVs
191
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
192
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
193
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
194
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
195
+
196
+ # delete missing files of train data table
197
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
198
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
199
+
200
+ # add image path
201
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
202
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
203
+
204
+ # add additional data
205
+ if ADD_TRAINDATA_CONFIG:
206
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
207
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
208
+ train_data = pd.concat([train_data, add_train_data], axis=0)
209
+
210
+ # limit data size
211
+ #train_data = train_data.head(1000)
212
+ #valid_data = valid_data.head(1000)
213
+ print(f'train data shape: {train_data.shape}')
214
+
215
+ # shuffle
216
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
217
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
218
+
219
+ # load transposed version of CCM table
220
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
221
+
222
+ # create datasets
223
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
224
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
225
+
226
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
227
+
228
+
229
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
230
+ # get train, valid augmentation & preprocessing pipelines
231
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
232
+ val_preprocessing = get_val_preprocessing(imgsize_val)
233
+ # prepare the datasets
234
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
235
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
236
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
237
+
238
+ return train_loader, valid_loader
239
+
240
+
241
+ # #################### plot train history #########################
242
+
243
+ def plot_history(logs):
244
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
245
+
246
+ ax[0].plot(logs['loss'], label="train data")
247
+ ax[0].plot(logs['val_loss'], label="valid data")
248
+ ax[0].legend(loc="best")
249
+ ax[0].set_ylabel("loss")
250
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
251
+ #ax[0].set_xlabel("epochs")
252
+ ax[0].set_title("train- vs. valid loss")
253
+
254
+ ax[1].plot(logs['acc'], label="train data")
255
+ ax[1].plot(logs['val_acc'], label="valid data")
256
+ ax[1].legend(loc="best")
257
+ ax[1].set_ylabel("accuracy")
258
+ ax[1].set_ylim([0, 1.01])
259
+ #ax[1].set_xlabel("epochs")
260
+ ax[1].set_title("train- vs. valid accuracy")
261
+
262
+ ax[2].plot(logs['f1'], label="train data")
263
+ ax[2].plot(logs['val_f1'], label="valid data")
264
+ ax[2].legend(loc="best")
265
+ ax[2].set_ylabel("f1")
266
+ ax[2].set_ylim([0, 1.01])
267
+ ax[2].set_xlabel("epochs")
268
+ ax[2].set_title("train- vs. valid f1")
269
+
270
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
271
+ plt.show()
272
+
273
+
274
+ # #################### Model #####################################
275
+
276
+ class FeatureExtractor(nn.Module):
277
+ def __init__(self):
278
+ super(FeatureExtractor, self).__init__()
279
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
280
+ if CHECKPOINTS['fe_cnn']:
281
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
282
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
283
+ torch.cuda.empty_cache()
284
+
285
+ def forward(self, img):
286
+ conv_features = self.conv_backbone(img)
287
+ return conv_features
288
+
289
+
290
+ class Classifier(nn.Module):
291
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
292
+ super(Classifier, self).__init__()
293
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
294
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
295
+
296
+ def forward(self, embeddings):
297
+ dropped_feature = self.dropout(embeddings)
298
+ outputs = self.classifier(dropped_feature)
299
+
300
+ return outputs
301
+
302
+
303
+ class Model(nn.Module):
304
+ def __init__(self):
305
+ super(Model, self).__init__()
306
+ self.feature_extractor = FeatureExtractor()
307
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024, dropout=0.25)
308
+
309
+ def forward(self, img):
310
+ img_features = self.feature_extractor(img)
311
+ classifier_outputs = self.classifier(img_features)
312
+ return classifier_outputs
313
+
314
+
315
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
316
+ if CHECKPOINTS['model'] and model is not None:
317
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
318
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
319
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
320
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
321
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
322
+ if CHECKPOINTS['scaler'] and scaler is not None:
323
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
324
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
325
+ torch.cuda.empty_cache()
326
+
327
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
328
+ if model is not None:
329
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
330
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
331
+ if optimizer is not None:
332
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
333
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
334
+
335
+ if scaler is not None:
336
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
337
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
338
+ torch.cuda.empty_cache()
339
+
340
+
341
+ def resume_logs(logs):
342
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
343
+ for m in list(logs.keys()):
344
+ logs[m].extend(list(old_logs[m].values))
345
+
346
+ ######################## Optimizer #####################################
347
+ def get_optm_group(module):
348
+ """
349
+ This long function is unfortunately doing something very simple and is being very defensive:
350
+ We are separating out all parameters of the model into two buckets: those that will experience
351
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
352
+ We are then returning the PyTorch optimizer object.
353
+ """
354
+
355
+ # separate out all parameters to those that will and won't experience regularizing weight decay
356
+ decay = set()
357
+ no_decay = set()
358
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
359
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
360
+ for mn, m in module.named_modules():
361
+ for pn, p in m.named_parameters():
362
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
363
+
364
+ if pn.endswith('bias'):
365
+ # all biases will not be decayed
366
+ no_decay.add(fpn)
367
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
368
+ # weights of whitelist modules will be weight decayed
369
+ decay.add(fpn)
370
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
371
+ # weights of blacklist modules will NOT be weight decayed
372
+ no_decay.add(fpn)
373
+
374
+
375
+ # validate that we considered every parameter
376
+ param_dict = {pn: p for pn, p in module.named_parameters()}
377
+ inter_params = decay & no_decay
378
+ union_params = decay | no_decay
379
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
380
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
381
+ % (str(param_dict.keys() - union_params), )
382
+
383
+ return param_dict, decay, no_decay
384
+
385
+
386
+ def get_warmup_optimizer(model):
387
+ params_group = []
388
+
389
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
390
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
391
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
392
+
393
+ optimizer = torch.optim.AdamW(params_group)
394
+ return optimizer
395
+
396
+
397
+ def get_after_warmup_optimizer(model, old_opt):
398
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
399
+
400
+ # add old param groups
401
+ for group in old_opt.param_groups:
402
+ new_opt.add_param_group(group)
403
+
404
+ return new_opt
405
+
406
+
407
+ # #################### Model Warmup #####################################
408
+
409
+ def warmup_start(model):
410
+ # freeze model feature_extractor.conv_backbone during warmup
411
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
412
+ param.requires_grad = False
413
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
414
+
415
+ def warmup_end(model):
416
+ # unfreeze feature_extractor.conv_backbone during warmup
417
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
418
+ param.requires_grad = True
419
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
420
+
421
+
422
+ # #################### Train Loop #####################################
423
+
424
+ # ### train
425
+ def main():
426
+ device = torch.device(f'cuda:1')
427
+ torch.cuda.set_device(device)
428
+
429
+ # prepare the datasets
430
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
431
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
432
+ rand_aug=True)
433
+
434
+ # instantiate the model
435
+ model = Model().to(device)
436
+ #load_checkpoints(model=model)
437
+ if RESUME_EPOCH > 0:
438
+ resume_checkpoints(model=model)
439
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
440
+ warmup_start(model)
441
+
442
+ # Optimizer & Schedules & early stopping
443
+ optimizer = get_warmup_optimizer(model)
444
+ scaler = GradScaler()
445
+ #load_checkpoints(optimizer=optimizer, scaler=scaler)
446
+ if RESUME_EPOCH > 0:
447
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
448
+
449
+ loss_fn = nn.CrossEntropyLoss()
450
+ loss_val_fn = nn.CrossEntropyLoss()
451
+
452
+ # running metrics during training
453
+ loss_metric = MeanMetric().to(device)
454
+ metrics = MetricCollection(metrics={
455
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
456
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
457
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
458
+ }).to(device)
459
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
460
+
461
+ # start time of trainig
462
+ start_training = time.perf_counter()
463
+ # create log dict
464
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
465
+ if RESUME_EPOCH > 0:
466
+ resume_logs(logs)
467
+
468
+ #iterate over epochs
469
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
470
+ for epoch in range(start_epoch, NUM_EPOCHS):
471
+ # start time of epoch
472
+ epoch_start = time.perf_counter()
473
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
474
+
475
+ ######################## toggle warmup ########################################
476
+ if (epoch) == WARMUP_EPOCHS:
477
+ warmup_end(model)
478
+ optimizer = get_after_warmup_optimizer(model, optimizer)
479
+ global BATCH_SIZE
480
+ BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP
481
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
482
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
483
+ rand_aug=True)
484
+
485
+ elif (epoch) < WARMUP_EPOCHS:
486
+ print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}')
487
+
488
+ ############################## train phase ####################################
489
+ model.train()
490
+
491
+ # zero the parameter gradients
492
+ optimizer.zero_grad(set_to_none=True)
493
+
494
+ # grad acc loss divider
495
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
496
+
497
+ # iterate over training batches
498
+ for batch_idx, (inputs, labels, ccm) in enumerate(train_loader):
499
+ inputs = inputs.to(device, non_blocking=True)
500
+ labels = labels.to(device, non_blocking=True)
501
+ ccm = ccm.to(device, non_blocking=True)
502
+
503
+ # forward with mixed precision
504
+ with autocast(device_type='cuda', dtype=torch.float16):
505
+ outputs = model(inputs)
506
+ loss = loss_fn(outputs, labels) / loss_div
507
+
508
+ # loss backward
509
+ scaler.scale(loss).backward()
510
+
511
+ # Compute metrics
512
+ loss_metric.update((loss * loss_div).detach())
513
+
514
+ preds = outputs.softmax(dim=-1).detach()
515
+ metrics.update(preds, labels)
516
+ metric_ccm.update(preds * ccm, labels)
517
+
518
+ ############################ grad acc ##############################
519
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
520
+ #scaler.unscale_(optimizer)
521
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
522
+ scaler.step(optimizer)
523
+ scaler.update()
524
+ # zero the parameter gradients
525
+ optimizer.zero_grad(set_to_none=True)
526
+ # update ema model
527
+ ema_model.update(model)
528
+
529
+
530
+ # compute, sync & reset metrics for validation
531
+ epoch_loss = loss_metric.compute()
532
+ epoch_metrics = metrics.compute()
533
+ epoch_metric_ccm = metric_ccm.compute()
534
+
535
+ loss_metric.reset()
536
+ metrics.reset()
537
+ metric_ccm.reset()
538
+
539
+ # Append metric results to logs
540
+ logs['loss'].append(epoch_loss.cpu().item())
541
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
542
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
543
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
544
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
545
+
546
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
547
+
548
+ # zero the parameter gradients
549
+ optimizer.zero_grad(set_to_none=True)
550
+
551
+ del inputs, labels, ccm, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
552
+ torch.cuda.empty_cache()
553
+
554
+ ############################## valid phase ####################################
555
+ with torch.no_grad():
556
+ model.eval()
557
+
558
+ # iterate over validation batches
559
+ for (inputs, labels, ccm) in valid_loader:
560
+ inputs = inputs.to(device, non_blocking=True)
561
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
562
+ labels = labels.to(device, non_blocking=True)
563
+ ccm = ccm.to(device, non_blocking=True)
564
+
565
+ # forward with mixed precision
566
+ with autocast(device_type='cuda', dtype=torch.float16):
567
+ outputs = model(inputs)
568
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
569
+ loss = loss_val_fn(outputs, labels)
570
+
571
+ # Compute metrics
572
+ loss_metric.update(loss.detach())
573
+
574
+ preds = outputs.softmax(dim=-1).detach()
575
+ metrics.update(preds, labels)
576
+ metric_ccm.update(preds * ccm, labels)
577
+
578
+ # compute, sync & reset metrics for validation
579
+ epoch_loss = loss_metric.compute()
580
+ epoch_metrics = metrics.compute()
581
+ epoch_metric_ccm = metric_ccm.compute()
582
+
583
+ loss_metric.reset()
584
+ metrics.reset()
585
+ metric_ccm.reset()
586
+
587
+ # Append metric results to logs
588
+ logs['val_loss'].append(epoch_loss.cpu().item())
589
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
590
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
591
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
592
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
593
+
594
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
595
+
596
+ del inputs, labels, ccm, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
597
+ torch.cuda.empty_cache()
598
+
599
+ # save logs as csv
600
+ logs_df = pd.DataFrame(logs)
601
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
602
+
603
+ if WANDB:
604
+ # at the end of each epoch, log anything you want to log for that epoch
605
+ wandb.log(
606
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
607
+ step=epoch # epoch index for wandb
608
+ )
609
+
610
+ #save trained model for each epoch
611
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
612
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
613
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
614
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
615
+
616
+ # end time of epoch
617
+ epoch_end = time.perf_counter()
618
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
619
+
620
+ del logs_df, epoch_start, epoch_end
621
+ torch.cuda.empty_cache()
622
+
623
+ ################################## EMA Model Validation ################################
624
+ del model
625
+ torch.cuda.empty_cache()
626
+
627
+ ema_net = ema_model.module
628
+ ema_net.eval()
629
+
630
+ with torch.no_grad():
631
+ # iterate over validation batches
632
+ for (inputs, labels, ccm) in valid_loader:
633
+ inputs = inputs.to(device, non_blocking=True)
634
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
635
+ labels = labels.to(device, non_blocking=True)
636
+ ccm = ccm.to(device, non_blocking=True)
637
+
638
+ # forward with mixed precision
639
+ with autocast(device_type='cuda', dtype=torch.float16):
640
+ outputs = ema_net(inputs)
641
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
642
+ loss = loss_val_fn(outputs, labels)
643
+
644
+ # Compute metrics
645
+ loss_metric.update(loss.detach())
646
+
647
+ preds = outputs.softmax(dim=-1).detach()
648
+ metrics.update(preds, labels)
649
+ metric_ccm.update(preds * ccm, labels)
650
+
651
+ # compute, sync & reset metrics for validation
652
+ epoch_loss = loss_metric.compute()
653
+ epoch_metrics = metrics.compute()
654
+ epoch_metric_ccm = metric_ccm.compute()
655
+
656
+ loss_metric.reset()
657
+ metrics.reset()
658
+ metric_ccm.reset()
659
+
660
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
661
+
662
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
663
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
664
+
665
+ plot_history(logs)
666
+ # end time of trainig
667
+ end_training = time.perf_counter()
668
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
669
+
670
+ if WANDB:
671
+ wandb.finish()
672
+
673
+
674
+ if __name__=="__main__":
675
+ main()
676
+
677
+
678
+
679
+
exp2/convnext2b_exp2_imgSizes_e10.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+
22
+
23
+ from torchmetrics import MeanMetric
24
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
25
+ from torchmetrics import MetricCollection
26
+
27
+
28
+ import wandb
29
+
30
+ import matplotlib.pyplot as plt
31
+
32
+
33
+ # ### parameters
34
+ ################## Settings #############################
35
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
+ torch.backends.cudnn.benchmark = True
37
+
38
+ ################## Data Paths ##########################
39
+ MODEL_DIR = "./convnext2b_imgSize_464px/"
40
+
41
+ if not os.path.exists(MODEL_DIR):
42
+ os.makedirs(MODEL_DIR)
43
+ shutil.copyfile('./convnext2b_exp2_imgSizes_e10.py', f'{MODEL_DIR}convnext2b_exp2_imgSizes_e10.py')
44
+
45
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
46
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
47
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
48
+
49
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
50
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
51
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
52
+
53
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
54
+
55
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
56
+
57
+
58
+ NUM_CLASSES = 1784
59
+
60
+ ################## Hyperparameters ########################
61
+ NUM_EPOCHS = 40
62
+ RESUME_EPOCH = 29 # resume model, optimizer from epoch 29 of experiment 1, checkpoint files need to be copied to the MODEL_DIR folder
63
+
64
+
65
+ LEARNING_RATE = {
66
+ 'cnn': 1e-05,
67
+ 'classifier': 1e-04,
68
+ }
69
+
70
+ BATCH_SIZE = {
71
+ 'train': 32,
72
+ 'valid': 64,
73
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
74
+ }
75
+
76
+ TRANSFORMS = {
77
+ 'IMAGE_SIZE_TRAIN': 464, # set image sizes here, main differents of runs in experiment 2, i.e. 384px, 464px, 544px, 624px
78
+ 'IMAGE_SIZE_VAL': 464,
79
+ 'RandAug' : {
80
+ 'm': 7,
81
+ 'n': 2
82
+ },
83
+ }
84
+
85
+
86
+ ############# Checkpoints ####################
87
+ CHECKPOINTS = {
88
+ 'fe_cnn': None, # iNaturalist pre-trained model checkpoints available at "https://huggingface.co/BBracke/convnextv2_base.inat21_384"
89
+ 'model': None,
90
+ 'optimizer': None,
91
+ 'scaler': None,
92
+ }
93
+
94
+ ################### WandB ##################
95
+ WANDB = False
96
+
97
+ if WANDB:
98
+ wandb.init(
99
+ entity="snakeclef2023", # our team at wandb
100
+
101
+ # set the wandb project where this run will be logged
102
+ project="exp2", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
103
+
104
+ # define a name for this run
105
+ name="464px",
106
+
107
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
108
+ config={
109
+ "learning_rate": LEARNING_RATE,
110
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
111
+ "pretrained": "iNat21",
112
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
113
+ "epochs": NUM_EPOCHS,
114
+ "transforms": TRANSFORMS,
115
+ "checkpoints": CHECKPOINTS,
116
+ "model_dir": MODEL_DIR
117
+ # ... any other hyperparameter that is necessary to reproduce the result
118
+ },
119
+ save_code=True, # save the script file as backup
120
+ dir=MODEL_DIR # locally folder where wandb log files are saved
121
+ )
122
+
123
+
124
+
125
+
126
+ ##################### Dataset & AugTransforms #####################################
127
+ # ### dataset & loaders
128
+ class SnakeTrainDataset(Dataset):
129
+ def __init__(self, data, ccm, transform=None):
130
+ self.data = data
131
+ self.transform = transform # Image augmentation pipeline
132
+ self.code_class_mapping = ccm
133
+
134
+ def __len__(self):
135
+ return self.data.shape[0]
136
+
137
+ def __getitem__(self, index):
138
+ obj = self.data.iloc[index] # get instance
139
+ label = obj.class_id # get label
140
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
141
+
142
+ img = Image.open(obj.image_path).convert("RGB") # load image
143
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
144
+
145
+ # img. augmentation
146
+ img = self.transform(img)
147
+
148
+ return (img, label, ccm)
149
+
150
+
151
+ # valid data preprocessing pipeline
152
+ def get_val_preprocessing(img_size):
153
+ print(f'IMG_SIZE_VAL: {img_size}')
154
+ return transforms.Compose([
155
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
156
+ transforms.Compose([
157
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
158
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
159
+ ]),
160
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
161
+ ])
162
+
163
+ class IdentityTransform:
164
+ def __call__(self, x):
165
+ return x
166
+
167
+
168
+ # train data augmentation/ preprocessing pipeline
169
+ def get_train_augmentation_preprocessing(img_size, rand_aug=False):
170
+ print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}')
171
+ return transforms.Compose([
172
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
173
+ transforms.RandomHorizontalFlip(p=0.5),
174
+ transforms.RandomVerticalFlip(p=0.5),
175
+ transforms.RandomCrop((img_size, img_size)), # Random Crop to IMAGE_SIZE
176
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(),
177
+ transforms.ToTensor(),
178
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
179
+ ])
180
+
181
+
182
+ def get_datasets(train_transfroms, val_transforms):
183
+ # load CSVs
184
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
185
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
186
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
187
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
188
+
189
+ # delete missing files of train data table
190
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
191
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
192
+
193
+ # add image path
194
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
195
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
196
+
197
+ # add additional data
198
+ if ADD_TRAINDATA_CONFIG:
199
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
200
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
201
+ train_data = pd.concat([train_data, add_train_data], axis=0)
202
+
203
+ # limit data size
204
+ #train_data = train_data.head(1000)
205
+ #valid_data = valid_data.head(1000)
206
+ print(f'train data shape: {train_data.shape}')
207
+
208
+ # shuffle
209
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
210
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
211
+
212
+ # load transposed version of CCM table
213
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
214
+
215
+ # create datasets
216
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
217
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
218
+
219
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
220
+
221
+
222
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
223
+ # get train, valid augmentation & preprocessing pipelines
224
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
225
+ val_preprocessing = get_val_preprocessing(imgsize_val)
226
+ # prepare the datasets
227
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
228
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
229
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
230
+
231
+ return train_loader, valid_loader
232
+
233
+
234
+ # #################### plot train history #########################
235
+
236
+ def plot_history(logs):
237
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
238
+
239
+ ax[0].plot(logs['loss'], label="train data")
240
+ ax[0].plot(logs['val_loss'], label="valid data")
241
+ ax[0].legend(loc="best")
242
+ ax[0].set_ylabel("loss")
243
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
244
+ #ax[0].set_xlabel("epochs")
245
+ ax[0].set_title("train- vs. valid loss")
246
+
247
+ ax[1].plot(logs['acc'], label="train data")
248
+ ax[1].plot(logs['val_acc'], label="valid data")
249
+ ax[1].legend(loc="best")
250
+ ax[1].set_ylabel("accuracy")
251
+ ax[1].set_ylim([0, 1.01])
252
+ #ax[1].set_xlabel("epochs")
253
+ ax[1].set_title("train- vs. valid accuracy")
254
+
255
+ ax[2].plot(logs['f1'], label="train data")
256
+ ax[2].plot(logs['val_f1'], label="valid data")
257
+ ax[2].legend(loc="best")
258
+ ax[2].set_ylabel("f1")
259
+ ax[2].set_ylim([0, 1.01])
260
+ ax[2].set_xlabel("epochs")
261
+ ax[2].set_title("train- vs. valid f1")
262
+
263
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
264
+ plt.show()
265
+
266
+ # #################### Model #####################################
267
+
268
+ class FeatureExtractor(nn.Module):
269
+ def __init__(self):
270
+ super(FeatureExtractor, self).__init__()
271
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
272
+ if CHECKPOINTS['fe_cnn']:
273
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
274
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
275
+ torch.cuda.empty_cache()
276
+
277
+ def forward(self, img):
278
+ conv_features = self.conv_backbone(img)
279
+ return conv_features
280
+
281
+ class Classifier(nn.Module):
282
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
283
+ super(Classifier, self).__init__()
284
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
285
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
286
+
287
+ def forward(self, embeddings):
288
+ dropped_feature = self.dropout(embeddings)
289
+ outputs = self.classifier(dropped_feature)
290
+
291
+ return outputs
292
+
293
+
294
+ class Model(nn.Module):
295
+ def __init__(self):
296
+ super(Model, self).__init__()
297
+ self.feature_extractor = FeatureExtractor()
298
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024, dropout=0.25)
299
+
300
+ def forward(self, img):
301
+ img_features = self.feature_extractor(img)
302
+ classifier_outputs = self.classifier(img_features)
303
+ return classifier_outputs
304
+
305
+
306
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
307
+ if CHECKPOINTS['model'] and model is not None:
308
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
309
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
310
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
311
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
312
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
313
+ if CHECKPOINTS['scaler'] and scaler is not None:
314
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
315
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
316
+ torch.cuda.empty_cache()
317
+
318
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
319
+ if model is not None:
320
+ model.load_state_dict(torch.load(f'{MODEL_DIR}ema_model_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
321
+ print(f"use model checkpoints: {MODEL_DIR}ema_model_epoch{RESUME_EPOCH}.pth")
322
+ if optimizer is not None:
323
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
324
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
325
+ if scaler is not None:
326
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
327
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
328
+ torch.cuda.empty_cache()
329
+
330
+
331
+ def resume_logs(logs):
332
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
333
+ for m in list(logs.keys()):
334
+ logs[m].extend(list(old_logs[m].values))
335
+
336
+ ######################## Optimizer #####################################
337
+ def get_optm_group(module):
338
+ """
339
+ This long function is unfortunately doing something very simple and is being very defensive:
340
+ We are separating out all parameters of the model into two buckets: those that will experience
341
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
342
+ We are then returning the PyTorch optimizer object.
343
+ """
344
+
345
+ # separate out all parameters to those that will and won't experience regularizing weight decay
346
+ decay = set()
347
+ no_decay = set()
348
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
349
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
350
+ for mn, m in module.named_modules():
351
+ for pn, p in m.named_parameters():
352
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
353
+
354
+ if pn.endswith('bias'):
355
+ # all biases will not be decayed
356
+ no_decay.add(fpn)
357
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
358
+ # weights of whitelist modules will be weight decayed
359
+ decay.add(fpn)
360
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
361
+ # weights of blacklist modules will NOT be weight decayed
362
+ no_decay.add(fpn)
363
+
364
+
365
+ # validate that we considered every parameter
366
+ param_dict = {pn: p for pn, p in module.named_parameters()}
367
+ inter_params = decay & no_decay
368
+ union_params = decay | no_decay
369
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
370
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
371
+ % (str(param_dict.keys() - union_params), )
372
+
373
+ return param_dict, decay, no_decay
374
+
375
+
376
+ def get_optimizer(model):
377
+ optimizer = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
378
+
379
+ params_group = []
380
+
381
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
382
+ optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
383
+ optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
384
+
385
+ return optimizer
386
+
387
+
388
+ # #################### Model FixRes #####################################
389
+
390
+ #def fixres(model):
391
+ # # freeze model during fixres
392
+ # for i, (param_name, param) in enumerate(model.named_parameters()):
393
+ # param.requires_grad = False
394
+ #
395
+ # # unfreeze last layers of feature extractor
396
+ # for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.head.named_parameters()):
397
+ # param.requires_grad = True
398
+ #
399
+ # # unfreeze classifier
400
+ # for i, (param_name, param) in enumerate(model.classifier.named_parameters()):
401
+ # param.requires_grad = True
402
+
403
+ # #################### Train Loop #####################################
404
+
405
+ # ### train
406
+ def main():
407
+ device = torch.device(f'cuda:1')
408
+ torch.cuda.set_device(device)
409
+
410
+ # prepare the datasets
411
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
412
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
413
+ rand_aug=True)
414
+
415
+ # instantiate the model
416
+ model = Model().to(device)
417
+ #load_checkpoints(model=model)
418
+ if RESUME_EPOCH > 0:
419
+ resume_checkpoints(model=model)
420
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
421
+
422
+ # Optimizer & Schedules & early stopping
423
+ optimizer = get_optimizer(model)
424
+ scaler = GradScaler()
425
+ #load_checkpoints(optimizer=optimizer, scaler=scaler)
426
+ if RESUME_EPOCH > 0:
427
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
428
+
429
+ loss_fn = nn.CrossEntropyLoss() #FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'])
430
+ loss_val_fn = nn.CrossEntropyLoss()
431
+
432
+ # running metrics during training
433
+ loss_metric = MeanMetric().to(device)
434
+ metrics = MetricCollection(metrics={
435
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
436
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
437
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
438
+ }).to(device)
439
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
440
+
441
+ # start time of trainig
442
+ start_training = time.perf_counter()
443
+ # create log dict
444
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
445
+ if RESUME_EPOCH > 0:
446
+ resume_logs(logs)
447
+
448
+ #iterate over epochs
449
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
450
+ for epoch in range(start_epoch, NUM_EPOCHS):
451
+ # start time of epoch
452
+ epoch_start = time.perf_counter()
453
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
454
+
455
+ ############################## train phase ####################################
456
+ model.train()
457
+
458
+ # zero the parameter gradients
459
+ optimizer.zero_grad(set_to_none=True)
460
+
461
+ # grad acc loss divider
462
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
463
+
464
+ # iterate over training batches
465
+ for batch_idx, (inputs, labels, ccm) in enumerate(train_loader):
466
+ inputs = inputs.to(device, non_blocking=True)
467
+ labels = labels.to(device, non_blocking=True)
468
+ ccm = ccm.to(device, non_blocking=True)
469
+
470
+ # forward with mixed precision
471
+ with autocast(device_type='cuda', dtype=torch.float16):
472
+ outputs = model(inputs)
473
+ loss = loss_fn(outputs, labels) / loss_div
474
+
475
+ # loss backward
476
+ scaler.scale(loss).backward()
477
+
478
+ # Compute metrics
479
+ loss_metric.update((loss * loss_div).detach())
480
+
481
+ preds = outputs.softmax(dim=-1).detach()
482
+ metrics.update(preds, labels)
483
+ metric_ccm.update(preds * ccm, labels)
484
+
485
+ ############################ grad acc ##############################
486
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
487
+ #scaler.unscale_(optimizer)
488
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
489
+ scaler.step(optimizer)
490
+ scaler.update()
491
+ # zero the parameter gradients
492
+ optimizer.zero_grad(set_to_none=True)
493
+ # update ema model
494
+ ema_model.update(model)
495
+
496
+
497
+ # compute, sync & reset metrics for validation
498
+ epoch_loss = loss_metric.compute()
499
+ epoch_metrics = metrics.compute()
500
+ epoch_metric_ccm = metric_ccm.compute()
501
+
502
+ loss_metric.reset()
503
+ metrics.reset()
504
+ metric_ccm.reset()
505
+
506
+ # Append metric results to logs
507
+ logs['loss'].append(epoch_loss.cpu().item())
508
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
509
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
510
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
511
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
512
+
513
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
514
+
515
+ # zero the parameter gradients
516
+ optimizer.zero_grad(set_to_none=True)
517
+
518
+ del inputs, labels, ccm, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
519
+ torch.cuda.empty_cache()
520
+
521
+ ############################## valid phase ####################################
522
+ with torch.no_grad():
523
+ model.eval()
524
+
525
+ # iterate over validation batches
526
+ for (inputs, labels, ccm) in valid_loader:
527
+ inputs = inputs.to(device, non_blocking=True)
528
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
529
+ labels = labels.to(device, non_blocking=True)
530
+ ccm = ccm.to(device, non_blocking=True)
531
+
532
+ # forward with mixed precision
533
+ with autocast(device_type='cuda', dtype=torch.float16):
534
+ outputs = model(inputs)
535
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
536
+ loss = loss_val_fn(outputs, labels)
537
+
538
+ # Compute metrics
539
+ loss_metric.update(loss.detach())
540
+
541
+ preds = outputs.softmax(dim=-1).detach()
542
+ metrics.update(preds, labels)
543
+ metric_ccm.update(preds * ccm, labels)
544
+
545
+ # compute, sync & reset metrics for validation
546
+ epoch_loss = loss_metric.compute()
547
+ epoch_metrics = metrics.compute()
548
+ epoch_metric_ccm = metric_ccm.compute()
549
+
550
+ loss_metric.reset()
551
+ metrics.reset()
552
+ metric_ccm.reset()
553
+
554
+ # Append metric results to logs
555
+ logs['val_loss'].append(epoch_loss.cpu().item())
556
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
557
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
558
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
559
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
560
+
561
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
562
+
563
+ del inputs, labels, ccm, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
564
+ torch.cuda.empty_cache()
565
+
566
+ # save logs as csv
567
+ logs_df = pd.DataFrame(logs)
568
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
569
+
570
+ if WANDB:
571
+ # at the end of each epoch, log anything you want to log for that epoch
572
+ wandb.log(
573
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
574
+ step=epoch # epoch index for wandb
575
+ )
576
+
577
+ #save trained model for each epoch
578
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
579
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
580
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
581
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
582
+
583
+ # end time of epoch
584
+ epoch_end = time.perf_counter()
585
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
586
+
587
+ del logs_df, epoch_start, epoch_end
588
+ torch.cuda.empty_cache()
589
+
590
+ ################################## EMA Model Validation ################################
591
+ del model
592
+ torch.cuda.empty_cache()
593
+
594
+ ema_net = ema_model.module
595
+ ema_net.eval()
596
+
597
+ with torch.no_grad():
598
+ # iterate over validation batches
599
+ for (inputs, labels, ccm) in valid_loader:
600
+ inputs = inputs.to(device, non_blocking=True)
601
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
602
+ labels = labels.to(device, non_blocking=True)
603
+ ccm = ccm.to(device, non_blocking=True)
604
+
605
+ # forward with mixed precision
606
+ with autocast(device_type='cuda', dtype=torch.float16):
607
+ outputs = ema_net(inputs)
608
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
609
+ loss = loss_val_fn(outputs, labels)
610
+
611
+ # Compute metrics
612
+ loss_metric.update(loss.detach())
613
+
614
+ preds = outputs.softmax(dim=-1).detach()
615
+ metrics.update(preds, labels)
616
+ metric_ccm.update(preds * ccm, labels)
617
+
618
+ # compute, sync & reset metrics for validation
619
+ epoch_loss = loss_metric.compute()
620
+ epoch_metrics = metrics.compute()
621
+ epoch_metric_ccm = metric_ccm.compute()
622
+
623
+ loss_metric.reset()
624
+ metrics.reset()
625
+ metric_ccm.reset()
626
+
627
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
628
+
629
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
630
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
631
+
632
+ plot_history(logs)
633
+ # end time of trainig
634
+ end_training = time.perf_counter()
635
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
636
+
637
+ if WANDB:
638
+ wandb.finish()
639
+
640
+
641
+ if __name__=="__main__":
642
+ main()
643
+
644
+
645
+
646
+
exp2/convnext2b_exp2_imgSizes_e40.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+
22
+ from torchmetrics import MeanMetric
23
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
24
+ from torchmetrics import MetricCollection
25
+
26
+ import wandb
27
+
28
+ import matplotlib.pyplot as plt
29
+
30
+
31
+ # ### parameters
32
+ ################## Settings #############################
33
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
34
+ torch.backends.cudnn.benchmark = True
35
+
36
+ ################## Data Paths ##########################
37
+ MODEL_DIR = "./convnext2b_imgSize_544px_end2end/"
38
+
39
+ if not os.path.exists(MODEL_DIR):
40
+ os.makedirs(MODEL_DIR)
41
+ shutil.copyfile('./convnext2b_exp2_imgSizes_e40.py', f'{MODEL_DIR}convnext2b_exp2_imgSizes_e40.py')
42
+
43
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
44
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
45
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
46
+
47
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
48
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
49
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
50
+
51
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
52
+
53
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
54
+
55
+
56
+ NUM_CLASSES = 1784
57
+
58
+ ################## Hyperparameters ########################
59
+ WARMUP_EPOCHS = 5 # num. epochs only training classification head of model
60
+ NUM_EPOCHS = 40
61
+ RESUME_EPOCH = 0
62
+
63
+
64
+ LEARNING_RATE = {
65
+ 'cnn': 1e-05,
66
+ 'classifier': 1e-04,
67
+ }
68
+
69
+ BATCH_SIZE = {
70
+ 'train': 32,
71
+ 'valid': 48,
72
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
73
+ }
74
+
75
+ BATCH_SIZE_AFTER_WARMUP = {
76
+ 'train': 32,
77
+ 'valid': 48,
78
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
79
+ }
80
+
81
+ TRANSFORMS = {
82
+ 'IMAGE_SIZE_TRAIN': 544,
83
+ 'IMAGE_SIZE_VAL': 544,
84
+ 'RandAug' : {
85
+ 'm': 7,
86
+ 'n': 2
87
+ }
88
+ }
89
+
90
+
91
+
92
+ ############# Checkpoints ####################
93
+ CHECKPOINTS = {
94
+ 'fe_cnn': "./iNat21_convnext2b.pth", # iNaturalist pre-trained model checkpoints available at "https://huggingface.co/BBracke/convnextv2_base.inat21_384"
95
+ 'model': None,
96
+ 'optimizer': None,
97
+ 'scaler': None,
98
+ }
99
+
100
+
101
+ ################### WandB ##################
102
+ WANDB = False
103
+
104
+ if WANDB:
105
+ wandb.init(
106
+ entity="snakeclef2023", # our team at wandb
107
+
108
+ # set the wandb project where this run will be logged
109
+ project="exp2", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
110
+
111
+ # define a name for this run
112
+ name="544px_end2end",
113
+
114
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
115
+ config={
116
+ "learning_rate": LEARNING_RATE,
117
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
118
+ "pretrained": "iNat21",
119
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
120
+ "epochs": NUM_EPOCHS,
121
+ "transforms": TRANSFORMS,
122
+ "checkpoints": CHECKPOINTS,
123
+ "model_dir": MODEL_DIR
124
+ # ... any other hyperparameter that is necessary to reproduce the result
125
+ },
126
+ save_code=True, # save the script file as backup
127
+ dir=MODEL_DIR # locally folder where wandb log files are saved
128
+ )
129
+
130
+
131
+
132
+
133
+ ##################### Dataset & AugTransforms #####################################
134
+ # ### dataset & loaders
135
+ class SnakeTrainDataset(Dataset):
136
+ def __init__(self, data, ccm, transform=None):
137
+ self.data = data
138
+ self.transform = transform # Image augmentation pipeline
139
+ self.code_class_mapping = ccm
140
+
141
+ def __len__(self):
142
+ return self.data.shape[0]
143
+
144
+ def __getitem__(self, index):
145
+ obj = self.data.iloc[index] # get instance
146
+ label = obj.class_id # get label
147
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
148
+
149
+ img = Image.open(obj.image_path).convert("RGB") # load image
150
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
151
+
152
+ # img. augmentation
153
+ img = self.transform(img)
154
+
155
+ return (img, label, ccm)
156
+
157
+
158
+ # valid data preprocessing pipeline
159
+ def get_val_preprocessing(img_size):
160
+ print(f'IMG_SIZE_VAL: {img_size}')
161
+ return transforms.Compose([
162
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
163
+ transforms.Compose([
164
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
165
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
166
+ ]),
167
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
168
+ ])
169
+
170
+ class IdentityTransform:
171
+ def __call__(self, x):
172
+ return x
173
+
174
+
175
+ # train data augmentation/ preprocessing pipeline
176
+ def get_train_augmentation_preprocessing(img_size, rand_aug=False):
177
+ print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}')
178
+ return transforms.Compose([
179
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
180
+ transforms.RandomHorizontalFlip(p=0.5),
181
+ transforms.RandomVerticalFlip(p=0.5),
182
+ transforms.RandomCrop((img_size, img_size)), # Random Crop to IMAGE_SIZE
183
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(),
184
+ transforms.ToTensor(),
185
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
186
+ ])
187
+
188
+
189
+
190
+ def get_datasets(train_transfroms, val_transforms):
191
+ # load CSVs
192
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
193
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
194
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
195
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
196
+
197
+ # delete missing files of train data table
198
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
199
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
200
+
201
+ # add image path
202
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
203
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
204
+
205
+ # add additional data
206
+ if ADD_TRAINDATA_CONFIG:
207
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
208
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
209
+ train_data = pd.concat([train_data, add_train_data], axis=0)
210
+
211
+ # limit data size
212
+ #train_data = train_data.head(1000)
213
+ #valid_data = valid_data.head(1000)
214
+ print(f'train data shape: {train_data.shape}')
215
+
216
+ # shuffle
217
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
218
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
219
+
220
+ # load transposed version of CCM table
221
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
222
+
223
+ # create datasets
224
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
225
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
226
+
227
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
228
+
229
+
230
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
231
+ # get train, valid augmentation & preprocessing pipelines
232
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
233
+ val_preprocessing = get_val_preprocessing(imgsize_val)
234
+ # prepare the datasets
235
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
236
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
237
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
238
+
239
+ return train_loader, valid_loader
240
+
241
+
242
+ # #################### plot train history #########################
243
+
244
+ def plot_history(logs):
245
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
246
+
247
+ ax[0].plot(logs['loss'], label="train data")
248
+ ax[0].plot(logs['val_loss'], label="valid data")
249
+ ax[0].legend(loc="best")
250
+ ax[0].set_ylabel("loss")
251
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
252
+ #ax[0].set_xlabel("epochs")
253
+ ax[0].set_title("train- vs. valid loss")
254
+
255
+ ax[1].plot(logs['acc'], label="train data")
256
+ ax[1].plot(logs['val_acc'], label="valid data")
257
+ ax[1].legend(loc="best")
258
+ ax[1].set_ylabel("accuracy")
259
+ ax[1].set_ylim([0, 1.01])
260
+ #ax[1].set_xlabel("epochs")
261
+ ax[1].set_title("train- vs. valid accuracy")
262
+
263
+ ax[2].plot(logs['f1'], label="train data")
264
+ ax[2].plot(logs['val_f1'], label="valid data")
265
+ ax[2].legend(loc="best")
266
+ ax[2].set_ylabel("f1")
267
+ ax[2].set_ylim([0, 1.01])
268
+ ax[2].set_xlabel("epochs")
269
+ ax[2].set_title("train- vs. valid f1")
270
+
271
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
272
+ plt.show()
273
+
274
+
275
+ # #################### Model #####################################
276
+
277
+ class FeatureExtractor(nn.Module):
278
+ def __init__(self):
279
+ super(FeatureExtractor, self).__init__()
280
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
281
+ if CHECKPOINTS['fe_cnn']:
282
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
283
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
284
+ torch.cuda.empty_cache()
285
+
286
+ def forward(self, img):
287
+ conv_features = self.conv_backbone(img)
288
+ return conv_features
289
+
290
+
291
+ class Classifier(nn.Module):
292
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
293
+ super(Classifier, self).__init__()
294
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
295
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
296
+
297
+ def forward(self, embeddings):
298
+ dropped_feature = self.dropout(embeddings)
299
+ outputs = self.classifier(dropped_feature)
300
+
301
+ return outputs
302
+
303
+
304
+ class Model(nn.Module):
305
+ def __init__(self):
306
+ super(Model, self).__init__()
307
+ self.feature_extractor = FeatureExtractor()
308
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024, dropout=0.25)
309
+
310
+ def forward(self, img):
311
+ img_features = self.feature_extractor(img)
312
+ classifier_outputs = self.classifier(img_features)
313
+
314
+ return classifier_outputs
315
+
316
+
317
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
318
+ if CHECKPOINTS['model'] and model is not None:
319
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
320
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
321
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
322
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
323
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
324
+ if CHECKPOINTS['scaler'] and scaler is not None:
325
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
326
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
327
+ torch.cuda.empty_cache()
328
+
329
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
330
+ if model is not None:
331
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
332
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
333
+ if optimizer is not None:
334
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
335
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
336
+
337
+ if scaler is not None:
338
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
339
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
340
+ torch.cuda.empty_cache()
341
+
342
+
343
+ def resume_logs(logs):
344
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
345
+ for m in list(logs.keys()):
346
+ logs[m].extend(list(old_logs[m].values))
347
+
348
+ ######################## Optimizer #####################################
349
+ def get_optm_group(module):
350
+ """
351
+ This long function is unfortunately doing something very simple and is being very defensive:
352
+ We are separating out all parameters of the model into two buckets: those that will experience
353
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
354
+ We are then returning the PyTorch optimizer object.
355
+ """
356
+
357
+ # separate out all parameters to those that will and won't experience regularizing weight decay
358
+ decay = set()
359
+ no_decay = set()
360
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
361
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
362
+ for mn, m in module.named_modules():
363
+ for pn, p in m.named_parameters():
364
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
365
+
366
+ if pn.endswith('bias'):
367
+ # all biases will not be decayed
368
+ no_decay.add(fpn)
369
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
370
+ # weights of whitelist modules will be weight decayed
371
+ decay.add(fpn)
372
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
373
+ # weights of blacklist modules will NOT be weight decayed
374
+ no_decay.add(fpn)
375
+
376
+
377
+ # validate that we considered every parameter
378
+ param_dict = {pn: p for pn, p in module.named_parameters()}
379
+ inter_params = decay & no_decay
380
+ union_params = decay | no_decay
381
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
382
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
383
+ % (str(param_dict.keys() - union_params), )
384
+
385
+ return param_dict, decay, no_decay
386
+
387
+
388
+ def get_warmup_optimizer(model):
389
+ params_group = []
390
+
391
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
392
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
393
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
394
+
395
+ optimizer = torch.optim.AdamW(params_group)
396
+ return optimizer
397
+
398
+
399
+ def get_after_warmup_optimizer(model, old_opt):
400
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
401
+
402
+ # add old param groups
403
+ for group in old_opt.param_groups:
404
+ new_opt.add_param_group(group)
405
+
406
+ return new_opt
407
+
408
+
409
+ # #################### Model Warmup #####################################
410
+
411
+ def warmup_start(model):
412
+ # freeze model feature_extractor.conv_backbone during warmup
413
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
414
+ param.requires_grad = False
415
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
416
+
417
+ def warmup_end(model):
418
+ # unfreeze feature_extractor.conv_backbone during warmup
419
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
420
+ param.requires_grad = True
421
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
422
+
423
+
424
+ # #################### Train Loop #####################################
425
+
426
+ # ### train
427
+ def main():
428
+ device = torch.device(f'cuda:0')
429
+ torch.cuda.set_device(device)
430
+
431
+ # prepare the datasets
432
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
433
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
434
+ rand_aug=True)
435
+
436
+ # instantiate the model
437
+ model = Model().to(device)
438
+ #load_checkpoints(model=model)
439
+ if RESUME_EPOCH > 0:
440
+ resume_checkpoints(model=model)
441
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
442
+ warmup_start(model)
443
+
444
+ # Optimizer & Schedules & early stopping
445
+ optimizer = get_warmup_optimizer(model)
446
+ scaler = GradScaler()
447
+ #load_checkpoints(optimizer=optimizer, scaler=scaler)
448
+ if RESUME_EPOCH > 0:
449
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
450
+
451
+ loss_fn = nn.CrossEntropyLoss()
452
+ loss_val_fn = nn.CrossEntropyLoss()
453
+
454
+ # running metrics during training
455
+ loss_metric = MeanMetric().to(device)
456
+ metrics = MetricCollection(metrics={
457
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
458
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
459
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
460
+ }).to(device)
461
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
462
+
463
+ # start time of trainig
464
+ start_training = time.perf_counter()
465
+ # create log dict
466
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
467
+ if RESUME_EPOCH > 0:
468
+ resume_logs(logs)
469
+
470
+ #iterate over epochs
471
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
472
+ for epoch in range(start_epoch, NUM_EPOCHS):
473
+ # start time of epoch
474
+ epoch_start = time.perf_counter()
475
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
476
+
477
+ ######################## toggle warmup ########################################
478
+ if (epoch) == WARMUP_EPOCHS:
479
+ warmup_end(model)
480
+ optimizer = get_after_warmup_optimizer(model, optimizer)
481
+ global BATCH_SIZE
482
+ BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP
483
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
484
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
485
+ rand_aug=True)
486
+
487
+ elif (epoch) < WARMUP_EPOCHS:
488
+ print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}')
489
+
490
+ ############################## train phase ####################################
491
+ model.train()
492
+
493
+ # zero the parameter gradients
494
+ optimizer.zero_grad(set_to_none=True)
495
+
496
+ # grad acc loss divider
497
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
498
+
499
+ # iterate over training batches
500
+ for batch_idx, (inputs, labels, ccm) in enumerate(train_loader):
501
+ inputs = inputs.to(device, non_blocking=True)
502
+ labels = labels.to(device, non_blocking=True)
503
+ ccm = ccm.to(device, non_blocking=True)
504
+
505
+ # forward with mixed precision
506
+ with autocast(device_type='cuda', dtype=torch.float16):
507
+ outputs = model(inputs)
508
+ loss = loss_fn(outputs, labels) / loss_div
509
+
510
+ # loss backward
511
+ scaler.scale(loss).backward()
512
+
513
+ # Compute metrics
514
+ loss_metric.update((loss * loss_div).detach())
515
+
516
+ preds = outputs.softmax(dim=-1).detach()
517
+ metrics.update(preds, labels)
518
+ metric_ccm.update(preds * ccm, labels)
519
+
520
+ ############################ grad acc ##############################
521
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
522
+ #scaler.unscale_(optimizer)
523
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
524
+ scaler.step(optimizer)
525
+ scaler.update()
526
+ # zero the parameter gradients
527
+ optimizer.zero_grad(set_to_none=True)
528
+ # update ema model
529
+ ema_model.update(model)
530
+
531
+
532
+ # compute, sync & reset metrics for validation
533
+ epoch_loss = loss_metric.compute()
534
+ epoch_metrics = metrics.compute()
535
+ epoch_metric_ccm = metric_ccm.compute()
536
+
537
+ loss_metric.reset()
538
+ metrics.reset()
539
+ metric_ccm.reset()
540
+
541
+ # Append metric results to logs
542
+ logs['loss'].append(epoch_loss.cpu().item())
543
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
544
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
545
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
546
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
547
+
548
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
549
+
550
+ # zero the parameter gradients
551
+ optimizer.zero_grad(set_to_none=True)
552
+
553
+ del inputs, labels, ccm, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
554
+ torch.cuda.empty_cache()
555
+
556
+ ############################## valid phase ####################################
557
+ with torch.no_grad():
558
+ model.eval()
559
+
560
+ # iterate over validation batches
561
+ for (inputs, labels, ccm) in valid_loader:
562
+ inputs = inputs.to(device, non_blocking=True)
563
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
564
+ labels = labels.to(device, non_blocking=True)
565
+ ccm = ccm.to(device, non_blocking=True)
566
+
567
+ # forward with mixed precision
568
+ with autocast(device_type='cuda', dtype=torch.float16):
569
+ outputs = model(inputs)
570
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
571
+ loss = loss_val_fn(outputs, labels)
572
+
573
+ # Compute metrics
574
+ loss_metric.update(loss.detach())
575
+
576
+ preds = outputs.softmax(dim=-1).detach()
577
+ metrics.update(preds, labels)
578
+ metric_ccm.update(preds * ccm, labels)
579
+
580
+ # compute, sync & reset metrics for validation
581
+ epoch_loss = loss_metric.compute()
582
+ epoch_metrics = metrics.compute()
583
+ epoch_metric_ccm = metric_ccm.compute()
584
+
585
+ loss_metric.reset()
586
+ metrics.reset()
587
+ metric_ccm.reset()
588
+
589
+ # Append metric results to logs
590
+ logs['val_loss'].append(epoch_loss.cpu().item())
591
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
592
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
593
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
594
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
595
+
596
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
597
+
598
+ del inputs, labels, ccm, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
599
+ torch.cuda.empty_cache()
600
+
601
+ # save logs as csv
602
+ logs_df = pd.DataFrame(logs)
603
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
604
+
605
+ if WANDB:
606
+ # at the end of each epoch, log anything you want to log for that epoch
607
+ wandb.log(
608
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
609
+ step=epoch # epoch index for wandb
610
+ )
611
+
612
+ #save trained model for each epoch
613
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
614
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
615
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
616
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
617
+
618
+ # end time of epoch
619
+ epoch_end = time.perf_counter()
620
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
621
+
622
+ del logs_df, epoch_start, epoch_end
623
+ torch.cuda.empty_cache()
624
+
625
+ ################################## EMA Model Validation ################################
626
+ del model
627
+ torch.cuda.empty_cache()
628
+
629
+ ema_net = ema_model.module
630
+ ema_net.eval()
631
+
632
+ with torch.no_grad():
633
+ # iterate over validation batches
634
+ for (inputs, labels, ccm) in valid_loader:
635
+ inputs = inputs.to(device, non_blocking=True)
636
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
637
+ labels = labels.to(device, non_blocking=True)
638
+ ccm = ccm.to(device, non_blocking=True)
639
+
640
+ # forward with mixed precision
641
+ with autocast(device_type='cuda', dtype=torch.float16):
642
+ outputs = ema_net(inputs, None)
643
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
644
+ loss = loss_val_fn(outputs, labels)
645
+
646
+ # Compute metrics
647
+ loss_metric.update(loss.detach())
648
+
649
+ preds = outputs.softmax(dim=-1).detach()
650
+ metrics.update(preds, labels)
651
+ metric_ccm.update(preds * ccm, labels)
652
+
653
+ # compute, sync & reset metrics for validation
654
+ epoch_loss = loss_metric.compute()
655
+ epoch_metrics = metrics.compute()
656
+ epoch_metric_ccm = metric_ccm.compute()
657
+
658
+ loss_metric.reset()
659
+ metrics.reset()
660
+ metric_ccm.reset()
661
+
662
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
663
+
664
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
665
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
666
+
667
+ plot_history(logs)
668
+ # end time of trainig
669
+ end_training = time.perf_counter()
670
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
671
+
672
+ if WANDB:
673
+ wandb.finish()
674
+
675
+
676
+ if __name__=="__main__":
677
+ main()
678
+
679
+
680
+
681
+
exp3/convnext2b_exp3_metaEmbedding.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+
22
+
23
+ from torchmetrics import MeanMetric
24
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
25
+ from torchmetrics import MetricCollection
26
+
27
+
28
+ import wandb
29
+
30
+ import matplotlib.pyplot as plt
31
+
32
+
33
+ # ### parameters
34
+ ################## Settings #############################
35
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
+ torch.backends.cudnn.benchmark = True
37
+
38
+ ################## Data Paths ##########################
39
+ MODEL_DIR = "./convnext2b_meta_embedding/"
40
+
41
+ if not os.path.exists(MODEL_DIR):
42
+ os.makedirs(MODEL_DIR)
43
+ shutil.copyfile('./convnext2b_exp3_metaEmbedding.py', f'{MODEL_DIR}convnext2b_exp3_metaEmbedding.py')
44
+
45
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
46
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
47
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
48
+
49
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
50
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
51
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
52
+
53
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
54
+
55
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
56
+
57
+
58
+ NUM_CLASSES = 1784
59
+
60
+ ################## Hyperparameters ########################
61
+ NUM_EPOCHS = 40
62
+ WARMUP_EPOCHS = 5 # num. epochs only training classification head of model
63
+ RESUME_EPOCH = 0
64
+
65
+
66
+ LEARNING_RATE = {
67
+ 'cnn': 1e-05,
68
+ 'embeddings': 1e-04,
69
+ 'classifier': 1e-04,
70
+ }
71
+
72
+ BATCH_SIZE = {
73
+ 'train': 32,
74
+ 'valid': 48,
75
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
76
+ }
77
+
78
+ BATCH_SIZE_AFTER_WARMUP = {
79
+ 'train': 32,
80
+ 'valid': 48,
81
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
82
+ }
83
+
84
+ TRANSFORMS = {
85
+ 'IMAGE_SIZE_TRAIN': 544,
86
+ 'IMAGE_SIZE_VAL': 544,
87
+ 'RandAug' : {
88
+ 'm': 7,
89
+ 'n': 2
90
+ }
91
+ }
92
+
93
+ ############# Checkpoints ####################
94
+ CHECKPOINTS = {
95
+ 'fe_cnn': None, # iNaturalist pre-trained model checkpoints available at "https://huggingface.co/BBracke/convnextv2_base.inat21_384"
96
+ 'model': None,
97
+ 'optimizer': None,
98
+ 'scaler': None,
99
+ }
100
+
101
+ # ####### Embedding Token Mappings ########################
102
+ META_SIZES = {'endemic': 2, 'code': 212}
103
+ EMBEDDING_SIZES = {'endemic': 64, 'code': 64}
104
+
105
+ CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb"))
106
+ ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb"))
107
+
108
+ ################### WandB ##################
109
+ WANDB = False
110
+
111
+ if WANDB:
112
+ wandb.init(
113
+ entity="snakeclef2023", # our team at wandb
114
+
115
+ # set the wandb project where this run will be logged
116
+ project="exp3", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
117
+
118
+ # define a name for this run
119
+ name="meta_embedding",
120
+
121
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
122
+ config={
123
+ "learning_rate": LEARNING_RATE,
124
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
125
+ "pretrained": "iNat21",
126
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
127
+ "epochs": NUM_EPOCHS,
128
+ "transforms": TRANSFORMS,
129
+ "checkpoints": CHECKPOINTS,
130
+ "model_dir": MODEL_DIR
131
+ # ... any other hyperparameter that is necessary to reproduce the result
132
+ },
133
+ save_code=True, # save the script file as backup
134
+ dir=MODEL_DIR # locally folder where wandb log files are saved
135
+ )
136
+
137
+
138
+
139
+
140
+ ##################### Dataset & AugTransforms #####################################
141
+ # ### dataset & loaders
142
+ class SnakeTrainDataset(Dataset):
143
+ def __init__(self, data, ccm, transform=None):
144
+ self.data = data
145
+ self.transform = transform # Image augmentation pipeline
146
+ self.code_class_mapping = ccm
147
+ self.code_tokens = CODE_TOKENS
148
+ self.endemic_tokens = ENDEMIC_TOKENS
149
+
150
+ def __len__(self):
151
+ return self.data.shape[0]
152
+
153
+ def __getitem__(self, index):
154
+ obj = self.data.iloc[index] # get instance
155
+ label = obj.class_id # get label
156
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
157
+ endemic = obj.endemic if obj.endemic in self.endemic_tokens.keys() else False # get endemic metadata
158
+
159
+ img = Image.open(obj.image_path).convert("RGB") # load image
160
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
161
+ meta = torch.tensor([self.code_tokens[code], self.endemic_tokens[endemic]]) # metadata tokens
162
+
163
+ # img. augmentation
164
+ img = self.transform(img)
165
+
166
+ return (img, label, ccm, meta)
167
+
168
+
169
+ # valid data preprocessing pipeline
170
+ def get_val_preprocessing(img_size):
171
+ print(f'IMG_SIZE_VAL: {img_size}')
172
+ return transforms.Compose([
173
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
174
+ transforms.Compose([
175
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
176
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
177
+ ]),
178
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
179
+ ])
180
+
181
+ class IdentityTransform:
182
+ def __call__(self, x):
183
+ return x
184
+
185
+
186
+ # train data augmentation/ preprocessing pipeline
187
+ def get_train_augmentation_preprocessing(img_size, rand_aug=False):
188
+ print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}')
189
+ return transforms.Compose([
190
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
191
+ transforms.RandomHorizontalFlip(p=0.5),
192
+ transforms.RandomVerticalFlip(p=0.5),
193
+ transforms.RandomCrop((img_size, img_size)), # Random Crop to IMAGE_SIZE
194
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(),
195
+ transforms.ToTensor(),
196
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
197
+ ])
198
+
199
+
200
+ def get_datasets(train_transfroms, val_transforms):
201
+ # load CSVs
202
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
203
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
204
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
205
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
206
+
207
+ # delete missing files of train data table
208
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
209
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
210
+
211
+ # add image path
212
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
213
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
214
+
215
+ # add additional data
216
+ if ADD_TRAINDATA_CONFIG:
217
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
218
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
219
+ train_data = pd.concat([train_data, add_train_data], axis=0)
220
+
221
+ # limit data size
222
+ #train_data = train_data.head(1000)
223
+ #valid_data = valid_data.head(1000)
224
+ print(f'train data shape: {train_data.shape}')
225
+
226
+ # shuffle
227
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
228
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
229
+
230
+ # load transposed version of CCM table
231
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
232
+
233
+ # create datasets
234
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
235
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
236
+
237
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
238
+
239
+
240
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
241
+ # get train, valid augmentation & preprocessing pipelines
242
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
243
+ val_preprocessing = get_val_preprocessing(imgsize_val)
244
+ # prepare the datasets
245
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
246
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
247
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
248
+
249
+ return train_loader, valid_loader
250
+
251
+
252
+ # #################### plot train history #########################
253
+
254
+ def plot_history(logs):
255
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
256
+
257
+ ax[0].plot(logs['loss'], label="train data")
258
+ ax[0].plot(logs['val_loss'], label="valid data")
259
+ ax[0].legend(loc="best")
260
+ ax[0].set_ylabel("loss")
261
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
262
+ #ax[0].set_xlabel("epochs")
263
+ ax[0].set_title("train- vs. valid loss")
264
+
265
+ ax[1].plot(logs['acc'], label="train data")
266
+ ax[1].plot(logs['val_acc'], label="valid data")
267
+ ax[1].legend(loc="best")
268
+ ax[1].set_ylabel("accuracy")
269
+ ax[1].set_ylim([0, 1.01])
270
+ #ax[1].set_xlabel("epochs")
271
+ ax[1].set_title("train- vs. valid accuracy")
272
+
273
+ ax[2].plot(logs['f1'], label="train data")
274
+ ax[2].plot(logs['val_f1'], label="valid data")
275
+ ax[2].legend(loc="best")
276
+ ax[2].set_ylabel("f1")
277
+ ax[2].set_ylim([0, 1.01])
278
+ ax[2].set_xlabel("epochs")
279
+ ax[2].set_title("train- vs. valid f1")
280
+
281
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
282
+ plt.show()
283
+
284
+
285
+ # #################### Model #####################################
286
+
287
+ class FeatureExtractor(nn.Module):
288
+ def __init__(self):
289
+ super(FeatureExtractor, self).__init__()
290
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
291
+ if CHECKPOINTS['fe_cnn']:
292
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
293
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
294
+ torch.cuda.empty_cache()
295
+
296
+ def forward(self, img):
297
+ conv_features = self.conv_backbone(img)
298
+ return conv_features
299
+
300
+
301
+ class MetaEmbeddings(nn.Module):
302
+ def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None):
303
+ super(MetaEmbeddings, self).__init__()
304
+ self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0)
305
+ self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0)
306
+
307
+ self.dim_embedding = sum(embedding_sizes.values())
308
+ self.embedding_net = nn.Sequential(
309
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
310
+ nn.GELU(),
311
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
312
+ nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(),
313
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
314
+ nn.GELU(),
315
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
316
+ )
317
+
318
+ def forward(self, meta):
319
+ code_feature = self.code_embedding(meta[:,0])
320
+ endemic_feature = self.endemic_embedding(meta[:,1])
321
+
322
+ embeddings = torch.concat([code_feature, endemic_feature], dim=-1)
323
+ embedding_features = self.embedding_net(embeddings)
324
+
325
+ return embedding_features
326
+
327
+
328
+ class Classifier(nn.Module):
329
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
330
+ super(Classifier, self).__init__()
331
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
332
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
333
+
334
+ def forward(self, embeddings):
335
+ dropped_feature = self.dropout(embeddings)
336
+ outputs = self.classifier(dropped_feature)
337
+
338
+ return outputs
339
+
340
+
341
+ class Model(nn.Module):
342
+ def __init__(self):
343
+ super(Model, self).__init__()
344
+ self.feature_extractor = FeatureExtractor()
345
+ self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25)
346
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25)
347
+
348
+ def forward(self, img, meta):
349
+ img_features = self.feature_extractor(img)
350
+
351
+ meta_features = self.embedding_net(meta)
352
+ cat_features = torch.concat([img_features, meta_features], dim=-1)
353
+ classifier_outputs = self.classifier(cat_features)
354
+
355
+ return classifier_outputs
356
+
357
+
358
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
359
+ if CHECKPOINTS['model'] and model is not None:
360
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
361
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
362
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
363
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
364
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
365
+ if CHECKPOINTS['scaler'] and scaler is not None:
366
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
367
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
368
+ torch.cuda.empty_cache()
369
+
370
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
371
+ if model is not None:
372
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
373
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
374
+ if optimizer is not None:
375
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
376
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
377
+
378
+ if scaler is not None:
379
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
380
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
381
+ torch.cuda.empty_cache()
382
+
383
+
384
+ def resume_logs(logs):
385
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
386
+ for m in list(logs.keys()):
387
+ logs[m].extend(list(old_logs[m].values))
388
+
389
+ ######################## Optimizer #####################################
390
+ def get_optm_group(module):
391
+ """
392
+ This long function is unfortunately doing something very simple and is being very defensive:
393
+ We are separating out all parameters of the model into two buckets: those that will experience
394
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
395
+ We are then returning the PyTorch optimizer object.
396
+ """
397
+
398
+ # separate out all parameters to those that will and won't experience regularizing weight decay
399
+ decay = set()
400
+ no_decay = set()
401
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
402
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
403
+ for mn, m in module.named_modules():
404
+ for pn, p in m.named_parameters():
405
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
406
+
407
+ if pn.endswith('bias'):
408
+ # all biases will not be decayed
409
+ no_decay.add(fpn)
410
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
411
+ # weights of whitelist modules will be weight decayed
412
+ decay.add(fpn)
413
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
414
+ # weights of blacklist modules will NOT be weight decayed
415
+ no_decay.add(fpn)
416
+
417
+
418
+ # validate that we considered every parameter
419
+ param_dict = {pn: p for pn, p in module.named_parameters()}
420
+ inter_params = decay & no_decay
421
+ union_params = decay | no_decay
422
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
423
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
424
+ % (str(param_dict.keys() - union_params), )
425
+
426
+ return param_dict, decay, no_decay
427
+
428
+
429
+ def get_warmup_optimizer(model):
430
+ params_group = []
431
+
432
+ param_dict, decay, no_decay = get_optm_group(model.embedding_net)
433
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']})
434
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']})
435
+
436
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
437
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
438
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
439
+
440
+ optimizer = torch.optim.AdamW(params_group)
441
+ return optimizer
442
+
443
+
444
+ def get_after_warmup_optimizer(model, old_opt):
445
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
446
+
447
+ # add old param groups
448
+ for group in old_opt.param_groups:
449
+ new_opt.add_param_group(group)
450
+
451
+ return new_opt
452
+
453
+
454
+ # #################### Model Warmup #####################################
455
+
456
+ def warmup_start(model):
457
+ # freeze model feature_extractor.conv_backbone during warmup
458
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
459
+ param.requires_grad = False
460
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
461
+
462
+ def warmup_end(model):
463
+ # unfreeze feature_extractor.conv_backbone during warmup
464
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
465
+ param.requires_grad = True
466
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
467
+
468
+
469
+ # #################### Train Loop #####################################
470
+
471
+ # ### train
472
+ def main():
473
+ device = torch.device(f'cuda:0')
474
+ torch.cuda.set_device(device)
475
+
476
+ # prepare the datasets
477
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
478
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
479
+ rand_aug=True)
480
+
481
+ # instantiate the model
482
+ model = Model().to(device)
483
+ #load_checkpoints(model=model)
484
+ if RESUME_EPOCH > 0:
485
+ resume_checkpoints(model=model)
486
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
487
+ warmup_start(model)
488
+
489
+ # Optimizer & Schedules & early stopping
490
+ optimizer = get_warmup_optimizer(model)
491
+ scaler = GradScaler()
492
+ #load_checkpoints(optimizer=optimizer, scaler=scaler)
493
+ if RESUME_EPOCH > 0:
494
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
495
+
496
+ loss_fn = nn.CrossEntropyLoss() #FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'])
497
+ loss_val_fn = nn.CrossEntropyLoss()
498
+
499
+ # running metrics during training
500
+ loss_metric = MeanMetric().to(device)
501
+ metrics = MetricCollection(metrics={
502
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
503
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
504
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
505
+ }).to(device)
506
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
507
+
508
+ # start time of trainig
509
+ start_training = time.perf_counter()
510
+ # create log dict
511
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
512
+ if RESUME_EPOCH > 0:
513
+ resume_logs(logs)
514
+
515
+ #iterate over epochs
516
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
517
+ for epoch in range(start_epoch, NUM_EPOCHS):
518
+ # start time of epoch
519
+ epoch_start = time.perf_counter()
520
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
521
+
522
+ ######################## toggle warmup ########################################
523
+ if (epoch) == WARMUP_EPOCHS:
524
+ warmup_end(model)
525
+ optimizer = get_after_warmup_optimizer(model, optimizer)
526
+ global BATCH_SIZE
527
+ BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP
528
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
529
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
530
+ rand_aug=True)
531
+
532
+ elif (epoch) < WARMUP_EPOCHS:
533
+ print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}')
534
+
535
+ ############################## train phase ####################################
536
+ model.train()
537
+
538
+ # zero the parameter gradients
539
+ optimizer.zero_grad(set_to_none=True)
540
+
541
+ # grad acc loss divider
542
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
543
+
544
+ # iterate over training batches
545
+ for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader):
546
+ inputs = inputs.to(device, non_blocking=True)
547
+ meta = meta.to(device, non_blocking=True)
548
+ labels = labels.to(device, non_blocking=True)
549
+ ccm = ccm.to(device, non_blocking=True)
550
+
551
+ # forward with mixed precision
552
+ with autocast(device_type='cuda', dtype=torch.float16):
553
+ outputs = model(inputs, meta)
554
+ loss = loss_fn(outputs, labels) / loss_div
555
+
556
+ # loss backward
557
+ scaler.scale(loss).backward()
558
+
559
+ # Compute metrics
560
+ loss_metric.update((loss * loss_div).detach())
561
+
562
+ preds = outputs.softmax(dim=-1).detach()
563
+ metrics.update(preds, labels)
564
+ metric_ccm.update(preds * ccm, labels)
565
+
566
+ ############################ grad acc ##############################
567
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
568
+ #scaler.unscale_(optimizer)
569
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
570
+ scaler.step(optimizer)
571
+ scaler.update()
572
+ # zero the parameter gradients
573
+ optimizer.zero_grad(set_to_none=True)
574
+ # update ema model
575
+ ema_model.update(model)
576
+
577
+
578
+ # compute, sync & reset metrics for validation
579
+ epoch_loss = loss_metric.compute()
580
+ epoch_metrics = metrics.compute()
581
+ epoch_metric_ccm = metric_ccm.compute()
582
+
583
+ loss_metric.reset()
584
+ metrics.reset()
585
+ metric_ccm.reset()
586
+
587
+ # Append metric results to logs
588
+ logs['loss'].append(epoch_loss.cpu().item())
589
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
590
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
591
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
592
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
593
+
594
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
595
+
596
+ # zero the parameter gradients
597
+ optimizer.zero_grad(set_to_none=True)
598
+
599
+ del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
600
+ torch.cuda.empty_cache()
601
+
602
+ ############################## valid phase ####################################
603
+ with torch.no_grad():
604
+ model.eval()
605
+
606
+ # iterate over validation batches
607
+ for (inputs, labels, ccm, meta) in valid_loader:
608
+ inputs = inputs.to(device, non_blocking=True)
609
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
610
+ meta = meta.to(device, non_blocking=True)
611
+ meta = torch.repeat_interleave(meta, repeats=5, dim=0)
612
+ labels = labels.to(device, non_blocking=True)
613
+ ccm = ccm.to(device, non_blocking=True)
614
+
615
+ # forward with mixed precision
616
+ with autocast(device_type='cuda', dtype=torch.float16):
617
+ outputs = model(inputs, meta)
618
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
619
+ loss = loss_val_fn(outputs, labels)
620
+
621
+ # Compute metrics
622
+ loss_metric.update(loss.detach())
623
+
624
+ preds = outputs.softmax(dim=-1).detach()
625
+ metrics.update(preds, labels)
626
+ metric_ccm.update(preds * ccm, labels)
627
+
628
+ # compute, sync & reset metrics for validation
629
+ epoch_loss = loss_metric.compute()
630
+ epoch_metrics = metrics.compute()
631
+ epoch_metric_ccm = metric_ccm.compute()
632
+
633
+ loss_metric.reset()
634
+ metrics.reset()
635
+ metric_ccm.reset()
636
+
637
+ # Append metric results to logs
638
+ logs['val_loss'].append(epoch_loss.cpu().item())
639
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
640
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
641
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
642
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
643
+
644
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
645
+
646
+ del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
647
+ torch.cuda.empty_cache()
648
+
649
+ # save logs as csv
650
+ logs_df = pd.DataFrame(logs)
651
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
652
+
653
+ if WANDB:
654
+ # at the end of each epoch, log anything you want to log for that epoch
655
+ wandb.log(
656
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
657
+ step=epoch # epoch index for wandb
658
+ )
659
+
660
+ #save trained model for each epoch
661
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
662
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
663
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
664
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
665
+
666
+ # end time of epoch
667
+ epoch_end = time.perf_counter()
668
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
669
+
670
+ del logs_df, epoch_start, epoch_end
671
+ torch.cuda.empty_cache()
672
+
673
+ ################################## EMA Model Validation ################################
674
+ del model
675
+ torch.cuda.empty_cache()
676
+
677
+ ema_net = ema_model.module
678
+ ema_net.eval()
679
+
680
+ with torch.no_grad():
681
+ # iterate over validation batches
682
+ for (inputs, labels, ccm, meta) in valid_loader:
683
+ inputs = inputs.to(device, non_blocking=True)
684
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
685
+ meta = meta.to(device, non_blocking=True)
686
+ meta = torch.repeat_interleave(meta, repeats=5, dim=0)
687
+ labels = labels.to(device, non_blocking=True)
688
+ ccm = ccm.to(device, non_blocking=True)
689
+
690
+ # forward with mixed precision
691
+ with autocast(device_type='cuda', dtype=torch.float16):
692
+ outputs = ema_net(inputs, meta)
693
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
694
+ loss = loss_val_fn(outputs, labels)
695
+
696
+ # Compute metrics
697
+ loss_metric.update(loss.detach())
698
+
699
+ preds = outputs.softmax(dim=-1).detach()
700
+ metrics.update(preds, labels)
701
+ metric_ccm.update(preds * ccm, labels)
702
+
703
+ # compute, sync & reset metrics for validation
704
+ epoch_loss = loss_metric.compute()
705
+ epoch_metrics = metrics.compute()
706
+ epoch_metric_ccm = metric_ccm.compute()
707
+
708
+ loss_metric.reset()
709
+ metrics.reset()
710
+ metric_ccm.reset()
711
+
712
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
713
+
714
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
715
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
716
+
717
+ plot_history(logs)
718
+ # end time of trainig
719
+ end_training = time.perf_counter()
720
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
721
+
722
+ if WANDB:
723
+ wandb.finish()
724
+
725
+
726
+ if __name__=="__main__":
727
+ main()
728
+
729
+
730
+
731
+
exp4/convnext2b_exp4_meta_embedding_focalarcloss.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+
22
+ from torchmetrics import MeanMetric
23
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
24
+ from torchmetrics import MetricCollection
25
+
26
+ from pytorch_metric_learning.losses import ArcFaceLoss
27
+
28
+ import wandb
29
+
30
+ import matplotlib.pyplot as plt
31
+
32
+
33
+ # ### parameters
34
+ ################## Settings #############################
35
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
+ torch.backends.cudnn.benchmark = True
37
+
38
+ ################## Data Paths ##########################
39
+ MODEL_DIR = "./convnext2b_metaEmbedding_focal05es_arcloss/"
40
+
41
+ if not os.path.exists(MODEL_DIR):
42
+ os.makedirs(MODEL_DIR)
43
+ shutil.copyfile('./convnext2b_exp4_meta_embedding_focalarcloss.py', f'{MODEL_DIR}convnext2b_exp4_meta_embedding_focalarcloss.py')
44
+
45
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
46
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
47
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
48
+
49
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
50
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
51
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
52
+
53
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
54
+
55
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
56
+
57
+
58
+ NUM_CLASSES = 1784
59
+
60
+ ################## Hyperparameters ########################
61
+ NUM_EPOCHS = 40
62
+ WARMUP_EPOCHS = 0
63
+ RESUME_EPOCH = 14 # resume model, optimizer from epoch 14 of experiment 3, checkpoint files need to be copied to the MODEL_DIR folder
64
+
65
+
66
+ LEARNING_RATE = {
67
+ 'cnn': 1e-05,
68
+ 'embeddings': 1e-04,
69
+ 'classifier': 1e-04,
70
+ }
71
+
72
+ BATCH_SIZE = {
73
+ 'train': 32,
74
+ 'valid': 48,
75
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
76
+ }
77
+
78
+ BATCH_SIZE_AFTER_WARMUP = {
79
+ 'train': 32,
80
+ 'valid': 48,
81
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
82
+ }
83
+
84
+ TRANSFORMS = {
85
+ 'IMAGE_SIZE_TRAIN': 544,
86
+ 'IMAGE_SIZE_VAL': 544,
87
+ 'RandAug' : {
88
+ 'm': 7,
89
+ 'n': 2
90
+ }
91
+ }
92
+
93
+
94
+ # ############# Focal Loss ####################
95
+ FOCAL_LOSS = {
96
+ 'class_dist': pickle.load(open("../classDist_HMP_missedRemoved.p", "rb"))['counts'], # snake species frequency obtained on observation_id level taken into account missing observation_id of missing image files
97
+ 'gamma': 0.5,
98
+ }
99
+
100
+
101
+ ############# Checkpoints ####################
102
+ CHECKPOINTS = {
103
+ 'fe_cnn': None,
104
+ 'model': None,
105
+ 'optimizer': None,
106
+ 'scaler': None,
107
+ }
108
+
109
+ # ####### Embedding Token Mappings ########################
110
+ META_SIZES = {'endemic': 2, 'code': 212}
111
+ EMBEDDING_SIZES = {'endemic': 64, 'code': 64}
112
+
113
+ CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb"))
114
+ ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb"))
115
+
116
+ ################### WandB ##################
117
+ WANDB = True
118
+
119
+ if WANDB:
120
+ wandb.init(
121
+ entity="snakeclef2023", # our team at wandb
122
+
123
+ # set the wandb project where this run will be logged
124
+ project="exp4", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
125
+
126
+ # define a name for this run
127
+ name="focal05es_arcloss",
128
+
129
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
130
+ config={
131
+ "learning_rate": LEARNING_RATE,
132
+ "focal_loss": FOCAL_LOSS,
133
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
134
+ "pretrained": "iNat21",
135
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
136
+ "epochs": NUM_EPOCHS,
137
+ "transforms": TRANSFORMS,
138
+ "checkpoints": CHECKPOINTS,
139
+ "model_dir": MODEL_DIR
140
+ # ... any other hyperparameter that is necessary to reproduce the result
141
+ },
142
+ save_code=True, # save the script file as backup
143
+ dir=MODEL_DIR # locally folder where wandb log files are saved
144
+ )
145
+
146
+
147
+
148
+
149
+ ##################### Dataset & AugTransforms #####################################
150
+ # ### dataset & loaders
151
+ class SnakeTrainDataset(Dataset):
152
+ def __init__(self, data, ccm, transform=None):
153
+ self.data = data
154
+ self.transform = transform # Image augmentation pipeline
155
+ self.code_class_mapping = ccm
156
+ self.code_tokens = CODE_TOKENS
157
+ self.endemic_tokens = ENDEMIC_TOKENS
158
+
159
+ def __len__(self):
160
+ return self.data.shape[0]
161
+
162
+ def __getitem__(self, index):
163
+ obj = self.data.iloc[index] # get instance
164
+ label = obj.class_id # get label
165
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
166
+ endemic = obj.endemic if obj.endemic in self.endemic_tokens.keys() else False # get endemic metadata
167
+
168
+ img = Image.open(obj.image_path).convert("RGB") # load image
169
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
170
+ meta = torch.tensor([self.code_tokens[code], self.endemic_tokens[endemic]]) # metadata tokens
171
+
172
+ # img. augmentation
173
+ img = self.transform(img)
174
+
175
+ return (img, label, ccm, meta)
176
+
177
+
178
+ # valid data preprocessing pipeline
179
+ def get_val_preprocessing(img_size):
180
+ print(f'IMG_SIZE_VAL: {img_size}')
181
+ return transforms.Compose([
182
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
183
+ transforms.Compose([
184
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
185
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
186
+ ]),
187
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
188
+ ])
189
+
190
+ class IdentityTransform:
191
+ def __call__(self, x):
192
+ return x
193
+
194
+
195
+ # train data augmentation/ preprocessing pipeline
196
+ def get_train_augmentation_preprocessing(img_size, rand_aug=False):
197
+ print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}')
198
+ return transforms.Compose([
199
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
200
+ transforms.RandomHorizontalFlip(p=0.5),
201
+ transforms.RandomVerticalFlip(p=0.5),
202
+ transforms.RandomCrop((img_size, img_size)), # Random Crop to IMAGE_SIZE
203
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(),
204
+ transforms.ToTensor(),
205
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
206
+ ])
207
+
208
+
209
+ def get_datasets(train_transfroms, val_transforms):
210
+ # load CSVs
211
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
212
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
213
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
214
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
215
+
216
+ # delete missing files of train data table
217
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
218
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
219
+
220
+ # add image path
221
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
222
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
223
+
224
+ # add additional data
225
+ if ADD_TRAINDATA_CONFIG:
226
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
227
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
228
+ train_data = pd.concat([train_data, add_train_data], axis=0)
229
+
230
+ # limit data size
231
+ #train_data = train_data.head(1000)
232
+ #valid_data = valid_data.head(1000)
233
+ print(f'train data shape: {train_data.shape}')
234
+
235
+ # shuffle
236
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
237
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
238
+
239
+ # load transposed version of CCM table
240
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
241
+
242
+ # create datasets
243
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
244
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
245
+
246
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
247
+
248
+
249
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
250
+ # get train, valid augmentation & preprocessing pipelines
251
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
252
+ val_preprocessing = get_val_preprocessing(imgsize_val)
253
+ # prepare the datasets
254
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
255
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
256
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
257
+
258
+ return train_loader, valid_loader
259
+
260
+
261
+ # #################### plot train history #########################
262
+
263
+ def plot_history(logs):
264
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
265
+
266
+ ax[0].plot(logs['loss'], label="train data")
267
+ ax[0].plot(logs['val_loss'], label="valid data")
268
+ ax[0].legend(loc="best")
269
+ ax[0].set_ylabel("loss")
270
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
271
+ #ax[0].set_xlabel("epochs")
272
+ ax[0].set_title("train- vs. valid loss")
273
+
274
+ ax[1].plot(logs['acc'], label="train data")
275
+ ax[1].plot(logs['val_acc'], label="valid data")
276
+ ax[1].legend(loc="best")
277
+ ax[1].set_ylabel("accuracy")
278
+ ax[1].set_ylim([0, 1.01])
279
+ #ax[1].set_xlabel("epochs")
280
+ ax[1].set_title("train- vs. valid accuracy")
281
+
282
+ ax[2].plot(logs['f1'], label="train data")
283
+ ax[2].plot(logs['val_f1'], label="valid data")
284
+ ax[2].legend(loc="best")
285
+ ax[2].set_ylabel("f1")
286
+ ax[2].set_ylim([0, 1.01])
287
+ ax[2].set_xlabel("epochs")
288
+ ax[2].set_title("train- vs. valid f1")
289
+
290
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
291
+ plt.show()
292
+
293
+ #################### Focal Loss ##################################
294
+ class FocalLoss(nn.Module):
295
+ '''
296
+ Multi-class Focal Loss
297
+ '''
298
+ def __init__(self, gamma, class_dist=None, reduction='mean', device='cuda'):
299
+ super(FocalLoss, self).__init__()
300
+ self.gamma = gamma
301
+ #self.weight = torch.tensor(1.0 / class_dist, dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device) # inverse class frequency weighting
302
+ self.weight = torch.tensor((1.0 - 0.999) / (1.0 - 0.999**class_dist), dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device) # "effectiv number of samples" weighting
303
+ self.reduction = reduction
304
+
305
+ def forward(self, inputs, targets):
306
+ """
307
+ input: [N, C], float32
308
+ target: [N, ], int64
309
+ """
310
+ logpt = torch.nn.functional.log_softmax(inputs, dim=1)
311
+ pt = torch.exp(logpt)
312
+ logpt = (1-pt)**self.gamma * logpt
313
+ loss = torch.nn.functional.nll_loss(logpt, targets, weight=self.weight, reduction=self.reduction)
314
+ return loss
315
+
316
+
317
+ # #################### Model #####################################
318
+
319
+ class FeatureExtractor(nn.Module):
320
+ def __init__(self):
321
+ super(FeatureExtractor, self).__init__()
322
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
323
+ if CHECKPOINTS['fe_cnn']:
324
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
325
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
326
+ torch.cuda.empty_cache()
327
+
328
+ def forward(self, img):
329
+ conv_features = self.conv_backbone(img)
330
+ return conv_features
331
+
332
+
333
+ class MetaEmbeddings(nn.Module):
334
+ def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None):
335
+ super(MetaEmbeddings, self).__init__()
336
+ self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0)
337
+ self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0)
338
+
339
+ self.dim_embedding = sum(embedding_sizes.values())
340
+ self.embedding_net = nn.Sequential(
341
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
342
+ nn.GELU(),
343
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
344
+ nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(),
345
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
346
+ nn.GELU(),
347
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
348
+ )
349
+
350
+ def forward(self, meta):
351
+ code_feature = self.code_embedding(meta[:,0])
352
+ endemic_feature = self.endemic_embedding(meta[:,1])
353
+
354
+ embeddings = torch.concat([code_feature, endemic_feature], dim=-1)
355
+ embedding_features = self.embedding_net(embeddings)
356
+
357
+ return embedding_features
358
+
359
+
360
+ class Classifier(nn.Module):
361
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
362
+ super(Classifier, self).__init__()
363
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
364
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
365
+
366
+ def forward(self, embeddings):
367
+ dropped_feature = self.dropout(embeddings)
368
+ outputs = self.classifier(dropped_feature)
369
+
370
+ return outputs
371
+
372
+
373
+ class Model(nn.Module):
374
+ def __init__(self):
375
+ super(Model, self).__init__()
376
+ self.feature_extractor = FeatureExtractor()
377
+ self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25)
378
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25)
379
+
380
+ def forward(self, img, meta):
381
+ img_features = self.feature_extractor(img)
382
+
383
+ meta_features = self.embedding_net(meta)
384
+ cat_features = torch.concat([img_features, meta_features], dim=-1)
385
+ classifier_outputs = self.classifier(cat_features)
386
+
387
+ return classifier_outputs, cat_features
388
+
389
+ class LossLayer(nn.Module):
390
+ def __init__(self):
391
+ super(LossLayer, self).__init__()
392
+ self.arcloss = ArcFaceLoss(num_classes=NUM_CLASSES, embedding_size=1024+128, margin=28.6, scale=64)
393
+ self.celoss = FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'])
394
+
395
+ def forward(self, classifier_outputs, cat_features, labels):
396
+ classifier_loss = self.celoss(classifier_outputs, labels)
397
+ embedding_loss = self.arcloss(cat_features, labels)
398
+ return classifier_loss + embedding_loss
399
+
400
+
401
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
402
+ if CHECKPOINTS['model'] and model is not None:
403
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
404
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
405
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
406
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
407
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
408
+ if CHECKPOINTS['scaler'] and scaler is not None:
409
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
410
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
411
+ torch.cuda.empty_cache()
412
+
413
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
414
+ if model is not None:
415
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
416
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
417
+ if optimizer is not None:
418
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
419
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
420
+
421
+ if scaler is not None:
422
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
423
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
424
+ torch.cuda.empty_cache()
425
+
426
+
427
+ def resume_logs(logs):
428
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
429
+ for m in list(logs.keys()):
430
+ logs[m].extend(list(old_logs[m].values))
431
+
432
+ ######################## Optimizer #####################################
433
+ def get_optm_group(module):
434
+ """
435
+ This long function is unfortunately doing something very simple and is being very defensive:
436
+ We are separating out all parameters of the model into two buckets: those that will experience
437
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
438
+ We are then returning the PyTorch optimizer object.
439
+ """
440
+
441
+ # separate out all parameters to those that will and won't experience regularizing weight decay
442
+ decay = set()
443
+ no_decay = set()
444
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
445
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
446
+ for mn, m in module.named_modules():
447
+ for pn, p in m.named_parameters():
448
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
449
+
450
+ if pn.endswith('bias'):
451
+ # all biases will not be decayed
452
+ no_decay.add(fpn)
453
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
454
+ # weights of whitelist modules will be weight decayed
455
+ decay.add(fpn)
456
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
457
+ # weights of blacklist modules will NOT be weight decayed
458
+ no_decay.add(fpn)
459
+
460
+
461
+ # validate that we considered every parameter
462
+ param_dict = {pn: p for pn, p in module.named_parameters()}
463
+ inter_params = decay & no_decay
464
+ union_params = decay | no_decay
465
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
466
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
467
+ % (str(param_dict.keys() - union_params), )
468
+
469
+ return param_dict, decay, no_decay
470
+
471
+
472
+ def get_warmup_optimizer(model):
473
+ params_group = []
474
+
475
+ param_dict, decay, no_decay = get_optm_group(model.embedding_net)
476
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']})
477
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']})
478
+
479
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
480
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
481
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
482
+
483
+ optimizer = torch.optim.AdamW(params_group)
484
+ return optimizer
485
+
486
+
487
+ def get_after_warmup_optimizer(model, old_opt):
488
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
489
+
490
+ # add old param groups
491
+ for group in old_opt.param_groups:
492
+ new_opt.add_param_group(group)
493
+
494
+ return new_opt
495
+
496
+
497
+ # #################### Model Warmup #####################################
498
+
499
+ def warmup_start(model):
500
+ # freeze model feature_extractor.conv_backbone during warmup
501
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
502
+ param.requires_grad = False
503
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
504
+
505
+ def warmup_end(model):
506
+ # unfreeze feature_extractor.conv_backbone during warmup
507
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
508
+ param.requires_grad = True
509
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
510
+
511
+
512
+ # #################### Train Loop #####################################
513
+
514
+ # ### train
515
+ def main():
516
+ device = torch.device(f'cuda:1')
517
+ torch.cuda.set_device(device)
518
+
519
+ # prepare the datasets
520
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
521
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
522
+ rand_aug=True)
523
+
524
+ # instantiate the model
525
+ model = Model().to(device)
526
+ #load_checkpoints(model=model)
527
+ if RESUME_EPOCH > 0:
528
+ resume_checkpoints(model=model)
529
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
530
+ #warmup_start(model)
531
+
532
+ # Optimizer & Schedules & early stopping
533
+ optimizer = get_warmup_optimizer(model)
534
+ scaler = GradScaler()
535
+ #load_checkpoints(optimizer=optimizer, scaler=scaler)
536
+ if RESUME_EPOCH > 0:
537
+ optimizer = get_after_warmup_optimizer(model, optimizer) if RESUME_EPOCH > WARMUP_EPOCHS else optimizer
538
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
539
+
540
+ loss_fn = LossLayer().to(device)
541
+ optimizer.add_param_group({"params": loss_fn.arcloss.parameters(), "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
542
+
543
+ # running metrics during training
544
+ loss_metric = MeanMetric().to(device)
545
+ metrics = MetricCollection(metrics={
546
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
547
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
548
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
549
+ }).to(device)
550
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
551
+
552
+ # start time of trainig
553
+ start_training = time.perf_counter()
554
+ # create log dict
555
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
556
+ if RESUME_EPOCH > 0:
557
+ resume_logs(logs)
558
+
559
+ #iterate over epochs
560
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
561
+ for epoch in range(start_epoch, NUM_EPOCHS):
562
+ # start time of epoch
563
+ epoch_start = time.perf_counter()
564
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
565
+
566
+ ######################## toggle warmup ########################################
567
+ if (epoch) == WARMUP_EPOCHS:
568
+ warmup_end(model)
569
+ optimizer = get_after_warmup_optimizer(model, optimizer)
570
+ global BATCH_SIZE
571
+ BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP
572
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
573
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
574
+ rand_aug=True)
575
+
576
+ elif (epoch) < WARMUP_EPOCHS:
577
+ print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}')
578
+
579
+ ############################## train phase ####################################
580
+ model.train()
581
+
582
+ # zero the parameter gradients
583
+ optimizer.zero_grad(set_to_none=True)
584
+
585
+ # grad acc loss divider
586
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
587
+
588
+ # iterate over training batches
589
+ for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader):
590
+ inputs = inputs.to(device, non_blocking=True)
591
+ meta = meta.to(device, non_blocking=True)
592
+ labels = labels.to(device, non_blocking=True)
593
+ ccm = ccm.to(device, non_blocking=True)
594
+
595
+ # forward with mixed precision
596
+ with autocast(device_type='cuda', dtype=torch.float16):
597
+ outputs, embeddings = model(inputs, meta)
598
+ loss = loss_fn(outputs, embeddings, labels) / loss_div
599
+
600
+ # loss backward
601
+ scaler.scale(loss).backward()
602
+
603
+ # Compute metrics
604
+ loss_metric.update((loss * loss_div).detach())
605
+
606
+ preds = outputs.softmax(dim=-1).detach()
607
+ metrics.update(preds, labels)
608
+ metric_ccm.update(preds * ccm, labels)
609
+
610
+ ############################ grad acc ##############################
611
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
612
+ #scaler.unscale_(optimizer)
613
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
614
+ scaler.step(optimizer)
615
+ scaler.update()
616
+ # zero the parameter gradients
617
+ optimizer.zero_grad(set_to_none=True)
618
+ # update ema model
619
+ ema_model.update(model)
620
+
621
+
622
+ # compute, sync & reset metrics for validation
623
+ epoch_loss = loss_metric.compute()
624
+ epoch_metrics = metrics.compute()
625
+ epoch_metric_ccm = metric_ccm.compute()
626
+
627
+ loss_metric.reset()
628
+ metrics.reset()
629
+ metric_ccm.reset()
630
+
631
+ # Append metric results to logs
632
+ logs['loss'].append(epoch_loss.cpu().item())
633
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
634
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
635
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
636
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
637
+
638
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
639
+
640
+ # zero the parameter gradients
641
+ optimizer.zero_grad(set_to_none=True)
642
+
643
+ del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
644
+ torch.cuda.empty_cache()
645
+
646
+ ############################## valid phase ####################################
647
+ with torch.no_grad():
648
+ model.eval()
649
+
650
+ # iterate over validation batches
651
+ for (inputs, labels, ccm, meta) in valid_loader:
652
+ inputs = inputs.to(device, non_blocking=True)
653
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
654
+ meta = meta.to(device, non_blocking=True)
655
+ meta = torch.repeat_interleave(meta, repeats=5, dim=0)
656
+ labels = labels.to(device, non_blocking=True)
657
+ ccm = ccm.to(device, non_blocking=True)
658
+
659
+ # forward with mixed precision
660
+ with autocast(device_type='cuda', dtype=torch.float16):
661
+ outputs, embeddings = model(inputs, meta)
662
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
663
+ embeddings = embeddings.view(-1, 5, 1024+128).mean(1)
664
+ loss = loss_fn(outputs, embeddings, labels)
665
+
666
+ # Compute metrics
667
+ loss_metric.update(loss.detach())
668
+
669
+ preds = outputs.softmax(dim=-1).detach()
670
+ metrics.update(preds, labels)
671
+ metric_ccm.update(preds * ccm, labels)
672
+
673
+ # compute, sync & reset metrics for validation
674
+ epoch_loss = loss_metric.compute()
675
+ epoch_metrics = metrics.compute()
676
+ epoch_metric_ccm = metric_ccm.compute()
677
+
678
+ loss_metric.reset()
679
+ metrics.reset()
680
+ metric_ccm.reset()
681
+
682
+ # Append metric results to logs
683
+ logs['val_loss'].append(epoch_loss.cpu().item())
684
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
685
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
686
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
687
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
688
+
689
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
690
+
691
+ del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
692
+ torch.cuda.empty_cache()
693
+
694
+ # save logs as csv
695
+ logs_df = pd.DataFrame(logs)
696
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
697
+
698
+ if WANDB:
699
+ # at the end of each epoch, log anything you want to log for that epoch
700
+ wandb.log(
701
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
702
+ step=epoch # epoch index for wandb
703
+ )
704
+
705
+ #save trained model for each epoch
706
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
707
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
708
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
709
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
710
+ torch.save(loss_fn.arcloss.state_dict(), f'{MODEL_DIR}arcloss_epoch{epoch}.pth')
711
+
712
+ # end time of epoch
713
+ epoch_end = time.perf_counter()
714
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
715
+
716
+ del logs_df, epoch_start, epoch_end
717
+ torch.cuda.empty_cache()
718
+
719
+ ################################## EMA Model Validation ################################
720
+ del model
721
+ torch.cuda.empty_cache()
722
+
723
+ ema_net = ema_model.module
724
+ ema_net.eval()
725
+
726
+ with torch.no_grad():
727
+ # iterate over validation batches
728
+ for (inputs, labels, ccm, meta) in valid_loader:
729
+ inputs = inputs.to(device, non_blocking=True)
730
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
731
+ meta = meta.to(device, non_blocking=True)
732
+ meta = torch.repeat_interleave(meta, repeats=5, dim=0)
733
+ labels = labels.to(device, non_blocking=True)
734
+ ccm = ccm.to(device, non_blocking=True)
735
+
736
+ # forward with mixed precision
737
+ with autocast(device_type='cuda', dtype=torch.float16):
738
+ outputs, embeddings = ema_net(inputs, meta)
739
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
740
+ embeddings = embeddings.view(-1, 5, 1024+128).mean(1)
741
+ loss = loss_fn(outputs, embeddings, labels)
742
+
743
+ # Compute metrics
744
+ loss_metric.update(loss.detach())
745
+
746
+ preds = outputs.softmax(dim=-1).detach()
747
+ metrics.update(preds, labels)
748
+ metric_ccm.update(preds * ccm, labels)
749
+
750
+ # compute, sync & reset metrics for validation
751
+ epoch_loss = loss_metric.compute()
752
+ epoch_metrics = metrics.compute()
753
+ epoch_metric_ccm = metric_ccm.compute()
754
+
755
+ loss_metric.reset()
756
+ metrics.reset()
757
+ metric_ccm.reset()
758
+
759
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
760
+
761
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
762
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
763
+
764
+ plot_history(logs)
765
+ # end time of trainig
766
+ end_training = time.perf_counter()
767
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
768
+
769
+ if WANDB:
770
+ wandb.finish()
771
+
772
+
773
+ if __name__=="__main__":
774
+ main()
775
+
776
+
777
+
778
+
exp4/convnext2b_exp4_meta_embedding_focalloss.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+ #from mixup import Mixup
22
+ #from gridshuffle import RandomGridShuffle
23
+
24
+ from torchmetrics import MeanMetric
25
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
26
+ from torchmetrics import MetricCollection
27
+
28
+ # from pytorch_metric_learning.losses import ArcFaceLoss
29
+
30
+ import wandb
31
+
32
+ import matplotlib.pyplot as plt
33
+
34
+
35
+ # ### parameters
36
+ ################## Settings #############################
37
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
38
+ torch.backends.cudnn.benchmark = True
39
+
40
+ ################## Data Paths ##########################
41
+ MODEL_DIR = "./convnext2b_meta_embedding_focal05es/"
42
+
43
+ if not os.path.exists(MODEL_DIR):
44
+ os.makedirs(MODEL_DIR)
45
+ shutil.copyfile('./convnext2b_exp4_meta_embedding_focalloss.py', f'{MODEL_DIR}convnext2b_exp4_meta_embedding_focalloss.py')
46
+
47
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
48
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
49
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
50
+
51
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
52
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
53
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
54
+
55
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
56
+
57
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
58
+
59
+
60
+ NUM_CLASSES = 1784
61
+
62
+ ################## Hyperparameters ########################
63
+ NUM_EPOCHS = 40
64
+ WARMUP_EPOCHS = 0
65
+ RESUME_EPOCH = 14 # resume model, optimizer from epoch 14 of experiment 3, checkpoint files need to be copied to the MODEL_DIR folder
66
+
67
+
68
+ LEARNING_RATE = {
69
+ 'cnn': 1e-05,
70
+ 'embeddings': 1e-04,
71
+ 'classifier': 1e-04,
72
+ }
73
+
74
+ BATCH_SIZE = {
75
+ 'train': 32,
76
+ 'valid': 48,
77
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
78
+ }
79
+
80
+ BATCH_SIZE_AFTER_WARMUP = {
81
+ 'train': 32,
82
+ 'valid': 48,
83
+ 'grad_acc': 4, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
84
+ }
85
+
86
+ TRANSFORMS = {
87
+ 'IMAGE_SIZE_TRAIN': 544,
88
+ 'IMAGE_SIZE_VAL': 544,
89
+ 'RandAug' : {
90
+ 'm': 7,
91
+ 'n': 2
92
+ },
93
+ }
94
+
95
+
96
+ # ############# Focal Loss ####################
97
+ FOCAL_LOSS = {
98
+ 'class_dist': pickle.load(open("../classDist_HMP_missedRemoved.p", "rb"))['counts'], # snake species frequency obtained on observation_id level taken into account missing observation_id of missing image files
99
+ 'gamma': 0.5, # main difference of experiment 4 as well as weighting term in FocalLoss class
100
+ }
101
+
102
+
103
+ ############# Checkpoints ####################
104
+ CHECKPOINTS = {
105
+ 'fe_cnn': None,
106
+ 'model': None,
107
+ 'optimizer': None,
108
+ 'scaler': None,
109
+ }
110
+
111
+ # ####### Embedding Token Mappings ########################
112
+ META_SIZES = {'endemic': 2, 'code': 212}
113
+ EMBEDDING_SIZES = {'endemic': 64, 'code': 64}
114
+
115
+ CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb"))
116
+ ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb"))
117
+
118
+ ################### WandB ##################
119
+ WANDB = True
120
+
121
+ if WANDB:
122
+ wandb.init(
123
+ entity="snakeclef2023", # our team at wandb
124
+
125
+ # set the wandb project where this run will be logged
126
+ project="exp4", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
127
+
128
+ # define a name for this run
129
+ name="focal05_es",
130
+
131
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
132
+ config={
133
+ "learning_rate": LEARNING_RATE,
134
+ "focal_loss": FOCAL_LOSS,
135
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
136
+ "pretrained": "iNat21",
137
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
138
+ "epochs": NUM_EPOCHS,
139
+ "transforms": TRANSFORMS,
140
+ "checkpoints": CHECKPOINTS,
141
+ "model_dir": MODEL_DIR
142
+ # ... any other hyperparameter that is necessary to reproduce the result
143
+ },
144
+ save_code=True, # save the script file as backup
145
+ dir=MODEL_DIR # locally folder where wandb log files are saved
146
+ )
147
+
148
+
149
+
150
+
151
+ ##################### Dataset & AugTransforms #####################################
152
+ # ### dataset & loaders
153
+ class SnakeTrainDataset(Dataset):
154
+ def __init__(self, data, ccm, transform=None):
155
+ self.data = data
156
+ self.transform = transform # Image augmentation pipeline
157
+ self.code_class_mapping = ccm
158
+ self.code_tokens = CODE_TOKENS
159
+ self.endemic_tokens = ENDEMIC_TOKENS
160
+
161
+ def __len__(self):
162
+ return self.data.shape[0]
163
+
164
+ def __getitem__(self, index):
165
+ obj = self.data.iloc[index] # get instance
166
+ label = obj.class_id # get label
167
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
168
+ endemic = obj.endemic if obj.endemic in self.endemic_tokens.keys() else False # get endemic metadata
169
+
170
+ img = Image.open(obj.image_path).convert("RGB") # load image
171
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
172
+ meta = torch.tensor([self.code_tokens[code], self.endemic_tokens[endemic]]) # metadata tokens
173
+
174
+ # img. augmentation
175
+ img = self.transform(img)
176
+
177
+ return (img, label, ccm, meta)
178
+
179
+
180
+ # valid data preprocessing pipeline
181
+ def get_val_preprocessing(img_size):
182
+ print(f'IMG_SIZE_VAL: {img_size}')
183
+ return transforms.Compose([
184
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
185
+ transforms.Compose([
186
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
187
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
188
+ ]),
189
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
190
+ ])
191
+
192
+ class IdentityTransform:
193
+ def __call__(self, x):
194
+ return x
195
+
196
+
197
+ # train data augmentation/ preprocessing pipeline
198
+ def get_train_augmentation_preprocessing(img_size, rand_aug=False):
199
+ print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}')
200
+ return transforms.Compose([
201
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
202
+ transforms.RandomHorizontalFlip(p=0.5),
203
+ transforms.RandomVerticalFlip(p=0.5),
204
+ transforms.RandomCrop((img_size, img_size)), # Random Crop to IMAGE_SIZE
205
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(),
206
+ transforms.ToTensor(),
207
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
208
+ ])
209
+
210
+
211
+ def get_datasets(train_transfroms, val_transforms):
212
+ # load CSVs
213
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
214
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
215
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
216
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
217
+
218
+ # delete missing files of train data table
219
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
220
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
221
+
222
+ # add image path
223
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
224
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
225
+
226
+ # add additional data
227
+ if ADD_TRAINDATA_CONFIG:
228
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
229
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
230
+ train_data = pd.concat([train_data, add_train_data], axis=0)
231
+
232
+ # limit data size
233
+ #train_data = train_data.head(1000)
234
+ #valid_data = valid_data.head(1000)
235
+ print(f'train data shape: {train_data.shape}')
236
+
237
+ # shuffle
238
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
239
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
240
+
241
+ # load transposed version of CCM table
242
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
243
+
244
+ # create datasets
245
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
246
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
247
+
248
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
249
+
250
+
251
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
252
+ # get train, valid augmentation & preprocessing pipelines
253
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
254
+ val_preprocessing = get_val_preprocessing(imgsize_val)
255
+ # prepare the datasets
256
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
257
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
258
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
259
+
260
+ return train_loader, valid_loader
261
+
262
+
263
+ # #################### plot train history #########################
264
+
265
+ def plot_history(logs):
266
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
267
+
268
+ ax[0].plot(logs['loss'], label="train data")
269
+ ax[0].plot(logs['val_loss'], label="valid data")
270
+ ax[0].legend(loc="best")
271
+ ax[0].set_ylabel("loss")
272
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
273
+ #ax[0].set_xlabel("epochs")
274
+ ax[0].set_title("train- vs. valid loss")
275
+
276
+ ax[1].plot(logs['acc'], label="train data")
277
+ ax[1].plot(logs['val_acc'], label="valid data")
278
+ ax[1].legend(loc="best")
279
+ ax[1].set_ylabel("accuracy")
280
+ ax[1].set_ylim([0, 1.01])
281
+ #ax[1].set_xlabel("epochs")
282
+ ax[1].set_title("train- vs. valid accuracy")
283
+
284
+ ax[2].plot(logs['f1'], label="train data")
285
+ ax[2].plot(logs['val_f1'], label="valid data")
286
+ ax[2].legend(loc="best")
287
+ ax[2].set_ylabel("f1")
288
+ ax[2].set_ylim([0, 1.01])
289
+ ax[2].set_xlabel("epochs")
290
+ ax[2].set_title("train- vs. valid f1")
291
+
292
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
293
+ plt.show()
294
+
295
+ #################### Focal Loss ##################################
296
+ class FocalLoss(nn.Module):
297
+ '''
298
+ Multi-class Focal Loss
299
+ '''
300
+ def __init__(self, gamma, class_dist=None, reduction='mean', device='cuda'):
301
+ super(FocalLoss, self).__init__()
302
+ self.gamma = gamma
303
+ #self.weight = torch.tensor(1.0 / class_dist, dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device) # inverse class frequency weighting
304
+ self.weight = torch.tensor((1.0 - 0.999) / (1.0 - 0.999**class_dist), dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device) # "effectiv number of samples" weighting
305
+ self.reduction = reduction
306
+
307
+ def forward(self, inputs, targets):
308
+ """
309
+ input: [N, C], float32
310
+ target: [N, ], int64
311
+ """
312
+ logpt = torch.nn.functional.log_softmax(inputs, dim=1)
313
+ pt = torch.exp(logpt)
314
+ logpt = (1-pt)**self.gamma * logpt
315
+ loss = torch.nn.functional.nll_loss(logpt, targets, weight=self.weight, reduction=self.reduction)
316
+ return loss
317
+
318
+
319
+ # #################### Model #####################################
320
+
321
+ class FeatureExtractor(nn.Module):
322
+ def __init__(self):
323
+ super(FeatureExtractor, self).__init__()
324
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
325
+ if CHECKPOINTS['fe_cnn']:
326
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
327
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
328
+ torch.cuda.empty_cache()
329
+
330
+ def forward(self, img):
331
+ conv_features = self.conv_backbone(img)
332
+ return conv_features
333
+
334
+
335
+ class MetaEmbeddings(nn.Module):
336
+ def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None):
337
+ super(MetaEmbeddings, self).__init__()
338
+ self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0)
339
+ self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0)
340
+
341
+ self.dim_embedding = sum(embedding_sizes.values())
342
+ self.embedding_net = nn.Sequential(
343
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
344
+ nn.GELU(),
345
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
346
+ nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(),
347
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
348
+ nn.GELU(),
349
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
350
+ )
351
+
352
+ def forward(self, meta):
353
+ code_feature = self.code_embedding(meta[:,0])
354
+ endemic_feature = self.endemic_embedding(meta[:,1])
355
+
356
+ embeddings = torch.concat([code_feature, endemic_feature], dim=-1)
357
+ embedding_features = self.embedding_net(embeddings)
358
+
359
+ return embedding_features
360
+
361
+
362
+ class Classifier(nn.Module):
363
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
364
+ super(Classifier, self).__init__()
365
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
366
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
367
+
368
+ def forward(self, embeddings):
369
+ dropped_feature = self.dropout(embeddings)
370
+ outputs = self.classifier(dropped_feature)
371
+
372
+ return outputs
373
+
374
+
375
+ class Model(nn.Module):
376
+ def __init__(self):
377
+ super(Model, self).__init__()
378
+ self.feature_extractor = FeatureExtractor()
379
+ self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25)
380
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25)
381
+
382
+ def forward(self, img, meta):
383
+ img_features = self.feature_extractor(img)
384
+
385
+ meta_features = self.embedding_net(meta)
386
+ cat_features = torch.concat([img_features, meta_features], dim=-1)
387
+ classifier_outputs = self.classifier(cat_features)
388
+
389
+ return classifier_outputs
390
+
391
+
392
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
393
+ if CHECKPOINTS['model'] and model is not None:
394
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
395
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
396
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
397
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
398
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
399
+ if CHECKPOINTS['scaler'] and scaler is not None:
400
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
401
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
402
+ torch.cuda.empty_cache()
403
+
404
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
405
+ if model is not None:
406
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
407
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
408
+ if optimizer is not None:
409
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
410
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
411
+
412
+ if scaler is not None:
413
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
414
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
415
+ torch.cuda.empty_cache()
416
+
417
+
418
+ def resume_logs(logs):
419
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
420
+ for m in list(logs.keys()):
421
+ logs[m].extend(list(old_logs[m].values))
422
+
423
+ ######################## Optimizer #####################################
424
+ def get_optm_group(module):
425
+ """
426
+ This long function is unfortunately doing something very simple and is being very defensive:
427
+ We are separating out all parameters of the model into two buckets: those that will experience
428
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
429
+ We are then returning the PyTorch optimizer object.
430
+ """
431
+
432
+ # separate out all parameters to those that will and won't experience regularizing weight decay
433
+ decay = set()
434
+ no_decay = set()
435
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
436
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
437
+ for mn, m in module.named_modules():
438
+ for pn, p in m.named_parameters():
439
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
440
+
441
+ if pn.endswith('bias'):
442
+ # all biases will not be decayed
443
+ no_decay.add(fpn)
444
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
445
+ # weights of whitelist modules will be weight decayed
446
+ decay.add(fpn)
447
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
448
+ # weights of blacklist modules will NOT be weight decayed
449
+ no_decay.add(fpn)
450
+
451
+
452
+ # validate that we considered every parameter
453
+ param_dict = {pn: p for pn, p in module.named_parameters()}
454
+ inter_params = decay & no_decay
455
+ union_params = decay | no_decay
456
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
457
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
458
+ % (str(param_dict.keys() - union_params), )
459
+
460
+ return param_dict, decay, no_decay
461
+
462
+
463
+ def get_warmup_optimizer(model):
464
+ params_group = []
465
+
466
+ param_dict, decay, no_decay = get_optm_group(model.embedding_net)
467
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']})
468
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']})
469
+
470
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
471
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
472
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
473
+
474
+ optimizer = torch.optim.AdamW(params_group)
475
+ return optimizer
476
+
477
+
478
+ def get_after_warmup_optimizer(model, old_opt):
479
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
480
+
481
+ # add old param groups
482
+ for group in old_opt.param_groups:
483
+ new_opt.add_param_group(group)
484
+
485
+ return new_opt
486
+
487
+
488
+ # #################### Model Warmup #####################################
489
+
490
+ def warmup_start(model):
491
+ # freeze model feature_extractor.conv_backbone during warmup
492
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
493
+ param.requires_grad = False
494
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
495
+
496
+ def warmup_end(model):
497
+ # unfreeze feature_extractor.conv_backbone during warmup
498
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
499
+ param.requires_grad = True
500
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
501
+
502
+
503
+ # #################### Train Loop #####################################
504
+
505
+ # ### train
506
+ def main():
507
+ device = torch.device(f'cuda:1')
508
+ torch.cuda.set_device(device)
509
+
510
+ # prepare the datasets
511
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
512
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
513
+ rand_aug=True)
514
+
515
+ # instantiate the model
516
+ model = Model().to(device)
517
+ #load_checkpoints(model=model)
518
+ if RESUME_EPOCH > 0:
519
+ resume_checkpoints(model=model)
520
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
521
+ #warmup_start(model)
522
+
523
+ # Optimizer & Schedules & early stopping
524
+ optimizer = get_warmup_optimizer(model)
525
+ scaler = GradScaler()
526
+ #load_checkpoints(optimizer=optimizer, scaler=scaler)
527
+ if RESUME_EPOCH > 0:
528
+ optimizer = get_after_warmup_optimizer(model, optimizer) if RESUME_EPOCH > WARMUP_EPOCHS else optimizer
529
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
530
+
531
+ loss_fn = FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'], device=device)
532
+ loss_val_fn = nn.CrossEntropyLoss()
533
+
534
+ # running metrics during training
535
+ loss_metric = MeanMetric().to(device)
536
+ metrics = MetricCollection(metrics={
537
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
538
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
539
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
540
+ }).to(device)
541
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
542
+
543
+ # start time of trainig
544
+ start_training = time.perf_counter()
545
+ # create log dict
546
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
547
+ if RESUME_EPOCH > 0:
548
+ resume_logs(logs)
549
+
550
+ #iterate over epochs
551
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
552
+ for epoch in range(start_epoch, NUM_EPOCHS):
553
+ # start time of epoch
554
+ epoch_start = time.perf_counter()
555
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
556
+
557
+ ######################## toggle warmup ########################################
558
+ if (epoch) == WARMUP_EPOCHS:
559
+ warmup_end(model)
560
+ optimizer = get_after_warmup_optimizer(model, optimizer)
561
+ global BATCH_SIZE
562
+ BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP
563
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
564
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
565
+ rand_aug=True)
566
+
567
+ elif (epoch) < WARMUP_EPOCHS:
568
+ print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}')
569
+
570
+ ############################## train phase ####################################
571
+ model.train()
572
+
573
+ # zero the parameter gradients
574
+ optimizer.zero_grad(set_to_none=True)
575
+
576
+ # grad acc loss divider
577
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
578
+
579
+ # iterate over training batches
580
+ for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader):
581
+ inputs = inputs.to(device, non_blocking=True)
582
+ meta = meta.to(device, non_blocking=True)
583
+ labels = labels.to(device, non_blocking=True)
584
+ ccm = ccm.to(device, non_blocking=True)
585
+
586
+ # forward with mixed precision
587
+ with autocast(device_type='cuda', dtype=torch.float16):
588
+ outputs = model(inputs, meta)
589
+ loss = loss_fn(outputs, labels) / loss_div
590
+
591
+ # loss backward
592
+ scaler.scale(loss).backward()
593
+
594
+ # Compute metrics
595
+ loss_metric.update((loss * loss_div).detach())
596
+
597
+ preds = outputs.softmax(dim=-1).detach()
598
+ metrics.update(preds, labels)
599
+ metric_ccm.update(preds * ccm, labels)
600
+
601
+ ############################ grad acc ##############################
602
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
603
+ #scaler.unscale_(optimizer)
604
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
605
+ scaler.step(optimizer)
606
+ scaler.update()
607
+ # zero the parameter gradients
608
+ optimizer.zero_grad(set_to_none=True)
609
+ # update ema model
610
+ ema_model.update(model)
611
+
612
+
613
+ # compute, sync & reset metrics for validation
614
+ epoch_loss = loss_metric.compute()
615
+ epoch_metrics = metrics.compute()
616
+ epoch_metric_ccm = metric_ccm.compute()
617
+
618
+ loss_metric.reset()
619
+ metrics.reset()
620
+ metric_ccm.reset()
621
+
622
+ # Append metric results to logs
623
+ logs['loss'].append(epoch_loss.cpu().item())
624
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
625
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
626
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
627
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
628
+
629
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
630
+
631
+ # zero the parameter gradients
632
+ optimizer.zero_grad(set_to_none=True)
633
+
634
+ del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
635
+ torch.cuda.empty_cache()
636
+
637
+ ############################## valid phase ####################################
638
+ with torch.no_grad():
639
+ model.eval()
640
+
641
+ # iterate over validation batches
642
+ for (inputs, labels, ccm, meta) in valid_loader:
643
+ inputs = inputs.to(device, non_blocking=True)
644
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
645
+ meta = meta.to(device, non_blocking=True)
646
+ meta = torch.repeat_interleave(meta, repeats=5, dim=0)
647
+ labels = labels.to(device, non_blocking=True)
648
+ ccm = ccm.to(device, non_blocking=True)
649
+
650
+ # forward with mixed precision
651
+ with autocast(device_type='cuda', dtype=torch.float16):
652
+ outputs = model(inputs, meta)
653
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
654
+ loss = loss_val_fn(outputs, labels)
655
+
656
+ # Compute metrics
657
+ loss_metric.update(loss.detach())
658
+
659
+ preds = outputs.softmax(dim=-1).detach()
660
+ metrics.update(preds, labels)
661
+ metric_ccm.update(preds * ccm, labels)
662
+
663
+ # compute, sync & reset metrics for validation
664
+ epoch_loss = loss_metric.compute()
665
+ epoch_metrics = metrics.compute()
666
+ epoch_metric_ccm = metric_ccm.compute()
667
+
668
+ loss_metric.reset()
669
+ metrics.reset()
670
+ metric_ccm.reset()
671
+
672
+ # Append metric results to logs
673
+ logs['val_loss'].append(epoch_loss.cpu().item())
674
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
675
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
676
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
677
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
678
+
679
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
680
+
681
+ del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
682
+ torch.cuda.empty_cache()
683
+
684
+ # save logs as csv
685
+ logs_df = pd.DataFrame(logs)
686
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
687
+
688
+ if WANDB:
689
+ # at the end of each epoch, log anything you want to log for that epoch
690
+ wandb.log(
691
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
692
+ step=epoch # epoch index for wandb
693
+ )
694
+
695
+ #save trained model for each epoch
696
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
697
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
698
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
699
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
700
+
701
+ # end time of epoch
702
+ epoch_end = time.perf_counter()
703
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
704
+
705
+ del logs_df, epoch_start, epoch_end
706
+ torch.cuda.empty_cache()
707
+
708
+ ################################## EMA Model Validation ################################
709
+ del model
710
+ torch.cuda.empty_cache()
711
+
712
+ ema_net = ema_model.module
713
+ ema_net.eval()
714
+
715
+ with torch.no_grad():
716
+ # iterate over validation batches
717
+ for (inputs, labels, ccm, meta) in valid_loader:
718
+ inputs = inputs.to(device, non_blocking=True)
719
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
720
+ meta = meta.to(device, non_blocking=True)
721
+ meta = torch.repeat_interleave(meta, repeats=5, dim=0)
722
+ labels = labels.to(device, non_blocking=True)
723
+ ccm = ccm.to(device, non_blocking=True)
724
+
725
+ # forward with mixed precision
726
+ with autocast(device_type='cuda', dtype=torch.float16):
727
+ outputs = ema_net(inputs, meta)
728
+ outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1)
729
+ loss = loss_val_fn(outputs, labels)
730
+
731
+ # Compute metrics
732
+ loss_metric.update(loss.detach())
733
+
734
+ preds = outputs.softmax(dim=-1).detach()
735
+ metrics.update(preds, labels)
736
+ metric_ccm.update(preds * ccm, labels)
737
+
738
+ # compute, sync & reset metrics for validation
739
+ epoch_loss = loss_metric.compute()
740
+ epoch_metrics = metrics.compute()
741
+ epoch_metric_ccm = metric_ccm.compute()
742
+
743
+ loss_metric.reset()
744
+ metrics.reset()
745
+ metric_ccm.reset()
746
+
747
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
748
+
749
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
750
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
751
+
752
+ plot_history(logs)
753
+ # end time of trainig
754
+ end_training = time.perf_counter()
755
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
756
+
757
+ if WANDB:
758
+ wandb.finish()
759
+
760
+
761
+ if __name__=="__main__":
762
+ main()
763
+
764
+
765
+
766
+
exp5/convnext2b_exp5_OBIDattention.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.policy import strict
2
+ import os, time, pickle, shutil
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ from PIL import Image, ImageFile
7
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from torch.cuda.amp import GradScaler
13
+ from torch import autocast
14
+
15
+ import torchvision.transforms as transforms
16
+
17
+ import timm
18
+ from timm.models import create_model
19
+ from timm.utils import ModelEmaV2
20
+
21
+ from timm.optim import create_optimizer_v2
22
+
23
+ from torchmetrics import MeanMetric
24
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
25
+ from torchmetrics import MetricCollection
26
+
27
+ from pytorch_metric_learning.losses import ArcFaceLoss
28
+
29
+ import wandb
30
+
31
+ import matplotlib.pyplot as plt
32
+
33
+
34
+ # ### parameters
35
+ ################## Settings #############################
36
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
37
+ torch.backends.cudnn.benchmark = True
38
+
39
+ ################## Data Paths ##########################
40
+ MODEL_DIR = "./convnext2b_obdid_attention/"
41
+
42
+ if not os.path.exists(MODEL_DIR):
43
+ os.makedirs(MODEL_DIR)
44
+ shutil.copyfile('./convnext2b_exp5_OBIDattention.py', f'{MODEL_DIR}convnext2b_exp5_OBIDattention.py')
45
+
46
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
47
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
48
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
49
+
50
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
51
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
52
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
53
+
54
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
55
+
56
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
57
+
58
+
59
+ NUM_CLASSES = 1784
60
+
61
+ ################## Hyperparameters ########################
62
+ NUM_EPOCHS = 50
63
+ WARMUP_EPOCHS = 0
64
+ RESUME_EPOCH = 39 # resume model, optimizer from epoch 39 of experiment 4, checkpoint files need to be copied to the MODEL_DIR folder
65
+
66
+
67
+ LEARNING_RATE = {
68
+ 'cnn': 1e-05,
69
+ 'embeddings': 1e-04,
70
+ 'classifier': 1e-04,
71
+ 'attention': 1e-04,
72
+ }
73
+
74
+ BATCH_SIZE = {
75
+ 'train': 1,
76
+ 'valid': 1,
77
+ 'grad_acc': 128, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
78
+ 'max_imgs_per_instance': 100 # maximum number of considered image instance (includes TTA) for each observation_id
79
+ }
80
+
81
+ BATCH_SIZE_AFTER_WARMUP = {
82
+ 'train': 1,
83
+ 'valid': 1,
84
+ 'grad_acc': 128, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
85
+ 'max_imgs_per_instance': 100 # maximum number of considered image instance (includes TTA) for each observation_id
86
+ }
87
+
88
+ TRANSFORMS = {
89
+ 'IMAGE_SIZE_TRAIN': 544,
90
+ 'IMAGE_SIZE_VAL': 544,
91
+ 'RandAug' : {
92
+ 'm': 7,
93
+ 'n': 2
94
+ },
95
+ 'num_rand_crops': 5, # num. of random crops during training per image instance
96
+ }
97
+
98
+
99
+ ############# Focal Loss ####################
100
+ FOCAL_LOSS = {
101
+ 'class_dist': pickle.load(open("../classDist_HMP_missedRemoved.p", "rb"))['counts'], # snake species frequency obtained on observation_id level taken into account missing observation_id of missing image files
102
+ 'gamma': 0.5,
103
+ }
104
+
105
+
106
+ ############# Checkpoints ####################
107
+ CHECKPOINTS = {
108
+ 'fe_cnn': None,
109
+ 'model': None,
110
+ 'optimizer': None,
111
+ 'scaler': None,
112
+ 'arcloss': None,
113
+ }
114
+
115
+ # ####### Embedding Token Mappings ########################
116
+ META_SIZES = {'endemic': 2, 'code': 212}
117
+ EMBEDDING_SIZES = {'endemic': 64, 'code': 64}
118
+
119
+ CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb"))
120
+ ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb"))
121
+
122
+ ################### WandB ##################
123
+ WANDB = True
124
+
125
+ if WANDB:
126
+ wandb.init(
127
+ entity="snakeclef2023", # our team at wandb
128
+
129
+ # set the wandb project where this run will be logged
130
+ project="exp5", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
131
+
132
+ # define a name for this run
133
+ name="OBIDattention",
134
+
135
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
136
+ config={
137
+ "learning_rate": LEARNING_RATE,
138
+ "focal_loss": FOCAL_LOSS,
139
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
140
+ "pretrained": "iNat21",
141
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
142
+ "epochs": NUM_EPOCHS,
143
+ "transforms": TRANSFORMS,
144
+ "checkpoints": CHECKPOINTS,
145
+ "model_dir": MODEL_DIR
146
+ # ... any other hyperparameter that is necessary to reproduce the result
147
+ },
148
+ save_code=True, # save the script file as backup
149
+ dir=MODEL_DIR # locally folder where wandb log files are saved
150
+ )
151
+
152
+
153
+
154
+
155
+ ##################### Dataset & AugTransforms #####################################
156
+ # ### dataset & loaders
157
+ class SnakeInstanceDataset(Dataset):
158
+ def __init__(self, data, ccm, transform, fix_num=None):
159
+ self.data = data
160
+ self.instance_groups = data.groupby('observation_id').groups
161
+ self.instance_obids = list(self.instance_groups.keys())
162
+
163
+ self.transform = transform # Image augmentation pipeline
164
+ self.code_class_mapping = ccm
165
+ self.code_tokens = CODE_TOKENS
166
+ self.endemic_tokens = ENDEMIC_TOKENS
167
+
168
+ self.fix_num = fix_num
169
+ self.random_gen = torch.Generator().manual_seed(1)
170
+
171
+
172
+ def __len__(self):
173
+ return len(self.instance_obids)
174
+
175
+ def __getitem__(self, index):
176
+ obid = self.instance_obids[index] # get observation id
177
+ instances = self.data.iloc[self.instance_groups[obid]]
178
+
179
+ code = instances.code.tolist()[0]
180
+ code = code if code in self.code_tokens.keys() else "unknown"
181
+ endemic = instances.endemic.tolist()[0]
182
+ endemic = endemic if endemic in self.endemic_tokens.keys() else False # get endemic metadata
183
+
184
+ label = torch.tensor([instances.class_id.tolist()[0]]) # get "global" label
185
+ ccm = torch.from_numpy(self.code_class_mapping[code].to_numpy()) # code class mapping
186
+ meta = torch.tensor([[self.code_tokens[code], self.endemic_tokens[endemic]]]) # metadata tokens
187
+
188
+ # load instance images
189
+ files = instances.image_path.tolist()
190
+ imgs = torch.stack([self.transform(Image.open(file).convert("RGB")) for file in files])
191
+ img_size = imgs.size(-1)
192
+ imgs = imgs.view(-1, 3, img_size, img_size)
193
+
194
+ # randomly shuffle imgs and/or draw subset of imgs
195
+ num_imgs = imgs.size(0)
196
+ idx = torch.randperm(num_imgs, generator=self.random_gen)
197
+ idx = idx[:self.fix_num] if self.fix_num else idx # randomly draw 5 imgs
198
+ imgs = imgs[idx, :, :, :]
199
+
200
+ return (imgs, label, ccm, meta)
201
+
202
+
203
+ # valid data preprocessing pipeline
204
+ def get_val_preprocessing(img_size):
205
+ print(f'IMG_SIZE_VAL: {img_size}')
206
+ return transforms.Compose([
207
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
208
+ transforms.Compose([
209
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
210
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
211
+ ]),
212
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
213
+ ])
214
+
215
+ class MultipleRandomCropsWithAugmentation:
216
+ def __init__(self, img_size, num_crops=5):
217
+ super(MultipleRandomCropsWithAugmentation, self).__init__()
218
+ self.num_crops = num_crops
219
+ self.random_crop = transforms.RandomCrop((img_size, img_size))
220
+ self.augment = transforms.Compose([
221
+ transforms.RandomHorizontalFlip(p=0.5),
222
+ transforms.RandomVerticalFlip(p=0.5),
223
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m'])
224
+ ])
225
+ self.to_tensor = transforms.ToTensor()
226
+
227
+ def __call__(self, x):
228
+ x = torch.stack([self.to_tensor(self.augment(self.random_crop(x))) for i in range(self.num_crops)])
229
+ return x
230
+
231
+ # train data augmentation/ preprocessing pipeline
232
+ def get_train_augmentation_preprocessing(img_size):
233
+ print(f'IMG_SIZE_TRAIN: {img_size}')
234
+ return transforms.Compose([
235
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
236
+ MultipleRandomCropsWithAugmentation(img_size, TRANSFORMS['num_rand_crops']),
237
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
238
+ ])
239
+
240
+
241
+ def get_datasets(train_transfroms, val_transforms):
242
+ # load CSVs
243
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
244
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
245
+ train_data = train_data.drop_duplicates(subset='image_path', keep="first")
246
+
247
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
248
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
249
+ valid_data = valid_data.drop_duplicates(subset='image_path', keep="first")
250
+
251
+
252
+ # delete missing files of train data table
253
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
254
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
255
+
256
+ # load transposed version of CCM table
257
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
258
+
259
+ # add image path
260
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
261
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
262
+
263
+ # add additional data
264
+ if ADD_TRAINDATA_CONFIG:
265
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
266
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
267
+ train_data = pd.concat([train_data, add_train_data], axis=0)
268
+
269
+ # limit data size
270
+ #train_data = train_data.head(150)
271
+ #valid_data = valid_data.head(150)
272
+
273
+ # shuffle
274
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
275
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
276
+
277
+ # compute train, valid data weights
278
+ #TCLASS_WEIGHTS = compute_weights(train_data)
279
+ #VCLASS_WEIGHTS = compute_weights(valid_data)
280
+
281
+ # create datasets
282
+ train_dataset = SnakeInstanceDataset(train_data, ccm, transform=train_transfroms, fix_num=BATCH_SIZE['max_imgs_per_instance'])
283
+ valid_dataset = SnakeInstanceDataset(valid_data, ccm, transform=val_transforms, fix_num=BATCH_SIZE['max_imgs_per_instance'])
284
+ print(f'train dataset shape: {len(train_dataset)}')
285
+ print(f'valid dataset shape: {len(valid_dataset)}')
286
+
287
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
288
+
289
+ def get_collate_fn():
290
+ def collate_fn(batch):
291
+ imgs = batch[0][0]
292
+ targets = batch[0][1]
293
+ ccm = batch[0][2]
294
+ meta = batch[0][3]
295
+ return [imgs, targets, ccm, meta]
296
+ return collate_fn
297
+
298
+ def get_dataloaders(imgsize_train, imgsize_val):
299
+ # get train, valid augmentation & preprocessing pipelines
300
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train)
301
+ val_preprocessing = get_val_preprocessing(imgsize_val)
302
+ # prepare the datasets
303
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
304
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=1, num_workers=4, prefetch_factor=8, collate_fn=get_collate_fn(), drop_last=False, pin_memory=True)
305
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=1, num_workers=4, prefetch_factor=8, collate_fn=get_collate_fn(), drop_last=False, pin_memory=True)
306
+
307
+ return train_loader, valid_loader
308
+
309
+
310
+ # #################### plot train history #########################
311
+
312
+ def plot_history(logs):
313
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
314
+
315
+ ax[0].plot(logs['loss'], label="train data")
316
+ ax[0].plot(logs['val_loss'], label="valid data")
317
+ ax[0].legend(loc="best")
318
+ ax[0].set_ylabel("loss")
319
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
320
+ #ax[0].set_xlabel("epochs")
321
+ ax[0].set_title("train- vs. valid loss")
322
+
323
+ ax[1].plot(logs['acc'], label="train data")
324
+ ax[1].plot(logs['val_acc'], label="valid data")
325
+ ax[1].legend(loc="best")
326
+ ax[1].set_ylabel("accuracy")
327
+ ax[1].set_ylim([0, 1.01])
328
+ #ax[1].set_xlabel("epochs")
329
+ ax[1].set_title("train- vs. valid accuracy")
330
+
331
+ ax[2].plot(logs['f1'], label="train data")
332
+ ax[2].plot(logs['val_f1'], label="valid data")
333
+ ax[2].legend(loc="best")
334
+ ax[2].set_ylabel("f1")
335
+ ax[2].set_ylim([0, 1.01])
336
+ ax[2].set_xlabel("epochs")
337
+ ax[2].set_title("train- vs. valid f1")
338
+
339
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
340
+ plt.show()
341
+
342
+ #################### Focal Loss ##################################
343
+ class FocalLoss(nn.Module):
344
+ '''
345
+ Multi-class Focal Loss
346
+ '''
347
+ def __init__(self, gamma=2, class_dist=None, reduction='mean', device='cuda'):
348
+ super(FocalLoss, self).__init__()
349
+ self.gamma = gamma
350
+ self.weight = torch.tensor((1.0 - 0.999) / (1.0 - 0.999**class_dist), dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device)
351
+ self.reduction = reduction
352
+
353
+ def forward(self, inputs, targets):
354
+ """
355
+ input: [N, C], float32
356
+ target: [N, ], int64
357
+ """
358
+ logpt = torch.nn.functional.log_softmax(inputs, dim=1)
359
+ pt = torch.exp(logpt)
360
+ logpt = (1-pt)**self.gamma * logpt
361
+ loss = torch.nn.functional.nll_loss(logpt, targets, weight=self.weight, reduction=self.reduction)
362
+ return loss
363
+
364
+
365
+ # #################### Model #####################################
366
+
367
+ class FeatureExtractor(nn.Module):
368
+ def __init__(self):
369
+ super(FeatureExtractor, self).__init__()
370
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
371
+ if CHECKPOINTS['fe_cnn']:
372
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
373
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
374
+ torch.cuda.empty_cache()
375
+
376
+ def forward(self, img):
377
+ conv_features = self.conv_backbone(img)
378
+ return conv_features
379
+
380
+
381
+ class MetaEmbeddings(nn.Module):
382
+ def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None):
383
+ super(MetaEmbeddings, self).__init__()
384
+ self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0)
385
+ self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0)
386
+
387
+ self.dim_embedding = sum(embedding_sizes.values())
388
+ self.embedding_net = nn.Sequential(
389
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
390
+ nn.GELU(),
391
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
392
+ nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(),
393
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
394
+ nn.GELU(),
395
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
396
+ )
397
+
398
+ def forward(self, meta):
399
+ code_feature = self.code_embedding(meta[:,0])
400
+ endemic_feature = self.endemic_embedding(meta[:,1])
401
+
402
+ embeddings = torch.concat([code_feature, endemic_feature], dim=-1)
403
+ embedding_features = self.embedding_net(embeddings)
404
+
405
+ return embedding_features
406
+
407
+
408
+ class Classifier(nn.Module):
409
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
410
+ super(Classifier, self).__init__()
411
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
412
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
413
+
414
+ def forward(self, embeddings):
415
+ dropped_feature = self.dropout(embeddings)
416
+ outputs = self.classifier(dropped_feature)
417
+
418
+ return outputs
419
+
420
+ class Attention(nn.Module):
421
+ def __init__(self):
422
+ super(Attention, self).__init__()
423
+ self.L = 1024
424
+ self.D = 256
425
+ self.K = 1
426
+
427
+ self.attention = nn.Sequential(
428
+ nn.Linear(self.L, self.D),
429
+ nn.Tanh(),
430
+ nn.Linear(self.D, self.K)
431
+ )
432
+
433
+ def forward(self, x):
434
+ N, L = x.shape
435
+ x = x.view(1,N,L)
436
+
437
+ A = self.attention(x) # 1xNx1
438
+ A = torch.transpose(A, 2, 1) # 1x1xN
439
+ A = nn.functional.softmax(A, dim=-1) # softmax over N
440
+ M = torch.bmm(A, x).squeeze(dim=1) # 1xL
441
+
442
+ return M, A
443
+
444
+
445
+ class Model(nn.Module):
446
+ def __init__(self):
447
+ super(Model, self).__init__()
448
+ self.feature_extractor = FeatureExtractor()
449
+ self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25)
450
+ self.mil_pooling = Attention()
451
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25)
452
+
453
+ def forward(self, img, meta):
454
+ img_features = self.feature_extractor(img)
455
+ img_features, A = self.mil_pooling(img_features)
456
+
457
+ meta_features = self.embedding_net(meta)
458
+ cat_features = torch.concat([img_features, meta_features], dim=-1)
459
+ classifier_outputs = self.classifier(cat_features)
460
+
461
+ return classifier_outputs, cat_features
462
+
463
+ class LossLayer(nn.Module):
464
+ def __init__(self):
465
+ super(LossLayer, self).__init__()
466
+ self.arcloss = ArcFaceLoss(num_classes=NUM_CLASSES, embedding_size=1024+128, margin=28.6, scale=64)
467
+ self.celoss = FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'])
468
+
469
+ def forward(self, classifier_outputs, cat_features, labels):
470
+ classifier_loss = self.celoss(classifier_outputs, labels)
471
+ embedding_loss = self.arcloss(cat_features, labels)
472
+ return classifier_loss + embedding_loss
473
+
474
+
475
+ def load_checkpoints(model=None, ema_model=None, optimizer=None, scaler=None, arcloss=None):
476
+ if CHECKPOINTS['model'] and model is not None:
477
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'), strict=False)
478
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
479
+ if CHECKPOINTS['ema_model'] and ema_model is not None:
480
+ ema_model.load_state_dict(torch.load(CHECKPOINTS['ema_model'], map_location='cpu'), strict=False)
481
+ print(f"use ema_model checkpoints: {CHECKPOINTS['ema_model']}")
482
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
483
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
484
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
485
+ if CHECKPOINTS['scaler'] and scaler is not None:
486
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
487
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
488
+ if CHECKPOINTS['arcloss'] and arcloss is not None:
489
+ arcloss.load_state_dict(torch.load(CHECKPOINTS['arcloss'], map_location='cpu'))
490
+ print(f"use arcloss checkpoints: {CHECKPOINTS['arcloss']}")
491
+ torch.cuda.empty_cache()
492
+
493
+ def resume_checkpoints(model=None, optimizer=None, scaler=None):
494
+ if model is not None:
495
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'), strict=False)
496
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
497
+ if optimizer is not None:
498
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
499
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
500
+
501
+ if scaler is not None:
502
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
503
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
504
+ torch.cuda.empty_cache()
505
+
506
+
507
+ def resume_logs(logs):
508
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
509
+ for m in list(logs.keys()):
510
+ logs[m].extend(list(old_logs[m].values))
511
+
512
+ ######################## Optimizer #####################################
513
+ def get_optm_group(module):
514
+ """
515
+ This long function is unfortunately doing something very simple and is being very defensive:
516
+ We are separating out all parameters of the model into two buckets: those that will experience
517
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
518
+ We are then returning the PyTorch optimizer object.
519
+ """
520
+
521
+ # separate out all parameters to those that will and won't experience regularizing weight decay
522
+ decay = set()
523
+ no_decay = set()
524
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
525
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
526
+ for mn, m in module.named_modules():
527
+ for pn, p in m.named_parameters():
528
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
529
+
530
+ if pn.endswith('bias'):
531
+ # all biases will not be decayed
532
+ no_decay.add(fpn)
533
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
534
+ # weights of whitelist modules will be weight decayed
535
+ decay.add(fpn)
536
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
537
+ # weights of blacklist modules will NOT be weight decayed
538
+ no_decay.add(fpn)
539
+
540
+
541
+ # validate that we considered every parameter
542
+ param_dict = {pn: p for pn, p in module.named_parameters()}
543
+ inter_params = decay & no_decay
544
+ union_params = decay | no_decay
545
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
546
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
547
+ % (str(param_dict.keys() - union_params), )
548
+
549
+ return param_dict, decay, no_decay
550
+
551
+
552
+ def get_warmup_optimizer(model):
553
+ params_group = []
554
+
555
+ param_dict, decay, no_decay = get_optm_group(model.embedding_net)
556
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']})
557
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']})
558
+
559
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
560
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
561
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
562
+
563
+ optimizer = torch.optim.AdamW(params_group)
564
+ return optimizer
565
+
566
+
567
+ def get_after_warmup_optimizer(model, old_opt):
568
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
569
+
570
+ # add old param groups
571
+ for group in old_opt.param_groups:
572
+ new_opt.add_param_group(group)
573
+
574
+ return new_opt
575
+
576
+
577
+ # #################### Model Warmup #####################################
578
+
579
+ def warmup_start(model):
580
+ # freeze model feature_extractor.conv_backbone during warmup
581
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
582
+ param.requires_grad = False
583
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
584
+
585
+ # freeze model feature_extractor.conv_backbone during warmup
586
+ for i, (param_name, param) in enumerate(model.embedding_net.named_parameters()):
587
+ param.requires_grad = False
588
+ print(f'--> freeze feature_extractor.embedding_net during warmup phase')
589
+
590
+ def warmup_end(model):
591
+ # unfreeze feature_extractor.conv_backbone during warmup
592
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
593
+ param.requires_grad = True
594
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
595
+
596
+ # freeze model feature_extractor.conv_backbone during warmup
597
+ for i, (param_name, param) in enumerate(model.embedding_net.named_parameters()):
598
+ param.requires_grad = True
599
+ print(f'--> unfreeze feature_extractor.embedding_net during warmup phase')
600
+
601
+
602
+ # #################### Train Loop #####################################
603
+
604
+ # ### train
605
+ def main():
606
+ device = torch.device(f'cuda:1')
607
+ torch.cuda.set_device(device)
608
+
609
+ # prepare the datasets
610
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
611
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
612
+ )
613
+
614
+ # instantiate the model
615
+ model = Model().to(device)
616
+ if RESUME_EPOCH > 0:
617
+ resume_checkpoints(model=model)
618
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
619
+ warmup_start(model)
620
+
621
+ loss_fn = LossLayer().to(device)
622
+ if RESUME_EPOCH > 0:
623
+ resume_checkpoints(arcloss=loss_fn.arcloss)
624
+
625
+ # Optimizer & Schedules & early stopping
626
+ optimizer = get_warmup_optimizer(model)
627
+ optimizer.add_param_group({"params": loss_fn.arcloss.parameters(), "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
628
+
629
+ scaler = GradScaler()
630
+ if RESUME_EPOCH > 0:
631
+ #optimizer = get_after_warmup_optimizer(model, optimizer) if RESUME_EPOCH > WARMUP_EPOCHS else optimizer
632
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
633
+
634
+ # add attention module
635
+ param_dict, decay, no_decay = get_optm_group(model.attention)
636
+ optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['attention']})
637
+ optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['attention']})
638
+
639
+ # running metrics during training
640
+ loss_metric = MeanMetric().to(device)
641
+ metrics = MetricCollection(metrics={
642
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
643
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
644
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
645
+ }).to(device)
646
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
647
+
648
+ # start time of trainig
649
+ start_training = time.perf_counter()
650
+ # create log dict
651
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
652
+ if RESUME_EPOCH > 0:
653
+ resume_logs(logs)
654
+
655
+ #iterate over epochs
656
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
657
+ for epoch in range(start_epoch, NUM_EPOCHS):
658
+ # start time of epoch
659
+ epoch_start = time.perf_counter()
660
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
661
+
662
+ ############################## train phase ####################################
663
+ model.train()
664
+
665
+ # zero the parameter gradients
666
+ optimizer.zero_grad(set_to_none=True)
667
+
668
+ # grad acc loss divider
669
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
670
+
671
+ # iterate over training batches
672
+ for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader):
673
+ inputs = inputs.to(device, non_blocking=True)
674
+ meta = meta.to(device, non_blocking=True)
675
+ labels = labels.to(device, non_blocking=True)
676
+ ccm = ccm.to(device, non_blocking=True)
677
+
678
+ # forward with mixed precision
679
+ with autocast(device_type='cuda', dtype=torch.float16):
680
+ outputs, embeddings = model(inputs, meta)
681
+ loss = loss_fn(outputs, embeddings, labels) / loss_div
682
+
683
+ # loss backward
684
+ scaler.scale(loss).backward()
685
+
686
+ # Compute metrics
687
+ loss_metric.update((loss * loss_div).detach())
688
+
689
+ preds = outputs.softmax(dim=-1).detach()
690
+ metrics.update(preds, labels)
691
+ metric_ccm.update(preds * ccm, labels)
692
+
693
+ ############################ grad acc ##############################
694
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
695
+ #scaler.unscale_(optimizer)
696
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
697
+ scaler.step(optimizer)
698
+ scaler.update()
699
+ # zero the parameter gradients
700
+ optimizer.zero_grad(set_to_none=True)
701
+ # update ema model
702
+ ema_model.update(model)
703
+
704
+
705
+ # compute, sync & reset metrics for validation
706
+ epoch_loss = loss_metric.compute()
707
+ epoch_metrics = metrics.compute()
708
+ epoch_metric_ccm = metric_ccm.compute()
709
+
710
+ loss_metric.reset()
711
+ metrics.reset()
712
+ metric_ccm.reset()
713
+
714
+ # Append metric results to logs
715
+ logs['loss'].append(epoch_loss.cpu().item())
716
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
717
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
718
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
719
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
720
+
721
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
722
+
723
+ # zero the parameter gradients
724
+ optimizer.zero_grad(set_to_none=True)
725
+
726
+ del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
727
+ torch.cuda.empty_cache()
728
+
729
+ ############################## valid phase ####################################
730
+ with torch.no_grad():
731
+ model.eval()
732
+
733
+ # iterate over validation batches
734
+ for (inputs, labels, ccm, meta) in valid_loader:
735
+ inputs = inputs.to(device, non_blocking=True)
736
+ meta = meta.to(device, non_blocking=True)
737
+ labels = labels.to(device, non_blocking=True)
738
+ ccm = ccm.to(device, non_blocking=True)
739
+
740
+ # forward with mixed precision
741
+ with autocast(device_type='cuda', dtype=torch.float16):
742
+ outputs, embeddings = model(inputs, meta)
743
+ loss = loss_fn(outputs, embeddings, labels)
744
+
745
+ # Compute metrics
746
+ loss_metric.update(loss.detach())
747
+
748
+ preds = outputs.softmax(dim=-1).detach()
749
+ metrics.update(preds, labels)
750
+ metric_ccm.update(preds * ccm, labels)
751
+
752
+ # compute, sync & reset metrics for validation
753
+ epoch_loss = loss_metric.compute()
754
+ epoch_metrics = metrics.compute()
755
+ epoch_metric_ccm = metric_ccm.compute()
756
+
757
+ loss_metric.reset()
758
+ metrics.reset()
759
+ metric_ccm.reset()
760
+
761
+ # Append metric results to logs
762
+ logs['val_loss'].append(epoch_loss.cpu().item())
763
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
764
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
765
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
766
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
767
+
768
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
769
+
770
+ del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
771
+ torch.cuda.empty_cache()
772
+
773
+ # save logs as csv
774
+ logs_df = pd.DataFrame(logs)
775
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
776
+
777
+ if WANDB:
778
+ # at the end of each epoch, log anything you want to log for that epoch
779
+ wandb.log(
780
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
781
+ step=epoch # epoch index for wandb
782
+ )
783
+
784
+ #save trained model for each epoch
785
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
786
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
787
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
788
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
789
+ torch.save(loss_fn.arcloss.state_dict(), f'{MODEL_DIR}arcloss_epoch{epoch}.pth')
790
+
791
+ # end time of epoch
792
+ epoch_end = time.perf_counter()
793
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
794
+
795
+ del logs_df, epoch_start, epoch_end
796
+ torch.cuda.empty_cache()
797
+
798
+ ################################## EMA Model Validation ################################
799
+ del model
800
+ torch.cuda.empty_cache()
801
+
802
+ ema_net = ema_model.module
803
+ ema_net.eval()
804
+
805
+ with torch.no_grad():
806
+ # iterate over validation batches
807
+ for (inputs, labels, ccm, meta) in valid_loader:
808
+ inputs = inputs.to(device, non_blocking=True)
809
+ meta = meta.to(device, non_blocking=True)
810
+ labels = labels.to(device, non_blocking=True)
811
+ ccm = ccm.to(device, non_blocking=True)
812
+
813
+ # forward with mixed precision
814
+ with autocast(device_type='cuda', dtype=torch.float16):
815
+ outputs, embeddings = model(inputs, meta)
816
+ loss = loss_fn(outputs, embeddings, labels)
817
+
818
+ # Compute metrics
819
+ loss_metric.update(loss.detach())
820
+
821
+ preds = outputs.softmax(dim=-1).detach()
822
+ metrics.update(preds, labels)
823
+ metric_ccm.update(preds * ccm, labels)
824
+
825
+ # compute, sync & reset metrics for validation
826
+ epoch_loss = loss_metric.compute()
827
+ epoch_metrics = metrics.compute()
828
+ epoch_metric_ccm = metric_ccm.compute()
829
+
830
+ loss_metric.reset()
831
+ metrics.reset()
832
+ metric_ccm.reset()
833
+
834
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
835
+
836
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
837
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
838
+
839
+ plot_history(logs)
840
+ # end time of trainig
841
+ end_training = time.perf_counter()
842
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
843
+
844
+ if WANDB:
845
+ wandb.finish()
846
+
847
+
848
+ if __name__=="__main__":
849
+ main()
850
+
851
+
852
+
853
+
exp5/convnext2b_exp5_TTAattention.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, pickle, shutil
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torch.cuda.amp import GradScaler
12
+ from torch import autocast
13
+
14
+ import torchvision.transforms as transforms
15
+
16
+ import timm
17
+ from timm.models import create_model
18
+ from timm.utils import ModelEmaV2
19
+
20
+ from timm.optim import create_optimizer_v2
21
+
22
+ from torchmetrics import MeanMetric
23
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
24
+ from torchmetrics import MetricCollection
25
+
26
+ from pytorch_metric_learning.losses import ArcFaceLoss
27
+
28
+ import wandb
29
+
30
+ import matplotlib.pyplot as plt
31
+
32
+
33
+ # ### parameters
34
+ ################## Settings #############################
35
+ #os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
+ torch.backends.cudnn.benchmark = True
37
+
38
+ ################## Data Paths ##########################
39
+ MODEL_DIR = "./convnext2b_TTAattention/"
40
+
41
+ if not os.path.exists(MODEL_DIR):
42
+ os.makedirs(MODEL_DIR)
43
+ shutil.copyfile('./convnext2b_exp5_TTAattention.py', f'{MODEL_DIR}convnext2b_exp5_TTAattention.py')
44
+
45
+ TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
46
+ ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
47
+ VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
48
+
49
+ TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
50
+ ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
51
+ VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
52
+
53
+ MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
54
+
55
+ CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
56
+
57
+
58
+ NUM_CLASSES = 1784
59
+
60
+ ################## Hyperparameters ########################
61
+ NUM_EPOCHS = 50
62
+ WARMUP_EPOCHS = 0
63
+ RESUME_EPOCH = 39 # resume model, optimizer from epoch 39 of experiment 4, checkpoint files need to be copied to the MODEL_DIR folder
64
+
65
+
66
+ LEARNING_RATE = {
67
+ 'cnn': 1e-05,
68
+ 'embeddings': 1e-04,
69
+ 'classifier': 1e-04,
70
+ 'attention': 1e-04,
71
+ }
72
+
73
+ BATCH_SIZE = {
74
+ 'train': 42,
75
+ 'valid': 48,
76
+ 'grad_acc': 3, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
77
+ }
78
+
79
+ BATCH_SIZE_AFTER_WARMUP = {
80
+ 'train': 42,
81
+ 'valid': 48,
82
+ 'grad_acc': 3, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
83
+ }
84
+
85
+ TRANSFORMS = {
86
+ 'IMAGE_SIZE_TRAIN': 544,
87
+ 'IMAGE_SIZE_VAL': 544,
88
+ 'RandAug' : {
89
+ 'm': 7,
90
+ 'n': 2
91
+ },
92
+ 'num_rand_crops': 5, # num. of random crops during training per image instance
93
+ }
94
+
95
+
96
+ # ############# Focal Loss ####################
97
+ FOCAL_LOSS = {
98
+ 'class_dist': pickle.load(open("../classDist_HMP_missedRemoved.p", "rb"))['counts'], # snake species frequency obtained on observation_id level taken into account missing observation_id of missing image files
99
+ 'gamma': 0.5,
100
+ }
101
+
102
+
103
+ ############# Checkpoints ####################
104
+ CHECKPOINTS = {
105
+ 'fe_cnn': None,
106
+ 'model': None,
107
+ 'optimizer': None,
108
+ 'scaler': None,
109
+ 'arcloss': None,
110
+ }
111
+
112
+ # ####### Embedding Token Mappings ########################
113
+ META_SIZES = {'endemic': 2, 'code': 212}
114
+ EMBEDDING_SIZES = {'endemic': 64, 'code': 64}
115
+
116
+ CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb"))
117
+ ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb"))
118
+
119
+ ################### WandB ##################
120
+ WANDB = False
121
+
122
+ if WANDB:
123
+ wandb.init(
124
+ entity="snakeclef2023", # our team at wandb
125
+
126
+ # set the wandb project where this run will be logged
127
+ project="exp5", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
128
+
129
+ # define a name for this run
130
+ name="TTAattention",
131
+
132
+ # track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
133
+ config={
134
+ "learning_rate": LEARNING_RATE,
135
+ "focal_loss": FOCAL_LOSS,
136
+ "architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
137
+ "pretrained": "iNat21",
138
+ "dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
139
+ "epochs": NUM_EPOCHS,
140
+ "transforms": TRANSFORMS,
141
+ "checkpoints": CHECKPOINTS,
142
+ "model_dir": MODEL_DIR
143
+ # ... any other hyperparameter that is necessary to reproduce the result
144
+ },
145
+ save_code=True, # save the script file as backup
146
+ dir=MODEL_DIR # locally folder where wandb log files are saved
147
+ )
148
+
149
+
150
+
151
+
152
+ ##################### Dataset & AugTransforms #####################################
153
+ # ### dataset & loaders
154
+ class SnakeTrainDataset(Dataset):
155
+ def __init__(self, data, ccm, transform=None):
156
+ self.data = data
157
+ self.transform = transform # Image augmentation pipeline
158
+ self.code_class_mapping = ccm
159
+ self.code_tokens = CODE_TOKENS
160
+ self.endemic_tokens = ENDEMIC_TOKENS
161
+
162
+ def __len__(self):
163
+ return self.data.shape[0]
164
+
165
+ def __getitem__(self, index):
166
+ obj = self.data.iloc[index] # get instance
167
+ label = obj.class_id # get label
168
+ code = obj.code if obj.code in self.code_tokens.keys() else "unknown"
169
+ endemic = obj.endemic if obj.endemic in self.endemic_tokens.keys() else False # get endemic metadata
170
+
171
+ img = Image.open(obj.image_path).convert("RGB") # load image
172
+ ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) # code class mapping
173
+ meta = torch.tensor([self.code_tokens[code], self.endemic_tokens[endemic]]) # metadata tokens
174
+
175
+ # img. augmentation
176
+ img = self.transform(img)
177
+
178
+ return (img, label, ccm, meta)
179
+
180
+
181
+ # valid data preprocessing pipeline
182
+ def get_val_preprocessing(img_size):
183
+ print(f'IMG_SIZE_VAL: {img_size}')
184
+ return transforms.Compose([
185
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
186
+ transforms.Compose([
187
+ transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
188
+ transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
189
+ ]),
190
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
191
+ ])
192
+
193
+ class IdentityTransform:
194
+ def __call__(self, x):
195
+ return x
196
+
197
+
198
+ class MultipleRandomCropsWithAugmentation:
199
+ def __init__(self, img_size, num_crops=5):
200
+ super(MultipleRandomCropsWithAugmentation, self).__init__()
201
+ self.num_crops = num_crops
202
+ self.random_crop = transforms.RandomCrop((img_size, img_size))
203
+ self.augment = transforms.Compose([
204
+ transforms.RandomHorizontalFlip(p=0.5),
205
+ transforms.RandomVerticalFlip(p=0.5),
206
+ transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m'])
207
+ ])
208
+ self.to_tensor = transforms.ToTensor()
209
+
210
+ def __call__(self, x):
211
+ x = torch.stack([self.to_tensor(self.augment(self.random_crop(x))) for i in range(self.num_crops)])
212
+ return x
213
+
214
+ # train data augmentation/ preprocessing pipeline
215
+ def get_train_augmentation_preprocessing(img_size, rang_aug):
216
+ print(f'IMG_SIZE_TRAIN: {img_size}')
217
+ return transforms.Compose([
218
+ transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
219
+ MultipleRandomCropsWithAugmentation(img_size, TRANSFORMS['num_rand_crops']),
220
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
221
+ ])
222
+
223
+
224
+
225
+ def get_datasets(train_transfroms, val_transforms):
226
+ # load CSVs
227
+ nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
228
+ train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
229
+ missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
230
+ valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
231
+
232
+ # delete missing files of train data table
233
+ train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
234
+ train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
235
+
236
+ # add image path
237
+ train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
238
+ valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
239
+
240
+ # add additional data
241
+ if ADD_TRAINDATA_CONFIG:
242
+ add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
243
+ add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
244
+ train_data = pd.concat([train_data, add_train_data], axis=0)
245
+
246
+ # limit data size
247
+ #train_data = train_data.head(200)
248
+ #valid_data = valid_data.head(200)
249
+ print(f'train data shape: {train_data.shape}')
250
+
251
+ # shuffle
252
+ train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
253
+ valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
254
+
255
+ # load transposed version of CCM table
256
+ ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
257
+
258
+ # create datasets
259
+ train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms)
260
+ valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms)
261
+
262
+ return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
263
+
264
+
265
+ def get_dataloaders(imgsize_train, imgsize_val, rand_aug):
266
+ # get train, valid augmentation & preprocessing pipelines
267
+ train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug)
268
+ val_preprocessing = get_val_preprocessing(imgsize_val)
269
+ # prepare the datasets
270
+ train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
271
+ train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True)
272
+ valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True)
273
+
274
+ return train_loader, valid_loader
275
+
276
+
277
+ # #################### plot train history #########################
278
+
279
+ def plot_history(logs):
280
+ fig, ax = plt.subplots(3, 1, figsize=(8, 12))
281
+
282
+ ax[0].plot(logs['loss'], label="train data")
283
+ ax[0].plot(logs['val_loss'], label="valid data")
284
+ ax[0].legend(loc="best")
285
+ ax[0].set_ylabel("loss")
286
+ ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
287
+ #ax[0].set_xlabel("epochs")
288
+ ax[0].set_title("train- vs. valid loss")
289
+
290
+ ax[1].plot(logs['acc'], label="train data")
291
+ ax[1].plot(logs['val_acc'], label="valid data")
292
+ ax[1].legend(loc="best")
293
+ ax[1].set_ylabel("accuracy")
294
+ ax[1].set_ylim([0, 1.01])
295
+ #ax[1].set_xlabel("epochs")
296
+ ax[1].set_title("train- vs. valid accuracy")
297
+
298
+ ax[2].plot(logs['f1'], label="train data")
299
+ ax[2].plot(logs['val_f1'], label="valid data")
300
+ ax[2].legend(loc="best")
301
+ ax[2].set_ylabel("f1")
302
+ ax[2].set_ylim([0, 1.01])
303
+ ax[2].set_xlabel("epochs")
304
+ ax[2].set_title("train- vs. valid f1")
305
+
306
+ fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
307
+ plt.show()
308
+
309
+ #################### Focal Loss ##################################
310
+ class FocalLoss(nn.Module):
311
+ '''
312
+ Multi-class Focal Loss
313
+ '''
314
+ def __init__(self, gamma, class_dist=None, reduction='mean', device='cuda'):
315
+ super(FocalLoss, self).__init__()
316
+ self.gamma = gamma
317
+ self.weight = torch.tensor((1.0 - 0.999) / (1.0 - 0.999**class_dist), dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device)
318
+ self.reduction = reduction
319
+
320
+ def forward(self, inputs, targets):
321
+ """
322
+ input: [N, C], float32
323
+ target: [N, ], int64
324
+ """
325
+ logpt = torch.nn.functional.log_softmax(inputs, dim=1)
326
+ pt = torch.exp(logpt)
327
+ logpt = (1-pt)**self.gamma * logpt
328
+ loss = torch.nn.functional.nll_loss(logpt, targets, weight=self.weight, reduction=self.reduction)
329
+ return loss
330
+
331
+
332
+ # #################### Model #####################################
333
+
334
+ class FeatureExtractor(nn.Module):
335
+ def __init__(self):
336
+ super(FeatureExtractor, self).__init__()
337
+ self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
338
+ if CHECKPOINTS['fe_cnn']:
339
+ self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
340
+ print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
341
+ torch.cuda.empty_cache()
342
+
343
+ def forward(self, img):
344
+ conv_features = self.conv_backbone(img)
345
+ return conv_features
346
+
347
+
348
+ class MetaEmbeddings(nn.Module):
349
+ def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None):
350
+ super(MetaEmbeddings, self).__init__()
351
+ self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0)
352
+ self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0)
353
+
354
+ self.dim_embedding = sum(embedding_sizes.values())
355
+ self.embedding_net = nn.Sequential(
356
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
357
+ nn.GELU(),
358
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
359
+ nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(),
360
+ nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
361
+ nn.GELU(),
362
+ nn.LayerNorm(self.dim_embedding, eps=1e-06),
363
+ )
364
+
365
+ def forward(self, meta):
366
+ code_feature = self.code_embedding(meta[:,0])
367
+ endemic_feature = self.endemic_embedding(meta[:,1])
368
+
369
+ embeddings = torch.concat([code_feature, endemic_feature], dim=-1)
370
+ embedding_features = self.embedding_net(embeddings)
371
+
372
+ return embedding_features
373
+
374
+
375
+ class Attention(nn.Module):
376
+ def __init__(self):
377
+ super(Attention, self).__init__()
378
+ self.L = 1024
379
+ self.D = 256
380
+ self.K = 1
381
+
382
+ self.attention = nn.Sequential(
383
+ nn.Linear(self.L, self.D),
384
+ nn.Tanh(),
385
+ nn.Linear(self.D, self.K)
386
+ )
387
+
388
+ def forward(self, x, imgs_per_instance=5):
389
+ x = x.view(-1, imgs_per_instance, self.L)
390
+
391
+ A = self.attention(x) # bx5x1
392
+ A = torch.transpose(A, 2, 1) # bx1x5
393
+ A = nn.functional.softmax(A, dim=-1) # softmax over 5
394
+ M = torch.bmm(A, x).squeeze(dim=1) # bx1x5 * bx5xL -> 1xL
395
+
396
+ return M, A
397
+
398
+
399
+ class Classifier(nn.Module):
400
+ def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
401
+ super(Classifier, self).__init__()
402
+ self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
403
+ self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
404
+
405
+ def forward(self, embeddings):
406
+ dropped_feature = self.dropout(embeddings)
407
+ outputs = self.classifier(dropped_feature)
408
+
409
+ return outputs
410
+
411
+
412
+ class Model(nn.Module):
413
+ def __init__(self):
414
+ super(Model, self).__init__()
415
+ self.feature_extractor = FeatureExtractor()
416
+ self.attention = Attention()
417
+ self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25)
418
+ self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25)
419
+
420
+ def forward(self, img, meta):
421
+ img_features = self.feature_extractor(img)
422
+ img_features, A = self.attention(img_features)
423
+
424
+ meta_features = self.embedding_net(meta)
425
+ cat_features = torch.concat([img_features, meta_features], dim=-1)
426
+ classifier_outputs = self.classifier(cat_features)
427
+
428
+ return classifier_outputs, cat_features
429
+
430
+ class LossLayer(nn.Module):
431
+ def __init__(self):
432
+ super(LossLayer, self).__init__()
433
+ self.arcloss = ArcFaceLoss(num_classes=NUM_CLASSES, embedding_size=1024+128, margin=28.6, scale=64)
434
+ self.celoss = FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'])
435
+
436
+ def forward(self, classifier_outputs, cat_features, labels):
437
+ classifier_loss = self.celoss(classifier_outputs, labels)
438
+ embedding_loss = self.arcloss(cat_features, labels)
439
+ return classifier_loss + embedding_loss
440
+
441
+
442
+ def load_checkpoints(model=None, optimizer=None, scaler=None):
443
+ if CHECKPOINTS['model'] and model is not None:
444
+ model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'))
445
+ print(f"use model checkpoints: {CHECKPOINTS['model']}")
446
+ if CHECKPOINTS['optimizer'] and optimizer is not None:
447
+ optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
448
+ print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
449
+ if CHECKPOINTS['scaler'] and scaler is not None:
450
+ scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
451
+ print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
452
+ torch.cuda.empty_cache()
453
+
454
+ def resume_checkpoints(model=None, optimizer=None, scaler=None, arcloss=None):
455
+ if model is not None:
456
+ model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'), strict=False)
457
+ print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
458
+ if optimizer is not None:
459
+ optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
460
+ print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
461
+
462
+ if scaler is not None:
463
+ scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
464
+ print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
465
+ if arcloss is not None:
466
+ arcloss.load_state_dict(torch.load(f'{MODEL_DIR}arcloss_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
467
+ print(f"use arcloss checkpoints: {MODEL_DIR}arcloss_epoch{RESUME_EPOCH}.pth")
468
+ torch.cuda.empty_cache()
469
+
470
+
471
+ def resume_logs(logs):
472
+ old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
473
+ for m in list(logs.keys()):
474
+ logs[m].extend(list(old_logs[m].values))
475
+
476
+ ######################## Optimizer #####################################
477
+ def get_optm_group(module):
478
+ """
479
+ This long function is unfortunately doing something very simple and is being very defensive:
480
+ We are separating out all parameters of the model into two buckets: those that will experience
481
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
482
+ We are then returning the PyTorch optimizer object.
483
+ """
484
+
485
+ # separate out all parameters to those that will and won't experience regularizing weight decay
486
+ decay = set()
487
+ no_decay = set()
488
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
489
+ blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
490
+ for mn, m in module.named_modules():
491
+ for pn, p in m.named_parameters():
492
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
493
+
494
+ if pn.endswith('bias'):
495
+ # all biases will not be decayed
496
+ no_decay.add(fpn)
497
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
498
+ # weights of whitelist modules will be weight decayed
499
+ decay.add(fpn)
500
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
501
+ # weights of blacklist modules will NOT be weight decayed
502
+ no_decay.add(fpn)
503
+
504
+
505
+ # validate that we considered every parameter
506
+ param_dict = {pn: p for pn, p in module.named_parameters()}
507
+ inter_params = decay & no_decay
508
+ union_params = decay | no_decay
509
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
510
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
511
+ % (str(param_dict.keys() - union_params), )
512
+
513
+ return param_dict, decay, no_decay
514
+
515
+
516
+ def get_warmup_optimizer(model):
517
+ params_group = []
518
+
519
+ param_dict, decay, no_decay = get_optm_group(model.embedding_net)
520
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']})
521
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']})
522
+
523
+ param_dict, decay, no_decay = get_optm_group(model.classifier)
524
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
525
+ params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
526
+
527
+ optimizer = torch.optim.AdamW(params_group)
528
+ return optimizer
529
+
530
+
531
+ def get_after_warmup_optimizer(model, old_opt):
532
+ new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
533
+
534
+ # add old param groups
535
+ for group in old_opt.param_groups:
536
+ new_opt.add_param_group(group)
537
+
538
+ return new_opt
539
+
540
+
541
+ # #################### Model Warmup #####################################
542
+
543
+ def warmup_start(model):
544
+ # freeze model feature_extractor.conv_backbone during warmup
545
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
546
+ param.requires_grad = False
547
+ print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
548
+
549
+ for i, (param_name, param) in enumerate(model.embedding_net.named_parameters()):
550
+ param.requires_grad = False
551
+ print(f'--> freeze embedding_net during warmup phase')
552
+
553
+
554
+ def warmup_end(model):
555
+ # unfreeze feature_extractor.conv_backbone during warmup
556
+ for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
557
+ param.requires_grad = True
558
+ print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
559
+
560
+
561
+
562
+ # #################### Train Loop #####################################
563
+
564
+ # ### train
565
+ def main():
566
+ device = torch.device(f'cuda:1')
567
+ torch.cuda.set_device(device)
568
+
569
+ # prepare the datasets
570
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
571
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
572
+ rand_aug=True)
573
+
574
+ # instantiate the model
575
+ model = Model().to(device)
576
+ if RESUME_EPOCH > 0:
577
+ resume_checkpoints(model=model)
578
+ ema_model = ModelEmaV2(model, decay=0.9998, device=device)
579
+ warmup_start(model)
580
+
581
+ loss_fn = LossLayer().to(device)
582
+ if RESUME_EPOCH > 0:
583
+ resume_checkpoints(arcloss=loss_fn.arcloss)
584
+
585
+ # Optimizer & Schedules & early stopping
586
+ optimizer = get_warmup_optimizer(model)
587
+ optimizer.add_param_group({"params": loss_fn.arcloss.parameters(), "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
588
+
589
+ scaler = GradScaler()
590
+ if RESUME_EPOCH > 0:
591
+ #optimizer = get_after_warmup_optimizer(model, optimizer) if RESUME_EPOCH > WARMUP_EPOCHS else optimizer
592
+ resume_checkpoints(optimizer=optimizer, scaler=scaler)
593
+
594
+ # add attention module
595
+ param_dict, decay, no_decay = get_optm_group(model.attention)
596
+ optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['attention']})
597
+ optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['attention']})
598
+
599
+ # running metrics during training
600
+ loss_metric = MeanMetric().to(device)
601
+ metrics = MetricCollection(metrics={
602
+ 'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
603
+ 'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
604
+ 'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
605
+ }).to(device)
606
+ metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
607
+
608
+ # start time of trainig
609
+ start_training = time.perf_counter()
610
+ # create log dict
611
+ logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
612
+ if RESUME_EPOCH > 0:
613
+ resume_logs(logs)
614
+
615
+ #iterate over epochs
616
+ start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
617
+ for epoch in range(start_epoch, NUM_EPOCHS):
618
+ # start time of epoch
619
+ epoch_start = time.perf_counter()
620
+ print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
621
+
622
+ ######################## toggle warmup ########################################
623
+ if (epoch) == WARMUP_EPOCHS:
624
+ warmup_end(model)
625
+ optimizer = get_after_warmup_optimizer(model, optimizer)
626
+ global BATCH_SIZE
627
+ BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP
628
+ train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
629
+ imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
630
+ rand_aug=True)
631
+
632
+ elif (epoch) < WARMUP_EPOCHS:
633
+ print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}')
634
+
635
+ ############################## train phase ####################################
636
+ model.train()
637
+
638
+ # zero the parameter gradients
639
+ optimizer.zero_grad(set_to_none=True)
640
+
641
+ # grad acc loss divider
642
+ loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
643
+
644
+ # iterate over training batches
645
+ for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader):
646
+ inputs = inputs.to(device, non_blocking=True)
647
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_TRAIN'], TRANSFORMS['IMAGE_SIZE_TRAIN'])
648
+ meta = meta.to(device, non_blocking=True)
649
+ labels = labels.to(device, non_blocking=True)
650
+ ccm = ccm.to(device, non_blocking=True)
651
+
652
+ # forward with mixed precision
653
+ with autocast(device_type='cuda', dtype=torch.float16):
654
+ outputs, embeddings = model(inputs, meta)
655
+ loss = loss_fn(outputs, embeddings, labels) / loss_div
656
+
657
+ # loss backward
658
+ scaler.scale(loss).backward()
659
+
660
+ # Compute metrics
661
+ loss_metric.update((loss * loss_div).detach())
662
+
663
+ preds = outputs.softmax(dim=-1).detach()
664
+ metrics.update(preds, labels)
665
+ metric_ccm.update(preds * ccm, labels)
666
+
667
+ ############################ grad acc ##############################
668
+ if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
669
+ #scaler.unscale_(optimizer)
670
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
671
+ scaler.step(optimizer)
672
+ scaler.update()
673
+ # zero the parameter gradients
674
+ optimizer.zero_grad(set_to_none=True)
675
+ # update ema model
676
+ ema_model.update(model)
677
+
678
+
679
+ # compute, sync & reset metrics for validation
680
+ epoch_loss = loss_metric.compute()
681
+ epoch_metrics = metrics.compute()
682
+ epoch_metric_ccm = metric_ccm.compute()
683
+
684
+ loss_metric.reset()
685
+ metrics.reset()
686
+ metric_ccm.reset()
687
+
688
+ # Append metric results to logs
689
+ logs['loss'].append(epoch_loss.cpu().item())
690
+ logs['acc'].append(epoch_metrics['acc'].cpu().item())
691
+ logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
692
+ logs['f1'].append(epoch_metrics['f1'].cpu().item())
693
+ logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
694
+
695
+ print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
696
+
697
+ # zero the parameter gradients
698
+ optimizer.zero_grad(set_to_none=True)
699
+
700
+ del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
701
+ torch.cuda.empty_cache()
702
+
703
+ ############################## valid phase ####################################
704
+ with torch.no_grad():
705
+ model.eval()
706
+
707
+ # iterate over validation batches
708
+ for (inputs, labels, ccm, meta) in valid_loader:
709
+ inputs = inputs.to(device, non_blocking=True)
710
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
711
+ meta = meta.to(device, non_blocking=True)
712
+ labels = labels.to(device, non_blocking=True)
713
+ ccm = ccm.to(device, non_blocking=True)
714
+
715
+ # forward with mixed precision
716
+ with autocast(device_type='cuda', dtype=torch.float16):
717
+ outputs, embeddings = model(inputs, meta)
718
+ loss = loss_fn(outputs, embeddings, labels)
719
+
720
+ # Compute metrics
721
+ loss_metric.update(loss.detach())
722
+
723
+ preds = outputs.softmax(dim=-1).detach()
724
+ metrics.update(preds, labels)
725
+ metric_ccm.update(preds * ccm, labels)
726
+
727
+ # compute, sync & reset metrics for validation
728
+ epoch_loss = loss_metric.compute()
729
+ epoch_metrics = metrics.compute()
730
+ epoch_metric_ccm = metric_ccm.compute()
731
+
732
+ loss_metric.reset()
733
+ metrics.reset()
734
+ metric_ccm.reset()
735
+
736
+ # Append metric results to logs
737
+ logs['val_loss'].append(epoch_loss.cpu().item())
738
+ logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
739
+ logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
740
+ logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
741
+ logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
742
+
743
+ print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
744
+
745
+ del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
746
+ torch.cuda.empty_cache()
747
+
748
+ # save logs as csv
749
+ logs_df = pd.DataFrame(logs)
750
+ logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
751
+
752
+ if WANDB:
753
+ # at the end of each epoch, log anything you want to log for that epoch
754
+ wandb.log(
755
+ {k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
756
+ step=epoch # epoch index for wandb
757
+ )
758
+
759
+ #save trained model for each epoch
760
+ torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
761
+ torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
762
+ torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
763
+ torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
764
+ torch.save(loss_fn.arcloss.state_dict(), f'{MODEL_DIR}arcloss_epoch{epoch}.pth')
765
+
766
+ # end time of epoch
767
+ epoch_end = time.perf_counter()
768
+ print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
769
+
770
+ del logs_df, epoch_start, epoch_end
771
+ torch.cuda.empty_cache()
772
+
773
+ ################################## EMA Model Validation ################################
774
+ del model
775
+ torch.cuda.empty_cache()
776
+
777
+ ema_net = ema_model.module
778
+ ema_net.eval()
779
+
780
+ with torch.no_grad():
781
+ # iterate over validation batches
782
+ for (inputs, labels, ccm, meta) in valid_loader:
783
+ inputs = inputs.to(device, non_blocking=True)
784
+ inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL'])
785
+ meta = meta.to(device, non_blocking=True)
786
+ labels = labels.to(device, non_blocking=True)
787
+ ccm = ccm.to(device, non_blocking=True)
788
+
789
+ # forward with mixed precision
790
+ with autocast(device_type='cuda', dtype=torch.float16):
791
+ outputs, embeddings = ema_net(inputs, meta)
792
+ loss = loss_fn(outputs, embeddings, labels)
793
+
794
+ # Compute metrics
795
+ loss_metric.update(loss.detach())
796
+
797
+ preds = outputs.softmax(dim=-1).detach()
798
+ metrics.update(preds, labels)
799
+ metric_ccm.update(preds * ccm, labels)
800
+
801
+ # compute, sync & reset metrics for validation
802
+ epoch_loss = loss_metric.compute()
803
+ epoch_metrics = metrics.compute()
804
+ epoch_metric_ccm = metric_ccm.compute()
805
+
806
+ loss_metric.reset()
807
+ metrics.reset()
808
+ metric_ccm.reset()
809
+
810
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
811
+
812
+ with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
813
+ print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
814
+
815
+ plot_history(logs)
816
+ # end time of trainig
817
+ end_training = time.perf_counter()
818
+ print(f'Training succeeded in {(end_training - start_training):5.3f}s')
819
+
820
+ if WANDB:
821
+ wandb.finish()
822
+
823
+
824
+ if __name__=="__main__":
825
+ main()
826
+
827
+
828
+
829
+
meta_code_tokens.p ADDED
Binary file (1.51 kB). View file
 
meta_endemic_tokens.p ADDED
Binary file (129 Bytes). View file
 
missing_train_data.csv ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ observation_id,endemic,binomial_name,code,image_path,class_id,subset
2
+ 67796298,False,Leptodrymus pulcherrimus,NI,2021/Leptodrymus_pulcherrimus/109630991.jpeg,911,train
3
+ 90990396,False,Anilios bicolor,AU,2021/Anilios_bicolor/150442199.jpg,69,train
4
+ 69872390,False,Hydrophis platurus,MX,2021/Hydrophis_platurus/113480809.jpeg,822,train
5
+ 68196893,True,Crotalus aquilus,MX,2021/Crotalus_aquilus/110364809.jpeg,403,train
6
+ 68306088,False,Coniophanes imperialis,unknown,2021/Coniophanes_imperialis/110568759.jpg,365,train
7
+ 79774040,False,Xenodon dorbignyi,AR,2021/Xenodon_dorbignyi/130690394.jpeg,1759,train
8
+ 69653234,False,Corallus ruschenbergerii,CR,2021/Corallus_ruschenbergerii/113065848.jpg,386,train
9
+ 69021659,False,Porthidium lansbergii,CO,2021/Porthidium_lansbergii/111905125.jpeg,1315,train
10
+ 69021659,False,Porthidium lansbergii,CO,2021/Porthidium_lansbergii/111905148.jpeg,1315,train
11
+ 69482422,False,Ficimia publia,MX,2021/Ficimia_publia/112749228.jpg,702,train
12
+ 69682741,False,Leptophis mexicanus,BZ,2021/Leptophis_mexicanus/113236039.jpg,917,train
13
+ 70454074,False,Indotyphlops braminus,unknown,2021/Indotyphlops_braminus/114541765.jpg,849,train
14
+ 70860740,False,Salvadora lineata,unknown,2021/Salvadora_lineata/115274966.jpg,1490,train
15
+ 71088614,False,Zamenis situla,HR,2021/Zamenis_situla/115682327.jpeg,1783,train
16
+ 71088614,False,Zamenis situla,HR,2021/Zamenis_situla/115682470.jpeg,1783,train
17
+ 71088614,False,Zamenis situla,HR,2021/Zamenis_situla/115682540.jpeg,1783,train
18
+ 77936031,False,Bothrops atrox,CO,2021/Bothrops_atrox/127491786.jpg,224,train
19
+ 77936031,False,Bothrops atrox,CO,2021/Bothrops_atrox/127491831.jpg,224,train
20
+ 71687721,False,Atretium schistosum,IN,2021/Atretium_schistosum/116753675.jpeg,149,train
21
+ 72522920,False,Coluber constrictor,US,2021/Coluber_constrictor/118257455.jpg,358,train
22
+ 72522920,False,Coluber constrictor,US,2021/Coluber_constrictor/118257461.jpg,358,train
23
+ 72522920,False,Coluber constrictor,US,2021/Coluber_constrictor/118257468.jpg,358,train
24
+ 72676335,False,Indotyphlops braminus,ZA,2021/Indotyphlops_braminus/118534637.jpg,849,train
25
+ 72676354,False,Indotyphlops braminus,ZA,2021/Indotyphlops_braminus/118534676.jpg,849,train
26
+ 72829108,False,Tachymenis ocellata,AR,2021/Tachymenis_ocellata/118815002.jpeg,1572,train
27
+ 73273528,False,Nerodia fasciata,unknown,2021/Nerodia_fasciata/119645669.jpg,1143,train
28
+ 73297381,False,Nerodia fasciata,unknown,2021/Nerodia_fasciata/119645716.jpg,1143,train
29
+ 73843863,False,Crotalus mitchellii,unknown,2021/Crotalus_mitchellii/120616047.jpg,419,train
30
+ 74623459,False,Bungarus caeruleus,IN,2021/Bungarus_caeruleus/121987144.jpeg,256,train
31
+ 78460090,False,Naja melanoleuca,BJ,2021/Naja_melanoleuca/128687447.jpg,1114,train
32
+ 77123157,False,Atractus carrioni,EC,2021/Atractus_carrioni/126102215.jpg,118,train
33
+ 78527509,False,Erythrolamprus melanotus,TT,2021/Erythrolamprus_melanotus/128525135.jpg,668,train
34
+ 78527511,False,Erythrolamprus melanotus,TT,2021/Erythrolamprus_melanotus/128525153.jpg,668,train
35
+ 78668716,False,Thamnophis ordinoides,CA,2021/Thamnophis_ordinoides/128773221.jpeg,1634,train
36
+ 78700443,False,Lampropeltis micropholis,EC,2021/Lampropeltis_micropholis/128829321.jpg,873,train
37
+ 79646763,False,Leptodeira nigrofasciata,NI,2021/Leptodeira_nigrofasciata/130472627.jpeg,901,train
38
+ 79646763,False,Leptodeira nigrofasciata,NI,2021/Leptodeira_nigrofasciata/130472758.jpeg,901,train
39
+ 79646763,False,Leptodeira nigrofasciata,NI,2021/Leptodeira_nigrofasciata/130472802.jpeg,901,train
40
+ 83296315,True,Vipera berus,GB,2021/Vipera_berus/136792442.jpeg,1736,train
41
+ 82283665,False,Storeria dekayi,CA,2021/Storeria_dekayi/135035849.jpg,1551,train
42
+ 83525188,False,Micrurus camilae,CO,2021/Micrurus_camilae/137183120.jpeg,1047,train
43
+ 95346633,False,Tantilla melanocephala,BR,2021/Tantilla_melanocephala/158274903.jpg,1590,train
44
+ 82672187,True,Macrovipera lebetinus,TR,2021/Macrovipera_lebetinus/135709924.jpeg,998,train
45
+ 84851943,False,Bothrops bilineatus,EC,2021/Bothrops_bilineatus/139513236.jpeg,225,train
46
+ 83516790,False,Bothrops asper,EC,2021/Bothrops_asper/137173097.jpeg,223,train
47
+ 84043527,True,Oligodon sublineatus,LK,2021/Oligodon_sublineatus/138090906.jpeg,1181,train
48
+ 89994824,False,Siphlophis compressus,EC,2021/Siphlophis_compressus/148675926.jpg,1524,train
49
+ 86759404,False,Thamnophis cyrtopsis,MX,2021/Thamnophis_cyrtopsis/142982312.jpeg,1624,train
50
+ 86508209,False,Clelia scytalina,MX,2021/Clelia_scytalina/143000260.jpg,351,train
51
+ 86658948,False,Demansia reticulata,AU,2021/Demansia_reticulata/142798080.jpeg,476,train
52
+ 86866624,False,Dolichophis jugularis,TR,2021/Dolichophis_jugularis/143176053.jpg,564,train
53
+ 87485941,False,Urotheca fulviceps,PA,2021/Urotheca_fulviceps/144246084.jpeg,1729,train
54
+ 95469215,False,Tretanorhinus nigroluteus,CR,2021/Tretanorhinus_nigroluteus/158498893.jpeg,1658,train
55
+ 96371001,False,Dendrophidion percarinatum,CO,2021/Dendrophidion_percarinatum/160094559.jpeg,521,train
56
+ 132373886,False,Pareas stanleyi,CN,2021/Pareas_stanleyi/225417796.jpeg,1255,train
57
+ 93380278,False,Coronella girondica,FR,2021/Coronella_girondica/154762765.jpg,388,train
58
+ 93380278,False,Coronella girondica,FR,2021/Coronella_girondica/154762825.jpg,388,train
59
+ 95515839,False,Laticauda colubrina,FJ,2021/Laticauda_colubrina/158580938.jpeg,887,train
60
+ 101870194,False,Hebius boulengeri,CN,2021/Hebius_boulengeri/170186519.jpeg,760,train
61
+ 94524054,False,Pantherophis spiloides,CA,2021/Pantherophis_spiloides/156809193.jpeg,1239,train
62
+ 95070025,False,Micrurus lemniscatus,EC,2021/Micrurus_lemniscatus/157784766.jpeg,1066,train
63
+ 107309695,False,Zamenis situla,AL,2021/Zamenis_situla/180456801.jpeg,1783,train
64
+ 97546752,True,Ahaetulla borealis,IN,2021/Ahaetulla_borealis/162250238.jpeg,36,train
65
+ 99942012,True,Micrurus diastema,GT,2021/Micrurus_diastema/166721251.jpg,1052,train
66
+ 97988351,False,Lampropeltis triangulum,unknown,2021/Lampropeltis_triangulum/163060352.jpeg,881,train
67
+ 101760741,False,Salvadora lineata,unknown,2021/Salvadora_lineata/169985709.jpg,1490,train
68
+ 101589271,False,Oxybelis potosiensis,BZ,2021/Oxybelis_potosiensis/169675126.jpeg,1210,train
69
+ 120112461,False,Eunectes murinus,PE,2021/Eunectes_murinus/204299422.jpeg,694,train
70
+ 122750982,False,Dipsas neuwiedi,BR,2021/Dipsas_neuwiedi/207869435.jpg,548,train
71
+ 122750989,False,Dipsas neuwiedi,BR,2021/Dipsas_neuwiedi/207869485.jpg,548,train
72
+ 102038435,False,Thamnophis proximus,BZ,2021/Thamnophis_proximus/170495233.jpeg,1635,train
73
+ 102115213,False,Boa imperator,PA,2021/Boa_imperator/170637991.jpeg,167,train
74
+ 102200994,False,Crotalus ehecatl,MX,2021/Crotalus_ehecatl/170788633.jpg,413,train
75
+ 102200994,False,Crotalus ehecatl,MX,2021/Crotalus_ehecatl/170788634.jpg,413,train
76
+ 102200994,False,Crotalus ehecatl,MX,2021/Crotalus_ehecatl/170788636.jpg,413,train
77
+ 102200994,False,Crotalus ehecatl,MX,2021/Crotalus_ehecatl/170788644.jpg,413,train
78
+ 102439855,False,Erythrolamprus typhlus,BR,2021/Erythrolamprus_typhlus/171240131.jpeg,681,train
79
+ 102661878,False,Chironius maculoventris,AR,2021/Chironius_maculoventris/171663522.jpeg,336,train
80
+ 109034494,False,Coelognathus radiatus,VN,2021/Coelognathus_radiatus/183693780.jpg,356,train
81
+ 108785823,False,Pseudonaja mengdeni,AU,2021/Pseudonaja_mengdeni/185153738.jpg,1393,train
82
+ 108785823,False,Pseudonaja mengdeni,AU,2021/Pseudonaja_mengdeni/185153744.jpg,1393,train
83
+ 108785823,False,Pseudonaja mengdeni,AU,2021/Pseudonaja_mengdeni/185153763.jpg,1393,train
84
+ 103160025,False,Stenorrhina degenhardtii,CR,2021/Stenorrhina_degenhardtii/172594851.jpeg,1549,train
85
+ 104199941,False,Bothrops ammodytoides,AR,2021/Bothrops_ammodytoides/174579877.jpg,222,train
86
+ 125047291,True,Lycognathophis seychellensis,SC,2021/Lycognathophis_seychellensis/212072332.jpg,980,train