Batik-Classification-Transformer-model / advanced_augmentation.py
Decoder24's picture
Upload folder using huggingface_hub
a6eed2b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
import random
def cutmix_data(x, y, alpha=1.0, device='cuda'):
"""
CutMix data augmentation.
Args:
x: Input batch
y: Target batch
alpha: CutMix parameter
device: Device to run on
Returns:
mixed_x: Mixed input batch
y_a, y_b: Original targets for loss calculation
lam: Mixing ratio
"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size(0)
if device == 'cuda':
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
# Generate random bounding box
W = x.size(2)
H = x.size(3)
cut_rat = np.sqrt(1. - lam)
cut_w = int(W * cut_rat)
cut_h = int(H * cut_rat)
# Uniform sampling
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
# Adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
y_a, y_b = y, y[index]
return x, y_a, y_b, lam
def cutmix_criterion(criterion, pred, y_a, y_b, lam):
"""
CutMix loss calculation.
"""
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
class LabelSmoothingCrossEntropy(nn.Module):
"""
Label smoothing cross entropy loss.
"""
def __init__(self, smoothing=0.1):
super(LabelSmoothingCrossEntropy, self).__init__()
self.smoothing = smoothing
def forward(self, x, target):
confidence = 1. - self.smoothing
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
class FocalLoss(nn.Module):
"""
Focal Loss for addressing class imbalance.
"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class AdvancedAugmentation:
"""
Advanced augmentation techniques for better generalization.
"""
def __init__(self, image_size=224):
self.image_size = image_size
def get_train_transforms(self):
"""
Get comprehensive training transforms with advanced augmentation.
"""
return transforms.Compose([
# Resize with padding
transforms.Resize((self.image_size + 32, self.image_size + 32)),
# Random crop with padding
transforms.RandomCrop((self.image_size, self.image_size), padding=4),
# Geometric augmentations
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.RandomRotation(degrees=15),
transforms.RandomAffine(
degrees=0,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=5
),
# Color augmentations
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.05
),
# Advanced augmentations
transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
# TrivialAugmentWide for additional randomness
transforms.TrivialAugmentWide(num_magnitude_bins=31),
# Convert to tensor and normalize
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def get_val_transforms(self):
"""
Get validation transforms (minimal augmentation).
"""
return transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class TestTimeAugmentation:
"""
Test Time Augmentation for better inference.
"""
def __init__(self, model, device, num_augmentations=5):
self.model = model
self.device = device
self.num_augmentations = num_augmentations
# Define TTA transforms
self.tta_transforms = [
transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomRotation(degrees=10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomRotation(degrees=10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
]
def predict(self, image):
"""
Predict with TTA.
"""
self.model.eval()
predictions = []
with torch.no_grad():
for transform in self.tta_transforms[:self.num_augmentations]:
# Apply transform
if hasattr(image, 'convert'):
# PIL Image
transformed = transform(image)
else:
# Already tensor
transformed = transform(image)
# Add batch dimension
transformed = transformed.unsqueeze(0).to(self.device)
# Get prediction
output = self.model(transformed)
predictions.append(F.softmax(output, dim=1))
# Average predictions
avg_prediction = torch.mean(torch.stack(predictions), dim=0)
return avg_prediction
def calculate_class_weights(train_targets, num_classes, method='balanced'):
"""
Calculate class weights for handling class imbalance.
Args:
train_targets: List of training targets
num_classes: Number of classes
method: 'balanced', 'inverse', or 'sqrt'
Returns:
class_weights: Tensor of class weights
"""
class_counts = np.bincount(train_targets, minlength=num_classes)
if method == 'balanced':
# sklearn's balanced method
total_samples = len(train_targets)
class_weights = total_samples / (num_classes * class_counts)
elif method == 'inverse':
# Simple inverse frequency
class_weights = 1.0 / class_counts
elif method == 'sqrt':
# Square root of inverse frequency
class_weights = 1.0 / np.sqrt(class_counts)
else:
raise ValueError(f"Unknown method: {method}")
# Normalize weights
class_weights = class_weights / class_weights.sum() * num_classes
return torch.tensor(class_weights, dtype=torch.float)
def get_advanced_scheduler(optimizer, method='cosine_warmup', total_epochs=50):
"""
Get advanced learning rate scheduler.
Args:
optimizer: PyTorch optimizer
method: Scheduler method
total_epochs: Total number of epochs
Returns:
scheduler: Learning rate scheduler
"""
if method == 'cosine_warmup':
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
return CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-7)
elif method == 'onecycle':
from torch.optim.lr_scheduler import OneCycleLR
return OneCycleLR(
optimizer,
max_lr=optimizer.param_groups[0]['lr'],
total_steps=total_epochs,
pct_start=0.3,
anneal_strategy='cos'
)
elif method == 'plateau':
from torch.optim.lr_scheduler import ReduceLROnPlateau
return ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.5,
patience=3,
min_lr=1e-7,
verbose=True
)
else:
raise ValueError(f"Unknown scheduler method: {method}")
def apply_mixup_cutmix_probability():
"""
Randomly choose between Mixup and CutMix based on probability.
"""
return random.choice(['mixup', 'cutmix', 'none'])