File size: 930 Bytes
148d42e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from dataclasses import dataclass, field
from torch import nn
from typing import *
from utils import parse_structure

import torch.nn.functional as F
import torch

@dataclass
class OptConfig:
    name:str = 'Adam'
    args:Dict = field(default_factory=dict) 

@dataclass
class LossConfig:
    name:str = 'BCELoss'
    args:Dict = field(default_factory=dict)

def parse_loss(cfg: Dict)->nn.Module:
    cfg:LossConfig = parse_structure(OptConfig, cfg)
    loss = getattr(torch.nn, cfg.name)(**cfg.args)
    return loss

def parse_optimizer(cfg: Dict, model:nn.Module)->torch.optim.Optimizer:
    cfg:OptConfig = parse_structure(OptConfig, cfg)
    params = model.parameters()
    optim = getattr(torch.optim, cfg.name)(params, **cfg.args)
    return optim

def parse_scheduler(cfg: Dict, optimizer: torch.optim.Optimizer):
    lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.name)(optimizer, **cfg.args)
    return lr_scheduler