Spaces:
Paused
Paused
| """ARTalk inference engine for HuggingFace ZeroGPU deployment. | |
| 원본 C:/DK/AT/ARTalk/inference.py 의 ARTAvatarInferEngine 을 기반으로, | |
| ZeroGPU 환경에 맞춰 다음을 수정: | |
| 1. 모델 가중치를 HF Hub 에서 자동 다운로드 (build_resources.sh 대체) | |
| 2. FLAME 라이선스 인터랙티브 프롬프트 제거 | |
| 3. 기본 동작은 동일 (inference + rendering) | |
| app.py 에서 이 모듈의 ARTAvatarInferEngine 를 import 해서 사용. | |
| """ | |
| import os | |
| import json | |
| import shutil | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from app import BitwiseARModel | |
| from app.flame_model import FLAMEModel, RenderMesh | |
| from app.utils_videos import write_video | |
| # ═══════════════════════════════════════════════════════════ | |
| # Model weights — HuggingFace Hub 에서 런타임 다운로드 | |
| # ═══════════════════════════════════════════════════════════ | |
| ARTALK_REPO = "xg-chu/ARTalk" | |
| GAGA_REPO = "xg-chu/GAGAvatar" | |
| ASSETS_DIR = Path("assets") | |
| # (repo_id, filename_in_repo, local_path_under_assets) | |
| WEIGHTS_MANIFEST = [ | |
| (GAGA_REPO, "assets/FLAME_with_eye.pt", "FLAME_with_eye.pt"), | |
| (GAGA_REPO, "assets/GAGAvatar.pt", "GAGAvatar/GAGAvatar.pt"), | |
| (ARTALK_REPO, "ARTalk_wav2vec.pt", "ARTalk_wav2vec.pt"), | |
| (ARTALK_REPO, "config.json", "config.json"), | |
| (ARTALK_REPO, "GAGAvatar/tracked.pt", "GAGAvatar/tracked.pt"), | |
| ] | |
| STYLE_MOTIONS = [ | |
| "angry_0", "curious_0", "doubtful_0", "doubtful_1", | |
| "happy_0", "happy_1", "happy_2", | |
| "natural_0", "natural_1", "natural_2", "natural_3", | |
| "natural_4", "natural_5", "natural_6", "natural_7", | |
| ] | |
| def _place_file(downloaded: str, target: Path): | |
| """HF cache 에서 다운로드된 파일을 target 경로에 배치. | |
| os.link 는 cross-filesystem 에서 실패 가능 → shutil.copy 사용. | |
| """ | |
| target.parent.mkdir(parents=True, exist_ok=True) | |
| if target.exists() or target.is_symlink(): | |
| target.unlink() | |
| shutil.copy(downloaded, target) | |
| assert target.exists(), f"copy failed: {target}" | |
| def download_all_weights(): | |
| """ARTalk 실행에 필요한 모든 모델 가중치를 HF Hub 에서 다운로드. | |
| 이미 존재하면 스킵. HF hub 는 자체 캐시 사용. | |
| FLAME 라이선스는 xg-chu repos 가 이미 동의한 상태로 호스팅 — 추가 프롬프트 없음. | |
| """ | |
| ASSETS_DIR.mkdir(exist_ok=True) | |
| (ASSETS_DIR / "GAGAvatar").mkdir(exist_ok=True) | |
| (ASSETS_DIR / "style_motion").mkdir(exist_ok=True) | |
| # Main weights | |
| for repo_id, repo_file, local_rel in WEIGHTS_MANIFEST: | |
| local_path = ASSETS_DIR / local_rel | |
| if local_path.exists() and local_path.stat().st_size > 0: | |
| print(f"[weights] exists: {local_path}") | |
| continue | |
| print(f"[weights] downloading {repo_id}:{repo_file} -> {local_path}", flush=True) | |
| try: | |
| downloaded = hf_hub_download(repo_id=repo_id, filename=repo_file) | |
| _place_file(downloaded, local_path) | |
| size_mb = local_path.stat().st_size / 1e6 | |
| print(f"[weights] OK: {local_path} ({size_mb:.1f}MB)", flush=True) | |
| except Exception as e: | |
| print(f"[weights] FAIL {repo_id}:{repo_file} -> {e}", flush=True) | |
| raise | |
| # Style motion files | |
| for style_name in STYLE_MOTIONS: | |
| local_path = ASSETS_DIR / "style_motion" / f"{style_name}.pt" | |
| if local_path.exists(): | |
| continue | |
| try: | |
| downloaded = hf_hub_download( | |
| repo_id=ARTALK_REPO, | |
| filename=f"style_motion/{style_name}.pt", | |
| ) | |
| _place_file(downloaded, local_path) | |
| except Exception as e: | |
| print(f"[weights] warning: style {style_name} unavailable ({e})") | |
| # Verification | |
| total = len(list(ASSETS_DIR.rglob('*.pt'))) | |
| main = ASSETS_DIR / "ARTalk_wav2vec.pt" | |
| print(f"[weights] ready - {total} .pt files; main={main.exists()} size={main.stat().st_size/1e9 if main.exists() else 0:.2f}GB", flush=True) | |
| if not main.exists(): | |
| raise FileNotFoundError(f"Main weight not placed: {main}") | |
| # ═══════════════════════════════════════════════════════════ | |
| # ARTAvatarInferEngine — 원본 inference.py 와 동일 구조 | |
| # ═══════════════════════════════════════════════════════════ | |
| class ARTAvatarInferEngine: | |
| """ARTalk 추론 + 렌더링 엔진. | |
| 원본 C:/DK/AT/ARTalk/inference.py 의 클래스를 그대로 복제하되, | |
| assets 다운로드 전제를 갖추고 초기화됨. | |
| """ | |
| def __init__(self, load_gaga=True, fix_pose=False, clip_length=750, device="cuda"): | |
| # 가중치가 없으면 다운로드 | |
| download_all_weights() | |
| self.device = device | |
| self.fix_pose = fix_pose | |
| self.clip_length = clip_length | |
| audio_encoder = "wav2vec" | |
| ckpt = torch.load( | |
| f"./assets/ARTalk_{audio_encoder}.pt", | |
| map_location="cpu", | |
| weights_only=True, | |
| ) | |
| with open("./assets/config.json") as f: | |
| configs = json.load(f) | |
| configs["AR_CONFIG"]["AUDIO_ENCODER"] = audio_encoder | |
| self.ARTalk = BitwiseARModel(configs).eval().to(device) | |
| self.ARTalk.load_state_dict(ckpt, strict=True) | |
| self.flame_model = FLAMEModel( | |
| n_shape=300, n_exp=100, scale=1.0, no_lmks=True | |
| ).to(device) | |
| self.mesh_renderer = RenderMesh( | |
| image_size=512, faces=self.flame_model.get_faces(), scale=1.0 | |
| ) | |
| self.output_dir = f"render_results/ARTAvatar_{audio_encoder}" | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| self.style_motion = None | |
| if load_gaga: | |
| from app.GAGAvatar import GAGAvatar | |
| self.GAGAvatar = GAGAvatar().to(device) | |
| self.GAGAvatar_flame = FLAMEModel( | |
| n_shape=300, n_exp=100, scale=5.0, no_lmks=True | |
| ).to(device) | |
| # Disable GAGAvatar's baked-in watermark (MIT license allows modification). | |
| self.GAGAvatar.add_water_mark = lambda image: image | |
| def add_custom_avatars(self, tracked_pt_path): | |
| """Merge a user-supplied tracked.pt into GAGAvatar.all_gagavatar_id. | |
| Accepted formats: | |
| 1. dict-of-dicts: {"alice": {image, shapecode, transform_matrix, ...}} | |
| → all entries merged as-is. | |
| 2. single entry dict: {image, shapecode, transform_matrix, ...} | |
| → registered under key derived from filename (without 'tracked_' prefix and '.pt'). | |
| Returns list of newly registered avatar_ids. | |
| """ | |
| if not hasattr(self, "GAGAvatar"): | |
| raise RuntimeError("GAGAvatar not loaded (load_gaga=False)") | |
| loaded = torch.load(tracked_pt_path, map_location="cpu", weights_only=False) | |
| if not isinstance(loaded, dict): | |
| raise ValueError(f"tracked.pt must be a dict, got {type(loaded).__name__}") | |
| required = {"image", "shapecode", "transform_matrix"} | |
| added = [] | |
| # Detect format: dict-of-dicts vs single entry | |
| values_are_dicts = all(isinstance(v, dict) for v in loaded.values()) if loaded else False | |
| looks_like_single = required.issubset(loaded.keys()) | |
| if looks_like_single and not values_are_dicts: | |
| stem = Path(tracked_pt_path).stem | |
| if stem.startswith("tracked_"): | |
| stem = stem[len("tracked_"):] | |
| avatar_id = stem or "custom" | |
| entries = {avatar_id: loaded} | |
| elif values_are_dicts: | |
| entries = loaded | |
| else: | |
| raise ValueError( | |
| "tracked.pt must be either a single-avatar dict with keys " | |
| "{image, shapecode, transform_matrix} or a dict-of-dicts mapping " | |
| "avatar_id -> {image, shapecode, transform_matrix, ...}" | |
| ) | |
| for avatar_id, entry in entries.items(): | |
| missing = required - set(entry.keys()) | |
| if missing: | |
| raise ValueError(f"avatar '{avatar_id}' missing keys: {missing}") | |
| self.GAGAvatar.all_gagavatar_id[avatar_id] = entry | |
| added.append(avatar_id) | |
| return added | |
| def set_style_motion(self, style_motion): | |
| if isinstance(style_motion, str): | |
| style_motion = torch.load( | |
| f"assets/style_motion/{style_motion}.pt", | |
| map_location="cpu", | |
| weights_only=True, | |
| ) | |
| assert style_motion.shape == (50, 106), ( | |
| f"Invalid style_motion shape: {style_motion.shape}" | |
| ) | |
| self.style_motion = style_motion[None].to(self.device) | |
| def inference(self, audio, clip_length=None): | |
| audio_batch = { | |
| "audio": audio[None].to(self.device), | |
| "style_motion": self.style_motion, | |
| } | |
| pred_motions = self.ARTalk.inference(audio_batch, with_gtmotion=False)[0] | |
| clip_length = clip_length if clip_length is not None else self.clip_length | |
| pred_motions = self.smooth_motion_savgol(pred_motions)[:clip_length] | |
| if self.fix_pose: | |
| pred_motions[..., 100:103] *= 0.0 | |
| pred_motions[..., 104:] *= 0.0 | |
| return pred_motions | |
| def rendering(self, audio, pred_motions, shape_id="mesh", shape_code=None, save_name="ARTAvatar.mp4"): | |
| pred_images = [] | |
| if shape_id == "mesh": | |
| if shape_code is None: | |
| shape_code = audio.new_zeros(1, 300).to(self.device).expand( | |
| pred_motions.shape[0], -1 | |
| ) | |
| else: | |
| assert shape_code.dim() == 2 | |
| assert shape_code.shape[0] == 1 | |
| shape_code = shape_code.to(self.device).expand( | |
| pred_motions.shape[0], -1 | |
| ) | |
| verts = self.ARTalk.basic_vae.get_flame_verts( | |
| self.flame_model, shape_code, pred_motions, with_global=True | |
| ) | |
| for v in verts: | |
| rgb = self.mesh_renderer(v[None])[0] | |
| pred_images.append(rgb.cpu()[0] / 255.0) | |
| else: | |
| self.GAGAvatar.set_avatar_id(shape_id) | |
| for motion in pred_motions: | |
| batch = self.GAGAvatar.build_forward_batch( | |
| motion[None], self.GAGAvatar_flame | |
| ) | |
| rgb = self.GAGAvatar.forward_expression(batch) | |
| pred_images.append(rgb.cpu()[0]) | |
| pred_images = torch.stack(pred_images) | |
| dump_path = os.path.join(self.output_dir, f"{save_name}.mp4") | |
| audio_slice = audio[: int(pred_images.shape[0] / 25.0 * 16000)] | |
| write_video(pred_images * 255.0, dump_path, 25.0, audio_slice, 16000, "aac") | |
| return dump_path | |
| def smooth_motion_savgol(motion_codes): | |
| from scipy.signal import savgol_filter | |
| motion_np = motion_codes.clone().detach().cpu().numpy() | |
| motion_np_smoothed = savgol_filter( | |
| motion_np, window_length=5, polyorder=2, axis=0 | |
| ) | |
| motion_np_smoothed[..., 100:103] = savgol_filter( | |
| motion_np[..., 100:103], window_length=9, polyorder=3, axis=0 | |
| ) | |
| return torch.tensor(motion_np_smoothed).type_as(motion_codes) | |