hispath / system /base.py
kohido's picture
init
148d42e
from dataclasses import dataclass, field
from typing import *
from .ops import parse_optimizer, parse_scheduler, parse_loss
from utils import parse_structure
from metrics import BaseMetrics
import lightning.pytorch as pl
import torch
import os
@dataclass
class BaseSystemConfig:
model_type:str = 'BaseModel'
model:Dict = field(default_factory=dict)
optimizer:Dict = field(default_factory=dict)
scheduler:Dict = field(default_factory=dict)
loss:Dict = field(default_factory=dict)
metrics:Dict = field(default_factory=dict)
log_on_step: bool = False
log_on_epoch: bool = True
log_prog_bar: bool = True
log_logger: bool = True
class BaseSystem(pl.LightningModule):
cfg: BaseSystemConfig
def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cfg = parse_structure(BaseSystemConfig, cfg)
self.criterion = parse_loss(self.cfg.loss)
self.metrics_func = BaseMetrics(self.cfg.metrics)
def log_metrics(self, metrics:Dict[str, float]):
for name, value in metrics.items():
self.log(
name,
value,
on_step=self.cfg.log_on_step,
on_epoch=self.cfg.log_on_epoch,
prog_bar=self.cfg.log_prog_bar,
logger=self.cfg.log_logger
)
def __str__(self):
return self.__class__.__name__
def set_save_dir(self, path:str):
self.trial_dir = path
os.makedirs(self.trial_dir, exist_ok=True)
def configure_optimizers(self):
if self.model is None:
raise ValueError(f'self.model is not initialized')
optimizer = parse_optimizer(self.cfg.optimizer, self.model)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": parse_scheduler(self.cfg.scheduler, optimizer)}
}
def on_fit_start(self) -> None:
super().on_fit_start()
print('[INFO]: Experiment Started')
def on_fit_end(self) -> None:
super().on_fit_end()
print('[INFO]: Experiment Ended')
with open(os.path.join(self.trial_dir, 'done.txt'), 'w') as file:
file.close()