| import gradio as gr | |
| from shared.utils.plugins import WAN2GPPlugin | |
| import json | |
| import traceback | |
| class LoraMultipliersUIPlugin(WAN2GPPlugin): | |
| def __init__(self): | |
| super().__init__() | |
| self.name = "Lora Multipliers UI" | |
| self.version = "1.0.9" | |
| self.description = "Dynamically set lora multipliers with a fast, JavaScript-powered UI." | |
| self.previous_loras_state = {} | |
| self.request_component("loras_multipliers") | |
| self.request_component("loras_choices") | |
| self.request_component("guidance_phases") | |
| self.request_component("num_inference_steps") | |
| self.request_component("main") | |
| def post_ui_setup(self, components: dict) -> dict: | |
| try: | |
| loras_multipliers = components["loras_multipliers"] | |
| loras_choices = components["loras_choices"] | |
| guidance_phases = components["guidance_phases"] | |
| num_inference_steps = components["num_inference_steps"] | |
| main_ui_block = components["main"] | |
| instance_id = loras_multipliers._id | |
| if instance_id not in self.previous_loras_state: | |
| self.previous_loras_state[instance_id] = {'loras': [], 'accelerators': [], 'multipliers': {}} | |
| def create_and_wire_ui(): | |
| container_id = f"lora_multiplier_ui_container_{instance_id}" | |
| update_btn_id = f"lora_mults_update_btn_{instance_id}" | |
| hidden_input_id = f"lora_mults_hidden_input_{instance_id}" | |
| js_renderer_func = f"wgpLoraUIRenderer_{instance_id}" | |
| css = f""" | |
| <style> | |
| #{container_id} {{ font-family: 'Segoe UI', 'Roboto', 'Helvetica Neue', sans-serif; }} | |
| .lora-main-container {{ border: 1px solid var(--border-color-primary); border-radius: 8px; padding: 12px; margin-bottom: 8px !important; background-color: var(--background-fill-secondary); }} | |
| .lora-main-container .gr-row {{ margin-bottom: 8px; justify-content: space-between; align-items: center; }} | |
| .lora-main-container .gr-row h3 {{ margin-top: 0; margin-bottom: 10;}} | |
| .lora-step-split-container {{ border: 1px dashed var(--border-color-accent); border-radius: 6px; padding: 10px; margin-top: 8px; }} | |
| .lora-slider-row {{ display: flex; gap: 16px; align-items: end; }} | |
| .lora-main-container button, .lora-main-container .wgp-lora-button {{ padding: 4px 12px !important; font-size: 1em !important; min-width: fit-content !important; flex-grow: 0; background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); border: 1px solid var(--button-secondary-border-color); border-radius: 4px; cursor: pointer; }} | |
| .lora-main-container button:hover, .lora-main-container .wgp-lora-button:hover {{ background: var(--button-secondary-background-fill-hover); border-color: var(--button-secondary-border-color-hover); }} | |
| .lora-section-header h3 {{ border-bottom: 1px solid var(--border-color-primary); padding-bottom: 4px; margin-top: 16px; margin-bottom: 8px; }} | |
| #{container_id} > .lora-section-header:first-child h3 {{ | |
| margin-top: 0; | |
| }} | |
| .lora-main-container > h3:first-child {{ margin-top: 0; }} | |
| .lora-slider-group {{ | |
| flex: 1; | |
| display: flex; | |
| flex-direction: column; | |
| }} | |
| .lora-slider-input-wrapper {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 8px; | |
| }} | |
| .lora-slider-group label {{ display: block; color: var(--body-text-color); font-size: 0.9em; margin-bottom: 4px; }} | |
| .lora-slider-group input[type=range] {{ | |
| flex-grow: 1; | |
| width: auto; | |
| }} | |
| .lora-slider-group input[type=number] {{ | |
| width: 60px; | |
| padding: 4px; | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 4px; | |
| background-color: var(--input-background-fill); | |
| color: var(--input-text-color); | |
| font-size: 0.9em; | |
| text-align: center; | |
| }} | |
| .lora-slider-group input[type=number]::-webkit-inner-spin-button, | |
| .lora-slider-group input[type=number]::-webkit-outer-spin-button {{ | |
| -webkit-appearance: none; | |
| margin: 0; | |
| }} | |
| .lora-slider-group input[type=number] {{ | |
| -moz-appearance: textfield; | |
| }} | |
| .hidden {{ display: none !important; }} | |
| .lora-split-title {{ margin-bottom: 8px; }} | |
| </style> | |
| """ | |
| main_js_script = f""" | |
| () => {{ | |
| const debounce = (func, delay) => {{ | |
| let timeout; | |
| return (...args) => {{ | |
| clearTimeout(timeout); | |
| timeout = setTimeout(() => func.apply(this, args), delay); | |
| }}; | |
| }}; | |
| const updatePythonTextbox_{instance_id} = debounce(() => {{ | |
| const container = document.getElementById('{container_id}'); | |
| if (!container) return; | |
| const loras = Array.from(container.querySelectorAll('.lora-main-container')); | |
| const textboxStrings = []; | |
| for (const loraEl of loras) {{ | |
| const splits = Array.from(loraEl.querySelectorAll('.lora-step-split-container')); | |
| const loraStepStrings = []; | |
| for (const splitEl of splits) {{ | |
| const sliders = Array.from(splitEl.querySelectorAll('input[type=range]')); | |
| const phaseValues = sliders | |
| .filter(s => !s.closest('.lora-slider-group').classList.contains('hidden')) | |
| .map(s => {{ | |
| const val = parseFloat(s.value); | |
| return val % 1 === 0 ? String(val) : val.toFixed(2).replace(/\\.?0+$/, ''); | |
| }}); | |
| if (phaseValues.length > 0) {{ | |
| loraStepStrings.push(phaseValues.join(';')); | |
| }} | |
| }} | |
| if (loraStepStrings.length > 0) {{ | |
| textboxStrings.push(loraStepStrings.join(',')); | |
| }} | |
| }} | |
| const separatorIndex = parseInt(container.dataset.separatorIndex || '-1'); | |
| let finalString = ""; | |
| if (separatorIndex > 0 && separatorIndex <= textboxStrings.length) {{ | |
| const part1 = textboxStrings.slice(0, separatorIndex).join(' '); | |
| const part2 = textboxStrings.slice(separatorIndex).join(' '); | |
| finalString = part1 + '|' + part2; | |
| }} else {{ | |
| finalString = textboxStrings.join(' '); | |
| }} | |
| const hiddenInput = document.querySelector('#{hidden_input_id} textarea'); | |
| const updateButton = document.getElementById('{update_btn_id}'); | |
| if (hiddenInput && updateButton) {{ | |
| hiddenInput.value = finalString; | |
| hiddenInput.dispatchEvent(new Event('input', {{ bubbles: true }})); | |
| updateButton.click(); | |
| }} | |
| }}, 200); | |
| function createSlider_{instance_id}(phase, value, isVisible) {{ | |
| const container = document.createElement('div'); | |
| container.className = 'lora-slider-group'; | |
| if (!isVisible) container.classList.add('hidden'); | |
| const initialValue = parseFloat(value); | |
| container.innerHTML = ` | |
| <label>Phase ${{phase}}</label> | |
| <div class="lora-slider-input-wrapper"> | |
| <input type="range" min="0" max="1" step="0.05" value="${{initialValue}}"> | |
| <input type="number" min="0" max="1" step="0.05" value="${{initialValue.toFixed(2)}}"> | |
| </div> | |
| `; | |
| const rangeInput = container.querySelector('input[type="range"]'); | |
| const numberInput = container.querySelector('input[type="number"]'); | |
| const syncAndUpdate = (source) => {{ | |
| let val = parseFloat(source.value); | |
| if (isNaN(val)) val = 0; | |
| val = Math.max(0, Math.min(1, val)); | |
| if (source === rangeInput) {{ | |
| numberInput.value = val.toFixed(2); | |
| }} else {{ | |
| rangeInput.value = val; | |
| if (source.value !== val.toFixed(2)) {{ | |
| numberInput.value = val.toFixed(2); | |
| }} | |
| }} | |
| updatePythonTextbox_{instance_id}(); | |
| }}; | |
| rangeInput.addEventListener('input', () => syncAndUpdate(rangeInput)); | |
| numberInput.addEventListener('input', () => syncAndUpdate(numberInput)); | |
| return container; | |
| }} | |
| function createStepSplit_{instance_id}(loraIndex, splitIndex, values, guidancePhases, stepText) {{ | |
| const splitContainer = document.createElement('div'); | |
| splitContainer.className = 'lora-step-split-container'; | |
| const sliderRow = document.createElement('div'); | |
| sliderRow.className = 'lora-slider-row'; | |
| const title = document.createElement('div'); | |
| title.className = 'lora-split-title'; | |
| title.innerHTML = `<strong>${{stepText}}</strong>`; | |
| splitContainer.appendChild(title); | |
| for (let i = 0; i < 3; i++) {{ | |
| const isVisible = (i + 1) <= guidancePhases; | |
| const sliderValue = values[i] !== undefined ? values[i] : 1.0; | |
| sliderRow.appendChild(createSlider_{instance_id}(i + 1, sliderValue, isVisible)); | |
| }} | |
| splitContainer.appendChild(sliderRow); | |
| return splitContainer; | |
| }} | |
| function updateRejoinVisibility_{instance_id}(loraContainer) {{ | |
| if (!loraContainer) return; | |
| const splits = loraContainer.querySelectorAll('.lora-step-split-container'); | |
| const rejoinBtn = loraContainer.querySelector('.rejoin-btn'); | |
| if (rejoinBtn) {{ rejoinBtn.style.display = splits.length > 1 ? 'inline-block' : 'none'; }} | |
| }} | |
| function recalculateStepRanges_{instance_id}(loraContainer) {{ | |
| const totalSteps = window.wgp_total_steps_{instance_id} || 1; | |
| const splits = loraContainer.querySelectorAll('.lora-step-split-container'); | |
| const numSplits = splits.length; | |
| if (numSplits === 0) return; | |
| const stepsPerSplit = Math.floor(totalSteps / numSplits); | |
| const remainder = totalSteps % numSplits; | |
| let startStep = 0; | |
| splits.forEach((split, i) => {{ | |
| const stepsInThisSplit = stepsPerSplit + (i < remainder ? 1 : 0); | |
| const endStep = startStep + stepsInThisSplit; | |
| const titleStrong = split.querySelector('.lora-split-title strong'); | |
| if(titleStrong) {{ | |
| const displayEnd = Math.max(startStep + 1, endStep); | |
| titleStrong.textContent = `Steps ${{startStep + 1}} to ${{displayEnd}}`; | |
| }} | |
| startStep = endStep; | |
| }}); | |
| }} | |
| function handleSplit_{instance_id}(e) {{ | |
| const loraIndex = parseInt(e.target.dataset.loraIndex); | |
| const loraContainer = document.getElementById(`lora-container-{instance_id}-${{loraIndex}}`); | |
| const newSplit = createStepSplit_{instance_id}(loraIndex, -1, [1.0, 1.0, 1.0], window.wgp_guidance_phases_{instance_id}, ""); | |
| loraContainer.appendChild(newSplit); | |
| recalculateStepRanges_{instance_id}(loraContainer); | |
| updateRejoinVisibility_{instance_id}(loraContainer); | |
| updatePythonTextbox_{instance_id}(); | |
| }} | |
| function handleRejoin_{instance_id}(e) {{ | |
| const loraIndex = parseInt(e.target.dataset.loraIndex); | |
| const loraContainer = document.getElementById(`lora-container-{instance_id}-${{loraIndex}}`); | |
| const splits = loraContainer.querySelectorAll('.lora-step-split-container'); | |
| if (splits.length > 1) {{ | |
| splits[splits.length - 1].remove(); | |
| recalculateStepRanges_{instance_id}(loraContainer); | |
| updateRejoinVisibility_{instance_id}(loraContainer); | |
| updatePythonTextbox_{instance_id}(); | |
| }} | |
| }} | |
| window.{js_renderer_func} = (jsonData) => {{ | |
| let data; | |
| try {{ data = JSON.parse(jsonData); }} catch (e) {{ console.error('Error parsing Lora UI JSON:', e); return; }} | |
| if (!data) return; | |
| const container = document.getElementById('{container_id}'); | |
| if (!container) return; | |
| container.innerHTML = ''; | |
| window.wgp_guidance_phases_{instance_id} = data.guidance_phases; | |
| window.wgp_total_steps_{instance_id} = data.total_steps; | |
| container.dataset.separatorIndex = data.separator_index; | |
| const createHeader = (text) => {{ | |
| const headerDiv = document.createElement('div'); | |
| headerDiv.className = 'lora-section-header'; | |
| headerDiv.innerHTML = `<h3>${{text}}</h3>`; | |
| return headerDiv; | |
| }}; | |
| if(data.separator_index > 0 && data.loras.length > 0) {{ | |
| container.appendChild(createHeader('Accelerator LoRAs')); | |
| }} | |
| data.loras.forEach((lora, i) => {{ | |
| if ((data.separator_index === -1 && i === 0) || data.separator_index === i) {{ | |
| container.appendChild(createHeader('User LoRAs')); | |
| }} | |
| const loraContainer = document.createElement('div'); | |
| loraContainer.className = 'lora-main-container'; | |
| loraContainer.id = `lora-container-{instance_id}-${{i}}`; | |
| loraContainer.innerHTML = `<div class="gr-row"><h3>${{lora.name}}</h3><div style="display:flex; gap: 8px;"><button class="wgp-lora-button split-btn" data-lora-index="${{i}}" type="button">Split Steps</button><button class="wgp-lora-button rejoin-btn" data-lora-index="${{i}}" type="button" style="display:none;">Rejoin Step</button></div></div>`; | |
| lora.splits.forEach((split) => {{ | |
| loraContainer.appendChild(createStepSplit_{instance_id}(i, -1, split.values, data.guidance_phases, "")); | |
| }}); | |
| container.appendChild(loraContainer); | |
| recalculateStepRanges_{instance_id}(loraContainer); | |
| updateRejoinVisibility_{instance_id}(loraContainer); | |
| loraContainer.querySelector('.split-btn').addEventListener('click', handleSplit_{instance_id}); | |
| loraContainer.querySelector('.rejoin-btn').addEventListener('click', handleRejoin_{instance_id}); | |
| }}); | |
| }}; | |
| }} | |
| """ | |
| def update_ui_data_from_python(selected_lora_names, multipliers_str, guidance_phases_val, total_steps): | |
| try: | |
| lora_names = selected_lora_names if selected_lora_names else [] | |
| num_selected_loras = len(lora_names) | |
| all_stale_multipliers = [s for s in (multipliers_str or "").replace('|', ' ').split(' ') if s] | |
| num_stale_multipliers = len(all_stale_multipliers) | |
| previous_state = self.previous_loras_state.get(instance_id, {'loras': [], 'accelerators': [], 'multipliers': {}}) | |
| old_mult_map = previous_state.get('multipliers', {}) | |
| previous_loras_set = set(previous_state.get('loras', [])) | |
| current_loras_set = set(lora_names) | |
| lora_list_changed = previous_loras_set != current_loras_set | |
| is_desynced = (num_selected_loras != num_stale_multipliers) or lora_list_changed | |
| lora_names_for_ui = [] | |
| multipliers_per_lora_str = [] | |
| new_separator_index = -1 | |
| if not is_desynced: | |
| lora_names_for_ui = lora_names | |
| multipliers_per_lora_str = all_stale_multipliers | |
| if multipliers_str and '|' in multipliers_str: | |
| parts = multipliers_str.split('|') | |
| new_separator_index = len([s for s in parts[0].split(' ') if s]) | |
| else: | |
| new_separator_index = -1 | |
| else: | |
| old_accelerators = set(previous_state.get('accelerators', [])) | |
| current_accelerators = [lora for lora in lora_names if lora in old_accelerators] | |
| current_user_loras = [lora for lora in lora_names if lora not in old_accelerators] | |
| lora_names_for_ui = current_accelerators + current_user_loras | |
| new_separator_index = len(current_accelerators) if current_accelerators else -1 | |
| for lora_name in lora_names_for_ui: | |
| multipliers_per_lora_str.append(old_mult_map.get(lora_name, "1.0")) | |
| loras_data = [] | |
| final_mult_map = {} | |
| for i, lora_name in enumerate(lora_names_for_ui): | |
| lora_obj = {"name": lora_name, "splits": []} | |
| steps_and_phases_str = multipliers_per_lora_str[i] | |
| final_mult_map[lora_name] = steps_and_phases_str | |
| for step_str in steps_and_phases_str.split(','): | |
| phase_values = [] | |
| phase_strs = step_str.split(';') | |
| for k in range(3): | |
| try: | |
| phase_values.append(float(phase_strs[k])) | |
| except (ValueError, IndexError): | |
| phase_values.append(1.0) | |
| lora_obj["splits"].append({"values": phase_values}) | |
| loras_data.append(lora_obj) | |
| payload = { | |
| "loras": loras_data, | |
| "guidance_phases": guidance_phases_val, | |
| "total_steps": total_steps or 1, | |
| "separator_index": new_separator_index, | |
| } | |
| final_accelerators = lora_names_for_ui[:new_separator_index if new_separator_index != -1 else 0] | |
| self.previous_loras_state[instance_id] = { | |
| 'loras': lora_names_for_ui, | |
| 'accelerators': final_accelerators, | |
| 'multipliers': final_mult_map | |
| } | |
| return json.dumps(payload) | |
| except Exception: | |
| traceback.print_exc() | |
| return "{}" | |
| def update_textbox_from_js(new_value): | |
| return gr.update(value=new_value) | |
| with gr.Accordion("Dynamic Lora Multipliers", open=True) as main_accordion: | |
| gr.HTML(value=css) | |
| gr.HTML(f"<div id='{container_id}'></div>") | |
| with gr.Row(visible=False): | |
| hidden_input = gr.Text(elem_id=hidden_input_id) | |
| update_button = gr.Button(elem_id=update_btn_id) | |
| ui_data_json = gr.Text(elem_id=f"ui_data_json_{instance_id}", visible=False) | |
| main_ui_block.load(fn=None, js=main_js_script) | |
| ui_data_json.change( | |
| fn=None, | |
| inputs=[ui_data_json], | |
| js=f"(jsonData) => {{ if(window.{js_renderer_func}) window.{js_renderer_func}(jsonData); }}", | |
| show_progress="hidden" | |
| ) | |
| input_components = [loras_choices, loras_multipliers, guidance_phases, num_inference_steps] | |
| main_ui_block.load( | |
| fn=update_ui_data_from_python, | |
| inputs=input_components, | |
| outputs=[ui_data_json], | |
| show_progress="hidden" | |
| ) | |
| events_to_trigger = [loras_choices.change, guidance_phases.change, num_inference_steps.change, loras_multipliers.blur] | |
| for event_fn in events_to_trigger: | |
| event_fn( | |
| fn=update_ui_data_from_python, | |
| inputs=input_components, | |
| outputs=[ui_data_json], | |
| show_progress="hidden" | |
| ) | |
| update_button.click( | |
| fn=update_textbox_from_js, | |
| inputs=[hidden_input], | |
| outputs=[loras_multipliers], | |
| show_progress="hidden" | |
| ) | |
| return main_accordion | |
| self.insert_after( | |
| target_component_id="loras_multipliers", | |
| new_component_constructor=create_and_wire_ui | |
| ) | |
| except Exception: | |
| traceback.print_exc() | |
| return {} |