| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import importlib.util |
| import io |
| import json |
| import os |
| import platform |
| import shutil |
| import subprocess |
| import sys |
| import tempfile |
| import time |
| import uuid |
| import zipfile |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import zstandard as zstd |
| from PIL import Image |
| from pydub import AudioSegment |
| from scipy.io import loadmat, savemat |
|
|
|
|
| |
| |
| |
|
|
| REAL_STYLE_ALIASES = {"real", "realistic", "photo", "photoreal", "liveaction"} |
|
|
|
|
| def ensure_dir(path: Path) -> None: |
| path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def utc_now_iso() -> str: |
| from datetime import datetime, timezone |
|
|
| return datetime.now(timezone.utc).isoformat(timespec="seconds") |
|
|
|
|
| def sha256_bytes(data: bytes) -> str: |
| return hashlib.sha256(data).hexdigest() |
|
|
|
|
| def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str: |
| h = hashlib.sha256() |
| with path.open("rb") as f: |
| for chunk in iter(lambda: f.read(chunk_size), b""): |
| h.update(chunk) |
| return h.hexdigest() |
|
|
|
|
| def tensor_to_bytes(obj: Any) -> bytes: |
| if isinstance(obj, (bytes, bytearray)): |
| return bytes(obj) |
| if torch.is_tensor(obj): |
| return obj.detach().cpu().contiguous().numpy().tobytes() |
| raise TypeError(f"Expected bytes or tensor, got {type(obj)!r}") |
|
|
|
|
| def bytes_to_tensor(data: bytes) -> torch.Tensor: |
| try: |
| return torch.frombuffer(memoryview(data), dtype=torch.uint8).clone() |
| except Exception: |
| return torch.tensor(list(data), dtype=torch.uint8) |
|
|
|
|
| def decode_png_or_zstd_image(blob: bytes) -> Image.Image: |
| """Decode a preview blob that may be a raw PNG or zstd-compressed PNG bytes.""" |
| try: |
| raw = zstd.ZstdDecompressor().decompress(blob) |
| except Exception: |
| raw = blob |
| return Image.open(io.BytesIO(raw)).convert("RGB") |
|
|
|
|
| def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray: |
| return np.asarray(img.convert("RGB"), dtype=np.uint8) |
|
|
|
|
| def normalize_style_name(style: Optional[str]) -> str: |
| return (style or "").strip().lower() |
|
|
|
|
| def normalize_gender_name(gender: Optional[str]) -> str: |
| return (gender or "").strip().lower() |
|
|
|
|
| def safe_load_bundle(path_or_bundle: Any) -> Optional[Dict[str, Any]]: |
| if path_or_bundle is None: |
| return None |
| if isinstance(path_or_bundle, dict): |
| return path_or_bundle |
| if isinstance(path_or_bundle, (str, os.PathLike)): |
| p = Path(path_or_bundle) |
| ext = p.suffix.lower() |
| if ext in {".pt", ".pth"}: |
| return torch.load(str(p), map_location="cpu", weights_only=False) |
| if ext == ".mat": |
| return loadmat(str(p)) |
| raise TypeError("Conditioning input must be None, a dict, or a .pt/.pth/.mat path") |
| def _resolve_checkpoint(self): |
| candidates = [ |
| "SadTalker_V0.0.2_512.safetensors", |
| "SadTalker_V0.0.2_256.safetensors", |
| "SadTalker_V0.0.2_512.pth", |
| "SadTalker_V0.0.2_256.pth", |
| ] |
|
|
| for name in candidates: |
| p = Path(self.checkpoint_path) / name |
| if p.exists(): |
| return str(p) |
|
|
| raise FileNotFoundError( |
| f"No SadTalker checkpoint found in {self.checkpoint_path}" |
| ) |
|
|
| def composite_alpha_to_rgb(image_path: Path, bg_rgb=(255, 255, 255)) -> Path: |
| """If the input image has alpha, composite it to RGB and return a new PNG path.""" |
| with Image.open(image_path) as im: |
| im = im.convert("RGBA") |
| bg = Image.new("RGBA", im.size, (*bg_rgb, 255)) |
| out = Image.alpha_composite(bg, im).convert("RGB") |
|
|
| out_path = image_path.with_name(f"{image_path.stem}_rgb.png") |
| out.save(out_path) |
| return out_path |
|
|
|
|
| def prepare_image_for_sadtalker(image_path: Path, remove_background_result: Optional[Path] = None) -> Path: |
| if remove_background_result is None: |
| with Image.open(image_path) as im: |
| if im.mode in {"RGBA", "LA"} or ("transparency" in im.info): |
| return composite_alpha_to_rgb(image_path) |
| return image_path |
| return composite_alpha_to_rgb(remove_background_result) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class MountedArchive: |
| name: str |
| zip_sha256: str |
| target_dir: Path |
| marker_path: Path |
|
|
|
|
| def extract_zip_bytes_to_dir(zip_bytes: bytes, dest_dir: Path) -> None: |
| ensure_dir(dest_dir) |
| with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf: |
| zf.extractall(dest_dir) |
|
|
|
|
| def mount_zip_payload(zip_bytes: bytes, zip_sha256: str, target_dir: Path, marker_name: str) -> MountedArchive: |
| ensure_dir(target_dir) |
| marker_path = target_dir / marker_name |
|
|
| if marker_path.exists(): |
| try: |
| existing = json.loads(marker_path.read_text(encoding="utf-8")) |
| if existing.get("zip_sha256") == zip_sha256 and existing.get("mounted") is True: |
| return MountedArchive( |
| name=existing.get("name", marker_name), |
| zip_sha256=zip_sha256, |
| target_dir=target_dir, |
| marker_path=marker_path, |
| ) |
| except Exception: |
| pass |
|
|
| |
| for child in list(target_dir.iterdir()): |
| if child == marker_path: |
| continue |
| if child.is_dir(): |
| shutil.rmtree(child, ignore_errors=True) |
| else: |
| try: |
| child.unlink() |
| except Exception: |
| pass |
|
|
| extract_zip_bytes_to_dir(zip_bytes, target_dir) |
| marker_path.write_text( |
| json.dumps( |
| { |
| "mounted": True, |
| "zip_sha256": zip_sha256, |
| "name": marker_name, |
| "created_at": utc_now_iso(), |
| }, |
| indent=2, |
| ), |
| encoding="utf-8", |
| ) |
| return MountedArchive( |
| name=marker_name, |
| zip_sha256=zip_sha256, |
| target_dir=target_dir, |
| marker_path=marker_path, |
| ) |
|
|
|
|
| |
| |
| |
| class AvatarBankRuntime: |
| def __init__( |
| self, |
| payload: Dict[str, Any], |
| defaults: Optional[Dict[str, Any]] = None, |
| ): |
| self.index: Dict[str, Dict[str, Any]] = payload.get("index", {}) or {} |
| self.embeddings: Dict[str, Dict[str, Any]] = payload.get("embeddings", {}) or {} |
| self.previews: Dict[str, Any] = payload.get("previews", {}) or {} |
| self.defaults = defaults or {} |
|
|
| @classmethod |
| def load( |
| cls, |
| path: Path, |
| defaults: Optional[Dict[str, Any]] = None, |
| ) -> "AvatarBankRuntime": |
| payload = torch.load(str(path), map_location="cpu", weights_only=False) |
|
|
| if not isinstance(payload, dict): |
| raise ValueError(f"Avatar bank file did not contain a dictionary: {path}") |
|
|
| return cls(payload, defaults=defaults) |
|
|
| def save(self, path: Union[str, Path]) -> None: |
| torch.save( |
| { |
| "index": self.index, |
| "embeddings": self.embeddings, |
| "previews": self.previews, |
| }, |
| str(path), |
| ) |
|
|
| |
| |
| |
|
|
| def __contains__(self, avatar_id: str) -> bool: |
| return avatar_id in self.index |
|
|
| def exists(self, avatar_id: str) -> bool: |
| return avatar_id in self.index |
|
|
| def available_ids(self) -> List[str]: |
| return list(self.index.keys()) |
|
|
| def list_avatars(self) -> List[str]: |
| return self.available_ids() |
|
|
| def get_metadata(self, avatar_id: str) -> Dict[str, Any]: |
| if avatar_id not in self.index: |
| raise KeyError(f"Avatar not found: {avatar_id}") |
|
|
| return dict(self.index[avatar_id]) |
|
|
| def get_avatar(self, avatar_id: str) -> Dict[str, Any]: |
| return self.build_avatar_condition(avatar_id) |
|
|
| def get_embedding_bundle(self, avatar_id: str) -> Dict[str, Any]: |
| if avatar_id not in self.embeddings: |
| raise KeyError(f"Avatar not found: {avatar_id}") |
|
|
| return self.embeddings[avatar_id] |
|
|
| |
| |
| |
|
|
| def _fuzzy_match_single( |
| self, |
| query: str, |
| choices: set, |
| cutoff: float = 0.6, |
| ): |
| if not query: |
| return None |
|
|
| query_lower = query.lower() |
|
|
| choice_map = { |
| c.lower(): c |
| for c in choices |
| } |
|
|
| matches = get_close_matches( |
| query_lower, |
| list(choice_map.keys()), |
| n=1, |
| cutoff=cutoff, |
| ) |
|
|
| return choice_map[matches[0]] if matches else None |
|
|
| def fuzzy_search_id( |
| self, |
| query_id: str, |
| n: int = 5, |
| cutoff: float = 0.5, |
| ) -> List[str]: |
| query_lower = query_id.lower() |
|
|
| id_map = { |
| aid.lower(): aid |
| for aid in self.index.keys() |
| } |
|
|
| matches = get_close_matches( |
| query_lower, |
| list(id_map.keys()), |
| n=n, |
| cutoff=cutoff, |
| ) |
|
|
| return [id_map[m] for m in matches] |
|
|
| |
| |
| |
|
|
| def query( |
| self, |
| gender=None, |
| style=None, |
| fuzzy=True, |
| cutoff=0.6, |
| ) -> List[str]: |
|
|
| known_genders = { |
| meta["gender"] |
| for meta in self.index.values() |
| if meta.get("gender") |
| } |
|
|
| known_styles = { |
| meta["style"] |
| for meta in self.index.values() |
| if meta.get("style") |
| } |
|
|
| target_gender = gender |
| target_style = style |
|
|
| if fuzzy: |
| if gender: |
| target_gender = ( |
| self._fuzzy_match_single( |
| gender, |
| known_genders, |
| cutoff, |
| ) |
| or gender |
| ) |
|
|
| if style: |
| target_style = ( |
| self._fuzzy_match_single( |
| style, |
| known_styles, |
| cutoff, |
| ) |
| or style |
| ) |
|
|
| results = [] |
|
|
| for aid, meta in self.index.items(): |
|
|
| if ( |
| target_gender |
| and meta.get("gender") != target_gender |
| ): |
| continue |
|
|
| if ( |
| target_style |
| and meta.get("style") != target_style |
| ): |
| continue |
|
|
| results.append(aid) |
|
|
| return results |
|
|
| |
| |
| |
|
|
| def get_preview(self, avatar_id: str): |
| if avatar_id not in self.previews: |
| raise KeyError(f"Avatar not found: {avatar_id}") |
|
|
| return decode_png_or_zstd_image( |
| self.previews[avatar_id] |
| ) |
|
|
| def get_preview_numpy( |
| self, |
| avatar_id: str, |
| ) -> Optional[np.ndarray]: |
| return self._preview_to_numpy(avatar_id) |
|
|
| def _preview_to_numpy( |
| self, |
| avatar_id: str, |
| ) -> Optional[np.ndarray]: |
|
|
| blob = self.previews.get(avatar_id) |
|
|
| if blob is None: |
| return None |
|
|
| try: |
| img = decode_png_or_zstd_image(blob) |
| return pil_to_numpy_rgb(img) |
|
|
| except Exception: |
| return None |
|
|
| |
| |
| |
|
|
| def delete_avatar( |
| self, |
| avatar_id: str, |
| ) -> None: |
|
|
| self.index.pop(avatar_id, None) |
| self.embeddings.pop(avatar_id, None) |
| self.previews.pop(avatar_id, None) |
|
|
| @classmethod |
| def load(cls, path: Path, defaults: Optional[Dict[str, Any]] = None) -> "AvatarBankRuntime": |
| payload = torch.load(str(path), map_location="cpu", weights_only=False) |
| if not isinstance(payload, dict): |
| raise ValueError(f"Avatar bank file did not contain a dictionary: {path}") |
| return cls(payload, defaults=defaults) |
|
|
| def available_ids(self) -> List[str]: |
| return list(self.index.keys()) |
|
|
| def _preview_to_numpy(self, avatar_id: str) -> Optional[np.ndarray]: |
| blob = self.previews.get(avatar_id) |
| if blob is None: |
| return None |
| try: |
| img = decode_png_or_zstd_image(blob) |
| return pil_to_numpy_rgb(img) |
| except Exception: |
| return None |
|
|
| def _style_is_real(self, style: Optional[str]) -> bool: |
| return normalize_style_name(style) in REAL_STYLE_ALIASES |
|
|
| def resolve_default_avatar_id(self) -> str: |
| if not self.index: |
| raise RuntimeError("Avatar bank is empty.") |
|
|
| default_voice = self.defaults.get("default_avatar") |
| if default_voice and default_voice in self.index: |
| return default_voice |
|
|
| |
| for avatar_id, meta in self.index.items(): |
| if normalize_gender_name(meta.get("gender")) == "male" and self._style_is_real(meta.get("style")): |
| return avatar_id |
|
|
| |
| for avatar_id, meta in self.index.items(): |
| if self._style_is_real(meta.get("style")): |
| return avatar_id |
|
|
| |
| for avatar_id, meta in self.index.items(): |
| if normalize_gender_name(meta.get("gender")) == "male": |
| return avatar_id |
|
|
| |
| for avatar_id, emb in self.embeddings.items(): |
| if emb is not None: |
| return avatar_id |
|
|
| |
| return next(iter(self.index.keys())) |
|
|
| def build_avatar_condition(self, avatar_id: str) -> Dict[str, Any]: |
| if avatar_id not in self.embeddings: |
| raise KeyError(f"Avatar not found: {avatar_id}") |
|
|
| meta = self.index.get(avatar_id, {}) or {} |
| emb = self.embeddings[avatar_id] or {} |
|
|
| coeff = emb.get("motion_3dmm") |
| if coeff is None: |
| coeff = emb.get("full_3dmm") |
| if coeff is None: |
| raise ValueError(f"Avatar '{avatar_id}' is missing motion_3dmm/full_3dmm") |
|
|
| crop_preview = emb.get("crop_preview") |
| if crop_preview is None: |
| crop_preview = self._preview_to_numpy(avatar_id) |
| else: |
| if torch.is_tensor(crop_preview): |
| crop_preview = crop_preview.detach().cpu() |
| elif isinstance(crop_preview, np.ndarray): |
| crop_preview = crop_preview |
| else: |
| crop_preview = np.asarray(crop_preview) |
|
|
| out = { |
| "avatar_id": avatar_id, |
| "gender": meta.get("gender"), |
| "style": meta.get("style"), |
| "coeff_3dmm": coeff.detach().cpu() if torch.is_tensor(coeff) else coeff, |
| "motion_3dmm": emb.get("motion_3dmm"), |
| "full_3dmm": emb.get("full_3dmm"), |
| "crop_info": emb.get("crop_info"), |
| "crop_preview": crop_preview, |
| } |
| if torch.is_tensor(out["motion_3dmm"]): |
| out["motion_3dmm"] = out["motion_3dmm"].detach().cpu() |
| if torch.is_tensor(out["full_3dmm"]): |
| out["full_3dmm"] = out["full_3dmm"].detach().cpu() |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class BriaBackgroundRemover: |
| """ |
| Best-effort loader for the packed briaaiRMBG-2.0 directory. |
| |
| It searches for a likely inference script and tries callable or CLI-based |
| execution patterns. If the local folder layout differs, the search list |
| below is the only part that usually needs adjustment. |
| """ |
|
|
| def __init__(self, root: Path): |
| self.root = root |
| self.entrypoint = self._discover_entrypoint() |
|
|
| def _discover_entrypoint(self) -> Optional[Path]: |
| if not self.root.exists(): |
| return None |
|
|
| preferred = [ |
| "inference.py", |
| "predict.py", |
| "app.py", |
| "main.py", |
| "run.py", |
| ] |
| for name in preferred: |
| hits = list(self.root.rglob(name)) |
| if hits: |
| return hits[0] |
|
|
| |
| for p in self.root.rglob("*.py"): |
| lower = str(p).lower() |
| if "bria" in lower or "rmbg" in lower or "background" in lower: |
| return p |
| return None |
|
|
| def _import_module_from_path(self, py_file: Path): |
| module_name = f"packed_bria_{sha256_bytes(str(py_file).encode('utf-8'))[:12]}" |
| spec = importlib.util.spec_from_file_location(module_name, str(py_file)) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Could not import module from {py_file}") |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| return module |
|
|
| def _call_module_callable(self, module, image_path: Path, output_path: Path) -> bool: |
| candidates = [ |
| "remove_background", |
| "predict_image", |
| "predict", |
| "run", |
| "inference", |
| "main", |
| ] |
| callables = [getattr(module, name, None) for name in candidates] |
| callables = [fn for fn in callables if callable(fn)] |
| for fn in callables: |
| attempts = [ |
| (str(image_path), str(output_path)), |
| (str(image_path),), |
| (Image.open(image_path),), |
| (), |
| ] |
| for args in attempts: |
| try: |
| result = fn(*args) |
| if isinstance(result, (str, os.PathLike)): |
| result_path = Path(result) |
| if result_path.exists(): |
| shutil.copy2(result_path, output_path) |
| return True |
| elif isinstance(result, Image.Image): |
| result.save(output_path) |
| return True |
| elif torch.is_tensor(result): |
| arr = result.detach().cpu().numpy() |
| if arr.ndim == 3 and arr.shape[-1] in (3, 4): |
| img = Image.fromarray(arr.astype(np.uint8)) |
| img.save(output_path) |
| return True |
| elif result is None and output_path.exists(): |
| return True |
| except Exception: |
| continue |
| return False |
|
|
| def _call_cli_with_patterns(self, image_path: Path, output_path: Path) -> bool: |
| if self.entrypoint is None: |
| return False |
|
|
| cmd_patterns = [ |
| [str(self.entrypoint), str(image_path), str(output_path)], |
| [str(self.entrypoint), "--input", str(image_path), "--output", str(output_path)], |
| [str(self.entrypoint), "--image", str(image_path), "--output", str(output_path)], |
| [str(self.entrypoint), "--input_path", str(image_path), "--output_path", str(output_path)], |
| [str(self.entrypoint), "-i", str(image_path), "-o", str(output_path)], |
| ] |
|
|
| for args in cmd_patterns: |
| try: |
| proc = subprocess.run( |
| [sys.executable, *args], |
| cwd=str(self.root), |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.DEVNULL, |
| check=False, |
| ) |
| if proc.returncode == 0 and output_path.exists(): |
| return True |
| except Exception: |
| continue |
| return False |
|
|
| def remove_background(self, image_path: Path, output_dir: Path) -> Path: |
| if self.entrypoint is None: |
| raise RuntimeError( |
| f"No usable background-removal entrypoint found under {self.root}." |
| ) |
|
|
| ensure_dir(output_dir) |
| output_path = output_dir / f"{image_path.stem}_rmbg.png" |
|
|
| try: |
| module = self._import_module_from_path(self.entrypoint) |
| if self._call_module_callable(module, image_path, output_path): |
| return output_path |
| except Exception: |
| pass |
|
|
| if self._call_cli_with_patterns(image_path, output_path): |
| return output_path |
|
|
| raise RuntimeError( |
| f"Could not execute background removal with entrypoint {self.entrypoint}. " |
| f"You may need to adjust the call patterns in BriaBackgroundRemover." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class SadTalkerRunner: |
| def __init__(self, checkpoint_path: str, config_path: str, device: str = "cpu"): |
| self.checkpoint_path = checkpoint_path |
| self.config_path = config_path |
| self.device = device |
| self._mods_loaded = False |
| self._load_modules() |
|
|
| def _load_modules(self): |
| if self._mods_loaded: |
| return |
|
|
| from SadTalker.src.facerender.pirender_animate import AnimateFromCoeff_PIRender |
| from SadTalker.src.utils.preprocess import CropAndExtract |
| from SadTalker.src.test_audio2coeff import Audio2Coeff |
| from SadTalker.src.facerender.animate import AnimateFromCoeff |
| from SadTalker.src.generate_batch import get_data |
| from SadTalker.src.generate_facerender_batch import get_facerender_data |
| from SadTalker.src.utils.init_path import init_path |
|
|
| self.AnimateFromCoeff_PIRender = AnimateFromCoeff_PIRender |
| self.CropAndExtract = CropAndExtract |
| self.Audio2Coeff = Audio2Coeff |
| self.AnimateFromCoeff = AnimateFromCoeff |
| self.get_data = get_data |
| self.get_facerender_data = get_facerender_data |
| self.init_path = init_path |
| self._mods_loaded = True |
|
|
| @staticmethod |
| def _mp3_to_wav(mp3_filename: str, wav_filename: str, frame_rate: int): |
| mp3_file = AudioSegment.from_file(file=mp3_filename) |
| mp3_file.set_frame_rate(frame_rate).export(wav_filename, format="wav") |
|
|
| def _to_numpy(self, x): |
| if x is None: |
| return None |
| if isinstance(x, np.ndarray): |
| return x |
| if torch.is_tensor(x): |
| return x.detach().cpu().numpy() |
| return np.asarray(x) |
|
|
| def _save_png_from_bundle(self, bundle, out_path): |
| for key in ("crop_preview", "aligned_face", "image", "png"): |
| if key in bundle and bundle[key] is not None: |
| arr = self._to_numpy(bundle[key]) |
| if arr.ndim == 3 and arr.shape[-1] in (1, 3, 4): |
| if arr.dtype != np.uint8: |
| arr = np.clip(arr, 0, 255).astype(np.uint8) |
| if arr.shape[-1] == 4: |
| img = Image.fromarray(arr, mode="RGBA").convert("RGB") |
| else: |
| img = Image.fromarray(arr, mode="RGB") |
| img.save(out_path) |
| return out_path |
| raise ValueError( |
| "Avatar conditioning bundle needs at least one image-like field such as crop_preview or aligned_face." |
| ) |
|
|
| def _save_mat_from_avatar_bundle(self, bundle, out_path): |
| coeff_3dmm = bundle.get("coeff_3dmm", None) |
| if coeff_3dmm is None: |
| coeff_3dmm = bundle.get("motion_3dmm", None) |
| if coeff_3dmm is None: |
| coeff_3dmm = bundle.get("full_3dmm", None) |
| if coeff_3dmm is None: |
| raise ValueError("Avatar bundle must contain coeff_3dmm, motion_3dmm, or full_3dmm.") |
|
|
| mat_dict = {"coeff_3dmm": self._to_numpy(coeff_3dmm)} |
| full_3dmm = bundle.get("full_3dmm", None) |
| if full_3dmm is not None: |
| mat_dict["full_3dmm"] = self._to_numpy(full_3dmm) |
|
|
| savemat(out_path, mat_dict) |
| return out_path |
|
|
| def _save_mat_from_motion_bundle(self, bundle, out_path): |
| motion = bundle.get("motion_3dmm", None) |
| if motion is None: |
| motion = bundle.get("coeff_3dmm", None) |
| if motion is None: |
| motion = bundle.get("full_3dmm_seq", None) |
| if motion is None: |
| motion = bundle.get("full_3dmm", None) |
|
|
| if motion is None: |
| raise ValueError( |
| "Motion bundle must contain motion_3dmm, coeff_3dmm, full_3dmm_seq, or full_3dmm." |
| ) |
|
|
| mat_dict = {"coeff_3dmm": self._to_numpy(motion)} |
|
|
| if "full_3dmm" in bundle and bundle["full_3dmm"] is not None: |
| mat_dict["full_3dmm"] = self._to_numpy(bundle["full_3dmm"]) |
| elif "full_3dmm_seq" in bundle and bundle["full_3dmm_seq"] is not None: |
| seq = self._to_numpy(bundle["full_3dmm_seq"]) |
| if seq.ndim >= 3: |
| mat_dict["full_3dmm"] = seq[0] |
| else: |
| mat_dict["full_3dmm"] = seq |
|
|
| savemat(out_path, mat_dict) |
| return out_path |
|
|
|
|
| def _bundle_from_preprocess_output( |
| self, |
| coeff_path, |
| crop_pic_path, |
| crop_info, |
| ): |
| bundle = {} |
|
|
| |
| if coeff_path is not None and os.path.isfile(coeff_path): |
| try: |
| raw = loadmat(coeff_path) |
| for key, value in raw.items(): |
| if not key.startswith("__"): |
| bundle[key] = value |
| except Exception: |
| pass |
|
|
| |
| if coeff_path is not None: |
| bundle["coeff_path"] = str(coeff_path) |
| if crop_pic_path is not None: |
| bundle["crop_pic_path"] = str(crop_pic_path) |
| if crop_info is not None: |
| bundle["crop_info"] = crop_info |
|
|
| |
| try: |
| if crop_pic_path is not None and os.path.isfile(crop_pic_path): |
| with Image.open(crop_pic_path) as im: |
| bundle["crop_preview"] = pil_to_numpy_rgb(im) |
| except Exception: |
| pass |
|
|
| |
| if "coeff_3dmm" in bundle and "motion_3dmm" not in bundle: |
| bundle["motion_3dmm"] = bundle["coeff_3dmm"] |
| if "motion_3dmm" in bundle and "coeff_3dmm" not in bundle: |
| bundle["coeff_3dmm"] = bundle["motion_3dmm"] |
|
|
| if "full_3dmm" not in bundle: |
| if "full_3dmm_seq" in bundle: |
| seq = bundle["full_3dmm_seq"] |
| try: |
| if hasattr(seq, "ndim") and seq.ndim >= 3: |
| bundle["full_3dmm"] = seq[0] |
| else: |
| bundle["full_3dmm"] = seq |
| except Exception: |
| bundle["full_3dmm"] = seq |
| elif "motion_3dmm" in bundle: |
| bundle["full_3dmm"] = bundle["motion_3dmm"] |
|
|
| if "landmarks" in bundle: |
| bundle["landmarks"] = bundle["landmarks"] |
|
|
| return bundle |
|
|
| def extract_embeddings( |
| self, |
| input_path, |
| crop_or_resize: str = "crop", |
| pic_size: int = 256, |
| save_dir: Optional[str] = None, |
| ): |
| """ |
| Public preprocessing helper. |
| |
| Accepts either a source image or a reference video, runs the packed |
| SadTalker preprocessing, and returns the extracted conditioning bundle. |
| """ |
| self._load_modules() |
| self._ensure_models(size=pic_size, preprocess=crop_or_resize, facerender="facevid2vid") |
|
|
| input_path = Path(input_path) |
| if not input_path.exists(): |
| raise FileNotFoundError(str(input_path)) |
|
|
| if save_dir is None: |
| save_dir = tempfile.mkdtemp(prefix="packedavatar_embeddings_") |
| else: |
| ensure_dir(Path(save_dir)) |
|
|
| work_dir = Path(save_dir) |
| input_ext = input_path.suffix.lower() |
| video_exts = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".flv", ".wmv", ".m4v", ".gif"} |
|
|
| if input_ext in video_exts: |
| frame_dir = work_dir / f"{input_path.stem}_frames" |
| ensure_dir(frame_dir) |
| coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
| str(input_path), |
| str(frame_dir), |
| crop_or_resize, |
| source_image_flag=False, |
| ) |
| else: |
| staged = work_dir / input_path.name |
| shutil.copy2(input_path, staged) |
|
|
| first_frame_dir = work_dir / "first_frame_dir" |
| ensure_dir(first_frame_dir) |
| coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
| str(staged), |
| str(first_frame_dir), |
| crop_or_resize, |
| True, |
| pic_size, |
| ) |
|
|
| return self._bundle_from_preprocess_output(coeff_path, crop_pic_path, crop_info) |
|
|
| def ExtractEmbeddings( |
| self, |
| input_path, |
| crop_or_resize: str = "crop", |
| pic_size: int = 256, |
| save_dir: Optional[str] = None, |
| ): |
| return self.extract_embeddings( |
| input_path=input_path, |
| crop_or_resize=crop_or_resize, |
| pic_size=pic_size, |
| save_dir=save_dir, |
| ) |
|
|
| def _materialize_avatar_condition(self, avatar_condition, save_dir): |
| bundle = safe_load_bundle(avatar_condition) |
| if bundle is None: |
| return None, None, None |
|
|
| coeff_path = bundle.get("coeff_path", None) |
| crop_pic_path = bundle.get("crop_pic_path", None) |
| crop_info = bundle.get("crop_info", None) |
|
|
| if coeff_path is None or not os.path.isfile(coeff_path): |
| coeff_path = os.path.join(save_dir, "avatar_condition.mat") |
| self._save_mat_from_avatar_bundle(bundle, coeff_path) |
|
|
| if crop_pic_path is None or not os.path.isfile(crop_pic_path): |
| crop_pic_path = os.path.join(save_dir, "avatar_condition.png") |
| self._save_png_from_bundle(bundle, crop_pic_path) |
|
|
| return coeff_path, crop_pic_path, crop_info |
|
|
| def _materialize_motion_condition(self, motion_condition, save_dir): |
| bundle = safe_load_bundle(motion_condition) |
| if bundle is None: |
| return None |
|
|
| coeff_path = bundle.get("coeff_path", None) |
| if coeff_path is not None and os.path.isfile(coeff_path): |
| return coeff_path |
|
|
| coeff_path = os.path.join(save_dir, "motion_condition.mat") |
| self._save_mat_from_motion_bundle(bundle, coeff_path) |
| return coeff_path |
|
|
| def _resolve_checkpoint(self): |
| candidates = [ |
| "SadTalker_V0.0.2_512.safetensors", |
| "SadTalker_V0.0.2_256.safetensors", |
| "SadTalker_V0.0.2_512.pth", |
| "SadTalker_V0.0.2_256.pth", |
| ] |
|
|
| for name in candidates: |
| p = Path(self.checkpoint_path) / name |
| if p.exists(): |
| return str(p) |
|
|
| raise FileNotFoundError( |
| f"No SadTalker checkpoint found in {self.checkpoint_path}" |
| ) |
|
|
| def _ensure_models(self, size: int, preprocess: str, facerender: str): |
| self.sadtalker_paths = self.init_path( |
| self.checkpoint_path, |
| self.config_path, |
| size, |
| False, |
| preprocess, |
| ) |
|
|
| |
| self.sadtalker_paths["checkpoint"] = self._resolve_checkpoint() |
|
|
| print("\n[PackedAvatar] Using checkpoint:") |
| print(self.sadtalker_paths["checkpoint"]) |
|
|
| self.audio_to_coeff = self.Audio2Coeff( |
| self.sadtalker_paths, |
| self.device |
| ) |
|
|
| self.preprocess_model = self.CropAndExtract( |
| self.sadtalker_paths, |
| self.device |
| ) |
|
|
| if facerender == "facevid2vid" and self.device != "mps": |
| self.animate_from_coeff = self.AnimateFromCoeff( |
| self.sadtalker_paths, |
| self.device |
| ) |
| else: |
| self.animate_from_coeff = self.AnimateFromCoeff_PIRender( |
| self.sadtalker_paths, |
| self.device |
| ) |
|
|
| def generate( |
| self, |
| source_image=None, |
| driven_audio=None, |
| preprocess="crop", |
| still_mode=False, |
| use_enhancer=False, |
| batch_size=1, |
| size=256, |
| pose_style=0, |
| facerender="facevid2vid", |
| exp_scale=1.0, |
| use_ref_video=False, |
| ref_video=None, |
| ref_info=None, |
| use_idle_mode=False, |
| length_of_audio=0, |
| use_blink=True, |
| result_dir="./results/", |
| avatar_condition=None, |
| motion_condition=None, |
| ): |
| self._load_modules() |
| self._ensure_models(size=size, preprocess=preprocess, facerender=facerender) |
|
|
| time_tag = str(uuid.uuid4()) |
| save_dir = os.path.join(result_dir, time_tag) |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| input_dir = os.path.join(save_dir, "input") |
| os.makedirs(input_dir, exist_ok=True) |
|
|
| |
| |
| |
| if driven_audio is not None and os.path.isfile(driven_audio): |
| audio_name = os.path.basename(driven_audio) |
| audio_path = os.path.join(input_dir, audio_name) |
|
|
| if audio_name.lower().endswith(".mp3"): |
| wav_path = os.path.splitext(audio_path)[0] + ".wav" |
| self._mp3_to_wav(driven_audio, wav_path, 16000) |
| audio_path = wav_path |
| else: |
| shutil.copy2(driven_audio, audio_path) |
|
|
| elif use_idle_mode: |
| audio_path = os.path.join(input_dir, f"idlemode_{str(length_of_audio)}.wav") |
| one_sec_segment = AudioSegment.silent(duration=1000 * length_of_audio) |
| one_sec_segment.export(audio_path, format="wav") |
| else: |
| assert use_ref_video is True and ref_info == "all", ( |
| "Either driven_audio, use_idle_mode, or use_ref_video/ref_info='all' must be provided." |
| ) |
|
|
| if use_ref_video and ref_info == "all" and ref_video is not None: |
| ref_video_videoname = os.path.basename(ref_video) |
| audio_path = os.path.join(save_dir, ref_video_videoname + ".wav") |
| cmd = f'ffmpeg -y -hide_banner -loglevel error -i "{ref_video}" "{audio_path}"' |
| os.system(cmd) |
|
|
| |
| |
| |
| if avatar_condition is not None: |
| first_coeff_path, crop_pic_path, crop_info = self._materialize_avatar_condition( |
| avatar_condition, save_dir |
| ) |
| if first_coeff_path is None: |
| raise AttributeError("Invalid avatar_condition bundle.") |
| pic_path = crop_pic_path |
| else: |
| if source_image is None: |
| raise ValueError("source_image is required when avatar_condition is not provided.") |
|
|
| pic_path = os.path.join(input_dir, os.path.basename(source_image)) |
| shutil.copy2(source_image, pic_path) |
|
|
| first_frame_dir = os.path.join(save_dir, "first_frame_dir") |
| os.makedirs(first_frame_dir, exist_ok=True) |
|
|
| first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( |
| pic_path, |
| first_frame_dir, |
| preprocess, |
| True, |
| size, |
| ) |
|
|
| if first_coeff_path is None: |
| raise AttributeError("No face is detected") |
|
|
| |
| |
| |
| if motion_condition is not None: |
| ref_video_coeff_path = self._materialize_motion_condition(motion_condition, save_dir) |
| ref_pose_coeff_path = ref_video_coeff_path |
| ref_eyeblink_coeff_path = ref_video_coeff_path |
| elif use_ref_video and ref_video is not None: |
| ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] |
| ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) |
| os.makedirs(ref_video_frame_dir, exist_ok=True) |
|
|
| print("3DMM Extraction for the reference video providing pose") |
| ref_video_coeff_path, _, _ = self.preprocess_model.generate( |
| ref_video, |
| ref_video_frame_dir, |
| preprocess, |
| source_image_flag=False, |
| ) |
|
|
| if use_ref_video: |
| if ref_info == "pose": |
| ref_pose_coeff_path = ref_video_coeff_path |
| ref_eyeblink_coeff_path = None |
| elif ref_info == "blink": |
| ref_pose_coeff_path = None |
| ref_eyeblink_coeff_path = ref_video_coeff_path |
| elif ref_info == "pose+blink": |
| ref_pose_coeff_path = ref_video_coeff_path |
| ref_eyeblink_coeff_path = ref_video_coeff_path |
| elif ref_info == "all": |
| ref_pose_coeff_path = None |
| ref_eyeblink_coeff_path = None |
| else: |
| raise ValueError("error in ref_info") |
| else: |
| ref_pose_coeff_path = None |
| ref_eyeblink_coeff_path = None |
| else: |
| ref_video_coeff_path = None |
| ref_pose_coeff_path = None |
| ref_eyeblink_coeff_path = None |
|
|
| |
| |
| |
| if use_ref_video and ref_info == "all" and ref_video_coeff_path is not None: |
| coeff_path = ref_video_coeff_path |
| else: |
| batch = self.get_data( |
| first_coeff_path, |
| audio_path, |
| self.device, |
| ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, |
| still=still_mode, |
| idlemode=use_idle_mode, |
| length_of_audio=length_of_audio, |
| use_blink=use_blink, |
| ) |
| coeff_path = self.audio_to_coeff.generate( |
| batch, |
| save_dir, |
| pose_style, |
| ref_pose_coeff_path, |
| ) |
|
|
| |
| |
| |
| data = self.get_facerender_data( |
| coeff_path, |
| crop_pic_path, |
| first_coeff_path, |
| audio_path, |
| batch_size, |
| still_mode=still_mode, |
| preprocess=preprocess, |
| size=size, |
| expression_scale=exp_scale, |
| facemodel=facerender, |
| ) |
|
|
| return_path = self.animate_from_coeff.generate( |
| data, |
| save_dir, |
| crop_pic_path if avatar_condition is not None else pic_path, |
| crop_info, |
| enhancer="gfpgan" if use_enhancer else None, |
| preprocess=preprocess, |
| img_size=size, |
| ) |
|
|
| video_name = data.get("video_name", "output") |
| print(f"The generated video is named {video_name} in {save_dir}") |
|
|
| del self.preprocess_model |
| del self.audio_to_coeff |
| del self.animate_from_coeff |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| import gc |
|
|
| gc.collect() |
|
|
| return return_path, audio_path, save_dir |
|
|
|
|
| |
| |
| |
|
|
| class PackedAvatar: |
| def __init__( |
| self, |
| packed_pt_path: str = None, |
| cache_dir: Optional[str] = None, |
| device: Optional[str] = None, |
| ): |
| self.packed_pt_path = Path(packed_pt_path or (Path(__file__).resolve().parent / "checkpoints" / "PackedAvatar.pt")) |
| if not self.packed_pt_path.exists(): |
| raise FileNotFoundError(f"Packed bundle not found: {self.packed_pt_path}") |
|
|
| self.device = device or ( |
| "cuda" if torch.cuda.is_available() else ("mps" if platform.system() == "Darwin" else "cpu") |
| ) |
|
|
| self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.gettempdir()) / "PackedAvatarCache" |
| ensure_dir(self.cache_dir) |
|
|
| self.bundle = self._load_bundle(self.packed_pt_path) |
| self.manifest = self.bundle.get("manifest", {}) or {} |
|
|
| self._extract_and_mount() |
| self._mount_python_path() |
|
|
| self.avatar_bank = self._load_avatar_bank() |
| self.bria_root = self.extracted_root / "checkpoints" / "briaaiRMBG-2.0" |
| self.background_remover = BriaBackgroundRemover(self.bria_root) |
| self._runner_cache: Dict[Tuple[int, str, str], SadTalkerRunner] = {} |
|
|
| @staticmethod |
| def _load_bundle(path: Path) -> Dict[str, Any]: |
| bundle = torch.load(str(path), map_location="cpu", weights_only=False) |
| if not isinstance(bundle, dict): |
| raise ValueError("PackedAvatar.pt did not contain a dictionary bundle.") |
| return bundle |
|
|
| def _asset_bytes(self, key: str) -> bytes: |
| asset = self.bundle.get("assets", {}).get(key) |
| if asset is None: |
| raise KeyError(f"Missing asset in bundle: {key}") |
| return tensor_to_bytes(asset) |
|
|
| def _bundle_id(self) -> str: |
| ck_hash = self.manifest.get("archives", {}).get("checkpoints_zip", {}).get("sha256", "") |
| sd_hash = self.manifest.get("archives", {}).get("sadtalker_zip", {}).get("sha256", "") |
| seed = f"{ck_hash}:{sd_hash}".encode("utf-8") |
| return sha256_bytes(seed)[:16] |
|
|
| def _extract_and_mount(self) -> None: |
| bundle_id = self._bundle_id() |
| runtime_root = self.cache_dir / f"packedavatar_{bundle_id}" |
| self.runtime_root = runtime_root |
| self.extracted_root = runtime_root / "extracted" |
| ensure_dir(self.extracted_root) |
|
|
| marker = runtime_root / "mount.json" |
| expected = { |
| "bundle_id": bundle_id, |
| "checkpoints_sha256": self.manifest.get("archives", {}).get("checkpoints_zip", {}).get("sha256"), |
| "sadtalker_sha256": self.manifest.get("archives", {}).get("sadtalker_zip", {}).get("sha256"), |
| } |
|
|
| if marker.exists(): |
| try: |
| existing = json.loads(marker.read_text(encoding="utf-8")) |
| if existing == expected: |
| self.checkpoints_dir = self.extracted_root / "checkpoints" |
| self.sadtalker_dir = self.extracted_root / "SadTalker" |
| return |
| except Exception: |
| pass |
|
|
| |
| if self.extracted_root.exists(): |
| for child in list(self.extracted_root.iterdir()): |
| if child.is_dir(): |
| shutil.rmtree(child, ignore_errors=True) |
| else: |
| try: |
| child.unlink() |
| except Exception: |
| pass |
|
|
| checkpoints_zip = self._asset_bytes("checkpoints_zip") |
| sadtalker_zip = self._asset_bytes("sadtalker_zip") |
|
|
| |
| extract_zip_bytes_to_dir(checkpoints_zip, self.extracted_root) |
| extract_zip_bytes_to_dir(sadtalker_zip, self.extracted_root) |
|
|
| marker.write_text(json.dumps(expected, indent=2), encoding="utf-8") |
|
|
| self.checkpoints_dir = self.extracted_root / "checkpoints" |
| self.sadtalker_dir = self.extracted_root / "SadTalker" |
|
|
| if not self.checkpoints_dir.exists(): |
| raise RuntimeError(f"checkpoints folder missing after extraction: {self.checkpoints_dir}") |
| if not self.sadtalker_dir.exists(): |
| raise RuntimeError(f"SadTalker folder missing after extraction: {self.sadtalker_dir}") |
|
|
| def _mount_python_path(self) -> None: |
| extracted = str(self.extracted_root) |
| if extracted not in sys.path: |
| sys.path.insert(0, extracted) |
|
|
| def _load_avatar_bank(self) -> AvatarBankRuntime: |
| bank_path = self.checkpoints_dir / "AvatarBank.pt" |
| if not bank_path.exists(): |
| raise FileNotFoundError(f"AvatarBank.pt not found inside packed checkpoints: {bank_path}") |
| defaults = { |
| "default_avatar": self.manifest.get("defaults", {}).get("default_avatar", ""), |
| "real_style_aliases": self.manifest.get("defaults", {}).get("real_style_aliases", list(REAL_STYLE_ALIASES)), |
| } |
| return AvatarBankRuntime.load(bank_path, defaults=defaults) |
|
|
| def _get_runner(self, size: int, preprocess: str, facerender: str) -> SadTalkerRunner: |
| key = (int(size), preprocess, facerender) |
| runner = self._runner_cache.get(key) |
| if runner is None: |
| runner = SadTalkerRunner( |
| checkpoint_path=str(self.checkpoints_dir), |
| config_path=str(self.sadtalker_dir / "src" / "config"), |
| device=self.device, |
| ) |
| self._runner_cache[key] = runner |
| return runner |
|
|
|
|
| def extract_embeddings( |
| self, |
| input_path: str, |
| crop_or_resize: str = "crop", |
| pic_size: int = 256, |
| save_dir: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Extract a conditioning bundle from a source image or reference video. |
| |
| The returned dictionary is the same kind of bundle the runtime uses |
| internally for avatar conditioning and motion conditioning. |
| """ |
| runner = self._get_runner(size=pic_size, preprocess=crop_or_resize, facerender="facevid2vid") |
| return runner.extract_embeddings( |
| input_path=input_path, |
| crop_or_resize=crop_or_resize, |
| pic_size=pic_size, |
| save_dir=save_dir, |
| ) |
|
|
| def ExtractEmbeddings( |
| self, |
| input_path: str, |
| crop_or_resize: str = "crop", |
| pic_size: int = 256, |
| save_dir: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| return self.extract_embeddings( |
| input_path=input_path, |
| crop_or_resize=crop_or_resize, |
| pic_size=pic_size, |
| save_dir=save_dir, |
| ) |
|
|
| def _resolve_avatar_condition_from_bank(self, avatar_id: Optional[str]) -> Dict[str, Any]: |
| if avatar_id is None: |
| avatar_id = self.avatar_bank.resolve_default_avatar_id() |
| return self.avatar_bank.build_avatar_condition(avatar_id) |
|
|
| def _normalize_avatar_condition(self, avatar_condition: Any) -> Optional[Dict[str, Any]]: |
| bundle = safe_load_bundle(avatar_condition) |
| if bundle is None: |
| return None |
| if "coeff_3dmm" not in bundle: |
| if "motion_3dmm" in bundle and bundle["motion_3dmm"] is not None: |
| bundle["coeff_3dmm"] = bundle["motion_3dmm"] |
| elif "full_3dmm" in bundle and bundle["full_3dmm"] is not None: |
| bundle["coeff_3dmm"] = bundle["full_3dmm"] |
| return bundle |
|
|
| def _remove_background_if_requested( |
| self, |
| source_image: Optional[str], |
| remove_background: bool, |
| work_dir: Path, |
| ) -> Optional[Path]: |
| if source_image is None: |
| return None |
|
|
| src = Path(source_image) |
| if not src.exists(): |
| raise FileNotFoundError(str(src)) |
|
|
| ensure_dir(work_dir) |
| staged = work_dir / src.name |
| shutil.copy2(src, staged) |
|
|
| if not remove_background: |
| return prepare_image_for_sadtalker(staged) |
|
|
| |
| try: |
| removed = self.background_remover.remove_background(staged, work_dir) |
| return prepare_image_for_sadtalker(staged, removed) |
| except Exception as e: |
| raise RuntimeError( |
| f"remove_background=True was requested, but Bria RMBG execution failed: {e}" |
| ) from e |
|
|
| def _run_wav2lip_gan( |
| self, |
| face_video: str, |
| audio_path: str, |
| save_dir: str, |
| wav2lip_repo: Optional[str] = None, |
| ) -> str: |
| wav2lip_checkpoint = self.checkpoints_dir / "wav2lip_gan.pth" |
| if not wav2lip_checkpoint.is_file(): |
| raise FileNotFoundError( |
| f"Could not find bundled Wav2Lip GAN checkpoint at: {wav2lip_checkpoint}" |
| ) |
|
|
| candidate_repos = [] |
| if wav2lip_repo: |
| candidate_repos.append(Path(wav2lip_repo)) |
|
|
| |
| candidate_repos.extend([ |
| self.checkpoints_dir / "Wav2Lip", |
| self.sadtalker_dir / "Wav2Lip", |
| Path(__file__).resolve().parent / "Wav2Lip", |
| ]) |
|
|
| repo = None |
| for candidate in candidate_repos: |
| if candidate is None: |
| continue |
| inference_py = candidate / "inference.py" |
| if inference_py.is_file(): |
| repo = candidate |
| break |
|
|
| |
| |
| if repo is None: |
| print( |
| "[PackedAvatar] Wav2Lip inference code was not found; " |
| "skipping Wav2Lip post-processing and returning the SadTalker video." |
| ) |
| return face_video |
|
|
| inference_py = repo / "inference.py" |
|
|
| out_video = os.path.join(save_dir, f"{Path(face_video).stem}_wav2lip_gan.mp4") |
| cmd = [ |
| sys.executable, |
| str(inference_py), |
| "--checkpoint_path", |
| str(wav2lip_checkpoint), |
| "--face", |
| str(face_video), |
| "--audio", |
| str(audio_path), |
| "--outfile", |
| str(out_video), |
| ] |
| subprocess.run(cmd, cwd=str(repo), check=True) |
| return out_video |
|
|
| def list_avatars(self): |
| return self.avatar_bank.list_avatars() |
|
|
| def query_avatars(self, *args, **kwargs): |
| return self.avatar_bank.query(*args, **kwargs) |
|
|
| def fuzzy_search_avatar(self, query, n=5, cutoff=0.5): |
| return self.avatar_bank.fuzzy_search_id(query, n, cutoff) |
|
|
| def get_avatar(self, avatar_id): |
| return self.avatar_bank.get_avatar(avatar_id) |
|
|
| def get_avatar_preview(self, avatar_id): |
| return self.avatar_bank.get_preview(avatar_id) |
|
|
| def get_avatar_metadata(self, avatar_id): |
| return self.avatar_bank.get_metadata(avatar_id) |
|
|
| def delete_avatar(self, avatar_id): |
| self.avatar_bank.delete_avatar(avatar_id) |
|
|
| def save_avatar_bank(self, path): |
| self.avatar_bank.save(path) |
|
|
| def generate( |
| self, |
| source_image: Optional[str] = None, |
| driven_audio: Optional[str] = None, |
| preprocess: str = "crop", |
| still_mode: bool = False, |
| use_enhancer: bool = False, |
| batch_size: int = 1, |
| size: int = 256, |
| pose_style: int = 0, |
| facerender: str = "facevid2vid", |
| exp_scale: float = 1.0, |
| use_ref_video: bool = False, |
| ref_video: Optional[str] = None, |
| ref_info: Optional[str] = None, |
| use_idle_mode: bool = False, |
| length_of_audio: int = 0, |
| use_blink: bool = True, |
| result_dir: str = "./results/", |
| avatar_id: Optional[str] = None, |
| avatar_condition: Optional[Any] = None, |
| motion_condition: Optional[Any] = None, |
| remove_background: bool = False, |
| use_wav2lip: bool = False, |
| wav2lip_repo: Optional[str] = None, |
| ) -> str: |
| runner = self._get_runner(size=size, preprocess=preprocess, facerender=facerender) |
| ensure_dir(Path(result_dir)) |
|
|
| |
| |
| resolved_avatar_condition = self._normalize_avatar_condition(avatar_condition) |
| source_image_for_runner: Optional[str] = source_image |
|
|
| if resolved_avatar_condition is None: |
| if source_image_for_runner is None: |
| resolved_avatar_condition = self._resolve_avatar_condition_from_bank(avatar_id) |
| else: |
| |
| source_work_dir = self.runtime_root / "source_work" |
| ensure_dir(source_work_dir) |
| prepared = self._remove_background_if_requested(source_image_for_runner, remove_background, source_work_dir) |
| source_image_for_runner = str(prepared) if prepared is not None else source_image_for_runner |
| else: |
| |
| source_image_for_runner = None |
|
|
| |
| |
| if resolved_avatar_condition is None and source_image_for_runner is None: |
| resolved_avatar_condition = self._resolve_avatar_condition_from_bank(avatar_id) |
|
|
| return_path, audio_path, save_dir = runner.generate( |
| source_image=source_image_for_runner, |
| driven_audio=driven_audio, |
| preprocess=preprocess, |
| still_mode=still_mode, |
| use_enhancer=use_enhancer, |
| batch_size=batch_size, |
| size=size, |
| pose_style=pose_style, |
| facerender=facerender, |
| exp_scale=exp_scale, |
| use_ref_video=use_ref_video, |
| ref_video=ref_video, |
| ref_info=ref_info, |
| use_idle_mode=use_idle_mode, |
| length_of_audio=length_of_audio, |
| use_blink=use_blink, |
| result_dir=result_dir, |
| avatar_condition=resolved_avatar_condition, |
| motion_condition=motion_condition, |
| ) |
|
|
| if use_wav2lip: |
| return_path = self._run_wav2lip_gan( |
| face_video=return_path, |
| audio_path=audio_path, |
| save_dir=save_dir, |
| wav2lip_repo=wav2lip_repo, |
| ) |
|
|
| return return_path |
|
|
|
|
| PackedAvatarModel = PackedAvatar |
|
|
|
|
| |
| |
| |
|
|
| def build_parser() -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser(description="Run the packed avatar bundle.") |
| p.add_argument("--packed-pt", type=Path, default=Path(__file__).resolve().parent / "PackedAvatar.pt") |
| p.add_argument("--cache-dir", type=Path, default=None) |
| p.add_argument("--device", type=str, default=None) |
| p.add_argument("--source-image", type=Path, default=None) |
| p.add_argument("--driven-audio", type=Path, default="speech.wav") |
| p.add_argument("--avatar-id", type=str, default=None) |
| p.add_argument("--avatar-condition", type=Path, default=None) |
| p.add_argument("--motion-condition", type=Path, default=None) |
| p.add_argument("--remove-background", action="store_true") |
| p.add_argument("--use-wav2lip", action="store_true", default=True) |
| p.add_argument("--wav2lip-repo", type=Path, default=None) |
| p.add_argument("--result-dir", type=Path, default=Path("./results")) |
| p.add_argument("--preprocess", type=str, default="crop") |
| p.add_argument("--size", type=int, default=256) |
| p.add_argument("--facerender", type=str, default="facevid2vid") |
| p.add_argument("--still-mode", action="store_true") |
| p.add_argument("--use-enhancer", action="store_true") |
| p.add_argument("--batch-size", type=int, default=1) |
| p.add_argument("--pose-style", type=int, default=0) |
| p.add_argument("--exp-scale", type=float, default=1.0) |
| p.add_argument("--use-ref-video", action="store_true") |
| p.add_argument("--ref-video", type=Path, default=None) |
| p.add_argument("--ref-info", type=str, default=None) |
| p.add_argument("--use-idle-mode", action="store_true") |
| p.add_argument("--length-of-audio", type=int, default=0) |
| p.add_argument("--use-blink", action="store_true", default=True) |
| p.add_argument("--no-blink", action="store_false", dest="use_blink") |
| p.add_argument("--manual-audio", action="store_true", help="Alias for driven-audio handling; kept for clarity.") |
| return p |
|
|
|
|
| def main() -> None: |
| parser = build_parser() |
| args = parser.parse_args() |
|
|
| model = PackedAvatar( |
| packed_pt_path=str(args.packed_pt), |
| cache_dir=str(args.cache_dir) if args.cache_dir else None, |
| device=args.device, |
| ) |
|
|
| avatar_condition = args.avatar_condition if args.avatar_condition else None |
| motion_condition = args.motion_condition if args.motion_condition else None |
|
|
| output = model.generate( |
| source_image=str(args.source_image) if args.source_image else None, |
| driven_audio=str(args.driven_audio) if args.driven_audio else None, |
| preprocess=args.preprocess, |
| still_mode=args.still_mode, |
| use_enhancer=args.use_enhancer, |
| batch_size=args.batch_size, |
| size=args.size, |
| pose_style=args.pose_style, |
| facerender=args.facerender, |
| exp_scale=args.exp_scale, |
| use_ref_video=args.use_ref_video, |
| ref_video=str(args.ref_video) if args.ref_video else None, |
| ref_info=args.ref_info, |
| use_idle_mode=args.use_idle_mode, |
| length_of_audio=args.length_of_audio, |
| use_blink=args.use_blink, |
| result_dir=str(args.result_dir), |
| avatar_id=args.avatar_id, |
| avatar_condition=str(avatar_condition) if avatar_condition else None, |
| motion_condition=str(motion_condition) if motion_condition else None, |
| remove_background=args.remove_background, |
| use_wav2lip=args.use_wav2lip, |
| wav2lip_repo=str(args.wav2lip_repo) if args.wav2lip_repo else None, |
| ) |
|
|
| print(output) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|