File size: 11,287 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from enum import Enum
from typing import Literal, Self

from pydantic import BaseModel as Validation
from pydantic import field_validator

Scheduler = Literal[
    "cosine",  # CosineAnnealingLR
    "cyclic",  # CosineAnnealingWarmRestarts
]

Precision = Literal[
    16,
    32,
    64,
    "16",
    "16-true",
    "16-mixed",
    "bf16-true",
    "bf16-mixed",
    "32",
    "32-true",
    "64",
    "64-true",
]


class ValidateEnum(str, Enum):
    @classmethod
    def get_all_values(cls) -> list[str]:
        return [value.value for value in cls]

    @classmethod
    def validate(cls, value: str) -> str:
        values = cls.get_all_values()
        if value not in values:
            raise ValueError(f"\n\nInvalid value: '{value}'\n\nPossible values are: {values}\n\nSee {__file__}\n\n")
        return value


class Optimizer(ValidateEnum):
    AdamW = "AdamW"
    SGD = "SGD"


class InferenceStrategy(ValidateEnum):
    SOFTMAX = "softmax"


class Head(ValidateEnum):
    Linear = "linear"
    NLinear = "LinearNorm"


class Backbone(ValidateEnum):
    # https://hf.co/docs/transformers/en/model_doc/clip
    # https://hf.co/openai/models?search=clip
    CLIP_B_16 = "openai/clip-vit-base-patch16"
    CLIP_B_32 = "openai/clip-vit-base-patch32"
    CLIP_L_14 = "openai/clip-vit-large-patch14"
    CLIP_L_14_336 = "openai/clip-vit-large-patch14-336"

    # https://hf.co/collections/facebook/perception-encoder-67f977c9a65ca5895a7f6ba1
    PerceptionEncoder_B_p16_224 = "vit_pe_core_base_patch16_224"  # (from timm)
    PerceptionEncoder_L_p14_336 = "vit_pe_core_large_patch14_336"  # (from timm)
    PerceptionEncoder_G_p14_448 = "vit_pe_core_gigantic_patch14_448"  # (from timm)

    # https://hf.co/models?search=facebook/dinov3
    DINOv3_ViT_B = "facebook/dinov3-vitb16-pretrain-lvd1689m"
    DINOv3_ViT_L = "facebook/dinov3-vitl16-pretrain-lvd1689m"


class BackboneArgs(Validation, validate_assignment=True):
    img_size: None | int = 224  # Image size for the backbone
    merge_cls_token_with_patches: None | Literal["cat", "mean"] = None  # Concatenate CLS token with patches


class Loss(Validation, validate_assignment=True):
    # Cross-entropy loss (multi-class classification)
    ce_labels: float = 0.0  # Loss weight
    label_smoothing: float = 0.0  # Loss weight
    # Uniformity and alignment loss
    uniformity: float = 0.0  # Loss weight
    alignment_labels: float = 0.0  # Loss weight


class LoRA(Validation, validate_assignment=True):
    target_modules: list[str] | str = ["out_proj"]  # Target modules
    rank: int = 1  # Rank of the decomposition
    alpha: int = 32  # Scaling factor
    dropout: float = 0.05  # Dropout probability
    bias: str = "none"  # Bias configuration
    use_rslora: bool = False  # Use rsLoRA
    use_dora: bool = False  # Use DoRA


class PEFT(Validation, validate_assignment=True):
    lora: None | LoRA = None  # LORA configuration


class CustomPreprocessing(Validation, validate_assignment=True):
    zoom_factor: float = 1.0  # Zoom factor for the input images
    image_size: None | list[int] = None  # Target image size (width, height)
    flip_left_right: bool = False  # Whether to flip the image left-right (mirror)


class Augmentations(Validation, validate_assignment=True):
    random_horizontal_flip: float = 0.5  # Probability of random horizontal flip, 0 - no augmentations
    random_affine_degrees: int = 10  # Random affine rotation degrees, 0 - no rotation
    random_affine_translate: None | list[float] = [0.1, 0.1]  # Random affine translation, None - no translation
    random_affine_scale: None | list[float] = [0.9, 1.1]  # Random affine scale, None - no scaling
    gaussian_blur_prob: float = 0.1  # Probability of applying Gaussian blur, 0 - no blur
    gaussian_blur_kernel_size: int | list[int] = 7  # Gaussian blur kernel size, 0 - no blur
    gaussian_blur_sigma: float | list[float] = [0.1, 2.0]  # Gaussian blur sigma
    color_jitter_brightness: float = 0.1  # Brightness jitter factor, 0 - no brightness jitter
    color_jitter_contrast: float = 0.1  # Contrast jitter factor, 0 - no contrast jitter
    jpeg_quality: int | list[int] = [40, 100]  # JPEG quality range, 100 - no JPEG compression
    resize: None | int | list[int] = None  # Resize to (width, height), None - no resizing
    # 0:nearest,  1:lanczos, 2:bilinear, 3:bicubic, 4:box, 5:hamming
    resize_interpolation: int = 2  # Interpolation method for resizing, see InterpolationMode or Pillow integer constant
    gaussian_noise_sigma: float = 0.0  # Standard deviation of Gaussian noise to add, 0 - no noise

    @staticmethod
    def get_empty() -> Self:
        return Augmentations(
            random_horizontal_flip=0.0,
            random_affine_degrees=0,
            random_affine_translate=None,
            random_affine_scale=None,
            gaussian_blur_prob=0.0,
            gaussian_blur_kernel_size=0,
            gaussian_blur_sigma=0.0,
            color_jitter_brightness=0.0,
            color_jitter_contrast=0.0,
            jpeg_quality=100,
            resize=None,
        )


