""" scripts/mega_freeu.py - Mega FreeU for A1111 / Forge Combined from 5 sources: 1. sd-webui-freeu th.cat hijack, V1/V2 backbone, box filter, schedule, presets JSON, PNG metadata, XYZ, ControlNet, region masking, dict-API compat (alwayson_scripts legacy) 2. WAS FreeU_Advanced 9 blending modes, 13 multi-scale FFT presets, override_scales, Post-CFG Shift (WAS_PostCFGShift ported to A1111 callback) NOTE: target_block / input_block / middle_block / slice_b1/b2 were not ported — th.cat hijack works on output-side skip concat. 3. ComfyUI_FreeU_V2_Adv Gaussian filter, Adaptive Cap (MAX_CAP_ITER=3), independent B/S timestep ranges per-stage, channel_threshold 4. FreeU_V2_timestepadd b_start/b_end%, s_start/s_end% per-stage gating NOTE: gating uses step-fraction (cur/total), not percent_to_sigma as in original ComfyUI sources. Conceptually equivalent. 5. nrs_kohaku_v3.5 hf_boost param, on_cpu_devices dict, gaussian standalone BUGS FIXED vs sdwebui-freeU-extension: BUG 1: bool mask in Fourier filter (scale multiplication was NOOP) BUG 2: single-quadrant mask instead of symmetric center """ import dataclasses import json from typing import List import gradio as gr from modules import script_callbacks, scripts, shared, processing from lib_mega_freeu import global_state, unet, xyz_grid _steps_comps = {"txt2img": None, "img2img": None} _steps_cbs = {"txt2img": [], "img2img": []} _SF = [f.name for f in dataclasses.fields(global_state.StageInfo)] _SN = len(_SF) # 19 fields per stage def _stage_ui(idx, si, elem_id_fn): n = idx + 1 ch = {0: "~1280ch (deep)", 1: "~640ch (mid)", 2: "~320ch (shallow)"}.get(idx, f"stage{n}") with gr.Accordion(open=(idx < 2), label=f"Stage {n} ({ch})"): # Backbone gr.HTML(f"

Backbone h (B)

") with gr.Row(): bf = gr.Slider(label=f"B{n} Scale", minimum=-1, maximum=3, step=0.001, value=si.backbone_factor, info=">1 strengthens backbone features. V2: adaptive per-region.") bo = gr.Slider(label=f"B{n} Offset", minimum=0, maximum=1, step=0.001, value=si.backbone_offset, info="Channel region start [0-1].") bw = gr.Slider(label=f"B{n} Width", minimum=-1, maximum=1, step=0.001, value=si.backbone_width, info="Channel region width. Negative=invert.") with gr.Row(): bm = gr.Dropdown(label=f"B{n} Blend Mode", choices=global_state.BLEND_MODE_NAMES, value=si.backbone_blend_mode, info="lerp=default, stable_slerp=quality, inject=additive") bb = gr.Slider(label=f"B{n} Blend Str", minimum=0, maximum=2, step=0.001, value=si.backbone_blend) gr.HTML("

B timestep range (ComfyUI V2)

") with gr.Row(): bsr = gr.Slider(label=f"B{n} Start%", minimum=0, maximum=1, step=0.001, value=si.b_start_ratio, info="B activates at this step fraction.") ber = gr.Slider(label=f"B{n} End%", minimum=0, maximum=1, step=0.001, value=si.b_end_ratio, info="B stops. 0.35=structure phase only.") # Skip / FFT gr.HTML(f"

Skip h_skip (S) - Fourier Filter

