File size: 23,344 Bytes
618f472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import gradio as gr
from shared.utils.plugins import WAN2GPPlugin
import json

class ConfigTabPlugin(WAN2GPPlugin):
    def __init__(self):
        super().__init__()
        self.name = "Configuration Tab"
        self.version = "1.1.0"
        self.description = "Lets you adjust all your performance and UI options for WAN2GP"

    def setup_ui(self):
        self.request_global("args")
        self.request_global("server_config")
        self.request_global("server_config_filename")
        self.request_global("attention_mode")
        self.request_global("compile")
        self.request_global("default_profile")
        self.request_global("vae_config")
        self.request_global("boost")
        self.request_global("preload_model_policy")
        self.request_global("transformer_quantization")
        self.request_global("transformer_dtype_policy")
        self.request_global("transformer_types")
        self.request_global("text_encoder_quantization")
        self.request_global("attention_modes_installed")
        self.request_global("attention_modes_supported")
        self.request_global("displayed_model_types")
        self.request_global("memory_profile_choices")
        self.request_global("save_path")
        self.request_global("image_save_path")
        self.request_global("quit_application")
        self.request_global("release_model")
        self.request_global("get_sorted_dropdown")
        self.request_global("app")
        self.request_global("fl")
        self.request_global("is_generation_in_progress")
        self.request_global("generate_header")
        self.request_global("generate_dropdown_model_list")
        self.request_global("get_unique_id")
        self.request_global("enhancer_offloadobj")

        self.request_component("header")
        self.request_component("model_family")
        self.request_component("model_base_type_choice")
        self.request_component("model_choice")
        self.request_component("refresh_form_trigger")      
        self.request_component("state")
        self.request_component("resolution")

        self.add_tab(
            tab_id="configuration",
            label="Configuration",
            component_constructor=self.create_config_ui,
            position=4
        )

    def create_config_ui(self):
        with gr.Column():
            with gr.Tabs():
                with gr.Tab("General"):
                    _, _, dropdown_choices = self.get_sorted_dropdown(self.displayed_model_types, None, None, False)

                    self.transformer_types_choices = gr.Dropdown(
                        choices=dropdown_choices, value=self.transformer_types,
                        label="Selectable Generative Models (leave empty for all)", multiselect=True
                    )
                    self.model_hierarchy_type_choice = gr.Dropdown(
                        choices=[
                            ("Two Levels: Model Family > Models & Finetunes", 0),
                            ("Three Levels: Model Family > Models > Finetunes", 1),
                        ],
                        value=self.server_config.get("model_hierarchy_type", 1),
                        label="Models Hierarchy In User Interface",
                        interactive=not self.args.lock_config
                    )
                    self.fit_canvas_choice = gr.Dropdown(
                        choices=[
                            ("Dimensions are Pixel Budget (preserves aspect ratio, may exceed dimensions)", 0),
                            ("Dimensions are Max Width/Height (preserves aspect ratio, fits within box)", 1),
                            ("Dimensions are Exact Output (crops input to fit exact dimensions)", 2),
                        ],
                        value=self.server_config.get("fit_canvas", 0),
                        label="Input Image/Video Sizing Behavior",
                        interactive=not self.args.lock_config
                    )

                    def check_attn(mode):
                        if mode not in self.attention_modes_installed: return " (NOT INSTALLED)"
                        if mode not in self.attention_modes_supported: return " (NOT SUPPORTED)"
                        return ""

                    self.attention_choice = gr.Dropdown(
                        choices=[
                            ("Auto: Best available (sage2 > sage > sdpa)", "auto"),
                            ("sdpa: Default, always available", "sdpa"),
                            (f'flash{check_attn("flash")}: High quality, requires manual install', "flash"),
                            (f'xformers{check_attn("xformers")}: Good quality, less VRAM, requires manual install', "xformers"),
                            (f'sage{check_attn("sage")}: ~30% faster, requires manual install', "sage"),
                            (f'sage2/sage2++{check_attn("sage2")}: ~40% faster, requires manual install', "sage2"),
                        ] + ([(f'radial{check_attn("radial")}: Experimental, may be faster, requires manual install', "radial")] if self.args.betatest else []) + [
                            (f'sage3{check_attn("sage3")}: >50% faster, may have quality trade-offs, requires manual install', "sage3"),
                        ],
                        value=self.attention_mode, label="Attention Type", interactive=not self.args.lock_config
                    )
                    self.preload_model_policy_choice = gr.CheckboxGroup(
                        [("Preload Model on App Launch","P"), ("Preload Model on Switch", "S"), ("Unload Model when Queue is Done", "U")],
                        value=self.preload_model_policy, label="Model Loading/Unloading Policy"
                    )
                    self.clear_file_list_choice = gr.Dropdown(
                        choices=[("None", 0), ("Keep last video", 1), ("Keep last 5 videos", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)],
                        value=self.server_config.get("clear_file_list", 5), label="Keep Previous Generations in Gallery"
                    )
                    self.display_stats_choice = gr.Dropdown(
                        choices=[("Disabled", 0), ("Enabled", 1)],
                        value=self.server_config.get("display_stats", 0), label="Display real-time RAM/VRAM stats (requires restart)"
                    )
                    self.max_frames_multiplier_choice = gr.Dropdown(
                        choices=[("Default", 1), ("x2", 2), ("x3", 3), ("x4", 4), ("x5", 5), ("x6", 6), ("x7", 7)],
                        value=self.server_config.get("max_frames_multiplier", 1), label="Max Frames Multiplier (requires restart)"
                    )
                    default_paths = self.fl.default_checkpoints_paths
                    checkpoints_paths_text = "\n".join(self.server_config.get("checkpoints_paths", default_paths))
                    self.checkpoints_paths_choice = gr.Textbox(
                        label="Model Checkpoint Folders (One Path per Line. First is Default Download Path)",
                        value=checkpoints_paths_text,
                        lines=3,
                        interactive=not self.args.lock_config
                    )
                    self.UI_theme_choice = gr.Dropdown(
                        choices=[("Blue Sky (Default)", "default"), ("Classic Gradio", "gradio")],
                        value=self.server_config.get("UI_theme", "default"), label="UI Theme (requires restart)"
                    )
                    self.queue_color_scheme_choice = gr.Dropdown(
                        choices=[
                            ("Pastel (Unique color for each item)", "pastel"),
                            ("Alternating Grey Shades", "alternating_grey"),
                        ],
                        value=self.server_config.get("queue_color_scheme", "pastel"),
                        label="Queue Color Scheme"
                    )

                with gr.Tab("Performance"):
                    self.quantization_choice = gr.Dropdown(choices=[("Scaled Int8 (recommended)", "int8"), ("16-bit (no quantization)", "bf16")], value=self.transformer_quantization, label="Transformer Model Quantization (if available)")
                    self.transformer_dtype_policy_choice = gr.Dropdown(choices=[("Auto (Best for Hardware)", ""), ("FP16", "fp16"), ("BF16", "bf16")], value=self.transformer_dtype_policy, label="Transformer Data Type (if available)")
                    self.mixed_precision_choice = gr.Dropdown(choices=[("16-bit only (less VRAM)", "0"), ("Mixed 16/32-bit (better quality)", "1")], value=self.server_config.get("mixed_precision", "0"), label="Transformer Engine Precision")
                    self.text_encoder_quantization_choice = gr.Dropdown(choices=[("16-bit (more RAM, better quality)", "bf16"), ("8-bit (less RAM, slightly lower quality)", "int8")], value=self.text_encoder_quantization, label="Text Encoder Precision")
                    self.VAE_precision_choice = gr.Dropdown(choices=[("16-bit (faster, less VRAM)", "16"), ("32-bit (slower, better for sliding window)", "32")], value=self.server_config.get("vae_precision", "16"), label="VAE Encoding/Decoding Precision")
                    self.compile_choice = gr.Dropdown(choices=[("On (up to 20% faster, requires Triton)", "transformer"), ("Off", "")], value=self.compile, label="Compile Transformer Model", interactive=not self.args.lock_config)
                    self.depth_anything_v2_variant_choice = gr.Dropdown(choices=[("Large (more precise, slower)", "vitl"), ("Big (less precise, faster)", "vitb")], value=self.server_config.get("depth_anything_v2_variant", "vitl"), label="Depth Anything v2 VACE Preprocessor")
                    self.vae_config_choice = gr.Dropdown(choices=[("Auto", 0), ("Disabled (fastest, high VRAM)", 1), ("256x256 Tiles (for >=8GB VRAM)", 2), ("128x128 Tiles (for >=6GB VRAM)", 3)], value=self.vae_config, label="VAE Tiling (to reduce VRAM usage)")
                    self.boost_choice = gr.Dropdown(choices=[("ON", 1), ("OFF", 2)], value=self.boost, label="Boost (~10% speedup for ~1GB VRAM)")
                    self.profile_choice = gr.Dropdown(choices=self.memory_profile_choices, value=self.default_profile, label="Memory Profile (Advanced)")
                    self.preload_in_VRAM_choice = gr.Slider(0, 40000, value=self.server_config.get("preload_in_VRAM", 0), step=100, label="VRAM (MB) for Preloaded Models (0=profile default)")
                    self.release_RAM_btn = gr.Button("Force Unload Models from RAM")

                with gr.Tab("Extensions"):
                    self.enhancer_enabled_choice = gr.Dropdown(choices=[("Off", 0), ("Florence 2 + LLama 3.2", 1), ("Florence 2 + Llama Joy (uncensored)", 2)], value=self.server_config.get("enhancer_enabled", 0), label="Prompt Enhancer (requires 8-14GB extra download)")
                    self.enhancer_mode_choice = gr.Dropdown(choices=[("Automatic on Generation", 0), ("On-Demand Button Only", 1)], value=self.server_config.get("enhancer_mode", 0), label="Prompt Enhancer Usage")
                    self.mmaudio_enabled_choice = gr.Dropdown(choices=[("Off", 0), ("Enabled (unloads after use)", 1), ("Enabled (persistent in RAM)", 2)], value=self.server_config.get("mmaudio_enabled", 0), label="MMAudio Soundtrack Generation (requires 10GB extra download)")

                with gr.Tab("Outputs"):
                    self.video_output_codec_choice = gr.Dropdown(choices=[("x265 CRF 28 (Balanced)", 'libx265_28'), ("x264 Level 8 (Balanced)", 'libx264_8'), ("x265 CRF 8 (High Quality)", 'libx265_8'), ("x264 Level 10 (High Quality)", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], value=self.server_config.get("video_output_codec", "libx264_8"), label="Video Codec")
                    self.image_output_codec_choice = gr.Dropdown(choices=[("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], value=self.server_config.get("image_output_codec", "jpeg_95"), label="Image Codec")
                    self.audio_output_codec_choice = gr.Dropdown(choices=[("AAC 128 kbit", 'aac_128')], value=self.server_config.get("audio_output_codec", "aac_128"), visible=False, label="Audio Codec to use")
                    self.metadata_choice = gr.Dropdown(
                        choices=[("Export JSON files", "json"), ("Embed metadata in file (Exif/tag)", "metadata"), ("None", "none")],
                        value=self.server_config.get("metadata_type", "metadata"), label="Metadata Handling"
                    )
                    self.embed_source_images_choice = gr.Checkbox(
                        value=self.server_config.get("embed_source_images", False),
                        label="Embed Source Images",
                        info="Saves i2v source images inside MP4 files"
                    )
                    self.video_save_path_choice = gr.Textbox(label="Video Output Folder (requires restart)", value=self.save_path)
                    self.image_save_path_choice = gr.Textbox(label="Image Output Folder (requires restart)", value=self.image_save_path)

                with gr.Tab("Notifications"):
                    self.notification_sound_enabled_choice = gr.Dropdown(choices=[("On", 1), ("Off", 0)], value=self.server_config.get("notification_sound_enabled", 0), label="Notification Sound")
                    self.notification_sound_volume_choice = gr.Slider(0, 100, value=self.server_config.get("notification_sound_volume", 50), step=5, label="Notification Volume")

            self.msg = gr.Markdown()
            with gr.Row():
                self.apply_btn = gr.Button("Save Settings")

        inputs = [
            self.state,
            self.transformer_types_choices, self.model_hierarchy_type_choice, self.fit_canvas_choice,
            self.attention_choice, self.preload_model_policy_choice, self.clear_file_list_choice,
            self.display_stats_choice, self.max_frames_multiplier_choice, self.checkpoints_paths_choice,
            self.UI_theme_choice, self.queue_color_scheme_choice,
            self.quantization_choice, self.transformer_dtype_policy_choice, self.mixed_precision_choice,
            self.text_encoder_quantization_choice, self.VAE_precision_choice, self.compile_choice,
            self.depth_anything_v2_variant_choice, self.vae_config_choice, self.boost_choice,
            self.profile_choice, self.preload_in_VRAM_choice,
            self.enhancer_enabled_choice, self.enhancer_mode_choice, self.mmaudio_enabled_choice,
            self.video_output_codec_choice, self.image_output_codec_choice, self.audio_output_codec_choice,
            self.metadata_choice, self.embed_source_images_choice,
            self.video_save_path_choice, self.image_save_path_choice,
            self.notification_sound_enabled_choice, self.notification_sound_volume_choice,
            self.resolution
        ]

        self.apply_btn.click(
            fn=self._save_changes,
            inputs=inputs,
            outputs=[
                self.msg,
                self.header,
                self.model_family,
                self.model_base_type_choice,
                self.model_choice,
                self.refresh_form_trigger
            ]
        )

        def release_ram_and_notify():
            self.release_model()
            gr.Info("Models unloaded from RAM.")
        
        self.release_RAM_btn.click(fn=release_ram_and_notify)
        return [self.release_RAM_btn]

    def _save_changes(self, state, *args):
        if self.is_generation_in_progress():
            return "<div style='color:red; text-align:center;'>Unable to change config when a generation is in progress.</div>", *[gr.update()]*5

        if self.args.lock_config:
            return "<div style='color:red; text-align:center;'>Configuration is locked by command-line arguments.</div>", *[gr.update()]*5

        old_server_config = self.server_config.copy()

        (
            transformer_types_choices, model_hierarchy_type_choice, fit_canvas_choice,
            attention_choice, preload_model_policy_choice, clear_file_list_choice,
            display_stats_choice, max_frames_multiplier_choice, checkpoints_paths_choice,
            UI_theme_choice, queue_color_scheme_choice,
            quantization_choice, transformer_dtype_policy_choice, mixed_precision_choice,
            text_encoder_quantization_choice, VAE_precision_choice, compile_choice,
            depth_anything_v2_variant_choice, vae_config_choice, boost_choice,
            profile_choice, preload_in_VRAM_choice,
            enhancer_enabled_choice, enhancer_mode_choice, mmaudio_enabled_choice,
            video_output_codec_choice, image_output_codec_choice, audio_output_codec_choice,
            metadata_choice, embed_source_images_choice,
            save_path_choice, image_save_path_choice,
            notification_sound_enabled_choice, notification_sound_volume_choice,
            last_resolution_choice
        ) = args

        if len(checkpoints_paths_choice.strip()) == 0:
            checkpoints_paths = self.fl.default_checkpoints_paths
        else:
            checkpoints_paths = [path.strip() for path in checkpoints_paths_choice.replace("\r", "").split("\n") if len(path.strip()) > 0]

        self.fl.set_checkpoints_paths(checkpoints_paths)

        new_server_config = {
            "attention_mode": attention_choice, "transformer_types": transformer_types_choices,
            "text_encoder_quantization": text_encoder_quantization_choice, "save_path": save_path_choice,
            "image_save_path": image_save_path_choice, "compile": compile_choice, "profile": profile_choice,
            "vae_config": vae_config_choice, "vae_precision": VAE_precision_choice,
            "mixed_precision": mixed_precision_choice, "metadata_type": metadata_choice,
            "transformer_quantization": quantization_choice, "transformer_dtype_policy": transformer_dtype_policy_choice,
            "boost": boost_choice, "clear_file_list": clear_file_list_choice,
            "preload_model_policy": preload_model_policy_choice, "UI_theme": UI_theme_choice,
            "fit_canvas": fit_canvas_choice, "enhancer_enabled": enhancer_enabled_choice,
            "enhancer_mode": enhancer_mode_choice, "mmaudio_enabled": mmaudio_enabled_choice,
            "preload_in_VRAM": preload_in_VRAM_choice, "depth_anything_v2_variant": depth_anything_v2_variant_choice,
            "notification_sound_enabled": notification_sound_enabled_choice,
            "notification_sound_volume": notification_sound_volume_choice,
            "max_frames_multiplier": max_frames_multiplier_choice, "display_stats": display_stats_choice,
            "video_output_codec": video_output_codec_choice, "image_output_codec": image_output_codec_choice,
            "audio_output_codec": audio_output_codec_choice,
            "model_hierarchy_type": model_hierarchy_type_choice,
            "checkpoints_paths": checkpoints_paths,
            "queue_color_scheme": queue_color_scheme_choice,
            "embed_source_images": embed_source_images_choice,
            "video_container": "mp4", # Fixed to MP4
            "last_model_type": state["model_type"],
            "last_model_per_family": state["last_model_per_family"],
            "last_model_per_type": state["last_model_per_type"],
            "last_advanced_choice": state["advanced"], "last_resolution_choice": last_resolution_choice,
            "last_resolution_per_group": state["last_resolution_per_group"],
        }
        
        if "enabled_plugins" in self.server_config:
            new_server_config["enabled_plugins"] = self.server_config["enabled_plugins"]

        if self.args.lock_config:
            if "attention_mode" in old_server_config: new_server_config["attention_mode"] = old_server_config["attention_mode"]
            if "compile" in old_server_config: new_server_config["compile"] = old_server_config["compile"]

        with open(self.server_config_filename, "w", encoding="utf-8") as writer:
            writer.write(json.dumps(new_server_config, indent=4))
        
        changes = [k for k, v in new_server_config.items() if v != old_server_config.get(k)]

        no_reload_keys = [
            "attention_mode", "vae_config", "boost", "save_path", "image_save_path",
            "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant",
            "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled",
            "max_frames_multiplier", "display_stats", "video_output_codec", "video_container",
            "embed_source_images", "image_output_codec", "audio_output_codec", "checkpoints_paths",
            "model_hierarchy_type", "UI_theme", "queue_color_scheme"
        ]

        needs_reload = not all(change in no_reload_keys for change in changes)

        self.set_global("server_config", new_server_config)
        self.set_global("three_levels_hierarchy", new_server_config["model_hierarchy_type"] == 1)
        self.set_global("attention_mode", new_server_config["attention_mode"])
        self.set_global("default_profile", new_server_config["profile"])
        self.set_global("compile", new_server_config["compile"])
        self.set_global("text_encoder_quantization", new_server_config["text_encoder_quantization"])
        self.set_global("vae_config", new_server_config["vae_config"])
        self.set_global("boost", new_server_config["boost"])
        self.set_global("save_path", new_server_config["save_path"])
        self.set_global("image_save_path", new_server_config["image_save_path"])
        self.set_global("preload_model_policy", new_server_config["preload_model_policy"])
        self.set_global("transformer_quantization", new_server_config["transformer_quantization"])
        self.set_global("transformer_dtype_policy", new_server_config["transformer_dtype_policy"])
        self.set_global("transformer_types", new_server_config["transformer_types"])
        self.set_global("reload_needed", needs_reload)

        if "enhancer_enabled" in changes or "enhancer_mode" in changes:
            self.set_global("prompt_enhancer_image_caption_model", None)
            self.set_global("prompt_enhancer_image_caption_processor", None)
            self.set_global("prompt_enhancer_llm_model", None)
            self.set_global("prompt_enhancer_llm_tokenizer", None)
            if self.enhancer_offloadobj:
                self.enhancer_offloadobj.release()
                self.set_global("enhancer_offloadobj", None)

        model_type = state["model_type"]
        
        model_family_update, model_base_type_update, model_choice_update = self.generate_dropdown_model_list(model_type)
        header_update = self.generate_header(model_type, compile=new_server_config["compile"], attention_mode=new_server_config["attention_mode"])
        
        return (
            "<div style='color:green; text-align:center;'>The new configuration has been succesfully applied.</div>",
            header_update,
            model_family_update,
            model_base_type_update,
            model_choice_update,
            self.get_unique_id()
        )