linoyts HF Staff commited on
Commit
ed65396
·
verified ·
1 Parent(s): 11c1efe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -53
app.py CHANGED
@@ -8,56 +8,137 @@ from transformers import AutoProcessor, VitPoseForPoseEstimation, RTDetrForObjec
8
  from PIL import Image
9
  import torch
10
 
 
 
 
11
 
12
- # COCO keypoint connections for skeleton visualization
13
- SKELETON_EDGES = [
14
  (0, 1), (0, 2), (1, 3), (2, 4), # head
15
  (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # arms
16
  (5, 11), (6, 12), (11, 12), # torso
17
  (11, 13), (13, 15), (12, 14), (14, 16) # legs
18
  ]
19
 
20
- KEYPOINT_COLORS = [
21
  (255, 0, 0), (255, 85, 0), (255, 170, 0), (255, 255, 0),
22
  (170, 255, 0), (85, 255, 0), (0, 255, 0), (0, 255, 85),
23
  (0, 255, 170), (0, 255, 255), (0, 170, 255), (0, 85, 255),
24
  (0, 0, 255), (85, 0, 255), (170, 0, 255), (255, 0, 255), (255, 0, 170)
25
  ]
26
 
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Load models
30
  print("Loading models...")
 
 
31
  person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
32
  person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
33
 
34
- vitpose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple").to(device)
35
- vitpose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
36
- print("Models loaded!")
 
 
 
 
 
 
37
 
 
38
 
39
- def detect_person(image, processor, model):
40
- """Detect person bounding box in image."""
41
- inputs = processor(images=image, return_tensors="pt").to(device)
42
  with torch.no_grad():
43
- outputs = model(**inputs)
44
 
45
- results = processor.post_process_object_detection(
46
  outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
47
  )
48
 
49
- # Find person detections (class 0 in COCO)
50
  boxes = []
51
  for result in results:
52
  for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
53
- if label.item() == 0: # person class
54
  boxes.append(box.cpu().numpy())
55
 
56
  return boxes if boxes else None
57
 
 
58
  @spaces.GPU
59
- def estimate_pose(image, boxes, processor, model):
60
- """Estimate pose keypoints for detected persons."""
 
 
 
 
 
 
 
61
  inputs = processor(images=image, boxes=[boxes], return_tensors="pt").to(device)
62
  with torch.no_grad():
63
  outputs = model(**inputs)
@@ -65,35 +146,90 @@ def estimate_pose(image, boxes, processor, model):
65
  pose_results = processor.post_process_pose_estimation(outputs, boxes=[boxes])
66
  return pose_results[0] if pose_results else None
67
 
 
68
 
69
- def draw_skeleton(frame, keypoints, scores, threshold=0.3):
70
- """Draw skeleton on a black background."""
71
  h, w = frame.shape[:2]
72
  skeleton_frame = np.zeros((h, w, 3), dtype=np.uint8)
73
 
74
  if keypoints is None:
75
- return skeleton_frame
76
 
77
- # Draw edges
78
- for start_idx, end_idx in SKELETON_EDGES:
79
- if scores[start_idx] > threshold and scores[end_idx] > threshold:
80
- start_point = (int(keypoints[start_idx][0]), int(keypoints[start_idx][1]))
81
- end_point = (int(keypoints[end_idx][0]), int(keypoints[end_idx][1]))
82
- color = KEYPOINT_COLORS[start_idx % len(KEYPOINT_COLORS)]
83
- cv2.line(skeleton_frame, start_point, end_point, color, 3)
 
 
 
84
 
85
- # Draw keypoints
86
  for i, (kp, score) in enumerate(zip(keypoints, scores)):
87
  if score > threshold:
88
  x, y = int(kp[0]), int(kp[1])
89
- color = KEYPOINT_COLORS[i % len(KEYPOINT_COLORS)]
90
  cv2.circle(skeleton_frame, (x, y), 6, color, -1)
91
  cv2.circle(skeleton_frame, (x, y), 8, (255, 255, 255), 2)
92
 
93
- return skeleton_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  @spaces.GPU
96
- def process_video(video_path, progress=gr.Progress()):
97
  """Process video to extract skeleton frames and first frame."""
98
  if video_path is None:
99
  return None, None
@@ -104,7 +240,6 @@ def process_video(video_path, progress=gr.Progress()):
104
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
105
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
106
 
107
- # Temp output for skeleton video
108
  temp_dir = tempfile.mkdtemp()
109
  skeleton_path = os.path.join(temp_dir, "skeleton.mp4")
110
  first_frame_path = os.path.join(temp_dir, "first_frame.png")
@@ -115,6 +250,10 @@ def process_video(video_path, progress=gr.Progress()):
115
  first_frame_saved = False
116
  frame_idx = 0
117
 
 
 
 
 
118
  while True:
119
  ret, frame = cap.read()
120
  if not ret:
@@ -122,34 +261,25 @@ def process_video(video_path, progress=gr.Progress()):
122
 
123
  progress((frame_idx + 1) / total_frames, desc=f"Processing frame {frame_idx + 1}/{total_frames}")
124
 
125
- # Convert BGR to RGB for PIL
126
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
127
  pil_image = Image.fromarray(frame_rgb)
128
 
129
- # Detect person
130
- boxes = detect_person(pil_image, person_processor, person_detector)
 
 
131
 
132
  if boxes is not None:
133
- # Save first frame with person
134
- if not first_frame_saved:
135
- cv2.imwrite(first_frame_path, frame)
136
- first_frame_saved = True
137
-
138
- # Estimate pose
139
- pose_results = estimate_pose(pil_image, boxes, vitpose_processor, vitpose_model)
140
 
141
  if pose_results:
142
- # Use first person's pose
143
  keypoints = pose_results[0]["keypoints"].cpu().numpy()
144
  scores = pose_results[0]["scores"].cpu().numpy()
145
- skeleton_frame = draw_skeleton(frame, keypoints, scores)
146
- else:
147
- skeleton_frame = np.zeros((height, width, 3), dtype=np.uint8)
148
- else:
149
- skeleton_frame = np.zeros((height, width, 3), dtype=np.uint8)
150
- if not first_frame_saved:
151
- cv2.imwrite(first_frame_path, frame)
152
- first_frame_saved = True
153
 
154
  out.write(skeleton_frame)
155
  frame_idx += 1
@@ -157,24 +287,51 @@ def process_video(video_path, progress=gr.Progress()):
157
  cap.release()
158
  out.release()
159
 
 
 
 
 
 
 
 
 
160
  return skeleton_path, first_frame_path if first_frame_saved else None
161
 
 
162
 
163
  with gr.Blocks() as demo:
164
  gr.Markdown("## ViTPose Skeleton Extractor")
165
-
 
166
  with gr.Row():
167
  video_input = gr.Video(label="Input Video")
168
 
 
 
 
 
 
 
 
 
169
  process_btn = gr.Button("Extract Skeleton", variant="primary")
170
 
171
  with gr.Row():
172
  skeleton_output = gr.Video(label="Skeleton Frames", interactive=True)
173
  first_frame_output = gr.Image(label="First Frame (Reference)", interactive=True)
174
 
 
 
 
 
 
 
 
 
 
175
  process_btn.click(
176
  fn=process_video,
177
- inputs=video_input,
178
  outputs=[skeleton_output, first_frame_output]
179
  )
180
 
 
8
  from PIL import Image
9
  import torch
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # ============== SKELETON DEFINITIONS ==============
14
 
15
+ # Simple model: 17 COCO keypoints
16
+ SIMPLE_EDGES = [
17
  (0, 1), (0, 2), (1, 3), (2, 4), # head
18
  (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # arms
19
  (5, 11), (6, 12), (11, 12), # torso
20
  (11, 13), (13, 15), (12, 14), (14, 16) # legs
21
  ]
22
 
23
+ SIMPLE_COLORS = [
24
  (255, 0, 0), (255, 85, 0), (255, 170, 0), (255, 255, 0),
25
  (170, 255, 0), (85, 255, 0), (0, 255, 0), (0, 255, 85),
26
  (0, 255, 170), (0, 255, 255), (0, 170, 255), (0, 85, 255),
27
  (0, 0, 255), (85, 0, 255), (170, 0, 255), (255, 0, 255), (255, 0, 170)
28
  ]
29
 
30
+ # WholeBody model: 133 keypoints
31
+ # 0-16: body, 17-22: feet, 23-90: face, 91-111: left hand, 112-132: right hand
32
+
33
+ BODY_EDGES = [
34
+ (0, 1), (0, 2), (1, 3), (2, 4),
35
+ (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
36
+ (5, 11), (6, 12), (11, 12),
37
+ (11, 13), (13, 15), (12, 14), (14, 16)
38
+ ]
39
+
40
+ FEET_EDGES = [
41
+ (15, 17), (17, 18), (18, 19),
42
+ (16, 20), (20, 21), (21, 22)
43
+ ]
44
+
45
+ # Face edges
46
+ FACE_EDGES = []
47
+ for i in range(16): # Jaw
48
+ FACE_EDGES.append((23 + i, 23 + i + 1))
49
+ for i in range(4): # Left eyebrow
50
+ FACE_EDGES.append((40 + i, 40 + i + 1))
51
+ for i in range(4): # Right eyebrow
52
+ FACE_EDGES.append((45 + i, 45 + i + 1))
53
+ for i in range(3): # Nose bridge
54
+ FACE_EDGES.append((50 + i, 50 + i + 1))
55
+ for i in range(4): # Nose bottom
56
+ FACE_EDGES.append((54 + i, 54 + i + 1))
57
+ for i in range(5): # Left eye
58
+ FACE_EDGES.append((59 + i, 59 + i + 1))
59
+ FACE_EDGES.append((64, 59))
60
+ for i in range(5): # Right eye
61
+ FACE_EDGES.append((65 + i, 65 + i + 1))
62
+ FACE_EDGES.append((70, 65))
63
+ for i in range(11): # Outer lip
64
+ FACE_EDGES.append((71 + i, 71 + i + 1))
65
+ FACE_EDGES.append((82, 71))
66
+ for i in range(7): # Inner lip
67
+ FACE_EDGES.append((83 + i, 83 + i + 1))
68
+ FACE_EDGES.append((90, 83))
69
+
70
+ def get_hand_edges(start_idx):
71
+ edges = []
72
+ edges.append((start_idx, start_idx + 1))
73
+ edges.append((start_idx, start_idx + 5))
74
+ edges.append((start_idx, start_idx + 9))
75
+ edges.append((start_idx, start_idx + 13))
76
+ edges.append((start_idx, start_idx + 17))
77
+ for finger_start in [1, 5, 9, 13, 17]:
78
+ for i in range(3):
79
+ edges.append((start_idx + finger_start + i, start_idx + finger_start + i + 1))
80
+ return edges
81
+
82
+ LEFT_HAND_EDGES = get_hand_edges(91)
83
+ RIGHT_HAND_EDGES = get_hand_edges(112)
84
+
85
+ WHOLEBODY_COLORS = {
86
+ 'body': (0, 255, 255),
87
+ 'face': (255, 255, 255),
88
+ 'left_hand': (0, 255, 0),
89
+ 'right_hand': (0, 0, 255),
90
+ 'feet': (255, 0, 255),
91
+ }
92
+
93
+ # ============== LOAD MODELS ==============
94
 
 
95
  print("Loading models...")
96
+
97
+ # Person detector (shared)
98
  person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
99
  person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
100
 
101
+ # Simple ViTPose (17 keypoints)
102
+ vitpose_simple = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple").to(device)
103
+ vitpose_simple_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
104
+
105
+ # WholeBody ViTPose (133 keypoints)
106
+ vitpose_wholebody = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-huge-wholebody").to(device)
107
+ vitpose_wholebody_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-huge-wholebody")
108
+
109
+ print("All models loaded!")
110
 
111
+ # ============== DETECTION & POSE FUNCTIONS ==============
112
 
113
+ def detect_person(image):
114
+ """Detect person bounding boxes in image."""
115
+ inputs = person_processor(images=image, return_tensors="pt").to(device)
116
  with torch.no_grad():
117
+ outputs = person_detector(**inputs)
118
 
119
+ results = person_processor.post_process_object_detection(
120
  outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
121
  )
122
 
 
123
  boxes = []
124
  for result in results:
125
  for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
126
+ if label.item() == 0:
127
  boxes.append(box.cpu().numpy())
128
 
129
  return boxes if boxes else None
130
 
131
+
132
  @spaces.GPU
133
+ def estimate_pose(image, boxes, model_choice):
134
+ """Estimate pose keypoints using selected model."""
135
+ if model_choice == "Simple (17 keypoints)":
136
+ processor = vitpose_simple_processor
137
+ model = vitpose_simple
138
+ else:
139
+ processor = vitpose_wholebody_processor
140
+ model = vitpose_wholebody
141
+
142
  inputs = processor(images=image, boxes=[boxes], return_tensors="pt").to(device)
143
  with torch.no_grad():
144
  outputs = model(**inputs)
 
146
  pose_results = processor.post_process_pose_estimation(outputs, boxes=[boxes])
147
  return pose_results[0] if pose_results else None
148
 
149
+ # ============== DRAWING FUNCTIONS ==============
150
 
151
+ def draw_simple_skeleton(frame, keypoints, scores, threshold=0.3):
152
+ """Draw 17-keypoint skeleton."""
153
  h, w = frame.shape[:2]
154
  skeleton_frame = np.zeros((h, w, 3), dtype=np.uint8)
155
 
156
  if keypoints is None:
157
+ return skeleton_frame, False
158
 
159
+ has_valid = False
160
+
161
+ for start_idx, end_idx in SIMPLE_EDGES:
162
+ if start_idx < len(scores) and end_idx < len(scores):
163
+ if scores[start_idx] > threshold and scores[end_idx] > threshold:
164
+ start_point = (int(keypoints[start_idx][0]), int(keypoints[start_idx][1]))
165
+ end_point = (int(keypoints[end_idx][0]), int(keypoints[end_idx][1]))
166
+ color = SIMPLE_COLORS[start_idx % len(SIMPLE_COLORS)]
167
+ cv2.line(skeleton_frame, start_point, end_point, color, 3)
168
+ has_valid = True
169
 
 
170
  for i, (kp, score) in enumerate(zip(keypoints, scores)):
171
  if score > threshold:
172
  x, y = int(kp[0]), int(kp[1])
173
+ color = SIMPLE_COLORS[i % len(SIMPLE_COLORS)]
174
  cv2.circle(skeleton_frame, (x, y), 6, color, -1)
175
  cv2.circle(skeleton_frame, (x, y), 8, (255, 255, 255), 2)
176
 
177
+ return skeleton_frame, has_valid
178
+
179
+
180
+ def draw_edges(frame, keypoints, scores, edges, color, threshold=0.3, thickness=2):
181
+ """Draw edges for a set of connections."""
182
+ for start_idx, end_idx in edges:
183
+ if start_idx < len(scores) and end_idx < len(scores):
184
+ if scores[start_idx] > threshold and scores[end_idx] > threshold:
185
+ start_point = (int(keypoints[start_idx][0]), int(keypoints[start_idx][1]))
186
+ end_point = (int(keypoints[end_idx][0]), int(keypoints[end_idx][1]))
187
+ cv2.line(frame, start_point, end_point, color, thickness)
188
+
189
+
190
+ def draw_keypoints(frame, keypoints, scores, indices, color, threshold=0.3, radius=3):
191
+ """Draw keypoints for specified indices."""
192
+ for idx in indices:
193
+ if idx < len(scores) and scores[idx] > threshold:
194
+ x, y = int(keypoints[idx][0]), int(keypoints[idx][1])
195
+ cv2.circle(frame, (x, y), radius, color, -1)
196
+
197
+
198
+ def draw_wholebody_skeleton(frame, keypoints, scores, threshold=0.3):
199
+ """Draw 133-keypoint skeleton with color coding."""
200
+ h, w = frame.shape[:2]
201
+ skeleton_frame = np.zeros((h, w, 3), dtype=np.uint8)
202
+
203
+ if keypoints is None:
204
+ return skeleton_frame, False
205
+
206
+ # Body
207
+ draw_edges(skeleton_frame, keypoints, scores, BODY_EDGES, WHOLEBODY_COLORS['body'], threshold, 3)
208
+ draw_keypoints(skeleton_frame, keypoints, scores, range(17), WHOLEBODY_COLORS['body'], threshold, 5)
209
+
210
+ # Feet
211
+ draw_edges(skeleton_frame, keypoints, scores, FEET_EDGES, WHOLEBODY_COLORS['feet'], threshold, 2)
212
+ draw_keypoints(skeleton_frame, keypoints, scores, range(17, 23), WHOLEBODY_COLORS['feet'], threshold, 3)
213
+
214
+ # Face
215
+ draw_edges(skeleton_frame, keypoints, scores, FACE_EDGES, WHOLEBODY_COLORS['face'], threshold, 1)
216
+
217
+ # Hands
218
+ draw_edges(skeleton_frame, keypoints, scores, LEFT_HAND_EDGES, WHOLEBODY_COLORS['left_hand'], threshold, 2)
219
+ draw_keypoints(skeleton_frame, keypoints, scores, range(91, 112), WHOLEBODY_COLORS['left_hand'], threshold, 2)
220
+
221
+ draw_edges(skeleton_frame, keypoints, scores, RIGHT_HAND_EDGES, WHOLEBODY_COLORS['right_hand'], threshold, 2)
222
+ draw_keypoints(skeleton_frame, keypoints, scores, range(112, 133), WHOLEBODY_COLORS['right_hand'], threshold, 2)
223
+
224
+ body_scores = scores[:17] if len(scores) >= 17 else scores
225
+ has_valid = np.sum(body_scores > threshold) >= 5
226
+
227
+ return skeleton_frame, has_valid
228
+
229
+ # ============== MAIN PROCESSING ==============
230
 
231
  @spaces.GPU
232
+ def process_video(video_path, model_choice, progress=gr.Progress()):
233
  """Process video to extract skeleton frames and first frame."""
234
  if video_path is None:
235
  return None, None
 
240
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
241
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
242
 
 
243
  temp_dir = tempfile.mkdtemp()
244
  skeleton_path = os.path.join(temp_dir, "skeleton.mp4")
245
  first_frame_path = os.path.join(temp_dir, "first_frame.png")
 
250
  first_frame_saved = False
251
  frame_idx = 0
252
 
253
+ # Select drawing function based on model
254
+ use_wholebody = model_choice == "WholeBody (133 keypoints)"
255
+ draw_fn = draw_wholebody_skeleton if use_wholebody else draw_simple_skeleton
256
+
257
  while True:
258
  ret, frame = cap.read()
259
  if not ret:
 
261
 
262
  progress((frame_idx + 1) / total_frames, desc=f"Processing frame {frame_idx + 1}/{total_frames}")
263
 
 
264
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
265
  pil_image = Image.fromarray(frame_rgb)
266
 
267
+ boxes = detect_person(pil_image)
268
+
269
+ skeleton_frame = np.zeros((height, width, 3), dtype=np.uint8)
270
+ person_detected = False
271
 
272
  if boxes is not None:
273
+ pose_results = estimate_pose(pil_image, boxes, model_choice)
 
 
 
 
 
 
274
 
275
  if pose_results:
 
276
  keypoints = pose_results[0]["keypoints"].cpu().numpy()
277
  scores = pose_results[0]["scores"].cpu().numpy()
278
+ skeleton_frame, person_detected = draw_fn(frame, keypoints, scores)
279
+
280
+ if person_detected and not first_frame_saved:
281
+ cv2.imwrite(first_frame_path, frame)
282
+ first_frame_saved = True
 
 
 
283
 
284
  out.write(skeleton_frame)
285
  frame_idx += 1
 
287
  cap.release()
288
  out.release()
289
 
290
+ if not first_frame_saved:
291
+ cap = cv2.VideoCapture(video_path)
292
+ ret, frame = cap.read()
293
+ if ret:
294
+ cv2.imwrite(first_frame_path, frame)
295
+ first_frame_saved = True
296
+ cap.release()
297
+
298
  return skeleton_path, first_frame_path if first_frame_saved else None
299
 
300
+ # ============== GRADIO UI ==============
301
 
302
  with gr.Blocks() as demo:
303
  gr.Markdown("## ViTPose Skeleton Extractor")
304
+ gr.Markdown("Choose between fast 17-keypoint extraction or detailed 133-keypoint wholebody extraction")
305
+
306
  with gr.Row():
307
  video_input = gr.Video(label="Input Video")
308
 
309
+ with gr.Row():
310
+ model_choice = gr.Radio(
311
+ choices=["Simple (17 keypoints)", "WholeBody (133 keypoints)"],
312
+ value="Simple (17 keypoints)",
313
+ label="Model Selection",
314
+ info="WholeBody includes hands, face & feet but is slower"
315
+ )
316
+
317
  process_btn = gr.Button("Extract Skeleton", variant="primary")
318
 
319
  with gr.Row():
320
  skeleton_output = gr.Video(label="Skeleton Frames", interactive=True)
321
  first_frame_output = gr.Image(label="First Frame (Reference)", interactive=True)
322
 
323
+ with gr.Accordion("Color Legend (WholeBody mode)", open=False):
324
+ gr.Markdown("""
325
+ - 🔵 **Cyan**: Body (17 keypoints)
326
+ - 🟣 **Magenta**: Feet (6 keypoints)
327
+ - ⚪ **White**: Face (68 keypoints)
328
+ - 🟢 **Green**: Left hand (21 keypoints)
329
+ - 🔴 **Red**: Right hand (21 keypoints)
330
+ """)
331
+
332
  process_btn.click(
333
  fn=process_video,
334
+ inputs=[video_input, model_choice],
335
  outputs=[skeleton_output, first_frame_output]
336
  )
337