huaichang commited on
Commit
ba25f75
·
verified ·
1 Parent(s): c2b2486

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +6 -0
  3. README.md +0 -3
  4. configs/.DS_Store +0 -0
  5. configs/inference/inference_stage3.yaml +47 -0
  6. configs/prompts/personalive_offline.yaml +13 -0
  7. configs/prompts/personalive_online.yaml +28 -0
  8. demo/driving_video.mp4 +3 -0
  9. demo/ref_image.png +3 -0
  10. pose2vid_offline.py +254 -0
  11. pose2vid_online.py +323 -0
  12. pretrained_weights/.DS_Store +0 -0
  13. pretrained_weights/onnx/.DS_Store +0 -0
  14. pretrained_weights/onnx/unet_opt/unet_opt.onnx +3 -0
  15. pretrained_weights/onnx/unet_opt/unet_opt.onnx.data +3 -0
  16. pretrained_weights/personalive/denoising_unet.pth +3 -0
  17. pretrained_weights/personalive/motion_encoder.pth +3 -0
  18. pretrained_weights/personalive/motion_extractor.pth +3 -0
  19. pretrained_weights/personalive/pose_guider.pth +3 -0
  20. pretrained_weights/personalive/reference_unet.pth +3 -0
  21. pretrained_weights/personalive/temporal_module.pth +3 -0
  22. pretrained_weights/tensorrt/.DS_Store +0 -0
  23. pretrained_weights/tensorrt/unet_work(H100).engine +3 -0
  24. results/20251209--personalive_offline/concat_vid/ref_image_driving_video.mp4 +3 -0
  25. results/20251209--personalive_offline/split_vid/ref_image_driving_video.mp4 +3 -0
  26. src/.DS_Store +0 -0
  27. src/__pycache__/wrapper.cpython-310.pyc +0 -0
  28. src/__pycache__/wrapper_trt.cpython-310.pyc +0 -0
  29. src/liveportrait/__pycache__/camera.cpython-310.pyc +0 -0
  30. src/liveportrait/__pycache__/camera.cpython-39.pyc +0 -0
  31. src/liveportrait/__pycache__/convnextv2.cpython-310.pyc +0 -0
  32. src/liveportrait/__pycache__/convnextv2.cpython-39.pyc +0 -0
  33. src/liveportrait/__pycache__/motion_extractor.cpython-310.pyc +0 -0
  34. src/liveportrait/__pycache__/motion_extractor.cpython-39.pyc +0 -0
  35. src/liveportrait/__pycache__/util.cpython-310.pyc +0 -0
  36. src/liveportrait/__pycache__/util.cpython-39.pyc +0 -0
  37. src/liveportrait/camera.py +73 -0
  38. src/liveportrait/convnextv2.py +216 -0
  39. src/liveportrait/motion_extractor.py +212 -0
  40. src/liveportrait/util.py +492 -0
  41. src/modeling/__pycache__/engine_model.cpython-310.pyc +0 -0
  42. src/modeling/__pycache__/framed_models.cpython-310.pyc +0 -0
  43. src/modeling/__pycache__/onnx_export.cpython-310.pyc +0 -0
  44. src/modeling/engine_model.py +308 -0
  45. src/modeling/framed_models.py +177 -0
  46. src/modeling/onnx_export.py +102 -0
  47. src/models/__pycache__/attention.cpython-310.pyc +0 -0
  48. src/models/__pycache__/attention.cpython-39.pyc +0 -0
  49. src/models/__pycache__/motion_module.cpython-310.pyc +0 -0
  50. src/models/__pycache__/motion_module.cpython-39.pyc +0 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ demo/ref_image.png filter=lfs diff=lfs merge=lfs -text
38
+ pretrained_weights/onnx/unet_opt/unet_opt.onnx.data filter=lfs diff=lfs merge=lfs -text
39
+ pretrained_weights/tensorrt/unet_work(H100).engine filter=lfs diff=lfs merge=lfs -text
40
+ results/20251209--personalive_offline/concat_vid/ref_image_driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ results/20251209--personalive_offline/split_vid/ref_image_driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +0,0 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/inference/inference_stage3.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions:
7
+ - 1
8
+ - 2
9
+ - 4
10
+ - 8
11
+ motion_module_mid_block: true
12
+ motion_module_decoder_only: false
13
+ motion_module_type: Vanilla
14
+ motion_module_kwargs:
15
+ num_attention_heads: 8
16
+ num_transformer_block: 1
17
+ cross_attention_dim: 16
18
+ attention_block_types:
19
+ - Spatial_Cross
20
+ - Spatial_Cross
21
+ temporal_position_encoding: false
22
+ temporal_position_encoding_max_len: 32
23
+ temporal_attention_dim_div: 1
24
+
25
+ use_temporal_module: true
26
+ temporal_module_type: Vanilla
27
+ temporal_module_kwargs:
28
+ num_attention_heads: 8
29
+ num_transformer_block: 1
30
+ attention_block_types:
31
+ - Temporal_Self
32
+ - Temporal_Self
33
+ temporal_position_encoding: true
34
+ temporal_position_encoding_max_len: 32
35
+ temporal_attention_dim_div: 1
36
+
37
+
38
+ noise_scheduler_kwargs:
39
+ beta_start: 0.00085
40
+ beta_end: 0.02
41
+ beta_schedule: "scaled_linear"
42
+ clip_sample: false
43
+ steps_offset: 1
44
+ prediction_type: "epsilon"
45
+ timestep_spacing: "trailing"
46
+
47
+ sampler: DDIM
configs/prompts/personalive_offline.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_base_model_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers'
2
+ image_encoder_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers/image_encoder'
3
+ vae_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-vae-ft-mse'
4
+ vae_tiny_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/taesd'
5
+
6
+ denoising_unet_path: "./pretrained_weights/personalive/denoising_unet.pth"
7
+
8
+ inference_config: "configs/inference/inference_stage3.yaml"
9
+ weight_dtype: 'fp16'
10
+
11
+ test_cases:
12
+ 'demo/ref_image.png':
13
+ - 'demo/driving_video.mp4'
configs/prompts/personalive_online.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 1
2
+ height: 512
3
+ width: 512
4
+ reference_image_height: 512
5
+ reference_image_width: 512
6
+ temporal_adaptive_step: 4
7
+ temporal_window_size: 4
8
+ num_inference_steps: 4
9
+ dtype: "fp16"
10
+ fps: 16
11
+
12
+ vae_model_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-vae-ft-mse'
13
+ image_encoder_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers/image_encoder'
14
+ pretrained_base_model_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers'
15
+
16
+ reference_unet_weight_path: "./pretrained_weights/personalive/reference_unet.pth"
17
+ denoising_unet_path: "./pretrained_weights/personalive/denoising_unet.pth"
18
+ pose_guider_path: "./pretrained_weights/personalive/pose_guider.pth"
19
+ motion_encoder_path: './pretrained_weights/personalive/motion_encoder.pth'
20
+ temporal_module_path: "./pretrained_weights/personalive/temporal_module.pth"
21
+ pose_encoder_path: './pretrained_weights/personalive/motion_extractor.pth'
22
+
23
+ onnx_path: './pretrained_weights/onnx/unet/unet.onnx'
24
+ onnx_opt_path: './pretrained_weights/onnx/unet_opt/unet_opt.onnx'
25
+ tensorrt_target_model: './pretrained_weights/tensorrt/unet_work.engine'
26
+
27
+ inference_config: "./configs/inference/inference_stage3.yaml"
28
+ seed: 42
demo/driving_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a67895a319bf48323ba63d15050299908e2cd6d99f79f766033423eb53662e07
3
+ size 2923884
demo/ref_image.png ADDED

Git LFS Details

  • SHA256: a0b1e353e33cda46135494c5625e689d9ffa42d65bfd83690dd0cd4449a74e3f
  • Pointer size: 131 Bytes
  • Size of remote file: 451 kB
