1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
22 kB
from __future__ import annotations
from dataclasses import dataclass
import gradio as gr
from . import common
from . import constants as ui_constants
from . import frame_planning as frames
from . import process_catalog as catalog
@dataclass(frozen=True)
class ProcessFormState:
process_model_type: str
process_name: str
source_path: str
process_strength: float
output_path: str
prompt: str
continue_enabled: bool
source_audio_track: str
output_resolution: str
target_ratio: str
chunk_size_seconds: float
sliding_window_overlap: int
start_seconds: str
end_seconds: str
def to_dict(self) -> dict:
return {
"process_model_type": self.process_model_type,
"process_name": self.process_name,
"source_path": self.source_path,
"process_strength": self.process_strength,
"output_path": self.output_path,
"prompt": self.prompt,
"continue_enabled": self.continue_enabled,
"source_audio_track": self.source_audio_track,
"output_resolution": self.output_resolution,
"target_ratio": self.target_ratio,
"chunk_size_seconds": self.chunk_size_seconds,
"sliding_window_overlap": self.sliding_window_overlap,
"start_seconds": self.start_seconds,
"end_seconds": self.end_seconds,
}
@dataclass(frozen=True)
class FormComponentValues:
source_path: object
process_strength: object
output_path: object
prompt_text: object
continue_enabled: object
source_audio_track: object
output_resolution: object
target_ratio: object
chunk_size_seconds: object
sliding_window_overlap: object
start_seconds: object
end_seconds: object
def to_raw_state(self) -> dict:
return {
"source_path": self.source_path,
"process_strength": self.process_strength,
"output_path": self.output_path,
"prompt": self.prompt_text,
"continue_enabled": self.continue_enabled,
"source_audio_track": self.source_audio_track,
"output_resolution": self.output_resolution,
"target_ratio": self.target_ratio,
"chunk_size_seconds": self.chunk_size_seconds,
"sliding_window_overlap": self.sliding_window_overlap,
"start_seconds": self.start_seconds,
"end_seconds": self.end_seconds,
}
@dataclass(frozen=True)
class InitialFormPatch:
model_type_choices: list[tuple[str, str]]
process_choices: list[tuple[str, str]]
model_type: str
process_name: str
form_state: ProcessFormState
overlap_step: int
overlap_max: int
overlap_visible: bool
output_resolution_visible: bool
prompt_visible: bool
process_strength_visible: bool
target_ratio_visible: bool
target_ratio_label: str
target_ratio_choices: list[tuple[str, str]]
RESTORED_FORM_OUTPUT_COUNT = 12
def skipped_restored_form_outputs() -> tuple:
return tuple(gr.skip() for _ in range(RESTORED_FORM_OUTPUT_COUNT))
class ProcessFormController:
def __init__(
self,
*,
library,
get_model_def,
output_resolution_values: set[str],
source_audio_track_values: set[str],
ratio_values: set[str],
default_model_type: str | None = None,
) -> None:
self.library = library
self.get_model_def = get_model_def
self.output_resolution_values = output_resolution_values
self.source_audio_track_values = source_audio_track_values
self.ratio_values = ratio_values
self.default_model_type = default_model_type or catalog.DEFAULT_MODEL_TYPE
@staticmethod
def _process_strength_slider_bounds(value: float) -> tuple[float, float]:
lower = value if value < 0.0 else 0.0
upper = value if value > 3.0 else 3.0
return lower, upper
@staticmethod
def _fit_overlap_slider_value(value: int, maximum: int) -> int:
if value < 1:
return 1
if value > maximum:
return maximum
return value
def build_initial_form(self, saved_ui_settings: dict, main_state: dict | None, initial_user_refs: list[str]) -> InitialFormPatch:
saved_process_name = str(saved_ui_settings.get("process_name") or "").strip()
saved_model_type = str(saved_ui_settings.get("process_model_type") or "").strip()
saved_process_definition = self.library.process_definition(saved_process_name, main_state, initial_user_refs)
if saved_process_definition is not None:
saved_model_type = self.library.process_definition_model_type(saved_process_definition) or saved_model_type
process_names_by_model_type = self.library.process_values_by_model_type(initial_user_refs)
default_model_type = (
saved_model_type
if saved_model_type in process_names_by_model_type
else catalog.DEFAULT_MODEL_TYPE
if catalog.DEFAULT_MODEL_TYPE in process_names_by_model_type
else next(iter(process_names_by_model_type), catalog.DEFAULT_MODEL_TYPE)
)
default_process_choices = self.library.normal_process_choices(default_model_type, initial_user_refs)
default_process_values = [value for _label, value in default_process_choices]
default_process_name = (
saved_process_name
if saved_process_name in default_process_values
else catalog.DEFAULT_PROCESS_NAME
if catalog.DEFAULT_PROCESS_NAME in default_process_values
else (default_process_values[0] if default_process_values else catalog.DEFAULT_PROCESS_NAME)
)
form_state = self.build_form_state(default_process_name, saved_ui_settings, main_state, initial_user_refs)
default_rules = self.library.process_frame_rules(default_process_name, main_state, initial_user_refs)
target_control_choices = self.library.target_control_choices(default_process_name, main_state, initial_user_refs)
has_target_control = len(target_control_choices) > 0
overlap_visible = not self.library.hides_sliding_window_overlap(default_process_name, main_state, initial_user_refs)
self.default_model_type = default_model_type
return InitialFormPatch(
model_type_choices=self.library.model_type_choices(initial_user_refs),
process_choices=default_process_choices,
model_type=default_model_type,
process_name=default_process_name,
form_state=form_state,
overlap_step=default_rules.frame_step,
overlap_max=frames.get_overlap_slider_max(default_model_type, self.get_model_def) if overlap_visible else 1,
overlap_visible=overlap_visible,
output_resolution_visible=not self.library.hides_output_resolution(default_process_name, main_state, initial_user_refs),
prompt_visible=not self.library.hides_prompt(default_process_name, main_state, initial_user_refs),
process_strength_visible=self.library.is_process_strength_visible(default_process_name, main_state, initial_user_refs),
target_ratio_visible=has_target_control or self.library.has_process_outpaint(default_process_name, main_state, initial_user_refs),
target_ratio_label=self.library.target_control_label(default_process_name, main_state, initial_user_refs) if has_target_control else "Target Ratio",
target_ratio_choices=target_control_choices if has_target_control else ui_constants.RATIO_CHOICES,
)
def user_settings_hint_update(self, process_choices: list[tuple[str, str]]):
return gr.update(visible=self.library.process_choices_have_user_settings(process_choices))
@staticmethod
def settings_action_updates(process_model_type_value: str, process_value: str) -> tuple:
add_visible = process_model_type_value == ui_constants.ADD_USER_SETTINGS_MODEL_TYPE
delete_visible = catalog.is_user_process_value(process_value)
actions_visible = add_visible or delete_visible
return gr.update(visible=actions_visible), gr.update(visible=add_visible), gr.update(visible=delete_visible), gr.update(visible=False)
def target_ratio_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, target_ratio: str | None = None):
target_control_choices = self.library.target_control_choices(process_name, main_state, user_refs)
if len(target_control_choices) > 0:
values = {value for _label, value in target_control_choices}
value = str(target_ratio or "").strip()
if value not in values:
value = self.library.target_control_default(process_name, main_state, user_refs)
return gr.update(label=self.library.target_control_label(process_name, main_state, user_refs), value=value, visible=True, choices=target_control_choices)
visible = self.library.has_process_outpaint(process_name, main_state, user_refs)
return gr.update(label="Target Ratio", value=target_ratio if visible else "", visible=visible, choices=ui_constants.RATIO_CHOICES if visible else ui_constants.RATIO_CHOICES_WITH_EMPTY)
def process_strength_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, process_strength: float | None = None):
process_definition = self.library.process_definition(process_name, main_state, user_refs)
visible = self.library.is_process_strength_visible(process_name, main_state, user_refs)
default_value = common.get_default_process_strength((process_definition or {}).get("settings", {}))
if isinstance(process_definition, dict) and process_definition.get("source") == "user":
user_default = self.library.user_lora_strength_override_default(process_definition)
if user_default is not None:
default_value = user_default
value = common.coerce_float(process_strength, default_value) if visible else default_value
minimum, maximum = self._process_strength_slider_bounds(value)
return gr.update(value=value, visible=visible, minimum=minimum, maximum=maximum)
def prompt_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, prompt: str):
return gr.update(value=prompt, visible=not self.library.hides_prompt(process_name, main_state, user_refs))
def output_resolution_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, output_resolution: str):
return gr.update(value=output_resolution, visible=not self.library.hides_output_resolution(process_name, main_state, user_refs))
def overlap_control_updates(self, process_name: str, main_state: dict | None, user_refs: list[str] | None):
if self.library.hides_sliding_window_overlap(process_name, main_state, user_refs):
return gr.update(minimum=0, maximum=1, step=1, value=0, visible=False)
process_definition = self.library.process_definition_or_default(process_name, main_state, user_refs)
settings = process_definition.get("settings", {})
model_type = str(settings.get("model_type") or "")
step = frames.get_vae_temporal_latent_size(model_type, self.get_model_def)
maximum = frames.get_overlap_slider_max(model_type, self.get_model_def)
value = common.coerce_int(settings.get("sliding_window_overlap"), 1, minimum=1)
value = self._fit_overlap_slider_value(frames.normalize_overlap_frames(value, frame_step=step), maximum)
return gr.update(minimum=1, maximum=maximum, step=step, value=value, visible=True)
def build_form_state(self, process_name: str, raw_state: dict | None = None, main_state: dict | None = None, user_refs: list[str] | None = None) -> ProcessFormState:
process_definition = self.library.process_definition_or_default(process_name, main_state, user_refs)
process_settings = process_definition.get("settings", {})
model_type = str(process_settings.get("model_type") or catalog.DEFAULT_MODEL_TYPE)
frame_rules = self.library.process_frame_rules(process_name, main_state, user_refs)
step = int(frame_rules.frame_step)
maximum = frames.get_overlap_slider_max(model_type, self.get_model_def) if not self.library.hides_sliding_window_overlap(process_name, main_state, user_refs) else 1
raw_state = raw_state if isinstance(raw_state, dict) else {}
default_strength = common.get_default_process_strength(process_settings)
saved_process_strength = raw_state.get("process_strength", raw_state.get("control_video_strength"))
process_strength = default_strength if saved_process_strength is None else common.coerce_float(saved_process_strength, default_strength)
source_audio_track = str(raw_state.get("source_audio_track") or "").strip()
output_resolution = str(raw_state.get("output_resolution") or "").strip()
target_control_choices = self.library.target_control_choices(process_name, main_state, user_refs)
if len(target_control_choices) > 0:
target_values = {value for _label, value in target_control_choices}
target_ratio = str(raw_state.get("target_ratio") or process_settings.get("target_ratio") or self.library.target_control_default(process_name, main_state, user_refs)).strip()
if target_ratio not in target_values:
target_ratio = self.library.target_control_default(process_name, main_state, user_refs)
else:
target_ratio = str(raw_state.get("target_ratio") or "4:3").strip()
if self.library.hides_sliding_window_overlap(process_name, main_state, user_refs):
sliding_window_overlap = 0
else:
overlap_default = self._fit_overlap_slider_value(frames.normalize_overlap_frames(common.coerce_int(process_settings.get("sliding_window_overlap"), 1, minimum=1), frame_step=step), maximum)
overlap_value = common.coerce_int(raw_state.get("sliding_window_overlap"), overlap_default, minimum=1)
sliding_window_overlap = self._fit_overlap_slider_value(frames.normalize_overlap_frames(overlap_value, frame_step=step), maximum)
default_chunk_size_seconds = self.library.default_chunk_size_seconds(process_name, main_state, user_refs)
return ProcessFormState(
process_model_type=model_type,
process_name=process_name,
source_path=str(raw_state.get("source_path") or ui_constants.DEFAULT_SOURCE_PATH),
process_strength=process_strength,
output_path=str(raw_state.get("output_path") or ui_constants.DEFAULT_OUTPUT_PATH),
prompt=str(raw_state.get("prompt") or "") if "prompt" in raw_state else str(process_settings.get("prompt") or ""),
continue_enabled=common.coerce_bool(raw_state.get("continue_enabled"), True),
source_audio_track=source_audio_track if source_audio_track in self.source_audio_track_values else "",
output_resolution=output_resolution if output_resolution in self.output_resolution_values else "720p",
target_ratio=target_ratio if len(target_control_choices) > 0 or target_ratio in self.ratio_values else "4:3",
chunk_size_seconds=common.coerce_float(raw_state.get("chunk_size_seconds"), default_chunk_size_seconds, minimum=0.1),
sliding_window_overlap=sliding_window_overlap,
start_seconds="" if raw_state.get("start_seconds") in (None, "") else str(raw_state.get("start_seconds")),
end_seconds="" if raw_state.get("end_seconds") in (None, "") else str(raw_state.get("end_seconds")),
)
def build_state(self, process_name: str, raw_state: dict | None = None, main_state: dict | None = None, user_refs: list[str] | None = None) -> dict:
return self.build_form_state(process_name, raw_state, main_state, user_refs).to_dict()
def snapshot_state(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, values: FormComponentValues) -> dict:
return self.build_state(process_name, values.to_raw_state(), main_state, user_refs)
def store_memory(self, memory_state: dict | None, current_process_name: str, main_state: dict | None, user_refs: list[str] | None, values: FormComponentValues):
updated_memory = dict(memory_state) if isinstance(memory_state, dict) else {}
current_process_name = str(current_process_name or "").strip()
if self.library.process_definition(current_process_name, main_state, user_refs) is not None:
updated_memory[current_process_name] = self.snapshot_state(current_process_name, main_state, user_refs, values)
return updated_memory
def restore_state(self, memory_state: dict | None, process_name: str, current_source_path: str, main_state: dict | None, user_refs: list[str] | None) -> tuple:
state = self.build_form_state(process_name, (memory_state or {}).get(process_name), main_state, user_refs)
source_path_value = current_source_path.strip() or state.source_path
return (
source_path_value,
self.process_strength_update(process_name, main_state, user_refs, state.process_strength),
state.output_path,
self.prompt_update(process_name, main_state, user_refs, state.prompt),
state.continue_enabled,
state.source_audio_track,
self.output_resolution_update(process_name, main_state, user_refs, state.output_resolution),
self.target_ratio_update(process_name, main_state, user_refs, state.target_ratio),
state.chunk_size_seconds,
self.overlap_control_updates(process_name, main_state, user_refs),
state.start_seconds,
state.end_seconds,
)
def change_process_model_type(self, memory_state: dict | None, current_process_name: str, next_model_type: str, main_state: dict | None, main_lset_name: str | None, user_refs: list[str] | None, values: FormComponentValues) -> tuple:
refs = catalog.get_saved_user_settings_refs({catalog.USER_SETTINGS_STORAGE_KEY: user_refs})
updated_memory = self.store_memory(memory_state, current_process_name, main_state, refs, values)
process_choices, next_process_name = self.library.process_choices(next_model_type, main_state, main_lset_name, refs)
next_process_name = str(next_process_name or ui_constants.NO_USER_SETTINGS_VALUE).strip()
catalog.save_process_full_video_selection(next_model_type, next_process_name)
return (
updated_memory,
next_process_name,
gr.update(choices=process_choices, value=next_process_name),
self.user_settings_hint_update(process_choices),
*self.settings_action_updates(next_model_type, next_process_name),
*self.restore_state(updated_memory, next_process_name, values.source_path, main_state, refs),
)
def change_process_name(self, memory_state: dict | None, current_process_name: str, next_process_name: str, process_model_type_value: str, main_state: dict | None, user_refs: list[str] | None, values: FormComponentValues) -> tuple:
refs = catalog.get_saved_user_settings_refs({catalog.USER_SETTINGS_STORAGE_KEY: user_refs})
updated_memory = self.store_memory(memory_state, current_process_name, main_state, refs, values)
next_process_name = str(next_process_name or "").strip()
catalog.save_process_full_video_selection(process_model_type_value, next_process_name)
actions_update, _add_update, delete_update, placeholder_update = self.settings_action_updates(process_model_type_value, next_process_name)
return (
updated_memory,
next_process_name,
actions_update,
delete_update,
placeholder_update,
*self.restore_state(updated_memory, next_process_name, values.source_path, main_state, refs),
)
def refresh_from_main(self, _refresh_id, memory_state: dict | None, current_process_name: str, process_model_type_value: str, main_state: dict | None, main_lset_name: str | None, user_refs: list[str] | None, values: FormComponentValues) -> tuple:
refs = catalog.get_saved_user_settings_refs({catalog.USER_SETTINGS_STORAGE_KEY: user_refs})
model_choices = self.library.model_type_choices(refs)
process_model_type_value = str(process_model_type_value or "").strip()
if process_model_type_value != ui_constants.ADD_USER_SETTINGS_MODEL_TYPE:
valid_model_values = {value for _label, value in model_choices}
model_value = process_model_type_value if process_model_type_value in valid_model_values else self.default_model_type
return (
gr.update(choices=model_choices, value=model_value),
gr.update(),
gr.update(),
memory_state,
current_process_name,
*self.settings_action_updates(model_value, current_process_name),
*skipped_restored_form_outputs(),
)
updated_memory = self.store_memory(memory_state, current_process_name, main_state, refs, values)
process_choices, next_process_name = self.library.current_user_settings_choices(main_state, main_lset_name)
return (
gr.update(choices=model_choices, value=ui_constants.ADD_USER_SETTINGS_MODEL_TYPE),
gr.update(choices=process_choices, value=next_process_name),
self.user_settings_hint_update(process_choices),
updated_memory,
next_process_name,
*self.settings_action_updates(ui_constants.ADD_USER_SETTINGS_MODEL_TYPE, next_process_name),
*self.restore_state(updated_memory, next_process_name, values.source_path, main_state, refs),
)