| """Complete Generator architecture: |
| * OmniGenerator |
| * Encoder |
| * Decoders |
| """ |
| from pathlib import Path |
| import traceback |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import yaml |
| from addict import Dict |
| from torch import softmax |
|
|
| import climategan.strings as strings |
| from climategan.deeplab import create_encoder, create_segmentation_decoder |
| from climategan.depth import create_depth_decoder |
| from climategan.masker import create_mask_decoder |
| from climategan.painter import create_painter |
| from climategan.tutils import init_weights, mix_noise, normalize |
|
|
|
|
| def create_generator(opts, device="cpu", latent_shape=None, no_init=False, verbose=0): |
| G = OmniGenerator(opts, latent_shape, verbose, no_init) |
| if no_init: |
| print("Sending to", device) |
| return G.to(device) |
|
|
| for model in G.decoders: |
| net = G.decoders[model] |
| if model == "s": |
| continue |
| if isinstance(net, nn.ModuleDict): |
| for domain, domain_model in net.items(): |
| init_weights( |
| net[domain_model], |
| init_type=opts.gen[model].init_type, |
| init_gain=opts.gen[model].init_gain, |
| verbose=verbose, |
| caller=f"create_generator decoder {model} {domain}", |
| ) |
| else: |
| init_weights( |
| G.decoders[model], |
| init_type=opts.gen[model].init_type, |
| init_gain=opts.gen[model].init_gain, |
| verbose=verbose, |
| caller=f"create_generator decoder {model}", |
| ) |
| if G.encoder is not None and opts.gen.encoder.architecture == "base": |
| init_weights( |
| G.encoder, |
| init_type=opts.gen.encoder.init_type, |
| init_gain=opts.gen.encoder.init_gain, |
| verbose=verbose, |
| caller="create_generator encoder", |
| ) |
|
|
| print("Sending to", device) |
| return G.to(device) |
|
|
|
|
| class OmniGenerator(nn.Module): |
| def __init__(self, opts, latent_shape=None, verbose=0, no_init=False): |
| """Creates the generator. All decoders listed in opts.gen will be added |
| to the Generator.decoders ModuleDict if opts.gen.DecoderInitial is not True. |
| Then can be accessed as G.decoders.T or G.decoders["T"] for instance, |
| for the image Translation decoder |
| |
| Args: |
| opts (addict.Dict): configuration dict |
| """ |
| super().__init__() |
| self.opts = opts |
| self.verbose = verbose |
| self.encoder = None |
| if any(t in opts.tasks for t in "msd"): |
| self.encoder = create_encoder(opts, no_init, verbose) |
|
|
| self.decoders = {} |
| self.painter = nn.Module() |
|
|
| if "d" in opts.tasks: |
| self.decoders["d"] = create_depth_decoder(opts, no_init, verbose) |
|
|
| if self.verbose > 0: |
| print(f" - Add {self.decoders['d'].__class__.__name__}") |
|
|
| if "s" in opts.tasks: |
| self.decoders["s"] = create_segmentation_decoder(opts, no_init, verbose) |
|
|
| if "m" in opts.tasks: |
| self.decoders["m"] = create_mask_decoder(opts, no_init, verbose) |
|
|
| self.decoders = nn.ModuleDict(self.decoders) |
|
|
| if "p" in self.opts.tasks: |
| self.painter = create_painter(opts, no_init, verbose) |
| else: |
| if self.verbose > 0: |
| print(" - Add Empty Painter") |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| def __str__(self): |
| return strings.generator(self) |
|
|
| def encode(self, x): |
| """ |
| Forward x through the encoder |
| |
| Args: |
| x (torch.Tensor): B3HW input tensor |
| |
| Returns: |
| list: High and Low level features from the encoder |
| """ |
| assert self.encoder is not None |
| return self.encoder.forward(x) |
|
|
| def decode(self, x=None, z=None, return_z=False, return_z_depth=False): |
| """ |
| Comptutes the predictions of all available decoders from either x or z. |
| If using spade for the masker with 15 channels, x *must* be provided, |
| whether z is too or not. |
| |
| Args: |
| x (torch.Tensor, optional): Input tensor (B3HW). Defaults to None. |
| z (list, optional): List of high and low-level features as BCHW. |
| Defaults to None. |
| return_z (bool, optional): whether or not to return z in the dict. |
| Defaults to False. |
| return_z_depth (bool, optional): whether or not to return z_depth |
| in the dict. Defaults to False. |
| |
| Raises: |
| ValueError: If using spade for the masker with 15 channels but x is None |
| |
| Returns: |
| dict: {task: prediction_tensor} (may include z and z_depth |
| depending on args) |
| """ |
|
|
| assert x is not None or z is not None |
| if self.opts.gen.m.use_spade and self.opts.m.spade.cond_nc == 15: |
| if x is None: |
| raise ValueError( |
| "When using spade for the Masker with 15 channels," |
| + " x MUST be provided" |
| ) |
|
|
| z_depth = cond = d = s = None |
| out = {} |
|
|
| if z is None: |
| z = self.encode(x) |
|
|
| if return_z: |
| out["z"] = z |
|
|
| if "d" in self.decoders: |
| d, z_depth = self.decoders["d"](z) |
| out["d"] = d |
|
|
| if return_z_depth: |
| out["z_depth"] = z_depth |
|
|
| if "s" in self.decoders: |
| s = self.decoders["s"](z, z_depth) |
| out["s"] = s |
|
|
| if "m" in self.decoders: |
| if s is not None and d is not None: |
| cond = self.make_m_cond(d, s, x) |
| m = self.mask(z=z, cond=cond) |
| out["m"] = m |
|
|
| return out |
|
|
| def sample_painter_z(self, batch_size, device, force_half=False): |
| if self.opts.gen.p.no_z: |
| return None |
|
|
| z = torch.empty( |
| batch_size, |
| self.opts.gen.p.latent_dim, |
| self.painter.z_h, |
| self.painter.z_w, |
| device=device, |
| ).normal_(mean=0, std=1.0) |
|
|
| if force_half: |
| z = z.half() |
|
|
| return z |
|
|
| def make_m_cond(self, d, s, x=None): |
| """ |
| Create the masker's conditioning input when using spade from the |
| d and s predictions and from the input x when cond_nc == 15. |
| |
| d and s are assumed to have the the same spatial resolution. |
| if cond_nc == 15 then x is interpolated to match that dimension. |
| |
| Args: |
| d (torch.Tensor): Raw depth prediction (B1HW) |
| s (torch.Tensor): Raw segmentation prediction (BCHW) |
| x (torch.Tensor, optional): Input tensor (B3hW). Mandatory |
| when opts.gen.m.spade.cond_nc == 15 |
| |
| Raises: |
| ValueError: opts.gen.m.spade.cond_nc == 15 but x is None |
| |
| Returns: |
| torch.Tensor: B x cond_nc x H x W conditioning tensor. |
| """ |
| if self.opts.gen.m.spade.detach: |
| d = d.detach() |
| s = s.detach() |
| cats = [normalize(d), softmax(s, dim=1)] |
| if self.opts.gen.m.spade.cond_nc == 15: |
| if x is None: |
| raise ValueError( |
| "When using spade for the Masker with 15 channels," |
| + " x MUST be provided" |
| ) |
| cats += [ |
| F.interpolate(x, s.shape[-2:], mode="bilinear", align_corners=True) |
| ] |
|
|
| return torch.cat(cats, dim=1) |
|
|
| def mask(self, x=None, z=None, cond=None, z_depth=None, sigmoid=True): |
| """ |
| Create a mask from either an input x or a latent vector z. |
| Optionally if the Masker has a spade architecture the conditioning tensor |
| may be provided (cond). Default behavior applies an element-wise |
| sigmoid, but can be deactivated (sigmoid=False). |
| |
| At least one of x or z must be provided (i.e. not None). |
| If the Masker has a spade architecture and cond_nc == 15 then x cannot |
| be None. |
| |
| Args: |
| x (torch.Tensor, optional): Input tensor B3HW. Defaults to None. |
| z (list, optional): High and Low level features of the encoder. |
| Will be computed if None. Defaults to None. |
| cond ([type], optional): [description]. Defaults to None. |
| sigmoid (bool, optional): [description]. Defaults to True. |
| |
| Returns: |
| torch.Tensor: B1HW mask tensor |
| """ |
| assert x is not None or z is not None |
| if z is None: |
| z = self.encode(x) |
|
|
| if cond is None and self.opts.gen.m.use_spade: |
| assert "s" in self.opts.tasks and "d" in self.opts.tasks |
| with torch.no_grad(): |
| d_pred, z_d = self.decoders["d"](z) |
| s_pred = self.decoders["s"](z, z_d) |
| cond = self.make_m_cond(d_pred, s_pred, x) |
| if z_depth is None and self.opts.gen.m.use_dada: |
| assert "d" in self.opts.tasks |
| with torch.no_grad(): |
| _, z_depth = self.decoders["d"](z) |
|
|
| if cond is not None: |
| device = z[0].device if isinstance(z, (tuple, list)) else z.device |
| cond = cond.to(device) |
|
|
| logits = self.decoders["m"](z, cond, z_depth) |
|
|
| if not sigmoid: |
| return logits |
|
|
| return torch.sigmoid(logits) |
|
|
| def paint(self, m, x, no_paste=False): |
| """ |
| Paints given a mask and an image |
| calls painter(z, x * (1.0 - m)) |
| Mask has 1s where water should be painted |
| |
| Args: |
| m (torch.Tensor): Mask |
| x (torch.Tensor): Image to paint |
| |
| Returns: |
| torch.Tensor: painted image |
| """ |
| z_paint = self.sample_painter_z(x.shape[0], x.device) |
| m = m.to(x.dtype) |
| fake = self.painter(z_paint, x * (1.0 - m)) |
| if self.opts.gen.p.paste_original_content and not no_paste: |
| return x * (1.0 - m) + fake * m |
| return fake |
|
|
| def paint_cloudy(self, m, x, s, sky_idx=9, res=(8, 8), weight=0.8): |
| """ |
| Paints x with water in m through an intermediary cloudy image |
| where the sky has been replaced with perlin noise to imitate clouds. |
| |
| The intermediary cloudy image is only used to control the painter's |
| painting mode, probing it with a cloudy input. |
| |
| Args: |
| m (torch.Tensor): water mask |
| x (torch.Tensor): input tensor |
| s (torch.Tensor): segmentation prediction (BCHW) |
| sky_idx (int, optional): Index of the sky class along s's C dimension. |
| Defaults to 9. |
| res (tuple, optional): Perlin noise spatial resolution. Defaults to (8, 8). |
| weight (float, optional): Intermediate image's cloud proportion |
| (w * cloud + (1-w) * original_sky). Defaults to 0.8. |
| |
| Returns: |
| torch.Tensor: painted image with original content pasted. |
| """ |
| sky_mask = ( |
| torch.argmax( |
| F.interpolate(s, x.shape[-2:], mode="bilinear"), dim=1, keepdim=True |
| ) |
| == sky_idx |
| ).to(x.dtype) |
| noised_x = mix_noise(x, sky_mask, res=res, weight=weight).to(x.dtype) |
| fake = self.paint(m, noised_x, no_paste=True) |
| return x * (1.0 - m) + fake * m |
|
|
| def depth(self, x=None, z=None, return_z_depth=False): |
| """ |
| Compute the depth head's output |
| |
| Args: |
| x (torch.Tensor, optional): Input B3HW tensor. Defaults to None. |
| z (list, optional): High and Low level features of the encoder. |
| Defaults to None. |
| |
| Returns: |
| torch.Tensor: B1HW tensor of depth predictions |
| """ |
| assert x is not None or z is not None |
| assert not (x is not None and z is not None) |
| if z is None: |
| z = self.encode(x) |
| depth, z_depth = self.decoders["d"](z) |
|
|
| if depth.shape[1] > 1: |
| depth = torch.argmax(depth, dim=1) |
| depth = depth / depth.max() |
|
|
| if return_z_depth: |
| return depth, z_depth |
|
|
| return depth |
|
|
| def load_val_painter(self): |
| """ |
| Loads a validation painter if available in opts.val.val_painter |
| |
| Returns: |
| bool: operation success status |
| """ |
| try: |
| |
| assert self.opts.val.val_painter |
|
|
| |
| ckpt_path = Path(self.opts.val.val_painter).resolve() |
| assert ckpt_path.exists() |
|
|
| |
| assert ckpt_path.is_file() |
|
|
| |
| opts_path = ckpt_path.parent.parent / "opts.yaml" |
| assert opts_path.exists() |
|
|
| |
| with opts_path.open("r") as f: |
| val_painter_opts = Dict(yaml.safe_load(f)) |
|
|
| |
| state_dict = torch.load(ckpt_path, map_location=self.device) |
|
|
| |
| painter = create_painter(val_painter_opts) |
|
|
| |
| painter.load_state_dict( |
| {k.replace("painter.", ""): v for k, v in state_dict["G"].items()} |
| ) |
|
|
| |
| device = next(self.parameters()).device |
| self.painter = painter.eval().to(device) |
|
|
| |
| for p in self.painter.parameters(): |
| p.requires_grad = False |
|
|
| |
| print(" - Loaded validation-only painter") |
| return True |
|
|
| except Exception as e: |
| |
| print(traceback.format_exc()) |
| print(e) |
| print(">>> WARNING: error (^) in load_val_painter, aborting.") |
| return False |
|
|