|
|
|
|
|
|
|
|
import wandb |
|
|
import torch |
|
|
from torchvision.utils import make_grid |
|
|
import torch.distributed as dist |
|
|
from PIL import Image |
|
|
import os |
|
|
import argparse |
|
|
import hashlib |
|
|
import math |
|
|
|
|
|
|
|
|
def is_main_process(): |
|
|
return dist.get_rank() == 0 |
|
|
|
|
|
def namespace_to_dict(namespace): |
|
|
return { |
|
|
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v |
|
|
for k, v in vars(namespace).items() |
|
|
} |
|
|
|
|
|
|
|
|
def generate_run_id(exp_name): |
|
|
|
|
|
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8) |
|
|
|
|
|
|
|
|
def initialize(args, entity, exp_name, project_name): |
|
|
config_dict = namespace_to_dict(args) |
|
|
wandb.login(key=os.environ["WANDB_KEY"]) |
|
|
wandb.init( |
|
|
entity=entity, |
|
|
project=project_name, |
|
|
name=exp_name, |
|
|
config=config_dict, |
|
|
id=generate_run_id(exp_name), |
|
|
resume="allow", |
|
|
) |
|
|
|
|
|
|
|
|
def log(stats, step=None): |
|
|
if is_main_process(): |
|
|
wandb.log({k: v for k, v in stats.items()}, step=step) |
|
|
|
|
|
|
|
|
def log_image(name, sample, step=None): |
|
|
if is_main_process(): |
|
|
sample = array2grid(sample) |
|
|
wandb.log({f"{name}": wandb.Image(sample), "train_step": step}) |
|
|
|
|
|
|
|
|
def array2grid(x): |
|
|
nrow = round(math.sqrt(x.size(0))) |
|
|
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1)) |
|
|
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy() |
|
|
return x |