Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import os | |
| class Config: | |
| """Config class""" | |
| def __init__(self, tag, root=""): | |
| self.tag = tag | |
| self.cli = False | |
| # self.wandb = True | |
| self.path = os.path.join(root, f"runs/{self.tag}") | |
| self.cm = "gray" | |
| self.data_path = "" | |
| self.mask_coords = [] | |
| self.net_type = "conv-resize" | |
| self.image_type = "n-phase" | |
| self.l = 80 | |
| self.n_phases = 2 | |
| # Training hyperparams | |
| self.batch_size = 4 | |
| self.beta1 = 0.9 | |
| self.beta2 = 0.999 | |
| self.max_iters = 400e3 | |
| self.timeout = 1e12 | |
| self.lrg = 0.0005 | |
| self.lr = 0.0005 | |
| self.Lambda = 10 | |
| self.critic_iters = 10 | |
| self.pw_coeff = 1 | |
| self.ngpu = torch.cuda.device_count() | |
| if self.ngpu > 0: | |
| self.device_name = "cuda:0" | |
| else: | |
| self.device_name = "cpu" | |
| self.conv_resize = True | |
| self.nz = 100 | |
| # Architecture | |
| self.lays = 4 | |
| self.laysd = 5 | |
| # kernel sizes | |
| self.dk, self.gk = [4] * self.laysd, [4] * self.lays | |
| self.ds, self.gs = [2] * self.laysd, [2] * self.lays | |
| self.df, self.gf = [self.n_phases, 64, 128, 256, 512, 1], [ | |
| self.nz, | |
| 512, | |
| 256, | |
| 128, | |
| self.n_phases, | |
| ] | |
| self.dp, self.gp = [1] * self.laysd, [2] * self.lays | |
| # Last two layers conv resize (3,1,0) | |
| self.gk[-2:], self.gs[-2:], self.gp[-2:] = [3, 3], [1, 1], [0, 0] | |
| def update_params(self): | |
| self.df[0] = self.n_phases | |
| self.gf[-1] = self.n_phases | |
| def save(self): | |
| # j = {} | |
| # for k, v in self.__dict__.items(): | |
| # j[k] = v | |
| # with open(f"{self.path}/config.json", "w") as f: | |
| # json.dump(j, f) | |
| pass | |
| def load(self): | |
| with open(f"{self.path}/config.json", "r") as f: | |
| j = json.load(f) | |
| for k, v in j.items(): | |
| setattr(self, k, v) | |
| def get_net_params(self): | |
| return self.dk, self.ds, self.df, self.dp, self.gk, self.gs, self.gf, self.gp | |
| def get_train_params(self): | |
| return ( | |
| self.l, | |
| self.batch_size, | |
| self.beta1, | |
| self.beta2, | |
| self.lrg, | |
| self.lr, | |
| self.Lambda, | |
| self.critic_iters, | |
| self.nz, | |
| ) | |
| class ConfigPoly(Config): | |
| def __init__(self, tag, root): | |
| super(ConfigPoly, self).__init__(tag, root=root) | |
| self.frames = 100 | |
| # optimisation parameters | |
| if self.cli: | |
| self.opt_iters = 10000 | |
| else: | |
| self.opt_iters = 1000 | |
| self.opt_lr = 0.001 | |
| # if self.image_type=='colour': | |
| self.opt_kl_coeff = 0.00001 | |
| def get_train_params(self): | |
| return ( | |
| self.l, | |
| self.batch_size, | |
| self.beta1, | |
| self.beta2, | |
| self.lrg, | |
| self.lr, | |
| self.Lambda, | |
| self.critic_iters, | |
| self.nz, | |
| ) | |