Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .DS_Store +0 -0
- .gitattributes +6 -0
- README.md +0 -3
- configs/.DS_Store +0 -0
- configs/inference/inference_stage3.yaml +47 -0
- configs/prompts/personalive_offline.yaml +13 -0
- configs/prompts/personalive_online.yaml +28 -0
- demo/driving_video.mp4 +3 -0
- demo/ref_image.png +3 -0
- pose2vid_offline.py +254 -0
- pose2vid_online.py +323 -0
- pretrained_weights/.DS_Store +0 -0
- pretrained_weights/onnx/.DS_Store +0 -0
- pretrained_weights/onnx/unet_opt/unet_opt.onnx +3 -0
- pretrained_weights/onnx/unet_opt/unet_opt.onnx.data +3 -0
- pretrained_weights/personalive/denoising_unet.pth +3 -0
- pretrained_weights/personalive/motion_encoder.pth +3 -0
- pretrained_weights/personalive/motion_extractor.pth +3 -0
- pretrained_weights/personalive/pose_guider.pth +3 -0
- pretrained_weights/personalive/reference_unet.pth +3 -0
- pretrained_weights/personalive/temporal_module.pth +3 -0
- pretrained_weights/tensorrt/.DS_Store +0 -0
- pretrained_weights/tensorrt/unet_work(H100).engine +3 -0
- results/20251209--personalive_offline/concat_vid/ref_image_driving_video.mp4 +3 -0
- results/20251209--personalive_offline/split_vid/ref_image_driving_video.mp4 +3 -0
- src/.DS_Store +0 -0
- src/__pycache__/wrapper.cpython-310.pyc +0 -0
- src/__pycache__/wrapper_trt.cpython-310.pyc +0 -0
- src/liveportrait/__pycache__/camera.cpython-310.pyc +0 -0
- src/liveportrait/__pycache__/camera.cpython-39.pyc +0 -0
- src/liveportrait/__pycache__/convnextv2.cpython-310.pyc +0 -0
- src/liveportrait/__pycache__/convnextv2.cpython-39.pyc +0 -0
- src/liveportrait/__pycache__/motion_extractor.cpython-310.pyc +0 -0
- src/liveportrait/__pycache__/motion_extractor.cpython-39.pyc +0 -0
- src/liveportrait/__pycache__/util.cpython-310.pyc +0 -0
- src/liveportrait/__pycache__/util.cpython-39.pyc +0 -0
- src/liveportrait/camera.py +73 -0
- src/liveportrait/convnextv2.py +216 -0
- src/liveportrait/motion_extractor.py +212 -0
- src/liveportrait/util.py +492 -0
- src/modeling/__pycache__/engine_model.cpython-310.pyc +0 -0
- src/modeling/__pycache__/framed_models.cpython-310.pyc +0 -0
- src/modeling/__pycache__/onnx_export.cpython-310.pyc +0 -0
- src/modeling/engine_model.py +308 -0
- src/modeling/framed_models.py +177 -0
- src/modeling/onnx_export.py +102 -0
- src/models/__pycache__/attention.cpython-310.pyc +0 -0
- src/models/__pycache__/attention.cpython-39.pyc +0 -0
- src/models/__pycache__/motion_module.cpython-310.pyc +0 -0
- src/models/__pycache__/motion_module.cpython-39.pyc +0 -0
.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo/driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo/ref_image.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
pretrained_weights/onnx/unet_opt/unet_opt.onnx.data filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
pretrained_weights/tensorrt/unet_work(H100).engine filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
results/20251209--personalive_offline/concat_vid/ref_image_driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
results/20251209--personalive_offline/split_vid/ref_image_driving_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
configs/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
configs/inference/inference_stage3.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
unet_additional_kwargs:
|
| 2 |
+
use_inflated_groupnorm: true
|
| 3 |
+
unet_use_cross_frame_attention: false
|
| 4 |
+
unet_use_temporal_attention: false
|
| 5 |
+
use_motion_module: true
|
| 6 |
+
motion_module_resolutions:
|
| 7 |
+
- 1
|
| 8 |
+
- 2
|
| 9 |
+
- 4
|
| 10 |
+
- 8
|
| 11 |
+
motion_module_mid_block: true
|
| 12 |
+
motion_module_decoder_only: false
|
| 13 |
+
motion_module_type: Vanilla
|
| 14 |
+
motion_module_kwargs:
|
| 15 |
+
num_attention_heads: 8
|
| 16 |
+
num_transformer_block: 1
|
| 17 |
+
cross_attention_dim: 16
|
| 18 |
+
attention_block_types:
|
| 19 |
+
- Spatial_Cross
|
| 20 |
+
- Spatial_Cross
|
| 21 |
+
temporal_position_encoding: false
|
| 22 |
+
temporal_position_encoding_max_len: 32
|
| 23 |
+
temporal_attention_dim_div: 1
|
| 24 |
+
|
| 25 |
+
use_temporal_module: true
|
| 26 |
+
temporal_module_type: Vanilla
|
| 27 |
+
temporal_module_kwargs:
|
| 28 |
+
num_attention_heads: 8
|
| 29 |
+
num_transformer_block: 1
|
| 30 |
+
attention_block_types:
|
| 31 |
+
- Temporal_Self
|
| 32 |
+
- Temporal_Self
|
| 33 |
+
temporal_position_encoding: true
|
| 34 |
+
temporal_position_encoding_max_len: 32
|
| 35 |
+
temporal_attention_dim_div: 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
noise_scheduler_kwargs:
|
| 39 |
+
beta_start: 0.00085
|
| 40 |
+
beta_end: 0.02
|
| 41 |
+
beta_schedule: "scaled_linear"
|
| 42 |
+
clip_sample: false
|
| 43 |
+
steps_offset: 1
|
| 44 |
+
prediction_type: "epsilon"
|
| 45 |
+
timestep_spacing: "trailing"
|
| 46 |
+
|
| 47 |
+
sampler: DDIM
|
configs/prompts/personalive_offline.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pretrained_base_model_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers'
|
| 2 |
+
image_encoder_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers/image_encoder'
|
| 3 |
+
vae_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-vae-ft-mse'
|
| 4 |
+
vae_tiny_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/taesd'
|
| 5 |
+
|
| 6 |
+
denoising_unet_path: "./pretrained_weights/personalive/denoising_unet.pth"
|
| 7 |
+
|
| 8 |
+
inference_config: "configs/inference/inference_stage3.yaml"
|
| 9 |
+
weight_dtype: 'fp16'
|
| 10 |
+
|
| 11 |
+
test_cases:
|
| 12 |
+
'demo/ref_image.png':
|
| 13 |
+
- 'demo/driving_video.mp4'
|
configs/prompts/personalive_online.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 1
|
| 2 |
+
height: 512
|
| 3 |
+
width: 512
|
| 4 |
+
reference_image_height: 512
|
| 5 |
+
reference_image_width: 512
|
| 6 |
+
temporal_adaptive_step: 4
|
| 7 |
+
temporal_window_size: 4
|
| 8 |
+
num_inference_steps: 4
|
| 9 |
+
dtype: "fp16"
|
| 10 |
+
fps: 16
|
| 11 |
+
|
| 12 |
+
vae_model_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-vae-ft-mse'
|
| 13 |
+
image_encoder_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers/image_encoder'
|
| 14 |
+
pretrained_base_model_path: '/public_hw/home/cit_xdcun/zyli/x-nemo-inference/pretrained_weights/sd-image-variations-diffusers'
|
| 15 |
+
|
| 16 |
+
reference_unet_weight_path: "./pretrained_weights/personalive/reference_unet.pth"
|
| 17 |
+
denoising_unet_path: "./pretrained_weights/personalive/denoising_unet.pth"
|
| 18 |
+
pose_guider_path: "./pretrained_weights/personalive/pose_guider.pth"
|
| 19 |
+
motion_encoder_path: './pretrained_weights/personalive/motion_encoder.pth'
|
| 20 |
+
temporal_module_path: "./pretrained_weights/personalive/temporal_module.pth"
|
| 21 |
+
pose_encoder_path: './pretrained_weights/personalive/motion_extractor.pth'
|
| 22 |
+
|
| 23 |
+
onnx_path: './pretrained_weights/onnx/unet/unet.onnx'
|
| 24 |
+
onnx_opt_path: './pretrained_weights/onnx/unet_opt/unet_opt.onnx'
|
| 25 |
+
tensorrt_target_model: './pretrained_weights/tensorrt/unet_work.engine'
|
| 26 |
+
|
| 27 |
+
inference_config: "./configs/inference/inference_stage3.yaml"
|
| 28 |
+
seed: 42
|
demo/driving_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a67895a319bf48323ba63d15050299908e2cd6d99f79f766033423eb53662e07
|
| 3 |
+
size 2923884
|
demo/ref_image.png
ADDED
|
Git LFS Details
|
pose2vid_offline.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import mediapipe as mp
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
from skimage.transform import resize
|
| 10 |
+
from diffusers import AutoencoderKLTemporalDecoder, AutoencoderKL, AutoencoderTiny
|
| 11 |
+
from src.scheduler.scheduler_ddim import DDIMScheduler
|
| 12 |
+
import random
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
from transformers import CLIPVisionModelWithProjection
|
| 17 |
+
from src.models.unet_2d_condition import UNet2DConditionModel
|
| 18 |
+
from src.models.unet_3d import UNet3DConditionModel
|
| 19 |
+
from src.pipelines.pipeline_pose2vid import Pose2VideoPipeline
|
| 20 |
+
from src.utils.util import save_videos_grid, crop_face
|
| 21 |
+
from decord import VideoReader
|
| 22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 23 |
+
|
| 24 |
+
from src.models.motion_encoder.encoder import MotEncoder
|
| 25 |
+
from src.liveportrait.motion_extractor import MotionExtractor
|
| 26 |
+
from src.models.pose_guider import PoseGuider
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
def parse_args():
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument("--config", type=str, default='configs/prompts/personalive_offline.yaml')
|
| 32 |
+
parser.add_argument("--name", type=str, default='personalive_offline')
|
| 33 |
+
parser.add_argument("-W", type=int, default=512)
|
| 34 |
+
parser.add_argument("-H", type=int, default=512)
|
| 35 |
+
parser.add_argument("-L", type=int, default=1500)
|
| 36 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 37 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
return args
|
| 41 |
+
|
| 42 |
+
def main(args):
|
| 43 |
+
device = args.device
|
| 44 |
+
print('device', device)
|
| 45 |
+
config = OmegaConf.load(args.config)
|
| 46 |
+
|
| 47 |
+
if config.weight_dtype == "fp16":
|
| 48 |
+
weight_dtype = torch.float16
|
| 49 |
+
else:
|
| 50 |
+
weight_dtype = torch.float32
|
| 51 |
+
|
| 52 |
+
vae = AutoencoderKL.from_pretrained(config.vae_path).to(device, dtype=weight_dtype)
|
| 53 |
+
# if use tiny VAE
|
| 54 |
+
# vae_tiny = AutoencoderTiny.from_pretrained(config.vae_tiny_path).to(device, dtype=weight_dtype)
|
| 55 |
+
|
| 56 |
+
infer_config = OmegaConf.load(config.inference_config)
|
| 57 |
+
reference_unet = UNet2DConditionModel.from_pretrained(
|
| 58 |
+
config.pretrained_base_model_path,
|
| 59 |
+
subfolder="unet",
|
| 60 |
+
).to(device=device, dtype=weight_dtype)
|
| 61 |
+
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
| 62 |
+
config.pretrained_base_model_path,
|
| 63 |
+
"",
|
| 64 |
+
subfolder="unet",
|
| 65 |
+
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
| 66 |
+
).to(dtype=weight_dtype, device=device)
|
| 67 |
+
|
| 68 |
+
motion_encoder = MotEncoder().to(dtype=weight_dtype, device=device).eval()
|
| 69 |
+
pose_guider = PoseGuider().to(device=device, dtype=weight_dtype)
|
| 70 |
+
pose_encoder = MotionExtractor(num_kp=21).to(device=device, dtype=weight_dtype).eval()
|
| 71 |
+
|
| 72 |
+
image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
| 73 |
+
config.image_encoder_path
|
| 74 |
+
).to(dtype=weight_dtype, device=device)
|
| 75 |
+
|
| 76 |
+
sched_kwargs = OmegaConf.to_container(
|
| 77 |
+
OmegaConf.load(config.inference_config).noise_scheduler_kwargs
|
| 78 |
+
)
|
| 79 |
+
scheduler = DDIMScheduler(**sched_kwargs)
|
| 80 |
+
|
| 81 |
+
generator = torch.manual_seed(args.seed)
|
| 82 |
+
width, height = args.W, args.H
|
| 83 |
+
|
| 84 |
+
# load pretrained weights
|
| 85 |
+
denoising_unet.load_state_dict(
|
| 86 |
+
torch.load(config.denoising_unet_path, map_location="cpu"), strict=False
|
| 87 |
+
)
|
| 88 |
+
reference_unet.load_state_dict(
|
| 89 |
+
torch.load(
|
| 90 |
+
config.denoising_unet_path.replace('denoising_unet', 'reference_unet'),
|
| 91 |
+
map_location="cpu",
|
| 92 |
+
),
|
| 93 |
+
strict=True,
|
| 94 |
+
)
|
| 95 |
+
motion_encoder.load_state_dict(
|
| 96 |
+
torch.load(
|
| 97 |
+
config.denoising_unet_path.replace('denoising_unet', 'motion_encoder'),
|
| 98 |
+
map_location="cpu",
|
| 99 |
+
),
|
| 100 |
+
strict=True,
|
| 101 |
+
)
|
| 102 |
+
pose_guider.load_state_dict(
|
| 103 |
+
torch.load(
|
| 104 |
+
config.denoising_unet_path.replace('denoising_unet', 'pose_guider'),
|
| 105 |
+
map_location="cpu",
|
| 106 |
+
),
|
| 107 |
+
strict=True,
|
| 108 |
+
)
|
| 109 |
+
denoising_unet.load_state_dict(
|
| 110 |
+
torch.load(
|
| 111 |
+
config.denoising_unet_path.replace('denoising_unet', 'temporal_module'),
|
| 112 |
+
map_location="cpu",
|
| 113 |
+
),
|
| 114 |
+
strict=False,
|
| 115 |
+
)
|
| 116 |
+
pose_encoder.load_state_dict(
|
| 117 |
+
torch.load(
|
| 118 |
+
config.denoising_unet_path.replace('denoising_unet', 'motion_extractor'),
|
| 119 |
+
map_location="cpu",
|
| 120 |
+
),
|
| 121 |
+
strict=False,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if is_xformers_available():
|
| 125 |
+
reference_unet.enable_xformers_memory_efficient_attention()
|
| 126 |
+
denoising_unet.enable_xformers_memory_efficient_attention()
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
"xformers is not available. Make sure it is installed correctly"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
mp_face_mesh = mp.solutions.face_mesh
|
| 133 |
+
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1)
|
| 134 |
+
|
| 135 |
+
pipe = Pose2VideoPipeline(
|
| 136 |
+
vae=vae,
|
| 137 |
+
# vae_tiny=vae_tiny,
|
| 138 |
+
image_encoder=image_enc,
|
| 139 |
+
reference_unet=reference_unet,
|
| 140 |
+
denoising_unet=denoising_unet,
|
| 141 |
+
motion_encoder=motion_encoder,
|
| 142 |
+
pose_encoder=pose_encoder,
|
| 143 |
+
pose_guider=pose_guider,
|
| 144 |
+
scheduler=scheduler,
|
| 145 |
+
)
|
| 146 |
+
pipe = pipe.to(device)
|
| 147 |
+
|
| 148 |
+
date_str = datetime.now().strftime("%Y%m%d")
|
| 149 |
+
if args.name is None:
|
| 150 |
+
time_str = datetime.now().strftime("%H%M")
|
| 151 |
+
save_dir_name = f"{date_str}--{time_str}"
|
| 152 |
+
else:
|
| 153 |
+
save_dir_name = f"{date_str}--{args.name}"
|
| 154 |
+
save_vid_dir = os.path.join('results', save_dir_name, 'concat_vid')
|
| 155 |
+
os.makedirs(save_vid_dir, exist_ok=True)
|
| 156 |
+
save_split_vid_dir = os.path.join('results', save_dir_name, 'split_vid')
|
| 157 |
+
os.makedirs(save_split_vid_dir, exist_ok=True)
|
| 158 |
+
|
| 159 |
+
pose_transform = transforms.Compose(
|
| 160 |
+
[transforms.Resize((height, width)), transforms.ToTensor()]
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
args.test_cases = OmegaConf.load(args.config)["test_cases"]
|
| 164 |
+
|
| 165 |
+
for ref_image_path in list(args.test_cases.keys()):
|
| 166 |
+
for pose_video_path in args.test_cases[ref_image_path]:
|
| 167 |
+
video_name = os.path.basename(pose_video_path).split(".")[0]
|
| 168 |
+
source_name = os.path.basename(ref_image_path).split(".")[0]
|
| 169 |
+
|
| 170 |
+
vid_name = f"{source_name}_{video_name}.mp4"
|
| 171 |
+
save_vid_path = os.path.join(save_vid_dir, vid_name)
|
| 172 |
+
print(save_vid_path)
|
| 173 |
+
if os.path.exists(save_vid_path):
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
if ref_image_path.endswith('.mp4'):
|
| 177 |
+
src_vid = VideoReader(ref_image_path)
|
| 178 |
+
ref_img = src_vid[0].asnumpy()
|
| 179 |
+
ref_img = Image.fromarray(ref_img).convert("RGB")
|
| 180 |
+
else:
|
| 181 |
+
ref_img = Image.open(ref_image_path).convert("RGB")
|
| 182 |
+
|
| 183 |
+
control = VideoReader(pose_video_path)
|
| 184 |
+
video_length = min(len(control) // 4 * 4, args.L)
|
| 185 |
+
sel_idx = range(len(control))[:video_length]
|
| 186 |
+
control = control.get_batch([sel_idx]).asnumpy() # N, H, W, C
|
| 187 |
+
|
| 188 |
+
ref_image_pil = ref_img.copy()
|
| 189 |
+
ref_patch = crop_face(ref_image_pil, face_mesh)
|
| 190 |
+
ref_face_pil = Image.fromarray(ref_patch).convert("RGB")
|
| 191 |
+
|
| 192 |
+
size = args.H
|
| 193 |
+
generator = torch.Generator(device=device)
|
| 194 |
+
generator.manual_seed(42)
|
| 195 |
+
|
| 196 |
+
dri_faces = []
|
| 197 |
+
ori_pose_images = []
|
| 198 |
+
for idx_control, pose_image_pil in tqdm(enumerate(control[:video_length]), total=video_length, desc='cropping faces'):
|
| 199 |
+
pose_image_pil = Image.fromarray(pose_image_pil).convert("RGB")
|
| 200 |
+
ori_pose_images.append(pose_image_pil)
|
| 201 |
+
dri_face = crop_face(pose_image_pil, face_mesh)
|
| 202 |
+
dri_face_pil = Image.fromarray(dri_face).convert("RGB")
|
| 203 |
+
dri_faces.append(dri_face_pil)
|
| 204 |
+
|
| 205 |
+
face_tensor_list = []
|
| 206 |
+
ori_pose_tensor_list = []
|
| 207 |
+
ref_tensor_list = []
|
| 208 |
+
|
| 209 |
+
for idx, pose_image_pil in enumerate(ori_pose_images):
|
| 210 |
+
face_tensor_list.append(pose_transform(dri_faces[idx]))
|
| 211 |
+
ori_pose_tensor_list.append(pose_transform(pose_image_pil))
|
| 212 |
+
ref_tensor_list.append(pose_transform(ref_image_pil))
|
| 213 |
+
|
| 214 |
+
ref_tensor = torch.stack(ref_tensor_list, dim=0) # (f, c, h, w)
|
| 215 |
+
ref_tensor = ref_tensor.transpose(0, 1).unsqueeze(0) # (c, f, h, w)
|
| 216 |
+
|
| 217 |
+
face_tensor = torch.stack(face_tensor_list, dim=0) # (f, c, h, w)
|
| 218 |
+
face_tensor = face_tensor.transpose(0, 1).unsqueeze(0)
|
| 219 |
+
|
| 220 |
+
ori_pose_tensor = torch.stack(ori_pose_tensor_list, dim=0) # (f, c, h, w)
|
| 221 |
+
ori_pose_tensor = ori_pose_tensor.transpose(0, 1).unsqueeze(0)
|
| 222 |
+
|
| 223 |
+
gen_video = pipe(
|
| 224 |
+
ori_pose_images,
|
| 225 |
+
ref_image_pil,
|
| 226 |
+
dri_faces,
|
| 227 |
+
ref_face_pil,
|
| 228 |
+
width,
|
| 229 |
+
height,
|
| 230 |
+
len(dri_faces),
|
| 231 |
+
num_inference_steps=4,
|
| 232 |
+
guidance_scale=1.0,
|
| 233 |
+
generator=generator,
|
| 234 |
+
temporal_window_size = 4,
|
| 235 |
+
temporal_adaptive_step = 4,
|
| 236 |
+
).videos
|
| 237 |
+
|
| 238 |
+
#Concat it with pose tensor
|
| 239 |
+
video = torch.cat([ref_tensor, face_tensor, ori_pose_tensor, gen_video], dim=0)
|
| 240 |
+
|
| 241 |
+
save_videos_grid(
|
| 242 |
+
video,
|
| 243 |
+
save_vid_path,
|
| 244 |
+
n_rows=4,
|
| 245 |
+
fps=25,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if True:
|
| 249 |
+
save_vid_path = save_vid_path.replace(save_vid_dir, save_split_vid_dir)
|
| 250 |
+
save_videos_grid(gen_video, save_vid_path, n_rows=1, fps=25, crf=18)
|
| 251 |
+
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
args = parse_args()
|
| 254 |
+
main(args)
|
pose2vid_online.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import signal
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, UploadFile, File
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from fastapi.staticfiles import StaticFiles
|
| 10 |
+
from fastapi import Request
|
| 11 |
+
|
| 12 |
+
import markdown2
|
| 13 |
+
import threading
|
| 14 |
+
import logging
|
| 15 |
+
import uuid
|
| 16 |
+
import time
|
| 17 |
+
from types import SimpleNamespace
|
| 18 |
+
import asyncio
|
| 19 |
+
import mimetypes
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from webcam.config import config, Args
|
| 23 |
+
from webcam.util import pil_to_frame, bytes_to_pil, is_firefox, bytes_to_tensor
|
| 24 |
+
from webcam.connection_manager import ConnectionManager, ServerFullException
|
| 25 |
+
import multiprocessing as mp
|
| 26 |
+
|
| 27 |
+
use_trt = True
|
| 28 |
+
|
| 29 |
+
if use_trt:
|
| 30 |
+
from webcam.vid2vid_trt import Pipeline
|
| 31 |
+
else:
|
| 32 |
+
from webcam.vid2vid import Pipeline
|
| 33 |
+
|
| 34 |
+
mimetypes.add_type("application/javascript", ".js")
|
| 35 |
+
|
| 36 |
+
THROTTLE = 0.001
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class App:
|
| 40 |
+
def __init__(self, config: Args, pipeline: Pipeline):
|
| 41 |
+
self.args = config
|
| 42 |
+
self.pipeline = pipeline
|
| 43 |
+
self.app = FastAPI()
|
| 44 |
+
self.conn_manager = ConnectionManager()
|
| 45 |
+
|
| 46 |
+
self.produce_predictions_stop_event = None
|
| 47 |
+
self.produce_predictions_task = None
|
| 48 |
+
self.shutdown_event = asyncio.Event()
|
| 49 |
+
|
| 50 |
+
self.init_app()
|
| 51 |
+
|
| 52 |
+
def init_app(self):
|
| 53 |
+
self.app.add_middleware(
|
| 54 |
+
CORSMiddleware,
|
| 55 |
+
allow_origins=["*"],
|
| 56 |
+
allow_credentials=True,
|
| 57 |
+
allow_methods=["*"],
|
| 58 |
+
allow_headers=["*"],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
@self.app.websocket("/api/ws/{user_id}")
|
| 62 |
+
async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
|
| 63 |
+
try:
|
| 64 |
+
await self.conn_manager.connect(
|
| 65 |
+
user_id, websocket, self.args.max_queue_size
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
sender_task = asyncio.create_task(push_results_to_client(user_id, websocket))
|
| 69 |
+
|
| 70 |
+
if self.produce_predictions_task is None or self.produce_predictions_task.done():
|
| 71 |
+
start_prediction_thread(user_id)
|
| 72 |
+
|
| 73 |
+
await handle_websocket_input(user_id, websocket)
|
| 74 |
+
|
| 75 |
+
except ServerFullException as e:
|
| 76 |
+
logging.error(f"Server Full: {e}")
|
| 77 |
+
except WebSocketDisconnect:
|
| 78 |
+
logging.info(f"User disconnected: {user_id}")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logging.error(f"WS Error: {e}")
|
| 81 |
+
finally:
|
| 82 |
+
if 'sender_task' in locals():
|
| 83 |
+
sender_task.cancel()
|
| 84 |
+
|
| 85 |
+
await self.conn_manager.disconnect(user_id, self.pipeline)
|
| 86 |
+
|
| 87 |
+
if self.produce_predictions_stop_event is not None:
|
| 88 |
+
self.produce_predictions_stop_event.set()
|
| 89 |
+
logging.info(f"Cleaned up user: {user_id}")
|
| 90 |
+
|
| 91 |
+
async def handle_websocket_input(user_id: uuid.UUID, websocket: WebSocket):
|
| 92 |
+
if not self.conn_manager.check_user(user_id):
|
| 93 |
+
raise HTTPException(status_code=404, detail="User not found")
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
while True:
|
| 97 |
+
message = await websocket.receive()
|
| 98 |
+
|
| 99 |
+
if "text" in message:
|
| 100 |
+
try:
|
| 101 |
+
text_data = message["text"]
|
| 102 |
+
data = json.loads(text_data)
|
| 103 |
+
status = data.get("status")
|
| 104 |
+
|
| 105 |
+
if status == "pause":
|
| 106 |
+
params = SimpleNamespace(**{"restart": True})
|
| 107 |
+
await self.conn_manager.update_data(user_id, params)
|
| 108 |
+
elif status == "resume":
|
| 109 |
+
await self.conn_manager.send_json(user_id, {"status": "send_frame"})
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logging.error(f"JSON Parse Error: {e}")
|
| 112 |
+
|
| 113 |
+
elif "bytes" in message:
|
| 114 |
+
image_data = message["bytes"]
|
| 115 |
+
if len(image_data) > 0:
|
| 116 |
+
input_tensor = bytes_to_tensor(image_data)
|
| 117 |
+
params = SimpleNamespace()
|
| 118 |
+
params.image = input_tensor
|
| 119 |
+
self.pipeline.accept_new_params(params)
|
| 120 |
+
|
| 121 |
+
except WebSocketDisconnect:
|
| 122 |
+
raise
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logging.error(f"Input Loop Error: {e}")
|
| 125 |
+
raise
|
| 126 |
+
|
| 127 |
+
async def push_results_to_client(user_id: uuid.UUID, websocket: WebSocket):
|
| 128 |
+
MIN_FPS = 10
|
| 129 |
+
MAX_FPS = 30
|
| 130 |
+
SMOOTHING = 0.8 # EMA smoothing factor
|
| 131 |
+
|
| 132 |
+
last_burst_time = time.time()
|
| 133 |
+
last_queue_size = 0
|
| 134 |
+
sleep_time = 1 / 40 # Initial guess
|
| 135 |
+
|
| 136 |
+
last_frame_time = None
|
| 137 |
+
frame_time_list = []
|
| 138 |
+
|
| 139 |
+
ema_frame_interval = sleep_time
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
while True:
|
| 143 |
+
queue_size = await self.conn_manager.get_output_queue_size(user_id)
|
| 144 |
+
if queue_size > last_queue_size:
|
| 145 |
+
current_burst_time = time.time()
|
| 146 |
+
elapsed = current_burst_time - last_burst_time
|
| 147 |
+
|
| 148 |
+
if queue_size > 0 and elapsed > 0:
|
| 149 |
+
raw_interval = elapsed / queue_size
|
| 150 |
+
ema_frame_interval = SMOOTHING * ema_frame_interval + (1 - SMOOTHING) * raw_interval
|
| 151 |
+
sleep_time = min(max(ema_frame_interval, 1 / MAX_FPS), 1 / MIN_FPS)
|
| 152 |
+
|
| 153 |
+
last_burst_time = current_burst_time
|
| 154 |
+
|
| 155 |
+
last_queue_size = queue_size
|
| 156 |
+
|
| 157 |
+
frame = await self.conn_manager.get_frame(user_id)
|
| 158 |
+
if frame is None:
|
| 159 |
+
await asyncio.sleep(0.001)
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
await websocket.send_bytes(frame)
|
| 163 |
+
|
| 164 |
+
if last_frame_time is None:
|
| 165 |
+
last_frame_time = time.time()
|
| 166 |
+
else:
|
| 167 |
+
frame_time_list.append(time.time() - last_frame_time)
|
| 168 |
+
if len(frame_time_list) > 100:
|
| 169 |
+
frame_time_list.pop(0)
|
| 170 |
+
last_frame_time = time.time()
|
| 171 |
+
|
| 172 |
+
await asyncio.sleep(sleep_time)
|
| 173 |
+
|
| 174 |
+
except asyncio.CancelledError:
|
| 175 |
+
pass
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logging.error(f"Push Result Error: {e}")
|
| 178 |
+
|
| 179 |
+
def start_prediction_thread(user_id):
|
| 180 |
+
self.produce_predictions_stop_event = threading.Event()
|
| 181 |
+
|
| 182 |
+
def prediction_loop(uid, loop, stop_event):
|
| 183 |
+
while not stop_event.is_set():
|
| 184 |
+
images = self.pipeline.produce_outputs()
|
| 185 |
+
if len(images) == 0:
|
| 186 |
+
time.sleep(THROTTLE)
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
frames = list(map(pil_to_frame, images))
|
| 190 |
+
asyncio.run_coroutine_threadsafe(
|
| 191 |
+
self.conn_manager.put_frames_to_output_queue(uid, frames),
|
| 192 |
+
loop
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.produce_predictions_task = asyncio.create_task(asyncio.to_thread(
|
| 196 |
+
prediction_loop, user_id, asyncio.get_running_loop(), self.produce_predictions_stop_event
|
| 197 |
+
))
|
| 198 |
+
|
| 199 |
+
@self.app.get("/api/queue")
|
| 200 |
+
async def get_queue_size():
|
| 201 |
+
queue_size = self.conn_manager.get_user_count()
|
| 202 |
+
return JSONResponse({"queue_size": queue_size})
|
| 203 |
+
|
| 204 |
+
@self.app.get("/api/settings")
|
| 205 |
+
async def settings():
|
| 206 |
+
info_schema = pipeline.Info.schema()
|
| 207 |
+
info = pipeline.Info()
|
| 208 |
+
if info.page_content:
|
| 209 |
+
page_content = markdown2.markdown(info.page_content)
|
| 210 |
+
|
| 211 |
+
input_params = pipeline.InputParams.schema()
|
| 212 |
+
return JSONResponse(
|
| 213 |
+
{
|
| 214 |
+
"info": info_schema,
|
| 215 |
+
"input_params": input_params,
|
| 216 |
+
"max_queue_size": self.args.max_queue_size,
|
| 217 |
+
"page_content": page_content if info.page_content else "",
|
| 218 |
+
}
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
@self.app.post("/api/upload_reference_image")
|
| 222 |
+
async def upload_reference_image(ref_image: UploadFile = File(...)):
|
| 223 |
+
try:
|
| 224 |
+
data = await ref_image.read()
|
| 225 |
+
img = bytes_to_pil(data)
|
| 226 |
+
self.pipeline.fuse_reference(img)
|
| 227 |
+
return {"status": "ok"}
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logging.error(f"Reference image error: {e}")
|
| 230 |
+
raise HTTPException(status_code=500, detail="Failed to process reference image")
|
| 231 |
+
|
| 232 |
+
if not os.path.exists("./demo_w_camera/frontend/public"):
|
| 233 |
+
os.makedirs("./demo_w_camera/frontend/public")
|
| 234 |
+
|
| 235 |
+
self.app.mount(
|
| 236 |
+
"/", StaticFiles(directory="./demo_w_camera/frontend/public", html=True), name="public"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
@self.app.on_event("shutdown")
|
| 240 |
+
async def shutdown_event():
|
| 241 |
+
await self.cleanup()
|
| 242 |
+
|
| 243 |
+
async def cleanup(self):
|
| 244 |
+
print("[App] Starting cleanup process...")
|
| 245 |
+
self.shutdown_event.set()
|
| 246 |
+
|
| 247 |
+
if self.produce_predictions_stop_event is not None:
|
| 248 |
+
self.produce_predictions_stop_event.set()
|
| 249 |
+
|
| 250 |
+
if self.produce_predictions_task is not None:
|
| 251 |
+
self.produce_predictions_task.cancel()
|
| 252 |
+
try:
|
| 253 |
+
await self.produce_predictions_task
|
| 254 |
+
except asyncio.CancelledError:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
await self.conn_manager.disconnect_all(self.pipeline)
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"[App] Error during disconnect_all: {e}")
|
| 261 |
+
|
| 262 |
+
print("[App] Cleanup completed")
|
| 263 |
+
|
| 264 |
+
app_instance = None
|
| 265 |
+
|
| 266 |
+
def signal_handler(signum, frame):
|
| 267 |
+
print(f"\n[Main] Received signal {signum}, shutting down gracefully...")
|
| 268 |
+
if app_instance:
|
| 269 |
+
import threading
|
| 270 |
+
def trigger_cleanup():
|
| 271 |
+
try:
|
| 272 |
+
loop = asyncio.new_event_loop()
|
| 273 |
+
asyncio.set_event_loop(loop)
|
| 274 |
+
loop.run_until_complete(app_instance.cleanup())
|
| 275 |
+
loop.close()
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"[Main] Error during cleanup: {e}")
|
| 278 |
+
|
| 279 |
+
cleanup_thread = threading.Thread(target=trigger_cleanup)
|
| 280 |
+
cleanup_thread.daemon = True
|
| 281 |
+
cleanup_thread.start()
|
| 282 |
+
cleanup_thread.join(timeout=5)
|
| 283 |
+
|
| 284 |
+
sys.exit(0)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
import uvicorn
|
| 289 |
+
signal.signal(signal.SIGINT, signal_handler)
|
| 290 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
| 291 |
+
mp.set_start_method("spawn", force=True)
|
| 292 |
+
|
| 293 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 294 |
+
pipeline = Pipeline(config, device)
|
| 295 |
+
|
| 296 |
+
app_obj = App(config, pipeline)
|
| 297 |
+
app = app_obj.app
|
| 298 |
+
app_instance = app_obj
|
| 299 |
+
|
| 300 |
+
print('init done')
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
uvicorn.run(
|
| 304 |
+
app,
|
| 305 |
+
host=config.host,
|
| 306 |
+
port=config.port,
|
| 307 |
+
reload=config.reload,
|
| 308 |
+
ssl_certfile=config.ssl_certfile,
|
| 309 |
+
ssl_keyfile=config.ssl_keyfile,
|
| 310 |
+
)
|
| 311 |
+
except KeyboardInterrupt:
|
| 312 |
+
try:
|
| 313 |
+
import asyncio
|
| 314 |
+
loop = asyncio.new_event_loop()
|
| 315 |
+
asyncio.set_event_loop(loop)
|
| 316 |
+
loop.run_until_complete(app_obj.cleanup())
|
| 317 |
+
loop.close()
|
| 318 |
+
except Exception as e:
|
| 319 |
+
print(f"[Main] Error during cleanup: {e}")
|
| 320 |
+
sys.exit(0)
|
| 321 |
+
except Exception as e:
|
| 322 |
+
print(f"[Main] Error: {e}")
|
| 323 |
+
sys.exit(1)
|
pretrained_weights/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
pretrained_weights/onnx/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
pretrained_weights/onnx/unet_opt/unet_opt.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:484aee7e8c45cddaac227b6ad331a88a77121dee0886f2152cc4bd0e9974b6fa
|
| 3 |
+
size 96224343
|
pretrained_weights/onnx/unet_opt/unet_opt.onnx.data
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa08ee8770f202be841e00f2bb94809c2ca6ca95ad8663c2917c4c6fa35d963e
|
| 3 |
+
size 3593537864
|
pretrained_weights/personalive/denoising_unet.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0446c4d2387f259d5f3c1ac54a5aefa93400f4672f942856bff2538df046162
|
| 3 |
+
size 4927015578
|
pretrained_weights/personalive/motion_encoder.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff7c6b0a84cd750046e7687f7a6f6bbc21317055bfcacef950ed347debae4d2c
|
| 3 |
+
size 246719031
|
pretrained_weights/personalive/motion_extractor.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:251e6a94ad667a1d0c69526d292677165110ef7f0cf0f6d199f0e414e8aa0ca5
|
| 3 |
+
size 112545506
|
pretrained_weights/personalive/pose_guider.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b997db63343a6a5d489778172d9544bcccaf27e6756505dc6353d84e877269d
|
| 3 |
+
size 4351790
|
pretrained_weights/personalive/reference_unet.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85eb03e6c34fab69f9246ff14b3016789232e56dc4892d0581fea21a3a8480f6
|
| 3 |
+
size 3438324340
|
pretrained_weights/personalive/temporal_module.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:295e8942a453adb48756432d99de103ecba9b840b5b8f6635a0687311cdff30e
|
| 3 |
+
size 1817903018
|
pretrained_weights/tensorrt/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
pretrained_weights/tensorrt/unet_work(H100).engine
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34bd6f7693300be8cf72a099f1160bfaedab7a677bcaf66f18ee33a5b871de50
|
| 3 |
+
size 3697605036
|
results/20251209--personalive_offline/concat_vid/ref_image_driving_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bf93d55acd386d689cda4588e636545219acf9910f1d6292eb6db0bed82c64b
|
| 3 |
+
size 7700854
|
results/20251209--personalive_offline/split_vid/ref_image_driving_video.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a064eb1a2effcb3514450e157ec6903973bca4c1d50a888e9c94c0f40a397213
|
| 3 |
+
size 7605688
|
src/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
src/__pycache__/wrapper.cpython-310.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
src/__pycache__/wrapper_trt.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
src/liveportrait/__pycache__/camera.cpython-310.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
src/liveportrait/__pycache__/camera.cpython-39.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
src/liveportrait/__pycache__/convnextv2.cpython-310.pyc
ADDED
|
Binary file (6.19 kB). View file
|
|
|
src/liveportrait/__pycache__/convnextv2.cpython-39.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
src/liveportrait/__pycache__/motion_extractor.cpython-310.pyc
ADDED
|
Binary file (6.61 kB). View file
|
|
|
src/liveportrait/__pycache__/motion_extractor.cpython-39.pyc
ADDED
|
Binary file (6.61 kB). View file
|
|
|
src/liveportrait/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
src/liveportrait/__pycache__/util.cpython-39.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
src/liveportrait/camera.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
functions for processing and transforming 3D facial keypoints
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
PI = np.pi
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def headpose_pred_to_degree(pred):
|
| 15 |
+
"""
|
| 16 |
+
pred: (bs, 66) or (bs, 1) or others
|
| 17 |
+
"""
|
| 18 |
+
if pred.ndim > 1 and pred.shape[1] == 66:
|
| 19 |
+
# NOTE: note that the average is modified to 97.5
|
| 20 |
+
device = pred.device
|
| 21 |
+
idx_tensor = [idx for idx in range(0, 66)]
|
| 22 |
+
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
|
| 23 |
+
pred = F.softmax(pred, dim=1)
|
| 24 |
+
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5
|
| 25 |
+
|
| 26 |
+
return degree
|
| 27 |
+
|
| 28 |
+
return pred
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_rotation_matrix(pitch_, yaw_, roll_):
|
| 32 |
+
""" the input is in degree
|
| 33 |
+
"""
|
| 34 |
+
# transform to radian
|
| 35 |
+
pitch = pitch_ / 180 * PI
|
| 36 |
+
yaw = yaw_ / 180 * PI
|
| 37 |
+
roll = roll_ / 180 * PI
|
| 38 |
+
|
| 39 |
+
device = pitch.device
|
| 40 |
+
|
| 41 |
+
if pitch.ndim == 1:
|
| 42 |
+
pitch = pitch.unsqueeze(1)
|
| 43 |
+
if yaw.ndim == 1:
|
| 44 |
+
yaw = yaw.unsqueeze(1)
|
| 45 |
+
if roll.ndim == 1:
|
| 46 |
+
roll = roll.unsqueeze(1)
|
| 47 |
+
|
| 48 |
+
# calculate the euler matrix
|
| 49 |
+
bs = pitch.shape[0]
|
| 50 |
+
ones = torch.ones([bs, 1]).to(device)
|
| 51 |
+
zeros = torch.zeros([bs, 1]).to(device)
|
| 52 |
+
x, y, z = pitch, yaw, roll
|
| 53 |
+
|
| 54 |
+
rot_x = torch.cat([
|
| 55 |
+
ones, zeros, zeros,
|
| 56 |
+
zeros, torch.cos(x), -torch.sin(x),
|
| 57 |
+
zeros, torch.sin(x), torch.cos(x)
|
| 58 |
+
], dim=1).reshape([bs, 3, 3])
|
| 59 |
+
|
| 60 |
+
rot_y = torch.cat([
|
| 61 |
+
torch.cos(y), zeros, torch.sin(y),
|
| 62 |
+
zeros, ones, zeros,
|
| 63 |
+
-torch.sin(y), zeros, torch.cos(y)
|
| 64 |
+
], dim=1).reshape([bs, 3, 3])
|
| 65 |
+
|
| 66 |
+
rot_z = torch.cat([
|
| 67 |
+
torch.cos(z), -torch.sin(z), zeros,
|
| 68 |
+
torch.sin(z), torch.cos(z), zeros,
|
| 69 |
+
zeros, zeros, ones
|
| 70 |
+
], dim=1).reshape([bs, 3, 3])
|
| 71 |
+
|
| 72 |
+
rot = rot_z @ rot_y @ rot_x
|
| 73 |
+
return rot.permute(0, 2, 1) # transpose
|
src/liveportrait/convnextv2.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
# from timm.models.layers import trunc_normal_, DropPath
|
| 10 |
+
from src.liveportrait.util import LayerNorm, DropPath, trunc_normal_, GRN
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
__all__ = ['convnextv2_tiny']
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Block(nn.Module):
|
| 17 |
+
""" ConvNeXtV2 Block.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dim (int): Number of input channels.
|
| 21 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, dim, drop_path=0.):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 27 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 28 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 29 |
+
self.act = nn.GELU()
|
| 30 |
+
self.grn = GRN(4 * dim)
|
| 31 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 32 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
input = x
|
| 36 |
+
x = self.dwconv(x)
|
| 37 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 38 |
+
x = self.norm(x)
|
| 39 |
+
x = self.pwconv1(x)
|
| 40 |
+
x = self.act(x)
|
| 41 |
+
x = self.grn(x)
|
| 42 |
+
x = self.pwconv2(x)
|
| 43 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 44 |
+
x = input + self.drop_path(x)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ConvNeXtV2(nn.Module):
|
| 49 |
+
""" ConvNeXt V2
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 53 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 54 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 55 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 56 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 57 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
in_chans=3,
|
| 63 |
+
depths=[3, 3, 9, 3],
|
| 64 |
+
dims=[96, 192, 384, 768],
|
| 65 |
+
drop_path_rate=0.,
|
| 66 |
+
**kwargs
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.depths = depths
|
| 70 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 71 |
+
stem = nn.Sequential(
|
| 72 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 73 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
| 74 |
+
)
|
| 75 |
+
self.downsample_layers.append(stem)
|
| 76 |
+
for i in range(3):
|
| 77 |
+
downsample_layer = nn.Sequential(
|
| 78 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 79 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
| 80 |
+
)
|
| 81 |
+
self.downsample_layers.append(downsample_layer)
|
| 82 |
+
|
| 83 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 84 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 85 |
+
cur = 0
|
| 86 |
+
for i in range(4):
|
| 87 |
+
stage = nn.Sequential(
|
| 88 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
|
| 89 |
+
)
|
| 90 |
+
self.stages.append(stage)
|
| 91 |
+
cur += depths[i]
|
| 92 |
+
|
| 93 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
| 94 |
+
|
| 95 |
+
# NOTE: the output semantic items
|
| 96 |
+
num_bins = kwargs.get('num_bins', 66)
|
| 97 |
+
num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
|
| 98 |
+
self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
|
| 99 |
+
|
| 100 |
+
# print('dims[-1]: ', dims[-1])
|
| 101 |
+
self.fc_scale = nn.Linear(dims[-1], 1) # scale
|
| 102 |
+
self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
|
| 103 |
+
self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
|
| 104 |
+
self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
|
| 105 |
+
self.fc_t = nn.Linear(dims[-1], 3) # translation
|
| 106 |
+
self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
|
| 107 |
+
|
| 108 |
+
def _init_weights(self, m):
|
| 109 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 110 |
+
trunc_normal_(m.weight, std=.02)
|
| 111 |
+
nn.init.constant_(m.bias, 0)
|
| 112 |
+
|
| 113 |
+
def forward_features(self, x):
|
| 114 |
+
for i in range(4):
|
| 115 |
+
x = self.downsample_layers[i](x)
|
| 116 |
+
x = self.stages[i](x)
|
| 117 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
x = self.forward_features(x)
|
| 121 |
+
|
| 122 |
+
# implicit keypoints
|
| 123 |
+
kp = self.fc_kp(x)
|
| 124 |
+
|
| 125 |
+
# pose and expression deformation
|
| 126 |
+
pitch = self.fc_pitch(x)
|
| 127 |
+
yaw = self.fc_yaw(x)
|
| 128 |
+
roll = self.fc_roll(x)
|
| 129 |
+
t = self.fc_t(x)
|
| 130 |
+
# exp = self.fc_exp(x)
|
| 131 |
+
scale = self.fc_scale(x)
|
| 132 |
+
|
| 133 |
+
ret_dct = {
|
| 134 |
+
'pitch': pitch,
|
| 135 |
+
'yaw': yaw,
|
| 136 |
+
'roll': roll,
|
| 137 |
+
't': t,
|
| 138 |
+
# 'exp': exp,
|
| 139 |
+
'scale': scale,
|
| 140 |
+
|
| 141 |
+
'kp': kp, # canonical keypoint
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
return ret_dct
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def convnextv2_tiny(**kwargs):
|
| 148 |
+
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 149 |
+
return model
|
| 150 |
+
|
| 151 |
+
class ConvNeXt(nn.Module):
|
| 152 |
+
""" ConvNeXt V2
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 156 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 157 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 158 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 159 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 160 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
in_chans=3,
|
| 166 |
+
depths=[3, 3, 9, 3],
|
| 167 |
+
dims=[96, 192, 384, 768],
|
| 168 |
+
drop_path_rate=0.,
|
| 169 |
+
**kwargs
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.depths = depths
|
| 173 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 174 |
+
stem = nn.Sequential(
|
| 175 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 176 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
| 177 |
+
)
|
| 178 |
+
self.downsample_layers.append(stem)
|
| 179 |
+
for i in range(3):
|
| 180 |
+
downsample_layer = nn.Sequential(
|
| 181 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 182 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
| 183 |
+
)
|
| 184 |
+
self.downsample_layers.append(downsample_layer)
|
| 185 |
+
|
| 186 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 187 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 188 |
+
cur = 0
|
| 189 |
+
for i in range(4):
|
| 190 |
+
stage = nn.Sequential(
|
| 191 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
|
| 192 |
+
)
|
| 193 |
+
self.stages.append(stage)
|
| 194 |
+
cur += depths[i]
|
| 195 |
+
|
| 196 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
| 197 |
+
|
| 198 |
+
def _init_weights(self, m):
|
| 199 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 200 |
+
trunc_normal_(m.weight, std=.02)
|
| 201 |
+
nn.init.constant_(m.bias, 0)
|
| 202 |
+
|
| 203 |
+
def forward_features(self, x):
|
| 204 |
+
for i in range(4):
|
| 205 |
+
x = self.downsample_layers[i](x)
|
| 206 |
+
x = self.stages[i](x)
|
| 207 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
| 208 |
+
|
| 209 |
+
def forward(self, x):
|
| 210 |
+
x = self.forward_features(x)
|
| 211 |
+
return x
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def convnextv2(**kwargs):
|
| 215 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 216 |
+
return model
|
src/liveportrait/motion_extractor.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch
|
| 9 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 10 |
+
from src.liveportrait.convnextv2 import convnextv2_tiny
|
| 11 |
+
from src.liveportrait.util import filter_state_dict
|
| 12 |
+
from src.liveportrait.camera import headpose_pred_to_degree, get_rotation_matrix
|
| 13 |
+
|
| 14 |
+
model_dict = {
|
| 15 |
+
'convnextv2_tiny': convnextv2_tiny,
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MotionExtractor(ModelMixin):
|
| 20 |
+
def __init__(self, **kwargs):
|
| 21 |
+
super(MotionExtractor, self).__init__()
|
| 22 |
+
|
| 23 |
+
# default is convnextv2_base
|
| 24 |
+
backbone = kwargs.get('backbone', 'convnextv2_tiny')
|
| 25 |
+
self.detector = model_dict.get(backbone)(**kwargs)
|
| 26 |
+
self.register_buffer('idx_tensor', torch.arange(66, dtype=torch.float32))
|
| 27 |
+
|
| 28 |
+
def headpose_pred_to_degree(self, pred):
|
| 29 |
+
"""
|
| 30 |
+
pred: (bs, 66) or (bs, 1) or others
|
| 31 |
+
"""
|
| 32 |
+
if pred.ndim > 1 and pred.shape[1] == 66:
|
| 33 |
+
# NOTE: note that the average is modified to 97.5
|
| 34 |
+
prob = torch.nn.functional.softmax(pred, dim=1)
|
| 35 |
+
degree = torch.matmul(prob, self.idx_tensor)
|
| 36 |
+
degree = degree * 3 - 97.5
|
| 37 |
+
|
| 38 |
+
return degree
|
| 39 |
+
|
| 40 |
+
return pred
|
| 41 |
+
|
| 42 |
+
def load_pretrained(self, init_path: str):
|
| 43 |
+
if init_path not in (None, ''):
|
| 44 |
+
state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
|
| 45 |
+
state_dict = filter_state_dict(state_dict, remove_name='head')
|
| 46 |
+
ret = self.detector.load_state_dict(state_dict, strict=False)
|
| 47 |
+
print(f'Load pretrained model from {init_path}, ret: {ret}')
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
kp_info = self.detector(x)
|
| 51 |
+
return self.get_kp(kp_info)
|
| 52 |
+
|
| 53 |
+
def get_kp(self, kp_info):
|
| 54 |
+
bs = kp_info['kp'].shape[0]
|
| 55 |
+
|
| 56 |
+
angles_raw = torch.cat([kp_info['pitch'], kp_info['yaw'], kp_info['roll']], dim=0) # (3, 66)
|
| 57 |
+
angles_deg = self.headpose_pred_to_degree(angles_raw)[:, None] # (B, 3)
|
| 58 |
+
pitch, yaw, roll = torch.chunk(angles_deg, chunks=3, dim=0)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
kp = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
|
| 62 |
+
t, scale = kp_info['t'], kp_info['scale']
|
| 63 |
+
|
| 64 |
+
rot_mat = get_rotation_matrix(pitch, yaw, roll).to(self.dtype) # (bs, 3, 3)
|
| 65 |
+
|
| 66 |
+
if kp.ndim == 2:
|
| 67 |
+
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
| 68 |
+
else:
|
| 69 |
+
num_kp = kp.shape[1] # Bxnum_kpx3
|
| 70 |
+
|
| 71 |
+
# Eqn.2: s * (R * x_c,s) + t
|
| 72 |
+
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat# + exp.view(bs, num_kp, 3)
|
| 73 |
+
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
| 74 |
+
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
| 75 |
+
|
| 76 |
+
return kp_transformed
|
| 77 |
+
|
| 78 |
+
def interpolate_tensors(self, a: torch.Tensor, b: torch.Tensor, num: int = 10) -> torch.Tensor:
|
| 79 |
+
if a.shape != b.shape:
|
| 80 |
+
raise ValueError(f"Shape mismatch: a.shape={a.shape}, b.shape={b.shape}")
|
| 81 |
+
|
| 82 |
+
B, *rest = a.shape
|
| 83 |
+
alphas = torch.linspace(0, 1, num, device=a.device, dtype=a.dtype)
|
| 84 |
+
view_shape = (num,) + (1,) * len(rest)
|
| 85 |
+
alphas = alphas.view(view_shape) # (1, num, 1, 1, ...)
|
| 86 |
+
|
| 87 |
+
result = (1 - alphas) * a + alphas * b
|
| 88 |
+
return result[:-1]
|
| 89 |
+
|
| 90 |
+
def interpolate_kps(self, ref, motion, num_interp, t_scale=0.5, s_scale=0):
|
| 91 |
+
kp1 = self.detector(ref.to(self.dtype))
|
| 92 |
+
kp2_list = []
|
| 93 |
+
for i in range(0, motion.shape[0], 256):
|
| 94 |
+
motion_chunk = motion[i:i+256]
|
| 95 |
+
kp2_chunk = self.detector(motion_chunk.to(self.dtype))
|
| 96 |
+
kp2_list.append(kp2_chunk)
|
| 97 |
+
kp2 = {}
|
| 98 |
+
for key in kp2_list[0].keys():
|
| 99 |
+
kp2[key] = torch.cat([kp2_chunk[key] for kp2_chunk in kp2_list], dim=0)
|
| 100 |
+
|
| 101 |
+
angles_raw = torch.cat([kp1['pitch'], kp1['yaw'], kp1['roll']], dim=0) # (3, 66)
|
| 102 |
+
angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
|
| 103 |
+
pitch_1, yaw_1, roll_1 = torch.chunk(angles_deg, chunks=3, dim=0)
|
| 104 |
+
|
| 105 |
+
angles_raw = torch.cat([kp2['pitch'], kp2['yaw'], kp2['roll']], dim=0) # (3, 66)
|
| 106 |
+
angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
|
| 107 |
+
pitch_2, yaw_2, roll_2 = torch.chunk(angles_deg, chunks=3, dim=0)
|
| 108 |
+
|
| 109 |
+
pitch_interp = self.interpolate_tensors(pitch_1, pitch_2[:1], num_interp) # Bx(num_interp)x1
|
| 110 |
+
yaw_interp = self.interpolate_tensors(yaw_1, yaw_2[:1], num_interp) # Bx(num_interp)x1
|
| 111 |
+
roll_interp = self.interpolate_tensors(roll_1, roll_2[:1], num_interp) # Bx(num_interp)x1
|
| 112 |
+
|
| 113 |
+
t_1 = kp1['t']
|
| 114 |
+
t_2 = kp2['t']
|
| 115 |
+
t_2 = (t_2 - t_2[0]) * t_scale + t_1
|
| 116 |
+
t_interp = self.interpolate_tensors(t_1, t_2[:1], num_interp)
|
| 117 |
+
|
| 118 |
+
s_1 = kp1['scale']
|
| 119 |
+
s_2 = kp2['scale']
|
| 120 |
+
s_2 = s_2 * s_scale + s_1
|
| 121 |
+
s_interp = self.interpolate_tensors(s_1, s_2[:1], num_interp)
|
| 122 |
+
|
| 123 |
+
kp = kp1['kp'].repeat(num_interp+motion.shape[0]-1, 1)
|
| 124 |
+
|
| 125 |
+
kps_interp = {
|
| 126 |
+
'pitch': torch.cat([pitch_interp, pitch_2], dim=0),
|
| 127 |
+
'yaw': torch.cat([yaw_interp, yaw_2], dim=0),
|
| 128 |
+
'roll': torch.cat([roll_interp, roll_2], dim=0),
|
| 129 |
+
't': torch.cat([t_interp, t_2], dim=0),
|
| 130 |
+
'scale': torch.cat([s_interp, s_2], dim=0),
|
| 131 |
+
'kp': kp
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
kp_intrep = self.get_kp(kps_interp)
|
| 135 |
+
|
| 136 |
+
return kp_intrep
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def interpolate_kps_online(self, ref, motion, num_interp, t_scale=0.5, s_scale=0):
|
| 140 |
+
kp1 = self.detector(ref.to(self.dtype))
|
| 141 |
+
kp_frame1 = self.detector(motion[:1].to(self.dtype))
|
| 142 |
+
kp2 = self.detector(motion.to(self.dtype))
|
| 143 |
+
|
| 144 |
+
angles_raw = torch.cat([kp1['pitch'], kp1['yaw'], kp1['roll']], dim=0) # (3, 66)
|
| 145 |
+
angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
|
| 146 |
+
pitch_1, yaw_1, roll_1 = torch.chunk(angles_deg, chunks=3, dim=0)
|
| 147 |
+
|
| 148 |
+
angles_raw = torch.cat([kp2['pitch'], kp2['yaw'], kp2['roll']], dim=0) # (3, 66)
|
| 149 |
+
angles_deg = self.headpose_pred_to_degree(angles_raw) # (B, 3)
|
| 150 |
+
pitch_2, yaw_2, roll_2 = torch.chunk(angles_deg, chunks=3, dim=0)
|
| 151 |
+
|
| 152 |
+
pitch_interp = self.interpolate_tensors(pitch_1, pitch_2[:1], num_interp) # Bx(num_interp)x1
|
| 153 |
+
yaw_interp = self.interpolate_tensors(yaw_1, yaw_2[:1], num_interp) # Bx(num_interp)x1
|
| 154 |
+
roll_interp = self.interpolate_tensors(roll_1, roll_2[:1], num_interp) # Bx(num_interp)x1
|
| 155 |
+
|
| 156 |
+
t_1 = kp1['t']
|
| 157 |
+
t_2 = kp2['t']
|
| 158 |
+
t_2 = (t_2 - t_2[0]) * t_scale + t_1
|
| 159 |
+
t_interp = self.interpolate_tensors(t_1, t_2[:1], num_interp)
|
| 160 |
+
|
| 161 |
+
s_1 = kp1['scale']
|
| 162 |
+
s_2 = kp2['scale']
|
| 163 |
+
s_2 = s_2 * s_scale + s_1
|
| 164 |
+
s_interp = self.interpolate_tensors(s_1, s_2[:1], num_interp)
|
| 165 |
+
|
| 166 |
+
kp = kp1['kp'].repeat(num_interp+motion.shape[0]-1, 1)
|
| 167 |
+
|
| 168 |
+
kps_interp = {
|
| 169 |
+
'pitch': torch.cat([pitch_interp, pitch_2], dim=0),
|
| 170 |
+
'yaw': torch.cat([yaw_interp, yaw_2], dim=0),
|
| 171 |
+
'roll': torch.cat([roll_interp, roll_2], dim=0),
|
| 172 |
+
't': torch.cat([t_interp, t_2], dim=0),
|
| 173 |
+
'scale': torch.cat([s_interp, s_2], dim=0),
|
| 174 |
+
'kp': kp
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
kp_intrep = self.get_kp(kps_interp)
|
| 178 |
+
|
| 179 |
+
kp_dri = self.get_kp(kp2)
|
| 180 |
+
|
| 181 |
+
return kp_intrep, kp1, kp_frame1, kp_dri
|
| 182 |
+
|
| 183 |
+
def get_kps(self, kp_ref, kp_frame1, motion, t_scale=0.5, s_scale=0):
|
| 184 |
+
kps_motion = self.detector(motion.to(self.dtype))
|
| 185 |
+
|
| 186 |
+
kps_dri = self.get_kp(kps_motion)
|
| 187 |
+
|
| 188 |
+
t_ref = kp_ref['t']
|
| 189 |
+
t_frame1 = kp_frame1['t']
|
| 190 |
+
t_motion = kps_motion['t']
|
| 191 |
+
kps_motion['t'] = (t_motion - t_frame1) * t_scale + t_ref
|
| 192 |
+
|
| 193 |
+
s_ref = kp_ref['scale']
|
| 194 |
+
s_motion = kps_motion['scale']
|
| 195 |
+
kps_motion['scale'] = s_motion * s_scale + s_ref
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
kps_motion['kp'] = kp_ref['kp'].repeat(motion.shape[0], 1)
|
| 199 |
+
|
| 200 |
+
kps_motion = self.get_kp(kps_motion)
|
| 201 |
+
|
| 202 |
+
return kps_motion, kps_dri
|
| 203 |
+
|
| 204 |
+
def inference(self, ref, motion):
|
| 205 |
+
kps_ref = self.detector(ref.to(self.dtype))
|
| 206 |
+
kps_motion = self.detector(motion.to(self.dtype))
|
| 207 |
+
kps_motion['kp'] = kps_ref['kp']
|
| 208 |
+
|
| 209 |
+
kp_s = self.get_kp(kps_ref)
|
| 210 |
+
kp_d = self.get_kp(kps_motion)
|
| 211 |
+
|
| 212 |
+
return kp_s, kp_d
|
src/liveportrait/util.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
This file defines various neural network modules and utility functions, including convolutional and residual blocks,
|
| 5 |
+
normalizations, and functions for spatial transformation and tensor manipulation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
| 12 |
+
import math
|
| 13 |
+
import warnings
|
| 14 |
+
import collections.abc
|
| 15 |
+
from itertools import repeat
|
| 16 |
+
|
| 17 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
| 18 |
+
"""
|
| 19 |
+
Transform a keypoint into gaussian like representation
|
| 20 |
+
"""
|
| 21 |
+
mean = kp
|
| 22 |
+
|
| 23 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean)
|
| 24 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
| 25 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
| 26 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
| 27 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
|
| 28 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
| 29 |
+
|
| 30 |
+
# Preprocess kp shape
|
| 31 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
|
| 32 |
+
mean = mean.view(*shape)
|
| 33 |
+
|
| 34 |
+
mean_sub = (coordinate_grid - mean)
|
| 35 |
+
|
| 36 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
| 37 |
+
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def make_coordinate_grid(spatial_size, ref, **kwargs):
|
| 42 |
+
d, h, w = spatial_size
|
| 43 |
+
x = torch.arange(w).type(ref.dtype).to(ref.device)
|
| 44 |
+
y = torch.arange(h).type(ref.dtype).to(ref.device)
|
| 45 |
+
z = torch.arange(d).type(ref.dtype).to(ref.device)
|
| 46 |
+
|
| 47 |
+
# NOTE: must be right-down-in
|
| 48 |
+
x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
|
| 49 |
+
y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
|
| 50 |
+
z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
|
| 51 |
+
|
| 52 |
+
yy = y.view(1, -1, 1).repeat(d, 1, w)
|
| 53 |
+
xx = x.view(1, 1, -1).repeat(d, h, 1)
|
| 54 |
+
zz = z.view(-1, 1, 1).repeat(1, h, w)
|
| 55 |
+
|
| 56 |
+
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
|
| 57 |
+
|
| 58 |
+
return meshed
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ConvT2d(nn.Module):
|
| 62 |
+
"""
|
| 63 |
+
Upsampling block for use in decoder.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
|
| 67 |
+
super(ConvT2d, self).__init__()
|
| 68 |
+
|
| 69 |
+
self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
|
| 70 |
+
padding=padding, output_padding=output_padding)
|
| 71 |
+
self.norm = nn.InstanceNorm2d(out_features)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
out = self.convT(x)
|
| 75 |
+
out = self.norm(out)
|
| 76 |
+
out = F.leaky_relu(out)
|
| 77 |
+
return out
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ResBlock3d(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
Res block, preserve spatial resolution.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, in_features, kernel_size, padding):
|
| 86 |
+
super(ResBlock3d, self).__init__()
|
| 87 |
+
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
|
| 88 |
+
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
|
| 89 |
+
self.norm1 = nn.BatchNorm3d(in_features, affine=True)
|
| 90 |
+
self.norm2 = nn.BatchNorm3d(in_features, affine=True)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
out = self.norm1(x)
|
| 94 |
+
out = F.relu(out)
|
| 95 |
+
out = self.conv1(out)
|
| 96 |
+
out = self.norm2(out)
|
| 97 |
+
out = F.relu(out)
|
| 98 |
+
out = self.conv2(out)
|
| 99 |
+
out += x
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class UpBlock3d(nn.Module):
|
| 104 |
+
"""
|
| 105 |
+
Upsampling block for use in decoder.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
| 109 |
+
super(UpBlock3d, self).__init__()
|
| 110 |
+
|
| 111 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
| 112 |
+
padding=padding, groups=groups)
|
| 113 |
+
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
out = F.interpolate(x, scale_factor=(1, 2, 2))
|
| 117 |
+
out = self.conv(out)
|
| 118 |
+
out = self.norm(out)
|
| 119 |
+
out = F.relu(out)
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DownBlock2d(nn.Module):
|
| 124 |
+
"""
|
| 125 |
+
Downsampling block for use in encoder.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
| 129 |
+
super(DownBlock2d, self).__init__()
|
| 130 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
|
| 131 |
+
self.norm = nn.BatchNorm2d(out_features, affine=True)
|
| 132 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
out = self.conv(x)
|
| 136 |
+
out = self.norm(out)
|
| 137 |
+
out = F.relu(out)
|
| 138 |
+
out = self.pool(out)
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class DownBlock3d(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Downsampling block for use in encoder.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
| 148 |
+
super(DownBlock3d, self).__init__()
|
| 149 |
+
'''
|
| 150 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
| 151 |
+
padding=padding, groups=groups, stride=(1, 2, 2))
|
| 152 |
+
'''
|
| 153 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
| 154 |
+
padding=padding, groups=groups)
|
| 155 |
+
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
| 156 |
+
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
out = self.conv(x)
|
| 160 |
+
out = self.norm(out)
|
| 161 |
+
out = F.relu(out)
|
| 162 |
+
out = self.pool(out)
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class SameBlock2d(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
Simple block, preserve spatial resolution.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
|
| 172 |
+
super(SameBlock2d, self).__init__()
|
| 173 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
|
| 174 |
+
self.norm = nn.BatchNorm2d(out_features, affine=True)
|
| 175 |
+
if lrelu:
|
| 176 |
+
self.ac = nn.LeakyReLU()
|
| 177 |
+
else:
|
| 178 |
+
self.ac = nn.ReLU()
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
out = self.conv(x)
|
| 182 |
+
out = self.norm(out)
|
| 183 |
+
out = self.ac(out)
|
| 184 |
+
return out
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class Encoder(nn.Module):
|
| 188 |
+
"""
|
| 189 |
+
Hourglass Encoder
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
| 193 |
+
super(Encoder, self).__init__()
|
| 194 |
+
|
| 195 |
+
down_blocks = []
|
| 196 |
+
for i in range(num_blocks):
|
| 197 |
+
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
|
| 198 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
| 199 |
+
|
| 200 |
+
def forward(self, x):
|
| 201 |
+
outs = [x]
|
| 202 |
+
for down_block in self.down_blocks:
|
| 203 |
+
outs.append(down_block(outs[-1]))
|
| 204 |
+
return outs
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class Decoder(nn.Module):
|
| 208 |
+
"""
|
| 209 |
+
Hourglass Decoder
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
| 213 |
+
super(Decoder, self).__init__()
|
| 214 |
+
|
| 215 |
+
up_blocks = []
|
| 216 |
+
|
| 217 |
+
for i in range(num_blocks)[::-1]:
|
| 218 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
| 219 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
| 220 |
+
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
|
| 221 |
+
|
| 222 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
| 223 |
+
self.out_filters = block_expansion + in_features
|
| 224 |
+
|
| 225 |
+
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
|
| 226 |
+
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
out = x.pop()
|
| 230 |
+
for up_block in self.up_blocks:
|
| 231 |
+
out = up_block(out)
|
| 232 |
+
skip = x.pop()
|
| 233 |
+
out = torch.cat([out, skip], dim=1)
|
| 234 |
+
out = self.conv(out)
|
| 235 |
+
out = self.norm(out)
|
| 236 |
+
out = F.relu(out)
|
| 237 |
+
return out
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class Hourglass(nn.Module):
|
| 241 |
+
"""
|
| 242 |
+
Hourglass architecture.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
| 246 |
+
super(Hourglass, self).__init__()
|
| 247 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
| 248 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
| 249 |
+
self.out_filters = self.decoder.out_filters
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
return self.decoder(self.encoder(x))
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class SPADE(nn.Module):
|
| 256 |
+
def __init__(self, norm_nc, label_nc):
|
| 257 |
+
super().__init__()
|
| 258 |
+
|
| 259 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
| 260 |
+
nhidden = 128
|
| 261 |
+
|
| 262 |
+
self.mlp_shared = nn.Sequential(
|
| 263 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
|
| 264 |
+
nn.ReLU())
|
| 265 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
| 266 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
| 267 |
+
|
| 268 |
+
def forward(self, x, segmap):
|
| 269 |
+
normalized = self.param_free_norm(x)
|
| 270 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
| 271 |
+
actv = self.mlp_shared(segmap)
|
| 272 |
+
gamma = self.mlp_gamma(actv)
|
| 273 |
+
beta = self.mlp_beta(actv)
|
| 274 |
+
out = normalized * (1 + gamma) + beta
|
| 275 |
+
return out
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class SPADEResnetBlock(nn.Module):
|
| 279 |
+
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
|
| 280 |
+
super().__init__()
|
| 281 |
+
# Attributes
|
| 282 |
+
self.learned_shortcut = (fin != fout)
|
| 283 |
+
fmiddle = min(fin, fout)
|
| 284 |
+
self.use_se = use_se
|
| 285 |
+
# create conv layers
|
| 286 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
| 287 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
|
| 288 |
+
if self.learned_shortcut:
|
| 289 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
| 290 |
+
# apply spectral norm if specified
|
| 291 |
+
if 'spectral' in norm_G:
|
| 292 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
| 293 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
| 294 |
+
if self.learned_shortcut:
|
| 295 |
+
self.conv_s = spectral_norm(self.conv_s)
|
| 296 |
+
# define normalization layers
|
| 297 |
+
self.norm_0 = SPADE(fin, label_nc)
|
| 298 |
+
self.norm_1 = SPADE(fmiddle, label_nc)
|
| 299 |
+
if self.learned_shortcut:
|
| 300 |
+
self.norm_s = SPADE(fin, label_nc)
|
| 301 |
+
|
| 302 |
+
def forward(self, x, seg1):
|
| 303 |
+
x_s = self.shortcut(x, seg1)
|
| 304 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
|
| 305 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
|
| 306 |
+
out = x_s + dx
|
| 307 |
+
return out
|
| 308 |
+
|
| 309 |
+
def shortcut(self, x, seg1):
|
| 310 |
+
if self.learned_shortcut:
|
| 311 |
+
x_s = self.conv_s(self.norm_s(x, seg1))
|
| 312 |
+
else:
|
| 313 |
+
x_s = x
|
| 314 |
+
return x_s
|
| 315 |
+
|
| 316 |
+
def actvn(self, x):
|
| 317 |
+
return F.leaky_relu(x, 2e-1)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def filter_state_dict(state_dict, remove_name='fc'):
|
| 321 |
+
new_state_dict = {}
|
| 322 |
+
for key in state_dict:
|
| 323 |
+
if remove_name in key:
|
| 324 |
+
continue
|
| 325 |
+
new_state_dict[key] = state_dict[key]
|
| 326 |
+
return new_state_dict
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class GRN(nn.Module):
|
| 330 |
+
""" GRN (Global Response Normalization) layer
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
def __init__(self, dim):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
| 336 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
| 337 |
+
|
| 338 |
+
def forward(self, x):
|
| 339 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
| 340 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 341 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class LayerNorm(nn.Module):
|
| 345 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 346 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 347 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 348 |
+
with shape (batch_size, channels, height, width).
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 352 |
+
super().__init__()
|
| 353 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape, dtype=torch.float32))
|
| 354 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape, dtype=torch.float32))
|
| 355 |
+
self.eps = float(eps)
|
| 356 |
+
self.data_format = data_format
|
| 357 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 358 |
+
raise NotImplementedError
|
| 359 |
+
self.normalized_shape = (normalized_shape, )
|
| 360 |
+
|
| 361 |
+
def _apply(self, fn):
|
| 362 |
+
"""
|
| 363 |
+
重写 _apply,完全接管参数的转换逻辑。
|
| 364 |
+
拦截所有 .cuda(), .cpu(), .half(), .to() 操作。
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
for name, param in self._parameters.items():
|
| 368 |
+
if param is not None:
|
| 369 |
+
dummy_probe = param.data.view(-1)[:1]
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
target_tensor = fn(dummy_probe)
|
| 373 |
+
|
| 374 |
+
target_device = target_tensor.device
|
| 375 |
+
target_dtype = target_tensor.dtype
|
| 376 |
+
except:
|
| 377 |
+
target_device = param.device
|
| 378 |
+
target_dtype = param.dtype
|
| 379 |
+
|
| 380 |
+
if name in ['weight', 'bias']:
|
| 381 |
+
# 核心逻辑:如果是 weight/bias,且目标是半精度,则强制保持 FP32
|
| 382 |
+
if target_dtype in [torch.float16, torch.bfloat16]:
|
| 383 |
+
new_data = param.data.to(device=target_device, dtype=torch.float32)
|
| 384 |
+
else:
|
| 385 |
+
new_data = fn(param.data)
|
| 386 |
+
else:
|
| 387 |
+
new_data = fn(param.data)
|
| 388 |
+
|
| 389 |
+
param.data = new_data
|
| 390 |
+
|
| 391 |
+
if param.grad is not None:
|
| 392 |
+
param.grad.data = param.grad.data.to(device=new_data.device, dtype=new_data.dtype)
|
| 393 |
+
|
| 394 |
+
for name, buf in self._buffers.items():
|
| 395 |
+
if buf is not None:
|
| 396 |
+
self._buffers[name] = fn(buf)
|
| 397 |
+
|
| 398 |
+
return self
|
| 399 |
+
|
| 400 |
+
def forward(self, x):
|
| 401 |
+
dtype = x.dtype
|
| 402 |
+
x = x.float()
|
| 403 |
+
if self.data_format == "channels_last":
|
| 404 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 405 |
+
elif self.data_format == "channels_first":
|
| 406 |
+
x = x.permute(0, 2, 3, 1) # BCHW → BHWC
|
| 407 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 408 |
+
x = x.permute(0, 3, 1, 2) # BHWC → BCHW
|
| 409 |
+
return x.to(dtype)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 413 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 414 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 415 |
+
def norm_cdf(x):
|
| 416 |
+
# Computes standard normal cumulative distribution function
|
| 417 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 418 |
+
|
| 419 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 420 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 421 |
+
"The distribution of values may be incorrect.",
|
| 422 |
+
stacklevel=2)
|
| 423 |
+
|
| 424 |
+
with torch.no_grad():
|
| 425 |
+
# Values are generated by using a truncated uniform distribution and
|
| 426 |
+
# then using the inverse CDF for the normal distribution.
|
| 427 |
+
# Get upper and lower cdf values
|
| 428 |
+
l = norm_cdf((a - mean) / std)
|
| 429 |
+
u = norm_cdf((b - mean) / std)
|
| 430 |
+
|
| 431 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 432 |
+
# [2l-1, 2u-1].
|
| 433 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 434 |
+
|
| 435 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 436 |
+
# standard normal
|
| 437 |
+
tensor.erfinv_()
|
| 438 |
+
|
| 439 |
+
# Transform to proper mean, std
|
| 440 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 441 |
+
tensor.add_(mean)
|
| 442 |
+
|
| 443 |
+
# Clamp to ensure it's in the proper range
|
| 444 |
+
tensor.clamp_(min=a, max=b)
|
| 445 |
+
return tensor
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
|
| 449 |
+
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 450 |
+
|
| 451 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 452 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 453 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 454 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 455 |
+
'survival rate' as the argument.
|
| 456 |
+
|
| 457 |
+
"""
|
| 458 |
+
if drop_prob == 0. or not training:
|
| 459 |
+
return x
|
| 460 |
+
keep_prob = 1 - drop_prob
|
| 461 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 462 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 463 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 464 |
+
random_tensor.div_(keep_prob)
|
| 465 |
+
return x * random_tensor
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class DropPath(nn.Module):
|
| 469 |
+
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
def __init__(self, drop_prob=None, scale_by_keep=True):
|
| 473 |
+
super(DropPath, self).__init__()
|
| 474 |
+
self.drop_prob = drop_prob
|
| 475 |
+
self.scale_by_keep = scale_by_keep
|
| 476 |
+
|
| 477 |
+
def forward(self, x):
|
| 478 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 482 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 483 |
+
|
| 484 |
+
# From PyTorch internals
|
| 485 |
+
def _ntuple(n):
|
| 486 |
+
def parse(x):
|
| 487 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 488 |
+
return tuple(x)
|
| 489 |
+
return tuple(repeat(x, n))
|
| 490 |
+
return parse
|
| 491 |
+
|
| 492 |
+
to_2tuple = _ntuple(2)
|
src/modeling/__pycache__/engine_model.cpython-310.pyc
ADDED
|
Binary file (9.13 kB). View file
|
|
|
src/modeling/__pycache__/framed_models.cpython-310.pyc
ADDED
|
Binary file (5.97 kB). View file
|
|
|
src/modeling/__pycache__/onnx_export.cpython-310.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
src/modeling/engine_model.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorrt as trt
|
| 2 |
+
import pycuda.driver as cuda
|
| 3 |
+
import pycuda.autoinit
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import traceback
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
TRT_LOGGER = trt.Logger()
|
| 11 |
+
SKIP_ENGINE_MODEL_CHECK = True
|
| 12 |
+
|
| 13 |
+
def get_engine(engine_file_path):
|
| 14 |
+
if os.path.exists(engine_file_path):
|
| 15 |
+
print(f"Loading engine from file {engine_file_path}...")
|
| 16 |
+
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
|
| 17 |
+
return runtime.deserialize_cuda_engine(f.read())
|
| 18 |
+
else:
|
| 19 |
+
print(f"No file named {engine_file_path}! Please check the input.")
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def numpy_to_torch_dtype(np_dtype):
|
| 24 |
+
mapping = {
|
| 25 |
+
np.float32: torch.float,
|
| 26 |
+
np.float64: torch.double,
|
| 27 |
+
np.float16: torch.half,
|
| 28 |
+
np.int32: torch.int32,
|
| 29 |
+
np.int64: torch.int64,
|
| 30 |
+
np.int16: torch.int16,
|
| 31 |
+
np.int8: torch.int8,
|
| 32 |
+
np.uint8: torch.uint8,
|
| 33 |
+
np.bool_: torch.bool
|
| 34 |
+
}
|
| 35 |
+
return mapping.get(np_dtype, None)
|
| 36 |
+
|
| 37 |
+
def match_shape(a, b):
|
| 38 |
+
if(len(a) == len(b)):
|
| 39 |
+
return tuple(a) == tuple(b)
|
| 40 |
+
elif len(a) > len(b):
|
| 41 |
+
if(a[0] == 1):
|
| 42 |
+
return match_shape(a[1:], b)
|
| 43 |
+
else:
|
| 44 |
+
if(b[0] == 1):
|
| 45 |
+
return match_shape(a, b[1:])
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
def match_dtype(a, b):
|
| 49 |
+
if(a.__class__ == torch.dtype):
|
| 50 |
+
a = torch.tensor(0,dtype=a).numpy().dtype
|
| 51 |
+
return a == b
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class EngineModel:
|
| 55 |
+
def __init__(self, engine_file_path, stream = None, device_int = 0, extra_lock = None):
|
| 56 |
+
self.device_int = device_int
|
| 57 |
+
self.extra_lock = extra_lock
|
| 58 |
+
if not(self.extra_lock is None):
|
| 59 |
+
self.extra_lock.acquire()
|
| 60 |
+
assert os.path.exists(engine_file_path), "Engine model path not exists!"
|
| 61 |
+
self.ctx = cuda.Device(self.device_int).make_context()
|
| 62 |
+
try:
|
| 63 |
+
self.engine = get_engine(engine_file_path) # 载入TensorRT引擎
|
| 64 |
+
input_nvars = 0
|
| 65 |
+
output_nvars = 0
|
| 66 |
+
self.input_names = []
|
| 67 |
+
self.output_names = []
|
| 68 |
+
|
| 69 |
+
# 【辅助函数】用于获取安全的 Shape (消除 -1)
|
| 70 |
+
def get_safe_shape(engine, name):
|
| 71 |
+
shape = engine.get_tensor_shape(name)
|
| 72 |
+
# 如果形状里包含 -1 (动态维度)
|
| 73 |
+
if -1 in shape:
|
| 74 |
+
# 获取 Profile 0 的 (min, opt, max)
|
| 75 |
+
# 取下标 [2] 即 Max Shape,确保分配足够大的显存
|
| 76 |
+
profile = engine.get_tensor_profile_shape(name, 0)
|
| 77 |
+
if profile:
|
| 78 |
+
print(f"[EngineModel] Detected dynamic shape for {name}: {shape} -> Using Max Profile: {profile[2]}")
|
| 79 |
+
return profile[2]
|
| 80 |
+
else:
|
| 81 |
+
# 如果获取不到 Profile (通常发生在 Output),这是一个风险点
|
| 82 |
+
# 这里为了防止报错,可以尝试打印警告
|
| 83 |
+
print(f"[EngineModel] Warning: Dynamic output {name} has no profile. Mem alloc might fail.")
|
| 84 |
+
return shape
|
| 85 |
+
|
| 86 |
+
for binding in self.engine: # 遍历所有tensor,区分Input/Output
|
| 87 |
+
mode = self.engine.get_tensor_mode(binding)
|
| 88 |
+
if(mode== trt.TensorIOMode.INPUT):
|
| 89 |
+
input_nvars += 1
|
| 90 |
+
self.input_names.append(binding)
|
| 91 |
+
elif(mode == trt.TensorIOMode.OUTPUT):
|
| 92 |
+
output_nvars += 1
|
| 93 |
+
self.output_names.append(binding)
|
| 94 |
+
|
| 95 |
+
self.input_nvars = input_nvars # input的数量
|
| 96 |
+
self.output_nvars = output_nvars # output的数量
|
| 97 |
+
|
| 98 |
+
self.input_shapes = {name : get_safe_shape(self.engine, name) for name in self.input_names} # 获取每个 I/O 的 shape 和 dtype
|
| 99 |
+
self.input_dtypes = {name : self.engine.get_tensor_dtype(name) for name in self.input_names}
|
| 100 |
+
self.input_nbytes = {
|
| 101 |
+
name : trt.volume(self.input_shapes[name]) * trt.nptype(self.input_dtypes[name])().itemsize
|
| 102 |
+
for name in self.input_names
|
| 103 |
+
} # nbytes = tensor 占多少 CUDA 内存(字节数)
|
| 104 |
+
self.output_shapes = {name : get_safe_shape(self.engine, name) for name in self.output_names}
|
| 105 |
+
self.output_dtypes = {name : self.engine.get_tensor_dtype(name) for name in self.output_names}
|
| 106 |
+
self.output_nbytes = {
|
| 107 |
+
name : trt.volume(self.output_shapes[name]) * trt.nptype(self.output_dtypes[name])().itemsize
|
| 108 |
+
for name in self.output_names
|
| 109 |
+
}
|
| 110 |
+
self.dinputs = {name : cuda.mem_alloc(self.input_nbytes[name]) for name in self.input_names} # 为每个输入/输出分配 CUDA 设备内存
|
| 111 |
+
self.doutputs = {name :cuda.mem_alloc(self.output_nbytes[name]) for name in self.output_names}
|
| 112 |
+
self.context = self.engine.create_execution_context() # 创建 ExecutionContext(执行上下文)
|
| 113 |
+
if stream is None:
|
| 114 |
+
self.stream = cuda.Stream()
|
| 115 |
+
else:
|
| 116 |
+
self.stream = stream
|
| 117 |
+
for name in self.input_names: # 绑定 tensor 到 context
|
| 118 |
+
self.context.set_tensor_address(name, int(self.dinputs[name]))
|
| 119 |
+
for name in self.output_names:
|
| 120 |
+
self.context.set_tensor_address(name, int(self.doutputs[name]))
|
| 121 |
+
self.houtputs = {
|
| 122 |
+
name :
|
| 123 |
+
cuda.pagelocked_empty(
|
| 124 |
+
trt.volume(self.output_shapes[name]), dtype=trt.nptype(self.output_dtypes[name])
|
| 125 |
+
) for name in self.output_names
|
| 126 |
+
} # 分配 page-locked host 内存以存储输出
|
| 127 |
+
except:
|
| 128 |
+
self.ctx.pop()
|
| 129 |
+
raise Exception("CUDA Initialization Failed!")
|
| 130 |
+
self.ctx.pop()
|
| 131 |
+
if not(self.extra_lock is None):
|
| 132 |
+
self.extra_lock.release()
|
| 133 |
+
|
| 134 |
+
def __call__(self, skip_check=SKIP_ENGINE_MODEL_CHECK, output_list=[], return_tensor=False, **inputs):
|
| 135 |
+
if not skip_check:
|
| 136 |
+
for name in inputs:
|
| 137 |
+
assert name in self.input_names
|
| 138 |
+
assert match_shape(inputs[name].shape, self.input_shapes[name])
|
| 139 |
+
assert match_dtype(inputs[name].dtype, trt.nptype(self.input_dtypes[name]))
|
| 140 |
+
if not(self.extra_lock is None):
|
| 141 |
+
self.extra_lock.acquire()
|
| 142 |
+
self.ctx.push()
|
| 143 |
+
r = {}
|
| 144 |
+
try:
|
| 145 |
+
|
| 146 |
+
for name in inputs:
|
| 147 |
+
hinput = inputs[name]
|
| 148 |
+
if (isinstance(hinput,torch.Tensor) and hinput.device.type=="cuda" and hinput.device.index==self.device_int):
|
| 149 |
+
hinput_con = hinput.contiguous()
|
| 150 |
+
ptr = hinput_con.data_ptr()
|
| 151 |
+
cuda.memcpy_dtod_async(self.dinputs[name], ptr, self.input_nbytes[name], self.stream)
|
| 152 |
+
else:
|
| 153 |
+
hinput_con = np.ascontiguousarray(hinput)
|
| 154 |
+
cuda.memcpy_htod_async(self.dinputs[name], hinput_con, self.stream)
|
| 155 |
+
for name in self.input_names:
|
| 156 |
+
if name not in inputs:
|
| 157 |
+
self.context.set_input_shape(name, self.input_shapes[name])
|
| 158 |
+
self.context.execute_async_v3(self.stream.handle)
|
| 159 |
+
if(return_tensor):
|
| 160 |
+
for name in output_list:
|
| 161 |
+
t = torch.zeros(trt.volume(self.output_shapes[name]), device=f"cuda:{self.device_int}", dtype=numpy_to_torch_dtype(trt.nptype(self.output_dtypes[name])))
|
| 162 |
+
ptr = t.data_ptr()
|
| 163 |
+
cuda.memcpy_dtod_async(ptr, self.doutputs[name], self.output_nbytes[name], self.stream)
|
| 164 |
+
t = t.reshape(tuple(self.output_shapes[name]))
|
| 165 |
+
r[name] = t
|
| 166 |
+
else:
|
| 167 |
+
for name in output_list:
|
| 168 |
+
cuda.memcpy_dtoh_async(self.houtputs[name], self.doutputs[name], self.stream)
|
| 169 |
+
r[name] = self.houtputs[name]
|
| 170 |
+
self.stream.synchronize()
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print("TensorRT Execution Failed!")
|
| 173 |
+
traceback.print_exc()
|
| 174 |
+
self.ctx.pop()
|
| 175 |
+
if not(self.extra_lock is None):
|
| 176 |
+
self.extra_lock.release()
|
| 177 |
+
return None
|
| 178 |
+
self.ctx.pop()
|
| 179 |
+
if not(self.extra_lock is None):
|
| 180 |
+
self.extra_lock.release()
|
| 181 |
+
return r
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def prefill(self, skip_check=SKIP_ENGINE_MODEL_CHECK, **inputs):
|
| 185 |
+
if not (skip_check):
|
| 186 |
+
for name in inputs:
|
| 187 |
+
in_input = (name in self.input_names)
|
| 188 |
+
assert in_input or (name in self.output_names)
|
| 189 |
+
assert match_shape(inputs[name].shape, self.input_shapes[name] if in_input else self.output_shapes[name])
|
| 190 |
+
assert match_dtype(inputs[name].dtype, trt.nptype(self.input_dtypes[name] if in_input else self.output_dtypes[name]))
|
| 191 |
+
if not(self.extra_lock is None):
|
| 192 |
+
self.extra_lock.acquire()
|
| 193 |
+
self.ctx.push()
|
| 194 |
+
try:
|
| 195 |
+
for name in inputs:
|
| 196 |
+
in_input = (name in self.input_names)
|
| 197 |
+
hinput = inputs[name]
|
| 198 |
+
|
| 199 |
+
dst_ptr = self.dinputs[name] if in_input else self.doutputs[name]
|
| 200 |
+
real_nbytes = 0
|
| 201 |
+
if isinstance(hinput, torch.Tensor):
|
| 202 |
+
real_nbytes = hinput.numel() * hinput.element_size()
|
| 203 |
+
else:
|
| 204 |
+
# 假设是 numpy
|
| 205 |
+
real_nbytes = hinput.nbytes
|
| 206 |
+
|
| 207 |
+
if (isinstance(hinput,torch.Tensor) and hinput.device.type=="cuda" and hinput.device.index==self.device_int):
|
| 208 |
+
hinput_con = hinput.contiguous()
|
| 209 |
+
ptr = hinput_con.data_ptr()
|
| 210 |
+
cuda.memcpy_dtod_async(dst_ptr, ptr, real_nbytes, self.stream)
|
| 211 |
+
else:
|
| 212 |
+
hinput_con = np.ascontiguousarray(hinput)
|
| 213 |
+
cuda.memcpy_htod_async(dst_ptr, hinput, self.stream)
|
| 214 |
+
self.stream.synchronize()
|
| 215 |
+
except Exception as e:
|
| 216 |
+
traceback.print_exc()
|
| 217 |
+
self.ctx.pop()
|
| 218 |
+
if not(self.extra_lock is None):
|
| 219 |
+
self.extra_lock.release()
|
| 220 |
+
return False
|
| 221 |
+
self.ctx.pop()
|
| 222 |
+
if not(self.extra_lock is None):
|
| 223 |
+
self.extra_lock.release()
|
| 224 |
+
return True
|
| 225 |
+
|
| 226 |
+
def __repr__(self):
|
| 227 |
+
r = "TensorRTEngineModel(\n\tInput=[\n"
|
| 228 |
+
for name in self.input_names:
|
| 229 |
+
r += f"\t\t{name}: \t{trt.nptype(self.input_dtypes[name]).__name__}{self.input_shapes[name]},\n"
|
| 230 |
+
r += "\t],Output=[\n"
|
| 231 |
+
for name in self.output_names:
|
| 232 |
+
r += f"\t\t{name}: \t{trt.nptype(self.output_dtypes[name]).__name__}{self.output_shapes[name]},\n"
|
| 233 |
+
r+="\t]\n)"
|
| 234 |
+
return r
|
| 235 |
+
|
| 236 |
+
def link(self, other, var_map, skip_check=SKIP_ENGINE_MODEL_CHECK):
|
| 237 |
+
assert self.device_int == other.device_int
|
| 238 |
+
if not (skip_check):
|
| 239 |
+
for source in var_map:
|
| 240 |
+
assert source in other.output_names
|
| 241 |
+
target = var_map[source]
|
| 242 |
+
assert target in self.input_names
|
| 243 |
+
assert match_shape(other.output_shapes[source], self.input_shapes[target])
|
| 244 |
+
assert match_dtype(other.output_dtypes[source], self.input_dtypes[target])
|
| 245 |
+
|
| 246 |
+
if not(self.extra_lock is None):
|
| 247 |
+
self.extra_lock.acquire()
|
| 248 |
+
self.ctx.push()
|
| 249 |
+
try:
|
| 250 |
+
for source in var_map:
|
| 251 |
+
target = var_map[source]
|
| 252 |
+
self.context.set_tensor_address(target, int(other.doutputs[source]))
|
| 253 |
+
except Exception as e:
|
| 254 |
+
traceback.print_exc()
|
| 255 |
+
self.ctx.pop()
|
| 256 |
+
if not(self.extra_lock is None):
|
| 257 |
+
self.extra_lock.release()
|
| 258 |
+
return False
|
| 259 |
+
self.ctx.pop()
|
| 260 |
+
if not(self.extra_lock is None):
|
| 261 |
+
self.extra_lock.release()
|
| 262 |
+
return True
|
| 263 |
+
|
| 264 |
+
def bind(self, var_map, skip_check=SKIP_ENGINE_MODEL_CHECK):
|
| 265 |
+
if not (skip_check):
|
| 266 |
+
for source in var_map:
|
| 267 |
+
assert source in self.output_names
|
| 268 |
+
target = var_map[source]
|
| 269 |
+
assert target in self.input_names
|
| 270 |
+
assert match_shape(self.output_shapes[source], self.input_shapes[target])
|
| 271 |
+
assert match_dtype(self.output_dtypes[source], self.input_dtypes[target])
|
| 272 |
+
|
| 273 |
+
if not(self.extra_lock is None):
|
| 274 |
+
self.extra_lock.acquire()
|
| 275 |
+
self.ctx.push()
|
| 276 |
+
try:
|
| 277 |
+
for source in var_map:
|
| 278 |
+
target = var_map[source]
|
| 279 |
+
self.context.set_tensor_address(target, int(self.doutputs[source]))
|
| 280 |
+
except Exception as e:
|
| 281 |
+
traceback.print_exc()
|
| 282 |
+
self.ctx.pop()
|
| 283 |
+
if not(self.extra_lock is None):
|
| 284 |
+
self.extra_lock.release()
|
| 285 |
+
return False
|
| 286 |
+
self.ctx.pop()
|
| 287 |
+
if not(self.extra_lock is None):
|
| 288 |
+
self.extra_lock.release()
|
| 289 |
+
return True
|
| 290 |
+
|
| 291 |
+
def unlink(self):
|
| 292 |
+
|
| 293 |
+
if not(self.extra_lock is None):
|
| 294 |
+
self.extra_lock.acquire()
|
| 295 |
+
self.ctx.push()
|
| 296 |
+
try:
|
| 297 |
+
for name in self.input_names:
|
| 298 |
+
self.context.set_tensor_address(name, int(self.dinputs[name]))
|
| 299 |
+
except:
|
| 300 |
+
self.ctx.pop()
|
| 301 |
+
if not(self.extra_lock is None):
|
| 302 |
+
self.extra_lock.release()
|
| 303 |
+
return False
|
| 304 |
+
self.ctx.pop()
|
| 305 |
+
if not(self.extra_lock is None):
|
| 306 |
+
self.extra_lock.release()
|
| 307 |
+
return True
|
| 308 |
+
|
src/modeling/framed_models.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from polygraphy.backend.trt import Profile
|
| 5 |
+
|
| 6 |
+
class unet_work(nn.Module): # Ugly Power Strip
|
| 7 |
+
def __init__(self, pose_guider, motion_encoder, unet, vae, scheduler, timestep):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.pose_guider = pose_guider
|
| 10 |
+
self.motion_encoder = motion_encoder
|
| 11 |
+
self.unet = unet
|
| 12 |
+
self.vae = vae
|
| 13 |
+
self.scheduler = scheduler
|
| 14 |
+
self.timesteps = timestep
|
| 15 |
+
|
| 16 |
+
def decode_slice(self, vae, x):
|
| 17 |
+
x = x / 0.18215
|
| 18 |
+
x = vae.decode(x).sample
|
| 19 |
+
x = rearrange(x, "b c h w -> b h w c")
|
| 20 |
+
x = (x / 2 + 0.5).clamp(0, 1)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
def forward(self, sample, encoder_hidden_states, motion_hidden_states, motion, pose_cond_fea, pose, new_noise,
|
| 24 |
+
d00, d01, d10, d11, d20, d21, m, u10, u11, u12, u20, u21, u22, u30, u31, u32
|
| 25 |
+
):
|
| 26 |
+
new_pose_cond_fea = self.pose_guider(pose)
|
| 27 |
+
pose_cond_fea = torch.cat([pose_cond_fea, new_pose_cond_fea], dim=2)
|
| 28 |
+
new_motion_hidden_states = self.motion_encoder(motion)
|
| 29 |
+
motion_hidden_states = torch.cat([motion_hidden_states, new_motion_hidden_states], dim=1)
|
| 30 |
+
encoder_hidden_states = [encoder_hidden_states, motion_hidden_states]
|
| 31 |
+
score = self.unet(sample, self.timesteps, encoder_hidden_states, pose_cond_fea, d00, d01, d10, d11, d20, d21, m, u10, u11, u12, u20, u21, u22, u30, u31, u32)
|
| 32 |
+
score = rearrange(score, 'b c f h w -> (b f) c h w')
|
| 33 |
+
sample = rearrange(sample, 'b c f h w -> (b f) c h w')
|
| 34 |
+
latents_model_input, pred_original_sample = self.scheduler.step(
|
| 35 |
+
score, self.timesteps, sample, return_dict=False
|
| 36 |
+
)
|
| 37 |
+
latents_model_input = latents_model_input.to(sample.dtype)
|
| 38 |
+
pred_original_sample = pred_original_sample.to(sample.dtype)
|
| 39 |
+
latents_model_input = rearrange(latents_model_input, '(b f) c h w -> b c f h w', f=16)
|
| 40 |
+
pred_video = self.decode_slice(self.vae, pred_original_sample[:4])
|
| 41 |
+
latents = torch.cat([latents_model_input[:, :, 4:, :, :], new_noise], dim=2)
|
| 42 |
+
pose_cond_fea_out = pose_cond_fea[:, :, 4:, :, :]
|
| 43 |
+
motion_hidden_states_out = motion_hidden_states[:, 4:, :, :]
|
| 44 |
+
motion_out = motion_hidden_states[:, :1, :, :]
|
| 45 |
+
return pred_video, latents, pose_cond_fea_out, motion_hidden_states_out, motion_out, pred_original_sample[:1]
|
| 46 |
+
|
| 47 |
+
def get_sample_input(self, batchsize, height, width, dtype, device):
|
| 48 |
+
tw, ts, tb = 4, 4, 16 # temporal window size| temporal adaptive steps | temporal batch size
|
| 49 |
+
ml, mc, mh, mw= 32, 16, 224, 224 # motion latent size | motion channels
|
| 50 |
+
b, h, w = batchsize, height, width
|
| 51 |
+
lh, lw = height // 8, width // 8 # latent height | width
|
| 52 |
+
cd0, cd1, cd2, cm, cu1, cu2, cu3 = 320, 640, 1280, 1280, 1280, 640, 320 # unet channels
|
| 53 |
+
emb = 768 # CLIP Embedding Dims | TAESDV Channels
|
| 54 |
+
lc, ic = 4, 3 # latent | image channels
|
| 55 |
+
profile = {
|
| 56 |
+
"sample" : [b, lc, tb, lh, lw],
|
| 57 |
+
"encoder_hidden_states" : [b, 1, emb],
|
| 58 |
+
"motion_hidden_states" : [b, tw * (ts - 1), ml, mc],
|
| 59 |
+
"motion": [b, ic, tw, mh, mw],
|
| 60 |
+
"pose_cond_fea" : [b, cd0, tw * (ts - 1), lh, lw],
|
| 61 |
+
"pose" : [b, ic, tw, h, w],
|
| 62 |
+
"new_noise" : [b, lc, tw, lh, lw],
|
| 63 |
+
"d00" : [b, lh * lw, cd0],
|
| 64 |
+
"d01" : [b, lh * lw, cd0],
|
| 65 |
+
"d10" : [b, lh * lw // 4, cd1],
|
| 66 |
+
"d11" : [b, lh * lw // 4, cd1],
|
| 67 |
+
"d20" : [b, lh * lw // 16, cd2],
|
| 68 |
+
"d21" : [b, lh * lw // 16, cd2],
|
| 69 |
+
"m" : [b, lh * lw // 64, cm],
|
| 70 |
+
"u10" : [b, lh * lw // 16, cu1],
|
| 71 |
+
"u11" : [b, lh * lw // 16, cu1],
|
| 72 |
+
"u12" : [b, lh * lw // 16, cu1],
|
| 73 |
+
"u20" : [b, lh * lw // 4, cu2],
|
| 74 |
+
"u21" : [b, lh * lw // 4, cu2],
|
| 75 |
+
"u22" : [b, lh * lw // 4, cu2],
|
| 76 |
+
"u30" : [b, lh * lw, cu3],
|
| 77 |
+
"u31" : [b, lh * lw, cu3],
|
| 78 |
+
"u32" : [b, lh * lw, cu3],
|
| 79 |
+
}
|
| 80 |
+
return {k: torch.randn(profile[k], dtype=dtype, device=device) for k in profile}
|
| 81 |
+
|
| 82 |
+
def get_input_names(self):
|
| 83 |
+
return ["sample", "encoder_hidden_states", "motion_hidden_states",
|
| 84 |
+
"motion", "pose_cond_fea", "pose", "new_noise",
|
| 85 |
+
"d00", "d01", "d10", "d11", "d20", "d21", "m", "u10", "u11", "u12",
|
| 86 |
+
"u20", "u21", "u22", "u30", "u31", "u32"]
|
| 87 |
+
|
| 88 |
+
def get_output_names(self):
|
| 89 |
+
return ["pred_video", "latents", "pose_cond_fea_out",
|
| 90 |
+
"motion_hidden_states_out", "motion_out", "latent_first"]
|
| 91 |
+
|
| 92 |
+
def get_dynamic_axes(self):
|
| 93 |
+
dynamic_axes = {
|
| 94 |
+
"sample": {3:"h_64", 4:"w_64"},
|
| 95 |
+
"pose_cond_fea": {3:"h_64", 4:"w_64"},
|
| 96 |
+
"pose": {3:"h_512", 4:"h_512"},
|
| 97 |
+
"new_noise": {3: "h_64", 4: "w_64"},
|
| 98 |
+
"d00" : {1: "len_4096"},
|
| 99 |
+
"d01" : {1: "len_4096"},
|
| 100 |
+
"u30" : {1: "len_4096"},
|
| 101 |
+
"u31" : {1: "len_4096"},
|
| 102 |
+
"u32" : {1: "len_4096"},
|
| 103 |
+
"d10" : {1: "len_1024"},
|
| 104 |
+
"d11" : {1: "len_1024"},
|
| 105 |
+
"u20" : {1: "len_1024"},
|
| 106 |
+
"u21" : {1: "len_1024"},
|
| 107 |
+
"u22" : {1: "len_1024"},
|
| 108 |
+
"d20" : {1: "len_256"},
|
| 109 |
+
"d21" : {1: "len_256"},
|
| 110 |
+
"u10" : {1: "len_256"},
|
| 111 |
+
"u11" : {1: "len_256"},
|
| 112 |
+
"u12" : {1: "len_256"},
|
| 113 |
+
"m" : {1: "len_64"},
|
| 114 |
+
}
|
| 115 |
+
return dynamic_axes
|
| 116 |
+
|
| 117 |
+
def get_dynamic_map(self, batchsize, height, width):
|
| 118 |
+
tw, ts, tb = 4, 4, 16 # temporal window size| temporal adaptive steps | temporal batch size
|
| 119 |
+
ml, mc, mh, mw= 32, 16, 224, 224 # motion latent size | motion channels
|
| 120 |
+
b, h, w = batchsize, height, width
|
| 121 |
+
lh, lw = height // 8, width // 8 # latent height | width
|
| 122 |
+
cd0, cd1, cd2, cm, cu1, cu2, cu3 = 320, 640, 1280, 1280, 1280, 640, 320 # unet channels
|
| 123 |
+
emb = 768 # CLIP Embedding Dims | TAESDV Channels
|
| 124 |
+
lc, ic = 4, 3 # latent | image channels
|
| 125 |
+
|
| 126 |
+
fixed_inputs_map = {
|
| 127 |
+
"sample": (b, lc, tb, lh, lw),
|
| 128 |
+
"encoder_hidden_states": (b, 1, emb),
|
| 129 |
+
"motion_hidden_states": (b, tw * (ts - 1), ml, mc),
|
| 130 |
+
"motion": (b, ic, tw, mh, mw),
|
| 131 |
+
"pose_cond_fea": (b, cd0, tw * (ts - 1), lh, lw),
|
| 132 |
+
"pose": (b, ic, tw, h, w),
|
| 133 |
+
"new_noise": (b, lc, tw, lh, lw),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
dynamic_inputs_map = {
|
| 137 |
+
"d00": (b, lh * lw, cd0),
|
| 138 |
+
"d01": (b, lh * lw, cd0),
|
| 139 |
+
"d10": (b, lh * lw // 4, cd1),
|
| 140 |
+
"d11": (b, lh * lw // 4, cd1),
|
| 141 |
+
"d20": (b, lh * lw // 16, cd2),
|
| 142 |
+
"d21": (b, lh * lw // 16, cd2),
|
| 143 |
+
"m": (b, lh * lw // 64, cm),
|
| 144 |
+
"u10": (b, lh * lw // 16, cu1),
|
| 145 |
+
"u11": (b, lh * lw // 16, cu1),
|
| 146 |
+
"u12": (b, lh * lw // 16, cu1),
|
| 147 |
+
"u20": (b, lh * lw // 4, cu2),
|
| 148 |
+
"u21": (b, lh * lw // 4, cu2),
|
| 149 |
+
"u22": (b, lh * lw // 4, cu2),
|
| 150 |
+
"u30": (b, lh * lw, cu3),
|
| 151 |
+
"u31": (b, lh * lw, cu3),
|
| 152 |
+
"u32": (b, lh * lw, cu3),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
profile = Profile()
|
| 156 |
+
|
| 157 |
+
for name, shape in fixed_inputs_map.items():
|
| 158 |
+
shape_tuple = tuple(shape)
|
| 159 |
+
profile.add(name, min=shape_tuple, opt=shape_tuple, max=shape_tuple)
|
| 160 |
+
|
| 161 |
+
for name, base_shape in dynamic_inputs_map.items():
|
| 162 |
+
|
| 163 |
+
dim0, dim1_base, dim2 = base_shape
|
| 164 |
+
|
| 165 |
+
val_1x = dim1_base * 1
|
| 166 |
+
val_2x = dim1_base * 2
|
| 167 |
+
val_4x = dim1_base * 4
|
| 168 |
+
|
| 169 |
+
min_shape = (dim0, val_1x, dim2)
|
| 170 |
+
opt_shape = (dim0, val_2x, dim2)
|
| 171 |
+
max_shape = (dim0, val_4x, dim2)
|
| 172 |
+
|
| 173 |
+
profile.add(name, min=min_shape, opt=opt_shape, max=max_shape)
|
| 174 |
+
|
| 175 |
+
print(f"Dynamic: {name:<5} | Base(1x): {dim1_base:<5} | Range: {val_1x} ~ {val_4x} | Opt: {val_2x}")
|
| 176 |
+
|
| 177 |
+
return profile
|
src/modeling/onnx_export.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adapted from https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 4 |
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 5 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
import onnx
|
| 21 |
+
import gc
|
| 22 |
+
import onnx_graphsurgeon as gs
|
| 23 |
+
import torch
|
| 24 |
+
from onnx import shape_inference
|
| 25 |
+
from polygraphy.backend.onnx.loader import fold_constants
|
| 26 |
+
import os
|
| 27 |
+
from onnxsim import simplify
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def export_onnx(
|
| 31 |
+
model,
|
| 32 |
+
onnx_path: str,
|
| 33 |
+
opt_image_height: int,
|
| 34 |
+
opt_image_width: int,
|
| 35 |
+
opt_batch_size: int,
|
| 36 |
+
onnx_opset: int,
|
| 37 |
+
dtype,
|
| 38 |
+
device,
|
| 39 |
+
auto_cast: bool = True,
|
| 40 |
+
):
|
| 41 |
+
from contextlib import contextmanager
|
| 42 |
+
|
| 43 |
+
@contextmanager
|
| 44 |
+
def auto_cast_manager(enabled):
|
| 45 |
+
if enabled:
|
| 46 |
+
with torch.inference_mode(), torch.autocast("cuda"):
|
| 47 |
+
yield
|
| 48 |
+
else:
|
| 49 |
+
yield
|
| 50 |
+
|
| 51 |
+
# 确保父目录存在
|
| 52 |
+
os.makedirs(os.path.dirname(onnx_path), exist_ok=True)
|
| 53 |
+
|
| 54 |
+
with auto_cast_manager(auto_cast):
|
| 55 |
+
inputs = model.get_sample_input(opt_batch_size, opt_image_height, opt_image_width, dtype, device)
|
| 56 |
+
|
| 57 |
+
print(model.get_output_names())
|
| 58 |
+
print(f"开始导出 ONNX 模型到: {onnx_path} ...")
|
| 59 |
+
torch.onnx.utils.export(
|
| 60 |
+
model,
|
| 61 |
+
inputs,
|
| 62 |
+
onnx_path,
|
| 63 |
+
export_params=True,
|
| 64 |
+
opset_version=onnx_opset,
|
| 65 |
+
do_constant_folding=True,
|
| 66 |
+
input_names=model.get_input_names(),
|
| 67 |
+
output_names=model.get_output_names(),
|
| 68 |
+
dynamic_axes=model.get_dynamic_axes(),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
del model
|
| 72 |
+
gc.collect()
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
|
| 75 |
+
def optimize_onnx(onnx_path, onnx_opt_path):
|
| 76 |
+
model = onnx.load(onnx_path)
|
| 77 |
+
name = os.path.splitext(os.path.basename(onnx_opt_path))[0]
|
| 78 |
+
model_opt = model
|
| 79 |
+
|
| 80 |
+
print(f"Saving to {onnx_opt_path}...")
|
| 81 |
+
onnx.save(
|
| 82 |
+
model_opt,
|
| 83 |
+
onnx_opt_path,
|
| 84 |
+
save_as_external_data=True,
|
| 85 |
+
all_tensors_to_one_file=True,
|
| 86 |
+
location=f"{name}.onnx.data",
|
| 87 |
+
size_threshold=1024
|
| 88 |
+
)
|
| 89 |
+
print("Optimization done.")
|
| 90 |
+
|
| 91 |
+
def handle_onnx_batch_norm(onnx_path: str):
|
| 92 |
+
onnx_model = onnx.load(onnx_path)
|
| 93 |
+
for node in onnx_model.graph.node:
|
| 94 |
+
if node.op_type == "BatchNormalization":
|
| 95 |
+
for attribute in node.attribute:
|
| 96 |
+
if attribute.name == "training_mode":
|
| 97 |
+
if attribute.i == 1:
|
| 98 |
+
node.output.remove(node.output[1])
|
| 99 |
+
node.output.remove(node.output[1])
|
| 100 |
+
attribute.i = 0
|
| 101 |
+
|
| 102 |
+
onnx.save_model(onnx_model, onnx_path)
|
src/models/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (9.71 kB). View file
|
|
|
src/models/__pycache__/attention.cpython-39.pyc
ADDED
|
Binary file (9.57 kB). View file
|
|
|
src/models/__pycache__/motion_module.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/models/__pycache__/motion_module.cpython-39.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|