| |
|
|
| """ |
| Benchmark the inference speed of each module in LivePortrait. |
| |
| TODO: heavy GPT style, need to refactor |
| """ |
|
|
| import yaml |
| import torch |
| import time |
| import numpy as np |
| from src.utils.helper import load_model, concat_feat |
| from src.config.inference_config import InferenceConfig |
|
|
|
|
| def initialize_inputs(batch_size=1): |
| """ |
| Generate random input tensors and move them to GPU |
| """ |
| feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() |
| kp_source = torch.randn(batch_size, 21, 3).cuda().half() |
| kp_driving = torch.randn(batch_size, 21, 3).cuda().half() |
| source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() |
| generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() |
| eye_close_ratio = torch.randn(batch_size, 3).cuda().half() |
| lip_close_ratio = torch.randn(batch_size, 2).cuda().half() |
| feat_stitching = concat_feat(kp_source, kp_driving).half() |
| feat_eye = concat_feat(kp_source, eye_close_ratio).half() |
| feat_lip = concat_feat(kp_source, lip_close_ratio).half() |
|
|
| inputs = { |
| 'feature_3d': feature_3d, |
| 'kp_source': kp_source, |
| 'kp_driving': kp_driving, |
| 'source_image': source_image, |
| 'generator_input': generator_input, |
| 'feat_stitching': feat_stitching, |
| 'feat_eye': feat_eye, |
| 'feat_lip': feat_lip |
| } |
|
|
| return inputs |
|
|
|
|
| def load_and_compile_models(cfg, model_config): |
| """ |
| Load and compile models for inference |
| """ |
| appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') |
| motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') |
| warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') |
| spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') |
| stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') |
|
|
| models_with_params = [ |
| ('Appearance Feature Extractor', appearance_feature_extractor), |
| ('Motion Extractor', motion_extractor), |
| ('Warping Network', warping_module), |
| ('SPADE Decoder', spade_generator) |
| ] |
|
|
| compiled_models = {} |
| for name, model in models_with_params: |
| model = model.half() |
| model = torch.compile(model, mode='max-autotune') |
| model.eval() |
| compiled_models[name] = model |
|
|
| retargeting_models = ['stitching', 'eye', 'lip'] |
| for retarget in retargeting_models: |
| module = stitching_retargeting_module[retarget].half() |
| module = torch.compile(module, mode='max-autotune') |
| module.eval() |
| stitching_retargeting_module[retarget] = module |
|
|
| return compiled_models, stitching_retargeting_module |
|
|
|
|
| def warm_up_models(compiled_models, stitching_retargeting_module, inputs): |
| """ |
| Warm up models to prepare them for benchmarking |
| """ |
| print("Warm up start!") |
| with torch.no_grad(): |
| for _ in range(10): |
| compiled_models['Appearance Feature Extractor'](inputs['source_image']) |
| compiled_models['Motion Extractor'](inputs['source_image']) |
| compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) |
| compiled_models['SPADE Decoder'](inputs['generator_input']) |
| stitching_retargeting_module['stitching'](inputs['feat_stitching']) |
| stitching_retargeting_module['eye'](inputs['feat_eye']) |
| stitching_retargeting_module['lip'](inputs['feat_lip']) |
| print("Warm up end!") |
|
|
|
|
| def measure_inference_times(compiled_models, stitching_retargeting_module, inputs): |
| """ |
| Measure inference times for each model |
| """ |
| times = {name: [] for name in compiled_models.keys()} |
| times['Retargeting Models'] = [] |
|
|
| overall_times = [] |
|
|
| with torch.no_grad(): |
| for _ in range(100): |
| torch.cuda.synchronize() |
| overall_start = time.time() |
|
|
| start = time.time() |
| compiled_models['Appearance Feature Extractor'](inputs['source_image']) |
| torch.cuda.synchronize() |
| times['Appearance Feature Extractor'].append(time.time() - start) |
|
|
| start = time.time() |
| compiled_models['Motion Extractor'](inputs['source_image']) |
| torch.cuda.synchronize() |
| times['Motion Extractor'].append(time.time() - start) |
|
|
| start = time.time() |
| compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) |
| torch.cuda.synchronize() |
| times['Warping Network'].append(time.time() - start) |
|
|
| start = time.time() |
| compiled_models['SPADE Decoder'](inputs['generator_input']) |
| torch.cuda.synchronize() |
| times['SPADE Decoder'].append(time.time() - start) |
|
|
| start = time.time() |
| stitching_retargeting_module['stitching'](inputs['feat_stitching']) |
| stitching_retargeting_module['eye'](inputs['feat_eye']) |
| stitching_retargeting_module['lip'](inputs['feat_lip']) |
| torch.cuda.synchronize() |
| times['Retargeting Models'].append(time.time() - start) |
|
|
| overall_times.append(time.time() - overall_start) |
|
|
| return times, overall_times |
|
|
|
|
| def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times): |
| """ |
| Print benchmark results with average and standard deviation of inference times |
| """ |
| average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()} |
| std_times = {name: np.std(times[name]) * 1000 for name in times.keys()} |
|
|
| for name, model in compiled_models.items(): |
| num_params = sum(p.numel() for p in model.parameters()) |
| num_params_in_millions = num_params / 1e6 |
| print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M") |
|
|
| for index, retarget in enumerate(retargeting_models): |
| num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters()) |
| num_params_in_millions = num_params / 1e6 |
| print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M") |
|
|
| for name, avg_time in average_times.items(): |
| std_time = std_times[name] |
| print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)") |
|
|
|
|
| def main(): |
| """ |
| Main function to benchmark speed and model parameters |
| """ |
| |
| inputs = initialize_inputs() |
|
|
| |
| cfg = InferenceConfig(device_id=0) |
| model_config_path = cfg.models_config |
| with open(model_config_path, 'r') as file: |
| model_config = yaml.safe_load(file) |
|
|
| |
| compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) |
|
|
| |
| warm_up_models(compiled_models, stitching_retargeting_module, inputs) |
|
|
| |
| times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs) |
|
|
| |
| print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|