from __future__ import annotations import argparse import hashlib import importlib.util import io import json import os import platform import shutil import subprocess import sys import tempfile import time import uuid import zipfile from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import cv2 import numpy as np import torch import zstandard as zstd from PIL import Image from pydub import AudioSegment from scipy.io import loadmat, savemat # ============================================================ # GENERAL HELPERS # ============================================================ REAL_STYLE_ALIASES = {"real", "realistic", "photo", "photoreal", "liveaction"} def ensure_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def utc_now_iso() -> str: from datetime import datetime, timezone return datetime.now(timezone.utc).isoformat(timespec="seconds") def sha256_bytes(data: bytes) -> str: return hashlib.sha256(data).hexdigest() def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str: h = hashlib.sha256() with path.open("rb") as f: for chunk in iter(lambda: f.read(chunk_size), b""): h.update(chunk) return h.hexdigest() def tensor_to_bytes(obj: Any) -> bytes: if isinstance(obj, (bytes, bytearray)): return bytes(obj) if torch.is_tensor(obj): return obj.detach().cpu().contiguous().numpy().tobytes() raise TypeError(f"Expected bytes or tensor, got {type(obj)!r}") def bytes_to_tensor(data: bytes) -> torch.Tensor: try: return torch.frombuffer(memoryview(data), dtype=torch.uint8).clone() except Exception: return torch.tensor(list(data), dtype=torch.uint8) def decode_png_or_zstd_image(blob: bytes) -> Image.Image: """Decode a preview blob that may be a raw PNG or zstd-compressed PNG bytes.""" try: raw = zstd.ZstdDecompressor().decompress(blob) except Exception: raw = blob return Image.open(io.BytesIO(raw)).convert("RGB") def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray: return np.asarray(img.convert("RGB"), dtype=np.uint8) def normalize_style_name(style: Optional[str]) -> str: return (style or "").strip().lower() def normalize_gender_name(gender: Optional[str]) -> str: return (gender or "").strip().lower() def safe_load_bundle(path_or_bundle: Any) -> Optional[Dict[str, Any]]: if path_or_bundle is None: return None if isinstance(path_or_bundle, dict): return path_or_bundle if isinstance(path_or_bundle, (str, os.PathLike)): p = Path(path_or_bundle) ext = p.suffix.lower() if ext in {".pt", ".pth"}: return torch.load(str(p), map_location="cpu", weights_only=False) if ext == ".mat": return loadmat(str(p)) raise TypeError("Conditioning input must be None, a dict, or a .pt/.pth/.mat path") def _resolve_checkpoint(self): candidates = [ "SadTalker_V0.0.2_512.safetensors", "SadTalker_V0.0.2_256.safetensors", "SadTalker_V0.0.2_512.pth", "SadTalker_V0.0.2_256.pth", ] for name in candidates: p = Path(self.checkpoint_path) / name if p.exists(): return str(p) raise FileNotFoundError( f"No SadTalker checkpoint found in {self.checkpoint_path}" ) def composite_alpha_to_rgb(image_path: Path, bg_rgb=(255, 255, 255)) -> Path: """If the input image has alpha, composite it to RGB and return a new PNG path.""" with Image.open(image_path) as im: im = im.convert("RGBA") bg = Image.new("RGBA", im.size, (*bg_rgb, 255)) out = Image.alpha_composite(bg, im).convert("RGB") out_path = image_path.with_name(f"{image_path.stem}_rgb.png") out.save(out_path) return out_path def prepare_image_for_sadtalker(image_path: Path, remove_background_result: Optional[Path] = None) -> Path: if remove_background_result is None: with Image.open(image_path) as im: if im.mode in {"RGBA", "LA"} or ("transparency" in im.info): return composite_alpha_to_rgb(image_path) return image_path return composite_alpha_to_rgb(remove_background_result) # ============================================================ # ARCHIVE EXTRACTION # ============================================================ @dataclass class MountedArchive: name: str zip_sha256: str target_dir: Path marker_path: Path def extract_zip_bytes_to_dir(zip_bytes: bytes, dest_dir: Path) -> None: ensure_dir(dest_dir) with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf: zf.extractall(dest_dir) def mount_zip_payload(zip_bytes: bytes, zip_sha256: str, target_dir: Path, marker_name: str) -> MountedArchive: ensure_dir(target_dir) marker_path = target_dir / marker_name if marker_path.exists(): try: existing = json.loads(marker_path.read_text(encoding="utf-8")) if existing.get("zip_sha256") == zip_sha256 and existing.get("mounted") is True: return MountedArchive( name=existing.get("name", marker_name), zip_sha256=zip_sha256, target_dir=target_dir, marker_path=marker_path, ) except Exception: pass # Clear any stale contents before extracting. for child in list(target_dir.iterdir()): if child == marker_path: continue if child.is_dir(): shutil.rmtree(child, ignore_errors=True) else: try: child.unlink() except Exception: pass extract_zip_bytes_to_dir(zip_bytes, target_dir) marker_path.write_text( json.dumps( { "mounted": True, "zip_sha256": zip_sha256, "name": marker_name, "created_at": utc_now_iso(), }, indent=2, ), encoding="utf-8", ) return MountedArchive( name=marker_name, zip_sha256=zip_sha256, target_dir=target_dir, marker_path=marker_path, ) # ============================================================ # AVATAR BANK RUNTIME # ============================================================ class AvatarBankRuntime: def __init__( self, payload: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None, ): self.index: Dict[str, Dict[str, Any]] = payload.get("index", {}) or {} self.embeddings: Dict[str, Dict[str, Any]] = payload.get("embeddings", {}) or {} self.previews: Dict[str, Any] = payload.get("previews", {}) or {} self.defaults = defaults or {} @classmethod def load( cls, path: Path, defaults: Optional[Dict[str, Any]] = None, ) -> "AvatarBankRuntime": payload = torch.load(str(path), map_location="cpu", weights_only=False) if not isinstance(payload, dict): raise ValueError(f"Avatar bank file did not contain a dictionary: {path}") return cls(payload, defaults=defaults) def save(self, path: Union[str, Path]) -> None: torch.save( { "index": self.index, "embeddings": self.embeddings, "previews": self.previews, }, str(path), ) # -------------------------------------------------------- # BASIC ACCESS # -------------------------------------------------------- def __contains__(self, avatar_id: str) -> bool: return avatar_id in self.index def exists(self, avatar_id: str) -> bool: return avatar_id in self.index def available_ids(self) -> List[str]: return list(self.index.keys()) def list_avatars(self) -> List[str]: return self.available_ids() def get_metadata(self, avatar_id: str) -> Dict[str, Any]: if avatar_id not in self.index: raise KeyError(f"Avatar not found: {avatar_id}") return dict(self.index[avatar_id]) def get_avatar(self, avatar_id: str) -> Dict[str, Any]: return self.build_avatar_condition(avatar_id) def get_embedding_bundle(self, avatar_id: str) -> Dict[str, Any]: if avatar_id not in self.embeddings: raise KeyError(f"Avatar not found: {avatar_id}") return self.embeddings[avatar_id] # -------------------------------------------------------- # FUZZY SEARCH # -------------------------------------------------------- def _fuzzy_match_single( self, query: str, choices: set, cutoff: float = 0.6, ): if not query: return None query_lower = query.lower() choice_map = { c.lower(): c for c in choices } matches = get_close_matches( query_lower, list(choice_map.keys()), n=1, cutoff=cutoff, ) return choice_map[matches[0]] if matches else None def fuzzy_search_id( self, query_id: str, n: int = 5, cutoff: float = 0.5, ) -> List[str]: query_lower = query_id.lower() id_map = { aid.lower(): aid for aid in self.index.keys() } matches = get_close_matches( query_lower, list(id_map.keys()), n=n, cutoff=cutoff, ) return [id_map[m] for m in matches] # -------------------------------------------------------- # QUERY # -------------------------------------------------------- def query( self, gender=None, style=None, fuzzy=True, cutoff=0.6, ) -> List[str]: known_genders = { meta["gender"] for meta in self.index.values() if meta.get("gender") } known_styles = { meta["style"] for meta in self.index.values() if meta.get("style") } target_gender = gender target_style = style if fuzzy: if gender: target_gender = ( self._fuzzy_match_single( gender, known_genders, cutoff, ) or gender ) if style: target_style = ( self._fuzzy_match_single( style, known_styles, cutoff, ) or style ) results = [] for aid, meta in self.index.items(): if ( target_gender and meta.get("gender") != target_gender ): continue if ( target_style and meta.get("style") != target_style ): continue results.append(aid) return results # -------------------------------------------------------- # PREVIEWS # -------------------------------------------------------- def get_preview(self, avatar_id: str): if avatar_id not in self.previews: raise KeyError(f"Avatar not found: {avatar_id}") return decode_png_or_zstd_image( self.previews[avatar_id] ) def get_preview_numpy( self, avatar_id: str, ) -> Optional[np.ndarray]: return self._preview_to_numpy(avatar_id) def _preview_to_numpy( self, avatar_id: str, ) -> Optional[np.ndarray]: blob = self.previews.get(avatar_id) if blob is None: return None try: img = decode_png_or_zstd_image(blob) return pil_to_numpy_rgb(img) except Exception: return None # -------------------------------------------------------- # MUTATION # -------------------------------------------------------- def delete_avatar( self, avatar_id: str, ) -> None: self.index.pop(avatar_id, None) self.embeddings.pop(avatar_id, None) self.previews.pop(avatar_id, None) @classmethod def load(cls, path: Path, defaults: Optional[Dict[str, Any]] = None) -> "AvatarBankRuntime": payload = torch.load(str(path), map_location="cpu", weights_only=False) if not isinstance(payload, dict): raise ValueError(f"Avatar bank file did not contain a dictionary: {path}") return cls(payload, defaults=defaults) def available_ids(self) -> List[str]: return list(self.index.keys()) def _preview_to_numpy(self, avatar_id: str) -> Optional[np.ndarray]: blob = self.previews.get(avatar_id) if blob is None: return None try: img = decode_png_or_zstd_image(blob) return pil_to_numpy_rgb(img) except Exception: return None def _style_is_real(self, style: Optional[str]) -> bool: return normalize_style_name(style) in REAL_STYLE_ALIASES def resolve_default_avatar_id(self) -> str: if not self.index: raise RuntimeError("Avatar bank is empty.") default_voice = self.defaults.get("default_avatar") if default_voice and default_voice in self.index: return default_voice # Prefer first real male. for avatar_id, meta in self.index.items(): if normalize_gender_name(meta.get("gender")) == "male" and self._style_is_real(meta.get("style")): return avatar_id # Then any real-style avatar. for avatar_id, meta in self.index.items(): if self._style_is_real(meta.get("style")): return avatar_id # Then any male avatar. for avatar_id, meta in self.index.items(): if normalize_gender_name(meta.get("gender")) == "male": return avatar_id # Then any complete avatar. for avatar_id, emb in self.embeddings.items(): if emb is not None: return avatar_id # Finally first available entry. return next(iter(self.index.keys())) def build_avatar_condition(self, avatar_id: str) -> Dict[str, Any]: if avatar_id not in self.embeddings: raise KeyError(f"Avatar not found: {avatar_id}") meta = self.index.get(avatar_id, {}) or {} emb = self.embeddings[avatar_id] or {} coeff = emb.get("motion_3dmm") if coeff is None: coeff = emb.get("full_3dmm") if coeff is None: raise ValueError(f"Avatar '{avatar_id}' is missing motion_3dmm/full_3dmm") crop_preview = emb.get("crop_preview") if crop_preview is None: crop_preview = self._preview_to_numpy(avatar_id) else: if torch.is_tensor(crop_preview): crop_preview = crop_preview.detach().cpu() elif isinstance(crop_preview, np.ndarray): crop_preview = crop_preview else: crop_preview = np.asarray(crop_preview) out = { "avatar_id": avatar_id, "gender": meta.get("gender"), "style": meta.get("style"), "coeff_3dmm": coeff.detach().cpu() if torch.is_tensor(coeff) else coeff, "motion_3dmm": emb.get("motion_3dmm"), "full_3dmm": emb.get("full_3dmm"), "crop_info": emb.get("crop_info"), "crop_preview": crop_preview, } if torch.is_tensor(out["motion_3dmm"]): out["motion_3dmm"] = out["motion_3dmm"].detach().cpu() if torch.is_tensor(out["full_3dmm"]): out["full_3dmm"] = out["full_3dmm"].detach().cpu() return out # ============================================================ # BRIA RMBG BACKGROUND REMOVER (BEST-EFFORT) # ============================================================ class BriaBackgroundRemover: """ Best-effort loader for the packed briaaiRMBG-2.0 directory. It searches for a likely inference script and tries callable or CLI-based execution patterns. If the local folder layout differs, the search list below is the only part that usually needs adjustment. """ def __init__(self, root: Path): self.root = root self.entrypoint = self._discover_entrypoint() def _discover_entrypoint(self) -> Optional[Path]: if not self.root.exists(): return None preferred = [ "inference.py", "predict.py", "app.py", "main.py", "run.py", ] for name in preferred: hits = list(self.root.rglob(name)) if hits: return hits[0] # Fall back to any Python file with a likely folder name. for p in self.root.rglob("*.py"): lower = str(p).lower() if "bria" in lower or "rmbg" in lower or "background" in lower: return p return None def _import_module_from_path(self, py_file: Path): module_name = f"packed_bria_{sha256_bytes(str(py_file).encode('utf-8'))[:12]}" spec = importlib.util.spec_from_file_location(module_name, str(py_file)) if spec is None or spec.loader is None: raise RuntimeError(f"Could not import module from {py_file}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def _call_module_callable(self, module, image_path: Path, output_path: Path) -> bool: candidates = [ "remove_background", "predict_image", "predict", "run", "inference", "main", ] callables = [getattr(module, name, None) for name in candidates] callables = [fn for fn in callables if callable(fn)] for fn in callables: attempts = [ (str(image_path), str(output_path)), (str(image_path),), (Image.open(image_path),), (), ] for args in attempts: try: result = fn(*args) if isinstance(result, (str, os.PathLike)): result_path = Path(result) if result_path.exists(): shutil.copy2(result_path, output_path) return True elif isinstance(result, Image.Image): result.save(output_path) return True elif torch.is_tensor(result): arr = result.detach().cpu().numpy() if arr.ndim == 3 and arr.shape[-1] in (3, 4): img = Image.fromarray(arr.astype(np.uint8)) img.save(output_path) return True elif result is None and output_path.exists(): return True except Exception: continue return False def _call_cli_with_patterns(self, image_path: Path, output_path: Path) -> bool: if self.entrypoint is None: return False cmd_patterns = [ [str(self.entrypoint), str(image_path), str(output_path)], [str(self.entrypoint), "--input", str(image_path), "--output", str(output_path)], [str(self.entrypoint), "--image", str(image_path), "--output", str(output_path)], [str(self.entrypoint), "--input_path", str(image_path), "--output_path", str(output_path)], [str(self.entrypoint), "-i", str(image_path), "-o", str(output_path)], ] for args in cmd_patterns: try: proc = subprocess.run( [sys.executable, *args], cwd=str(self.root), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False, ) if proc.returncode == 0 and output_path.exists(): return True except Exception: continue return False def remove_background(self, image_path: Path, output_dir: Path) -> Path: if self.entrypoint is None: raise RuntimeError( f"No usable background-removal entrypoint found under {self.root}." ) ensure_dir(output_dir) output_path = output_dir / f"{image_path.stem}_rmbg.png" try: module = self._import_module_from_path(self.entrypoint) if self._call_module_callable(module, image_path, output_path): return output_path except Exception: pass if self._call_cli_with_patterns(image_path, output_path): return output_path raise RuntimeError( f"Could not execute background removal with entrypoint {self.entrypoint}. " f"You may need to adjust the call patterns in BriaBackgroundRemover." ) # ============================================================ # SADTALKER CORE RUNTIME # ============================================================ class SadTalkerRunner: def __init__(self, checkpoint_path: str, config_path: str, device: str = "cpu"): self.checkpoint_path = checkpoint_path self.config_path = config_path self.device = device self._mods_loaded = False self._load_modules() def _load_modules(self): if self._mods_loaded: return from SadTalker.src.facerender.pirender_animate import AnimateFromCoeff_PIRender from SadTalker.src.utils.preprocess import CropAndExtract from SadTalker.src.test_audio2coeff import Audio2Coeff from SadTalker.src.facerender.animate import AnimateFromCoeff from SadTalker.src.generate_batch import get_data from SadTalker.src.generate_facerender_batch import get_facerender_data from SadTalker.src.utils.init_path import init_path self.AnimateFromCoeff_PIRender = AnimateFromCoeff_PIRender self.CropAndExtract = CropAndExtract self.Audio2Coeff = Audio2Coeff self.AnimateFromCoeff = AnimateFromCoeff self.get_data = get_data self.get_facerender_data = get_facerender_data self.init_path = init_path self._mods_loaded = True @staticmethod def _mp3_to_wav(mp3_filename: str, wav_filename: str, frame_rate: int): mp3_file = AudioSegment.from_file(file=mp3_filename) mp3_file.set_frame_rate(frame_rate).export(wav_filename, format="wav") def _to_numpy(self, x): if x is None: return None if isinstance(x, np.ndarray): return x if torch.is_tensor(x): return x.detach().cpu().numpy() return np.asarray(x) def _save_png_from_bundle(self, bundle, out_path): for key in ("crop_preview", "aligned_face", "image", "png"): if key in bundle and bundle[key] is not None: arr = self._to_numpy(bundle[key]) if arr.ndim == 3 and arr.shape[-1] in (1, 3, 4): if arr.dtype != np.uint8: arr = np.clip(arr, 0, 255).astype(np.uint8) if arr.shape[-1] == 4: img = Image.fromarray(arr, mode="RGBA").convert("RGB") else: img = Image.fromarray(arr, mode="RGB") img.save(out_path) return out_path raise ValueError( "Avatar conditioning bundle needs at least one image-like field such as crop_preview or aligned_face." ) def _save_mat_from_avatar_bundle(self, bundle, out_path): coeff_3dmm = bundle.get("coeff_3dmm", None) if coeff_3dmm is None: coeff_3dmm = bundle.get("motion_3dmm", None) if coeff_3dmm is None: coeff_3dmm = bundle.get("full_3dmm", None) if coeff_3dmm is None: raise ValueError("Avatar bundle must contain coeff_3dmm, motion_3dmm, or full_3dmm.") mat_dict = {"coeff_3dmm": self._to_numpy(coeff_3dmm)} full_3dmm = bundle.get("full_3dmm", None) if full_3dmm is not None: mat_dict["full_3dmm"] = self._to_numpy(full_3dmm) savemat(out_path, mat_dict) return out_path def _save_mat_from_motion_bundle(self, bundle, out_path): motion = bundle.get("motion_3dmm", None) if motion is None: motion = bundle.get("coeff_3dmm", None) if motion is None: motion = bundle.get("full_3dmm_seq", None) if motion is None: motion = bundle.get("full_3dmm", None) if motion is None: raise ValueError( "Motion bundle must contain motion_3dmm, coeff_3dmm, full_3dmm_seq, or full_3dmm." ) mat_dict = {"coeff_3dmm": self._to_numpy(motion)} if "full_3dmm" in bundle and bundle["full_3dmm"] is not None: mat_dict["full_3dmm"] = self._to_numpy(bundle["full_3dmm"]) elif "full_3dmm_seq" in bundle and bundle["full_3dmm_seq"] is not None: seq = self._to_numpy(bundle["full_3dmm_seq"]) if seq.ndim >= 3: mat_dict["full_3dmm"] = seq[0] else: mat_dict["full_3dmm"] = seq savemat(out_path, mat_dict) return out_path def _bundle_from_preprocess_output( self, coeff_path, crop_pic_path, crop_info, ): bundle = {} # Load whatever the SadTalker preprocessing wrote to disk. if coeff_path is not None and os.path.isfile(coeff_path): try: raw = loadmat(coeff_path) for key, value in raw.items(): if not key.startswith("__"): bundle[key] = value except Exception: pass # Preserve the paths used to generate the bundle. if coeff_path is not None: bundle["coeff_path"] = str(coeff_path) if crop_pic_path is not None: bundle["crop_pic_path"] = str(crop_pic_path) if crop_info is not None: bundle["crop_info"] = crop_info # Keep a usable preview in memory. try: if crop_pic_path is not None and os.path.isfile(crop_pic_path): with Image.open(crop_pic_path) as im: bundle["crop_preview"] = pil_to_numpy_rgb(im) except Exception: pass # Normalize common aliases so downstream code can rely on them. if "coeff_3dmm" in bundle and "motion_3dmm" not in bundle: bundle["motion_3dmm"] = bundle["coeff_3dmm"] if "motion_3dmm" in bundle and "coeff_3dmm" not in bundle: bundle["coeff_3dmm"] = bundle["motion_3dmm"] if "full_3dmm" not in bundle: if "full_3dmm_seq" in bundle: seq = bundle["full_3dmm_seq"] try: if hasattr(seq, "ndim") and seq.ndim >= 3: bundle["full_3dmm"] = seq[0] else: bundle["full_3dmm"] = seq except Exception: bundle["full_3dmm"] = seq elif "motion_3dmm" in bundle: bundle["full_3dmm"] = bundle["motion_3dmm"] if "landmarks" in bundle: bundle["landmarks"] = bundle["landmarks"] return bundle def extract_embeddings( self, input_path, crop_or_resize: str = "crop", pic_size: int = 256, save_dir: Optional[str] = None, ): """ Public preprocessing helper. Accepts either a source image or a reference video, runs the packed SadTalker preprocessing, and returns the extracted conditioning bundle. """ self._load_modules() self._ensure_models(size=pic_size, preprocess=crop_or_resize, facerender="facevid2vid") input_path = Path(input_path) if not input_path.exists(): raise FileNotFoundError(str(input_path)) if save_dir is None: save_dir = tempfile.mkdtemp(prefix="packedavatar_embeddings_") else: ensure_dir(Path(save_dir)) work_dir = Path(save_dir) input_ext = input_path.suffix.lower() video_exts = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".flv", ".wmv", ".m4v", ".gif"} if input_ext in video_exts: frame_dir = work_dir / f"{input_path.stem}_frames" ensure_dir(frame_dir) coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( str(input_path), str(frame_dir), crop_or_resize, source_image_flag=False, ) else: staged = work_dir / input_path.name shutil.copy2(input_path, staged) first_frame_dir = work_dir / "first_frame_dir" ensure_dir(first_frame_dir) coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( str(staged), str(first_frame_dir), crop_or_resize, True, pic_size, ) return self._bundle_from_preprocess_output(coeff_path, crop_pic_path, crop_info) def ExtractEmbeddings( self, input_path, crop_or_resize: str = "crop", pic_size: int = 256, save_dir: Optional[str] = None, ): return self.extract_embeddings( input_path=input_path, crop_or_resize=crop_or_resize, pic_size=pic_size, save_dir=save_dir, ) def _materialize_avatar_condition(self, avatar_condition, save_dir): bundle = safe_load_bundle(avatar_condition) if bundle is None: return None, None, None coeff_path = bundle.get("coeff_path", None) crop_pic_path = bundle.get("crop_pic_path", None) crop_info = bundle.get("crop_info", None) if coeff_path is None or not os.path.isfile(coeff_path): coeff_path = os.path.join(save_dir, "avatar_condition.mat") self._save_mat_from_avatar_bundle(bundle, coeff_path) if crop_pic_path is None or not os.path.isfile(crop_pic_path): crop_pic_path = os.path.join(save_dir, "avatar_condition.png") self._save_png_from_bundle(bundle, crop_pic_path) return coeff_path, crop_pic_path, crop_info def _materialize_motion_condition(self, motion_condition, save_dir): bundle = safe_load_bundle(motion_condition) if bundle is None: return None coeff_path = bundle.get("coeff_path", None) if coeff_path is not None and os.path.isfile(coeff_path): return coeff_path coeff_path = os.path.join(save_dir, "motion_condition.mat") self._save_mat_from_motion_bundle(bundle, coeff_path) return coeff_path def _resolve_checkpoint(self): candidates = [ "SadTalker_V0.0.2_512.safetensors", "SadTalker_V0.0.2_256.safetensors", "SadTalker_V0.0.2_512.pth", "SadTalker_V0.0.2_256.pth", ] for name in candidates: p = Path(self.checkpoint_path) / name if p.exists(): return str(p) raise FileNotFoundError( f"No SadTalker checkpoint found in {self.checkpoint_path}" ) def _ensure_models(self, size: int, preprocess: str, facerender: str): self.sadtalker_paths = self.init_path( self.checkpoint_path, self.config_path, size, False, preprocess, ) # override whatever init_path guessed self.sadtalker_paths["checkpoint"] = self._resolve_checkpoint() print("\n[PackedAvatar] Using checkpoint:") print(self.sadtalker_paths["checkpoint"]) self.audio_to_coeff = self.Audio2Coeff( self.sadtalker_paths, self.device ) self.preprocess_model = self.CropAndExtract( self.sadtalker_paths, self.device ) if facerender == "facevid2vid" and self.device != "mps": self.animate_from_coeff = self.AnimateFromCoeff( self.sadtalker_paths, self.device ) else: self.animate_from_coeff = self.AnimateFromCoeff_PIRender( self.sadtalker_paths, self.device ) def generate( self, source_image=None, driven_audio=None, preprocess="crop", still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, facerender="facevid2vid", exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir="./results/", avatar_condition=None, motion_condition=None, ): self._load_modules() self._ensure_models(size=size, preprocess=preprocess, facerender=facerender) time_tag = str(uuid.uuid4()) save_dir = os.path.join(result_dir, time_tag) os.makedirs(save_dir, exist_ok=True) input_dir = os.path.join(save_dir, "input") os.makedirs(input_dir, exist_ok=True) # ----------------------------- # Audio handling # ----------------------------- if driven_audio is not None and os.path.isfile(driven_audio): audio_name = os.path.basename(driven_audio) audio_path = os.path.join(input_dir, audio_name) if audio_name.lower().endswith(".mp3"): wav_path = os.path.splitext(audio_path)[0] + ".wav" self._mp3_to_wav(driven_audio, wav_path, 16000) audio_path = wav_path else: shutil.copy2(driven_audio, audio_path) elif use_idle_mode: audio_path = os.path.join(input_dir, f"idlemode_{str(length_of_audio)}.wav") one_sec_segment = AudioSegment.silent(duration=1000 * length_of_audio) one_sec_segment.export(audio_path, format="wav") else: assert use_ref_video is True and ref_info == "all", ( "Either driven_audio, use_idle_mode, or use_ref_video/ref_info='all' must be provided." ) if use_ref_video and ref_info == "all" and ref_video is not None: ref_video_videoname = os.path.basename(ref_video) audio_path = os.path.join(save_dir, ref_video_videoname + ".wav") cmd = f'ffmpeg -y -hide_banner -loglevel error -i "{ref_video}" "{audio_path}"' os.system(cmd) # ----------------------------- # Avatar / source conditioning # ----------------------------- if avatar_condition is not None: first_coeff_path, crop_pic_path, crop_info = self._materialize_avatar_condition( avatar_condition, save_dir ) if first_coeff_path is None: raise AttributeError("Invalid avatar_condition bundle.") pic_path = crop_pic_path else: if source_image is None: raise ValueError("source_image is required when avatar_condition is not provided.") pic_path = os.path.join(input_dir, os.path.basename(source_image)) shutil.copy2(source_image, pic_path) first_frame_dir = os.path.join(save_dir, "first_frame_dir") os.makedirs(first_frame_dir, exist_ok=True) first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( pic_path, first_frame_dir, preprocess, True, size, ) if first_coeff_path is None: raise AttributeError("No face is detected") # ----------------------------- # Motion conditioning / reference video # ----------------------------- if motion_condition is not None: ref_video_coeff_path = self._materialize_motion_condition(motion_condition, save_dir) ref_pose_coeff_path = ref_video_coeff_path ref_eyeblink_coeff_path = ref_video_coeff_path elif use_ref_video and ref_video is not None: ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) os.makedirs(ref_video_frame_dir, exist_ok=True) print("3DMM Extraction for the reference video providing pose") ref_video_coeff_path, _, _ = self.preprocess_model.generate( ref_video, ref_video_frame_dir, preprocess, source_image_flag=False, ) if use_ref_video: if ref_info == "pose": ref_pose_coeff_path = ref_video_coeff_path ref_eyeblink_coeff_path = None elif ref_info == "blink": ref_pose_coeff_path = None ref_eyeblink_coeff_path = ref_video_coeff_path elif ref_info == "pose+blink": ref_pose_coeff_path = ref_video_coeff_path ref_eyeblink_coeff_path = ref_video_coeff_path elif ref_info == "all": ref_pose_coeff_path = None ref_eyeblink_coeff_path = None else: raise ValueError("error in ref_info") else: ref_pose_coeff_path = None ref_eyeblink_coeff_path = None else: ref_video_coeff_path = None ref_pose_coeff_path = None ref_eyeblink_coeff_path = None # ----------------------------- # Audio -> coeff # ----------------------------- if use_ref_video and ref_info == "all" and ref_video_coeff_path is not None: coeff_path = ref_video_coeff_path else: batch = self.get_data( first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink, ) coeff_path = self.audio_to_coeff.generate( batch, save_dir, pose_style, ref_pose_coeff_path, ) # ----------------------------- # coeff -> video # ----------------------------- data = self.get_facerender_data( coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale=exp_scale, facemodel=facerender, ) return_path = self.animate_from_coeff.generate( data, save_dir, crop_pic_path if avatar_condition is not None else pic_path, crop_info, enhancer="gfpgan" if use_enhancer else None, preprocess=preprocess, img_size=size, ) video_name = data.get("video_name", "output") print(f"The generated video is named {video_name} in {save_dir}") del self.preprocess_model del self.audio_to_coeff del self.animate_from_coeff if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() import gc gc.collect() return return_path, audio_path, save_dir # ============================================================ # PACKED AVATAR ORCHESTRATOR # ============================================================ class PackedAvatar: def __init__( self, packed_pt_path: str = None, cache_dir: Optional[str] = None, device: Optional[str] = None, ): self.packed_pt_path = Path(packed_pt_path or (Path(__file__).resolve().parent / "checkpoints" / "PackedAvatar.pt")) if not self.packed_pt_path.exists(): raise FileNotFoundError(f"Packed bundle not found: {self.packed_pt_path}") self.device = device or ( "cuda" if torch.cuda.is_available() else ("mps" if platform.system() == "Darwin" else "cpu") ) self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.gettempdir()) / "PackedAvatarCache" ensure_dir(self.cache_dir) self.bundle = self._load_bundle(self.packed_pt_path) self.manifest = self.bundle.get("manifest", {}) or {} self._extract_and_mount() self._mount_python_path() self.avatar_bank = self._load_avatar_bank() self.bria_root = self.extracted_root / "checkpoints" / "briaaiRMBG-2.0" self.background_remover = BriaBackgroundRemover(self.bria_root) self._runner_cache: Dict[Tuple[int, str, str], SadTalkerRunner] = {} @staticmethod def _load_bundle(path: Path) -> Dict[str, Any]: bundle = torch.load(str(path), map_location="cpu", weights_only=False) if not isinstance(bundle, dict): raise ValueError("PackedAvatar.pt did not contain a dictionary bundle.") return bundle def _asset_bytes(self, key: str) -> bytes: asset = self.bundle.get("assets", {}).get(key) if asset is None: raise KeyError(f"Missing asset in bundle: {key}") return tensor_to_bytes(asset) def _bundle_id(self) -> str: ck_hash = self.manifest.get("archives", {}).get("checkpoints_zip", {}).get("sha256", "") sd_hash = self.manifest.get("archives", {}).get("sadtalker_zip", {}).get("sha256", "") seed = f"{ck_hash}:{sd_hash}".encode("utf-8") return sha256_bytes(seed)[:16] def _extract_and_mount(self) -> None: bundle_id = self._bundle_id() runtime_root = self.cache_dir / f"packedavatar_{bundle_id}" self.runtime_root = runtime_root self.extracted_root = runtime_root / "extracted" ensure_dir(self.extracted_root) marker = runtime_root / "mount.json" expected = { "bundle_id": bundle_id, "checkpoints_sha256": self.manifest.get("archives", {}).get("checkpoints_zip", {}).get("sha256"), "sadtalker_sha256": self.manifest.get("archives", {}).get("sadtalker_zip", {}).get("sha256"), } if marker.exists(): try: existing = json.loads(marker.read_text(encoding="utf-8")) if existing == expected: self.checkpoints_dir = self.extracted_root / "checkpoints" self.sadtalker_dir = self.extracted_root / "SadTalker" return except Exception: pass # Reset stale extraction if the bundle changed. if self.extracted_root.exists(): for child in list(self.extracted_root.iterdir()): if child.is_dir(): shutil.rmtree(child, ignore_errors=True) else: try: child.unlink() except Exception: pass checkpoints_zip = self._asset_bytes("checkpoints_zip") sadtalker_zip = self._asset_bytes("sadtalker_zip") # Extract both archives into the same extracted root. extract_zip_bytes_to_dir(checkpoints_zip, self.extracted_root) extract_zip_bytes_to_dir(sadtalker_zip, self.extracted_root) marker.write_text(json.dumps(expected, indent=2), encoding="utf-8") self.checkpoints_dir = self.extracted_root / "checkpoints" self.sadtalker_dir = self.extracted_root / "SadTalker" if not self.checkpoints_dir.exists(): raise RuntimeError(f"checkpoints folder missing after extraction: {self.checkpoints_dir}") if not self.sadtalker_dir.exists(): raise RuntimeError(f"SadTalker folder missing after extraction: {self.sadtalker_dir}") def _mount_python_path(self) -> None: extracted = str(self.extracted_root) if extracted not in sys.path: sys.path.insert(0, extracted) def _load_avatar_bank(self) -> AvatarBankRuntime: bank_path = self.checkpoints_dir / "AvatarBank.pt" if not bank_path.exists(): raise FileNotFoundError(f"AvatarBank.pt not found inside packed checkpoints: {bank_path}") defaults = { "default_avatar": self.manifest.get("defaults", {}).get("default_avatar", ""), "real_style_aliases": self.manifest.get("defaults", {}).get("real_style_aliases", list(REAL_STYLE_ALIASES)), } return AvatarBankRuntime.load(bank_path, defaults=defaults) def _get_runner(self, size: int, preprocess: str, facerender: str) -> SadTalkerRunner: key = (int(size), preprocess, facerender) runner = self._runner_cache.get(key) if runner is None: runner = SadTalkerRunner( checkpoint_path=str(self.checkpoints_dir), config_path=str(self.sadtalker_dir / "src" / "config"), device=self.device, ) self._runner_cache[key] = runner return runner def extract_embeddings( self, input_path: str, crop_or_resize: str = "crop", pic_size: int = 256, save_dir: Optional[str] = None, ) -> Dict[str, Any]: """ Extract a conditioning bundle from a source image or reference video. The returned dictionary is the same kind of bundle the runtime uses internally for avatar conditioning and motion conditioning. """ runner = self._get_runner(size=pic_size, preprocess=crop_or_resize, facerender="facevid2vid") return runner.extract_embeddings( input_path=input_path, crop_or_resize=crop_or_resize, pic_size=pic_size, save_dir=save_dir, ) def ExtractEmbeddings( self, input_path: str, crop_or_resize: str = "crop", pic_size: int = 256, save_dir: Optional[str] = None, ) -> Dict[str, Any]: return self.extract_embeddings( input_path=input_path, crop_or_resize=crop_or_resize, pic_size=pic_size, save_dir=save_dir, ) def _resolve_avatar_condition_from_bank(self, avatar_id: Optional[str]) -> Dict[str, Any]: if avatar_id is None: avatar_id = self.avatar_bank.resolve_default_avatar_id() return self.avatar_bank.build_avatar_condition(avatar_id) def _normalize_avatar_condition(self, avatar_condition: Any) -> Optional[Dict[str, Any]]: bundle = safe_load_bundle(avatar_condition) if bundle is None: return None if "coeff_3dmm" not in bundle: if "motion_3dmm" in bundle and bundle["motion_3dmm"] is not None: bundle["coeff_3dmm"] = bundle["motion_3dmm"] elif "full_3dmm" in bundle and bundle["full_3dmm"] is not None: bundle["coeff_3dmm"] = bundle["full_3dmm"] return bundle def _remove_background_if_requested( self, source_image: Optional[str], remove_background: bool, work_dir: Path, ) -> Optional[Path]: if source_image is None: return None src = Path(source_image) if not src.exists(): raise FileNotFoundError(str(src)) ensure_dir(work_dir) staged = work_dir / src.name shutil.copy2(src, staged) if not remove_background: return prepare_image_for_sadtalker(staged) # Best-effort background removal using the packed Bria folder. try: removed = self.background_remover.remove_background(staged, work_dir) return prepare_image_for_sadtalker(staged, removed) except Exception as e: raise RuntimeError( f"remove_background=True was requested, but Bria RMBG execution failed: {e}" ) from e def _run_wav2lip_gan( self, face_video: str, audio_path: str, save_dir: str, wav2lip_repo: Optional[str] = None, ) -> str: wav2lip_checkpoint = self.checkpoints_dir / "wav2lip_gan.pth" if not wav2lip_checkpoint.is_file(): raise FileNotFoundError( f"Could not find bundled Wav2Lip GAN checkpoint at: {wav2lip_checkpoint}" ) candidate_repos = [] if wav2lip_repo: candidate_repos.append(Path(wav2lip_repo)) # Prefer packed locations first. candidate_repos.extend([ self.checkpoints_dir / "Wav2Lip", self.sadtalker_dir / "Wav2Lip", Path(__file__).resolve().parent / "Wav2Lip", ]) repo = None for candidate in candidate_repos: if candidate is None: continue inference_py = candidate / "inference.py" if inference_py.is_file(): repo = candidate break # No error just because wav2lip_repo was not passed. # If we cannot find runnable Wav2Lip code anywhere, fall back gracefully. if repo is None: print( "[PackedAvatar] Wav2Lip inference code was not found; " "skipping Wav2Lip post-processing and returning the SadTalker video." ) return face_video inference_py = repo / "inference.py" out_video = os.path.join(save_dir, f"{Path(face_video).stem}_wav2lip_gan.mp4") cmd = [ sys.executable, str(inference_py), "--checkpoint_path", str(wav2lip_checkpoint), "--face", str(face_video), "--audio", str(audio_path), "--outfile", str(out_video), ] subprocess.run(cmd, cwd=str(repo), check=True) return out_video def list_avatars(self): return self.avatar_bank.list_avatars() def query_avatars(self, *args, **kwargs): return self.avatar_bank.query(*args, **kwargs) def fuzzy_search_avatar(self, query, n=5, cutoff=0.5): return self.avatar_bank.fuzzy_search_id(query, n, cutoff) def get_avatar(self, avatar_id): return self.avatar_bank.get_avatar(avatar_id) def get_avatar_preview(self, avatar_id): return self.avatar_bank.get_preview(avatar_id) def get_avatar_metadata(self, avatar_id): return self.avatar_bank.get_metadata(avatar_id) def delete_avatar(self, avatar_id): self.avatar_bank.delete_avatar(avatar_id) def save_avatar_bank(self, path): self.avatar_bank.save(path) def generate( self, source_image: Optional[str] = None, driven_audio: Optional[str] = None, preprocess: str = "crop", still_mode: bool = False, use_enhancer: bool = False, batch_size: int = 1, size: int = 256, pose_style: int = 0, facerender: str = "facevid2vid", exp_scale: float = 1.0, use_ref_video: bool = False, ref_video: Optional[str] = None, ref_info: Optional[str] = None, use_idle_mode: bool = False, length_of_audio: int = 0, use_blink: bool = True, result_dir: str = "./results/", avatar_id: Optional[str] = None, avatar_condition: Optional[Any] = None, motion_condition: Optional[Any] = None, remove_background: bool = False, use_wav2lip: bool = False, wav2lip_repo: Optional[str] = None, ) -> str: runner = self._get_runner(size=size, preprocess=preprocess, facerender=facerender) ensure_dir(Path(result_dir)) # If the caller did not provide a source image or explicit avatar condition, # use the bank. If a source image is provided, it stays in the SadTalker path. resolved_avatar_condition = self._normalize_avatar_condition(avatar_condition) source_image_for_runner: Optional[str] = source_image if resolved_avatar_condition is None: if source_image_for_runner is None: resolved_avatar_condition = self._resolve_avatar_condition_from_bank(avatar_id) else: # source_image path will be used directly by SadTalker; optionally background remove it. source_work_dir = self.runtime_root / "source_work" ensure_dir(source_work_dir) prepared = self._remove_background_if_requested(source_image_for_runner, remove_background, source_work_dir) source_image_for_runner = str(prepared) if prepared is not None else source_image_for_runner else: # If an explicit avatar_condition is supplied, it supersedes source_image-driven conditioning. source_image_for_runner = None # When avatar_id is explicitly selected and no source_image/condition was given, # build the corresponding condition from the packed AvatarBank. if resolved_avatar_condition is None and source_image_for_runner is None: resolved_avatar_condition = self._resolve_avatar_condition_from_bank(avatar_id) return_path, audio_path, save_dir = runner.generate( source_image=source_image_for_runner, driven_audio=driven_audio, preprocess=preprocess, still_mode=still_mode, use_enhancer=use_enhancer, batch_size=batch_size, size=size, pose_style=pose_style, facerender=facerender, exp_scale=exp_scale, use_ref_video=use_ref_video, ref_video=ref_video, ref_info=ref_info, use_idle_mode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink, result_dir=result_dir, avatar_condition=resolved_avatar_condition, motion_condition=motion_condition, ) if use_wav2lip: return_path = self._run_wav2lip_gan( face_video=return_path, audio_path=audio_path, save_dir=save_dir, wav2lip_repo=wav2lip_repo, ) return return_path PackedAvatarModel = PackedAvatar # ============================================================ # CLI # ============================================================ def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Run the packed avatar bundle.") p.add_argument("--packed-pt", type=Path, default=Path(__file__).resolve().parent / "PackedAvatar.pt") p.add_argument("--cache-dir", type=Path, default=None) p.add_argument("--device", type=str, default=None) p.add_argument("--source-image", type=Path, default=None) p.add_argument("--driven-audio", type=Path, default="speech.wav") p.add_argument("--avatar-id", type=str, default=None) p.add_argument("--avatar-condition", type=Path, default=None) p.add_argument("--motion-condition", type=Path, default=None) p.add_argument("--remove-background", action="store_true") p.add_argument("--use-wav2lip", action="store_true", default=True) p.add_argument("--wav2lip-repo", type=Path, default=None) p.add_argument("--result-dir", type=Path, default=Path("./results")) p.add_argument("--preprocess", type=str, default="crop") p.add_argument("--size", type=int, default=256) p.add_argument("--facerender", type=str, default="facevid2vid") p.add_argument("--still-mode", action="store_true") p.add_argument("--use-enhancer", action="store_true") p.add_argument("--batch-size", type=int, default=1) p.add_argument("--pose-style", type=int, default=0) p.add_argument("--exp-scale", type=float, default=1.0) p.add_argument("--use-ref-video", action="store_true") p.add_argument("--ref-video", type=Path, default=None) p.add_argument("--ref-info", type=str, default=None) p.add_argument("--use-idle-mode", action="store_true") p.add_argument("--length-of-audio", type=int, default=0) p.add_argument("--use-blink", action="store_true", default=True) p.add_argument("--no-blink", action="store_false", dest="use_blink") p.add_argument("--manual-audio", action="store_true", help="Alias for driven-audio handling; kept for clarity.") return p def main() -> None: parser = build_parser() args = parser.parse_args() model = PackedAvatar( packed_pt_path=str(args.packed_pt), cache_dir=str(args.cache_dir) if args.cache_dir else None, device=args.device, ) avatar_condition = args.avatar_condition if args.avatar_condition else None motion_condition = args.motion_condition if args.motion_condition else None output = model.generate( source_image=str(args.source_image) if args.source_image else None, driven_audio=str(args.driven_audio) if args.driven_audio else None, preprocess=args.preprocess, still_mode=args.still_mode, use_enhancer=args.use_enhancer, batch_size=args.batch_size, size=args.size, pose_style=args.pose_style, facerender=args.facerender, exp_scale=args.exp_scale, use_ref_video=args.use_ref_video, ref_video=str(args.ref_video) if args.ref_video else None, ref_info=args.ref_info, use_idle_mode=args.use_idle_mode, length_of_audio=args.length_of_audio, use_blink=args.use_blink, result_dir=str(args.result_dir), avatar_id=args.avatar_id, avatar_condition=str(avatar_condition) if avatar_condition else None, motion_condition=str(motion_condition) if motion_condition else None, remove_background=args.remove_background, use_wav2lip=args.use_wav2lip, wav2lip_repo=str(args.wav2lip_repo) if args.wav2lip_repo else None, ) print(output) if __name__ == "__main__": main()