MuseTalk / avatar_preprocessor.py
marcosremar's picture
Add LFS tracking for binary files
b53da6e
"""
Avatar Preprocessor - Pre-compute everything related to the avatar video.
This saves ~30s per inference when using the same avatar repeatedly.
"""
import os
import cv2
import torch
import glob
import pickle
import numpy as np
import hashlib
import time
import argparse
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
# MuseTalk imports
from musetalk.utils.utils import get_file_type, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, coord_placeholder
from musetalk.utils.face_parsing import FaceParsing
class AvatarPreprocessor:
def __init__(self, avatar_dir: str = "./avatars"):
self.avatar_dir = Path(avatar_dir)
self.avatar_dir.mkdir(exist_ok=True)
# Config (must match server)
self.fps = 25
self.version = "v15"
self.extra_margin = 10
self.left_cheek_width = 90
self.right_cheek_width = 90
# Models (loaded lazily)
self.device = None
self.vae = None
self.fp = None
def _get_file_hash(self, file_path: str) -> str:
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()[:16]
def _load_models(self):
if self.vae is not None:
return
print("Loading models for preprocessing...")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.vae, _, _ = load_all_model(
unet_model_path="./models/musetalkV15/unet.pth",
vae_type="sd-vae",
unet_config="./models/musetalk/config.json",
device=self.device
)
self.vae.vae = self.vae.vae.half().to(self.device)
self.fp = FaceParsing(
left_cheek_width=self.left_cheek_width,
right_cheek_width=self.right_cheek_width
)
print("Models loaded!")
def preprocess_avatar(self, video_path: str, avatar_name: str = None) -> dict:
"""
Pre-process an avatar video and save all computed data.
Returns dict with paths to saved data.
"""
self._load_models()
video_path = Path(video_path)
if not video_path.exists():
raise FileNotFoundError(f"Video not found: {video_path}")
# Generate avatar name from hash if not provided
if avatar_name is None:
avatar_name = f"avatar_{self._get_file_hash(str(video_path))}"
avatar_path = self.avatar_dir / avatar_name
avatar_path.mkdir(exist_ok=True)
print(f"\n{'='*50}")
print(f"Pre-processing avatar: {avatar_name}")
print(f"{'='*50}")
timings = {}
total_start = time.time()
# 1. Extract frames
print("\n[1/4] Extracting frames...")
t0 = time.time()
frames_dir = avatar_path / "frames"
frames_dir.mkdir(exist_ok=True)
if get_file_type(str(video_path)) == "video":
cmd = f"ffmpeg -y -v fatal -i {video_path} -vf fps={self.fps} -start_number 0 {frames_dir}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(str(frames_dir / '*.[jpJP][pnPN]*[gG]')))
else:
# Single image
import shutil
dest = frames_dir / "00000000.png"
shutil.copy(video_path, dest)
input_img_list = [str(dest)]
timings["frame_extraction"] = time.time() - t0
print(f" Extracted {len(input_img_list)} frames in {timings['frame_extraction']:.2f}s")
# 2. Compute landmarks and bboxes
print("\n[2/4] Computing landmarks and bounding boxes...")
t0 = time.time()
coord_list, frame_list = get_landmark_and_bbox(input_img_list, 0)
timings["landmarks"] = time.time() - t0
print(f" Computed landmarks in {timings['landmarks']:.2f}s")
# 3. Compute VAE latents for each frame
print("\n[3/4] Computing VAE latents...")
t0 = time.time()
input_latent_list = []
crop_frames = [] # Store crop frames for blending
for i, (bbox, frame) in enumerate(tqdm(zip(coord_list, frame_list), total=len(coord_list))):
if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder):
continue
x1, y1, x2, y2 = bbox
if self.version == "v15":
y2 = y2 + self.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame_resized = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
with torch.no_grad():
latents = self.vae.get_latents_for_unet(crop_frame_resized)
# Convert to CPU numpy for storage
input_latent_list.append(latents.cpu().numpy())
crop_frames.append({
'bbox': bbox,
'original_size': (x2-x1, y2-y1)
})
timings["vae_encoding"] = time.time() - t0
print(f" Computed {len(input_latent_list)} latents in {timings['vae_encoding']:.2f}s")
# 4. Pre-compute face parsing masks (for blending)
print("\n[4/4] Pre-computing face parsing data...")
t0 = time.time()
parsing_data = []
for i, (bbox, frame) in enumerate(tqdm(zip(coord_list, frame_list), total=len(coord_list))):
if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder):
parsing_data.append(None)
continue
x1, y1, x2, y2 = bbox
if self.version == "v15":
y2 = y2 + self.extra_margin
y2 = min(y2, frame.shape[0])
# Get parsing mask for this frame region
crop_frame = frame[y1:y2, x1:x2]
try:
# Pre-compute the parsing for the crop region
parsing = self.fp.get_parsing(crop_frame)
parsing_data.append(parsing)
except:
parsing_data.append(None)
timings["face_parsing"] = time.time() - t0
print(f" Computed parsing in {timings['face_parsing']:.2f}s")
# Save all data
print("\nSaving preprocessed data...")
# Save metadata
metadata = {
'video_path': str(video_path),
'avatar_name': avatar_name,
'num_frames': len(input_img_list),
'fps': self.fps,
'version': self.version,
'extra_margin': self.extra_margin,
'timings': timings
}
with open(avatar_path / "metadata.pkl", 'wb') as f:
pickle.dump(metadata, f)
# Save coord_list (bounding boxes)
with open(avatar_path / "coords.pkl", 'wb') as f:
pickle.dump(coord_list, f)
# Save frame_list (original frames as numpy)
with open(avatar_path / "frames.pkl", 'wb') as f:
pickle.dump([f if isinstance(f, np.ndarray) else np.array(f) for f in frame_list], f)
# Save latents
with open(avatar_path / "latents.pkl", 'wb') as f:
pickle.dump(input_latent_list, f)
# Save crop frame info
with open(avatar_path / "crop_info.pkl", 'wb') as f:
pickle.dump(crop_frames, f)
# Save parsing data
with open(avatar_path / "parsing.pkl", 'wb') as f:
pickle.dump(parsing_data, f)
timings["total"] = time.time() - total_start
print(f"\n{'='*50}")
print(f"Avatar preprocessed successfully!")
print(f"Total time: {timings['total']:.2f}s")
print(f"Saved to: {avatar_path}")
print(f"{'='*50}")
return {
'avatar_name': avatar_name,
'avatar_path': str(avatar_path),
'num_frames': len(input_img_list),
'timings': timings
}
def list_avatars(self) -> list:
"""List all preprocessed avatars."""
avatars = []
for p in self.avatar_dir.iterdir():
if p.is_dir() and (p / "metadata.pkl").exists():
with open(p / "metadata.pkl", 'rb') as f:
metadata = pickle.load(f)
avatars.append(metadata)
return avatars
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess avatar for MuseTalk")
parser.add_argument("video_path", type=str, help="Path to avatar video/image")
parser.add_argument("--name", type=str, default=None, help="Avatar name (optional)")
parser.add_argument("--avatar_dir", type=str, default="./avatars", help="Directory to save avatars")
args = parser.parse_args()
preprocessor = AvatarPreprocessor(avatar_dir=args.avatar_dir)
result = preprocessor.preprocess_avatar(args.video_path, avatar_name=args.name)
print(f"\nResult: {result}")