File size: 930 Bytes
148d42e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from dataclasses import dataclass, field
from torch import nn
from typing import *
from utils import parse_structure
import torch.nn.functional as F
import torch
@dataclass
class OptConfig:
name:str = 'Adam'
args:Dict = field(default_factory=dict)
@dataclass
class LossConfig:
name:str = 'BCELoss'
args:Dict = field(default_factory=dict)
def parse_loss(cfg: Dict)->nn.Module:
cfg:LossConfig = parse_structure(OptConfig, cfg)
loss = getattr(torch.nn, cfg.name)(**cfg.args)
return loss
def parse_optimizer(cfg: Dict, model:nn.Module)->torch.optim.Optimizer:
cfg:OptConfig = parse_structure(OptConfig, cfg)
params = model.parameters()
optim = getattr(torch.optim, cfg.name)(params, **cfg.args)
return optim
def parse_scheduler(cfg: Dict, optimizer: torch.optim.Optimizer):
lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.name)(optimizer, **cfg.args)
return lr_scheduler |