Spaces:
Runtime error
Runtime error
Create collect_data.py
Browse files- collect_data.py +113 -0
collect_data.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# collect_data.py
|
| 2 |
+
import cv2
|
| 3 |
+
import mediapipe as mp
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from collections import deque
|
| 8 |
+
|
| 9 |
+
# configuration
|
| 10 |
+
DATA_DIR = "gesture_data"
|
| 11 |
+
SEQUENCE_LENGTH = 30 # number of frames per sample
|
| 12 |
+
EXAMPLES_PER_LABEL = 50
|
| 13 |
+
LABELS = ["air_lock", "swipe_left", "swipe_right", "circle", "hug"] # update as needed
|
| 14 |
+
|
| 15 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
mp_hands = mp.solutions.hands
|
| 18 |
+
mp_drawing = mp.solutions.drawing_utils
|
| 19 |
+
|
| 20 |
+
def extract_landmarks(hand_landmarks):
|
| 21 |
+
# returns 21*3 normalized coords (x,y,z) flattened; if hand missing, return zeros
|
| 22 |
+
if hand_landmarks is None:
|
| 23 |
+
return np.zeros(21 * 3, dtype=np.float32)
|
| 24 |
+
coords = []
|
| 25 |
+
for lm in hand_landmarks.landmark:
|
| 26 |
+
coords.extend([lm.x, lm.y, lm.z])
|
| 27 |
+
return np.array(coords, dtype=np.float32)
|
| 28 |
+
|
| 29 |
+
def capture_label_sequence(label, cap, hands):
|
| 30 |
+
seq = deque(maxlen=SEQUENCE_LENGTH)
|
| 31 |
+
print(f"Prepare to record label: {label}. Press 'r' to start recording one example.")
|
| 32 |
+
while True:
|
| 33 |
+
ret, frame = cap.read()
|
| 34 |
+
if not ret:
|
| 35 |
+
raise RuntimeError("Failed reading webcam")
|
| 36 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 37 |
+
res = hands.process(frame_rgb)
|
| 38 |
+
left_hand = None
|
| 39 |
+
right_hand = None
|
| 40 |
+
# We will just use first detected hand (or zeros) to keep it simple:
|
| 41 |
+
lm_vec = None
|
| 42 |
+
if res.multi_hand_landmarks:
|
| 43 |
+
# choose the first hand
|
| 44 |
+
lm_vec = extract_landmarks(res.multi_hand_landmarks[0])
|
| 45 |
+
mp_drawing.draw_landmarks(frame, res.multi_hand_landmarks[0], mp_hands.HAND_CONNECTIONS)
|
| 46 |
+
else:
|
| 47 |
+
lm_vec = extract_landmarks(None)
|
| 48 |
+
cv2.putText(frame, f"Label: {label} | Press 'r' start, 'q' quit", (10,30),
|
| 49 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,0), 2)
|
| 50 |
+
cv2.imshow("Collect Gestures", frame)
|
| 51 |
+
k = cv2.waitKey(1) & 0xFF
|
| 52 |
+
if k == ord('r'):
|
| 53 |
+
# record one sequence
|
| 54 |
+
seq.clear()
|
| 55 |
+
print("Recording...")
|
| 56 |
+
t0 = time.time()
|
| 57 |
+
while len(seq) < SEQUENCE_LENGTH:
|
| 58 |
+
ret, f2 = cap.read()
|
| 59 |
+
if not ret: break
|
| 60 |
+
f2_rgb = cv2.cvtColor(f2, cv2.COLOR_BGR2RGB)
|
| 61 |
+
r = hands.process(f2_rgb)
|
| 62 |
+
if r.multi_hand_landmarks:
|
| 63 |
+
vec = extract_landmarks(r.multi_hand_landmarks[0])
|
| 64 |
+
mp_drawing.draw_landmarks(f2, r.multi_hand_landmarks[0], mp_hands.HAND_CONNECTIONS)
|
| 65 |
+
else:
|
| 66 |
+
vec = extract_landmarks(None)
|
| 67 |
+
seq.append(vec)
|
| 68 |
+
cv2.putText(f2, f"Recording... {len(seq)}/{SEQUENCE_LENGTH}", (10,30),
|
| 69 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,255), 2)
|
| 70 |
+
cv2.imshow("Collect Gestures", f2)
|
| 71 |
+
cv2.waitKey(1)
|
| 72 |
+
t1 = time.time()
|
| 73 |
+
print(f"Finished recording (took {t1-t0:.2f}s).")
|
| 74 |
+
if len(seq) == SEQUENCE_LENGTH:
|
| 75 |
+
arr = np.stack(seq, axis=0) # (seq_len, features)
|
| 76 |
+
# save into disk
|
| 77 |
+
label_dir = os.path.join(DATA_DIR, label)
|
| 78 |
+
os.makedirs(label_dir, exist_ok=True)
|
| 79 |
+
idx = len(os.listdir(label_dir))
|
| 80 |
+
fname = os.path.join(label_dir, f"{idx:04d}.npz")
|
| 81 |
+
np.savez_compressed(fname, data=arr)
|
| 82 |
+
print(f"Saved {fname}")
|
| 83 |
+
return True
|
| 84 |
+
elif k == ord('q'):
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
def main():
|
| 88 |
+
cap = cv2.VideoCapture(0)
|
| 89 |
+
if not cap.isOpened():
|
| 90 |
+
raise RuntimeError("Cannot open webcam")
|
| 91 |
+
with mp_hands.Hands(static_image_mode=False,
|
| 92 |
+
max_num_hands=1,
|
| 93 |
+
min_detection_confidence=0.5,
|
| 94 |
+
min_tracking_confidence=0.5) as hands:
|
| 95 |
+
for label in LABELS:
|
| 96 |
+
label_dir = os.path.join(DATA_DIR, label)
|
| 97 |
+
os.makedirs(label_dir, exist_ok=True)
|
| 98 |
+
cur = len(os.listdir(label_dir))
|
| 99 |
+
print(f"Label '{label}' currently has {cur} examples. Target: {EXAMPLES_PER_LABEL}")
|
| 100 |
+
while cur < EXAMPLES_PER_LABEL:
|
| 101 |
+
ok = capture_label_sequence(label, cap, hands)
|
| 102 |
+
if not ok:
|
| 103 |
+
print("User requested quit.")
|
| 104 |
+
cap.release()
|
| 105 |
+
cv2.destroyAllWindows()
|
| 106 |
+
return
|
| 107 |
+
cur = len(os.listdir(label_dir))
|
| 108 |
+
print(f"Now {cur}/{EXAMPLES_PER_LABEL} for label '{label}'")
|
| 109 |
+
cap.release()
|
| 110 |
+
cv2.destroyAllWindows()
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
main()
|