Spaces:
Running
on
Zero
Running
on
Zero
| """SAMPLING ONLY.""" | |
| import torch | |
| from dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3 | |
| from uni_pc import UniPC | |
| from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d | |
| class DPMSolverv3Sampler: | |
| def __init__(self, stats_dir, pipe, steps, guidance_scale, **kwargs): | |
| super().__init__() | |
| self.model = pipe | |
| to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device) | |
| DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse | |
| device = "cuda" | |
| noise_scheduler = pipe.scheduler | |
| alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE) | |
| self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod) | |
| self.device = device | |
| self.guidance_scale = guidance_scale | |
| self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) | |
| assert stats_dir is not None, f"No statistics file found in {stats_dir}." | |
| print("Use statistics", stats_dir) | |
| self.dpm_solver_v3 = DPM_Solver_v3( | |
| statistics_dir=stats_dir, | |
| noise_schedule=self.ns, | |
| steps=steps, | |
| t_start=None, | |
| t_end=None, | |
| skip_type="customed_time_karras", | |
| degenerated=False, | |
| device=self.device, | |
| ) | |
| self.steps = steps | |
| def apply_free_unet(self): | |
| register_free_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2) | |
| register_free_crossattn_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2) | |
| def stop_free_unet(self): | |
| register_free_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0) | |
| register_free_crossattn_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0) | |
| def sample( | |
| self, | |
| batch_size, | |
| shape, | |
| conditioning=None, | |
| x_T=None, | |
| unconditional_conditioning=None, | |
| use_corrector=False, | |
| half=False, | |
| start_free_u_step=None, | |
| # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
| **kwargs, | |
| ): | |
| if conditioning is not None: | |
| cond_in = torch.cat([unconditional_conditioning, conditioning]) | |
| # extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.guidance_scale} | |
| if isinstance(conditioning, dict): | |
| cbs = conditioning[list(conditioning.keys())[0]].shape[0] | |
| if cbs != batch_size: | |
| print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") | |
| else: | |
| if conditioning.shape[0] != batch_size: | |
| print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") | |
| # sampling | |
| C, H, W = shape | |
| size = (batch_size, C, H, W) | |
| if x_T is None: | |
| img = torch.randn(size, device=self.device) | |
| else: | |
| img = x_T | |
| if conditioning is None: | |
| model_fn = model_wrapper( | |
| lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample, | |
| self.ns, | |
| model_type="noise", | |
| guidance_type="uncond", | |
| ) | |
| ORDER = 3 | |
| else: | |
| model_fn = model_wrapper( | |
| lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample, | |
| self.ns, | |
| model_type="noise", | |
| guidance_type="classifier-free", | |
| condition=conditioning, | |
| unconditional_condition=unconditional_conditioning, | |
| guidance_scale=self.guidance_scale, | |
| ) | |
| if self.steps == 8: | |
| ORDER = 2 | |
| else: | |
| ORDER = 1 | |
| x = self.dpm_solver_v3.sample( | |
| img, | |
| model_fn, | |
| order=ORDER, | |
| p_pseudo=False, | |
| c_pseudo=True, | |
| lower_order_final=True, | |
| use_corrector=use_corrector, | |
| start_free_u_step=start_free_u_step, | |
| free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None, | |
| free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None, | |
| half=half, | |
| ) | |
| return x.to(self.device), None | |
| class UniPCSampler: | |
| def __init__(self | |
| , pipe | |
| , model_closure | |
| , steps | |
| , guidance_scale,denoise_to_zero=False | |
| , need_fp16_discrete_method = False | |
| , ultilize_vae_in_fp16 = False | |
| , is_high_resoulution = True | |
| , skip_type="customed_time_karras" | |
| , force_not_use_afs=False | |
| , **kwargs): | |
| super().__init__() | |
| # self.model = pipe | |
| self.model = model_closure(pipe) | |
| self.pipe = pipe | |
| self.need_fp16_discrete_method = need_fp16_discrete_method | |
| # to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device) | |
| DTYPE = self.pipe.unet.dtype # torch.float16 works as well, but pictures seem to be a bit worse | |
| device = self.pipe.device | |
| noise_scheduler = pipe.scheduler | |
| alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE) | |
| self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod) | |
| self.device = device | |
| self.guidance_scale = guidance_scale | |
| self.use_afs = steps <= 8 and is_high_resoulution and not force_not_use_afs | |
| self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) | |
| self.unipc_solver = UniPC( | |
| noise_schedule=self.ns, | |
| steps=steps, | |
| t_start=None, | |
| t_end=None, | |
| skip_type=skip_type, | |
| degenerated=False, | |
| use_afs=self.use_afs, | |
| device=self.device, | |
| denoise_to_zero=denoise_to_zero, | |
| need_fp16_discrete_method = self.need_fp16_discrete_method, | |
| ultilize_vae_in_fp16 = ultilize_vae_in_fp16, | |
| is_high_resoulution=is_high_resoulution, | |
| ) | |
| self.steps = steps | |
| def apply_free_unet(self): | |
| register_free_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2) | |
| register_free_crossattn_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2) | |
| def stop_free_unet(self): | |
| register_free_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0) | |
| register_free_crossattn_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0) | |
| def sample( | |
| self, | |
| batch_size, | |
| shape, | |
| conditioning=None, | |
| x_T=None, | |
| unconditional_conditioning=None, | |
| use_corrector=False, | |
| half=False, | |
| start_free_u_step=None, | |
| xl_preprocess_closure=None, | |
| npnet=None, | |
| # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
| **kwargs, | |
| ): | |
| # sampling | |
| C, H, W = shape | |
| size = (batch_size, C, H, W) | |
| new_img = None | |
| if xl_preprocess_closure is not None: | |
| prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning) | |
| if x_T is None: | |
| img = torch.randn(size, device=self.device) | |
| else: | |
| img = x_T | |
| if xl_preprocess_closure is not None and npnet is not None: | |
| c, _ = prompt_embeds | |
| c = c.unsqueeze(0) # add dummy dimension for npnet | |
| new_img = npnet(img, c) | |
| if conditioning is None: | |
| model_fn = model_wrapper( | |
| lambda x, t, c: self.model(x, t, c), | |
| self.ns, | |
| model_type="noise", | |
| guidance_type="uncond", | |
| ) | |
| ORDER = 3 | |
| else: | |
| model_fn = model_wrapper( | |
| lambda x, t, c: self.model(x, t, c), | |
| self.ns, | |
| model_type="noise", | |
| guidance_type="classifier-free", | |
| condition=conditioning if xl_preprocess_closure is None else prompt_embeds, | |
| unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs, | |
| guidance_scale=self.guidance_scale, | |
| ) | |
| if self.steps >= 7: | |
| ORDER = 2 | |
| else: | |
| ORDER = 1 | |
| x, full_cache = self.unipc_solver.sample( | |
| x=img, | |
| model_fn=model_fn, | |
| order=ORDER, | |
| use_corrector=use_corrector, | |
| lower_order_final=True, | |
| start_free_u_step=start_free_u_step, | |
| free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None, | |
| free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None, | |
| npnet_x=new_img if new_img is not None else None, | |
| npnet_scale=self.guidance_scale if new_img is not None else None, | |
| half=half, | |
| ) | |
| return x.to(self.device), full_cache | |
| def sample_mix( | |
| self, | |
| batch_size, | |
| shape, | |
| conditioning=None, | |
| x_T=None, | |
| unconditional_conditioning=None, | |
| use_corrector=False, | |
| half=False, | |
| start_free_u_step=None, | |
| xl_preprocess_closure=None, | |
| npnet=None, | |
| # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | |
| **kwargs, | |
| ): | |
| # sampling | |
| C, H, W = shape | |
| size = (batch_size, C, H, W) | |
| if xl_preprocess_closure is not None: | |
| prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning) | |
| if x_T is None: | |
| img = torch.randn(size, device=self.device) | |
| else: | |
| img = x_T | |
| if xl_preprocess_closure is not None and npnet is not None: | |
| c, _ = prompt_embeds | |
| c = c.unsqueeze(0) # add dummy dimension for npnet | |
| img = npnet(img, c) | |
| if conditioning is None: | |
| model_fn = model_wrapper( | |
| lambda x, t, c: self.model(x, t, c), | |
| self.ns, | |
| model_type="noise", | |
| guidance_type="uncond", | |
| ) | |
| ORDER = 3 | |
| else: | |
| model_fn = model_wrapper( | |
| lambda x, t, c: self.model(x, t, c), | |
| self.ns, | |
| model_type="noise", | |
| guidance_type="classifier-free", | |
| condition=conditioning if xl_preprocess_closure is None else prompt_embeds, | |
| unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs, | |
| guidance_scale=self.guidance_scale, | |
| ) | |
| if self.steps >= 8 and not self.need_fp16_discrete_method: | |
| ORDER = 2 | |
| else: | |
| ORDER = 1 | |
| x, full_cache = self.unipc_solver.sample_mix( | |
| x=img, | |
| model_fn=model_fn, | |
| order=ORDER, | |
| use_corrector=use_corrector, | |
| lower_order_final=True, | |
| start_free_u_step=start_free_u_step, | |
| free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None, | |
| free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None, | |
| half=half, | |
| ) | |
| return x.to(self.device), full_cache |