| |
|
|
| """ |
| config dataclass used for inference |
| """ |
|
|
| import os.path as osp |
| import cv2 |
| from numpy import ndarray |
| from dataclasses import dataclass |
| from typing import Literal, Tuple |
| from .base_config import PrintableConfig, make_abs_path |
|
|
|
|
| @dataclass(repr=False) |
| class InferenceConfig(PrintableConfig): |
| |
| models_config: str = make_abs_path('./models.yaml') |
| checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') |
| checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') |
| checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') |
| checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') |
| checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') |
|
|
| |
| flag_use_half_precision: bool = True |
| flag_crop_driving_video: bool = False |
| device_id: int = 0 |
| flag_lip_zero: bool = False |
| flag_eye_retargeting: bool = False |
| flag_lip_retargeting: bool = False |
| flag_stitching: bool = False |
| flag_relative_motion: bool = False |
| flag_pasteback: bool = False |
| flag_do_crop: bool = False |
| flag_do_rot: bool = False |
| flag_force_cpu: bool = False |
| flag_do_torch_compile: bool = False |
|
|
| |
| lip_zero_threshold: float = 0.03 |
| anchor_frame: int = 0 |
|
|
| input_shape: Tuple[int, int] = (256, 256) |
| output_format: Literal['mp4', 'gif'] = 'mp4' |
| crf: int = 15 |
| output_fps: int = 25 |
|
|
| mask_crop: ndarray = cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR) |
| size_gif: int = 256 |
| source_max_dim: int = 1280 |
| source_division: int = 2 |
|
|