| from tile_methods.abstractdiffusion import AbstractDiffusion |
| from tile_utils.utils import * |
| import torch.nn.functional as F |
| import random |
| from copy import deepcopy |
| import inspect |
| from modules import sd_samplers_common |
|
|
|
|
| class DemoFusion(AbstractDiffusion): |
| """ |
| DemoFusion Implementation |
| https://arxiv.org/abs/2311.16973 |
| """ |
|
|
| def __init__(self, p:Processing, *args, **kwargs): |
| super().__init__(p, *args, **kwargs) |
| assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' |
|
|
|
|
| def hook(self): |
| steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) |
|
|
| self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward |
| self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward |
| self.sampler.model_wrap_cfg.forward = self.forward_one_step |
| if self.is_kdiff: |
| self.sampler: KDiffusionSampler |
| self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion |
| self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] |
| else: |
| self.sampler: CompVisSampler |
| self.sampler.model_wrap_cfg: CFGDenoiserTimesteps |
| self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] |
| self.timesteps = self.sampler.get_timesteps(self.p, steps) |
|
|
| @staticmethod |
| def unhook(): |
| if hasattr(shared.sd_model, 'apply_model_ori'): |
| shared.sd_model.apply_model = shared.sd_model.apply_model_ori |
| del shared.sd_model.apply_model_ori |
|
|
| def reset_buffer(self, x_in:Tensor): |
| super().reset_buffer(x_in) |
|
|
|
|
|
|
| def repeat_tensor(self, x:Tensor, n:int) -> Tensor: |
| ''' repeat the tensor on it's first dim ''' |
| if n == 1: return x |
| B = x.shape[0] |
| r_dims = len(x.shape) - 1 |
| if B == 1: |
| shape = [n] + [-1] * r_dims |
| return x.expand(shape) |
| else: |
| shape = [n] + [1] * r_dims |
| return x.repeat(shape) |
|
|
| def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict: |
| ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' |
| |
| n_rep = len(bboxes) |
| |
| tcond = self.get_tcond(cond_in) |
| tcond = self.repeat_tensor(tcond, n_rep) |
| |
| icond = self.get_icond(cond_in) |
| if icond.shape[2:] == (self.h, self.w): |
| if mode == 0: |
| if self.p.random_jitter: |
| jitter_range = self.jitter_range |
| icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) |
| icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) |
| else: |
| icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0) |
| else: |
| icond = self.repeat_tensor(icond, n_rep) |
|
|
| |
| vcond = self.get_vcond(cond_in) |
| if vcond is not None: |
| vcond = self.repeat_tensor(vcond, n_rep) |
| return self.make_cond_dict(cond_in, tcond, icond, vcond) |
|
|
|
|
| def global_split_bboxes(self): |
| cols = self.p.current_scale_num |
| rows = cols |
|
|
| bbox_list = [] |
| for row in range(rows): |
| y = row |
| for col in range(cols): |
| x = col |
| bbox = (x, y) |
| bbox_list.append(bbox) |
|
|
| return bbox_list+bbox_list if self.p.mixture else bbox_list |
|
|
| def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: |
| cols = math.ceil((w_l - overlap) / (tile_w - overlap)) |
| rows = math.ceil((h_l - overlap) / (tile_h - overlap)) |
| if rows==0: |
| rows=1 |
| if cols == 0: |
| cols=1 |
| dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 |
| dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 |
| bbox_list: List[BBox] = [] |
| self.jitter_range = 0 |
| for row in range(rows): |
| for col in range(cols): |
| h = min(int(row * dy), h_l - tile_h) |
| w = min(int(col * dx), w_l - tile_w) |
| if self.p.random_jitter: |
| self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2))) |
| jitter_range = self.jitter_range |
| w_jitter = 0 |
| h_jitter = 0 |
| if (w != 0) and (w+tile_w != w_l): |
| w_jitter = random.randint(-jitter_range, jitter_range) |
| elif (w == 0) and (w + tile_w != w_l): |
| w_jitter = random.randint(-jitter_range, 0) |
| elif (w != 0) and (w + tile_w == w_l): |
| w_jitter = random.randint(0, jitter_range) |
| if (h != 0) and (h + tile_h != h_l): |
| h_jitter = random.randint(-jitter_range, jitter_range) |
| elif (h == 0) and (h + tile_h != h_l): |
| h_jitter = random.randint(-jitter_range, 0) |
| elif (h != 0) and (h + tile_h == h_l): |
| h_jitter = random.randint(0, jitter_range) |
| h +=(h_jitter + jitter_range) |
| w += (w_jitter + jitter_range) |
|
|
| bbox = BBox(w, h, tile_w, tile_h) |
| bbox_list.append(bbox) |
| return bbox_list, None |
|
|
| @grid_bbox |
| def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int): |
| self.enable_grid_bbox = True |
| self.tile_w = self.window_size |
| self.tile_h = self.window_size |
|
|
| self.overlap = max(0, min(overlap, self.window_size - 4)) |
|
|
| self.stride = max(4,self.window_size - self.overlap) |
|
|
| |
| |
| bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights()) |
| self.num_tiles = len(bboxes) |
| self.num_batches = math.ceil(self.num_tiles / tile_bs) |
| self.tile_bs = math.ceil(len(bboxes) / self.num_batches) |
| self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] |
|
|
| global_bboxes = self.global_split_bboxes() |
| self.global_num_tiles = len(global_bboxes) |
| self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g) |
| self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) |
| self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] |
|
|
| def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): |
| x_coord = torch.arange(kernel_size, device=devices.device) |
| gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) |
| gaussian_1d = gaussian_1d / gaussian_1d.sum() |
| gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] |
| kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) |
|
|
| return kernel |
|
|
| def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): |
| channels = latents.shape[1] |
| kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) |
| blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) |
|
|
| return blurred_latents |
|
|
|
|
|
|
| ''' ↓↓↓ kernel hijacks ↓↓↓ ''' |
| @torch.no_grad() |
| @keep_signature |
| def forward_one_step(self, x_in, sigma, **kwarg): |
| if self.is_kdiff: |
| x_noisy = self.p.x + self.p.noise * sigma[0] |
| else: |
| alphas_cumprod = self.p.sd_model.alphas_cumprod |
| sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) |
| sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) |
| x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod |
|
|
| self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) |
|
|
| c1 = self.cosine_factor ** self.p.cosine_scale_1 |
|
|
| x_in = x_in*(1 - c1) + x_noisy * c1 |
|
|
| if self.p.random_jitter: |
| jitter_range = self.jitter_range |
| else: |
| jitter_range = 0 |
| x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) |
| _,_,H,W = x_in.shape |
|
|
| self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step |
| self.repeat_3 = False |
|
|
| x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg) |
| self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward |
| x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] |
|
|
| return x_out |
|
|
|
|
| @torch.no_grad() |
| @keep_signature |
| def sample_one_step(self, x_in, sigma, cond): |
| assert LatentDiffusion.apply_model |
| def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor: |
| sigma_tile = self.repeat_tensor(sigma, len(bboxes)) |
| cond_tile = self.repeat_cond_dict(cond, bboxes,mode) |
| return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) |
|
|
| def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]: |
| n_rep = len(bboxes) |
| ts_tile = self.repeat_tensor(sigma, n_rep) |
| if isinstance(cond, dict): |
| cond_tile = self.repeat_cond_dict(cond, bboxes,mode) |
| else: |
| cond_tile = self.repeat_tensor(cond, n_rep) |
| return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) |
|
|
| def repeat_func_3(x_tile:Tensor, bboxes,mode=0): |
| sigma_in_tile = sigma.repeat(len(bboxes)) |
| cond_out = self.repeat_cond_dict(cond, bboxes,mode) |
| x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) |
| return x_tile_out |
|
|
| if self.repeat_3: |
| repeat_func = repeat_func_3 |
| self.repeat_3 = False |
| elif self.is_kdiff: |
| repeat_func = repeat_func_1 |
| else: |
| repeat_func = repeat_func_2 |
| N,_,_,_ = x_in.shape |
|
|
|
|
| self.x_buffer = torch.zeros_like(x_in) |
| self.weights = torch.zeros_like(x_in) |
|
|
| for batch_id, bboxes in enumerate(self.batched_bboxes): |
| if state.interrupted: return x_in |
| x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) |
| x_tile_out = repeat_func(x_tile, bboxes) |
| |
| for i, bbox in enumerate(bboxes): |
| self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] |
| self.weights[bbox.slicer] += 1 |
| self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) |
|
|
| x_local = self.x_buffer/self.weights |
|
|
| self.x_buffer = torch.zeros_like(self.x_buffer) |
| self.weights = torch.zeros_like(self.weights) |
|
|
| std_, mean_ = x_in.std(), x_in.mean() |
| c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 |
| if self.p.gaussian_filter: |
| x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3) |
| x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_ |
|
|
| if not hasattr(self.p.sd_model, 'apply_model_ori'): |
| self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model |
| self.p.sd_model.apply_model = self.apply_model_hijack |
| x_global = torch.zeros_like(x_local) |
| jitter_range = self.jitter_range |
| end = x_global.shape[3]-jitter_range |
|
|
| current_num = 0 |
| if self.p.mixture: |
| for batch_id, bboxes in enumerate(self.global_batched_bboxes): |
| current_num += len(bboxes) |
| if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): |
| res = len(bboxes) - (current_num - self.global_num_tiles//2) |
| x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx<res else x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for idx,bbox in enumerate(bboxes)],dim=0) |
| elif current_num > (self.global_num_tiles//2): |
| x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) |
| else: |
| x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) |
|
|
| x_global_i = repeat_func(x_in_i,bboxes,mode=1) |
|
|
| if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): |
| for idx,bbox in enumerate(bboxes): |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] |
| elif current_num > (self.global_num_tiles//2): |
| for idx,bbox in enumerate(bboxes): |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] |
| else: |
| for idx,bbox in enumerate(bboxes): |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] |
| else: |
| for batch_id, bboxes in enumerate(self.global_batched_bboxes): |
| x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) |
| x_global_i = repeat_func(x_in_i,bboxes,mode=1) |
| for idx,bbox in enumerate(bboxes): |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] |
| |
| if self.p.mixture: |
| self.x_buffer +=x_global/2 |
| else: |
| self.x_buffer += x_global |
| self.weights += 1 |
|
|
| self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori |
|
|
| x_global = self.x_buffer/self.weights |
| c2 = self.cosine_factor**self.p.cosine_scale_2 |
| self.x_buffer= x_local*(1-c2)+ x_global*c2 |
|
|
| return self.x_buffer |
|
|
|
|
|
|
| @torch.no_grad() |
| @keep_signature |
| def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict): |
| assert LatentDiffusion.apply_model |
|
|
| x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) |
| return x_tile_out |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: |
| |
| cond_in_original = cond_in.copy() |
| self.repeat_3 = True |
| self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) |
| jitter_range = self.jitter_range |
| _,_,H,W = x_in.shape |
| x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) |
| return self.sample_one_step(x_in_, sigma_in, cond_in_original)[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] |
|
|