Spaces:
Paused
Paused
| import wandb | |
| import torch | |
| from torchvision.utils import make_grid | |
| import torch.distributed as dist | |
| from PIL import Image | |
| import os | |
| import argparse | |
| import hashlib | |
| import math | |
| def is_main_process(): | |
| return dist.get_rank() == 0 | |
| def namespace_to_dict(namespace): | |
| return { | |
| k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v | |
| for k, v in vars(namespace).items() | |
| } | |
| def generate_run_id(exp_name): | |
| # https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits | |
| return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8) | |
| def initialize(args, entity, exp_name, project_name): | |
| config_dict = namespace_to_dict(args) | |
| wandb.login(key=os.environ["WANDB_KEY"]) | |
| wandb.init( | |
| entity=entity, | |
| project=project_name, | |
| name=exp_name, | |
| config=config_dict, | |
| id=generate_run_id(exp_name), | |
| resume="allow", | |
| ) | |
| def log(stats, step=None): | |
| if is_main_process(): | |
| wandb.log({k: v for k, v in stats.items()}, step=step) | |
| def log_image(name, sample, step=None): | |
| if is_main_process(): | |
| sample = array2grid(sample) | |
| wandb.log({f"{name}": wandb.Image(sample), "train_step": step}) | |
| def array2grid(x): | |
| nrow = round(math.sqrt(x.size(0))) | |
| x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1)) | |
| x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy() | |
| return x |