Xiaoyan-Yang commited on
Commit
234acd4
·
verified ·
1 Parent(s): 0bf66f3

Delete telestylevideo_inference.py

Browse files
Files changed (1) hide show
  1. telestylevideo_inference.py +0 -207
telestylevideo_inference.py DELETED
@@ -1,207 +0,0 @@
1
-
2
- import os
3
- import cv2
4
- import json
5
- import time
6
- import torch
7
- import numpy as np
8
- import torch.nn.functional as F
9
-
10
- from PIL import Image
11
- from typing import List, Dict, Optional, Tuple
12
- from einops import rearrange
13
- from omegaconf import OmegaConf
14
- from decord import VideoReader
15
- from diffusers.utils import export_to_video
16
- from diffusers.models import AutoencoderKLWan
17
- from diffusers.schedulers import UniPCMultistepScheduler
18
- from telestylevideo_transformer import WanTransformer3DModel
19
- from telestylevideo_pipeline import WanPipeline
20
-
21
-
22
- def load_video(video_path: str, video_length: int) -> torch.Tensor:
23
- if "png" in video_path.lower() or "jpeg" in video_path.lower() or "jpg" in video_path.lower():
24
- image = cv2.imread(video_path)
25
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
26
- image = np.array(image)
27
- image = image[None, None] # 添加 batch 和 frame 维度
28
- image = torch.from_numpy(image) / 127.5 - 1.0
29
- return image
30
-
31
- vr = VideoReader(video_path)
32
- frames = list(range(min(len(vr), video_length)))
33
- images = vr.get_batch(frames).asnumpy()
34
- images = torch.from_numpy(images) / 127.5 - 1.0
35
- images = images[None] # 添加 batch 维度
36
- return images
37
-
38
- class VideoStyleInference:
39
- """
40
- 视频风格转换推理类
41
- """
42
- def __init__(self, config: Dict):
43
- """
44
- 初始化推理器
45
-
46
- Args:
47
- config: 配置字典
48
- """
49
- self.config = config
50
- self.device = torch.device(f"cuda:0")
51
- self.random_seed = config['random_seed']
52
- self.video_length = config['video_length']
53
- self.H = config['height']
54
- self.W = config['width']
55
- self.num_inference_steps = config['num_inference_steps']
56
- self.vae_path = os.path.join(config['ckpt_t2v_path'], "vae")
57
- self.transformer_config_path = os.path.join(config['ckpt_t2v_path'], "transformer_config.json")
58
- self.scheduler_path = os.path.join(config['ckpt_t2v_path'], "scheduler")
59
- self.ckpt_path = config['ckpt_dit_path']
60
- self.output_path = config['output_path']
61
- self.prompt_embeds_path = config['prompt_embeds_path']
62
-
63
- # 加载模型
64
- self._load_models()
65
-
66
- def _load_models(self):
67
- """
68
- 加载模型和权重
69
- """
70
- # 加载状态字典
71
- state_dict = torch.load(self.ckpt_path, map_location="cpu")["transformer_state_dict"]
72
- transformer_state_dict = {}
73
- for key in state_dict:
74
- transformer_state_dict[key.split("module.")[1]] = state_dict[key]
75
-
76
- # 加载配置
77
- config = OmegaConf.to_container(
78
- OmegaConf.load(self.transformer_config_path)
79
- )
80
-
81
- # 初始化模型
82
- self.vae = AutoencoderKLWan.from_pretrained(self.vae_path, torch_dtype=torch.float16).to(self.device)
83
- self.transformer = WanTransformer3DModel(**config)
84
- self.transformer.load_state_dict(transformer_state_dict)
85
- self.transformer = self.transformer.to(self.device).half()
86
- self.scheduler = UniPCMultistepScheduler.from_pretrained(self.scheduler_path)
87
-
88
- # 初始化管道
89
- self.pipe = WanPipeline(
90
- transformer=self.transformer,
91
- vae=self.vae,
92
- scheduler=self.scheduler
93
- )
94
- self.pipe.to(self.device)
95
-
96
- def inference(self, source_videos: torch.Tensor, first_images: torch.Tensor, video_path: str, step: int) -> torch.Tensor:
97
- """
98
- 执行风格转换推理
99
-
100
- Args:
101
- source_videos: 源视频张量
102
- first_images: 风格参考图像张量
103
- video_path: 源视频路径
104
- step: 推理步骤索引
105
-
106
- Returns:
107
- 生成的视频张量
108
- """
109
- source_videos = source_videos.to(self.device).half()
110
- first_images = first_images.to(self.device).half()
111
- prompt_embeds_ = torch.load(self.prompt_embeds_path).to(self.device).half()
112
-
113
- print(f"Source videos shape: {source_videos.shape}, First images shape: {first_images.shape}")
114
-
115
- latents_mean = torch.tensor(self.vae.config.latents_mean)
116
- latents_mean = latents_mean.view(1, 16, 1, 1, 1).to(self.device, torch.float16)
117
- latents_std = 1.0 / torch.tensor(self.vae.config.latents_std)
118
- latents_std = latents_std.view(1, 16, 1, 1, 1).to(self.device, torch.float16)
119
-
120
- bsz = 1
121
- _, _, h, w, _ = source_videos.shape
122
-
123
- if h < w:
124
- output_h, output_w = self.H, self.W
125
- else:
126
- output_h, output_w = self.W, self.H
127
-
128
- with torch.no_grad():
129
- # 处理源视频
130
- source_videos = rearrange(source_videos, "b f h w c -> (b f) c h w")
131
- source_videos = F.interpolate(source_videos, (output_h, output_w), mode="bilinear")
132
- source_videos = rearrange(source_videos, "(b f) c h w -> b c f h w", b=bsz)
133
-
134
- # 处理风格参考图像
135
- first_images = rearrange(first_images, "b f h w c -> (b f) c h w")
136
- first_images = F.interpolate(first_images, (output_h, output_w), mode="bilinear")
137
- first_images = rearrange(first_images, "(b f) c h w -> b c f h w", b=bsz)
138
-
139
- # 编码到潜在空间
140
- source_latents = self.vae.encode(source_videos).latent_dist.mode()
141
- source_latents = (source_latents - latents_mean) * latents_std
142
-
143
- first_latents = self.vae.encode(first_images).latent_dist.mode()
144
- first_latents = (first_latents - latents_mean) * latents_std
145
-
146
- neg_first_latents = self.vae.encode(torch.zeros_like(first_images)).latent_dist.mode()
147
- neg_first_latents = (neg_first_latents - latents_mean) * latents_std
148
-
149
- video = self.pipe(
150
- source_latents=source_latents,
151
- first_latents=first_latents,
152
- neg_first_latents=neg_first_latents,
153
- num_frames=self.video_length,
154
- guidance_scale=3.0,
155
- height=output_h,
156
- width=output_w,
157
- prompt_embeds_=prompt_embeds_,
158
- num_inference_steps=self.num_inference_steps,
159
- generator=torch.Generator(device=self.device).manual_seed(self.random_seed),
160
- ).frames[0]
161
-
162
- return video
163
-
164
- if __name__ == "__main__":
165
- config = {
166
- "random_seed": 42,
167
- "video_length": 129,
168
- "height": 720,
169
- "width": 1248,
170
- "num_inference_steps": 25,
171
- "ckpt_t2v_path": "./Wan2.1-T2V-1.3B-Diffusers",
172
- "ckpt_dit_path": "weights/dit.ckpt",
173
- "prompt_embeds_path": "weights/prompt_embeds.pth",
174
- "output_path": "./results"
175
- }
176
-
177
- # 初始化推理器
178
- inference_engine = VideoStyleInference(config)
179
-
180
- data_list = [
181
- {
182
- "video_path": "assets/example/2.mp4",
183
- "image_path": "assets/example/2-0.png"
184
- },
185
- {
186
- "video_path": "assets/example/2.mp4",
187
- "image_path": "assets/example/2-1.png"
188
- },
189
- ]
190
-
191
- for step, data in enumerate(data_list):
192
- video_path = data['video_path']
193
- style_image_path = data['image_path']
194
-
195
- source_video = load_video(video_path, config['video_length'])
196
- style_image = Image.open(style_image_path)
197
- style_image = np.array(style_image)
198
- style_image = torch.from_numpy(style_image) / 127.5 - 1.0
199
- style_image = style_image[None, None, :, :, :] # 添加 batch 和 frame 维度
200
-
201
- with torch.no_grad():
202
- generated_video = inference_engine.inference(source_video, style_image, video_path, step)
203
-
204
- os.makedirs(config['output_path'], exist_ok=True)
205
- output_filename = f"{config['output_path']}/{step}.mp4"
206
- export_to_video(generated_video, output_filename)
207
-