SignView2.0 / realtime_sign_prediction.py
XiaoBai1221's picture
🎉 SignView2.0 精簡版部署 (with Git LFS)
ddab8ea
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import mediapipe as mp
import json
import time
from collections import deque
import argparse
class FeatureExtractor:
def __init__(self, use_segmentation=True):
# Initialize MediaPipe models
self.mp_holistic = mp.solutions.holistic
self.mp_drawing = mp.solutions.drawing_utils
self.mp_drawing_styles = mp.solutions.drawing_styles
self.mp_selfie_segmentation = mp.solutions.selfie_segmentation
# Segmentation settings
self.use_segmentation = use_segmentation
self.segmentation = None
if self.use_segmentation:
self.segmentation = self.mp_selfie_segmentation.SelfieSegmentation(model_selection=1)
# Optical flow parameters
self.optical_flow_params = dict(
flow=None,
pyr_scale=0.5,
levels=3,
winsize=15,
iterations=3,
poly_n=5,
poly_sigma=1.2,
flags=0
)
def extract_pose_keypoints(self, frame, holistic_results):
"""Extract pose keypoints"""
keypoints = []
# Extract hand keypoints
if holistic_results.left_hand_landmarks:
for landmark in holistic_results.left_hand_landmarks.landmark:
keypoints.extend([landmark.x, landmark.y, landmark.z])
else:
keypoints.extend([0] * (21 * 3))
if holistic_results.right_hand_landmarks:
for landmark in holistic_results.right_hand_landmarks.landmark:
keypoints.extend([landmark.x, landmark.y, landmark.z])
else:
keypoints.extend([0] * (21 * 3))
# Extract pose keypoints
if holistic_results.pose_landmarks:
for landmark in holistic_results.pose_landmarks.landmark:
keypoints.extend([landmark.x, landmark.y, landmark.z])
else:
keypoints.extend([0] * (33 * 3))
return np.array(keypoints)
def create_hand_mask(self, frame, left_hand_landmarks, right_hand_landmarks, pose_landmarks):
"""Create ROI mask for hands and upper body"""
h, w = frame.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
def draw_landmarks_on_mask(landmarks, radius=15):
if landmarks:
for landmark in landmarks.landmark:
x, y = int(landmark.x * w), int(landmark.y * h)
if 0 <= x < w and 0 <= y < h:
cv2.circle(mask, (x, y), radius=radius, color=255, thickness=-1)
# Draw hand keypoints
draw_landmarks_on_mask(left_hand_landmarks, radius=20)
draw_landmarks_on_mask(right_hand_landmarks, radius=20)
# Draw upper body keypoints
if pose_landmarks:
upper_body_indices = list(range(0, 25))
for idx in upper_body_indices:
if idx < len(pose_landmarks.landmark):
landmark = pose_landmarks.landmark[idx]
x, y = int(landmark.x * w), int(landmark.y * h)
if 0 <= x < w and 0 <= y < h:
cv2.circle(mask, (x, y), radius=10, color=255, thickness=-1)
# Dilate mask
kernel = np.ones((15, 15), np.uint8)
dilated_mask = cv2.dilate(mask, kernel, iterations=1)
return dilated_mask
def compute_regional_optical_flow(self, prev_frame, curr_frame, mask, downscale=0.5):
"""Compute optical flow only in masked regions"""
if downscale < 1.0:
h, w = prev_frame.shape[:2]
new_h, new_w = int(h * downscale), int(w * downscale)
prev_small = cv2.resize(prev_frame, (new_w, new_h))
curr_small = cv2.resize(curr_frame, (new_w, new_h))
mask_small = cv2.resize(mask, (new_w, new_h))
else:
prev_small = prev_frame
curr_small = curr_frame
mask_small = mask
# Convert to grayscale
prev_gray = cv2.cvtColor(prev_small, cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr_small, cv2.COLOR_BGR2GRAY)
# Compute optical flow
flow = cv2.calcOpticalFlowFarneback(
prev_gray, curr_gray,
self.optical_flow_params['flow'],
self.optical_flow_params['pyr_scale'],
self.optical_flow_params['levels'],
self.optical_flow_params['winsize'],
self.optical_flow_params['iterations'],
self.optical_flow_params['poly_n'],
self.optical_flow_params['poly_sigma'],
self.optical_flow_params['flags']
)
# Extract flow features from masked region
bool_mask = mask_small > 0
if np.any(bool_mask):
fx = flow[..., 0][bool_mask]
fy = flow[..., 1][bool_mask]
flow_features = np.array([
np.mean(fx), np.std(fx),
np.mean(fy), np.std(fy),
np.percentile(fx, 25), np.percentile(fx, 75),
np.percentile(fy, 25), np.percentile(fy, 75),
np.max(np.abs(fx)), np.max(np.abs(fy))
], dtype=np.float16)
else:
flow_features = np.zeros(10, dtype=np.float16)
return flow_features
def apply_segmentation_mask(self, frame):
"""Apply human segmentation to focus on person area"""
if not self.use_segmentation or self.segmentation is None:
return frame, None
try:
# Convert BGR to RGB for MediaPipe
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_rgb.flags.writeable = False
# Process segmentation
results = self.segmentation.process(frame_rgb)
segmentation_mask = results.segmentation_mask
if segmentation_mask is not None:
# Resize mask to match frame size
h, w = frame.shape[:2]
mask = cv2.resize(segmentation_mask, (w, h))
# Convert to 3-channel mask
mask_3channel = np.stack((mask,) * 3, axis=-1)
# Apply Gaussian blur to smooth edges
mask_3channel = cv2.GaussianBlur(mask_3channel, (5, 5), 0)
# Create segmented frame
segmented_frame = frame * mask_3channel
# Convert binary mask for optical flow processing
binary_mask = (mask > 0.5).astype(np.uint8) * 255
return segmented_frame.astype(np.uint8), binary_mask
else:
return frame, None
except Exception as e:
print(f"Segmentation error: {e}")
return frame, None
def create_enhanced_hand_mask(self, frame, left_hand_landmarks, right_hand_landmarks, pose_landmarks, seg_mask=None):
"""Create enhanced ROI mask combining landmarks and segmentation"""
h, w = frame.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
def draw_landmarks_on_mask(landmarks, radius=15):
if landmarks:
for landmark in landmarks.landmark:
x, y = int(landmark.x * w), int(landmark.y * h)
if 0 <= x < w and 0 <= y < h:
cv2.circle(mask, (x, y), radius=radius, color=255, thickness=-1)
# Draw hand keypoints with larger radius
draw_landmarks_on_mask(left_hand_landmarks, radius=25)
draw_landmarks_on_mask(right_hand_landmarks, radius=25)
# Draw upper body keypoints
if pose_landmarks:
upper_body_indices = list(range(0, 25))
for idx in upper_body_indices:
if idx < len(pose_landmarks.landmark):
landmark = pose_landmarks.landmark[idx]
x, y = int(landmark.x * w), int(landmark.y * h)
if 0 <= x < w and 0 <= y < h:
cv2.circle(mask, (x, y), radius=12, color=255, thickness=-1)
# Combine with segmentation mask if available
if seg_mask is not None:
seg_mask_resized = cv2.resize(seg_mask, (w, h))
mask = cv2.bitwise_and(mask, seg_mask_resized)
# Dilate mask
kernel = np.ones((20, 20), np.uint8)
dilated_mask = cv2.dilate(mask, kernel, iterations=2)
return dilated_mask
class SignLanguageModel(nn.Module):
"""Sign Language Recognition Model"""
def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout=0.5, flow_dim=10):
super(SignLanguageModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_classes = num_classes
# Keypoint feature projection
self.keypoint_projection = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout/2),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout/2)
)
# Flow feature projection
self.flow_projection = nn.Sequential(
nn.Linear(flow_dim, hidden_dim // 2),
nn.BatchNorm1d(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout/2),
nn.Linear(hidden_dim // 2, hidden_dim // 2),
nn.BatchNorm1d(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout/2)
)
# Feature fusion
self.fusion_layer = nn.Sequential(
nn.Linear(hidden_dim + (hidden_dim // 2), hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout/2)
)
# Bidirectional LSTM
self.lstm = nn.LSTM(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=True
)
# GRU for additional temporal features
self.gru = nn.GRU(
input_size=hidden_dim * 2,
hidden_size=hidden_dim,
num_layers=1,
batch_first=True,
bidirectional=True
)
# Batch normalization
self.lstm_bn = nn.BatchNorm1d(hidden_dim * 2)
self.gru_bn = nn.BatchNorm1d(hidden_dim * 2)
# Multi-head attention
self.multihead_attn = nn.MultiheadAttention(
embed_dim=hidden_dim * 2,
num_heads=4,
dropout=dropout,
batch_first=True
)
# Attention mechanism
self.attention = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
nn.Softmax(dim=1)
)
# Classifier
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.BatchNorm1d(hidden_dim * 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout/2),
nn.Linear(hidden_dim, num_classes)
)
self._init_weights()
def _init_weights(self):
"""Initialize model weights"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LSTM, nn.GRU)):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def forward(self, keypoints, flow=None):
"""Forward pass"""
batch_size, seq_len, _ = keypoints.size()
# Process keypoint features
kp_reshaped = keypoints.reshape(-1, keypoints.size(-1))
# First layer
kp_projected = self.keypoint_projection[0](kp_reshaped)
kp_projected = kp_projected.reshape(batch_size, seq_len, -1)
kp_projected = kp_projected.transpose(1, 2)
kp_projected = self.keypoint_projection[1](kp_projected)
kp_projected = kp_projected.transpose(1, 2)
kp_projected = self.keypoint_projection[2](kp_projected)
kp_projected = self.keypoint_projection[3](kp_projected)
# Second layer
kp_projected_reshaped = kp_projected.reshape(-1, kp_projected.size(-1))
kp_projected = self.keypoint_projection[4](kp_projected_reshaped)
kp_projected = kp_projected.reshape(batch_size, seq_len, -1)
kp_projected = kp_projected.transpose(1, 2)
kp_projected = self.keypoint_projection[5](kp_projected)
kp_projected = kp_projected.transpose(1, 2)
kp_projected = self.keypoint_projection[6](kp_projected)
kp_projected = self.keypoint_projection[7](kp_projected)
# Process flow features if provided
if flow is not None:
flow_reshaped = flow.reshape(-1, flow.size(-1))
# First layer
flow_projected = self.flow_projection[0](flow_reshaped)
flow_projected = flow_projected.reshape(batch_size, seq_len, -1)
flow_projected = flow_projected.transpose(1, 2)
flow_projected = self.flow_projection[1](flow_projected)
flow_projected = flow_projected.transpose(1, 2)
flow_projected = self.flow_projection[2](flow_projected)
flow_projected = self.flow_projection[3](flow_projected)
# Second layer
flow_projected_reshaped = flow_projected.reshape(-1, flow_projected.size(-1))
flow_projected = self.flow_projection[4](flow_projected_reshaped)
flow_projected = flow_projected.reshape(batch_size, seq_len, -1)
flow_projected = flow_projected.transpose(1, 2)
flow_projected = self.flow_projection[5](flow_projected)
flow_projected = flow_projected.transpose(1, 2)
flow_projected = self.flow_projection[6](flow_projected)
flow_projected = self.flow_projection[7](flow_projected)
# Feature fusion
combined_features = torch.cat([kp_projected, flow_projected], dim=2)
combined_reshaped = combined_features.reshape(-1, combined_features.size(-1))
fused_features = self.fusion_layer[0](combined_reshaped)
fused_features = fused_features.reshape(batch_size, seq_len, -1)
fused_features = fused_features.transpose(1, 2)
fused_features = self.fusion_layer[1](fused_features)
fused_features = fused_features.transpose(1, 2)
fused_features = self.fusion_layer[2](fused_features)
fused_features = self.fusion_layer[3](fused_features)
x_projected = fused_features
else:
x_projected = kp_projected
# Residual connection
x_residual = x_projected
# LSTM processing
lstm_out, _ = self.lstm(x_projected)
# Residual connection
x_residual_expanded = torch.cat([x_residual, x_residual], dim=2)
lstm_out_with_residual = lstm_out + x_residual_expanded
# BatchNorm for LSTM output
lstm_out_bn = lstm_out_with_residual.transpose(1, 2)
lstm_out_bn = self.lstm_bn(lstm_out_bn)
lstm_out = lstm_out_bn.transpose(1, 2)
# GRU processing
gru_out, _ = self.gru(lstm_out)
# BatchNorm for GRU output
gru_out_bn = gru_out.transpose(1, 2)
gru_out_bn = self.gru_bn(gru_out_bn)
gru_out = gru_out_bn.transpose(1, 2)
# Multi-head attention
attn_output, _ = self.multihead_attn(lstm_out, lstm_out, lstm_out)
# Traditional attention
attention_weights = self.attention(gru_out)
context_gru = torch.bmm(gru_out.transpose(1, 2), attention_weights)
context_gru = context_gru.squeeze(-1)
attention_weights_attn = self.attention(attn_output)
context_attn = torch.bmm(attn_output.transpose(1, 2), attention_weights_attn)
context_attn = context_attn.squeeze(-1)
# Combine contexts
combined_context = torch.cat([context_gru, context_attn], dim=1)
# Final classification
output = self.classifier(combined_context)
return output
class RealtimeSignPredictor:
def __init__(self, model_path, config_path, sequence_length=50, confidence_threshold=0.5, use_segmentation=True):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.sequence_length = sequence_length
self.confidence_threshold = confidence_threshold
# Load configuration and label mapping
with open(config_path, 'r') as f:
config = json.load(f)
self.label_mapping = config['label_mapping']
self.idx_to_label = {int(k): v for k, v in self.label_mapping.items()}
# Initialize model
self.model = SignLanguageModel(
input_dim=225, # keypoint dimension
hidden_dim=256,
num_layers=2,
num_classes=len(self.label_mapping),
dropout=0.5,
flow_dim=10
)
# Load trained weights
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
self.model.to(self.device)
self.model.eval()
# Initialize feature extractor with segmentation
self.feature_extractor = FeatureExtractor(use_segmentation=use_segmentation)
# Initialize sequences for storing features
self.keypoint_sequence = deque(maxlen=sequence_length)
self.flow_sequence = deque(maxlen=sequence_length)
# Variables for optical flow
self.prev_frame = None
self.prev_mask = None
print(f"Model loaded successfully. Using device: {self.device}")
print(f"Recognized classes: {list(self.idx_to_label.values())}")
def _linear_interpolate_sequence(self, data, target_length):
"""Linear interpolation to adjust sequence length"""
if len(data) == target_length:
return np.array(data)
data = np.array(data)
original_length = len(data)
feature_dim = data.shape[1]
interpolated_data = np.zeros((target_length, feature_dim))
for dim in range(feature_dim):
original_indices = np.linspace(0, original_length - 1, original_length)
target_indices = np.linspace(0, original_length - 1, target_length)
interpolated_data[:, dim] = np.interp(target_indices, original_indices, data[:, dim])
return interpolated_data
def process_frame(self, frame):
"""Process a single frame and extract features with segmentation"""
# Apply segmentation mask first
segmented_frame, seg_mask = self.feature_extractor.apply_segmentation_mask(frame)
# Convert to RGB for MediaPipe
frame_rgb = cv2.cvtColor(segmented_frame, cv2.COLOR_BGR2RGB)
frame_rgb.flags.writeable = False
# Process with MediaPipe
with self.feature_extractor.mp_holistic.Holistic(
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=1) as holistic:
results = holistic.process(frame_rgb)
frame_rgb.flags.writeable = True
# Extract keypoints
keypoints = self.feature_extractor.extract_pose_keypoints(segmented_frame, results)
# Create enhanced hand mask with segmentation
hand_mask = self.feature_extractor.create_enhanced_hand_mask(
segmented_frame,
results.left_hand_landmarks,
results.right_hand_landmarks,
results.pose_landmarks,
seg_mask
)
# Calculate optical flow on segmented frame
flow_features = np.zeros(10, dtype=np.float16)
if self.prev_frame is not None and self.prev_mask is not None:
flow_features = self.feature_extractor.compute_regional_optical_flow(
self.prev_frame, segmented_frame, hand_mask, downscale=0.5
)
# Update previous frame and mask
self.prev_frame = segmented_frame.copy()
self.prev_mask = hand_mask
# Add to sequences
self.keypoint_sequence.append(keypoints)
self.flow_sequence.append(flow_features)
return results, keypoints, flow_features
def predict(self):
"""Make prediction based on current sequence"""
if len(self.keypoint_sequence) < self.sequence_length:
return None, 0.0
# Convert sequences to arrays and interpolate
keypoints_array = self._linear_interpolate_sequence(
list(self.keypoint_sequence), self.sequence_length
)
flow_array = self._linear_interpolate_sequence(
list(self.flow_sequence), self.sequence_length
)
# Convert to tensors
keypoints_tensor = torch.FloatTensor(keypoints_array).unsqueeze(0).to(self.device)
flow_tensor = torch.FloatTensor(flow_array).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(keypoints_tensor, flow_tensor)
probabilities = F.softmax(outputs, dim=1)
max_prob, max_idx = torch.max(probabilities, 1)
predicted_label = self.idx_to_label[max_idx.item()]
confidence = max_prob.item()
return predicted_label, confidence
def get_top_predictions(self, top_k=3):
"""Get top-k predictions"""
if len(self.keypoint_sequence) < self.sequence_length:
return []
# Convert sequences to arrays and interpolate
keypoints_array = self._linear_interpolate_sequence(
list(self.keypoint_sequence), self.sequence_length
)
flow_array = self._linear_interpolate_sequence(
list(self.flow_sequence), self.sequence_length
)
# Convert to tensors
keypoints_tensor = torch.FloatTensor(keypoints_array).unsqueeze(0).to(self.device)
flow_tensor = torch.FloatTensor(flow_array).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(keypoints_tensor, flow_tensor)
probabilities = F.softmax(outputs, dim=1)
top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.idx_to_label)))
predictions = []
for i in range(top_indices.size(1)):
idx = top_indices[0, i].item()
prob = top_probs[0, i].item()
label = self.idx_to_label[idx]
predictions.append((label, prob))
return predictions
def draw_landmarks(self, frame, results):
"""Draw MediaPipe landmarks on frame"""
if results.left_hand_landmarks:
self.feature_extractor.mp_drawing.draw_landmarks(
frame, results.left_hand_landmarks,
self.feature_extractor.mp_holistic.HAND_CONNECTIONS,
self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(),
self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style()
)
if results.right_hand_landmarks:
self.feature_extractor.mp_drawing.draw_landmarks(
frame, results.right_hand_landmarks,
self.feature_extractor.mp_holistic.HAND_CONNECTIONS,
self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(),
self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style()
)
if results.pose_landmarks:
self.feature_extractor.mp_drawing.draw_landmarks(
frame, results.pose_landmarks,
self.feature_extractor.mp_holistic.POSE_CONNECTIONS,
self.feature_extractor.mp_drawing_styles.get_default_pose_landmarks_style()
)
return frame
class SingleSignPredictor:
def __init__(self, model_path, config_path, sequence_length=50, recording_duration=4.0, use_segmentation=True):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.sequence_length = sequence_length
self.recording_duration = recording_duration # seconds to record each sign
# Load configuration and label mapping
with open(config_path, 'r') as f:
config = json.load(f)
self.label_mapping = config['label_mapping']
self.idx_to_label = {int(k): v for k, v in self.label_mapping.items()}
# Initialize model
self.model = SignLanguageModel(
input_dim=225, # keypoint dimension
hidden_dim=256,
num_layers=2,
num_classes=len(self.label_mapping),
dropout=0.5,
flow_dim=10
)
# Load trained weights
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
self.model.to(self.device)
self.model.eval()
# Initialize feature extractor with segmentation
self.feature_extractor = FeatureExtractor(use_segmentation=use_segmentation)
# Recording state
self.is_recording = False
self.recording_start_time = None
self.recorded_keypoints = []
self.recorded_flow = []
self.prev_frame = None
self.prev_mask = None
# Results
self.last_prediction = None
self.last_confidence = 0.0
self.last_top_predictions = []
print(f"Model loaded successfully. Using device: {self.device}")
print(f"Recording duration: {self.recording_duration} seconds")
print(f"Recognized classes: {list(self.idx_to_label.values())}")
def _linear_interpolate_sequence(self, data, target_length):
"""Linear interpolation to adjust sequence length"""
if len(data) == target_length:
return np.array(data)
data = np.array(data)
original_length = len(data)
feature_dim = data.shape[1]
interpolated_data = np.zeros((target_length, feature_dim))
for dim in range(feature_dim):
original_indices = np.linspace(0, original_length - 1, original_length)
target_indices = np.linspace(0, original_length - 1, target_length)
interpolated_data[:, dim] = np.interp(target_indices, original_indices, data[:, dim])
return interpolated_data
def start_recording(self):
"""Start recording a new sign"""
self.is_recording = True
self.recording_start_time = time.time()
self.recorded_keypoints = []
self.recorded_flow = []
self.prev_frame = None
self.prev_mask = None
print("Recording started...")
def stop_recording(self):
"""Stop recording and make prediction"""
if not self.is_recording:
return
self.is_recording = False
print(f"Recording stopped. Collected {len(self.recorded_keypoints)} frames")
if len(self.recorded_keypoints) < 10: # Need minimum frames
print("Not enough frames for prediction")
self.last_prediction = "Not enough data"
self.last_confidence = 0.0
self.last_top_predictions = []
return
# Interpolate to target length
keypoints_array = self._linear_interpolate_sequence(
self.recorded_keypoints, self.sequence_length
)
flow_array = self._linear_interpolate_sequence(
self.recorded_flow, self.sequence_length
)
# Convert to tensors
keypoints_tensor = torch.FloatTensor(keypoints_array).unsqueeze(0).to(self.device)
flow_tensor = torch.FloatTensor(flow_array).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(keypoints_tensor, flow_tensor)
probabilities = F.softmax(outputs, dim=1)
# Get top-5 predictions
top_probs, top_indices = torch.topk(probabilities, k=min(5, len(self.idx_to_label)))
predictions = []
for i in range(top_indices.size(1)):
idx = top_indices[0, i].item()
prob = top_probs[0, i].item()
label = self.idx_to_label[idx]
predictions.append((label, prob))
# Store results
self.last_prediction = predictions[0][0]
self.last_confidence = predictions[0][1]
self.last_top_predictions = predictions
print(f"Prediction: {self.last_prediction} (confidence: {self.last_confidence:.3f})")
def process_frame(self, frame):
"""Process a single frame with segmentation"""
# Apply segmentation mask first
segmented_frame, seg_mask = self.feature_extractor.apply_segmentation_mask(frame)
# Convert to RGB for MediaPipe
frame_rgb = cv2.cvtColor(segmented_frame, cv2.COLOR_BGR2RGB)
frame_rgb.flags.writeable = False
# Process with MediaPipe
with self.feature_extractor.mp_holistic.Holistic(
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=1) as holistic:
results = holistic.process(frame_rgb)
frame_rgb.flags.writeable = True
# Extract keypoints
keypoints = self.feature_extractor.extract_pose_keypoints(segmented_frame, results)
# Create enhanced hand mask with segmentation
hand_mask = self.feature_extractor.create_enhanced_hand_mask(
segmented_frame,
results.left_hand_landmarks,
results.right_hand_landmarks,
results.pose_landmarks,
seg_mask
)
# Calculate optical flow on segmented frame
flow_features = np.zeros(10, dtype=np.float16)
if self.prev_frame is not None and self.prev_mask is not None:
flow_features = self.feature_extractor.compute_regional_optical_flow(
self.prev_frame, segmented_frame, hand_mask, downscale=0.5
)
# If recording, store the features
if self.is_recording:
self.recorded_keypoints.append(keypoints)
self.recorded_flow.append(flow_features)
# Check if recording duration is reached
if time.time() - self.recording_start_time >= self.recording_duration:
self.stop_recording()
# Update previous frame and mask
self.prev_frame = segmented_frame.copy()
self.prev_mask = hand_mask
return results
def draw_landmarks(self, frame, results):
"""Draw MediaPipe landmarks on frame"""
if results.left_hand_landmarks:
self.feature_extractor.mp_drawing.draw_landmarks(
frame, results.left_hand_landmarks,
self.feature_extractor.mp_holistic.HAND_CONNECTIONS,
self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(),
self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style()
)
if results.right_hand_landmarks:
self.feature_extractor.mp_drawing.draw_landmarks(
frame, results.right_hand_landmarks,
self.feature_extractor.mp_holistic.HAND_CONNECTIONS,
self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(),
self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style()
)
if results.pose_landmarks:
self.feature_extractor.mp_drawing.draw_landmarks(
frame, results.pose_landmarks,
self.feature_extractor.mp_holistic.POSE_CONNECTIONS,
self.feature_extractor.mp_drawing_styles.get_default_pose_landmarks_style()
)
return frame
def main():
parser = argparse.ArgumentParser(description='Sign Language Recognition - Choose Mode')
parser.add_argument('--model', default='tsflow/models/best_model.pt',
help='Path to trained model')
parser.add_argument('--config', default='tsflow/results/test_results.json',
help='Path to config file with label mappings')
parser.add_argument('--camera', type=int, default=0,
help='Camera index')
parser.add_argument('--sequence_length', type=int, default=50,
help='Sequence length for prediction')
parser.add_argument('--confidence_threshold', type=float, default=0.5,
help='Confidence threshold for predictions')
parser.add_argument('--mode', choices=['realtime', 'single'], default='single',
help='Recognition mode: realtime (continuous) or single (one-by-one)')
parser.add_argument('--recording_duration', type=float, default=4.0,
help='Duration to record each sign in single mode (seconds)')
parser.add_argument('--use_segmentation', action='store_true', default=True,
help='Enable human segmentation for background removal')
parser.add_argument('--no_segmentation', action='store_true', default=False,
help='Disable human segmentation')
args = parser.parse_args()
# Check if model and config files exist
if not os.path.exists(args.model):
print(f"Model file not found: {args.model}")
return
if not os.path.exists(args.config):
print(f"Config file not found: {args.config}")
return
# Determine segmentation setting
use_segmentation = args.use_segmentation and not args.no_segmentation
if args.mode == 'single':
# Single sign mode
predictor = SingleSignPredictor(
model_path=args.model,
config_path=args.config,
sequence_length=args.sequence_length,
recording_duration=args.recording_duration,
use_segmentation=use_segmentation
)
# Initialize camera
cap = cv2.VideoCapture(args.camera)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
cap.set(cv2.CAP_PROP_FPS, 30)
if not cap.isOpened():
print(f"Cannot open camera {args.camera}")
return
print("\n" + "="*60)
print("Single Sign Language Recognition")
print("="*60)
print("Controls:")
print(" SPACE: Start/Stop recording a sign")
print(" 'c': Clear last prediction")
print(" 'q': Quit")
print("="*60)
# FPS calculation
fps_counter = 0
fps_start_time = time.time()
current_fps = 0
while True:
ret, frame = cap.read()
if not ret:
print("Failed to read frame from camera")
break
# Mirror frame horizontally
frame = cv2.flip(frame, 1)
# Process frame
results = predictor.process_frame(frame)
# Draw landmarks
frame = predictor.draw_landmarks(frame, results)
# Calculate FPS
fps_counter += 1
if fps_counter % 30 == 0:
fps_end_time = time.time()
current_fps = 30 / (fps_end_time - fps_start_time)
fps_start_time = fps_end_time
# Draw UI
h, w, _ = frame.shape
# Main info panel
cv2.rectangle(frame, (10, 10), (w-10, 200), (0, 0, 0), -1)
cv2.rectangle(frame, (10, 10), (w-10, 200), (255, 255, 255), 2)
# FPS
cv2.putText(frame, f"FPS: {current_fps:.1f}", (20, 35),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Recording status
if predictor.is_recording:
elapsed = time.time() - predictor.recording_start_time
remaining = max(0, args.recording_duration - elapsed)
progress = elapsed / args.recording_duration
# Recording indicator
cv2.putText(frame, "RECORDING", (20, 65),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
cv2.putText(frame, f"Time: {remaining:.1f}s", (20, 90),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
# Progress bar
bar_width = w - 40
bar_height = 15
cv2.rectangle(frame, (20, 100), (20 + bar_width, 100 + bar_height), (100, 100, 100), -1)
cv2.rectangle(frame, (20, 100), (20 + int(bar_width * progress), 100 + bar_height), (0, 0, 255), -1)
# Recording circle (blinking effect)
if int(elapsed * 4) % 2 == 0: # Blink every 0.25 seconds
cv2.circle(frame, (w - 40, 40), 15, (0, 0, 255), -1)
else:
cv2.putText(frame, "READY - Press SPACE to record", (20, 65),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Show last prediction if available
if predictor.last_prediction and predictor.last_prediction != "Not enough data":
y_offset = 100
cv2.putText(frame, "Last Prediction:", (20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# Top prediction
cv2.putText(frame, f"1. {predictor.last_prediction}: {predictor.last_confidence:.3f}",
(20, y_offset + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Top 5 predictions
for i, (label, conf) in enumerate(predictor.last_top_predictions[1:4], 2):
cv2.putText(frame, f"{i}. {label}: {conf:.3f}",
(20, y_offset + 25 * i), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Instructions
cv2.putText(frame, "SPACE: Record | C: Clear | Q: Quit", (20, h-20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Show frame
cv2.imshow('Single Sign Language Recognition', frame)
# Handle key presses
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
elif key == ord(' '): # Space bar
if not predictor.is_recording:
predictor.start_recording()
else:
predictor.stop_recording()
elif key == ord('c'):
predictor.last_prediction = None
predictor.last_confidence = 0.0
predictor.last_top_predictions = []
print("Prediction cleared")
else:
# Realtime mode
predictor = RealtimeSignPredictor(
model_path=args.model,
config_path=args.config,
sequence_length=args.sequence_length,
confidence_threshold=args.confidence_threshold,
use_segmentation=use_segmentation
)
# Initialize camera
cap = cv2.VideoCapture(args.camera)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
cap.set(cv2.CAP_PROP_FPS, 30)
if not cap.isOpened():
print(f"Cannot open camera {args.camera}")
return
print("Starting real-time sign language recognition...")
print("Press 'q' to quit, 'r' to reset sequence")
# FPS calculation
fps_counter = 0
fps_start_time = time.time()
current_fps = 0
while True:
ret, frame = cap.read()
if not ret:
print("Failed to read frame from camera")
break
# Mirror frame horizontally
frame = cv2.flip(frame, 1)
# Process frame
results, keypoints, flow_features = predictor.process_frame(frame)
# Draw landmarks
frame = predictor.draw_landmarks(frame, results)
# Get predictions
top_predictions = predictor.get_top_predictions(top_k=3)
# Calculate FPS
fps_counter += 1
if fps_counter % 30 == 0:
fps_end_time = time.time()
current_fps = 30 / (fps_end_time - fps_start_time)
fps_start_time = fps_end_time
# Draw information on frame
h, w, _ = frame.shape
# Background for text
cv2.rectangle(frame, (10, 10), (w-10, 150), (0, 0, 0), -1)
cv2.rectangle(frame, (10, 10), (w-10, 150), (255, 255, 255), 2)
# FPS
cv2.putText(frame, f"FPS: {current_fps:.1f}", (20, 35),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Sequence progress
progress = len(predictor.keypoint_sequence) / args.sequence_length
cv2.putText(frame, f"Sequence: {len(predictor.keypoint_sequence)}/{args.sequence_length}",
(20, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# Progress bar
bar_width = w - 40
bar_height = 10
cv2.rectangle(frame, (20, 70), (20 + bar_width, 70 + bar_height), (100, 100, 100), -1)
cv2.rectangle(frame, (20, 70), (20 + int(bar_width * progress), 70 + bar_height), (0, 255, 0), -1)
# Predictions
y_offset = 100
if top_predictions:
for i, (label, confidence) in enumerate(top_predictions):
color = (0, 255, 0) if confidence > args.confidence_threshold else (0, 255, 255)
text = f"{i+1}. {label}: {confidence:.2f}"
cv2.putText(frame, text, (20, y_offset + i * 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
else:
cv2.putText(frame, "Collecting frames...", (20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# Instructions
cv2.putText(frame, "Press 'q' to quit, 'r' to reset", (20, h-20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Show frame
cv2.imshow('Real-time Sign Language Recognition', frame)
# Handle key presses
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
elif key == ord('r'):
predictor.keypoint_sequence.clear()
predictor.flow_sequence.clear()
print("Sequence reset")
# Cleanup
cap.release()
cv2.destroyAllWindows()
print("Recognition stopped")
if __name__ == "__main__":
main()