|
|
import numpy as np |
|
|
import yaml |
|
|
import argparse |
|
|
import math |
|
|
import torch |
|
|
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import * |
|
|
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL |
|
|
|
|
|
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet |
|
|
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import * |
|
|
from fvcore.common.config import CfgNode |
|
|
from pathlib import Path |
|
|
|
|
|
def load_conf(config_file, conf={}): |
|
|
with open(config_file) as f: |
|
|
exp_conf = yaml.load(f, Loader=yaml.FullLoader) |
|
|
for k, v in exp_conf.items(): |
|
|
conf[k] = v |
|
|
return conf |
|
|
|
|
|
def prepare_args(ckpt_path, sampling_timesteps=1): |
|
|
return argparse.Namespace( |
|
|
cfg=load_conf(Path(__file__).parent / "default.yaml"), |
|
|
pre_weight=ckpt_path, |
|
|
sampling_timesteps=sampling_timesteps |
|
|
) |
|
|
|
|
|
class DiffusionEdge: |
|
|
def __init__(self, args) -> None: |
|
|
self.cfg = CfgNode(args.cfg) |
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
model_cfg = self.cfg.model |
|
|
first_stage_cfg = model_cfg.first_stage |
|
|
first_stage_model = AutoencoderKL( |
|
|
ddconfig=first_stage_cfg.ddconfig, |
|
|
lossconfig=first_stage_cfg.lossconfig, |
|
|
embed_dim=first_stage_cfg.embed_dim, |
|
|
ckpt_path=first_stage_cfg.ckpt_path, |
|
|
) |
|
|
if model_cfg.model_name == 'cond_unet': |
|
|
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet |
|
|
unet_cfg = model_cfg.unet |
|
|
unet = Unet(dim=unet_cfg.dim, |
|
|
channels=unet_cfg.channels, |
|
|
dim_mults=unet_cfg.dim_mults, |
|
|
learned_variance=unet_cfg.get('learned_variance', False), |
|
|
out_mul=unet_cfg.out_mul, |
|
|
cond_in_dim=unet_cfg.cond_in_dim, |
|
|
cond_dim=unet_cfg.cond_dim, |
|
|
cond_dim_mults=unet_cfg.cond_dim_mults, |
|
|
window_sizes1=unet_cfg.window_sizes1, |
|
|
window_sizes2=unet_cfg.window_sizes2, |
|
|
fourier_scale=unet_cfg.fourier_scale, |
|
|
cfg=unet_cfg, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if model_cfg.model_type == 'const_sde': |
|
|
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion |
|
|
else: |
|
|
raise NotImplementedError(f'{model_cfg.model_type} is not surportted !') |
|
|
|
|
|
self.model = LatentDiffusion( |
|
|
model=unet, |
|
|
auto_encoder=first_stage_model, |
|
|
train_sample=model_cfg.train_sample, |
|
|
image_size=model_cfg.image_size, |
|
|
timesteps=model_cfg.timesteps, |
|
|
sampling_timesteps=args.sampling_timesteps, |
|
|
loss_type=model_cfg.loss_type, |
|
|
objective=model_cfg.objective, |
|
|
scale_factor=model_cfg.scale_factor, |
|
|
scale_by_std=model_cfg.scale_by_std, |
|
|
scale_by_softsign=model_cfg.scale_by_softsign, |
|
|
default_scale=model_cfg.get('default_scale', False), |
|
|
input_keys=model_cfg.input_keys, |
|
|
ckpt_path=model_cfg.ckpt_path, |
|
|
ignore_keys=model_cfg.ignore_keys, |
|
|
only_model=model_cfg.only_model, |
|
|
start_dist=model_cfg.start_dist, |
|
|
perceptual_weight=model_cfg.perceptual_weight, |
|
|
use_l1=model_cfg.get('use_l1', True), |
|
|
cfg=model_cfg, |
|
|
) |
|
|
self.cfg.sampler.ckpt_path = args.pre_weight |
|
|
|
|
|
data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu") |
|
|
if self.cfg.sampler.use_ema: |
|
|
sd = data['ema'] |
|
|
new_sd = {} |
|
|
for k in sd.keys(): |
|
|
if k.startswith("ema_model."): |
|
|
new_k = k[10:] |
|
|
new_sd[new_k] = sd[k] |
|
|
sd = new_sd |
|
|
self.model.load_state_dict(sd) |
|
|
else: |
|
|
self.model.load_state_dict(data['model']) |
|
|
if 'scale_factor' in data['model']: |
|
|
self.model.scale_factor = data['model']['scale_factor'] |
|
|
|
|
|
self.model.eval() |
|
|
self.device = "cpu" |
|
|
|
|
|
def to(self, device): |
|
|
self.model.to(device) |
|
|
self.device = device |
|
|
return self |
|
|
|
|
|
def __call__(self, image, batch_size=8): |
|
|
image = normalize_to_neg_one_to_one(image).to(self.device) |
|
|
mask = None |
|
|
if self.cfg.sampler.sample_type == 'whole': |
|
|
return self.whole_sample(image, raw_size=image.shape[2:], mask=mask) |
|
|
elif self.cfg.sampler.sample_type == 'slide': |
|
|
return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]), |
|
|
stride=self.cfg.sampler.stride, mask=mask, bs=batch_size) |
|
|
|
|
|
def whole_sample(self, inputs, raw_size, mask=None): |
|
|
inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True) |
|
|
seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask) |
|
|
seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True) |
|
|
return seg_logits |
|
|
|
|
|
def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8): |
|
|
"""Inference by sliding-window with overlap. |
|
|
|
|
|
If h_crop > h_img or w_crop > w_img, the small patch will be used to |
|
|
decode without padding. |
|
|
|
|
|
Args: |
|
|
inputs (tensor): the tensor should have a shape NxCxHxW, |
|
|
which contains all images in the batch. |
|
|
batch_img_metas (List[dict]): List of image metainfo where each may |
|
|
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', |
|
|
'ori_shape', and 'pad_shape'. |
|
|
For details on the values of these keys see |
|
|
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`. |
|
|
|
|
|
Returns: |
|
|
Tensor: The segmentation results, seg_logits from model of each |
|
|
input image. |
|
|
""" |
|
|
|
|
|
h_stride, w_stride = stride |
|
|
h_crop, w_crop = crop_size |
|
|
batch_size, _, h_img, w_img = inputs.size() |
|
|
out_channels = 1 |
|
|
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 |
|
|
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 |
|
|
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) |
|
|
|
|
|
|
|
|
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) |
|
|
crop_imgs = [] |
|
|
x1s = [] |
|
|
x2s = [] |
|
|
y1s = [] |
|
|
y2s = [] |
|
|
for h_idx in range(h_grids): |
|
|
for w_idx in range(w_grids): |
|
|
y1 = h_idx * h_stride |
|
|
x1 = w_idx * w_stride |
|
|
y2 = min(y1 + h_crop, h_img) |
|
|
x2 = min(x1 + w_crop, w_img) |
|
|
y1 = max(y2 - h_crop, 0) |
|
|
x1 = max(x2 - w_crop, 0) |
|
|
crop_img = inputs[:, :, y1:y2, x1:x2] |
|
|
crop_imgs.append(crop_img) |
|
|
x1s.append(x1) |
|
|
x2s.append(x2) |
|
|
y1s.append(y1) |
|
|
y2s.append(y2) |
|
|
crop_imgs = torch.cat(crop_imgs, dim=0) |
|
|
crop_seg_logits_list = [] |
|
|
num_windows = crop_imgs.shape[0] |
|
|
bs = bs |
|
|
length = math.ceil(num_windows / bs) |
|
|
for i in range(length): |
|
|
if i == length - 1: |
|
|
crop_imgs_temp = crop_imgs[bs * i:num_windows, ...] |
|
|
else: |
|
|
crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...] |
|
|
|
|
|
crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask) |
|
|
crop_seg_logits_list.append(crop_seg_logits) |
|
|
crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0) |
|
|
for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s): |
|
|
preds += F.pad(crop_seg_logit, |
|
|
(int(x1), int(preds.shape[3] - x2), int(y1), |
|
|
int(preds.shape[2] - y2))) |
|
|
count_mat[:, :, y1:y2, x1:x2] += 1 |
|
|
|
|
|
assert (count_mat == 0).sum() == 0 |
|
|
seg_logits = preds / count_mat |
|
|
return seg_logits |
|
|
|