|
|
from dataclasses import dataclass, field |
|
|
from torch import nn |
|
|
from typing import * |
|
|
from utils import parse_structure |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import torch |
|
|
|
|
|
@dataclass |
|
|
class OptConfig: |
|
|
name:str = 'Adam' |
|
|
args:Dict = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
|
class LossConfig: |
|
|
name:str = 'BCELoss' |
|
|
args:Dict = field(default_factory=dict) |
|
|
|
|
|
def parse_loss(cfg: Dict)->nn.Module: |
|
|
cfg:LossConfig = parse_structure(OptConfig, cfg) |
|
|
loss = getattr(torch.nn, cfg.name)(**cfg.args) |
|
|
return loss |
|
|
|
|
|
def parse_optimizer(cfg: Dict, model:nn.Module)->torch.optim.Optimizer: |
|
|
cfg:OptConfig = parse_structure(OptConfig, cfg) |
|
|
params = model.parameters() |
|
|
optim = getattr(torch.optim, cfg.name)(params, **cfg.args) |
|
|
return optim |
|
|
|
|
|
def parse_scheduler(cfg: Dict, optimizer: torch.optim.Optimizer): |
|
|
lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.name)(optimizer, **cfg.args) |
|
|
return lr_scheduler |