") with gr.Row(): sf = gr.Slider(label=f"S{n} LF Scale", minimum=-1, maximum=3, step=0.001, value=si.skip_factor, info="<1 suppresses LF components. 0.2=strong suppression.") she = gr.Slider(label=f"S{n} HF (Box)", minimum=-1, maximum=3, step=0.001, value=si.skip_high_end_factor, info="HF scale outside LF region (box filter). >1=boost HF.") hfb = gr.Slider(label=f"S{n} HF Boost (Gauss)", minimum=0, maximum=3, step=0.001, value=si.hf_boost, info="Gaussian explicit HF multiplier. Combined as max(hf_boost, high_end).") with gr.Row(): ft = gr.Radio(label=f"S{n} FFT Type", choices=global_state.FFT_TYPES, value=si.fft_type, info="gaussian=smooth no-ringing. box=original FreeU (both bugs fixed).") sco = gr.Slider(label=f"S{n} Cutoff (Box)", minimum=0, maximum=1, step=0.001, value=si.skip_cutoff, info="Box: LF cutoff fraction. 0=1px default.") srr = gr.Slider(label=f"S{n} Radius (Gauss)", minimum=0.01, maximum=0.5, step=0.001, value=si.fft_radius_ratio, info="Gaussian R=ratio*min(H,W). 0.07=moderate LF.") gr.HTML("

S timestep range (ComfyUI V2)

") with gr.Row(): ssr = gr.Slider(label=f"S{n} Start%", minimum=0, maximum=1, step=0.001, value=si.s_start_ratio, info="S activates. Tip: set = B End% for clean phase separation.") ser = gr.Slider(label=f"S{n} End%", minimum=0, maximum=1, step=0.001, value=si.s_end_ratio, info="S stops. 1.0=to last step.") # Adaptive Cap gr.HTML("

Adaptive Cap - prevents LF over-attenuation (FreeU_S1S2.py)

") with gr.Row(): eac = gr.Checkbox(label=f"S{n} Enable Cap", value=si.enable_adaptive_cap, info="Iteratively weakens Gaussian if LF/HF drop exceeds threshold.") ct = gr.Slider(label="Threshold", minimum=0, maximum=1, step=0.001, value=si.cap_threshold, info="Max allowed LF/HF ratio drop. 0.35=35%.") cf = gr.Slider(label="Factor", minimum=0, maximum=1, step=0.001, value=si.cap_factor, info="Relaxation factor. 0.6=moderate.") cm = gr.Radio(label="Mode", choices=["adaptive", "fixed"], value=si.adaptive_cap_mode, info="adaptive: scales factor with over-attenuation. fixed: always cap_factor.") # Return exactly in _SF field order return [bf, sf, bo, bw, sco, she, bm, bb, bsr, ber, ssr, ser, ft, srr, hfb, eac, ct, cf, cm] class MegaFreeUScript(scripts.Script): def title(self): return "Mega FreeU" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): global_state.reload_presets() pnames = list(global_state.all_presets.keys()) def_sis = global_state.all_presets[pnames[0]].stage_infos with gr.Accordion(open=False, label="Mega FreeU"): # Top bar with gr.Row(): enabled = gr.Checkbox(label="Enable Mega FreeU", value=False) version = gr.Dropdown( label="Version", choices=list(global_state.ALL_VERSIONS.keys()), value="Version 2", elem_id=self.elem_id("version"), info="V2=adaptive hidden-mean backbone. V1=flat multiplier.") with gr.Row(): preset_dd = gr.Dropdown( label="Preset", choices=pnames, value=pnames[0], allow_custom_value=True, elem_id=self.elem_id("preset_name"), info="Apply loads settings. Custom name enables Save. Delete auto-saves.") btn_apply = gr.Button("Apply", size="sm", elem_classes="tool") btn_save = gr.Button("Save", size="sm", elem_classes="tool") btn_refresh = gr.Button("Refresh", size="sm", elem_classes="tool") btn_delete = gr.Button("Delete", size="sm", elem_classes="tool") # Global schedule gr.HTML("

Global Schedule

") with gr.Row(): start_r = gr.Slider(label="Start At", elem_id=self.elem_id("start_at_step"), minimum=0, maximum=1, step=0.001, value=0) stop_r = gr.Slider(label="Stop At", elem_id=self.elem_id("stop_at_step"), minimum=0, maximum=1, step=0.001, value=1) smooth = gr.Slider(label="Transition Smoothness", elem_id=self.elem_id("transition_smoothness"), minimum=0, maximum=1, step=0.001, value=0, info="0=hard on/off. 1=smooth fade.") # Box Multi-Scale (WAS FreeU_Advanced) with gr.Accordion(open=False, label="Box Multi-Scale FFT (WAS FreeU_Advanced)"): gr.HTML("

