from modules.processing import StableDiffusionProcessingTxt2Img as txt2img from modules.processing import StableDiffusionProcessingImg2Img as img2img from modules.sd_samplers_kdiffusion import KDiffusionSampler as KSampler from modules.script_callbacks import on_script_unloaded from modules.ui_components import InputAccordion from modules import scripts from lib_resharpen.scaling import apply_scaling from lib_resharpen.param import ReSharpenParams from lib_resharpen.xyz import xyz_support from functools import wraps from typing import Callable import gradio as gr import torch original_callback: Callable = KSampler.callback_state @torch.inference_mode() @wraps(original_callback) def hijack_callback(self, d: dict): if getattr(self.p, "_ad_inner", False): return original_callback(self, d) params: ReSharpenParams = getattr(self, "resharpen_params", None) if not params: return original_callback(self, d) if params.cache is not None: delta: torch.Tensor = d["x"].detach().clone() - params.cache d["x"] += delta * apply_scaling( params.scaling, params.strength, d["i"], params.total_step ) params.cache = d["x"].detach().clone() return original_callback(self, d) KSampler.callback_state = hijack_callback class ReSharpen(scripts.Script): def __init__(self): self.XYZ_CACHE = {} xyz_support(self.XYZ_CACHE) def title(self): return "ReSharpen" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with InputAccordion(value=False, label=self.title()) as enable: gr.Markdown( '

Softer

Sharper

' ) with gr.Group(elem_classes="resharpen"): decay = gr.Slider( label="Sharpness", minimum=-1.0, maximum=1.0, step=0.1, value=0.0 ) scaling = gr.Radio( ["Flat", "Cos", "Sin", "1 - Cos", "1 - Sin"], label="Scaling Settings", value="Flat", ) with gr.Group(elem_classes=["resharpen", "hr"], visible=(not is_img2img)): hr_decay = gr.Slider( label="Hires. Fix Sharpness", minimum=-1.0, maximum=1.0, step=0.1, value=0.0, ) hr_scaling = gr.Radio( ["Flat", "Cos", "Sin", "1 - Cos", "1 - Sin"], label="Scaling Settings", value="Flat", ) self.paste_field_names = [] self.infotext_fields = [ (enable, "Resharpen Enabled"), (decay, "Resharpen Sharpness"), (scaling, "Resharpen Scaling"), (hr_decay, "Resharpen Sharpness Hires"), (hr_scaling, "Resharpen Scaling Hires"), ] for comp, name in self.infotext_fields: comp.do_not_save_to_config = True self.paste_field_names.append(name) return [enable, decay, scaling, hr_decay, hr_scaling] def process( self, p, enable: bool, decay: float, scaling: str, hr_decay: float, hr_scaling: str, *args, **kwargs, ): if not enable: if hasattr(KSampler, "resharpen_params"): delattr(KSampler, "resharpen_params") self.XYZ_CACHE.clear() return p if p.sampler_name.strip() == "Euler a": print("\n[ReSharpen] has little effect with Ancestral samplers!\n") decay = float(self.XYZ_CACHE.pop("decay", decay)) scaling = str(self.XYZ_CACHE.pop("scaling", scaling)) p.extra_generation_params.update( { "Resharpen Enabled": enable, "Resharpen Sharpness": decay, "Resharpen Scaling": scaling, } ) params = ReSharpenParams( enable, scaling, decay / -10.0, 0, None, ) setattr(KSampler, "resharpen_params", params) if isinstance(p, img2img): self.XYZ_CACHE.clear() return p assert isinstance(p, txt2img) if not getattr(p, "enable_hr", False): self.XYZ_CACHE.clear() return p hr_decay = float(self.XYZ_CACHE.get("hr_decay", hr_decay)) hr_scaling = str(self.XYZ_CACHE.get("hr_scaling", hr_scaling)) p.extra_generation_params.update( { "Resharpen Sharpness Hires": hr_decay, "Resharpen Scaling Hires": hr_scaling, } ) return p def process_before_every_sampling(self, p, enable, *args, **kwargs): if enable: KSampler.resharpen_params.total_step = ( getattr(p, "firstpass_steps", 0) or p.steps ) def before_hr( self, p, enable: bool, decay: float, scaling: str, hr_decay: float, hr_scaling: str, *args, **kwargs, ): if not enable: return p hr_decay = float(self.XYZ_CACHE.pop("hr_decay", hr_decay)) hr_scaling = str(self.XYZ_CACHE.pop("hr_scaling", hr_scaling)) params = ReSharpenParams( enable, hr_scaling, hr_decay / -10.0, getattr(p, "hr_second_pass_steps", 0) or p.steps, None, ) setattr(KSampler, "resharpen_params", params) self.XYZ_CACHE.clear() return p def restore_callback(): KSampler.callback_state = original_callback on_script_unloaded(restore_callback)