DyStream / tools /visualization_0416 /img_to_latent.py
robinwitch's picture
upload ckpt
872b1a7
import sys
import os
# 获取项目根目录并添加到 sys.path 最前面,确保导入正确的 utils 模块
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as T
from omegaconf import OmegaConf
import fire
def init_fn(config_path):
from utils import instantiate
transform = T.Compose([T.Resize((512, 512)), T.ToTensor(), T.Normalize([0.5], [0.5])])
config = OmegaConf.load(config_path)
module = instantiate(config.model, instantiate_module=False)
model = module(config=config)
checkpoint = torch.load(config.resume_ckpt, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"], strict=False)
model.eval()
motion_encoder = model.motion_encoder
return {"transform": transform, "motion_encoder": motion_encoder}
def extract_motion_latent(
mask_image_path='./test_case/test_img_masked.png',
config_path='./configs/head_animator_best_0506.yaml',
save_npz_path='./test_case/test_img_resize.npz',
version="0506"):
sys.path.insert(0, f'./utils/model_{version}')
config_path = config_path.replace("0506", version)
context = init_fn(config_path)
transform = context["transform"]
motion_encoder = context["motion_encoder"]
img = Image.open(mask_image_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
latent = motion_encoder(img_tensor)[0] # [1, 512]
latent_np = latent.numpy()
# 如果文件已存在,先加载原有数据
if os.path.exists(save_npz_path):
existing_data = np.load(save_npz_path, allow_pickle=True)
data_dict = dict(existing_data)
existing_data.close() # 关闭文件
else:
data_dict = {}
# 更新或添加新的键值对
data_dict.update({
'video_id': os.path.basename(save_npz_path)[:-4],
'mask_img_path': mask_image_path,
'ref_img_path': save_npz_path.replace('npz', 'png'),
'motion_latent': latent_np
})
# 保存更新后的数据
np.savez(save_npz_path, **data_dict)
# np.savez(
# save_npz_path,
# video_id=os.path.basename(save_npz_path)[:-4],
# mask_img_path=mask_image_path,
# ref_img_path=save_npz_path.replace('npz', 'png'),
# motion_latent=latent_np
# )
if __name__ == '__main__':
fire.Fire(extract_motion_latent)