class Config(Validation, validate_assignment=True):
    # Run configuration
    run_name: str = "exp-name-1"  # Name of the run
    run_dir: str = "runs/exp"  # Directory to save the run
    seed: int = 42  # Random seed for reproducibility
    throw_exception_if_run_exists: bool = False  # Throw an exception if the run directory exists
    remove_if_run_exists: bool = False  # Remove existing run directory if it exists

    # Model configuration
    num_classes: int = 2
    num_sources: int = 5
    checkpoint: None | str = None  # Path to a checkpoint to load
    backbone: str = Backbone.CLIP_B_32  # Backbone model to use
    backbone_args: None | BackboneArgs = None  # Arguments for the backbone model
    freeze_feature_extractor: bool = True  # Freeze the feature extractor
    unfreeze_layers: list[str] = []  # Layers to unfreeze
    head: str = Head.Linear  # Head model to use
    inference_strategy: str = "softmax"  # Inference strategy to use

    # PEFT configuration
    peft_v2: None | PEFT = None

    # Data configuration
    trn_files: list[str] | dict[str, list[str]] = []  # Files containing paths to training samples
    val_files: list[str] | dict[str, list[str]] = []  # Files containing paths to validation samples
    tst_files: list[str] | dict[str, list[str]] = []  # Files containing paths to test samples
    limit_trn_files: None | int = None  # Limit the number of training files
    limit_val_files: None | int = None  # Limit the number of validation files
    limit_tst_files: None | int = None  # Limit the number of test files
    binary_labels: bool = True  # Use binary labels
    custom_preprocessing: None | CustomPreprocessing = None  # Custom preprocessing pipeline
    augmentations: None | Augmentations = Augmentations()  # Training augmentations
    test_augmentations: None | Augmentations = None  # Test-time augmentations
    load_pairs: bool = False  # Whether to load csv files with paired videos

    # Optimization configuration
    lr: float = 0.0003  # Learning rate (initial / base)
    min_lr: float = 1e-6  # Minimum learning rate
    lr_scheduler: None | Scheduler = "cosine"  # Learning rate scheduler
    warmup_epochs: float = 0  # Number of warmup epochs (can be a fraction)
    num_epochs_in_cycle: float = 1  # Number of epochs in a cycle (for cyclic schedulers)
    optimizer: str = "AdamW"  # Optimizer to use
    weight_decay: float = 0.0  # AdamW weight decay
    betas: list[float] = [0.9, 0.999]  # First and second moment coefficients for SGD and AdamW
    loss: Loss = Loss()  # Loss function to use

    # Training configuration (managed by Lightning Trainer)
    max_epochs: int = 1  # Number of epochs to train
    batch_size: int = 512  # Required batch size to perform one step
    mini_batch_size: int = 512  # Mini batch size per device
    num_workers: int = 12  # Number of workers for the DataLoader
    devices: list[int] | str | int = "auto"  # Devices to use for training
    precision: Precision = "bf16-mixed"  # Precision for the model
    fast_dev_run: int | bool = False  # Run a fast development run
    overfit_batches: int | float = 0.0  # Overfit on a subset of the data
    limit_train_batches: None | int | float = None  # Limit the number of training batches
    limit_test_batches: None | int | float = None  # Limit the number of test batches
    limit_val_batches: None | int | float = None  # Limit the number of validation batches
    deterministic: None | bool = None  # Set random seed for reproducibility
    detect_anomaly: bool = False  # Detect anomalies in the model
    early_stopping_patience: int = -1  # Early stopping patience, -1 to disable
    checkpoint_name: str = "best_mAP"  # Checkpoint to use for testing
    monitor_metric: str = "val/mAP_video"  # Metric to monitor for early stopping and checkpointing
    monitor_metric_mode: str = "max"  # Mode for monitoring metric ("max" or "min")

    # Logging
    wandb: bool = False  # Log metrics to Weights & Biases
    wandb_tags: list[str] = []  # Tags to use for Weights & Biases
    wandb_group: None | str = None  # Group to use for Weights & Biases

    # Post-processing
    make_binary_before_video_aggregation: bool = True  # Make binary labels before video aggregation
    reduce_video_predictions: Literal["mean", "median"] = "mean"  # Reduce strategy for frame to video probs

    # Validation
    @field_validator("head")
    @classmethod
    def validate_head(cls, head: str) -> str:
        return Head.validate(head)

    @field_validator("backbone")
    @classmethod
    def validate_backbone(cls, backbone: str) -> str:
        return Backbone.validate(backbone)

    @field_validator("inference_strategy")
    @classmethod
    def validate_inference_strategy(cls, inference_strategy: str) -> str:
        return InferenceStrategy.validate(inference_strategy)

    @field_validator("optimizer")
    @classmethod
    def validate_optimizer(cls, optimizer: str) -> str:
        return Optimizer.validate(optimizer)

    def set_values_from_dict(self, dict: dict) -> Self:
        """
        Set values in the config from a dictionary. The keys of the dictionary can be
        either the names of the attributes in the config or a dot-separated path to the
        attribute. For example, if the config has an attribute `a.b.c`, you can set its
        value by passing a dictionary with the key `a.b.c`.
        """
        # Iterate over the dictionary and set the values in the config
        for key, value in dict.items():
            # If key contains a dot, traverse the config to the last key
            if "." in key:
                keys = key.split(".")
                # Traverse the config to the last key
                last_dict = self
                for next_key in keys[:-1]:
                    last_dict = getattr(last_dict, next_key)
                setattr(last_dict, keys[-1], value)
            else:
                setattr(self, key, value)
        return self


def load_config(path: str) -> Config:
    import yaml

    # read yaml config
    with open(path, "r") as f:
        config = yaml.safe_load(f)

    # overwrite config
    config = Config(**config)
    return config