Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| import numpy as np | |
| def load_pkl(pkl): | |
| with open(pkl, "rb") as f: | |
| return pickle.load(f) | |
| def parse_cfg(cfg_pkl, data_root, replace_cfg=None): | |
| def _check_path(p): | |
| if os.path.isfile(p): | |
| return p | |
| else: | |
| return os.path.join(data_root, p) | |
| cfg = load_pkl(cfg_pkl) | |
| # --- | |
| # replace cfg for debug | |
| if isinstance(replace_cfg, dict): | |
| for k, v in replace_cfg.items(): | |
| if not isinstance(v, dict): | |
| continue | |
| for kk, vv in v.items(): | |
| cfg[k][kk] = vv | |
| # --- | |
| base_cfg = cfg["base_cfg"] | |
| audio2motion_cfg = cfg["audio2motion_cfg"] | |
| default_kwargs = cfg["default_kwargs"] | |
| for k in base_cfg: | |
| if k == "landmark478_cfg": | |
| for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]: | |
| if kk in base_cfg[k] and base_cfg[k][kk]: | |
| base_cfg[k][kk] = _check_path(base_cfg[k][kk]) | |
| else: | |
| base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"]) | |
| audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"]) | |
| avatar_registrar_cfg = { | |
| k: base_cfg[k] | |
| for k in [ | |
| "insightface_det_cfg", | |
| "landmark106_cfg", | |
| "landmark203_cfg", | |
| "landmark478_cfg", | |
| "appearance_extractor_cfg", | |
| "motion_extractor_cfg", | |
| ] | |
| } | |
| stitch_network_cfg = base_cfg["stitch_network_cfg"] | |
| warp_network_cfg = base_cfg["warp_network_cfg"] | |
| decoder_cfg = base_cfg["decoder_cfg"] | |
| condition_handler_cfg = { | |
| k: audio2motion_cfg[k] | |
| for k in [ | |
| "use_emo", | |
| "use_sc", | |
| "use_eye_open", | |
| "use_eye_ball", | |
| "seq_frames", | |
| ] | |
| } | |
| lmdm_cfg = { | |
| k: audio2motion_cfg[k] | |
| for k in [ | |
| "model_path", | |
| "device", | |
| "motion_feat_dim", | |
| "audio_feat_dim", | |
| "seq_frames", | |
| ] | |
| } | |
| w2f_type = audio2motion_cfg["w2f_type"] | |
| wav2feat_cfg = { | |
| "w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"], | |
| "w2f_type": w2f_type, | |
| } | |
| return [ | |
| avatar_registrar_cfg, | |
| condition_handler_cfg, | |
| lmdm_cfg, | |
| stitch_network_cfg, | |
| warp_network_cfg, | |
| decoder_cfg, | |
| wav2feat_cfg, | |
| default_kwargs, | |
| ] | |
| def print_cfg(**kwargs): | |
| for k, v in kwargs.items(): | |
| if k == "ch_info": | |
| print(k, type(v)) | |
| elif k == "ctrl_info": | |
| print(k, type(v), len(v)) | |
| else: | |
| if isinstance(v, np.ndarray): | |
| print(k, type(v), v.shape) | |
| else: | |
| print(k, type(v), v) | |