Packed-Avatar / PackedAvatar.py
HiMind's picture
Upload 2 files
64cfabb verified
Raw
History Blame Contribute Delete
59.2 kB
from __future__ import annotations
import argparse
import hashlib
import importlib.util
import io
import json
import os
import platform
import shutil
import subprocess
import sys
import tempfile
import time
import uuid
import zipfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import torch
import zstandard as zstd
from PIL import Image
from pydub import AudioSegment
from scipy.io import loadmat, savemat
# ============================================================
# GENERAL HELPERS
# ============================================================
REAL_STYLE_ALIASES = {"real", "realistic", "photo", "photoreal", "liveaction"}
def ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def utc_now_iso() -> str:
from datetime import datetime, timezone
return datetime.now(timezone.utc).isoformat(timespec="seconds")
def sha256_bytes(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
h = hashlib.sha256()
with path.open("rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
h.update(chunk)
return h.hexdigest()
def tensor_to_bytes(obj: Any) -> bytes:
if isinstance(obj, (bytes, bytearray)):
return bytes(obj)
if torch.is_tensor(obj):
return obj.detach().cpu().contiguous().numpy().tobytes()
raise TypeError(f"Expected bytes or tensor, got {type(obj)!r}")
def bytes_to_tensor(data: bytes) -> torch.Tensor:
try:
return torch.frombuffer(memoryview(data), dtype=torch.uint8).clone()
except Exception:
return torch.tensor(list(data), dtype=torch.uint8)
def decode_png_or_zstd_image(blob: bytes) -> Image.Image:
"""Decode a preview blob that may be a raw PNG or zstd-compressed PNG bytes."""
try:
raw = zstd.ZstdDecompressor().decompress(blob)
except Exception:
raw = blob
return Image.open(io.BytesIO(raw)).convert("RGB")
def pil_to_numpy_rgb(img: Image.Image) -> np.ndarray:
return np.asarray(img.convert("RGB"), dtype=np.uint8)
def normalize_style_name(style: Optional[str]) -> str:
return (style or "").strip().lower()
def normalize_gender_name(gender: Optional[str]) -> str:
return (gender or "").strip().lower()
def safe_load_bundle(path_or_bundle: Any) -> Optional[Dict[str, Any]]:
if path_or_bundle is None:
return None
if isinstance(path_or_bundle, dict):
return path_or_bundle
if isinstance(path_or_bundle, (str, os.PathLike)):
p = Path(path_or_bundle)
ext = p.suffix.lower()
if ext in {".pt", ".pth"}:
return torch.load(str(p), map_location="cpu", weights_only=False)
if ext == ".mat":
return loadmat(str(p))
raise TypeError("Conditioning input must be None, a dict, or a .pt/.pth/.mat path")
def _resolve_checkpoint(self):
candidates = [
"SadTalker_V0.0.2_512.safetensors",
"SadTalker_V0.0.2_256.safetensors",
"SadTalker_V0.0.2_512.pth",
"SadTalker_V0.0.2_256.pth",
]
for name in candidates:
p = Path(self.checkpoint_path) / name
if p.exists():
return str(p)
raise FileNotFoundError(
f"No SadTalker checkpoint found in {self.checkpoint_path}"
)
def composite_alpha_to_rgb(image_path: Path, bg_rgb=(255, 255, 255)) -> Path:
"""If the input image has alpha, composite it to RGB and return a new PNG path."""
with Image.open(image_path) as im:
im = im.convert("RGBA")
bg = Image.new("RGBA", im.size, (*bg_rgb, 255))
out = Image.alpha_composite(bg, im).convert("RGB")
out_path = image_path.with_name(f"{image_path.stem}_rgb.png")
out.save(out_path)
return out_path
def prepare_image_for_sadtalker(image_path: Path, remove_background_result: Optional[Path] = None) -> Path:
if remove_background_result is None:
with Image.open(image_path) as im:
if im.mode in {"RGBA", "LA"} or ("transparency" in im.info):
return composite_alpha_to_rgb(image_path)
return image_path
return composite_alpha_to_rgb(remove_background_result)
# ============================================================
# ARCHIVE EXTRACTION
# ============================================================
@dataclass
class MountedArchive:
name: str
zip_sha256: str
target_dir: Path
marker_path: Path
def extract_zip_bytes_to_dir(zip_bytes: bytes, dest_dir: Path) -> None:
ensure_dir(dest_dir)
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zf:
zf.extractall(dest_dir)
def mount_zip_payload(zip_bytes: bytes, zip_sha256: str, target_dir: Path, marker_name: str) -> MountedArchive:
ensure_dir(target_dir)
marker_path = target_dir / marker_name
if marker_path.exists():
try:
existing = json.loads(marker_path.read_text(encoding="utf-8"))
if existing.get("zip_sha256") == zip_sha256 and existing.get("mounted") is True:
return MountedArchive(
name=existing.get("name", marker_name),
zip_sha256=zip_sha256,
target_dir=target_dir,
marker_path=marker_path,
)
except Exception:
pass
# Clear any stale contents before extracting.
for child in list(target_dir.iterdir()):
if child == marker_path:
continue
if child.is_dir():
shutil.rmtree(child, ignore_errors=True)
else:
try:
child.unlink()
except Exception:
pass
extract_zip_bytes_to_dir(zip_bytes, target_dir)
marker_path.write_text(
json.dumps(
{
"mounted": True,
"zip_sha256": zip_sha256,
"name": marker_name,
"created_at": utc_now_iso(),
},
indent=2,
),
encoding="utf-8",
)
return MountedArchive(
name=marker_name,
zip_sha256=zip_sha256,
target_dir=target_dir,
marker_path=marker_path,
)
# ============================================================
# AVATAR BANK RUNTIME
# ============================================================
class AvatarBankRuntime:
def __init__(
self,
payload: Dict[str, Any],
defaults: Optional[Dict[str, Any]] = None,
):
self.index: Dict[str, Dict[str, Any]] = payload.get("index", {}) or {}
self.embeddings: Dict[str, Dict[str, Any]] = payload.get("embeddings", {}) or {}
self.previews: Dict[str, Any] = payload.get("previews", {}) or {}
self.defaults = defaults or {}
@classmethod
def load(
cls,
path: Path,
defaults: Optional[Dict[str, Any]] = None,
) -> "AvatarBankRuntime":
payload = torch.load(str(path), map_location="cpu", weights_only=False)
if not isinstance(payload, dict):
raise ValueError(f"Avatar bank file did not contain a dictionary: {path}")
return cls(payload, defaults=defaults)
def save(self, path: Union[str, Path]) -> None:
torch.save(
{
"index": self.index,
"embeddings": self.embeddings,
"previews": self.previews,
},
str(path),
)
# --------------------------------------------------------
# BASIC ACCESS
# --------------------------------------------------------
def __contains__(self, avatar_id: str) -> bool:
return avatar_id in self.index
def exists(self, avatar_id: str) -> bool:
return avatar_id in self.index
def available_ids(self) -> List[str]:
return list(self.index.keys())
def list_avatars(self) -> List[str]:
return self.available_ids()
def get_metadata(self, avatar_id: str) -> Dict[str, Any]:
if avatar_id not in self.index:
raise KeyError(f"Avatar not found: {avatar_id}")
return dict(self.index[avatar_id])
def get_avatar(self, avatar_id: str) -> Dict[str, Any]:
return self.build_avatar_condition(avatar_id)
def get_embedding_bundle(self, avatar_id: str) -> Dict[str, Any]:
if avatar_id not in self.embeddings:
raise KeyError(f"Avatar not found: {avatar_id}")
return self.embeddings[avatar_id]
# --------------------------------------------------------
# FUZZY SEARCH
# --------------------------------------------------------
def _fuzzy_match_single(
self,
query: str,
choices: set,
cutoff: float = 0.6,
):
if not query:
return None
query_lower = query.lower()
choice_map = {
c.lower(): c
for c in choices
}
matches = get_close_matches(
query_lower,
list(choice_map.keys()),
n=1,
cutoff=cutoff,
)
return choice_map[matches[0]] if matches else None
def fuzzy_search_id(
self,
query_id: str,
n: int = 5,
cutoff: float = 0.5,
) -> List[str]:
query_lower = query_id.lower()
id_map = {
aid.lower(): aid
for aid in self.index.keys()
}
matches = get_close_matches(
query_lower,
list(id_map.keys()),
n=n,
cutoff=cutoff,
)
return [id_map[m] for m in matches]
# --------------------------------------------------------
# QUERY
# --------------------------------------------------------
def query(
self,
gender=None,
style=None,
fuzzy=True,
cutoff=0.6,
) -> List[str]:
known_genders = {
meta["gender"]
for meta in self.index.values()
if meta.get("gender")
}
known_styles = {
meta["style"]
for meta in self.index.values()
if meta.get("style")
}
target_gender = gender
target_style = style
if fuzzy:
if gender:
target_gender = (
self._fuzzy_match_single(
gender,
known_genders,
cutoff,
)
or gender
)
if style:
target_style = (
self._fuzzy_match_single(
style,
known_styles,
cutoff,
)
or style
)
results = []
for aid, meta in self.index.items():
if (
target_gender
and meta.get("gender") != target_gender
):
continue
if (
target_style
and meta.get("style") != target_style
):
continue
results.append(aid)
return results
# --------------------------------------------------------
# PREVIEWS
# --------------------------------------------------------
def get_preview(self, avatar_id: str):
if avatar_id not in self.previews:
raise KeyError(f"Avatar not found: {avatar_id}")
return decode_png_or_zstd_image(
self.previews[avatar_id]
)
def get_preview_numpy(
self,
avatar_id: str,
) -> Optional[np.ndarray]:
return self._preview_to_numpy(avatar_id)
def _preview_to_numpy(
self,
avatar_id: str,
) -> Optional[np.ndarray]:
blob = self.previews.get(avatar_id)
if blob is None:
return None
try:
img = decode_png_or_zstd_image(blob)
return pil_to_numpy_rgb(img)
except Exception:
return None
# --------------------------------------------------------
# MUTATION
# --------------------------------------------------------
def delete_avatar(
self,
avatar_id: str,
) -> None:
self.index.pop(avatar_id, None)
self.embeddings.pop(avatar_id, None)
self.previews.pop(avatar_id, None)
@classmethod
def load(cls, path: Path, defaults: Optional[Dict[str, Any]] = None) -> "AvatarBankRuntime":
payload = torch.load(str(path), map_location="cpu", weights_only=False)
if not isinstance(payload, dict):
raise ValueError(f"Avatar bank file did not contain a dictionary: {path}")
return cls(payload, defaults=defaults)
def available_ids(self) -> List[str]:
return list(self.index.keys())
def _preview_to_numpy(self, avatar_id: str) -> Optional[np.ndarray]:
blob = self.previews.get(avatar_id)
if blob is None:
return None
try:
img = decode_png_or_zstd_image(blob)
return pil_to_numpy_rgb(img)
except Exception:
return None
def _style_is_real(self, style: Optional[str]) -> bool:
return normalize_style_name(style) in REAL_STYLE_ALIASES
def resolve_default_avatar_id(self) -> str:
if not self.index:
raise RuntimeError("Avatar bank is empty.")
default_voice = self.defaults.get("default_avatar")
if default_voice and default_voice in self.index:
return default_voice
# Prefer first real male.
for avatar_id, meta in self.index.items():
if normalize_gender_name(meta.get("gender")) == "male" and self._style_is_real(meta.get("style")):
return avatar_id
# Then any real-style avatar.
for avatar_id, meta in self.index.items():
if self._style_is_real(meta.get("style")):
return avatar_id
# Then any male avatar.
for avatar_id, meta in self.index.items():
if normalize_gender_name(meta.get("gender")) == "male":
return avatar_id
# Then any complete avatar.
for avatar_id, emb in self.embeddings.items():
if emb is not None:
return avatar_id
# Finally first available entry.
return next(iter(self.index.keys()))
def build_avatar_condition(self, avatar_id: str) -> Dict[str, Any]:
if avatar_id not in self.embeddings:
raise KeyError(f"Avatar not found: {avatar_id}")
meta = self.index.get(avatar_id, {}) or {}
emb = self.embeddings[avatar_id] or {}
coeff = emb.get("motion_3dmm")
if coeff is None:
coeff = emb.get("full_3dmm")
if coeff is None:
raise ValueError(f"Avatar '{avatar_id}' is missing motion_3dmm/full_3dmm")
crop_preview = emb.get("crop_preview")
if crop_preview is None:
crop_preview = self._preview_to_numpy(avatar_id)
else:
if torch.is_tensor(crop_preview):
crop_preview = crop_preview.detach().cpu()
elif isinstance(crop_preview, np.ndarray):
crop_preview = crop_preview
else:
crop_preview = np.asarray(crop_preview)
out = {
"avatar_id": avatar_id,
"gender": meta.get("gender"),
"style": meta.get("style"),
"coeff_3dmm": coeff.detach().cpu() if torch.is_tensor(coeff) else coeff,
"motion_3dmm": emb.get("motion_3dmm"),
"full_3dmm": emb.get("full_3dmm"),
"crop_info": emb.get("crop_info"),
"crop_preview": crop_preview,
}
if torch.is_tensor(out["motion_3dmm"]):
out["motion_3dmm"] = out["motion_3dmm"].detach().cpu()
if torch.is_tensor(out["full_3dmm"]):
out["full_3dmm"] = out["full_3dmm"].detach().cpu()
return out
# ============================================================
# BRIA RMBG BACKGROUND REMOVER (BEST-EFFORT)
# ============================================================
class BriaBackgroundRemover:
"""
Best-effort loader for the packed briaaiRMBG-2.0 directory.
It searches for a likely inference script and tries callable or CLI-based
execution patterns. If the local folder layout differs, the search list
below is the only part that usually needs adjustment.
"""
def __init__(self, root: Path):
self.root = root
self.entrypoint = self._discover_entrypoint()
def _discover_entrypoint(self) -> Optional[Path]:
if not self.root.exists():
return None
preferred = [
"inference.py",
"predict.py",
"app.py",
"main.py",
"run.py",
]
for name in preferred:
hits = list(self.root.rglob(name))
if hits:
return hits[0]
# Fall back to any Python file with a likely folder name.
for p in self.root.rglob("*.py"):
lower = str(p).lower()
if "bria" in lower or "rmbg" in lower or "background" in lower:
return p
return None
def _import_module_from_path(self, py_file: Path):
module_name = f"packed_bria_{sha256_bytes(str(py_file).encode('utf-8'))[:12]}"
spec = importlib.util.spec_from_file_location(module_name, str(py_file))
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not import module from {py_file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _call_module_callable(self, module, image_path: Path, output_path: Path) -> bool:
candidates = [
"remove_background",
"predict_image",
"predict",
"run",
"inference",
"main",
]
callables = [getattr(module, name, None) for name in candidates]
callables = [fn for fn in callables if callable(fn)]
for fn in callables:
attempts = [
(str(image_path), str(output_path)),
(str(image_path),),
(Image.open(image_path),),
(),
]
for args in attempts:
try:
result = fn(*args)
if isinstance(result, (str, os.PathLike)):
result_path = Path(result)
if result_path.exists():
shutil.copy2(result_path, output_path)
return True
elif isinstance(result, Image.Image):
result.save(output_path)
return True
elif torch.is_tensor(result):
arr = result.detach().cpu().numpy()
if arr.ndim == 3 and arr.shape[-1] in (3, 4):
img = Image.fromarray(arr.astype(np.uint8))
img.save(output_path)
return True
elif result is None and output_path.exists():
return True
except Exception:
continue
return False
def _call_cli_with_patterns(self, image_path: Path, output_path: Path) -> bool:
if self.entrypoint is None:
return False
cmd_patterns = [
[str(self.entrypoint), str(image_path), str(output_path)],
[str(self.entrypoint), "--input", str(image_path), "--output", str(output_path)],
[str(self.entrypoint), "--image", str(image_path), "--output", str(output_path)],
[str(self.entrypoint), "--input_path", str(image_path), "--output_path", str(output_path)],
[str(self.entrypoint), "-i", str(image_path), "-o", str(output_path)],
]
for args in cmd_patterns:
try:
proc = subprocess.run(
[sys.executable, *args],
cwd=str(self.root),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False,
)
if proc.returncode == 0 and output_path.exists():
return True
except Exception:
continue
return False
def remove_background(self, image_path: Path, output_dir: Path) -> Path:
if self.entrypoint is None:
raise RuntimeError(
f"No usable background-removal entrypoint found under {self.root}."
)
ensure_dir(output_dir)
output_path = output_dir / f"{image_path.stem}_rmbg.png"
try:
module = self._import_module_from_path(self.entrypoint)
if self._call_module_callable(module, image_path, output_path):
return output_path
except Exception:
pass
if self._call_cli_with_patterns(image_path, output_path):
return output_path
raise RuntimeError(
f"Could not execute background removal with entrypoint {self.entrypoint}. "
f"You may need to adjust the call patterns in BriaBackgroundRemover."
)
# ============================================================
# SADTALKER CORE RUNTIME
# ============================================================
class SadTalkerRunner:
def __init__(self, checkpoint_path: str, config_path: str, device: str = "cpu"):
self.checkpoint_path = checkpoint_path
self.config_path = config_path
self.device = device
self._mods_loaded = False
self._load_modules()
def _load_modules(self):
if self._mods_loaded:
return
from SadTalker.src.facerender.pirender_animate import AnimateFromCoeff_PIRender
from SadTalker.src.utils.preprocess import CropAndExtract
from SadTalker.src.test_audio2coeff import Audio2Coeff
from SadTalker.src.facerender.animate import AnimateFromCoeff
from SadTalker.src.generate_batch import get_data
from SadTalker.src.generate_facerender_batch import get_facerender_data
from SadTalker.src.utils.init_path import init_path
self.AnimateFromCoeff_PIRender = AnimateFromCoeff_PIRender
self.CropAndExtract = CropAndExtract
self.Audio2Coeff = Audio2Coeff
self.AnimateFromCoeff = AnimateFromCoeff
self.get_data = get_data
self.get_facerender_data = get_facerender_data
self.init_path = init_path
self._mods_loaded = True
@staticmethod
def _mp3_to_wav(mp3_filename: str, wav_filename: str, frame_rate: int):
mp3_file = AudioSegment.from_file(file=mp3_filename)
mp3_file.set_frame_rate(frame_rate).export(wav_filename, format="wav")
def _to_numpy(self, x):
if x is None:
return None
if isinstance(x, np.ndarray):
return x
if torch.is_tensor(x):
return x.detach().cpu().numpy()
return np.asarray(x)
def _save_png_from_bundle(self, bundle, out_path):
for key in ("crop_preview", "aligned_face", "image", "png"):
if key in bundle and bundle[key] is not None:
arr = self._to_numpy(bundle[key])
if arr.ndim == 3 and arr.shape[-1] in (1, 3, 4):
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 255).astype(np.uint8)
if arr.shape[-1] == 4:
img = Image.fromarray(arr, mode="RGBA").convert("RGB")
else:
img = Image.fromarray(arr, mode="RGB")
img.save(out_path)
return out_path
raise ValueError(
"Avatar conditioning bundle needs at least one image-like field such as crop_preview or aligned_face."
)
def _save_mat_from_avatar_bundle(self, bundle, out_path):
coeff_3dmm = bundle.get("coeff_3dmm", None)
if coeff_3dmm is None:
coeff_3dmm = bundle.get("motion_3dmm", None)
if coeff_3dmm is None:
coeff_3dmm = bundle.get("full_3dmm", None)
if coeff_3dmm is None:
raise ValueError("Avatar bundle must contain coeff_3dmm, motion_3dmm, or full_3dmm.")
mat_dict = {"coeff_3dmm": self._to_numpy(coeff_3dmm)}
full_3dmm = bundle.get("full_3dmm", None)
if full_3dmm is not None:
mat_dict["full_3dmm"] = self._to_numpy(full_3dmm)
savemat(out_path, mat_dict)
return out_path
def _save_mat_from_motion_bundle(self, bundle, out_path):
motion = bundle.get("motion_3dmm", None)
if motion is None:
motion = bundle.get("coeff_3dmm", None)
if motion is None:
motion = bundle.get("full_3dmm_seq", None)
if motion is None:
motion = bundle.get("full_3dmm", None)
if motion is None:
raise ValueError(
"Motion bundle must contain motion_3dmm, coeff_3dmm, full_3dmm_seq, or full_3dmm."
)
mat_dict = {"coeff_3dmm": self._to_numpy(motion)}
if "full_3dmm" in bundle and bundle["full_3dmm"] is not None:
mat_dict["full_3dmm"] = self._to_numpy(bundle["full_3dmm"])
elif "full_3dmm_seq" in bundle and bundle["full_3dmm_seq"] is not None:
seq = self._to_numpy(bundle["full_3dmm_seq"])
if seq.ndim >= 3:
mat_dict["full_3dmm"] = seq[0]
else:
mat_dict["full_3dmm"] = seq
savemat(out_path, mat_dict)
return out_path
def _bundle_from_preprocess_output(
self,
coeff_path,
crop_pic_path,
crop_info,
):
bundle = {}
# Load whatever the SadTalker preprocessing wrote to disk.
if coeff_path is not None and os.path.isfile(coeff_path):
try:
raw = loadmat(coeff_path)
for key, value in raw.items():
if not key.startswith("__"):
bundle[key] = value
except Exception:
pass
# Preserve the paths used to generate the bundle.
if coeff_path is not None:
bundle["coeff_path"] = str(coeff_path)
if crop_pic_path is not None:
bundle["crop_pic_path"] = str(crop_pic_path)
if crop_info is not None:
bundle["crop_info"] = crop_info
# Keep a usable preview in memory.
try:
if crop_pic_path is not None and os.path.isfile(crop_pic_path):
with Image.open(crop_pic_path) as im:
bundle["crop_preview"] = pil_to_numpy_rgb(im)
except Exception:
pass
# Normalize common aliases so downstream code can rely on them.
if "coeff_3dmm" in bundle and "motion_3dmm" not in bundle:
bundle["motion_3dmm"] = bundle["coeff_3dmm"]
if "motion_3dmm" in bundle and "coeff_3dmm" not in bundle:
bundle["coeff_3dmm"] = bundle["motion_3dmm"]
if "full_3dmm" not in bundle:
if "full_3dmm_seq" in bundle:
seq = bundle["full_3dmm_seq"]
try:
if hasattr(seq, "ndim") and seq.ndim >= 3:
bundle["full_3dmm"] = seq[0]
else:
bundle["full_3dmm"] = seq
except Exception:
bundle["full_3dmm"] = seq
elif "motion_3dmm" in bundle:
bundle["full_3dmm"] = bundle["motion_3dmm"]
if "landmarks" in bundle:
bundle["landmarks"] = bundle["landmarks"]
return bundle
def extract_embeddings(
self,
input_path,
crop_or_resize: str = "crop",
pic_size: int = 256,
save_dir: Optional[str] = None,
):
"""
Public preprocessing helper.
Accepts either a source image or a reference video, runs the packed
SadTalker preprocessing, and returns the extracted conditioning bundle.
"""
self._load_modules()
self._ensure_models(size=pic_size, preprocess=crop_or_resize, facerender="facevid2vid")
input_path = Path(input_path)
if not input_path.exists():
raise FileNotFoundError(str(input_path))
if save_dir is None:
save_dir = tempfile.mkdtemp(prefix="packedavatar_embeddings_")
else:
ensure_dir(Path(save_dir))
work_dir = Path(save_dir)
input_ext = input_path.suffix.lower()
video_exts = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".flv", ".wmv", ".m4v", ".gif"}
if input_ext in video_exts:
frame_dir = work_dir / f"{input_path.stem}_frames"
ensure_dir(frame_dir)
coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
str(input_path),
str(frame_dir),
crop_or_resize,
source_image_flag=False,
)
else:
staged = work_dir / input_path.name
shutil.copy2(input_path, staged)
first_frame_dir = work_dir / "first_frame_dir"
ensure_dir(first_frame_dir)
coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
str(staged),
str(first_frame_dir),
crop_or_resize,
True,
pic_size,
)
return self._bundle_from_preprocess_output(coeff_path, crop_pic_path, crop_info)
def ExtractEmbeddings(
self,
input_path,
crop_or_resize: str = "crop",
pic_size: int = 256,
save_dir: Optional[str] = None,
):
return self.extract_embeddings(
input_path=input_path,
crop_or_resize=crop_or_resize,
pic_size=pic_size,
save_dir=save_dir,
)
def _materialize_avatar_condition(self, avatar_condition, save_dir):
bundle = safe_load_bundle(avatar_condition)
if bundle is None:
return None, None, None
coeff_path = bundle.get("coeff_path", None)
crop_pic_path = bundle.get("crop_pic_path", None)
crop_info = bundle.get("crop_info", None)
if coeff_path is None or not os.path.isfile(coeff_path):
coeff_path = os.path.join(save_dir, "avatar_condition.mat")
self._save_mat_from_avatar_bundle(bundle, coeff_path)
if crop_pic_path is None or not os.path.isfile(crop_pic_path):
crop_pic_path = os.path.join(save_dir, "avatar_condition.png")
self._save_png_from_bundle(bundle, crop_pic_path)
return coeff_path, crop_pic_path, crop_info
def _materialize_motion_condition(self, motion_condition, save_dir):
bundle = safe_load_bundle(motion_condition)
if bundle is None:
return None
coeff_path = bundle.get("coeff_path", None)
if coeff_path is not None and os.path.isfile(coeff_path):
return coeff_path
coeff_path = os.path.join(save_dir, "motion_condition.mat")
self._save_mat_from_motion_bundle(bundle, coeff_path)
return coeff_path
def _resolve_checkpoint(self):
candidates = [
"SadTalker_V0.0.2_512.safetensors",
"SadTalker_V0.0.2_256.safetensors",
"SadTalker_V0.0.2_512.pth",
"SadTalker_V0.0.2_256.pth",
]
for name in candidates:
p = Path(self.checkpoint_path) / name
if p.exists():
return str(p)
raise FileNotFoundError(
f"No SadTalker checkpoint found in {self.checkpoint_path}"
)
def _ensure_models(self, size: int, preprocess: str, facerender: str):
self.sadtalker_paths = self.init_path(
self.checkpoint_path,
self.config_path,
size,
False,
preprocess,
)
# override whatever init_path guessed
self.sadtalker_paths["checkpoint"] = self._resolve_checkpoint()
print("\n[PackedAvatar] Using checkpoint:")
print(self.sadtalker_paths["checkpoint"])
self.audio_to_coeff = self.Audio2Coeff(
self.sadtalker_paths,
self.device
)
self.preprocess_model = self.CropAndExtract(
self.sadtalker_paths,
self.device
)
if facerender == "facevid2vid" and self.device != "mps":
self.animate_from_coeff = self.AnimateFromCoeff(
self.sadtalker_paths,
self.device
)
else:
self.animate_from_coeff = self.AnimateFromCoeff_PIRender(
self.sadtalker_paths,
self.device
)
def generate(
self,
source_image=None,
driven_audio=None,
preprocess="crop",
still_mode=False,
use_enhancer=False,
batch_size=1,
size=256,
pose_style=0,
facerender="facevid2vid",
exp_scale=1.0,
use_ref_video=False,
ref_video=None,
ref_info=None,
use_idle_mode=False,
length_of_audio=0,
use_blink=True,
result_dir="./results/",
avatar_condition=None,
motion_condition=None,
):
self._load_modules()
self._ensure_models(size=size, preprocess=preprocess, facerender=facerender)
time_tag = str(uuid.uuid4())
save_dir = os.path.join(result_dir, time_tag)
os.makedirs(save_dir, exist_ok=True)
input_dir = os.path.join(save_dir, "input")
os.makedirs(input_dir, exist_ok=True)
# -----------------------------
# Audio handling
# -----------------------------
if driven_audio is not None and os.path.isfile(driven_audio):
audio_name = os.path.basename(driven_audio)
audio_path = os.path.join(input_dir, audio_name)
if audio_name.lower().endswith(".mp3"):
wav_path = os.path.splitext(audio_path)[0] + ".wav"
self._mp3_to_wav(driven_audio, wav_path, 16000)
audio_path = wav_path
else:
shutil.copy2(driven_audio, audio_path)
elif use_idle_mode:
audio_path = os.path.join(input_dir, f"idlemode_{str(length_of_audio)}.wav")
one_sec_segment = AudioSegment.silent(duration=1000 * length_of_audio)
one_sec_segment.export(audio_path, format="wav")
else:
assert use_ref_video is True and ref_info == "all", (
"Either driven_audio, use_idle_mode, or use_ref_video/ref_info='all' must be provided."
)
if use_ref_video and ref_info == "all" and ref_video is not None:
ref_video_videoname = os.path.basename(ref_video)
audio_path = os.path.join(save_dir, ref_video_videoname + ".wav")
cmd = f'ffmpeg -y -hide_banner -loglevel error -i "{ref_video}" "{audio_path}"'
os.system(cmd)
# -----------------------------
# Avatar / source conditioning
# -----------------------------
if avatar_condition is not None:
first_coeff_path, crop_pic_path, crop_info = self._materialize_avatar_condition(
avatar_condition, save_dir
)
if first_coeff_path is None:
raise AttributeError("Invalid avatar_condition bundle.")
pic_path = crop_pic_path
else:
if source_image is None:
raise ValueError("source_image is required when avatar_condition is not provided.")
pic_path = os.path.join(input_dir, os.path.basename(source_image))
shutil.copy2(source_image, pic_path)
first_frame_dir = os.path.join(save_dir, "first_frame_dir")
os.makedirs(first_frame_dir, exist_ok=True)
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
pic_path,
first_frame_dir,
preprocess,
True,
size,
)
if first_coeff_path is None:
raise AttributeError("No face is detected")
# -----------------------------
# Motion conditioning / reference video
# -----------------------------
if motion_condition is not None:
ref_video_coeff_path = self._materialize_motion_condition(motion_condition, save_dir)
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = ref_video_coeff_path
elif use_ref_video and ref_video is not None:
ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
os.makedirs(ref_video_frame_dir, exist_ok=True)
print("3DMM Extraction for the reference video providing pose")
ref_video_coeff_path, _, _ = self.preprocess_model.generate(
ref_video,
ref_video_frame_dir,
preprocess,
source_image_flag=False,
)
if use_ref_video:
if ref_info == "pose":
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = None
elif ref_info == "blink":
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = ref_video_coeff_path
elif ref_info == "pose+blink":
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = ref_video_coeff_path
elif ref_info == "all":
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None
else:
raise ValueError("error in ref_info")
else:
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None
else:
ref_video_coeff_path = None
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None
# -----------------------------
# Audio -> coeff
# -----------------------------
if use_ref_video and ref_info == "all" and ref_video_coeff_path is not None:
coeff_path = ref_video_coeff_path
else:
batch = self.get_data(
first_coeff_path,
audio_path,
self.device,
ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
still=still_mode,
idlemode=use_idle_mode,
length_of_audio=length_of_audio,
use_blink=use_blink,
)
coeff_path = self.audio_to_coeff.generate(
batch,
save_dir,
pose_style,
ref_pose_coeff_path,
)
# -----------------------------
# coeff -> video
# -----------------------------
data = self.get_facerender_data(
coeff_path,
crop_pic_path,
first_coeff_path,
audio_path,
batch_size,
still_mode=still_mode,
preprocess=preprocess,
size=size,
expression_scale=exp_scale,
facemodel=facerender,
)
return_path = self.animate_from_coeff.generate(
data,
save_dir,
crop_pic_path if avatar_condition is not None else pic_path,
crop_info,
enhancer="gfpgan" if use_enhancer else None,
preprocess=preprocess,
img_size=size,
)
video_name = data.get("video_name", "output")
print(f"The generated video is named {video_name} in {save_dir}")
del self.preprocess_model
del self.audio_to_coeff
del self.animate_from_coeff
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
import gc
gc.collect()
return return_path, audio_path, save_dir
# ============================================================
# PACKED AVATAR ORCHESTRATOR
# ============================================================
class PackedAvatar:
def __init__(
self,
packed_pt_path: str = None,
cache_dir: Optional[str] = None,
device: Optional[str] = None,
):
self.packed_pt_path = Path(packed_pt_path or (Path(__file__).resolve().parent / "checkpoints" / "PackedAvatar.pt"))
if not self.packed_pt_path.exists():
raise FileNotFoundError(f"Packed bundle not found: {self.packed_pt_path}")
self.device = device or (
"cuda" if torch.cuda.is_available() else ("mps" if platform.system() == "Darwin" else "cpu")
)
self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.gettempdir()) / "PackedAvatarCache"
ensure_dir(self.cache_dir)
self.bundle = self._load_bundle(self.packed_pt_path)
self.manifest = self.bundle.get("manifest", {}) or {}
self._extract_and_mount()
self._mount_python_path()
self.avatar_bank = self._load_avatar_bank()
self.bria_root = self.extracted_root / "checkpoints" / "briaaiRMBG-2.0"
self.background_remover = BriaBackgroundRemover(self.bria_root)
self._runner_cache: Dict[Tuple[int, str, str], SadTalkerRunner] = {}
@staticmethod
def _load_bundle(path: Path) -> Dict[str, Any]:
bundle = torch.load(str(path), map_location="cpu", weights_only=False)
if not isinstance(bundle, dict):
raise ValueError("PackedAvatar.pt did not contain a dictionary bundle.")
return bundle
def _asset_bytes(self, key: str) -> bytes:
asset = self.bundle.get("assets", {}).get(key)
if asset is None:
raise KeyError(f"Missing asset in bundle: {key}")
return tensor_to_bytes(asset)
def _bundle_id(self) -> str:
ck_hash = self.manifest.get("archives", {}).get("checkpoints_zip", {}).get("sha256", "")
sd_hash = self.manifest.get("archives", {}).get("sadtalker_zip", {}).get("sha256", "")
seed = f"{ck_hash}:{sd_hash}".encode("utf-8")
return sha256_bytes(seed)[:16]
def _extract_and_mount(self) -> None:
bundle_id = self._bundle_id()
runtime_root = self.cache_dir / f"packedavatar_{bundle_id}"
self.runtime_root = runtime_root
self.extracted_root = runtime_root / "extracted"
ensure_dir(self.extracted_root)
marker = runtime_root / "mount.json"
expected = {
"bundle_id": bundle_id,
"checkpoints_sha256": self.manifest.get("archives", {}).get("checkpoints_zip", {}).get("sha256"),
"sadtalker_sha256": self.manifest.get("archives", {}).get("sadtalker_zip", {}).get("sha256"),
}
if marker.exists():
try:
existing = json.loads(marker.read_text(encoding="utf-8"))
if existing == expected:
self.checkpoints_dir = self.extracted_root / "checkpoints"
self.sadtalker_dir = self.extracted_root / "SadTalker"
return
except Exception:
pass
# Reset stale extraction if the bundle changed.
if self.extracted_root.exists():
for child in list(self.extracted_root.iterdir()):
if child.is_dir():
shutil.rmtree(child, ignore_errors=True)
else:
try:
child.unlink()
except Exception:
pass
checkpoints_zip = self._asset_bytes("checkpoints_zip")
sadtalker_zip = self._asset_bytes("sadtalker_zip")
# Extract both archives into the same extracted root.
extract_zip_bytes_to_dir(checkpoints_zip, self.extracted_root)
extract_zip_bytes_to_dir(sadtalker_zip, self.extracted_root)
marker.write_text(json.dumps(expected, indent=2), encoding="utf-8")
self.checkpoints_dir = self.extracted_root / "checkpoints"
self.sadtalker_dir = self.extracted_root / "SadTalker"
if not self.checkpoints_dir.exists():
raise RuntimeError(f"checkpoints folder missing after extraction: {self.checkpoints_dir}")
if not self.sadtalker_dir.exists():
raise RuntimeError(f"SadTalker folder missing after extraction: {self.sadtalker_dir}")
def _mount_python_path(self) -> None:
extracted = str(self.extracted_root)
if extracted not in sys.path:
sys.path.insert(0, extracted)
def _load_avatar_bank(self) -> AvatarBankRuntime:
bank_path = self.checkpoints_dir / "AvatarBank.pt"
if not bank_path.exists():
raise FileNotFoundError(f"AvatarBank.pt not found inside packed checkpoints: {bank_path}")
defaults = {
"default_avatar": self.manifest.get("defaults", {}).get("default_avatar", ""),
"real_style_aliases": self.manifest.get("defaults", {}).get("real_style_aliases", list(REAL_STYLE_ALIASES)),
}
return AvatarBankRuntime.load(bank_path, defaults=defaults)
def _get_runner(self, size: int, preprocess: str, facerender: str) -> SadTalkerRunner:
key = (int(size), preprocess, facerender)
runner = self._runner_cache.get(key)
if runner is None:
runner = SadTalkerRunner(
checkpoint_path=str(self.checkpoints_dir),
config_path=str(self.sadtalker_dir / "src" / "config"),
device=self.device,
)
self._runner_cache[key] = runner
return runner
def extract_embeddings(
self,
input_path: str,
crop_or_resize: str = "crop",
pic_size: int = 256,
save_dir: Optional[str] = None,
) -> Dict[str, Any]:
"""
Extract a conditioning bundle from a source image or reference video.
The returned dictionary is the same kind of bundle the runtime uses
internally for avatar conditioning and motion conditioning.
"""
runner = self._get_runner(size=pic_size, preprocess=crop_or_resize, facerender="facevid2vid")
return runner.extract_embeddings(
input_path=input_path,
crop_or_resize=crop_or_resize,
pic_size=pic_size,
save_dir=save_dir,
)
def ExtractEmbeddings(
self,
input_path: str,
crop_or_resize: str = "crop",
pic_size: int = 256,
save_dir: Optional[str] = None,
) -> Dict[str, Any]:
return self.extract_embeddings(
input_path=input_path,
crop_or_resize=crop_or_resize,
pic_size=pic_size,
save_dir=save_dir,
)
def _resolve_avatar_condition_from_bank(self, avatar_id: Optional[str]) -> Dict[str, Any]:
if avatar_id is None:
avatar_id = self.avatar_bank.resolve_default_avatar_id()
return self.avatar_bank.build_avatar_condition(avatar_id)
def _normalize_avatar_condition(self, avatar_condition: Any) -> Optional[Dict[str, Any]]:
bundle = safe_load_bundle(avatar_condition)
if bundle is None:
return None
if "coeff_3dmm" not in bundle:
if "motion_3dmm" in bundle and bundle["motion_3dmm"] is not None:
bundle["coeff_3dmm"] = bundle["motion_3dmm"]
elif "full_3dmm" in bundle and bundle["full_3dmm"] is not None:
bundle["coeff_3dmm"] = bundle["full_3dmm"]
return bundle
def _remove_background_if_requested(
self,
source_image: Optional[str],
remove_background: bool,
work_dir: Path,
) -> Optional[Path]:
if source_image is None:
return None
src = Path(source_image)
if not src.exists():
raise FileNotFoundError(str(src))
ensure_dir(work_dir)
staged = work_dir / src.name
shutil.copy2(src, staged)
if not remove_background:
return prepare_image_for_sadtalker(staged)
# Best-effort background removal using the packed Bria folder.
try:
removed = self.background_remover.remove_background(staged, work_dir)
return prepare_image_for_sadtalker(staged, removed)
except Exception as e:
raise RuntimeError(
f"remove_background=True was requested, but Bria RMBG execution failed: {e}"
) from e
def _run_wav2lip_gan(
self,
face_video: str,
audio_path: str,
save_dir: str,
wav2lip_repo: Optional[str] = None,
) -> str:
wav2lip_checkpoint = self.checkpoints_dir / "wav2lip_gan.pth"
if not wav2lip_checkpoint.is_file():
raise FileNotFoundError(
f"Could not find bundled Wav2Lip GAN checkpoint at: {wav2lip_checkpoint}"
)
candidate_repos = []
if wav2lip_repo:
candidate_repos.append(Path(wav2lip_repo))
# Prefer packed locations first.
candidate_repos.extend([
self.checkpoints_dir / "Wav2Lip",
self.sadtalker_dir / "Wav2Lip",
Path(__file__).resolve().parent / "Wav2Lip",
])
repo = None
for candidate in candidate_repos:
if candidate is None:
continue
inference_py = candidate / "inference.py"
if inference_py.is_file():
repo = candidate
break
# No error just because wav2lip_repo was not passed.
# If we cannot find runnable Wav2Lip code anywhere, fall back gracefully.
if repo is None:
print(
"[PackedAvatar] Wav2Lip inference code was not found; "
"skipping Wav2Lip post-processing and returning the SadTalker video."
)
return face_video
inference_py = repo / "inference.py"
out_video = os.path.join(save_dir, f"{Path(face_video).stem}_wav2lip_gan.mp4")
cmd = [
sys.executable,
str(inference_py),
"--checkpoint_path",
str(wav2lip_checkpoint),
"--face",
str(face_video),
"--audio",
str(audio_path),
"--outfile",
str(out_video),
]
subprocess.run(cmd, cwd=str(repo), check=True)
return out_video
def list_avatars(self):
return self.avatar_bank.list_avatars()
def query_avatars(self, *args, **kwargs):
return self.avatar_bank.query(*args, **kwargs)
def fuzzy_search_avatar(self, query, n=5, cutoff=0.5):
return self.avatar_bank.fuzzy_search_id(query, n, cutoff)
def get_avatar(self, avatar_id):
return self.avatar_bank.get_avatar(avatar_id)
def get_avatar_preview(self, avatar_id):
return self.avatar_bank.get_preview(avatar_id)
def get_avatar_metadata(self, avatar_id):
return self.avatar_bank.get_metadata(avatar_id)
def delete_avatar(self, avatar_id):
self.avatar_bank.delete_avatar(avatar_id)
def save_avatar_bank(self, path):
self.avatar_bank.save(path)
def generate(
self,
source_image: Optional[str] = None,
driven_audio: Optional[str] = None,
preprocess: str = "crop",
still_mode: bool = False,
use_enhancer: bool = False,
batch_size: int = 1,
size: int = 256,
pose_style: int = 0,
facerender: str = "facevid2vid",
exp_scale: float = 1.0,
use_ref_video: bool = False,
ref_video: Optional[str] = None,
ref_info: Optional[str] = None,
use_idle_mode: bool = False,
length_of_audio: int = 0,
use_blink: bool = True,
result_dir: str = "./results/",
avatar_id: Optional[str] = None,
avatar_condition: Optional[Any] = None,
motion_condition: Optional[Any] = None,
remove_background: bool = False,
use_wav2lip: bool = False,
wav2lip_repo: Optional[str] = None,
) -> str:
runner = self._get_runner(size=size, preprocess=preprocess, facerender=facerender)
ensure_dir(Path(result_dir))
# If the caller did not provide a source image or explicit avatar condition,
# use the bank. If a source image is provided, it stays in the SadTalker path.
resolved_avatar_condition = self._normalize_avatar_condition(avatar_condition)
source_image_for_runner: Optional[str] = source_image
if resolved_avatar_condition is None:
if source_image_for_runner is None:
resolved_avatar_condition = self._resolve_avatar_condition_from_bank(avatar_id)
else:
# source_image path will be used directly by SadTalker; optionally background remove it.
source_work_dir = self.runtime_root / "source_work"
ensure_dir(source_work_dir)
prepared = self._remove_background_if_requested(source_image_for_runner, remove_background, source_work_dir)
source_image_for_runner = str(prepared) if prepared is not None else source_image_for_runner
else:
# If an explicit avatar_condition is supplied, it supersedes source_image-driven conditioning.
source_image_for_runner = None
# When avatar_id is explicitly selected and no source_image/condition was given,
# build the corresponding condition from the packed AvatarBank.
if resolved_avatar_condition is None and source_image_for_runner is None:
resolved_avatar_condition = self._resolve_avatar_condition_from_bank(avatar_id)
return_path, audio_path, save_dir = runner.generate(
source_image=source_image_for_runner,
driven_audio=driven_audio,
preprocess=preprocess,
still_mode=still_mode,
use_enhancer=use_enhancer,
batch_size=batch_size,
size=size,
pose_style=pose_style,
facerender=facerender,
exp_scale=exp_scale,
use_ref_video=use_ref_video,
ref_video=ref_video,
ref_info=ref_info,
use_idle_mode=use_idle_mode,
length_of_audio=length_of_audio,
use_blink=use_blink,
result_dir=result_dir,
avatar_condition=resolved_avatar_condition,
motion_condition=motion_condition,
)
if use_wav2lip:
return_path = self._run_wav2lip_gan(
face_video=return_path,
audio_path=audio_path,
save_dir=save_dir,
wav2lip_repo=wav2lip_repo,
)
return return_path
PackedAvatarModel = PackedAvatar
# ============================================================
# CLI
# ============================================================
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Run the packed avatar bundle.")
p.add_argument("--packed-pt", type=Path, default=Path(__file__).resolve().parent / "PackedAvatar.pt")
p.add_argument("--cache-dir", type=Path, default=None)
p.add_argument("--device", type=str, default=None)
p.add_argument("--source-image", type=Path, default=None)
p.add_argument("--driven-audio", type=Path, default="speech.wav")
p.add_argument("--avatar-id", type=str, default=None)
p.add_argument("--avatar-condition", type=Path, default=None)
p.add_argument("--motion-condition", type=Path, default=None)
p.add_argument("--remove-background", action="store_true")
p.add_argument("--use-wav2lip", action="store_true", default=True)
p.add_argument("--wav2lip-repo", type=Path, default=None)
p.add_argument("--result-dir", type=Path, default=Path("./results"))
p.add_argument("--preprocess", type=str, default="crop")
p.add_argument("--size", type=int, default=256)
p.add_argument("--facerender", type=str, default="facevid2vid")
p.add_argument("--still-mode", action="store_true")
p.add_argument("--use-enhancer", action="store_true")
p.add_argument("--batch-size", type=int, default=1)
p.add_argument("--pose-style", type=int, default=0)
p.add_argument("--exp-scale", type=float, default=1.0)
p.add_argument("--use-ref-video", action="store_true")
p.add_argument("--ref-video", type=Path, default=None)
p.add_argument("--ref-info", type=str, default=None)
p.add_argument("--use-idle-mode", action="store_true")
p.add_argument("--length-of-audio", type=int, default=0)
p.add_argument("--use-blink", action="store_true", default=True)
p.add_argument("--no-blink", action="store_false", dest="use_blink")
p.add_argument("--manual-audio", action="store_true", help="Alias for driven-audio handling; kept for clarity.")
return p
def main() -> None:
parser = build_parser()
args = parser.parse_args()
model = PackedAvatar(
packed_pt_path=str(args.packed_pt),
cache_dir=str(args.cache_dir) if args.cache_dir else None,
device=args.device,
)
avatar_condition = args.avatar_condition if args.avatar_condition else None
motion_condition = args.motion_condition if args.motion_condition else None
output = model.generate(
source_image=str(args.source_image) if args.source_image else None,
driven_audio=str(args.driven_audio) if args.driven_audio else None,
preprocess=args.preprocess,
still_mode=args.still_mode,
use_enhancer=args.use_enhancer,
batch_size=args.batch_size,
size=args.size,
pose_style=args.pose_style,
facerender=args.facerender,
exp_scale=args.exp_scale,
use_ref_video=args.use_ref_video,
ref_video=str(args.ref_video) if args.ref_video else None,
ref_info=args.ref_info,
use_idle_mode=args.use_idle_mode,
length_of_audio=args.length_of_audio,
use_blink=args.use_blink,
result_dir=str(args.result_dir),
avatar_id=args.avatar_id,
avatar_condition=str(avatar_condition) if avatar_condition else None,
motion_condition=str(motion_condition) if motion_condition else None,
remove_background=args.remove_background,
use_wav2lip=args.use_wav2lip,
wav2lip_repo=str(args.wav2lip_repo) if args.wav2lip_repo else None,
)
print(output)
if __name__ == "__main__":
main()