File size: 2,805 Bytes
6104fdc
 
 
 
 
 
62febf3
6104fdc
62febf3
 
 
 
 
6104fdc
62febf3
 
6104fdc
62febf3
 
 
 
 
6104fdc
 
 
 
62febf3
 
 
6104fdc
62febf3
 
 
 
 
6104fdc
62febf3
 
 
 
 
 
 
 
 
6104fdc
 
62febf3
6104fdc
 
 
62febf3
6104fdc
62febf3
6104fdc
 
62febf3
 
6104fdc
62febf3
 
 
 
 
 
 
 
 
6104fdc
62febf3
 
 
 
 
 
6104fdc
62febf3
 
 
 
 
 
6104fdc
62febf3
 
 
6104fdc
182c437
6104fdc
 
721dfc7
6104fdc
182c437
62febf3
 
6104fdc
 
 
182c437
62febf3
6104fdc
62febf3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import sys
import time

import torch

# Ensure project roots are on sys.path
current_file_path = os.path.abspath(__file__)
project_roots = [
    os.path.dirname(current_file_path),
    os.path.dirname(os.path.dirname(current_file_path)),
    os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))),
]
for project_root in project_roots:
    if project_root not in sys.path:
        sys.path.insert(0, project_root)

from cogvideox.api.api import (
    infer_forward_api,
    update_diffusion_transformer_api,
    update_edition_api
)
from cogvideox.ui.controller import flow_scheduler_dict
from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope

if __name__ == "__main__":
    # --- Configuration ---

    # Choose the UI mode: one of "eas", "modelscope", or default
    ui_mode = "eas"

    # GPU memory mode: choices are
    #   - "model_cpu_offload"
    #   - "model_cpu_offload_and_qfloat8"
    #   - "sequential_cpu_offload"
    GPU_memory_mode = "model_cpu_offload"

    # Weight dtype: use bfloat16 if supported, otherwise float16
    weight_dtype = (
        torch.bfloat16
        if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
        else torch.float16
    )

    # Path to your OmegaConf config for WAN2.1
    config_path = "config/wan2.1/wan_civitai.yaml"

    # Server binding for Gradio
    server_name = "0.0.0.0"
    server_port = 7860

    # Parameters for modelscope mode
    model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
    model_type = "Inpaint"   # or "Control"
    savedir_sample = "samples"

    # --- Initialize UI & Controller ---

    if ui_mode == "modelscope":
        demo, controller = ui_modelscope(
            model_name,
            model_type,
            savedir_sample,
            GPU_memory_mode,
            flow_scheduler_dict,
            weight_dtype,
            config_path
        )
    elif ui_mode == "eas":
        demo, controller = ui_eas(
            model_name,
            flow_scheduler_dict,
            savedir_sample,
            config_path
        )
    else:
        demo, controller = ui(
            GPU_memory_mode,
            flow_scheduler_dict,
            weight_dtype,
            config_path
        )

    # --- Launch Gradio app ---

    # share=False for local/Colab use; ssr=False disables experimental SSR to avoid 405 errors
    app, _, _ = demo.queue(status_update_rate=1).launch(
        share=False,
        server_name=server_name,
        server_port=server_port,
        prevent_thread_lock=True
    )

    # --- Mount API endpoints ---

    infer_forward_api(None, app, controller)
    update_diffusion_transformer_api(None, app, controller)
    update_edition_api(None, app, controller)

    # Keep the script alive
    while True:
        time.sleep(5)