huaichang commited on
Commit
615421b
·
verified ·
1 Parent(s): de64499

Delete torch2trt.py

Browse files
Files changed (1) hide show
  1. torch2trt.py +0 -131
torch2trt.py DELETED
@@ -1,131 +0,0 @@
1
- import os
2
- import torch
3
- from omegaconf import OmegaConf
4
- import gc
5
- from src.modeling.framed_models import unet_work
6
- from diffusers import AutoencoderKL
7
- from diffusers.models.attention_processor import AttnProcessor
8
- from src.scheduler.scheduler_ddim import DDIMScheduler
9
- from src.models.unet_3d_explicit_reference import UNet3DConditionModel
10
- from src.models.motion_encoder.encoder import MotEncoder
11
- from src.models.pose_guider import PoseGuider
12
-
13
- from src.modeling.onnx_export import export_onnx, handle_onnx_batch_norm, optimize_onnx
14
- from polygraphy.backend.trt import engine_from_network, network_from_onnx_path, save_engine, CreateConfig
15
- from polygraphy.logger import G_LOGGER
16
- import tensorrt as trt
17
- G_LOGGER.severity = G_LOGGER.VERBOSE
18
-
19
- def map_device(device_or_str):
20
- return device_or_str if isinstance(device_or_str, torch.device) else torch.device(device_or_str)
21
-
22
- # parameters
23
- batch_size = 1
24
- height = 256
25
- width = 256
26
- onnx_opset = 17
27
- device = torch.device("cuda:0")
28
- device = map_device(device)
29
- dtype = torch.float16
30
-
31
- config_path = './configs/prompts/personalive_trt.yaml'
32
- cfg = OmegaConf.load(config_path)
33
- onnx_path = cfg.onnx_path
34
- onnx_opt_path = cfg.onnx_opt_path
35
- tensorrt_target_model = cfg.tensorrt_target_model
36
-
37
- infer_config = OmegaConf.load(cfg.inference_config)
38
- sched_kwargs = OmegaConf.to_container(
39
- infer_config.noise_scheduler_kwargs
40
- )
41
-
42
- pose_guider = PoseGuider().to(device=device, dtype=dtype)
43
- pose_guider_state_dict = torch.load(cfg.pose_guider_path, map_location="cpu")
44
- pose_guider.load_state_dict(pose_guider_state_dict)
45
- del pose_guider_state_dict
46
-
47
- motion_encoder:MotEncoder = MotEncoder().to(dtype=dtype, device=device).eval()
48
- motion_encoder.set_attn_processor(AttnProcessor())
49
- motion_encoder_state_dict = torch.load(cfg.motion_encoder_path, map_location="cpu")
50
- motion_encoder.load_state_dict(motion_encoder_state_dict)
51
- del motion_encoder_state_dict
52
-
53
- denoising_unet:UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d(
54
- cfg.pretrained_base_model_path,
55
- "",
56
- subfolder="unet",
57
- unet_additional_kwargs=infer_config.unet_additional_kwargs,
58
- ).to(dtype=dtype, device=device)
59
-
60
- denoising_unet.load_state_dict(
61
- torch.load(cfg.denoising_unet_path, map_location="cpu"), strict=False
62
- )
63
- denoising_unet.load_state_dict(
64
- torch.load(
65
- cfg.temporal_module_path,
66
- map_location="cpu",
67
- ),
68
- strict=False,
69
- )
70
- denoising_unet.set_attn_processor(AttnProcessor())
71
-
72
- vae:AutoencoderKL = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(device=device, dtype=dtype)
73
- vae.set_default_attn_processor()
74
-
75
- scheduler = DDIMScheduler(**sched_kwargs)
76
- scheduler.to(device)
77
- timesteps = torch.tensor([0,0,0,0,333,333,333,333,666,666,666,666,999,999,999,999], device=device).long()
78
- scheduler.set_step_length(333)
79
-
80
- model = unet_work(
81
- pose_guider,
82
- motion_encoder,
83
- denoising_unet,
84
- vae,
85
- scheduler,
86
- timesteps,
87
- )
88
-
89
- if(not os.path.exists(os.path.dirname(onnx_path))):
90
- os.mkdir(os.path.dirname(onnx_path))
91
-
92
- if not os.path.exists(onnx_path):
93
- export_onnx(
94
- model,
95
- onnx_path=onnx_path,
96
- opt_image_height=height,
97
- opt_image_width=width,
98
- opt_batch_size=batch_size,
99
- onnx_opset=onnx_opset,
100
- auto_cast=True,
101
- dtype=dtype,
102
- device=device,
103
- )
104
-
105
- batch_size = 1
106
- height = 512
107
- width = 512
108
- profile = model.get_dynamic_map(batch_size, height, width)
109
- del model
110
- gc.collect()
111
- torch.cuda.empty_cache()
112
-
113
- print('finished')
114
-
115
- print("Optimizing Onnx Model...")
116
- if(not os.path.exists(os.path.dirname(onnx_opt_path))):
117
- os.mkdir(os.path.dirname(onnx_opt_path))
118
- optimize_onnx(
119
- onnx_path=onnx_path,
120
- onnx_opt_path=onnx_opt_path,
121
- )
122
-
123
- engine = engine_from_network(
124
- network_from_onnx_path(onnx_opt_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
125
- config=CreateConfig(
126
- fp16=True, refittable=False, profiles=[profile]
127
- ),
128
- )
129
- save_engine(engine, path=tensorrt_target_model)
130
- gc.collect()
131
- torch.cuda.empty_cache()