File size: 9,903 Bytes
85752bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
from PIL import Image
from einops import rearrange
import numpy as np
from typing import Optional, List, Tuple, Callable
import json
import math

from tqdm import tqdm
import os
import argparse
import torch.distributed as dist
import torch.nn.functional as F
from diffsynth.models import ModelManager
from diffsynth.models.utils import load_state_dict
import torch
from PIL import Image
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
import yaml
import torch, os, json
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
from datasets.videodataset import MulltiShot_MultiView_Dataset
from PIL import Image, ImageOps


def test_video(args):
    checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors")
    output_path =  os.path.join("./output", args.visual_log_project_name)
    os.makedirs(f"{output_path}/ref_images", exist_ok=True)
    os.makedirs(f"{output_path}/video", exist_ok=True)
    print(checkpoint_path)
    pipe = WanVideoPipeline.from_pretrained(
        torch_dtype=torch.bfloat16,
        device="cuda",
        model_configs=[
            ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"),  offload_device="cuda"),
            ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"),
            ModelConfig(path=checkpoint_path, offload_device="cuda"),
        ],
        redirect_common_files = False
    )
    pipe.enable_vram_management()

    with open(args.train_yaml, "r", encoding="utf-8") as f:
        conf_info = yaml.safe_load(f)   # 用 safe_load 更安全
    dataset = MulltiShot_MultiView_Dataset(
        dataset_base_path=args.dataset_base_path,
        resolution=(args.height, args.width),
        ref_num=args.ref_num,
        training=False
    )

    log_file_name = "output_log.txt"
    import pdb; pdb.set_trace()
    # v_indexs = [0, 10, 30, 50, 70, 100, 130, 150, 180, 200]
    v_indexs = [0, 5, 15, 20]
    with open(os.path.join(output_path, log_file_name), "w") as f:
        for v_index in v_indexs: 
            metadata = dataset[v_index]
            video, _ = pipe(
                args = args,
                prompt = [metadata["single_caption"]], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch  = 1  的list
                ref_images = [metadata["ref_images"]],
                negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"],
                seed=42, tiled=True,
                height=args.height, width=args.width,
                num_frames=args.num_frames,
                cfg_scale_face = 5.,
                num_ref_images = metadata["ref_num"]
            )
            for r_index, img in enumerate(metadata["ref_images"]):
                img.save(f"{output_path}/ref_images/{v_index}-{r_index}.png")
            
            save_video(video, f"{output_path}/video/{v_index}.mp4", fps=15, quality=10)
               
            f.write(f"{metadata['single_caption']}\n")
    
def specify_video(args):
    def process_ref_images(ref_images, height, width):
        ref_images_new = []
        for ref_image in ref_images:
            h = height
            w = width
            ref_image = ref_image.convert("RGB")
            # Calculate the required size to keep aspect ratio and fill the rest with padding.
            img_ratio = ref_image.width / ref_image.height
            target_ratio = w / h
            
            if img_ratio > target_ratio:  # Image is wider than target
                new_width = w
                new_height = int(new_width / img_ratio)
            else:  # Image is taller than target
                new_height = h
                new_width = int(new_height * img_ratio)
            
            # img = img.resize((new_width, new_height), Image.ANTIALIAS)
            ref_image = ref_image.resize((new_width, new_height), Image.Resampling.LANCZOS)

            # Create a new image with the target size and place the resized image in the center
            delta_w = w - ref_image.size[0]
            delta_h = h - ref_image.size[1]
            padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
            new_img = ImageOps.expand(ref_image, padding, fill=(255, 255, 255))
            ref_images_new.append(new_img)
        return ref_images_new

    checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors")
    output_path =  os.path.join("./output", args.visual_log_project_name)
    os.makedirs(f"{output_path}/ref_images", exist_ok=True)
    os.makedirs(f"{output_path}/video", exist_ok=True)
    print(checkpoint_path)
    pipe = WanVideoPipeline.from_pretrained(
        torch_dtype=torch.bfloat16,
        device="cuda",
        model_configs=[
            ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"),  offload_device="cuda"),
            ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"),
            ModelConfig(path=checkpoint_path, offload_device="cuda"),
        ],
        redirect_common_files = False
    )
    pipe.enable_vram_management()
    ref_images=[
                    Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_0.png"),
                    # Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_2.png"),
                    # Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_3.png"),
                    ]
    
    ref_images = process_ref_images(ref_images, args.height, args.width)


    video, _ = pipe(
            args = args,
            prompt = ["An elderly man with short gray hair and glasses stands in a softly lit indoor hallway. The shot begins with a frontal view of his face, his expression calm and attentive as he looks straight ahead. Then, he turns his head to his right, responding to someone standing beside him. His gaze shifts fully toward the other person as his expression becomes more engaged. The movement continues until he reaches a complete side profile, fully turning his face toward the person he is interacting with. Smooth and natural head rotation, warm indoor lighting."], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch  = 1  的list
            ref_images = [ref_images],
            negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"],
            seed=42, tiled=True,
            height=args.height, width=args.width,
            num_frames=args.num_frames,
            cfg_scale_face = 5.,
            num_ref_images = len(ref_images)
        )
    save_video(video, f"{output_path}/video/cl.mp4", fps=15, quality=10)
            

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="长视频分镜头连续生成脚本")
    
    # --- 核心路径参数 ---
    parser = wan_parser()
    args = parser.parse_args()

    args, unknown = parser.parse_known_args()
    print("❗ Unknown arguments:", unknown)
    ### 执行过pip install -e . 的话diffsynth 里的东西修改后要重新安装
    # import pdb; pdb.set_trace()
    ###下面是解析train.yaml里的内容
    with open(args.train_yaml, "r", encoding="utf-8") as f:
        conf_info = yaml.safe_load(f)   # 用 safe_load 更安全
    print(conf_info)
    args.dataset_base_path  = conf_info["dataset_args"]["base_path"]
    args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"]
    args.resume_from_checkpoint =  conf_info["train_args"]["resume_from_checkpoint"]
    args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"]
    args.seed = conf_info["train_args"]["seed"]
    args.output_path = conf_info["train_args"]["output_path"]
    args.save_steps = conf_info["train_args"]["save_steps"]
    args.save_epoches = conf_info["train_args"]["save_epoches"]
    args.batch_size = conf_info["train_args"]["batch_size"]
    args.local_model_path = conf_info["train_args"]["local_model_path"]
    args.height = conf_info["dataset_args"]["height"] 
    args.width = conf_info["dataset_args"]["width"]
    args.num_frames = conf_info["dataset_args"]["num_frames"]
    args.ref_num = conf_info["dataset_args"]["ref_num"]
    args.infer_step = conf_info["infer_args"]["infer_step"]
    args.epoch_id = conf_info["infer_args"]["epoch_id"]
    args.split_rope = conf_info["train_args"]["split_rope"]
    args.split1 = conf_info["train_args"]["split1"]
    args.split2 = conf_info["train_args"]["split2"]
    args.split3 = conf_info["train_args"]["split3"]
        
    test_video(args)
    # specify_video(args)