artalk-youtube / inference_engine.py
Nanny7's picture
Deploy ARTalk ZeroGPU Space
e9ab57b
"""ARTalk inference engine for HuggingFace ZeroGPU deployment.
원본 C:/DK/AT/ARTalk/inference.py 의 ARTAvatarInferEngine 을 기반으로,
ZeroGPU 환경에 맞춰 다음을 수정:
1. 모델 가중치를 HF Hub 에서 자동 다운로드 (build_resources.sh 대체)
2. FLAME 라이선스 인터랙티브 프롬프트 제거
3. 기본 동작은 동일 (inference + rendering)
app.py 에서 이 모듈의 ARTAvatarInferEngine 를 import 해서 사용.
"""
import os
import json
import shutil
import torch
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
from app import BitwiseARModel
from app.flame_model import FLAMEModel, RenderMesh
from app.utils_videos import write_video
# ═══════════════════════════════════════════════════════════
# Model weights — HuggingFace Hub 에서 런타임 다운로드
# ═══════════════════════════════════════════════════════════
ARTALK_REPO = "xg-chu/ARTalk"
GAGA_REPO = "xg-chu/GAGAvatar"
ASSETS_DIR = Path("assets")
# (repo_id, filename_in_repo, local_path_under_assets)
WEIGHTS_MANIFEST = [
(GAGA_REPO, "assets/FLAME_with_eye.pt", "FLAME_with_eye.pt"),
(GAGA_REPO, "assets/GAGAvatar.pt", "GAGAvatar/GAGAvatar.pt"),
(ARTALK_REPO, "ARTalk_wav2vec.pt", "ARTalk_wav2vec.pt"),
(ARTALK_REPO, "config.json", "config.json"),
(ARTALK_REPO, "GAGAvatar/tracked.pt", "GAGAvatar/tracked.pt"),
]
STYLE_MOTIONS = [
"angry_0", "curious_0", "doubtful_0", "doubtful_1",
"happy_0", "happy_1", "happy_2",
"natural_0", "natural_1", "natural_2", "natural_3",
"natural_4", "natural_5", "natural_6", "natural_7",
]
def _place_file(downloaded: str, target: Path):
"""HF cache 에서 다운로드된 파일을 target 경로에 배치.
os.link 는 cross-filesystem 에서 실패 가능 → shutil.copy 사용.
"""
target.parent.mkdir(parents=True, exist_ok=True)
if target.exists() or target.is_symlink():
target.unlink()
shutil.copy(downloaded, target)
assert target.exists(), f"copy failed: {target}"
def download_all_weights():
"""ARTalk 실행에 필요한 모든 모델 가중치를 HF Hub 에서 다운로드.
이미 존재하면 스킵. HF hub 는 자체 캐시 사용.
FLAME 라이선스는 xg-chu repos 가 이미 동의한 상태로 호스팅 — 추가 프롬프트 없음.
"""
ASSETS_DIR.mkdir(exist_ok=True)
(ASSETS_DIR / "GAGAvatar").mkdir(exist_ok=True)
(ASSETS_DIR / "style_motion").mkdir(exist_ok=True)
# Main weights
for repo_id, repo_file, local_rel in WEIGHTS_MANIFEST:
local_path = ASSETS_DIR / local_rel
if local_path.exists() and local_path.stat().st_size > 0:
print(f"[weights] exists: {local_path}")
continue
print(f"[weights] downloading {repo_id}:{repo_file} -> {local_path}", flush=True)
try:
downloaded = hf_hub_download(repo_id=repo_id, filename=repo_file)
_place_file(downloaded, local_path)
size_mb = local_path.stat().st_size / 1e6
print(f"[weights] OK: {local_path} ({size_mb:.1f}MB)", flush=True)
except Exception as e:
print(f"[weights] FAIL {repo_id}:{repo_file} -> {e}", flush=True)
raise
# Style motion files
for style_name in STYLE_MOTIONS:
local_path = ASSETS_DIR / "style_motion" / f"{style_name}.pt"
if local_path.exists():
continue
try:
downloaded = hf_hub_download(
repo_id=ARTALK_REPO,
filename=f"style_motion/{style_name}.pt",
)
_place_file(downloaded, local_path)
except Exception as e:
print(f"[weights] warning: style {style_name} unavailable ({e})")
# Verification
total = len(list(ASSETS_DIR.rglob('*.pt')))
main = ASSETS_DIR / "ARTalk_wav2vec.pt"
print(f"[weights] ready - {total} .pt files; main={main.exists()} size={main.stat().st_size/1e9 if main.exists() else 0:.2f}GB", flush=True)
if not main.exists():
raise FileNotFoundError(f"Main weight not placed: {main}")
# ═══════════════════════════════════════════════════════════
# ARTAvatarInferEngine — 원본 inference.py 와 동일 구조
# ═══════════════════════════════════════════════════════════
class ARTAvatarInferEngine:
"""ARTalk 추론 + 렌더링 엔진.
원본 C:/DK/AT/ARTalk/inference.py 의 클래스를 그대로 복제하되,
assets 다운로드 전제를 갖추고 초기화됨.
"""
def __init__(self, load_gaga=True, fix_pose=False, clip_length=750, device="cuda"):
# 가중치가 없으면 다운로드
download_all_weights()
self.device = device
self.fix_pose = fix_pose
self.clip_length = clip_length
audio_encoder = "wav2vec"
ckpt = torch.load(
f"./assets/ARTalk_{audio_encoder}.pt",
map_location="cpu",
weights_only=True,
)
with open("./assets/config.json") as f:
configs = json.load(f)
configs["AR_CONFIG"]["AUDIO_ENCODER"] = audio_encoder
self.ARTalk = BitwiseARModel(configs).eval().to(device)
self.ARTalk.load_state_dict(ckpt, strict=True)
self.flame_model = FLAMEModel(
n_shape=300, n_exp=100, scale=1.0, no_lmks=True
).to(device)
self.mesh_renderer = RenderMesh(
image_size=512, faces=self.flame_model.get_faces(), scale=1.0
)
self.output_dir = f"render_results/ARTAvatar_{audio_encoder}"
os.makedirs(self.output_dir, exist_ok=True)
self.style_motion = None
if load_gaga:
from app.GAGAvatar import GAGAvatar
self.GAGAvatar = GAGAvatar().to(device)
self.GAGAvatar_flame = FLAMEModel(
n_shape=300, n_exp=100, scale=5.0, no_lmks=True
).to(device)
# Disable GAGAvatar's baked-in watermark (MIT license allows modification).
self.GAGAvatar.add_water_mark = lambda image: image
def add_custom_avatars(self, tracked_pt_path):
"""Merge a user-supplied tracked.pt into GAGAvatar.all_gagavatar_id.
Accepted formats:
1. dict-of-dicts: {"alice": {image, shapecode, transform_matrix, ...}}
→ all entries merged as-is.
2. single entry dict: {image, shapecode, transform_matrix, ...}
→ registered under key derived from filename (without 'tracked_' prefix and '.pt').
Returns list of newly registered avatar_ids.
"""
if not hasattr(self, "GAGAvatar"):
raise RuntimeError("GAGAvatar not loaded (load_gaga=False)")
loaded = torch.load(tracked_pt_path, map_location="cpu", weights_only=False)
if not isinstance(loaded, dict):
raise ValueError(f"tracked.pt must be a dict, got {type(loaded).__name__}")
required = {"image", "shapecode", "transform_matrix"}
added = []
# Detect format: dict-of-dicts vs single entry
values_are_dicts = all(isinstance(v, dict) for v in loaded.values()) if loaded else False
looks_like_single = required.issubset(loaded.keys())
if looks_like_single and not values_are_dicts:
stem = Path(tracked_pt_path).stem
if stem.startswith("tracked_"):
stem = stem[len("tracked_"):]
avatar_id = stem or "custom"
entries = {avatar_id: loaded}
elif values_are_dicts:
entries = loaded
else:
raise ValueError(
"tracked.pt must be either a single-avatar dict with keys "
"{image, shapecode, transform_matrix} or a dict-of-dicts mapping "
"avatar_id -> {image, shapecode, transform_matrix, ...}"
)
for avatar_id, entry in entries.items():
missing = required - set(entry.keys())
if missing:
raise ValueError(f"avatar '{avatar_id}' missing keys: {missing}")
self.GAGAvatar.all_gagavatar_id[avatar_id] = entry
added.append(avatar_id)
return added
def set_style_motion(self, style_motion):
if isinstance(style_motion, str):
style_motion = torch.load(
f"assets/style_motion/{style_motion}.pt",
map_location="cpu",
weights_only=True,
)
assert style_motion.shape == (50, 106), (
f"Invalid style_motion shape: {style_motion.shape}"
)
self.style_motion = style_motion[None].to(self.device)
def inference(self, audio, clip_length=None):
audio_batch = {
"audio": audio[None].to(self.device),
"style_motion": self.style_motion,
}
pred_motions = self.ARTalk.inference(audio_batch, with_gtmotion=False)[0]
clip_length = clip_length if clip_length is not None else self.clip_length
pred_motions = self.smooth_motion_savgol(pred_motions)[:clip_length]
if self.fix_pose:
pred_motions[..., 100:103] *= 0.0
pred_motions[..., 104:] *= 0.0
return pred_motions
def rendering(self, audio, pred_motions, shape_id="mesh", shape_code=None, save_name="ARTAvatar.mp4"):
pred_images = []
if shape_id == "mesh":
if shape_code is None:
shape_code = audio.new_zeros(1, 300).to(self.device).expand(
pred_motions.shape[0], -1
)
else:
assert shape_code.dim() == 2
assert shape_code.shape[0] == 1
shape_code = shape_code.to(self.device).expand(
pred_motions.shape[0], -1
)
verts = self.ARTalk.basic_vae.get_flame_verts(
self.flame_model, shape_code, pred_motions, with_global=True
)
for v in verts:
rgb = self.mesh_renderer(v[None])[0]
pred_images.append(rgb.cpu()[0] / 255.0)
else:
self.GAGAvatar.set_avatar_id(shape_id)
for motion in pred_motions:
batch = self.GAGAvatar.build_forward_batch(
motion[None], self.GAGAvatar_flame
)
rgb = self.GAGAvatar.forward_expression(batch)
pred_images.append(rgb.cpu()[0])
pred_images = torch.stack(pred_images)
dump_path = os.path.join(self.output_dir, f"{save_name}.mp4")
audio_slice = audio[: int(pred_images.shape[0] / 25.0 * 16000)]
write_video(pred_images * 255.0, dump_path, 25.0, audio_slice, 16000, "aac")
return dump_path
@staticmethod
def smooth_motion_savgol(motion_codes):
from scipy.signal import savgol_filter
motion_np = motion_codes.clone().detach().cpu().numpy()
motion_np_smoothed = savgol_filter(
motion_np, window_length=5, polyorder=2, axis=0
)
motion_np_smoothed[..., 100:103] = savgol_filter(
motion_np[..., 100:103], window_length=9, polyorder=3, axis=0
)
return torch.tensor(motion_np_smoothed).type_as(motion_codes)