| import argparse, os, sys, glob |
| import torch |
| import time |
| import numpy as np |
| from omegaconf import OmegaConf |
| from PIL import Image |
| from tqdm import tqdm, trange |
| from einops import repeat |
|
|
| from main import instantiate_from_config |
| from taming.modules.transformer.mingpt import sample_with_past |
|
|
|
|
| rescale = lambda x: (x + 1.) / 2. |
|
|
|
|
| def chw_to_pillow(x): |
| return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8)) |
|
|
|
|
| @torch.no_grad() |
| def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None, |
| dim_z=256, h=16, w=16, verbose_time=False, top_p=None): |
| log = dict() |
| assert type(class_label) == int, f'expecting type int but type is {type(class_label)}' |
| qzshape = [batch_size, dim_z, h, w] |
| assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.' |
| c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) |
| t1 = time.time() |
| index_sample = sample_with_past(c_indices, model.transformer, steps=steps, |
| sample_logits=True, top_k=top_k, callback=callback, |
| temperature=temperature, top_p=top_p) |
| if verbose_time: |
| sampling_time = time.time() - t1 |
| print(f"Full sampling takes about {sampling_time:.2f} seconds.") |
| x_sample = model.decode_to_img(index_sample, qzshape) |
| log["samples"] = x_sample |
| log["class_label"] = c_indices |
| return log |
|
|
|
|
| @torch.no_grad() |
| def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None, |
| dim_z=256, h=16, w=16, verbose_time=False): |
| log = dict() |
| qzshape = [batch_size, dim_z, h, w] |
| assert model.be_unconditional, 'Expecting an unconditional model.' |
| c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) |
| t1 = time.time() |
| index_sample = sample_with_past(c_indices, model.transformer, steps=steps, |
| sample_logits=True, top_k=top_k, callback=callback, |
| temperature=temperature, top_p=top_p) |
| if verbose_time: |
| sampling_time = time.time() - t1 |
| print(f"Full sampling takes about {sampling_time:.2f} seconds.") |
| x_sample = model.decode_to_img(index_sample, qzshape) |
| log["samples"] = x_sample |
| return log |
|
|
|
|
| @torch.no_grad() |
| def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000, |
| given_classes=None, top_p=None): |
| batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size] |
| if not unconditional: |
| assert given_classes is not None |
| print("Running in pure class-conditional sampling mode. I will produce " |
| f"{num_samples} samples for each of the {len(given_classes)} classes, " |
| f"i.e. {num_samples*len(given_classes)} in total.") |
| for class_label in tqdm(given_classes, desc="Classes"): |
| for n, bs in tqdm(enumerate(batches), desc="Sampling Class"): |
| if bs == 0: break |
| logs = sample_classconditional(model, batch_size=bs, class_label=class_label, |
| temperature=temperature, top_k=top_k, top_p=top_p) |
| save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"]) |
| else: |
| print(f"Running in unconditional sampling mode, producing {num_samples} samples.") |
| for n, bs in tqdm(enumerate(batches), desc="Sampling"): |
| if bs == 0: break |
| logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p) |
| save_from_logs(logs, logdir, base_count=n * batch_size) |
|
|
|
|
| def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None): |
| xx = logs[key] |
| for i, x in enumerate(xx): |
| x = chw_to_pillow(x) |
| count = base_count + i |
| if cond_key is None: |
| x.save(os.path.join(logdir, f"{count:06}.png")) |
| else: |
| condlabel = cond_key[i] |
| if type(condlabel) == torch.Tensor: condlabel = condlabel.item() |
| os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True) |
| x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png")) |
|
|
|
|
| def get_parser(): |
| def str2bool(v): |
| if isinstance(v, bool): |
| return v |
| if v.lower() in ("yes", "true", "t", "y", "1"): |
| return True |
| elif v.lower() in ("no", "false", "f", "n", "0"): |
| return False |
| else: |
| raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-r", |
| "--resume", |
| type=str, |
| nargs="?", |
| help="load from logdir or checkpoint in logdir", |
| ) |
| parser.add_argument( |
| "-o", |
| "--outdir", |
| type=str, |
| nargs="?", |
| help="path where the samples will be logged to.", |
| default="" |
| ) |
| parser.add_argument( |
| "-b", |
| "--base", |
| nargs="*", |
| metavar="base_config.yaml", |
| help="paths to base configs. Loaded from left-to-right. " |
| "Parameters can be overwritten or added with command-line options of the form `--key value`.", |
| default=list(), |
| ) |
| parser.add_argument( |
| "-n", |
| "--num_samples", |
| type=int, |
| nargs="?", |
| help="num_samples to draw", |
| default=50000 |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| nargs="?", |
| help="the batch size", |
| default=25 |
| ) |
| parser.add_argument( |
| "-k", |
| "--top_k", |
| type=int, |
| nargs="?", |
| help="top-k value to sample with", |
| default=250, |
| ) |
| parser.add_argument( |
| "-t", |
| "--temperature", |
| type=float, |
| nargs="?", |
| help="temperature value to sample with", |
| default=1.0 |
| ) |
| parser.add_argument( |
| "-p", |
| "--top_p", |
| type=float, |
| nargs="?", |
| help="top-p value to sample with", |
| default=1.0 |
| ) |
| parser.add_argument( |
| "--classes", |
| type=str, |
| nargs="?", |
| help="specify comma-separated classes to sample from. Uses 1000 classes per default.", |
| default="imagenet" |
| ) |
| return parser |
|
|
|
|
| def load_model_from_config(config, sd, gpu=True, eval_mode=True): |
| model = instantiate_from_config(config) |
| if sd is not None: |
| model.load_state_dict(sd) |
| if gpu: |
| model.cuda() |
| if eval_mode: |
| model.eval() |
| return {"model": model} |
|
|
|
|
| def load_model(config, ckpt, gpu, eval_mode): |
| |
| if ckpt: |
| pl_sd = torch.load(ckpt, map_location="cpu") |
| global_step = pl_sd["global_step"] |
| print(f"loaded model from global step {global_step}.") |
| else: |
| pl_sd = {"state_dict": None} |
| global_step = None |
| model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] |
| return model, global_step |
|
|
|
|
| if __name__ == "__main__": |
| sys.path.append(os.getcwd()) |
| parser = get_parser() |
|
|
| opt, unknown = parser.parse_known_args() |
| assert opt.resume |
|
|
| ckpt = None |
|
|
| if not os.path.exists(opt.resume): |
| raise ValueError("Cannot find {}".format(opt.resume)) |
| if os.path.isfile(opt.resume): |
| paths = opt.resume.split("/") |
| try: |
| idx = len(paths)-paths[::-1].index("logs")+1 |
| except ValueError: |
| idx = -2 |
| logdir = "/".join(paths[:idx]) |
| ckpt = opt.resume |
| else: |
| assert os.path.isdir(opt.resume), opt.resume |
| logdir = opt.resume.rstrip("/") |
| ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
|
|
| base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) |
| opt.base = base_configs+opt.base |
|
|
| configs = [OmegaConf.load(cfg) for cfg in opt.base] |
| cli = OmegaConf.from_dotlist(unknown) |
| config = OmegaConf.merge(*configs, cli) |
|
|
| model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True) |
|
|
| if opt.outdir: |
| print(f"Switching logdir from '{logdir}' to '{opt.outdir}'") |
| logdir = opt.outdir |
|
|
| if opt.classes == "imagenet": |
| given_classes = [i for i in range(1000)] |
| else: |
| cls_str = opt.classes |
| assert not cls_str.endswith(","), 'class string should not end with a ","' |
| given_classes = [int(c) for c in cls_str.split(",")] |
|
|
| logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}", |
| f"{global_step}") |
|
|
| print(f"Logging to {logdir}") |
| os.makedirs(logdir, exist_ok=True) |
|
|
| run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional, |
| given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p) |
|
|
| print("done.") |
|
|