_vggt / training /data /augmentation.py
CgvKodai's picture
Upload folder using huggingface_hub
66003a2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Dict
from torchvision import transforms
def get_image_augmentation(
color_jitter: Optional[Dict[str, float]] = None,
gray_scale: bool = True,
gau_blur: bool = False
) -> Optional[transforms.Compose]:
"""Create a composition of image augmentations.
Args:
color_jitter: Dictionary containing color jitter parameters:
- brightness: float (default: 0.5)
- contrast: float (default: 0.5)
- saturation: float (default: 0.5)
- hue: float (default: 0.1)
- p: probability of applying (default: 0.9)
If None, uses default values
gray_scale: Whether to apply random grayscale (default: True)
gau_blur: Whether to apply gaussian blur (default: False)
Returns:
A Compose object of transforms or None if no transforms are added
"""
transform_list = []
default_jitter = {
"brightness": 0.5,
"contrast": 0.5,
"saturation": 0.5,
"hue": 0.1,
"p": 0.9
}
# Handle color jitter
if color_jitter is not None:
# Merge with defaults for missing keys
effective_jitter = {**default_jitter, **color_jitter}
else:
effective_jitter = default_jitter
transform_list.append(
transforms.RandomApply(
[
transforms.ColorJitter(
brightness=effective_jitter["brightness"],
contrast=effective_jitter["contrast"],
saturation=effective_jitter["saturation"],
hue=effective_jitter["hue"],
)
],
p=effective_jitter["p"],
)
)
if gray_scale:
transform_list.append(transforms.RandomGrayscale(p=0.05))
if gau_blur:
transform_list.append(
transforms.RandomApply(
[transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05
)
)
return transforms.Compose(transform_list) if transform_list else None