| """ |
| scripts/mega_freeu.py - Mega FreeU for A1111 / Forge |
| |
| Combined from 5 sources: |
| 1. sd-webui-freeu th.cat hijack, V1/V2 backbone, box filter, schedule, |
| presets JSON, PNG metadata, XYZ, ControlNet, region masking, |
| dict-API compat (alwayson_scripts legacy) |
| 2. WAS FreeU_Advanced 9 blending modes, 13 multi-scale FFT presets, override_scales, |
| Post-CFG Shift (WAS_PostCFGShift ported to A1111 callback) |
| NOTE: target_block / input_block / middle_block / slice_b1/b2 |
| were not ported — th.cat hijack works on output-side skip concat. |
| 3. ComfyUI_FreeU_V2_Adv Gaussian filter, Adaptive Cap (MAX_CAP_ITER=3), |
| independent B/S timestep ranges per-stage, channel_threshold |
| 4. FreeU_V2_timestepadd b_start/b_end%, s_start/s_end% per-stage gating |
| NOTE: gating uses step-fraction (cur/total), not percent_to_sigma |
| as in original ComfyUI sources. Conceptually equivalent. |
| 5. nrs_kohaku_v3.5 hf_boost param, on_cpu_devices dict, gaussian standalone |
| |
| BUGS FIXED vs sdwebui-freeU-extension: |
| BUG 1: bool mask in Fourier filter (scale multiplication was NOOP) |
| BUG 2: single-quadrant mask instead of symmetric center |
| """ |
| import dataclasses |
| import json |
| from typing import List |
|
|
| import gradio as gr |
| from modules import script_callbacks, scripts, shared, processing |
|
|
| from lib_mega_freeu import global_state, unet, xyz_grid |
|
|
| _steps_comps = {"txt2img": None, "img2img": None} |
| _steps_cbs = {"txt2img": [], "img2img": []} |
|
|
| _SF = [f.name for f in dataclasses.fields(global_state.StageInfo)] |
| _SN = len(_SF) |
|
|
|
|
| def _stage_ui(idx, si, elem_id_fn): |
| n = idx + 1 |
| ch = {0: "~1280ch (deep)", 1: "~640ch (mid)", 2: "~320ch (shallow)"}.get(idx, f"stage{n}") |
|
|
| with gr.Accordion(open=(idx < 2), label=f"Stage {n} ({ch})"): |
|
|
| |
| gr.HTML(f"<p style=\'margin:4px 0;font-size:.82em;color:#aaa;\'>Backbone h (B)</p>") |
| with gr.Row(): |
| bf = gr.Slider(label=f"B{n} Scale", minimum=-1, maximum=3, step=0.001, |
| value=si.backbone_factor, |
| info=">1 strengthens backbone features. V2: adaptive per-region.") |
| bo = gr.Slider(label=f"B{n} Offset", minimum=0, maximum=1, step=0.001, |
| value=si.backbone_offset, info="Channel region start [0-1].") |
| bw = gr.Slider(label=f"B{n} Width", minimum=-1, maximum=1, step=0.001, |
| value=si.backbone_width, info="Channel region width. Negative=invert.") |
| with gr.Row(): |
| bm = gr.Dropdown(label=f"B{n} Blend Mode", |
| choices=global_state.BLEND_MODE_NAMES, |
| value=si.backbone_blend_mode, |
| info="lerp=default, stable_slerp=quality, inject=additive") |
| bb = gr.Slider(label=f"B{n} Blend Str", minimum=0, maximum=2, step=0.001, |
| value=si.backbone_blend) |
| gr.HTML("<p style=\'font-size:.75em;color:#888;margin:2px 0;\'>B timestep range (ComfyUI V2)</p>") |
| with gr.Row(): |
| bsr = gr.Slider(label=f"B{n} Start%", minimum=0, maximum=1, step=0.001, |
| value=si.b_start_ratio, info="B activates at this step fraction.") |
| ber = gr.Slider(label=f"B{n} End%", minimum=0, maximum=1, step=0.001, |
| value=si.b_end_ratio, info="B stops. 0.35=structure phase only.") |
|
|
| |
| gr.HTML(f"<p style=\'margin:8px 0 4px;font-size:.82em;color:#aaa;\'>Skip h_skip (S) - Fourier Filter</p>") |
| with gr.Row(): |
| sf = gr.Slider(label=f"S{n} LF Scale", minimum=-1, maximum=3, step=0.001, |
| value=si.skip_factor, |
| info="<1 suppresses LF components. 0.2=strong suppression.") |
| she = gr.Slider(label=f"S{n} HF (Box)", minimum=-1, maximum=3, step=0.001, |
| value=si.skip_high_end_factor, |
| info="HF scale outside LF region (box filter). >1=boost HF.") |
| hfb = gr.Slider(label=f"S{n} HF Boost (Gauss)", minimum=0, maximum=3, step=0.001, |
| value=si.hf_boost, |
| info="Gaussian explicit HF multiplier. Combined as max(hf_boost, high_end).") |
| with gr.Row(): |
| ft = gr.Radio(label=f"S{n} FFT Type", |
| choices=global_state.FFT_TYPES, value=si.fft_type, |
| info="gaussian=smooth no-ringing. box=original FreeU (both bugs fixed).") |
| sco = gr.Slider(label=f"S{n} Cutoff (Box)", minimum=0, maximum=1, step=0.001, |
| value=si.skip_cutoff, info="Box: LF cutoff fraction. 0=1px default.") |
| srr = gr.Slider(label=f"S{n} Radius (Gauss)", minimum=0.01, maximum=0.5, step=0.001, |
| value=si.fft_radius_ratio, |
| info="Gaussian R=ratio*min(H,W). 0.07=moderate LF.") |
| gr.HTML("<p style=\'font-size:.75em;color:#888;margin:2px 0;\'>S timestep range (ComfyUI V2)</p>") |
| with gr.Row(): |
| ssr = gr.Slider(label=f"S{n} Start%", minimum=0, maximum=1, step=0.001, |
| value=si.s_start_ratio, |
| info="S activates. Tip: set = B End% for clean phase separation.") |
| ser = gr.Slider(label=f"S{n} End%", minimum=0, maximum=1, step=0.001, |
| value=si.s_end_ratio, info="S stops. 1.0=to last step.") |
|
|
| |
| gr.HTML("<p style=\'font-size:.75em;color:#888;margin:4px 0;\'>Adaptive Cap - prevents LF over-attenuation (FreeU_S1S2.py)</p>") |
| with gr.Row(): |
| eac = gr.Checkbox(label=f"S{n} Enable Cap", value=si.enable_adaptive_cap, |
| info="Iteratively weakens Gaussian if LF/HF drop exceeds threshold.") |
| ct = gr.Slider(label="Threshold", minimum=0, maximum=1, step=0.001, |
| value=si.cap_threshold, info="Max allowed LF/HF ratio drop. 0.35=35%.") |
| cf = gr.Slider(label="Factor", minimum=0, maximum=1, step=0.001, |
| value=si.cap_factor, info="Relaxation factor. 0.6=moderate.") |
| cm = gr.Radio(label="Mode", choices=["adaptive", "fixed"], |
| value=si.adaptive_cap_mode, |
| info="adaptive: scales factor with over-attenuation. fixed: always cap_factor.") |
|
|
| |
| return [bf, sf, bo, bw, sco, she, bm, bb, bsr, ber, ssr, ser, ft, srr, hfb, eac, ct, cf, cm] |
|
|
|
|
| class MegaFreeUScript(scripts.Script): |
|
|
| def title(self): return "Mega FreeU" |
| def show(self, is_img2img): return scripts.AlwaysVisible |
|
|
| def ui(self, is_img2img): |
| global_state.reload_presets() |
| pnames = list(global_state.all_presets.keys()) |
| def_sis = global_state.all_presets[pnames[0]].stage_infos |
|
|
| with gr.Accordion(open=False, label="Mega FreeU"): |
|
|
| |
| with gr.Row(): |
| enabled = gr.Checkbox(label="Enable Mega FreeU", value=False) |
| version = gr.Dropdown( |
| label="Version", |
| choices=list(global_state.ALL_VERSIONS.keys()), |
| value="Version 2", |
| elem_id=self.elem_id("version"), |
| info="V2=adaptive hidden-mean backbone. V1=flat multiplier.") |
|
|
| with gr.Row(): |
| preset_dd = gr.Dropdown( |
| label="Preset", choices=pnames, value=pnames[0], |
| allow_custom_value=True, |
| elem_id=self.elem_id("preset_name"), |
| info="Apply loads settings. Custom name enables Save. Delete auto-saves.") |
| btn_apply = gr.Button("Apply", size="sm", elem_classes="tool") |
| btn_save = gr.Button("Save", size="sm", elem_classes="tool") |
| btn_refresh = gr.Button("Refresh", size="sm", elem_classes="tool") |
| btn_delete = gr.Button("Delete", size="sm", elem_classes="tool") |
|
|
| |
| gr.HTML("<p style=\'font-size:.82em;color:#aaa;margin:6px 0 2px;\'>Global Schedule</p>") |
| with gr.Row(): |
| start_r = gr.Slider(label="Start At", elem_id=self.elem_id("start_at_step"), |
| minimum=0, maximum=1, step=0.001, value=0) |
| stop_r = gr.Slider(label="Stop At", elem_id=self.elem_id("stop_at_step"), |
| minimum=0, maximum=1, step=0.001, value=1) |
| smooth = gr.Slider(label="Transition Smoothness", |
| elem_id=self.elem_id("transition_smoothness"), |
| minimum=0, maximum=1, step=0.001, value=0, |
| info="0=hard on/off. 1=smooth fade.") |
|
|
| |
| with gr.Accordion(open=False, label="Box Multi-Scale FFT (WAS FreeU_Advanced)"): |
| gr.HTML("<p style=\'font-size:.8em;color:#888;\'>Applied on top of Box filter. Ignored in Gaussian mode.</p>") |
| with gr.Row(): |
| ms_mode = gr.Dropdown(label="Multiscale Mode", |
| choices=list(global_state.MSCALES.keys()), |
| value="Default") |
| ms_str = gr.Slider(label="Strength", minimum=0, maximum=1, |
| step=0.001, value=1.0) |
| ov_scales = gr.Textbox( |
| label="Override Scales (WAS format: radius_px, scale per line, # comments)", |
| lines=3, |
| placeholder="# Example custom scales:\n10, 1.5\n20, 0.8", |
| value="") |
|
|
| with gr.Row(): |
| ch_thresh = gr.Slider( |
| label="Channel Match Threshold (+-)", |
| elem_id=self.elem_id("ch_thresh"), |
| minimum=0, maximum=256, step=1, value=96, |
| info="Stage channel tolerance. 96=standard (FreeU_B1B2.py default).") |
|
|
| |
| flat_comps: List = [] |
| for i in range(global_state.STAGES_COUNT): |
| si = def_sis[i] if i < len(def_sis) else global_state.StageInfo() |
| flat_comps.extend(_stage_ui(i, si, self.elem_id)) |
|
|
| |
| with gr.Accordion(open=False, label="Post-CFG Shift (WAS_PostCFGShift -> A1111 callback)"): |
| gr.HTML("<p style=\'font-size:.8em;color:#888;\'>Runs after combine_denoised. Blends denoised*b into output via on_cfg_after_cfg callback.</p>") |
| with gr.Row(): |
| pcfg_en = gr.Checkbox(label="Enable Post-CFG Shift", value=False) |
| pcfg_steps = gr.Slider(label="Max Steps", minimum=1, maximum=200, |
| step=1, value=20, |
| info="Apply only to first N steps.") |
| with gr.Row(): |
| pcfg_mode = gr.Dropdown(label="Blend Mode", |
| choices=global_state.BLEND_MODE_NAMES, |
| value="inject") |
| pcfg_bl = gr.Slider(label="Blend", minimum=0, maximum=5, |
| step=0.001, value=1.0) |
| pcfg_b = gr.Slider(label="B Factor", minimum=0, maximum=5, |
| step=0.001, value=1.1, |
| info=">1 amplifies shift.") |
| with gr.Row(): |
| pcfg_fou = gr.Checkbox(label="Apply Fourier Filter", value=False) |
| pcfg_mmd = gr.Dropdown(label="Fourier Multiscale", |
| choices=list(global_state.MSCALES.keys()), |
| value="Default") |
| pcfg_mst = gr.Slider(label="Fourier Strength", minimum=0, maximum=1, |
| step=0.001, value=1.0) |
| with gr.Row(): |
| pcfg_thr = gr.Slider(label="Threshold (px)", minimum=1, maximum=20, |
| step=1, value=1, |
| info="Box filter LF radius in pixels.") |
| pcfg_s = gr.Slider(label="S Scale", minimum=0, maximum=3, |
| step=0.001, value=0.5) |
| pcfg_gain = gr.Slider(label="Force Gain", minimum=0, maximum=5, |
| step=0.01, value=1.0, |
| info="Final output multiplier.") |
|
|
| verbose = gr.Checkbox(label="Verbose Logging (Adaptive Cap, energy stats)", value=False) |
|
|
| |
| sched_info = gr.HTML(visible=False) |
| stages_info = gr.HTML(visible=False) |
| version_info = gr.HTML(visible=False) |
| ms_mode_info = gr.HTML(visible=False) |
| ms_str_info = gr.HTML(visible=False) |
| ov_scales_info = gr.HTML(visible=False) |
| ch_thresh_info = gr.HTML(visible=False) |
| postcfg_info = gr.HTML(visible=False) |
| verbose_info = gr.HTML(visible=False) |
| |
| legacy_sched_info = gr.HTML(visible=False) |
| legacy_stages_info = gr.HTML(visible=False) |
| legacy_version_info = gr.HTML(visible=False) |
|
|
| |
| def _btn_upd(name): |
| ex = name in global_state.all_presets |
| usr = name not in global_state.default_presets |
| return (gr.update(interactive=ex), |
| gr.update(interactive=usr), |
| gr.update(interactive=usr and ex)) |
|
|
| preset_dd.change(fn=_btn_upd, inputs=[preset_dd], |
| outputs=[btn_apply, btn_save, btn_delete]) |
|
|
| def _apply_p(name): |
| p = global_state.all_presets.get(name) |
| n_extras = 20 |
| if p is None: |
| return [gr.skip()] * (n_extras + len(flat_comps)) |
| flat = [] |
| for si in p.stage_infos: |
| for f in _SF: |
| flat.append(getattr(si, f)) |
| vlabel = global_state.REVERSED_VERSIONS.get(p.version, "Version 2") |
| return ( |
| gr.update(value=p.start_ratio), |
| gr.update(value=p.stop_ratio), |
| gr.update(value=p.transition_smoothness), |
| gr.update(value=vlabel), |
| gr.update(value=p.multiscale_mode), |
| gr.update(value=p.multiscale_strength), |
| gr.update(value=p.override_scales), |
| gr.update(value=p.channel_threshold), |
| gr.update(value=p.pcfg_enabled), |
| gr.update(value=p.pcfg_steps), |
| gr.update(value=p.pcfg_mode), |
| gr.update(value=p.pcfg_blend), |
| gr.update(value=p.pcfg_b), |
| gr.update(value=p.pcfg_fourier), |
| gr.update(value=p.pcfg_ms_mode), |
| gr.update(value=p.pcfg_ms_str), |
| gr.update(value=p.pcfg_threshold), |
| gr.update(value=p.pcfg_s), |
| gr.update(value=p.pcfg_gain), |
| gr.update(value=p.verbose), |
| *[gr.update(value=v) for v in flat], |
| ) |
|
|
| btn_apply.click( |
| fn=_apply_p, |
| inputs=[preset_dd], |
| outputs=[ |
| start_r, stop_r, smooth, version, |
| ms_mode, ms_str, ov_scales, ch_thresh, |
| pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, |
| pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain, |
| verbose, |
| *flat_comps, |
| ] |
| ) |
|
|
| def _save_p( |
| name, sr, sp, sm, ver, msm, mss, ovs, cht, |
| p_en, p_steps, p_mode, p_bl, p_b, |
| p_four, p_mmd, p_mst, p_thr, p_s, p_gain, |
| v_log, |
| *flat |
| ): |
| sis = _flat_to_sis(flat) |
| vc = global_state.ALL_VERSIONS.get(ver, "1") |
| global_state.all_presets[name] = global_state.State( |
| start_ratio=sr, stop_ratio=sp, transition_smoothness=sm, |
| version=vc, |
| multiscale_mode=msm, |
| multiscale_strength=float(mss), |
| override_scales=ovs or "", |
| channel_threshold=int(cht), |
| stage_infos=sis, |
| pcfg_enabled=bool(p_en), |
| pcfg_steps=int(p_steps), |
| pcfg_mode=str(p_mode), |
| pcfg_blend=float(p_bl), |
| pcfg_b=float(p_b), |
| pcfg_fourier=bool(p_four), |
| pcfg_ms_mode=str(p_mmd), |
| pcfg_ms_str=float(p_mst), |
| pcfg_threshold=int(p_thr), |
| pcfg_s=float(p_s), |
| pcfg_gain=float(p_gain), |
| verbose=bool(v_log), |
| ) |
| global_state.save_presets() |
| return ( |
| gr.update(choices=list(global_state.all_presets.keys())), |
| gr.update(interactive=True), |
| gr.update(interactive=True), |
| ) |
|
|
| btn_save.click( |
| fn=_save_p, |
| inputs=[ |
| preset_dd, start_r, stop_r, smooth, version, |
| ms_mode, ms_str, ov_scales, ch_thresh, |
| pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, |
| pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain, |
| verbose, |
| *flat_comps, |
| ], |
| outputs=[preset_dd, btn_apply, btn_delete] |
| ) |
|
|
| def _refresh_p(name): |
| global_state.reload_presets() |
| ex = name in global_state.all_presets |
| usr = name not in global_state.default_presets |
| ch = list(global_state.all_presets.keys()) |
| return (gr.update(choices=ch, value=name), |
| gr.update(interactive=ex), gr.update(interactive=usr), |
| gr.update(interactive=usr and ex)) |
|
|
| btn_refresh.click(fn=_refresh_p, inputs=[preset_dd], |
| outputs=[preset_dd, btn_apply, btn_save, btn_delete]) |
|
|
| def _delete_p(name): |
| if name in global_state.all_presets and name not in global_state.default_presets: |
| idx = list(global_state.all_presets.keys()).index(name) |
| del global_state.all_presets[name] |
| global_state.save_presets() |
| names = list(global_state.all_presets.keys()) |
| name = names[min(idx, len(names) - 1)] |
| ex = name in global_state.all_presets |
| usr = name not in global_state.default_presets |
| return (gr.update(choices=list(global_state.all_presets.keys()), value=name), |
| gr.update(interactive=ex), gr.update(interactive=usr), |
| gr.update(interactive=usr and ex)) |
|
|
| btn_delete.click(fn=_delete_p, inputs=[preset_dd], |
| outputs=[preset_dd, btn_apply, btn_save, btn_delete]) |
|
|
| |
| def _restore_sched(info, steps): |
| if not info: return [gr.skip()] * 4 |
| try: |
| parts = info.split(", ") |
| sr, sp, sm = parts[0], parts[1], parts[2] |
| total = max(int(float(steps)), 1) |
| def _r(v): |
| n = float(v.strip()) |
| return n / total if n > 1.0 else n |
| return (gr.update(value=""), gr.update(value=_r(sr)), |
| gr.update(value=_r(sp)), gr.update(value=float(sm))) |
| except Exception: |
| return [gr.skip()] * 4 |
|
|
| def _reg_sched_cb(steps_comp): |
| sched_info.change(fn=_restore_sched, |
| inputs=[sched_info, steps_comp], |
| outputs=[sched_info, start_r, stop_r, smooth]) |
|
|
| mode_key = "img2img" if is_img2img else "txt2img" |
| if _steps_comps[mode_key] is None: |
| _steps_cbs[mode_key].append(_reg_sched_cb) |
| else: |
| _reg_sched_cb(_steps_comps[mode_key]) |
|
|
| def _restore_stages(info): |
| n_out = 2 + len(flat_comps) |
| if not info: return [gr.skip()] * n_out |
| try: |
| raw_list = json.loads(info) |
| sis = [] |
| for d in raw_list: |
| known = {k: v for k, v in d.items() |
| if k in global_state.STAGE_FIELD_NAMES} |
| sis.append(global_state.StageInfo(**known)) |
| while len(sis) < global_state.STAGES_COUNT: |
| sis.append(global_state.StageInfo()) |
| except Exception: |
| return [gr.skip()] * n_out |
| flat = [] |
| for si in sis: |
| for f in _SF: |
| flat.append(getattr(si, f, getattr(global_state.StageInfo(), f))) |
| auto_en = shared.opts.data.get("mega_freeu_png_auto_enable", True) |
| return (gr.update(value=""), gr.update(value=auto_en), |
| *[gr.update(value=v) for v in flat]) |
|
|
| stages_info.change(fn=_restore_stages, inputs=[stages_info], |
| outputs=[stages_info, enabled, *flat_comps]) |
|
|
| def _restore_ver(info): |
| if not info: return [gr.skip()] * 2 |
| lbl = global_state.REVERSED_VERSIONS.get(info.strip(), info.strip()) |
| return gr.update(value=""), gr.update(value=lbl) |
|
|
| version_info.change(fn=_restore_ver, inputs=[version_info], |
| outputs=[version_info, version]) |
|
|
| |
| def _restore_ms_mode(info): |
| if not info: return gr.skip(), gr.skip() |
| return gr.update(value=""), gr.update(value=info.strip()) |
|
|
| def _restore_ms_str(info): |
| if not info: return gr.skip(), gr.skip() |
| try: return gr.update(value=""), gr.update(value=float(info.strip())) |
| except Exception: return gr.skip(), gr.skip() |
|
|
| def _restore_ov_scales(info): |
| if info is None: return gr.skip(), gr.skip() |
| return gr.update(value=""), gr.update(value=info) |
|
|
| def _restore_ch_thresh(info): |
| if not info: return gr.skip(), gr.skip() |
| try: return gr.update(value=""), gr.update(value=int(float(info.strip()))) |
| except Exception: return gr.skip(), gr.skip() |
|
|
| def _restore_verbose(info): |
| if not info: return gr.skip(), gr.skip() |
| return gr.update(value=""), gr.update(value=(info.strip().lower() == "true")) |
|
|
| def _restore_postcfg(info): |
| n = 12 |
| if not info: return [gr.skip()] * n |
| try: |
| d = json.loads(info) |
| return ( |
| gr.update(value=""), |
| gr.update(value=bool(d.get("enabled", False))), |
| gr.update(value=int(d.get("steps", 20))), |
| gr.update(value=str(d.get("mode", "inject"))), |
| gr.update(value=float(d.get("blend", 1.0))), |
| gr.update(value=float(d.get("b", 1.1))), |
| gr.update(value=bool(d.get("fourier", False))), |
| gr.update(value=str(d.get("ms_mode", "Default"))), |
| gr.update(value=float(d.get("ms_str", 1.0))), |
| gr.update(value=int(d.get("threshold", 1))), |
| gr.update(value=float(d.get("s", 0.5))), |
| gr.update(value=float(d.get("gain", 1.0))), |
| ) |
| except Exception: |
| return [gr.skip()] * n |
|
|
| ms_mode_info.change(fn=_restore_ms_mode, inputs=[ms_mode_info], |
| outputs=[ms_mode_info, ms_mode]) |
| ms_str_info.change(fn=_restore_ms_str, inputs=[ms_str_info], |
| outputs=[ms_str_info, ms_str]) |
| ov_scales_info.change(fn=_restore_ov_scales, inputs=[ov_scales_info], |
| outputs=[ov_scales_info, ov_scales]) |
| ch_thresh_info.change(fn=_restore_ch_thresh, inputs=[ch_thresh_info], |
| outputs=[ch_thresh_info, ch_thresh]) |
| verbose_info.change(fn=_restore_verbose, inputs=[verbose_info], |
| outputs=[verbose_info, verbose]) |
| postcfg_info.change(fn=_restore_postcfg, inputs=[postcfg_info], |
| outputs=[postcfg_info, |
| pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, |
| pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain]) |
|
|
| |
| legacy_sched_info.change(fn=lambda info, steps: _restore_sched(info, steps), |
| inputs=[legacy_sched_info, _steps_comps.get(mode_key) or legacy_sched_info], |
| outputs=[legacy_sched_info, start_r, stop_r, smooth]) |
| legacy_stages_info.change(fn=_restore_stages, inputs=[legacy_stages_info], |
| outputs=[legacy_stages_info, enabled, *flat_comps]) |
| legacy_version_info.change(fn=_restore_ver, inputs=[legacy_version_info], |
| outputs=[legacy_version_info, version]) |
|
|
| self.infotext_fields = [ |
| (sched_info, "MegaFreeU Schedule"), |
| (stages_info, "MegaFreeU Stages"), |
| (version_info, "MegaFreeU Version"), |
| (ms_mode_info, "MegaFreeU Multiscale Mode"), |
| (ms_str_info, "MegaFreeU Multiscale Strength"), |
| (ov_scales_info, "MegaFreeU Override Scales"), |
| (ch_thresh_info, "MegaFreeU Channel Threshold"), |
| (postcfg_info, "MegaFreeU PostCFG"), |
| (verbose_info, "MegaFreeU Verbose"), |
| |
| (legacy_sched_info, "FreeU Schedule"), |
| (legacy_stages_info, "FreeU Stages"), |
| (legacy_version_info,"FreeU Version"), |
| ] |
| self.paste_field_names = [f for _, f in self.infotext_fields] |
|
|
| return [ |
| enabled, version, preset_dd, |
| start_r, stop_r, smooth, |
| ms_mode, ms_str, ov_scales, |
| ch_thresh, |
| *flat_comps, |
| pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, |
| pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain, |
| verbose, |
| ] |
|
|
| def process(self, p: processing.StableDiffusionProcessing, *args): |
| |
| if args and isinstance(args[0], dict): |
| global_state.instance = global_state.State(**{ |
| k: v for k, v in args[0].items() |
| if k in {f.name for f in dataclasses.fields(global_state.State)} |
| }) |
| global_state.apply_xyz() |
| global_state.xyz_attrs.clear() |
| st = global_state.instance |
| unet.verbose_ref.value = bool(getattr(st, "verbose", False)) |
| if getattr(st, "pcfg_enabled", False): |
| p._mega_pcfg = { |
| "enabled": True, |
| "steps": st.pcfg_steps, |
| "mode": st.pcfg_mode, |
| "blend": st.pcfg_blend, |
| "b": st.pcfg_b, |
| "fourier": st.pcfg_fourier, |
| "ms_mode": st.pcfg_ms_mode, |
| "ms_str": st.pcfg_ms_str, |
| "threshold": st.pcfg_threshold, |
| "s": st.pcfg_s, |
| "gain": st.pcfg_gain, |
| "step": 0, |
| } |
| else: |
| p._mega_pcfg = {"enabled": False} |
| if st.enable: |
| unet.detect_model_channels() |
| unet._on_cpu_devices.clear() |
| _write_generation_params(p, st) |
| return |
|
|
| |
| (enabled, version, preset_dd, |
| start_r, stop_r, smooth, |
| ms_mode, ms_str, ov_scales, |
| ch_thresh, *rest) = args |
|
|
| n_sv = _SN * global_state.STAGES_COUNT |
| flat_stage = rest[:n_sv] |
| post = rest[n_sv:] |
|
|
| verbose = bool(post[11]) if len(post) > 11 else False |
| unet.verbose_ref.value = verbose |
|
|
| |
| inst = global_state.instance |
| inst.enable = bool(enabled) |
| inst.start_ratio = start_r |
| inst.stop_ratio = stop_r |
| inst.transition_smoothness = smooth |
| inst.version = global_state.ALL_VERSIONS.get(version, "1") |
| inst.multiscale_mode = ms_mode |
| inst.multiscale_strength = float(ms_str) |
| inst.override_scales = ov_scales or "" |
| inst.channel_threshold = int(ch_thresh) |
| inst.stage_infos = _flat_to_sis(flat_stage) |
|
|
| |
| pcfg = post[:11] |
| if len(pcfg) >= 11: |
| inst.pcfg_enabled = bool(pcfg[0]) |
| inst.pcfg_steps = int(pcfg[1]) |
| inst.pcfg_mode = str(pcfg[2]) |
| inst.pcfg_blend = float(pcfg[3]) |
| inst.pcfg_b = float(pcfg[4]) |
| inst.pcfg_fourier = bool(pcfg[5]) |
| inst.pcfg_ms_mode = str(pcfg[6]) |
| inst.pcfg_ms_str = float(pcfg[7]) |
| inst.pcfg_threshold = int(pcfg[8]) |
| inst.pcfg_s = float(pcfg[9]) |
| inst.pcfg_gain = float(pcfg[10]) |
| inst.verbose = verbose |
|
|
| |
| |
| global_state.apply_xyz() |
| global_state.xyz_attrs.clear() |
| st = global_state.instance |
|
|
| |
| if st.pcfg_enabled: |
| p._mega_pcfg = { |
| "enabled": True, |
| "steps": st.pcfg_steps, |
| "mode": st.pcfg_mode, |
| "blend": st.pcfg_blend, |
| "b": st.pcfg_b, |
| "fourier": st.pcfg_fourier, |
| "ms_mode": st.pcfg_ms_mode, |
| "ms_str": st.pcfg_ms_str, |
| "threshold": st.pcfg_threshold, |
| "s": st.pcfg_s, |
| "gain": st.pcfg_gain, |
| "step": 0, |
| } |
| else: |
| p._mega_pcfg = {"enabled": False} |
|
|
| if not st.enable: |
| |
| _write_generation_params(p, st) |
| return |
|
|
| unet.detect_model_channels() |
| unet._on_cpu_devices.clear() |
|
|
| _write_generation_params(p, st) |
|
|
| if unet.verbose_ref.value: |
| print(f"[MegaFreeU] v{st.version} " |
| f"start={st.start_ratio:.3f} stop={st.stop_ratio:.3f} " |
| f"smooth={st.transition_smoothness:.3f} " |
| f"ch_thresh=+-{st.channel_threshold}") |
| for i, si in enumerate(st.stage_infos): |
| ch = unet._stage_channels[i] if i < len(unet._stage_channels) else "?" |
| print(f" Stage {i+1} ({ch}ch): " |
| f"b={si.backbone_factor:.3f} [{si.b_start_ratio:.2f}-{si.b_end_ratio:.2f}] " |
| f"{si.backbone_blend_mode}:{si.backbone_blend:.2f} " |
| f"s={si.skip_factor:.3f} [{si.s_start_ratio:.2f}-{si.s_end_ratio:.2f}] " |
| f"fft={si.fft_type} r={si.fft_radius_ratio:.3f} " |
| f"hfe={si.skip_high_end_factor:.2f} hfb={si.hf_boost:.2f} " |
| f"cap={'ON' if si.enable_adaptive_cap else 'off'} " |
| f"({si.cap_threshold:.2f}/{si.cap_factor:.2f} {si.adaptive_cap_mode})") |
|
|
| def process_batch(self, p, *args, **kwargs): |
| global_state.current_sampling_step = 0 |
| |
| if hasattr(p, "_mega_pcfg"): |
| p._mega_pcfg["step"] = 0 |
|
|
| def postprocess(self, p, processed, *args, **kwargs): |
| """Clean up per-image state after generation.""" |
| if hasattr(p, "_mega_pcfg"): |
| p._mega_pcfg = {"enabled": False} |
|
|
|
|
| def _write_generation_params(p, st): |
| """Write full Mega FreeU state into PNG extra_generation_params.""" |
| p.extra_generation_params["MegaFreeU Schedule"] = ( |
| f"{st.start_ratio}, {st.stop_ratio}, {st.transition_smoothness}") |
| p.extra_generation_params["MegaFreeU Stages"] = ( |
| json.dumps([si.to_dict() for si in st.stage_infos])) |
| p.extra_generation_params["MegaFreeU Version"] = st.version |
| p.extra_generation_params["MegaFreeU Multiscale Mode"] = st.multiscale_mode |
| p.extra_generation_params["MegaFreeU Multiscale Strength"] = str(st.multiscale_strength) |
| p.extra_generation_params["MegaFreeU Override Scales"] = st.override_scales or "" |
| p.extra_generation_params["MegaFreeU Channel Threshold"] = str(st.channel_threshold) |
| p.extra_generation_params["MegaFreeU Verbose"] = str(st.verbose) |
| if st.pcfg_enabled: |
| p.extra_generation_params["MegaFreeU PostCFG"] = json.dumps({ |
| "enabled": st.pcfg_enabled, |
| "steps": st.pcfg_steps, |
| "mode": st.pcfg_mode, |
| "blend": st.pcfg_blend, |
| "b": st.pcfg_b, |
| "fourier": st.pcfg_fourier, |
| "ms_mode": st.pcfg_ms_mode, |
| "ms_str": st.pcfg_ms_str, |
| "threshold": st.pcfg_threshold, |
| "s": st.pcfg_s, |
| "gain": st.pcfg_gain, |
| }) |
|
|
|
|
| def _flat_to_sis(flat) -> List[global_state.StageInfo]: |
| result = [] |
| for i in range(global_state.STAGES_COUNT): |
| chunk = flat[i * _SN:(i + 1) * _SN] |
| si = global_state.StageInfo() |
| for j, fname in enumerate(_SF): |
| if j < len(chunk): |
| setattr(si, fname, chunk[j]) |
| result.append(si) |
| return result |
|
|
|
|
| |
| def _on_cfg_step(*_args, **_kwargs): |
| global_state.current_sampling_step += 1 |
|
|
| def _on_cfg_post(params): |
| """WAS_PostCFGShift ported to A1111 on_cfg_after_cfg callback (exact algorithm).""" |
| p = getattr(params, "p", None) |
| if p is None: |
| p = getattr(getattr(params, "denoiser", None), "p", None) |
| if p is None: return |
| cfg = getattr(p, "_mega_pcfg", None) |
| if not cfg or not cfg.get("enabled"): return |
| cfg["step"] = cfg.get("step", 0) + 1 |
| if cfg["step"] > cfg["steps"]: return |
| x = params.x |
| fn = unet.BLENDING_MODES.get(cfg["mode"], unet.BLENDING_MODES["inject"]) |
| y = fn(x, x * cfg["b"], cfg["blend"]) |
| if cfg["fourier"]: |
| ms = global_state.MSCALES.get(cfg["ms_mode"]) |
| y = unet.filter_skip_box_multiscale( |
| y, cfg["threshold"], cfg["s"], ms, cfg["ms_str"]) |
| if cfg["gain"] != 1.0: |
| y = y * float(cfg["gain"]) |
| params.x = y |
|
|
| try: |
| script_callbacks.on_cfg_after_cfg(_on_cfg_step) |
| script_callbacks.on_cfg_after_cfg(_on_cfg_post) |
| except AttributeError: |
| |
| script_callbacks.on_cfg_denoised(_on_cfg_step) |
| script_callbacks.on_cfg_denoised(_on_cfg_post) |
|
|
| def _on_after_component(component, **kwargs): |
| eid = kwargs.get("elem_id", "") |
| for key, sid in [("txt2img", "txt2img_steps"), ("img2img", "img2img_steps")]: |
| if eid == sid: |
| _steps_comps[key] = component |
| for cb in _steps_cbs[key]: cb(component) |
| _steps_cbs[key].clear() |
|
|
| script_callbacks.on_after_component(_on_after_component) |
|
|
| def _on_ui_settings(): |
| shared.opts.add_option( |
| "mega_freeu_png_auto_enable", |
| shared.OptionInfo( |
| default=True, |
| label="Auto-enable Mega FreeU when loading PNG info from a FreeU generation", |
| section=("mega_freeu", "Mega FreeU"))) |
|
|
| script_callbacks.on_ui_settings(_on_ui_settings) |
| script_callbacks.on_before_ui(xyz_grid.patch) |
|
|
| |
| unet.patch() |
|
|