| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import gradio as gr |
|
|
| from . import common |
| from . import constants as ui_constants |
| from . import frame_planning as frames |
| from . import process_catalog as catalog |
|
|
|
|
| @dataclass(frozen=True) |
| class ProcessFormState: |
| process_model_type: str |
| process_name: str |
| source_path: str |
| process_strength: float |
| output_path: str |
| prompt: str |
| continue_enabled: bool |
| source_audio_track: str |
| output_resolution: str |
| target_ratio: str |
| chunk_size_seconds: float |
| sliding_window_overlap: int |
| start_seconds: str |
| end_seconds: str |
|
|
| def to_dict(self) -> dict: |
| return { |
| "process_model_type": self.process_model_type, |
| "process_name": self.process_name, |
| "source_path": self.source_path, |
| "process_strength": self.process_strength, |
| "output_path": self.output_path, |
| "prompt": self.prompt, |
| "continue_enabled": self.continue_enabled, |
| "source_audio_track": self.source_audio_track, |
| "output_resolution": self.output_resolution, |
| "target_ratio": self.target_ratio, |
| "chunk_size_seconds": self.chunk_size_seconds, |
| "sliding_window_overlap": self.sliding_window_overlap, |
| "start_seconds": self.start_seconds, |
| "end_seconds": self.end_seconds, |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class FormComponentValues: |
| source_path: object |
| process_strength: object |
| output_path: object |
| prompt_text: object |
| continue_enabled: object |
| source_audio_track: object |
| output_resolution: object |
| target_ratio: object |
| chunk_size_seconds: object |
| sliding_window_overlap: object |
| start_seconds: object |
| end_seconds: object |
|
|
| def to_raw_state(self) -> dict: |
| return { |
| "source_path": self.source_path, |
| "process_strength": self.process_strength, |
| "output_path": self.output_path, |
| "prompt": self.prompt_text, |
| "continue_enabled": self.continue_enabled, |
| "source_audio_track": self.source_audio_track, |
| "output_resolution": self.output_resolution, |
| "target_ratio": self.target_ratio, |
| "chunk_size_seconds": self.chunk_size_seconds, |
| "sliding_window_overlap": self.sliding_window_overlap, |
| "start_seconds": self.start_seconds, |
| "end_seconds": self.end_seconds, |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class InitialFormPatch: |
| model_type_choices: list[tuple[str, str]] |
| process_choices: list[tuple[str, str]] |
| model_type: str |
| process_name: str |
| form_state: ProcessFormState |
| overlap_step: int |
| overlap_max: int |
| overlap_visible: bool |
| output_resolution_visible: bool |
| prompt_visible: bool |
| process_strength_visible: bool |
| target_ratio_visible: bool |
| target_ratio_label: str |
| target_ratio_choices: list[tuple[str, str]] |
|
|
|
|
| RESTORED_FORM_OUTPUT_COUNT = 12 |
|
|
|
|
| def skipped_restored_form_outputs() -> tuple: |
| return tuple(gr.skip() for _ in range(RESTORED_FORM_OUTPUT_COUNT)) |
|
|
|
|
| class ProcessFormController: |
| def __init__( |
| self, |
| *, |
| library, |
| get_model_def, |
| output_resolution_values: set[str], |
| source_audio_track_values: set[str], |
| ratio_values: set[str], |
| default_model_type: str | None = None, |
| ) -> None: |
| self.library = library |
| self.get_model_def = get_model_def |
| self.output_resolution_values = output_resolution_values |
| self.source_audio_track_values = source_audio_track_values |
| self.ratio_values = ratio_values |
| self.default_model_type = default_model_type or catalog.DEFAULT_MODEL_TYPE |
|
|
| @staticmethod |
| def _process_strength_slider_bounds(value: float) -> tuple[float, float]: |
| lower = value if value < 0.0 else 0.0 |
| upper = value if value > 3.0 else 3.0 |
| return lower, upper |
|
|
| @staticmethod |
| def _fit_overlap_slider_value(value: int, maximum: int) -> int: |
| if value < 1: |
| return 1 |
| if value > maximum: |
| return maximum |
| return value |
|
|
| def build_initial_form(self, saved_ui_settings: dict, main_state: dict | None, initial_user_refs: list[str]) -> InitialFormPatch: |
| saved_process_name = str(saved_ui_settings.get("process_name") or "").strip() |
| saved_model_type = str(saved_ui_settings.get("process_model_type") or "").strip() |
| saved_process_definition = self.library.process_definition(saved_process_name, main_state, initial_user_refs) |
| if saved_process_definition is not None: |
| saved_model_type = self.library.process_definition_model_type(saved_process_definition) or saved_model_type |
| process_names_by_model_type = self.library.process_values_by_model_type(initial_user_refs) |
| default_model_type = ( |
| saved_model_type |
| if saved_model_type in process_names_by_model_type |
| else catalog.DEFAULT_MODEL_TYPE |
| if catalog.DEFAULT_MODEL_TYPE in process_names_by_model_type |
| else next(iter(process_names_by_model_type), catalog.DEFAULT_MODEL_TYPE) |
| ) |
| default_process_choices = self.library.normal_process_choices(default_model_type, initial_user_refs) |
| default_process_values = [value for _label, value in default_process_choices] |
| default_process_name = ( |
| saved_process_name |
| if saved_process_name in default_process_values |
| else catalog.DEFAULT_PROCESS_NAME |
| if catalog.DEFAULT_PROCESS_NAME in default_process_values |
| else (default_process_values[0] if default_process_values else catalog.DEFAULT_PROCESS_NAME) |
| ) |
| form_state = self.build_form_state(default_process_name, saved_ui_settings, main_state, initial_user_refs) |
| default_rules = self.library.process_frame_rules(default_process_name, main_state, initial_user_refs) |
| target_control_choices = self.library.target_control_choices(default_process_name, main_state, initial_user_refs) |
| has_target_control = len(target_control_choices) > 0 |
| overlap_visible = not self.library.hides_sliding_window_overlap(default_process_name, main_state, initial_user_refs) |
| self.default_model_type = default_model_type |
| return InitialFormPatch( |
| model_type_choices=self.library.model_type_choices(initial_user_refs), |
| process_choices=default_process_choices, |
| model_type=default_model_type, |
| process_name=default_process_name, |
| form_state=form_state, |
| overlap_step=default_rules.frame_step, |
| overlap_max=frames.get_overlap_slider_max(default_model_type, self.get_model_def) if overlap_visible else 1, |
| overlap_visible=overlap_visible, |
| output_resolution_visible=not self.library.hides_output_resolution(default_process_name, main_state, initial_user_refs), |
| prompt_visible=not self.library.hides_prompt(default_process_name, main_state, initial_user_refs), |
| process_strength_visible=self.library.is_process_strength_visible(default_process_name, main_state, initial_user_refs), |
| target_ratio_visible=has_target_control or self.library.has_process_outpaint(default_process_name, main_state, initial_user_refs), |
| target_ratio_label=self.library.target_control_label(default_process_name, main_state, initial_user_refs) if has_target_control else "Target Ratio", |
| target_ratio_choices=target_control_choices if has_target_control else ui_constants.RATIO_CHOICES, |
| ) |
|
|
| def user_settings_hint_update(self, process_choices: list[tuple[str, str]]): |
| return gr.update(visible=self.library.process_choices_have_user_settings(process_choices)) |
|
|
| @staticmethod |
| def settings_action_updates(process_model_type_value: str, process_value: str) -> tuple: |
| add_visible = process_model_type_value == ui_constants.ADD_USER_SETTINGS_MODEL_TYPE |
| delete_visible = catalog.is_user_process_value(process_value) |
| actions_visible = add_visible or delete_visible |
| return gr.update(visible=actions_visible), gr.update(visible=add_visible), gr.update(visible=delete_visible), gr.update(visible=False) |
|
|
| def target_ratio_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, target_ratio: str | None = None): |
| target_control_choices = self.library.target_control_choices(process_name, main_state, user_refs) |
| if len(target_control_choices) > 0: |
| values = {value for _label, value in target_control_choices} |
| value = str(target_ratio or "").strip() |
| if value not in values: |
| value = self.library.target_control_default(process_name, main_state, user_refs) |
| return gr.update(label=self.library.target_control_label(process_name, main_state, user_refs), value=value, visible=True, choices=target_control_choices) |
| visible = self.library.has_process_outpaint(process_name, main_state, user_refs) |
| return gr.update(label="Target Ratio", value=target_ratio if visible else "", visible=visible, choices=ui_constants.RATIO_CHOICES if visible else ui_constants.RATIO_CHOICES_WITH_EMPTY) |
|
|
| def process_strength_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, process_strength: float | None = None): |
| process_definition = self.library.process_definition(process_name, main_state, user_refs) |
| visible = self.library.is_process_strength_visible(process_name, main_state, user_refs) |
| default_value = common.get_default_process_strength((process_definition or {}).get("settings", {})) |
| if isinstance(process_definition, dict) and process_definition.get("source") == "user": |
| user_default = self.library.user_lora_strength_override_default(process_definition) |
| if user_default is not None: |
| default_value = user_default |
| value = common.coerce_float(process_strength, default_value) if visible else default_value |
| minimum, maximum = self._process_strength_slider_bounds(value) |
| return gr.update(value=value, visible=visible, minimum=minimum, maximum=maximum) |
|
|
| def prompt_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, prompt: str): |
| return gr.update(value=prompt, visible=not self.library.hides_prompt(process_name, main_state, user_refs)) |
|
|
| def output_resolution_update(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, output_resolution: str): |
| return gr.update(value=output_resolution, visible=not self.library.hides_output_resolution(process_name, main_state, user_refs)) |
|
|
| def overlap_control_updates(self, process_name: str, main_state: dict | None, user_refs: list[str] | None): |
| if self.library.hides_sliding_window_overlap(process_name, main_state, user_refs): |
| return gr.update(minimum=0, maximum=1, step=1, value=0, visible=False) |
| process_definition = self.library.process_definition_or_default(process_name, main_state, user_refs) |
| settings = process_definition.get("settings", {}) |
| model_type = str(settings.get("model_type") or "") |
| step = frames.get_vae_temporal_latent_size(model_type, self.get_model_def) |
| maximum = frames.get_overlap_slider_max(model_type, self.get_model_def) |
| value = common.coerce_int(settings.get("sliding_window_overlap"), 1, minimum=1) |
| value = self._fit_overlap_slider_value(frames.normalize_overlap_frames(value, frame_step=step), maximum) |
| return gr.update(minimum=1, maximum=maximum, step=step, value=value, visible=True) |
|
|
| def build_form_state(self, process_name: str, raw_state: dict | None = None, main_state: dict | None = None, user_refs: list[str] | None = None) -> ProcessFormState: |
| process_definition = self.library.process_definition_or_default(process_name, main_state, user_refs) |
| process_settings = process_definition.get("settings", {}) |
| model_type = str(process_settings.get("model_type") or catalog.DEFAULT_MODEL_TYPE) |
| frame_rules = self.library.process_frame_rules(process_name, main_state, user_refs) |
| step = int(frame_rules.frame_step) |
| maximum = frames.get_overlap_slider_max(model_type, self.get_model_def) if not self.library.hides_sliding_window_overlap(process_name, main_state, user_refs) else 1 |
| raw_state = raw_state if isinstance(raw_state, dict) else {} |
|
|
| default_strength = common.get_default_process_strength(process_settings) |
| saved_process_strength = raw_state.get("process_strength", raw_state.get("control_video_strength")) |
| process_strength = default_strength if saved_process_strength is None else common.coerce_float(saved_process_strength, default_strength) |
| source_audio_track = str(raw_state.get("source_audio_track") or "").strip() |
| output_resolution = str(raw_state.get("output_resolution") or "").strip() |
| target_control_choices = self.library.target_control_choices(process_name, main_state, user_refs) |
| if len(target_control_choices) > 0: |
| target_values = {value for _label, value in target_control_choices} |
| target_ratio = str(raw_state.get("target_ratio") or process_settings.get("target_ratio") or self.library.target_control_default(process_name, main_state, user_refs)).strip() |
| if target_ratio not in target_values: |
| target_ratio = self.library.target_control_default(process_name, main_state, user_refs) |
| else: |
| target_ratio = str(raw_state.get("target_ratio") or "4:3").strip() |
| if self.library.hides_sliding_window_overlap(process_name, main_state, user_refs): |
| sliding_window_overlap = 0 |
| else: |
| overlap_default = self._fit_overlap_slider_value(frames.normalize_overlap_frames(common.coerce_int(process_settings.get("sliding_window_overlap"), 1, minimum=1), frame_step=step), maximum) |
| overlap_value = common.coerce_int(raw_state.get("sliding_window_overlap"), overlap_default, minimum=1) |
| sliding_window_overlap = self._fit_overlap_slider_value(frames.normalize_overlap_frames(overlap_value, frame_step=step), maximum) |
| default_chunk_size_seconds = self.library.default_chunk_size_seconds(process_name, main_state, user_refs) |
|
|
| return ProcessFormState( |
| process_model_type=model_type, |
| process_name=process_name, |
| source_path=str(raw_state.get("source_path") or ui_constants.DEFAULT_SOURCE_PATH), |
| process_strength=process_strength, |
| output_path=str(raw_state.get("output_path") or ui_constants.DEFAULT_OUTPUT_PATH), |
| prompt=str(raw_state.get("prompt") or "") if "prompt" in raw_state else str(process_settings.get("prompt") or ""), |
| continue_enabled=common.coerce_bool(raw_state.get("continue_enabled"), True), |
| source_audio_track=source_audio_track if source_audio_track in self.source_audio_track_values else "", |
| output_resolution=output_resolution if output_resolution in self.output_resolution_values else "720p", |
| target_ratio=target_ratio if len(target_control_choices) > 0 or target_ratio in self.ratio_values else "4:3", |
| chunk_size_seconds=common.coerce_float(raw_state.get("chunk_size_seconds"), default_chunk_size_seconds, minimum=0.1), |
| sliding_window_overlap=sliding_window_overlap, |
| start_seconds="" if raw_state.get("start_seconds") in (None, "") else str(raw_state.get("start_seconds")), |
| end_seconds="" if raw_state.get("end_seconds") in (None, "") else str(raw_state.get("end_seconds")), |
| ) |
|
|
| def build_state(self, process_name: str, raw_state: dict | None = None, main_state: dict | None = None, user_refs: list[str] | None = None) -> dict: |
| return self.build_form_state(process_name, raw_state, main_state, user_refs).to_dict() |
|
|
| def snapshot_state(self, process_name: str, main_state: dict | None, user_refs: list[str] | None, values: FormComponentValues) -> dict: |
| return self.build_state(process_name, values.to_raw_state(), main_state, user_refs) |
|
|
| def store_memory(self, memory_state: dict | None, current_process_name: str, main_state: dict | None, user_refs: list[str] | None, values: FormComponentValues): |
| updated_memory = dict(memory_state) if isinstance(memory_state, dict) else {} |
| current_process_name = str(current_process_name or "").strip() |
| if self.library.process_definition(current_process_name, main_state, user_refs) is not None: |
| updated_memory[current_process_name] = self.snapshot_state(current_process_name, main_state, user_refs, values) |
| return updated_memory |
|
|
| def restore_state(self, memory_state: dict | None, process_name: str, current_source_path: str, main_state: dict | None, user_refs: list[str] | None) -> tuple: |
| state = self.build_form_state(process_name, (memory_state or {}).get(process_name), main_state, user_refs) |
| source_path_value = current_source_path.strip() or state.source_path |
| return ( |
| source_path_value, |
| self.process_strength_update(process_name, main_state, user_refs, state.process_strength), |
| state.output_path, |
| self.prompt_update(process_name, main_state, user_refs, state.prompt), |
| state.continue_enabled, |
| state.source_audio_track, |
| self.output_resolution_update(process_name, main_state, user_refs, state.output_resolution), |
| self.target_ratio_update(process_name, main_state, user_refs, state.target_ratio), |
| state.chunk_size_seconds, |
| self.overlap_control_updates(process_name, main_state, user_refs), |
| state.start_seconds, |
| state.end_seconds, |
| ) |
|
|
| def change_process_model_type(self, memory_state: dict | None, current_process_name: str, next_model_type: str, main_state: dict | None, main_lset_name: str | None, user_refs: list[str] | None, values: FormComponentValues) -> tuple: |
| refs = catalog.get_saved_user_settings_refs({catalog.USER_SETTINGS_STORAGE_KEY: user_refs}) |
| updated_memory = self.store_memory(memory_state, current_process_name, main_state, refs, values) |
| process_choices, next_process_name = self.library.process_choices(next_model_type, main_state, main_lset_name, refs) |
| next_process_name = str(next_process_name or ui_constants.NO_USER_SETTINGS_VALUE).strip() |
| catalog.save_process_full_video_selection(next_model_type, next_process_name) |
| return ( |
| updated_memory, |
| next_process_name, |
| gr.update(choices=process_choices, value=next_process_name), |
| self.user_settings_hint_update(process_choices), |
| *self.settings_action_updates(next_model_type, next_process_name), |
| *self.restore_state(updated_memory, next_process_name, values.source_path, main_state, refs), |
| ) |
|
|
| def change_process_name(self, memory_state: dict | None, current_process_name: str, next_process_name: str, process_model_type_value: str, main_state: dict | None, user_refs: list[str] | None, values: FormComponentValues) -> tuple: |
| refs = catalog.get_saved_user_settings_refs({catalog.USER_SETTINGS_STORAGE_KEY: user_refs}) |
| updated_memory = self.store_memory(memory_state, current_process_name, main_state, refs, values) |
| next_process_name = str(next_process_name or "").strip() |
| catalog.save_process_full_video_selection(process_model_type_value, next_process_name) |
| actions_update, _add_update, delete_update, placeholder_update = self.settings_action_updates(process_model_type_value, next_process_name) |
| return ( |
| updated_memory, |
| next_process_name, |
| actions_update, |
| delete_update, |
| placeholder_update, |
| *self.restore_state(updated_memory, next_process_name, values.source_path, main_state, refs), |
| ) |
|
|
| def refresh_from_main(self, _refresh_id, memory_state: dict | None, current_process_name: str, process_model_type_value: str, main_state: dict | None, main_lset_name: str | None, user_refs: list[str] | None, values: FormComponentValues) -> tuple: |
| refs = catalog.get_saved_user_settings_refs({catalog.USER_SETTINGS_STORAGE_KEY: user_refs}) |
| model_choices = self.library.model_type_choices(refs) |
| process_model_type_value = str(process_model_type_value or "").strip() |
| if process_model_type_value != ui_constants.ADD_USER_SETTINGS_MODEL_TYPE: |
| valid_model_values = {value for _label, value in model_choices} |
| model_value = process_model_type_value if process_model_type_value in valid_model_values else self.default_model_type |
| return ( |
| gr.update(choices=model_choices, value=model_value), |
| gr.update(), |
| gr.update(), |
| memory_state, |
| current_process_name, |
| *self.settings_action_updates(model_value, current_process_name), |
| *skipped_restored_form_outputs(), |
| ) |
| updated_memory = self.store_memory(memory_state, current_process_name, main_state, refs, values) |
| process_choices, next_process_name = self.library.current_user_settings_choices(main_state, main_lset_name) |
| return ( |
| gr.update(choices=model_choices, value=ui_constants.ADD_USER_SETTINGS_MODEL_TYPE), |
| gr.update(choices=process_choices, value=next_process_name), |
| self.user_settings_hint_update(process_choices), |
| updated_memory, |
| next_process_name, |
| *self.settings_action_updates(ui_constants.ADD_USER_SETTINGS_MODEL_TYPE, next_process_name), |
| *self.restore_state(updated_memory, next_process_name, values.source_path, main_state, refs), |
| ) |
|
|