| import gradio as gr |
|
|
| from modules import scripts |
| import modules.shared as shared |
| import torch, math |
| import torchvision.transforms.functional as TF |
|
|
| |
|
|
| class driftrForge(scripts.Script): |
| def __init__(self): |
| self.method1 = "None" |
| self.method2 = "None" |
|
|
| def title(self): |
| return "Latent Drift Correction" |
|
|
| def show(self, is_img2img): |
| |
| return scripts.AlwaysVisible |
|
|
| def ui(self, *args, **kwargs): |
| with gr.Accordion(open=False, label=self.title()): |
| with gr.Row(): |
| method1 = gr.Dropdown(["None", "custom", "mean", "median", "mean/median average", "centered mean", "average of extremes", "average of quantiles"], value="None", type="value", label='Correction method (per channel)') |
| method2 = gr.Dropdown(["None", "mean", "median", "mean/median average", "center to quantile", "local average"], value="None", type="value", label='Correction method (overall)') |
| with gr.Row(): |
| strengthC = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=1.0, label='strength (per channel)') |
| strengthO = gr.Slider(minimum=-1.0, maximum=1.0, step=0.01, value=0.8, label='strength (overall)') |
| with gr.Row(equalHeight=True): |
| custom = gr.Textbox(value='0.5 * (M + m)', max_lines=1, label='custom function', visible=True) |
| topK = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, value=0.5, label='quantiles', visible=False, scale=0) |
| blur = gr.Slider(minimum=0, maximum=128, step=1, value=0, label='blur radius (x8)', visible=False, scale=0) |
| sigmaWeight = gr.Dropdown(["Hard", "Soft", "None"], value="Hard", type="value", label='Limit effect by sigma', scale=0) |
| with gr.Row(): |
| stepS = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label='Start step') |
| stepE = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='End step') |
| with gr.Row(): |
| softClampS = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='Soft clamp start step') |
| softClampE = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='Soft clamp end step') |
|
|
| def show_topK(m1, m2): |
| if m1 == "centered mean" or m1 == "average of extremes" or m1 == "average of quantiles": |
| return gr.update(visible=True), gr.update(visible=False) |
| elif m2 == "center to quantile": |
| return gr.update(visible=True), gr.update(visible=False) |
| elif m2 == "local average": |
| return gr.update(visible=False), gr.update(visible=True) |
| else: |
| return gr.update(visible=False), gr.update(visible=False) |
|
|
| method1.change( |
| fn=show_topK, |
| inputs=[method1, method2], |
| outputs=[topK, blur], |
| show_progress=False |
| ) |
| method2.change( |
| fn=show_topK, |
| inputs=[method1, method2], |
| outputs=[topK, blur], |
| show_progress=False |
| ) |
|
|
| self.infotext_fields = [ |
| (method1, "ldc_method1"), |
| (method2, "ldc_method2"), |
| (topK, "ldc_topK"), |
| (blur, "ldc_blur"), |
| (strengthC, "ldc_strengthC"), |
| (strengthO, "ldc_strengthO"), |
| (stepS, "ldc_stepS"), |
| (stepE, "ldc_stepE"), |
| (sigmaWeight, "ldc_sigW"), |
| (softClampS, "ldc_softClampS"), |
| (softClampE, "ldc_softClampE"), |
| (custom, "ldc_custom"), |
| ] |
|
|
| return method1, method2, topK, blur, strengthC, strengthO, stepS, stepE, sigmaWeight, softClampS, softClampE, custom |
|
|
|
|
| def patch(self, model): |
| model_sampling = model.model.model_sampling |
| sigmin = model_sampling.sigma(model_sampling.timestep(model_sampling.sigma_min)) |
| sigmax = model_sampling.sigma(model_sampling.timestep(model_sampling.sigma_max)) |
|
|
|
|
| |
| def soft_clamp_tensor(input_tensor, threshold=3.5, boundary=4): |
| if max(abs(input_tensor.max()), abs(input_tensor.min())) < 4: |
| return input_tensor |
| channel_dim = 1 |
|
|
| max_vals = input_tensor.max(channel_dim, keepdim=True)[0] |
| max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold |
| over_mask = (input_tensor > threshold) |
|
|
| min_vals = input_tensor.min(channel_dim, keepdim=True)[0] |
| min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold |
| under_mask = (input_tensor < -threshold) |
|
|
| return torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor)) |
|
|
| def center_latent_mean_values(latent, multiplier): |
| thisStep = shared.state.sampling_step |
| lastStep = shared.state.sampling_steps |
|
|
| channelMultiplier = multiplier * self.strengthC |
| fullMultiplier = multiplier * self.strengthO |
| |
| if thisStep >= self.stepS * lastStep and thisStep <= self.stepE * lastStep: |
| for b in range(len(latent)): |
| for c in range(4): |
| custom = None |
| channel = latent[b][c] |
|
|
| if self.method1 == "mean": |
| custom = "M" |
| |
| |
|
|
| elif self.method1 == "median": |
| custom = "m" |
| |
| |
|
|
| elif self.method1 == "mean/median average": |
| custom = "0.5 * (M+m)" |
| |
| |
| |
|
|
| elif self.method1 == "centered mean": |
| custom="rM(self.topK, 1.0-self.topK)" |
| |
| |
| |
| |
| |
| |
| |
|
|
| elif self.method1 == "average of extremes": |
| custom="0.5 * (inner_rL(self.topK) + inner_rH(self.topK))" |
| |
| |
| |
| |
|
|
| elif self.method1 == "average of quantiles": |
| custom="0.5 * (q(self.topK) + q(1.0-self.topK))" |
| |
| |
|
|
| elif self.method1 == "custom": |
| custom = self.custom |
|
|
| if custom != None: |
| M = torch.mean(channel) |
| m = torch.quantile(channel, 0.5) |
| def q(quant): |
| return torch.quantile(channel, quant) |
| def qa(quant): |
| return torch.quantile(abs(channel), quant) |
| def inner_rL(lo): |
| valuesLo = torch.topk(channel, int(len(channel)*lo), largest=False).values |
| return torch.mean(valuesLo).item() |
| def inner_rH(hi): |
| valuesHi = torch.topk(channel, int(len(channel)*hi), largest=True).values |
| return torch.mean(valuesHi).item() |
| |
| def rM(rangelo, rangehi): |
| if rangelo == rangehi: |
| return M |
| else: |
| averageHi = inner_rH(1.0-rangehi) |
| averageLo = inner_rL(rangelo) |
|
|
| average = torch.mean(channel).item() * len(channel) |
| average -= averageLo * len(channel) * rangelo |
| average -= averageHi * len(channel) * (1.0-rangehi) |
| average /= len(channel)*(rangehi - rangelo) |
| return average |
|
|
| averageMid = eval(custom) |
| latent[b][c] -= averageMid * channelMultiplier |
| |
| if self.method2 == "mean": |
| latent[b] -= latent[b].mean() * fullMultiplier |
| elif self.method2 == "median": |
| latent[b] -= latent[b].median() * fullMultiplier |
| elif self.method2 == "mean/median average": |
| mm = latent[b].mean() + latent[b].median() |
| latent[b] -= 0.5 * fullMultiplier * mm |
| elif self.method2 == "center to quantile": |
| quantile = torch.quantile(latent[b].flatten(), self.topK) |
| latent[b] -= quantile * fullMultiplier |
| elif self.method2 == "local average" and fullMultiplier != 0.0 and self.blur != 0: |
| minDim = min(latent.size(2), latent.size(3)) |
| if minDim % 2 == 0: |
| minDim -= 1 |
| blurSize = min (minDim, 1+self.blur+self.blur) |
| |
| blurred = TF.gaussian_blur(latent[b], blurSize) |
| torch.lerp(latent[b], blurred, fullMultiplier, out=latent[b]) |
| del blurred |
|
|
|
|
| if thisStep >= self.softClampS * lastStep and thisStep <= self.softClampE * lastStep: |
| for b in range(len(latent)): |
| latent[b] = soft_clamp_tensor (latent[b]) |
|
|
|
|
| return latent |
|
|
|
|
| def map_sigma(sigma, sigmax, sigmin): |
| return (sigma - sigmin) / (sigmax - sigmin) |
|
|
| def center_mean_latent_post_cfg(args): |
| denoised = args["denoised"] |
| sigma = args["sigma"][0] |
|
|
| if self.sigmaWeight == "None": |
| mult = 1 |
| else: |
| mult = map_sigma(sigma, sigmax, sigmin) |
| if self.sigmaWeight == "Soft": |
| mult += 1.0 |
| mult /= 2.0 |
|
|
| denoised = center_latent_mean_values(denoised, mult) |
| return denoised |
|
|
| m = model.clone() |
| m.set_model_sampler_post_cfg_function(center_mean_latent_post_cfg) |
|
|
| return (m, ) |
|
|
|
|
| def process(self, params, *script_args, **kwargs): |
| method1, method2, topK, blur, strengthC, strengthO, stepS, stepE, sigmaWeight, softClampS, softClampE, custom = script_args |
|
|
| if method1 == "None" and method2 == "None": |
| return |
|
|
| self.method1 = method1 |
| self.method2 = method2 |
| self.topK = topK |
| self.blur = blur |
| self.strengthC = strengthC |
| self.strengthO = strengthO |
| self.stepS = stepS |
| self.stepE = stepE |
| self.sigmaWeight = sigmaWeight |
| self.softClampS = softClampS |
| self.softClampE = softClampE |
| self.custom = custom |
| |
| |
| |
| params.extra_generation_params.update(dict( |
| ldc_method1 = method1, |
| ldc_method2 = method2, |
| ldc_strengthC = strengthC, |
| ldc_strengthO = strengthO, |
| ldc_stepS = stepS, |
| ldc_stepE = stepE, |
| ldc_sigW = sigmaWeight, |
| ldc_softClampS = softClampS, |
| ldc_softClampE = softClampE, |
| )) |
| if method1 == "custom": |
| params.extra_generation_params.update(dict(ldc_custom = custom, )) |
| if method1 == "centered mean" or method1 == "average of extremes" or method1 == "average of quantiles" or method2 == "center to quantile": |
| params.extra_generation_params.update(dict(ldc_topK = topK, )) |
| if method2 == "local average": |
| params.extra_generation_params.update(dict(ldc_blur = blur, )) |
|
|
| return |
|
|
|
|
| def process_before_every_sampling(self, params, *script_args, **kwargs): |
| method1, method2 = script_args[0], script_args[1] |
| if method1 != "None" or method2 != "None": |
| unet = params.sd_model.forge_objects.unet |
| unet = driftrForge.patch(self, unet)[0] |
| params.sd_model.forge_objects.unet = unet |
|
|
| return |
|
|