eksemyashkina's picture
Added files
f096e52
raw
history blame contribute delete
845 Bytes
import torchvision.transforms as T
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transform = T.Compose([
T.RandomRotation(degrees=15),
T.RandomResizedCrop(224, scale=(0.5, 1.0)),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
test_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
class EMA:
def __init__(self, alpha: float = 0.9) -> None:
self.value = None
self.alpha = alpha
def __call__(self, value: float) -> float:
if self.value is None:
self.value = value
else:
self.value = self.alpha * self.value + (1 - self.alpha) * value
return self.value