from dataclasses import dataclass, field from utils import parse_structure from typing import Any, Dict, Mapping from .base import BaseSystemConfig, BaseSystem from torch import nn, Tensor import os import torch import numpy as np import models @dataclass class SimpleClassificationConfig(BaseSystemConfig): pass class SimpleClassificationSystem(BaseSystem): def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> BaseSystem: super().__init__(cfg, *args, **kwargs) self.cfg:SimpleClassificationConfig = parse_structure(SimpleClassificationConfig, cfg) self.model: nn.Module = getattr(models, self.cfg.model_type)(self.cfg.model) def forward(self, x: Tensor) -> Tensor: return self.model(x) def training_step(self, batch: Mapping[str, Tensor], batch_idx: int) -> Tensor: x: Tensor = batch[0] y: Tensor = batch[1].float() y_hat: Tensor = self.model(x).squeeze(-1) loss = self.criterion(y_hat, y) self.log( "train/loss", loss, on_step=self.cfg.log_on_step, on_epoch=self.cfg.log_on_epoch, prog_bar=self.cfg.log_prog_bar, logger=self.cfg.log_logger ) self.log_metrics(self.metrics_func(y_hat, y, 'train')) return loss def validation_step(self, batch: Mapping[str, Tensor], batch_idx: int) -> Tensor: x: Tensor = batch[0] y: Tensor = batch[1].float() y_hat: Tensor = self.model(x).squeeze(-1) loss = self.criterion(y_hat, y) self.log( "valid/loss", loss, on_step=self.cfg.log_on_step, on_epoch=self.cfg.log_on_epoch, prog_bar=self.cfg.log_prog_bar, logger=self.cfg.log_logger ) self.log_metrics(self.metrics_func(y_hat, y, 'valid')) return loss def test_step(self, batch: Mapping[str, Tensor], batch_idx: int) -> Tensor: x: Tensor = batch[0] y: Tensor = batch[1].float() y_hat: Tensor = self.model(x).squeeze(-1) loss = self.criterion(y_hat, y) self.log( "test/loss", loss, on_step=self.cfg.log_on_step, on_epoch=self.cfg.log_on_epoch, prog_bar=self.cfg.log_prog_bar, logger=self.cfg.log_logger ) metrics_dict = self.metrics_func(y_hat, y, 'test') self.log_metrics(metrics_dict)