leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
import json
import torch
import os
class Config:
"""Config class"""
def __init__(self, tag, root=""):
self.tag = tag
self.cli = False
# self.wandb = True
self.path = os.path.join(root, f"runs/{self.tag}")
self.cm = "gray"
self.data_path = ""
self.mask_coords = []
self.net_type = "conv-resize"
self.image_type = "n-phase"
self.l = 80
self.n_phases = 2
# Training hyperparams
self.batch_size = 4
self.beta1 = 0.9
self.beta2 = 0.999
self.max_iters = 400e3
self.timeout = 1e12
self.lrg = 0.0005
self.lr = 0.0005
self.Lambda = 10
self.critic_iters = 10
self.pw_coeff = 1
self.ngpu = torch.cuda.device_count()
if self.ngpu > 0:
self.device_name = "cuda:0"
else:
self.device_name = "cpu"
self.conv_resize = True
self.nz = 100
# Architecture
self.lays = 4
self.laysd = 5
# kernel sizes
self.dk, self.gk = [4] * self.laysd, [4] * self.lays
self.ds, self.gs = [2] * self.laysd, [2] * self.lays
self.df, self.gf = [self.n_phases, 64, 128, 256, 512, 1], [
self.nz,
512,
256,
128,
self.n_phases,
]
self.dp, self.gp = [1] * self.laysd, [2] * self.lays
# Last two layers conv resize (3,1,0)
self.gk[-2:], self.gs[-2:], self.gp[-2:] = [3, 3], [1, 1], [0, 0]
def update_params(self):
self.df[0] = self.n_phases
self.gf[-1] = self.n_phases
def save(self):
# j = {}
# for k, v in self.__dict__.items():
# j[k] = v
# with open(f"{self.path}/config.json", "w") as f:
# json.dump(j, f)
pass
def load(self):
with open(f"{self.path}/config.json", "r") as f:
j = json.load(f)
for k, v in j.items():
setattr(self, k, v)
def get_net_params(self):
return self.dk, self.ds, self.df, self.dp, self.gk, self.gs, self.gf, self.gp
def get_train_params(self):
return (
self.l,
self.batch_size,
self.beta1,
self.beta2,
self.lrg,
self.lr,
self.Lambda,
self.critic_iters,
self.nz,
)
class ConfigPoly(Config):
def __init__(self, tag, root):
super(ConfigPoly, self).__init__(tag, root=root)
self.frames = 100
# optimisation parameters
if self.cli:
self.opt_iters = 10000
else:
self.opt_iters = 1000
self.opt_lr = 0.001
# if self.image_type=='colour':
self.opt_kl_coeff = 0.00001
def get_train_params(self):
return (
self.l,
self.batch_size,
self.beta1,
self.beta2,
self.lrg,
self.lr,
self.Lambda,
self.critic_iters,
self.nz,
)