| import numpy as np |
| import gradio as gr |
| from PIL import Image |
| import torch, math, re, random |
| from distutils.version import StrictVersion |
|
|
| import modules.images as images |
| import modules.sd_models as sd_models |
| from modules import scripts, script_callbacks, shared, sd_samplers, devices, extra_networks |
|
|
| from modules.shared import opts |
| from modules.processing import program_version |
| from modules.processing import StableDiffusionProcessingTxt2Img, create_random_tensors, opt_C, opt_f, decode_first_stage, get_fixed_seed, create_infotext |
|
|
| suppver = "1.3.0" |
| version = re.search("v[\d\.]*", program_version())[0].replace('v','') |
| low = StrictVersion(version) < StrictVersion(suppver) |
| sample_org = StableDiffusionProcessingTxt2Img.sample |
|
|
| def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): |
| |
| self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) |
| |
| latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") |
| if self.enable_hr and latent_scale_mode is None: |
| assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" |
| |
| x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) |
| samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) |
| |
| if not self.enable_hr: |
| return samples |
| |
| self.is_hr_pass = True |
| |
| target_width = self.hr_upscale_to_x |
| target_height = self.hr_upscale_to_y |
|
|
| rolling_factor = getattr(self, 'hfp_rolling_factor', 0) |
|
|
| def save_intermediate(image, index): |
| """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" |
| |
| if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: |
| return |
| |
| if not isinstance(image, Image.Image): |
| image = sd_samplers.sample_to_image(image, index, approximation=0) |
| |
| info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) |
| images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix") |
| |
| if rolling_factor != 1 and rolling_factor < self.hr_upscale_to_x/self.width: |
|
|
| rounds = math.ceil(math.log(self.hr_upscale_to_x/self.width)/math.log(rolling_factor)) |
|
|
| shared.state.job_count = rounds |
| shared.total_tqdm.updateTotal(self.steps+get_steps(self) * (rounds - 1)) |
| |
| for t in range(1, rounds): |
| print(f"Generation round {t}/{rounds - 1} ") |
| target_width = int(self.width * math.pow(rolling_factor,t)) |
| target_height = int(self.height * math.pow(rolling_factor,t)) |
| seeds = list(map(lambda x: x + opts.data.get("hfp_jitter_step", 1), seeds)) if opts.data.get("hfp_jitter_seeds", False) else seeds |
| seeds = [get_fixed_seed(-1)] * len(seeds) if opts.data.get("hfp_random_seeds", False) else seeds |
|
|
| if t == rounds-1: |
| target_width = self.hr_upscale_to_x |
| target_height = self.hr_upscale_to_y |
|
|
| if latent_scale_mode is not None: |
| for i in range(samples.shape[0]): |
| if opts.data.get("hfp_save_every_image", False): |
| save_intermediate(samples, i) |
| else: |
| if t == 1: |
| save_intermediate(samples, i) |
| |
| samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) |
| |
| |
| |
| if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: |
| image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) |
| else: |
| image_conditioning = self.txt2img_image_conditioning(samples) |
| else: |
| decoded_samples = decode_first_stage(self.sd_model, samples) |
| lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) |
| |
| batch_images = [] |
| for i, x_sample in enumerate(lowres_samples): |
| x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) |
| x_sample = x_sample.astype(np.uint8) |
| image = Image.fromarray(x_sample) |
| |
| if opts.data.get("hfp_save_every_image", False): |
| save_intermediate(samples, i) |
| else: |
| if t == 1: |
| save_intermediate(samples, i) |
| |
| image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) |
| image = np.array(image).astype(np.float32) / 255.0 |
| image = np.moveaxis(image, 2, 0) |
| batch_images.append(image) |
| |
| decoded_samples = torch.from_numpy(np.array(batch_images)) |
| decoded_samples = decoded_samples.to(shared.device) |
| decoded_samples = 2. * decoded_samples - 1. |
| |
| samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) |
| |
| image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) |
| |
| shared.state.nextjob() |
| |
| img2img_sampler_name = self.hr_sampler_name or self.sampler_name |
| |
| if self.sampler_name in ['PLMS', 'UniPC']: |
| img2img_sampler_name = 'DDIM' |
| |
| self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) |
| |
| samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] |
|
|
| noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) |
| |
| |
| x = None |
| devices.torch_gc() |
| |
| if not self.disable_extra_networks: |
| with devices.autocast(): |
| extra_networks.activate(self, self.hr_extra_network_data) |
| |
| sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) |
| |
| cfg = self.cfg_scale |
| self.cfg_scale = getattr(self, 'hfp_cfg', 0) or self.cfg_scale |
|
|
| samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) |
|
|
| self.cfg_scale = cfg |
| |
| sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) |
| else: |
| if latent_scale_mode is not None: |
| for i in range(samples.shape[0]): |
| save_intermediate(samples, i) |
| |
| samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) |
| |
| |
| |
| if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: |
| image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) |
| else: |
| image_conditioning = self.txt2img_image_conditioning(samples) |
| else: |
| decoded_samples = decode_first_stage(self.sd_model, samples) |
| lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) |
| |
| batch_images = [] |
| for i, x_sample in enumerate(lowres_samples): |
| x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) |
| x_sample = x_sample.astype(np.uint8) |
| image = Image.fromarray(x_sample) |
| |
| save_intermediate(image, i) |
| |
| image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) |
| image = np.array(image).astype(np.float32) / 255.0 |
| image = np.moveaxis(image, 2, 0) |
| batch_images.append(image) |
| |
| decoded_samples = torch.from_numpy(np.array(batch_images)) |
| decoded_samples = decoded_samples.to(shared.device) |
| decoded_samples = 2. * decoded_samples - 1. |
| |
| samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) |
| |
| image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) |
| |
| shared.state.nextjob() |
| |
| img2img_sampler_name = self.hr_sampler_name or self.sampler_name |
| |
| if self.sampler_name in ['PLMS', 'UniPC']: |
| img2img_sampler_name = 'DDIM' |
| |
| self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) |
| |
| samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] |
| |
| noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) |
| |
| |
| x = None |
| devices.torch_gc() |
| |
| if not self.disable_extra_networks: |
| with devices.autocast(): |
| extra_networks.activate(self, self.hr_extra_network_data) |
| |
| sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) |
| |
| cfg = self.cfg_scale |
| self.cfg_scale = getattr(self, 'hfp_cfg', 0) or self.cfg_scale |
|
|
| samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) |
|
|
| self.cfg_scale = cfg |
| |
| sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) |
| |
| self.is_hr_pass = False |
| return samples |
|
|
| def gr_show(visible=True, n=1): |
| if n > 1: |
| return [{"visible": visible, "__type__": "update"}] * n |
| return {"visible": visible, "__type__": "update"} |
|
|
| def get_steps(p): |
| log_steps = max(opts.data.get("hfp_smartstep_min", 9), round(math.log(10,p.steps)*p.steps*p.denoising_strength)) |
| steps = p.hr_second_pass_steps if p.hr_second_pass_steps !=0 else log_steps |
| return steps |
|
|
| class HiresFixPlus(scripts.Script): |
| def title(self): |
| return 'Hires.fix Progressive' |
|
|
| def describe(self): |
| return "A progressive version of hires.fix implementation." |
|
|
| def show(self, is_img2img): |
| if not is_img2img: |
| return scripts.AlwaysVisible |
|
|
| def after_component(self, component, **kwargs): |
| if low: |
| if kwargs.get("elem_id") == f"txt2img_enable_hr": |
| self.warring_text = gr.HTML(value=f'Hires.fix+ requires WebUI v{suppver} or later<br>But you have {program_version()}, please update it.', elem_id="hfp_warring_text") |
| else: |
| if kwargs.get("elem_id") == f"txt2img_enable_hr": |
| self.warring_text = gr.HTML(value='Set "Hires steps" to [0], if you need<br>Hires. fix+ to do steps optimization', elem_id="hfp_warring_text") |
| if kwargs.get("elem_id") == f"txt2img_denoising_strength": |
| self.hfp_cfg = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, label='Hires CFG', value=0.0, elem_id="txt2img_hfp_cfg", interactive=True) |
| if kwargs.get("elem_id") == f"txt2img_hr_resize_y": |
| self.hfp_rolling_factor = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, label='Rolling factor', value=1.0, elem_id="txt2img_hfp_rolling_factor", interactive=True) |
|
|
| def ui(self, is_img2img): |
| if not low: |
| self.infotext_fields = [ |
| (self.hfp_cfg, "Hires CFG"), |
| (self.hfp_rolling_factor, "Rolling factor") |
| ] |
| |
| self.paste_field_names = [ |
| (self.hfp_cfg, "Hires CFG"), |
| (self.hfp_rolling_factor, "Rolling factor") |
| ] |
| |
| return [self.hfp_cfg, self.hfp_rolling_factor] |
|
|
| def process(self, p, hfp_cfg:int = 0, hfp_rolling_factor:float = 1.0): |
| if not low and p.enable_hr: |
| print('Hijacking Hires. fix... ') |
| StableDiffusionProcessingTxt2Img.sample = sample |
| self.hr_step = p.hr_second_pass_steps |
| p.hr_second_pass_steps = get_steps(p) |
| hires_cfg = (getattr(p, 'hfp_cfg', 0) or hfp_cfg) or p.cfg_scale |
| setattr(p, "hfp_cfg", hires_cfg) |
| setattr(p, "hfp_rolling_factor", hfp_rolling_factor) |
| if hires_cfg != p.cfg_scale: |
| p.extra_generation_params["Hires CFG"] = hfp_cfg |
| if hfp_rolling_factor != 1: |
| p.extra_generation_params["Rolling factor"] = hfp_rolling_factor |
|
|
| def process_batch(self, p, *args, **kwargs): |
| if not low and p.enable_hr: |
| p.extra_generation_params["Hires steps"] = self.hr_step if self.hr_step != 0 else None |
|
|
| def postprocess(self, p, processed, *args): |
| if not low and p.enable_hr: |
| StableDiffusionProcessingTxt2Img.sample = sample_org |
|
|
| def create_script_items(): |
| try: |
| xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module |
| |
| def apply_hires_cfg(p, x, xs): |
| setattr(p, "hfp_cfg", x) |
| |
| def apply_hires_sampler(p, x, xs): |
| hr_sampler = sd_samplers.samplers_map.get(x.lower(), None) |
| if hr_sampler is None: |
| raise RuntimeError(f"Unknown sampler: {x} ") |
| setattr(p, "hr_sampler_name", hr_sampler) |
| |
| extra_axis_options = [ |
| xyz_grid.AxisOptionTxt2Img("Hires Sampler", str, apply_hires_sampler, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), |
| xyz_grid.AxisOptionTxt2Img("Hires CFG", float, apply_hires_cfg) |
| ] |
| if not any("[HF+]" in x.label for x in xyz_grid.axis_options): |
| xyz_grid.axis_options.extend(extra_axis_options) |
| except Exception as e: |
| traceback.print_exc() |
| print(f"Failed to add support for X/Y/Z Plot Script because: {e} ") |
|
|
| def create_settings_items(): |
| section_hfp = ('hiresfix_plus', 'Hires. fix+') |
| opts.add_option("hfp_smartstep_min", shared.OptionInfo( |
| 9, "If Smart-Step is enabled, the number of iterations for Hires. fix will never be less than this:", |
| gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, section=section_hfp |
| )) |
| opts.add_option("hfp_save_every_image", shared.OptionInfo( |
| False, "If \"Save a copy of image before doing face restoration.\" is enabled, save every image during rolling generation", section=section_hfp |
| )) |
| opts.add_option("hfp_jitter_seeds", shared.OptionInfo( |
| False, "Jitter the seeds of sub-generations when doing a rolling generation (Still deterministic)", section=section_hfp |
| )) |
| opts.add_option("hfp_jitter_step", shared.OptionInfo( |
| 1, "Jitter step:", |
| gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}, section=section_hfp |
| )) |
| opts.add_option("hfp_random_seeds", shared.OptionInfo( |
| False, "Use random seeds for sub-generations when doing a rolling generation (WARNING!!! The result will be non-deterministic!!!)", section=section_hfp |
| )) |
|
|
| if low: |
| print(f'Hires.fix+ requires WebUI v{suppver} or later. But you have {program_version()}, please update it. ') |
| else: |
| scripts.script_callbacks.on_ui_settings(create_settings_items) |
| script_callbacks.on_before_ui(create_script_items) |
|
|