File size: 9,903 Bytes
85752bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | import torch
from PIL import Image
from einops import rearrange
import numpy as np
from typing import Optional, List, Tuple, Callable
import json
import math
from tqdm import tqdm
import os
import argparse
import torch.distributed as dist
import torch.nn.functional as F
from diffsynth.models import ModelManager
from diffsynth.models.utils import load_state_dict
import torch
from PIL import Image
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
import yaml
import torch, os, json
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
from datasets.videodataset import MulltiShot_MultiView_Dataset
from PIL import Image, ImageOps
def test_video(args):
checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors")
output_path = os.path.join("./output", args.visual_log_project_name)
os.makedirs(f"{output_path}/ref_images", exist_ok=True)
os.makedirs(f"{output_path}/video", exist_ok=True)
print(checkpoint_path)
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"),
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"),
ModelConfig(path=checkpoint_path, offload_device="cuda"),
],
redirect_common_files = False
)
pipe.enable_vram_management()
with open(args.train_yaml, "r", encoding="utf-8") as f:
conf_info = yaml.safe_load(f) # 用 safe_load 更安全
dataset = MulltiShot_MultiView_Dataset(
dataset_base_path=args.dataset_base_path,
resolution=(args.height, args.width),
ref_num=args.ref_num,
training=False
)
log_file_name = "output_log.txt"
import pdb; pdb.set_trace()
# v_indexs = [0, 10, 30, 50, 70, 100, 130, 150, 180, 200]
v_indexs = [0, 5, 15, 20]
with open(os.path.join(output_path, log_file_name), "w") as f:
for v_index in v_indexs:
metadata = dataset[v_index]
video, _ = pipe(
args = args,
prompt = [metadata["single_caption"]], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch = 1 的list
ref_images = [metadata["ref_images"]],
negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"],
seed=42, tiled=True,
height=args.height, width=args.width,
num_frames=args.num_frames,
cfg_scale_face = 5.,
num_ref_images = metadata["ref_num"]
)
for r_index, img in enumerate(metadata["ref_images"]):
img.save(f"{output_path}/ref_images/{v_index}-{r_index}.png")
save_video(video, f"{output_path}/video/{v_index}.mp4", fps=15, quality=10)
f.write(f"{metadata['single_caption']}\n")
def specify_video(args):
def process_ref_images(ref_images, height, width):
ref_images_new = []
for ref_image in ref_images:
h = height
w = width
ref_image = ref_image.convert("RGB")
# Calculate the required size to keep aspect ratio and fill the rest with padding.
img_ratio = ref_image.width / ref_image.height
target_ratio = w / h
if img_ratio > target_ratio: # Image is wider than target
new_width = w
new_height = int(new_width / img_ratio)
else: # Image is taller than target
new_height = h
new_width = int(new_height * img_ratio)
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
ref_image = ref_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create a new image with the target size and place the resized image in the center
delta_w = w - ref_image.size[0]
delta_h = h - ref_image.size[1]
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
new_img = ImageOps.expand(ref_image, padding, fill=(255, 255, 255))
ref_images_new.append(new_img)
return ref_images_new
checkpoint_path = os.path.join(args.output_path, args.visual_log_project_name, f"checkpoint-step-{args.infer_step}-epoch-{args.epoch_id}", "weights.safetensors")
output_path = os.path.join("./output", args.visual_log_project_name)
os.makedirs(f"{output_path}/ref_images", exist_ok=True)
os.makedirs(f"{output_path}/video", exist_ok=True)
print(checkpoint_path)
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth"), offload_device="cuda"),
ModelConfig(path=os.path.join(args.local_model_path, "Wan2.2-TI2V-5B/Wan2.2_VAE.pth"), offload_device="cuda"),
ModelConfig(path=checkpoint_path, offload_device="cuda"),
],
redirect_common_files = False
)
pipe.enable_vram_management()
ref_images=[
Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_0.png"),
# Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_2.png"),
# Image.open("/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/cl_3.png"),
]
ref_images = process_ref_images(ref_images, args.height, args.width)
video, _ = pipe(
args = args,
prompt = ["An elderly man with short gray hair and glasses stands in a softly lit indoor hallway. The shot begins with a frontal view of his face, his expression calm and attentive as he looks straight ahead. Then, he turns his head to his right, responding to someone standing beside him. His gaze shifts fully toward the other person as his expression becomes more engaged. The movement continues until he reaches a complete side profile, fully turning his face toward the person he is interacting with. Smooth and natural head rotation, warm indoor lighting."], #prompt, #"两只狗在擂台上打拳击", ### 手动变成batch = 1 的list
ref_images = [ref_images],
negative_prompt=["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"],
seed=42, tiled=True,
height=args.height, width=args.width,
num_frames=args.num_frames,
cfg_scale_face = 5.,
num_ref_images = len(ref_images)
)
save_video(video, f"{output_path}/video/cl.mp4", fps=15, quality=10)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="长视频分镜头连续生成脚本")
# --- 核心路径参数 ---
parser = wan_parser()
args = parser.parse_args()
args, unknown = parser.parse_known_args()
print("❗ Unknown arguments:", unknown)
### 执行过pip install -e . 的话diffsynth 里的东西修改后要重新安装
# import pdb; pdb.set_trace()
###下面是解析train.yaml里的内容
with open(args.train_yaml, "r", encoding="utf-8") as f:
conf_info = yaml.safe_load(f) # 用 safe_load 更安全
print(conf_info)
args.dataset_base_path = conf_info["dataset_args"]["base_path"]
args.max_checkpoints_to_keep = conf_info["train_args"]["max_checkpoints_to_keep"]
args.resume_from_checkpoint = conf_info["train_args"]["resume_from_checkpoint"]
args.visual_log_project_name = conf_info["train_args"]["visual_log_project_name"]
args.seed = conf_info["train_args"]["seed"]
args.output_path = conf_info["train_args"]["output_path"]
args.save_steps = conf_info["train_args"]["save_steps"]
args.save_epoches = conf_info["train_args"]["save_epoches"]
args.batch_size = conf_info["train_args"]["batch_size"]
args.local_model_path = conf_info["train_args"]["local_model_path"]
args.height = conf_info["dataset_args"]["height"]
args.width = conf_info["dataset_args"]["width"]
args.num_frames = conf_info["dataset_args"]["num_frames"]
args.ref_num = conf_info["dataset_args"]["ref_num"]
args.infer_step = conf_info["infer_args"]["infer_step"]
args.epoch_id = conf_info["infer_args"]["epoch_id"]
args.split_rope = conf_info["train_args"]["split_rope"]
args.split1 = conf_info["train_args"]["split1"]
args.split2 = conf_info["train_args"]["split2"]
args.split3 = conf_info["train_args"]["split3"]
test_video(args)
# specify_video(args) |