pose2vid_offline.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from datetime import datetime
5
+ import mediapipe as mp
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ from skimage.transform import resize
10
+ from diffusers import AutoencoderKLTemporalDecoder, AutoencoderKL, AutoencoderTiny
11
+ from src.scheduler.scheduler_ddim import DDIMScheduler
12
+ import random
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from transformers import CLIPVisionModelWithProjection
17
+ from src.models.unet_2d_condition import UNet2DConditionModel
18
+ from src.models.unet_3d import UNet3DConditionModel
19
+ from src.pipelines.pipeline_pose2vid import Pose2VideoPipeline
20
+ from src.utils.util import save_videos_grid, crop_face
21
+ from decord import VideoReader
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+
24
+ from src.models.motion_encoder.encoder import MotEncoder
25
+ from src.liveportrait.motion_extractor import MotionExtractor
26
+ from src.models.pose_guider import PoseGuider
27
+ from tqdm import tqdm
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--config", type=str, default='configs/prompts/personalive_offline.yaml')
32
+ parser.add_argument("--name", type=str, default='personalive_offline')
33
+ parser.add_argument("-W", type=int, default=512)
34
+ parser.add_argument("-H", type=int, default=512)
35
+ parser.add_argument("-L", type=int, default=1500)
36
+ parser.add_argument("--seed", type=int, default=42)
37
+ parser.add_argument("--device", type=str, default="cuda")
38
+ args = parser.parse_args()
39
+
40
+ return args
41
+
42
+ def main(args):
43
+ device = args.device
44
+ print('device', device)
45
+ config = OmegaConf.load(args.config)
46
+
47
+ if config.weight_dtype == "fp16":
48
+ weight_dtype = torch.float16
49
+ else:
50
+ weight_dtype = torch.float32
51
+
52
+ vae = AutoencoderKL.from_pretrained(config.vae_path).to(device, dtype=weight_dtype)
53
+ # if use tiny VAE
54
+ # vae_tiny = AutoencoderTiny.from_pretrained(config.vae_tiny_path).to(device, dtype=weight_dtype)
55
+
56
+ infer_config = OmegaConf.load(config.inference_config)
57
+ reference_unet = UNet2DConditionModel.from_pretrained(
58
+ config.pretrained_base_model_path,
59
+ subfolder="unet",
60
+ ).to(device=device, dtype=weight_dtype)
61
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
62
+ config.pretrained_base_model_path,
63
+ "",
64
+ subfolder="unet",
65
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
66
+ ).to(dtype=weight_dtype, device=device)
67
+
68
+ motion_encoder = MotEncoder().to(dtype=weight_dtype, device=device).eval()
69
+ pose_guider = PoseGuider().to(device=device, dtype=weight_dtype)
70
+ pose_encoder = MotionExtractor(num_kp=21).to(device=device, dtype=weight_dtype).eval()
71
+
72
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
73
+ config.image_encoder_path
74
+ ).to(dtype=weight_dtype, device=device)
75
+
76
+ sched_kwargs = OmegaConf.to_container(
77
+ OmegaConf.load(config.inference_config).noise_scheduler_kwargs
78
+ )
79
+ scheduler = DDIMScheduler(**sched_kwargs)
80
+
81
+ generator = torch.manual_seed(args.seed)
82
+ width, height = args.W, args.H
83
+
84
+ # load pretrained weights
85
+ denoising_unet.load_state_dict(
86
+ torch.load(config.denoising_unet_path, map_location="cpu"), strict=False
87
+ )
88
+ reference_unet.load_state_dict(
89
+ torch.load(
90
+ config.denoising_unet_path.replace('denoising_unet', 'reference_unet'),
91
+ map_location="cpu",
92
+ ),
93
+ strict=True,
94
+ )
95
+ motion_encoder.load_state_dict(
96
+ torch.load(
97
+ config.denoising_unet_path.replace('denoising_unet', 'motion_encoder'),
98
+ map_location="cpu",
99
+ ),
100
+ strict=True,
101
+ )
102
+ pose_guider.load_state_dict(
103
+ torch.load(
104
+ config.denoising_unet_path.replace('denoising_unet', 'pose_guider'),
105
+ map_location="cpu",
106
+ ),
107
+ strict=True,
108
+ )
109
+ denoising_unet.load_state_dict(
110
+ torch.load(
111
+ config.denoising_unet_path.replace('denoising_unet', 'temporal_module'),
112
+ map_location="cpu",
113
+ ),
114
+ strict=False,
115
+ )
116
+ pose_encoder.load_state_dict(
117
+ torch.load(
118
+ config.denoising_unet_path.replace('denoising_unet', 'motion_extractor'),
119
+ map_location="cpu",
120
+ ),
121
+ strict=False,
122
+ )
123
+
124
+ if is_xformers_available():
125
+ reference_unet.enable_xformers_memory_efficient_attention()
126
+ denoising_unet.enable_xformers_memory_efficient_attention()
127
+ else:
128
+ raise ValueError(
129
+ "xformers is not available. Make sure it is installed correctly"
130
+ )
131
+
132
+ mp_face_mesh = mp.solutions.face_mesh
133
+ face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1)
134
+
135
+ pipe = Pose2VideoPipeline(
136
+ vae=vae,
137
+ # vae_tiny=vae_tiny,
138
+ image_encoder=image_enc,
139
+ reference_unet=reference_unet,
140
+ denoising_unet=denoising_unet,
141
+ motion_encoder=motion_encoder,
142
+ pose_encoder=pose_encoder,
143
+ pose_guider=pose_guider,
144
+ scheduler=scheduler,
145
+ )
146
+ pipe = pipe.to(device)
147
+
148
+ date_str = datetime.now().strftime("%Y%m%d")
149
+ if args.name is None:
150
+ time_str = datetime.now().strftime("%H%M")
151
+ save_dir_name = f"{date_str}--{time_str}"
152
+ else:
153
+ save_dir_name = f"{date_str}--{args.name}"
154
+ save_vid_dir = os.path.join('results', save_dir_name, 'concat_vid')
155
+ os.makedirs(save_vid_dir, exist_ok=True)
156
+ save_split_vid_dir = os.path.join('results', save_dir_name, 'split_vid')
157
+ os.makedirs(save_split_vid_dir, exist_ok=True)
158
+
159
+ pose_transform = transforms.Compose(
160
+ [transforms.Resize((height, width)), transforms.ToTensor()]
161
+ )
162
+
163
+ args.test_cases = OmegaConf.load(args.config)["test_cases"]
164
+
165
+ for ref_image_path in list(args.test_cases.keys()):
166
+ for pose_video_path in args.test_cases[ref_image_path]:
167
+ video_name = os.path.basename(pose_video_path).split(".")[0]
168
+ source_name = os.path.basename(ref_image_path).split(".")[0]
169
+
170
+ vid_name = f"{source_name}_{video_name}.mp4"
171
+ save_vid_path = os.path.join(save_vid_dir, vid_name)
172
+ print(save_vid_path)
173
+ if os.path.exists(save_vid_path):
174
+ continue
175
+
176
+ if ref_image_path.endswith('.mp4'):
177
+ src_vid = VideoReader(ref_image_path)
178
+ ref_img = src_vid[0].asnumpy()
179
+ ref_img = Image.fromarray(ref_img).convert("RGB")
180
+ else:
181
+ ref_img = Image.open(ref_image_path).convert("RGB")
182
+
183
+ control = VideoReader(pose_video_path)
184
+ video_length = min(len(control) // 4 * 4, args.L)
185
+ sel_idx = range(len(control))[:video_length]
186
+ control = control.get_batch([sel_idx]).asnumpy() # N, H, W, C
187
+
188
+ ref_image_pil = ref_img.copy()
189
+ ref_patch = crop_face(ref_image_pil, face_mesh)
190
+ ref_face_pil = Image.fromarray(ref_patch).convert("RGB")
191
+
192
+ size = args.H
193
+ generator = torch.Generator(device=device)
194
+ generator.manual_seed(42)
195
+
196
+ dri_faces = []
197
+ ori_pose_images = []
198
+ for idx_control, pose_image_pil in tqdm(enumerate(control[:video_length]), total=video_length, desc='cropping faces'):
199
+ pose_image_pil = Image.fromarray(pose_image_pil).convert("RGB")
200
+ ori_pose_images.append(pose_image_pil)
201
+ dri_face = crop_face(pose_image_pil, face_mesh)
202
+ dri_face_pil = Image.fromarray(dri_face).convert("RGB")
203
+ dri_faces.append(dri_face_pil)
204
+
205
+ face_tensor_list = []
206
+ ori_pose_tensor_list = []
207
+ ref_tensor_list = []
208
+
209
+ for idx, pose_image_pil in enumerate(ori_pose_images):
210
+ face_tensor_list.append(pose_transform(dri_faces[idx]))
211
+ ori_pose_tensor_list.append(pose_transform(pose_image_pil))
212
+ ref_tensor_list.append(pose_transform(ref_image_pil))
213
+
214
+ ref_tensor = torch.stack(ref_tensor_list, dim=0) # (f, c, h, w)
215
+ ref_tensor = ref_tensor.transpose(0, 1).unsqueeze(0) # (c, f, h, w)
216
+
217
+ face_tensor = torch.stack(face_tensor_list, dim=0) # (f, c, h, w)
218
+ face_tensor = face_tensor.transpose(0, 1).unsqueeze(0)
219
+
220
+ ori_pose_tensor = torch.stack(ori_pose_tensor_list, dim=0) # (f, c, h, w)
221
+ ori_pose_tensor = ori_pose_tensor.transpose(0, 1).unsqueeze(0)
222
+
223
+ gen_video = pipe(
224
+ ori_pose_images,
225
+ ref_image_pil,
226
+ dri_faces,
227
+ ref_face_pil,
228
+ width,
229
+ height,
230
+ len(dri_faces),
231
+ num_inference_steps=4,
232
+ guidance_scale=1.0,
233
+ generator=generator,
234
+ temporal_window_size = 4,
235
+ temporal_adaptive_step = 4,
236
+ ).videos
237
+
238
+ #Concat it with pose tensor
239
+ video = torch.cat([ref_tensor, face_tensor, ori_pose_tensor, gen_video], dim=0)
240
+
241
+ save_videos_grid(
242
+ video,
243
+ save_vid_path,
244
+ n_rows=4,
245
+ fps=25,
246
+ )
247
+
248
+ if True:
249
+ save_vid_path = save_vid_path.replace(save_vid_dir, save_split_vid_dir)
250
+ save_videos_grid(gen_video, save_vid_path, n_rows=1, fps=25, crf=18)
251
+
252
+ if __name__ == "__main__":
253
+ args = parse_args()
254
+ main(args)
pose2vid_online.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import signal
3
+ import sys
4
+ import json
5
+
6
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, UploadFile, File
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fastapi import Request
11
+
12
+ import markdown2
13
+ import threading
14
+ import logging
15
+ import uuid
16
+ import time
17
+ from types import SimpleNamespace
18
+ import asyncio
19
+ import mimetypes
20
+ import torch
21
+
22
+ from webcam.config import config, Args
23
+ from webcam.util import pil_to_frame, bytes_to_pil, is_firefox, bytes_to_tensor
24
+ from webcam.connection_manager import ConnectionManager, ServerFullException
25
+ import multiprocessing as mp
26
+
27
+ use_trt = True
28
+
29
+ if use_trt:
30
+ from webcam.vid2vid_trt import Pipeline
31
+ else:
32
+ from webcam.vid2vid import Pipeline
33
+
34
+ mimetypes.add_type("application/javascript", ".js")
35
+
36
+ THROTTLE = 0.001
37
+
38
+
39
+ class App:
40
+ def __init__(self, config: Args, pipeline: Pipeline):
41
+ self.args = config
42
+ self.pipeline = pipeline
43
+ self.app = FastAPI()
44
+ self.conn_manager = ConnectionManager()
45
+
46
+ self.produce_predictions_stop_event = None
47
+ self.produce_predictions_task = None
48
+ self.shutdown_event = asyncio.Event()
49
+
50
+ self.init_app()
51
+
52
+ def init_app(self):
53
+ self.app.add_middleware(
54
+ CORSMiddleware,
55
+ allow_origins=["*"],
56
+ allow_credentials=True,
57
+ allow_methods=["*"],
58
+ allow_headers=["*"],
59
+ )
60
+
61
+ @self.app.websocket("/api/ws/{user_id}")
62
+ async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
63
+ try:
64
+ await self.conn_manager.connect(
65
+ user_id, websocket, self.args.max_queue_size
66
+ )
67
+
68
+ sender_task = asyncio.create_task(push_results_to_client(user_id, websocket))
69
+
70
+ if self.produce_predictions_task is None or self.produce_predictions_task.done():
71
+ start_prediction_thread(user_id)
72
+
73
+ await handle_websocket_input(user_id, websocket)
74
+
75
+ except ServerFullException as e:
76
+ logging.error(f"Server Full: {e}")
77
+ except WebSocketDisconnect:
78
+ logging.info(f"User disconnected: {user_id}")
79
+ except Exception as e:
80
+ logging.error(f"WS Error: {e}")
81
+ finally:
82
+ if 'sender_task' in locals():
83
+ sender_task.cancel()
84
+
85
+ await self.conn_manager.disconnect(user_id, self.pipeline)
86
+
87
+ if self.produce_predictions_stop_event is not None:
88
+ self.produce_predictions_stop_event.set()
89
+ logging.info(f"Cleaned up user: {user_id}")
90
+
91
+ async def handle_websocket_input(user_id: uuid.UUID, websocket: WebSocket):
92
+ if not self.conn_manager.check_user(user_id):
93
+ raise HTTPException(status_code=404, detail="User not found")
94
+
95
+ try:
96
+ while True:
97
+ message = await websocket.receive()
98
+
99
+ if "text" in message:
100
+ try:
101
+ text_data = message["text"]
102
+ data = json.loads(text_data)
103
+ status = data.get("status")
104
+
105
+ if status == "pause":
106
+ params = SimpleNamespace(**{"restart": True})
107
+ await self.conn_manager.update_data(user_id, params)
108
+ elif status == "resume":
109
+ await self.conn_manager.send_json(user_id, {"status": "send_frame"})
110
+ except Exception as e:
111
+ logging.error(f"JSON Parse Error: {e}")
112
+
113
+ elif "bytes" in message:
114
+ image_data = message["bytes"]
115
+ if len(image_data) > 0:
116
+ input_tensor = bytes_to_tensor(image_data)
117
+ params = SimpleNamespace()
118
+ params.image = input_tensor
119
+ self.pipeline.accept_new_params(params)
120
+
121
+ except WebSocketDisconnect:
122
+ raise
123
+ except Exception as e:
124
+ logging.error(f"Input Loop Error: {e}")
125
+ raise
126
+
127
+ async def push_results_to_client(user_id: uuid.UUID, websocket: WebSocket):
128
+ MIN_FPS = 10
129
+ MAX_FPS = 30
130
+ SMOOTHING = 0.8 # EMA smoothing factor
131
+
132
+ last_burst_time = time.time()
133
+ last_queue_size = 0
134
+ sleep_time = 1 / 40 # Initial guess
135
+
136
+ last_frame_time = None
137
+ frame_time_list = []
138
+
139
+ ema_frame_interval = sleep_time
140
+
141
+ try:
142
+ while True:
143
+ queue_size = await self.conn_manager.get_output_queue_size(user_id)
144
+ if queue_size > last_queue_size:
145
+ current_burst_time = time.time()
146
+ elapsed = current_burst_time - last_burst_time
147
+
148
+ if queue_size > 0 and elapsed > 0:
149
+ raw_interval = elapsed / queue_size
150
+ ema_frame_interval = SMOOTHING * ema_frame_interval + (1 - SMOOTHING) * raw_interval
151
+ sleep_time = min(max(ema_frame_interval, 1 / MAX_FPS), 1 / MIN_FPS)
152
+
153
+ last_burst_time = current_burst_time
154
+
155
+ last_queue_size = queue_size
156
+
157
+ frame = await self.conn_manager.get_frame(user_id)
158
+ if frame is None:
159
+ await asyncio.sleep(0.001)
160
+ continue
161
+
162
+ await websocket.send_bytes(frame)
163
+
164
+ if last_frame_time is None:
165
+ last_frame_time = time.time()
166
+ else:
167
+ frame_time_list.append(time.time() - last_frame_time)
168
+ if len(frame_time_list) > 100:
169
+ frame_time_list.pop(0)
170
+ last_frame_time = time.time()
171
+
172
+ await asyncio.sleep(sleep_time)
173
+
174
+ except asyncio.CancelledError:
175
+ pass
176
+ except Exception as e:
177
+ logging.error(f"Push Result Error: {e}")
178
+
179
+ def start_prediction_thread(user_id):
180
+ self.produce_predictions_stop_event = threading.Event()
181
+
182
+ def prediction_loop(uid, loop, stop_event):
183
+ while not stop_event.is_set():
184
+ images = self.pipeline.produce_outputs()
185
+ if len(images) == 0:
186
+ time.sleep(THROTTLE)
187
+ continue
188
+
189
+ frames = list(map(pil_to_frame, images))
190
+ asyncio.run_coroutine_threadsafe(
191
+ self.conn_manager.put_frames_to_output_queue(uid, frames),
192
+ loop
193
+ )
194
+
195
+ self.produce_predictions_task = asyncio.create_task(asyncio.to_thread(
196
+ prediction_loop, user_id, asyncio.get_running_loop(), self.produce_predictions_stop_event
197
+ ))
198
+
199
+ @self.app.get("/api/queue")
200
+ async def get_queue_size():
201
+ queue_size = self.conn_manager.get_user_count()
202
+ return JSONResponse({"queue_size": queue_size})
203
+
204
+ @self.app.get("/api/settings")
205
+ async def settings():
206
+ info_schema = pipeline.Info.schema()
207
+ info = pipeline.Info()
208
+ if info.page_content:
209
+ page_content = markdown2.markdown(info.page_content)
210
+
211
+ input_params = pipeline.InputParams.schema()
212
+ return JSONResponse(
213
+ {
214
+ "info": info_schema,
215
+ "input_params": input_params,
216
+ "max_queue_size": self.args.max_queue_size,
217
+ "page_content": page_content if info.page_content else "",
218
+ }
219
+ )
220
+
221
+ @self.app.post("/api/upload_reference_image")
222
+ async def upload_reference_image(ref_image: UploadFile = File(...)):
223
+ try:
224
+ data = await ref_image.read()
225
+ img = bytes_to_pil(data)
226
+ self.pipeline.fuse_reference(img)
227
+ return {"status": "ok"}
228
+ except Exception as e:
229
+ logging.error(f"Reference image error: {e}")
230
+ raise HTTPException(status_code=500, detail="Failed to process reference image")
231
+
232
+ if not os.path.exists("./demo_w_camera/frontend/public"):
233
+ os.makedirs("./demo_w_camera/frontend/public")
234
+
235
+ self.app.mount(
236
+ "/", StaticFiles(directory="./demo_w_camera/frontend/public", html=True), name="public"
237
+ )
238
+
239
+ @self.app.on_event("shutdown")
240
+ async def shutdown_event():
241
+ await self.cleanup()
242
+
243
+ async def cleanup(self):
244
+ print("[App] Starting cleanup process...")
245
+ self.shutdown_event.set()
246
+
247
+ if self.produce_predictions_stop_event is not None:
248
+ self.produce_predictions_stop_event.set()
249
+
250
+ if self.produce_predictions_task is not None:
251
+ self.produce_predictions_task.cancel()
252
+ try:
253
+ await self.produce_predictions_task
254
+ except asyncio.CancelledError:
255
+ pass
256
+
257
+ try:
258
+ await self.conn_manager.disconnect_all(self.pipeline)
259
+ except Exception as e:
260
+ print(f"[App] Error during disconnect_all: {e}")
261
+
262
+ print("[App] Cleanup completed")
263
+
264
+ app_instance = None
265
+
266
+ def signal_handler(signum, frame):
267
+ print(f"\n[Main] Received signal {signum}, shutting down gracefully...")
268
+ if app_instance:
269
+ import threading
270
+ def trigger_cleanup():
271
+ try:
272
+ loop = asyncio.new_event_loop()
273
+ asyncio.set_event_loop(loop)
274
+ loop.run_until_complete(app_instance.cleanup())
275
+ loop.close()
276
+ except Exception as e:
277
+ print(f"[Main] Error during cleanup: {e}")
278
+
279
+ cleanup_thread = threading.Thread(target=trigger_cleanup)
280
+ cleanup_thread.daemon = True
281
+ cleanup_thread.start()
282
+ cleanup_thread.join(timeout=5)
283
+
284
+ sys.exit(0)
285
+
286
+
287
+ if __name__ == "__main__":
288
+ import uvicorn
289
+ signal.signal(signal.SIGINT, signal_handler)
290
+ signal.signal(signal.SIGTERM, signal_handler)
291
+ mp.set_start_method("spawn", force=True)
292
+
293
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
294
+ pipeline = Pipeline(config, device)
295
+
296
+ app_obj = App(config, pipeline)
297
+ app = app_obj.app
298
+ app_instance = app_obj
299
+
300
+ print('init done')
301
+
302
+ try:
303
+ uvicorn.run(
304
+ app,
305
+ host=config.host,
306
+ port=config.port,
307
+ reload=config.reload,
308
+ ssl_certfile=config.ssl_certfile,
309
+ ssl_keyfile=config.ssl_keyfile,
310
+ )
311
+ except KeyboardInterrupt:
312
+ try:
313
+ import asyncio
314
+ loop = asyncio.new_event_loop()
315
+ asyncio.set_event_loop(loop)
316
+ loop.run_until_complete(app_obj.cleanup())
317
+ loop.close()
318
+ except Exception as e:
319
+ print(f"[Main] Error during cleanup: {e}")
320
+ sys.exit(0)
321
+ except Exception as e:
322
+ print(f"[Main] Error: {e}")
323
+ sys.exit(1)
pretrained_weights/.DS_Store ADDED
Binary file (8.2 kB). View file
 
pretrained_weights/onnx/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pretrained_weights/onnx/unet_opt/unet_opt.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:484aee7e8c45cddaac227b6ad331a88a77121dee0886f2152cc4bd0e9974b6fa
3
+ size 96224343
pretrained_weights/onnx/unet_opt/unet_opt.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa08ee8770f202be841e00f2bb94809c2ca6ca95ad8663c2917c4c6fa35d963e
3
+ size 3593537864
pretrained_weights/personalive/denoising_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0446c4d2387f259d5f3c1ac54a5aefa93400f4672f942856bff2538df046162
3
+ size 4927015578
pretrained_weights/personalive/motion_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff7c6b0a84cd750046e7687f7a6f6bbc21317055bfcacef950ed347debae4d2c
3
+ size 246719031
pretrained_weights/personalive/motion_extractor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:251e6a94ad667a1d0c69526d292677165110ef7f0cf0f6d199f0e414e8aa0ca5
3
+ size 112545506
pretrained_weights/personalive/pose_guider.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b997db63343a6a5d489778172d9544bcccaf27e6756505dc6353d84e877269d
3
+ size 4351790
pretrained_weights/personalive/reference_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85eb03e6c34fab69f9246ff14b3016789232e56dc4892d0581fea21a3a8480f6
3
+ size 3438324340
pretrained_weights/personalive/temporal_module.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:295e8942a453adb48756432d99de103ecba9b840b5b8f6635a0687311cdff30e
3
+ size 1817903018
pretrained_weights/tensorrt/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pretrained_weights/tensorrt/unet_work(H100).engine ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34bd6f7693300be8cf72a099f1160bfaedab7a677bcaf66f18ee33a5b871de50
3
+ size 3697605036
results/20251209--personalive_offline/concat_vid/ref_image_driving_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bf93d55acd386d689cda4588e636545219acf9910f1d6292eb6db0bed82c64b
3
+ size 7700854
results/20251209--personalive_offline/split_vid/ref_image_driving_video.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a064eb1a2effcb3514450e157ec6903973bca4c1d50a888e9c94c0f40a397213
3
+ size 7605688
src/.DS_Store ADDED
Binary file (8.2 kB). View file
 
src/__pycache__/wrapper.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
src/__pycache__/wrapper_trt.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
src/liveportrait/__pycache__/camera.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
src/liveportrait/__pycache__/camera.cpython-39.pyc ADDED
Binary file (1.77 kB). View file
 
src/liveportrait/__pycache__/convnextv2.cpython-310.pyc ADDED
Binary file (6.19 kB). View file
 
src/liveportrait/__pycache__/convnextv2.cpython-39.pyc ADDED
Binary file (6.45 kB). View file
 
src/liveportrait/__pycache__/motion_extractor.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
src/liveportrait/__pycache__/motion_extractor.cpython-39.pyc ADDED
Binary file (6.61 kB). View file
 
src/liveportrait/__pycache__/util.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
src/liveportrait/__pycache__/util.cpython-39.pyc ADDED
Binary file (16.1 kB). View file
 
src/liveportrait/camera.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ functions for processing and transforming 3D facial keypoints
5
+ """
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ PI = np.pi
12
+
13
+
14
+ def headpose_pred_to_degree(pred):
15
+ """
16
+ pred: (bs, 66) or (bs, 1) or others
17
+ """
18
+ if pred.ndim > 1 and pred.shape[1] == 66:
19
+ # NOTE: note that the average is modified to 97.5
20
+ device = pred.device
21
+ idx_tensor = [idx for idx in range(0, 66)]
22
+ idx_tensor = torch.FloatTensor(idx_tensor).to(device)
23
+ pred = F.softmax(pred, dim=1)
24
+ degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5
25
+
26
+ return degree
27
+
28
+ return pred
29
+
30
+
31
+ def get_rotation_matrix(pitch_, yaw_, roll_):
32
+ """ the input is in degree
33
+ """
34
+ # transform to radian
35
+ pitch = pitch_ / 180 * PI
36
+ yaw = yaw_ / 180 * PI
37
+ roll = roll_ / 180 * PI
38
+
39
+ device = pitch.device
40
+
41
+ if pitch.ndim == 1:
42
+ pitch = pitch.unsqueeze(1)
43
+ if yaw.ndim == 1:
44
+ yaw = yaw.unsqueeze(1)
45
+ if roll.ndim == 1:
46
+ roll = roll.unsqueeze(1)
47
+
48
+ # calculate the euler matrix
49
+ bs = pitch.shape[0]
50
+ ones = torch.ones([bs, 1]).to(device)
51
+ zeros = torch.zeros([bs, 1]).to(device)
52
+ x, y, z = pitch, yaw, roll
53
+
54
+ rot_x = torch.cat([
55
+ ones, zeros, zeros,
56
+ zeros, torch.cos(x), -torch.sin(x),
57
+ zeros, torch.sin(x), torch.cos(x)
58
+ ], dim=1).reshape([bs, 3, 3])
59
+
60
+ rot_y = torch.cat([
61
+ torch.cos(y), zeros, torch.sin(y),
62
+ zeros, ones, zeros,
63
+ -torch.sin(y), zeros, torch.cos(y)
64
+ ], dim=1).reshape([bs, 3, 3])
65
+
66
+ rot_z = torch.cat([
67
+ torch.cos(z), -torch.sin(z), zeros,
68
+ torch.sin(z), torch.cos(z), zeros,
69
+ zeros, zeros, ones
70
+ ], dim=1).reshape([bs, 3, 3])
71
+
72
+ rot = rot_z @ rot_y @ rot_x
73
+ return rot.permute(0, 2, 1) # transpose
src/liveportrait/convnextv2.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ # from timm.models.layers import trunc_normal_, DropPath
10
+ from src.liveportrait.util import LayerNorm, DropPath, trunc_normal_, GRN
11
+ from einops import rearrange
12
+
13
+ __all__ = ['convnextv2_tiny']
14
+
15
+
16
+ class Block(nn.Module):
17
+ """ ConvNeXtV2 Block.
18
+
19
+ Args:
20
+ dim (int): Number of input channels.
21
+ drop_path (float): Stochastic depth rate. Default: 0.0
22
+ """
23
+
24
+ def __init__(self, dim, drop_path=0.):
25
+ super().__init__()
26
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
27
+ self.norm = LayerNorm(dim, eps=1e-6)
28
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
29
+ self.act = nn.GELU()
30
+ self.grn = GRN(4 * dim)
31
+ self.pwconv2 = nn.Linear(4 * dim, dim)
32
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
33
+
34
+ def forward(self, x):
35
+ input = x
36
+ x = self.dwconv(x)
37
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
38
+ x = self.norm(x)
39
+ x = self.pwconv1(x)
40
+ x = self.act(x)
41
+ x = self.grn(x)
42
+ x = self.pwconv2(x)
43
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
44
+ x = input + self.drop_path(x)
45
+ return x
46
+
47
+
48
+ class ConvNeXtV2(nn.Module):
49
+ """ ConvNeXt V2
50
+
51
+ Args:
52
+ in_chans (int): Number of input image channels. Default: 3
53
+ num_classes (int): Number of classes for classification head. Default: 1000
54
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
55
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
56
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
57
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_chans=3,
63
+ depths=[3, 3, 9, 3],
64
+ dims=[96, 192, 384, 768],
65
+ drop_path_rate=0.,
66
+ **kwargs
67
+ ):
68
+ super().__init__()
69
+ self.depths = depths
70
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71
+ stem = nn.Sequential(
72
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74
+ )
75
+ self.downsample_layers.append(stem)
76
+ for i in range(3):
77
+ downsample_layer = nn.Sequential(
78
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80
+ )
81
+ self.downsample_layers.append(downsample_layer)
82
+
83
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85
+ cur = 0
86
+ for i in range(4):
87
+ stage = nn.Sequential(
88
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
89
+ )
90
+ self.stages.append(stage)
91
+ cur += depths[i]
92
+
93
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
94
+
95
+ # NOTE: the output semantic items
96
+ num_bins = kwargs.get('num_bins', 66)
97
+ num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
98
+ self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
99
+
100
+ # print('dims[-1]: ', dims[-1])
101
+ self.fc_scale = nn.Linear(dims[-1], 1) # scale
102
+ self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
103
+ self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
104
+ self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
105
+ self.fc_t = nn.Linear(dims[-1], 3) # translation
106
+ self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
107
+
108
+ def _init_weights(self, m):
109
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
110
+ trunc_normal_(m.weight, std=.02)
111
+ nn.init.constant_(m.bias, 0)
112
+
113
+ def forward_features(self, x):
114
+ for i in range(4):
115
+ x = self.downsample_layers[i](x)
116
+ x = self.stages[i](x)
117
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
118
+
119
+ def forward(self, x):
120
+ x = self.forward_features(x)
121
+
122
+ # implicit keypoints
123
+ kp = self.fc_kp(x)
124
+
125
+ # pose and expression deformation
126
+ pitch = self.fc_pitch(x)
127
+ yaw = self.fc_yaw(x)
128
+ roll = self.fc_roll(x)
129
+ t = self.fc_t(x)
130
+ # exp = self.fc_exp(x)
131
+ scale = self.fc_scale(x)
132
+
133
+ ret_dct = {
134
+ 'pitch': pitch,
135
+ 'yaw': yaw,
136
+ 'roll': roll,
137
+ 't': t,
138
+ # 'exp': exp,
139
+ 'scale': scale,
140
+
141
+ 'kp': kp, # canonical keypoint
142
+ }
143
+
144
+ return ret_dct
145
+
146
+
147
+ def convnextv2_tiny(**kwargs):
148
+ model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
149
+ return model
150
+
151
+ class ConvNeXt(nn.Module):
152
+ """ ConvNeXt V2
153
+
154
+ Args:
155
+ in_chans (int): Number of input image channels. Default: 3
156
+ num_classes (int): Number of classes for classification head. Default: 1000
157
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
158
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
159
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
160
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ in_chans=3,
166
+ depths=[3, 3, 9, 3],
167
+ dims=[96, 192, 384, 768],
168
+ drop_path_rate=0.,
169
+ **kwargs
170
+ ):
171
+ super().__init__()
172
+ self.depths = depths
173
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
174
+ stem = nn.Sequential(
175
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
176
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
177
+ )
178
+ self.downsample_layers.append(stem)
179
+ for i in range(3):
180
+ downsample_layer = nn.Sequential(
181
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
182
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
183
+ )
184
+ self.downsample_layers.append(downsample_layer)
185
+
186
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
187
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
188
+ cur = 0
189
+ for i in range(4):
190
+ stage = nn.Sequential(
191
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
192
+ )
193
+ self.stages.append(stage)
194
+ cur += depths[i]
195
+
196
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
197
+
198
+ def _init_weights(self, m):
199
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
200
+ trunc_normal_(m.weight, std=.02)
201
+ nn.init.constant_(m.bias, 0)
202
+
203
+ def forward_features(self, x):
204
+ for i in range(4):
205
+ x = self.downsample_layers[i](x)
206
+ x = self.stages[i](x)
207
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
208
+
209
+ def forward(self, x):
210
+ x = self.forward_features(x)
211
+ return x
212
+
213
+
214
+ def convnextv2(**kwargs):
215
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
216
+ return model
src/liveportrait/motion_extractor.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
5
+ """
6
+
7
+ from torch import nn
8
+ import torch
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from src.liveportrait.convnextv2 import convnextv2_tiny
11
+ from src.liveportrait.util import filter_state_dict
12
+ from src.liveportrait.camera import headpose_pred_to_degree, get_rotation_matrix
13
+
14
+ model_dict = {
15
+ 'convnextv2_tiny': convnextv2_tiny,
16
+ }
17
+
18
+
19
+ class MotionExtractor(ModelMixin):
20
+ def __init__(self, **kwargs):
21
+ super(MotionExtractor, self).__init__()
22
+
23
+ # default is convnextv2_base
24
+ backbone = kwargs.get('backbone', 'convnextv2_tiny')
25
+ self.detector = model_dict.get(backbone)(**kwargs)
26
+ self.register_buffer('idx_tensor', torch.arange(66, dtype=torch.float32))
27
+
28
+ def headpose_pred_to_degree(self, pred):
29
+ """
30
+ pred: (bs, 66) or (bs, 1) or others
31
+ """
32
+ if pred.ndim > 1 and pred.shape[1] == 66:
33
+ # NOTE: note that the average is modified to 97.5
34
+ prob = torch.nn.functional.softmax(pred, dim=1)
35
+ degree = torch.matmul(prob, self.idx_tensor)
36
+ degree = degree * 3 - 97.5
37
+
38
+ return degree
39
+
40
+ return pred
41
+
42
+ def load_pretrained(self, init_path: str):
43
+ if init_path not in (None, ''):
44
+ state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
45
+ state_dict = filter_state_dict(state_dict, remove_name='head')
46
+ ret = self.detector.load_state_dict(state_dict, strict=False)
47
+ print(f'Load pretrained model from {init_path}, ret: {ret}')
48
+
49
+ def forward(self, x):
50
+ kp_info = self.detector(x)
51
+ return self.get_kp(kp_info)
52
+
53
+ def get_kp(self, kp_info):
54
+ bs = kp_info['kp'].shape[0]
55
+
56
+ angles_raw = torch.cat([kp_info['pitch'], kp_info['yaw'], kp_info['roll']], dim=0) # (3, 66)
57
+ angles_deg = self.headpose_pred_to_degree(angles_raw)[:, None] # (B, 3)
58
+ pitch, yaw, roll = torch.chunk(angles_deg, chunks=3, dim=0)
59
+
60
+
61
+ kp = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
62
+ t, scale = kp_info['t'], kp_info['scale']
63
+
64
+ rot_mat = get_rotation_matrix(pitch, yaw, roll).to(self.dtype) # (bs, 3, 3)
65
+
66
+ if kp.ndim == 2:
67
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
68
+ else:
69
+ num_kp = kp.shape[1] # Bxnum_kpx3
70
+
71
+ # Eqn.2: s * (R * x_c,s) + t
72
+ kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat# + exp.view(bs, num_kp, 3)
73
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
74
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
75
+
76
+ return kp_transformed
77
+
78
+ def interpolate_tensors(self, a: torch.Tensor, b: torch.Tensor, num: int = 10) -> torch.Tensor:
79
+ if a.shape != b.shape:
80
+ raise ValueError(f"Shape mismatch: a.shape={a.shape}, b.shape={b.shape}")
81
+
82
+ B, *rest = a.shape
83
+ alphas = torch.linspace(0, 1, num, device=a.device, dtype=a.dtype)
84
+ view_shape = (num,) + (1,) * len(rest)
85
+ alphas = alphas.view(view_shape) # (1, num, 1, 1, ...)
86
+
87
+ result = (1 - alphas) * a + alphas * b
88
+ return result[:-1]
89
+
90
+ def interpolate_kps(self, ref, motion, num_interp, t_scale=0.5, s_scale=0):
91
+ kp1 = self.detector(ref.to(self.dtype))
92
+ kp2_list = []
93
+ for i in range(0, motion.shape[0], 256):
94
+ motion_chunk = motion[i:i+256]
95
+ kp2_chunk = self.detector(motion_chunk.to(self.dtype))
96
+ kp2_list.append(kp2_chunk)
97
+ kp2 = {}
98
+ for key in kp2_list[0].keys():
99
+ kp2[key] = torch.cat([kp2_chunk[key] for kp2_chunk in kp2_list], dim=0)
100
+
101
+ angles_raw = torch.cat([kp1['pitch'], kp1['yaw'], kp1['roll']], dim=0) # (3, 66)
102
+ angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
103
+ pitch_1, yaw_1, roll_1 = torch.chunk(angles_deg, chunks=3, dim=0)
104
+
105
+ angles_raw = torch.cat([kp2['pitch'], kp2['yaw'], kp2['roll']], dim=0) # (3, 66)
106
+ angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
107
+ pitch_2, yaw_2, roll_2 = torch.chunk(angles_deg, chunks=3, dim=0)
108
+
109
+ pitch_interp = self.interpolate_tensors(pitch_1, pitch_2[:1], num_interp) # Bx(num_interp)x1
110
+ yaw_interp = self.interpolate_tensors(yaw_1, yaw_2[:1], num_interp) # Bx(num_interp)x1
111
+ roll_interp = self.interpolate_tensors(roll_1, roll_2[:1], num_interp) # Bx(num_interp)x1
112
+
113
+ t_1 = kp1['t']
114
+ t_2 = kp2['t']
115
+ t_2 = (t_2 - t_2[0]) * t_scale + t_1
116
+ t_interp = self.interpolate_tensors(t_1, t_2[:1], num_interp)
117
+
118
+ s_1 = kp1['scale']
119
+ s_2 = kp2['scale']
120
+ s_2 = s_2 * s_scale + s_1
121
+ s_interp = self.interpolate_tensors(s_1, s_2[:1], num_interp)
122
+
123
+ kp = kp1['kp'].repeat(num_interp+motion.shape[0]-1, 1)
124
+
125
+ kps_interp = {
126
+ 'pitch': torch.cat([pitch_interp, pitch_2], dim=0),
127
+ 'yaw': torch.cat([yaw_interp, yaw_2], dim=0),
128
+ 'roll': torch.cat([roll_interp, roll_2], dim=0),
129
+ 't': torch.cat([t_interp, t_2], dim=0),
130
+ 'scale': torch.cat([s_interp, s_2], dim=0),
131
+ 'kp': kp
132
+ }
133
+
134
+ kp_intrep = self.get_kp(kps_interp)
135
+
136
+ return kp_intrep
137
+
138
+
139
+ def interpolate_kps_online(self, ref, motion, num_interp, t_scale=0.5, s_scale=0):
140
+ kp1 = self.detector(ref.to(self.dtype))
141
+ kp_frame1 = self.detector(motion[:1].to(self.dtype))
142
+ kp2 = self.detector(motion.to(self.dtype))
143
+
144
+ angles_raw = torch.cat([kp1['pitch'], kp1['yaw'], kp1['roll']], dim=0) # (3, 66)
145
+ angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
146
+ pitch_1, yaw_1, roll_1 = torch.chunk(angles_deg, chunks=3, dim=0)
147
+
148
+ angles_raw = torch.cat([kp2['pitch'], kp2['yaw'], kp2['roll']], dim=0) # (3, 66)
149
+ angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
150
+ pitch_2, yaw_2, roll_2 = torch.chunk(angles_deg, chunks=3, dim=0)
151
+
152
+ pitch_interp = self.interpolate_tensors(pitch_1, pitch_2[:1], num_interp) # Bx(num_interp)x1
153
+ yaw_interp = self.interpolate_tensors(yaw_1, yaw_2[:1], num_interp) # Bx(num_interp)x1
154
+ roll_interp = self.interpolate_tensors(roll_1, roll_2[:1], num_interp) # Bx(num_interp)x1
155
+
156
+ t_1 = kp1['t']
157
+ t_2 = kp2['t']
158
+ t_2 = (t_2 - t_2[0]) * t_scale + t_1
159
+ t_interp = self.interpolate_tensors(t_1, t_2[:1], num_interp)
160
+
161
+ s_1 = kp1['scale']
162
+ s_2 = kp2['scale']
163
+ s_2 = s_2 * s_scale + s_1
164
+ s_interp = self.interpolate_tensors(s_1, s_2[:1], num_interp)
165
+
166
+ kp = kp1['kp'].repeat(num_interp+motion.shape[0]-1, 1)
167
+
168
+ kps_interp = {
169
+ 'pitch': torch.cat([pitch_interp, pitch_2], dim=0),
170
+ 'yaw': torch.cat([yaw_interp, yaw_2], dim=0),
171
+ 'roll': torch.cat([roll_interp, roll_2], dim=0),
172
+ 't': torch.cat([t_interp, t_2], dim=0),
173
+ 'scale': torch.cat([s_interp, s_2], dim=0),
174
+ 'kp': kp
175
+ }
176
+
177
+ kp_intrep = self.get_kp(kps_interp)
178
+
179
+ kp_dri = self.get_kp(kp2)
180
+
181
+ return kp_intrep, kp1, kp_frame1, kp_dri
182
+
183
+ def get_kps(self, kp_ref, kp_frame1, motion, t_scale=0.5, s_scale=0):
184
+ kps_motion = self.detector(motion.to(self.dtype))
185
+
186
+ kps_dri = self.get_kp(kps_motion)
187
+
188
+ t_ref = kp_ref['t']
189
+ t_frame1 = kp_frame1['t']
190
+ t_motion = kps_motion['t']
191
+ kps_motion['t'] = (t_motion - t_frame1) * t_scale + t_ref
192
+
193
+ s_ref = kp_ref['scale']
194
+ s_motion = kps_motion['scale']
195
+ kps_motion['scale'] = s_motion * s_scale + s_ref
196
+
197
+
198
+ kps_motion['kp'] = kp_ref['kp'].repeat(motion.shape[0], 1)
199
+
200
+ kps_motion = self.get_kp(kps_motion)
201
+
202
+ return kps_motion, kps_dri
203
+
204
+ def inference(self, ref, motion):
205
+ kps_ref = self.detector(ref.to(self.dtype))
206
+ kps_motion = self.detector(motion.to(self.dtype))
207
+ kps_motion['kp'] = kps_ref['kp']
208
+
209
+ kp_s = self.get_kp(kps_ref)
210
+ kp_d = self.get_kp(kps_motion)
211
+
212
+ return kp_s, kp_d
src/liveportrait/util.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This file defines various neural network modules and utility functions, including convolutional and residual blocks,
5
+ normalizations, and functions for spatial transformation and tensor manipulation.
6
+ """
7
+
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch
11
+ import torch.nn.utils.spectral_norm as spectral_norm
12
+ import math
13
+ import warnings
14
+ import collections.abc
15
+ from itertools import repeat
16
+
17
+ def kp2gaussian(kp, spatial_size, kp_variance):
18
+ """
19
+ Transform a keypoint into gaussian like representation
20
+ """
21
+ mean = kp
22
+
23
+ coordinate_grid = make_coordinate_grid(spatial_size, mean)
24
+ number_of_leading_dimensions = len(mean.shape) - 1
25
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
26
+ coordinate_grid = coordinate_grid.view(*shape)
27
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
28
+ coordinate_grid = coordinate_grid.repeat(*repeats)
29
+
30
+ # Preprocess kp shape
31
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
32
+ mean = mean.view(*shape)
33
+
34
+ mean_sub = (coordinate_grid - mean)
35
+
36
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
37
+
38
+ return out
39
+
40
+
41
+ def make_coordinate_grid(spatial_size, ref, **kwargs):
42
+ d, h, w = spatial_size
43
+ x = torch.arange(w).type(ref.dtype).to(ref.device)
44
+ y = torch.arange(h).type(ref.dtype).to(ref.device)
45
+ z = torch.arange(d).type(ref.dtype).to(ref.device)
46
+
47
+ # NOTE: must be right-down-in
48
+ x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
49
+ y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
50
+ z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
51
+
52
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
53
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
54
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
55
+
56
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
57
+
58
+ return meshed
59
+
60
+
61
+ class ConvT2d(nn.Module):
62
+ """
63
+ Upsampling block for use in decoder.
64
+ """
65
+
66
+ def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
67
+ super(ConvT2d, self).__init__()
68
+
69
+ self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
70
+ padding=padding, output_padding=output_padding)
71
+ self.norm = nn.InstanceNorm2d(out_features)
72
+
73
+ def forward(self, x):
74
+ out = self.convT(x)
75
+ out = self.norm(out)
76
+ out = F.leaky_relu(out)
77
+ return out
78
+
79
+
80
+ class ResBlock3d(nn.Module):
81
+ """
82
+ Res block, preserve spatial resolution.
83
+ """
84
+
85
+ def __init__(self, in_features, kernel_size, padding):
86
+ super(ResBlock3d, self).__init__()
87
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
88
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
89
+ self.norm1 = nn.BatchNorm3d(in_features, affine=True)
90
+ self.norm2 = nn.BatchNorm3d(in_features, affine=True)
91
+
92
+ def forward(self, x):
93
+ out = self.norm1(x)
94
+ out = F.relu(out)
95
+ out = self.conv1(out)
96
+ out = self.norm2(out)
97
+ out = F.relu(out)
98
+ out = self.conv2(out)
99
+ out += x
100
+ return out
101
+
102
+
103
+ class UpBlock3d(nn.Module):
104
+ """
105
+ Upsampling block for use in decoder.
106
+ """
107
+
108
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
109
+ super(UpBlock3d, self).__init__()
110
+
111
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
112
+ padding=padding, groups=groups)
113
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
114
+
115
+ def forward(self, x):
116
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
117
+ out = self.conv(out)
118
+ out = self.norm(out)
119
+ out = F.relu(out)
120
+ return out
121
+
122
+
123
+ class DownBlock2d(nn.Module):
124
+ """
125
+ Downsampling block for use in encoder.
126
+ """
127
+
128
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
129
+ super(DownBlock2d, self).__init__()
130
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
131
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
132
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
133
+
134
+ def forward(self, x):
135
+ out = self.conv(x)
136
+ out = self.norm(out)
137
+ out = F.relu(out)
138
+ out = self.pool(out)
139
+ return out
140
+
141
+
142
+ class DownBlock3d(nn.Module):
143
+ """
144
+ Downsampling block for use in encoder.
145
+ """
146
+
147
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
148
+ super(DownBlock3d, self).__init__()
149
+ '''
150
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
151
+ padding=padding, groups=groups, stride=(1, 2, 2))
152
+ '''
153
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
154
+ padding=padding, groups=groups)
155
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
156
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
157
+
158
+ def forward(self, x):
159
+ out = self.conv(x)
160
+ out = self.norm(out)
161
+ out = F.relu(out)
162
+ out = self.pool(out)
163
+ return out
164
+
165
+
166
+ class SameBlock2d(nn.Module):
167
+ """
168
+ Simple block, preserve spatial resolution.
169
+ """
170
+
171
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
172
+ super(SameBlock2d, self).__init__()
173
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
174
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
175
+ if lrelu:
176
+ self.ac = nn.LeakyReLU()
177
+ else:
178
+ self.ac = nn.ReLU()
179
+
180
+ def forward(self, x):
181
+ out = self.conv(x)
182
+ out = self.norm(out)
183
+ out = self.ac(out)
184
+ return out
185
+
186
+
187
+ class Encoder(nn.Module):
188
+ """
189
+ Hourglass Encoder
190
+ """
191
+
192
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
193
+ super(Encoder, self).__init__()
194
+
195
+ down_blocks = []
196
+ for i in range(num_blocks):
197
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
198
+ self.down_blocks = nn.ModuleList(down_blocks)
199
+
200
+ def forward(self, x):
201
+ outs = [x]
202
+ for down_block in self.down_blocks:
203
+ outs.append(down_block(outs[-1]))
204
+ return outs
205
+
206
+
207
+ class Decoder(nn.Module):
208
+ """
209
+ Hourglass Decoder
210
+ """
211
+
212
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
213
+ super(Decoder, self).__init__()
214
+
215
+ up_blocks = []
216
+
217
+ for i in range(num_blocks)[::-1]:
218
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
219
+ out_filters = min(max_features, block_expansion * (2 ** i))
220
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
221
+
222
+ self.up_blocks = nn.ModuleList(up_blocks)
223
+ self.out_filters = block_expansion + in_features
224
+
225
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
226
+ self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
227
+
228
+ def forward(self, x):
229
+ out = x.pop()
230
+ for up_block in self.up_blocks:
231
+ out = up_block(out)
232
+ skip = x.pop()
233
+ out = torch.cat([out, skip], dim=1)
234
+ out = self.conv(out)
235
+ out = self.norm(out)
236
+ out = F.relu(out)
237
+ return out
238
+
239
+
240
+ class Hourglass(nn.Module):
241
+ """
242
+ Hourglass architecture.
243
+ """
244
+
245
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
246
+ super(Hourglass, self).__init__()
247
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
248
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
249
+ self.out_filters = self.decoder.out_filters
250
+
251
+ def forward(self, x):
252
+ return self.decoder(self.encoder(x))
253
+
254
+
255
+ class SPADE(nn.Module):
256
+ def __init__(self, norm_nc, label_nc):
257
+ super().__init__()
258
+
259
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
260
+ nhidden = 128
261
+
262
+ self.mlp_shared = nn.Sequential(
263
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
264
+ nn.ReLU())
265
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
266
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
267
+
268
+ def forward(self, x, segmap):
269
+ normalized = self.param_free_norm(x)
270
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
271
+ actv = self.mlp_shared(segmap)
272
+ gamma = self.mlp_gamma(actv)
273
+ beta = self.mlp_beta(actv)
274
+ out = normalized * (1 + gamma) + beta
275
+ return out
276
+
277
+
278
+ class SPADEResnetBlock(nn.Module):
279
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
280
+ super().__init__()
281
+ # Attributes
282
+ self.learned_shortcut = (fin != fout)
283
+ fmiddle = min(fin, fout)
284
+ self.use_se = use_se
285
+ # create conv layers
286
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
287
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
288
+ if self.learned_shortcut:
289
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
290
+ # apply spectral norm if specified
291
+ if 'spectral' in norm_G:
292
+ self.conv_0 = spectral_norm(self.conv_0)
293
+ self.conv_1 = spectral_norm(self.conv_1)
294
+ if self.learned_shortcut:
295
+ self.conv_s = spectral_norm(self.conv_s)
296
+ # define normalization layers
297
+ self.norm_0 = SPADE(fin, label_nc)
298
+ self.norm_1 = SPADE(fmiddle, label_nc)
299
+ if self.learned_shortcut:
300
+ self.norm_s = SPADE(fin, label_nc)
301
+
302
+ def forward(self, x, seg1):
303
+ x_s = self.shortcut(x, seg1)
304
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
305
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
306
+ out = x_s + dx
307
+ return out
308
+
309
+ def shortcut(self, x, seg1):
310
+ if self.learned_shortcut:
311
+ x_s = self.conv_s(self.norm_s(x, seg1))
312
+ else:
313
+ x_s = x
314
+ return x_s
315
+
316
+ def actvn(self, x):
317
+ return F.leaky_relu(x, 2e-1)
318
+
319
+
320
+ def filter_state_dict(state_dict, remove_name='fc'):
321
+ new_state_dict = {}
322
+ for key in state_dict:
323
+ if remove_name in key:
324
+ continue
325
+ new_state_dict[key] = state_dict[key]
326
+ return new_state_dict
327
+
328
+
329
+ class GRN(nn.Module):
330
+ """ GRN (Global Response Normalization) layer
331
+ """
332
+
333
+ def __init__(self, dim):
334
+ super().__init__()
335
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
336
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
337
+
338
+ def forward(self, x):
339
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
340
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
341
+ return self.gamma * (x * Nx) + self.beta + x
342
+
343
+
344
+ class LayerNorm(nn.Module):
345
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
346
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
347
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
348
+ with shape (batch_size, channels, height, width).
349
+ """
350
+
351
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
352
+ super().__init__()
353
+ self.weight = nn.Parameter(torch.ones(normalized_shape, dtype=torch.float32))
354
+ self.bias = nn.Parameter(torch.zeros(normalized_shape, dtype=torch.float32))
355
+ self.eps = float(eps)
356
+ self.data_format = data_format
357
+ if self.data_format not in ["channels_last", "channels_first"]:
358
+ raise NotImplementedError
359
+ self.normalized_shape = (normalized_shape, )
360
+
361
+ def _apply(self, fn):
362
+ """
363
+ 重写 _apply,完全接管参数的转换逻辑。
364
+ 拦截所有 .cuda(), .cpu(), .half(), .to() 操作。
365
+ """
366
+
367
+ for name, param in self._parameters.items():
368
+ if param is not None:
369
+ dummy_probe = param.data.view(-1)[:1]
370
+
371
+ try:
372
+ target_tensor = fn(dummy_probe)
373
+
374
+ target_device = target_tensor.device
375
+ target_dtype = target_tensor.dtype
376
+ except:
377
+ target_device = param.device
378
+ target_dtype = param.dtype
379
+
380
+ if name in ['weight', 'bias']:
381
+ # 核心逻辑:如果是 weight/bias,且目标是半精度,则强制保持 FP32
382
+ if target_dtype in [torch.float16, torch.bfloat16]:
383
+ new_data = param.data.to(device=target_device, dtype=torch.float32)
384
+ else:
385
+ new_data = fn(param.data)
386
+ else:
387
+ new_data = fn(param.data)
388
+
389
+ param.data = new_data
390
+
391
+ if param.grad is not None:
392
+ param.grad.data = param.grad.data.to(device=new_data.device, dtype=new_data.dtype)
393
+
394
+ for name, buf in self._buffers.items():
395
+ if buf is not None:
396
+ self._buffers[name] = fn(buf)
397
+
398
+ return self
399
+
400
+ def forward(self, x):
401
+ dtype = x.dtype
402
+ x = x.float()
403
+ if self.data_format == "channels_last":
404
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
405
+ elif self.data_format == "channels_first":
406
+ x = x.permute(0, 2, 3, 1) # BCHW → BHWC
407
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
408
+ x = x.permute(0, 3, 1, 2) # BHWC → BCHW
409
+ return x.to(dtype)
410
+
411
+
412
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
413
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
414
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
415
+ def norm_cdf(x):
416
+ # Computes standard normal cumulative distribution function
417
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
418
+
419
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
420
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
421
+ "The distribution of values may be incorrect.",
422
+ stacklevel=2)
423
+
424
+ with torch.no_grad():
425
+ # Values are generated by using a truncated uniform distribution and
426
+ # then using the inverse CDF for the normal distribution.
427
+ # Get upper and lower cdf values
428
+ l = norm_cdf((a - mean) / std)
429
+ u = norm_cdf((b - mean) / std)
430
+
431
+ # Uniformly fill tensor with values from [l, u], then translate to
432
+ # [2l-1, 2u-1].
433
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
434
+
435
+ # Use inverse cdf transform for normal distribution to get truncated
436
+ # standard normal
437
+ tensor.erfinv_()
438
+
439
+ # Transform to proper mean, std
440
+ tensor.mul_(std * math.sqrt(2.))
441
+ tensor.add_(mean)
442
+
443
+ # Clamp to ensure it's in the proper range
444
+ tensor.clamp_(min=a, max=b)
445
+ return tensor
446
+
447
+
448
+ def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
449
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
450
+
451
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
452
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
453
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
454
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
455
+ 'survival rate' as the argument.
456
+
457
+ """
458
+ if drop_prob == 0. or not training:
459
+ return x
460
+ keep_prob = 1 - drop_prob
461
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
462
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
463
+ if keep_prob > 0.0 and scale_by_keep:
464
+ random_tensor.div_(keep_prob)
465
+ return x * random_tensor
466
+
467
+
468
+ class DropPath(nn.Module):
469
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
470
+ """
471
+
472
+ def __init__(self, drop_prob=None, scale_by_keep=True):
473
+ super(DropPath, self).__init__()
474
+ self.drop_prob = drop_prob
475
+ self.scale_by_keep = scale_by_keep
476
+
477
+ def forward(self, x):
478
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
479
+
480
+
481
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
482
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
483
+
484
+ # From PyTorch internals
485
+ def _ntuple(n):
486
+ def parse(x):
487
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
488
+ return tuple(x)
489
+ return tuple(repeat(x, n))
490
+ return parse
491
+
492
+ to_2tuple = _ntuple(2)
src/modeling/__pycache__/engine_model.cpython-310.pyc ADDED
Binary file (9.13 kB). View file
 
