|
|
from dataclasses import dataclass, field |
|
|
from typing import * |
|
|
from .ops import parse_optimizer, parse_scheduler, parse_loss |
|
|
from utils import parse_structure |
|
|
from metrics import BaseMetrics |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import torch |
|
|
import os |
|
|
|
|
|
@dataclass |
|
|
class BaseSystemConfig: |
|
|
model_type:str = 'BaseModel' |
|
|
model:Dict = field(default_factory=dict) |
|
|
optimizer:Dict = field(default_factory=dict) |
|
|
scheduler:Dict = field(default_factory=dict) |
|
|
loss:Dict = field(default_factory=dict) |
|
|
metrics:Dict = field(default_factory=dict) |
|
|
log_on_step: bool = False |
|
|
log_on_epoch: bool = True |
|
|
log_prog_bar: bool = True |
|
|
log_logger: bool = True |
|
|
|
|
|
class BaseSystem(pl.LightningModule): |
|
|
cfg: BaseSystemConfig |
|
|
|
|
|
def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self.cfg = parse_structure(BaseSystemConfig, cfg) |
|
|
self.criterion = parse_loss(self.cfg.loss) |
|
|
self.metrics_func = BaseMetrics(self.cfg.metrics) |
|
|
|
|
|
def log_metrics(self, metrics:Dict[str, float]): |
|
|
for name, value in metrics.items(): |
|
|
self.log( |
|
|
name, |
|
|
value, |
|
|
on_step=self.cfg.log_on_step, |
|
|
on_epoch=self.cfg.log_on_epoch, |
|
|
prog_bar=self.cfg.log_prog_bar, |
|
|
logger=self.cfg.log_logger |
|
|
) |
|
|
|
|
|
def __str__(self): |
|
|
return self.__class__.__name__ |
|
|
|
|
|
def set_save_dir(self, path:str): |
|
|
self.trial_dir = path |
|
|
os.makedirs(self.trial_dir, exist_ok=True) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
if self.model is None: |
|
|
raise ValueError(f'self.model is not initialized') |
|
|
optimizer = parse_optimizer(self.cfg.optimizer, self.model) |
|
|
return { |
|
|
"optimizer": optimizer, |
|
|
"lr_scheduler": {"scheduler": parse_scheduler(self.cfg.scheduler, optimizer)} |
|
|
} |
|
|
|
|
|
def on_fit_start(self) -> None: |
|
|
super().on_fit_start() |
|
|
print('[INFO]: Experiment Started') |
|
|
|
|
|
def on_fit_end(self) -> None: |
|
|
super().on_fit_end() |
|
|
print('[INFO]: Experiment Ended') |
|
|
with open(os.path.join(self.trial_dir, 'done.txt'), 'w') as file: |
|
|
file.close() |