Update app.py
Browse files
app.py
CHANGED
|
@@ -8,56 +8,137 @@ from transformers import AutoProcessor, VitPoseForPoseEstimation, RTDetrForObjec
|
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
(0, 1), (0, 2), (1, 3), (2, 4), # head
|
| 15 |
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # arms
|
| 16 |
(5, 11), (6, 12), (11, 12), # torso
|
| 17 |
(11, 13), (13, 15), (12, 14), (14, 16) # legs
|
| 18 |
]
|
| 19 |
|
| 20 |
-
|
| 21 |
(255, 0, 0), (255, 85, 0), (255, 170, 0), (255, 255, 0),
|
| 22 |
(170, 255, 0), (85, 255, 0), (0, 255, 0), (0, 255, 85),
|
| 23 |
(0, 255, 170), (0, 255, 255), (0, 170, 255), (0, 85, 255),
|
| 24 |
(0, 0, 255), (85, 0, 255), (170, 0, 255), (255, 0, 255), (255, 0, 170)
|
| 25 |
]
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
# Load models
|
| 30 |
print("Loading models...")
|
|
|
|
|
|
|
| 31 |
person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
|
| 32 |
person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
|
|
|
| 38 |
|
| 39 |
-
def detect_person(image
|
| 40 |
-
"""Detect person bounding
|
| 41 |
-
inputs =
|
| 42 |
with torch.no_grad():
|
| 43 |
-
outputs =
|
| 44 |
|
| 45 |
-
results =
|
| 46 |
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
|
| 47 |
)
|
| 48 |
|
| 49 |
-
# Find person detections (class 0 in COCO)
|
| 50 |
boxes = []
|
| 51 |
for result in results:
|
| 52 |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 53 |
-
if label.item() == 0:
|
| 54 |
boxes.append(box.cpu().numpy())
|
| 55 |
|
| 56 |
return boxes if boxes else None
|
| 57 |
|
|
|
|
| 58 |
@spaces.GPU
|
| 59 |
-
def estimate_pose(image, boxes,
|
| 60 |
-
"""Estimate pose keypoints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
inputs = processor(images=image, boxes=[boxes], return_tensors="pt").to(device)
|
| 62 |
with torch.no_grad():
|
| 63 |
outputs = model(**inputs)
|
|
@@ -65,35 +146,90 @@ def estimate_pose(image, boxes, processor, model):
|
|
| 65 |
pose_results = processor.post_process_pose_estimation(outputs, boxes=[boxes])
|
| 66 |
return pose_results[0] if pose_results else None
|
| 67 |
|
|
|
|
| 68 |
|
| 69 |
-
def
|
| 70 |
-
"""Draw
|
| 71 |
h, w = frame.shape[:2]
|
| 72 |
skeleton_frame = np.zeros((h, w, 3), dtype=np.uint8)
|
| 73 |
|
| 74 |
if keypoints is None:
|
| 75 |
-
return skeleton_frame
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
# Draw keypoints
|
| 86 |
for i, (kp, score) in enumerate(zip(keypoints, scores)):
|
| 87 |
if score > threshold:
|
| 88 |
x, y = int(kp[0]), int(kp[1])
|
| 89 |
-
color =
|
| 90 |
cv2.circle(skeleton_frame, (x, y), 6, color, -1)
|
| 91 |
cv2.circle(skeleton_frame, (x, y), 8, (255, 255, 255), 2)
|
| 92 |
|
| 93 |
-
return skeleton_frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
@spaces.GPU
|
| 96 |
-
def process_video(video_path, progress=gr.Progress()):
|
| 97 |
"""Process video to extract skeleton frames and first frame."""
|
| 98 |
if video_path is None:
|
| 99 |
return None, None
|
|
@@ -104,7 +240,6 @@ def process_video(video_path, progress=gr.Progress()):
|
|
| 104 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 105 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 106 |
|
| 107 |
-
# Temp output for skeleton video
|
| 108 |
temp_dir = tempfile.mkdtemp()
|
| 109 |
skeleton_path = os.path.join(temp_dir, "skeleton.mp4")
|
| 110 |
first_frame_path = os.path.join(temp_dir, "first_frame.png")
|
|
@@ -115,6 +250,10 @@ def process_video(video_path, progress=gr.Progress()):
|
|
| 115 |
first_frame_saved = False
|
| 116 |
frame_idx = 0
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
while True:
|
| 119 |
ret, frame = cap.read()
|
| 120 |
if not ret:
|
|
@@ -122,34 +261,25 @@ def process_video(video_path, progress=gr.Progress()):
|
|
| 122 |
|
| 123 |
progress((frame_idx + 1) / total_frames, desc=f"Processing frame {frame_idx + 1}/{total_frames}")
|
| 124 |
|
| 125 |
-
# Convert BGR to RGB for PIL
|
| 126 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 127 |
pil_image = Image.fromarray(frame_rgb)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
|
| 132 |
if boxes is not None:
|
| 133 |
-
|
| 134 |
-
if not first_frame_saved:
|
| 135 |
-
cv2.imwrite(first_frame_path, frame)
|
| 136 |
-
first_frame_saved = True
|
| 137 |
-
|
| 138 |
-
# Estimate pose
|
| 139 |
-
pose_results = estimate_pose(pil_image, boxes, vitpose_processor, vitpose_model)
|
| 140 |
|
| 141 |
if pose_results:
|
| 142 |
-
# Use first person's pose
|
| 143 |
keypoints = pose_results[0]["keypoints"].cpu().numpy()
|
| 144 |
scores = pose_results[0]["scores"].cpu().numpy()
|
| 145 |
-
skeleton_frame =
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if not first_frame_saved:
|
| 151 |
-
cv2.imwrite(first_frame_path, frame)
|
| 152 |
-
first_frame_saved = True
|
| 153 |
|
| 154 |
out.write(skeleton_frame)
|
| 155 |
frame_idx += 1
|
|
@@ -157,24 +287,51 @@ def process_video(video_path, progress=gr.Progress()):
|
|
| 157 |
cap.release()
|
| 158 |
out.release()
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
return skeleton_path, first_frame_path if first_frame_saved else None
|
| 161 |
|
|
|
|
| 162 |
|
| 163 |
with gr.Blocks() as demo:
|
| 164 |
gr.Markdown("## ViTPose Skeleton Extractor")
|
| 165 |
-
|
|
|
|
| 166 |
with gr.Row():
|
| 167 |
video_input = gr.Video(label="Input Video")
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
process_btn = gr.Button("Extract Skeleton", variant="primary")
|
| 170 |
|
| 171 |
with gr.Row():
|
| 172 |
skeleton_output = gr.Video(label="Skeleton Frames", interactive=True)
|
| 173 |
first_frame_output = gr.Image(label="First Frame (Reference)", interactive=True)
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
process_btn.click(
|
| 176 |
fn=process_video,
|
| 177 |
-
inputs=video_input,
|
| 178 |
outputs=[skeleton_output, first_frame_output]
|
| 179 |
)
|
| 180 |
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
|
| 13 |
+
# ============== SKELETON DEFINITIONS ==============
|
| 14 |
|
| 15 |
+
# Simple model: 17 COCO keypoints
|
| 16 |
+
SIMPLE_EDGES = [
|
| 17 |
(0, 1), (0, 2), (1, 3), (2, 4), # head
|
| 18 |
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # arms
|
| 19 |
(5, 11), (6, 12), (11, 12), # torso
|
| 20 |
(11, 13), (13, 15), (12, 14), (14, 16) # legs
|
| 21 |
]
|
| 22 |
|
| 23 |
+
SIMPLE_COLORS = [
|
| 24 |
(255, 0, 0), (255, 85, 0), (255, 170, 0), (255, 255, 0),
|
| 25 |
(170, 255, 0), (85, 255, 0), (0, 255, 0), (0, 255, 85),
|
| 26 |
(0, 255, 170), (0, 255, 255), (0, 170, 255), (0, 85, 255),
|
| 27 |
(0, 0, 255), (85, 0, 255), (170, 0, 255), (255, 0, 255), (255, 0, 170)
|
| 28 |
]
|
| 29 |
|
| 30 |
+
# WholeBody model: 133 keypoints
|
| 31 |
+
# 0-16: body, 17-22: feet, 23-90: face, 91-111: left hand, 112-132: right hand
|
| 32 |
+
|
| 33 |
+
BODY_EDGES = [
|
| 34 |
+
(0, 1), (0, 2), (1, 3), (2, 4),
|
| 35 |
+
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
|
| 36 |
+
(5, 11), (6, 12), (11, 12),
|
| 37 |
+
(11, 13), (13, 15), (12, 14), (14, 16)
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
FEET_EDGES = [
|
| 41 |
+
(15, 17), (17, 18), (18, 19),
|
| 42 |
+
(16, 20), (20, 21), (21, 22)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
# Face edges
|
| 46 |
+
FACE_EDGES = []
|
| 47 |
+
for i in range(16): # Jaw
|
| 48 |
+
FACE_EDGES.append((23 + i, 23 + i + 1))
|
| 49 |
+
for i in range(4): # Left eyebrow
|
| 50 |
+
FACE_EDGES.append((40 + i, 40 + i + 1))
|
| 51 |
+
for i in range(4): # Right eyebrow
|
| 52 |
+
FACE_EDGES.append((45 + i, 45 + i + 1))
|
| 53 |
+
for i in range(3): # Nose bridge
|
| 54 |
+
FACE_EDGES.append((50 + i, 50 + i + 1))
|
| 55 |
+
for i in range(4): # Nose bottom
|
| 56 |
+
FACE_EDGES.append((54 + i, 54 + i + 1))
|
| 57 |
+
for i in range(5): # Left eye
|
| 58 |
+
FACE_EDGES.append((59 + i, 59 + i + 1))
|
| 59 |
+
FACE_EDGES.append((64, 59))
|
| 60 |
+
for i in range(5): # Right eye
|
| 61 |
+
FACE_EDGES.append((65 + i, 65 + i + 1))
|
| 62 |
+
FACE_EDGES.append((70, 65))
|
| 63 |
+
for i in range(11): # Outer lip
|
| 64 |
+
FACE_EDGES.append((71 + i, 71 + i + 1))
|
| 65 |
+
FACE_EDGES.append((82, 71))
|
| 66 |
+
for i in range(7): # Inner lip
|
| 67 |
+
FACE_EDGES.append((83 + i, 83 + i + 1))
|
| 68 |
+
FACE_EDGES.append((90, 83))
|
| 69 |
+
|
| 70 |
+
def get_hand_edges(start_idx):
|
| 71 |
+
edges = []
|
| 72 |
+
edges.append((start_idx, start_idx + 1))
|
| 73 |
+
edges.append((start_idx, start_idx + 5))
|
| 74 |
+
edges.append((start_idx, start_idx + 9))
|
| 75 |
+
edges.append((start_idx, start_idx + 13))
|
| 76 |
+
edges.append((start_idx, start_idx + 17))
|
| 77 |
+
for finger_start in [1, 5, 9, 13, 17]:
|
| 78 |
+
for i in range(3):
|
| 79 |
+
edges.append((start_idx + finger_start + i, start_idx + finger_start + i + 1))
|
| 80 |
+
return edges
|
| 81 |
+
|
| 82 |
+
LEFT_HAND_EDGES = get_hand_edges(91)
|
| 83 |
+
RIGHT_HAND_EDGES = get_hand_edges(112)
|
| 84 |
+
|
| 85 |
+
WHOLEBODY_COLORS = {
|
| 86 |
+
'body': (0, 255, 255),
|
| 87 |
+
'face': (255, 255, 255),
|
| 88 |
+
'left_hand': (0, 255, 0),
|
| 89 |
+
'right_hand': (0, 0, 255),
|
| 90 |
+
'feet': (255, 0, 255),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# ============== LOAD MODELS ==============
|
| 94 |
|
|
|
|
| 95 |
print("Loading models...")
|
| 96 |
+
|
| 97 |
+
# Person detector (shared)
|
| 98 |
person_detector = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
|
| 99 |
person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
| 100 |
|
| 101 |
+
# Simple ViTPose (17 keypoints)
|
| 102 |
+
vitpose_simple = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple").to(device)
|
| 103 |
+
vitpose_simple_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
| 104 |
+
|
| 105 |
+
# WholeBody ViTPose (133 keypoints)
|
| 106 |
+
vitpose_wholebody = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-huge-wholebody").to(device)
|
| 107 |
+
vitpose_wholebody_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-huge-wholebody")
|
| 108 |
+
|
| 109 |
+
print("All models loaded!")
|
| 110 |
|
| 111 |
+
# ============== DETECTION & POSE FUNCTIONS ==============
|
| 112 |
|
| 113 |
+
def detect_person(image):
|
| 114 |
+
"""Detect person bounding boxes in image."""
|
| 115 |
+
inputs = person_processor(images=image, return_tensors="pt").to(device)
|
| 116 |
with torch.no_grad():
|
| 117 |
+
outputs = person_detector(**inputs)
|
| 118 |
|
| 119 |
+
results = person_processor.post_process_object_detection(
|
| 120 |
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
|
| 121 |
)
|
| 122 |
|
|
|
|
| 123 |
boxes = []
|
| 124 |
for result in results:
|
| 125 |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]):
|
| 126 |
+
if label.item() == 0:
|
| 127 |
boxes.append(box.cpu().numpy())
|
| 128 |
|
| 129 |
return boxes if boxes else None
|
| 130 |
|
| 131 |
+
|
| 132 |
@spaces.GPU
|
| 133 |
+
def estimate_pose(image, boxes, model_choice):
|
| 134 |
+
"""Estimate pose keypoints using selected model."""
|
| 135 |
+
if model_choice == "Simple (17 keypoints)":
|
| 136 |
+
processor = vitpose_simple_processor
|
| 137 |
+
model = vitpose_simple
|
| 138 |
+
else:
|
| 139 |
+
processor = vitpose_wholebody_processor
|
| 140 |
+
model = vitpose_wholebody
|
| 141 |
+
|
| 142 |
inputs = processor(images=image, boxes=[boxes], return_tensors="pt").to(device)
|
| 143 |
with torch.no_grad():
|
| 144 |
outputs = model(**inputs)
|
|
|
|
| 146 |
pose_results = processor.post_process_pose_estimation(outputs, boxes=[boxes])
|
| 147 |
return pose_results[0] if pose_results else None
|
| 148 |
|
| 149 |
+
# ============== DRAWING FUNCTIONS ==============
|
| 150 |
|
| 151 |
+
def draw_simple_skeleton(frame, keypoints, scores, threshold=0.3):
|
| 152 |
+
"""Draw 17-keypoint skeleton."""
|
| 153 |
h, w = frame.shape[:2]
|
| 154 |
skeleton_frame = np.zeros((h, w, 3), dtype=np.uint8)
|
| 155 |
|
| 156 |
if keypoints is None:
|
| 157 |
+
return skeleton_frame, False
|
| 158 |
|
| 159 |
+
has_valid = False
|
| 160 |
+
|
| 161 |
+
for start_idx, end_idx in SIMPLE_EDGES:
|
| 162 |
+
if start_idx < len(scores) and end_idx < len(scores):
|
| 163 |
+
if scores[start_idx] > threshold and scores[end_idx] > threshold:
|
| 164 |
+
start_point = (int(keypoints[start_idx][0]), int(keypoints[start_idx][1]))
|
| 165 |
+
end_point = (int(keypoints[end_idx][0]), int(keypoints[end_idx][1]))
|
| 166 |
+
color = SIMPLE_COLORS[start_idx % len(SIMPLE_COLORS)]
|
| 167 |
+
cv2.line(skeleton_frame, start_point, end_point, color, 3)
|
| 168 |
+
has_valid = True
|
| 169 |
|
|
|
|
| 170 |
for i, (kp, score) in enumerate(zip(keypoints, scores)):
|
| 171 |
if score > threshold:
|
| 172 |
x, y = int(kp[0]), int(kp[1])
|
| 173 |
+
color = SIMPLE_COLORS[i % len(SIMPLE_COLORS)]
|
| 174 |
cv2.circle(skeleton_frame, (x, y), 6, color, -1)
|
| 175 |
cv2.circle(skeleton_frame, (x, y), 8, (255, 255, 255), 2)
|
| 176 |
|
| 177 |
+
return skeleton_frame, has_valid
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def draw_edges(frame, keypoints, scores, edges, color, threshold=0.3, thickness=2):
|
| 181 |
+
"""Draw edges for a set of connections."""
|
| 182 |
+
for start_idx, end_idx in edges:
|
| 183 |
+
if start_idx < len(scores) and end_idx < len(scores):
|
| 184 |
+
if scores[start_idx] > threshold and scores[end_idx] > threshold:
|
| 185 |
+
start_point = (int(keypoints[start_idx][0]), int(keypoints[start_idx][1]))
|
| 186 |
+
end_point = (int(keypoints[end_idx][0]), int(keypoints[end_idx][1]))
|
| 187 |
+
cv2.line(frame, start_point, end_point, color, thickness)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def draw_keypoints(frame, keypoints, scores, indices, color, threshold=0.3, radius=3):
|
| 191 |
+
"""Draw keypoints for specified indices."""
|
| 192 |
+
for idx in indices:
|
| 193 |
+
if idx < len(scores) and scores[idx] > threshold:
|
| 194 |
+
x, y = int(keypoints[idx][0]), int(keypoints[idx][1])
|
| 195 |
+
cv2.circle(frame, (x, y), radius, color, -1)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def draw_wholebody_skeleton(frame, keypoints, scores, threshold=0.3):
|
| 199 |
+
"""Draw 133-keypoint skeleton with color coding."""
|
| 200 |
+
h, w = frame.shape[:2]
|
| 201 |
+
skeleton_frame = np.zeros((h, w, 3), dtype=np.uint8)
|
| 202 |
+
|
| 203 |
+
if keypoints is None:
|
| 204 |
+
return skeleton_frame, False
|
| 205 |
+
|
| 206 |
+
# Body
|
| 207 |
+
draw_edges(skeleton_frame, keypoints, scores, BODY_EDGES, WHOLEBODY_COLORS['body'], threshold, 3)
|
| 208 |
+
draw_keypoints(skeleton_frame, keypoints, scores, range(17), WHOLEBODY_COLORS['body'], threshold, 5)
|
| 209 |
+
|
| 210 |
+
# Feet
|
| 211 |
+
draw_edges(skeleton_frame, keypoints, scores, FEET_EDGES, WHOLEBODY_COLORS['feet'], threshold, 2)
|
| 212 |
+
draw_keypoints(skeleton_frame, keypoints, scores, range(17, 23), WHOLEBODY_COLORS['feet'], threshold, 3)
|
| 213 |
+
|
| 214 |
+
# Face
|
| 215 |
+
draw_edges(skeleton_frame, keypoints, scores, FACE_EDGES, WHOLEBODY_COLORS['face'], threshold, 1)
|
| 216 |
+
|
| 217 |
+
# Hands
|
| 218 |
+
draw_edges(skeleton_frame, keypoints, scores, LEFT_HAND_EDGES, WHOLEBODY_COLORS['left_hand'], threshold, 2)
|
| 219 |
+
draw_keypoints(skeleton_frame, keypoints, scores, range(91, 112), WHOLEBODY_COLORS['left_hand'], threshold, 2)
|
| 220 |
+
|
| 221 |
+
draw_edges(skeleton_frame, keypoints, scores, RIGHT_HAND_EDGES, WHOLEBODY_COLORS['right_hand'], threshold, 2)
|
| 222 |
+
draw_keypoints(skeleton_frame, keypoints, scores, range(112, 133), WHOLEBODY_COLORS['right_hand'], threshold, 2)
|
| 223 |
+
|
| 224 |
+
body_scores = scores[:17] if len(scores) >= 17 else scores
|
| 225 |
+
has_valid = np.sum(body_scores > threshold) >= 5
|
| 226 |
+
|
| 227 |
+
return skeleton_frame, has_valid
|
| 228 |
+
|
| 229 |
+
# ============== MAIN PROCESSING ==============
|
| 230 |
|
| 231 |
@spaces.GPU
|
| 232 |
+
def process_video(video_path, model_choice, progress=gr.Progress()):
|
| 233 |
"""Process video to extract skeleton frames and first frame."""
|
| 234 |
if video_path is None:
|
| 235 |
return None, None
|
|
|
|
| 240 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 241 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 242 |
|
|
|
|
| 243 |
temp_dir = tempfile.mkdtemp()
|
| 244 |
skeleton_path = os.path.join(temp_dir, "skeleton.mp4")
|
| 245 |
first_frame_path = os.path.join(temp_dir, "first_frame.png")
|
|
|
|
| 250 |
first_frame_saved = False
|
| 251 |
frame_idx = 0
|
| 252 |
|
| 253 |
+
# Select drawing function based on model
|
| 254 |
+
use_wholebody = model_choice == "WholeBody (133 keypoints)"
|
| 255 |
+
draw_fn = draw_wholebody_skeleton if use_wholebody else draw_simple_skeleton
|
| 256 |
+
|
| 257 |
while True:
|
| 258 |
ret, frame = cap.read()
|
| 259 |
if not ret:
|
|
|
|
| 261 |
|
| 262 |
progress((frame_idx + 1) / total_frames, desc=f"Processing frame {frame_idx + 1}/{total_frames}")
|
| 263 |
|
|
|
|
| 264 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 265 |
pil_image = Image.fromarray(frame_rgb)
|
| 266 |
|
| 267 |
+
boxes = detect_person(pil_image)
|
| 268 |
+
|
| 269 |
+
skeleton_frame = np.zeros((height, width, 3), dtype=np.uint8)
|
| 270 |
+
person_detected = False
|
| 271 |
|
| 272 |
if boxes is not None:
|
| 273 |
+
pose_results = estimate_pose(pil_image, boxes, model_choice)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
if pose_results:
|
|
|
|
| 276 |
keypoints = pose_results[0]["keypoints"].cpu().numpy()
|
| 277 |
scores = pose_results[0]["scores"].cpu().numpy()
|
| 278 |
+
skeleton_frame, person_detected = draw_fn(frame, keypoints, scores)
|
| 279 |
+
|
| 280 |
+
if person_detected and not first_frame_saved:
|
| 281 |
+
cv2.imwrite(first_frame_path, frame)
|
| 282 |
+
first_frame_saved = True
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
out.write(skeleton_frame)
|
| 285 |
frame_idx += 1
|
|
|
|
| 287 |
cap.release()
|
| 288 |
out.release()
|
| 289 |
|
| 290 |
+
if not first_frame_saved:
|
| 291 |
+
cap = cv2.VideoCapture(video_path)
|
| 292 |
+
ret, frame = cap.read()
|
| 293 |
+
if ret:
|
| 294 |
+
cv2.imwrite(first_frame_path, frame)
|
| 295 |
+
first_frame_saved = True
|
| 296 |
+
cap.release()
|
| 297 |
+
|
| 298 |
return skeleton_path, first_frame_path if first_frame_saved else None
|
| 299 |
|
| 300 |
+
# ============== GRADIO UI ==============
|
| 301 |
|
| 302 |
with gr.Blocks() as demo:
|
| 303 |
gr.Markdown("## ViTPose Skeleton Extractor")
|
| 304 |
+
gr.Markdown("Choose between fast 17-keypoint extraction or detailed 133-keypoint wholebody extraction")
|
| 305 |
+
|
| 306 |
with gr.Row():
|
| 307 |
video_input = gr.Video(label="Input Video")
|
| 308 |
|
| 309 |
+
with gr.Row():
|
| 310 |
+
model_choice = gr.Radio(
|
| 311 |
+
choices=["Simple (17 keypoints)", "WholeBody (133 keypoints)"],
|
| 312 |
+
value="Simple (17 keypoints)",
|
| 313 |
+
label="Model Selection",
|
| 314 |
+
info="WholeBody includes hands, face & feet but is slower"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
process_btn = gr.Button("Extract Skeleton", variant="primary")
|
| 318 |
|
| 319 |
with gr.Row():
|
| 320 |
skeleton_output = gr.Video(label="Skeleton Frames", interactive=True)
|
| 321 |
first_frame_output = gr.Image(label="First Frame (Reference)", interactive=True)
|
| 322 |
|
| 323 |
+
with gr.Accordion("Color Legend (WholeBody mode)", open=False):
|
| 324 |
+
gr.Markdown("""
|
| 325 |
+
- 🔵 **Cyan**: Body (17 keypoints)
|
| 326 |
+
- 🟣 **Magenta**: Feet (6 keypoints)
|
| 327 |
+
- ⚪ **White**: Face (68 keypoints)
|
| 328 |
+
- 🟢 **Green**: Left hand (21 keypoints)
|
| 329 |
+
- 🔴 **Red**: Right hand (21 keypoints)
|
| 330 |
+
""")
|
| 331 |
+
|
| 332 |
process_btn.click(
|
| 333 |
fn=process_video,
|
| 334 |
+
inputs=[video_input, model_choice],
|
| 335 |
outputs=[skeleton_output, first_frame_output]
|
| 336 |
)
|
| 337 |
|