Spaces:
Configuration error
Configuration error
| import numpy as np | |
| import yaml | |
| import argparse | |
| import math | |
| import torch | |
| from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import * | |
| from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL | |
| # from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.transmodel import TransModel | |
| from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet | |
| from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import * | |
| from fvcore.common.config import CfgNode | |
| from pathlib import Path | |
| def load_conf(config_file, conf={}): | |
| with open(config_file) as f: | |
| exp_conf = yaml.load(f, Loader=yaml.FullLoader) | |
| for k, v in exp_conf.items(): | |
| conf[k] = v | |
| return conf | |
| def prepare_args(ckpt_path, sampling_timesteps=1): | |
| return argparse.Namespace( | |
| cfg=load_conf(Path(__file__).parent / "default.yaml"), | |
| pre_weight=ckpt_path, | |
| sampling_timesteps=sampling_timesteps | |
| ) | |
| class DiffusionEdge: | |
| def __init__(self, args) -> None: | |
| self.cfg = CfgNode(args.cfg) | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| model_cfg = self.cfg.model | |
| first_stage_cfg = model_cfg.first_stage | |
| first_stage_model = AutoencoderKL( | |
| ddconfig=first_stage_cfg.ddconfig, | |
| lossconfig=first_stage_cfg.lossconfig, | |
| embed_dim=first_stage_cfg.embed_dim, | |
| ckpt_path=first_stage_cfg.ckpt_path, | |
| ) | |
| if model_cfg.model_name == 'cond_unet': | |
| from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet | |
| unet_cfg = model_cfg.unet | |
| unet = Unet(dim=unet_cfg.dim, | |
| channels=unet_cfg.channels, | |
| dim_mults=unet_cfg.dim_mults, | |
| learned_variance=unet_cfg.get('learned_variance', False), | |
| out_mul=unet_cfg.out_mul, | |
| cond_in_dim=unet_cfg.cond_in_dim, | |
| cond_dim=unet_cfg.cond_dim, | |
| cond_dim_mults=unet_cfg.cond_dim_mults, | |
| window_sizes1=unet_cfg.window_sizes1, | |
| window_sizes2=unet_cfg.window_sizes2, | |
| fourier_scale=unet_cfg.fourier_scale, | |
| cfg=unet_cfg, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if model_cfg.model_type == 'const_sde': | |
| from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion | |
| else: | |
| raise NotImplementedError(f'{model_cfg.model_type} is not surportted !') | |
| self.model = LatentDiffusion( | |
| model=unet, | |
| auto_encoder=first_stage_model, | |
| train_sample=model_cfg.train_sample, | |
| image_size=model_cfg.image_size, | |
| timesteps=model_cfg.timesteps, | |
| sampling_timesteps=args.sampling_timesteps, | |
| loss_type=model_cfg.loss_type, | |
| objective=model_cfg.objective, | |
| scale_factor=model_cfg.scale_factor, | |
| scale_by_std=model_cfg.scale_by_std, | |
| scale_by_softsign=model_cfg.scale_by_softsign, | |
| default_scale=model_cfg.get('default_scale', False), | |
| input_keys=model_cfg.input_keys, | |
| ckpt_path=model_cfg.ckpt_path, | |
| ignore_keys=model_cfg.ignore_keys, | |
| only_model=model_cfg.only_model, | |
| start_dist=model_cfg.start_dist, | |
| perceptual_weight=model_cfg.perceptual_weight, | |
| use_l1=model_cfg.get('use_l1', True), | |
| cfg=model_cfg, | |
| ) | |
| self.cfg.sampler.ckpt_path = args.pre_weight | |
| data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu") | |
| if self.cfg.sampler.use_ema: | |
| sd = data['ema'] | |
| new_sd = {} | |
| for k in sd.keys(): | |
| if k.startswith("ema_model."): | |
| new_k = k[10:] # remove ema_model. | |
| new_sd[new_k] = sd[k] | |
| sd = new_sd | |
| self.model.load_state_dict(sd) | |
| else: | |
| self.model.load_state_dict(data['model']) | |
| if 'scale_factor' in data['model']: | |
| self.model.scale_factor = data['model']['scale_factor'] | |
| self.model.eval() | |
| self.device = "cpu" | |
| def to(self, device): | |
| self.model.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, image, batch_size=8): | |
| image = normalize_to_neg_one_to_one(image).to(self.device) | |
| mask = None | |
| if self.cfg.sampler.sample_type == 'whole': | |
| return self.whole_sample(image, raw_size=image.shape[2:], mask=mask) | |
| elif self.cfg.sampler.sample_type == 'slide': | |
| return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]), | |
| stride=self.cfg.sampler.stride, mask=mask, bs=batch_size) | |
| def whole_sample(self, inputs, raw_size, mask=None): | |
| inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True) | |
| seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask) | |
| seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True) | |
| return seg_logits | |
| def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8): | |
| """Inference by sliding-window with overlap. | |
| If h_crop > h_img or w_crop > w_img, the small patch will be used to | |
| decode without padding. | |
| Args: | |
| inputs (tensor): the tensor should have a shape NxCxHxW, | |
| which contains all images in the batch. | |
| batch_img_metas (List[dict]): List of image metainfo where each may | |
| also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', | |
| 'ori_shape', and 'pad_shape'. | |
| For details on the values of these keys see | |
| `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. | |
| Returns: | |
| Tensor: The segmentation results, seg_logits from model of each | |
| input image. | |
| """ | |
| h_stride, w_stride = stride | |
| h_crop, w_crop = crop_size | |
| batch_size, _, h_img, w_img = inputs.size() | |
| out_channels = 1 | |
| h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 | |
| w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 | |
| preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) | |
| # aux_out1 = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) | |
| # aux_out2 = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) | |
| count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) | |
| crop_imgs = [] | |
| x1s = [] | |
| x2s = [] | |
| y1s = [] | |
| y2s = [] | |
| for h_idx in range(h_grids): | |
| for w_idx in range(w_grids): | |
| y1 = h_idx * h_stride | |
| x1 = w_idx * w_stride | |
| y2 = min(y1 + h_crop, h_img) | |
| x2 = min(x1 + w_crop, w_img) | |
| y1 = max(y2 - h_crop, 0) | |
| x1 = max(x2 - w_crop, 0) | |
| crop_img = inputs[:, :, y1:y2, x1:x2] | |
| crop_imgs.append(crop_img) | |
| x1s.append(x1) | |
| x2s.append(x2) | |
| y1s.append(y1) | |
| y2s.append(y2) | |
| crop_imgs = torch.cat(crop_imgs, dim=0) | |
| crop_seg_logits_list = [] | |
| num_windows = crop_imgs.shape[0] | |
| bs = bs | |
| length = math.ceil(num_windows / bs) | |
| for i in range(length): | |
| if i == length - 1: | |
| crop_imgs_temp = crop_imgs[bs * i:num_windows, ...] | |
| else: | |
| crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...] | |
| crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask) | |
| crop_seg_logits_list.append(crop_seg_logits) | |
| crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0) | |
| for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s): | |
| preds += F.pad(crop_seg_logit, | |
| (int(x1), int(preds.shape[3] - x2), int(y1), | |
| int(preds.shape[2] - y2))) | |
| count_mat[:, :, y1:y2, x1:x2] += 1 | |
| assert (count_mat == 0).sum() == 0 | |
| seg_logits = preds / count_mat | |
| return seg_logits | |