|
|
""" |
|
|
批处理优化版本的 latent_to_video |
|
|
相比原版逐帧处理,使用批处理加速约 10-30 倍 |
|
|
v2: 优化 GPU→CPU 传输和视频编码,使用流式处理 |
|
|
""" |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
_PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
|
|
if _PROJECT_ROOT not in sys.path: |
|
|
sys.path.insert(0, _PROJECT_ROOT) |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
from omegaconf import OmegaConf |
|
|
import fire |
|
|
import imageio |
|
|
import moviepy.editor as mp |
|
|
from tqdm import tqdm |
|
|
import time |
|
|
import subprocess |
|
|
import tempfile |
|
|
|
|
|
|
|
|
def init_fn(config_path, version): |
|
|
sys.path.insert(0, f'./utils/model_{version}') |
|
|
from utils import instantiate |
|
|
config = OmegaConf.load(config_path) |
|
|
module = instantiate(config.model, instantiate_module=False) |
|
|
model = module(config=config) |
|
|
checkpoint = torch.load(config.resume_ckpt, map_location="cpu") |
|
|
model.load_state_dict(checkpoint["state_dict"], strict=False) |
|
|
model.eval().to("cuda") |
|
|
transform = T.Compose([ |
|
|
T.Resize((512, 512)), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.5], [0.5]), |
|
|
]) |
|
|
return { |
|
|
"transform": transform, |
|
|
"flow_estimator": model.flow_estimator, |
|
|
"face_generator": model.face_generator, |
|
|
"face_encoder": model.face_encoder, |
|
|
} |
|
|
|
|
|
|
|
|
def latent_to_video_batch( |
|
|
npz_dir="./test_case/", |
|
|
save_dir="./test_case/", |
|
|
save_fps: int = 25, |
|
|
config_path: str = './configs/head_animator_best_0416.yaml', |
|
|
version: str = '0416', |
|
|
batch_size: int = 32, |
|
|
use_fp16: bool = True, |
|
|
): |
|
|
""" |
|
|
批处理优化版本的 latent_to_video |
|
|
|
|
|
Args: |
|
|
npz_dir: NPZ 文件目录 |
|
|
save_dir: 输出视频目录 |
|
|
save_fps: 输出视频帧率 |
|
|
config_path: 模型配置文件路径 |
|
|
version: 模型版本 |
|
|
batch_size: 批处理大小,根据显存调整 (默认 32,显存不足可降到 16 或 8) |
|
|
use_fp16: 是否使用混合精度加速 (默认 True) |
|
|
""" |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
config_path = config_path.replace("0416", version) |
|
|
|
|
|
|
|
|
print("Initializing models...") |
|
|
ctx = init_fn(config_path, version) |
|
|
transform = ctx["transform"] |
|
|
flow_estimator = ctx["flow_estimator"] |
|
|
face_generator = ctx["face_generator"] |
|
|
face_encoder = ctx["face_encoder"] |
|
|
|
|
|
|
|
|
npz_files = [f for f in os.listdir(npz_dir) if f.endswith('_output.npz')] |
|
|
print(f"Found {len(npz_files)} files to process") |
|
|
print(f"Batch size: {batch_size}, FP16: {use_fp16}") |
|
|
|
|
|
total_frames = 0 |
|
|
total_time = 0 |
|
|
|
|
|
|
|
|
for npz_file in tqdm(npz_files, desc="Processing files"): |
|
|
if not npz_file.endswith('.npz'): |
|
|
continue |
|
|
try: |
|
|
npz_path = os.path.join(npz_dir, npz_file) |
|
|
data = np.load(npz_path, allow_pickle=True) |
|
|
motion_latent = torch.from_numpy(data["motion_latent"]).to("cuda").float() |
|
|
if len(motion_latent.shape) == 3: |
|
|
motion_latent = motion_latent.squeeze(0) |
|
|
num_frames = motion_latent.shape[0] |
|
|
print(f"\nProcessing {npz_file} with {num_frames} frames") |
|
|
|
|
|
|
|
|
ref_img_path = str(data["ref_img_path"]) |
|
|
if not os.path.isabs(ref_img_path): |
|
|
ref_img_path = os.path.join(_PROJECT_ROOT, ref_img_path) |
|
|
ref_img = Image.open(ref_img_path).convert("RGB") |
|
|
ref_img = transform(ref_img).unsqueeze(0).to("cuda") |
|
|
|
|
|
video_id = str(data["video_id"]) |
|
|
|
|
|
if video_id.startswith('-'): |
|
|
video_id = video_id[1:] |
|
|
|
|
|
|
|
|
audio_path = str(data["audio_path"]) if "audio_path" in data.files else None |
|
|
if audio_path and not os.path.isabs(audio_path): |
|
|
audio_path = os.path.join(_PROJECT_ROOT, audio_path) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
temp_mp4 = os.path.join(save_dir, f"{video_id}_temp.mp4") |
|
|
final_mp4 = os.path.join(save_dir, f"{video_id}.mp4") |
|
|
finalfinal_mp4 = os.path.join(save_dir, f"{str(data['video_id'])}.mp4") |
|
|
|
|
|
if num_frames == 1: |
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(enabled=use_fp16): |
|
|
face_feat = face_encoder(ref_img) |
|
|
tgt = flow_estimator(motion_latent[0:1], motion_latent[0:1]) |
|
|
recon = face_generator(tgt, face_feat) |
|
|
if use_fp16: |
|
|
recon = recon.float() |
|
|
|
|
|
video_np = recon.permute(0, 2, 3, 1).cpu().numpy() |
|
|
video_np = np.clip((video_np + 1) / 2 * 255, 0, 255).astype("uint8") |
|
|
out_path = os.path.join(save_dir, f"{video_id}_rec.png") |
|
|
Image.fromarray(video_np[0]).save(out_path) |
|
|
else: |
|
|
|
|
|
|
|
|
ffmpeg_cmd = [ |
|
|
'ffmpeg', '-y', |
|
|
'-f', 'rawvideo', |
|
|
'-vcodec', 'rawvideo', |
|
|
'-s', '512x512', |
|
|
'-pix_fmt', 'rgb24', |
|
|
'-r', str(save_fps), |
|
|
'-i', '-', |
|
|
'-c:v', 'libx264', |
|
|
'-preset', 'fast', |
|
|
'-crf', '18', |
|
|
'-pix_fmt', 'yuv420p', |
|
|
temp_mp4 |
|
|
] |
|
|
|
|
|
ffmpeg_process = subprocess.Popen( |
|
|
ffmpeg_cmd, |
|
|
stdin=subprocess.PIPE, |
|
|
stdout=subprocess.DEVNULL, |
|
|
stderr=subprocess.DEVNULL |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(enabled=use_fp16): |
|
|
face_feat = face_encoder(ref_img) |
|
|
ref_latent = motion_latent[0:1] |
|
|
|
|
|
|
|
|
for i in range(0, num_frames, batch_size): |
|
|
batch_end = min(i + batch_size, num_frames) |
|
|
current_batch_size = batch_end - i |
|
|
|
|
|
|
|
|
batch_motion = motion_latent[i:batch_end] |
|
|
|
|
|
|
|
|
ref_latent_expanded = ref_latent.expand(current_batch_size, -1) |
|
|
|
|
|
|
|
|
face_feat_expanded = face_feat.expand(current_batch_size, -1, -1, -1, -1) |
|
|
|
|
|
|
|
|
tgt = flow_estimator(ref_latent_expanded, batch_motion) |
|
|
|
|
|
|
|
|
recon = face_generator(tgt, face_feat_expanded) |
|
|
|
|
|
|
|
|
|
|
|
recon = recon.float() |
|
|
recon = (recon + 1) / 2 * 255 |
|
|
recon = recon.clamp(0, 255).to(torch.uint8) |
|
|
recon = recon.permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
|
|
|
frames_np = recon.cpu().numpy() |
|
|
ffmpeg_process.stdin.write(frames_np.tobytes()) |
|
|
|
|
|
|
|
|
ffmpeg_process.stdin.close() |
|
|
ffmpeg_process.wait() |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
total_frames += num_frames |
|
|
total_time += elapsed |
|
|
fps = num_frames / elapsed |
|
|
print(f" Rendered + encoded {num_frames} frames in {elapsed:.2f}s ({fps:.1f} fps)") |
|
|
|
|
|
|
|
|
if audio_path and os.path.exists(audio_path): |
|
|
|
|
|
final_with_audio = os.path.join(save_dir, f"{video_id}_with_audio.mp4") |
|
|
ffmpeg_audio_cmd = [ |
|
|
'ffmpeg', '-y', |
|
|
'-i', temp_mp4, |
|
|
'-i', audio_path, |
|
|
'-c:v', 'copy', |
|
|
'-c:a', 'aac', |
|
|
'-shortest', |
|
|
final_with_audio |
|
|
] |
|
|
subprocess.run(ffmpeg_audio_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
|
os.remove(temp_mp4) |
|
|
os.rename(final_with_audio, finalfinal_mp4) |
|
|
else: |
|
|
os.rename(temp_mp4, finalfinal_mp4) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
print(f"Error processing {npz_file}: {str(e)}") |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
|
|
|
if total_time > 0: |
|
|
print(f"\n{'='*50}") |
|
|
print(f"总计: {total_frames} 帧, {total_time:.2f} 秒") |
|
|
print(f"平均渲染速度: {total_frames / total_time:.1f} fps") |
|
|
print(f"{'='*50}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
fire.Fire(latent_to_video_batch) |
|
|
|
|
|
|
|
|
|