rlogh commited on
Commit
239c574
·
verified ·
1 Parent(s): ee549e7

Delete vis.py

Browse files
Files changed (1) hide show
  1. vis.py +0 -364
vis.py DELETED
@@ -1,364 +0,0 @@
1
- import sys
2
- import argparse
3
- import cv2
4
- from lib.preprocess import h36m_coco_format, revise_kpts
5
- from lib.hrnet.gen_kpts import gen_video_kpts as hrnet_pose
6
- import os
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- import glob
11
- from tqdm import tqdm
12
- import copy
13
-
14
- sys.path.append(os.getcwd())
15
- from common.model_poseformer import PoseTransformerV2 as Model
16
- from common.camera import *
17
-
18
- import matplotlib
19
- import matplotlib.pyplot as plt
20
- from mpl_toolkits.mplot3d import Axes3D
21
- import matplotlib.gridspec as gridspec
22
-
23
- plt.switch_backend('agg')
24
- matplotlib.rcParams['pdf.fonttype'] = 42
25
- matplotlib.rcParams['ps.fonttype'] = 42
26
-
27
- def show2Dpose(kps, img):
28
- connections = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5],
29
- [5, 6], [0, 7], [7, 8], [8, 9], [9, 10],
30
- [8, 11], [11, 12], [12, 13], [8, 14], [14, 15], [15, 16]]
31
-
32
- LR = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=bool)
33
-
34
- lcolor = (255, 0, 0)
35
- rcolor = (0, 0, 255)
36
- thickness = 3
37
-
38
- for j,c in enumerate(connections):
39
- start = map(int, kps[c[0]])
40
- end = map(int, kps[c[1]])
41
- start = list(start)
42
- end = list(end)
43
- cv2.line(img, (start[0], start[1]), (end[0], end[1]), lcolor if LR[j] else rcolor, thickness)
44
- cv2.circle(img, (start[0], start[1]), thickness=-1, color=(0, 255, 0), radius=3)
45
- cv2.circle(img, (end[0], end[1]), thickness=-1, color=(0, 255, 0), radius=3)
46
-
47
- return img
48
-
49
-
50
- def show3Dpose(vals, ax):
51
- ax.view_init(elev=15., azim=70)
52
-
53
- lcolor=(0,0,1)
54
- rcolor=(1,0,0)
55
-
56
- I = np.array( [0, 0, 1, 4, 2, 5, 0, 7, 8, 8, 14, 15, 11, 12, 8, 9])
57
- J = np.array( [1, 4, 2, 5, 3, 6, 7, 8, 14, 11, 15, 16, 12, 13, 9, 10])
58
-
59
- LR = np.array([0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], dtype=bool)
60
-
61
- for i in np.arange( len(I) ):
62
- x, y, z = [np.array( [vals[I[i], j], vals[J[i], j]] ) for j in range(3)]
63
- ax.plot(x, y, z, lw=2, color = lcolor if LR[i] else rcolor)
64
-
65
- RADIUS = 0.72
66
- RADIUS_Z = 0.7
67
-
68
- xroot, yroot, zroot = vals[0,0], vals[0,1], vals[0,2]
69
- ax.set_xlim3d([-RADIUS+xroot, RADIUS+xroot])
70
- ax.set_ylim3d([-RADIUS+yroot, RADIUS+yroot])
71
- ax.set_zlim3d([-RADIUS_Z+zroot, RADIUS_Z+zroot])
72
- ax.set_aspect('auto') # works fine in matplotlib==2.2.2
73
-
74
- white = (1.0, 1.0, 1.0, 0.0)
75
- ax.xaxis.set_pane_color(white)
76
- ax.yaxis.set_pane_color(white)
77
- ax.zaxis.set_pane_color(white)
78
-
79
- ax.tick_params('x', labelbottom = False)
80
- ax.tick_params('y', labelleft = False)
81
- ax.tick_params('z', labelleft = False)
82
-
83
-
84
- def get_pose2D(video_path, output_dir):
85
- cap = cv2.VideoCapture(video_path)
86
- width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
87
- height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
88
-
89
- print('\nGenerating 2D pose...')
90
- keypoints, scores = hrnet_pose(video_path, det_dim=416, num_peroson=1, gen_output=True)
91
- keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
92
- re_kpts = revise_kpts(keypoints, scores, valid_frames)
93
- print('Generating 2D pose successful!')
94
-
95
- output_dir += 'input_2D/'
96
- os.makedirs(output_dir, exist_ok=True)
97
-
98
- output_npz = output_dir + 'keypoints.npz'
99
- np.savez_compressed(output_npz, reconstruction=keypoints)
100
-
101
-
102
- def img2video(video_path, output_dir):
103
- cap = cv2.VideoCapture(video_path)
104
- fps = int(cap.get(cv2.CAP_PROP_FPS)) + 5
105
-
106
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
107
-
108
- names = sorted(glob.glob(os.path.join(output_dir + 'pose/', '*.png')))
109
- img = cv2.imread(names[0])
110
- size = (img.shape[1], img.shape[0])
111
-
112
- videoWrite = cv2.VideoWriter(output_dir + video_name + '.mp4', fourcc, fps, size)
113
-
114
- for name in names:
115
- img = cv2.imread(name)
116
- videoWrite.write(img)
117
-
118
- videoWrite.release()
119
-
120
-
121
- def showimage(ax, img):
122
- ax.set_xticks([])
123
- ax.set_yticks([])
124
- plt.axis('off')
125
- ax.imshow(img)
126
-
127
-
128
- def get_pose3D(video_path, output_dir):
129
- args, _ = argparse.ArgumentParser().parse_known_args()
130
- args.embed_dim_ratio, args.depth, args.frames = 32, 4, 243
131
- args.number_of_kept_frames, args.number_of_kept_coeffs = 27, 27
132
- args.pad = (args.frames - 1) // 2
133
- args.previous_dir = 'checkpoint/'
134
- args.n_joints, args.out_joints = 17, 17
135
-
136
- ## Reload
137
- cuda_available = torch.cuda.is_available()
138
- print(f"CUDA available in get_pose3D: {cuda_available}")
139
- if cuda_available:
140
- print(f"CUDA device count: {torch.cuda.device_count()}")
141
- print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
142
-
143
- device = torch.device('cuda' if cuda_available else 'cpu')
144
- print(f"Using device: {device}")
145
-
146
- base_model = Model(args=args)
147
-
148
- # Always use DataParallel when CUDA is available (checkpoint expects it)
149
- if cuda_available:
150
- model = nn.DataParallel(base_model).to(device)
151
- else:
152
- model = base_model.to(device)
153
-
154
- model_dict = model.state_dict()
155
- # Put the pretrained model of PoseFormerV2 in 'checkpoint/']
156
- # model_path = sorted(glob.glob(os.path.join(args.previous_dir, '27_243_45.2.bin')))
157
- # Support both local structure and HF Spaces structure
158
- if os.path.exists("./demo/lib/checkpoint/27_243_45.2.bin"):
159
- model_path = "./demo/lib/checkpoint/27_243_45.2.bin"
160
- elif os.path.exists("./lib/checkpoint/27_243_45.2.bin"):
161
- model_path = "./lib/checkpoint/27_243_45.2.bin"
162
- else:
163
- model_path = "./checkpoint/27_243_45.2.bin"
164
-
165
- map_location = device
166
- pre_dict = torch.load(model_path, map_location=map_location, weights_only=False)
167
-
168
- # Handle DataParallel checkpoint mismatch
169
- state_dict = pre_dict['model_pos']
170
- from collections import OrderedDict
171
- new_state_dict = OrderedDict()
172
-
173
- # Check if we need to add or remove "module." prefix
174
- checkpoint_has_module = any(k.startswith('module.') for k in state_dict.keys())
175
- model_has_module = isinstance(model, nn.DataParallel)
176
-
177
- if checkpoint_has_module and not model_has_module:
178
- # Remove "module." prefix
179
- for k, v in state_dict.items():
180
- name = k[7:] if k.startswith('module.') else k
181
- new_state_dict[name] = v
182
- elif not checkpoint_has_module and model_has_module:
183
- # Add "module." prefix
184
- for k, v in state_dict.items():
185
- name = 'module.' + k if not k.startswith('module.') else k
186
- new_state_dict[name] = v
187
- else:
188
- # No change needed
189
- new_state_dict = state_dict
190
-
191
- model.load_state_dict(new_state_dict, strict=True)
192
-
193
- model.eval()
194
-
195
- ## input
196
- keypoints = np.load(output_dir + 'input_2D/keypoints.npz', allow_pickle=True)['reconstruction']
197
-
198
- cap = cv2.VideoCapture(video_path)
199
- video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
200
-
201
- ## 3D
202
- print('\nGenerating 3D pose...')
203
- keypoints_3D = []
204
- for i in tqdm(range(video_length)):
205
- ret, img = cap.read()
206
- if img is None:
207
- continue
208
- img_size = img.shape
209
-
210
- ## input frames
211
- start = max(0, i - args.pad)
212
- end = min(i + args.pad, len(keypoints[0])-1)
213
-
214
- input_2D_no = keypoints[0][start:end+1]
215
-
216
- left_pad, right_pad = 0, 0
217
- if input_2D_no.shape[0] != args.frames:
218
- if i < args.pad:
219
- left_pad = args.pad - i
220
- if i > len(keypoints[0]) - args.pad - 1:
221
- right_pad = i + args.pad - (len(keypoints[0]) - 1)
222
-
223
- input_2D_no = np.pad(input_2D_no, ((left_pad, right_pad), (0, 0), (0, 0)), 'edge')
224
-
225
- joints_left = [4, 5, 6, 11, 12, 13]
226
- joints_right = [1, 2, 3, 14, 15, 16]
227
-
228
- # input_2D_no += np.random.normal(loc=0.0, scale=5, size=input_2D_no.shape)
229
- input_2D = normalize_screen_coordinates(input_2D_no, w=img_size[1], h=img_size[0])
230
-
231
- input_2D_aug = copy.deepcopy(input_2D)
232
- input_2D_aug[ :, :, 0] *= -1
233
- input_2D_aug[ :, joints_left + joints_right] = input_2D_aug[ :, joints_right + joints_left]
234
- input_2D = np.concatenate((np.expand_dims(input_2D, axis=0), np.expand_dims(input_2D_aug, axis=0)), 0)
235
- # (2, 243, 17, 2)
236
-
237
- input_2D = input_2D[np.newaxis, :, :, :, :]
238
-
239
- input_2D = torch.from_numpy(input_2D.astype('float32')).to(device)
240
-
241
- N = input_2D.size(0)
242
-
243
- ## estimation
244
- output_3D_non_flip = model(input_2D[:, 0])
245
- output_3D_flip = model(input_2D[:, 1])
246
- # [1, 1, 17, 3]
247
-
248
- output_3D_flip[:, :, :, 0] *= -1
249
- output_3D_flip[:, :, joints_left + joints_right, :] = output_3D_flip[:, :, joints_right + joints_left, :]
250
-
251
- output_3D = (output_3D_non_flip + output_3D_flip) / 2
252
-
253
- output_3D[:, :, 0, :] = 0
254
- post_out = output_3D[0, 0].cpu().detach().numpy()
255
- keypoints_3D.append(post_out)
256
- # print(f'Output 3D shape: {output_3D.shape}, post_out shape: {post_out.shape}, output 3D sample: {output_3D[0]}, post out sample: {post_out}')
257
-
258
- rot = [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088]
259
- rot = np.array(rot, dtype='float32')
260
- post_out = camera_to_world(post_out, R=rot, t=0)
261
- post_out[:, 2] -= np.min(post_out[:, 2])
262
-
263
- input_2D_no = input_2D_no[args.pad]
264
-
265
- ## 2D
266
- image = show2Dpose(input_2D_no, copy.deepcopy(img))
267
-
268
- output_dir_2D = output_dir +'pose2D/'
269
- os.makedirs(output_dir_2D, exist_ok=True)
270
- cv2.imwrite(output_dir_2D + str(('%04d'% i)) + '_2D.png', image)
271
-
272
- ## 3D
273
- fig = plt.figure(figsize=(9.6, 5.4))
274
- gs = gridspec.GridSpec(1, 1)
275
- gs.update(wspace=-0.00, hspace=0.05)
276
- ax = plt.subplot(gs[0], projection='3d')
277
- show3Dpose( post_out, ax)
278
-
279
- output_dir_3D = output_dir +'pose3D/'
280
- os.makedirs(output_dir_3D, exist_ok=True)
281
- plt.savefig(output_dir_3D + str(('%04d'% i)) + '_3D.png', dpi=200, format='png', bbox_inches = 'tight')
282
- plt.clf()
283
- plt.close(fig)
284
-
285
- output_npz = output_dir + 'keypoints_3D.npz'
286
- np.savez_compressed(output_npz, reconstruction=keypoints_3D)
287
- print('Generating 3D pose successful!')
288
-
289
- ## all
290
- image_dir = 'results/'
291
- image_2d_dir = sorted(glob.glob(os.path.join(output_dir_2D, '*.png')))
292
- image_3d_dir = sorted(glob.glob(os.path.join(output_dir_3D, '*.png')))
293
-
294
- print('\nGenerating demo...')
295
- for i in tqdm(range(len(image_2d_dir))):
296
- image_2d = plt.imread(image_2d_dir[i])
297
- image_3d = plt.imread(image_3d_dir[i])
298
-
299
- ## crop
300
- edge = (image_2d.shape[1] - image_2d.shape[0]) // 2
301
- image_2d = image_2d[:, edge:image_2d.shape[1] - edge]
302
-
303
- edge = 130
304
- image_3d = image_3d[edge:image_3d.shape[0] - edge, edge:image_3d.shape[1] - edge]
305
-
306
- ## show
307
- font_size = 12
308
- fig = plt.figure(figsize=(15.0, 5.4))
309
- ax = plt.subplot(121)
310
- showimage(ax, image_2d)
311
- ax.set_title("Input", fontsize = font_size)
312
-
313
- ax = plt.subplot(122)
314
- showimage(ax, image_3d)
315
- ax.set_title("Reconstruction", fontsize = font_size)
316
-
317
- ## save
318
- output_dir_pose = output_dir +'pose/'
319
- os.makedirs(output_dir_pose, exist_ok=True)
320
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
321
- plt.margins(0, 0)
322
- plt.savefig(output_dir_pose + str(('%04d'% i)) + '_pose.png', dpi=200, bbox_inches = 'tight')
323
- plt.clf()
324
- plt.close(fig)
325
-
326
- if __name__ == "__main__":
327
- parser = argparse.ArgumentParser()
328
- parser.add_argument('--video', type=str, default='sample_video.mp4', help='input video')
329
- parser.add_argument('--gpu', type=str, default='0', help='GPU device ID (set CUDA_VISIBLE_DEVICES before running if needed)')
330
- args = parser.parse_args()
331
-
332
- # Note: CUDA_VISIBLE_DEVICES must be set BEFORE importing torch
333
- # Since torch is imported at the top, setting it here won't work
334
- # Set it in your environment before running: $env:CUDA_VISIBLE_DEVICES="0" (PowerShell) or export CUDA_VISIBLE_DEVICES=0 (bash)
335
-
336
- # Verify CUDA availability
337
- print(f"CUDA available: {torch.cuda.is_available()}")
338
- if torch.cuda.is_available():
339
- print(f"CUDA device count: {torch.cuda.device_count()}")
340
- print(f"Current device: {torch.cuda.current_device()}")
341
- print(f"Device name: {torch.cuda.get_device_name(0)}")
342
- if "CUDA_VISIBLE_DEVICES" in os.environ:
343
- print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
344
- else:
345
- print("WARNING: CUDA is not available!")
346
- print("This might be because:")
347
- print(" 1. CUDA_VISIBLE_DEVICES was set incorrectly")
348
- print(" 2. PyTorch was installed without CUDA support")
349
- print(" 3. GPU drivers are not installed")
350
- print("\nTo use GPU, set CUDA_VISIBLE_DEVICES BEFORE running Python:")
351
- print(" PowerShell: $env:CUDA_VISIBLE_DEVICES='0'")
352
- print(" Bash: export CUDA_VISIBLE_DEVICES=0")
353
- print("\nOr don't set it at all to use the default GPU")
354
-
355
- video_path = './demo/video/' + args.video
356
- video_name = video_path.split('/')[-1].split('.')[0]
357
- output_dir = './demo/output/' + video_name + '/'
358
-
359
- get_pose2D(video_path, output_dir)
360
- get_pose3D(video_path, output_dir)
361
- img2video(video_path, output_dir)
362
- print('Generating demo successful!')
363
-
364
-