|
|
from modules.processing import StableDiffusionProcessingTxt2Img as txt2img |
|
|
from modules.processing import StableDiffusionProcessingImg2Img as img2img |
|
|
from modules.sd_samplers_kdiffusion import KDiffusionSampler as KSampler |
|
|
from modules.script_callbacks import on_script_unloaded |
|
|
from modules.ui_components import InputAccordion |
|
|
from modules import scripts |
|
|
|
|
|
from lib_resharpen.scaling import apply_scaling |
|
|
from lib_resharpen.param import ReSharpenParams |
|
|
from lib_resharpen.xyz import xyz_support |
|
|
|
|
|
from functools import wraps |
|
|
from typing import Callable |
|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
|
|
|
original_callback: Callable = KSampler.callback_state |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
@wraps(original_callback) |
|
|
def hijack_callback(self, d: dict): |
|
|
if getattr(self.p, "_ad_inner", False): |
|
|
return original_callback(self, d) |
|
|
|
|
|
params: ReSharpenParams = getattr(self, "resharpen_params", None) |
|
|
|
|
|
if not params: |
|
|
return original_callback(self, d) |
|
|
|
|
|
if params.cache is not None: |
|
|
delta: torch.Tensor = d["x"].detach().clone() - params.cache |
|
|
d["x"] += delta * apply_scaling( |
|
|
params.scaling, params.strength, d["i"], params.total_step |
|
|
) |
|
|
|
|
|
params.cache = d["x"].detach().clone() |
|
|
return original_callback(self, d) |
|
|
|
|
|
|
|
|
KSampler.callback_state = hijack_callback |
|
|
|
|
|
|
|
|
class ReSharpen(scripts.Script): |
|
|
|
|
|
def __init__(self): |
|
|
self.XYZ_CACHE = {} |
|
|
xyz_support(self.XYZ_CACHE) |
|
|
|
|
|
def title(self): |
|
|
return "ReSharpen" |
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img): |
|
|
with InputAccordion(value=False, label=self.title()) as enable: |
|
|
|
|
|
gr.Markdown( |
|
|
'<h3 style="float: left;">Softer</h3> <h3 style="float: right;">Sharper</h3>' |
|
|
) |
|
|
|
|
|
with gr.Group(elem_classes="resharpen"): |
|
|
decay = gr.Slider( |
|
|
label="Sharpness", minimum=-1.0, maximum=1.0, step=0.1, value=0.0 |
|
|
) |
|
|
scaling = gr.Radio( |
|
|
["Flat", "Cos", "Sin", "1 - Cos", "1 - Sin"], |
|
|
label="Scaling Settings", |
|
|
value="Flat", |
|
|
) |
|
|
|
|
|
with gr.Group(elem_classes=["resharpen", "hr"], visible=(not is_img2img)): |
|
|
hr_decay = gr.Slider( |
|
|
label="Hires. Fix Sharpness", |
|
|
minimum=-1.0, |
|
|
maximum=1.0, |
|
|
step=0.1, |
|
|
value=0.0, |
|
|
) |
|
|
hr_scaling = gr.Radio( |
|
|
["Flat", "Cos", "Sin", "1 - Cos", "1 - Sin"], |
|
|
label="Scaling Settings", |
|
|
value="Flat", |
|
|
) |
|
|
|
|
|
self.paste_field_names = [] |
|
|
self.infotext_fields = [ |
|
|
(enable, "Resharpen Enabled"), |
|
|
(decay, "Resharpen Sharpness"), |
|
|
(scaling, "Resharpen Scaling"), |
|
|
(hr_decay, "Resharpen Sharpness Hires"), |
|
|
(hr_scaling, "Resharpen Scaling Hires"), |
|
|
] |
|
|
|
|
|
for comp, name in self.infotext_fields: |
|
|
comp.do_not_save_to_config = True |
|
|
self.paste_field_names.append(name) |
|
|
|
|
|
return [enable, decay, scaling, hr_decay, hr_scaling] |
|
|
|
|
|
def process( |
|
|
self, |
|
|
p, |
|
|
enable: bool, |
|
|
decay: float, |
|
|
scaling: str, |
|
|
hr_decay: float, |
|
|
hr_scaling: str, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
if not enable: |
|
|
if hasattr(KSampler, "resharpen_params"): |
|
|
delattr(KSampler, "resharpen_params") |
|
|
self.XYZ_CACHE.clear() |
|
|
return p |
|
|
|
|
|
if p.sampler_name.strip() == "Euler a": |
|
|
print("\n[ReSharpen] has little effect with Ancestral samplers!\n") |
|
|
|
|
|
decay = float(self.XYZ_CACHE.pop("decay", decay)) |
|
|
scaling = str(self.XYZ_CACHE.pop("scaling", scaling)) |
|
|
|
|
|
p.extra_generation_params.update( |
|
|
{ |
|
|
"Resharpen Enabled": enable, |
|
|
"Resharpen Sharpness": decay, |
|
|
"Resharpen Scaling": scaling, |
|
|
} |
|
|
) |
|
|
|
|
|
params = ReSharpenParams( |
|
|
enable, |
|
|
scaling, |
|
|
decay / -10.0, |
|
|
0, |
|
|
None, |
|
|
) |
|
|
|
|
|
setattr(KSampler, "resharpen_params", params) |
|
|
|
|
|
if isinstance(p, img2img): |
|
|
self.XYZ_CACHE.clear() |
|
|
return p |
|
|
|
|
|
assert isinstance(p, txt2img) |
|
|
|
|
|
if not getattr(p, "enable_hr", False): |
|
|
self.XYZ_CACHE.clear() |
|
|
return p |
|
|
|
|
|
hr_decay = float(self.XYZ_CACHE.get("hr_decay", hr_decay)) |
|
|
hr_scaling = str(self.XYZ_CACHE.get("hr_scaling", hr_scaling)) |
|
|
p.extra_generation_params.update( |
|
|
{ |
|
|
"Resharpen Sharpness Hires": hr_decay, |
|
|
"Resharpen Scaling Hires": hr_scaling, |
|
|
} |
|
|
) |
|
|
|
|
|
return p |
|
|
|
|
|
def process_before_every_sampling(self, p, enable, *args, **kwargs): |
|
|
if enable: |
|
|
KSampler.resharpen_params.total_step = ( |
|
|
getattr(p, "firstpass_steps", 0) or p.steps |
|
|
) |
|
|
|
|
|
def before_hr( |
|
|
self, |
|
|
p, |
|
|
enable: bool, |
|
|
decay: float, |
|
|
scaling: str, |
|
|
hr_decay: float, |
|
|
hr_scaling: str, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
if not enable: |
|
|
return p |
|
|
|
|
|
hr_decay = float(self.XYZ_CACHE.pop("hr_decay", hr_decay)) |
|
|
hr_scaling = str(self.XYZ_CACHE.pop("hr_scaling", hr_scaling)) |
|
|
|
|
|
params = ReSharpenParams( |
|
|
enable, |
|
|
hr_scaling, |
|
|
hr_decay / -10.0, |
|
|
getattr(p, "hr_second_pass_steps", 0) or p.steps, |
|
|
None, |
|
|
) |
|
|
|
|
|
setattr(KSampler, "resharpen_params", params) |
|
|
|
|
|
self.XYZ_CACHE.clear() |
|
|
return p |
|
|
|
|
|
|
|
|
def restore_callback(): |
|
|
KSampler.callback_state = original_callback |
|
|
|
|
|
|
|
|
on_script_unloaded(restore_callback) |
|
|
|