Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -193,95 +193,101 @@ holistic = mp_holistic.Holistic(
|
|
| 193 |
# ----------------------------
|
| 194 |
# Gradio inference with state
|
| 195 |
# ----------------------------
|
| 196 |
-
def run(frame, sequence_state):
|
| 197 |
-
"""
|
| 198 |
-
frame: numpy array from webcam (RGB)
|
| 199 |
-
sequence_state: list of last keypoint vectors
|
| 200 |
-
returns: annotated_frame (RGB), label dict, updated sequence_state
|
| 201 |
-
"""
|
| 202 |
-
if frame is None:
|
| 203 |
-
return None, {"(waiting for camera)": 1.0}, (sequence_state or [])
|
| 204 |
-
|
| 205 |
-
if sequence_state is None:
|
| 206 |
-
sequence_state = []
|
| 207 |
-
|
| 208 |
-
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 209 |
-
|
| 210 |
-
image_bgr, results = mediapipe_detection(frame_bgr, holistic)
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
keypoints = extract_keypoints(results)
|
| 216 |
-
sequence_state.append(keypoints)
|
| 217 |
-
sequence_state = sequence_state[-SEQ_LEN:]
|
| 218 |
-
|
| 219 |
-
probs_dict = {}
|
| 220 |
-
pred_text = "Waiting..."
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
if not hands_present:
|
| 225 |
-
pred_text = "No hands detected"
|
| 226 |
-
elif len(sequence_state) == SEQ_LEN:
|
| 227 |
-
x = torch.tensor(np.expand_dims(sequence_state, axis=0), dtype=torch.float32)
|
| 228 |
-
with torch.no_grad():
|
| 229 |
-
logits = model(x)
|
| 230 |
-
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
out_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
|
|
|
|
| 249 |
|
|
|
|
| 250 |
|
| 251 |
-
# Gradio gives RGB; MediaPipe
|
| 252 |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 253 |
|
|
|
|
|
|
|
|
|
|
| 254 |
image_bgr, results = mediapipe_detection(frame_bgr, holistic)
|
|
|
|
|
|
|
| 255 |
draw_styled_landmarks(image_bgr, results)
|
| 256 |
|
|
|
|
| 257 |
keypoints = extract_keypoints(results)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
probs_dict = {}
|
| 262 |
-
pred_text = "Waiting..."
|
| 263 |
-
conf = 0.0
|
| 264 |
|
| 265 |
hands_present = (results.left_hand_landmarks is not None) or (results.right_hand_landmarks is not None)
|
| 266 |
|
| 267 |
if not hands_present:
|
| 268 |
-
|
| 269 |
-
elif len(
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
cv2.putText(
|
| 283 |
image_bgr,
|
| 284 |
-
|
| 285 |
(10, 30),
|
| 286 |
cv2.FONT_HERSHEY_SIMPLEX,
|
| 287 |
0.9,
|
|
@@ -290,32 +296,26 @@ def run(frame, sequence_state):
|
|
| 290 |
cv2.LINE_AA
|
| 291 |
)
|
| 292 |
|
| 293 |
-
# Back to RGB for Gradio display
|
| 294 |
out_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
| 295 |
|
| 296 |
-
# If probs_dict is empty (e.g., still warming up), show something stable
|
| 297 |
-
if not probs_dict:
|
| 298 |
-
probs_dict = {"(warming up)": 1.0}
|
| 299 |
-
|
| 300 |
-
return out_rgb, probs_dict, sequence_state
|
| 301 |
|
|
|
|
|
|
|
|
|
|
| 302 |
with gr.Blocks() as demo:
|
| 303 |
-
gr.Markdown("# Live
|
| 304 |
-
gr.Markdown("
|
| 305 |
-
|
| 306 |
-
seq_state = gr.State([])
|
| 307 |
|
| 308 |
-
|
| 309 |
-
cam = gr.Image(sources=["webcam"], type="numpy", label="Webcam")
|
| 310 |
-
out_img = gr.Image(type="numpy", label="Output (Annotated)")
|
| 311 |
|
| 312 |
-
|
| 313 |
|
| 314 |
-
# Stream
|
| 315 |
cam.stream(
|
| 316 |
fn=run,
|
| 317 |
-
inputs=[cam,
|
| 318 |
-
outputs=[
|
| 319 |
)
|
| 320 |
|
| 321 |
if __name__ == "__main__":
|
|
|
|
| 193 |
# ----------------------------
|
| 194 |
# Gradio inference with state
|
| 195 |
# ----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
+
# Make landmarks thicker & clearer
|
| 198 |
+
HAND_LANDMARK_STYLE = mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=4, circle_radius=6)
|
| 199 |
+
HAND_CONNECTION_STYLE = mp_drawing.DrawingSpec(color=(255, 0, 0), thickness=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
POSE_LANDMARK_STYLE = mp_drawing.DrawingSpec(color=(0, 255, 255), thickness=3, circle_radius=5)
|
| 202 |
+
POSE_CONNECTION_STYLE = mp_drawing.DrawingSpec(color=(0, 128, 255), thickness=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
def draw_styled_landmarks(image, results):
|
| 205 |
+
# Pose (optional; comment out if you want faster)
|
| 206 |
+
if results.pose_landmarks:
|
| 207 |
+
mp_drawing.draw_landmarks(
|
| 208 |
+
image,
|
| 209 |
+
results.pose_landmarks,
|
| 210 |
+
mp_holistic.POSE_CONNECTIONS,
|
| 211 |
+
POSE_LANDMARK_STYLE,
|
| 212 |
+
POSE_CONNECTION_STYLE,
|
| 213 |
+
)
|
| 214 |
|
| 215 |
+
# Hands
|
| 216 |
+
if results.left_hand_landmarks:
|
| 217 |
+
mp_drawing.draw_landmarks(
|
| 218 |
+
image,
|
| 219 |
+
results.left_hand_landmarks,
|
| 220 |
+
mp_holistic.HAND_CONNECTIONS,
|
| 221 |
+
HAND_LANDMARK_STYLE,
|
| 222 |
+
HAND_CONNECTION_STYLE,
|
| 223 |
+
)
|
| 224 |
|
| 225 |
+
if results.right_hand_landmarks:
|
| 226 |
+
mp_drawing.draw_landmarks(
|
| 227 |
+
image,
|
| 228 |
+
results.right_hand_landmarks,
|
| 229 |
+
mp_holistic.HAND_CONNECTIONS,
|
| 230 |
+
HAND_LANDMARK_STYLE,
|
| 231 |
+
HAND_CONNECTION_STYLE,
|
| 232 |
+
)
|
| 233 |
|
|
|
|
| 234 |
|
| 235 |
+
def run(frame, state):
|
| 236 |
+
"""
|
| 237 |
+
One-screen mode:
|
| 238 |
+
- input: webcam frame
|
| 239 |
+
- output: the same frame with tracking overlay (and optional top-1 text)
|
| 240 |
+
state: dict that keeps sequence + last text + frame counter
|
| 241 |
+
"""
|
| 242 |
+
if frame is None:
|
| 243 |
+
return None, state
|
| 244 |
|
| 245 |
+
if state is None:
|
| 246 |
+
state = {"seq": [], "t": 0, "text": "Warming up..."}
|
| 247 |
|
| 248 |
+
state["t"] += 1
|
| 249 |
|
| 250 |
+
# Gradio gives RGB; MediaPipe expects RGB internally but our helper uses cv2 BGR conversions
|
| 251 |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
| 252 |
|
| 253 |
+
# OPTIONAL speed boost: smaller frame (uncomment if needed)
|
| 254 |
+
# frame_bgr = cv2.resize(frame_bgr, (640, 360))
|
| 255 |
+
|
| 256 |
image_bgr, results = mediapipe_detection(frame_bgr, holistic)
|
| 257 |
+
|
| 258 |
+
# Draw overlay landmarks on the SAME frame
|
| 259 |
draw_styled_landmarks(image_bgr, results)
|
| 260 |
|
| 261 |
+
# Build sequence for your model
|
| 262 |
keypoints = extract_keypoints(results)
|
| 263 |
+
state["seq"].append(keypoints)
|
| 264 |
+
state["seq"] = state["seq"][-SEQ_LEN:]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
hands_present = (results.left_hand_landmarks is not None) or (results.right_hand_landmarks is not None)
|
| 267 |
|
| 268 |
if not hands_present:
|
| 269 |
+
state["text"] = "No hands detected"
|
| 270 |
+
elif len(state["seq"]) < SEQ_LEN:
|
| 271 |
+
state["text"] = f"Warming up... {len(state['seq'])}/{SEQ_LEN}"
|
| 272 |
+
else:
|
| 273 |
+
# Run model less frequently to reduce lag on CPU
|
| 274 |
+
PRED_EVERY = 3
|
| 275 |
+
if state["t"] % PRED_EVERY == 0:
|
| 276 |
+
x = torch.tensor(np.expand_dims(state["seq"], axis=0), dtype=torch.float32)
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
logits = model(x)
|
| 279 |
+
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
|
| 280 |
+
|
| 281 |
+
top_idx = int(np.argmax(probs))
|
| 282 |
+
conf = float(probs[top_idx])
|
| 283 |
+
state["text"] = f"{LABELS[top_idx]} ({conf:.0%})"
|
| 284 |
+
# else: keep last state["text"] so it looks stable
|
| 285 |
+
|
| 286 |
+
# Optional: overlay top-1 text on the same screen
|
| 287 |
+
cv2.rectangle(image_bgr, (0, 0), (640, 45), (0, 0, 0), -1)
|
| 288 |
cv2.putText(
|
| 289 |
image_bgr,
|
| 290 |
+
state["text"],
|
| 291 |
(10, 30),
|
| 292 |
cv2.FONT_HERSHEY_SIMPLEX,
|
| 293 |
0.9,
|
|
|
|
| 296 |
cv2.LINE_AA
|
| 297 |
)
|
| 298 |
|
|
|
|
| 299 |
out_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
| 300 |
+
return out_rgb, state
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
# ----------------------------
|
| 304 |
+
# One-screen Gradio UI
|
| 305 |
+
# ----------------------------
|
| 306 |
with gr.Blocks() as demo:
|
| 307 |
+
gr.Markdown("# Live Hand Gesture Tracking (Single Screen)")
|
| 308 |
+
gr.Markdown("Webcam shows the tracking overlay directly. Prediction text appears on the same screen.")
|
|
|
|
|
|
|
| 309 |
|
| 310 |
+
st = gr.State(None)
|
|
|
|
|
|
|
| 311 |
|
| 312 |
+
cam = gr.Image(sources=["webcam"], type="numpy", label="Webcam (Overlay)")
|
| 313 |
|
| 314 |
+
# Stream: output back into the SAME webcam component
|
| 315 |
cam.stream(
|
| 316 |
fn=run,
|
| 317 |
+
inputs=[cam, st],
|
| 318 |
+
outputs=[cam, st],
|
| 319 |
)
|
| 320 |
|
| 321 |
if __name__ == "__main__":
|