alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import os
from dataclasses import dataclass
from model import ModelConfig, CoSeNetConfig, TransformerConfig
from dataset import DatasetConfig
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# SETUP CONFIGURATION #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
@dataclass
class SetupConfig:
"""
Configuration parameters related to the execution environment and logging.
This configuration controls device selection, checkpointing behavior,
reproducibility settings, and logging paths for an experiment.
"""
device_number: int = 0
save_model_each: int = 0
seed: int = None
logging_path: str = None
reload_checkpoint: bool = False
def overwrite_setup_config() -> SetupConfig:
"""
Create and override the default setup configuration.
This function customizes execution-level parameters such as logging
paths, checkpoint reloading, and model saving frequency.
Returns:
SetupConfig: The configured setup configuration object.
"""
config = SetupConfig()
config.logging_path = r'/workspace/logs'
config.reload_checkpoint = True
config.save_model_each = 1
return config
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# TRAINING CONFIGURATION #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
@dataclass
class TrainConfig:
"""
Training configuration container.
This dataclass aggregates model, dataset, and setup configurations,
together with optimization and training hyperparameters.
"""
# Linked configurations:
model_config: ModelConfig | None = None
dataset_config: DatasetConfig | None = None
setup_config: SetupConfig | None = None
# Training parameters:
batch_size: int = 32
num_epochs: int = 100
# Optimizer parameters:
learning_rate: float = 1e-4
learning_rate_min: float = 1e-5
weight_decay: float = 1e-8
betas: tuple[float, float] = (0.5, 0.999)
def overwrite_train_config() -> TrainConfig:
"""
Create and override the default training configuration.
This function customizes batch size, number of epochs, and optimizer
hyperparameters for the training process.
Returns:
TrainConfig: The configured training configuration object.
"""
config = TrainConfig()
config.batch_size = 4
config.num_epochs = 200
config.learning_rate = 5e-4
config.learning_rate_min = 5e-5
config.weight_decay = 1e-6
return config
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# DATASET CONFIGURATION #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def overwrite_dataset_config() -> DatasetConfig:
"""
Create and override the dataset configuration.
This function sets the file paths and usage percentages for training,
validation, and test datasets.
Returns:
DatasetConfig: The configured dataset configuration object.
"""
config = DatasetConfig()
config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
config.train_percentage = 1.0
config.val_percentage = 1.0
config.test_percentage = 1.0
return config
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# MODEL CONFIGURATION #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def overwrite_model_config() -> ModelConfig:
"""
Create and override the model configuration.
This function defines the architecture-level parameters, including
vocabulary size, embedding dimensionality, CoSeNet settings, and
the stack of Transformer encoder configurations.
Returns:
ModelConfig: The configured model configuration object.
"""
config = ModelConfig()
# High-level params:
config.vocab_size = 32_768
config.model_dim = 256
config.valid_padding = True
# CoSeNet params:
config.cosenet = CoSeNetConfig(
trainable=True,
init_scale=5.0
)
# Transformer params:
config.transformers = [
TransformerConfig(**cfg)
for cfg in [
{
"attention_heads": 16,
"feed_forward_multiplier": 8,
"dropout": 0.0,
"pre_normalize": True
},
{
"attention_heads": 16,
"feed_forward_multiplier": 8,
"dropout": 0.0,
"pre_normalize": True
}
]
]
return config
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# WHOLE CONFIGURATION #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def configuration() -> TrainConfig:
"""
Create the experiment configuration
:return: A TrainConfig configuration object
"""
config = overwrite_train_config()
config.setup_config = overwrite_setup_config()
config.model_config = overwrite_model_config()
config.dataset_config = overwrite_dataset_config()
# Assert:
if not os.path.exists(config.dataset_config.train_data_path):
raise FileNotFoundError(f"Train data path does not exist: {config.dataset_config.train_data_path}")
if not os.path.exists(config.dataset_config.val_data_path):
raise FileNotFoundError(f"Validation data path does not exist: {config.dataset_config.val_data_path}")
if not 0.0 < config.dataset_config.train_percentage <= 1.0:
raise ValueError("Train percentage must be in (0.0, 1.0]")
if not 0.0 < config.dataset_config.val_percentage <= 1.0:
raise ValueError("Validation percentage must be in (0.0, 1.0]")
return config
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #