aasdfdas / extensions /CHECK /multidiffusion-upscaler-for-automatic1111 /tile_methods /demofusion.py
| from tile_methods.abstractdiffusion import AbstractDiffusion | |
| from tile_utils.utils import * | |
| import torch.nn.functional as F | |
| import random | |
| from copy import deepcopy | |
| import inspect | |
| from modules import sd_samplers_common | |
| class DemoFusion(AbstractDiffusion): | |
| """ | |
| DemoFusion Implementation | |
| https://arxiv.org/abs/2311.16973 | |
| """ | |
| def __init__(self, p:Processing, *args, **kwargs): | |
| super().__init__(p, *args, **kwargs) | |
| assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' | |
| def hook(self): | |
| steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) | |
| self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward | |
| self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward | |
| self.sampler.model_wrap_cfg.forward = self.forward_one_step | |
| if self.is_kdiff: | |
| self.sampler: KDiffusionSampler | |
| self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion | |
| self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] | |
| else: | |
| self.sampler: CompVisSampler | |
| self.sampler.model_wrap_cfg: CFGDenoiserTimesteps | |
| self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] | |
| self.timesteps = self.sampler.get_timesteps(self.p, steps) | |
| def unhook(): | |
| if hasattr(shared.sd_model, 'apply_model_ori'): | |
| shared.sd_model.apply_model = shared.sd_model.apply_model_ori | |
| del shared.sd_model.apply_model_ori | |
| def reset_buffer(self, x_in:Tensor): | |
| super().reset_buffer(x_in) | |
| def repeat_tensor(self, x:Tensor, n:int) -> Tensor: | |
| ''' repeat the tensor on it's first dim ''' | |
| if n == 1: return x | |
| B = x.shape[0] | |
| r_dims = len(x.shape) - 1 | |
| if B == 1: # batch_size = 1 (not `tile_batch_size`) | |
| shape = [n] + [-1] * r_dims # [N, -1, ...] | |
| return x.expand(shape) # `expand` is much lighter than `tile` | |
| else: | |
| shape = [n] + [1] * r_dims # [N, 1, ...] | |
| return x.repeat(shape) | |
| def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict: | |
| ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' | |
| # n_repeat | |
| n_rep = len(bboxes) | |
| # txt cond | |
| tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] | |
| tcond = self.repeat_tensor(tcond, n_rep) | |
| # img cond | |
| icond = self.get_icond(cond_in) | |
| if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] | |
| if mode == 0: | |
| if self.p.random_jitter: | |
| jitter_range = self.jitter_range | |
| icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) | |
| icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) | |
| else: | |
| icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0) | |
| else: # txt2img, [B=1, C=5, H=1, W=1] | |
| icond = self.repeat_tensor(icond, n_rep) | |
| # vec cond (SDXL) | |
| vcond = self.get_vcond(cond_in) # [B=1, D] | |
| if vcond is not None: | |
| vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] | |
| return self.make_cond_dict(cond_in, tcond, icond, vcond) | |
| def global_split_bboxes(self): | |
| cols = self.p.current_scale_num | |
| rows = cols | |
| bbox_list = [] | |
| for row in range(rows): | |
| y = row | |
| for col in range(cols): | |
| x = col | |
| bbox = (x, y) | |
| bbox_list.append(bbox) | |
| return bbox_list+bbox_list if self.p.mixture else bbox_list | |
| def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: | |
| cols = math.ceil((w_l - overlap) / (tile_w - overlap)) | |
| rows = math.ceil((h_l - overlap) / (tile_h - overlap)) | |
| if rows==0: | |
| rows=1 | |
| if cols == 0: | |
| cols=1 | |
| dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 | |
| dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 | |
| bbox_list: List[BBox] = [] | |
| self.jitter_range = 0 | |
| for row in range(rows): | |
| for col in range(cols): | |
| h = min(int(row * dy), h_l - tile_h) | |
| w = min(int(col * dx), w_l - tile_w) | |
| if self.p.random_jitter: | |
| self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2))) | |
| jitter_range = self.jitter_range | |
| w_jitter = 0 | |
| h_jitter = 0 | |
| if (w != 0) and (w+tile_w != w_l): | |
| w_jitter = random.randint(-jitter_range, jitter_range) | |
| elif (w == 0) and (w + tile_w != w_l): | |
| w_jitter = random.randint(-jitter_range, 0) | |
| elif (w != 0) and (w + tile_w == w_l): | |
| w_jitter = random.randint(0, jitter_range) | |
| if (h != 0) and (h + tile_h != h_l): | |
| h_jitter = random.randint(-jitter_range, jitter_range) | |
| elif (h == 0) and (h + tile_h != h_l): | |
| h_jitter = random.randint(-jitter_range, 0) | |
| elif (h != 0) and (h + tile_h == h_l): | |
| h_jitter = random.randint(0, jitter_range) | |
| h +=(h_jitter + jitter_range) | |
| w += (w_jitter + jitter_range) | |
| bbox = BBox(w, h, tile_w, tile_h) | |
| bbox_list.append(bbox) | |
| return bbox_list, None | |
| def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int): | |
| self.enable_grid_bbox = True | |
| self.tile_w = self.window_size | |
| self.tile_h = self.window_size | |
| self.overlap = max(0, min(overlap, self.window_size - 4)) | |
| self.stride = max(4,self.window_size - self.overlap) | |
| # split the latent into overlapped tiles, then batching | |
| # weights basically indicate how many times a pixel is painted | |
| bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights()) | |
| self.num_tiles = len(bboxes) | |
| self.num_batches = math.ceil(self.num_tiles / tile_bs) | |
| self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size | |
| self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] | |
| global_bboxes = self.global_split_bboxes() | |
| self.global_num_tiles = len(global_bboxes) | |
| self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g) | |
| self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) | |
| self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] | |
| def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): | |
| x_coord = torch.arange(kernel_size, device=devices.device) | |
| gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) | |
| gaussian_1d = gaussian_1d / gaussian_1d.sum() | |
| gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] | |
| kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) | |
| return kernel | |
| def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): | |
| channels = latents.shape[1] | |
| kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) | |
| blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) | |
| return blurred_latents | |
| ''' ↓↓↓ kernel hijacks ↓↓↓ ''' | |
| def forward_one_step(self, x_in, sigma, **kwarg): | |
| if self.is_kdiff: | |
| x_noisy = self.p.x + self.p.noise * sigma[0] | |
| else: | |
| alphas_cumprod = self.p.sd_model.alphas_cumprod | |
| sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) | |
| sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) | |
| x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod | |
| self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) | |
| c1 = self.cosine_factor ** self.p.cosine_scale_1 | |
| x_in = x_in*(1 - c1) + x_noisy * c1 | |
| if self.p.random_jitter: | |
| jitter_range = self.jitter_range | |
| else: | |
| jitter_range = 0 | |
| x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) | |
| _,_,H,W = x_in.shape | |
| self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step | |
| self.repeat_3 = False | |
| x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg) | |
| self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward | |
| x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] | |
| return x_out | |
| def sample_one_step(self, x_in, sigma, cond): | |
| assert LatentDiffusion.apply_model | |
| def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor: | |
| sigma_tile = self.repeat_tensor(sigma, len(bboxes)) | |
| cond_tile = self.repeat_cond_dict(cond, bboxes,mode) | |
| return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) | |
| def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]: | |
| n_rep = len(bboxes) | |
| ts_tile = self.repeat_tensor(sigma, n_rep) | |
| if isinstance(cond, dict): # FIXME: when will enter this branch? | |
| cond_tile = self.repeat_cond_dict(cond, bboxes,mode) | |
| else: | |
| cond_tile = self.repeat_tensor(cond, n_rep) | |
| return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) | |
| def repeat_func_3(x_tile:Tensor, bboxes,mode=0): | |
| sigma_in_tile = sigma.repeat(len(bboxes)) | |
| cond_out = self.repeat_cond_dict(cond, bboxes,mode) | |
| x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) | |
| return x_tile_out | |
| if self.repeat_3: | |
| repeat_func = repeat_func_3 | |
| self.repeat_3 = False | |
| elif self.is_kdiff: | |
| repeat_func = repeat_func_1 | |
| else: | |
| repeat_func = repeat_func_2 | |
| N,_,_,_ = x_in.shape | |
| self.x_buffer = torch.zeros_like(x_in) | |
| self.weights = torch.zeros_like(x_in) | |
| for batch_id, bboxes in enumerate(self.batched_bboxes): | |
| if state.interrupted: return x_in | |
| x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) | |
| x_tile_out = repeat_func(x_tile, bboxes) | |
| # de-batching | |
| for i, bbox in enumerate(bboxes): | |
| self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] | |
| self.weights[bbox.slicer] += 1 | |
| self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode | |
| x_local = self.x_buffer/self.weights | |
| self.x_buffer = torch.zeros_like(self.x_buffer) | |
| self.weights = torch.zeros_like(self.weights) | |
| std_, mean_ = x_in.std(), x_in.mean() | |
| c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 | |
| if self.p.gaussian_filter: | |
| x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3) | |
| x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_ | |
| if not hasattr(self.p.sd_model, 'apply_model_ori'): | |
| self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model | |
| self.p.sd_model.apply_model = self.apply_model_hijack | |
| x_global = torch.zeros_like(x_local) | |
| jitter_range = self.jitter_range | |
| end = x_global.shape[3]-jitter_range | |
| current_num = 0 | |
| if self.p.mixture: | |
| for batch_id, bboxes in enumerate(self.global_batched_bboxes): | |
| current_num += len(bboxes) | |
| if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): | |
| res = len(bboxes) - (current_num - self.global_num_tiles//2) | |
| x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx<res else x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for idx,bbox in enumerate(bboxes)],dim=0) | |
| elif current_num > (self.global_num_tiles//2): | |
| x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) | |
| else: | |
| x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) | |
| x_global_i = repeat_func(x_in_i,bboxes,mode=1) | |
| if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): | |
| for idx,bbox in enumerate(bboxes): | |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] | |
| elif current_num > (self.global_num_tiles//2): | |
| for idx,bbox in enumerate(bboxes): | |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] | |
| else: | |
| for idx,bbox in enumerate(bboxes): | |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] | |
| else: | |
| for batch_id, bboxes in enumerate(self.global_batched_bboxes): | |
| x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) | |
| x_global_i = repeat_func(x_in_i,bboxes,mode=1) | |
| for idx,bbox in enumerate(bboxes): | |
| x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] | |
| #NOTE According to the original execution process, it would be very strange to use the predicted noise of gaussian latents to predict the denoised data in non Gaussian latents. Why? | |
| if self.p.mixture: | |
| self.x_buffer +=x_global/2 | |
| else: | |
| self.x_buffer += x_global | |
| self.weights += 1 | |
| self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori | |
| x_global = self.x_buffer/self.weights | |
| c2 = self.cosine_factor**self.p.cosine_scale_2 | |
| self.x_buffer= x_local*(1-c2)+ x_global*c2 | |
| return self.x_buffer | |
| def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict): | |
| assert LatentDiffusion.apply_model | |
| x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) | |
| return x_tile_out | |
| # NOTE: Using Gaussian Latent to Predict Noise on the Original Latent | |
| # if self.flag == 1: | |
| # x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) | |
| # self.x_out_list.append(x_tile_out) | |
| # return x_tile_out | |
| # else: | |
| # self.x_out_idx += 1 | |
| # return self.x_out_list[self.x_out_idx] | |
| def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: | |
| # NOTE: The following code is analytically wrong but aesthetically beautiful | |
| cond_in_original = cond_in.copy() | |
| self.repeat_3 = True | |
| self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) | |
| jitter_range = self.jitter_range | |
| _,_,H,W = x_in.shape | |
| x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) | |
| return self.sample_one_step(x_in_, sigma_in, cond_in_original)[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] | |