esther
revert back, no live
1619dda
import gradio as gr
import numpy as np
import cv2
import mediapipe as mp
import tensorflow as tf
import os
import json
# Ensure class order exactly matches training
# Load labels.json (must be pushed along with the model)
with open("labels.json", "r", encoding="utf-8") as f:
class_names = json.load(f) # list
print(f"Loaded {len(class_names)} classes:")
for i, label in enumerate(class_names):
print(f"{i}: {label}")
# Load model
model = tf.keras.models.load_model("action_final.h5")
# Mediapipe Hands setup
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
static_image_mode=False,
max_num_hands=2,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
mp_holistic = mp.solutions.holistic
# Use Holistic instead of Hands
holistic = mp_holistic.Holistic(
static_image_mode=False,
model_complexity=1,
smooth_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
# Feature extraction functions (same as your current ones)
SEQUENCE_LENGTH = 30 # same as training
FEATURE_SIZE = 258
def extract_keypoints_inference(results):
"""
Matches the exact order used in training:
pose(33*4) β†’ left_hand(21*3) β†’ right_hand(21*3)
"""
pose = np.array([[lm.x, lm.y, lm.z, lm.visibility] for lm in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
lh = np.array([[lm.x, lm.y, lm.z] for lm in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21*3)
rh = np.array([[lm.x, lm.y, lm.z] for lm in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21*3)
return np.concatenate([pose, lh, rh])
def video_to_features_inference(video_path):
"""
Convert video into shape (1, 30, 258) exactly like training.
Pads with zeros if less than 30 frames.
"""
cap = cv2.VideoCapture(video_path)
frames_features = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR β†’ RGB
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = holistic.process(img_rgb)
keypoints = extract_keypoints_inference(results)
frames_features.append(keypoints)
if len(frames_features) >= SEQUENCE_LENGTH:
break
cap.release()
# Pad if video is shorter than SEQUENCE_LENGTH
while len(frames_features) < SEQUENCE_LENGTH:
frames_features.append(np.zeros(FEATURE_SIZE))
X = np.array(frames_features[:SEQUENCE_LENGTH], dtype=np.float32)
X = np.expand_dims(X, axis=0) # shape: (1,30,258)
return X
# Prediction function for Gradio
# Prediction function for Gradio (debug mode)
# Prediction function for Gradio (debug mode)
def predict_video(file):
video_path = file # Gradio temp file
cap = cv2.VideoCapture(video_path)
frames_features = []
frame_count = 0
all_window_preds = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR β†’ RGB
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = holistic.process(img_rgb)
keypoints = extract_keypoints_inference(results)
frames_features.append(keypoints)
frame_count += 1
# Once we have 30 frames, predict
if len(frames_features) == SEQUENCE_LENGTH:
X = np.array(frames_features, dtype=np.float32)
X = np.expand_dims(X, axis=0) # shape: (1,30,258)
preds = model.predict(X)[0]
all_window_preds.append(preds)
frames_features.pop(0) # slide window by 1 frame
cap.release()
# Average predictions across all windows
if not all_window_preds:
return {"Unknown": 0.0}
avg_preds = np.mean(all_window_preds, axis=0)
label_idx = int(np.argmax(avg_preds))
confidence = float(np.max(avg_preds))
class_label = class_names[label_idx]
# Print top 5 for debugging
top_n = 5
top_indices = avg_preds.argsort()[-top_n:][::-1]
print("Top predictions:")
for i in top_indices:
print(f"{class_names[i]}: {avg_preds[i]:.4f}")
return {class_label: confidence}
# Gradio interface
iface = gr.Interface(
fn=predict_video,
inputs=gr.Video(label="Upload a Video"),
outputs=gr.Label(num_top_classes=1),
live=False,
api_name="predict" # <--- ADD THIS LINE!
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)