vivienfanghua commited on
Commit
c546b3f
·
verified ·
1 Parent(s): 6ba4fdc

Delete generate.py

Browse files
Files changed (1) hide show
  1. generate.py +0 -411
generate.py DELETED
@@ -1,411 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import argparse
3
- import logging
4
- import os
5
- import sys
6
- import warnings
7
- from datetime import datetime
8
-
9
- warnings.filterwarnings('ignore')
10
-
11
- import random
12
-
13
- import torch
14
- import torch.distributed as dist
15
- from PIL import Image
16
-
17
- import wan
18
- from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
- from wan.distributed.util import init_distributed_group
20
- from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
- from wan.utils.utils import cache_video, str2bool
22
-
23
- EXAMPLE_PROMPT = {
24
- "t2v-A14B": {
25
- "prompt":
26
- "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
27
- },
28
- "i2v-A14B": {
29
- "prompt":
30
- "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
31
- "image":
32
- "examples/i2v_input.JPG",
33
- },
34
- "ti2v-5B": {
35
- "prompt":
36
- "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
37
- },
38
- }
39
-
40
-
41
- def _validate_args(args):
42
- # Basic check
43
- assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
44
- assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
45
- assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
46
-
47
- if args.prompt is None:
48
- args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
49
- if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
50
- args.image = EXAMPLE_PROMPT[args.task]["image"]
51
-
52
- if args.task == "i2v-A14B":
53
- assert args.image is not None, "Please specify the image path for i2v."
54
-
55
- cfg = WAN_CONFIGS[args.task]
56
-
57
- if args.sample_steps is None:
58
- args.sample_steps = cfg.sample_steps
59
-
60
- if args.sample_shift is None:
61
- args.sample_shift = cfg.sample_shift
62
-
63
- if args.sample_guide_scale is None:
64
- args.sample_guide_scale = cfg.sample_guide_scale
65
-
66
- if args.frame_num is None:
67
- args.frame_num = cfg.frame_num
68
-
69
- args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
70
- 0, sys.maxsize)
71
- # Size check
72
- assert args.size in SUPPORTED_SIZES[
73
- args.
74
- task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
75
-
76
-
77
- def _parse_args():
78
- parser = argparse.ArgumentParser(
79
- description="Generate a image or video from a text prompt or image using Wan"
80
- )
81
- parser.add_argument(
82
- "--task",
83
- type=str,
84
- default="t2v-A14B",
85
- choices=list(WAN_CONFIGS.keys()),
86
- help="The task to run.")
87
- parser.add_argument(
88
- "--size",
89
- type=str,
90
- default="1280*720",
91
- choices=list(SIZE_CONFIGS.keys()),
92
- help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
93
- )
94
- parser.add_argument(
95
- "--frame_num",
96
- type=int,
97
- default=None,
98
- help="How many frames of video are generated. The number should be 4n+1"
99
- )
100
- parser.add_argument(
101
- "--ckpt_dir",
102
- type=str,
103
- default=None,
104
- help="The path to the checkpoint directory.")
105
- parser.add_argument(
106
- "--offload_model",
107
- type=str2bool,
108
- default=None,
109
- help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
110
- )
111
- parser.add_argument(
112
- "--ulysses_size",
113
- type=int,
114
- default=1,
115
- help="The size of the ulysses parallelism in DiT.")
116
- parser.add_argument(
117
- "--t5_fsdp",
118
- action="store_true",
119
- default=False,
120
- help="Whether to use FSDP for T5.")
121
- parser.add_argument(
122
- "--t5_cpu",
123
- action="store_true",
124
- default=False,
125
- help="Whether to place T5 model on CPU.")
126
- parser.add_argument(
127
- "--dit_fsdp",
128
- action="store_true",
129
- default=False,
130
- help="Whether to use FSDP for DiT.")
131
- parser.add_argument(
132
- "--save_file",
133
- type=str,
134
- default=None,
135
- help="The file to save the generated video to.")
136
- parser.add_argument(
137
- "--prompt",
138
- type=str,
139
- default=None,
140
- help="The prompt to generate the video from.")
141
- parser.add_argument(
142
- "--use_prompt_extend",
143
- action="store_true",
144
- default=False,
145
- help="Whether to use prompt extend.")
146
- parser.add_argument(
147
- "--prompt_extend_method",
148
- type=str,
149
- default="local_qwen",
150
- choices=["dashscope", "local_qwen"],
151
- help="The prompt extend method to use.")
152
- parser.add_argument(
153
- "--prompt_extend_model",
154
- type=str,
155
- default=None,
156
- help="The prompt extend model to use.")
157
- parser.add_argument(
158
- "--prompt_extend_target_lang",
159
- type=str,
160
- default="zh",
161
- choices=["zh", "en"],
162
- help="The target language of prompt extend.")
163
- parser.add_argument(
164
- "--base_seed",
165
- type=int,
166
- default=-1,
167
- help="The seed to use for generating the video.")
168
- parser.add_argument(
169
- "--image",
170
- type=str,
171
- default=None,
172
- help="The image to generate the video from.")
173
- parser.add_argument(
174
- "--sample_solver",
175
- type=str,
176
- default='unipc',
177
- choices=['unipc', 'dpm++'],
178
- help="The solver used to sample.")
179
- parser.add_argument(
180
- "--sample_steps", type=int, default=None, help="The sampling steps.")
181
- parser.add_argument(
182
- "--sample_shift",
183
- type=float,
184
- default=None,
185
- help="Sampling shift factor for flow matching schedulers.")
186
- parser.add_argument(
187
- "--sample_guide_scale",
188
- type=float,
189
- default=None,
190
- help="Classifier free guidance scale.")
191
- parser.add_argument(
192
- "--convert_model_dtype",
193
- action="store_true",
194
- default=False,
195
- help="Whether to convert model paramerters dtype.")
196
-
197
- args = parser.parse_args()
198
-
199
- _validate_args(args)
200
-
201
- return args
202
-
203
-
204
- def _init_logging(rank):
205
- # logging
206
- if rank == 0:
207
- # set format
208
- logging.basicConfig(
209
- level=logging.INFO,
210
- format="[%(asctime)s] %(levelname)s: %(message)s",
211
- handlers=[logging.StreamHandler(stream=sys.stdout)])
212
- else:
213
- logging.basicConfig(level=logging.ERROR)
214
-
215
-
216
- def generate(args):
217
- rank = int(os.getenv("RANK", 0))
218
- world_size = int(os.getenv("WORLD_SIZE", 1))
219
- local_rank = int(os.getenv("LOCAL_RANK", 0))
220
- device = local_rank
221
- _init_logging(rank)
222
-
223
- if args.offload_model is None:
224
- args.offload_model = False if world_size > 1 else True
225
- logging.info(
226
- f"offload_model is not specified, set to {args.offload_model}.")
227
- if world_size > 1:
228
- torch.cuda.set_device(local_rank)
229
- dist.init_process_group(
230
- backend="nccl",
231
- init_method="env://",
232
- rank=rank,
233
- world_size=world_size)
234
- else:
235
- assert not (
236
- args.t5_fsdp or args.dit_fsdp
237
- ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
238
- assert not (
239
- args.ulysses_size > 1
240
- ), f"sequence parallel are not supported in non-distributed environments."
241
-
242
- if args.ulysses_size > 1:
243
- assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
244
- init_distributed_group()
245
-
246
- if args.use_prompt_extend:
247
- if args.prompt_extend_method == "dashscope":
248
- prompt_expander = DashScopePromptExpander(
249
- model_name=args.prompt_extend_model,
250
- task=args.task,
251
- is_vl=args.image is not None)
252
- elif args.prompt_extend_method == "local_qwen":
253
- prompt_expander = QwenPromptExpander(
254
- model_name=args.prompt_extend_model,
255
- task=args.task,
256
- is_vl=args.image is not None,
257
- device=rank)
258
- else:
259
- raise NotImplementedError(
260
- f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
261
-
262
- cfg = WAN_CONFIGS[args.task]
263
- if args.ulysses_size > 1:
264
- assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
265
-
266
- logging.info(f"Generation job args: {args}")
267
- logging.info(f"Generation model config: {cfg}")
268
-
269
- if dist.is_initialized():
270
- base_seed = [args.base_seed] if rank == 0 else [None]
271
- dist.broadcast_object_list(base_seed, src=0)
272
- args.base_seed = base_seed[0]
273
-
274
- logging.info(f"Input prompt: {args.prompt}")
275
- img = None
276
- if args.image is not None:
277
- img = Image.open(args.image).convert("RGB")
278
- logging.info(f"Input image: {args.image}")
279
-
280
- # prompt extend
281
- if args.use_prompt_extend:
282
- logging.info("Extending prompt ...")
283
- if rank == 0:
284
- prompt_output = prompt_expander(
285
- args.prompt,
286
- image=img,
287
- tar_lang=args.prompt_extend_target_lang,
288
- seed=args.base_seed)
289
- if prompt_output.status == False:
290
- logging.info(
291
- f"Extending prompt failed: {prompt_output.message}")
292
- logging.info("Falling back to original prompt.")
293
- input_prompt = args.prompt
294
- else:
295
- input_prompt = prompt_output.prompt
296
- input_prompt = [input_prompt]
297
- else:
298
- input_prompt = [None]
299
- if dist.is_initialized():
300
- dist.broadcast_object_list(input_prompt, src=0)
301
- args.prompt = input_prompt[0]
302
- logging.info(f"Extended prompt: {args.prompt}")
303
-
304
- if "t2v" in args.task:
305
- logging.info("Creating WanT2V pipeline.")
306
- wan_t2v = wan.WanT2V(
307
- config=cfg,
308
- checkpoint_dir=args.ckpt_dir,
309
- device_id=device,
310
- rank=rank,
311
- t5_fsdp=args.t5_fsdp,
312
- dit_fsdp=args.dit_fsdp,
313
- use_sp=(args.ulysses_size > 1),
314
- t5_cpu=args.t5_cpu,
315
- convert_model_dtype=args.convert_model_dtype,
316
- )
317
-
318
- logging.info(f"Generating video ...")
319
- video = wan_t2v.generate(
320
- args.prompt,
321
- size=SIZE_CONFIGS[args.size],
322
- frame_num=args.frame_num,
323
- shift=args.sample_shift,
324
- sample_solver=args.sample_solver,
325
- sampling_steps=args.sample_steps,
326
- guide_scale=args.sample_guide_scale,
327
- seed=args.base_seed,
328
- offload_model=args.offload_model)
329
- elif "ti2v" in args.task:
330
- logging.info("Creating WanTI2V pipeline.")
331
- wan_ti2v = wan.WanTI2V(
332
- config=cfg,
333
- checkpoint_dir=args.ckpt_dir,
334
- device_id=device,
335
- rank=rank,
336
- t5_fsdp=args.t5_fsdp,
337
- dit_fsdp=args.dit_fsdp,
338
- use_sp=(args.ulysses_size > 1),
339
- t5_cpu=args.t5_cpu,
340
- convert_model_dtype=args.convert_model_dtype,
341
- )
342
-
343
- logging.info(f"Generating video ...")
344
- video = wan_ti2v.generate(
345
- args.prompt,
346
- img=img,
347
- size=SIZE_CONFIGS[args.size],
348
- max_area=MAX_AREA_CONFIGS[args.size],
349
- frame_num=args.frame_num,
350
- shift=args.sample_shift,
351
- sample_solver=args.sample_solver,
352
- sampling_steps=args.sample_steps,
353
- guide_scale=args.sample_guide_scale,
354
- seed=args.base_seed,
355
- offload_model=args.offload_model)
356
- else:
357
- logging.info("Creating WanI2V pipeline.")
358
- wan_i2v = wan.WanI2V(
359
- config=cfg,
360
- checkpoint_dir=args.ckpt_dir,
361
- device_id=device,
362
- rank=rank,
363
- t5_fsdp=args.t5_fsdp,
364
- dit_fsdp=args.dit_fsdp,
365
- use_sp=(args.ulysses_size > 1),
366
- t5_cpu=args.t5_cpu,
367
- convert_model_dtype=args.convert_model_dtype,
368
- )
369
-
370
- logging.info("Generating video ...")
371
- video = wan_i2v.generate(
372
- args.prompt,
373
- img,
374
- max_area=MAX_AREA_CONFIGS[args.size],
375
- frame_num=args.frame_num,
376
- shift=args.sample_shift,
377
- sample_solver=args.sample_solver,
378
- sampling_steps=args.sample_steps,
379
- guide_scale=args.sample_guide_scale,
380
- seed=args.base_seed,
381
- offload_model=args.offload_model)
382
-
383
- if rank == 0:
384
- if args.save_file is None:
385
- formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
386
- formatted_prompt = args.prompt.replace(" ", "_").replace("/",
387
- "_")[:50]
388
- suffix = '.mp4'
389
- args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
390
-
391
- logging.info(f"Saving generated video to {args.save_file}")
392
- cache_video(
393
- tensor=video[None],
394
- save_file=args.save_file,
395
- fps=cfg.sample_fps,
396
- nrow=1,
397
- normalize=True,
398
- value_range=(-1, 1))
399
- del video
400
-
401
- torch.cuda.synchronize()
402
- if dist.is_initialized():
403
- dist.barrier()
404
- dist.destroy_process_group()
405
-
406
- logging.info("Finished.")
407
-
408
-
409
- if __name__ == "__main__":
410
- args = _parse_args()
411
- generate(args)