sup-toolbox-app / ui /ui_events.py
elismasilva's picture
fixed layout and added verification to env variable RUN_ON_SPACES
1a844b7
# Copyright 2025 The DEVAIEXP Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import logging
import math
import os
import queue
import random
import sys
import threading
from dataclasses import asdict, fields, is_dataclass
from pathlib import Path
from typing import Any, Union, cast
import gradio as gr
from gradio_imagemeta.helpers import extract_metadata, transfer_metadata
from gradio_livelog.utils import ProgressTracker, Tee, TqdmToQueueWriter, capture_logs
from gradio_propertysheet import PropertySheet
from gradio_propertysheet.helpers import flatten_dataclass_with_labels
from PIL import Image
from sup_toolbox.enums import (
ColorFix,
ImageSizeFixMode,
PromptMethod,
RestorerEngine,
Sampler,
StartPoint,
SUPIRModel,
UpscalerEngine,
UpscalingMode,
WeightingMethod,
)
from sup_toolbox.utils.system import infer_type
from ui.globals import pipeline_lock, sup_toolbox_pipe
from ui.ui_config import (
APPSETTINGS_SHEET_DEPENDENCY_RULES,
DEFAULT_PROMPTS,
RESTORER_CONFIG_MAPPING,
RESTORER_SHEET_DEPENDENCY_RULES,
RUN_ON_SPACES,
SAMPLER_MAPPING,
SUPIR_ADVANCED_RULES,
UPSCALER_CONFIG_MAPPING,
UPSCALER_SHEET_DEPENDENCY_RULES,
ControlNetTile_Config,
FaithDiff_Config,
SUPIR_Config,
SUPIRAdvanced_Config,
)
from ui.ui_layout import UIComponents
from ui.ui_state import AppState
from ui.util.dataclass_helpers import (
apply_dynamic_changes,
dataclass_from_dict,
get_nested_attr,
)
class EventHandlers:
def __init__(self, state: AppState, components: UIComponents):
self.state = state
self.components = components
def update_pipeline(self, log_callback=None, progress_bar_handler=None):
"""
Initializes or updates the shared SUPToolBoxPipeline instance in a thread-safe manner.
This method ensures that the module-level pipeline object (`sup_toolbox_pipe`)
is created and configured according to the current application state. It uses a
thread lock (`pipeline_lock`) to guarantee that the pipeline is initialized
only once, even if multiple threads or processes attempt to call this method
concurrently.
If a pipeline instance does not yet exist, it constructs a new `SUPToolBoxPipeline`
using the configuration from `self.state.uidata.config` and the provided callbacks.
If an instance already exists, it simply updates the instance's configuration and
callback attributes with the latest values from the application state.
Parameters
----------
log_callback : callable, optional
A callback function used by the pipeline to emit log messages. Defaults to `None`.
progress_bar_handler : callable, optional
A handler for reporting progress, typically for UI progress bars. Defaults to `None`.
Accesses
--------
self.state.uidata.config : Config
The main configuration object used to initialize or update the pipeline.
self.state.cancel_event : threading.Event
The cancellation event object passed to the pipeline.
Side Effects
------------
- Initializes the module-level `sup_toolbox_pipe` variable if it is `None`.
- Mutates the attributes (`config`, `log_callback`, etc.) of an existing
`sup_toolbox_pipe` instance.
- The import of `SUPToolBoxPipeline` is done locally within the function
to support the lazy-loading architecture.
Returns
-------
None
"""
global sup_toolbox_pipe
from sup_toolbox.sup_toolbox_pipeline import (
SUPToolBoxPipeline,
)
with pipeline_lock:
if sup_toolbox_pipe is None:
sup_toolbox_pipe = SUPToolBoxPipeline(
self.state.uidata.config,
log_callback=log_callback,
progress_bar_handler=progress_bar_handler,
cancel_event=self.state.cancel_event,
)
else:
sup_toolbox_pipe.config = self.state.uidata.config
sup_toolbox_pipe.log_callback = log_callback
sup_toolbox_pipe.progress_bar_handler = progress_bar_handler
sup_toolbox_pipe.cancel_event = self.state.cancel_event
def get_supir_advanced_values(self):
"""
Create and return a SUPIRAdvanced_Config instance with specific advanced flags disabled.
This function constructs a new SUPIRAdvanced_Config object initialized with its
default settings and then explicitly disables the "sft_active" flag for two
cross-up blocks used in stage 1:
- cross_up_block_0_stage1.sft_active is set to False
- cross_up_block_1_stage1.sft_active is set to False
Returns:
SUPIRAdvanced_Config: A configuration object reflecting the default advanced
settings with the two stage-1 cross-up block SFT flags disabled.
Notes:
- The function does not modify any global state; it returns a freshly
created configuration instance.
- It assumes SUPIRAdvanced_Config and its nested attributes
(cross_up_block_0_stage1, cross_up_block_1_stage1, and their sft_active
attributes) are defined and accessible.
- No parameters are required.
Example:
cfg = get_supir_advanced_values()
assert cfg.cross_up_block_0_stage1.sft_active is False
assert cfg.cross_up_block_1_stage1.sft_active is False
"""
initial_supir_advanced_settings = SUPIRAdvanced_Config()
initial_supir_advanced_settings.cross_up_block_0_stage1.sft_active = False
initial_supir_advanced_settings.cross_up_block_1_stage1.sft_active = False
return initial_supir_advanced_settings
def calculate_effective_steps(
self,
config: Union[SUPIR_Config, FaithDiff_Config, ControlNetTile_Config],
is_upscaler: bool,
) -> int:
"""
Calculates the total effective number of inference steps for a given engine configuration.
For upscalers in 'Progressive' mode, it simulates the distributed decay of the 'strength'
parameter across multiple passes to provide an accurate total step count.
Args:
config: The configuration object for the engine.
is_upscaler: A boolean flag indicating if this config is for an upscaler.
Returns:
The total calculated effective number of steps for the configuration.
"""
if not config or not hasattr(config, "general"):
return 0
num_images = getattr(config.general, "num_images", 1)
num_steps = getattr(config.general, "num_steps", 0)
initial_strength = getattr(config, "strength", 1.0)
total_steps_for_one_image = 0
# 1. Calculate steps for the first (or only) pass.
# The strength used is the initial strength.
first_pass_steps = min(int(num_steps * initial_strength), num_steps)
total_steps_for_one_image += first_pass_steps
# 2. If it's an upscaler in progressive mode, simulate subsequent passes with distributed decay.
is_progressive_mode = is_upscaler and hasattr(config.general, "upscaling_mode") and config.general.upscaling_mode == UpscalingMode.Progressive.value
if is_progressive_mode:
scale_factor_str = getattr(config.general, "upscale_factor", "1x")
# Ensure the string is not empty before trying to slice
if scale_factor_str:
try:
target_scale_factor = int(scale_factor_str[:-1])
except (ValueError, IndexError):
target_scale_factor = 1
else:
target_scale_factor = 1
# Only calculate decay if there are multiple passes
if target_scale_factor > 1: # Changed from > 2 to handle 2x upscale correctly
num_2x_passes = math.ceil(math.log2(target_scale_factor))
if num_2x_passes > 1:
strength_decay_rate = getattr(config.general, "strength_decay_rate", 0.5)
# Calculate the total amount of strength to be reduced over all passes
total_strength_decay = initial_strength * strength_decay_rate
# Distribute this total decay amount over the subsequent passes
num_decay_steps = num_2x_passes - 1
strength_decay_per_step = total_strength_decay / num_decay_steps if num_decay_steps > 0 else 0
current_strength = initial_strength
# Simulate the remaining passes (first pass is already counted)
for _ in range(num_decay_steps):
# Apply the linear decay step
current_strength -= strength_decay_per_step
# Apply the safety floor, same as in the pipeline
final_strength_for_pass = round(max(current_strength, 0.1), 2)
# Calculate steps for this pass and add to total
pass_steps = min(int(num_steps * final_strength_for_pass), num_steps)
total_steps_for_one_image += pass_steps
# 3. Multiply by the number of images to get the overall total
total_effective_steps = num_images * total_steps_for_one_image
return total_effective_steps
def prepare_engine_configs(
self,
restorer_engine: str,
restorer_config: dict,
upscaler_engine: str,
upscaler_config: dict,
):
"""
Prepares and casts the engine configuration objects based on the selected engine names.
This function is responsible only for creating the correct dataclass instances.
Args:
restorer_engine: The name of the selected restorer engine.
restorer_config: The raw configuration data for the restorer from the UI.
upscaler_engine: The name of the selected upscaler engine.
upscaler_config: The raw configuration data for the upscaler from the UI.
Returns:
A tuple containing the prepared restorer config and upscaler config objects,
which may be None if the corresponding engine is not selected.
"""
res_config, ups_config = None, None
if restorer_engine in [
RestorerEngine.SUPIR.value,
RestorerEngine.FaithDiff.value,
]:
res_config = cast(Union[SUPIR_Config, FaithDiff_Config], restorer_config)
if upscaler_engine in [
UpscalerEngine.SUPIR.value,
UpscalerEngine.FaithDiff.value,
UpscalerEngine.ControlNetTile.value,
]:
if upscaler_engine == UpscalerEngine.ControlNetTile.value:
ups_config = cast(ControlNetTile_Config, upscaler_config)
else:
ups_config = cast(Union[SUPIR_Config, FaithDiff_Config], upscaler_config)
return res_config, ups_config
def calculate_total_steps(self, res_config, ups_config, res_engine_name, ups_engine_name):
"""
Calculates the total effective inference steps based on the prepared engine configurations.
Args:
res_config: The prepared configuration object for the restorer.
ups_config: The prepared configuration object for the upscaler.
res_engine_name: The selected restorer engine name.
ups_engine_name: The selected upscaler engine name.
Returns:
The total number of effective inference steps.
"""
total_inference_steps = 0
if res_config and res_engine_name != "None":
total_inference_steps += self.calculate_effective_steps(res_config, is_upscaler=False)
if ups_config and ups_engine_name != "None":
total_inference_steps += self.calculate_effective_steps(ups_config, is_upscaler=True)
return total_inference_steps
def generate_image_metadata(
self,
res_config_class,
ups_config_class,
res_supir_advanced_config_class,
ups_supir_advanced_config_class,
input_params,
res_sampler_config_class,
ups_sampler_config_class,
):
"""
Generates metadata by aggregating configuration values from various sources.
This function takes multiple configuration classes and input parameters, flattens them,
and combines them into a single dictionary with prefixed keys for metadata tracking.
Args:
res_config_class (dict): Configuration for the image restoration.
ups_config_class (dict): Configuration for the image upscaling.
res_supir_advanced_config_class (dict): Advanced SUPIR configuration for restoration.
ups_supir_advanced_config_class (dict): Advanced SUPIR configuration for upscaling.
input_params (dict): User input parameters including engine selections.
res_sampler_config_class (dict): Sampler configuration for restoration.
ups_sampler_config_class (dict): Sampler configuration for upscaling.
Returns:
Dict[str, Any]: A flattened dictionary containing all configuration values with
prefixed keys that identify their source and purpose. For example:
{
"Restorer - Engine1 - param1": value1,
"Upscaler - Engine2 - param2": value2,
...
}
Notes:
- Keys are prefixed based on their source (Restorer/Upscaler) and engine type
- Processing only occurs if the respective engine is selected and not "none"
- Input parameters are preserved in the output dictionary
"""
res_engine, ups_engine = (
input_params["Image Restore Engine"],
input_params["Image Upscale Engine"],
)
all_values = input_params.copy()
def process_and_prefix(instance: Any, prefix: str):
"""Helper to flatten a dataclass and add a final prefix to its keys."""
if instance:
for key, value in flatten_dataclass_with_labels(instance).items():
all_values[f"{prefix} - {key}"] = value
if res_engine and res_engine.lower() != "none":
process_and_prefix(res_config_class, f"Restorer - {res_engine}")
process_and_prefix(res_supir_advanced_config_class, "Restorer")
process_and_prefix(res_sampler_config_class, "Restorer - Sampler")
if ups_engine and ups_engine.lower() != "none":
process_and_prefix(ups_config_class, f"Upscaler - {ups_engine}")
process_and_prefix(ups_supir_advanced_config_class, "Upscaler")
process_and_prefix(ups_sampler_config_class, "Upscaler - Sampler")
return all_values
def update_preset_list(self, preset):
"""
Updates the preset list in the UI dropdown with the latest presets and sets a specific value.
Args:
preset (str): The preset value to be selected in the dropdown after updating the list
Returns:
gr.update: A Gradio update object containing the new preset choices and selected value
"""
self.state.uidata.get_preset_list()
return gr.update(choices=self.state.uidata.PRESETS_LIST, value=preset)
def save_preset(self, preset_name: str, *all_component_values):
"""
Parameters:
preset_name (str): The name of the preset to be saved. Must not be empty or a default preset name.
*all_component_values: A variable number of values representing the current state of UI components.
Raises:
gr.Error: If the preset name is empty or if an attempt is made to overwrite a default preset.
Notes:
- Maps the received values back to the components using the order of the input list passed to the .click() event.
- It is safer to map by elem_id.
- Removes 'restorer_supir_advanced_settings' if the 'restorer_engine' is not set to SUPIR.
- Removes 'upscaler_supir_advanced_settings' if the 'upscaler_engine' is not set to SUPIR.
- For PropertySheets, if the value is a dataclass, it is converted to a dictionary before saving.
"""
if not preset_name or not preset_name.strip():
raise gr.Error("Please enter a preset name.")
if preset_name in self.state.uidata.PRESETS_LIST and "Default:" in preset_name:
raise gr.Error("A default preset cannot be overwritten; please set a different name.")
preset_data = {}
ALL_UI_COMPONENTS = self.components.ALL_UI_COMPONENTS
component_values = dict(zip(ALL_UI_COMPONENTS.keys(), all_component_values))
component_values["restorer_sampler_settings"] = SAMPLER_MAPPING["restorer_sampler"]
component_values["upscaler_sampler_settings"] = SAMPLER_MAPPING["upscaler_sampler"]
if component_values.get("restorer_engine") != RestorerEngine.SUPIR.value:
component_values.pop("restorer_supir_advanced_settings", None)
if component_values.get("upscaler_engine") != RestorerEngine.SUPIR.value:
component_values.pop("upscaler_supir_advanced_settings", None)
for elem_id, value in component_values.items():
if dataclasses.is_dataclass(value):
preset_data[elem_id] = asdict(value)
else:
preset_data[elem_id] = value
self.state.uidata.save_preset(preset_name.strip(), preset_data)
gr.Info(f"Preset '{preset_name}' saved successfully!")
def load_preset(self, preset_name: str):
"""
Load Presets to UI components. It also updates the backend state, such as
SAMPLER_MAPPING, based on the loaded preset.
Parameters:
preset_name (str): The name of the preset to load. Must be a non-empty string.
Returns:
List[gr.Update]: A list of updates for the UI components, where each update
corresponds to the state of a component after loading the preset.
Raises:
gr.Error: If no preset is selected, if the preset is empty or cannot be found,
or if there is an error during the loading process.
The function handles different types of UI components, including PropertySheets,
and updates their values based on the data retrieved from the preset. It also
ensures that the backend state is synchronized with the loaded preset data.
"""
ALL_UI_COMPONENTS = self.components.ALL_UI_COMPONENTS
# Prepare a default output list. `gr.skip()` means "do not change this component".
output_updates = [gr.skip()] * len(ALL_UI_COMPONENTS)
if not preset_name or not preset_name.strip():
raise gr.Error("No preset selected to load.")
try:
# Load the preset data from the JSON file
preset_data = self.state.uidata.load_preset(preset_name.strip())
if not preset_data:
raise gr.Error(f"Preset '{preset_name}' is empty or could not be found.")
# Create a mapping of components to their indices in the output list for easy access
component_to_index = {id(comp): i for i, comp in enumerate(ALL_UI_COMPONENTS.values())}
for elem_id, component in ALL_UI_COMPONENTS.items():
# Check if there is a value for this component in the preset
if elem_id in preset_data:
value_from_preset = preset_data[elem_id]
output_index = component_to_index.get(id(component))
if output_index is None:
continue
# If the component is a PropertySheet, reconstruct the dataclass instance
if isinstance(component, PropertySheet):
dc_type = type(getattr(component, "_dataclass_value", None))
# If the component is a PropertySheet, reconstruct the dataclass instance
if dc_type and is_dataclass(dc_type) and isinstance(value_from_preset, dict):
if component.elem_id == "restorer_settings" and output_updates[0]["value"] == RestorerEngine.SUPIR.value:
dc_type = type(RESTORER_CONFIG_MAPPING[RestorerEngine.SUPIR.value])
instance = dataclass_from_dict(dc_type, value_from_preset)
instance = apply_dynamic_changes(instance, RESTORER_SHEET_DEPENDENCY_RULES)
RESTORER_CONFIG_MAPPING[RestorerEngine.SUPIR.value] = instance
elif component.elem_id == "restorer_supir_advanced_settings" and output_updates[0]["value"] == RestorerEngine.SUPIR.value:
dc_type = type(RESTORER_CONFIG_MAPPING["SUPIRAdvanced"])
instance = dataclass_from_dict(dc_type, value_from_preset)
instance = apply_dynamic_changes(instance, SUPIR_ADVANCED_RULES)
RESTORER_CONFIG_MAPPING["SUPIRAdvanced"] = instance
elif component.elem_id == "restorer_settings" and output_updates[0]["value"] == RestorerEngine.FaithDiff.value:
dc_type = type(RESTORER_CONFIG_MAPPING[RestorerEngine.FaithDiff.value])
instance = dataclass_from_dict(dc_type, value_from_preset)
instance = apply_dynamic_changes(instance, RESTORER_SHEET_DEPENDENCY_RULES)
RESTORER_CONFIG_MAPPING[RestorerEngine.FaithDiff.value] = instance
if component.elem_id == "upscaler_settings" and output_updates[1]["value"] == UpscalerEngine.SUPIR.value:
dc_type = type(UPSCALER_CONFIG_MAPPING[UpscalerEngine.SUPIR.value])
instance = dataclass_from_dict(dc_type, value_from_preset)
UPSCALER_CONFIG_MAPPING[UpscalerEngine.SUPIR.value] = instance
elif component.elem_id == "upscaler_supir_advanced_settings" and output_updates[1]["value"] == UpscalerEngine.SUPIR.value:
dc_type = type(UPSCALER_CONFIG_MAPPING["SUPIRAdvanced"])
instance = dataclass_from_dict(dc_type, value_from_preset)
instance = apply_dynamic_changes(instance, SUPIR_ADVANCED_RULES)
UPSCALER_CONFIG_MAPPING["SUPIRAdvanced"] = instance
elif component.elem_id == "upscaler_settings" and output_updates[1]["value"] == UpscalerEngine.FaithDiff.value:
dc_type = type(UPSCALER_CONFIG_MAPPING[UpscalerEngine.FaithDiff.value])
instance = dataclass_from_dict(dc_type, value_from_preset)
UPSCALER_CONFIG_MAPPING[UpscalerEngine.FaithDiff.value] = instance
elif component.elem_id == "upscaler_settings" and output_updates[1]["value"] == UpscalerEngine.ControlNetTile.value:
dc_type = type(UPSCALER_CONFIG_MAPPING[UpscalerEngine.ControlNetTile.value])
instance = dataclass_from_dict(dc_type, value_from_preset)
UPSCALER_CONFIG_MAPPING[UpscalerEngine.ControlNetTile.value] = instance
if instance is None:
instance = dataclass_from_dict(dc_type, value_from_preset)
output_updates[output_index] = gr.update(value=instance)
# For all other standard Gradio components
else:
output_updates[output_index] = gr.update(value=value_from_preset)
# Update the backend state (SAMPLER_MAPPING)
# Iterate over the special keys you saved for the samplers
for sampler_key in [
"restorer_sampler_settings",
"upscaler_sampler_settings",
]:
if sampler_key in preset_data:
sampler_data = preset_data[sampler_key]
# The key in SAMPLER_MAPPING is the elem_id (e.g., "restorer_sampler_settings")
target_instance = SAMPLER_MAPPING.get(sampler_key.rpartition("_")[0])
if target_instance and is_dataclass(target_instance) and isinstance(sampler_data, dict):
# Populate the existing instance in SAMPLER_MAPPING with the preset data
# This is similar to how on_flyout_change works
for field_name, value in sampler_data.items():
if hasattr(target_instance, field_name):
setattr(target_instance, field_name, value)
gr.Info(f"Preset '{preset_name}' loaded successfully.")
return output_updates
except Exception as e:
raise gr.Error(f"Failed to load or apply preset '{preset_name}': {e}")
def restart(self):
"""Triggers a restart of the Python script."""
print("Please wait. The UI is being restarted...")
os.execv(sys.executable, [os.path.basename(sys.executable)] + sys.argv)
# region Flyout Event Function Logic
def on_handle_flyout_toggle(self, is_vis, current_anchor, *, clicked_elem_id, target_elem_id):
"""
Manages the visibility and content of a flyout panel based on user interaction.
This function determines whether to show, hide, or update a flyout panel.
- If the clicked element is already the active anchor, it hides the flyout.
- Otherwise, it shows the flyout, positioning it relative to the clicked element,
and populates it with the correct settings from `SAMPLER_MAPPING`.
Args:
is_vis (bool): The current visibility state of the flyout.
current_anchor (str): The elem_id of the current element the flyout is anchored to.
clicked_elem_id (str): The elem_id of the element that was just clicked.
target_elem_id (str): The elem_id of the flyout panel to control.
Returns:
Tuple[bool, Optional[str], gr.update, gr.update]: A tuple of updates for:
- flyout_visible (gr.State)
- active_anchor_id (gr.State)
- flyout_sheet (PropertySheet content)
- js_data_bridge (JSON data for the frontend)
"""
settings_obj = SAMPLER_MAPPING.get(clicked_elem_id)
if settings_obj is None: # not a propertysheet
# Command JS to show and position
js_data = json.dumps(
{
"isVisible": True,
"anchorId": clicked_elem_id,
"targetId": target_elem_id,
}
)
return True, clicked_elem_id, gr.skip(), gr.update(value=js_data)
if is_vis and current_anchor == clicked_elem_id:
# Command JS to hide
js_data = json.dumps({"isVisible": False, "anchorId": None, "targetId": target_elem_id})
return False, None, gr.update(), gr.update(value=js_data)
else:
# Command JS to show and position
js_data = json.dumps(
{
"isVisible": True,
"anchorId": clicked_elem_id,
"targetId": target_elem_id,
}
)
return (
True,
clicked_elem_id,
gr.update(value=settings_obj),
gr.update(value=js_data),
)
def on_update_ear_visibility(self, elem_id: str):
"""
Controls the visibility of an 'ear' button next to a sampler dropdown.
The button is made visible only if the selected sampler has advanced settings
defined in the global `SAMPLER_MAPPING`.
Args:
elem_id (str): The elem_id of the sampler dropdown component.
Returns:
gr.update: A Gradio update object to set the visibility of the ear button.
"""
has_settings = elem_id in SAMPLER_MAPPING
return gr.update(visible=has_settings)
def on_flyout_change(self, updated_settings, active_id):
"""
Callback for when the flyout PropertySheet's value changes.
It updates the corresponding sampler settings object in the global
`SAMPLER_MAPPING` dictionary with the new values from the flyout.
Args:
updated_settings (dataclass): The new settings object from the PropertySheet.
active_id (str): The elem_id of the component that triggered the flyout,
used as a key in `SAMPLER_MAPPING`.
"""
if updated_settings is None or active_id is None:
return
if active_id in SAMPLER_MAPPING:
original_settings_obj = SAMPLER_MAPPING[active_id]
for f in dataclasses.fields(original_settings_obj):
if hasattr(updated_settings, f.name):
setattr(original_settings_obj, f.name, getattr(updated_settings, f.name))
def on_close_the_flyout(self, target_elem_id):
"""
Closes the flyout panel.
This function prepares the necessary state and JS data to command the frontend
to hide the flyout panel.
Args:
target_elem_id (str): The elem_id of the flyout panel to close.
Returns:
Tuple[bool, None, gr.update]: Updates for flyout visibility state,
active anchor ID, and the JS data bridge.
"""
js_data = json.dumps({"isVisible": False, "anchorId": None, "targetId": target_elem_id})
return False, None, gr.update(value=js_data)
def initial_flyout_setup(self):
"""
Sets the initial visibility for all ear buttons on application load.
Returns:
Dict[gr.Button, gr.update]: A dictionary mapping each ear button
component to its visibility update.
"""
return {
"restorer_sampler_ear_btn": self.on_update_ear_visibility("restorer_sampler"),
"upscaler_sampler_ear_btn": self.on_update_ear_visibility("upscaler_sampler"),
}
# endregion
# region Tokenizer Event Function Logic
def update_positive_tokenizer(self, p1, p2):
"""
Combines two positive prompt textboxes into a single string for the tokenizer.
Args:
p1 (str): Content of the first prompt textbox.
p2 (str): Content of the second prompt textbox.
Returns:
gr.update: An update for the TokenizerTextBox with the combined prompt.
"""
return gr.update(value=f"{p1}\n{p2}".strip())
def update_positive_prompt_helper(self, evt: gr.EventData):
"""
Updates the target textbox for the positive TagGroupHelper.
This is triggered on focus for a prompt textbox, ensuring that when a tag
is clicked in the helper, it's inserted into the currently active prompt box.
Args:
evt (gr.EventData): Event data from Gradio, containing the target component.
Returns:
gr.update: An update for the TagGroupHelper to set its target textbox ID.
"""
return gr.update(target_textbox_id=evt.target.elem_id)
def update_prompt_helper_from_tab(self, evt: gr.EventData):
"""
Updates the target textboxes for both TagGroupHelpers when a tab is selected.
This ensures the helpers target the correct prompt and negative prompt
textboxes based on whether the 'Restoration' or 'Upscaling' tab is active.
Args:
evt (gr.EventData): Event data from Gradio, containing the selected tab.
Returns:
Dict[TagGroupHelper, gr.update]: A dictionary of updates for both the
positive and negative tag helpers.
"""
if evt.target.elem_id == "res-tab":
return (
gr.update(target_textbox_id="restorer_prompt_1"),
gr.update(target_textbox_id="restorer_negative_prompt"),
)
else:
return (
gr.update(target_textbox_id="upscaler_prompt_1"),
gr.update(target_textbox_id="upscaler_negative_prompt"),
)
# endregion
# region Others UI Event Function Logic
def on_reset_settings(self):
"""
Handles the 'reset settings' button click. Loads default
settings via `uidata` and displays an info message to the user before
the UI is restarted.
"""
self.state.uidata.load_defaults()
gr.Info("Defaults loaded! UI will be restarted!")
def on_save_settings(self, settings_dict):
"""
Handles the 'save settings' button click. Converts the settings
dataclass to a dictionary and saves it using `uidata`. Displays a
confirmation message before the UI is restarted.
Args:
settings_dict (AppSettings): The settings dataclass instance from the PropertySheet.
"""
values_dict = asdict(settings_dict)
self.state.uidata.save_settings(values_dict)
gr.Info("Settings saved! UI will be restarted!")
def on_settings_sheet_change(self, updated_settings: Any):
"""
Handles changes in the AppSettings PropertySheet.
It applies dynamic visibility rules to the settings sheet based on the
current values (e.g., hiding `quantization_mode` if `quantization_method`
is 'None').
Args:
updated_settings (Any): The updated AppSettings dataclass instance
from the PropertySheet.
Returns:
AppSettings: The modified AppSettings instance with dynamic rules applied.
"""
if updated_settings is None:
return updated_settings
rules_to_apply = []
if updated_settings.quantization_method == "None":
rules_to_apply.append({"quantization_mode": False})
else:
rules_to_apply.append({"quantization_mode": True})
self.state.uidata.add_visibility_rules(APPSETTINGS_SHEET_DEPENDENCY_RULES, rules_to_apply)
return apply_dynamic_changes(updated_settings, APPSETTINGS_SHEET_DEPENDENCY_RULES)
def on_set_default_prompts(
self,
res_engine_name,
ups_engine_name,
res_prompt_value,
res_prompt_2_value,
res_negative_prompt_value,
ups_prompt_value,
ups_prompt_2_value,
ups_negative_prompt_value,
):
"""
Sets default prompts in the UI when an engine is selected or changed.
This function checks if the current prompt fields are empty or if the engine
has changed. If so, it populates the respective prompt fields with
pre-defined default values for the newly selected engine.
Args:
res_engine_name (str): The selected restorer engine name.
ups_engine_name (str): The selected upscaler engine name.
res_prompt_value (str): Current value of the restorer's prompt 1.
res_prompt_2_value (str): Current value of the restorer's prompt 2.
res_negative_prompt_value (str): Current value of the restorer's negative prompt.
ups_prompt_value (str): Current value of the upscaler's prompt 1.
ups_prompt_2_value (str): Current value of the upscaler's prompt 2.
ups_negative_prompt_value (str): Current value of the upscaler's negative prompt.
Returns:
Tuple[str, str, str, str, str, str]: A tuple containing the new values
for all six prompt textboxes.
"""
restorer_engine_selected = self.state.restorer_engine_selected
upscaler_engine_selected = self.state.upscaler_engine_selected
# Set defaults for restorer prompts only if input is empty/None
if res_engine_name == "None":
res_prompt = ""
elif not res_prompt_value or (restorer_engine_selected != res_engine_name):
res_prompt = DEFAULT_PROMPTS["Restorer"][res_engine_name]["prompt"]
else:
res_prompt = res_prompt_value
if res_engine_name == "None":
res_prompt_2 = ""
elif not res_prompt_2_value or (restorer_engine_selected != res_engine_name):
res_prompt_2 = DEFAULT_PROMPTS["Restorer"][res_engine_name]["prompt_2"]
else:
res_prompt_2 = res_prompt_2_value
if res_engine_name == "None":
res_negative = ""
elif not res_negative_prompt_value or (restorer_engine_selected != res_engine_name):
res_negative = DEFAULT_PROMPTS["Restorer"][res_engine_name]["negative_prompt"]
else:
res_negative = res_negative_prompt_value
# Set defaults for upscaler prompts only if input is empty/None
if ups_engine_name == "None":
ups_prompt = ""
elif not ups_prompt_value or (upscaler_engine_selected != ups_engine_name):
ups_prompt = DEFAULT_PROMPTS["Upscaler"][ups_engine_name]["prompt"]
else:
ups_prompt = ups_prompt_value
if ups_engine_name == "None":
ups_prompt_2 = ""
elif not ups_prompt_2_value or (upscaler_engine_selected != ups_engine_name):
ups_prompt_2 = DEFAULT_PROMPTS["Upscaler"][ups_engine_name]["prompt_2"]
else:
ups_prompt_2 = ups_prompt_2_value
if ups_engine_name == "None":
ups_negative = ""
elif not ups_negative_prompt_value or (upscaler_engine_selected != ups_engine_name):
ups_negative = DEFAULT_PROMPTS["Upscaler"][ups_engine_name]["negative_prompt"]
else:
ups_negative = ups_negative_prompt_value
self.state.restorer_engine_selected = res_engine_name
self.state.upscaler_engine_selected = ups_engine_name
return (
res_prompt,
res_prompt_2,
res_negative,
ups_prompt,
ups_prompt_2,
ups_negative,
)
def _update_sheet_changes(self, updated_config, mode="Restorer"):
"""
Internal helper to apply dynamic changes to a PropertySheet's dataclass.
This function applies visibility rules and handles seed randomization logic
common to both the restorer and upscaler PropertySheets.
Args:
updated_config (dataclass): The configuration dataclass instance to update.
mode (str): Either "Restorer" or "Upscaler" to apply the correct rules.
Returns:
dataclass: The modified configuration dataclass.
"""
if mode == "Restorer":
updated_config = apply_dynamic_changes(updated_config, RESTORER_SHEET_DEPENDENCY_RULES)
previous_randomize_state = get_nested_attr(self.state.restorer_config_class, "general.randomize_seed")
else:
rules_to_apply = []
if updated_config.general.upscaling_mode == "Direct":
rules_to_apply.extend(
[
{"general.cfg_decay_rate": False},
{"general.strength_decay_rate": False},
]
)
else:
rules_to_apply.extend(
[
{"general.cfg_decay_rate": True},
{"general.strength_decay_rate": True},
]
)
self.state.uidata.add_visibility_rules(UPSCALER_SHEET_DEPENDENCY_RULES, rules_to_apply)
updated_config = apply_dynamic_changes(updated_config, UPSCALER_SHEET_DEPENDENCY_RULES)
previous_randomize_state = get_nested_attr(self.state.upscaler_config_class, "general.randomize_seed")
should_generate_new_seed = (updated_config.general.randomize_seed and not previous_randomize_state) or (
updated_config.general.randomize_seed and updated_config.general.seed == -1
)
if should_generate_new_seed:
updated_config.general.seed = random.randint(0, self.state.uidata.MAX_SEED)
return updated_config
def on_restore_engine_change(self, restorer_engine_name: str, upscaler_engine_name: str):
"""
Handles UI changes when the restorer engine dropdown is modified.
This updates the visibility of the restoration tab, loads the correct
configuration object into the `restorer_sheet` PropertySheet, and shows/hides
the advanced SUPIR settings sheet accordingly.
Args:
restorer_engine_name (str): The newly selected restorer engine.
upscaler_engine_name (str): The current upscaler engine.
Returns:
Tuple[gr.update, gr.update, gr.update, gr.update, gr.update]: A tuple of
updates for the restoration tab, advanced settings, main settings sheet,
engine configuration accordion, and the config tabs selector.
"""
is_restorer_active = restorer_engine_name != "None" and restorer_engine_name in RESTORER_CONFIG_MAPPING
is_upscaler_active = upscaler_engine_name != "None" and upscaler_engine_name in UPSCALER_CONFIG_MAPPING
is_supir = restorer_engine_name == "SUPIR"
if is_restorer_active:
config_class = RESTORER_CONFIG_MAPPING.get(restorer_engine_name, SUPIR_Config)
if is_supir:
self.state.restorer_supir_advanced_config_class = self.get_supir_advanced_values()
self.state.restorer_config_class = self._update_sheet_changes(config_class, "Restorer")
else:
self.state.restorer_config_class = None
if is_supir:
self.state.restorer_supir_advanced_config_class = None
ec_visible = is_restorer_active or is_upscaler_active
selected_tab = 1 if is_upscaler_active and not is_restorer_active else 0
return (
gr.update(visible=is_restorer_active),
gr.update(value=self.state.restorer_supir_advanced_config_class, visible=is_supir),
gr.update(
value=self.state.restorer_config_class,
label=f"{restorer_engine_name} Settings",
),
gr.update(visible=ec_visible),
gr.update(selected=selected_tab),
)
def on_upscaler_engine_change(self, restorer_engine_name: str, upscaler_engine_name: str):
"""
Handles UI changes when the upscaler engine dropdown is modified.
This updates the visibility of the upscaling tab, loads the correct
configuration object into the `upscaler_sheet` PropertySheet, and manages
the visibility of related UI elements like advanced SUPIR settings.
Args:
restorer_engine_name (str): The current restorer engine.
upscaler_engine_name (str): The newly selected upscaler engine.
Returns:
Tuple[gr.update, ...]: A tuple of updates for the upscaling tab,
advanced settings, main settings sheet, engine accordion, config tabs,
and the prompt method dropdown.
"""
is_restorer_active = restorer_engine_name != "None" and restorer_engine_name in RESTORER_CONFIG_MAPPING
is_upscaler_active = upscaler_engine_name != "None" and upscaler_engine_name in UPSCALER_CONFIG_MAPPING
is_supir = upscaler_engine_name == "SUPIR"
is_controlnettile = upscaler_engine_name == "ControlNetTile"
if is_upscaler_active:
config_class = UPSCALER_CONFIG_MAPPING.get(upscaler_engine_name, ControlNetTile_Config)
if is_supir:
self.state.upscaler_supir_advanced_config_class = self.get_supir_advanced_values()
self.state.upscaler_config_class = self._update_sheet_changes(config_class, "Upscaler")
else:
self.state.upscaler_config_class = None
if is_supir:
self.state.upscaler_supir_advanced_config_class = None
ec_visible = is_restorer_active or is_upscaler_active
selected_tab = 1 if is_upscaler_active and not is_restorer_active else 0
return (
gr.update(visible=is_upscaler_active),
gr.update(value=self.state.upscaler_supir_advanced_config_class, visible=is_supir),
gr.update(
value=self.state.upscaler_config_class,
label=f"{upscaler_engine_name} Settings",
),
gr.update(visible=ec_visible),
gr.update(selected=selected_tab),
gr.update(visible=not is_controlnettile),
)
def on_restorer_sheet_change(self, updated_config: Union[SUPIR_Config, FaithDiff_Config, None]):
"""
Handles changes from the restorer configuration PropertySheet.
It calls the internal `_update_sheet_changes` helper to apply dynamic rules
and manage seed randomization, then updates the global state.
Args:
updated_config (dataclass | None): The new configuration from the sheet.
Returns:
dataclass: The updated and processed configuration dataclass.
"""
if updated_config is None:
return self.state.restorer_config_class
self.state.restorer_config_class = self._update_sheet_changes(updated_config, "Restorer")
return self.state.restorer_config_class
def on_restorer_supir_advanced_sheet_change(self, updated_config: SUPIRAdvanced_Config | None):
"""
Handles changes from the restorer supir advanced configuration PropertySheet.
It calls the internal `apply_dynamic_changes` helper to apply dynamic rules
and manage seed randomization, then updates the global state.
Args:
updated_config (dataclass | None): The new configuration from the sheet.
Returns:
dataclass: The updated and processed configuration dataclass.
"""
if updated_config is None:
return self.state.restorer_supir_advanced_config_class
self.state.restorer_supir_advanced_config_class = apply_dynamic_changes(updated_config, SUPIR_ADVANCED_RULES)
return self.state.restorer_supir_advanced_config_class
def on_upscaler_sheet_change(
self,
updated_config: Union[SUPIR_Config, ControlNetTile_Config, FaithDiff_Config, None],
):
"""
Handles changes from the upscaler configuration PropertySheet.
It calls the internal `_update_sheet_changes` helper to apply dynamic rules
(like for progressive upscaling) and manage seed randomization, then
updates the global state.
Args:
updated_config (dataclass | None): The new configuration from the sheet.
Returns:
dataclass: The updated and processed configuration dataclass.
"""
if updated_config is None:
return self.state.upscaler_config_class
self.state.upscaler_config_class = self._update_sheet_changes(updated_config, "Upscaler")
return self.state.upscaler_config_class
def on_upscaler_supir_advanced_sheet_change(self, updated_config: SUPIRAdvanced_Config | None):
"""
Handles changes from the upscaler supir advanced configuration PropertySheet.
It calls the internal `apply_dynamic_changes` helper to apply dynamic rules
and manage seed randomization, then updates the global state.
Args:
updated_config (dataclass | None): The new configuration from the sheet.
Returns:
dataclass: The updated and processed configuration dataclass.
"""
if updated_config is None:
return self.state.upscaler_supir_advanced_config_class
self.state.upscaler_supir_advanced_config_class = apply_dynamic_changes(updated_config, SUPIR_ADVANCED_RULES)
return self.state.upscaler_supir_advanced_config_class
def on_settings_tab_select(self):
"""
Callback triggered when the 'Settings' tab is selected.
Ensures the settings sheet is correctly rendered with any dynamic rules applied.
Returns:
AppSettings: The processed AppSettings dataclass to be rendered.
"""
return self.on_settings_sheet_change(self.state.uidata.settings)
def on_cancel_click(self):
"""
Handles the 'Cancel' button click by setting a global threading event.
The running pipeline periodically checks this event and will stop execution
if it is set.
"""
self.state.cancel_event.set()
def load_metadata(self, metadata: dict):
"""
Loads settings from an image's metadata dictionary into the UI components.
This function acts as a bridge between the raw metadata and the UI. It defines
mappings for complex components like PropertySheets and standard Gradio
components, then uses a helper (`transfer_metadata`) to apply the values.
Args:
metadata (dict): A dictionary of key-value pairs extracted from image metadata.
Returns:
List[Any]: A list of values in the correct order to update the UI
output components.
Raises:
gr.Error: If metadata conversion for a sampler setting fails.
"""
# Get UI components
ui_inputs, output_fields = self.components._get_ui_inputs_and_outputs()
ui_inputs_for_metadata = ui_inputs.copy()
restorer_sheet = self.components.restorer_sheet
upscaler_sheet = self.components.upscaler_sheet
restorer_sheet_supir_advanced = self.components.restorer_sheet_supir_advanced
upscaler_sheet_supir_advanced = self.components.upscaler_sheet_supir_advanced
# Define the map that tells the helper how to process each PropertySheet.
# This is the "glue" between your generic helper and your specific app.
source_restorer_engine = RestorerEngine.from_str(metadata["Image Restore Engine"])
source_upscaler_engine = UpscalerEngine.from_str(metadata["Image Upscale Engine"])
source_restorer_class = RESTORER_CONFIG_MAPPING.get(source_restorer_engine.value)
source_upscaler_class = UPSCALER_CONFIG_MAPPING.get(source_upscaler_engine.value)
sheet_map = {}
if source_restorer_class:
sheet_map[id(restorer_sheet)] = {
"type": source_restorer_class.__class__,
"prefixes": ["Restorer", "Image Restore Engine"],
}
if source_restorer_engine == RestorerEngine.SUPIR:
sheet_map[id(restorer_sheet_supir_advanced)] = {
"type": restorer_sheet_supir_advanced._dataclass_type,
"prefixes": ["Restorer"],
}
if source_upscaler_class:
sheet_map[id(upscaler_sheet)] = {
"type": source_upscaler_class.__class__,
"prefixes": ["Upscaler", "Image Upscale Engine"],
}
if source_upscaler_engine == UpscalerEngine.SUPIR:
sheet_map[id(upscaler_sheet_supir_advanced)] = {
"type": upscaler_sheet_supir_advanced._dataclass_type,
"prefixes": ["Upscaler"],
}
gradio_map = {id(component): label for label, component in ui_inputs_for_metadata.items()}
output_values = transfer_metadata(
output_fields=output_fields,
metadata=metadata,
propertysheet_map=sheet_map,
gradio_component_map=gradio_map,
)
sampler_map_data = {
"restorer_sampler_settings": {
"prefix": "Restorer - Sampler",
"instance": SAMPLER_MAPPING.get("restorer_sampler"),
},
"upscaler_sampler_settings": {
"prefix": "Upscaler - Sampler",
"instance": SAMPLER_MAPPING.get("upscaler_sampler"),
},
}
for _, data in sampler_map_data.items():
sampler_instance, prefix = data["instance"], data["prefix"]
if not (sampler_instance and is_dataclass(sampler_instance)):
continue
for field in fields(sampler_instance):
label = field.metadata.get("label", field.name.replace("_", " ").title())
metadata_key = f"{prefix} - {label}"
if metadata_key in metadata:
try:
setattr(
sampler_instance,
field.name,
infer_type(metadata[metadata_key]),
)
except (ValueError, TypeError):
print(f"Warning: Could not convert metadata value '{metadata[metadata_key]}' for sampler field '{field.name}'.")
raise gr.Error("Error loading Image metadata, see console log.")
return output_values
def on_load_metadata_from_gallery(self, folder_explorer, image_data: gr.EventData):
"""
Callback to load metadata from an image selected in the 'Generated' gallery.
It extracts metadata from the selected image, calls the `load_metadata`
helper to process it, and applies the final settings to the UI components.
It also handles path corrections for example images.
Args:
folder_explorer (Any): The current value of the folder explorer component.
image_data (gr.EventData): Event data for the selected image, containing metadata.
Returns:
List[Any]: A list of values to update the UI components with the loaded settings.
"""
# Get output_fields
_, output_fields = self.components._get_ui_inputs_and_outputs()
gallery_path = Path(folder_explorer)
is_example = all(part in gallery_path.parts for part in ["outputs", "examples"])
# Initial checks for valid input
if not image_data or not hasattr(image_data, "_data"):
return [gr.skip()] * len(output_fields)
metadata = image_data._data
output_values = self.load_metadata(metadata)
if output_values[0]:
self.state.restorer_config_class = self._update_sheet_changes(output_values[0], "Restorer")
RESTORER_CONFIG_MAPPING[metadata["Image Restore Engine"]] = self.state.restorer_config_class
if output_values[2]:
self.state.upscaler_config_class = self._update_sheet_changes(output_values[2], "Upscaler")
UPSCALER_CONFIG_MAPPING[metadata["Image Upscale Engine"]] = self.state.upscaler_config_class
output_values[0], output_values[2] = (
self.state.restorer_config_class,
self.state.upscaler_config_class,
)
input_image_name = output_values[11]
output_values[11] = os.path.join("assets/samples", input_image_name) if is_example and input_image_name else None
gr.Info("Image metadata loaded.")
return output_values
def on_input_image_change(self, input_image):
self.state.input_image_path = getattr(input_image, "path", None)
def on_load_metadata_from_single_image(self, image_data):
"""
Callback to load metadata from the main input image component.
Triggered when an image with embedded metadata is uploaded. It extracts the
metadata and calls the `load_metadata` helper to apply it to the UI.
Args:
image_data (Image.Image | None): The image object from the ImageMeta component.
Returns:
List[Any]: A list of values/updates for the UI components.
"""
# Get output_fields
_, output_fields = self.components._get_ui_inputs_and_outputs()
# Initial checks for valid input
if not image_data or not hasattr(image_data, "path"):
return [gr.skip()] * len(output_fields)
# Extract the flat metadata dictionary from the image
metadata = extract_metadata(image_data, only_custom_metadata=True)
if not metadata:
return [gr.skip()] * len(output_fields)
output_values = self.load_metadata(metadata)
if output_values[0]:
self.state.restorer_config_class = self._update_sheet_changes(output_values[0], "Restorer")
RESTORER_CONFIG_MAPPING[metadata["Image Restore Engine"]] = self.state.restorer_config_class
if output_values[2]:
self.state.upscaler_config_class = self._update_sheet_changes(output_values[2], "Upscaler")
UPSCALER_CONFIG_MAPPING[metadata["Image Upscale Engine"]] = self.state.upscaler_config_class
output_values[0], output_values[2] = (
self.state.restorer_config_class,
self.state.upscaler_config_class,
)
output_values[11] = None
gr.Info("Image metadata loaded.")
return output_values
def on_clear_log_output(self):
"""
Clears the content of the LiveLog viewer component.
Returns:
None: Sets the value of the LiveLog component to None, clearing it.
"""
return None
def on_check_inputs(
self,
restorer_engine,
upscaler_engine,
restorer_model_name,
upscaler_model_name,
action="generation_process",
):
"""
Validates the core inputs before starting a process like generation or masking.
Raises a gr.Error with a user-friendly message if validation fails.
Also configures the bottom bar and logger for the upcoming process.
Args:
restorer_engine (str): The selected restorer engine.
upscaler_engine (str): The selected upscaler engine.
restorer_model_name (str): The selected restorer model.
upscaler_model_name (str): The selected upscaler model.
action (str): The type of action being initiated, used to tailor checks.
Returns:
Tuple[gr.update, gr.update]: Updates to open the bottom bar and configure
the LiveLog display mode.
"""
if self.state.input_image_path is None:
raise gr.Error("Input image is required. Please upload an image to proceed.")
if action == "generation_process":
if restorer_engine == "None" and upscaler_engine == "None":
raise gr.Error("Please select at least a Restorer or an Upscaler engine.")
if restorer_engine != "None" and restorer_model_name == "None":
raise gr.Error("Please select a Restorer model when using the Restore engine.")
if upscaler_engine != "None" and upscaler_model_name == "None":
raise gr.Error("Please select an Upscaler model when using the Upscale engine.")
return gr.update(open=True), gr.update(display_mode="full" if action == "generation_process" else "log")
def on_refresh_restoration_mask(self, mask_prompt, **kwargs):
"""
Generates a face restoration mask based on a text prompt.
This function is decorated with `@livelog` to stream log outputs to the UI.
It initializes the pipeline and calls the `generate_prompt_mask` method.
Args:
mask_prompt (str): The prompt used to identify the area to mask (e.g., "head").
lq_image_path (Any): The file object for the low-quality input image.
**kwargs: Injected by the `@livelog` decorator (e.g., `log_callback`).
Returns:
PIL.Image: The generated mask image.
"""
log_callback = kwargs.get("log_callback")
logger = logging.getLogger(kwargs.get("log_name", "suptoolbox_app"))
log_callback(log_content="Starting masking generation...")
if self.state.input_image_path is None:
raise gr.Error("Input image must be provided!")
lq_image = Image.open(self.state.input_image_path).convert("RGB")
self.state.cancel_event.clear()
self.update_pipeline(log_callback=log_callback)
try:
mask, _ = sup_toolbox_pipe.generate_prompt_mask(lq_image, mask_prompt)
if mask is None:
gr.Warning("Mask couldn't be generated!")
return mask
except Exception as e:
logger.error(f"Error in generation object mask: {e}, process aborted!", exc_info=True)
raise e
def on_generate_caption(self, **kwargs):
"""
Generates a descriptive caption for the input image.
Decorated with `@livelog` to stream logs. It initializes the pipeline
and uses the `generate_caption` method.
Args:
lq_image_path (Any): The file object for the low-quality input image.
**kwargs: Injected by the `@livelog` decorator.
Returns:
str: The generated image caption.
"""
log_callback = kwargs.get("log_callback")
logger = logging.getLogger(kwargs.get("log_name", "suptoolbox_app"))
log_callback(log_content="Starting caption generation...")
if self.state.input_image_path is None:
raise gr.Error("Input image must be provided!")
lq_image = Image.open(self.state.input_image_path).convert("RGB")
self.state.cancel_event.clear()
self.update_pipeline(log_callback=log_callback)
try:
caption = sup_toolbox_pipe.generate_caption(lq_image)
if caption is None:
gr.Warning("Caption couldn't be generated!")
return caption
except Exception as e:
logger.error(f"Error in caption generation: {e}, process aborted!", exc_info=True)
raise e
def on_generate(self, *args):
"""
Starts the image processing thread and yields updates from a queue.
This method is the entry point for the Gradio event chain.
"""
# 1. Send initial UI state update (disable buttons, show progress).
yield None, None, gr.update(interactive=False), gr.update(visible=True), gr.update(open=True)
update_queue = queue.Queue()
# 2. Package arguments for the worker thread.
thread_args = (update_queue, *args)
# 3. Start the image processing worker thread.
diffusion_thread = threading.Thread(target=self._process_image, args=thread_args)
diffusion_thread.start()
# 4. Loop to read from the queue and yield updates to the UI.
final_images, log_update = None, None
while True:
update = update_queue.get()
if update is None: # Sentinel value indicating the thread has finished.
break
images, log_update = update
if images:
final_images = images
# Yield the update to the Gradio outputs.
yield final_images, log_update, gr.skip(), gr.skip(), gr.skip()
# 5. Send final UI state update (re-enable buttons).
yield final_images, log_update, gr.update(interactive=True), gr.update(visible=False), gr.skip()
def _process_image(self, update_queue: queue.Queue, total_steps: int, *input_params: Any, **kwargs):
"""
Executes the main image processing pipeline in a worker thread.
This method receives simple, serializable data types from the UI event.
It retrieves complex configuration objects (like PropertySheet values)
directly from the shared application state (`self.state`), which are kept
up-to-date by their respective `.change()` events.
Args:
update_queue (queue.Queue): The queue for sending status updates back to the UI.
total_steps (int): The pre-calculated total number of diffusion steps.
*input_params (Any): A tuple of the remaining simple UI input values
(e.g., engine selections, prompts), passed positionally.
**kwargs: Catches any extra keyword arguments.
Side Effects:
- Puts multiple update dictionaries and a final `None` sentinel onto the `update_queue`.
- Modifies and uses the shared `sup_toolbox_pipe` instance via `self.update_pipeline`.
- Catches all exceptions and reports them via the queue.
"""
from sup_toolbox.sup_toolbox_pipeline import PipelineCancelationRequested
# 1. Prepare input params
restorer_config = self.state.restorer_config_class
upscaler_config = self.state.upscaler_config_class
restorer_supir_advanced_config = self.state.restorer_supir_advanced_config_class
upscaler_supir_advanced_config = self.state.upscaler_supir_advanced_config_class
ui_inputs, _ = self.components._get_ui_inputs_and_outputs()
ui_inputs_keys_sliced = list(ui_inputs.keys())[4:]
input_param_with_values = dict(zip(ui_inputs_keys_sliced, input_params))
input_image, res_engine, ups_engine = (
self.state.input_image_path,
input_param_with_values.get("Image Restore Engine"),
input_param_with_values.get("Image Upscale Engine"),
)
if hasattr(input_image, "orig_name"):
input_param_with_values["Input Image"] = input_image.orig_name
else:
input_param_with_values.pop("Input Image", None)
if input_image is None:
update_queue.put((None, {"logs": [{"type": "log", "level": "ERROR", "content": "Error: Input image is required."}]}))
update_queue.put(None)
return
# 2. Generate image metadata
image_metadata = self.generate_image_metadata(
restorer_config,
upscaler_config,
(restorer_supir_advanced_config if res_engine == RestorerEngine.SUPIR.value else None),
(upscaler_supir_advanced_config if ups_engine == UpscalerEngine.SUPIR.value else None),
input_param_with_values,
SAMPLER_MAPPING["restorer_sampler"],
SAMPLER_MAPPING["upscaler_sampler"],
)
tracker = None
self.state.cancel_event.clear()
with capture_logs(log_level=logging.INFO, log_name=["suptoolbox_app", "suptoolbox"]) as get_logs:
try:
rate_queue = queue.Queue()
tqdm_writer = TqdmToQueueWriter(rate_queue)
progress_bar_handler = Tee(sys.stderr, tqdm_writer)
all_logs, last_known_rate_data = [], None
# 2. Prepare selected engines
config = self.state.uidata.config
_restorer_config, _upscaler_config = self.prepare_engine_configs(res_engine, restorer_config, ups_engine, upscaler_config)
# 3. Define update callback
def process_and_send_updates(status="running", advance=0, final_image_payload=None):
nonlocal all_logs, last_known_rate_data
new_rate_data = None
while not rate_queue.empty():
try:
new_rate_data = rate_queue.get_nowait()
except queue.Empty:
break
if new_rate_data:
last_known_rate_data = new_rate_data
new_records = get_logs()
if new_records:
new_logs = [
{"type": "log", "level": "SUCCESS" if r.levelno == logging.INFO + 5 else r.levelname, "content": r.getMessage()}
for r in new_records
]
all_logs.extend(new_logs)
update_dict = (
tracker.update(advance=advance, status=status, logs=all_logs, rate_data=last_known_rate_data)
if tracker
else {"type": "progress", "logs": all_logs, "current": 0, "total": total_steps, "desc": "Diffusion Steps"}
)
update_queue.put((final_image_payload, update_dict))
logger = logging.getLogger("suptoolbox_app")
logger.info("Starting diffusion process...")
process_and_send_updates()
# 4. Create tracker
tracker = ProgressTracker(total=total_steps, description="Diffusion Steps", rate_unit="s/it")
# 5. Define pipeline progress callback
def progress_callback(_, __, ___, callback_kwargs):
process_and_send_updates(advance=callback_kwargs["advance"])
return callback_kwargs
# 6. Map all pipeline configuration from UI to SUP-Toolbox Pipeline
config = self.state.uidata.config
config.running_on_spaces = True if RUN_ON_SPACES == "True" else False
config.restorer_engine, config.upscaler_engine = (
RestorerEngine.from_str(res_engine),
UpscalerEngine.from_str(ups_engine),
)
config.selected_vae_model = input_param_with_values["VAE Model"]
if res_engine == RestorerEngine.SUPIR.value:
_restorer_config = cast(SUPIR_Config, _restorer_config)
config.selected_restorer_checkpoint_model = input_param_with_values["Image Restore Model"]
config.restorer_engine = RestorerEngine.SUPIR
config.selected_restorer_sampler = Sampler.from_str(input_param_with_values["Image Restore Sampler"])
config.restorer_pipeline_params.supir_model = SUPIRModel.from_str(_restorer_config.supir_model)
config.restorer_pipeline_params.seed = _restorer_config.general.seed
config.restorer_pipeline_params.upscale_factor = _restorer_config.general.upscale_factor
config.restorer_pipeline_params.prompt = input_param_with_values["Image Restore - Prompt 1"]
config.restorer_pipeline_params.prompt_2 = input_param_with_values["Image Restore - Prompt 2"]
config.restorer_pipeline_params.negative_prompt = input_param_with_values["Image Restore - Negative prompt"]
config.restore_face = input_param_with_values["Enable Face Restoration"]
config.mask_prompt = input_param_with_values["Mask Prompt"]
config.restorer_pipeline_params.num_images = _restorer_config.general.num_images
config.restorer_pipeline_params.num_steps = _restorer_config.general.num_steps
config.restorer_pipeline_params.use_lpw_prompt = (
True if input_param_with_values["Image Restore - Prompt method"] == PromptMethod.Weighted.value else False
)
config.restorer_pipeline_params.tile_size = _restorer_config.general.tile_size
config.restorer_pipeline_params.restoration_scale = float(_restorer_config.restoration_scale)
config.restorer_pipeline_params.s_churn = float(_restorer_config.s_churn)
config.restorer_pipeline_params.s_noise = float(_restorer_config.s_noise)
config.restorer_pipeline_params.strength = float(_restorer_config.strength)
config.restorer_pipeline_params.use_linear_CFG = _restorer_config.cfg_settings.use_linear_CFG
config.restorer_pipeline_params.guidance_scale = float(_restorer_config.general.guidance_scale)
config.restorer_pipeline_params.guidance_rescale = float(_restorer_config.general.guidance_rescale)
config.restorer_pipeline_params.reverse_linear_CFG = float(_restorer_config.cfg_settings.reverse_linear_CFG)
config.restorer_pipeline_params.guidance_scale_start = float(_restorer_config.cfg_settings.guidance_scale_start)
config.restorer_pipeline_params.use_linear_control_scale = _restorer_config.controlnet_settings.use_linear_control_scale
config.restorer_pipeline_params.reverse_linear_control_scale = _restorer_config.controlnet_settings.reverse_linear_control_scale
config.restorer_pipeline_params.controlnet_conditioning_scale = float(_restorer_config.controlnet_settings.controlnet_conditioning_scale)
config.restorer_pipeline_params.control_scale_start = float(_restorer_config.controlnet_settings.control_scale_start)
config.restorer_pipeline_params.enable_PAG = _restorer_config.pag_settings.enable_PAG and len(_restorer_config.pag_settings.pag_layers) > 0
config.restorer_pipeline_params.use_linear_PAG = _restorer_config.pag_settings.use_linear_PAG
config.restorer_pipeline_params.reverse_linear_PAG = _restorer_config.pag_settings.reverse_linear_PAG
config.restorer_pipeline_params.pag_scale = float(_restorer_config.pag_settings.pag_scale)
config.restorer_pipeline_params.pag_scale_start = float(_restorer_config.pag_settings.pag_scale_start)
config.restorer_pipeline_params.pag_layers = _restorer_config.pag_settings.pag_layers
config.restorer_pipeline_params.start_point = StartPoint.from_str(_restorer_config.start_point)
config.restorer_pipeline_params.image_size_fix_mode = ImageSizeFixMode.from_str(_restorer_config.general.image_size_fix_mode)
config.restorer_pipeline_params.color_fix_mode = ColorFix.from_str(_restorer_config.post_processsing_settings.color_fix_mode)
(
config.restorer_pipeline_params.zero_sft_injection_configs,
config.restorer_pipeline_params.zero_sft_injection_flags,
) = self.state.uidata.map_ui_supir_injection_to_pipeline_params(restorer_supir_advanced_config)
config.restorer_pipeline_params.callback_on_step_end = progress_callback
config.restorer_sampler_config = self.state.uidata.map_scheduler_settings_to_config(SAMPLER_MAPPING["restorer_sampler"])
elif res_engine == RestorerEngine.FaithDiff.value:
_restorer_config = cast(FaithDiff_Config, _restorer_config)
config.selected_restorer_checkpoint_model = input_param_with_values["Image Restore Model"]
config.restorer_engine = RestorerEngine.FaithDiff
config.selected_restorer_sampler = Sampler.from_str(input_param_with_values["Image Restore Sampler"])
config.restorer_pipeline_params.seed = _restorer_config.general.seed
config.restorer_pipeline_params.upscale_factor = _restorer_config.general.upscale_factor
config.restorer_pipeline_params.prompt = input_param_with_values["Image Restore - Prompt 1"]
config.restorer_pipeline_params.prompt_2 = input_param_with_values["Image Restore - Prompt 2"]
config.restorer_pipeline_params.negative_prompt = input_param_with_values["Image Restore - Negative prompt"]
config.restore_face = input_param_with_values["Enable Face Restoration"]
config.mask_prompt = input_param_with_values["Mask Prompt"]
config.restorer_pipeline_params.num_images = _restorer_config.general.num_images
config.restorer_pipeline_params.num_steps = _restorer_config.general.num_steps
config.restorer_pipeline_params.use_lpw_prompt = (
True if input_param_with_values["Image Restore - Prompt method"] == PromptMethod.Weighted.value else False
)
config.restorer_pipeline_params.tile_size = _restorer_config.general.tile_size
config.restorer_pipeline_params.s_churn = float(_restorer_config.s_churn)
config.restorer_pipeline_params.s_noise = float(_restorer_config.s_noise)
config.restorer_pipeline_params.strength = float(_restorer_config.strength)
config.restorer_pipeline_params.guidance_scale = float(_restorer_config.general.guidance_scale)
config.restorer_pipeline_params.guidance_rescale = float(_restorer_config.general.guidance_rescale)
config.restorer_pipeline_params.use_linear_control_scale = _restorer_config.controlnet_settings.use_linear_control_scale
config.restorer_pipeline_params.reverse_linear_control_scale = _restorer_config.controlnet_settings.reverse_linear_control_scale
config.restorer_pipeline_params.controlnet_conditioning_scale = float(_restorer_config.controlnet_settings.controlnet_conditioning_scale)
config.restorer_pipeline_params.control_scale_start = float(_restorer_config.controlnet_settings.control_scale_start)
config.restorer_pipeline_params.enable_PAG = _restorer_config.pag_settings.enable_PAG and len(_restorer_config.pag_settings.pag_layers) > 0
config.restorer_pipeline_params.use_linear_PAG = _restorer_config.pag_settings.use_linear_PAG
config.restorer_pipeline_params.reverse_linear_PAG = _restorer_config.pag_settings.reverse_linear_PAG
config.restorer_pipeline_params.pag_scale = float(_restorer_config.pag_settings.pag_scale)
config.restorer_pipeline_params.pag_scale_start = float(_restorer_config.pag_settings.pag_scale_start)
config.restorer_pipeline_params.pag_layers = _restorer_config.pag_settings.pag_layers
config.restorer_pipeline_params.start_point = StartPoint.from_str(_restorer_config.start_point)
config.restorer_pipeline_params.image_size_fix_mode = ImageSizeFixMode.from_str(_restorer_config.general.image_size_fix_mode)
config.restorer_pipeline_params.color_fix_mode = ColorFix.from_str(_restorer_config.post_processsing_settings.color_fix_mode)
config.restorer_pipeline_params.invert_prompts = _restorer_config.invert_prompts
config.restorer_pipeline_params.apply_ipa_embeds = _restorer_config.apply_ipa_embeds
config.restorer_pipeline_params.callback_on_step_end = progress_callback
config.restorer_sampler_config = self.state.uidata.map_scheduler_settings_to_config(SAMPLER_MAPPING["restorer_sampler"])
if ups_engine == UpscalerEngine.SUPIR.value:
_upscaler_config = cast(SUPIR_Config, _upscaler_config)
config.selected_upscaler_checkpoint_model = input_param_with_values["Image Upscale Model"]
config.upscaler_engine = UpscalerEngine.SUPIR
config.selected_upscaler_sampler = Sampler.from_str(input_param_with_values["Image Upscale Sampler"])
config.upscaler_pipeline_params.supir_model = SUPIRModel.from_str(_upscaler_config.supir_model)
config.upscaler_pipeline_params.seed = _upscaler_config.general.seed
config.upscaler_pipeline_params.upscale_factor = _upscaler_config.general.upscale_factor
config.upscaler_pipeline_params.prompt = input_param_with_values["Image Upscale - Prompt 1"]
config.upscaler_pipeline_params.prompt_2 = input_param_with_values["Image Upscale - Prompt 2"]
config.upscaler_pipeline_params.negative_prompt = input_param_with_values["Image Upscale - Negative prompt"]
config.upscaler_pipeline_params.num_images = _upscaler_config.general.num_images
config.upscaler_pipeline_params.num_steps = _upscaler_config.general.num_steps
config.upscaler_pipeline_params.use_lpw_prompt = (
True if input_param_with_values["Image Upscale - Prompt method"] == PromptMethod.Weighted.value else False
)
config.upscaler_pipeline_params.tile_size = _upscaler_config.general.tile_size
config.upscaler_pipeline_params.restoration_scale = float(_upscaler_config.restoration_scale)
config.upscaler_pipeline_params.s_churn = float(_upscaler_config.s_churn)
config.upscaler_pipeline_params.s_noise = float(_upscaler_config.s_noise)
config.upscaler_pipeline_params.strength = float(_upscaler_config.strength)
config.upscaler_pipeline_params.use_linear_CFG = _upscaler_config.cfg_settings.use_linear_CFG
config.upscaler_pipeline_params.guidance_scale = float(_upscaler_config.general.guidance_scale)
config.upscaler_pipeline_params.guidance_rescale = float(_upscaler_config.general.guidance_rescale)
config.upscaler_pipeline_params.reverse_linear_CFG = float(_upscaler_config.cfg_settings.reverse_linear_CFG)
config.upscaler_pipeline_params.guidance_scale_start = float(_upscaler_config.cfg_settings.guidance_scale_start)
config.upscaler_pipeline_params.use_linear_control_scale = _upscaler_config.controlnet_settings.use_linear_control_scale
config.upscaler_pipeline_params.reverse_linear_control_scale = _upscaler_config.controlnet_settings.reverse_linear_control_scale
config.upscaler_pipeline_params.controlnet_conditioning_scale = float(_upscaler_config.controlnet_settings.controlnet_conditioning_scale)
config.upscaler_pipeline_params.control_scale_start = float(_upscaler_config.controlnet_settings.control_scale_start)
config.upscaler_pipeline_params.enable_PAG = _upscaler_config.pag_settings.enable_PAG
config.upscaler_pipeline_params.use_linear_PAG = _upscaler_config.pag_settings.use_linear_PAG
config.upscaler_pipeline_params.reverse_linear_PAG = (
_upscaler_config.pag_settings.reverse_linear_PAG and len(_upscaler_config.pag_settings.pag_layers) > 0
)
config.upscaler_pipeline_params.pag_scale = float(_upscaler_config.pag_settings.pag_scale)
config.upscaler_pipeline_params.pag_scale_start = float(_upscaler_config.pag_settings.pag_scale_start)
config.upscaler_pipeline_params.pag_layers = _upscaler_config.pag_settings.pag_layers
config.upscaler_pipeline_params.start_point = StartPoint.from_str(_upscaler_config.start_point)
config.upscaler_pipeline_params.image_size_fix_mode = ImageSizeFixMode.from_str(_upscaler_config.general.image_size_fix_mode)
config.upscaler_pipeline_params.upscaling_mode = UpscalingMode.from_str(_upscaler_config.general.upscaling_mode)
(
config.upscaler_pipeline_params.zero_sft_injection_configs,
config.upscaler_pipeline_params.zero_sft_injection_flags,
) = self.state.uidata.map_ui_supir_injection_to_pipeline_params(upscaler_supir_advanced_config)
config.upscaler_pipeline_params.cfg_decay_rate = _upscaler_config.general.cfg_decay_rate
config.upscaler_pipeline_params.strength_decay_rate = _upscaler_config.general.strength_decay_rate
config.upscaler_pipeline_params.color_fix_mode = ColorFix.from_str(_upscaler_config.post_processsing_settings.color_fix_mode)
config.upscaler_pipeline_params.callback_on_step_end = progress_callback
config.upscaler_sampler_config = self.state.uidata.map_scheduler_settings_to_config(SAMPLER_MAPPING["upscaler_sampler"])
elif ups_engine == UpscalerEngine.FaithDiff.value:
_upscaler_config = cast(FaithDiff_Config, _upscaler_config)
config.selected_upscaler_checkpoint_model = input_param_with_values["Image Upscale Model"]
config.upscaler_engine = UpscalerEngine.FaithDiff
config.selected_upscaler_sampler = Sampler.from_str(input_param_with_values["Image Upscale Sampler"])
config.upscaler_pipeline_params.seed = _upscaler_config.general.seed
config.upscaler_pipeline_params.upscale_factor = _upscaler_config.general.upscale_factor
config.upscaler_pipeline_params.prompt = input_param_with_values["Image Upscale - Prompt 1"]
config.upscaler_pipeline_params.prompt_2 = input_param_with_values["Image Upscale - Prompt 2"]
config.upscaler_pipeline_params.negative_prompt = input_param_with_values["Image Upscale - Negative prompt"]
config.upscaler_pipeline_params.num_images = _upscaler_config.general.num_images
config.upscaler_pipeline_params.num_steps = _upscaler_config.general.num_steps
config.upscaler_pipeline_params.use_lpw_prompt = (
True if input_param_with_values["Image Upscale - Prompt method"] == PromptMethod.Weighted.value else False
)
config.upscaler_pipeline_params.tile_size = _upscaler_config.general.tile_size
config.upscaler_pipeline_params.s_churn = float(_upscaler_config.s_churn)
config.upscaler_pipeline_params.s_noise = float(_upscaler_config.s_noise)
config.upscaler_pipeline_params.strength = float(_upscaler_config.strength)
config.upscaler_pipeline_params.guidance_scale = float(_upscaler_config.general.guidance_scale)
config.upscaler_pipeline_params.guidance_rescale = float(_upscaler_config.general.guidance_rescale)
config.upscaler_pipeline_params.use_linear_control_scale = _upscaler_config.controlnet_settings.use_linear_control_scale
config.upscaler_pipeline_params.reverse_linear_control_scale = _upscaler_config.controlnet_settings.reverse_linear_control_scale
config.upscaler_pipeline_params.controlnet_conditioning_scale = float(_upscaler_config.controlnet_settings.controlnet_conditioning_scale)
config.upscaler_pipeline_params.control_scale_start = float(_upscaler_config.controlnet_settings.control_scale_start)
config.upscaler_pipeline_params.enable_PAG = _upscaler_config.pag_settings.enable_PAG and len(_upscaler_config.pag_settings.pag_layers) > 0
config.upscaler_pipeline_params.use_linear_PAG = _upscaler_config.pag_settings.use_linear_PAG
config.upscaler_pipeline_params.reverse_linear_PAG = _upscaler_config.pag_settings.reverse_linear_PAG
config.upscaler_pipeline_params.pag_scale = float(_upscaler_config.pag_settings.pag_scale)
config.upscaler_pipeline_params.pag_scale_start = float(_upscaler_config.pag_settings.pag_scale_start)
config.upscaler_pipeline_params.pag_layers = _upscaler_config.pag_settings.pag_layers
config.upscaler_pipeline_params.start_point = StartPoint.from_str(_upscaler_config.start_point)
config.upscaler_pipeline_params.image_size_fix_mode = ImageSizeFixMode.from_str(_upscaler_config.general.image_size_fix_mode)
config.upscaler_pipeline_params.upscaling_mode = UpscalingMode.from_str(_upscaler_config.general.upscaling_mode)
config.upscaler_pipeline_params.invert_prompts = _upscaler_config.invert_prompts
config.upscaler_pipeline_params.apply_ipa_embeds = _upscaler_config.apply_ipa_embeds
config.upscaler_pipeline_params.cfg_decay_rate = _upscaler_config.general.cfg_decay_rate
config.upscaler_pipeline_params.strength_decay_rate = _upscaler_config.general.strength_decay_rate
config.upscaler_pipeline_params.color_fix_mode = ColorFix.from_str(_upscaler_config.post_processsing_settings.color_fix_mode)
config.upscaler_pipeline_params.callback_on_step_end = progress_callback
config.upscaler_sampler_config = self.state.uidata.map_scheduler_settings_to_config(SAMPLER_MAPPING["upscaler_sampler"])
elif ups_engine == UpscalerEngine.ControlNetTile.value:
_upscaler_config = cast(ControlNetTile_Config, _upscaler_config)
config.selected_upscaler_checkpoint_model = input_param_with_values["Image Upscale Model"]
config.upscaler_engine = UpscalerEngine.ControlNetTile
config.selected_upscaler_sampler = Sampler.from_str(input_param_with_values["Image Upscale Sampler"])
config.upscaler_pipeline_params.seed = _upscaler_config.general.seed
config.upscaler_pipeline_params.upscale_factor = _upscaler_config.general.upscale_factor
config.upscaler_pipeline_params.prompt = input_param_with_values["Image Upscale - Prompt 1"]
config.upscaler_pipeline_params.prompt_2 = input_param_with_values["Image Upscale - Prompt 2"]
config.upscaler_pipeline_params.negative_prompt = input_param_with_values["Image Upscale - Negative prompt"]
config.upscaler_pipeline_params.num_images = _upscaler_config.general.num_images
config.upscaler_pipeline_params.num_steps = _upscaler_config.general.num_steps
config.upscaler_pipeline_params.tile_size = _upscaler_config.general.tile_size
config.upscaler_pipeline_params.tile_overlap = _upscaler_config.tile_overlap
config.upscaler_pipeline_params.tile_weighting_method = WeightingMethod.from_str(_upscaler_config.tile_weighting_method)
config.upscaler_pipeline_params.tile_gaussian_sigma = _upscaler_config.tile_gaussian_sigma
config.upscaler_pipeline_params.strength = float(_upscaler_config.strength)
config.upscaler_pipeline_params.guidance_scale = float(_upscaler_config.general.guidance_scale)
config.upscaler_pipeline_params.guidance_rescale = float(_upscaler_config.general.guidance_rescale)
config.upscaler_pipeline_params.image_size_fix_mode = ImageSizeFixMode.from_str(_upscaler_config.general.image_size_fix_mode)
config.upscaler_pipeline_params.upscaling_mode = UpscalingMode.from_str(_upscaler_config.general.upscaling_mode)
config.upscaler_pipeline_params.cfg_decay_rate = _upscaler_config.general.cfg_decay_rate
config.upscaler_pipeline_params.strength_decay_rate = _upscaler_config.general.strength_decay_rate
config.upscaler_pipeline_params.color_fix_mode = ColorFix.from_str(_upscaler_config.post_processsing_settings.color_fix_mode)
config.upscaler_pipeline_params.callback_on_step_end = progress_callback
config.upscaler_sampler_config = self.state.uidata.map_scheduler_settings_to_config(SAMPLER_MAPPING["upscaler_sampler"])
config.image_path = input_image
self.update_pipeline(log_callback=process_and_send_updates, progress_bar_handler=progress_bar_handler)
initialize_status = sup_toolbox_pipe.initialize()
if initialize_status:
result, process_status = sup_toolbox_pipe.predict(metadata=image_metadata)
if process_status:
logger.log(logging.INFO + 5, "Image generated successfully!")
process_and_send_updates(status="success", final_image_payload=(input_image, result))
else:
raise RuntimeError("Pipeline prediction returned a failure status.")
else:
raise RuntimeError("Pipeline initialization failed.")
except PipelineCancelationRequested as ce:
logger.warning(str(ce))
process_and_send_updates(status="error")
except Exception as e:
logger.error(f"Error in diffusion thread: {e}, process aborted!", exc_info=True)
process_and_send_updates(status="error")
finally:
update_queue.put(None)