File size: 10,290 Bytes
872b1a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
"""
批处理优化版本的 latent_to_video
相比原版逐帧处理,使用批处理加速约 10-30 倍
v2: 优化 GPU→CPU 传输和视频编码,使用流式处理
"""
import sys
import os
# 获取项目根目录并添加到 sys.path 最前面,确保导入正确的 utils 模块
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as T
from omegaconf import OmegaConf
import fire
import imageio
import moviepy.editor as mp
from tqdm import tqdm
import time
import subprocess
import tempfile
def init_fn(config_path, version):
sys.path.insert(0, f'./utils/model_{version}')
from utils import instantiate
config = OmegaConf.load(config_path)
module = instantiate(config.model, instantiate_module=False)
model = module(config=config)
checkpoint = torch.load(config.resume_ckpt, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"], strict=False)
model.eval().to("cuda")
transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
])
return {
"transform": transform,
"flow_estimator": model.flow_estimator,
"face_generator": model.face_generator,
"face_encoder": model.face_encoder,
}
def latent_to_video_batch(
npz_dir="./test_case/",
save_dir="./test_case/",
save_fps: int = 25,
config_path: str = './configs/head_animator_best_0416.yaml',
version: str = '0416',
batch_size: int = 32,
use_fp16: bool = True,
):
"""
批处理优化版本的 latent_to_video
Args:
npz_dir: NPZ 文件目录
save_dir: 输出视频目录
save_fps: 输出视频帧率
config_path: 模型配置文件路径
version: 模型版本
batch_size: 批处理大小,根据显存调整 (默认 32,显存不足可降到 16 或 8)
use_fp16: 是否使用混合精度加速 (默认 True)
"""
os.makedirs(save_dir, exist_ok=True)
config_path = config_path.replace("0416", version)
# Initialize models only once
print("Initializing models...")
ctx = init_fn(config_path, version)
transform = ctx["transform"]
flow_estimator = ctx["flow_estimator"]
face_generator = ctx["face_generator"]
face_encoder = ctx["face_encoder"]
# Get all npz files
npz_files = [f for f in os.listdir(npz_dir) if f.endswith('_output.npz')]
print(f"Found {len(npz_files)} files to process")
print(f"Batch size: {batch_size}, FP16: {use_fp16}")
total_frames = 0
total_time = 0
# Process each file
for npz_file in tqdm(npz_files, desc="Processing files"):
if not npz_file.endswith('.npz'):
continue
try:
npz_path = os.path.join(npz_dir, npz_file)
data = np.load(npz_path, allow_pickle=True)
motion_latent = torch.from_numpy(data["motion_latent"]).to("cuda").float()
if len(motion_latent.shape) == 3:
motion_latent = motion_latent.squeeze(0)
num_frames = motion_latent.shape[0]
print(f"\nProcessing {npz_file} with {num_frames} frames")
# 处理 ref_img_path - 如果是相对路径,基于项目根目录解析
ref_img_path = str(data["ref_img_path"])
if not os.path.isabs(ref_img_path):
ref_img_path = os.path.join(_PROJECT_ROOT, ref_img_path)
ref_img = Image.open(ref_img_path).convert("RGB")
ref_img = transform(ref_img).unsqueeze(0).to("cuda")
video_id = str(data["video_id"])
# Remove leading dash to prevent FFMPEG command line parsing issues
if video_id.startswith('-'):
video_id = video_id[1:]
# 处理 audio_path
audio_path = str(data["audio_path"]) if "audio_path" in data.files else None
if audio_path and not os.path.isabs(audio_path):
audio_path = os.path.join(_PROJECT_ROOT, audio_path)
start_time = time.time()
# 准备输出路径
temp_mp4 = os.path.join(save_dir, f"{video_id}_temp.mp4")
final_mp4 = os.path.join(save_dir, f"{video_id}.mp4")
finalfinal_mp4 = os.path.join(save_dir, f"{str(data['video_id'])}.mp4")
if num_frames == 1:
# 单帧情况
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=use_fp16):
face_feat = face_encoder(ref_img)
tgt = flow_estimator(motion_latent[0:1], motion_latent[0:1])
recon = face_generator(tgt, face_feat)
if use_fp16:
recon = recon.float()
video_np = recon.permute(0, 2, 3, 1).cpu().numpy()
video_np = np.clip((video_np + 1) / 2 * 255, 0, 255).astype("uint8")
out_path = os.path.join(save_dir, f"{video_id}_rec.png")
Image.fromarray(video_np[0]).save(out_path)
else:
# 多帧情况 - 使用 FFmpeg pipe 流式编码
# 启动 FFmpeg 进程
ffmpeg_cmd = [
'ffmpeg', '-y',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-s', '512x512',
'-pix_fmt', 'rgb24',
'-r', str(save_fps),
'-i', '-',
'-c:v', 'libx264',
'-preset', 'fast',
'-crf', '18',
'-pix_fmt', 'yuv420p',
temp_mp4
]
ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=use_fp16):
face_feat = face_encoder(ref_img) # (1, 32, 16, 64, 64)
ref_latent = motion_latent[0:1] # 参考帧的 latent
# 批处理推理 + 流式写入
for i in range(0, num_frames, batch_size):
batch_end = min(i + batch_size, num_frames)
current_batch_size = batch_end - i
# 获取当前批次的 motion latent
batch_motion = motion_latent[i:batch_end]
# 扩展参考帧 latent 到批次大小
ref_latent_expanded = ref_latent.expand(current_batch_size, -1)
# 扩展 face_feat 到批次大小
face_feat_expanded = face_feat.expand(current_batch_size, -1, -1, -1, -1)
# 批量计算 flow
tgt = flow_estimator(ref_latent_expanded, batch_motion)
# 批量生成图像
recon = face_generator(tgt, face_feat_expanded)
# 转换并写入 - 直接在 GPU 上做归一化
# (batch, 3, 512, 512) -> (batch, 512, 512, 3)
recon = recon.float()
recon = (recon + 1) / 2 * 255
recon = recon.clamp(0, 255).to(torch.uint8)
recon = recon.permute(0, 2, 3, 1).contiguous()
# 分块传输到 CPU 并写入
frames_np = recon.cpu().numpy()
ffmpeg_process.stdin.write(frames_np.tobytes())
# 关闭 FFmpeg
ffmpeg_process.stdin.close()
ffmpeg_process.wait()
elapsed = time.time() - start_time
total_frames += num_frames
total_time += elapsed
fps = num_frames / elapsed
print(f" Rendered + encoded {num_frames} frames in {elapsed:.2f}s ({fps:.1f} fps)")
# 合并音频
if audio_path and os.path.exists(audio_path):
# 使用 FFmpeg 直接合并音频(比 moviepy 快很多)
final_with_audio = os.path.join(save_dir, f"{video_id}_with_audio.mp4")
ffmpeg_audio_cmd = [
'ffmpeg', '-y',
'-i', temp_mp4,
'-i', audio_path,
'-c:v', 'copy',
'-c:a', 'aac',
'-shortest',
final_with_audio
]
subprocess.run(ffmpeg_audio_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
os.remove(temp_mp4)
os.rename(final_with_audio, finalfinal_mp4)
else:
os.rename(temp_mp4, finalfinal_mp4)
except Exception as e:
import traceback
print(f"Error processing {npz_file}: {str(e)}")
traceback.print_exc()
continue
# 打印总体统计
if total_time > 0:
print(f"\n{'='*50}")
print(f"总计: {total_frames} 帧, {total_time:.2f} 秒")
print(f"平均渲染速度: {total_frames / total_time:.1f} fps")
print(f"{'='*50}")
if __name__ == "__main__":
fire.Fire(latent_to_video_batch)
# Example usage:
# python latent_to_video_batch.py --npz_dir ./test_case/ --save_dir ./test_case/ --batch_size 32 --use_fp16 True
|