File size: 20,526 Bytes
ba96580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
import os
import sys

import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from omegaconf import OmegaConf
from PIL import Image

current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
    sys.path.insert(0, project_root) if project_root not in sys.path else None

from videox_fun.dist import set_multi_gpus_devices, shard_model
from videox_fun.models import (AutoencoderKL, AutoTokenizer,
                               Qwen3ForCausalLM, ZImageControlTransformer2DModel)
from videox_fun.models.cache_utils import get_teacache_coefficients
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8,
                                               convert_weight_dtype_wrapper)
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, get_image,
                                    get_video_to_video_latent,
                                    save_videos_grid)

from loguru import logger
import onnx
import subprocess


def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx"):
    """
    执行 onnxslim 命令压缩 ONNX 模型
    """
    try:
        # 使用完整的命令路径(如果知道的话)
        cmd = ["onnxslim", input_file, output_file]
        
        print(f"执行命令: {' '.join(cmd)}")
        
        # 执行命令, 实时输出
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,
            universal_newlines=True
        )
        
        # 实时打印输出
        for line in process.stdout:
            print(line, end='')
        
        # 等待命令完成
        stdout, stderr = process.communicate()
        
        if process.returncode != 0:
            print(f"命令执行失败, 错误信息:\n{stderr}")
            return False
        else:
            print("ONNX模型压缩完成!")
            return True
            
    except FileNotFoundError:
        print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim")
        print("安装方法: pip install onnx-simplifier")
        return False
    except Exception as e:
        print(f"执行命令时发生错误: {e}")
        return False


# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
# model_full_load means that the entire model will be moved to the GPU.
# 
# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU,
# and the transformer model has been quantized to float8, which can save more GPU memory. 
# 
# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
# 
# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, 
# and the transformer model has been quantized to float8, which can save more GPU memory. 
# 
# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, 
# resulting in slower speeds but saving a large amount of GPU memory.
GPU_memory_mode     = "model_cpu_offload"
# Multi GPUs config
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. 
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
ulysses_degree      = 1
ring_degree         = 1
# Use FSDP to save more GPU memory in multi gpus.
fsdp_dit            = False
fsdp_text_encoder   = False
# Compile will give a speedup in fixed resolution and need a little GPU memory. 
# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
compile_dit         = False

# Config and model path
config_path         = "config/z_image/z_image_control.yaml"
# model path
model_name          = "models/Diffusion_Transformer/Z-Image-Turbo/"

# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++"
sampler_name        = "Flow"

# Load pretrained model if need
transformer_path    = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" 
vae_path            = None
lora_path           = None

# Other params
sample_size         = [1728, 992] # H, W

# Use torch.float16 if GPU does not support torch.bfloat16
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
weight_dtype        = torch.bfloat16
control_image       = "asset/pose.jpg"
control_context_scale  = 0.75

# 使用更长的neg prompt如"模糊, 突变, 变形, 失真, 画面暗, 文本字幕, 画面固定, 连环画, 漫画, 线稿, 没有主体.", 可以增加稳定性
# 在neg prompt中添加"安静, 固定"等词语可以增加动态性.
prompt              = "一位年轻女子站在阳光明媚的海岸线上, 白裙在轻拂的海风中微微飘动.她拥有一头鲜艳的紫色长发, 在风中轻盈舞动, 发间系着一个精致的黑色蝴蝶结, 与身后柔和的蔚蓝天空形成鲜明对比.她面容清秀, 眉目精致, 透着一股甜美的青春气息;神情柔和, 略带羞涩, 目光静静地凝望着远方的地平线, 双手自然交叠于身前, 仿佛沉浸在思绪之中.在她身后, 是辽阔无垠、波光粼粼的大海, 阳光洒在海面上, 映出温暖的金色光晕."
# prompt              = "一位身穿白色仙袍的仙人女子手持青色仙剑, 她抬头望着天空疾驰而来的雷霆, 面容严肃, 一袭紫色长发在风中飘扬, 仿佛与天地间的风雷共舞.她站立在一片古老的山巅, 背后是连绵起伏的群山和翻滚的乌云, 整个场景充满了神秘而壮丽的气息.天空中闪电划过, 照亮了她坚定的眼神和手中的仙剑, 彷佛预示着一场即将到来的大战.她的姿态优雅而坚定, 彷佛是天地间的守护者, 准备迎接任何挑战."
negative_prompt     = " "
guidance_scale      = 0.00
seed                = 43
num_inference_steps = 9
lora_weight         = 0.55
save_path           = "samples/z-image-t2i-control"

