kohido's picture
init
148d42e
raw
history blame contribute delete
930 Bytes
from dataclasses import dataclass, field
from torch import nn
from typing import *
from utils import parse_structure
import torch.nn.functional as F
import torch
@dataclass
class OptConfig:
name:str = 'Adam'
args:Dict = field(default_factory=dict)
@dataclass
class LossConfig:
name:str = 'BCELoss'
args:Dict = field(default_factory=dict)
def parse_loss(cfg: Dict)->nn.Module:
cfg:LossConfig = parse_structure(OptConfig, cfg)
loss = getattr(torch.nn, cfg.name)(**cfg.args)
return loss
def parse_optimizer(cfg: Dict, model:nn.Module)->torch.optim.Optimizer:
cfg:OptConfig = parse_structure(OptConfig, cfg)
params = model.parameters()
optim = getattr(torch.optim, cfg.name)(params, **cfg.args)
return optim
def parse_scheduler(cfg: Dict, optimizer: torch.optim.Optimizer):
lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.name)(optimizer, **cfg.args)
return lr_scheduler