Spaces:
Sleeping
Sleeping
| import torch | |
| import collections | |
| import logging | |
| import omegaconf | |
| import wandb | |
| import datetime | |
| import glob | |
| import os | |
| import json | |
| from PIL import Image | |
| class BaseTimer: | |
| def __init__(self): | |
| self.start = torch.cuda.Event(enable_timing=True) | |
| self.end = torch.cuda.Event(enable_timing=True) | |
| self.start.record() | |
| def stop(self): | |
| self.end.record() | |
| torch.cuda.synchronize() | |
| return self.start.elapsed_time(self.end) / 1000 | |
| class Timer: | |
| def __init__(self, info=None, log_event=None): | |
| self.info = info | |
| self.log_event = log_event | |
| def __enter__(self): | |
| self.start = torch.cuda.Event(enable_timing=True) | |
| self.end = torch.cuda.Event(enable_timing=True) | |
| self.start.record() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.end.record() | |
| torch.cuda.synchronize() | |
| self.duration = self.start.elapsed_time(self.end) / 1000 | |
| if self.info: | |
| self.info[f"duration/{self.log_event}"] = self.duration | |
| class _StreamingMean: | |
| def __init__(self, val=None, counts=None): | |
| if val is None: | |
| self.mean = 0.0 | |
| self.counts = 0 | |
| else: | |
| if isinstance(val, torch.Tensor): | |
| val = val.data.cpu().numpy() | |
| self.mean = val | |
| if counts is not None: | |
| self.counts = counts | |
| else: | |
| self.counts = 1 | |
| def update(self, mean, counts=1): | |
| if isinstance(mean, torch.Tensor): | |
| mean = mean.data.cpu().numpy() | |
| elif isinstance(mean, _StreamingMean): | |
| mean, counts = mean.mean, mean.counts * counts | |
| assert counts >= 0 | |
| if counts == 0: | |
| return | |
| total = self.counts + counts | |
| self.mean = self.counts / total * self.mean + counts / total * mean | |
| self.counts = total | |
| def __add__(self, other): | |
| new = self.__class__(self.mean, self.counts) | |
| if isinstance(other, _StreamingMean): | |
| if other.counts == 0: | |
| return new | |
| else: | |
| new.update(other.mean, other.counts) | |
| else: | |
| new.update(other) | |
| return new | |
| class StreamingMeans(collections.defaultdict): | |
| def __init__(self): | |
| super().__init__(_StreamingMean) | |
| def __setitem__(self, key, value): | |
| if isinstance(value, _StreamingMean): | |
| super().__setitem__(key, value) | |
| else: | |
| super().__setitem__(key, _StreamingMean(value)) | |
| def update(self, *args, **kwargs): | |
| for_update = dict(*args, **kwargs) | |
| for k, v in for_update.items(): | |
| self[k].update(v) | |
| def to_dict(self, prefix=""): | |
| return dict((prefix + k, v.mean) for k, v in self.items()) | |
| def to_str(self): | |
| return ", ".join([f"{k} = {v:.3f}" for k, v in self.to_dict().items()]) | |
| class ConsoleLogger: | |
| def __init__(self, name): | |
| self.logger = logging.getLogger(name) | |
| self.logger.handlers = [] | |
| self.logger.setLevel(logging.INFO) | |
| log_formatter = logging.Formatter( | |
| "%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(log_formatter) | |
| self.logger.addHandler(console_handler) | |
| self.logger.propagate = False | |
| def format_info(info): | |
| if not info: | |
| return str(info) | |
| log_groups = collections.defaultdict(dict) | |
| for k, v in info.to_dict().items(): | |
| prefix, suffix = k.split("/", 1) | |
| log_groups[prefix][suffix] = f"{v:.3f}" if isinstance(v, float) else str(v) | |
| formatted_info = "" | |
| max_group_size = len(max(log_groups, key=len)) + 2 | |
| max_k_size = max([len(max(g, key=len)) for g in log_groups.values()]) + 1 | |
| max_v_size = ( | |
| max([len(max(g.values(), key=len)) for g in log_groups.values()]) + 1 | |
| ) | |
| for group, group_info in log_groups.items(): | |
| group_str = [ | |
| f"{k:<{max_k_size}}={v:>{max_v_size}}" for k, v in group_info.items() | |
| ] | |
| max_g_size = len(max(group_str, key=len)) + 2 | |
| group_str = "".join([f"{g:>{max_g_size}}" for g in group_str]) | |
| formatted_info += f"\n{group + ':':<{max_group_size}}{group_str}" | |
| return formatted_info | |
| def log_iter(self, epoch_num, iter_num, num_iters, iter_info, event="epoch"): | |
| output_info = f"{event.upper()} {epoch_num}, ITER {iter_num}/{num_iters}:" | |
| output_info += self.format_info(iter_info) | |
| self.logger.info(output_info) | |
| def log_epoch(self, epoch_info, epoch_num): | |
| output_info = f"EPOCH {epoch_num}:" | |
| output_info += self.format_info(epoch_info) | |
| self.logger.info(output_info) | |
| class WandbLogger: | |
| def __init__(self, config): | |
| wandb.login(key=os.environ['WANDB_KEY'].strip(), relogin=True) | |
| if config.train.resume_path == "": | |
| config_for_logger = omegaconf.OmegaConf.to_container(config) | |
| self.wandb_args = { | |
| "id": wandb.util.generate_id(), | |
| "project": config.exp.wandb_project, | |
| "name": config.exp.name, | |
| "config": config_for_logger, | |
| } | |
| wandb.init(**self.wandb_args, resume="allow") | |
| run_dir = wandb.run.dir | |
| print("run_dir", run_dir) | |
| code = wandb.Artifact("project-source", type="code") | |
| for path in glob.glob("**/*.py", recursive=True): | |
| if not path.startswith("wandb"): | |
| if os.path.basename(path) != path: | |
| code.add_dir( | |
| os.path.dirname(path), name=os.path.dirname(path) | |
| ) | |
| else: | |
| code.add_file(os.path.basename(path), name=path) | |
| wandb.run.log_artifact(code) | |
| else: | |
| print(f"Resume training from {config.train.resume_path}") | |
| with open(config.train.resume_path, "r") as f: | |
| options = json.load(f) | |
| self.wandb_args = { | |
| "id": options['id'], | |
| "project": options['project'], | |
| "name": options['name'], | |
| "config": options['config'], | |
| } | |
| wandb.init(resume=True, **self.wandb_args) | |
| def log_epoch(iter_info, step): | |
| wandb.log( | |
| data={k: v.mean for k, v in iter_info.items()}, | |
| step=step + 1, | |
| commit=True, | |
| ) | |
| def log_special_pics(pics, captions, paths): | |
| to_log = {} | |
| for i, path in enumerate(paths): | |
| to_log[path] = wandb.Image(pics[i], caption=captions[path]) | |
| wandb.log(to_log) | |
| class BlankWandbLogger: | |
| def __init__(self): | |
| self.wandb_args = None | |
| def log_epoch(*args, **kwars): | |
| pass | |
| def log_special_pics(*args, **kwars): | |
| pass | |
| class TrainigLogger: | |
| def __init__(self, config): | |
| self.console_logger = ConsoleLogger("") | |
| if config.exp.wandb == True: | |
| self.wandb_logger = WandbLogger(config) | |
| else: | |
| self.wandb_logger = BlankWandbLogger() | |
| self.trainig_steps = config.train.steps | |
| self.val_step = config.train.val_step | |
| def log_train_time_left(self, iter_info, step): | |
| float_iter_time = iter_info["duration/iter_train"].mean | |
| float_val_time = iter_info["duration/iter_val"].mean | |
| time_left = str( | |
| datetime.datetime.fromtimestamp( | |
| float_iter_time * (self.trainig_steps - step) | |
| + float_val_time | |
| * ( | |
| (self.trainig_steps - step) // self.val_step | |
| ) | |
| ) | |
| - datetime.datetime.fromtimestamp(0) | |
| ) | |
| print() | |
| print(f"Step {step}/{self.trainig_steps}") | |
| print(f"Time left: {time_left}") | |
| print(f"Time per step: {iter_info['duration/iter_train'].mean :.3f}") | |
| print() | |
| print() | |
| def save_train_logs(self, iter_info, step): | |
| self.wandb_logger.log_epoch(iter_info, step) | |
| self.console_logger.log_epoch(iter_info, step) | |
| self.log_train_time_left(iter_info, step) | |
| def save_validation_logs(self, orig_pics, method_pics, captions, special_paths): | |
| log_pics = [] | |
| for real_img, fake_img in zip(orig_pics, method_pics): | |
| concat_img = Image.new( | |
| "RGB", (real_img.width + fake_img.width, real_img.height) | |
| ) | |
| concat_img.paste(real_img, (0, 0)) | |
| concat_img.paste(fake_img, (real_img.width, 0)) | |
| log_pics.append(concat_img) | |
| self.wandb_logger.log_special_pics(log_pics, captions, special_paths) | |