| | |
| |
|
| | """ |
| | utility functions and classes to handle feature extraction and model loading |
| | """ |
| |
|
| | import os |
| | import os.path as osp |
| | import cv2 |
| | import torch |
| | import yaml |
| | import argparse |
| | import locale |
| | import numpy as np |
| | from PIL import Image |
| | from rich.console import Console |
| | from collections import OrderedDict |
| |
|
| | from ..live_portrait.spade_generator import SPADEDecoder |
| | from ..live_portrait.warping_network import WarpingNetwork |
| | from ..live_portrait.motion_extractor import MotionExtractor |
| | from ..live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor |
| | from ..live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork |
| | from .rprint import rlog as log |
| |
|
| |
|
| | def suffix(filename): |
| | """a.jpg -> jpg""" |
| | pos = filename.rfind(".") |
| | if pos == -1: |
| | return "" |
| | return filename[pos + 1:] |
| |
|
| |
|
| | def prefix(filename): |
| | """a.jpg -> a""" |
| | pos = filename.rfind(".") |
| | if pos == -1: |
| | return filename |
| | return filename[:pos] |
| |
|
| |
|
| | def basename(filename): |
| | """a/b/c.jpg -> c""" |
| | return prefix(osp.basename(filename)) |
| |
|
| |
|
| | def is_video(file_path): |
| | if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): |
| | return True |
| | return False |
| |
|
| | def is_template(file_path): |
| | if file_path.endswith(".pkl"): |
| | return True |
| | return False |
| |
|
| |
|
| | def mkdir(d, log=False): |
| | |
| | if not osp.exists(d): |
| | os.makedirs(d, exist_ok=True) |
| | if log: |
| | print(f"Make dir: {d}") |
| | return d |
| |
|
| |
|
| | def squeeze_tensor_to_numpy(tensor): |
| | out = tensor.data.squeeze(0).cpu().numpy() |
| | return out |
| |
|
| |
|
| | def dct2cuda(dct: dict, device_id: int): |
| | for key in dct: |
| | dct[key] = torch.tensor(dct[key]).cuda(device_id) |
| | return dct |
| |
|
| |
|
| | def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: |
| | """ |
| | kp_source: (bs, k, 3) |
| | kp_driving: (bs, k, 3) |
| | Return: (bs, 2k*3) |
| | """ |
| | bs_src = kp_source.shape[0] |
| | bs_dri = kp_driving.shape[0] |
| | assert bs_src == bs_dri, 'batch size must be equal' |
| |
|
| | feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) |
| | return feat |
| |
|
| |
|
| | |
| | def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i): |
| | if config.relative: |
| | new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s |
| | new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp']) |
| | else: |
| | new_rotation = R_t_i |
| | new_expression = t_i_kp_info['exp'] |
| | new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t']) |
| | new_translation[..., 2].fill_(0) |
| | new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale']) |
| | return new_rotation, new_expression, new_translation, new_scale |
| |
|
| | def load_description(fp): |
| | with open(fp, 'r', encoding='utf-8') as f: |
| | content = f.read() |
| | return content |
| |
|
| |
|
| | def resize_to_limit(img, max_dim=1280, n=2): |
| | h, w = img.shape[:2] |
| | if max_dim > 0 and max(h, w) > max_dim: |
| | if h > w: |
| | new_h = max_dim |
| | new_w = int(w * (max_dim / h)) |
| | else: |
| | new_w = max_dim |
| | new_h = int(h * (max_dim / w)) |
| | img = cv2.resize(img, (new_w, new_h)) |
| | n = max(n, 1) |
| | new_h = img.shape[0] - (img.shape[0] % n) |
| | new_w = img.shape[1] - (img.shape[1] % n) |
| | if new_h == 0 or new_w == 0: |
| | return img |
| | if new_h != img.shape[0] or new_w != img.shape[1]: |
| | img = img[:new_h, :new_w] |
| | return img |
| |
|
| |
|
| | def load_yaml(file_path): |
| | encoding = locale.getpreferredencoding(False) |
| | with open(file_path, 'r', encoding=encoding) as file: |
| | data = yaml.safe_load(file) |
| | return data |
| |
|
| |
|
| | def str2bool(v): |
| | if isinstance(v, bool): |
| | return v |
| | if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| | return True |
| | elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| | return False |
| | else: |
| | raise argparse.ArgumentTypeError('Boolean value expected.') |
| |
|