src/modeling/__pycache__/framed_models.cpython-310.pyc ADDED
Binary file (5.97 kB). View file
 
src/modeling/__pycache__/onnx_export.cpython-310.pyc ADDED
Binary file (2.52 kB). View file
 
src/modeling/engine_model.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorrt as trt
2
+ import pycuda.driver as cuda
3
+ import pycuda.autoinit
4
+ import numpy as np
5
+ import torch
6
+ import traceback
7
+ import os
8
+ from PIL import Image
9
+
10
+ TRT_LOGGER = trt.Logger()
11
+ SKIP_ENGINE_MODEL_CHECK = True
12
+
13
+ def get_engine(engine_file_path):
14
+ if os.path.exists(engine_file_path):
15
+ print(f"Loading engine from file {engine_file_path}...")
16
+ with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
17
+ return runtime.deserialize_cuda_engine(f.read())
18
+ else:
19
+ print(f"No file named {engine_file_path}! Please check the input.")
20
+ return None
21
+
22
+
23
+ def numpy_to_torch_dtype(np_dtype):
24
+ mapping = {
25
+ np.float32: torch.float,
26
+ np.float64: torch.double,
27
+ np.float16: torch.half,
28
+ np.int32: torch.int32,
29
+ np.int64: torch.int64,
30
+ np.int16: torch.int16,
31
+ np.int8: torch.int8,
32
+ np.uint8: torch.uint8,
33
+ np.bool_: torch.bool
34
+ }
35
+ return mapping.get(np_dtype, None)
36
+
37
+ def match_shape(a, b):
38
+ if(len(a) == len(b)):
39
+ return tuple(a) == tuple(b)
40
+ elif len(a) > len(b):
41
+ if(a[0] == 1):
42
+ return match_shape(a[1:], b)
43
+ else:
44
+ if(b[0] == 1):
45
+ return match_shape(a, b[1:])
46
+ return False
47
+
48
+ def match_dtype(a, b):
49
+ if(a.__class__ == torch.dtype):
50
+ a = torch.tensor(0,dtype=a).numpy().dtype
51
+ return a == b
52
+
53
+
54
+ class EngineModel:
55
+ def __init__(self, engine_file_path, stream = None, device_int = 0, extra_lock = None):
56
+ self.device_int = device_int
57
+ self.extra_lock = extra_lock
58
+ if not(self.extra_lock is None):
59
+ self.extra_lock.acquire()
60
+ assert os.path.exists(engine_file_path), "Engine model path not exists!"
61
+ self.ctx = cuda.Device(self.device_int).make_context()
62
+ try:
63
+ self.engine = get_engine(engine_file_path) # 载入TensorRT引擎
64
+ input_nvars = 0
65
+ output_nvars = 0
66
+ self.input_names = []
67
+ self.output_names = []
68
+
69
+ # 【辅助函数】用于获取安全的 Shape (消除 -1)
70
+ def get_safe_shape(engine, name):
71
+ shape = engine.get_tensor_shape(name)
72
+ # 如果形状里包含 -1 (动态维度)
73
+ if -1 in shape:
74
+ # 获取 Profile 0 的 (min, opt, max)
75
+ # 取下标 [2] 即 Max Shape,确保分配足够大的显存
76
+ profile = engine.get_tensor_profile_shape(name, 0)
77
+ if profile:
78
+ print(f"[EngineModel] Detected dynamic shape for {name}: {shape} -> Using Max Profile: {profile[2]}")
79
+ return profile[2]
80
+ else:
81
+ # 如果获取不到 Profile (通常发生在 Output),这是一个风险点
82
+ # 这里为了防止报错,可以尝试打印警告
83
+ print(f"[EngineModel] Warning: Dynamic output {name} has no profile. Mem alloc might fail.")
84
+ return shape
85
+
86
+ for binding in self.engine: # 遍历所有tensor,区分Input/Output
87
+ mode = self.engine.get_tensor_mode(binding)
88
+ if(mode== trt.TensorIOMode.INPUT):
89
+ input_nvars += 1
90
+ self.input_names.append(binding)
91
+ elif(mode == trt.TensorIOMode.OUTPUT):
92
+ output_nvars += 1
93
+ self.output_names.append(binding)
94
+
95
+ self.input_nvars = input_nvars # input的数量
96
+ self.output_nvars = output_nvars # output的数量
97
+
98
+ self.input_shapes = {name : get_safe_shape(self.engine, name) for name in self.input_names} # 获取每个 I/O 的 shape 和 dtype
99
+ self.input_dtypes = {name : self.engine.get_tensor_dtype(name) for name in self.input_names}
100
+ self.input_nbytes = {
101
+ name : trt.volume(self.input_shapes[name]) * trt.nptype(self.input_dtypes[name])().itemsize
102
+ for name in self.input_names
103
+ } # nbytes = tensor 占多少 CUDA 内存(字节数)
104
+ self.output_shapes = {name : get_safe_shape(self.engine, name) for name in self.output_names}
105
+ self.output_dtypes = {name : self.engine.get_tensor_dtype(name) for name in self.output_names}
106
+ self.output_nbytes = {
107
+ name : trt.volume(self.output_shapes[name]) * trt.nptype(self.output_dtypes[name])().itemsize
108
+ for name in self.output_names
109
+ }
110
+ self.dinputs = {name : cuda.mem_alloc(self.input_nbytes[name]) for name in self.input_names} # 为每个输入/输出分配 CUDA 设备内存
111
+ self.doutputs = {name :cuda.mem_alloc(self.output_nbytes[name]) for name in self.output_names}
112
+ self.context = self.engine.create_execution_context() # 创建 ExecutionContext(执行上下文)
113
+ if stream is None:
114
+ self.stream = cuda.Stream()
115
+ else:
116
+ self.stream = stream
117
+ for name in self.input_names: # 绑定 tensor 到 context
118
+ self.context.set_tensor_address(name, int(self.dinputs[name]))
119
+ for name in self.output_names:
120
+ self.context.set_tensor_address(name, int(self.doutputs[name]))
121
+ self.houtputs = {
122
+ name :
123
+ cuda.pagelocked_empty(
124
+ trt.volume(self.output_shapes[name]), dtype=trt.nptype(self.output_dtypes[name])
125
+ ) for name in self.output_names
126
+ } # 分配 page-locked host 内存以存储输出
127
+ except:
128
+ self.ctx.pop()
129
+ raise Exception("CUDA Initialization Failed!")
130
+ self.ctx.pop()
131
+ if not(self.extra_lock is None):
132
+ self.extra_lock.release()
133
+
134
+ def __call__(self, skip_check=SKIP_ENGINE_MODEL_CHECK, output_list=[], return_tensor=False, **inputs):
135
+ if not skip_check:
136
+ for name in inputs:
137
+ assert name in self.input_names
138
+ assert match_shape(inputs[name].shape, self.input_shapes[name])
139
+ assert match_dtype(inputs[name].dtype, trt.nptype(self.input_dtypes[name]))
140
+ if not(self.extra_lock is None):
141
+ self.extra_lock.acquire()
142
+ self.ctx.push()
143
+ r = {}
144
+ try:
145
+
146
+ for name in inputs:
147
+ hinput = inputs[name]
148
+ if (isinstance(hinput,torch.Tensor) and hinput.device.type=="cuda" and hinput.device.index==self.device_int):
149
+ hinput_con = hinput.contiguous()
150
+ ptr = hinput_con.data_ptr()
151
+ cuda.memcpy_dtod_async(self.dinputs[name], ptr, self.input_nbytes[name], self.stream)
152
+ else:
153
+ hinput_con = np.ascontiguousarray(hinput)
154
+ cuda.memcpy_htod_async(self.dinputs[name], hinput_con, self.stream)
155
+ for name in self.input_names:
156
+ if name not in inputs:
157
+ self.context.set_input_shape(name, self.input_shapes[name])
158
+ self.context.execute_async_v3(self.stream.handle)
159
+ if(return_tensor):
160
+ for name in output_list:
161
+ t = torch.zeros(trt.volume(self.output_shapes[name]), device=f"cuda:{self.device_int}", dtype=numpy_to_torch_dtype(trt.nptype(self.output_dtypes[name])))
162
+ ptr = t.data_ptr()
163
+ cuda.memcpy_dtod_async(ptr, self.doutputs[name], self.output_nbytes[name], self.stream)
164
+ t = t.reshape(tuple(self.output_shapes[name]))
165
+ r[name] = t
166
+ else:
167
+ for name in output_list:
168
+ cuda.memcpy_dtoh_async(self.houtputs[name], self.doutputs[name], self.stream)
169
+ r[name] = self.houtputs[name]
170
+ self.stream.synchronize()
171
+ except Exception as e:
172
+ print("TensorRT Execution Failed!")
173
+ traceback.print_exc()
174
+ self.ctx.pop()
175
+ if not(self.extra_lock is None):
176
+ self.extra_lock.release()
177
+ return None
178
+ self.ctx.pop()
179
+ if not(self.extra_lock is None):
180
+ self.extra_lock.release()
181
+ return r
182
+
183
+
184
+ def prefill(self, skip_check=SKIP_ENGINE_MODEL_CHECK, **inputs):
185
+ if not (skip_check):
186
+ for name in inputs:
187
+ in_input = (name in self.input_names)
188
+ assert in_input or (name in self.output_names)
189
+ assert match_shape(inputs[name].shape, self.input_shapes[name] if in_input else self.output_shapes[name])
190
+ assert match_dtype(inputs[name].dtype, trt.nptype(self.input_dtypes[name] if in_input else self.output_dtypes[name]))
191
+ if not(self.extra_lock is None):
192
+ self.extra_lock.acquire()
193
+ self.ctx.push()
194
+ try:
195
+ for name in inputs:
196
+ in_input = (name in self.input_names)
197
+ hinput = inputs[name]
198
+
199
+ dst_ptr = self.dinputs[name] if in_input else self.doutputs[name]
200
+ real_nbytes = 0
201
+ if isinstance(hinput, torch.Tensor):
202
+ real_nbytes = hinput.numel() * hinput.element_size()
203
+ else:
204
+ # 假设是 numpy
205
+ real_nbytes = hinput.nbytes
206
+
207
+ if (isinstance(hinput,torch.Tensor) and hinput.device.type=="cuda" and hinput.device.index==self.device_int):
208
+ hinput_con = hinput.contiguous()
209
+ ptr = hinput_con.data_ptr()
210
+ cuda.memcpy_dtod_async(dst_ptr, ptr, real_nbytes, self.stream)
211
+ else:
212
+ hinput_con = np.ascontiguousarray(hinput)
213
+ cuda.memcpy_htod_async(dst_ptr, hinput, self.stream)
214
+ self.stream.synchronize()
215
+ except Exception as e:
216
+ traceback.print_exc()
217
+ self.ctx.pop()
218
+ if not(self.extra_lock is None):
219
+ self.extra_lock.release()
220
+ return False
221
+ self.ctx.pop()
222
+ if not(self.extra_lock is None):
223
+ self.extra_lock.release()
224
+ return True
225
+
226
+ def __repr__(self):
227
+ r = "TensorRTEngineModel(\n\tInput=[\n"
228
+ for name in self.input_names:
229
+ r += f"\t\t{name}: \t{trt.nptype(self.input_dtypes[name]).__name__}{self.input_shapes[name]},\n"
230
+ r += "\t],Output=[\n"
231
+ for name in self.output_names:
232
+ r += f"\t\t{name}: \t{trt.nptype(self.output_dtypes[name]).__name__}{self.output_shapes[name]},\n"
233
+ r+="\t]\n)"
234
+ return r
235
+
236
+ def link(self, other, var_map, skip_check=SKIP_ENGINE_MODEL_CHECK):
237
+ assert self.device_int == other.device_int
238
+ if not (skip_check):
239
+ for source in var_map:
240
+ assert source in other.output_names
241
+ target = var_map[source]
242
+ assert target in self.input_names
243
+ assert match_shape(other.output_shapes[source], self.input_shapes[target])
244
+ assert match_dtype(other.output_dtypes[source], self.input_dtypes[target])
245
+
246
+ if not(self.extra_lock is None):
247
+ self.extra_lock.acquire()
248
+ self.ctx.push()
249
+ try:
250
+ for source in var_map:
251
+ target = var_map[source]
252
+ self.context.set_tensor_address(target, int(other.doutputs[source]))
253
+ except Exception as e:
254
+ traceback.print_exc()
255
+ self.ctx.pop()
256
+ if not(self.extra_lock is None):
257
+ self.extra_lock.release()
258
+ return False
259
+ self.ctx.pop()
260
+ if not(self.extra_lock is None):
261
+ self.extra_lock.release()
262
+ return True
263
+
264
+ def bind(self, var_map, skip_check=SKIP_ENGINE_MODEL_CHECK):
265
+ if not (skip_check):
266
+ for source in var_map:
267
+ assert source in self.output_names
268
+ target = var_map[source]
269
+ assert target in self.input_names
270
+ assert match_shape(self.output_shapes[source], self.input_shapes[target])
271
+ assert match_dtype(self.output_dtypes[source], self.input_dtypes[target])
272
+
273
+ if not(self.extra_lock is None):
274
+ self.extra_lock.acquire()
275
+ self.ctx.push()
276
+ try:
277
+ for source in var_map:
278
+ target = var_map[source]
279
+ self.context.set_tensor_address(target, int(self.doutputs[source]))
280
+ except Exception as e:
281
+ traceback.print_exc()
282
+ self.ctx.pop()
283
+ if not(self.extra_lock is None):
284
+ self.extra_lock.release()
285
+ return False
286
+ self.ctx.pop()
287
+ if not(self.extra_lock is None):
288
+ self.extra_lock.release()
289
+ return True
290
+
291
+ def unlink(self):
292
+
293
+ if not(self.extra_lock is None):
294
+ self.extra_lock.acquire()
295
+ self.ctx.push()
296
+ try:
297
+ for name in self.input_names:
298
+ self.context.set_tensor_address(name, int(self.dinputs[name]))
299
+ except:
300
+ self.ctx.pop()
301
+ if not(self.extra_lock is None):
302
+ self.extra_lock.release()
303
+ return False
304
+ self.ctx.pop()
305
+ if not(self.extra_lock is None):
306
+ self.extra_lock.release()
307
+ return True
308
+
src/modeling/framed_models.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops import rearrange
4
+ from polygraphy.backend.trt import Profile
5
+
6
+ class unet_work(nn.Module): # Ugly Power Strip
7
+ def __init__(self, pose_guider, motion_encoder, unet, vae, scheduler, timestep):
8
+ super().__init__()
9
+ self.pose_guider = pose_guider
10
+ self.motion_encoder = motion_encoder
11
+ self.unet = unet
12
+ self.vae = vae
13
+ self.scheduler = scheduler
14
+ self.timesteps = timestep
15
+
16
+ def decode_slice(self, vae, x):
17
+ x = x / 0.18215
18
+ x = vae.decode(x).sample
19
+ x = rearrange(x, "b c h w -> b h w c")
20
+ x = (x / 2 + 0.5).clamp(0, 1)
21
+ return x
22
+
23
+ def forward(self, sample, encoder_hidden_states, motion_hidden_states, motion, pose_cond_fea, pose, new_noise,
24
+ d00, d01, d10, d11, d20, d21, m, u10, u11, u12, u20, u21, u22, u30, u31, u32
25
+ ):
26
+ new_pose_cond_fea = self.pose_guider(pose)
27
+ pose_cond_fea = torch.cat([pose_cond_fea, new_pose_cond_fea], dim=2)
28
+ new_motion_hidden_states = self.motion_encoder(motion)
29
+ motion_hidden_states = torch.cat([motion_hidden_states, new_motion_hidden_states], dim=1)
30
+ encoder_hidden_states = [encoder_hidden_states, motion_hidden_states]
31
+ score = self.unet(sample, self.timesteps, encoder_hidden_states, pose_cond_fea, d00, d01, d10, d11, d20, d21, m, u10, u11, u12, u20, u21, u22, u30, u31, u32)
32
+ score = rearrange(score, 'b c f h w -> (b f) c h w')
33
+ sample = rearrange(sample, 'b c f h w -> (b f) c h w')
34
+ latents_model_input, pred_original_sample = self.scheduler.step(
35
+ score, self.timesteps, sample, return_dict=False
36
+ )
37
+ latents_model_input = latents_model_input.to(sample.dtype)
38
+ pred_original_sample = pred_original_sample.to(sample.dtype)
39
+ latents_model_input = rearrange(latents_model_input, '(b f) c h w -> b c f h w', f=16)
40
+ pred_video = self.decode_slice(self.vae, pred_original_sample[:4])
41
+ latents = torch.cat([latents_model_input[:, :, 4:, :, :], new_noise], dim=2)
42
+ pose_cond_fea_out = pose_cond_fea[:, :, 4:, :, :]
43
+ motion_hidden_states_out = motion_hidden_states[:, 4:, :, :]
44
+ motion_out = motion_hidden_states[:, :1, :, :]
45
+ return pred_video, latents, pose_cond_fea_out, motion_hidden_states_out, motion_out, pred_original_sample[:1]
46
+
47
+ def get_sample_input(self, batchsize, height, width, dtype, device):
48
+ tw, ts, tb = 4, 4, 16 # temporal window size| temporal adaptive steps | temporal batch size
49
+ ml, mc, mh, mw= 32, 16, 224, 224 # motion latent size | motion channels
50
+ b, h, w = batchsize, height, width
51
+ lh, lw = height // 8, width // 8 # latent height | width
52
+ cd0, cd1, cd2, cm, cu1, cu2, cu3 = 320, 640, 1280, 1280, 1280, 640, 320 # unet channels
53
+ emb = 768 # CLIP Embedding Dims | TAESDV Channels
54
+ lc, ic = 4, 3 # latent | image channels
55
+ profile = {
56
+ "sample" : [b, lc, tb, lh, lw],
57
+ "encoder_hidden_states" : [b, 1, emb],
58
+ "motion_hidden_states" : [b, tw * (ts - 1), ml, mc],
59
+ "motion": [b, ic, tw, mh, mw],
60
+ "pose_cond_fea" : [b, cd0, tw * (ts - 1), lh, lw],
61
+ "pose" : [b, ic, tw, h, w],
62
+ "new_noise" : [b, lc, tw, lh, lw],
63
+ "d00" : [b, lh * lw, cd0],
64
+ "d01" : [b, lh * lw, cd0],
65
+ "d10" : [b, lh * lw // 4, cd1],
66
+ "d11" : [b, lh * lw // 4, cd1],
67
+ "d20" : [b, lh * lw // 16, cd2],
68
+ "d21" : [b, lh * lw // 16, cd2],
69
+ "m" : [b, lh * lw // 64, cm],
70
+ "u10" : [b, lh * lw // 16, cu1],
71
+ "u11" : [b, lh * lw // 16, cu1],
72
+ "u12" : [b, lh * lw // 16, cu1],
73
+ "u20" : [b, lh * lw // 4, cu2],
74
+ "u21" : [b, lh * lw // 4, cu2],
75
+ "u22" : [b, lh * lw // 4, cu2],
76
+ "u30" : [b, lh * lw, cu3],
77
+ "u31" : [b, lh * lw, cu3],
78
+ "u32" : [b, lh * lw, cu3],
79
+ }
80
+ return {k: torch.randn(profile[k], dtype=dtype, device=device) for k in profile}
81
+
82
+ def get_input_names(self):
83
+ return ["sample", "encoder_hidden_states", "motion_hidden_states",
84
+ "motion", "pose_cond_fea", "pose", "new_noise",
85
+ "d00", "d01", "d10", "d11", "d20", "d21", "m", "u10", "u11", "u12",
86
+ "u20", "u21", "u22", "u30", "u31", "u32"]
87
+
88
+ def get_output_names(self):
89
+ return ["pred_video", "latents", "pose_cond_fea_out",
90
+ "motion_hidden_states_out", "motion_out", "latent_first"]
91
+
92
+ def get_dynamic_axes(self):
93
+ dynamic_axes = {
94
+ "sample": {3:"h_64", 4:"w_64"},
95
+ "pose_cond_fea": {3:"h_64", 4:"w_64"},
96
+ "pose": {3:"h_512", 4:"h_512"},
97
+ "new_noise": {3: "h_64", 4: "w_64"},
98
+ "d00" : {1: "len_4096"},
99
+ "d01" : {1: "len_4096"},
100
+ "u30" : {1: "len_4096"},
101
+ "u31" : {1: "len_4096"},
102
+ "u32" : {1: "len_4096"},
103
+ "d10" : {1: "len_1024"},
104
+ "d11" : {1: "len_1024"},
105
+ "u20" : {1: "len_1024"},
106
+ "u21" : {1: "len_1024"},
107
+ "u22" : {1: "len_1024"},
108
+ "d20" : {1: "len_256"},
109
+ "d21" : {1: "len_256"},
110
+ "u10" : {1: "len_256"},
111
+ "u11" : {1: "len_256"},
112
+ "u12" : {1: "len_256"},
113
+ "m" : {1: "len_64"},
114
+ }
115
+ return dynamic_axes
116
+
117
+ def get_dynamic_map(self, batchsize, height, width):
118
+ tw, ts, tb = 4, 4, 16 # temporal window size| temporal adaptive steps | temporal batch size
119
+ ml, mc, mh, mw= 32, 16, 224, 224 # motion latent size | motion channels
120
+ b, h, w = batchsize, height, width
121
+ lh, lw = height // 8, width // 8 # latent height | width
122
+ cd0, cd1, cd2, cm, cu1, cu2, cu3 = 320, 640, 1280, 1280, 1280, 640, 320 # unet channels
123
+ emb = 768 # CLIP Embedding Dims | TAESDV Channels
124
+ lc, ic = 4, 3 # latent | image channels
125
+
126
+ fixed_inputs_map = {
127
+ "sample": (b, lc, tb, lh, lw),
128
+ "encoder_hidden_states": (b, 1, emb),
129
+ "motion_hidden_states": (b, tw * (ts - 1), ml, mc),
130
+ "motion": (b, ic, tw, mh, mw),
131
+ "pose_cond_fea": (b, cd0, tw * (ts - 1), lh, lw),
132
+ "pose": (b, ic, tw, h, w),
133
+ "new_noise": (b, lc, tw, lh, lw),
134
+ }
135
+
136
+ dynamic_inputs_map = {
137
+ "d00": (b, lh * lw, cd0),
138
+ "d01": (b, lh * lw, cd0),
139
+ "d10": (b, lh * lw // 4, cd1),
140
+ "d11": (b, lh * lw // 4, cd1),
141
+ "d20": (b, lh * lw // 16, cd2),
142
+ "d21": (b, lh * lw // 16, cd2),
143
+ "m": (b, lh * lw // 64, cm),
144
+ "u10": (b, lh * lw // 16, cu1),
145
+ "u11": (b, lh * lw // 16, cu1),
146
+ "u12": (b, lh * lw // 16, cu1),
147
+ "u20": (b, lh * lw // 4, cu2),
148
+ "u21": (b, lh * lw // 4, cu2),
149
+ "u22": (b, lh * lw // 4, cu2),
150
+ "u30": (b, lh * lw, cu3),
151
+ "u31": (b, lh * lw, cu3),
152
+ "u32": (b, lh * lw, cu3),
153
+ }
154
+
155
+ profile = Profile()
156
+
157
+ for name, shape in fixed_inputs_map.items():
158
+ shape_tuple = tuple(shape)
159
+ profile.add(name, min=shape_tuple, opt=shape_tuple, max=shape_tuple)
160
+
161
+ for name, base_shape in dynamic_inputs_map.items():
162
+
163
+ dim0, dim1_base, dim2 = base_shape
164
+
165
+ val_1x = dim1_base * 1
166
+ val_2x = dim1_base * 2
167
+ val_4x = dim1_base * 4
168
+
169
+ min_shape = (dim0, val_1x, dim2)
170
+ opt_shape = (dim0, val_2x, dim2)
171
+ max_shape = (dim0, val_4x, dim2)
172
+
173
+ profile.add(name, min=min_shape, opt=opt_shape, max=max_shape)
174
+
175
+ print(f"Dynamic: {name:<5} | Base(1x): {dim1_base:<5} | Range: {val_1x} ~ {val_4x} | Opt: {val_2x}")
176
+
177
+ return profile
src/modeling/onnx_export.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py
2
+ #
3
+ # Copyright 2022 The HuggingFace Inc. team.
4
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
5
+ # SPDX-License-Identifier: Apache-2.0
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+
20
+ import onnx
21
+ import gc
22
+ import onnx_graphsurgeon as gs
23
+ import torch
24
+ from onnx import shape_inference
25
+ from polygraphy.backend.onnx.loader import fold_constants
26
+ import os
27
+ from onnxsim import simplify
28
+
29
+ @torch.no_grad()
30
+ def export_onnx(
31
+ model,
32
+ onnx_path: str,
33
+ opt_image_height: int,
34
+ opt_image_width: int,
35
+ opt_batch_size: int,
36
+ onnx_opset: int,
37
+ dtype,
38
+ device,
39
+ auto_cast: bool = True,
40
+ ):
41
+ from contextlib import contextmanager
42
+
43
+ @contextmanager
44
+ def auto_cast_manager(enabled):
45
+ if enabled:
46
+ with torch.inference_mode(), torch.autocast("cuda"):
47
+ yield
48
+ else:
49
+ yield
50
+
51
+ # 确保父目录存在
52
+ os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
53
+
54
+ with auto_cast_manager(auto_cast):
55
+ inputs = model.get_sample_input(opt_batch_size, opt_image_height, opt_image_width, dtype, device)
56
+
57
+ print(model.get_output_names())
58
+ print(f"开始导出 ONNX 模型到: {onnx_path} ...")
59
+ torch.onnx.utils.export(
60
+ model,
61
+ inputs,
62
+ onnx_path,
63
+ export_params=True,
64
+ opset_version=onnx_opset,
65
+ do_constant_folding=True,
66
+ input_names=model.get_input_names(),
67
+ output_names=model.get_output_names(),
68
+ dynamic_axes=model.get_dynamic_axes(),
69
+ )
70
+
71
+ del model
72
+ gc.collect()
73
+ torch.cuda.empty_cache()
74
+
75
+ def optimize_onnx(onnx_path, onnx_opt_path):
76
+ model = onnx.load(onnx_path)
77
+ name = os.path.splitext(os.path.basename(onnx_opt_path))[0]
78
+ model_opt = model
79
+
80
+ print(f"Saving to {onnx_opt_path}...")
81
+ onnx.save(
82
+ model_opt,
83
+ onnx_opt_path,
84
+ save_as_external_data=True,
85
+ all_tensors_to_one_file=True,
86
+ location=f"{name}.onnx.data",
87
+ size_threshold=1024
88
+ )
89
+ print("Optimization done.")
90
+
91
+ def handle_onnx_batch_norm(onnx_path: str):
92
+ onnx_model = onnx.load(onnx_path)
93
+ for node in onnx_model.graph.node:
94
+ if node.op_type == "BatchNormalization":
95
+ for attribute in node.attribute:
96
+ if attribute.name == "training_mode":
97
+ if attribute.i == 1:
98
+ node.output.remove(node.output[1])
99
+ node.output.remove(node.output[1])
100
+ attribute.i = 0
101
+
102
+ onnx.save_model(onnx_model, onnx_path)
src/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (9.71 kB). View file
 
src/models/__pycache__/attention.cpython-39.pyc ADDED
Binary file (9.57 kB). View file
 
src/models/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
src/models/__pycache__/motion_module.cpython-39.pyc ADDED
Binary file (10.4 kB). View file