computervision / app.py
aziraarshad's picture
Update app.py
62845ea verified
# app.py
import json
from collections import deque
import cv2
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
import mediapipe as mp
# ----------------------------
# Config (match your training)
# ----------------------------
SEQ_LEN = 30 # <-- change if your training used a different sequence length
FEATURES = 258 # 33*4 + 21*3 + 21*3
CONF_THRESH = 0.60 # demo threshold
SMOOTH_K = 5 # moving average over last K probability vectors
# ----------------------------
# Load labels (90 classes)
# ----------------------------
with open("labels.json", "r", encoding="utf-8") as f:
LABELS = json.load(f)
NUM_CLASSES = len(LABELS)
# ----------------------------
# Model (MATCHES your trained_model.pth)
# ----------------------------
class CNNLSTMHybrid(nn.Module):
def __init__(self, input_size=258, num_classes=90, dropout=0.4):
super().__init__()
self.conv1 = nn.Conv1d(input_size, 128, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(128)
self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(256)
self.conv3 = nn.Conv1d(256, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(128)
self.dropout_cnn = nn.Dropout(dropout)
# From checkpoint: lstm1 hidden=256, bidirectional
self.lstm1 = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=1,
batch_first=True,
bidirectional=True
)
# From checkpoint: lstm2 input=512, hidden=128, bidirectional
self.lstm2 = nn.LSTM(
input_size=512,
hidden_size=128,
num_layers=1,
batch_first=True,
bidirectional=True
)
self.dropout_lstm = nn.Dropout(dropout)
# From checkpoint: fc1 in=256 -> 128, fc2 128->64, out 64->num_classes
self.fc1 = nn.Linear(256, 128)
self.bn_fc = nn.BatchNorm1d(128)
self.fc2 = nn.Linear(128, 64)
self.dropout_fc = nn.Dropout(dropout)
self.output_layer = nn.Linear(64, num_classes)
def forward(self, x):
# x: (B, T, 258)
x = x.transpose(1, 2) # (B, 258, T)
x = torch.relu(self.bn1(self.conv1(x)))
x = torch.relu(self.bn2(self.conv2(x)))
x = torch.relu(self.bn3(self.conv3(x)))
x = self.dropout_cnn(x)
x = x.transpose(1, 2) # (B, T, 128)
x, _ = self.lstm1(x) # (B, T, 512)
x = self.dropout_lstm(x)
_, (h, _) = self.lstm2(x) # h: (2, B, 128)
h = h.transpose(0, 1).contiguous().view(h.size(1), -1) # (B, 256)
x = torch.relu(self.bn_fc(self.fc1(h)))
x = self.dropout_fc(x)
x = torch.relu(self.fc2(x))
x = self.dropout_fc(x)
return self.output_layer(x)
# ----------------------------
# Load trained weights
# ----------------------------
DEVICE = torch.device("cpu")
model = CNNLSTMHybrid(input_size=FEATURES, num_classes=NUM_CLASSES).to(DEVICE)
ckpt = torch.load("trained_model.pth", map_location="cpu")
# Support either plain state_dict or wrapped dict
state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
model.load_state_dict(state_dict)
model.eval()
# ----------------------------
# MediaPipe Holistic + drawing (clear hands overlay)
# ----------------------------
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
# Style: green lines + red dots (OpenCV uses BGR)
LANDMARK_STYLE = mp_drawing.DrawingSpec(color=(0, 0, 255), thickness=2, circle_radius=4) # red
CONNECTION_STYLE = mp_drawing.DrawingSpec(color=(0, 255, 0), thickness=3) # green
# Create once (faster + stable)
holistic = mp_holistic.Holistic(
model_complexity=1,
smooth_landmarks=True,
min_detection_confidence=0.7,
min_tracking_confidence=0.7
)
def extract_keypoints(results) -> np.ndarray:
# Pose: 33*(x,y,z,visibility) = 132
if results.pose_landmarks:
pose = np.array([[l.x, l.y, l.z, l.visibility] for l in results.pose_landmarks.landmark]).flatten()
else:
pose = np.zeros(33 * 4, dtype=np.float32)
# Left hand: 21*(x,y,z) = 63
if results.left_hand_landmarks:
lh = np.array([[l.x, l.y, l.z] for l in results.left_hand_landmarks.landmark]).flatten()
else:
lh = np.zeros(21 * 3, dtype=np.float32)
# Right hand: 21*(x,y,z) = 63
if results.right_hand_landmarks:
rh = np.array([[l.x, l.y, l.z] for l in results.right_hand_landmarks.landmark]).flatten()
else:
rh = np.zeros(21 * 3, dtype=np.float32)
return np.concatenate([pose, lh, rh]).astype(np.float32)
def overlay_header(frame_bgr, label, conf, ready, seq_len):
h, w = frame_bgr.shape[:2]
cv2.rectangle(frame_bgr, (0, 0), (w, 70), (0, 0, 0), -1)
status = "READY" if ready else f"COLLECTING {seq_len}/{SEQ_LEN}"
cv2.putText(frame_bgr, f"{status}", (10, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
cv2.putText(frame_bgr, f"{label} ({conf:.2f})", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
return frame_bgr
def predict_from_probs(prob_hist: deque):
avg = np.mean(np.stack(prob_hist, axis=0), axis=0) # (C,)
idx = int(np.argmax(avg))
conf = float(avg[idx])
if conf < CONF_THRESH:
return "…", conf
return LABELS[idx], conf
def stream_fn(frame_rgb, state):
"""
frame_rgb: numpy array (H,W,3) RGB from Gradio
state: persisted dict
returns: (out_rgb, new_state)
"""
if state is None:
state = {
"seq": deque(maxlen=SEQ_LEN),
"prob_hist": deque(maxlen=SMOOTH_K),
"last_label": "Collecting...",
"last_conf": 0.0
}
# Convert to BGR for OpenCV drawing
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
# MediaPipe Holistic expects RGB
results = holistic.process(frame_rgb)
# ---- NEW: clearer hand tracking overlay (hands only) ----
if results.left_hand_landmarks:
mp_drawing.draw_landmarks(
image=frame_bgr,
landmark_list=results.left_hand_landmarks,
connections=mp_holistic.HAND_CONNECTIONS,
landmark_drawing_spec=LANDMARK_STYLE,
connection_drawing_spec=CONNECTION_STYLE
)
if results.right_hand_landmarks:
mp_drawing.draw_landmarks(
image=frame_bgr,
landmark_list=results.right_hand_landmarks,
connections=mp_holistic.HAND_CONNECTIONS,
landmark_drawing_spec=LANDMARK_STYLE,
connection_drawing_spec=CONNECTION_STYLE
)
# --------------------------------------------------------
# Keypoints for your model
keypoints = extract_keypoints(results)
state["seq"].append(keypoints)
ready = len(state["seq"]) == SEQ_LEN
if ready:
x = np.expand_dims(np.stack(list(state["seq"]), axis=0), axis=0) # (1,T,258)
x_t = torch.tensor(x, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
logits = model(x_t)
probs = torch.softmax(logits, dim=1).cpu().numpy()[0] # (C,)
state["prob_hist"].append(probs)
label, conf = predict_from_probs(state["prob_hist"])
state["last_label"] = label
state["last_conf"] = conf
# Header text
frame_bgr = overlay_header(frame_bgr, state["last_label"], state["last_conf"], ready, len(state["seq"]))
# Return RGB to Gradio
out_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
return out_rgb, state
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Malaysian Sign Language (Keypoints) — Live Webcam Demo")
st = gr.State(None)
cam = gr.Image(
sources=["webcam"],
streaming=True,
type="numpy",
label="Webcam"
)
out = gr.Image(label="Output (Hand tracking + Prediction)")
cam.stream(
fn=stream_fn,
inputs=[cam, st],
outputs=[out, st],
time_limit=60
)
demo.launch()