import numpy as np import gradio as gr from PIL import Image import torch, math, re, random from distutils.version import StrictVersion import modules.images as images import modules.sd_models as sd_models from modules import scripts, script_callbacks, shared, sd_samplers, devices, extra_networks from modules.shared import opts from modules.processing import program_version from modules.processing import StableDiffusionProcessingTxt2Img, create_random_tensors, opt_C, opt_f, decode_first_stage, get_fixed_seed, create_infotext suppver = "1.3.0" version = re.search("v[\d\.]*", program_version())[0].replace('v','') low = StrictVersion(version) < StrictVersion(suppver) sample_org = StableDiffusionProcessingTxt2Img.sample def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): #print("Running custom sample function... ") self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) if not self.enable_hr: return samples self.is_hr_pass = True target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y rolling_factor = getattr(self, 'hfp_rolling_factor', 0) def save_intermediate(image, index): """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return if not isinstance(image, Image.Image): image = sd_samplers.sample_to_image(image, index, approximation=0) info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix") if rolling_factor != 1 and rolling_factor < self.hr_upscale_to_x/self.width: rounds = math.ceil(math.log(self.hr_upscale_to_x/self.width)/math.log(rolling_factor)) shared.state.job_count = rounds shared.total_tqdm.updateTotal(self.steps+get_steps(self) * (rounds - 1)) for t in range(1, rounds): print(f"Generation round {t}/{rounds - 1} ") target_width = int(self.width * math.pow(rolling_factor,t)) target_height = int(self.height * math.pow(rolling_factor,t)) seeds = list(map(lambda x: x + opts.data.get("hfp_jitter_step", 1), seeds)) if opts.data.get("hfp_jitter_seeds", False) else seeds seeds = [get_fixed_seed(-1)] * len(seeds) if opts.data.get("hfp_random_seeds", False) else seeds if t == rounds-1: target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y if latent_scale_mode is not None: for i in range(samples.shape[0]): if opts.data.get("hfp_save_every_image", False): save_intermediate(samples, i) else: if t == 1: save_intermediate(samples, i) samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) else: image_conditioning = self.txt2img_image_conditioning(samples) else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] for i, x_sample in enumerate(lowres_samples): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) image = Image.fromarray(x_sample) if opts.data.get("hfp_save_every_image", False): save_intermediate(samples, i) else: if t == 1: save_intermediate(samples, i) image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) decoded_samples = 2. * decoded_samples - 1. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) shared.state.nextjob() img2img_sampler_name = self.hr_sampler_name or self.sampler_name if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM img2img_sampler_name = 'DDIM' self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() if not self.disable_extra_networks: with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) cfg = self.cfg_scale self.cfg_scale = getattr(self, 'hfp_cfg', 0) or self.cfg_scale samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) self.cfg_scale = cfg sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) else: if latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) else: image_conditioning = self.txt2img_image_conditioning(samples) else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] for i, x_sample in enumerate(lowres_samples): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) image = Image.fromarray(x_sample) save_intermediate(image, i) image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) decoded_samples = 2. * decoded_samples - 1. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) shared.state.nextjob() img2img_sampler_name = self.hr_sampler_name or self.sampler_name if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM img2img_sampler_name = 'DDIM' self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() if not self.disable_extra_networks: with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) cfg = self.cfg_scale self.cfg_scale = getattr(self, 'hfp_cfg', 0) or self.cfg_scale samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) self.cfg_scale = cfg sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.is_hr_pass = False return samples def gr_show(visible=True, n=1): if n > 1: return [{"visible": visible, "__type__": "update"}] * n return {"visible": visible, "__type__": "update"} def get_steps(p): log_steps = max(opts.data.get("hfp_smartstep_min", 9), round(math.log(10,p.steps)*p.steps*p.denoising_strength)) steps = p.hr_second_pass_steps if p.hr_second_pass_steps !=0 else log_steps return steps class HiresFixPlus(scripts.Script): def title(self): return 'Hires.fix Progressive' def describe(self): return "A progressive version of hires.fix implementation." def show(self, is_img2img): if not is_img2img: return scripts.AlwaysVisible def after_component(self, component, **kwargs): if low: if kwargs.get("elem_id") == f"txt2img_enable_hr": self.warring_text = gr.HTML(value=f'Hires.fix+ requires WebUI v{suppver} or later
But you have {program_version()}, please update it.', elem_id="hfp_warring_text") else: if kwargs.get("elem_id") == f"txt2img_enable_hr": self.warring_text = gr.HTML(value='Set "Hires steps" to [0], if you need
Hires. fix+ to do steps optimization', elem_id="hfp_warring_text") if kwargs.get("elem_id") == f"txt2img_denoising_strength": self.hfp_cfg = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, label='Hires CFG', value=0.0, elem_id="txt2img_hfp_cfg", interactive=True) if kwargs.get("elem_id") == f"txt2img_hr_resize_y": self.hfp_rolling_factor = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, label='Rolling factor', value=1.0, elem_id="txt2img_hfp_rolling_factor", interactive=True) def ui(self, is_img2img): if not low: self.infotext_fields = [ (self.hfp_cfg, "Hires CFG"), (self.hfp_rolling_factor, "Rolling factor") ] self.paste_field_names = [ (self.hfp_cfg, "Hires CFG"), (self.hfp_rolling_factor, "Rolling factor") ] return [self.hfp_cfg, self.hfp_rolling_factor] def process(self, p, hfp_cfg:int = 0, hfp_rolling_factor:float = 1.0): if not low and p.enable_hr: print('Hijacking Hires. fix... ') StableDiffusionProcessingTxt2Img.sample = sample self.hr_step = p.hr_second_pass_steps p.hr_second_pass_steps = get_steps(p) hires_cfg = (getattr(p, 'hfp_cfg', 0) or hfp_cfg) or p.cfg_scale setattr(p, "hfp_cfg", hires_cfg) setattr(p, "hfp_rolling_factor", hfp_rolling_factor) if hires_cfg != p.cfg_scale: p.extra_generation_params["Hires CFG"] = hfp_cfg if hfp_rolling_factor != 1: p.extra_generation_params["Rolling factor"] = hfp_rolling_factor def process_batch(self, p, *args, **kwargs): if not low and p.enable_hr: p.extra_generation_params["Hires steps"] = self.hr_step if self.hr_step != 0 else None def postprocess(self, p, processed, *args): if not low and p.enable_hr: StableDiffusionProcessingTxt2Img.sample = sample_org def create_script_items(): try: xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module def apply_hires_cfg(p, x, xs): setattr(p, "hfp_cfg", x) def apply_hires_sampler(p, x, xs): hr_sampler = sd_samplers.samplers_map.get(x.lower(), None) if hr_sampler is None: raise RuntimeError(f"Unknown sampler: {x} ") setattr(p, "hr_sampler_name", hr_sampler) extra_axis_options = [ xyz_grid.AxisOptionTxt2Img("Hires Sampler", str, apply_hires_sampler, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), xyz_grid.AxisOptionTxt2Img("Hires CFG", float, apply_hires_cfg) ] if not any("[HF+]" in x.label for x in xyz_grid.axis_options): xyz_grid.axis_options.extend(extra_axis_options) except Exception as e: traceback.print_exc() print(f"Failed to add support for X/Y/Z Plot Script because: {e} ") def create_settings_items(): section_hfp = ('hiresfix_plus', 'Hires. fix+') opts.add_option("hfp_smartstep_min", shared.OptionInfo( 9, "If Smart-Step is enabled, the number of iterations for Hires. fix will never be less than this:", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, section=section_hfp )) opts.add_option("hfp_save_every_image", shared.OptionInfo( False, "If \"Save a copy of image before doing face restoration.\" is enabled, save every image during rolling generation", section=section_hfp )) opts.add_option("hfp_jitter_seeds", shared.OptionInfo( False, "Jitter the seeds of sub-generations when doing a rolling generation (Still deterministic)", section=section_hfp )) opts.add_option("hfp_jitter_step", shared.OptionInfo( 1, "Jitter step:", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}, section=section_hfp )) opts.add_option("hfp_random_seeds", shared.OptionInfo( False, "Use random seeds for sub-generations when doing a rolling generation (WARNING!!! The result will be non-deterministic!!!)", section=section_hfp )) if low: print(f'Hires.fix+ requires WebUI v{suppver} or later. But you have {program_version()}, please update it. ') else: scripts.script_callbacks.on_ui_settings(create_settings_items) script_callbacks.on_before_ui(create_script_items)