zlf18 commited on
Commit
8299893
·
verified ·
1 Parent(s): 4af1644

Create demo/vis.py

Browse files
Files changed (1) hide show
  1. demo/vis.py +356 -0
demo/vis.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import cv2
4
+ from lib.preprocess import h36m_coco_format, revise_kpts
5
+ from lib.hrnet.gen_kpts import gen_video_kpts as hrnet_pose
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import glob
11
+ from tqdm import tqdm
12
+ import copy
13
+
14
+ sys.path.append(os.getcwd())
15
+ from common.model_poseformer import PoseTransformerV2 as Model
16
+ from common.camera import *
17
+
18
+ import matplotlib
19
+ import matplotlib.pyplot as plt
20
+ from mpl_toolkits.mplot3d import Axes3D
21
+ import matplotlib.gridspec as gridspec
22
+
23
+ plt.switch_backend('agg')
24
+ matplotlib.rcParams['pdf.fonttype'] = 42
25
+ matplotlib.rcParams['ps.fonttype'] = 42
26
+
27
+ def show2Dpose(kps, img):
28
+ connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
29
+ [5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
30
+ [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]]
31
+
32
+ LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool)
33
+
34
+ lcolor = (255, 0, 0)
35
+ rcolor = (0, 0, 255)
36
+ thickness = 3
37
+
38
+ for j,c in enumerate(connections):
39
+ start = map(int, kps[c[0]])
40
+ end = map(int, kps[c[1]])
41
+ start = list(start)
42
+ end = list(end)
43
+ cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness)
44
+ cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
45
+ cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
46
+
47
+ return img
48
+
49
+
50
+ def show3Dpose(vals, ax):
51
+ ax.view_init(elev=15., azim=70)
52
+
53
+ lcolor=(0,0,1)
54
+ rcolor=(1,0,0)
55
+
56
+ I = np.array( [0, 0, 1, 4, 2, 5, 0, 7, 8, 8, 14, 15, 11, 12, 8, 9])
57
+ J = np.array( [1, 4, 2, 5, 3, 6, 7, 8, 14, 11, 15, 16, 12, 13, 9, 10])
58
+
59
+ LR = np.array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], dtype=bool)
60
+
61
+ for i in np.arange( len(I) ):
62
+ x, y, z = [np.array( [vals[I[i], j], vals[J[i], j]] ) for j in range(3)]
63
+ ax.plot(x, y, z, lw=2, color = lcolor if LR[i] else rcolor)
64
+
65
+ RADIUS = 0.72
66
+ RADIUS_Z = 0.7
67
+
68
+ xroot, yroot, zroot = vals[0,0], vals[0,1], vals[0,2]
69
+ ax.set_xlim3d([-RADIUS+xroot, RADIUS+xroot])
70
+ ax.set_ylim3d([-RADIUS+yroot, RADIUS+yroot])
71
+ ax.set_zlim3d([-RADIUS_Z+zroot, RADIUS_Z+zroot])
72
+ ax.set_aspect('auto') # works fine in matplotlib==2.2.2
73
+
74
+ white = (1.0, 1.0, 1.0, 0.0)
75
+ ax.xaxis.set_pane_color(white)
76
+ ax.yaxis.set_pane_color(white)
77
+ ax.zaxis.set_pane_color(white)
78
+
79
+ ax.tick_params('x', labelbottom = False)
80
+ ax.tick_params('y', labelleft = False)
81
+ ax.tick_params('z', labelleft = False)
82
+
83
+
84
+ def get_pose2D(video_path, output_dir):
85
+ cap = cv2.VideoCapture(video_path)
86
+ width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
87
+ height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
88
+
89
+ print('\nGenerating 2D pose...')
90
+ keypoints, scores = hrnet_pose(video_path, det_dim=416, num_peroson=1, gen_output=True)
91
+ keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
92
+ re_kpts = revise_kpts(keypoints, scores, valid_frames)
93
+ print('Generating 2D pose successful!')
94
+
95
+ output_dir += 'input_2D/'
96
+ os.makedirs(output_dir, exist_ok=True)
97
+
98
+ output_npz = output_dir + 'keypoints.npz'
99
+ np.savez_compressed(output_npz, reconstruction=keypoints)
100
+
101
+
102
+ def img2video(video_path, output_dir):
103
+ cap = cv2.VideoCapture(video_path)
104
+ fps = int(cap.get(cv2.CAP_PROP_FPS)) + 5
105
+
106
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
107
+
108
+ names = sorted(glob.glob(os.path.join(output_dir + 'pose/', '*.png')))
109
+ img = cv2.imread(names[0])
110
+ size = (img.shape[1], img.shape[0])
111
+
112
+ videoWrite = cv2.VideoWriter(output_dir + video_name + '.mp4', fourcc, fps, size)
113
+
114
+ for name in names:
115
+ img = cv2.imread(name)
116
+ videoWrite.write(img)
117
+
118
+ videoWrite.release()
119
+
120
+
121
+ def showimage(ax, img):
122
+ ax.set_xticks([])
123
+ ax.set_yticks([])
124
+ plt.axis('off')
125
+ ax.imshow(img)
126
+
127
+
128
+ def get_pose3D(video_path, output_dir):
129
+ args, _ = argparse.ArgumentParser().parse_known_args()
130
+ args.embed_dim_ratio, args.depth, args.frames = 32, 4, 243
131
+ args.number_of_kept_frames, args.number_of_kept_coeffs = 27, 27
132
+ args.pad = (args.frames - 1) // 2
133
+ args.previous_dir = 'checkpoint/'
134
+ args.n_joints, args.out_joints = 17, 17
135
+
136
+ ## Reload
137
+ cuda_available = torch.cuda.is_available()
138
+ print(f"CUDA available in get_pose3D: {cuda_available}")
139
+ if cuda_available:
140
+ print(f"CUDA device count: {torch.cuda.device_count()}")
141
+ print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
142
+
143
+ device = torch.device('cuda' if cuda_available else 'cpu')
144
+ print(f"Using device: {device}")
145
+
146
+ base_model = Model(args=args)
147
+
148
+ # Always use DataParallel when CUDA is available (checkpoint expects it)
149
+ if cuda_available:
150
+ model = nn.DataParallel(base_model).to(device)
151
+ else:
152
+ model = base_model.to(device)
153
+
154
+ model_dict = model.state_dict()
155
+ # Put the pretrained model of PoseFormerV2 in 'checkpoint/']
156
+ # model_path = sorted(glob.glob(os.path.join(args.previous_dir, '27_243_45.2.bin')))
157
+ model_path = "./demo/lib/checkpoint/27_243_45.2.bin"
158
+
159
+ map_location = device
160
+ pre_dict = torch.load(model_path, map_location=map_location, weights_only=False)
161
+
162
+ # Handle DataParallel checkpoint mismatch
163
+ state_dict = pre_dict['model_pos']
164
+ from collections import OrderedDict
165
+ new_state_dict = OrderedDict()
166
+
167
+ # Check if we need to add or remove "module." prefix
168
+ checkpoint_has_module = any(k.startswith('module.') for k in state_dict.keys())
169
+ model_has_module = isinstance(model, nn.DataParallel)
170
+
171
+ if checkpoint_has_module and not model_has_module:
172
+ # Remove "module." prefix
173
+ for k, v in state_dict.items():
174
+ name = k[7:] if k.startswith('module.') else k
175
+ new_state_dict[name] = v
176
+ elif not checkpoint_has_module and model_has_module:
177
+ # Add "module." prefix
178
+ for k, v in state_dict.items():
179
+ name = 'module.' + k if not k.startswith('module.') else k
180
+ new_state_dict[name] = v
181
+ else:
182
+ # No change needed
183
+ new_state_dict = state_dict
184
+
185
+ model.load_state_dict(new_state_dict, strict=True)
186
+
187
+ model.eval()
188
+
189
+ ## input
190
+ keypoints = np.load(output_dir + 'input_2D/keypoints.npz', allow_pickle=True)['reconstruction']
191
+
192
+ cap = cv2.VideoCapture(video_path)
193
+ video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
194
+
195
+ ## 3D
196
+ print('\nGenerating 3D pose...')
197
+ keypoints_3D = []
198
+ for i in tqdm(range(video_length)):
199
+ ret, img = cap.read()
200
+ if img is None:
201
+ continue
202
+ img_size = img.shape
203
+
204
+ ## input frames
205
+ start = max(0, i - args.pad)
206
+ end = min(i + args.pad, len(keypoints[0])-1)
207
+
208
+ input_2D_no = keypoints[0][start:end+1]
209
+
210
+ left_pad, right_pad = 0, 0
211
+ if input_2D_no.shape[0] != args.frames:
212
+ if i < args.pad:
213
+ left_pad = args.pad - i
214
+ if i > len(keypoints[0]) - args.pad - 1:
215
+ right_pad = i + args.pad - (len(keypoints[0]) - 1)
216
+
217
+ input_2D_no = np.pad(input_2D_no, ((left_pad, right_pad), (0, 0), (0, 0)), 'edge')
218
+
219
+ joints_left = [4, 5, 6, 11, 12, 13]
220
+ joints_right = [1, 2, 3, 14, 15, 16]
221
+
222
+ # input_2D_no += np.random.normal(loc=0.0, scale=5, size=input_2D_no.shape)
223
+ input_2D = normalize_screen_coordinates(input_2D_no, w=img_size[1], h=img_size[0])
224
+
225
+ input_2D_aug = copy.deepcopy(input_2D)
226
+ input_2D_aug[ :, :, 0] *= -1
227
+ input_2D_aug[ :, joints_left + joints_right] = input_2D_aug[ :, joints_right + joints_left]
228
+ input_2D = np.concatenate((np.expand_dims(input_2D, axis=0), np.expand_dims(input_2D_aug, axis=0)), 0)
229
+ # (2, 243, 17, 2)
230
+
231
+ input_2D = input_2D[np.newaxis, :, :, :, :]
232
+
233
+ input_2D = torch.from_numpy(input_2D.astype('float32')).to(device)
234
+
235
+ N = input_2D.size(0)
236
+
237
+ ## estimation
238
+ output_3D_non_flip = model(input_2D[:, 0])
239
+ output_3D_flip = model(input_2D[:, 1])
240
+ # [1, 1, 17, 3]
241
+
242
+ output_3D_flip[:, :, :, 0] *= -1
243
+ output_3D_flip[:, :, joints_left + joints_right, :] = output_3D_flip[:, :, joints_right + joints_left, :]
244
+
245
+ output_3D = (output_3D_non_flip + output_3D_flip) / 2
246
+
247
+ output_3D[:, :, 0, :] = 0
248
+ post_out = output_3D[0, 0].cpu().detach().numpy()
249
+ keypoints_3D.append(post_out)
250
+ # print(f'Output 3D shape: {output_3D.shape}, post_out shape: {post_out.shape}, output 3D sample: {output_3D[0]}, post out sample: {post_out}')
251
+
252
+ rot = [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088]
253
+ rot = np.array(rot, dtype='float32')
254
+ post_out = camera_to_world(post_out, R=rot, t=0)
255
+ post_out[:, 2] -= np.min(post_out[:, 2])
256
+
257
+ input_2D_no = input_2D_no[args.pad]
258
+
259
+ ## 2D
260
+ image = show2Dpose(input_2D_no, copy.deepcopy(img))
261
+
262
+ output_dir_2D = output_dir +'pose2D/'
263
+ os.makedirs(output_dir_2D, exist_ok=True)
264
+ cv2.imwrite(output_dir_2D + str(('%04d'% i)) + '_2D.png', image)
265
+
266
+ ## 3D
267
+ fig = plt.figure(figsize=(9.6, 5.4))
268
+ gs = gridspec.GridSpec(1, 1)
269
+ gs.update(wspace=-0.00, hspace=0.05)
270
+ ax = plt.subplot(gs[0], projection='3d')
271
+ show3Dpose( post_out, ax)
272
+
273
+ output_dir_3D = output_dir +'pose3D/'
274
+ os.makedirs(output_dir_3D, exist_ok=True)
275
+ plt.savefig(output_dir_3D + str(('%04d'% i)) + '_3D.png', dpi=200, format='png', bbox_inches = 'tight')
276
+ plt.clf()
277
+ plt.close(fig)
278
+
279
+ output_npz = output_dir + 'keypoints_3D.npz'
280
+ np.savez_compressed(output_npz, reconstruction=keypoints_3D)
281
+ print('Generating 3D pose successful!')
282
+
283
+ ## all
284
+ image_dir = 'results/'
285
+ image_2d_dir = sorted(glob.glob(os.path.join(output_dir_2D, '*.png')))
286
+ image_3d_dir = sorted(glob.glob(os.path.join(output_dir_3D, '*.png')))
287
+
288
+ print('\nGenerating demo...')
289
+ for i in tqdm(range(len(image_2d_dir))):
290
+ image_2d = plt.imread(image_2d_dir[i])
291
+ image_3d = plt.imread(image_3d_dir[i])
292
+
293
+ ## crop
294
+ edge = (image_2d.shape[1] - image_2d.shape[0]) // 2
295
+ image_2d = image_2d[:, edge:image_2d.shape[1] - edge]
296
+
297
+ edge = 130
298
+ image_3d = image_3d[edge:image_3d.shape[0] - edge, edge:image_3d.shape[1] - edge]
299
+
300
+ ## show
301
+ font_size = 12
302
+ fig = plt.figure(figsize=(15.0, 5.4))
303
+ ax = plt.subplot(121)
304
+ showimage(ax, image_2d)
305
+ ax.set_title("Input", fontsize = font_size)
306
+
307
+ ax = plt.subplot(122)
308
+ showimage(ax, image_3d)
309
+ ax.set_title("Reconstruction", fontsize = font_size)
310
+
311
+ ## save
312
+ output_dir_pose = output_dir +'pose/'
313
+ os.makedirs(output_dir_pose, exist_ok=True)
314
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
315
+ plt.margins(0, 0)
316
+ plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.png', dpi=200, bbox_inches = 'tight')
317
+ plt.clf()
318
+ plt.close(fig)
319
+
320
+ if __name__ == "__main__":
321
+ parser = argparse.ArgumentParser()
322
+ parser.add_argument('--video', type=str, default='sample_video.mp4', help='input video')
323
+ parser.add_argument('--gpu', type=str, default='0', help='GPU device ID (set CUDA_VISIBLE_DEVICES before running if needed)')
324
+ args = parser.parse_args()
325
+
326
+ # Note: CUDA_VISIBLE_DEVICES must be set BEFORE importing torch
327
+ # Since torch is imported at the top, setting it here won't work
328
+ # Set it in your environment before running: $env:CUDA_VISIBLE_DEVICES="0" (PowerShell) or export CUDA_VISIBLE_DEVICES=0 (bash)
329
+
330
+ # Verify CUDA availability
331
+ print(f"CUDA available: {torch.cuda.is_available()}")
332
+ if torch.cuda.is_available():
333
+ print(f"CUDA device count: {torch.cuda.device_count()}")
334
+ print(f"Current device: {torch.cuda.current_device()}")
335
+ print(f"Device name: {torch.cuda.get_device_name(0)}")
336
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
337
+ print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
338
+ else:
339
+ print("WARNING: CUDA is not available!")
340
+ print("This might be because:")
341
+ print(" 1. CUDA_VISIBLE_DEVICES was set incorrectly")
342
+ print(" 2. PyTorch was installed without CUDA support")
343
+ print(" 3. GPU drivers are not installed")
344
+ print("\nTo use GPU, set CUDA_VISIBLE_DEVICES BEFORE running Python:")
345
+ print(" PowerShell: $env:CUDA_VISIBLE_DEVICES='0'")
346
+ print(" Bash: export CUDA_VISIBLE_DEVICES=0")
347
+ print("\nOr don't set it at all to use the default GPU")
348
+
349
+ video_path = './demo/video/' + args.video
350
+ video_name = video_path.split('/')[-1].split('.')[0]
351
+ output_dir = './demo/output/' + video_name + '/'
352
+
353
+ get_pose2D(video_path, output_dir)
354
+ get_pose3D(video_path, output_dir)
355
+ img2video(video_path, output_dir)
356
+ print('Generating demo successful!')