Spaces:
Runtime error
Runtime error
| import os | |
| from argparse import Namespace | |
| import yaml | |
| class DataConfig: | |
| def __init__( | |
| self, | |
| train_data_path: str, | |
| valid_data_path: str, | |
| batch_size: int, | |
| num_data_workers: int, | |
| prefetch_factor: int, | |
| time_delta_input_minutes: list[int], | |
| n_input_timestamps: int | None = None, | |
| pooling: int | None = None, | |
| random_vert_flip: bool = False, | |
| **kwargs, | |
| ): | |
| self.__dict__.update(kwargs) | |
| self.train_data_path = train_data_path | |
| self.valid_data_path = valid_data_path | |
| self.batch_size = batch_size | |
| self.num_data_workers = num_data_workers | |
| self.prefetch_factor = prefetch_factor | |
| self.time_delta_input_minutes = sorted(time_delta_input_minutes) | |
| self.n_input_timestamps = n_input_timestamps | |
| self.pooling = pooling | |
| self.random_vert_flip = random_vert_flip | |
| if self.n_input_timestamps is None: | |
| self.n_input_timestamps = len(self.time_delta_input_minutes) | |
| assert ( | |
| self.n_input_timestamps > 0 | |
| ), "Number of input timestamps must be greater than 0." | |
| assert self.n_input_timestamps <= len(self.time_delta_input_minutes), ( | |
| f"Cannot sample {self.n_input_timestamps} from list of " | |
| f"{self.time_delta_input_minutes} input timestamps." | |
| ) | |
| def to_dict(self): | |
| return self.__dict__ | |
| def from_argparse(args: Namespace): | |
| return DataConfig(**args.__dict__) | |
| def __str__(self): | |
| return ( | |
| f"Training index: {self.train_data_path}, " | |
| f"Validation index: {self.valid_data_path}, " | |
| ) | |
| def __repr__(self): | |
| return ( | |
| f"Training index: {self.train_data_path}, " | |
| f"Validation index: {self.valid_data_path}, " | |
| ) | |
| class ModelConfig: | |
| def __init__( | |
| self, | |
| # enc_num_layers: int, | |
| # enc_num_heads: int, | |
| # enc_embed_size: int, | |
| # dec_num_layers: int, | |
| # dec_num_heads: int, | |
| # dec_embed_size: int, | |
| # mask_ratio: float, | |
| **kwargs, | |
| ): | |
| self.__dict__.update(kwargs) | |
| # self.enc_num_layers = enc_num_layers | |
| # self.enc_num_heads = enc_num_heads | |
| # self.enc_embed_size = enc_embed_size | |
| # self.dec_num_layers = dec_num_layers | |
| # self.dec_num_heads = dec_num_heads | |
| # self.dec_embed_size = dec_embed_size | |
| # self.mlp_ratio = 0.0 | |
| # self.mask_ratio = mask_ratio | |
| self.__dict__.update(kwargs) | |
| def to_dict(self): | |
| return self.__dict__ | |
| def from_argparse(args: Namespace): | |
| return ModelConfig(**args.__dict__) | |
| def encoder_d_ff(self): | |
| return int(self.enc_embed_size * self.mlp_ratio) | |
| def decoder_d_ff(self): | |
| return int(self.dec_embed_size * self.mlp_ratio) | |
| def __str__(self): | |
| return ( | |
| f"Input channels: {self.model.in_channels}, " | |
| f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, " | |
| f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}" | |
| ) | |
| def __repr__(self): | |
| return ( | |
| f"Input channels: {self.model.in_channels}, " | |
| f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, " | |
| f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}" | |
| ) | |
| class OptimizerConfig: | |
| def __init__( | |
| self, | |
| warm_up_steps: int, | |
| max_epochs: int, | |
| learning_rate: float, | |
| min_lr: float, | |
| ): | |
| self.warm_up_steps = warm_up_steps | |
| self.max_epochs = max_epochs | |
| self.learning_rate = learning_rate | |
| self.min_lr = min_lr | |
| def to_dict(self): | |
| return self.__dict__ | |
| def from_argparse(args: Namespace): | |
| return ModelConfig(**args.__dict__) | |
| def __str__(self): | |
| return ( | |
| f"Epochs: {self.max_epochs}, " | |
| f"LR: {[self.learning_rate, self.min_lr]}, " | |
| f"Warm up: {self.warm_up_steps}," | |
| ) | |
| def __repr__(self): | |
| return ( | |
| f"Epochs: {self.max_epochs}, " | |
| f"LR: {[self.learning_rate, self.min_lr]}, " | |
| f"Warm up: {self.warm_up_steps}," | |
| ) | |
| class ExperimentConfig: | |
| def __init__( | |
| self, | |
| job_id: str, | |
| data_config: DataConfig, | |
| model_config: ModelConfig, | |
| optimizer_config: OptimizerConfig, | |
| path_experiment: str, | |
| parallelism: str, | |
| from_checkpoint: str | None = None, | |
| **kwargs, | |
| ): | |
| # additional experiment parameters used in downstream tasks | |
| self.__dict__.update(kwargs) | |
| self.job_id = job_id | |
| self.data = data_config | |
| self.model = model_config | |
| self.optimizer = optimizer_config | |
| self.path_experiment = path_experiment | |
| self.from_checkpoint = from_checkpoint | |
| self.parallelism = parallelism | |
| assert self.model.in_channels == len(self.data.channels), ( | |
| f"Number of model input channels ({self.model.in_channels}) must be " | |
| f"equal to number of input variables ({len(self.data.channels)})." | |
| ) | |
| if self.model.time_embedding["type"] == "linear": | |
| assert ( | |
| self.model.time_embedding["time_dim"] == self.data.n_input_timestamps | |
| ), "Time dimension of linear embedding must be equal to number of input timestamps." | |
| if self.rollout_steps > 0: | |
| assert self.data.n_input_timestamps == len( | |
| self.data.time_delta_input_minutes | |
| ), "Rollout does not support randomly sampled input timestamps." | |
| metrics_channels = [] | |
| for field1, value1 in self.metrics["train_metrics_config"].items(): | |
| for field2, value2 in self.metrics["train_metrics_config"][field1].items(): | |
| if field2 == "metrics": | |
| for metric_definition in value2: | |
| split_metric_definition = metric_definition.split(":") | |
| channels = ( | |
| split_metric_definition[2] | |
| if len(split_metric_definition) > 2 | |
| else None | |
| ) | |
| if channels is not None: | |
| metrics_channels = metrics_channels + channels.split("...") | |
| for field1, value1 in self.metrics["validation_metrics_config"].items(): | |
| for field2, value2 in self.metrics["validation_metrics_config"][ | |
| field1 | |
| ].items(): | |
| if field2 == "metrics": | |
| for metric_definition in value2: | |
| split_metric_definition = metric_definition.split(":") | |
| channels = ( | |
| split_metric_definition[2] | |
| if len(split_metric_definition) > 2 | |
| else None | |
| ) | |
| if channels is not None: | |
| metrics_channels = metrics_channels + channels.replace( | |
| "...", "&" | |
| ).split("&") | |
| assert set(metrics_channels).issubset(self.data.channels), ( | |
| f"{set(metrics_channels).difference(self.data.channels)} " | |
| f"not part of data input channels." | |
| ) | |
| assert self.parallelism in [ | |
| "ddp", | |
| "fsdp", | |
| ], 'Valid choices for `parallelism` are "ddp" and "fsdp".' | |
| def path_checkpoint(self) -> str: | |
| if self.path_experiment == "": | |
| return os.path.join(self.path_weights, "train", "checkpoint.pt") | |
| else: | |
| return os.path.join( | |
| os.path.dirname(self.path_experiment), | |
| "weights", | |
| "train", | |
| "checkpoint.pt", | |
| ) | |
| def path_weights(self) -> str: | |
| return os.path.join(self.path_experiment, self.make_suffix_path(), "weights") | |
| def path_states(self) -> str: | |
| return os.path.join(self.path_experiment, self.make_suffix_path(), "states") | |
| def to_dict(self): | |
| d = self.__dict__.copy() | |
| d["model"] = self.model.to_dict() | |
| d["data"] = self.data.to_dict() | |
| return d | |
| def from_argparse(args: Namespace): | |
| return ExperimentConfig( | |
| data_config=DataConfig.from_argparse(args), | |
| model_config=ModelConfig.from_argparse(args), | |
| optimizer_config=OptimizerConfig.from_argparse(args), | |
| **args.__dict__, | |
| ) | |
| def from_dict(params: dict): | |
| return ExperimentConfig( | |
| data_config=DataConfig(**params["data"]), | |
| model_config=ModelConfig(**params["model"]), | |
| optimizer_config=OptimizerConfig(**params["optimizer"]), | |
| **params, | |
| ) | |
| def make_folder_name(self) -> str: | |
| param_folder = "wpt-c1-s1" | |
| return param_folder | |
| def make_suffix_path(self) -> str: | |
| return os.path.join(self.job_id) | |
| def __str__(self): | |
| return ( | |
| f"ID: {self.job_id}, " | |
| f"Epochs: {self.optimizer.max_epochs}, " | |
| f"Batch size: {self.data.batch_size}, " | |
| f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, " | |
| f"Warm up: {self.optimizer.warm_up_steps}," | |
| f"DL workers: {self.data.num_data_workers}," | |
| f"Parallelism: {self.parallelism}" | |
| ) | |
| def __repr__(self): | |
| return ( | |
| f"ID: {self.job_id}, " | |
| f"Epochs: {self.optimizer.max_epochs}, " | |
| f"Batch size: {self.data.batch_size}, " | |
| f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, " | |
| f"Warm up: {self.optimizer.warm_up_steps}," | |
| f"DL workers: {self.data.num_data_workers}," | |
| f"Parallelism: {self.parallelism}" | |
| ) | |
| def get_config( | |
| config_path: str, | |
| ) -> ExperimentConfig: | |
| cfg = yaml.safe_load(open(config_path, "r")) | |
| cfg["data"]["scalers"] = yaml.safe_load(open(cfg["data"]["scalers_path"], "r")) | |
| return ExperimentConfig.from_dict(params=cfg) | |