device = set_multi_gpus_devices(ulysses_degree, ring_degree)
config = OmegaConf.load(config_path)

transformer = ZImageControlTransformer2DModel.from_pretrained(
    model_name, 
    subfolder="transformer",
    low_cpu_mem_usage=True,
    torch_dtype=weight_dtype,
    transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
).to(weight_dtype)

if transformer_path is not None:
    print(f"From checkpoint: {transformer_path}")
    if transformer_path.endswith("safetensors"):
        from safetensors.torch import load_file, safe_open
        state_dict = load_file(transformer_path)
    else:
        state_dict = torch.load(transformer_path, map_location="cpu")
    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict

    m, u = transformer.load_state_dict(state_dict, strict=False)
    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")


if False:

    class DummyControlTransformerWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
            self.control_context_scale = 0.75

        def forward(
            self,
            latent_model_input,
            timestep,
            prompt_embeds,
            control_context,
        ):
            model_out = self.model(
                latent_model_input,
                timestep,
                prompt_embeds,
                control_context=control_context,
                control_context_scale=self.control_context_scale,
            )
            return model_out

    # 使用 Torch 导出 transformer onnx 模型
    dummy_input = {
        "latent_model_input": [torch.randn(16, 1, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32)],
        "timestep": torch.tensor([0.], device="cpu", dtype=torch.float32),
        "prompt_embeds": [torch.randn(512, 2560, device="cpu", dtype=torch.float32)], # TODO: 这里需要支持最大长度
        "control_context": torch.randn(1, 16, 1, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32),
        # "control_context_scale": 0.75,
    }

    # import pdb; pdb.set_trace()
    transformer_warpper = DummyControlTransformerWrapper(transformer)
    transformer_warpper.eval()

    transformer_path = "onnx-models/trans/"
    transformer_onnx_path = os.path.join(transformer_path, "z_image_control_transformer.onnx")
    if not os.path.exists(transformer_path):
        os.makedirs(transformer_path, exist_ok=True)

    torch.onnx.export(
        transformer_warpper.to(device="cpu", dtype=torch.float32),
        tuple(dummy_input.values()),
        transformer_onnx_path,
        opset_version=17,
        input_names=list(dummy_input.keys()),
        output_names=["model_out"],
        do_constant_folding=True,
        export_params=True,
        verbose=False
    )
    trans_onnx = onnx.load(transformer_onnx_path)

    simp_onnx_data = "onnx-models/z_image_control_transformer.onnx"
    onnx.save(
        trans_onnx,
        simp_onnx_data, 
        save_as_external_data=True,
        all_tensors_to_one_file=True
    )

    logger.info("Transformer ONNX model exported, start to simplify.")
    # 在 python 中执行终端指令: onnxslim vae.onnx vae_slim.onnx 实现模型简化
    success = run_onnxslim(simp_onnx_data, simp_onnx_data.replace(".onnx", "_slim.onnx"))
    if success:
        logger.info("Transformer ONNX model exported successfully.")
    else:
        sys.exit(1)

    exit()
    """
    (Pdb) latent_model_input_list[0].shape
        torch.Size([16, 1, 216, 124])
    (Pdb) timestep_model_input
        tensor([0.], device='cuda:0')
    (Pdb) prompt_embeds_model_input[0].shape
        torch.Size([165, 2560])
    (Pdb) control_context.shape
        torch.Size([1, 16, 1, 216, 124])
    (Pdb) control_context_scale
        0.75
    (Pdb) model_out_list[0].shape
        torch.Size([16, 1, 216, 124])
    """


# Get Vae
vae = AutoencoderKL.from_pretrained(
    model_name, 
    subfolder="vae"
).to(weight_dtype)


