from dataclasses import dataclass, field from torch import nn from typing import * from utils import parse_structure import torch.nn.functional as F import torch @dataclass class OptConfig: name:str = 'Adam' args:Dict = field(default_factory=dict) @dataclass class LossConfig: name:str = 'BCELoss' args:Dict = field(default_factory=dict) def parse_loss(cfg: Dict)->nn.Module: cfg:LossConfig = parse_structure(OptConfig, cfg) loss = getattr(torch.nn, cfg.name)(**cfg.args) return loss def parse_optimizer(cfg: Dict, model:nn.Module)->torch.optim.Optimizer: cfg:OptConfig = parse_structure(OptConfig, cfg) params = model.parameters() optim = getattr(torch.optim, cfg.name)(params, **cfg.args) return optim def parse_scheduler(cfg: Dict, optimizer: torch.optim.Optimizer): lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.name)(optimizer, **cfg.args) return lr_scheduler