Spaces:
Paused
Paused
| # app.py | |
| import json | |
| from collections import deque | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import mediapipe as mp | |
| # ---------------------------- | |
| # Config (match your training) | |
| # ---------------------------- | |
| SEQ_LEN = 30 # <-- change if your training used a different sequence length | |
| FEATURES = 258 # 33*4 + 21*3 + 21*3 | |
| CONF_THRESH = 0.60 # demo threshold | |
| SMOOTH_K = 5 # moving average over last K probability vectors | |
| # ---------------------------- | |
| # Load labels (90 classes) | |
| # ---------------------------- | |
| with open("labels.json", "r", encoding="utf-8") as f: | |
| LABELS = json.load(f) | |
| NUM_CLASSES = len(LABELS) | |
| # ---------------------------- | |
| # Model (MATCHES your trained_model.pth) | |
| # ---------------------------- | |
| class CNNLSTMHybrid(nn.Module): | |
| def __init__(self, input_size=258, num_classes=90, dropout=0.4): | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(input_size, 128, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm1d(128) | |
| self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| self.conv3 = nn.Conv1d(256, 128, kernel_size=3, padding=1) | |
| self.bn3 = nn.BatchNorm1d(128) | |
| self.dropout_cnn = nn.Dropout(dropout) | |
| # From checkpoint: lstm1 hidden=256, bidirectional | |
| self.lstm1 = nn.LSTM( | |
| input_size=128, | |
| hidden_size=256, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| # From checkpoint: lstm2 input=512, hidden=128, bidirectional | |
| self.lstm2 = nn.LSTM( | |
| input_size=512, | |
| hidden_size=128, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.dropout_lstm = nn.Dropout(dropout) | |
| # From checkpoint: fc1 in=256 -> 128, fc2 128->64, out 64->num_classes | |
| self.fc1 = nn.Linear(256, 128) | |
| self.bn_fc = nn.BatchNorm1d(128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.dropout_fc = nn.Dropout(dropout) | |
| self.output_layer = nn.Linear(64, num_classes) | |
| def forward(self, x): | |
| # x: (B, T, 258) | |
| x = x.transpose(1, 2) # (B, 258, T) | |
| x = torch.relu(self.bn1(self.conv1(x))) | |
| x = torch.relu(self.bn2(self.conv2(x))) | |
| x = torch.relu(self.bn3(self.conv3(x))) | |
| x = self.dropout_cnn(x) | |
| x = x.transpose(1, 2) # (B, T, 128) | |
| x, _ = self.lstm1(x) # (B, T, 512) | |
| x = self.dropout_lstm(x) | |
| _, (h, _) = self.lstm2(x) # h: (2, B, 128) | |
| h = h.transpose(0, 1).contiguous().view(h.size(1), -1) # (B, 256) | |
| x = torch.relu(self.bn_fc(self.fc1(h))) | |
| x = self.dropout_fc(x) | |
| x = torch.relu(self.fc2(x)) | |
| x = self.dropout_fc(x) | |
| return self.output_layer(x) | |
| # ---------------------------- | |
| # Load trained weights | |
| # ---------------------------- | |
| DEVICE = torch.device("cpu") | |
| model = CNNLSTMHybrid(input_size=FEATURES, num_classes=NUM_CLASSES).to(DEVICE) | |
| ckpt = torch.load("trained_model.pth", map_location="cpu") | |
| # Support either plain state_dict or wrapped dict | |
| state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # ---------------------------- | |
| # MediaPipe Holistic + drawing (clear hands overlay) | |
| # ---------------------------- | |
| mp_holistic = mp.solutions.holistic | |
| mp_drawing = mp.solutions.drawing_utils | |
| # Style: green lines + red dots (OpenCV uses BGR) | |
| LANDMARK_STYLE = mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=4) # red | |
| CONNECTION_STYLE = mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=3) # green | |
| # Create once (faster + stable) | |
| holistic = mp_holistic.Holistic( | |
| model_complexity=1, | |
| smooth_landmarks=True, | |
| min_detection_confidence=0.7, | |
| min_tracking_confidence=0.7 | |
| ) | |
| def extract_keypoints(results) -> np.ndarray: | |
| # Pose: 33*(x,y,z,visibility) = 132 | |
| if results.pose_landmarks: | |
| pose = np.array([[l.x, l.y, l.z, l.visibility] for l in results.pose_landmarks.landmark]).flatten() | |
| else: | |
| pose = np.zeros(33 * 4, dtype=np.float32) | |
| # Left hand: 21*(x,y,z) = 63 | |
| if results.left_hand_landmarks: | |
| lh = np.array([[l.x, l.y, l.z] for l in results.left_hand_landmarks.landmark]).flatten() | |
| else: | |
| lh = np.zeros(21 * 3, dtype=np.float32) | |
| # Right hand: 21*(x,y,z) = 63 | |
| if results.right_hand_landmarks: | |
| rh = np.array([[l.x, l.y, l.z] for l in results.right_hand_landmarks.landmark]).flatten() | |
| else: | |
| rh = np.zeros(21 * 3, dtype=np.float32) | |
| return np.concatenate([pose, lh, rh]).astype(np.float32) | |
| def overlay_header(frame_bgr, label, conf, ready, seq_len): | |
| h, w = frame_bgr.shape[:2] | |
| cv2.rectangle(frame_bgr, (0, 0), (w, 70), (0, 0, 0), -1) | |
| status = "READY" if ready else f"COLLECTING {seq_len}/{SEQ_LEN}" | |
| cv2.putText(frame_bgr, f"{status}", (10, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2) | |
| cv2.putText(frame_bgr, f"{label} ({conf:.2f})", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
| return frame_bgr | |
| def predict_from_probs(prob_hist: deque): | |
| avg = np.mean(np.stack(prob_hist, axis=0), axis=0) # (C,) | |
| idx = int(np.argmax(avg)) | |
| conf = float(avg[idx]) | |
| if conf < CONF_THRESH: | |
| return "…", conf | |
| return LABELS[idx], conf | |
| def stream_fn(frame_rgb, state): | |
| """ | |
| frame_rgb: numpy array (H,W,3) RGB from Gradio | |
| state: persisted dict | |
| returns: (out_rgb, new_state) | |
| """ | |
| if state is None: | |
| state = { | |
| "seq": deque(maxlen=SEQ_LEN), | |
| "prob_hist": deque(maxlen=SMOOTH_K), | |
| "last_label": "Collecting...", | |
| "last_conf": 0.0 | |
| } | |
| # Convert to BGR for OpenCV drawing | |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| # MediaPipe Holistic expects RGB | |
| results = holistic.process(frame_rgb) | |
| # ---- NEW: clearer hand tracking overlay (hands only) ---- | |
| if results.left_hand_landmarks: | |
| mp_drawing.draw_landmarks( | |
| image=frame_bgr, | |
| landmark_list=results.left_hand_landmarks, | |
| connections=mp_holistic.HAND_CONNECTIONS, | |
| landmark_drawing_spec=LANDMARK_STYLE, | |
| connection_drawing_spec=CONNECTION_STYLE | |
| ) | |
| if results.right_hand_landmarks: | |
| mp_drawing.draw_landmarks( | |
| image=frame_bgr, | |
| landmark_list=results.right_hand_landmarks, | |
| connections=mp_holistic.HAND_CONNECTIONS, | |
| landmark_drawing_spec=LANDMARK_STYLE, | |
| connection_drawing_spec=CONNECTION_STYLE | |
| ) | |
| # -------------------------------------------------------- | |
| # Keypoints for your model | |
| keypoints = extract_keypoints(results) | |
| state["seq"].append(keypoints) | |
| ready = len(state["seq"]) == SEQ_LEN | |
| if ready: | |
| x = np.expand_dims(np.stack(list(state["seq"]), axis=0), axis=0) # (1,T,258) | |
| x_t = torch.tensor(x, dtype=torch.float32, device=DEVICE) | |
| with torch.no_grad(): | |
| logits = model(x_t) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] # (C,) | |
| state["prob_hist"].append(probs) | |
| label, conf = predict_from_probs(state["prob_hist"]) | |
| state["last_label"] = label | |
| state["last_conf"] = conf | |
| # Header text | |
| frame_bgr = overlay_header(frame_bgr, state["last_label"], state["last_conf"], ready, len(state["seq"])) | |
| # Return RGB to Gradio | |
| out_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| return out_rgb, state | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Malaysian Sign Language (Keypoints) — Live Webcam Demo") | |
| st = gr.State(None) | |
| cam = gr.Image( | |
| sources=["webcam"], | |
| streaming=True, | |
| type="numpy", | |
| label="Webcam" | |
| ) | |
| out = gr.Image(label="Output (Hand tracking + Prediction)") | |
| cam.stream( | |
| fn=stream_fn, | |
| inputs=[cam, st], | |
| outputs=[out, st], | |
| time_limit=60 | |
| ) | |
| demo.launch() | |