DyStream / tools /visualization_0416 /latent_to_video_batch.py
robinwitch's picture
upload ckpt
872b1a7
"""
批处理优化版本的 latent_to_video
相比原版逐帧处理,使用批处理加速约 10-30 倍
v2: 优化 GPU→CPU 传输和视频编码,使用流式处理
"""
import sys
import os
# 获取项目根目录并添加到 sys.path 最前面,确保导入正确的 utils 模块
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as T
from omegaconf import OmegaConf
import fire
import imageio
import moviepy.editor as mp
from tqdm import tqdm
import time
import subprocess
import tempfile
def init_fn(config_path, version):
sys.path.insert(0, f'./utils/model_{version}')
from utils import instantiate
config = OmegaConf.load(config_path)
module = instantiate(config.model, instantiate_module=False)
model = module(config=config)
checkpoint = torch.load(config.resume_ckpt, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"], strict=False)
model.eval().to("cuda")
transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
])
return {
"transform": transform,
"flow_estimator": model.flow_estimator,
"face_generator": model.face_generator,
"face_encoder": model.face_encoder,
}
def latent_to_video_batch(
npz_dir="./test_case/",
save_dir="./test_case/",
save_fps: int = 25,
config_path: str = './configs/head_animator_best_0416.yaml',
version: str = '0416',
batch_size: int = 32,
use_fp16: bool = True,
):
"""
批处理优化版本的 latent_to_video
Args:
npz_dir: NPZ 文件目录
save_dir: 输出视频目录
save_fps: 输出视频帧率
config_path: 模型配置文件路径
version: 模型版本
batch_size: 批处理大小,根据显存调整 (默认 32,显存不足可降到 16 或 8)
use_fp16: 是否使用混合精度加速 (默认 True)
"""
os.makedirs(save_dir, exist_ok=True)
config_path = config_path.replace("0416", version)
# Initialize models only once
print("Initializing models...")
ctx = init_fn(config_path, version)
transform = ctx["transform"]
flow_estimator = ctx["flow_estimator"]
face_generator = ctx["face_generator"]
face_encoder = ctx["face_encoder"]
# Get all npz files
npz_files = [f for f in os.listdir(npz_dir) if f.endswith('_output.npz')]
print(f"Found {len(npz_files)} files to process")
print(f"Batch size: {batch_size}, FP16: {use_fp16}")
total_frames = 0
total_time = 0
# Process each file
for npz_file in tqdm(npz_files, desc="Processing files"):
if not npz_file.endswith('.npz'):
continue
try:
npz_path = os.path.join(npz_dir, npz_file)
data = np.load(npz_path, allow_pickle=True)
motion_latent = torch.from_numpy(data["motion_latent"]).to("cuda").float()
if len(motion_latent.shape) == 3:
motion_latent = motion_latent.squeeze(0)
num_frames = motion_latent.shape[0]
print(f"\nProcessing {npz_file} with {num_frames} frames")
# 处理 ref_img_path - 如果是相对路径,基于项目根目录解析
ref_img_path = str(data["ref_img_path"])
if not os.path.isabs(ref_img_path):
ref_img_path = os.path.join(_PROJECT_ROOT, ref_img_path)
ref_img = Image.open(ref_img_path).convert("RGB")
ref_img = transform(ref_img).unsqueeze(0).to("cuda")
video_id = str(data["video_id"])
# Remove leading dash to prevent FFMPEG command line parsing issues
if video_id.startswith('-'):
video_id = video_id[1:]
# 处理 audio_path
audio_path = str(data["audio_path"]) if "audio_path" in data.files else None
if audio_path and not os.path.isabs(audio_path):
audio_path = os.path.join(_PROJECT_ROOT, audio_path)
start_time = time.time()
# 准备输出路径
temp_mp4 = os.path.join(save_dir, f"{video_id}_temp.mp4")
final_mp4 = os.path.join(save_dir, f"{video_id}.mp4")
finalfinal_mp4 = os.path.join(save_dir, f"{str(data['video_id'])}.mp4")
if num_frames == 1:
# 单帧情况
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=use_fp16):
face_feat = face_encoder(ref_img)
tgt = flow_estimator(motion_latent[0:1], motion_latent[0:1])
recon = face_generator(tgt, face_feat)
if use_fp16:
recon = recon.float()
video_np = recon.permute(0, 2, 3, 1).cpu().numpy()
video_np = np.clip((video_np + 1) / 2 * 255, 0, 255).astype("uint8")
out_path = os.path.join(save_dir, f"{video_id}_rec.png")
Image.fromarray(video_np[0]).save(out_path)
else:
# 多帧情况 - 使用 FFmpeg pipe 流式编码
# 启动 FFmpeg 进程
ffmpeg_cmd = [
'ffmpeg', '-y',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-s', '512x512',
'-pix_fmt', 'rgb24',
'-r', str(save_fps),
'-i', '-',
'-c:v', 'libx264',
'-preset', 'fast',
'-crf', '18',
'-pix_fmt', 'yuv420p',
temp_mp4
]
ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=use_fp16):
face_feat = face_encoder(ref_img) # (1, 32, 16, 64, 64)
ref_latent = motion_latent[0:1] # 参考帧的 latent
# 批处理推理 + 流式写入
for i in range(0, num_frames, batch_size):
batch_end = min(i + batch_size, num_frames)
current_batch_size = batch_end - i
# 获取当前批次的 motion latent
batch_motion = motion_latent[i:batch_end]
# 扩展参考帧 latent 到批次大小
ref_latent_expanded = ref_latent.expand(current_batch_size, -1)
# 扩展 face_feat 到批次大小
face_feat_expanded = face_feat.expand(current_batch_size, -1, -1, -1, -1)
# 批量计算 flow
tgt = flow_estimator(ref_latent_expanded, batch_motion)
# 批量生成图像
recon = face_generator(tgt, face_feat_expanded)
# 转换并写入 - 直接在 GPU 上做归一化
# (batch, 3, 512, 512) -> (batch, 512, 512, 3)
recon = recon.float()
recon = (recon + 1) / 2 * 255
recon = recon.clamp(0, 255).to(torch.uint8)
recon = recon.permute(0, 2, 3, 1).contiguous()
# 分块传输到 CPU 并写入
frames_np = recon.cpu().numpy()
ffmpeg_process.stdin.write(frames_np.tobytes())
# 关闭 FFmpeg
ffmpeg_process.stdin.close()
ffmpeg_process.wait()
elapsed = time.time() - start_time
total_frames += num_frames
total_time += elapsed
fps = num_frames / elapsed
print(f" Rendered + encoded {num_frames} frames in {elapsed:.2f}s ({fps:.1f} fps)")
# 合并音频
if audio_path and os.path.exists(audio_path):
# 使用 FFmpeg 直接合并音频(比 moviepy 快很多)
final_with_audio = os.path.join(save_dir, f"{video_id}_with_audio.mp4")
ffmpeg_audio_cmd = [
'ffmpeg', '-y',
'-i', temp_mp4,
'-i', audio_path,
'-c:v', 'copy',
'-c:a', 'aac',
'-shortest',
final_with_audio
]
subprocess.run(ffmpeg_audio_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
os.remove(temp_mp4)
os.rename(final_with_audio, finalfinal_mp4)
else:
os.rename(temp_mp4, finalfinal_mp4)
except Exception as e:
import traceback
print(f"Error processing {npz_file}: {str(e)}")
traceback.print_exc()
continue
# 打印总体统计
if total_time > 0:
print(f"\n{'='*50}")
print(f"总计: {total_frames} 帧, {total_time:.2f} 秒")
print(f"平均渲染速度: {total_frames / total_time:.1f} fps")
print(f"{'='*50}")
if __name__ == "__main__":
fire.Fire(latent_to_video_batch)
# Example usage:
# python latent_to_video_batch.py --npz_dir ./test_case/ --save_dir ./test_case/ --batch_size 32 --use_fp16 True