Delete telestylevideo_inference.py
Browse files- telestylevideo_inference.py +0 -207
telestylevideo_inference.py
DELETED
|
@@ -1,207 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
-
import cv2
|
| 4 |
-
import json
|
| 5 |
-
import time
|
| 6 |
-
import torch
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from PIL import Image
|
| 11 |
-
from typing import List, Dict, Optional, Tuple
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
from omegaconf import OmegaConf
|
| 14 |
-
from decord import VideoReader
|
| 15 |
-
from diffusers.utils import export_to_video
|
| 16 |
-
from diffusers.models import AutoencoderKLWan
|
| 17 |
-
from diffusers.schedulers import UniPCMultistepScheduler
|
| 18 |
-
from telestylevideo_transformer import WanTransformer3DModel
|
| 19 |
-
from telestylevideo_pipeline import WanPipeline
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def load_video(video_path: str, video_length: int) -> torch.Tensor:
|
| 23 |
-
if "png" in video_path.lower() or "jpeg" in video_path.lower() or "jpg" in video_path.lower():
|
| 24 |
-
image = cv2.imread(video_path)
|
| 25 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 26 |
-
image = np.array(image)
|
| 27 |
-
image = image[None, None] # 添加 batch 和 frame 维度
|
| 28 |
-
image = torch.from_numpy(image) / 127.5 - 1.0
|
| 29 |
-
return image
|
| 30 |
-
|
| 31 |
-
vr = VideoReader(video_path)
|
| 32 |
-
frames = list(range(min(len(vr), video_length)))
|
| 33 |
-
images = vr.get_batch(frames).asnumpy()
|
| 34 |
-
images = torch.from_numpy(images) / 127.5 - 1.0
|
| 35 |
-
images = images[None] # 添加 batch 维度
|
| 36 |
-
return images
|
| 37 |
-
|
| 38 |
-
class VideoStyleInference:
|
| 39 |
-
"""
|
| 40 |
-
视频风格转换推理类
|
| 41 |
-
"""
|
| 42 |
-
def __init__(self, config: Dict):
|
| 43 |
-
"""
|
| 44 |
-
初始化推理器
|
| 45 |
-
|
| 46 |
-
Args:
|
| 47 |
-
config: 配置字典
|
| 48 |
-
"""
|
| 49 |
-
self.config = config
|
| 50 |
-
self.device = torch.device(f"cuda:0")
|
| 51 |
-
self.random_seed = config['random_seed']
|
| 52 |
-
self.video_length = config['video_length']
|
| 53 |
-
self.H = config['height']
|
| 54 |
-
self.W = config['width']
|
| 55 |
-
self.num_inference_steps = config['num_inference_steps']
|
| 56 |
-
self.vae_path = os.path.join(config['ckpt_t2v_path'], "vae")
|
| 57 |
-
self.transformer_config_path = os.path.join(config['ckpt_t2v_path'], "transformer_config.json")
|
| 58 |
-
self.scheduler_path = os.path.join(config['ckpt_t2v_path'], "scheduler")
|
| 59 |
-
self.ckpt_path = config['ckpt_dit_path']
|
| 60 |
-
self.output_path = config['output_path']
|
| 61 |
-
self.prompt_embeds_path = config['prompt_embeds_path']
|
| 62 |
-
|
| 63 |
-
# 加载模型
|
| 64 |
-
self._load_models()
|
| 65 |
-
|
| 66 |
-
def _load_models(self):
|
| 67 |
-
"""
|
| 68 |
-
加载模型和权重
|
| 69 |
-
"""
|
| 70 |
-
# 加载状态字典
|
| 71 |
-
state_dict = torch.load(self.ckpt_path, map_location="cpu")["transformer_state_dict"]
|
| 72 |
-
transformer_state_dict = {}
|
| 73 |
-
for key in state_dict:
|
| 74 |
-
transformer_state_dict[key.split("module.")[1]] = state_dict[key]
|
| 75 |
-
|
| 76 |
-
# 加载配置
|
| 77 |
-
config = OmegaConf.to_container(
|
| 78 |
-
OmegaConf.load(self.transformer_config_path)
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
# 初始化模型
|
| 82 |
-
self.vae = AutoencoderKLWan.from_pretrained(self.vae_path, torch_dtype=torch.float16).to(self.device)
|
| 83 |
-
self.transformer = WanTransformer3DModel(**config)
|
| 84 |
-
self.transformer.load_state_dict(transformer_state_dict)
|
| 85 |
-
self.transformer = self.transformer.to(self.device).half()
|
| 86 |
-
self.scheduler = UniPCMultistepScheduler.from_pretrained(self.scheduler_path)
|
| 87 |
-
|
| 88 |
-
# 初始化管道
|
| 89 |
-
self.pipe = WanPipeline(
|
| 90 |
-
transformer=self.transformer,
|
| 91 |
-
vae=self.vae,
|
| 92 |
-
scheduler=self.scheduler
|
| 93 |
-
)
|
| 94 |
-
self.pipe.to(self.device)
|
| 95 |
-
|
| 96 |
-
def inference(self, source_videos: torch.Tensor, first_images: torch.Tensor, video_path: str, step: int) -> torch.Tensor:
|
| 97 |
-
"""
|
| 98 |
-
执行风格转换推理
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
source_videos: 源视频张量
|
| 102 |
-
first_images: 风格参考图像张量
|
| 103 |
-
video_path: 源视频路径
|
| 104 |
-
step: 推理步骤索引
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
-
生成的视频张量
|
| 108 |
-
"""
|
| 109 |
-
source_videos = source_videos.to(self.device).half()
|
| 110 |
-
first_images = first_images.to(self.device).half()
|
| 111 |
-
prompt_embeds_ = torch.load(self.prompt_embeds_path).to(self.device).half()
|
| 112 |
-
|
| 113 |
-
print(f"Source videos shape: {source_videos.shape}, First images shape: {first_images.shape}")
|
| 114 |
-
|
| 115 |
-
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
| 116 |
-
latents_mean = latents_mean.view(1, 16, 1, 1, 1).to(self.device, torch.float16)
|
| 117 |
-
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std)
|
| 118 |
-
latents_std = latents_std.view(1, 16, 1, 1, 1).to(self.device, torch.float16)
|
| 119 |
-
|
| 120 |
-
bsz = 1
|
| 121 |
-
_, _, h, w, _ = source_videos.shape
|
| 122 |
-
|
| 123 |
-
if h < w:
|
| 124 |
-
output_h, output_w = self.H, self.W
|
| 125 |
-
else:
|
| 126 |
-
output_h, output_w = self.W, self.H
|
| 127 |
-
|
| 128 |
-
with torch.no_grad():
|
| 129 |
-
# 处理源视频
|
| 130 |
-
source_videos = rearrange(source_videos, "b f h w c -> (b f) c h w")
|
| 131 |
-
source_videos = F.interpolate(source_videos, (output_h, output_w), mode="bilinear")
|
| 132 |
-
source_videos = rearrange(source_videos, "(b f) c h w -> b c f h w", b=bsz)
|
| 133 |
-
|
| 134 |
-
# 处理风格参考图像
|
| 135 |
-
first_images = rearrange(first_images, "b f h w c -> (b f) c h w")
|
| 136 |
-
first_images = F.interpolate(first_images, (output_h, output_w), mode="bilinear")
|
| 137 |
-
first_images = rearrange(first_images, "(b f) c h w -> b c f h w", b=bsz)
|
| 138 |
-
|
| 139 |
-
# 编码到潜在空间
|
| 140 |
-
source_latents = self.vae.encode(source_videos).latent_dist.mode()
|
| 141 |
-
source_latents = (source_latents - latents_mean) * latents_std
|
| 142 |
-
|
| 143 |
-
first_latents = self.vae.encode(first_images).latent_dist.mode()
|
| 144 |
-
first_latents = (first_latents - latents_mean) * latents_std
|
| 145 |
-
|
| 146 |
-
neg_first_latents = self.vae.encode(torch.zeros_like(first_images)).latent_dist.mode()
|
| 147 |
-
neg_first_latents = (neg_first_latents - latents_mean) * latents_std
|
| 148 |
-
|
| 149 |
-
video = self.pipe(
|
| 150 |
-
source_latents=source_latents,
|
| 151 |
-
first_latents=first_latents,
|
| 152 |
-
neg_first_latents=neg_first_latents,
|
| 153 |
-
num_frames=self.video_length,
|
| 154 |
-
guidance_scale=3.0,
|
| 155 |
-
height=output_h,
|
| 156 |
-
width=output_w,
|
| 157 |
-
prompt_embeds_=prompt_embeds_,
|
| 158 |
-
num_inference_steps=self.num_inference_steps,
|
| 159 |
-
generator=torch.Generator(device=self.device).manual_seed(self.random_seed),
|
| 160 |
-
).frames[0]
|
| 161 |
-
|
| 162 |
-
return video
|
| 163 |
-
|
| 164 |
-
if __name__ == "__main__":
|
| 165 |
-
config = {
|
| 166 |
-
"random_seed": 42,
|
| 167 |
-
"video_length": 129,
|
| 168 |
-
"height": 720,
|
| 169 |
-
"width": 1248,
|
| 170 |
-
"num_inference_steps": 25,
|
| 171 |
-
"ckpt_t2v_path": "./Wan2.1-T2V-1.3B-Diffusers",
|
| 172 |
-
"ckpt_dit_path": "weights/dit.ckpt",
|
| 173 |
-
"prompt_embeds_path": "weights/prompt_embeds.pth",
|
| 174 |
-
"output_path": "./results"
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
# 初始化推理器
|
| 178 |
-
inference_engine = VideoStyleInference(config)
|
| 179 |
-
|
| 180 |
-
data_list = [
|
| 181 |
-
{
|
| 182 |
-
"video_path": "assets/example/2.mp4",
|
| 183 |
-
"image_path": "assets/example/2-0.png"
|
| 184 |
-
},
|
| 185 |
-
{
|
| 186 |
-
"video_path": "assets/example/2.mp4",
|
| 187 |
-
"image_path": "assets/example/2-1.png"
|
| 188 |
-
},
|
| 189 |
-
]
|
| 190 |
-
|
| 191 |
-
for step, data in enumerate(data_list):
|
| 192 |
-
video_path = data['video_path']
|
| 193 |
-
style_image_path = data['image_path']
|
| 194 |
-
|
| 195 |
-
source_video = load_video(video_path, config['video_length'])
|
| 196 |
-
style_image = Image.open(style_image_path)
|
| 197 |
-
style_image = np.array(style_image)
|
| 198 |
-
style_image = torch.from_numpy(style_image) / 127.5 - 1.0
|
| 199 |
-
style_image = style_image[None, None, :, :, :] # 添加 batch 和 frame 维度
|
| 200 |
-
|
| 201 |
-
with torch.no_grad():
|
| 202 |
-
generated_video = inference_engine.inference(source_video, style_image, video_path, step)
|
| 203 |
-
|
| 204 |
-
os.makedirs(config['output_path'], exist_ok=True)
|
| 205 |
-
output_filename = f"{config['output_path']}/{step}.mp4"
|
| 206 |
-
export_to_video(generated_video, output_filename)
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|