if False:

    class DummyVAEEncoderWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, x):
            latent_dist = self.model.encode(x)[0].mode()
            return latent_dist

    class DummyVAEDecoderWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, latents):
            image = self.model.decode(latents, return_dict=False)[0]
            return image

    def export_vae_onnx(vae, sample_size):
        # 使用 Torch 导出 vae onnx 模型
        vae.eval()
        if not os.path.exists("./onnx-models"):
            os.makedirs("./onnx-models", exist_ok=True)

        ## 导出 VAE Decoder
        vae_encoder_onnx_path = "./onnx-models/vae_encoder.onnx"
        dummy_input = torch.randn(1, 3, sample_size[0], sample_size[1], device="cpu", dtype=torch.float32)
        vae_encode_wrapper = DummyVAEEncoderWrapper(vae)
        vae_encode_wrapper.eval()
        torch.onnx.export(vae_encode_wrapper.to(torch.float32), dummy_input, vae_encoder_onnx_path, opset_version=17)
        # import pdb; pdb.set_trace()
        onnx.checker.check_model(vae_encoder_onnx_path)
        logger.info("VAE-Encoder ONNX model exported, start to simplify.")
        # 在 python 中执行终端指令: onnxslim vae.onnx vae_slim.onnx 实现模型简化
        success = run_onnxslim(vae_encoder_onnx_path, vae_encoder_onnx_path.replace(".onnx", "_slim.onnx"))

        ## 导出 VAE Decoder
        vae_decoder_onnx_path = "./onnx-models/vae_decoder.onnx"
        dummy_latent = torch.randn(1, vae.config.latent_channels, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32)
        vae_decode_wrapper = DummyVAEDecoderWrapper(vae)
        vae_decode_wrapper.eval()
        torch.onnx.export(vae_decode_wrapper.to(torch.float32), dummy_latent, input_names=["latent"], output_names=["image"], f=vae_decoder_onnx_path, opset_version=17)
        onnx.checker.check_model(vae_decoder_onnx_path)
        logger.info("VAE-Decoder ONNX model exported, start to simplify.")
        success = run_onnxslim(vae_decoder_onnx_path, vae_decoder_onnx_path.replace(".onnx", "_slim.onnx"))

        if success:
            logger.info("VAE ONNX model exported successfully.")
        else:
            sys.exit(1)
    export_vae_onnx(vae, sample_size)
    exit()

if vae_path is not None:
    print(f"From checkpoint: {vae_path}")
    if vae_path.endswith("safetensors"):
        from safetensors.torch import load_file, safe_open
        state_dict = load_file(vae_path)
    else:
        state_dict = torch.load(vae_path, map_location="cpu")
    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict

    m, u = vae.load_state_dict(state_dict, strict=False)
    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")

# Get tokenizer and text_encoder
tokenizer = AutoTokenizer.from_pretrained(
    model_name, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
    model_name, subfolder="text_encoder", torch_dtype=weight_dtype,
    low_cpu_mem_usage=True,
)

