| |
| |
| |
| |
| |
| |
| |
| import os |
| from dataclasses import dataclass |
| from model import ModelConfig, CoSeNetConfig, TransformerConfig |
| from dataset import DatasetConfig |
|
|
|
|
| |
| |
| |
| @dataclass |
| class SetupConfig: |
| """ |
| Configuration parameters related to the execution environment and logging. |
| |
| This configuration controls device selection, checkpointing behavior, |
| reproducibility settings, and logging paths for an experiment. |
| """ |
| device_number: int = 0 |
| save_model_each: int = 0 |
| seed: int = None |
| logging_path: str = None |
| reload_checkpoint: bool = False |
|
|
|
|
| def overwrite_setup_config() -> SetupConfig: |
| """ |
| Create and override the default setup configuration. |
| |
| This function customizes execution-level parameters such as logging |
| paths, checkpoint reloading, and model saving frequency. |
| |
| Returns: |
| SetupConfig: The configured setup configuration object. |
| """ |
| config = SetupConfig() |
| config.logging_path = r'/workspace/logs' |
| config.reload_checkpoint = True |
| config.save_model_each = 1 |
| return config |
|
|
|
|
| |
| |
| |
| @dataclass |
| class TrainConfig: |
| """ |
| Training configuration container. |
| |
| This dataclass aggregates model, dataset, and setup configurations, |
| together with optimization and training hyperparameters. |
| """ |
| |
| model_config: ModelConfig | None = None |
| dataset_config: DatasetConfig | None = None |
| setup_config: SetupConfig | None = None |
|
|
| |
| batch_size: int = 32 |
| num_epochs: int = 100 |
|
|
| |
| learning_rate: float = 1e-4 |
| learning_rate_min: float = 1e-5 |
| weight_decay: float = 1e-8 |
| betas: tuple[float, float] = (0.5, 0.999) |
|
|
|
|
| def overwrite_train_config() -> TrainConfig: |
| """ |
| Create and override the default training configuration. |
| |
| This function customizes batch size, number of epochs, and optimizer |
| hyperparameters for the training process. |
| |
| Returns: |
| TrainConfig: The configured training configuration object. |
| """ |
| config = TrainConfig() |
| config.batch_size = 4 |
| config.num_epochs = 200 |
| config.learning_rate = 5e-4 |
| config.learning_rate_min = 5e-5 |
| config.weight_decay = 1e-6 |
| return config |
|
|
|
|
| |
| |
| |
| def overwrite_dataset_config() -> DatasetConfig: |
| """ |
| Create and override the dataset configuration. |
| |
| This function sets the file paths and usage percentages for training, |
| validation, and test datasets. |
| |
| Returns: |
| DatasetConfig: The configured dataset configuration object. |
| """ |
| config = DatasetConfig() |
| config.train_data_path = r"/workspace/data/tokens-A000-segmentation" |
| config.val_data_path = r"/workspace/data/tokens-A001-segmentation" |
| config.test_data_path = r"/workspace/data/tokens-A002-segmentation" |
| config.train_percentage = 1.0 |
| config.val_percentage = 1.0 |
| config.test_percentage = 1.0 |
| return config |
|
|
|
|
| |
| |
| |
| def overwrite_model_config() -> ModelConfig: |
| """ |
| Create and override the model configuration. |
| |
| This function defines the architecture-level parameters, including |
| vocabulary size, embedding dimensionality, CoSeNet settings, and |
| the stack of Transformer encoder configurations. |
| |
| Returns: |
| ModelConfig: The configured model configuration object. |
| """ |
| config = ModelConfig() |
|
|
| |
| config.vocab_size = 32_768 |
| config.model_dim = 256 |
| config.valid_padding = True |
|
|
| |
| config.cosenet = CoSeNetConfig( |
| trainable=True, |
| init_scale=5.0 |
| ) |
|
|
| |
| config.transformers = [ |
| TransformerConfig(**cfg) |
| for cfg in [ |
| { |
| "attention_heads": 16, |
| "feed_forward_multiplier": 8, |
| "dropout": 0.0, |
| "pre_normalize": True |
| }, |
| { |
| "attention_heads": 16, |
| "feed_forward_multiplier": 8, |
| "dropout": 0.0, |
| "pre_normalize": True |
| } |
| ] |
| ] |
|
|
| return config |
|
|
|
|
| |
| |
| |
| def configuration() -> TrainConfig: |
| """ |
| Create the experiment configuration |
| :return: A TrainConfig configuration object |
| """ |
| config = overwrite_train_config() |
| config.setup_config = overwrite_setup_config() |
| config.model_config = overwrite_model_config() |
| config.dataset_config = overwrite_dataset_config() |
|
|
| |
| if not os.path.exists(config.dataset_config.train_data_path): |
| raise FileNotFoundError(f"Train data path does not exist: {config.dataset_config.train_data_path}") |
| if not os.path.exists(config.dataset_config.val_data_path): |
| raise FileNotFoundError(f"Validation data path does not exist: {config.dataset_config.val_data_path}") |
| if not 0.0 < config.dataset_config.train_percentage <= 1.0: |
| raise ValueError("Train percentage must be in (0.0, 1.0]") |
| if not 0.0 < config.dataset_config.val_percentage <= 1.0: |
| raise ValueError("Validation percentage must be in (0.0, 1.0]") |
|
|
| return config |
| |
| |
| |
|
|