"""ARTalk inference engine for HuggingFace ZeroGPU deployment. 원본 C:/DK/AT/ARTalk/inference.py 의 ARTAvatarInferEngine 을 기반으로, ZeroGPU 환경에 맞춰 다음을 수정: 1. 모델 가중치를 HF Hub 에서 자동 다운로드 (build_resources.sh 대체) 2. FLAME 라이선스 인터랙티브 프롬프트 제거 3. 기본 동작은 동일 (inference + rendering) app.py 에서 이 모듈의 ARTAvatarInferEngine 를 import 해서 사용. """ import os import json import shutil import torch import numpy as np from pathlib import Path from huggingface_hub import hf_hub_download from app import BitwiseARModel from app.flame_model import FLAMEModel, RenderMesh from app.utils_videos import write_video # ═══════════════════════════════════════════════════════════ # Model weights — HuggingFace Hub 에서 런타임 다운로드 # ═══════════════════════════════════════════════════════════ ARTALK_REPO = "xg-chu/ARTalk" GAGA_REPO = "xg-chu/GAGAvatar" ASSETS_DIR = Path("assets") # (repo_id, filename_in_repo, local_path_under_assets) WEIGHTS_MANIFEST = [ (GAGA_REPO, "assets/FLAME_with_eye.pt", "FLAME_with_eye.pt"), (GAGA_REPO, "assets/GAGAvatar.pt", "GAGAvatar/GAGAvatar.pt"), (ARTALK_REPO, "ARTalk_wav2vec.pt", "ARTalk_wav2vec.pt"), (ARTALK_REPO, "config.json", "config.json"), (ARTALK_REPO, "GAGAvatar/tracked.pt", "GAGAvatar/tracked.pt"), ] STYLE_MOTIONS = [ "angry_0", "curious_0", "doubtful_0", "doubtful_1", "happy_0", "happy_1", "happy_2", "natural_0", "natural_1", "natural_2", "natural_3", "natural_4", "natural_5", "natural_6", "natural_7", ] def _place_file(downloaded: str, target: Path): """HF cache 에서 다운로드된 파일을 target 경로에 배치. os.link 는 cross-filesystem 에서 실패 가능 → shutil.copy 사용. """ target.parent.mkdir(parents=True, exist_ok=True) if target.exists() or target.is_symlink(): target.unlink() shutil.copy(downloaded, target) assert target.exists(), f"copy failed: {target}" def download_all_weights(): """ARTalk 실행에 필요한 모든 모델 가중치를 HF Hub 에서 다운로드. 이미 존재하면 스킵. HF hub 는 자체 캐시 사용. FLAME 라이선스는 xg-chu repos 가 이미 동의한 상태로 호스팅 — 추가 프롬프트 없음. """ ASSETS_DIR.mkdir(exist_ok=True) (ASSETS_DIR / "GAGAvatar").mkdir(exist_ok=True) (ASSETS_DIR / "style_motion").mkdir(exist_ok=True) # Main weights for repo_id, repo_file, local_rel in WEIGHTS_MANIFEST: local_path = ASSETS_DIR / local_rel if local_path.exists() and local_path.stat().st_size > 0: print(f"[weights] exists: {local_path}") continue print(f"[weights] downloading {repo_id}:{repo_file} -> {local_path}", flush=True) try: downloaded = hf_hub_download(repo_id=repo_id, filename=repo_file) _place_file(downloaded, local_path) size_mb = local_path.stat().st_size / 1e6 print(f"[weights] OK: {local_path} ({size_mb:.1f}MB)", flush=True) except Exception as e: print(f"[weights] FAIL {repo_id}:{repo_file} -> {e}", flush=True) raise # Style motion files for style_name in STYLE_MOTIONS: local_path = ASSETS_DIR / "style_motion" / f"{style_name}.pt" if local_path.exists(): continue try: downloaded = hf_hub_download( repo_id=ARTALK_REPO, filename=f"style_motion/{style_name}.pt", ) _place_file(downloaded, local_path) except Exception as e: print(f"[weights] warning: style {style_name} unavailable ({e})") # Verification total = len(list(ASSETS_DIR.rglob('*.pt'))) main = ASSETS_DIR / "ARTalk_wav2vec.pt" print(f"[weights] ready - {total} .pt files; main={main.exists()} size={main.stat().st_size/1e9 if main.exists() else 0:.2f}GB", flush=True) if not main.exists(): raise FileNotFoundError(f"Main weight not placed: {main}") # ═══════════════════════════════════════════════════════════ # ARTAvatarInferEngine — 원본 inference.py 와 동일 구조 # ═══════════════════════════════════════════════════════════ class ARTAvatarInferEngine: """ARTalk 추론 + 렌더링 엔진. 원본 C:/DK/AT/ARTalk/inference.py 의 클래스를 그대로 복제하되, assets 다운로드 전제를 갖추고 초기화됨. """ def __init__(self, load_gaga=True, fix_pose=False, clip_length=750, device="cuda"): # 가중치가 없으면 다운로드 download_all_weights() self.device = device self.fix_pose = fix_pose self.clip_length = clip_length audio_encoder = "wav2vec" ckpt = torch.load( f"./assets/ARTalk_{audio_encoder}.pt", map_location="cpu", weights_only=True, ) with open("./assets/config.json") as f: configs = json.load(f) configs["AR_CONFIG"]["AUDIO_ENCODER"] = audio_encoder self.ARTalk = BitwiseARModel(configs).eval().to(device) self.ARTalk.load_state_dict(ckpt, strict=True) self.flame_model = FLAMEModel( n_shape=300, n_exp=100, scale=1.0, no_lmks=True ).to(device) self.mesh_renderer = RenderMesh( image_size=512, faces=self.flame_model.get_faces(), scale=1.0 ) self.output_dir = f"render_results/ARTAvatar_{audio_encoder}" os.makedirs(self.output_dir, exist_ok=True) self.style_motion = None if load_gaga: from app.GAGAvatar import GAGAvatar self.GAGAvatar = GAGAvatar().to(device) self.GAGAvatar_flame = FLAMEModel( n_shape=300, n_exp=100, scale=5.0, no_lmks=True ).to(device) # Disable GAGAvatar's baked-in watermark (MIT license allows modification). self.GAGAvatar.add_water_mark = lambda image: image def add_custom_avatars(self, tracked_pt_path): """Merge a user-supplied tracked.pt into GAGAvatar.all_gagavatar_id. Accepted formats: 1. dict-of-dicts: {"alice": {image, shapecode, transform_matrix, ...}} → all entries merged as-is. 2. single entry dict: {image, shapecode, transform_matrix, ...} → registered under key derived from filename (without 'tracked_' prefix and '.pt'). Returns list of newly registered avatar_ids. """ if not hasattr(self, "GAGAvatar"): raise RuntimeError("GAGAvatar not loaded (load_gaga=False)") loaded = torch.load(tracked_pt_path, map_location="cpu", weights_only=False) if not isinstance(loaded, dict): raise ValueError(f"tracked.pt must be a dict, got {type(loaded).__name__}") required = {"image", "shapecode", "transform_matrix"} added = [] # Detect format: dict-of-dicts vs single entry values_are_dicts = all(isinstance(v, dict) for v in loaded.values()) if loaded else False looks_like_single = required.issubset(loaded.keys()) if looks_like_single and not values_are_dicts: stem = Path(tracked_pt_path).stem if stem.startswith("tracked_"): stem = stem[len("tracked_"):] avatar_id = stem or "custom" entries = {avatar_id: loaded} elif values_are_dicts: entries = loaded else: raise ValueError( "tracked.pt must be either a single-avatar dict with keys " "{image, shapecode, transform_matrix} or a dict-of-dicts mapping " "avatar_id -> {image, shapecode, transform_matrix, ...}" ) for avatar_id, entry in entries.items(): missing = required - set(entry.keys()) if missing: raise ValueError(f"avatar '{avatar_id}' missing keys: {missing}") self.GAGAvatar.all_gagavatar_id[avatar_id] = entry added.append(avatar_id) return added def set_style_motion(self, style_motion): if isinstance(style_motion, str): style_motion = torch.load( f"assets/style_motion/{style_motion}.pt", map_location="cpu", weights_only=True, ) assert style_motion.shape == (50, 106), ( f"Invalid style_motion shape: {style_motion.shape}" ) self.style_motion = style_motion[None].to(self.device) def inference(self, audio, clip_length=None): audio_batch = { "audio": audio[None].to(self.device), "style_motion": self.style_motion, } pred_motions = self.ARTalk.inference(audio_batch, with_gtmotion=False)[0] clip_length = clip_length if clip_length is not None else self.clip_length pred_motions = self.smooth_motion_savgol(pred_motions)[:clip_length] if self.fix_pose: pred_motions[..., 100:103] *= 0.0 pred_motions[..., 104:] *= 0.0 return pred_motions def rendering(self, audio, pred_motions, shape_id="mesh", shape_code=None, save_name="ARTAvatar.mp4"): pred_images = [] if shape_id == "mesh": if shape_code is None: shape_code = audio.new_zeros(1, 300).to(self.device).expand( pred_motions.shape[0], -1 ) else: assert shape_code.dim() == 2 assert shape_code.shape[0] == 1 shape_code = shape_code.to(self.device).expand( pred_motions.shape[0], -1 ) verts = self.ARTalk.basic_vae.get_flame_verts( self.flame_model, shape_code, pred_motions, with_global=True ) for v in verts: rgb = self.mesh_renderer(v[None])[0] pred_images.append(rgb.cpu()[0] / 255.0) else: self.GAGAvatar.set_avatar_id(shape_id) for motion in pred_motions: batch = self.GAGAvatar.build_forward_batch( motion[None], self.GAGAvatar_flame ) rgb = self.GAGAvatar.forward_expression(batch) pred_images.append(rgb.cpu()[0]) pred_images = torch.stack(pred_images) dump_path = os.path.join(self.output_dir, f"{save_name}.mp4") audio_slice = audio[: int(pred_images.shape[0] / 25.0 * 16000)] write_video(pred_images * 255.0, dump_path, 25.0, audio_slice, 16000, "aac") return dump_path @staticmethod def smooth_motion_savgol(motion_codes): from scipy.signal import savgol_filter motion_np = motion_codes.clone().detach().cpu().numpy() motion_np_smoothed = savgol_filter( motion_np, window_length=5, polyorder=2, axis=0 ) motion_np_smoothed[..., 100:103] = savgol_filter( motion_np[..., 100:103], window_length=9, polyorder=3, axis=0 ) return torch.tensor(motion_np_smoothed).type_as(motion_codes)