if False:
    # 目前使用 llm_build 方法进行编译, Qwen3 架构
    text_encoder.eval()
    text_encoder.config.use_cache = False
    text_encoder.config.output_attentions = False
    text_encoder.config.output_hidden_states = False

    # 创建包装器
    import torch.nn as nn
    class Qwen3TextEncoderExporter(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
            
        def forward(self, input_ids, attention_mask=None):
            # 返回 hidden_states[-2]
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                return_dict=True
            )
            return outputs.hidden_states[-2]

    wrapped_text_encoder = Qwen3TextEncoderExporter(text_encoder)
    wrapped_text_encoder.eval()

    # 导出 text_encoder 的 onnx 模型
    max_sequence_length = 512
    text_encoder_onnx_path = "./onnx-models/text_encoder.onnx"
    """
    NOTE: 注意输入 onnx 的 mask 的 size 与 input_ids 的 size 一致. 前 N 个有效输入为 True, 后面的 padding 为 False.
    例如 input_ids 的 size 为 (1, 512), 实际有效输入长度为 20, 则 attention_mask 应该为:
    attention_mask = [True, True, ..., True, False, False, ..., False]  # 共 512 个元素, 前 20 个为 True, 后 492 个为 False
    这样可以确保 ONNX 模型在推理时正确处理 padding 部分, 避免无效计算.
    """
    input_ids = torch.randint(0, tokenizer.vocab_size, (1, max_sequence_length), device="cpu", dtype=torch.long)
    attention_mask = torch.ones((1, max_sequence_length), device="cpu", dtype=torch.long)

    # 测试前向传播
    with torch.no_grad():
        test_output = wrapped_text_encoder(input_ids, attention_mask)
        print(f"测试输出形状: {test_output.shape}")
    import pdb; pdb.set_trace()

    torch.onnx.export(
        wrapped_text_encoder,
        (input_ids, attention_mask),
        text_encoder_onnx_path,
        opset_version=17,
        input_names=["input_ids", "attention_mask"],
        output_names=["last_hidden_state"],
        do_constant_folding=True,
        export_params=True,
        verbose=False
    )
    onnx.checker.check_model(text_encoder_onnx_path)
    logger.info("Text Encoder ONNX model exported successfully.")

    if not os.path.exists("./onnx-models"):
        os.makedirs("./onnx-models", exist_ok=True)

    logger.info("Text Encoder ONNX model exported, start to simplify.")
    # 在 python 中执行终端指令: onnxslim text_encoder.onnx text_encoder_slim.onnx 实现模型简化
    success = run_onnxslim(text_encoder_onnx_path, text_encoder_onnx_path.replace(".onnx", "_slim.onnx"))
    if success:
        logger.info("Text Encoder ONNX model exported successfully.")
    else:
        sys.exit(1)

# Get Scheduler
Chosen_Scheduler = scheduler_dict = {
    "Flow": FlowMatchEulerDiscreteScheduler,
    "Flow_Unipc": FlowUniPCMultistepScheduler,
    "Flow_DPM++": FlowDPMSolverMultistepScheduler,
}[sampler_name]
scheduler = Chosen_Scheduler.from_pretrained(
    model_name, 
    subfolder="scheduler"
)

pipeline = ZImageControlPipeline(
    vae=vae,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    transformer=transformer,
    scheduler=scheduler,
)

# if ulysses_degree > 1 or ring_degree > 1:
#     from functools import partial
#     transformer.enable_multi_gpus_inference()
#     if fsdp_dit:
#         shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks))
#         pipeline.transformer = shard_fn(pipeline.transformer)
#         print("Add FSDP DIT")
#     if fsdp_text_encoder:
#         shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"])
#         text_encoder = shard_fn(text_encoder)
#         print("Add FSDP TEXT ENCODER")

# if compile_dit:
#     for i in range(len(pipeline.transformer.transformer_blocks)):
#         pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i])
#     print("Add Compile")

if GPU_memory_mode == "sequential_cpu_offload":
    pipeline.enable_sequential_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
    convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
    convert_weight_dtype_wrapper(transformer, weight_dtype)
    pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload":
    pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_full_load_and_qfloat8":
    convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
    convert_weight_dtype_wrapper(transformer, weight_dtype)
    pipeline.to(device=device)
else:
    pipeline.to(device=device)

generator = torch.Generator(device=device).manual_seed(seed)

# if lora_path is not None:
#     pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

with torch.no_grad():
    if control_image is not None:
        control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] # torch.Size([1, 3, sample_size[0], sample_size[1]])

    sample = pipeline(
        prompt      = prompt, 
        negative_prompt = negative_prompt,
        height      = sample_size[0],
        width       = sample_size[1],
        generator   = generator,
        guidance_scale = guidance_scale,
        control_image       = control_image,
        num_inference_steps = num_inference_steps,
        control_context_scale = control_context_scale,
    ).images

# if lora_path is not None:
#     pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)

def save_results():
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)

    index = len([path for path in os.listdir(save_path)]) + 1
    prefix = str(index).zfill(8)
    video_path = os.path.join(save_path, prefix + ".png")
    image = sample[0]
    image.save(video_path)

# if ulysses_degree * ring_degree > 1:
#     import torch.distributed as dist
#     if dist.get_rank() == 0:
#         save_results()
# else:
#     save_results()

save_results()