| | import os, math |
| | import torch |
| | import torch.nn.functional as F |
| | import pytorch_lightning as pl |
| |
|
| | from main import instantiate_from_config |
| | from taming.modules.util import SOSProvider |
| |
|
| |
|
| | def disabled_train(self, mode=True): |
| | """Overwrite model.train with this function to make sure train/eval mode |
| | does not change anymore.""" |
| | return self |
| |
|
| |
|
| | class Net2NetTransformer(pl.LightningModule): |
| | def __init__(self, |
| | transformer_config, |
| | first_stage_config, |
| | cond_stage_config, |
| | permuter_config=None, |
| | ckpt_path=None, |
| | ignore_keys=[], |
| | first_stage_key="image", |
| | cond_stage_key="depth", |
| | downsample_cond_size=-1, |
| | pkeep=1.0, |
| | sos_token=0, |
| | unconditional=False, |
| | ): |
| | super().__init__() |
| | self.be_unconditional = unconditional |
| | self.sos_token = sos_token |
| | self.first_stage_key = first_stage_key |
| | self.cond_stage_key = cond_stage_key |
| | self.init_first_stage_from_ckpt(first_stage_config) |
| | self.init_cond_stage_from_ckpt(cond_stage_config) |
| | if permuter_config is None: |
| | permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} |
| | self.permuter = instantiate_from_config(config=permuter_config) |
| | self.transformer = instantiate_from_config(config=transformer_config) |
| |
|
| | if ckpt_path is not None: |
| | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) |
| | self.downsample_cond_size = downsample_cond_size |
| | self.pkeep = pkeep |
| |
|
| | def init_from_ckpt(self, path, ignore_keys=list()): |
| | sd = torch.load(path, map_location="cpu")["state_dict"] |
| | for k in sd.keys(): |
| | for ik in ignore_keys: |
| | if k.startswith(ik): |
| | self.print("Deleting key {} from state_dict.".format(k)) |
| | del sd[k] |
| | self.load_state_dict(sd, strict=False) |
| | print(f"Restored from {path}") |
| |
|
| | def init_first_stage_from_ckpt(self, config): |
| | model = instantiate_from_config(config) |
| | model = model.eval() |
| | model.train = disabled_train |
| | self.first_stage_model = model |
| |
|
| | def init_cond_stage_from_ckpt(self, config): |
| | if config == "__is_first_stage__": |
| | print("Using first stage also as cond stage.") |
| | self.cond_stage_model = self.first_stage_model |
| | elif config == "__is_unconditional__" or self.be_unconditional: |
| | print(f"Using no cond stage. Assuming the training is intended to be unconditional. " |
| | f"Prepending {self.sos_token} as a sos token.") |
| | self.be_unconditional = True |
| | self.cond_stage_key = self.first_stage_key |
| | self.cond_stage_model = SOSProvider(self.sos_token) |
| | else: |
| | model = instantiate_from_config(config) |
| | model = model.eval() |
| | model.train = disabled_train |
| | self.cond_stage_model = model |
| |
|
| | def forward(self, x, c): |
| | |
| | |
| | |
| | _, z_indices = self.encode_to_z(x) |
| | _, c_indices = self.encode_to_c(c) |
| | |
| | if self.training and self.pkeep < 1.0: |
| | mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, |
| | device=z_indices.device)) |
| | mask = mask.round().to(dtype=torch.int64) |
| | r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) |
| | a_indices = mask*z_indices+(1-mask)*r_indices |
| | else: |
| | a_indices = z_indices |
| | |
| | cz_indices = torch.cat((c_indices, a_indices), dim=1) |
| |
|
| | |
| | |
| | target = z_indices |
| | |
| | logits, _ = self.transformer(cz_indices[:, :-1]) |
| | |
| | logits = logits[:, c_indices.shape[1]-1:] |
| |
|
| | return logits, target |
| |
|
| | def top_k_logits(self, logits, k): |
| | v, ix = torch.topk(logits, k) |
| | out = logits.clone() |
| | out[out < v[..., [-1]]] = -float('Inf') |
| | return out |
| |
|
| | @torch.no_grad() |
| | def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None, |
| | callback=lambda k: None): |
| | x = torch.cat((c,x),dim=1) |
| | block_size = self.transformer.get_block_size() |
| | assert not self.transformer.training |
| | if self.pkeep <= 0.0: |
| | |
| | assert len(x.shape)==2 |
| | noise_shape = (x.shape[0], steps-1) |
| | |
| | noise = c.clone()[:,x.shape[1]-c.shape[1]:-1] |
| | x = torch.cat((x,noise),dim=1) |
| | logits, _ = self.transformer(x) |
| | |
| | logits = logits / temperature |
| | |
| | if top_k is not None: |
| | logits = self.top_k_logits(logits, top_k) |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | if sample: |
| | shape = probs.shape |
| | probs = probs.reshape(shape[0]*shape[1],shape[2]) |
| | ix = torch.multinomial(probs, num_samples=1) |
| | probs = probs.reshape(shape[0],shape[1],shape[2]) |
| | ix = ix.reshape(shape[0],shape[1]) |
| | else: |
| | _, ix = torch.topk(probs, k=1, dim=-1) |
| | |
| | x = ix[:, c.shape[1]-1:] |
| | else: |
| | for k in range(steps): |
| | callback(k) |
| | assert x.size(1) <= block_size |
| | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] |
| | logits, _ = self.transformer(x_cond) |
| | |
| | logits = logits[:, -1, :] / temperature |
| | |
| | if top_k is not None: |
| | logits = self.top_k_logits(logits, top_k) |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | if sample: |
| | ix = torch.multinomial(probs, num_samples=1) |
| | else: |
| | _, ix = torch.topk(probs, k=1, dim=-1) |
| | |
| | x = torch.cat((x, ix), dim=1) |
| | |
| | x = x[:, c.shape[1]:] |
| | return x |
| |
|
| | @torch.no_grad() |
| | def encode_to_z(self, x): |
| | quant_z, _, info = self.first_stage_model.encode(x) |
| | indices = info[2].view(quant_z.shape[0], -1) |
| | indices = self.permuter(indices) |
| | return quant_z, indices |
| |
|
| | @torch.no_grad() |
| | def encode_to_c(self, c): |
| | if self.downsample_cond_size > -1: |
| | c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) |
| | |
| | |
| | |
| | |
| | quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) |
| | if len(indices.shape) != 2: |
| | indices = indices.view(c.shape[0], -1) |
| | return quant_c, indices |
| | |
| | @torch.no_grad() |
| | def decode_to_img(self, index, zshape): |
| | index = self.permuter(index, reverse=True) |
| | bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) |
| | quant_z = self.first_stage_model.quantize.get_codebook_entry( |
| | index.reshape(-1), shape=bhwc) |
| | x = self.first_stage_model.decode(quant_z) |
| | return x |
| |
|
| | @torch.no_grad() |
| | def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): |
| | log = dict() |
| |
|
| | N = 4 |
| | if lr_interface: |
| | x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) |
| | else: |
| | x, c = self.get_xc(batch, N) |
| | x = x.to(device=self.device) |
| | c = c.to(device=self.device) |
| |
|
| | quant_z, z_indices = self.encode_to_z(x) |
| | quant_c, c_indices = self.encode_to_c(c) |
| |
|
| | |
| | z_start_indices = z_indices[:,:z_indices.shape[1]//2] |
| | index_sample = self.sample(z_start_indices, c_indices, |
| | steps=z_indices.shape[1]-z_start_indices.shape[1], |
| | temperature=temperature if temperature is not None else 1.0, |
| | sample=True, |
| | top_k=top_k if top_k is not None else 100, |
| | callback=callback if callback is not None else lambda k: None) |
| | x_sample = self.decode_to_img(index_sample, quant_z.shape) |
| |
|
| | |
| | z_start_indices = z_indices[:, :0] |
| | index_sample = self.sample(z_start_indices, c_indices, |
| | steps=z_indices.shape[1], |
| | temperature=temperature if temperature is not None else 1.0, |
| | sample=True, |
| | top_k=top_k if top_k is not None else 100, |
| | callback=callback if callback is not None else lambda k: None) |
| | x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) |
| |
|
| | |
| | z_start_indices = z_indices[:, :0] |
| | index_sample = self.sample(z_start_indices, c_indices, |
| | steps=z_indices.shape[1], |
| | sample=False, |
| | callback=callback if callback is not None else lambda k: None) |
| | x_sample_det = self.decode_to_img(index_sample, quant_z.shape) |
| |
|
| | |
| | x_rec = self.decode_to_img(z_indices, quant_z.shape) |
| |
|
| | log["inputs"] = x |
| | log["reconstructions"] = x_rec |
| |
|
| | if self.cond_stage_key != "image" or self.cond_stage_key != "nucleus" or self.cond_stage_key != "target": |
| | cond_rec = self.cond_stage_model.decode(quant_c) |
| | if self.cond_stage_key == "segmentation": |
| | |
| | num_classes = cond_rec.shape[1] |
| |
|
| | c = torch.argmax(c, dim=1, keepdim=True) |
| | c = F.one_hot(c, num_classes=num_classes) |
| | c = c.squeeze(1).permute(0, 3, 1, 2).float() |
| | c = self.cond_stage_model.to_rgb(c) |
| |
|
| | cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) |
| | cond_rec = F.one_hot(cond_rec, num_classes=num_classes) |
| | cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() |
| | cond_rec = self.cond_stage_model.to_rgb(cond_rec) |
| | log["conditioning_rec"] = cond_rec |
| | log["conditioning"] = c |
| |
|
| | log["samples_half"] = x_sample |
| | log["samples_nopix"] = x_sample_nopix |
| | log["samples_det"] = x_sample_det |
| | return log |
| |
|
| | def get_input(self, key, batch): |
| | x = batch[key] |
| | if len(x.shape) == 3: |
| | x = x[..., None] |
| | |
| | |
| | if x.dtype == torch.double: |
| | x = x.float() |
| | return x |
| |
|
| | def get_xc(self, batch, N=None): |
| | x = self.get_input(self.first_stage_key, batch) |
| | c = self.get_input(self.cond_stage_key, batch) |
| | if N is not None: |
| | x = x[:N] |
| | c = c[:N] |
| | return x, c |
| |
|
| | def shared_step(self, batch): |
| | x, c = self.get_xc(batch) |
| | logits, target = self(x, c) |
| | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) |
| | return loss |
| |
|
| | def training_step(self, batch, batch_idx): |
| | loss = self.shared_step(batch) |
| | self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | loss = self.shared_step(batch) |
| | self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) |
| | return loss |
| |
|
| | def configure_optimizers(self): |
| | """ |
| | Following minGPT: |
| | This long function is unfortunately doing something very simple and is being very defensive: |
| | We are separating out all parameters of the model into two buckets: those that will experience |
| | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). |
| | We are then returning the PyTorch optimizer object. |
| | """ |
| | |
| | decay = set() |
| | no_decay = set() |
| | whitelist_weight_modules = (torch.nn.Linear, ) |
| | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) |
| | for mn, m in self.transformer.named_modules(): |
| | for pn, p in m.named_parameters(): |
| | fpn = '%s.%s' % (mn, pn) if mn else pn |
| |
|
| | if pn.endswith('bias'): |
| | |
| | no_decay.add(fpn) |
| | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): |
| | |
| | decay.add(fpn) |
| | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): |
| | |
| | no_decay.add(fpn) |
| |
|
| | |
| | no_decay.add('pos_emb') |
| |
|
| | |
| | param_dict = {pn: p for pn, p in self.transformer.named_parameters()} |
| | inter_params = decay & no_decay |
| | union_params = decay | no_decay |
| | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) |
| | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ |
| | % (str(param_dict.keys() - union_params), ) |
| |
|
| | |
| | optim_groups = [ |
| | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, |
| | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, |
| | ] |
| | optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) |
| | return optimizer |
| |
|