Spaces:
Configuration error
Configuration error
| """ | |
| Avatar Preprocessor - Pre-compute everything related to the avatar video. | |
| This saves ~30s per inference when using the same avatar repeatedly. | |
| """ | |
| import os | |
| import cv2 | |
| import torch | |
| import glob | |
| import pickle | |
| import numpy as np | |
| import hashlib | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| # MuseTalk imports | |
| from musetalk.utils.utils import get_file_type, load_all_model | |
| from musetalk.utils.preprocessing import get_landmark_and_bbox, coord_placeholder | |
| from musetalk.utils.face_parsing import FaceParsing | |
| class AvatarPreprocessor: | |
| def __init__(self, avatar_dir: str = "./avatars"): | |
| self.avatar_dir = Path(avatar_dir) | |
| self.avatar_dir.mkdir(exist_ok=True) | |
| # Config (must match server) | |
| self.fps = 25 | |
| self.version = "v15" | |
| self.extra_margin = 10 | |
| self.left_cheek_width = 90 | |
| self.right_cheek_width = 90 | |
| # Models (loaded lazily) | |
| self.device = None | |
| self.vae = None | |
| self.fp = None | |
| def _get_file_hash(self, file_path: str) -> str: | |
| hash_md5 = hashlib.md5() | |
| with open(file_path, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest()[:16] | |
| def _load_models(self): | |
| if self.vae is not None: | |
| return | |
| print("Loading models for preprocessing...") | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| self.vae, _, _ = load_all_model( | |
| unet_model_path="./models/musetalkV15/unet.pth", | |
| vae_type="sd-vae", | |
| unet_config="./models/musetalk/config.json", | |
| device=self.device | |
| ) | |
| self.vae.vae = self.vae.vae.half().to(self.device) | |
| self.fp = FaceParsing( | |
| left_cheek_width=self.left_cheek_width, | |
| right_cheek_width=self.right_cheek_width | |
| ) | |
| print("Models loaded!") | |
| def preprocess_avatar(self, video_path: str, avatar_name: str = None) -> dict: | |
| """ | |
| Pre-process an avatar video and save all computed data. | |
| Returns dict with paths to saved data. | |
| """ | |
| self._load_models() | |
| video_path = Path(video_path) | |
| if not video_path.exists(): | |
| raise FileNotFoundError(f"Video not found: {video_path}") | |
| # Generate avatar name from hash if not provided | |
| if avatar_name is None: | |
| avatar_name = f"avatar_{self._get_file_hash(str(video_path))}" | |
| avatar_path = self.avatar_dir / avatar_name | |
| avatar_path.mkdir(exist_ok=True) | |
| print(f"\n{'='*50}") | |
| print(f"Pre-processing avatar: {avatar_name}") | |
| print(f"{'='*50}") | |
| timings = {} | |
| total_start = time.time() | |
| # 1. Extract frames | |
| print("\n[1/4] Extracting frames...") | |
| t0 = time.time() | |
| frames_dir = avatar_path / "frames" | |
| frames_dir.mkdir(exist_ok=True) | |
| if get_file_type(str(video_path)) == "video": | |
| cmd = f"ffmpeg -y -v fatal -i {video_path} -vf fps={self.fps} -start_number 0 {frames_dir}/%08d.png" | |
| os.system(cmd) | |
| input_img_list = sorted(glob.glob(str(frames_dir / '*.[jpJP][pnPN]*[gG]'))) | |
| else: | |
| # Single image | |
| import shutil | |
| dest = frames_dir / "00000000.png" | |
| shutil.copy(video_path, dest) | |
| input_img_list = [str(dest)] | |
| timings["frame_extraction"] = time.time() - t0 | |
| print(f" Extracted {len(input_img_list)} frames in {timings['frame_extraction']:.2f}s") | |
| # 2. Compute landmarks and bboxes | |
| print("\n[2/4] Computing landmarks and bounding boxes...") | |
| t0 = time.time() | |
| coord_list, frame_list = get_landmark_and_bbox(input_img_list, 0) | |
| timings["landmarks"] = time.time() - t0 | |
| print(f" Computed landmarks in {timings['landmarks']:.2f}s") | |
| # 3. Compute VAE latents for each frame | |
| print("\n[3/4] Computing VAE latents...") | |
| t0 = time.time() | |
| input_latent_list = [] | |
| crop_frames = [] # Store crop frames for blending | |
| for i, (bbox, frame) in enumerate(tqdm(zip(coord_list, frame_list), total=len(coord_list))): | |
| if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): | |
| continue | |
| x1, y1, x2, y2 = bbox | |
| if self.version == "v15": | |
| y2 = y2 + self.extra_margin | |
| y2 = min(y2, frame.shape[0]) | |
| crop_frame = frame[y1:y2, x1:x2] | |
| crop_frame_resized = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) | |
| with torch.no_grad(): | |
| latents = self.vae.get_latents_for_unet(crop_frame_resized) | |
| # Convert to CPU numpy for storage | |
| input_latent_list.append(latents.cpu().numpy()) | |
| crop_frames.append({ | |
| 'bbox': bbox, | |
| 'original_size': (x2-x1, y2-y1) | |
| }) | |
| timings["vae_encoding"] = time.time() - t0 | |
| print(f" Computed {len(input_latent_list)} latents in {timings['vae_encoding']:.2f}s") | |
| # 4. Pre-compute face parsing masks (for blending) | |
| print("\n[4/4] Pre-computing face parsing data...") | |
| t0 = time.time() | |
| parsing_data = [] | |
| for i, (bbox, frame) in enumerate(tqdm(zip(coord_list, frame_list), total=len(coord_list))): | |
| if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): | |
| parsing_data.append(None) | |
| continue | |
| x1, y1, x2, y2 = bbox | |
| if self.version == "v15": | |
| y2 = y2 + self.extra_margin | |
| y2 = min(y2, frame.shape[0]) | |
| # Get parsing mask for this frame region | |
| crop_frame = frame[y1:y2, x1:x2] | |
| try: | |
| # Pre-compute the parsing for the crop region | |
| parsing = self.fp.get_parsing(crop_frame) | |
| parsing_data.append(parsing) | |
| except: | |
| parsing_data.append(None) | |
| timings["face_parsing"] = time.time() - t0 | |
| print(f" Computed parsing in {timings['face_parsing']:.2f}s") | |
| # Save all data | |
| print("\nSaving preprocessed data...") | |
| # Save metadata | |
| metadata = { | |
| 'video_path': str(video_path), | |
| 'avatar_name': avatar_name, | |
| 'num_frames': len(input_img_list), | |
| 'fps': self.fps, | |
| 'version': self.version, | |
| 'extra_margin': self.extra_margin, | |
| 'timings': timings | |
| } | |
| with open(avatar_path / "metadata.pkl", 'wb') as f: | |
| pickle.dump(metadata, f) | |
| # Save coord_list (bounding boxes) | |
| with open(avatar_path / "coords.pkl", 'wb') as f: | |
| pickle.dump(coord_list, f) | |
| # Save frame_list (original frames as numpy) | |
| with open(avatar_path / "frames.pkl", 'wb') as f: | |
| pickle.dump([f if isinstance(f, np.ndarray) else np.array(f) for f in frame_list], f) | |
| # Save latents | |
| with open(avatar_path / "latents.pkl", 'wb') as f: | |
| pickle.dump(input_latent_list, f) | |
| # Save crop frame info | |
| with open(avatar_path / "crop_info.pkl", 'wb') as f: | |
| pickle.dump(crop_frames, f) | |
| # Save parsing data | |
| with open(avatar_path / "parsing.pkl", 'wb') as f: | |
| pickle.dump(parsing_data, f) | |
| timings["total"] = time.time() - total_start | |
| print(f"\n{'='*50}") | |
| print(f"Avatar preprocessed successfully!") | |
| print(f"Total time: {timings['total']:.2f}s") | |
| print(f"Saved to: {avatar_path}") | |
| print(f"{'='*50}") | |
| return { | |
| 'avatar_name': avatar_name, | |
| 'avatar_path': str(avatar_path), | |
| 'num_frames': len(input_img_list), | |
| 'timings': timings | |
| } | |
| def list_avatars(self) -> list: | |
| """List all preprocessed avatars.""" | |
| avatars = [] | |
| for p in self.avatar_dir.iterdir(): | |
| if p.is_dir() and (p / "metadata.pkl").exists(): | |
| with open(p / "metadata.pkl", 'rb') as f: | |
| metadata = pickle.load(f) | |
| avatars.append(metadata) | |
| return avatars | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Preprocess avatar for MuseTalk") | |
| parser.add_argument("video_path", type=str, help="Path to avatar video/image") | |
| parser.add_argument("--name", type=str, default=None, help="Avatar name (optional)") | |
| parser.add_argument("--avatar_dir", type=str, default="./avatars", help="Directory to save avatars") | |
| args = parser.parse_args() | |
| preprocessor = AvatarPreprocessor(avatar_dir=args.avatar_dir) | |
| result = preprocessor.preprocess_avatar(args.video_path, avatar_name=args.name) | |
| print(f"\nResult: {result}") | |