# Handling files import os # Custom Dataset from dataset import DinoDataset # Training and Evaluating from train_utils import * # ConvNext model from transformers import ConvNextForImageClassification # Torch import torch import torchvision.transforms as T from torch.utils.data import DataLoader, WeightedRandomSampler import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts # mixup/cutmix augmentation from timm.data import Mixup if __name__ == "__main__": # Get current device device = "cuda" if torch.cuda.is_available() else "cpu" # Path dataset_path = "/dinosaur_project/data/filtered_split_dinosaurs_dataset" train_path = os.path.join(dataset_path, "train") val_path = os.path.join(dataset_path, "val") # Transform pipeline for training data train_transforms = T.Compose( [ T.Resize((256, 256)), T.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(0.9, 1.1)), T.RandomHorizontalFlip(), T.RandomRotation(degrees=15), T.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=0.01 ), T.RandomApply( [T.GaussianBlur(kernel_size=3)], p=0.05 ), T.RandomAffine( degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05) ), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), T.RandomErasing(p=0.1, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random') ] ) # Transform pipeline for validation/testing data val_transforms = T.Compose( [ T.Resize((224, 224)), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) # Dataset train_dataset = DinoDataset(root=train_path, transform=train_transforms) val_dataset = DinoDataset(root=val_path, transform=val_transforms) # Weighted random sampler ## Calculating number of images per class cls_count = torch.bincount(torch.tensor(train_dataset.targets)) ## Assign weight to each trainning sample sample_weights = torch.tensor( [1.0 / cls_count[img[1]] for img in train_dataset.imgs] ) ## Create sampler sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(train_dataset), replacement=True ) # DataLoader train_loader = DataLoader( dataset=train_dataset, batch_size=32, sampler=sampler, num_workers=4, drop_last=True ) val_loader = DataLoader( dataset=val_dataset, batch_size=32, shuffle=False, num_workers=4 ) # Model convnext_tiny = ConvNextForImageClassification.from_pretrained( "facebook/convnext-tiny-224", num_labels=len(train_dataset.classes), ignore_mismatched_sizes=True ) # Freeze classifier and last two stages of encoder for param in convnext_tiny.parameters(): param.requires_grad = False for param in convnext_tiny.convnext.encoder.stages[-2:].parameters(): param.requires_grad = True for param in convnext_tiny.classifier.parameters(): param.requires_grad = True # Loss function, optimizer, lr scheduler loss_func = nn.CrossEntropyLoss(reduction="sum", label_smoothing=0.05) optimizer = AdamW( params=filter(lambda p: p.requires_grad, convnext_tiny.parameters()), lr=2e-4, weight_decay=1e-2, betas=(0.9, 0.999), eps=1e-8 ) scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=10, T_mult=2, eta_min=1e-6 ) # cutmix/mixup augmentation mixup_fn = Mixup( mixup_alpha=0.1, cutmix_alpha=0.3, prob=0.5, switch_prob=0.3, mode="elem", label_smoothing=0, num_classes=len(train_dataset.classes) ) # Number of epoch n_epochs = 30 # Training train( model=convnext_tiny, n_epochs=n_epochs, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, device=device, early_stopping_patience=5, scheduler=scheduler, mix_augment=mixup_fn )