yongqiang
initialize this repo
ba96580
import argparse
import os
import sys
import time
import gradio as gr
import ray
import torch
current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
sys.path.insert(0, project_root) if project_root not in sys.path else None
from videox_fun.api.api_multi_nodes import (MultiNodesEngine,
multi_nodes_infer_forward_api)
from videox_fun.ui.controller import flow_scheduler_dict
from videox_fun.ui.wan2_2_ui import Wan2_2_Controller
def main():
parser = argparse.ArgumentParser(description='xDiT HTTP Service')
parser.add_argument('--world_size', type=int, default=8, help='Number of parallel workers')
parser.add_argument(
'--gpu_memory_mode', type=str, default="model_full_load", help='''
GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8].
model_full_load means that the entire model will be moved to the GPU.
model_full_load_and_qfloat8 means that the entire model will be moved to the GPU,
and the transformer model has been quantized to float8, which can save more GPU memory.
model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
and the transformer model has been quantized to float8, which can save more GPU memory.
'''
)
parser.add_argument('--ulysses_degree', type=int, default=4, help='Degree of Ulysses configuration')
parser.add_argument('--ring_degree', type=int, default=2, help='Degree of Ring configuration')
parser.add_argument(
'--compile_dit', action='store_true', help='''
Enable compile dit.
Compile will give a speedup in fixed resolution and need a little GPU memory.
The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
'''
)
parser.add_argument('--fsdp_dit', action='store_true', help="Use DIT FSDP to save more GPU memory in multi gpus.")
parser.add_argument('--fsdp_text_encoder', action='store_true', help="Use Text Encoder FSDP to save more GPU memory in multi gpus.")
parser.add_argument('--weight_dtype', type=str, default='bf16', help='Weight data type')
parser.add_argument('--server_name', type=str, default="0.0.0.0", help='Server IP address')
parser.add_argument('--server_port', type=int, default=7860, help='Server Port')
parser.add_argument('--config_path', type=str, default="config/wan2.2/wan_civitai_i2v.yaml", help='Path to config file')
parser.add_argument('--model_name', type=str, default="models/Diffusion_Transformer/Wan2.2-I2V-A14B", help='Model path')
parser.add_argument('--model_type', type=str, default="Inpaint", help='Model type (Inpaint/Control)')
parser.add_argument('--savedir_sample', type=str, default=None, help='The save directory for samples')
args = parser.parse_args()
weight_dtype = torch.float32
if args.weight_dtype == "bf16":
weight_dtype = torch.bfloat16
elif args.weight_dtype == "fp16":
weight_dtype = torch.float16
engine = MultiNodesEngine(
world_size=args.world_size, Controller=Wan2_2_Controller,
GPU_memory_mode=args.gpu_memory_mode, scheduler_dict=flow_scheduler_dict, model_name=args.model_name, model_type=args.model_type, config_path=args.config_path,
ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree,
fsdp_dit=args.fsdp_dit, fsdp_text_encoder=args.fsdp_text_encoder, compile_dit=args.compile_dit,
weight_dtype=weight_dtype, savedir_sample=args.savedir_sample,
)
def gr_launch():
# launch gradio
with gr.Blocks() as demo:
gr.Markdown("")
app, _, _ = demo.queue(status_update_rate=1).launch(
server_name=args.server_name,
server_port=args.server_port,
prevent_thread_lock=True
)
# launch api
multi_nodes_infer_forward_api(None, app, engine)
gr_launch()
# not close the python
while True:
time.sleep(5)
if __name__ == "__main__":
main()