Applied on top of Box filter. Ignored in Gaussian mode.

") with gr.Row(): ms_mode = gr.Dropdown(label="Multiscale Mode", choices=list(global_state.MSCALES.keys()), value="Default") ms_str = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.001, value=1.0) ov_scales = gr.Textbox( label="Override Scales (WAS format: radius_px, scale per line, # comments)", lines=3, placeholder="# Example custom scales:\n10, 1.5\n20, 0.8", value="") with gr.Row(): ch_thresh = gr.Slider( label="Channel Match Threshold (+-)", elem_id=self.elem_id("ch_thresh"), minimum=0, maximum=256, step=1, value=96, info="Stage channel tolerance. 96=standard (FreeU_B1B2.py default).") # Per-stage accordions flat_comps: List = [] for i in range(global_state.STAGES_COUNT): si = def_sis[i] if i < len(def_sis) else global_state.StageInfo() flat_comps.extend(_stage_ui(i, si, self.elem_id)) # Post-CFG Shift (WAS_PostCFGShift -> A1111) with gr.Accordion(open=False, label="Post-CFG Shift (WAS_PostCFGShift -> A1111 callback)"): gr.HTML("

Runs after combine_denoised. Blends denoised*b into output via on_cfg_after_cfg callback.

") with gr.Row(): pcfg_en = gr.Checkbox(label="Enable Post-CFG Shift", value=False) pcfg_steps = gr.Slider(label="Max Steps", minimum=1, maximum=200, step=1, value=20, info="Apply only to first N steps.") with gr.Row(): pcfg_mode = gr.Dropdown(label="Blend Mode", choices=global_state.BLEND_MODE_NAMES, value="inject") pcfg_bl = gr.Slider(label="Blend", minimum=0, maximum=5, step=0.001, value=1.0) pcfg_b = gr.Slider(label="B Factor", minimum=0, maximum=5, step=0.001, value=1.1, info=">1 amplifies shift.") with gr.Row(): pcfg_fou = gr.Checkbox(label="Apply Fourier Filter", value=False) pcfg_mmd = gr.Dropdown(label="Fourier Multiscale", choices=list(global_state.MSCALES.keys()), value="Default") pcfg_mst = gr.Slider(label="Fourier Strength", minimum=0, maximum=1, step=0.001, value=1.0) with gr.Row(): pcfg_thr = gr.Slider(label="Threshold (px)", minimum=1, maximum=20, step=1, value=1, info="Box filter LF radius in pixels.") pcfg_s = gr.Slider(label="S Scale", minimum=0, maximum=3, step=0.001, value=0.5) pcfg_gain = gr.Slider(label="Force Gain", minimum=0, maximum=5, step=0.01, value=1.0, info="Final output multiplier.") verbose = gr.Checkbox(label="Verbose Logging (Adaptive Cap, energy stats)", value=False) # Hidden PNG infotext components sched_info = gr.HTML(visible=False) stages_info = gr.HTML(visible=False) version_info = gr.HTML(visible=False) ms_mode_info = gr.HTML(visible=False) ms_str_info = gr.HTML(visible=False) ov_scales_info = gr.HTML(visible=False) ch_thresh_info = gr.HTML(visible=False) postcfg_info = gr.HTML(visible=False) verbose_info = gr.HTML(visible=False) # Legacy sd-webui-freeu keys for backward compat legacy_sched_info = gr.HTML(visible=False) legacy_stages_info = gr.HTML(visible=False) legacy_version_info = gr.HTML(visible=False) # Preset buttons def _btn_upd(name): ex = name in global_state.all_presets usr = name not in global_state.default_presets return (gr.update(interactive=ex), gr.update(interactive=usr), gr.update(interactive=usr and ex)) preset_dd.change(fn=_btn_upd, inputs=[preset_dd], outputs=[btn_apply, btn_save, btn_delete]) def _apply_p(name): p = global_state.all_presets.get(name) n_extras = 20 # 8 main + 11 Post-CFG + 1 verbose if p is None: return [gr.skip()] * (n_extras + len(flat_comps)) flat = [] for si in p.stage_infos: for f in _SF: flat.append(getattr(si, f)) vlabel = global_state.REVERSED_VERSIONS.get(p.version, "Version 2") return ( gr.update(value=p.start_ratio), gr.update(value=p.stop_ratio), gr.update(value=p.transition_smoothness), gr.update(value=vlabel), gr.update(value=p.multiscale_mode), gr.update(value=p.multiscale_strength), gr.update(value=p.override_scales), gr.update(value=p.channel_threshold), gr.update(value=p.pcfg_enabled), gr.update(value=p.pcfg_steps), gr.update(value=p.pcfg_mode), gr.update(value=p.pcfg_blend), gr.update(value=p.pcfg_b), gr.update(value=p.pcfg_fourier), gr.update(value=p.pcfg_ms_mode), gr.update(value=p.pcfg_ms_str), gr.update(value=p.pcfg_threshold), gr.update(value=p.pcfg_s), gr.update(value=p.pcfg_gain), gr.update(value=p.verbose), *[gr.update(value=v) for v in flat], ) btn_apply.click( fn=_apply_p, inputs=[preset_dd], outputs=[ start_r, stop_r, smooth, version, ms_mode, ms_str, ov_scales, ch_thresh, pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain, verbose, *flat_comps, ] ) def _save_p( name, sr, sp, sm, ver, msm, mss, ovs, cht, p_en, p_steps, p_mode, p_bl, p_b, p_four, p_mmd, p_mst, p_thr, p_s, p_gain, v_log, *flat ): sis = _flat_to_sis(flat) vc = global_state.ALL_VERSIONS.get(ver, "1") global_state.all_presets[name] = global_state.State( start_ratio=sr, stop_ratio=sp, transition_smoothness=sm, version=vc, multiscale_mode=msm, multiscale_strength=float(mss), override_scales=ovs or "", channel_threshold=int(cht), stage_infos=sis, pcfg_enabled=bool(p_en), pcfg_steps=int(p_steps), pcfg_mode=str(p_mode), pcfg_blend=float(p_bl), pcfg_b=float(p_b), pcfg_fourier=bool(p_four), pcfg_ms_mode=str(p_mmd), pcfg_ms_str=float(p_mst), pcfg_threshold=int(p_thr), pcfg_s=float(p_s), pcfg_gain=float(p_gain), verbose=bool(v_log), ) global_state.save_presets() return ( gr.update(choices=list(global_state.all_presets.keys())), gr.update(interactive=True), gr.update(interactive=True), ) btn_save.click( fn=_save_p, inputs=[ preset_dd, start_r, stop_r, smooth, version, ms_mode, ms_str, ov_scales, ch_thresh, pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain, verbose, *flat_comps, ], outputs=[preset_dd, btn_apply, btn_delete] ) def _refresh_p(name): global_state.reload_presets() ex = name in global_state.all_presets usr = name not in global_state.default_presets ch = list(global_state.all_presets.keys()) return (gr.update(choices=ch, value=name), gr.update(interactive=ex), gr.update(interactive=usr), gr.update(interactive=usr and ex)) btn_refresh.click(fn=_refresh_p, inputs=[preset_dd], outputs=[preset_dd, btn_apply, btn_save, btn_delete]) def _delete_p(name): if name in global_state.all_presets and name not in global_state.default_presets: idx = list(global_state.all_presets.keys()).index(name) del global_state.all_presets[name] global_state.save_presets() names = list(global_state.all_presets.keys()) name = names[min(idx, len(names) - 1)] ex = name in global_state.all_presets usr = name not in global_state.default_presets return (gr.update(choices=list(global_state.all_presets.keys()), value=name), gr.update(interactive=ex), gr.update(interactive=usr), gr.update(interactive=usr and ex)) btn_delete.click(fn=_delete_p, inputs=[preset_dd], outputs=[preset_dd, btn_apply, btn_save, btn_delete]) # PNG schedule restore def _restore_sched(info, steps): if not info: return [gr.skip()] * 4 try: parts = info.split(", ") sr, sp, sm = parts[0], parts[1], parts[2] total = max(int(float(steps)), 1) def _r(v): n = float(v.strip()) return n / total if n > 1.0 else n return (gr.update(value=""), gr.update(value=_r(sr)), gr.update(value=_r(sp)), gr.update(value=float(sm))) except Exception: return [gr.skip()] * 4 def _reg_sched_cb(steps_comp): sched_info.change(fn=_restore_sched, inputs=[sched_info, steps_comp], outputs=[sched_info, start_r, stop_r, smooth]) mode_key = "img2img" if is_img2img else "txt2img" if _steps_comps[mode_key] is None: _steps_cbs[mode_key].append(_reg_sched_cb) else: _reg_sched_cb(_steps_comps[mode_key]) def _restore_stages(info): n_out = 2 + len(flat_comps) if not info: return [gr.skip()] * n_out try: raw_list = json.loads(info) sis = [] for d in raw_list: known = {k: v for k, v in d.items() if k in global_state.STAGE_FIELD_NAMES} sis.append(global_state.StageInfo(**known)) while len(sis) < global_state.STAGES_COUNT: sis.append(global_state.StageInfo()) except Exception: return [gr.skip()] * n_out flat = [] for si in sis: for f in _SF: flat.append(getattr(si, f, getattr(global_state.StageInfo(), f))) auto_en = shared.opts.data.get("mega_freeu_png_auto_enable", True) return (gr.update(value=""), gr.update(value=auto_en), *[gr.update(value=v) for v in flat]) stages_info.change(fn=_restore_stages, inputs=[stages_info], outputs=[stages_info, enabled, *flat_comps]) def _restore_ver(info): if not info: return [gr.skip()] * 2 lbl = global_state.REVERSED_VERSIONS.get(info.strip(), info.strip()) return gr.update(value=""), gr.update(value=lbl) version_info.change(fn=_restore_ver, inputs=[version_info], outputs=[version_info, version]) # ── New extended PNG restore callbacks ───────────────────────────── def _restore_ms_mode(info): if not info: return gr.skip(), gr.skip() return gr.update(value=""), gr.update(value=info.strip()) def _restore_ms_str(info): if not info: return gr.skip(), gr.skip() try: return gr.update(value=""), gr.update(value=float(info.strip())) except Exception: return gr.skip(), gr.skip() def _restore_ov_scales(info): if info is None: return gr.skip(), gr.skip() return gr.update(value=""), gr.update(value=info) def _restore_ch_thresh(info): if not info: return gr.skip(), gr.skip() try: return gr.update(value=""), gr.update(value=int(float(info.strip()))) except Exception: return gr.skip(), gr.skip() def _restore_verbose(info): if not info: return gr.skip(), gr.skip() return gr.update(value=""), gr.update(value=(info.strip().lower() == "true")) def _restore_postcfg(info): n = 12 if not info: return [gr.skip()] * n try: d = json.loads(info) return ( gr.update(value=""), gr.update(value=bool(d.get("enabled", False))), gr.update(value=int(d.get("steps", 20))), gr.update(value=str(d.get("mode", "inject"))), gr.update(value=float(d.get("blend", 1.0))), gr.update(value=float(d.get("b", 1.1))), gr.update(value=bool(d.get("fourier", False))), gr.update(value=str(d.get("ms_mode", "Default"))), gr.update(value=float(d.get("ms_str", 1.0))), gr.update(value=int(d.get("threshold", 1))), gr.update(value=float(d.get("s", 0.5))), gr.update(value=float(d.get("gain", 1.0))), ) except Exception: return [gr.skip()] * n ms_mode_info.change(fn=_restore_ms_mode, inputs=[ms_mode_info], outputs=[ms_mode_info, ms_mode]) ms_str_info.change(fn=_restore_ms_str, inputs=[ms_str_info], outputs=[ms_str_info, ms_str]) ov_scales_info.change(fn=_restore_ov_scales, inputs=[ov_scales_info], outputs=[ov_scales_info, ov_scales]) ch_thresh_info.change(fn=_restore_ch_thresh, inputs=[ch_thresh_info], outputs=[ch_thresh_info, ch_thresh]) verbose_info.change(fn=_restore_verbose, inputs=[verbose_info], outputs=[verbose_info, verbose]) postcfg_info.change(fn=_restore_postcfg, inputs=[postcfg_info], outputs=[postcfg_info, pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain]) # Legacy sd-webui-freeu keys — reuse same restore logic legacy_sched_info.change(fn=lambda info, steps: _restore_sched(info, steps), inputs=[legacy_sched_info, _steps_comps.get(mode_key) or legacy_sched_info], outputs=[legacy_sched_info, start_r, stop_r, smooth]) legacy_stages_info.change(fn=_restore_stages, inputs=[legacy_stages_info], outputs=[legacy_stages_info, enabled, *flat_comps]) legacy_version_info.change(fn=_restore_ver, inputs=[legacy_version_info], outputs=[legacy_version_info, version]) self.infotext_fields = [ (sched_info, "MegaFreeU Schedule"), (stages_info, "MegaFreeU Stages"), (version_info, "MegaFreeU Version"), (ms_mode_info, "MegaFreeU Multiscale Mode"), (ms_str_info, "MegaFreeU Multiscale Strength"), (ov_scales_info, "MegaFreeU Override Scales"), (ch_thresh_info, "MegaFreeU Channel Threshold"), (postcfg_info, "MegaFreeU PostCFG"), (verbose_info, "MegaFreeU Verbose"), # Backward compat with sd-webui-freeu generated PNGs (legacy_sched_info, "FreeU Schedule"), (legacy_stages_info, "FreeU Stages"), (legacy_version_info,"FreeU Version"), ] self.paste_field_names = [f for _, f in self.infotext_fields] return [ enabled, version, preset_dd, start_r, stop_r, smooth, ms_mode, ms_str, ov_scales, ch_thresh, *flat_comps, pcfg_en, pcfg_steps, pcfg_mode, pcfg_bl, pcfg_b, pcfg_fou, pcfg_mmd, pcfg_mst, pcfg_thr, pcfg_s, pcfg_gain, verbose, ] def process(self, p: processing.StableDiffusionProcessing, *args): # ── Branch 1: old sd-webui-freeu API (dict passed as first arg) ─────── if args and isinstance(args[0], dict): global_state.instance = global_state.State(**{ k: v for k, v in args[0].items() if k in {f.name for f in dataclasses.fields(global_state.State)} }) global_state.apply_xyz() global_state.xyz_attrs.clear() st = global_state.instance unet.verbose_ref.value = bool(getattr(st, "verbose", False)) if getattr(st, "pcfg_enabled", False): p._mega_pcfg = { "enabled": True, "steps": st.pcfg_steps, "mode": st.pcfg_mode, "blend": st.pcfg_blend, "b": st.pcfg_b, "fourier": st.pcfg_fourier, "ms_mode": st.pcfg_ms_mode, "ms_str": st.pcfg_ms_str, "threshold": st.pcfg_threshold, "s": st.pcfg_s, "gain": st.pcfg_gain, "step": 0, } else: p._mega_pcfg = {"enabled": False} if st.enable: unet.detect_model_channels() unet._on_cpu_devices.clear() _write_generation_params(p, st) return # ── Branch 2: normal UI call ─────────────────────────────────────────── (enabled, version, preset_dd, start_r, stop_r, smooth, ms_mode, ms_str, ov_scales, ch_thresh, *rest) = args n_sv = _SN * global_state.STAGES_COUNT flat_stage = rest[:n_sv] post = rest[n_sv:] # 11 pcfg params + verbose verbose = bool(post[11]) if len(post) > 11 else False unet.verbose_ref.value = verbose # Write UI values into instance BEFORE apply_xyz so XYZ can override any of them inst = global_state.instance inst.enable = bool(enabled) inst.start_ratio = start_r inst.stop_ratio = stop_r inst.transition_smoothness = smooth inst.version = global_state.ALL_VERSIONS.get(version, "1") inst.multiscale_mode = ms_mode inst.multiscale_strength = float(ms_str) inst.override_scales = ov_scales or "" inst.channel_threshold = int(ch_thresh) inst.stage_infos = _flat_to_sis(flat_stage) # Sync Post-CFG into instance state so presets/PNG capture it pcfg = post[:11] if len(pcfg) >= 11: inst.pcfg_enabled = bool(pcfg[0]) inst.pcfg_steps = int(pcfg[1]) inst.pcfg_mode = str(pcfg[2]) inst.pcfg_blend = float(pcfg[3]) inst.pcfg_b = float(pcfg[4]) inst.pcfg_fourier = bool(pcfg[5]) inst.pcfg_ms_mode = str(pcfg[6]) inst.pcfg_ms_str = float(pcfg[7]) inst.pcfg_threshold = int(pcfg[8]) inst.pcfg_s = float(pcfg[9]) inst.pcfg_gain = float(pcfg[10]) inst.verbose = verbose # apply_xyz() may replace global_state.instance with a preset copy; # take the fresh reference AFTER so PNG metadata / verbose use the final state. global_state.apply_xyz() global_state.xyz_attrs.clear() st = global_state.instance # ← fresh ref post-XYZ # ── Post-CFG: set up ALWAYS (independent of main Enable) ────────────── if st.pcfg_enabled: p._mega_pcfg = { "enabled": True, "steps": st.pcfg_steps, "mode": st.pcfg_mode, "blend": st.pcfg_blend, "b": st.pcfg_b, "fourier": st.pcfg_fourier, "ms_mode": st.pcfg_ms_mode, "ms_str": st.pcfg_ms_str, "threshold": st.pcfg_threshold, "s": st.pcfg_s, "gain": st.pcfg_gain, "step": 0, } else: p._mega_pcfg = {"enabled": False} if not st.enable: # Write partial params so PNG records the session even when disabled _write_generation_params(p, st) return unet.detect_model_channels() unet._on_cpu_devices.clear() _write_generation_params(p, st) if unet.verbose_ref.value: print(f"[MegaFreeU] v{st.version} " f"start={st.start_ratio:.3f} stop={st.stop_ratio:.3f} " f"smooth={st.transition_smoothness:.3f} " f"ch_thresh=+-{st.channel_threshold}") for i, si in enumerate(st.stage_infos): ch = unet._stage_channels[i] if i < len(unet._stage_channels) else "?" print(f" Stage {i+1} ({ch}ch): " f"b={si.backbone_factor:.3f} [{si.b_start_ratio:.2f}-{si.b_end_ratio:.2f}] " f"{si.backbone_blend_mode}:{si.backbone_blend:.2f} " f"s={si.skip_factor:.3f} [{si.s_start_ratio:.2f}-{si.s_end_ratio:.2f}] " f"fft={si.fft_type} r={si.fft_radius_ratio:.3f} " f"hfe={si.skip_high_end_factor:.2f} hfb={si.hf_boost:.2f} " f"cap={'ON' if si.enable_adaptive_cap else 'off'} " f"({si.cap_threshold:.2f}/{si.cap_factor:.2f} {si.adaptive_cap_mode})") def process_batch(self, p, *args, **kwargs): global_state.current_sampling_step = 0 # FIX: reset PostCFG step counter for each image in batch if hasattr(p, "_mega_pcfg"): p._mega_pcfg["step"] = 0 def postprocess(self, p, processed, *args, **kwargs): """Clean up per-image state after generation.""" if hasattr(p, "_mega_pcfg"): p._mega_pcfg = {"enabled": False} def _write_generation_params(p, st): """Write full Mega FreeU state into PNG extra_generation_params.""" p.extra_generation_params["MegaFreeU Schedule"] = ( f"{st.start_ratio}, {st.stop_ratio}, {st.transition_smoothness}") p.extra_generation_params["MegaFreeU Stages"] = ( json.dumps([si.to_dict() for si in st.stage_infos])) p.extra_generation_params["MegaFreeU Version"] = st.version p.extra_generation_params["MegaFreeU Multiscale Mode"] = st.multiscale_mode p.extra_generation_params["MegaFreeU Multiscale Strength"] = str(st.multiscale_strength) p.extra_generation_params["MegaFreeU Override Scales"] = st.override_scales or "" p.extra_generation_params["MegaFreeU Channel Threshold"] = str(st.channel_threshold) p.extra_generation_params["MegaFreeU Verbose"] = str(st.verbose) if st.pcfg_enabled: p.extra_generation_params["MegaFreeU PostCFG"] = json.dumps({ "enabled": st.pcfg_enabled, "steps": st.pcfg_steps, "mode": st.pcfg_mode, "blend": st.pcfg_blend, "b": st.pcfg_b, "fourier": st.pcfg_fourier, "ms_mode": st.pcfg_ms_mode, "ms_str": st.pcfg_ms_str, "threshold": st.pcfg_threshold, "s": st.pcfg_s, "gain": st.pcfg_gain, }) def _flat_to_sis(flat) -> List[global_state.StageInfo]: result = [] for i in range(global_state.STAGES_COUNT): chunk = flat[i * _SN:(i + 1) * _SN] si = global_state.StageInfo() for j, fname in enumerate(_SF): if j < len(chunk): setattr(si, fname, chunk[j]) result.append(si) return result # Callbacks def _on_cfg_step(*_args, **_kwargs): global_state.current_sampling_step += 1 def _on_cfg_post(params): """WAS_PostCFGShift ported to A1111 on_cfg_after_cfg callback (exact algorithm).""" p = getattr(params, "p", None) if p is None: p = getattr(getattr(params, "denoiser", None), "p", None) if p is None: return cfg = getattr(p, "_mega_pcfg", None) if not cfg or not cfg.get("enabled"): return cfg["step"] = cfg.get("step", 0) + 1 if cfg["step"] > cfg["steps"]: return x = params.x fn = unet.BLENDING_MODES.get(cfg["mode"], unet.BLENDING_MODES["inject"]) y = fn(x, x * cfg["b"], cfg["blend"]) if cfg["fourier"]: ms = global_state.MSCALES.get(cfg["ms_mode"]) y = unet.filter_skip_box_multiscale( y, cfg["threshold"], cfg["s"], ms, cfg["ms_str"]) if cfg["gain"] != 1.0: y = y * float(cfg["gain"]) params.x = y try: script_callbacks.on_cfg_after_cfg(_on_cfg_step) script_callbacks.on_cfg_after_cfg(_on_cfg_post) except AttributeError: # webui < 1.6.0 (sd-webui-freeu compatibility note) script_callbacks.on_cfg_denoised(_on_cfg_step) script_callbacks.on_cfg_denoised(_on_cfg_post) def _on_after_component(component, **kwargs): eid = kwargs.get("elem_id", "") for key, sid in [("txt2img", "txt2img_steps"), ("img2img", "img2img_steps")]: if eid == sid: _steps_comps[key] = component for cb in _steps_cbs[key]: cb(component) _steps_cbs[key].clear() script_callbacks.on_after_component(_on_after_component) def _on_ui_settings(): shared.opts.add_option( "mega_freeu_png_auto_enable", shared.OptionInfo( default=True, label="Auto-enable Mega FreeU when loading PNG info from a FreeU generation", section=("mega_freeu", "Mega FreeU"))) script_callbacks.on_ui_settings(_on_ui_settings) script_callbacks.on_before_ui(xyz_grid.patch) # Install th.cat patch at import (sd-webui-freeu pattern) unet.patch()