""" Avatar Preprocessor - Pre-compute everything related to the avatar video. This saves ~30s per inference when using the same avatar repeatedly. """ import os import cv2 import torch import glob import pickle import numpy as np import hashlib import time import argparse from pathlib import Path from tqdm import tqdm from omegaconf import OmegaConf # MuseTalk imports from musetalk.utils.utils import get_file_type, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, coord_placeholder from musetalk.utils.face_parsing import FaceParsing class AvatarPreprocessor: def __init__(self, avatar_dir: str = "./avatars"): self.avatar_dir = Path(avatar_dir) self.avatar_dir.mkdir(exist_ok=True) # Config (must match server) self.fps = 25 self.version = "v15" self.extra_margin = 10 self.left_cheek_width = 90 self.right_cheek_width = 90 # Models (loaded lazily) self.device = None self.vae = None self.fp = None def _get_file_hash(self, file_path: str) -> str: hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest()[:16] def _load_models(self): if self.vae is not None: return print("Loading models for preprocessing...") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.vae, _, _ = load_all_model( unet_model_path="./models/musetalkV15/unet.pth", vae_type="sd-vae", unet_config="./models/musetalk/config.json", device=self.device ) self.vae.vae = self.vae.vae.half().to(self.device) self.fp = FaceParsing( left_cheek_width=self.left_cheek_width, right_cheek_width=self.right_cheek_width ) print("Models loaded!") def preprocess_avatar(self, video_path: str, avatar_name: str = None) -> dict: """ Pre-process an avatar video and save all computed data. Returns dict with paths to saved data. """ self._load_models() video_path = Path(video_path) if not video_path.exists(): raise FileNotFoundError(f"Video not found: {video_path}") # Generate avatar name from hash if not provided if avatar_name is None: avatar_name = f"avatar_{self._get_file_hash(str(video_path))}" avatar_path = self.avatar_dir / avatar_name avatar_path.mkdir(exist_ok=True) print(f"\n{'='*50}") print(f"Pre-processing avatar: {avatar_name}") print(f"{'='*50}") timings = {} total_start = time.time() # 1. Extract frames print("\n[1/4] Extracting frames...") t0 = time.time() frames_dir = avatar_path / "frames" frames_dir.mkdir(exist_ok=True) if get_file_type(str(video_path)) == "video": cmd = f"ffmpeg -y -v fatal -i {video_path} -vf fps={self.fps} -start_number 0 {frames_dir}/%08d.png" os.system(cmd) input_img_list = sorted(glob.glob(str(frames_dir / '*.[jpJP][pnPN]*[gG]'))) else: # Single image import shutil dest = frames_dir / "00000000.png" shutil.copy(video_path, dest) input_img_list = [str(dest)] timings["frame_extraction"] = time.time() - t0 print(f" Extracted {len(input_img_list)} frames in {timings['frame_extraction']:.2f}s") # 2. Compute landmarks and bboxes print("\n[2/4] Computing landmarks and bounding boxes...") t0 = time.time() coord_list, frame_list = get_landmark_and_bbox(input_img_list, 0) timings["landmarks"] = time.time() - t0 print(f" Computed landmarks in {timings['landmarks']:.2f}s") # 3. Compute VAE latents for each frame print("\n[3/4] Computing VAE latents...") t0 = time.time() input_latent_list = [] crop_frames = [] # Store crop frames for blending for i, (bbox, frame) in enumerate(tqdm(zip(coord_list, frame_list), total=len(coord_list))): if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): continue x1, y1, x2, y2 = bbox if self.version == "v15": y2 = y2 + self.extra_margin y2 = min(y2, frame.shape[0]) crop_frame = frame[y1:y2, x1:x2] crop_frame_resized = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) with torch.no_grad(): latents = self.vae.get_latents_for_unet(crop_frame_resized) # Convert to CPU numpy for storage input_latent_list.append(latents.cpu().numpy()) crop_frames.append({ 'bbox': bbox, 'original_size': (x2-x1, y2-y1) }) timings["vae_encoding"] = time.time() - t0 print(f" Computed {len(input_latent_list)} latents in {timings['vae_encoding']:.2f}s") # 4. Pre-compute face parsing masks (for blending) print("\n[4/4] Pre-computing face parsing data...") t0 = time.time() parsing_data = [] for i, (bbox, frame) in enumerate(tqdm(zip(coord_list, frame_list), total=len(coord_list))): if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): parsing_data.append(None) continue x1, y1, x2, y2 = bbox if self.version == "v15": y2 = y2 + self.extra_margin y2 = min(y2, frame.shape[0]) # Get parsing mask for this frame region crop_frame = frame[y1:y2, x1:x2] try: # Pre-compute the parsing for the crop region parsing = self.fp.get_parsing(crop_frame) parsing_data.append(parsing) except: parsing_data.append(None) timings["face_parsing"] = time.time() - t0 print(f" Computed parsing in {timings['face_parsing']:.2f}s") # Save all data print("\nSaving preprocessed data...") # Save metadata metadata = { 'video_path': str(video_path), 'avatar_name': avatar_name, 'num_frames': len(input_img_list), 'fps': self.fps, 'version': self.version, 'extra_margin': self.extra_margin, 'timings': timings } with open(avatar_path / "metadata.pkl", 'wb') as f: pickle.dump(metadata, f) # Save coord_list (bounding boxes) with open(avatar_path / "coords.pkl", 'wb') as f: pickle.dump(coord_list, f) # Save frame_list (original frames as numpy) with open(avatar_path / "frames.pkl", 'wb') as f: pickle.dump([f if isinstance(f, np.ndarray) else np.array(f) for f in frame_list], f) # Save latents with open(avatar_path / "latents.pkl", 'wb') as f: pickle.dump(input_latent_list, f) # Save crop frame info with open(avatar_path / "crop_info.pkl", 'wb') as f: pickle.dump(crop_frames, f) # Save parsing data with open(avatar_path / "parsing.pkl", 'wb') as f: pickle.dump(parsing_data, f) timings["total"] = time.time() - total_start print(f"\n{'='*50}") print(f"Avatar preprocessed successfully!") print(f"Total time: {timings['total']:.2f}s") print(f"Saved to: {avatar_path}") print(f"{'='*50}") return { 'avatar_name': avatar_name, 'avatar_path': str(avatar_path), 'num_frames': len(input_img_list), 'timings': timings } def list_avatars(self) -> list: """List all preprocessed avatars.""" avatars = [] for p in self.avatar_dir.iterdir(): if p.is_dir() and (p / "metadata.pkl").exists(): with open(p / "metadata.pkl", 'rb') as f: metadata = pickle.load(f) avatars.append(metadata) return avatars if __name__ == "__main__": parser = argparse.ArgumentParser(description="Preprocess avatar for MuseTalk") parser.add_argument("video_path", type=str, help="Path to avatar video/image") parser.add_argument("--name", type=str, default=None, help="Avatar name (optional)") parser.add_argument("--avatar_dir", type=str, default="./avatars", help="Directory to save avatars") args = parser.parse_args() preprocessor = AvatarPreprocessor(avatar_dir=args.avatar_dir) result = preprocessor.preprocess_avatar(args.video_path, avatar_name=args.name) print(f"\nResult: {result}")