| | |
| |
|
| | """ |
| | config dataclass used for inference |
| | """ |
| |
|
| | import os.path as osp |
| | from dataclasses import dataclass |
| | from typing import Literal, Tuple |
| | from .base_config import PrintableConfig, make_abs_path |
| |
|
| |
|
| | @dataclass(repr=False) |
| | class InferenceConfig(PrintableConfig): |
| | models_config: str = make_abs_path('./models.yaml') |
| | checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') |
| | checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') |
| | checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') |
| | checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') |
| |
|
| | checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') |
| | flag_use_half_precision: bool = True |
| |
|
| | flag_lip_zero: bool = True |
| | lip_zero_threshold: float = 0.03 |
| |
|
| | flag_eye_retargeting: bool = False |
| | flag_lip_retargeting: bool = False |
| | flag_stitching: bool = True |
| |
|
| | flag_relative: bool = True |
| | anchor_frame: int = 0 |
| |
|
| | input_shape: Tuple[int, int] = (256, 256) |
| | output_format: Literal['mp4', 'gif'] = 'mp4' |
| | output_fps: int = 30 |
| | crf: int = 15 |
| |
|
| | flag_write_result: bool = True |
| | flag_pasteback: bool = True |
| | mask_crop = None |
| | flag_write_gif: bool = False |
| | size_gif: int = 256 |
| | ref_max_shape: int = 1280 |
| | ref_shape_n: int = 2 |
| |
|
| | device_id: int = 0 |
| | flag_do_crop: bool = False |
| | flag_do_rot: bool = True |
| |
|