from dataclasses import dataclass, field from typing import * from .ops import parse_optimizer, parse_scheduler, parse_loss from utils import parse_structure from metrics import BaseMetrics import lightning.pytorch as pl import torch import os @dataclass class BaseSystemConfig: model_type:str = 'BaseModel' model:Dict = field(default_factory=dict) optimizer:Dict = field(default_factory=dict) scheduler:Dict = field(default_factory=dict) loss:Dict = field(default_factory=dict) metrics:Dict = field(default_factory=dict) log_on_step: bool = False log_on_epoch: bool = True log_prog_bar: bool = True log_logger: bool = True class BaseSystem(pl.LightningModule): cfg: BaseSystemConfig def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.cfg = parse_structure(BaseSystemConfig, cfg) self.criterion = parse_loss(self.cfg.loss) self.metrics_func = BaseMetrics(self.cfg.metrics) def log_metrics(self, metrics:Dict[str, float]): for name, value in metrics.items(): self.log( name, value, on_step=self.cfg.log_on_step, on_epoch=self.cfg.log_on_epoch, prog_bar=self.cfg.log_prog_bar, logger=self.cfg.log_logger ) def __str__(self): return self.__class__.__name__ def set_save_dir(self, path:str): self.trial_dir = path os.makedirs(self.trial_dir, exist_ok=True) def configure_optimizers(self): if self.model is None: raise ValueError(f'self.model is not initialized') optimizer = parse_optimizer(self.cfg.optimizer, self.model) return { "optimizer": optimizer, "lr_scheduler": {"scheduler": parse_scheduler(self.cfg.scheduler, optimizer)} } def on_fit_start(self) -> None: super().on_fit_start() print('[INFO]: Experiment Started') def on_fit_end(self) -> None: super().on_fit_end() print('[INFO]: Experiment Ended') with open(os.path.join(self.trial_dir, 'done.txt'), 'w') as file: file.close()