import os, sys
os.environ["GRADIO_LANG"] = "en"
# # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding
# os.environ["TORCH_LOGS"]= "recompiles"
import torch._logging as tlog
# tlog.set_logs(recompiles=True, guards=True, graph_breaks=True)
p = os.path.dirname(os.path.abspath(__file__))
if p not in sys.path:
sys.path.insert(0, p)
import asyncio
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if sys.platform.startswith("linux") and "NUMBA_THREADING_LAYER" not in os.environ:
os.environ["NUMBA_THREADING_LAYER"] = "workqueue"
from shared.asyncio_utils import silence_proactor_connection_reset
silence_proactor_connection_reset()
import time
import threading
import argparse
import warnings
warnings.filterwarnings('ignore', message='Failed to find.*', module='triton')
from mmgp import offload, safetensors2, profile_type , quant_router
try:
import triton
except ImportError:
pass
from pathlib import Path
from datetime import datetime
import gradio as gr
import random
import json
import numpy as np
import importlib
from shared.utils import notification_sound
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
from shared.utils.utils import has_video_file_extension, has_image_file_extension, has_audio_file_extension
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
from shared.utils.audio_video import save_image_metadata, read_image_metadata
from shared.utils.audio_metadata import save_audio_metadata, read_audio_metadata
from shared.utils.video_metadata import save_video_metadata
from shared.match_archi import match_nvidia_architecture
from shared.attention import get_attention_modes, get_supported_attention_modes
from shared.utils.utils import truncate_for_filesystem, sanitize_file_name, process_images_multithread, get_default_workers
from shared.utils.process_locks import acquire_GPU_ressources, release_GPU_ressources, any_GPU_process_running, gen_lock
from shared.loras_migration import migrate_loras_layout
from huggingface_hub import hf_hub_download, snapshot_download
from shared.utils import files_locator as fl
from shared.gradio.audio_gallery import AudioGallery
import torch
import gc
import traceback
import math
import typing
import inspect
from shared.utils import prompt_parser
import base64
import io
from PIL import Image
import zipfile
import tempfile
import atexit
import shutil
import glob
import cv2
import html
from transformers.utils import logging
logging.set_verbosity_error
from tqdm import tqdm
import requests
from shared.gradio.gallery import AdvancedMediaGallery
from shared.ffmpeg_setup import download_ffmpeg
from shared.utils.plugins import PluginManager, WAN2GPApplication, SYSTEM_PLUGINS
from collections import defaultdict
# import torch._dynamo as dynamo
# dynamo.config.recompile_limit = 2000 # default is 256
# dynamo.config.accumulated_recompile_limit = 2000 # or whatever limit you want
global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip"
AUTOSAVE_PATH = AUTOSAVE_FILENAME
AUTOSAVE_TEMPLATE_PATH = AUTOSAVE_FILENAME
CONFIG_FILENAME = "wgp_config.json"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.16"
WanGP_version = "10.41"
settings_version = 2.43
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
image_names_list = ["image_start", "image_end", "image_refs"]
# All media attachment keys for queue save/load
ATTACHMENT_KEYS = ["image_start", "image_end", "image_refs", "image_guide", "image_mask",
"video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source", "custom_guide"]
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
exit()
lock = threading.Lock()
current_task_id = None
task_id = 0
unique_id = 0
unique_id_lock = threading.Lock()
offloadobj = enhancer_offloadobj = wan_model = None
reload_needed = True
_HANDLER_MODULES = [
"shared.qtypes.scaled_fp8",
"shared.qtypes.nvfp4",
"shared.qtypes.nunchaku_int4",
"shared.qtypes.nunchaku_fp4",
"shared.qtypes.gguf",
]
quant_router.unregister_handler(".fp8_quanto_bridge")
for handler in _HANDLER_MODULES:
quant_router.register_handler(handler)
from shared.qtypes import gguf as gguf_handler
offload.register_file_extension("gguf", gguf_handler)
def set_wgp_global(variable_name: str, new_value: any) -> str:
if variable_name not in globals():
error_msg = f"Plugin tried to modify a non-existent global: '{variable_name}'."
print(f"ERROR: {error_msg}")
gr.Warning(error_msg)
return f"Error: Global variable '{variable_name}' does not exist."
try:
globals()[variable_name] = new_value
except Exception as e:
error_msg = f"Error while setting global '{variable_name}': {e}"
print(f"ERROR: {error_msg}")
return error_msg
def clear_gen_cache():
if "_cache" in offload.shared_state:
del offload.shared_state["_cache"]
def release_model():
global wan_model, offloadobj, reload_needed
wan_model = None
clear_gen_cache()
if "_cache" in offload.shared_state:
del offload.shared_state["_cache"]
if offloadobj is not None:
offloadobj.release()
offloadobj = None
offload.flush_torch_caches()
reload_needed = True
def get_unique_id():
global unique_id
with unique_id_lock:
unique_id += 1
return str(time.time()+unique_id)
def format_time(seconds):
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes:02d}m {secs:02d}s"
elif seconds >= 60:
return f"{minutes}m {secs:02d}s"
else:
return f"{seconds:.1f}s"
def format_generation_time(seconds):
"""Format generation time showing raw seconds with human-readable time in parentheses when over 60s"""
raw_seconds = f"{int(seconds)}s"
if seconds < 60:
return raw_seconds
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
human_readable = f"{hours}h {minutes}m {secs}s"
else:
human_readable = f"{minutes}m {secs}s"
return f"{raw_seconds} ({human_readable})"
def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None:
return None
if isinstance(pil_image, str):
# Check file type and load appropriately
if has_video_file_extension(pil_image):
from shared.utils.utils import get_video_frame
pil_image = get_video_frame(pil_image, 0)
elif has_image_file_extension(pil_image):
pil_image = Image.open(pil_image)
else:
# Audio or unknown file type - can't convert to image
return None
buffer = io.BytesIO()
try:
img_to_save = pil_image
if format.lower() == 'jpeg' and pil_image.mode == 'RGBA':
img_to_save = pil_image.convert('RGB')
elif format.lower() == 'png' and pil_image.mode not in ['RGB', 'RGBA', 'L', 'P']:
img_to_save = pil_image.convert('RGBA')
elif pil_image.mode == 'P':
img_to_save = pil_image.convert('RGBA' if 'transparency' in pil_image.info else 'RGB')
if format.lower() == 'jpeg':
img_to_save.save(buffer, format=format, quality=quality)
else:
img_to_save.save(buffer, format=format)
img_bytes = buffer.getvalue()
encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return f"data:image/{format.lower()};base64,{encoded_string}"
except Exception as e:
print(f"Error converting PIL to base64: {e}")
return None
def is_integer(n):
try:
float(n)
except ValueError:
return False
else:
return float(n).is_integer()
def get_state_model_type(state):
key= "model_type" if state.get("active_form", "add") == "add" else "edit_model_type"
return state[key]
def compute_sliding_window_no(current_video_length, sliding_window_size, discard_last_frames, reuse_frames):
left_after_first_window = current_video_length - sliding_window_size + discard_last_frames
return 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames))
def clean_image_list(gradio_list):
if not isinstance(gradio_list, list): gradio_list = [gradio_list]
gradio_list = [ tup[0] if isinstance(tup, tuple) else tup for tup in gradio_list ]
if any( not isinstance(image, (Image.Image, str)) for image in gradio_list): return None
if any( isinstance(image, str) and not has_image_file_extension(image) for image in gradio_list): return None
gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ]
return gradio_list
_generate_video_param_names = None
def _get_generate_video_param_names():
"""Get parameter names from generate_video signature (cached)."""
global _generate_video_param_names
if _generate_video_param_names is None:
_generate_video_param_names = [
x for x in inspect.signature(generate_video).parameters
if x not in ["task", "send_cmd", "plugin_data"]
]
return _generate_video_param_names
def silent_cancel_edit(state):
gen = get_gen_info(state)
state["editing_task_id"] = None
if gen.get("queue_paused_for_edit"):
gen["queue_paused_for_edit"] = False
return gr.Tabs(selected="video_gen"), None, gr.update(visible=False)
def cancel_edit(state):
gen = get_gen_info(state)
state["editing_task_id"] = None
if gen.get("queue_paused_for_edit"):
gen["queue_paused_for_edit"] = False
gr.Info("Edit cancelled. Resuming queue processing.")
else:
gr.Info("Edit cancelled.")
return gr.Tabs(selected="video_gen"), gr.update(visible=False)
def validate_edit(state):
state["validate_edit_success"] = 0
model_type = get_state_model_type(state)
inputs = state.get("edit_state", None)
if inputs is None:
return
override_inputs, prompts, image_start, image_end = validate_settings(state, model_type, True, inputs)
if override_inputs is None:
return
inputs.update(override_inputs)
state["edit_state"] = inputs
state["validate_edit_success"] = 1
def edit_task_in_queue( state ):
gen = get_gen_info(state)
queue = gen.get("queue", [])
editing_task_id = state.get("editing_task_id", None)
new_inputs = state.pop("edit_state", None)
if editing_task_id is None or new_inputs is None:
gr.Warning("No task selected for editing.")
return None, gr.Tabs(selected="video_gen"), gr.update(visible=False), gr.update()
if state.get("validate_edit_success", 0) == 0:
return None, gr.update(), gr.update(), gr.update()
task_to_edit_index = -1
with lock:
task_to_edit_index = next((i for i, task in enumerate(queue) if task['id'] == editing_task_id), -1)
if task_to_edit_index == -1:
gr.Warning("Task not found in queue. It might have been processed or deleted.")
state["editing_task_id"] = None
gen["queue_paused_for_edit"] = False
return None, gr.Tabs(selected="video_gen"), gr.update(visible=False), gr.update()
model_type = get_state_model_type(state)
new_inputs["model_type"] = model_type
new_inputs["state"] = state
new_inputs["model_filename"] = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)
task_to_edit = queue[task_to_edit_index]
task_to_edit['params'] = new_inputs
task_to_edit['prompt'] = new_inputs.get('prompt')
task_to_edit['length'] = new_inputs.get('video_length')
task_to_edit['steps'] = new_inputs.get('num_inference_steps')
update_task_thumbnails(task_to_edit, task_to_edit['params'])
gr.Info(f"Task ID {task_to_edit['id']} has been updated successfully.")
state["editing_task_id"] = None
if gen.get("queue_paused_for_edit"):
gr.Info("Resuming queue processing.")
gen["queue_paused_for_edit"] = False
return task_to_edit_index -1, gr.Tabs(selected="video_gen"), gr.update(visible=False), update_queue_data(queue)
def process_prompt_and_add_tasks(state, current_gallery_tab, model_choice):
def ret():
return gr.update(), gr.update()
gen = get_gen_info(state)
current_gallery_tab
gen["last_was_audio"] = current_gallery_tab == 1
if state.get("validate_success",0) != 1:
ret()
state["validate_success"] = 0
model_type = get_state_model_type(state)
inputs = get_model_settings(state, model_type)
if model_choice != model_type or inputs ==None:
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
inputs["state"] = state
inputs["model_type"] = model_type
inputs.pop("lset_name")
if inputs == None:
gr.Warning("Internal state error: Could not retrieve inputs for the model.")
queue = gen.get("queue", [])
return ret()
mode = inputs["mode"]
if mode.startswith("edit_"):
edit_video_source =gen.get("edit_video_source", None)
edit_overrides =gen.get("edit_overrides", None)
_ , _ , _, frames_count = get_video_info(edit_video_source)
if frames_count > max_source_video_frames:
gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated")
# return
for prop in ["state", "model_type", "mode"]:
edit_overrides[prop] = inputs[prop]
for k,v in inputs.items():
inputs[k] = None
inputs.update(edit_overrides)
del gen["edit_video_source"], gen["edit_overrides"]
inputs["video_source"]= edit_video_source
prompt = []
repeat_generation = 1
if mode == "edit_postprocessing":
spatial_upsampling = inputs.get("spatial_upsampling","")
if len(spatial_upsampling) >0: prompt += ["Spatial Upsampling"]
temporal_upsampling = inputs.get("temporal_upsampling","")
if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"]
if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0:
gr.Info("Temporal Upsampling can not be used with an Image")
return ret()
film_grain_intensity = inputs.get("film_grain_intensity",0)
film_grain_saturation = inputs.get("film_grain_saturation",0.5)
# if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"]
if film_grain_intensity >0: prompt += ["Film Grain"]
elif mode =="edit_remux":
MMAudio_setting = inputs.get("MMAudio_setting",0)
repeat_generation= inputs.get("repeat_generation",1)
audio_source = inputs["audio_source"]
if MMAudio_setting== 1:
prompt += ["MMAudio"]
audio_source = None
inputs["audio_source"] = audio_source
else:
if audio_source is None:
gr.Info("You must provide a custom Audio")
return ret()
prompt += ["Custom Audio"]
repeat_generation = 1
seed = inputs.get("seed",None)
inputs["repeat_generation"] = repeat_generation
if len(prompt) == 0:
if mode=="edit_remux":
gr.Info("You must choose at least one Remux Method")
else:
gr.Info("You must choose at least one Post Processing Method")
return ret()
inputs["prompt"] = ", ".join(prompt)
add_video_task(**inputs)
new_prompts_count = gen["prompts_max"] = 1 + gen.get("prompts_max",0)
state["validate_success"] = 1
queue= gen.get("queue", [])
return update_queue_data(queue), gr.update(open=True) if new_prompts_count > 1 else gr.update()
override_inputs, prompts, image_start, image_end = validate_settings(state, model_type, False, inputs)
if override_inputs is None:
return ret()
multi_prompts_gen_type = inputs["multi_prompts_gen_type"]
if multi_prompts_gen_type in [0,2]:
if image_start != None and len(image_start) > 0:
if inputs["multi_images_gen_type"] == 0:
new_prompts = []
new_image_start = []
new_image_end = []
for i in range(len(prompts) * len(image_start) ):
new_prompts.append( prompts[ i % len(prompts)] )
new_image_start.append(image_start[i // len(prompts)] )
if image_end != None:
new_image_end.append(image_end[i // len(prompts)] )
prompts = new_prompts
image_start = new_image_start
if image_end != None:
image_end = new_image_end
else:
if len(prompts) >= len(image_start):
if len(prompts) % len(image_start) != 0:
gr.Info("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
return ret()
rep = len(prompts) // len(image_start)
new_image_start = []
new_image_end = []
for i, _ in enumerate(prompts):
new_image_start.append(image_start[i//rep] )
if image_end != None:
new_image_end.append(image_end[i//rep] )
image_start = new_image_start
if image_end != None:
image_end = new_image_end
else:
if len(image_start) % len(prompts) !=0:
gr.Info("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
return ret()
rep = len(image_start) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_start):
new_prompts.append( prompts[ i//rep] )
prompts = new_prompts
if image_end == None or len(image_end) == 0:
image_end = [None] * len(prompts)
for single_prompt, start, end in zip(prompts, image_start, image_end) :
override_inputs.update({
"prompt" : single_prompt,
"image_start": start,
"image_end" : end,
})
inputs.update(override_inputs)
add_video_task(**inputs)
else:
for single_prompt in prompts :
override_inputs["prompt"] = single_prompt
inputs.update(override_inputs)
add_video_task(**inputs)
new_prompts_count = len(prompts)
else:
new_prompts_count = 1
override_inputs["prompt"] = "\n".join(prompts)
inputs.update(override_inputs)
add_video_task(**inputs)
new_prompts_count += gen.get("prompts_max",0)
gen["prompts_max"] = new_prompts_count
state["validate_success"] = 1
queue= gen.get("queue", [])
return update_queue_data(queue), gr.update(open=True) if new_prompts_count > 1 else gr.update()
def validate_settings(state, model_type, single_prompt, inputs):
def ret():
return None, None, None, None
model_def = get_model_def(model_type)
model_handler = get_model_handler(model_type)
image_outputs = inputs["image_mode"] > 0
any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False)
model_type = get_base_model_type(model_type)
model_filename = get_model_filename(model_type)
if hasattr(model_handler, "validate_generative_settings"):
error = model_handler.validate_generative_settings(model_type, model_def, inputs)
if error is not None and len(error) > 0:
gr.Info(error)
return ret()
if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0:
gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time")
return ret()
prompt = inputs["prompt"]
if len(prompt) ==0:
gr.Info("Prompt cannot be empty.")
gen = get_gen_info(state)
queue = gen.get("queue", [])
return ret()
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info("Error processing prompt template: " + errors)
return ret()
multi_prompts_gen_type = inputs["multi_prompts_gen_type"]
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if single_prompt or multi_prompts_gen_type == 2:
prompts = ["\n".join(prompts)]
if len(prompts) == 0:
gr.Info("Prompt cannot be empty.")
gen = get_gen_info(state)
queue = gen.get("queue", [])
return ret()
if hasattr(model_handler, "validate_generative_prompt"):
for one_prompt in prompts:
error = model_handler.validate_generative_prompt(model_type, model_def, inputs, one_prompt)
if error is not None and len(error) > 0:
gr.Info(error)
return ret()
resolution = inputs["resolution"]
width, height = resolution.split("x")
width, height = int(width), int(height)
image_start = inputs["image_start"]
image_end = inputs["image_end"]
image_refs = inputs["image_refs"]
image_prompt_type = inputs["image_prompt_type"]
audio_prompt_type = inputs["audio_prompt_type"]
if image_prompt_type == None: image_prompt_type = ""
video_prompt_type = inputs["video_prompt_type"]
if video_prompt_type == None: video_prompt_type = ""
force_fps = inputs["force_fps"]
audio_guide = inputs["audio_guide"]
audio_guide2 = inputs["audio_guide2"]
audio_source = inputs["audio_source"]
video_guide = inputs["video_guide"]
image_guide = inputs["image_guide"]
video_mask = inputs["video_mask"]
image_mask = inputs["image_mask"]
custom_guide = inputs["custom_guide"]
speakers_locations = inputs["speakers_locations"]
video_source = inputs["video_source"]
frames_positions = inputs["frames_positions"]
keep_frames_video_guide= inputs["keep_frames_video_guide"]
keep_frames_video_source = inputs["keep_frames_video_source"]
denoising_strength= inputs["denoising_strength"]
masking_strength= inputs["masking_strength"]
input_video_strength = inputs.get("input_video_strength", 1.0)
sliding_window_size = inputs["sliding_window_size"]
sliding_window_overlap = inputs["sliding_window_overlap"]
sliding_window_discard_last_frames = inputs["sliding_window_discard_last_frames"]
video_length = inputs["video_length"]
num_inference_steps= inputs["num_inference_steps"]
skip_steps_cache_type= inputs["skip_steps_cache_type"]
MMAudio_setting = inputs["MMAudio_setting"]
image_mode = inputs["image_mode"]
switch_threshold = inputs["switch_threshold"]
loras_multipliers = inputs["loras_multipliers"]
activated_loras = inputs["activated_loras"]
guidance_phases= inputs["guidance_phases"]
model_switch_phase = inputs["model_switch_phase"]
switch_threshold = inputs["switch_threshold"]
switch_threshold2 = inputs["switch_threshold2"]
video_guide_outpainting = inputs["video_guide_outpainting"]
spatial_upsampling = inputs["spatial_upsampling"]
motion_amplitude = inputs["motion_amplitude"]
medium = "Videos" if image_mode == 0 else "Images"
if image_start is not None and not isinstance(image_start, list): image_start = [image_start]
outpainting_dims = get_outpainting_dims(video_guide_outpainting)
if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"):
gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting")
if not model_def.get("motion_amplitude", False): motion_amplitude = 1.
if "vae" in spatial_upsampling:
if image_mode not in model_def.get("vae_upsampler", []):
gr.Info(f"VAE Spatial Upsampling is not available for {medium}")
return ret()
if len(activated_loras) > 0:
error = check_loras_exist(model_type, activated_loras)
if len(error) > 0:
gr.Info(error)
return ret()
if model_def.get("lock_guidance_phases", False):
guidance_phases = model_def.get("guidance_max_phases", 0)
else:
guidance_phases = min(guidance_phases, model_def.get("guidance_max_phases", 0))
if len(loras_multipliers) > 0:
_, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases)
if len(errors) > 0:
gr.Info(f"Error parsing Loras Multipliers: {errors}")
return ret()
if guidance_phases == 3:
if switch_threshold < switch_threshold2:
gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.")
return ret()
else:
model_switch_phase = 1
if not any_steps_skipping: skip_steps_cache_type = ""
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
gr.Info("The minimum number of steps should be 20")
return ret()
if skip_steps_cache_type == "mag":
if num_inference_steps > 50:
gr.Info("Mag Cache maximum number of steps is 50")
return ret()
if image_mode > 0:
audio_prompt_type = ""
if "K" in audio_prompt_type and "V" not in video_prompt_type:
gr.Info("You must enable a Control Video to use the Control Video Audio Track as an audio prompt")
return ret()
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
speakers_bboxes, error = parse_speakers_locations(speakers_locations)
if len(error) > 0:
gr.Info(error)
return ret()
if MMAudio_setting != 0 and get_mmaudio_settings(server_config)[0] and video_length <16: #should depend on the architecture
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
if "F" in video_prompt_type:
if len(frames_positions.strip()) > 0:
positions = frames_positions.replace(","," ").split(" ")
for pos_str in positions:
if not pos_str in ["L", "l"] and len(pos_str)>0:
if not is_integer(pos_str):
gr.Info(f"Invalid Frame Position '{pos_str}'")
return ret()
pos = int(pos_str)
if pos <1 or pos > max_source_video_frames:
gr.Info(f"Invalid Frame Position Value'{pos_str}'")
return ret()
else:
frames_positions = None
if audio_source is not None and MMAudio_setting != 0:
gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time")
return ret()
if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0:
if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0:
gr.Info("The number of frames to keep must be a non null integer")
return ret()
else:
keep_frames_video_source = ""
if image_outputs:
image_prompt_type = image_prompt_type.replace("V", "").replace("L", "")
custom_guide_def = model_def.get("custom_guide", None)
if custom_guide_def is not None:
if custom_guide is None and custom_guide_def.get("required", False):
gr.Info(f"You must provide a {custom_guide_def.get('label', 'Custom Guide')}")
return ret()
else:
custom_guide = None
if "V" in image_prompt_type:
if video_source == None:
gr.Info("You must provide a Source Video file to continue")
return ret()
else:
video_source = None
if len(model_def.get("input_video_strength", ""))==0 or not any_letters(image_prompt_type, "SVL"):
input_video_strength = 1.0
if "A" in audio_prompt_type:
if audio_guide == None:
gr.Info("You must provide an Audio Source")
return ret()
if "B" in audio_prompt_type:
if audio_guide2 == None:
gr.Info("You must provide a second Audio Source")
return ret()
else:
audio_guide2 = None
else:
audio_guide = None
audio_guide2 = None
if model_type in ["vace_multitalk_14B"] and ("B" in audio_prompt_type or "X" in audio_prompt_type):
if not "I" in video_prompt_type and not not "V" in video_prompt_type:
gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame or a Control Video (potentially truncated) that contains the two people one on each side")
if model_def.get("one_image_ref_needed", False):
if image_refs == None :
gr.Info("You must provide an Image Reference")
return ret()
if len(image_refs) > 1:
gr.Info("Only one Image Reference (a person) is supported for the moment by this model")
return ret()
if model_def.get("at_least_one_image_ref_needed", False):
if image_refs == None :
gr.Info("You must provide at least one Image Reference")
return ret()
if "I" in video_prompt_type:
if image_refs == None or len(image_refs) == 0:
gr.Info("You must provide at least one Reference Image")
return ret()
image_refs = clean_image_list(image_refs)
if image_refs == None :
gr.Info("A Reference Image should be an Image")
return ret()
else:
image_refs = None
if "V" in video_prompt_type:
if image_outputs:
if image_guide is None:
gr.Info("You must provide a Control Image")
return ret()
else:
if video_guide is None:
gr.Info("You must provide a Control Video")
return ret()
if "A" in video_prompt_type and not "U" in video_prompt_type:
if image_outputs:
if image_mask is None:
gr.Info("You must provide a Image Mask")
return ret()
else:
if video_mask is None:
gr.Info("You must provide a Video Mask")
return ret()
else:
video_mask = None
image_mask = None
if "G" in video_prompt_type:
if denoising_strength < 1. and not model_def.get("custom_denoising_strength", False):
gr.Info(f"With Denoising Strength {denoising_strength:.1f}, Denoising will start at Step no {int(round(num_inference_steps * (1. - denoising_strength),4))} ")
else:
denoising_strength = 1.0
if "G" in video_prompt_type or model_def.get("mask_strength_always_enabled", False):
if "A" in video_prompt_type and "U" not in video_prompt_type and masking_strength < 1.:
masking_duration = math.ceil(num_inference_steps * masking_strength)
if masking_strength:
gr.Info(f"With Masking Strength {masking_strength:.1f}, Masking will last {masking_duration}{' Step' if masking_duration==1 else ' Steps'}")
else:
masking_strength = 1.0
if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
gr.Info("Keep Frames for Control Video is not supported with LTX Video")
return ret()
_, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length)
if len(error) > 0:
gr.Info(f"Invalid Keep Frames property: {error}")
return ret()
else:
video_guide = None
image_guide = None
video_mask = None
image_mask = None
keep_frames_video_guide = ""
denoising_strength = 1.0
masking_strength = 1.0
if image_outputs:
video_guide = None
video_mask = None
else:
image_guide = None
image_mask = None
if "S" in image_prompt_type:
if model_def.get("black_frame", False) and len(image_start or [])==0:
if "E" in image_prompt_type and len(image_end or []):
image_end = clean_image_list(image_end)
image_start = [Image.new("RGB", image.size, (0, 0, 0, 255)) for image in image_end]
else:
image_start = [Image.new("RGB", (width, height), (0, 0, 0, 255))]
if image_start == None or isinstance(image_start, list) and len(image_start) == 0:
gr.Info("You must provide a Start Image")
return ret()
image_start = clean_image_list(image_start)
if image_start == None :
gr.Info("Start Image should be an Image")
return ret()
if multi_prompts_gen_type in [1] and len(image_start) > 1:
gr.Info("Only one Start Image is supported if the option 'Each Line Will be used for a new Sliding Window of the same Video Generation' is set")
return ret()
else:
image_start = None
if not any_letters(image_prompt_type, "SVL"):
image_prompt_type = image_prompt_type.replace("E", "")
if "E" in image_prompt_type:
if image_end == None or isinstance(image_end, list) and len(image_end) == 0:
gr.Info("You must provide an End Image")
return ret()
image_end = clean_image_list(image_end)
if image_end == None :
gr.Info("End Image should be an Image")
return ret()
if (video_source is not None or "L" in image_prompt_type):
if multi_prompts_gen_type in [0,2] and len(image_end)> 1:
gr.Info("If you want to Continue a Video, you can use Multiple End Images only if the option 'Each Line Will be used for a new Sliding Window of the same Video Generation' is set")
return ret()
elif multi_prompts_gen_type in [0, 2]:
if len(image_start or []) != len(image_end or []):
gr.Info("The number of Start and End Images should be the same if the option 'Each Line Will be used for a new Sliding Window of the same Video Generation' is not set")
return ret()
else:
image_end = None
if test_any_sliding_window(model_type) and image_mode == 0:
if video_length > sliding_window_size:
if test_class_t2v(model_type) and not "G" in video_prompt_type :
gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.")
return ret()
full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1
extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation"
no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap)
gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated")
if "recam" in model_filename:
if video_guide == None:
gr.Info("You must provide a Control Video")
return ret()
computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source )
frames = get_resampled_video(video_guide, 0, 81, computed_fps)
if len(frames)<81:
gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done")
return ret()
if "hunyuan_custom_custom_edit" in model_filename:
if len(keep_frames_video_guide) > 0:
gr.Info("Filtering Frames with this model is not supported")
return ret()
if multi_prompts_gen_type in [1] or single_prompt:
if image_start != None and len(image_start) > 1:
if single_prompt:
gr.Info("Only one Start Image can be provided in Edit Mode")
else:
gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows")
return ret()
# if image_end != None and len(image_end) > 1:
# gr.Info("Only one End Image must be provided if multiple prompts are used for different windows")
# return
override_inputs = {
"image_start": image_start[0] if image_start !=None and len(image_start) > 0 else None,
"image_end": image_end, #[0] if image_end !=None and len(image_end) > 0 else None,
"image_refs": image_refs,
"audio_guide": audio_guide,
"audio_guide2": audio_guide2,
"audio_source": audio_source,
"video_guide": video_guide,
"image_guide": image_guide,
"video_mask": video_mask,
"image_mask": image_mask,
"custom_guide": custom_guide,
"video_source": video_source,
"frames_positions": frames_positions,
"keep_frames_video_source": keep_frames_video_source,
"input_video_strength": input_video_strength,
"keep_frames_video_guide": keep_frames_video_guide,
"denoising_strength": denoising_strength,
"masking_strength": masking_strength,
"image_prompt_type": image_prompt_type,
"video_prompt_type": video_prompt_type,
"audio_prompt_type": audio_prompt_type,
"skip_steps_cache_type": skip_steps_cache_type,
"model_switch_phase": model_switch_phase,
"motion_amplitude": motion_amplitude,
}
return override_inputs, prompts, image_start, image_end
def get_preview_images(inputs):
inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ]
labels = ["Start Image", "Video Source", "End Image", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"]
start_image_data = None
start_image_labels = []
end_image_data = None
end_image_labels = []
for label, name in zip(labels,inputs_to_query):
image= inputs.get(name, None)
if image is not None:
image= [image] if not isinstance(image, list) else image.copy()
if start_image_data == None:
start_image_data = image
start_image_labels += [label] * len(image)
else:
if end_image_data == None:
end_image_data = image
else:
end_image_data += image
end_image_labels += [label] * len(image)
if start_image_data != None and len(start_image_data) > 1 and end_image_data == None:
end_image_data = start_image_data [1:]
end_image_labels = start_image_labels [1:]
start_image_data = start_image_data [:1]
start_image_labels = start_image_labels [:1]
return start_image_data, end_image_data, start_image_labels, end_image_labels
def add_video_task(**inputs):
global task_id
state = inputs["state"]
gen = get_gen_info(state)
queue = gen["queue"]
task_id += 1
current_task_id = task_id
start_image_data, end_image_data, start_image_labels, end_image_labels = get_preview_images(inputs)
plugin_data = inputs.pop('plugin_data', {})
queue.append({
"id": current_task_id,
"params": inputs.copy(),
"plugin_data": plugin_data,
"repeats": inputs.get("repeat_generation",1),
"length": inputs.get("video_length",0) or 0,
"steps": inputs.get("num_inference_steps",0) or 0,
"prompt": inputs.get("prompt", ""),
"start_image_labels": start_image_labels,
"end_image_labels": end_image_labels,
"start_image_data": start_image_data,
"end_image_data": end_image_data,
"start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None,
"end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None
})
def update_task_thumbnails(task, inputs):
start_image_data, end_image_data, start_labels, end_labels = get_preview_images(inputs)
task.update({
"start_image_labels": start_labels,
"end_image_labels": end_labels,
"start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None,
"end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None
})
def move_task(queue, old_index_str, new_index_str):
try:
old_idx = int(old_index_str)
new_idx = int(new_index_str)
except (ValueError, IndexError):
return update_queue_data(queue)
with lock:
old_idx += 1
new_idx += 1
if not (0 < old_idx < len(queue)):
return update_queue_data(queue)
item_to_move = queue.pop(old_idx)
if old_idx < new_idx:
new_idx -= 1
clamped_new_idx = max(1, min(new_idx, len(queue)))
queue.insert(clamped_new_idx, item_to_move)
return update_queue_data(queue)
def remove_task(queue, task_id_to_remove):
if not task_id_to_remove:
return update_queue_data(queue)
with lock:
idx_to_del = next((i for i, task in enumerate(queue) if task['id'] == task_id_to_remove), -1)
if idx_to_del != -1:
if idx_to_del == 0:
wan_model._interrupt = True
del queue[idx_to_del]
return update_queue_data(queue)
def update_global_queue_ref(queue):
global global_queue_ref
with lock:
global_queue_ref = queue[:]
def _save_queue_to_zip(queue, output):
"""Save queue to ZIP. output can be a filename (str) or BytesIO buffer.
Returns True on success, False on failure.
"""
if not queue:
return False
with tempfile.TemporaryDirectory() as tmpdir:
queue_manifest = []
file_paths_in_zip = {}
for task_index, task in enumerate(queue):
if task is None or not isinstance(task, dict) or task.get('id') is None:
continue
params_copy = task.get('params', {}).copy()
task_id_s = task.get('id', f"task_{task_index}")
for key in ATTACHMENT_KEYS:
value = params_copy.get(key)
if value is None:
continue
is_originally_list = isinstance(value, list)
items = value if is_originally_list else [value]
processed_filenames = []
for item_index, item in enumerate(items):
if isinstance(item, Image.Image):
item_id = id(item)
if item_id in file_paths_in_zip:
processed_filenames.append(file_paths_in_zip[item_id])
continue
filename_in_zip = f"task{task_id_s}_{key}_{item_index}.png"
save_path = os.path.join(tmpdir, filename_in_zip)
try:
item.save(save_path, "PNG")
processed_filenames.append(filename_in_zip)
file_paths_in_zip[item_id] = filename_in_zip
except Exception as e:
print(f"Error saving attachment {filename_in_zip}: {e}")
elif isinstance(item, str):
if item in file_paths_in_zip:
processed_filenames.append(file_paths_in_zip[item])
continue
if not os.path.isfile(item):
continue
_, extension = os.path.splitext(item)
filename_in_zip = f"task{task_id_s}_{key}_{item_index}{extension if extension else ''}"
save_path = os.path.join(tmpdir, filename_in_zip)
try:
shutil.copy2(item, save_path)
processed_filenames.append(filename_in_zip)
file_paths_in_zip[item] = filename_in_zip
except Exception as e:
print(f"Error copying attachment {item}: {e}")
if processed_filenames:
params_copy[key] = processed_filenames if is_originally_list else processed_filenames[0]
# Remove runtime-only keys
for runtime_key in ['state', 'start_image_labels', 'end_image_labels',
'start_image_data_base64', 'end_image_data_base64',
'start_image_data', 'end_image_data']:
params_copy.pop(runtime_key, None)
params_copy['settings_version'] = settings_version
params_copy['base_model_type'] = get_base_model_type(params_copy["model_type"])
manifest_entry = {"id": task.get('id'), "params": params_copy}
manifest_entry = {k: v for k, v in manifest_entry.items() if v is not None}
queue_manifest.append(manifest_entry)
manifest_path = os.path.join(tmpdir, "queue.json")
try:
with open(manifest_path, 'w', encoding='utf-8') as f:
json.dump(queue_manifest, f, indent=4)
except Exception as e:
print(f"Error writing queue.json: {e}")
return False
try:
with zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(manifest_path, arcname="queue.json")
for saved_file_rel_path in file_paths_in_zip.values():
saved_file_abs_path = os.path.join(tmpdir, saved_file_rel_path)
if os.path.exists(saved_file_abs_path):
zf.write(saved_file_abs_path, arcname=saved_file_rel_path)
return True
except Exception as e:
print(f"Error creating zip: {e}")
return False
def save_queue_action(state):
gen = get_gen_info(state)
queue = gen.get("queue", [])
if not queue or len(queue) == 0:
gr.Info("Queue is empty. Nothing to save.")
return ""
zip_buffer = io.BytesIO()
try:
if _save_queue_to_zip(queue, zip_buffer):
zip_buffer.seek(0)
zip_base64 = base64.b64encode(zip_buffer.getvalue()).decode('utf-8')
print(f"Queue saved ({len(zip_base64)} chars)")
return zip_base64
else:
gr.Warning("Failed to save queue.")
return None
finally:
zip_buffer.close()
def _load_task_attachments(params, media_base_path, cache_dir=None, log_prefix="[load]"):
for key in ATTACHMENT_KEYS:
value = params.get(key)
if value is None:
continue
is_originally_list = isinstance(value, list)
filenames = value if is_originally_list else [value]
loaded_items = []
for filename in filenames:
if not isinstance(filename, str) or not filename.strip():
print(f"{log_prefix} Warning: Invalid filename for key '{key}'. Skipping.")
continue
if os.path.isabs(filename):
source_path = filename
else:
source_path = os.path.join(media_base_path, filename)
if not os.path.exists(source_path):
print(f"{log_prefix} Warning: File not found for '{key}': {source_path}")
continue
if cache_dir:
final_path = os.path.join(cache_dir, os.path.basename(filename))
try:
shutil.copy2(source_path, final_path)
except Exception as e:
print(f"{log_prefix} Error copying {filename}: {e}")
continue
else:
final_path = source_path
# Load images as PIL, keep videos/audio as paths
if has_image_file_extension(final_path):
try:
loaded_items.append(Image.open(final_path))
print(f"{log_prefix} Loaded image: {final_path}")
except Exception as e:
print(f"{log_prefix} Error loading image {final_path}: {e}")
else:
loaded_items.append(final_path)
print(f"{log_prefix} Using path: {final_path}")
# Update params, preserving list/single structure
if loaded_items:
# has_pil_item = any(isinstance(item, Image.Image) for item in loaded_items)
if is_originally_list: # or has_pil_item
params[key] = loaded_items
else:
params[key] = loaded_items[0]
else:
params.pop(key, None)
def _build_runtime_task(task_id_val, params, plugin_data=None):
"""Build a runtime task dict from params."""
primary_preview, secondary_preview, primary_labels, secondary_labels = get_preview_images(params)
start_b64 = [pil_to_base64_uri(primary_preview[0], format="jpeg", quality=70)] if isinstance(primary_preview, list) and primary_preview else None
end_b64 = [pil_to_base64_uri(secondary_preview[0], format="jpeg", quality=70)] if isinstance(secondary_preview, list) and secondary_preview else None
return {
"id": task_id_val,
"params": params,
"plugin_data": plugin_data or {},
"repeats": params.get('repeat_generation', 1),
"length": params.get('video_length'),
"steps": params.get('num_inference_steps'),
"prompt": params.get('prompt'),
"start_image_labels": primary_labels,
"end_image_labels": secondary_labels,
"start_image_data": params.get("image_start") or params.get("image_refs"),
"end_image_data": params.get("image_end"),
"start_image_data_base64": start_b64,
"end_image_data_base64": end_b64,
}
def _process_task_params(params, state, log_prefix="[load]"):
"""Apply defaults, fix settings, and prepare params for a task.
Returns (model_type, error_msg or None). Modifies params in place.
"""
base_model_type = params.get('base_model_type', None)
model_type = original_model_type = params.get('model_type', base_model_type)
if model_type is not None and get_model_def(model_type) is None:
model_type = base_model_type
if model_type is None:
return None, "Settings must contain 'model_type'"
params["model_type"] = model_type
if get_model_def(model_type) is None:
return None, f"Unknown model type: {original_model_type}"
# Use primary_settings as base (not model-specific saved settings)
# This ensures loaded queues/settings behave predictably
saved_settings_version = params.get('settings_version', 0)
merged = primary_settings.copy()
merged.update(params)
params.clear()
params.update(merged)
fix_settings(model_type, params, saved_settings_version)
for meta_key in ['type', 'base_model_type', 'settings_version']:
params.pop(meta_key, None)
params['state'] = state
return model_type, None
def _parse_task_manifest(manifest, state, media_base_path, cache_dir=None, log_prefix="[load]"):
global task_id
newly_loaded_queue = []
for task_index, task_data in enumerate(manifest):
if task_data is None or not isinstance(task_data, dict):
print(f"{log_prefix} Skipping invalid task data at index {task_index}")
continue
params = task_data.get('params', {})
task_id_loaded = task_data.get('id', task_id + 1)
# Process params (merge defaults, fix settings)
model_type, error = _process_task_params(params, state, log_prefix)
if error:
print(f"{log_prefix} {error} for task #{task_id_loaded}. Skipping.")
continue
# Load media attachments
_load_task_attachments(params, media_base_path, cache_dir, log_prefix)
# Build runtime task
runtime_task = _build_runtime_task(task_id_loaded, params, task_data.get('plugin_data', {}))
newly_loaded_queue.append(runtime_task)
print(f"{log_prefix} Task {task_index+1}/{len(manifest)} ready, ID: {task_id_loaded}, model: {model_type}")
# Update global task_id
if newly_loaded_queue:
current_max_id = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0])
if current_max_id >= task_id:
task_id = current_max_id + 1
return newly_loaded_queue, None
def _parse_queue_zip(filename, state):
"""Parse queue ZIP file. Returns (queue_list, error_msg or None)."""
save_path_base = server_config.get("save_path", "outputs")
cache_dir = os.path.join(save_path_base, "_loaded_queue_cache")
try:
print(f"[load_queue] Attempting to load queue from: {filename}")
os.makedirs(cache_dir, exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(filename, 'r') as zf:
if "queue.json" not in zf.namelist():
return None, "queue.json not found in zip file"
print(f"[load_queue] Extracting to temp directory...")
zf.extractall(tmpdir)
manifest_path = os.path.join(tmpdir, "queue.json")
with open(manifest_path, 'r', encoding='utf-8') as f:
manifest = json.load(f)
print(f"[load_queue] Loaded manifest with {len(manifest)} tasks.")
return _parse_task_manifest(manifest, state, tmpdir, cache_dir, "[load_queue]")
except Exception as e:
traceback.print_exc()
return None, str(e)
def _parse_settings_json(filename, state):
"""Parse a single settings JSON file. Returns (queue_list, error_msg or None).
Media paths in JSON are filesystem paths (absolute or relative to WanGP folder).
"""
global task_id
try:
print(f"[load_settings] Loading settings from: {filename}")
with open(filename, 'r', encoding='utf-8') as f:
params = json.load(f)
if isinstance(params, list):
# Accept full queue manifests or a list of settings dicts
if all(isinstance(item, dict) and "params" in item for item in params):
manifest = params
else:
manifest = []
for item in params:
if not isinstance(item, dict):
continue
task_id += 1
manifest.append({"id": task_id, "params": item, "plugin_data": {}})
elif isinstance(params, dict):
# Wrap as single-task manifest
task_id += 1
manifest = [{"id": task_id, "params": params, "plugin_data": {}}]
else:
return None, "Settings file must contain a JSON object or a list of tasks"
# Media paths are relative to WanGP folder (no cache needed)
wgp_folder = os.path.dirname(os.path.abspath(__file__))
return _parse_task_manifest(manifest, state, wgp_folder, None, "[load_settings]")
except json.JSONDecodeError as e:
return None, f"Invalid JSON: {e}"
except Exception as e:
traceback.print_exc()
return None, str(e)
def load_queue_action(filepath, state, evt:gr.EventData):
"""Load queue from ZIP or JSON file (Gradio UI wrapper)."""
global task_id
gen = get_gen_info(state)
original_queue = gen.get("queue", [])
# Determine filename (autoload vs user upload)
delete_autoqueue_file = False
if evt.target == None:
# Autoload only works with empty queue
if original_queue:
return
autoload_path = None
if Path(AUTOSAVE_PATH).is_file():
autoload_path = AUTOSAVE_PATH
delete_autoqueue_file = True
elif AUTOSAVE_TEMPLATE_PATH != AUTOSAVE_PATH and Path(AUTOSAVE_TEMPLATE_PATH).is_file():
autoload_path = AUTOSAVE_TEMPLATE_PATH
else:
return
print(f"Autoloading queue from {autoload_path}...")
filename = autoload_path
else:
if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file():
print("[load_queue_action] Warning: No valid file selected or file not found.")
return update_queue_data(original_queue)
filename = filepath.name
try:
# Detect file type and use appropriate parser
is_json = filename.lower().endswith('.json')
if is_json:
newly_loaded_queue, error = _parse_settings_json(filename, state)
# Safety: clear attachment paths when loading JSON through UI
# (JSON files contain filesystem paths which could be security-sensitive)
if newly_loaded_queue:
for task in newly_loaded_queue:
params = task.get('params', {})
for key in ATTACHMENT_KEYS:
if key in params:
params[key] = None
else:
newly_loaded_queue, error = _parse_queue_zip(filename, state)
if error:
gr.Warning(f"Failed to load queue: {error[:200]}")
return update_queue_data(original_queue)
# Merge with existing queue: renumber task IDs to avoid conflicts
# IMPORTANT: Modify list in-place to preserve references held by process_tasks
if original_queue:
# Find the highest existing task ID
max_existing_id = max([t.get('id', 0) for t in original_queue] + [0])
# Renumber newly loaded tasks
for i, task in enumerate(newly_loaded_queue):
task['id'] = max_existing_id + 1 + i
# Update global task_id counter
task_id = max_existing_id + len(newly_loaded_queue) + 1
# Extend existing queue in-place (preserves reference for running process_tasks)
original_queue.extend(newly_loaded_queue)
action_msg = f"Merged {len(newly_loaded_queue)} task(s) with existing {len(original_queue) - len(newly_loaded_queue)} task(s)"
merged_queue = original_queue
else:
# No existing queue - assign newly loaded queue directly
merged_queue = newly_loaded_queue
action_msg = f"Loaded {len(newly_loaded_queue)} task(s)"
with lock:
gen["queue"] = merged_queue
# Update state (Gradio-specific)
with lock:
gen["prompts_max"] = len(merged_queue)
update_global_queue_ref(merged_queue)
print(f"[load_queue_action] {action_msg}.")
gr.Info(action_msg)
return update_queue_data(merged_queue)
except Exception as e:
error_message = f"Error during queue load: {e}"
print(f"[load_queue_action] Caught error: {error_message}")
traceback.print_exc()
gr.Warning(f"Failed to load queue: {error_message[:200]}")
return update_queue_data(original_queue)
finally:
if delete_autoqueue_file:
if os.path.isfile(filename):
os.remove(filename)
print(f"Clear Queue: Deleted autosave file '{filename}'.")
if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name):
if tempfile.gettempdir() in os.path.abspath(filepath.name):
try:
os.remove(filepath.name)
print(f"[load_queue_action] Removed temporary upload file: {filepath.name}")
except OSError as e:
print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}")
else:
print(f"[load_queue_action] Info: Did not remove non-temporary file: {filepath.name}")
def clear_queue_action(state):
gen = get_gen_info(state)
gen["resume"] = True
queue = gen.get("queue", [])
aborted_current = False
cleared_pending = False
with lock:
if "in_progress" in gen and gen["in_progress"]:
print("Clear Queue: Signalling abort for in-progress task.")
gen["abort"] = True
gen["extra_orders"] = 0
if wan_model is not None:
wan_model._interrupt = True
aborted_current = True
if queue:
if len(queue) > 1 or (len(queue) == 1 and queue[0] is not None and queue[0].get('id') is not None):
print(f"Clear Queue: Clearing {len(queue)} tasks from queue.")
queue.clear()
cleared_pending = True
else:
pass
if aborted_current or cleared_pending:
gen["prompts_max"] = 0
if cleared_pending:
try:
if os.path.isfile(AUTOSAVE_PATH):
os.remove(AUTOSAVE_PATH)
print(f"Clear Queue: Deleted autosave file '{AUTOSAVE_PATH}'.")
except OSError as e:
print(f"Clear Queue: Error deleting autosave file '{AUTOSAVE_PATH}': {e}")
gr.Warning(f"Could not delete the autosave file '{AUTOSAVE_PATH}'. You may need to remove it manually.")
if aborted_current and cleared_pending:
gr.Info("Queue cleared and current generation aborted.")
elif aborted_current:
gr.Info("Current generation aborted.")
elif cleared_pending:
gr.Info("Queue cleared.")
else:
gr.Info("Queue is already empty or only contains the active task (which wasn't aborted now).")
return update_queue_data([])
def quit_application():
print("Save and Quit requested...")
autosave_queue()
import signal
os.kill(os.getpid(), signal.SIGINT)
def start_quit_process():
return 5, gr.update(visible=False), gr.update(visible=True)
def cancel_quit_process():
return -1, gr.update(visible=True), gr.update(visible=False)
def show_countdown_info_from_state(current_value: int):
if current_value > 0:
gr.Info(f"Quitting in {current_value}...")
return current_value - 1
return current_value
quitting_app = False
def autosave_queue():
global quitting_app
quitting_app = True
global global_queue_ref
if not global_queue_ref:
print("Autosave: Queue is empty, nothing to save.")
return
print(f"Autosaving queue ({len(global_queue_ref)} items) to {AUTOSAVE_PATH}...")
try:
if _save_queue_to_zip(global_queue_ref, AUTOSAVE_PATH):
print(f"Queue autosaved successfully to {AUTOSAVE_PATH}")
else:
print("Autosave failed.")
except Exception as e:
print(f"Error during autosave: {e}")
traceback.print_exc()
def finalize_generation_with_state(current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state
gallery_update, audio_files_paths_update, audio_file_selected_update, audio_gallery_refresh_trigger_update, gallery_tabs_update, current_gallery_tab_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
accordion_update = gr.Accordion(open=False) if len(get_gen_info(current_state).get("queue", [])) <= 1 else gr.update()
return gallery_update, audio_files_paths_update, audio_file_selected_update, audio_gallery_refresh_trigger_update, gallery_tabs_update, current_gallery_tab_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, accordion_update, current_state
def generate_queue_html(queue):
if len(queue) <= 1:
return "
Attention mode " + (attn_mode if attn_mode!="auto" else "auto/" + get_auto_attention() )
if attention_mode not in attention_modes_installed:
header += " -NOT INSTALLED-"
elif attention_mode not in attention_modes_supported:
header += " -NOT SUPPORTED-"
elif overridden_attention is not None and attention_mode != overridden_attention:
header += " -MODEL SPECIFIC-"
header += ""
if compile:
header += ", Pytorch compilation ON"
if "fp16" in model_filename:
header += ", Data Type FP16"
else:
header += ", Data Type BF16"
quant_label = quant_router.detect_quantization_label_from_filename(get_local_model_filename(full_filename))
if quant_label:
header += f", Quantization {quant_label}"
header += "
"
return header
def release_RAM():
if gen_in_progress:
gr.Info("Unable to release RAM when a Generation is in Progress")
else:
release_model()
gr.Info("Models stored in RAM have been released")
def get_gen_info(state):
cache = state.get("gen", None)
if cache == None:
cache = dict()
state["gen"] = cache
return cache
def build_callback(state, pipe, send_cmd, status, num_inference_steps, preview_meta=None):
gen = get_gen_info(state)
gen["num_inference_steps"] = num_inference_steps
start_time = time.time()
def callback(step_idx = -1, latent = None, force_refresh = True, read_state = False, override_num_inference_steps = -1, pass_no = -1, preview_meta=preview_meta, denoising_extra =""):
in_pause = False
with gen_lock:
process_status = gen.get("process_status", None)
pause_msg = None
if process_status.startswith("request:"):
gen["process_status"] = "process:" + process_status[len("request:"):]
offloadobj.unload_all()
pause_msg = gen.get("pause_msg", "Unknown Pause")
in_pause = True
if in_pause:
send_cmd("progress", [0, pause_msg])
while True:
time.sleep(0.1)
with gen_lock:
process_status = gen.get("process_status", None)
if process_status == "process:main": break
force_refresh = True
refresh_id = gen.get("refresh", -1)
if force_refresh or step_idx >= 0:
pass
else:
refresh_id = gen.get("refresh", -1)
if refresh_id < 0:
return
UI_refresh = state.get("refresh", 0)
if UI_refresh >= refresh_id:
return
if override_num_inference_steps > 0:
gen["num_inference_steps"] = override_num_inference_steps
num_inference_steps = gen.get("num_inference_steps", 0)
status = gen["progress_status"]
state["refresh"] = refresh_id
if read_state:
phase, step_idx = gen["progress_phase"]
else:
step_idx += 1
if gen.get("abort", False):
# pipe._interrupt = True
phase = "Aborting"
elif step_idx == num_inference_steps:
phase = "VAE Decoding"
else:
if pass_no <=0:
phase = "Denoising"
elif pass_no == 1:
phase = "Denoising First Pass"
elif pass_no == 2:
phase = "Denoising Second Pass"
elif pass_no == 3:
phase = "Denoising Third Pass"
else:
phase = f"Denoising {pass_no}th Pass"
if len(denoising_extra) > 0: phase += " | " + denoising_extra
gen["progress_phase"] = (phase, step_idx)
status_msg = merge_status_context(status, phase)
elapsed_time = time.time() - start_time
status_msg = merge_status_context(status, f"{phase} | {format_time(elapsed_time)}")
if step_idx >= 0:
progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps]
else:
progress_args = [0, status_msg]
# progress(*progress_args)
send_cmd("progress", progress_args)
if latent is not None:
payload = pipe.prepare_preview_payload(latent, preview_meta) if hasattr(pipe, "prepare_preview_payload") else latent
if isinstance(payload, dict):
data = payload.copy()
lat = data.get("latents")
if torch.is_tensor(lat):
data["latents"] = lat.to("cpu", non_blocking=True)
payload = data
elif torch.is_tensor(payload):
payload = payload.to("cpu", non_blocking=True)
if payload is not None:
send_cmd("preview", payload)
# gen["progress_args"] = progress_args
return callback
def pause_generation(state):
gen = get_gen_info(state)
process_id = "pause"
GPU_process_running = any_GPU_process_running(state, process_id, ignore_main= True )
if GPU_process_running:
gr.Info("Unable to pause, a PlugIn is using the GPU")
yield gr.update(), gr.update()
return
gen["resume"] = False
yield gr.Button(interactive= False), gr.update()
pause_msg = "Generation on Pause, click Resume to Restart Generation"
acquire_GPU_ressources(state, process_id , "Pause", gr= gr, custom_pause_msg= pause_msg, custom_wait_msg= "Please wait while the Pause Request is being Processed...")
gr.Info(pause_msg)
yield gr.Button(visible= False, interactive= True), gr.Button(visible= True)
while not gen.get("resume", False):
time.sleep(0.5)
release_GPU_ressources(state, process_id )
gen["resume"] = False
yield gr.Button(visible= True, interactive= True), gr.Button(visible= False)
def resume_generation(state):
gen = get_gen_info(state)
gen["resume"] = True
def abort_generation(state):
gen = get_gen_info(state)
gen["resume"] = True
if "in_progress" in gen: # and wan_model != None:
if wan_model != None:
wan_model._interrupt= True
gen["abort"] = True
msg = "Processing Request to abort Current Generation"
gen["status"] = msg
gr.Info(msg)
return gr.Button(interactive= False)
else:
return gr.Button(interactive= True)
def pack_audio_gallery_state(audio_file_list, selected_index, refresh = True):
return [json.dumps(audio_file_list), selected_index, time.time()]
def unpack_audio_list(packed_audio_file_list):
return json.loads(packed_audio_file_list)
def refresh_gallery(state): #, msg
gen = get_gen_info(state)
# gen["last_msg"] = msg
file_list = gen.get("file_list", None)
choice = gen.get("selected",0)
audio_file_list = gen.get("audio_file_list", None)
audio_choice = gen.get("audio_selected",0)
header_text = gen.get("header_text", "")
in_progress = "in_progress" in gen
if gen.get("last_selected", True) and file_list is not None:
choice = max(len(file_list) - 1,0)
if gen.get("audio_last_selected", True) and audio_file_list is not None:
audio_choice = max(len(audio_file_list) - 1,0)
last_was_audio = gen.get("last_was_audio", False)
queue = gen.get("queue", [])
abort_interactive = not gen.get("abort", False)
if not in_progress or len(queue) == 0:
return gr.Gallery(value = file_list) if last_was_audio else gr.Gallery(selected_index=choice, value = file_list), gr.update() if last_was_audio else choice, *pack_audio_gallery_state(audio_file_list, audio_choice), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= False)
else:
task = queue[0]
prompt = task["prompt"]
params = task["params"]
model_type = params["model_type"]
multi_prompts_gen_type = params["multi_prompts_gen_type"]
base_model_type = get_base_model_type(model_type)
model_def = get_model_def(model_type)
onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0 and (not params.get("mode","").startswith("edit_")) and not model_def.get("preprocess_all", False)
enhanced = False
if prompt.startswith("!enhanced!\n"):
enhanced = True
prompt = prompt[len("!enhanced!\n"):]
prompt = html.escape(prompt)
if multi_prompts_gen_type == 2:
prompt = prompt.replace("\n", " ")
elif "\n" in prompt :
prompts = prompt.split("\n")
window_no= gen.get("window_no",1)
if window_no > len(prompts):
window_no = len(prompts)
window_no -= 1
prompts[window_no]="" + prompts[window_no] + ""
prompt = " ".join(prompts)
if enhanced:
prompt = "Enhanced: " + prompt
if len(header_text) > 0:
prompt = "" + header_text + "
" + prompt
thumbnail_size = "100px"
thumbnails = ""
start_img_data = task.get('start_image_data_base64')
start_img_labels = task.get('start_image_labels')
if start_img_data and start_img_labels:
for i, (img_uri, img_label) in enumerate(zip(start_img_data, start_img_labels)):
thumbnails += f'
{img_label}
'
end_img_data = task.get('end_image_data_base64')
end_img_labels = task.get('end_image_labels')
if end_img_data and end_img_labels:
for i, (img_uri, img_label) in enumerate(zip(end_img_data, end_img_labels)):
thumbnails += f'
{img_label}
'
# Get current theme from server config
current_theme = server_config.get("UI_theme", "default")
# Use minimal, adaptive styling that blends with any background
# This creates a subtle container that doesn't interfere with the page's theme
table_style = """
border: 1px solid rgba(128, 128, 128, 0.3);
background-color: transparent;
color: inherit;
padding: 8px;
border-radius: 6px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
"""
if params.get("mode", None) in ['edit'] : onemorewindow_visible = False
gen_buttons_visible = True
html_content = f"
" + prompt + "
" + thumbnails + "
"
html_output = gr.HTML(html_content, visible= True)
if last_was_audio:
audio_choice = max(0, audio_choice)
else:
choice = max(0, choice)
return gr.Gallery(value = file_list) if last_was_audio else gr.Gallery(selected_index=choice, value = file_list), gr.update() if last_was_audio else choice, *pack_audio_gallery_state(audio_file_list, audio_choice), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), gr.Row(visible= gen_buttons_visible), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= onemorewindow_visible)
def finalize_generation(state):
gen = get_gen_info(state)
choice = gen.get("selected",0)
if "in_progress" in gen:
del gen["in_progress"]
if gen.get("last_selected", True):
file_list = gen.get("file_list", [])
choice = len(file_list) - 1
audio_file_list = gen.get("audio_file_list", [])
audio_choice = gen.get("audio_selected", 0)
if gen.get("audio_last_selected", True):
audio_choice = len(audio_file_list) - 1
gen["extra_orders"] = 0
last_was_audio = gen.get("last_was_audio", False)
gallery_tabs = gr.Tabs(selected= "audio" if last_was_audio else "video_images")
time.sleep(0.2)
global gen_in_progress
gen_in_progress = False
return gr.update() if last_was_audio else gr.Gallery(selected_index=choice), *pack_audio_gallery_state(audio_file_list, audio_choice), gallery_tabs, 1 if last_was_audio else 0, gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
def get_default_video_info():
return "Please Select an Video / Image"
def get_file_list(state, input_file_list, audio_files = False):
gen = get_gen_info(state)
with lock:
if audio_files:
file_list_name = "audio_file_list"
file_settings_name = "audio_file_settings_list"
else:
file_list_name = "file_list"
file_settings_name = "file_settings_list"
if file_list_name in gen:
file_list = gen[file_list_name]
file_settings_list = gen[file_settings_name]
else:
file_list = []
file_settings_list = []
if input_file_list != None:
for file_path in input_file_list:
if isinstance(file_path, tuple): file_path = file_path[0]
file_settings, _, _ = get_settings_from_file(state, file_path, False, False, False)
file_list.append(file_path)
file_settings_list.append(file_settings)
gen[file_list_name] = file_list
gen[file_settings_name] = file_settings_list
return file_list, file_settings_list
def set_file_choice(gen, file_list, choice, audio_files = False):
if len(file_list) > 0: choice = max(choice,0)
gen["audio_last_selected" if audio_files else "last_selected"] = (choice + 1) >= len(file_list)
gen["audio_selected" if audio_files else "selected"] = choice
def select_audio(state, audio_files_paths, audio_file_selected):
gen = get_gen_info(state)
audio_file_list, audio_file_settings_list = get_file_list(state, unpack_audio_list(audio_files_paths))
if audio_file_selected >= 0:
choice = audio_file_selected
else:
choice = min(len(audio_file_list)-1, gen.get("audio_selected",0)) if len(audio_file_list) > 0 else -1
set_file_choice(gen, audio_file_list, choice, audio_files=True )
video_guide_processes = "PEDSLCMU"
all_guide_processes = video_guide_processes + "VGBH"
process_map_outside_mask = { "Y" : "depth", "W": "scribble", "X": "inpaint", "Z": "flow"}
process_map_video_guide = { "P": "pose", "D" : "depth", "S": "scribble", "E": "canny", "L": "flow", "C": "gray", "M": "inpaint", "U": "identity"}
all_process_map_video_guide = { "B": "face", "H" : "bbox"}
all_process_map_video_guide.update(process_map_video_guide)
processes_names = { "pose": "Open Pose", "depth": "Depth Mask", "scribble" : "Shapes", "flow" : "Flow Map", "gray" : "Gray Levels", "inpaint" : "Inpaint Mask", "identity": "Identity Mask", "raw" : "Raw Format", "canny" : "Canny Edges", "face": "Face Movements", "bbox": "BBox"}
def update_video_prompt_type(state, any_video_guide = False, any_video_mask = False, any_background_image_ref = False, process_type = None, default_update = ""):
letters = default_update
settings = get_current_model_settings(state)
video_prompt_type = settings["video_prompt_type"]
if process_type is not None:
video_prompt_type = del_in_sequence(video_prompt_type, video_guide_processes)
for one_process_type in process_type:
for k,v in process_map_video_guide.items():
if v== one_process_type:
letters += k
break
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
guide_preprocessing = model_def.get("guide_preprocessing", None)
mask_preprocessing = model_def.get("mask_preprocessing", None)
guide_custom_choices = model_def.get("guide_custom_choices", None)
if any_video_guide: letters += "V"
if any_video_mask: letters += "A"
if any_background_image_ref:
video_prompt_type = del_in_sequence(video_prompt_type, "F")
letters += "KI"
validated_letters = ""
for letter in letters:
if not guide_preprocessing is None:
if any(letter in choice for choice in guide_preprocessing["selection"] ):
validated_letters += letter
continue
if not mask_preprocessing is None:
if any(letter in choice for choice in mask_preprocessing["selection"] ):
validated_letters += letter
continue
if not guide_custom_choices is None:
if any(letter in choice for label, choice in guide_custom_choices["choices"] ):
validated_letters += letter
continue
video_prompt_type = add_to_sequence(video_prompt_type, letters)
settings["video_prompt_type"] = video_prompt_type
def select_video(state, current_gallery_tab, input_file_list, file_selected, audio_files_paths, audio_file_selected, source, event_data: gr.EventData):
gen = get_gen_info(state)
if source=="video":
if current_gallery_tab != 0:
return [gr.update()] * 7
file_list, file_settings_list = get_file_list(state, input_file_list)
data= event_data._data
if data!=None and isinstance(data, dict):
choice = data.get("index",0)
elif file_selected >= 0:
choice = file_selected
else:
choice = gen.get("selected",0)
choice = min(len(file_list)-1, choice)
set_file_choice(gen, file_list, choice)
files, settings_list = file_list, file_settings_list
else:
if current_gallery_tab != 1:
return [gr.update()] * 7
audio_file_list, audio_file_settings_list = get_file_list(state, unpack_audio_list(audio_files_paths), audio_files= True)
if audio_file_selected >= 0:
choice = audio_file_selected
else:
choice = gen.get("audio_selected",0)
choice = min(len(audio_file_list)-1, choice)
set_file_choice(gen, audio_file_list, choice, audio_files=True )
files, settings_list = audio_file_list, audio_file_settings_list
is_audio = False
is_image = False
is_video = False
if len(files) > 0:
if len(settings_list) <= choice:
pass
configs = settings_list[choice]
file_name = files[choice]
values = [ os.path.basename(file_name)]
labels = [ "File Name"]
misc_values= []
misc_labels = []
pp_values= []
pp_labels = []
extension = os.path.splitext(file_name)[-1]
if has_audio_file_extension(file_name):
is_audio = True
width, height = 0, 0
frames_count = fps = 1
nb_audio_tracks = 0
elif not has_video_file_extension(file_name):
img = Image.open(file_name)
width, height = img.size
is_image = True
frames_count = fps = 1
nb_audio_tracks = 0
else:
fps, width, height, frames_count = get_video_info(file_name)
is_image = False
nb_audio_tracks = extract_audio_tracks(file_name,query_only = True)
is_video = not (is_image or is_audio)
if configs != None:
video_model_name = configs.get("type", "Unknown model")
if "-" in video_model_name: video_model_name = video_model_name[video_model_name.find("-")+2:]
misc_values += [video_model_name]
misc_labels += ["Model"]
video_temporal_upsampling = configs.get("temporal_upsampling", "")
video_spatial_upsampling = configs.get("spatial_upsampling", "")
video_film_grain_intensity = configs.get("film_grain_intensity", 0)
video_film_grain_saturation = configs.get("film_grain_saturation", 0.5)
video_MMAudio_setting = configs.get("MMAudio_setting", 0)
video_MMAudio_prompt = configs.get("MMAudio_prompt", "")
video_MMAudio_neg_prompt = configs.get("MMAudio_neg_prompt", "")
video_seed = configs.get("seed", -1)
video_MMAudio_seed = configs.get("MMAudio_seed", video_seed)
if len(video_spatial_upsampling) > 0:
video_temporal_upsampling += " " + video_spatial_upsampling
if len(video_temporal_upsampling) > 0:
pp_values += [ video_temporal_upsampling ]
pp_labels += [ "Upsampling" ]
if video_film_grain_intensity > 0:
pp_values += [ f"Intensity={video_film_grain_intensity}, Saturation={video_film_grain_saturation}" ]
pp_labels += [ "Film Grain" ]
if video_MMAudio_setting != 0:
pp_values += [ f'Prompt="{video_MMAudio_prompt}", Neg Prompt="{video_MMAudio_neg_prompt}", Seed={video_MMAudio_seed}' ]
pp_labels += [ "MMAudio" ]
if configs == None or not "seed" in configs:
values += misc_values
labels += misc_labels
video_creation_date = str(get_file_creation_date(file_name))
if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")]
if is_audio:
pass
elif is_image:
values += [f"{width}x{height}"]
labels += ["Resolution"]
else:
values += [f"{width}x{height}", f"{frames_count} frames (duration={frames_count/fps:.1f} s, fps={round(fps)})"]
labels += ["Resolution", "Frames"]
if nb_audio_tracks > 0:
values +=[nb_audio_tracks]
labels +=["Nb Audio Tracks"]
values += pp_values
labels += pp_labels
values +=[video_creation_date]
labels +=["Creation Date"]
else:
video_prompt = html.escape(configs.get("prompt", "")[:1024]).replace("\n", " ")
enhanced_video_prompt = html.escape(configs.get("enhanced_prompt", "")[:1024]).replace("\n", " ")
video_video_prompt_type = configs.get("video_prompt_type", "")
video_image_prompt_type = configs.get("image_prompt_type", "")
video_audio_prompt_type = configs.get("audio_prompt_type", "")
def check(src, cond):
pos, neg = cond if isinstance(cond, tuple) else (cond, None)
if not all_letters(src, pos): return False
if neg is not None and any_letters(src, neg): return False
return True
image_outputs = configs.get("image_mode",0) > 0
map_video_prompt = {"V" : "Control Image" if image_outputs else "Control Video", ("VA", "U") : "Mask Image" if image_outputs else "Mask Video", "I" : "Reference Images"}
map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"}
map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2", "K": "Control Video Audio Track"}
video_other_prompts = [ v for s,v in map_image_prompt.items() if all_letters(video_image_prompt_type,s)] \
+ [ v for s,v in map_video_prompt.items() if check(video_video_prompt_type,s)] \
+ [ v for s,v in map_audio_prompt.items() if all_letters(video_audio_prompt_type,s)]
any_mask = "A" in video_video_prompt_type and not "U" in video_video_prompt_type
video_model_type = configs.get("model_type", "t2v")
model_family = get_model_family(video_model_type)
model_def = get_model_def(video_model_type)
multiple_submodels = model_def.get("multiple_submodels", False)
video_other_prompts = ", ".join(video_other_prompts)
if is_audio:
video_resolution = None
video_length_summary = None
video_length_label = ""
original_fps = 0
video_num_inference_steps = None
else:
video_length = configs.get("video_length", 0)
original_fps= int(video_length/frames_count*fps)
video_length_summary = f"{video_length} frames"
video_window_no = configs.get("window_no", 0)
if video_window_no > 0: video_length_summary +=f", Window no {video_window_no }"
if is_image:
video_length_summary = configs.get("batch_size", 1)
video_length_label = "Number of Images"
else:
video_length_summary += " ("
video_length_label = "Video Length"
if video_length != frames_count: video_length_summary += f"real: {frames_count} frames, "
video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)"
video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})"
video_num_inference_steps = configs.get("num_inference_steps", 0)
video_guidance_scale = configs.get("guidance_scale", None)
video_guidance2_scale = configs.get("guidance2_scale", None)
video_guidance3_scale = configs.get("guidance3_scale", None)
video_audio_guidance_scale = configs.get("audio_guidance_scale", None)
video_alt_guidance_scale = configs.get("alt_guidance_scale", None)
video_switch_threshold = configs.get("switch_threshold", 0)
video_switch_threshold2 = configs.get("switch_threshold2", 0)
video_model_switch_phase = configs.get("model_switch_phase", 1)
video_guidance_phases = configs.get("guidance_phases", 0)
video_embedded_guidance_scale = configs.get("embedded_guidance_scale", None)
video_guidance_label = "Guidance"
visible_phases = model_def.get("visible_phases", video_guidance_phases)
if model_def.get("embedded_guidance", False):
video_guidance_scale = video_embedded_guidance_scale
video_guidance_label = "Embedded Guidance Scale"
elif video_guidance_phases == 0 or visible_phases ==0:
video_guidance_scale = None
elif video_guidance_phases > 0:
if video_guidance_phases == 1 and visible_phases >=1:
video_guidance_scale = f"{video_guidance_scale}"
elif video_guidance_phases == 2 and visible_phases >=2:
if multiple_submodels:
video_guidance_scale = f"{video_guidance_scale} (High Noise), {video_guidance2_scale} (Low Noise) with Switch at Noise Level {video_switch_threshold}"
else:
video_guidance_scale = f"{video_guidance_scale}, {video_guidance2_scale}" + ("" if video_switch_threshold ==0 else " with Guidance Switch at Noise Level {video_switch_threshold}")
elif visible_phases >=3:
video_guidance_scale = f"{video_guidance_scale}, {video_guidance2_scale} & {video_guidance3_scale} with Switch at Noise Levels {video_switch_threshold} & {video_switch_threshold2}"
if multiple_submodels:
video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}"
if model_def.get("flow_shift", False):
video_flow_shift = configs.get("flow_shift", None)
else:
video_flow_shift = None
video_video_guide_outpainting = configs.get("video_guide_outpainting", "")
video_outpainting = ""
if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \
and (any_letters(video_video_prompt_type, "VFK") ) :
video_video_guide_outpainting = video_video_guide_outpainting.split(" ")
video_outpainting = f"Top={video_video_guide_outpainting[0]}%, Bottom={video_video_guide_outpainting[1]}%, Left={video_video_guide_outpainting[2]}%, Right={video_video_guide_outpainting[3]}%"
video_creation_date = str(get_file_creation_date(file_name))
if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")]
video_generation_time = format_generation_time(float(configs.get("generation_time", "0")))
video_activated_loras = configs.get("activated_loras", [])
video_loras_multipliers = configs.get("loras_multipliers", "")
video_loras_multipliers = preparse_loras_multipliers(video_loras_multipliers)
video_loras_multipliers += [""] * len(video_activated_loras)
video_activated_loras = [ f"{os.path.basename(lora)}{lora}" for lora in video_activated_loras]
video_activated_loras = [ f"
{lora}
x{multiplier if len(multiplier)>0 else '1'}
" for lora, multiplier in zip(video_activated_loras, video_loras_multipliers) ]
video_activated_loras_str = "
" + "".join(video_activated_loras) + "
" if len(video_activated_loras) > 0 else ""
values += misc_values + [video_prompt]
labels += misc_labels + ["Text Prompt"]
if len(enhanced_video_prompt):
values += [enhanced_video_prompt]
labels += ["Enhanced Text Prompt"]
if len(video_other_prompts) >0 :
values += [video_other_prompts]
labels += ["Other Prompts"]
def gen_process_list(map):
video_preprocesses = ""
for k,v in map.items():
if k in video_video_prompt_type:
process_name = processes_names[v]
video_preprocesses += process_name if len(video_preprocesses) == 0 else ", " + process_name
return video_preprocesses
video_preprocesses_in = gen_process_list(all_process_map_video_guide) if "V" else ""
video_preprocesses_out = gen_process_list(process_map_outside_mask) if "V" else ""
if "N" in video_video_prompt_type:
alt = video_preprocesses_in
video_preprocesses_in = video_preprocesses_out
video_preprocesses_out = alt
if len(video_preprocesses_in) >0 :
values += [video_preprocesses_in]
labels += [ "Process Inside Mask" if any_mask else "Preprocessing"]
if len(video_preprocesses_out) >0 :
values += [video_preprocesses_out]
labels += [ "Process Outside Mask"]
video_frames_positions = configs.get("frames_positions", "")
if "F" in video_video_prompt_type and len(video_frames_positions):
values += [video_frames_positions]
labels += [ "Injected Frames"]
if len(video_outpainting) >0:
values += [video_outpainting]
labels += ["Outpainting"]
if len(model_def.get("input_video_strength", "")) and any_letters(video_image_prompt_type, "SVL"):
values += [configs.get("input_video_strength",1)]
labels += ["Input Image Strength"]
if "G" in video_video_prompt_type and "V" in video_video_prompt_type:
values += [configs.get("denoising_strength",1)]
labels += ["Denoising Strength"]
if ("G" in video_video_prompt_type or model_def.get("mask_strength_always_enabled", False)) and "A" in video_video_prompt_type and "U" not in video_video_prompt_type:
values += [configs.get("masking_strength",1)]
labels += ["Masking Strength"]
video_sample_solver = configs.get("sample_solver", "")
if model_def.get("sample_solvers", None) is not None and len(video_sample_solver) > 0 :
values += [video_sample_solver]
labels += ["Sampler Solver"]
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_audio_guidance_scale]
labels += ["Resolution", video_length_label, "Seed", video_guidance_label, "Audio Guidance Scale"]
alt_guidance_type = model_def.get("alt_guidance", None)
if alt_guidance_type is not None and video_alt_guidance_scale is not None:
values += [video_alt_guidance_scale]
labels += [alt_guidance_type]
values += [video_flow_shift, video_num_inference_steps]
labels += ["Shift Scale", "Num Inference steps"]
video_negative_prompt = configs.get("negative_prompt", "")
if len(video_negative_prompt) > 0:
values += [video_negative_prompt]
labels += ["Negative Prompt"]
video_NAG_scale = configs.get("NAG_scale", None)
if video_NAG_scale is not None and video_NAG_scale > 1:
values += [video_NAG_scale]
labels += ["NAG Scale"]
video_apg_switch = configs.get("apg_switch", None)
if video_apg_switch is not None and video_apg_switch != 0:
values += ["on"]
labels += ["APG"]
video_motion_amplitude = configs.get("motion_amplitude", 1.)
if video_motion_amplitude != 1:
values += [video_motion_amplitude]
labels += ["Motion Amplitude"]
control_net_weight_name = model_def.get("control_net_weight_name", "")
control_net_weight = ""
if len(control_net_weight_name):
video_control_net_weight = configs.get("control_net_weight", 1)
if len(filter_letters(video_video_prompt_type, video_guide_processes))> 1:
video_control_net_weight2 = configs.get("control_net_weight2", 1)
control_net_weight = f"{control_net_weight_name} #1={video_control_net_weight}, {control_net_weight_name} #2={video_control_net_weight2}"
else:
control_net_weight = f"{control_net_weight_name}={video_control_net_weight}"
control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "")
if len(control_net_weight_alt_name) >0:
if len(control_net_weight): control_net_weight += ", "
control_net_weight += control_net_weight_alt_name + "=" + str(configs.get("control_net_weight_alt", 1))
if len(control_net_weight) > 0:
values += [control_net_weight]
labels += ["Control Net Weights"]
audio_scale_name = model_def.get("audio_scale_name", "")
if len(audio_scale_name) > 0:
values += [configs.get("audio_scale", 1)]
labels += [audio_scale_name]
video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "")
video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0)
video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0)
if len(video_skip_steps_cache_type) > 0:
video_skip_steps_cache = "TeaCache" if video_skip_steps_cache_type == "tea" else "MagCache"
video_skip_steps_cache += f" x{video_skip_steps_multiplier }"
if video_skip_steps_cache_start_step_perc >0: video_skip_steps_cache += f", Start from {video_skip_steps_cache_start_step_perc}%"
values += [ video_skip_steps_cache ]
labels += [ "Skip Steps" ]
values += pp_values
labels += pp_labels
if len(video_activated_loras_str) > 0:
values += [video_activated_loras_str]
labels += ["Loras"]
if nb_audio_tracks > 0:
values +=[nb_audio_tracks]
labels +=["Nb Audio Tracks"]
values += [ video_creation_date, video_generation_time ]
labels += [ "Creation Date", "Generation Time" ]
labels = [label for value, label in zip(values, labels) if value is not None]
values = [value for value in values if value is not None]
table_style = """
"""
rows = [f"
{label}
{value}
" for label, value in zip(labels, values)]
html_content = f"{table_style}
" + "".join(rows) + "
"
else:
html_content = get_default_video_info()
visible= len(files) > 0
return choice if source=="video" else gr.update(), html_content, gr.update(visible=visible and is_video) , gr.update(visible=visible and is_image), gr.update(visible=visible and is_audio), gr.update(visible=visible and is_video) , gr.update(visible=visible and is_video)
def convert_image(image):
from PIL import ImageOps
from typing import cast
if isinstance(image, str):
image = Image.open(image)
image = image.convert('RGB')
return cast(Image, ImageOps.exif_transpose(image))
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
if isinstance(video_in, str) and has_image_file_extension(video_in):
video_in = Image.open(video_in)
if isinstance(video_in, Image.Image):
return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0)
from shared.utils.utils import resample
import decord
decord.bridge.set_bridge(bridge)
reader = decord.VideoReader(video_in)
fps = round(reader.get_avg_fps())
if max_frames < 0:
max_frames = int(max(len(reader)/ fps * target_fps + max_frames, 0))
frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=target_fps, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos)
# print(f"frame nos: {frame_nos}")
return frames_list
# def get_resampled_video(video_in, start_frame, max_frames, target_fps):
# from torchvision.io import VideoReader
# import torch
# from shared.utils.utils import resample
# vr = VideoReader(video_in, "video")
# meta = vr.get_metadata()["video"]
# fps = round(float(meta["fps"][0]))
# duration_s = float(meta["duration"][0])
# num_src_frames = int(round(duration_s * fps)) # robust length estimate
# if max_frames < 0:
# max_frames = max(int(num_src_frames / fps * target_fps + max_frames), 0)
# frame_nos = resample(
# fps, num_src_frames,
# max_target_frames_count=max_frames,
# target_fps=target_fps,
# start_target_frame=start_frame
# )
# if len(frame_nos) == 0:
# return torch.empty((0,)) # nothing to return
# target_ts = [i / fps for i in frame_nos]
# # Read forward once, grabbing frames when we pass each target timestamp
# frames = []
# vr.seek(target_ts[0])
# idx = 0
# tol = 0.5 / fps # half-frame tolerance
# for frame in vr:
# t = float(frame["pts"]) # seconds
# if idx < len(target_ts) and t + tol >= target_ts[idx]:
# frames.append(frame["data"].permute(1,2,0)) # Tensor [H, W, C]
# idx += 1
# if idx >= len(target_ts):
# break
# return frames
def get_preprocessor(process_type, inpaint_color):
if process_type=="pose":
from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
cfg_dict = {
"DETECTION_MODEL": fl.locate_file("pose/yolox_l.onnx"),
"POSE_MODEL": fl.locate_file("pose/dw-ll_ucoco_384.onnx"),
"RESIZE_SIZE": 1024
}
anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img)
elif process_type=="depth":
from preprocessing.depth_anything_v2.depth import DepthV2VideoAnnotator
if server_config.get("depth_anything_v2_variant", "vitl") == "vitl":
cfg_dict = {
"PRETRAINED_MODEL": fl.locate_file("depth/depth_anything_v2_vitl.pth"),
'MODEL_VARIANT': 'vitl'
}
else:
cfg_dict = {
"PRETRAINED_MODEL": fl.locate_file("depth/depth_anything_v2_vitb.pth"),
'MODEL_VARIANT': 'vitb',
}
anno_ins = lambda img: DepthV2VideoAnnotator(cfg_dict).forward(img)
elif process_type=="gray":
from preprocessing.gray import GrayVideoAnnotator
cfg_dict = {}
anno_ins = lambda img: GrayVideoAnnotator(cfg_dict).forward(img)
elif process_type=="canny":
from preprocessing.canny import CannyVideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": fl.locate_file("scribble/netG_A_latest.pth")
}
anno_ins = lambda img: CannyVideoAnnotator(cfg_dict).forward(img)
elif process_type=="scribble":
from preprocessing.scribble import ScribbleVideoAnnotator
cfg_dict = {
"PRETRAINED_MODEL": fl.locate_file("scribble/netG_A_latest.pth")
}
anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img)
elif process_type=="flow":
from preprocessing.flow import FlowVisAnnotator
cfg_dict = {
"PRETRAINED_MODEL": fl.locate_file("flow/raft-things.pth")
}
anno_ins = lambda img: FlowVisAnnotator(cfg_dict).forward(img)
elif process_type=="inpaint":
anno_ins = lambda img : len(img) * [inpaint_color]
elif process_type == None or process_type in ["raw", "identity"]:
anno_ins = lambda img : img
else:
raise Exception(f"process type '{process_type}' non supported")
return anno_ins
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
if not input_video_path or max_frames <= 0:
return None, None
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
any_mask = input_mask_path != None
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
if len(video) == 0: return None
frame_height, frame_width, _ = video[0].shape
num_frames = len(video)
if any_mask:
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
num_frames = min(num_frames, len(mask_video))
if num_frames == 0: return None
video = video[:num_frames]
if any_mask:
mask_video = mask_video[:num_frames]
from preprocessing.face_preprocessor import FaceProcessor
face_processor = FaceProcessor()
face_list = []
for frame_idx in range(num_frames):
frame = video[frame_idx].cpu().numpy()
# video[frame_idx] = None
if any_mask:
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy())
# mask_video[frame_idx] = None
if (frame_width, frame_height) != mask.size:
mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS)
mask = np.array(mask)
alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
alpha_mask[mask > 127] = 1
frame = frame * alpha_mask
frame = Image.fromarray(frame)
face = face_processor.process(frame, resize_to=size)
face_list.append(face)
face_processor = None
gc.collect()
torch.cuda.empty_cache()
face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w
if pad_frames > 0:
face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2)
if args.save_masks:
from preprocessing.dwpose.pose import save_one_video
saved_faces_frames = [np.array(face) for face in face_list ]
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
return face_tensor
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
def mask_to_xyxy_box(mask):
rows, cols = np.where(mask == 255)
xmin = min(cols)
xmax = max(cols) + 1
ymin = min(rows)
ymax = max(rows) + 1
xmin = max(xmin, 0)
ymin = max(ymin, 0)
xmax = min(xmax, mask.shape[1])
ymax = min(ymax, mask.shape[0])
box = [xmin, ymin, xmax, ymax]
box = [int(x) for x in box]
return box
inpaint_color = int(inpaint_color)
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
if not input_video_path or max_frames <= 0:
return None, None
any_mask = input_mask_path != None
pose_special = "pose" in process_type
any_identity_mask = False
if process_type == "identity":
any_identity_mask = True
negate_mask = False
process_outside_mask = None
preproc = get_preprocessor(process_type, inpaint_color)
preproc2 = None
if process_type2 != None:
preproc2 = get_preprocessor(process_type2, inpaint_color) if process_type != process_type2 else preproc
if process_outside_mask == process_type :
preproc_outside = preproc
elif preproc2 != None and process_outside_mask == process_type2 :
preproc_outside = preproc2
else:
preproc_outside = get_preprocessor(process_outside_mask, inpaint_color)
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
if any_mask:
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
if len(video) == 0 or any_mask and len(mask_video) == 0:
return None, None
if fit_crop and outpainting_dims != None:
fit_crop = False
fit_canvas = 0 if fit_canvas is not None else None
frame_height, frame_width, _ = video[0].shape
if outpainting_dims != None:
if fit_canvas != None:
frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
else:
frame_height, frame_width = height, width
if fit_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
if outpainting_dims != None:
final_height, final_width = height, width
height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
if any_mask:
num_frames = min(len(video), len(mask_video))
else:
num_frames = len(video)
if any_identity_mask:
any_mask = True
proc_list =[]
proc_list_outside =[]
proc_mask = []
# for frame_idx in range(num_frames):
def prep_prephase(frame_idx):
frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy()
if fit_crop:
frame = rescale_and_crop(frame, width, height)
else:
frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS)
frame = np.array(frame)
if any_mask:
if any_identity_mask:
mask = np.full( (height, width, 3), 0, dtype= np.uint8)
else:
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy()
if fit_crop:
mask = rescale_and_crop(mask, width, height)
else:
mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS)
mask = np.array(mask)
if len(mask.shape) == 3 and mask.shape[2] == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
_, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY)
original_mask = mask.copy()
if expand_scale != 0:
kernel_size = abs(expand_scale)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
mask = op_expand(mask, kernel, iterations=3)
if to_bbox and np.sum(mask == 255) > 0 : #or True
x0, y0, x1, y1 = mask_to_xyxy_box(mask)
mask = mask * 0
mask[y0:y1, x0:x1] = 255
if negate_mask:
mask = 255 - mask
if pose_special:
original_mask = 255 - original_mask
if pose_special and any_mask:
target_frame = np.where(original_mask[..., None], frame, 0)
else:
target_frame = frame
if any_mask:
return (target_frame, frame, mask)
else:
return (target_frame, None, None)
max_workers = get_default_workers()
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers, in_place= True)
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
for frame_idx, frame_group in enumerate(proc_lists):
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
prep_prephase = None
video = None
mask_video = None
if preproc2 != None:
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers)
#### to be finished ...or not
proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers)
if any_mask:
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers)
else:
proc_list_outside = proc_mask = len(proc_list) * [None]
masked_frames = []
masks = []
for frame_no, (processed_img, processed_img_outside, mask) in enumerate(zip(proc_list, proc_list_outside, proc_mask)):
if any_mask :
masked_frame = np.where(mask[..., None], processed_img, processed_img_outside)
if process_outside_mask != None:
mask = np.full_like(mask, 255)
mask = torch.from_numpy(mask)
if RGB_Mask:
mask = mask.unsqueeze(-1).repeat(1,1,3)
if outpainting_dims != None:
full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
mask = full_frame
masks.append(mask[:, :, 0:1].clone())
else:
masked_frame = processed_img
if isinstance(masked_frame, int):
masked_frame= np.full( (height, width, 3), inpaint_color, dtype= np.uint8)
masked_frame = torch.from_numpy(masked_frame)
if masked_frame.shape[-1] == 1:
masked_frame = masked_frame.repeat(1,1,3).to(torch.uint8)
if outpainting_dims != None:
full_frame= torch.full( (final_height, final_width, masked_frame.shape[-1]), inpaint_color, dtype= torch.uint8, device= masked_frame.device)
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = masked_frame
masked_frame = full_frame
masked_frames.append(masked_frame)
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
# if args.save_masks:
# from preprocessing.dwpose.pose import save_one_video
# saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
# save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
# if any_mask:
# saved_masks = [mask.cpu().numpy() for mask in masks ]
# save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
preproc = None
preproc_outside = None
gc.collect()
torch.cuda.empty_cache()
if pad_frames > 0:
masked_frames = masked_frames[0] * pad_frames + masked_frames
if any_mask: masked_frames = masks[0] * pad_frames + masks
masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.)
masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None
return masked_frames, masks
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps)
if len(frames_list) == 0:
return None
if fit_canvas == None or fit_crop:
new_height = height
new_width = width
else:
frame_height, frame_width, _ = frames_list[0].shape
if fit_canvas :
scale1 = min(height / frame_height, width / frame_width)
scale2 = min(height / frame_width, width / frame_height)
scale = max(scale1, scale2)
else:
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
new_height = (int(frame_height * scale) // block_size) * block_size
new_width = (int(frame_width * scale) // block_size) * block_size
processed_frames_list = []
for frame in frames_list:
frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8))
if fit_crop:
frame = rescale_and_crop(frame, new_width, new_height)
else:
frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
processed_frames_list.append(frame)
np_frames = [np.array(frame) for frame in processed_frames_list]
# from preprocessing.dwpose.pose import save_one_video
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
torch_frames = []
for np_frame in np_frames:
torch_frame = torch.from_numpy(np_frame)
torch_frames.append(torch_frame)
return torch.stack(torch_frames)
def parse_keep_frames_video_guide(keep_frames, video_length):
def absolute(n):
if n==0:
return 0
elif n < 0:
return max(0, video_length + n)
else:
return min(n-1, video_length-1)
keep_frames = keep_frames.strip()
if len(keep_frames) == 0:
return [True] *video_length, ""
frames =[False] *video_length
error = ""
sections = keep_frames.split(" ")
for section in sections:
section = section.strip()
if ":" in section:
parts = section.split(":")
if not is_integer(parts[0]):
error =f"Invalid integer {parts[0]}"
break
start_range = absolute(int(parts[0]))
if not is_integer(parts[1]):
error =f"Invalid integer {parts[1]}"
break
end_range = absolute(int(parts[1]))
for i in range(start_range, end_range + 1):
frames[i] = True
else:
if not is_integer(section) or int(section) == 0:
error =f"Invalid integer {section}"
break
index = absolute(int(section))
frames[index] = True
if len(error ) > 0:
return [], error
for i in range(len(frames)-1, 0, -1):
if frames[i]:
break
frames= frames[0: i+1]
return frames, error
def perform_temporal_upsampling(sample, previous_last_frame, temporal_upsampling, fps):
exp = 0
if temporal_upsampling == "rife2":
exp = 1
elif temporal_upsampling == "rife4":
exp = 2
output_fps = fps
if exp > 0:
from postprocessing.rife.inference import temporal_interpolation
if previous_last_frame != None:
sample = torch.cat([previous_last_frame, sample], dim=1)
previous_last_frame = sample[:, -1:].clone()
sample = temporal_interpolation( fl.locate_file("flownet.pkl"), sample, exp, device=processing_device)
sample = sample[:, 1:]
else:
sample = temporal_interpolation( fl.locate_file("flownet.pkl"), sample, exp, device=processing_device)
previous_last_frame = sample[:, -1:].clone()
output_fps = output_fps * 2**exp
return sample, previous_last_frame, output_fps
def perform_spatial_upsampling(sample, spatial_upsampling):
from shared.utils.utils import resize_lanczos
if spatial_upsampling == "vae2":
return sample
method = None
if spatial_upsampling == "vae1":
scale = 0.5
method = Image.Resampling.BICUBIC
elif spatial_upsampling == "lanczos1.5":
scale = 1.5
else:
scale = 2
h, w = sample.shape[-2:]
h *= scale
h = round(h/16) * 16
w *= scale
w = round(w/16) * 16
h = int(h)
w = int(w)
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
def upsample_frames(frame):
return resize_lanczos(frame, h, w, method).unsqueeze(1)
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers(), in_place=True), dim=1)
frames_to_upsample = None
return sample
def any_audio_track(model_type):
model_def = get_model_def(model_type)
if not model_def:
return False
return ( model_def.get("returns_audio", False) or model_def.get("any_audio_prompt", False) )
def get_available_filename(target_path, video_source, suffix = "", force_extension = None):
name, extension = os.path.splitext(os.path.basename(video_source))
if force_extension != None:
extension = force_extension
name+= suffix
full_path= os.path.join(target_path, f"{name}{extension}")
if not os.path.exists(full_path):
return full_path
counter = 2
while True:
full_path= os.path.join(target_path, f"{name}({counter}){extension}")
if not os.path.exists(full_path):
return full_path
counter += 1
def set_seed(seed):
import random
seed = random.randint(0, 999999999) if seed == None or seed < 0 else seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
return seed
def edit_video(
send_cmd,
state,
mode,
video_source,
seed,
temporal_upsampling,
spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting,
MMAudio_prompt,
MMAudio_neg_prompt,
repeat_generation,
audio_source,
**kwargs
):
gen = get_gen_info(state)
if gen.get("abort", False): return
abort = False
MMAudio_setting = MMAudio_setting or 0
configs, _ , _ = get_settings_from_file(state, video_source, False, False, False)
if configs == None: configs = { "type" : get_model_record("Post Processing") }
has_already_audio = False
audio_tracks = []
if MMAudio_setting == 0:
audio_tracks, audio_metadata = extract_audio_tracks(video_source)
has_already_audio = len(audio_tracks) > 0
if audio_source is not None:
audio_tracks = [audio_source]
with lock:
file_list = gen["file_list"]
file_settings_list = gen["file_settings_list"]
seed = set_seed(seed)
from shared.utils.utils import get_video_info
fps, width, height, frames_count = get_video_info(video_source)
frames_count = min(frames_count, max_source_video_frames)
sample = None
if mode == "edit_postprocessing":
if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 or film_grain_intensity > 0:
send_cmd("progress", [0, get_latest_status(state,"Upsampling" if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 else "Adding Film Grain" )])
sample = get_resampled_video(video_source, 0, max_source_video_frames, fps)
sample = sample.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
frames_count = sample.shape[1]
output_fps = round(fps)
if len(temporal_upsampling) > 0:
sample, previous_last_frame, output_fps = perform_temporal_upsampling(sample, None, temporal_upsampling, fps)
configs["temporal_upsampling"] = temporal_upsampling
frames_count = sample.shape[1]
if len(spatial_upsampling) > 0:
sample = perform_spatial_upsampling(sample, spatial_upsampling )
configs["spatial_upsampling"] = spatial_upsampling
if film_grain_intensity > 0:
from postprocessing.film_grain import add_film_grain
sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation)
configs["film_grain_intensity"] = film_grain_intensity
configs["film_grain_saturation"] = film_grain_saturation
else:
output_fps = round(fps)
mmaudio_enabled, mmaudio_mode, mmaudio_persistence, mmaudio_model_name, mmaudio_model_path = get_mmaudio_settings(server_config)
any_mmaudio = MMAudio_setting != 0 and mmaudio_enabled and frames_count >=output_fps
if any_mmaudio: download_mmaudio()
tmp_path = None
any_change = False
if sample != None:
video_path =get_available_filename(save_path, video_source, "_tmp") if any_mmaudio or has_already_audio else get_available_filename(save_path, video_source, "_post")
save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container=server_config.get("video_container", "mp4"))
if any_mmaudio or has_already_audio: tmp_path = video_path
any_change = True
else:
video_path = video_source
repeat_no = 0
extra_generation = 0
initial_total_windows = 0
any_change_initial = any_change
while not gen.get("abort", False):
any_change = any_change_initial
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
total_generation = repeat_generation + extra_generation
gen["total_generation"] = total_generation
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
suffix = "" if "_post" in video_source else "_post"
if audio_source is not None:
audio_prompt_type = configs.get("audio_prompt_type", "")
if not "T" in audio_prompt_type:audio_prompt_type += "T"
configs["audio_prompt_type"] = audio_prompt_type
any_change = True
if any_mmaudio:
send_cmd("progress", [0, get_latest_status(state,"MMAudio Soundtrack Generation")])
from postprocessing.mmaudio.mmaudio import video_to_audio
new_video_path = get_available_filename(save_path, video_source, suffix)
video_to_audio(video_path, prompt = MMAudio_prompt, negative_prompt = MMAudio_neg_prompt, seed = seed, num_steps = 25, cfg_strength = 4.5, duration= frames_count /output_fps, save_path = new_video_path , persistent_models = mmaudio_persistence == MMAUDIO_PERSIST_RAM, verboseLevel = verbose_level, model_name = mmaudio_model_name, model_path = mmaudio_model_path)
configs["MMAudio_setting"] = MMAudio_setting
configs["MMAudio_prompt"] = MMAudio_prompt
configs["MMAudio_neg_prompt"] = MMAudio_neg_prompt
configs["MMAudio_seed"] = seed
any_change = True
elif len(audio_tracks) > 0:
# combine audio files and new video file
new_video_path = get_available_filename(save_path, video_source, suffix)
combine_video_with_audio_tracks(video_path, audio_tracks, new_video_path, audio_metadata=audio_metadata)
else:
new_video_path = video_path
if tmp_path != None:
os.remove(tmp_path)
if any_change:
if mode == "edit_remux":
print(f"Remuxed Video saved to Path: "+ new_video_path)
else:
print(f"Postprocessed video saved to Path: "+ new_video_path)
with lock:
file_list.append(new_video_path)
file_settings_list.append(configs)
if configs != None:
from shared.utils.video_metadata import extract_source_images, save_video_metadata
temp_images_path = get_available_filename(save_path, video_source, force_extension= ".temp")
embedded_images = extract_source_images(video_source, temp_images_path)
save_video_metadata(new_video_path, configs, embedded_images)
if os.path.isdir(temp_images_path):
shutil.rmtree(temp_images_path, ignore_errors= True)
send_cmd("output")
seed = set_seed(-1)
if has_already_audio:
cleanup_temp_audio_files(audio_tracks)
clear_status(state)
def get_overridden_attention(model_type):
model_def = get_model_def(model_type)
override_attention = model_def.get("attention", None)
if override_attention is None: return None
gpu_version = gpu_major * 10 + gpu_minor
attention_list = match_nvidia_architecture(override_attention, gpu_version)
if len(attention_list ) == 0: return None
override_attention = attention_list[0]
if override_attention is not None and override_attention not in attention_modes_supported: return None
return override_attention
def get_transformer_loras(model_type):
model_def = get_model_def(model_type)
transformer_loras_filenames = get_model_recursive_prop(model_type, "loras", return_list=True)
lora_dir = get_lora_dir(model_type)
transformer_loras_filenames = [ os.path.join(lora_dir, os.path.basename(filename)) for filename in transformer_loras_filenames]
transformer_loras_multipliers = get_model_recursive_prop(model_type, "loras_multipliers", return_list=True) + [1.] * len(transformer_loras_filenames)
transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)]
return transformer_loras_filenames, transformer_loras_multipliers
class DynamicClass:
def __init__(self, **kwargs):
self._data = {}
# Preassign default properties from kwargs
for key, value in kwargs.items():
self._data[key] = value
def __getattr__(self, name):
if name in self._data:
return self._data[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def __setattr__(self, name, value):
if name.startswith('_'):
super().__setattr__(name, value)
else:
if not hasattr(self, '_data'):
super().__setattr__('_data', {})
self._data[name] = value
def assign(self, **kwargs):
"""Assign multiple properties at once"""
for key, value in kwargs.items():
self._data[key] = value
return self # For method chaining
def update(self, dict):
"""Alias for assign() - more dict-like"""
return self.assign(**dict)
def process_prompt_enhancer(model_def, prompt_enhancer, original_prompts, image_start, original_image_refs, is_image, audio_only, seed, prompt_enhancer_instructions = None ):
prompt_enhancer_instructions = model_def.get("image_prompt_enhancer_instructions" if is_image else "video_prompt_enhancer_instructions", None)
text_encoder_max_tokens = model_def.get("image_prompt_enhancer_max_tokens" if is_image else "video_prompt_enhancer_max_tokens", 256)
from models.ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt
prompt_images = []
if "I" in prompt_enhancer:
if image_start != None:
if not isinstance(image_start, list): image_start= [image_start]
prompt_images += image_start
if original_image_refs != None:
prompt_images += original_image_refs[:1]
prompt_images = [Image.open(img) if isinstance(img,str) else img for img in prompt_images]
if len(original_prompts) == 0 and not "T" in prompt_enhancer:
return None
else:
from shared.utils.utils import seed_everything
seed = seed_everything(seed)
# for i, original_prompt in enumerate(original_prompts):
prompts = generate_cinematic_prompt(
prompt_enhancer_image_caption_model,
prompt_enhancer_image_caption_processor,
prompt_enhancer_llm_model,
prompt_enhancer_llm_tokenizer,
original_prompts if "T" in prompt_enhancer else ["an image"],
prompt_images if len(prompt_images) > 0 else None,
video_prompt = not is_image,
text_prompt = audio_only,
max_new_tokens=text_encoder_max_tokens,
prompt_enhancer_instructions = prompt_enhancer_instructions,
)
return prompts
def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, override_profile, progress=gr.Progress()):
global enhancer_offloadobj
prefix = "#!PROMPT!:"
model_type = get_state_model_type(state)
inputs = get_model_settings(state, model_type)
original_prompts = inputs["prompt"]
original_prompts, errors = prompt_parser.process_template(original_prompts, keep_comments= True)
if len(errors) > 0:
gr.Info("Error processing prompt template: " + errors)
return gr.update(), gr.update()
original_prompts = original_prompts.replace("\r", "").split("\n")
prompts_to_process = []
skip_next_non_comment = False
for prompt in original_prompts:
if prompt.startswith(prefix):
new_prompt = prompt[len(prefix):].strip()
prompts_to_process.append(new_prompt)
skip_next_non_comment = True
else:
if not prompt.startswith("#") and not skip_next_non_comment and len(prompt) > 0:
prompts_to_process.append(prompt)
skip_next_non_comment = False
original_prompts = prompts_to_process
num_prompts = len(original_prompts)
image_start = inputs["image_start"]
if image_start is None or not "I" in prompt_enhancer:
image_start = [None] * num_prompts
else:
image_start = [convert_image(img[0]) for img in image_start]
if len(image_start) == 1:
image_start = image_start * num_prompts
else:
if multi_images_gen_type !=1:
gr.Info("On Demand Prompt Enhancer with multiple Start Images requires that option 'Match images and text prompts' is set")
return gr.update(), gr.update()
if len(image_start) != num_prompts:
gr.Info("On Demand Prompt Enhancer supports only mutiple Start Images if their number matches the number of Text Prompts")
return gr.update(), gr.update()
if enhancer_offloadobj is None:
status = "Please Wait While Loading Prompt Enhancer"
progress(0, status)
kwargs = {}
pipe = {}
download_models()
model_def = get_model_def(get_state_model_type(state))
audio_only = model_def.get("audio_only", False)
acquire_GPU_ressources(state, "prompt_enhancer", "Prompt Enhancer")
if enhancer_offloadobj is None:
_, mmgp_profile = init_pipe(pipe, kwargs, override_profile)
setup_prompt_enhancer(pipe, kwargs)
enhancer_offloadobj = offload.profile(pipe, profile_no= mmgp_profile, **kwargs)
original_image_refs = inputs["image_refs"]
if original_image_refs is not None:
original_image_refs = [ convert_image(tup[0]) for tup in original_image_refs ]
is_image = inputs["image_mode"] > 0
seed = inputs["seed"]
seed = set_seed(seed)
enhanced_prompts = []
for i, (one_prompt, one_image) in enumerate(zip(original_prompts, image_start)):
start_images = [one_image] if one_image is not None else None
status = f'Please Wait While Enhancing Prompt' if num_prompts==1 else f'Please Wait While Enhancing Prompt #{i+1}'
progress((i , num_prompts), desc=status, total= num_prompts)
try:
enhanced_prompt = process_prompt_enhancer(model_def, prompt_enhancer, [one_prompt], start_images, original_image_refs, is_image, audio_only, seed)
except Exception as e:
enhancer_offloadobj.unload_all()
release_GPU_ressources(state, "prompt_enhancer")
raise gr.Error(e)
if enhanced_prompt is not None:
enhanced_prompt = enhanced_prompt[0].replace("\n", "").replace("\r", "")
enhanced_prompts.append(prefix + " " + one_prompt)
enhanced_prompts.append(enhanced_prompt)
enhancer_offloadobj.unload_all()
release_GPU_ressources(state, "prompt_enhancer")
prompt = '\n'.join(enhanced_prompts)
if num_prompts > 1:
gr.Info(f'{num_prompts} Prompts have been Enhanced')
else:
gr.Info(f'Prompt "{original_prompts[0][:100]}" has been enhanced')
return prompt, prompt
def get_outpainting_dims(video_guide_outpainting):
return None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
def truncate_audio(generated_audio, trim_video_frames_beginning, trim_video_frames_end, video_fps, audio_sampling_rate):
samples_per_frame = audio_sampling_rate / video_fps
start = int(trim_video_frames_beginning * samples_per_frame)
end = len(generated_audio) - int(trim_video_frames_end * samples_per_frame)
return generated_audio[start:end if end > 0 else None]
def slice_audio_window(audio_path, start_frame, num_frames, fps, output_dir, suffix=""):
import soundfile as sf
import numpy as np
start_sec = float(start_frame) / float(fps)
duration_sec = float(num_frames) / float(fps)
with sf.SoundFile(audio_path) as audio_file:
sample_rate = audio_file.samplerate
channels = audio_file.channels
total_frames = len(audio_file)
start_sample = int(round(start_sec * sample_rate))
pad_start = 0
if start_sample < 0:
pad_start = -start_sample
start_sample = 0
frames_to_read = int(round(duration_sec * sample_rate))
if start_sample > total_frames:
data = np.zeros((0, channels), dtype=np.float32)
else:
audio_file.seek(min(start_sample, total_frames))
data = audio_file.read(frames_to_read, dtype="float32", always_2d=True)
if pad_start > 0:
data = np.concatenate([np.zeros((pad_start, channels), dtype=np.float32), data], axis=0)
target_frames = pad_start + frames_to_read
if data.shape[0] < target_frames:
pad_end = target_frames - data.shape[0]
data = np.concatenate([data, np.zeros((pad_end, channels), dtype=np.float32)], axis=0)
if data.ndim == 2:
data = data.T
return data, sample_rate
def extract_audio_track_to_wav(video_path, output_dir, suffix=""):
if not video_path:
return None
try:
video_path = os.fspath(video_path)
except TypeError:
return None
try:
import ffmpeg
except Exception:
return None
output_dir = output_dir or os.path.dirname(video_path) or "."
output_path = get_available_filename(output_dir, video_path, suffix=suffix, force_extension=".wav")
try:
(
ffmpeg.input(video_path)
.output(output_path, **{"map": "0:a:0", "acodec": "pcm_s16le"})
.overwrite_output()
.run(quiet=True)
)
except Exception:
return None
return output_path
def write_wav_file(file_path, audio_data, sampling_rate):
if audio_data is None:
return
import numpy as np
import soundfile as sf
audio = np.asarray(audio_data)
if audio.ndim > 1 and audio.shape[0] == 1:
audio = audio.squeeze(0)
if audio.ndim == 2 and audio.shape[0] in (1, 2) and audio.shape[1] > audio.shape[0]:
audio = audio.T
if audio.ndim == 0:
audio = audio.reshape(1)
if audio.ndim == 2 and audio.shape[1] == 1:
audio = audio[:, 0]
if not np.issubdtype(audio.dtype, np.floating):
audio = audio.astype(np.float32)
sampling_rate = int(sampling_rate)
audio = np.ascontiguousarray(audio)
try:
sf.write(file_path, audio, sampling_rate)
except sf.LibsndfileError:
sf.write(file_path, audio, sampling_rate, format="WAV", subtype="FLOAT")
def custom_preprocess_video_with_mask(model_handler, base_model_type, pre_video_guide, video_guide, video_mask, height, width, max_frames, start_frame, fit_canvas, fit_crop, target_fps, block_size, expand_scale, video_prompt_type):
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
max_workers = get_default_workers()
if not video_guide or max_frames <= 0:
return None, None, None, None
video_guide = get_resampled_video(video_guide, start_frame, max_frames, target_fps).permute(-1, 0, 1, 2)
video_guide = video_guide / 127.5 - 1.
any_mask = video_mask is not None
if video_mask is not None:
video_mask = get_resampled_video(video_mask, start_frame, max_frames, target_fps).permute(-1, 0, 1, 2)
video_mask = video_mask[:1] / 255.
# Mask filtering: resize, binarize, expand mask and keep only masked areas of video guide
if any_mask:
invert_mask = "N" in video_prompt_type
import concurrent.futures
tgt_h, tgt_w = video_guide.shape[2], video_guide.shape[3]
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (abs(expand_scale), abs(expand_scale))) if expand_scale != 0 else None
op = (cv2.dilate if expand_scale > 0 else cv2.erode) if expand_scale != 0 else None
def process_mask(idx):
m = (video_mask[0, idx].numpy() * 255).astype(np.uint8)
if m.shape[0] != tgt_h or m.shape[1] != tgt_w:
m = cv2.resize(m, (tgt_w, tgt_h), interpolation=cv2.INTER_NEAREST)
_, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) # binarize grey values
if op: m = op(m, kernel, iterations=3)
if invert_mask:
return torch.from_numpy((m <= 127).astype(np.float32))
else:
return torch.from_numpy((m > 127).astype(np.float32))
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
video_mask = torch.stack([f.result() for f in [ex.submit(process_mask, i) for i in range(video_mask.shape[1])]]).unsqueeze(0)
video_guide = video_guide * video_mask + (-1) * (1-video_mask)
if video_guide.shape[1] == 0 or any_mask and video_mask.shape[1] == 0:
return None, None, None, None
video_guide_processed, video_guide_processed2, video_mask_processed, video_mask_processed2 = model_handler.custom_preprocess(base_model_type = base_model_type, pre_video_guide = pre_video_guide, video_guide = video_guide, video_mask = video_mask, height = height, width = width, fit_canvas = fit_canvas , fit_crop = fit_crop, target_fps = target_fps, block_size = block_size, max_workers = max_workers, expand_scale = expand_scale, video_prompt_type=video_prompt_type)
# if pad_frames > 0:
# masked_frames = masked_frames[0] * pad_frames + masked_frames
# if any_mask: masked_frames = masks[0] * pad_frames + masks
return video_guide_processed, video_guide_processed2, video_mask_processed, video_mask_processed2
def generate_video(
task,
send_cmd,
image_mode,
prompt,
negative_prompt,
resolution,
video_length,
batch_size,
seed,
force_fps,
num_inference_steps,
guidance_scale,
guidance2_scale,
guidance3_scale,
switch_threshold,
switch_threshold2,
guidance_phases,
model_switch_phase,
alt_guidance_scale,
audio_guidance_scale,
audio_scale,
flow_shift,
sample_solver,
embedded_guidance_scale,
repeat_generation,
multi_prompts_gen_type,
multi_images_gen_type,
skip_steps_cache_type,
skip_steps_multiplier,
skip_steps_start_step_perc,
activated_loras,
loras_multipliers,
image_prompt_type,
image_start,
image_end,
model_mode,
video_source,
keep_frames_video_source,
input_video_strength,
video_prompt_type,
image_refs,
frames_positions,
video_guide,
image_guide,
keep_frames_video_guide,
denoising_strength,
masking_strength,
video_guide_outpainting,
video_mask,
image_mask,
control_net_weight,
control_net_weight2,
control_net_weight_alt,
motion_amplitude,
mask_expand,
audio_guide,
audio_guide2,
custom_guide,
audio_source,
audio_prompt_type,
speakers_locations,
sliding_window_size,
sliding_window_overlap,
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
image_refs_relative_size,
remove_background_images_ref,
temporal_upsampling,
spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting,
MMAudio_prompt,
MMAudio_neg_prompt,
RIFLEx_setting,
NAG_scale,
NAG_tau,
NAG_alpha,
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
apg_switch,
cfg_star_switch,
cfg_zero_step,
prompt_enhancer,
min_frames_if_references,
override_profile,
pace,
exaggeration,
temperature,
output_filename,
state,
model_type,
mode,
plugin_data=None,
):
def remove_temp_filenames(temp_filenames_list):
for temp_filename in temp_filenames_list:
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
global wan_model, offloadobj, reload_needed
gen = get_gen_info(state)
torch.set_grad_enabled(False)
if mode.startswith("edit_"):
edit_video(send_cmd, state, mode, video_source, seed, temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation, audio_source)
return
with lock:
file_list = gen["file_list"]
file_settings_list = gen["file_settings_list"]
audio_file_list = gen["audio_file_list"]
audio_file_settings_list = gen["audio_file_settings_list"]
model_def = get_model_def(model_type)
is_image = image_mode > 0
audio_only = model_def.get("audio_only", False)
set_video_prompt_type = model_def.get("set_video_prompt_type", None)
if set_video_prompt_type is not None:
video_prompt_type = add_to_sequence(video_prompt_type, set_video_prompt_type)
if is_image:
if not model_def.get("custom_video_length", False):
if min_frames_if_references >= 1000:
video_length = min_frames_if_references - 1000
else:
video_length = min_frames_if_references if "I" in video_prompt_type or "V" in video_prompt_type else 1
else:
batch_size = 1
temp_filenames_list = []
if image_guide is not None and isinstance(image_guide, Image.Image):
video_guide = image_guide
image_guide = None
if image_mask is not None and isinstance(image_mask, Image.Image):
video_mask = image_mask
image_mask = None
if model_def.get("no_background_removal", False): remove_background_images_ref = 0
base_model_type = get_base_model_type(model_type)
model_handler = get_model_handler(base_model_type)
block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16
if "P" in preload_model_policy and not "U" in preload_model_policy:
while wan_model == None:
time.sleep(1)
vae_upsampling = model_def.get("vae_upsampler", None)
model_kwargs = {}
if vae_upsampling is not None:
new_vae_upsampling = None if image_mode not in vae_upsampling or "vae" not in spatial_upsampling else spatial_upsampling
old_vae_upsampling = None if reload_needed or wan_model is None or not hasattr(wan_model, "vae") or not hasattr(wan_model.vae, "upsampling_set") else wan_model.vae.upsampling_set
reload_needed = reload_needed or old_vae_upsampling != new_vae_upsampling
if new_vae_upsampling: model_kwargs = {"VAE_upsampling": new_vae_upsampling}
if model_type != transformer_type or reload_needed or override_profile>0 and override_profile != loaded_profile or override_profile<0 and default_profile != loaded_profile:
wan_model = None
release_model()
send_cmd("status", f"Loading model {get_model_name(model_type)}...")
wan_model, offloadobj = load_models(model_type, override_profile, **model_kwargs)
send_cmd("status", "Model loaded")
reload_needed= False
if args.test:
skip_gemma_save = os.environ.get("WAN2GP_SKIP_GEMMA_SAVE", "").strip().lower() in (
"1",
"true",
"yes",
"on",
)
send_cmd("info", "Test mode: model loaded, skipping generation.")
return
overridden_attention = get_overridden_attention(model_type)
# if overridden_attention is not None and overridden_attention != attention_mode: print(f"Attention mode has been overriden to {overridden_attention} for model type '{model_type}'")
attn = overridden_attention if overridden_attention is not None else attention_mode
if attn == "auto":
attn = get_auto_attention()
elif not attn in attention_modes_supported:
send_cmd("info", f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
send_cmd("exit")
return
width, height = resolution.split("x")
width, height = int(width) // block_size * block_size, int(height) // block_size * block_size
default_image_size = (height, width)
if slg_switch == 0:
slg_layers = None
offload.shared_state["_attention"] = attn
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
if hasattr(wan_model, "vae") and hasattr(wan_model.vae, "get_VAE_tile_size"):
get_tile_size = wan_model.vae.get_VAE_tile_size
try:
sig = inspect.signature(get_tile_size)
except (TypeError, ValueError):
sig = None
if sig is not None and "output_height" in sig.parameters:
VAE_tile_size = get_tile_size(
vae_config,
device_mem_capacity,
server_config.get("vae_precision", "16") == "32",
output_height=height,
output_width=width,
)
else:
VAE_tile_size = get_tile_size(
vae_config,
device_mem_capacity,
server_config.get("vae_precision", "16") == "32",
)
else:
VAE_tile_size = None
trans = get_transformer_model(wan_model)
trans2 = get_transformer_model(wan_model, 2)
audio_sampling_rate = 16000
if multi_prompts_gen_type == 2:
prompts = [prompt]
else:
prompts = prompt.split("\n")
prompts = [part.strip() for part in prompts if len(prompt)>0]
parsed_keep_frames_video_source= max_source_video_frames if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
transformer_loras_filenames, transformer_loras_multipliers = get_transformer_loras(model_type)
if guidance_phases < 1: guidance_phases = 1
if transformer_loras_filenames != None:
loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(transformer_loras_multipliers, len(transformer_loras_filenames), num_inference_steps, nb_phases = guidance_phases, model_switch_phase= model_switch_phase )
if len(errors) > 0: raise Exception(f"Error parsing Transformer Loras: {errors}")
loras_selected = transformer_loras_filenames
if hasattr(wan_model, "get_loras_transformer"):
extra_loras_transformers, extra_loras_multipliers = wan_model.get_loras_transformer(get_model_recursive_prop, **locals())
loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(extra_loras_multipliers, len(extra_loras_transformers), num_inference_steps, nb_phases = guidance_phases, merge_slist= loras_slists, model_switch_phase= model_switch_phase )
if len(errors) > 0: raise Exception(f"Error parsing Extra Transformer Loras: {errors}")
loras_selected += extra_loras_transformers
if len(activated_loras) > 0:
loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases = guidance_phases, merge_slist= loras_slists, model_switch_phase= model_switch_phase )
if len(errors) > 0: raise Exception(f"Error parsing Loras: {errors}")
lora_dir = get_lora_dir(model_type)
errors = check_loras_exist(model_type, activated_loras, True, send_cmd)
if len(errors) > 0 : raise gr.Error(errors)
loras_selected += [ os.path.join(lora_dir, os.path.basename(lora)) for lora in activated_loras]
if hasattr(wan_model, "get_trans_lora"):
trans_lora, trans2_lora = wan_model.get_trans_lora()
else:
trans_lora, trans2_lora = trans, trans2
if len(loras_selected) > 0:
pinnedLora = loaded_profile !=5 # and transformer_loras_filenames == None False # # #
preprocess_target = trans_lora if trans_lora is not None else trans
split_linear_modules_map = getattr(preprocess_target, "split_linear_modules_map", None)
offload.load_loras_into_model(
trans_lora,
loras_selected,
loras_list_mult_choices_nums,
activate_all_loras=True,
preprocess_sd=get_loras_preprocessor(preprocess_target, base_model_type),
pinnedLora=pinnedLora,
maxReservedLoras=server_config.get("max_reserved_loras", -1),
split_linear_modules_map=split_linear_modules_map,
)
errors = trans_lora._loras_errors
if len(errors) > 0:
error_files = [msg for _ , msg in errors]
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
if trans2_lora is not None:
offload.sync_models_loras(trans_lora, trans2_lora)
seed = None if seed == -1 else seed
# negative_prompt = "" # not applicable in the inference
model_filename = get_model_filename(base_model_type)
_, _, latent_size = get_model_min_frames_and_step(model_type)
video_length = (video_length -1) // latent_size * latent_size + 1
if sliding_window_size !=0:
sliding_window_size = (sliding_window_size -1) // latent_size * latent_size + 1
if sliding_window_overlap !=0:
sliding_window_defaults = model_def.get("sliding_window_defaults", {})
if sliding_window_defaults.get("overlap_default", 0) != sliding_window_overlap:
sliding_window_overlap = (sliding_window_overlap -1) // latent_size * latent_size + 1
if sliding_window_discard_last_frames !=0:
sliding_window_discard_last_frames = sliding_window_discard_last_frames // latent_size * latent_size
current_video_length = video_length
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
hunyuan_custom = "hunyuan_video_custom" in model_filename
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
fantasy = base_model_type in ["fantasy"]
multitalk = model_def.get("multitalk_class", False)
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
speakers_bboxes, error = parse_speakers_locations(speakers_locations)
else:
speakers_bboxes = None
if "L" in image_prompt_type:
if len(file_list)>0:
video_source = file_list[-1]
else:
mp4_files = glob.glob(os.path.join(save_path, "*.mp4"))
video_source = max(mp4_files, key=os.path.getmtime) if mp4_files else None
fps = 1 if is_image else get_computed_fps(force_fps, base_model_type , video_guide, video_source )
control_audio_tracks = source_audio_tracks = source_audio_metadata = []
if any_letters(audio_prompt_type, "R") and video_guide is not None and MMAudio_setting == 0 and not any_letters(audio_prompt_type, "ABXK"):
control_audio_tracks, _ = extract_audio_tracks(video_guide)
if "K" in audio_prompt_type and video_guide is not None:
try:
audio_guide = extract_audio_track_to_wav(video_guide, save_path, suffix="_control_audio")
temp_filenames_list.append(audio_guide)
except:
audio_guide = None
audio_guide2 = None
if video_source is not None:
source_audio_tracks, source_audio_metadata = extract_audio_tracks(video_source)
video_fps, _, _, video_frames_count = get_video_info(video_source)
video_source_duration = video_frames_count / video_fps
else:
video_source_duration = 0
reset_control_aligment = "T" in video_prompt_type
if test_any_sliding_window(model_type) :
if video_source is not None:
current_video_length += sliding_window_overlap - 1
sliding_window = current_video_length > sliding_window_size
reuse_frames = min(sliding_window_size - latent_size, sliding_window_overlap)
else:
sliding_window = False
sliding_window_size = current_video_length
reuse_frames = 0
original_image_refs = image_refs
image_refs = None if image_refs is None else ([] + image_refs) # work on a copy as it is going to be modified
# image_refs = None
# nb_frames_positions= 0
# Output Video Ratio Priorities:
# Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height
# Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio
frames_to_inject = []
any_background_ref = 0
if "K" in video_prompt_type:
any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1
outpainting_dims = get_outpainting_dims(video_guide_outpainting)
fit_canvas = server_config.get("fit_canvas", 0)
fit_crop = fit_canvas == 2
if fit_crop and outpainting_dims is not None:
fit_crop = False
fit_canvas = 0
joint_pass = boost ==1 #and profile != 1 and profile != 3
skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type)
if skip_steps_cache != None:
skip_steps_cache.update({
"multiplier" : skip_steps_multiplier,
"start_step": int(skip_steps_start_step_perc*num_inference_steps/100)
})
model_handler.set_cache_parameters(skip_steps_cache_type, base_model_type, model_def, locals(), skip_steps_cache)
if skip_steps_cache_type == "mag":
def_mag_ratios = model_def.get("magcache_ratios", None) if model_def != None else None
if def_mag_ratios is not None: skip_steps_cache.def_mag_ratios = def_mag_ratios
elif skip_steps_cache_type == "tea":
def_tea_coefficients = model_def.get("teacache_coefficients", None) if model_def != None else None
if def_tea_coefficients is not None: skip_steps_cache.coefficients = def_tea_coefficients
else:
raise Exception(f"unknown cache type {skip_steps_cache_type}")
trans.cache = skip_steps_cache
if trans2 is not None: trans2.cache = skip_steps_cache
face_arc_embeds = None
src_ref_images = src_ref_masks = None
output_new_audio_data = None
output_new_audio_filepath = None
original_audio_guide = audio_guide
original_audio_guide2 = audio_guide2
audio_proj_split = None
audio_proj_full = None
audio_scale = audio_scale if model_def.get("audio_scale_name") else None
audio_context_lens = None
if audio_guide != None:
from preprocessing.extract_vocals import get_vocals
import librosa
duration = librosa.get_duration(path=audio_guide)
combination_type = "add"
clean_audio_files = "V" in audio_prompt_type
if audio_guide2 is not None:
duration2 = librosa.get_duration(path=audio_guide2)
if "C" in audio_prompt_type: duration += duration2
else: duration = min(duration, duration2)
combination_type = "para" if "P" in audio_prompt_type else "add"
if clean_audio_files:
audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav"))
audio_guide2 = get_vocals(original_audio_guide2, get_available_filename(save_path, audio_guide2, "_clean2", ".wav"))
temp_filenames_list += [audio_guide, audio_guide2]
else:
if "X" in audio_prompt_type:
# dual speaker, voice separation
from preprocessing.speakers_separator import extract_dual_audio
combination_type = "para"
if args.save_speakers:
audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav"
else:
audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav")
temp_filenames_list += [audio_guide, audio_guide2]
if clean_audio_files:
clean_audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, original_audio_guide, "_clean", ".wav"))
temp_filenames_list += [clean_audio_guide]
extract_dual_audio(clean_audio_guide if clean_audio_files else original_audio_guide, audio_guide, audio_guide2)
elif clean_audio_files:
# Single Speaker
audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav"))
temp_filenames_list += [audio_guide]
output_new_audio_filepath = original_audio_guide
current_video_length = min(int(fps * duration //latent_size) * latent_size + latent_size + 1, current_video_length)
if fantasy:
from models.wan.fantasytalking.infer import parse_audio
# audio_proj_split_full, audio_context_lens_full = parse_audio(audio_guide, num_frames= max_source_video_frames, fps= fps, padded_frames_for_embeddings= (reuse_frames if reset_control_aligment else 0), device= processing_device )
if audio_scale is None:
audio_scale = 1.0
elif multitalk:
from models.wan.multitalk.multitalk import get_full_audio_embeddings
# pad audio_proj_full if aligned to beginning of window to simulate source window overlap
min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps
audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration)
if output_new_audio_data is not None: # not none if modified
if clean_audio_files: # need to rebuild the sum of audios with original audio
_, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = original_audio_guide, audio_guide2= original_audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration, return_sum_only= True)
output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined
if hunyuan_custom_edit and video_guide != None:
import cv2
cap = cv2.VideoCapture(video_guide)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
current_video_length = min(current_video_length, length)
seed = set_seed(seed)
torch.set_grad_enabled(False)
os.makedirs(save_path, exist_ok=True)
os.makedirs(image_save_path, exist_ok=True)
gc.collect()
torch.cuda.empty_cache()
wan_model._interrupt = False
abort = False
if gen.get("abort", False):
return
# gen["abort"] = False
gen["prompt"] = prompt
repeat_no = 0
extra_generation = 0
initial_total_windows = 0
discard_last_frames = sliding_window_discard_last_frames
default_requested_frames_to_generate = current_video_length
nb_frames_positions = 0
if sliding_window:
initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames)
current_video_length = sliding_window_size
else:
initial_total_windows = 1
first_window_video_length = current_video_length
original_prompts = prompts.copy()
gen["sliding_window"] = sliding_window
while not abort:
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
total_generation = repeat_generation + extra_generation
gen["total_generation"] = total_generation
gen["header_text"] = ""
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
src_video = src_video2 = src_mask = src_mask2 = src_faces = sparse_video_image = full_generated_audio =None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
frames_already_processed = None
overlapped_latents = None
context_scale = None
window_no = 0
extra_windows = 0
abort_scheduled = False
guide_start_frame = 0 # pos of of first control video frame of current window (reuse_frames later than the first processed frame)
keep_frames_parsed = [] # aligned to the first control frame of current window (therefore ignore previous reuse_frames)
pre_video_guide = None # reuse_frames of previous window
image_size = default_image_size # default frame dimensions for budget until it is change due to a resize
sample_fit_canvas = fit_canvas
current_video_length = first_window_video_length
gen["extra_windows"] = 0
gen["total_windows"] = 1
gen["window_no"] = 1
input_waveform, input_waveform_sample_rate = None, 0
num_frames_generated = 0 # num of new frames created (lower than the number of frames really processed due to overlaps and discards)
requested_frames_to_generate = default_requested_frames_to_generate # num of num frames to create (if any source window this num includes also the overlapped source window frames)
cached_video_guide_processed = cached_video_mask_processed = cached_video_guide_processed2 = cached_video_mask_processed2 = None
cached_video_video_start_frame = cached_video_video_end_frame = -1
start_time = time.time()
if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0 and server_config.get("enhancer_mode", 0) == 0:
send_cmd("progress", [0, get_latest_status(state, "Enhancing Prompt")])
enhanced_prompts = process_prompt_enhancer(model_def, prompt_enhancer, original_prompts, image_start, original_image_refs, is_image, audio_only, seed )
if enhanced_prompts is not None:
print(f"Enhanced prompts: {enhanced_prompts}" )
task["prompt"] = "\n".join(["!enhanced!"] + enhanced_prompts)
send_cmd("output")
prompts = enhanced_prompts
abort = gen.get("abort", False)
while not abort:
enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)+1) or RIFLEx_setting == 1
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
new_extra_windows = gen.get("extra_windows",0)
gen["extra_windows"] = 0
extra_windows += new_extra_windows
requested_frames_to_generate += new_extra_windows * (sliding_window_size - discard_last_frames - reuse_frames)
sliding_window = sliding_window or extra_windows > 0
if sliding_window and window_no > 0:
# num_frames_generated -= reuse_frames
if (requested_frames_to_generate - num_frames_generated) < latent_size:
break
current_video_length = min(sliding_window_size, ((requested_frames_to_generate - num_frames_generated + reuse_frames + discard_last_frames) // latent_size) * latent_size + 1 )
total_windows = initial_total_windows + extra_windows
gen["total_windows"] = total_windows
if window_no >= total_windows:
break
window_no += 1
gen["window_no"] = window_no
return_latent_slice = None
if reuse_frames > 0:
return_latent_slice = slice(- max(1, (reuse_frames + discard_last_frames ) // latent_size) , None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None:
image_start_tensor, new_height, new_width = calculate_dimensions_and_resize_image(image_start, height, width, sample_fit_canvas, fit_crop, block_size = block_size)
if fit_crop: refresh_preview["image_start"] = image_start_tensor
image_start_tensor = convert_image_to_tensor(image_start_tensor)
pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1)
else:
prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size )
prefix_video = prefix_video.permute(3, 0, 1, 2)
prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w
if fit_crop or "L" in image_prompt_type: refresh_preview["video_source"] = convert_tensor_to_image(prefix_video, 0)
new_height, new_width = prefix_video.shape[-2:]
pre_video_guide = prefix_video[:, -reuse_frames:]
pre_video_frame = convert_tensor_to_image(prefix_video[:, -1])
source_video_overlap_frames_count = pre_video_guide.shape[1]
source_video_frames_count = prefix_video.shape[1]
if sample_fit_canvas != None:
image_size = pre_video_guide.shape[-2:]
sample_fit_canvas = None
guide_start_frame = prefix_video.shape[1]
if image_end is not None:
image_end_list= image_end if isinstance(image_end, list) else [image_end]
if len(image_end_list) >= window_no:
new_height, new_width = image_size
image_end_tensor, _, _ = calculate_dimensions_and_resize_image(image_end_list[window_no-1], new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size)
# image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
refresh_preview["image_end"] = image_end_tensor
image_end_tensor = convert_image_to_tensor(image_end_tensor)
image_end_list= None
window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count)
guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames)
alignment_shift = source_video_frames_count if reset_control_aligment else 0
aligned_guide_start_frame = guide_start_frame - alignment_shift
aligned_guide_end_frame = guide_end_frame - alignment_shift
aligned_window_start_frame = window_start_frame - alignment_shift
if audio_guide is not None and model_def.get("audio_guide_window_slicing", False):
audio_start_frame = aligned_window_start_frame
if reset_control_aligment:
audio_start_frame += source_video_overlap_frames_count
input_waveform, input_waveform_sample_rate = slice_audio_window( audio_guide, audio_start_frame, current_video_length, fps, save_path, suffix=f"_win{window_no}", )
if fantasy and audio_guide is not None:
audio_proj_split , audio_context_lens = parse_audio(audio_guide, start_frame = aligned_window_start_frame, num_frames= current_video_length, fps= fps, device= processing_device )
if multitalk:
from models.wan.multitalk.multitalk import get_window_audio_embeddings
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0:
frames_positions_list = []
if frames_positions is not None and len(frames_positions)> 0:
positions = frames_positions.replace(","," ").split(" ")
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count)
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
joker_used = False
project_window_no = 1
for pos in positions :
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
if cur_end_pos >= last_frame_no-1 and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
frames_positions_list.append(cur_end_pos)
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
else:
frames_positions_list.append(int(pos)-1 + alignment_shift)
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
if nb_frames_positions > 0:
frames_to_inject = [None] * (max(frames_positions_list) + 1)
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = sparse_video_image = None
if video_guide is not None:
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
extra_control_frames = model_def.get("extra_control_frames", 0)
if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
guide_frames_extract_count = len(keep_frames_parsed)
process_all = model_def.get("preprocess_all", False)
if process_all:
guide_slice_to_extract = guide_frames_extract_count
guide_frames_extract_count = (-guide_frames_extract_start if guide_frames_extract_start <0 else 0) + len( keep_frames_parsed_full[max(0, guide_frames_extract_start):] )
# Extract Faces to video
if "B" in video_prompt_type:
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
if src_faces is not None and src_faces.shape[1] < current_video_length:
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
# Sparse Video to Video
sparse_video_image = None
if "R" in video_prompt_type:
sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True)
if not process_all or cached_video_video_start_frame < 0:
# Generic Video Preprocessing
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
preprocess_type, preprocess_type2 = "raw", None
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, video_guide_processes)):
if process_num == 0:
preprocess_type = process_map_video_guide.get(process_letter, "raw")
else:
preprocess_type2 = process_map_video_guide.get(process_letter, None)
custom_preprocessor = model_def.get("custom_preprocessor", None)
if custom_preprocessor is not None:
status_info = custom_preprocessor
send_cmd("progress", [0, get_latest_status(state, status_info)])
video_guide_processed, video_guide_processed2, video_mask_processed, video_mask_processed2 = custom_preprocess_video_with_mask(model_handler, base_model_type, pre_video_guide, video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size, expand_scale = mask_expand, video_prompt_type= video_prompt_type)
else:
status_info = "Extracting " + processes_names[preprocess_type]
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
if len(extra_process_list) == 1:
status_info += " and " + processes_names[extra_process_list[0]]
elif len(extra_process_list) == 2:
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size, to_bbox = "H" in video_prompt_type )
if preprocess_type2 != None:
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size, to_bbox = "H" in video_prompt_type )
if video_guide_processed is not None and sample_fit_canvas is not None:
image_size = video_guide_processed.shape[-2:]
sample_fit_canvas = None
if process_all:
cached_video_guide_processed, cached_video_mask_processed, cached_video_guide_processed2, cached_video_mask_processed2 = video_guide_processed, video_mask_processed, video_guide_processed2, video_mask_processed2
cached_video_video_start_frame = guide_frames_extract_start
if process_all:
process_slice = slice(guide_frames_extract_start - cached_video_video_start_frame, guide_frames_extract_start - cached_video_video_start_frame + guide_slice_to_extract )
video_guide_processed = None if cached_video_guide_processed is None else cached_video_guide_processed[:, process_slice]
video_mask_processed = None if cached_video_mask_processed is None else cached_video_mask_processed[:, process_slice]
video_guide_processed2 = None if cached_video_guide_processed2 is None else cached_video_guide_processed2[:, process_slice]
video_mask_processed2 = None if cached_video_mask_processed2 is None else cached_video_mask_processed2[:, process_slice]
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
from shared.utils.utils import get_outpainting_full_area_dimensions
w, h = image_refs[0].size
if outpainting_dims != None:
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
sample_fit_canvas = None
if repeat_no == 1:
if fit_crop:
if any_background_ref == 2:
end_ref_position = len(image_refs)
elif any_background_ref == 1:
end_ref_position = nb_frames_positions + 1
else:
end_ref_position = nb_frames_positions
for i, img in enumerate(image_refs[:end_ref_position]):
image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0])
refresh_preview["image_refs"] = image_refs
if len(image_refs) > nb_frames_positions:
src_ref_images = image_refs[nb_frames_positions:]
if "Q" in video_prompt_type:
from preprocessing.arc.face_encoder import FaceEncoderArcFace, get_landmarks_from_image
image_pil = src_ref_images[-1]
face_encoder = FaceEncoderArcFace()
face_encoder.init_encoder_model(processing_device)
face_arc_embeds = face_encoder(image_pil, need_proc=True, landmarks=get_landmarks_from_image(image_pil))
face_arc_embeds = face_arc_embeds.squeeze(0).cpu()
face_encoder = image_pil = None
gc.collect()
torch.cuda.empty_cache()
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0],
remove_background_images_ref > 0, any_background_ref,
fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1),
block_size=block_size,
outpainting_dims =outpainting_dims,
background_ref_outpainted = model_def.get("background_ref_outpainted", True),
return_tensor= model_def.get("return_image_refs_tensor", False),
ignore_last_refs =model_def.get("no_processing_on_last_images_refs",0),
background_removal_color = model_def.get("background_removal_color", [255, 255, 255] ))
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False):
any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False)
any_guide_padding = model_def.get("pad_guide_video", False)
dont_cat_preguide = extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None
from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
[video_mask_processed] + ([] if video_guide_processed2 is None else [video_mask_processed2]),
None if dont_cat_preguide else pre_video_guide,
image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color,
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None
if len(src_videos) == 1:
src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None
else:
src_video, src_video2 = src_videos
src_mask, src_mask2 = src_masks
src_videos = src_masks = None
if src_video is None or window_no >1 and src_video.shape[1] <= sliding_window_overlap and not dont_cat_preguide:
abort = True
break
if model_def.get("control_video_trim", False) :
if src_video is None:
abort = True
break
elif src_video.shape[1] < current_video_length:
current_video_length = src_video.shape[1]
abort_scheduled = True
if src_faces is not None:
if src_faces.shape[1] < src_video.shape[1]:
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
else:
src_faces = src_faces[:, :src_video.shape[1]]
if video_guide is not None or len(frames_to_inject_parsed) > 0:
if args.save_masks:
if src_video is not None:
save_video( src_video, "masked_frames.mp4", fps)
if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
if src_video2 is not None:
save_video( src_video2, "masked_frames2.mp4", fps)
if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1))
if video_guide is not None:
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
preview_frame_no = min(src_video.shape[1] -1, preview_frame_no)
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
if src_video2 is not None and not model_def.get("no_guide2_refresh", False):
refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)]
if src_mask is not None and video_mask is not None and not model_def.get("no_mask_refresh", False):
refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True)
if src_ref_images is not None or nb_frames_positions:
if len(frames_to_inject_parsed):
new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
else:
new_image_refs = []
if src_ref_images is not None:
new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ]
refresh_preview["image_refs"] = new_image_refs
new_image_refs = None
if len(refresh_preview) > 0:
new_inputs= locals()
new_inputs.update(refresh_preview)
update_task_thumbnails(task, new_inputs)
send_cmd("output")
if window_no == 1:
conditioning_latents_size = ( (source_video_overlap_frames_count-1) // latent_size) + 1 if source_video_overlap_frames_count > 0 else 0
else:
conditioning_latents_size = ( (reuse_frames-1) // latent_size) + 1
status = get_latest_status(state)
gen["progress_status"] = status
progress_phase = "Generation Audio" if audio_only else "Encoding Prompt"
gen["progress_phase"] = (progress_phase , -1 )
callback = build_callback(state, trans, send_cmd, status, num_inference_steps)
progress_args = [0, merge_status_context(status, progress_phase )]
send_cmd("progress", progress_args)
if skip_steps_cache != None:
skip_steps_cache.update({
"num_steps" : num_inference_steps,
"skipped_steps" : 0,
"previous_residual": None,
"previous_modulated_input": None,
})
# samples = torch.empty( (1,2)) #for testing
# if False:
def set_header_text(txt):
gen["header_text"] = txt
send_cmd("output")
try:
input_video_for_model = pre_video_guide
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames
samples = wan_model.generate(
input_prompt = prompt,
image_start = image_start_tensor,
image_end = image_end_tensor,
input_frames = src_video,
input_frames2 = src_video2,
input_ref_images= src_ref_images,
input_ref_masks = src_ref_masks,
input_masks = src_mask,
input_masks2 = src_mask2,
input_video= input_video_for_model,
input_faces = src_faces,
input_custom = custom_guide,
denoising_strength=denoising_strength,
masking_strength=masking_strength,
prefix_frames_count = prefix_frames_count,
frame_num= (current_video_length // latent_size)* latent_size + 1,
batch_size = batch_size,
height = image_size[0],
width = image_size[1],
fit_into_canvas = fit_canvas,
shift=flow_shift,
sample_solver=sample_solver,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
guide2_scale = guidance2_scale,
guide3_scale = guidance3_scale,
switch_threshold = switch_threshold,
switch2_threshold = switch_threshold2,
guide_phases= guidance_phases,
model_switch_phase = model_switch_phase,
embedded_guidance_scale=embedded_guidance_scale,
n_prompt=negative_prompt,
seed=seed,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
joint_pass = joint_pass,
slg_layers = slg_layers,
slg_start = slg_start_perc/100,
slg_end = slg_end_perc/100,
apg_switch = apg_switch,
cfg_star_switch = cfg_star_switch,
cfg_zero_step = cfg_zero_step,
alt_guide_scale= alt_guidance_scale,
audio_cfg_scale= audio_guidance_scale,
input_waveform=input_waveform,
input_waveform_sample_rate=input_waveform_sample_rate,
audio_guide=audio_guide,
audio_guide2=audio_guide2,
audio_proj= audio_proj_split,
audio_scale= audio_scale,
audio_context_lens= audio_context_lens,
context_scale = context_scale,
control_scale_alt = control_net_weight_alt,
motion_amplitude = motion_amplitude,
model_mode = model_mode,
causal_block_size = 5,
causal_attention = True,
fps = fps,
overlapped_latents = overlapped_latents,
return_latent_slice= return_latent_slice,
overlap_noise = sliding_window_overlap_noise,
overlap_size = sliding_window_overlap,
color_correction_strength = sliding_window_color_correction_strength,
conditioning_latents_size = conditioning_latents_size,
keep_frames_parsed = keep_frames_parsed,
model_filename = model_filename,
model_type = base_model_type,
loras_slists = loras_slists,
NAG_scale = NAG_scale,
NAG_tau = NAG_tau,
NAG_alpha = NAG_alpha,
speakers_bboxes =speakers_bboxes,
image_mode = image_mode,
video_prompt_type= video_prompt_type,
window_no = window_no,
offloadobj = offloadobj,
set_header_text= set_header_text,
pre_video_frame = pre_video_frame,
prefix_video = prefix_video,
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
outpainting_dims = outpainting_dims,
face_arc_embeds = face_arc_embeds,
exaggeration=exaggeration,
pace=pace,
temperature=temperature,
window_start_frame_no = window_start_frame,
input_video_strength = input_video_strength,
)
except Exception as e:
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks)
remove_temp_filenames(temp_filenames_list)
clear_gen_cache()
offloadobj.unload_all()
trans.cache = None
if trans2 is not None:
trans2.cache = None
offload.unload_loras_from_model(trans_lora)
if trans2_lora is not None:
offload.unload_loras_from_model(trans2_lora)
skip_steps_cache = None
# if compile:
# cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset()
# torch._dynamo.config.cache_size_limit = cache_size
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = {"CUDA out of memory" : "VRAM", "Tried to allocate":"VRAM", "CUDA error: out of memory": "RAM", "CUDA error: too many resources requested": "RAM"}
crash_type = ""
for keyword, tp in keyword_list.items():
if keyword in s:
crash_type = tp
break
state["prompt"] = ""
if crash_type == "VRAM":
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
elif crash_type == "RAM":
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient RAM and / or Reserved RAM allocation should be reduced using 'perc_reserved_mem_max' or using a different Profile."
else:
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
tb = traceback.format_exc().split('\n')[:-1]
print('\n'.join(tb))
send_cmd("error", new_error)
clear_status(state)
return
if skip_steps_cache != None :
skip_steps_cache.previous_residual = None
skip_steps_cache.previous_modulated_input = None
print(f"Skipped Steps:{skip_steps_cache.skipped_steps}/{skip_steps_cache.num_steps}" )
generated_audio = None
BGRA_frames = None
post_decode_pre_trim = 0
output_audio_sampling_rate= audio_sampling_rate
if samples != None:
if isinstance(samples, dict):
overlapped_latents = samples.get("latent_slice", None)
BGRA_frames = samples.get("BGRA_frames", None)
generated_audio = samples.get("audio", generated_audio)
if generated_audio is not None and model_def.get("output_audio_is_input_audio", False) and output_new_audio_filepath is not None:
generated_audio = None
else:
output_new_audio_filepath = None
output_audio_sampling_rate = samples.get("audio_sampling_rate", audio_sampling_rate)
post_decode_pre_trim = samples.get("post_decode_pre_trim", 0)
samples = samples.get("x", None)
if samples is not None:
samples = samples.to("cpu")
clear_gen_cache()
offloadobj.unload_all()
gc.collect()
torch.cuda.empty_cache()
# time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
# save_prompt = "_in_" + original_prompts[0]
# file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(save_prompt[:50]).strip()}.mp4"
# sample = samples.cpu()
# cache_video( tensor=sample[None].clone(), save_file=os.path.join(save_path, file_name), fps=16, nrow=1, normalize=True, value_range=(-1, 1))
if samples == None:
abort = True
state["prompt"] = ""
send_cmd("output")
else:
sample = samples.cpu()
abort = abort_scheduled or not (is_image or audio_only) and sample.shape[1] < current_video_length
# if True: # for testing
# torch.save(sample, "output.pt")
# else:
# sample =torch.load("output.pt")
if post_decode_pre_trim > 0 :
sample = sample[:, post_decode_pre_trim:]
if gen.get("extra_windows",0) > 0:
sliding_window = True
if sliding_window :
# guide_start_frame = guide_end_frame
guide_start_frame += current_video_length
if discard_last_frames > 0:
sample = sample[: , :-discard_last_frames]
guide_start_frame -= discard_last_frames
if generated_audio is not None:
generated_audio = truncate_audio( generated_audio, 0, discard_last_frames, fps, output_audio_sampling_rate,)
if reuse_frames == 0:
pre_video_guide = sample[:,max_source_video_frames :].clone()
else:
pre_video_guide = sample[:, -reuse_frames:].clone()
if prefix_video != None and window_no == 1 :
if prefix_video.shape[1] > 1:
# remove sliding window overlapped frames at the beginning of the generation
sample = torch.cat([ prefix_video, sample[: , source_video_overlap_frames_count:]], dim = 1)
else:
# remove source video overlapped frames at the beginning of the generation if there is only a start frame
sample = torch.cat([ prefix_video[:, :-source_video_overlap_frames_count], sample], dim = 1)
guide_start_frame -= source_video_overlap_frames_count
if generated_audio is not None:
generated_audio = truncate_audio( generated_audio, source_video_overlap_frames_count, 0, fps, output_audio_sampling_rate,)
elif sliding_window and window_no > 1 and reuse_frames > 0:
# remove sliding window overlapped frames at the beginning of the generation
sample = sample[: , reuse_frames:]
guide_start_frame -= reuse_frames
if generated_audio is not None:
generated_audio = truncate_audio( generated_audio, reuse_frames, 0, fps, output_audio_sampling_rate,)
num_frames_generated = guide_start_frame - (source_video_frames_count - source_video_overlap_frames_count)
if generated_audio is not None:
full_generated_audio = generated_audio if full_generated_audio is None else np.concatenate([full_generated_audio, generated_audio], axis=0)
output_new_audio_data = full_generated_audio
if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 and not "vae2" in spatial_upsampling:
send_cmd("progress", [0, get_latest_status(state,"Upsampling")])
output_fps = fps
if len(temporal_upsampling) > 0:
sample, previous_last_frame, output_fps = perform_temporal_upsampling(sample, previous_last_frame if sliding_window and window_no > 1 else None, temporal_upsampling, fps)
if len(spatial_upsampling) > 0:
sample = perform_spatial_upsampling(sample, spatial_upsampling )
if film_grain_intensity> 0:
from postprocessing.film_grain import add_film_grain
sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation)
if sliding_window :
if frames_already_processed == None:
frames_already_processed = sample
else:
sample = torch.cat([frames_already_processed, sample], dim=1)
frames_already_processed = sample
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
save_prompt = original_prompts[0]
if audio_only:
extension = "wav"
elif is_image:
extension = "jpg"
else:
container = server_config.get("video_container", "mp4")
extension = container
inputs = get_function_arguments(generate_video, locals())
if len(output_filename):
from shared.utils.filename_formatter import FilenameFormatter
file_name = FilenameFormatter.format_filename(output_filename, inputs)
file_name = f"{sanitize_file_name(truncate_for_filesystem(os.path.splitext(os.path.basename(file_name))[0])).strip()}.{extension}"
file_name = os.path.basename(get_available_filename(save_path, file_name))
else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(truncate_for_filesystem(save_prompt)).strip()}.{extension}"
video_path = os.path.join(save_path, file_name)
mmaudio_enabled, mmaudio_mode, mmaudio_persistence, mmaudio_model_name, mmaudio_model_path = get_mmaudio_settings(server_config)
any_mmaudio = MMAudio_setting != 0 and mmaudio_enabled and sample.shape[1] >=fps
if BGRA_frames is not None:
from models.wan.alpha.utils import write_zip_file
write_zip_file(os.path.splitext(video_path)[0] + ".zip", BGRA_frames)
BGRA_frames = None
if audio_only:
audio_path = os.path.join(image_save_path, file_name)
write_wav_file(audio_path, sample.squeeze(0), output_audio_sampling_rate)
video_path= audio_path
elif is_image:
image_path = os.path.join(image_save_path, file_name)
sample = sample.transpose(1,0) #c f h w -> f c h w
new_image_path = []
for no, img in enumerate(sample):
img_path = os.path.splitext(image_path)[0] + ("" if no==0 else f"_{no}") + ".jpg"
new_image_path.append(save_image(img, save_file = img_path, quality = server_config.get("image_output_codec", None)))
video_path= new_image_path
elif len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0 or output_new_audio_filepath is not None or any_mmaudio or output_new_audio_data is not None or audio_source is not None:
video_path = os.path.join(save_path, file_name)
save_path_tmp = video_path.rsplit('.', 1)[0] + f"_tmp.{container}"
save_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type = server_config.get("video_output_codec", None), container=container)
output_new_audio_temp_filepath = None
new_audio_added_from_audio_start = reset_control_aligment or full_generated_audio is not None # if not beginning of audio will be skipped
source_audio_duration = source_video_frames_count / fps
if any_mmaudio:
send_cmd("progress", [0, get_latest_status(state,"MMAudio Soundtrack Generation")])
from postprocessing.mmaudio.mmaudio import video_to_audio
output_new_audio_filepath = output_new_audio_temp_filepath = get_available_filename(save_path, f"tmp{time_flag}.wav" )
video_to_audio(save_path_tmp, prompt = MMAudio_prompt, negative_prompt = MMAudio_neg_prompt, seed = seed, num_steps = 25, cfg_strength = 4.5, duration= sample.shape[1] /fps, save_path = output_new_audio_filepath, persistent_models = mmaudio_persistence == MMAUDIO_PERSIST_RAM, audio_file_only = True, verboseLevel = verbose_level, model_name = mmaudio_model_name, model_path = mmaudio_model_path)
new_audio_added_from_audio_start = False
elif audio_source is not None:
output_new_audio_filepath = audio_source
new_audio_added_from_audio_start = True
elif output_new_audio_data is not None:
output_new_audio_filepath = output_new_audio_temp_filepath = get_available_filename(save_path, f"tmp{time_flag}.wav" )
write_wav_file(output_new_audio_filepath, output_new_audio_data, output_audio_sampling_rate)
if output_new_audio_filepath is not None:
new_audio_tracks = [output_new_audio_filepath]
else:
new_audio_tracks = control_audio_tracks
combine_and_concatenate_video_with_audio_tracks(
video_path,
save_path_tmp,
source_audio_tracks,
new_audio_tracks,
source_audio_duration,
output_audio_sampling_rate,
new_audio_from_start=new_audio_added_from_audio_start,
source_audio_metadata=source_audio_metadata,
verbose=verbose_level >= 2,
)
os.remove(save_path_tmp)
if output_new_audio_temp_filepath is not None: os.remove(output_new_audio_temp_filepath)
else:
save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container= container)
end_time = time.time()
inputs.pop("send_cmd")
inputs.pop("task")
inputs.pop("mode")
inputs["model_type"] = model_type
inputs["model_filename"] = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)
if is_image:
inputs["image_quality"] = server_config.get("image_output_codec", None)
else:
inputs["video_quality"] = server_config.get("video_output_codec", None)
modules = get_model_recursive_prop(model_type, "modules", return_list= True)
if len(modules) > 0 : inputs["modules"] = modules
if len(transformer_loras_filenames) > 0:
inputs.update({
"transformer_loras_filenames" : transformer_loras_filenames,
"transformer_loras_multipliers" : transformer_loras_multipliers
})
embedded_images = {img_name: inputs[img_name] for img_name in image_names_list } if server_config.get("embed_source_images", False) else None
configs = prepare_inputs_dict("metadata", inputs, model_type)
if sliding_window: configs["window_no"] = window_no
configs["prompt"] = "\n".join(original_prompts)
if prompt_enhancer_image_caption_model != None and prompt_enhancer !=None and len(prompt_enhancer)>0:
configs["enhanced_prompt"] = "\n".join(prompts)
configs["generation_time"] = round(end_time-start_time)
# if sample_is_image: configs["is_image"] = True
metadata_choice = server_config.get("metadata_type","metadata")
video_path = [video_path] if not isinstance(video_path, list) else video_path
for no, path in enumerate(video_path):
if metadata_choice == "json":
json_path = os.path.splitext(path)[0] + ".json"
with open(json_path, 'w') as f:
json.dump(configs, f, indent=4)
elif metadata_choice == "metadata":
if audio_only:
save_audio_metadata(path, configs)
if is_image:
save_image_metadata(path, configs)
else:
save_video_metadata(path, configs, embedded_images)
if audio_only:
print(f"New audio file saved to Path: "+ path)
elif is_image:
print(f"New image saved to Path: "+ path)
else:
print(f"New video saved to Path: "+ path)
with lock:
if audio_only:
audio_file_list.append(path)
audio_file_settings_list.append(configs if no > 0 else configs.copy())
else:
file_list.append(path)
file_settings_list.append(configs if no > 0 else configs.copy())
gen["last_was_audio"] = audio_only
embedded_images = None
# Play notification sound for single video
try:
if server_config.get("notification_sound_enabled", 0):
volume = server_config.get("notification_sound_volume", 50)
notification_sound.notify_video_completion(
video_path=video_path,
volume=volume
)
except Exception as e:
print(f"Error playing notification sound for individual video: {e}")
send_cmd("output")
seed = set_seed(-1)
clear_status(state)
trans.cache = None
offload.unload_loras_from_model(trans_lora)
if not trans2_lora is None:
offload.unload_loras_from_model(trans2_lora)
if not trans2 is None:
trans2.cache = None
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks)
remove_temp_filenames(temp_filenames_list)
def prepare_generate_video(state):
if state.get("validate_success",0) != 1:
return gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.update(visible=False)
else:
return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True), gr.update(visible= False)
def generate_preview(model_type, payload):
import einops
if payload is None:
return None
if isinstance(payload, dict):
meta = {k: v for k, v in payload.items() if k != "latents"}
latents = payload.get("latents")
else:
meta = {}
latents = payload
if latents is None:
return None
# latents shape should be C, T, H, W (no batch)
if not torch.is_tensor(latents):
return None
model_handler = get_model_handler(model_type)
base_model_type = get_base_model_type(model_type)
custom_preview = getattr(model_handler, "preview_latents", None)
if callable(custom_preview):
preview = custom_preview(base_model_type, latents, meta)
if preview is not None:
return preview
if hasattr(model_handler, "get_rgb_factors"):
latent_rgb_factors, latent_rgb_factors_bias = model_handler.get_rgb_factors(base_model_type )
else:
return None
if latent_rgb_factors is None: return None
latents = latents.unsqueeze(0)
nb_latents = latents.shape[2]
latents_to_preview = 4
latents_to_preview = min(nb_latents, latents_to_preview)
skip_latent = nb_latents / latents_to_preview
latent_no = 0
selected_latents = []
while latent_no < nb_latents:
selected_latents.append( latents[:, : , int(latent_no): int(latent_no)+1])
latent_no += skip_latent
latents = torch.cat(selected_latents, dim = 2)
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
images = images.add_(1.0).mul_(127.5)
images = images.detach().cpu()
if images.dtype == torch.bfloat16:
images = images.to(torch.float16)
images = images.numpy().clip(0, 255).astype(np.uint8)
images = einops.rearrange(images, 'b c t h w -> (b h) (t w) c')
h, w, _ = images.shape
scale = 200 / h
images= Image.fromarray(images)
images = images.resize(( int(w*scale),int(h*scale)), resample=Image.Resampling.BILINEAR)
return images
def process_tasks(state):
from shared.utils.thread_utils import AsyncStream, async_run
gen = get_gen_info(state)
queue = gen.get("queue", [])
progress = None
if len(queue) == 0:
gen["status_display"] = False
return
with lock:
gen = get_gen_info(state)
clear_file_list = server_config.get("clear_file_list", 0)
def truncate_list(file_list, file_settings_list, choice):
if clear_file_list > 0:
file_list_current_size = len(file_list)
keep_file_from = max(file_list_current_size - clear_file_list, 0)
files_removed = keep_file_from
choice = max(choice- files_removed, 0)
file_list = file_list[ keep_file_from: ]
file_settings_list = file_settings_list[ keep_file_from: ]
else:
file_list = []
choice = 0
return file_list, file_settings_list, choice
file_list = gen.get("file_list", [])
file_settings_list = gen.get("file_settings_list", [])
choice = gen.get("selected",0)
gen["file_list"], gen["file_settings_list"], gen["selected"] = truncate_list(file_list, file_settings_list, choice)
audio_file_list = gen.get("audio_file_list", [])
audio_file_settings_list = gen.get("audio_file_settings_list", [])
audio_choice = gen.get("audio_selected",0)
gen["audio_file_list"], gen["audio_file_settings_list"], gen["audio_selected"] = truncate_list(audio_file_list, audio_file_settings_list, audio_choice)
while True:
with gen_lock:
process_status = gen.get("process_status", None)
if process_status is None or process_status == "process:main":
gen["process_status"] = "process:main"
break
time.sleep(0.1)
def release_gen():
with gen_lock:
process_status = gen.get("process_status", None)
if process_status.startswith("request:"):
gen["process_status"] = "process:" + process_status[len("request:"):]
else:
gen["process_status"] = None
start_time = time.time()
global gen_in_progress
gen_in_progress = True
gen["in_progress"] = True
gen["preview"] = None
gen["status"] = "Generating Video"
gen["header_text"] = ""
yield time.time(), time.time()
prompt_no = 0
while len(queue) > 0:
paused_for_edit = False
while gen.get("queue_paused_for_edit", False):
if not paused_for_edit:
gr.Info("Queue Paused until Current Task Edition is Done")
gen["status"] = "Queue paused for editing..."
yield time.time(), time.time()
paused_for_edit = True
time.sleep(0.5)
if paused_for_edit:
gen["status"] = "Resuming queue processing..."
yield time.time(), time.time()
prompt_no += 1
gen["prompt_no"] = prompt_no
task = None
with lock:
if len(queue) > 0:
task = queue[0]
if task is None:
break
task_id = task["id"]
params = task['params']
for key in ["model_filename", "lset_name"]:
params.pop(key, None)
com_stream = AsyncStream()
send_cmd = com_stream.output_queue.push
def generate_video_error_handler():
try:
import inspect
model_type = params.get('model_type')
known_defaults = {
'image_refs_relative_size': 50,
}
for arg_name, default_value in known_defaults.items():
if arg_name not in params:
print(f"Warning: Missing argument '{arg_name}' in loaded task. Applying default value: {default_value}")
params[arg_name] = default_value
if model_type:
default_settings = get_default_settings(model_type)
expected_args = inspect.signature(generate_video).parameters.keys()
for arg_name in expected_args:
if arg_name not in params and arg_name in default_settings:
params[arg_name] = default_settings[arg_name]
plugin_data = task.pop('plugin_data', {})
generate_video(task, send_cmd, plugin_data=plugin_data, **params)
except Exception as e:
tb = traceback.format_exc().split('\n')[:-1]
print('\n'.join(tb))
send_cmd("error",str(e))
finally:
send_cmd("exit", None)
async_run(generate_video_error_handler)
while True:
cmd, data = com_stream.output_queue.next()
if cmd == "exit":
break
elif cmd == "info":
gr.Info(data)
elif cmd == "error":
queue.clear()
gen["prompts_max"] = 0
gen["prompt"] = ""
gen["status_display"] = False
release_gen()
raise gr.Error(data, print_exception= False, duration = 0)
elif cmd == "status":
gen["status"] = data
elif cmd == "output":
gen["preview"] = None
yield time.time() , time.time()
elif cmd == "progress":
gen["progress_args"] = data
elif cmd == "preview":
torch.cuda.current_stream().synchronize()
preview= None if data== None else generate_preview(params["model_type"], data)
gen["preview"] = preview
yield time.time() , gr.Text()
else:
release_gen()
raise Exception(f"unknown command {cmd}")
abort = gen.get("abort", False)
if abort:
gen["abort"] = False
status = "Video Generation Aborted", "Video Generation Aborted"
yield time.time() , time.time()
gen["status"] = status
with lock:
queue[:] = [item for item in queue if item['id'] != task_id]
update_global_queue_ref(queue)
gen["prompts_max"] = 0
gen["prompt"] = ""
end_time = time.time()
if abort:
status = f"Video generation was aborted. Total Generation Time: {format_time(end_time-start_time)}"
else:
status = f"Total Generation Time: {format_time(end_time-start_time)}"
try:
if server_config.get("notification_sound_enabled", 1):
volume = server_config.get("notification_sound_volume", 50)
notification_sound.notify_video_completion(volume=volume)
except Exception as e:
print(f"Error playing notification sound: {e}")
gen["status"] = status
gen["status_display"] = False
release_gen()
def validate_task(task, state):
"""Validate a task's settings. Returns updated params dict or None if invalid."""
params = task.get('params', {})
model_type = params.get('model_type')
if not model_type:
print(" [SKIP] No model_type specified")
return None
inputs = primary_settings.copy()
inputs.update(params)
inputs['prompt'] = task.get('prompt', '')
inputs.setdefault('mode', "")
override_inputs, _, _, _ = validate_settings(state, model_type, single_prompt=True, inputs=inputs)
if override_inputs is None:
return None
inputs.update(override_inputs)
return inputs
def process_tasks_cli(queue, state):
"""Process queue tasks with console output for CLI mode. Returns True on success."""
from shared.utils.thread_utils import AsyncStream, async_run
import inspect
gen = get_gen_info(state)
total_tasks = len(queue)
completed = 0
skipped = 0
start_time = time.time()
for task_idx, task in enumerate(queue):
task_no = task_idx + 1
prompt_preview = (task.get('prompt', '') or '')[:60]
print(f"\n[Task {task_no}/{total_tasks}] {prompt_preview}...")
# Validate task settings before processing
validated_params = validate_task(task, state)
if validated_params is None:
print(f" [SKIP] Task {task_no} failed validation")
skipped += 1
continue
# Update gen state for this task
gen["prompt_no"] = task_no
gen["prompts_max"] = total_tasks
params = validated_params.copy()
params['state'] = state
com_stream = AsyncStream()
send_cmd = com_stream.output_queue.push
def make_error_handler(task, params, send_cmd):
def error_handler():
try:
# Filter to only valid generate_video params
expected_args = set(inspect.signature(generate_video).parameters.keys())
filtered_params = {k: v for k, v in params.items() if k in expected_args}
plugin_data = task.get('plugin_data', {})
generate_video(task, send_cmd, plugin_data=plugin_data, **filtered_params)
except Exception as e:
print(f"\n [ERROR] {e}")
traceback.print_exc()
send_cmd("error", str(e))
finally:
send_cmd("exit", None)
return error_handler
async_run(make_error_handler(task, params, send_cmd))
# Process output stream
task_error = False
last_msg_len = 0
in_status_line = False # Track if we're in an overwritable line
while True:
cmd, data = com_stream.output_queue.next()
if cmd == "exit":
if in_status_line:
print() # End the status line
break
elif cmd == "error":
print(f"\n [ERROR] {data}")
in_status_line = False
task_error = True
elif cmd == "progress":
if isinstance(data, list) and len(data) >= 2:
if isinstance(data[0], tuple):
step, total = data[0]
msg = data[1] if len(data) > 1 else ""
else:
step, msg = 0, data[1] if len(data) > 1 else str(data[0])
total = 1
status_line = f"\r [{step}/{total}] {msg}"
# Pad to clear previous longer messages
print(status_line.ljust(max(last_msg_len, len(status_line))), end="", flush=True)
last_msg_len = len(status_line)
in_status_line = True
elif cmd == "status":
# "Loading..." messages are followed by external library output, so end with newline
if "Loading" in str(data):
print(data)
in_status_line = False
last_msg_len = 0
else:
status_line = f"\r {data}"
print(status_line.ljust(max(last_msg_len, len(status_line))), end="", flush=True)
last_msg_len = len(status_line)
in_status_line = True
elif cmd == "output":
# "output" is used for UI refresh, not just video saves - don't print anything
pass
elif cmd == "info":
print(f"\n [INFO] {data}")
in_status_line = False
if not task_error:
completed += 1
print(f"\n Task {task_no} completed")
elapsed = time.time() - start_time
print(f"\n{'='*50}")
summary = f"Queue completed: {completed}/{total_tasks} tasks in {format_time(elapsed)}"
if skipped > 0:
summary += f" ({skipped} skipped)"
print(summary)
return completed == (total_tasks - skipped)
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, window_no, total_windows):
if prompts_max == 1:
if repeat_max <= 1:
status = ""
else:
status = f"Sample {repeat_no}/{repeat_max}"
else:
if repeat_max <= 1:
status = f"Prompt {prompt_no}/{prompts_max}"
else:
status = f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
if total_windows > 1:
if len(status) > 0:
status += ", "
status += f"Sliding Window {window_no}/{total_windows}"
return status
refresh_id = 0
def get_new_refresh_id():
global refresh_id
refresh_id += 1
return refresh_id
def merge_status_context(status="", context=""):
if len(status) == 0:
return context
elif len(context) == 0:
return status
else:
# Check if context already contains the time
if "|" in context:
parts = context.split("|")
return f"{status} - {parts[0].strip()} | {parts[1].strip()}"
else:
return f"{status} - {context}"
def clear_status(state):
gen = get_gen_info(state)
gen["extra_windows"] = 0
gen["total_windows"] = 1
gen["window_no"] = 1
gen["extra_orders"] = 0
gen["repeat_no"] = 0
gen["total_generation"] = 0
def get_latest_status(state, context=""):
gen = get_gen_info(state)
prompt_no = gen["prompt_no"]
prompts_max = gen.get("prompts_max",0)
total_generation = gen.get("total_generation", 1)
repeat_no = gen.get("repeat_no",0)
total_generation += gen.get("extra_orders", 0)
total_windows = gen.get("total_windows", 0)
total_windows += gen.get("extra_windows", 0)
window_no = gen.get("window_no", 0)
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, window_no, total_windows)
return merge_status_context(status, context)
def update_status(state):
gen = get_gen_info(state)
gen["progress_status"] = get_latest_status(state)
gen["refresh"] = get_new_refresh_id()
def one_more_sample(state):
gen = get_gen_info(state)
extra_orders = gen.get("extra_orders", 0)
extra_orders += 1
gen["extra_orders"] = extra_orders
in_progress = gen.get("in_progress", False)
if not in_progress :
return state
total_generation = gen.get("total_generation", 0) + extra_orders
gen["progress_status"] = get_latest_status(state)
gen["refresh"] = get_new_refresh_id()
gr.Info(f"An extra sample generation is planned for a total of {total_generation} samples for this prompt")
return state
def one_more_window(state):
gen = get_gen_info(state)
extra_windows = gen.get("extra_windows", 0)
extra_windows += 1
gen["extra_windows"]= extra_windows
in_progress = gen.get("in_progress", False)
if not in_progress :
return state
total_windows = gen.get("total_windows", 0) + extra_windows
gen["progress_status"] = get_latest_status(state)
gen["refresh"] = get_new_refresh_id()
gr.Info(f"An extra window generation is planned for a total of {total_windows} videos for this sample")
return state
def get_new_preset_msg(advanced = True):
if advanced:
return "Enter here a Name for a Lora Preset or a Settings or Choose one"
else:
return "Choose a Lora Preset or a Settings file in this List"
def compute_lset_choices(model_type, loras_presets):
global_list = []
if model_type is not None:
top_dir = "profiles" # get_lora_dir(model_type)
model_def = get_model_def(model_type)
settings_dir = get_model_recursive_prop(model_type, "profiles_dir", return_list=False)
if settings_dir is None or len(settings_dir) == 0: settings_dir = [""]
for dir in settings_dir:
if len(dir) == "": continue
cur_path = os.path.join(top_dir, dir)
if os.path.isdir(cur_path):
cur_dir_presets = glob.glob( os.path.join(cur_path, "*.json") )
cur_dir_presets =[ os.path.join(dir, os.path.basename(path)) for path in cur_dir_presets]
global_list += cur_dir_presets
global_list = sorted(global_list, key=lambda n: os.path.basename(n))
lset_list = []
settings_list = []
for item in loras_presets:
if item.endswith(".lset"):
lset_list.append(item)
else:
settings_list.append(item)
sep = '\u2500'
indent = chr(160) * 4
lset_choices = []
if len(global_list) > 0:
lset_choices += [( (sep*12) +"Accelerators Profiles" + (sep*13), ">profiles")]
lset_choices += [ ( indent + os.path.splitext(os.path.basename(preset))[0], preset) for preset in global_list ]
if len(settings_list) > 0:
settings_list.sort()
lset_choices += [( (sep*16) +"Settings" + (sep*17), ">settings")]
lset_choices += [ ( indent + os.path.splitext(preset)[0], preset) for preset in settings_list ]
if len(lset_list) > 0:
lset_list.sort()
lset_choices += [( (sep*18) + "Lsets" + (sep*18), ">lset")]
lset_choices += [ ( indent + os.path.splitext(preset)[0], preset) for preset in lset_list ]
return lset_choices
def get_lset_name(state, lset_name):
presets = state["loras_presets"]
if len(lset_name) == 0 or lset_name.startswith(">") or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False): return ""
if lset_name in presets: return lset_name
model_type = get_state_model_type(state)
choices = compute_lset_choices(model_type, presets)
for label, value in choices:
if label == lset_name: return value
return lset_name
def validate_delete_lset(state, lset_name):
lset_name = get_lset_name(state, lset_name)
if len(lset_name) == 0 :
gr.Info(f"Choose a Preset to delete")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
elif "/" in lset_name or "\\" in lset_name:
gr.Info(f"You can't Delete a Profile")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
else:
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
def validate_save_lset(state, lset_name):
lset_name = get_lset_name(state, lset_name)
if len(lset_name) == 0:
gr.Info("Please enter a name for the preset")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
elif "/" in lset_name or "\\" in lset_name:
gr.Info(f"You can't Edit a Profile")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
else:
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
def cancel_lset():
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
if lset_name.endswith(".json") or lset_name.endswith(".lset"):
lset_name = os.path.splitext(lset_name)[0]
loras_presets = state["loras_presets"]
loras = state["loras"]
if state.get("validate_success",0) == 0:
pass
lset_name = get_lset_name(state, lset_name)
if len(lset_name) == 0:
gr.Info("Please enter a name for the preset / settings file")
lset_choices =[("Please enter a name for a Lora Preset / Settings file","")]
else:
lset_name = sanitize_file_name(lset_name)
lset_name = lset_name.replace('\u2500',"").strip()
if save_lset_prompt_cbox ==2:
lset = collect_current_model_settings(state)
extension = ".json"
else:
from shared.utils.loras_mutipliers import extract_loras_side
loras_choices, loras_mult_choices = extract_loras_side(loras_choices, loras_mult_choices, "after")
lset = {"loras" : loras_choices, "loras_mult" : loras_mult_choices}
if save_lset_prompt_cbox!=1:
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")]
prompt = "\n".join(prompts)
if len(prompt) > 0:
lset["prompt"] = prompt
lset["full_prompt"] = save_lset_prompt_cbox ==1
extension = ".lset"
if lset_name.endswith(".json") or lset_name.endswith(".lset"): lset_name = os.path.splitext(lset_name)[0]
old_lset_name = lset_name + ".json"
if not old_lset_name in loras_presets:
old_lset_name = lset_name + ".lset"
if not old_lset_name in loras_presets: old_lset_name = ""
lset_name = lset_name + extension
model_type = get_state_model_type(state)
lora_dir = get_lora_dir(model_type)
full_lset_name_filename = os.path.join(lora_dir, lset_name )
with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(lset, indent=4))
if len(old_lset_name) > 0 :
if save_lset_prompt_cbox ==2:
gr.Info(f"Settings File '{lset_name}' has been updated")
else:
gr.Info(f"Lora Preset '{lset_name}' has been updated")
if old_lset_name != lset_name:
pos = loras_presets.index(old_lset_name)
loras_presets[pos] = lset_name
shutil.move( os.path.join(lora_dir, old_lset_name), get_available_filename(lora_dir, old_lset_name + ".bkp" ) )
else:
if save_lset_prompt_cbox ==2:
gr.Info(f"Settings File '{lset_name}' has been created")
else:
gr.Info(f"Lora Preset '{lset_name}' has been created")
loras_presets.append(lset_name)
state["loras_presets"] = loras_presets
lset_choices = compute_lset_choices(model_type, loras_presets)
lset_choices.append( (get_new_preset_msg(), ""))
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def delete_lset(state, lset_name):
loras_presets = state["loras_presets"]
lset_name = get_lset_name(state, lset_name)
model_type = get_state_model_type(state)
if len(lset_name) > 0:
lset_name_filename = os.path.join( get_lora_dir(model_type), sanitize_file_name(lset_name))
if not os.path.isfile(lset_name_filename):
gr.Info(f"Preset '{lset_name}' not found ")
return [gr.update()]*7
os.remove(lset_name_filename)
lset_choices = compute_lset_choices(None, loras_presets)
pos = next( (i for i, item in enumerate(lset_choices) if item[1]==lset_name ), -1)
gr.Info(f"Lora Preset '{lset_name}' has been deleted")
loras_presets.remove(lset_name)
else:
pos = -1
gr.Info(f"Choose a Preset / Settings File to delete")
state["loras_presets"] = loras_presets
lset_choices = compute_lset_choices(model_type, loras_presets)
lset_choices.append((get_new_preset_msg(), ""))
selected_lset_name = "" if pos < 0 else lset_choices[min(pos, len(lset_choices)-1)][1]
return gr.Dropdown(choices=lset_choices, value= selected_lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
def get_updated_loras_dropdown(loras, loras_choices):
loras_choices = [os.path.basename(choice) for choice in loras_choices]
loras_choices_dict = { choice : True for choice in loras_choices}
for lora in loras:
loras_choices_dict.pop(lora, False)
new_loras = loras[:]
for choice, _ in loras_choices_dict.items():
new_loras.append(choice)
new_loras_dropdown= [ ( os.path.splitext(choice)[0], choice) for choice in new_loras ]
return new_loras, new_loras_dropdown
def refresh_lora_list(state, lset_name, loras_choices):
model_type= get_state_model_type(state)
loras, loras_presets, _, _, _, _ = setup_loras(model_type, None, get_lora_dir(model_type), lora_preselected_preset, None)
state["loras_presets"] = loras_presets
gc.collect()
loras, new_loras_dropdown = get_updated_loras_dropdown(loras, loras_choices)
state["loras"] = loras
model_type = get_state_model_type(state)
lset_choices = compute_lset_choices(model_type, loras_presets)
lset_choices.append((get_new_preset_msg( state["advanced"]), ""))
if not lset_name in loras_presets:
lset_name = ""
if wan_model != None:
trans_err = get_transformer_model(wan_model)
if hasattr(wan_model, "get_trans_lora"):
trans_err, _ = wan_model.get_trans_lora()
errors = getattr(trans_err, "_loras_errors", "")
if errors !=None and len(errors) > 0:
error_files = [path for path, _ in errors]
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
else:
gr.Info("Lora List has been refreshed")
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Dropdown(choices=new_loras_dropdown, value= loras_choices)
def update_lset_type(state, lset_name):
return 1 if lset_name.endswith(".lset") else 2
from shared.utils.loras_mutipliers import merge_loras_settings
def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_mult_choices, prompt):
state["apply_success"] = 0
lset_name = get_lset_name(state, lset_name)
if len(lset_name) == 0:
gr.Info("Please choose a Lora Preset or Setting File in the list or create one")
return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
else:
current_model_type = get_state_model_type(state)
ui_settings = get_current_model_settings(state)
old_activated_loras, old_loras_multipliers = ui_settings.get("activated_loras", []), ui_settings.get("loras_multipliers", ""),
if lset_name.endswith(".lset"):
loras = state["loras"]
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(current_model_type, lset_name, loras)
if full_prompt:
prompt = preset_prompt
elif len(preset_prompt) > 0:
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
prompt = "\n".join(prompts)
prompt = preset_prompt + '\n' + prompt
loras_choices, loras_mult_choices = merge_loras_settings(old_activated_loras, old_loras_multipliers, loras_choices, loras_mult_choices, "merge after")
loras_choices = update_loras_url_cache(get_lora_dir(current_model_type), loras_choices)
loras_choices = [os.path.basename(lora) for lora in loras_choices]
gr.Info(f"Lora Preset '{lset_name}' has been applied")
state["apply_success"] = 1
wizard_prompt_activated = "on"
return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, get_unique_id(), gr.update(), gr.update(), gr.update(), gr.update()
else:
accelerator_profile = len(Path(lset_name).parts)>1
lset_path = os.path.join("profiles" if accelerator_profile else get_lora_dir(current_model_type), lset_name)
configs, _, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38, merge_loras = "merge after" if len(Path(lset_name).parts)<=1 else "merge before" )
if configs == None:
gr.Info("File not supported")
return [gr.update()] * 9
model_type = configs["model_type"]
configs["lset_name"] = lset_name
if accelerator_profile:
gr.Info(f"Accelerator Profile '{os.path.splitext(os.path.basename(lset_name))[0]}' has been applied")
else:
gr.Info(f"Settings File '{os.path.basename(lset_name)}' has been applied")
help = configs.get("help", None)
if help is not None: gr.Info(help)
if model_type == current_model_type:
set_model_settings(state, current_model_type, configs)
return *[gr.update()] * 4, gr.update(), gr.update(), gr.update(), gr.update(), get_unique_id()
else:
set_model_settings(state, model_type, configs)
return *[gr.update()] * 4, gr.update(), *generate_dropdown_model_list(model_type), gr.update()
def extract_prompt_from_wizard(state, variables_names, prompt, wizard_prompt, allow_null_values, *args):
prompts = wizard_prompt.replace("\r" ,"").split("\n")
new_prompts = []
macro_already_written = False
for prompt in prompts:
if not macro_already_written and not prompt.startswith("#") and "{" in prompt and "}" in prompt:
variables = variables_names.split("\n")
values = args[:len(variables)]
macro = "! "
for i, (variable, value) in enumerate(zip(variables, values)):
if len(value) == 0 and not allow_null_values:
return prompt, "You need to provide a value for '" + variable + "'"
sub_values= [ "\"" + sub_value + "\"" for sub_value in value.split("\n") ]
value = ",".join(sub_values)
if i>0:
macro += " : "
macro += "{" + variable + "}"+ f"={value}"
if len(variables) > 0:
macro_already_written = True
new_prompts.append(macro)
new_prompts.append(prompt)
else:
new_prompts.append(prompt)
prompt = "\n".join(new_prompts)
return prompt, ""
def validate_wizard_prompt(state, wizard_prompt_activated, wizard_variables_names, prompt, wizard_prompt, *args):
state["validate_success"] = 0
if wizard_prompt_activated != "on":
state["validate_success"] = 1
return prompt
prompt, errors = extract_prompt_from_wizard(state, wizard_variables_names, prompt, wizard_prompt, False, *args)
if len(errors) > 0:
gr.Info(errors)
return prompt
state["validate_success"] = 1
return prompt
def fill_prompt_from_wizard(state, wizard_prompt_activated, wizard_variables_names, prompt, wizard_prompt, *args):
if wizard_prompt_activated == "on":
prompt, errors = extract_prompt_from_wizard(state, wizard_variables_names, prompt, wizard_prompt, True, *args)
if len(errors) > 0:
gr.Info(errors)
wizard_prompt_activated = "off"
return wizard_prompt_activated, "", gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX
def extract_wizard_prompt(prompt):
variables = []
values = {}
prompts = prompt.replace("\r" ,"").split("\n")
if sum(prompt.startswith("!") for prompt in prompts) > 1:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
new_prompts = []
errors = ""
for prompt in prompts:
if prompt.startswith("!"):
variables, errors = prompt_parser.extract_variable_names(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
if len(variables) > PROMPT_VARS_MAX:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
values, errors = prompt_parser.extract_variable_values(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
else:
variables_extra, errors = prompt_parser.extract_variable_names(prompt)
if len(errors) > 0:
return "", variables, values, "Error parsing Prompt templace: " + errors
variables += variables_extra
variables = [var for pos, var in enumerate(variables) if var not in variables[:pos]]
if len(variables) > PROMPT_VARS_MAX:
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
new_prompts.append(prompt)
wizard_prompt = "\n".join(new_prompts)
return wizard_prompt, variables, values, errors
def fill_wizard_prompt(state, wizard_prompt_activated, prompt, wizard_prompt):
def get_hidden_textboxes(num = PROMPT_VARS_MAX ):
return [gr.Textbox(value="", visible=False)] * num
hidden_column = gr.Column(visible = False)
visible_column = gr.Column(visible = True)
wizard_prompt_activated = "off"
if state["advanced"] or state.get("apply_success") != 1:
return wizard_prompt_activated, gr.Text(), prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes()
prompt_parts= []
wizard_prompt, variables, values, errors = extract_wizard_prompt(prompt)
if len(errors) > 0:
gr.Info( errors )
return wizard_prompt_activated, "", gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes()
for variable in variables:
value = values.get(variable, "")
prompt_parts.append(gr.Textbox( placeholder=variable, info= variable, visible= True, value= "\n".join(value) ))
any_macro = len(variables) > 0
prompt_parts += get_hidden_textboxes(PROMPT_VARS_MAX-len(prompt_parts))
variables_names= "\n".join(variables)
wizard_prompt_activated = "on"
return wizard_prompt_activated, variables_names, gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts
def switch_prompt_type(state, wizard_prompt_activated_var, wizard_variables_names, prompt, wizard_prompt, *prompt_vars):
if state["advanced"]:
return fill_prompt_from_wizard(state, wizard_prompt_activated_var, wizard_variables_names, prompt, wizard_prompt, *prompt_vars)
else:
state["apply_success"] = 1
return fill_wizard_prompt(state, wizard_prompt_activated_var, prompt, wizard_prompt)
visible= False
def switch_advanced(state, new_advanced, lset_name):
state["advanced"] = new_advanced
loras_presets = state["loras_presets"]
model_type = get_state_model_type(state)
lset_choices = compute_lset_choices(model_type, loras_presets)
lset_choices.append((get_new_preset_msg(new_advanced), ""))
server_config["last_advanced_choice"] = new_advanced
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config, indent=4))
if lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False) or lset_name=="":
lset_name = get_new_preset_msg(new_advanced)
if only_allow_edit_in_advanced:
return gr.Row(visible=new_advanced), gr.Row(visible=new_advanced), gr.Button(visible=new_advanced), gr.Row(visible= not new_advanced), gr.Dropdown(choices=lset_choices, value= lset_name)
else:
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None ):
state = inputs.pop("state")
plugin_data = inputs.pop("plugin_data", {})
if "loras_choices" in inputs:
loras_choices = inputs.pop("loras_choices")
inputs.pop("model_filename", None)
else:
loras_choices = inputs["activated_loras"]
if model_type == None: model_type = get_state_model_type(state)
inputs["activated_loras"] = update_loras_url_cache(get_lora_dir(model_type), loras_choices)
if target in ["state", "edit_state"]:
return inputs
if "lset_name" in inputs:
inputs.pop("lset_name")
unsaved_params = ATTACHMENT_KEYS
for k in unsaved_params:
inputs.pop(k)
inputs["type"] = get_model_record(get_model_name(model_type))
inputs["settings_version"] = settings_version
model_def = get_model_def(model_type)
base_model_type = get_base_model_type(model_type)
model_family = get_model_family(base_model_type)
if model_type != base_model_type:
inputs["base_model_type"] = base_model_type
diffusion_forcing = base_model_type in ["sky_df_1.3B", "sky_df_14B"]
vace = test_vace_module(base_model_type)
t2v= test_class_t2v(base_model_type)
ltxv = base_model_type in ["ltxv_13B"]
if target == "settings":
return inputs
image_outputs = inputs.get("image_mode",0) > 0
pop=[]
if not model_def.get("audio_only", False):
pop += [ "pace", "exaggeration", "temperature"]
if "force_fps" in inputs and len(inputs["force_fps"])== 0:
pop += ["force_fps"]
if model_def.get("sample_solvers", None) is None:
pop += ["sample_solver"]
if any_audio_track(base_model_type) or not get_mmaudio_settings(server_config)[0]:
pop += ["MMAudio_setting", "MMAudio_prompt", "MMAudio_neg_prompt"]
image_prompt_type = inputs.get("image_prompt_type", "") or ""
video_prompt_type = inputs["video_prompt_type"]
if "G" not in video_prompt_type:
pop += ["denoising_strength"]
if "G" not in video_prompt_type and not model_def.get("mask_strength_always_enabled", False):
pop += ["masking_strength"]
if len(model_def.get("input_video_strength", ""))==0 or not any_letters(image_prompt_type, "SVL"):
pop += ["input_video_strength"]
if not (server_config.get("enhancer_enabled", 0) > 0 and server_config.get("enhancer_mode", 0) == 0):
pop += ["prompt_enhancer"]
if model_def.get("model_modes", None) is None:
pop += ["model_mode"]
if model_def.get("guide_custom_choices", None ) is None and model_def.get("guide_preprocessing", None ) is None:
pop += ["keep_frames_video_guide", "mask_expand"]
if not "I" in video_prompt_type:
pop += ["remove_background_images_ref"]
if not model_def.get("any_image_refs_relative_size", False):
pop += ["image_refs_relative_size"]
if not vace:
pop += ["frames_positions"]
if model_def.get("control_net_weight_name", None) is None:
pop += ["control_net_weight", "control_net_weight2"]
if not len(model_def.get("control_net_weight_alt_name", "")) >0:
pop += ["control_net_weight_alt"]
if model_def.get("audio_scale_name", None) is None:
pop += ["audio_scale"]
if not model_def.get("motion_amplitude", False):
pop += ["motion_amplitude"]
if model_def.get("video_guide_outpainting", None) is None:
pop += ["video_guide_outpainting"]
if not (vace or t2v):
pop += ["min_frames_if_references"]
if not (diffusion_forcing or ltxv or vace):
pop += ["keep_frames_video_source"]
if not test_any_sliding_window( base_model_type):
pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames", "sliding_window_color_correction_strength"]
if not model_def.get("audio_guidance", False):
pop += ["audio_guidance_scale", "speakers_locations"]
if not model_def.get("embedded_guidance", False):
pop += ["embedded_guidance_scale"]
if model_def.get("alt_guidance", None) is None:
pop += ["alt_guidance_scale"]
if not (model_def.get("tea_cache", False) or model_def.get("mag_cache", False)) :
pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"]
guidance_max_phases = model_def.get("guidance_max_phases", 0)
guidance_phases = inputs.get("guidance_phases", 1)
visible_phases = model_def.get("visible_phases", guidance_phases)
if guidance_max_phases < 1 or visible_phases < 1:
pop += ["guidance_scale", "guidance_phases"]
if guidance_max_phases < 2 or guidance_phases < 2 or visible_phases < 2:
pop += ["guidance2_scale", "switch_threshold"]
if guidance_max_phases < 3 or guidance_phases < 3 or visible_phases < 3:
pop += ["guidance3_scale", "switch_threshold2", "model_switch_phase"]
if not model_def.get("flow_shift", False):
pop += ["flow_shift"]
if model_def.get("no_negative_prompt", False) :
pop += ["negative_prompt" ]
if not model_def.get("skip_layer_guidance", False):
pop += ["slg_switch", "slg_layers", "slg_start_perc", "slg_end_perc"]
if not model_def.get("cfg_zero", False):
pop += [ "cfg_zero_step" ]
if not model_def.get("cfg_star", False):
pop += ["cfg_star_switch" ]
if not model_def.get("adaptive_projected_guidance", False):
pop += ["apg_switch"]
if not model_def.get("NAG", False):
pop +=["NAG_scale", "NAG_tau", "NAG_alpha" ]
for k in pop:
if k in inputs: inputs.pop(k)
if target == "metadata":
inputs = {k: v for k,v in inputs.items() if v != None }
if hasattr(app, 'plugin_manager'):
inputs = app.plugin_manager.run_data_hooks(
'before_metadata_save',
configs=inputs,
plugin_data=plugin_data,
model_type=model_type
)
return inputs
def get_function_arguments(func, locals):
args_names = list(inspect.signature(func).parameters)
kwargs = typing.OrderedDict()
for k in args_names:
kwargs[k] = locals[k]
return kwargs
def init_generate(state, input_file_list, last_choice, audio_files_paths, audio_file_selected):
gen = get_gen_info(state)
file_list, file_settings_list = get_file_list(state, input_file_list)
set_file_choice(gen, file_list, last_choice)
audio_file_list, audio_file_settings_list = get_file_list(state, unpack_audio_list(audio_files_paths), audio_files=True)
set_file_choice(gen, audio_file_list, audio_file_selected, audio_files=True)
return get_unique_id(), ""
def video_to_control_video(state, input_file_list, choice):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info("Selected Video was copied to Control Video input")
return file_list[choice]
def video_to_source_video(state, input_file_list, choice):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info("Selected Video was copied to Source Video input")
return file_list[choice]
def image_to_ref_image_add(state, input_file_list, choice, target, target_name):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
if model_def.get("one_image_ref_needed", False):
gr.Info(f"Selected Image was set to {target_name}")
target =[file_list[choice]]
else:
gr.Info(f"Selected Image was added to {target_name}")
if target == None:
target =[]
target.append( file_list[choice])
return target
def image_to_ref_image_set(state, input_file_list, choice, target, target_name):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info(f"Selected Image was copied to {target_name}")
return file_list[choice]
def image_to_ref_image_guide(state, input_file_list, choice):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update(), gr.update()
ui_settings = get_current_model_settings(state)
gr.Info(f"Selected Image was copied to Control Image")
new_image = file_list[choice]
if ui_settings["image_mode"]==2 or True:
return new_image, new_image
else:
return new_image, None
def audio_to_source_set(state, input_file_list, choice, target_name):
file_list, file_settings_list = get_file_list(state, unpack_audio_list(input_file_list), audio_files=True)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
gr.Info(f"Selected Audio File was copied to {target_name}")
return file_list[choice]
def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation):
gen = get_gen_info(state)
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
return gr.update(), gr.update(), gr.update()
if not (file_list[choice].endswith(".mp4") or file_list[choice].endswith(".mkv")):
gr.Info("Post processing is only available with Videos")
return gr.update(), gr.update(), gr.update()
overrides = {
"temporal_upsampling":PP_temporal_upsampling,
"spatial_upsampling":PP_spatial_upsampling,
"film_grain_intensity": PP_film_grain_intensity,
"film_grain_saturation": PP_film_grain_saturation,
}
gen["edit_video_source"] = file_list[choice]
gen["edit_overrides"] = overrides
in_progress = gen.get("in_progress", False)
return "edit_postprocessing", get_unique_id() if not in_progress else gr.update(), get_unique_id() if in_progress else gr.update()
def remux_audio(state, input_file_list, choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio):
gen = get_gen_info(state)
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
return gr.update(), gr.update(), gr.update()
if not (file_list[choice].endswith(".mp4") or file_list[choice].endswith(".mkv")):
gr.Info("Post processing is only available with Videos")
return gr.update(), gr.update(), gr.update()
overrides = {
"MMAudio_setting" : PP_MMAudio_setting,
"MMAudio_prompt" : PP_MMAudio_prompt,
"MMAudio_neg_prompt": PP_MMAudio_neg_prompt,
"seed": PP_MMAudio_seed,
"repeat_generation": PP_repeat_generation,
"audio_source": PP_custom_audio,
}
gen["edit_video_source"] = file_list[choice]
gen["edit_overrides"] = overrides
in_progress = gen.get("in_progress", False)
return "edit_remux", get_unique_id() if not in_progress else gr.update(), get_unique_id() if in_progress else gr.update()
def eject_video_from_gallery(state, input_file_list, choice):
gen = get_gen_info(state)
file_list, file_settings_list = get_file_list(state, input_file_list)
with lock:
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
return gr.update(), gr.update(), gr.update()
extend_list = file_list[choice + 1:] # inplace List change
file_list[:] = file_list[:choice]
file_list.extend(extend_list)
extend_list = file_settings_list[choice + 1:]
file_settings_list[:] = file_settings_list[:choice]
file_settings_list.extend(extend_list)
choice = min(choice, len(file_list))
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
def eject_audio_from_gallery(state, input_file_list, choice):
gen = get_gen_info(state)
file_list, file_settings_list = get_file_list(state, unpack_audio_list(input_file_list), audio_files=True)
with lock:
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
return [gr.update()] * 6
extend_list = file_list[choice + 1:] # inplace List change
file_list[:] = file_list[:choice]
file_list.extend(extend_list)
extend_list = file_settings_list[choice + 1:]
file_settings_list[:] = file_settings_list[:choice]
file_settings_list.extend(extend_list)
choice = min(choice, len(file_list))
return *pack_audio_gallery_state(file_list, choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
def add_videos_to_gallery(state, input_file_list, choice, audio_files_paths, audio_file_selected, files_to_load):
gen = get_gen_info(state)
if files_to_load == None:
return [gr.update()]*7
new_audio= False
new_video= False
file_list, file_settings_list = get_file_list(state, input_file_list)
audio_file_list, audio_file_settings_list = get_file_list(state, unpack_audio_list(audio_files_paths), audio_files= True)
audio_file = False
with lock:
valid_files_count = 0
invalid_files_count = 0
for file_path in files_to_load:
file_settings, _, audio_file = get_settings_from_file(state, file_path, False, False, False)
if file_settings == None:
audio_file = False
fps = 0
try:
if has_audio_file_extension(file_path):
audio_file = True
elif has_video_file_extension(file_path):
fps, width, height, frames_count = get_video_info(file_path)
elif has_image_file_extension(file_path):
width, height = Image.open(file_path).size
fps = 1
except:
pass
if fps == 0 and not audio_file:
invalid_files_count += 1
continue
if audio_file:
new_audio= True
audio_file_list.append(file_path)
audio_file_settings_list.append(file_settings)
else:
new_video= True
file_list.append(file_path)
file_settings_list.append(file_settings)
valid_files_count +=1
if valid_files_count== 0 and invalid_files_count ==0:
gr.Info("No Video to Add")
else:
txt = ""
if valid_files_count > 0:
txt = f"{valid_files_count} files were added. " if valid_files_count > 1 else f"One file was added."
if invalid_files_count > 0:
txt += f"Unable to add {invalid_files_count} files which were invalid. " if invalid_files_count > 1 else f"Unable to add one file which was invalid."
gr.Info(txt)
if new_video:
choice = len(file_list) - 1
else:
choice = min(len(file_list) - 1, choice)
gen["selected"] = choice
if new_audio:
audio_file_selected = len(audio_file_list) - 1
else:
audio_file_selected = min(len(file_list) - 1, audio_file_selected)
gen["audio_selected"] = audio_file_selected
gallery_tabs = gr.Tabs(selected= "audio" if audio_file else "video_images")
# return gallery_tabs, gr.Gallery(value = file_list) if audio_file else gr.Gallery(value = file_list, selected_index=choice, preview= True) , *pack_audio_gallery_state(audio_file_list, audio_file_selected), gr.Files(value=[]), gr.Tabs(selected="video_info"), "audio" if audio_file else "video"
return gallery_tabs, 1 if audio_file else 0, gr.Gallery(value = file_list, selected_index=choice, preview= True) , *pack_audio_gallery_state(audio_file_list, audio_file_selected), gr.Files(value=[]), gr.Tabs(selected="video_info"), "audio" if audio_file else "video"
def get_model_settings(state, model_type):
all_settings = state.get("all_settings", None)
return None if all_settings == None else all_settings.get(model_type, None)
def set_model_settings(state, model_type, settings):
all_settings = state.get("all_settings", None)
if all_settings == None:
all_settings = {}
state["all_settings"] = all_settings
all_settings[model_type] = settings
def collect_current_model_settings(state):
model_type = get_state_model_type(state)
settings = get_model_settings(state, model_type)
settings["state"] = state
settings = prepare_inputs_dict("metadata", settings)
settings["model_filename"] = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)
settings["model_type"] = model_type
return settings
def export_settings(state):
model_type = get_state_model_type(state)
text = json.dumps(collect_current_model_settings(state), indent=4)
text_base64 = base64.b64encode(text.encode('utf8')).decode('utf-8')
return text_base64, sanitize_file_name(model_type + "_" + datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") + ".json")
def extract_and_apply_source_images(file_path, current_settings):
from shared.utils.video_metadata import extract_source_images
if not os.path.isfile(file_path): return 0
extracted_files = extract_source_images(file_path)
if not extracted_files: return 0
applied_count = 0
for name in image_names_list:
if name in extracted_files:
img = extracted_files[name]
img = img if isinstance(img,list) else [img]
applied_count += len(img)
current_settings[name] = img
return applied_count
def use_video_settings(state, input_file_list, choice, source):
gen = get_gen_info(state)
any_audio = source == "audio"
if any_audio:
input_file_list = unpack_audio_list(input_file_list)
file_list, file_settings_list = get_file_list(state, input_file_list, audio_files=any_audio)
if choice != None and len(file_list)>0:
choice= max(0, choice)
configs = file_settings_list[choice]
file_name= file_list[choice]
if configs == None:
gr.Info("No Settings to Extract")
else:
current_model_type = get_state_model_type(state)
model_type = configs["model_type"]
models_compatible = are_model_types_compatible(model_type,current_model_type)
if models_compatible:
model_type = current_model_type
defaults = get_model_settings(state, model_type)
defaults = get_default_settings(model_type) if defaults == None else defaults
defaults.update(configs)
defaults["model_type"] = model_type
prompt = configs.get("prompt", "")
if has_audio_file_extension(file_name):
set_model_settings(state, model_type, defaults)
gr.Info(f"Settings Loaded from Audio File with prompt '{prompt[:100]}'")
elif has_image_file_extension(file_name):
set_model_settings(state, model_type, defaults)
gr.Info(f"Settings Loaded from Image with prompt '{prompt[:100]}'")
elif has_video_file_extension(file_name):
extracted_images = extract_and_apply_source_images(file_name, defaults)
set_model_settings(state, model_type, defaults)
info_msg = f"Settings Loaded from Video with prompt '{prompt[:100]}'"
if extracted_images:
info_msg += f" + {extracted_images} source {'image' if extracted_images == 1 else 'images'} extracted"
gr.Info(info_msg)
if models_compatible:
return gr.update(), gr.update(), gr.update(), str(time.time())
else:
return *generate_dropdown_model_list(model_type), gr.update()
else:
gr.Info(f"Please Select a File")
return gr.update(), gr.update(), gr.update(), gr.update()
loras_url_cache = None
def update_loras_url_cache(lora_dir, loras_selected):
if loras_selected is None:
return None
global loras_url_cache
loras_cache_file = "loras_url_cache.json"
if loras_url_cache is None:
if os.path.isfile(loras_cache_file):
try:
with open(loras_cache_file, 'r', encoding='utf-8') as f:
loras_url_cache = json.load(f)
except:
loras_url_cache = {}
else:
loras_url_cache = {}
new_loras_selected = []
update = False
for lora in loras_selected:
base_name = os.path.basename(lora)
local_name = os.path.join(lora_dir, base_name)
url = loras_url_cache.get(local_name, base_name)
if (lora.startswith("http:") or lora.startswith("https:")) and url != lora:
loras_url_cache[local_name]=lora
update = True
new_loras_selected.append(url)
if update:
with open(loras_cache_file, "w", encoding="utf-8") as writer:
writer.write(json.dumps(loras_url_cache, indent=4))
return new_loras_selected
def _ensure_loras_url_cache():
update_loras_url_cache("", [])
def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, switch_type_if_compatible, min_settings_version = 0, merge_loras = None):
configs = None
any_image_or_video = False
any_audio = False
if file_path.endswith(".json") and allow_json:
try:
with open(file_path, 'r', encoding='utf-8') as f:
configs = json.load(f)
except:
pass
elif file_path.endswith(".mp4") or file_path.endswith(".mkv"):
from shared.utils.video_metadata import read_metadata_from_video
try:
configs = read_metadata_from_video(file_path)
if configs:
any_image_or_video = True
except:
pass
elif has_image_file_extension(file_path):
try:
configs = read_image_metadata(file_path)
any_image_or_video = True
except:
pass
elif has_audio_file_extension(file_path):
try:
configs = read_audio_metadata(file_path)
any_audio = True
except:
pass
if configs is None: return None, False, False
try:
if isinstance(configs, dict):
if (not merge_with_defaults) and not "WanGP" in configs.get("type", ""): configs = None
else:
configs = None
except:
configs = None
if configs is None: return None, False, False
current_model_type = get_state_model_type(state)
model_type = configs.get("model_type", None)
if get_base_model_type(model_type) == None:
model_type = configs.get("base_model_type", None)
if model_type == None:
model_filename = configs.get("model_filename", "")
model_type = get_model_type(model_filename)
if model_type == None:
model_type = current_model_type
elif not model_type in model_types:
model_type = current_model_type
if switch_type_if_compatible and are_model_types_compatible(model_type,current_model_type):
model_type = current_model_type
old_loras_selected = old_loras_multipliers = None
if merge_with_defaults:
defaults = get_model_settings(state, model_type)
defaults = get_default_settings(model_type) if defaults == None else defaults
if merge_loras is not None and model_type == current_model_type:
old_loras_selected, old_loras_multipliers = defaults.get("activated_loras", []), defaults.get("loras_multipliers", ""),
defaults.update(configs)
configs = defaults
loras_selected =configs.get("activated_loras", [])
loras_multipliers = configs.get("loras_multipliers", "")
if loras_selected is not None and len(loras_selected) > 0:
loras_selected = update_loras_url_cache(get_lora_dir(model_type), loras_selected)
if old_loras_selected is not None:
if len(old_loras_selected) == 0 and "|" in loras_multipliers:
pass
else:
old_loras_selected = update_loras_url_cache(get_lora_dir(model_type), old_loras_selected)
loras_selected, loras_multipliers = merge_loras_settings(old_loras_selected, old_loras_multipliers, loras_selected, loras_multipliers, merge_loras )
configs["activated_loras"]= loras_selected or []
configs["loras_multipliers"] = loras_multipliers
fix_settings(model_type, configs, min_settings_version)
configs["model_type"] = model_type
return configs, any_image_or_video, any_audio
def record_image_mode_tab(state, evt:gr.SelectData):
state["image_mode_tab"] = evt.index
def switch_image_mode(state):
image_mode = state.get("image_mode_tab", 0)
model_type =get_state_model_type(state)
ui_defaults = get_model_settings(state, model_type)
ui_defaults["image_mode"] = image_mode
video_prompt_type = ui_defaults.get("video_prompt_type", "")
model_def = get_model_def( model_type)
inpaint_support = model_def.get("inpaint_support", False)
if inpaint_support:
model_type = get_state_model_type(state)
inpaint_cache= state.get("inpaint_cache", None)
if inpaint_cache is None:
state["inpaint_cache"] = inpaint_cache = {}
model_cache = inpaint_cache.get(model_type, None)
if model_cache is None:
inpaint_cache[model_type] = model_cache ={}
video_prompt_inpaint_mode = model_def.get("inpaint_video_prompt_type", "VAG")
video_prompt_image_mode = "KI"
old_video_prompt_type = video_prompt_type
if image_mode == 1:
model_cache[2] = video_prompt_type
video_prompt_type = model_cache.get(1, None)
if video_prompt_type is None:
video_prompt_type = del_in_sequence(old_video_prompt_type, video_prompt_inpaint_mode + all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_image_mode)
elif image_mode == 2:
model_cache[1] = video_prompt_type
video_prompt_type = model_cache.get(2, None)
if video_prompt_type is None:
video_prompt_type = del_in_sequence(old_video_prompt_type, video_prompt_image_mode + all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_inpaint_mode)
ui_defaults["video_prompt_type"] = video_prompt_type
return str(time.time())
def load_settings_from_file(state, file_path):
gen = get_gen_info(state)
if file_path==None:
return gr.update(), gr.update(), gr.update(), gr.update(), None
configs, any_video_or_image_file, any_audio = get_settings_from_file(state, file_path, True, True, True)
if configs == None:
gr.Info("File not supported")
return gr.update(), gr.update(), gr.update(), gr.update(), None
current_model_type = get_state_model_type(state)
model_type = configs["model_type"]
prompt = configs.get("prompt", "")
is_image = configs.get("is_image", False)
# Extract and apply embedded source images from video files
extracted_images = 0
if file_path.endswith('.mkv') or file_path.endswith('.mp4'):
extracted_images = extract_and_apply_source_images(file_path, configs)
if any_audio:
gr.Info(f"Settings Loaded from Audio file with prompt '{prompt[:100]}'")
elif any_video_or_image_file:
info_msg = f"Settings Loaded from {'Image' if is_image else 'Video'} generated with prompt '{prompt[:100]}'"
if extracted_images > 0:
info_msg += f" + {extracted_images} source image(s) extracted and applied"
gr.Info(info_msg)
else:
gr.Info(f"Settings Loaded from Settings file with prompt '{prompt[:100]}'")
if model_type == current_model_type:
set_model_settings(state, current_model_type, configs)
return gr.update(), gr.update(), gr.update(), str(time.time()), None
else:
set_model_settings(state, model_type, configs)
return *generate_dropdown_model_list(model_type), gr.update(), None
def goto_model_type(state, model_type):
gen = get_gen_info(state)
return *generate_dropdown_model_list(model_type), gr.update()
def reset_settings(state):
model_type = get_state_model_type(state)
ui_defaults = get_default_settings(model_type)
set_model_settings(state, model_type, ui_defaults)
gr.Info(f"Default Settings have been Restored")
return str(time.time())
def save_inputs(
target,
image_mask_guide,
lset_name,
image_mode,
prompt,
negative_prompt,
resolution,
video_length,
batch_size,
seed,
force_fps,
num_inference_steps,
guidance_scale,
guidance2_scale,
guidance3_scale,
switch_threshold,
switch_threshold2,
guidance_phases,
model_switch_phase,
alt_guidance_scale,
audio_guidance_scale,
audio_scale,
flow_shift,
sample_solver,
embedded_guidance_scale,
repeat_generation,
multi_prompts_gen_type,
multi_images_gen_type,
skip_steps_cache_type,
skip_steps_multiplier,
skip_steps_start_step_perc,
loras_choices,
loras_multipliers,
image_prompt_type,
image_start,
image_end,
model_mode,
video_source,
keep_frames_video_source,
input_video_strength,
video_guide_outpainting,
video_prompt_type,
image_refs,
frames_positions,
video_guide,
image_guide,
keep_frames_video_guide,
denoising_strength,
masking_strength,
video_mask,
image_mask,
control_net_weight,
control_net_weight2,
control_net_weight_alt,
motion_amplitude,
mask_expand,
audio_guide,
audio_guide2,
custom_guide,
audio_source,
audio_prompt_type,
speakers_locations,
sliding_window_size,
sliding_window_overlap,
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
image_refs_relative_size,
remove_background_images_ref,
temporal_upsampling,
spatial_upsampling,
film_grain_intensity,
film_grain_saturation,
MMAudio_setting,
MMAudio_prompt,
MMAudio_neg_prompt,
RIFLEx_setting,
NAG_scale,
NAG_tau,
NAG_alpha,
slg_switch,
slg_layers,
slg_start_perc,
slg_end_perc,
apg_switch,
cfg_star_switch,
cfg_zero_step,
prompt_enhancer,
min_frames_if_references,
override_profile,
pace,
exaggeration,
temperature,
output_filename,
mode,
state,
plugin_data,
):
if state.pop("ignore_save_form", False):
return
model_type = get_state_model_type(state)
if image_mask_guide is not None and image_mode >= 1 and video_prompt_type is not None and "A" in video_prompt_type and not "U" in video_prompt_type:
# if image_mask_guide is not None and image_mode == 2:
if "background" in image_mask_guide:
image_guide = image_mask_guide["background"]
if "layers" in image_mask_guide and len(image_mask_guide["layers"])>0:
image_mask = image_mask_guide["layers"][0]
image_mask_guide = None
inputs = get_function_arguments(save_inputs, locals())
inputs.pop("target")
inputs.pop("image_mask_guide")
cleaned_inputs = prepare_inputs_dict(target, inputs)
if target == "settings":
defaults_filename = get_settings_file_name(model_type)
with open(defaults_filename, "w", encoding="utf-8") as f:
json.dump(cleaned_inputs, f, indent=4)
gr.Info("New Default Settings saved")
elif target == "state":
set_model_settings(state, model_type, cleaned_inputs)
elif target == "edit_state":
state["edit_state"] = cleaned_inputs
def handle_queue_action(state, action_string):
if not action_string:
return gr.HTML(), gr.Tabs(), gr.update()
gen = get_gen_info(state)
queue = gen.get("queue", [])
try:
parts = action_string.split('_')
action = parts[0]
params = parts[1:]
except (IndexError, ValueError):
return update_queue_data(queue), gr.Tabs(), gr.update()
if action == "edit" or action == "silent_edit":
task_id = int(params[0])
with lock:
task_index = next((i for i, task in enumerate(queue) if task['id'] == task_id), -1)
if task_index != -1:
state["editing_task_id"] = task_id
task_data = queue[task_index]
if task_index == 1:
gen["queue_paused_for_edit"] = True
gr.Info("Queue processing will pause after the current generation, as you are editing the next item to generate.")
if action == "edit":
gr.Info(f"Loading task ID {task_id} ('{task_data['prompt'][:50]}...') for editing.")
return update_queue_data(queue), gr.Tabs(selected="edit"), gr.update(visible=True), get_unique_id()
else:
gr.Warning("Task ID not found. It may have already been processed.")
return update_queue_data(queue), gr.Tabs(), gr.update(), gr.update()
elif action == "move" and len(params) == 3 and params[1] == "to":
old_index_str, new_index_str = params[0], params[2]
return move_task(queue, old_index_str, new_index_str), gr.Tabs(), gr.update(), gr.update()
elif action == "remove":
task_id_to_remove = int(params[0])
new_queue_data = remove_task(queue, task_id_to_remove)
gen["prompts_max"] = gen.get("prompts_max", 0) - 1
update_status(state)
return new_queue_data, gr.Tabs(), gr.update(), gr.update()
return update_queue_data(queue), gr.Tabs(), gr.update(), gr.update()
def change_model(state, model_choice):
if model_choice == None:
return
model_filename = get_model_filename(model_choice, transformer_quantization, transformer_dtype_policy)
last_model_per_family = state["last_model_per_family"]
last_model_per_family[get_model_family(model_choice, for_ui= True)] = model_choice
server_config["last_model_per_family"] = last_model_per_family
last_model_per_type = state["last_model_per_type"]
last_model_per_type[get_base_model_type(model_choice)] = model_choice
server_config["last_model_per_type"] = last_model_per_type
server_config["last_model_type"] = model_choice
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config, indent=4))
state["model_type"] = model_choice
header = generate_header(model_choice, compile=compile, attention_mode=attention_mode)
return header
def get_current_model_settings(state):
model_type = get_state_model_type(state)
ui_defaults = get_model_settings(state, model_type)
if ui_defaults == None:
ui_defaults = get_default_settings(model_type)
set_model_settings(state, model_type, ui_defaults)
return ui_defaults
def fill_inputs(state):
ui_defaults = get_current_model_settings(state)
return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
def preload_model_when_switching(state):
global reload_needed, wan_model, offloadobj
if "S" in preload_model_policy:
model_type = get_state_model_type(state)
if model_type != transformer_type:
wan_model = None
release_model()
model_filename = get_model_name(model_type)
yield f"Loading model {model_filename}..."
wan_model, offloadobj = load_models(model_type)
yield f"Model loaded"
reload_needed= False
return
return gr.Text()
def unload_model_if_needed(state):
global wan_model
if "U" in preload_model_policy:
if wan_model != None:
wan_model = None
release_model()
def all_letters(source_str, letters):
for letter in letters:
if not letter in source_str:
return False
return True
def any_letters(source_str, letters):
for letter in letters:
if letter in source_str:
return True
return False
def filter_letters(source_str, letters, default= ""):
ret = ""
for letter in letters:
if letter in source_str:
ret += letter
if len(ret) == 0:
return default
return ret
def add_to_sequence(source_str, letters):
ret = source_str
for letter in letters:
if not letter in source_str:
ret += letter
return ret
def del_in_sequence(source_str, letters):
ret = source_str
for letter in letters:
if letter in source_str:
ret = ret.replace(letter, "")
return ret
def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux):
audio_prompt_type = del_in_sequence(audio_prompt_type, "R")
audio_prompt_type = add_to_sequence(audio_prompt_type, remux)
return audio_prompt_type
def refresh_remove_background_sound(state, audio_prompt_type, remove_background_sound):
audio_prompt_type = del_in_sequence(audio_prompt_type, "V")
if remove_background_sound:
audio_prompt_type = add_to_sequence(audio_prompt_type, "V")
return audio_prompt_type
def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources):
audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPABK")
audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources)
return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)), gr.update(visible= any_letters(audio_prompt_type, "ABXK")), gr.update(visible= any_letters(audio_prompt_type,"AB"))
def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio):
image_prompt_type = del_in_sequence(image_prompt_type, "VLTS")
image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
any_video_source = len(filter_letters(image_prompt_type, "VL"))>0
model_def = get_model_def(get_state_model_type(state))
image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "")
end_visible = "E" in image_prompt_types_allowed and any_letters(image_prompt_type, "SVL")
input_strength_visible = len(model_def.get("input_video_strength","")) and any_letters(image_prompt_type, "SVL")
return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = input_strength_visible), gr.update(visible = any_video_source), gr.update(visible = end_visible)
def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox):
image_prompt_type = del_in_sequence(image_prompt_type, "E")
if end_checkbox: image_prompt_type += "E"
image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
return image_prompt_type, gr.update(visible = "E" in image_prompt_type )
def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs, image_mode):
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
image_ref_choices = model_def.get("image_ref_choices", None)
if image_ref_choices is not None:
video_prompt_type = del_in_sequence(video_prompt_type, image_ref_choices["letters_filter"])
else:
video_prompt_type = del_in_sequence(video_prompt_type, "KFI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
visible = "I" in video_prompt_type
any_outpainting= image_mode in model_def.get("video_guide_outpainting", [])
rm_bg_visible= visible and not model_def.get("no_background_removal", False)
img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False)
return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and any_outpainting )
def update_image_mask_guide(state, image_mask_guide):
img = image_mask_guide["background"]
if img.mode != 'RGBA':
return image_mask_guide
arr = np.array(img)
rgb = Image.fromarray(arr[..., :3], 'RGB')
alpha_gray = np.repeat(arr[..., 3:4], 3, axis=2)
alpha_gray = 255 - alpha_gray
alpha_rgb = Image.fromarray(alpha_gray, 'RGB')
image_mask_guide = {"background" : rgb, "composite" : None, "layers": [rgb_bw_to_rgba_mask(alpha_rgb)]}
return image_mask_guide
def switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
if image_mode == 0: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
mask_in_old = "A" in old_video_prompt_type and not "U" in old_video_prompt_type
mask_in_new = "A" in video_prompt_type and not "U" in video_prompt_type
image_mask_guide_value, image_mask_value, image_guide_value = {}, {}, {}
visible = "V" in video_prompt_type
if mask_in_old != mask_in_new:
if mask_in_new:
if old_image_mask_value is None:
image_mask_guide_value["value"] = old_image_guide_value
else:
image_mask_guide_value["value"] = {"background" : old_image_guide_value, "composite" : None, "layers": [rgb_bw_to_rgba_mask(old_image_mask_value)]}
image_guide_value["value"] = image_mask_value["value"] = None
else:
if old_image_mask_guide_value is not None and "background" in old_image_mask_guide_value:
image_guide_value["value"] = old_image_mask_guide_value["background"]
if "layers" in old_image_mask_guide_value:
image_mask_value["value"] = old_image_mask_guide_value["layers"][0] if len(old_image_mask_guide_value["layers"]) >=1 else None
image_mask_guide_value["value"] = {"background" : None, "composite" : None, "layers": []}
image_mask_guide = gr.update(visible= visible and mask_in_new, **image_mask_guide_value)
image_guide = gr.update(visible = visible and not mask_in_new, **image_guide_value)
image_mask = gr.update(visible = False, **image_mask_value)
return image_mask_guide, image_guide, image_mask
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
old_video_prompt_type = video_prompt_type
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
visible= "A" in video_prompt_type
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
image_outputs = image_mode > 0
mask_strength_always_enabled = model_def.get("mask_strength_always_enabled", False)
image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
return video_prompt_type, gr.update(visible= visible and not image_outputs), image_mask_guide, image_guide, image_mask, gr.update(visible= visible ) , gr.update(visible= visible and (mask_strength_always_enabled or "G" in video_prompt_type ) )
def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_type_video_guide):
video_prompt_type = del_in_sequence(video_prompt_type, "T")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
return video_prompt_type
def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
old_video_prompt_type = video_prompt_type
if filter_type == "alt":
guide_custom_choices = model_def.get("guide_custom_choices",{})
letter_filter = guide_custom_choices.get("letters_filter","")
else:
letter_filter = all_guide_processes
video_prompt_type = del_in_sequence(video_prompt_type, letter_filter)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
any_outpainting= image_mode in model_def.get("video_guide_outpainting", [])
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
image_outputs = image_mode > 0
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
# mask_video_input_visible = image_mode == 0 and mask_visible
mask_preprocessing = model_def.get("mask_preprocessing", None)
if mask_preprocessing is not None:
mask_selector_visible = mask_preprocessing.get("visible", True)
else:
mask_selector_visible = True
ref_images_visible = "I" in video_prompt_type
custom_options = custom_checkbox = False
custom_video_selection = model_def.get("custom_video_selection", None)
if custom_video_selection is not None:
custom_trigger = custom_video_selection.get("trigger","")
if len(custom_trigger) == 0 or custom_trigger in video_prompt_type:
custom_options = True
custom_checkbox = custom_video_selection.get("type","") == "checkbox"
mask_strength_always_enabled = model_def.get("mask_strength_always_enabled", False)
return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible = mask_visible and( mask_strength_always_enabled or "G" in video_prompt_type)), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ), gr.update(visible= custom_options and not custom_checkbox ), gr.update(visible= custom_options and custom_checkbox )
def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, video_prompt_type_video_custom_dropbox):
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
custom_video_selection = model_def.get("custom_video_selection", None)
if custom_video_selection is None: return gr.update()
letters_filter = custom_video_selection.get("letters_filter", "")
video_prompt_type = del_in_sequence(video_prompt_type, letters_filter)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox)
return video_prompt_type
def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, video_prompt_type_video_custom_checkbox):
model_type = get_state_model_type(state)
model_def = get_model_def(model_type)
custom_video_selection = model_def.get("custom_video_selection", None)
if custom_video_selection is None: return gr.update()
letters_filter = custom_video_selection.get("letters_filter", "")
video_prompt_type = del_in_sequence(video_prompt_type, letters_filter)
if video_prompt_type_video_custom_checkbox:
video_prompt_type = add_to_sequence(video_prompt_type, custom_video_selection["choices"][1][1])
return video_prompt_type
def refresh_preview(state):
gen = get_gen_info(state)
preview_image = gen.get("preview", None)
if preview_image is None:
return ""
preview_base64 = pil_to_base64_uri(preview_image, format="jpeg", quality=85)
if preview_base64 is None:
return ""
html_content = f"""