Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import mediapipe as mp | |
| import tensorflow as tf | |
| import os | |
| import json | |
| # Ensure class order exactly matches training | |
| # Load labels.json (must be pushed along with the model) | |
| with open("labels.json", "r", encoding="utf-8") as f: | |
| class_names = json.load(f) # list | |
| print(f"Loaded {len(class_names)} classes:") | |
| for i, label in enumerate(class_names): | |
| print(f"{i}: {label}") | |
| # Load model | |
| model = tf.keras.models.load_model("action_final.h5") | |
| # Mediapipe Hands setup | |
| mp_hands = mp.solutions.hands | |
| hands = mp_hands.Hands( | |
| static_image_mode=False, | |
| max_num_hands=2, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5 | |
| ) | |
| mp_holistic = mp.solutions.holistic | |
| # Use Holistic instead of Hands | |
| holistic = mp_holistic.Holistic( | |
| static_image_mode=False, | |
| model_complexity=1, | |
| smooth_landmarks=True, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5 | |
| ) | |
| # Feature extraction functions (same as your current ones) | |
| SEQUENCE_LENGTH = 30 # same as training | |
| FEATURE_SIZE = 258 | |
| def extract_keypoints_inference(results): | |
| """ | |
| Matches the exact order used in training: | |
| pose(33*4) β left_hand(21*3) β right_hand(21*3) | |
| """ | |
| pose = np.array([[lm.x, lm.y, lm.z, lm.visibility] for lm in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4) | |
| lh = np.array([[lm.x, lm.y, lm.z] for lm in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21*3) | |
| rh = np.array([[lm.x, lm.y, lm.z] for lm in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21*3) | |
| return np.concatenate([pose, lh, rh]) | |
| def video_to_features_inference(video_path): | |
| """ | |
| Convert video into shape (1, 30, 258) exactly like training. | |
| Pads with zeros if less than 30 frames. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| frames_features = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert BGR β RGB | |
| img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = holistic.process(img_rgb) | |
| keypoints = extract_keypoints_inference(results) | |
| frames_features.append(keypoints) | |
| if len(frames_features) >= SEQUENCE_LENGTH: | |
| break | |
| cap.release() | |
| # Pad if video is shorter than SEQUENCE_LENGTH | |
| while len(frames_features) < SEQUENCE_LENGTH: | |
| frames_features.append(np.zeros(FEATURE_SIZE)) | |
| X = np.array(frames_features[:SEQUENCE_LENGTH], dtype=np.float32) | |
| X = np.expand_dims(X, axis=0) # shape: (1,30,258) | |
| return X | |
| # Prediction function for Gradio | |
| # Prediction function for Gradio (debug mode) | |
| # Prediction function for Gradio (debug mode) | |
| def predict_video(file): | |
| video_path = file # Gradio temp file | |
| cap = cv2.VideoCapture(video_path) | |
| frames_features = [] | |
| frame_count = 0 | |
| all_window_preds = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert BGR β RGB | |
| img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = holistic.process(img_rgb) | |
| keypoints = extract_keypoints_inference(results) | |
| frames_features.append(keypoints) | |
| frame_count += 1 | |
| # Once we have 30 frames, predict | |
| if len(frames_features) == SEQUENCE_LENGTH: | |
| X = np.array(frames_features, dtype=np.float32) | |
| X = np.expand_dims(X, axis=0) # shape: (1,30,258) | |
| preds = model.predict(X)[0] | |
| all_window_preds.append(preds) | |
| frames_features.pop(0) # slide window by 1 frame | |
| cap.release() | |
| # Average predictions across all windows | |
| if not all_window_preds: | |
| return {"Unknown": 0.0} | |
| avg_preds = np.mean(all_window_preds, axis=0) | |
| label_idx = int(np.argmax(avg_preds)) | |
| confidence = float(np.max(avg_preds)) | |
| class_label = class_names[label_idx] | |
| # Print top 5 for debugging | |
| top_n = 5 | |
| top_indices = avg_preds.argsort()[-top_n:][::-1] | |
| print("Top predictions:") | |
| for i in top_indices: | |
| print(f"{class_names[i]}: {avg_preds[i]:.4f}") | |
| return {class_label: confidence} | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_video, | |
| inputs=gr.Video(label="Upload a Video"), | |
| outputs=gr.Label(num_top_classes=1), | |
| live=False, | |
| api_name="predict" # <--- ADD THIS LINE! | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |