File size: 2,805 Bytes
6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 62febf3 6104fdc 182c437 6104fdc 721dfc7 6104fdc 182c437 62febf3 6104fdc 182c437 62febf3 6104fdc 62febf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import sys
import time
import torch
# Ensure project roots are on sys.path
current_file_path = os.path.abspath(__file__)
project_roots = [
os.path.dirname(current_file_path),
os.path.dirname(os.path.dirname(current_file_path)),
os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))),
]
for project_root in project_roots:
if project_root not in sys.path:
sys.path.insert(0, project_root)
from cogvideox.api.api import (
infer_forward_api,
update_diffusion_transformer_api,
update_edition_api
)
from cogvideox.ui.controller import flow_scheduler_dict
from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope
if __name__ == "__main__":
# --- Configuration ---
# Choose the UI mode: one of "eas", "modelscope", or default
ui_mode = "eas"
# GPU memory mode: choices are
# - "model_cpu_offload"
# - "model_cpu_offload_and_qfloat8"
# - "sequential_cpu_offload"
GPU_memory_mode = "model_cpu_offload"
# Weight dtype: use bfloat16 if supported, otherwise float16
weight_dtype = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else torch.float16
)
# Path to your OmegaConf config for WAN2.1
config_path = "config/wan2.1/wan_civitai.yaml"
# Server binding for Gradio
server_name = "0.0.0.0"
server_port = 7860
# Parameters for modelscope mode
model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
model_type = "Inpaint" # or "Control"
savedir_sample = "samples"
# --- Initialize UI & Controller ---
if ui_mode == "modelscope":
demo, controller = ui_modelscope(
model_name,
model_type,
savedir_sample,
GPU_memory_mode,
flow_scheduler_dict,
weight_dtype,
config_path
)
elif ui_mode == "eas":
demo, controller = ui_eas(
model_name,
flow_scheduler_dict,
savedir_sample,
config_path
)
else:
demo, controller = ui(
GPU_memory_mode,
flow_scheduler_dict,
weight_dtype,
config_path
)
# --- Launch Gradio app ---
# share=False for local/Colab use; ssr=False disables experimental SSR to avoid 405 errors
app, _, _ = demo.queue(status_update_rate=1).launch(
share=False,
server_name=server_name,
server_port=server_port,
prevent_thread_lock=True
)
# --- Mount API endpoints ---
infer_forward_api(None, app, controller)
update_diffusion_transformer_api(None, app, controller)
update_edition_api(None, app, controller)
# Keep the script alive
while True:
time.sleep(5)
|