Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import mediapipe as mp | |
| import json | |
| import time | |
| from collections import deque | |
| import argparse | |
| class FeatureExtractor: | |
| def __init__(self, use_segmentation=True): | |
| # Initialize MediaPipe models | |
| self.mp_holistic = mp.solutions.holistic | |
| self.mp_drawing = mp.solutions.drawing_utils | |
| self.mp_drawing_styles = mp.solutions.drawing_styles | |
| self.mp_selfie_segmentation = mp.solutions.selfie_segmentation | |
| # Segmentation settings | |
| self.use_segmentation = use_segmentation | |
| self.segmentation = None | |
| if self.use_segmentation: | |
| self.segmentation = self.mp_selfie_segmentation.SelfieSegmentation(model_selection=1) | |
| # Optical flow parameters | |
| self.optical_flow_params = dict( | |
| flow=None, | |
| pyr_scale=0.5, | |
| levels=3, | |
| winsize=15, | |
| iterations=3, | |
| poly_n=5, | |
| poly_sigma=1.2, | |
| flags=0 | |
| ) | |
| def extract_pose_keypoints(self, frame, holistic_results): | |
| """Extract pose keypoints""" | |
| keypoints = [] | |
| # Extract hand keypoints | |
| if holistic_results.left_hand_landmarks: | |
| for landmark in holistic_results.left_hand_landmarks.landmark: | |
| keypoints.extend([landmark.x, landmark.y, landmark.z]) | |
| else: | |
| keypoints.extend([0] * (21 * 3)) | |
| if holistic_results.right_hand_landmarks: | |
| for landmark in holistic_results.right_hand_landmarks.landmark: | |
| keypoints.extend([landmark.x, landmark.y, landmark.z]) | |
| else: | |
| keypoints.extend([0] * (21 * 3)) | |
| # Extract pose keypoints | |
| if holistic_results.pose_landmarks: | |
| for landmark in holistic_results.pose_landmarks.landmark: | |
| keypoints.extend([landmark.x, landmark.y, landmark.z]) | |
| else: | |
| keypoints.extend([0] * (33 * 3)) | |
| return np.array(keypoints) | |
| def create_hand_mask(self, frame, left_hand_landmarks, right_hand_landmarks, pose_landmarks): | |
| """Create ROI mask for hands and upper body""" | |
| h, w = frame.shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| def draw_landmarks_on_mask(landmarks, radius=15): | |
| if landmarks: | |
| for landmark in landmarks.landmark: | |
| x, y = int(landmark.x * w), int(landmark.y * h) | |
| if 0 <= x < w and 0 <= y < h: | |
| cv2.circle(mask, (x, y), radius=radius, color=255, thickness=-1) | |
| # Draw hand keypoints | |
| draw_landmarks_on_mask(left_hand_landmarks, radius=20) | |
| draw_landmarks_on_mask(right_hand_landmarks, radius=20) | |
| # Draw upper body keypoints | |
| if pose_landmarks: | |
| upper_body_indices = list(range(0, 25)) | |
| for idx in upper_body_indices: | |
| if idx < len(pose_landmarks.landmark): | |
| landmark = pose_landmarks.landmark[idx] | |
| x, y = int(landmark.x * w), int(landmark.y * h) | |
| if 0 <= x < w and 0 <= y < h: | |
| cv2.circle(mask, (x, y), radius=10, color=255, thickness=-1) | |
| # Dilate mask | |
| kernel = np.ones((15, 15), np.uint8) | |
| dilated_mask = cv2.dilate(mask, kernel, iterations=1) | |
| return dilated_mask | |
| def compute_regional_optical_flow(self, prev_frame, curr_frame, mask, downscale=0.5): | |
| """Compute optical flow only in masked regions""" | |
| if downscale < 1.0: | |
| h, w = prev_frame.shape[:2] | |
| new_h, new_w = int(h * downscale), int(w * downscale) | |
| prev_small = cv2.resize(prev_frame, (new_w, new_h)) | |
| curr_small = cv2.resize(curr_frame, (new_w, new_h)) | |
| mask_small = cv2.resize(mask, (new_w, new_h)) | |
| else: | |
| prev_small = prev_frame | |
| curr_small = curr_frame | |
| mask_small = mask | |
| # Convert to grayscale | |
| prev_gray = cv2.cvtColor(prev_small, cv2.COLOR_BGR2GRAY) | |
| curr_gray = cv2.cvtColor(curr_small, cv2.COLOR_BGR2GRAY) | |
| # Compute optical flow | |
| flow = cv2.calcOpticalFlowFarneback( | |
| prev_gray, curr_gray, | |
| self.optical_flow_params['flow'], | |
| self.optical_flow_params['pyr_scale'], | |
| self.optical_flow_params['levels'], | |
| self.optical_flow_params['winsize'], | |
| self.optical_flow_params['iterations'], | |
| self.optical_flow_params['poly_n'], | |
| self.optical_flow_params['poly_sigma'], | |
| self.optical_flow_params['flags'] | |
| ) | |
| # Extract flow features from masked region | |
| bool_mask = mask_small > 0 | |
| if np.any(bool_mask): | |
| fx = flow[..., 0][bool_mask] | |
| fy = flow[..., 1][bool_mask] | |
| flow_features = np.array([ | |
| np.mean(fx), np.std(fx), | |
| np.mean(fy), np.std(fy), | |
| np.percentile(fx, 25), np.percentile(fx, 75), | |
| np.percentile(fy, 25), np.percentile(fy, 75), | |
| np.max(np.abs(fx)), np.max(np.abs(fy)) | |
| ], dtype=np.float16) | |
| else: | |
| flow_features = np.zeros(10, dtype=np.float16) | |
| return flow_features | |
| def apply_segmentation_mask(self, frame): | |
| """Apply human segmentation to focus on person area""" | |
| if not self.use_segmentation or self.segmentation is None: | |
| return frame, None | |
| try: | |
| # Convert BGR to RGB for MediaPipe | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_rgb.flags.writeable = False | |
| # Process segmentation | |
| results = self.segmentation.process(frame_rgb) | |
| segmentation_mask = results.segmentation_mask | |
| if segmentation_mask is not None: | |
| # Resize mask to match frame size | |
| h, w = frame.shape[:2] | |
| mask = cv2.resize(segmentation_mask, (w, h)) | |
| # Convert to 3-channel mask | |
| mask_3channel = np.stack((mask,) * 3, axis=-1) | |
| # Apply Gaussian blur to smooth edges | |
| mask_3channel = cv2.GaussianBlur(mask_3channel, (5, 5), 0) | |
| # Create segmented frame | |
| segmented_frame = frame * mask_3channel | |
| # Convert binary mask for optical flow processing | |
| binary_mask = (mask > 0.5).astype(np.uint8) * 255 | |
| return segmented_frame.astype(np.uint8), binary_mask | |
| else: | |
| return frame, None | |
| except Exception as e: | |
| print(f"Segmentation error: {e}") | |
| return frame, None | |
| def create_enhanced_hand_mask(self, frame, left_hand_landmarks, right_hand_landmarks, pose_landmarks, seg_mask=None): | |
| """Create enhanced ROI mask combining landmarks and segmentation""" | |
| h, w = frame.shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| def draw_landmarks_on_mask(landmarks, radius=15): | |
| if landmarks: | |
| for landmark in landmarks.landmark: | |
| x, y = int(landmark.x * w), int(landmark.y * h) | |
| if 0 <= x < w and 0 <= y < h: | |
| cv2.circle(mask, (x, y), radius=radius, color=255, thickness=-1) | |
| # Draw hand keypoints with larger radius | |
| draw_landmarks_on_mask(left_hand_landmarks, radius=25) | |
| draw_landmarks_on_mask(right_hand_landmarks, radius=25) | |
| # Draw upper body keypoints | |
| if pose_landmarks: | |
| upper_body_indices = list(range(0, 25)) | |
| for idx in upper_body_indices: | |
| if idx < len(pose_landmarks.landmark): | |
| landmark = pose_landmarks.landmark[idx] | |
| x, y = int(landmark.x * w), int(landmark.y * h) | |
| if 0 <= x < w and 0 <= y < h: | |
| cv2.circle(mask, (x, y), radius=12, color=255, thickness=-1) | |
| # Combine with segmentation mask if available | |
| if seg_mask is not None: | |
| seg_mask_resized = cv2.resize(seg_mask, (w, h)) | |
| mask = cv2.bitwise_and(mask, seg_mask_resized) | |
| # Dilate mask | |
| kernel = np.ones((20, 20), np.uint8) | |
| dilated_mask = cv2.dilate(mask, kernel, iterations=2) | |
| return dilated_mask | |
| class SignLanguageModel(nn.Module): | |
| """Sign Language Recognition Model""" | |
| def __init__(self, input_dim, hidden_dim, num_layers, num_classes, dropout=0.5, flow_dim=10): | |
| super(SignLanguageModel, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| self.num_classes = num_classes | |
| # Keypoint feature projection | |
| self.keypoint_projection = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout/2), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout/2) | |
| ) | |
| # Flow feature projection | |
| self.flow_projection = nn.Sequential( | |
| nn.Linear(flow_dim, hidden_dim // 2), | |
| nn.BatchNorm1d(hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout/2), | |
| nn.Linear(hidden_dim // 2, hidden_dim // 2), | |
| nn.BatchNorm1d(hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout/2) | |
| ) | |
| # Feature fusion | |
| self.fusion_layer = nn.Sequential( | |
| nn.Linear(hidden_dim + (hidden_dim // 2), hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout/2) | |
| ) | |
| # Bidirectional LSTM | |
| self.lstm = nn.LSTM( | |
| input_size=hidden_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout if num_layers > 1 else 0, | |
| bidirectional=True | |
| ) | |
| # GRU for additional temporal features | |
| self.gru = nn.GRU( | |
| input_size=hidden_dim * 2, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| # Batch normalization | |
| self.lstm_bn = nn.BatchNorm1d(hidden_dim * 2) | |
| self.gru_bn = nn.BatchNorm1d(hidden_dim * 2) | |
| # Multi-head attention | |
| self.multihead_attn = nn.MultiheadAttention( | |
| embed_dim=hidden_dim * 2, | |
| num_heads=4, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| # Attention mechanism | |
| self.attention = nn.Sequential( | |
| nn.Linear(hidden_dim * 2, hidden_dim), | |
| nn.Tanh(), | |
| nn.Linear(hidden_dim, 1), | |
| nn.Softmax(dim=1) | |
| ) | |
| # Classifier | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim * 4, hidden_dim * 2), | |
| nn.BatchNorm1d(hidden_dim * 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim * 2, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout/2), | |
| nn.Linear(hidden_dim, num_classes) | |
| ) | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize model weights""" | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, (nn.LSTM, nn.GRU)): | |
| for name, param in m.named_parameters(): | |
| if 'weight' in name: | |
| nn.init.orthogonal_(param) | |
| elif 'bias' in name: | |
| nn.init.zeros_(param) | |
| def forward(self, keypoints, flow=None): | |
| """Forward pass""" | |
| batch_size, seq_len, _ = keypoints.size() | |
| # Process keypoint features | |
| kp_reshaped = keypoints.reshape(-1, keypoints.size(-1)) | |
| # First layer | |
| kp_projected = self.keypoint_projection[0](kp_reshaped) | |
| kp_projected = kp_projected.reshape(batch_size, seq_len, -1) | |
| kp_projected = kp_projected.transpose(1, 2) | |
| kp_projected = self.keypoint_projection[1](kp_projected) | |
| kp_projected = kp_projected.transpose(1, 2) | |
| kp_projected = self.keypoint_projection[2](kp_projected) | |
| kp_projected = self.keypoint_projection[3](kp_projected) | |
| # Second layer | |
| kp_projected_reshaped = kp_projected.reshape(-1, kp_projected.size(-1)) | |
| kp_projected = self.keypoint_projection[4](kp_projected_reshaped) | |
| kp_projected = kp_projected.reshape(batch_size, seq_len, -1) | |
| kp_projected = kp_projected.transpose(1, 2) | |
| kp_projected = self.keypoint_projection[5](kp_projected) | |
| kp_projected = kp_projected.transpose(1, 2) | |
| kp_projected = self.keypoint_projection[6](kp_projected) | |
| kp_projected = self.keypoint_projection[7](kp_projected) | |
| # Process flow features if provided | |
| if flow is not None: | |
| flow_reshaped = flow.reshape(-1, flow.size(-1)) | |
| # First layer | |
| flow_projected = self.flow_projection[0](flow_reshaped) | |
| flow_projected = flow_projected.reshape(batch_size, seq_len, -1) | |
| flow_projected = flow_projected.transpose(1, 2) | |
| flow_projected = self.flow_projection[1](flow_projected) | |
| flow_projected = flow_projected.transpose(1, 2) | |
| flow_projected = self.flow_projection[2](flow_projected) | |
| flow_projected = self.flow_projection[3](flow_projected) | |
| # Second layer | |
| flow_projected_reshaped = flow_projected.reshape(-1, flow_projected.size(-1)) | |
| flow_projected = self.flow_projection[4](flow_projected_reshaped) | |
| flow_projected = flow_projected.reshape(batch_size, seq_len, -1) | |
| flow_projected = flow_projected.transpose(1, 2) | |
| flow_projected = self.flow_projection[5](flow_projected) | |
| flow_projected = flow_projected.transpose(1, 2) | |
| flow_projected = self.flow_projection[6](flow_projected) | |
| flow_projected = self.flow_projection[7](flow_projected) | |
| # Feature fusion | |
| combined_features = torch.cat([kp_projected, flow_projected], dim=2) | |
| combined_reshaped = combined_features.reshape(-1, combined_features.size(-1)) | |
| fused_features = self.fusion_layer[0](combined_reshaped) | |
| fused_features = fused_features.reshape(batch_size, seq_len, -1) | |
| fused_features = fused_features.transpose(1, 2) | |
| fused_features = self.fusion_layer[1](fused_features) | |
| fused_features = fused_features.transpose(1, 2) | |
| fused_features = self.fusion_layer[2](fused_features) | |
| fused_features = self.fusion_layer[3](fused_features) | |
| x_projected = fused_features | |
| else: | |
| x_projected = kp_projected | |
| # Residual connection | |
| x_residual = x_projected | |
| # LSTM processing | |
| lstm_out, _ = self.lstm(x_projected) | |
| # Residual connection | |
| x_residual_expanded = torch.cat([x_residual, x_residual], dim=2) | |
| lstm_out_with_residual = lstm_out + x_residual_expanded | |
| # BatchNorm for LSTM output | |
| lstm_out_bn = lstm_out_with_residual.transpose(1, 2) | |
| lstm_out_bn = self.lstm_bn(lstm_out_bn) | |
| lstm_out = lstm_out_bn.transpose(1, 2) | |
| # GRU processing | |
| gru_out, _ = self.gru(lstm_out) | |
| # BatchNorm for GRU output | |
| gru_out_bn = gru_out.transpose(1, 2) | |
| gru_out_bn = self.gru_bn(gru_out_bn) | |
| gru_out = gru_out_bn.transpose(1, 2) | |
| # Multi-head attention | |
| attn_output, _ = self.multihead_attn(lstm_out, lstm_out, lstm_out) | |
| # Traditional attention | |
| attention_weights = self.attention(gru_out) | |
| context_gru = torch.bmm(gru_out.transpose(1, 2), attention_weights) | |
| context_gru = context_gru.squeeze(-1) | |
| attention_weights_attn = self.attention(attn_output) | |
| context_attn = torch.bmm(attn_output.transpose(1, 2), attention_weights_attn) | |
| context_attn = context_attn.squeeze(-1) | |
| # Combine contexts | |
| combined_context = torch.cat([context_gru, context_attn], dim=1) | |
| # Final classification | |
| output = self.classifier(combined_context) | |
| return output | |
| class RealtimeSignPredictor: | |
| def __init__(self, model_path, config_path, sequence_length=50, confidence_threshold=0.5, use_segmentation=True): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.sequence_length = sequence_length | |
| self.confidence_threshold = confidence_threshold | |
| # Load configuration and label mapping | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| self.label_mapping = config['label_mapping'] | |
| self.idx_to_label = {int(k): v for k, v in self.label_mapping.items()} | |
| # Initialize model | |
| self.model = SignLanguageModel( | |
| input_dim=225, # keypoint dimension | |
| hidden_dim=256, | |
| num_layers=2, | |
| num_classes=len(self.label_mapping), | |
| dropout=0.5, | |
| flow_dim=10 | |
| ) | |
| # Load trained weights | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| if 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Initialize feature extractor with segmentation | |
| self.feature_extractor = FeatureExtractor(use_segmentation=use_segmentation) | |
| # Initialize sequences for storing features | |
| self.keypoint_sequence = deque(maxlen=sequence_length) | |
| self.flow_sequence = deque(maxlen=sequence_length) | |
| # Variables for optical flow | |
| self.prev_frame = None | |
| self.prev_mask = None | |
| print(f"Model loaded successfully. Using device: {self.device}") | |
| print(f"Recognized classes: {list(self.idx_to_label.values())}") | |
| def _linear_interpolate_sequence(self, data, target_length): | |
| """Linear interpolation to adjust sequence length""" | |
| if len(data) == target_length: | |
| return np.array(data) | |
| data = np.array(data) | |
| original_length = len(data) | |
| feature_dim = data.shape[1] | |
| interpolated_data = np.zeros((target_length, feature_dim)) | |
| for dim in range(feature_dim): | |
| original_indices = np.linspace(0, original_length - 1, original_length) | |
| target_indices = np.linspace(0, original_length - 1, target_length) | |
| interpolated_data[:, dim] = np.interp(target_indices, original_indices, data[:, dim]) | |
| return interpolated_data | |
| def process_frame(self, frame): | |
| """Process a single frame and extract features with segmentation""" | |
| # Apply segmentation mask first | |
| segmented_frame, seg_mask = self.feature_extractor.apply_segmentation_mask(frame) | |
| # Convert to RGB for MediaPipe | |
| frame_rgb = cv2.cvtColor(segmented_frame, cv2.COLOR_BGR2RGB) | |
| frame_rgb.flags.writeable = False | |
| # Process with MediaPipe | |
| with self.feature_extractor.mp_holistic.Holistic( | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5, | |
| model_complexity=1) as holistic: | |
| results = holistic.process(frame_rgb) | |
| frame_rgb.flags.writeable = True | |
| # Extract keypoints | |
| keypoints = self.feature_extractor.extract_pose_keypoints(segmented_frame, results) | |
| # Create enhanced hand mask with segmentation | |
| hand_mask = self.feature_extractor.create_enhanced_hand_mask( | |
| segmented_frame, | |
| results.left_hand_landmarks, | |
| results.right_hand_landmarks, | |
| results.pose_landmarks, | |
| seg_mask | |
| ) | |
| # Calculate optical flow on segmented frame | |
| flow_features = np.zeros(10, dtype=np.float16) | |
| if self.prev_frame is not None and self.prev_mask is not None: | |
| flow_features = self.feature_extractor.compute_regional_optical_flow( | |
| self.prev_frame, segmented_frame, hand_mask, downscale=0.5 | |
| ) | |
| # Update previous frame and mask | |
| self.prev_frame = segmented_frame.copy() | |
| self.prev_mask = hand_mask | |
| # Add to sequences | |
| self.keypoint_sequence.append(keypoints) | |
| self.flow_sequence.append(flow_features) | |
| return results, keypoints, flow_features | |
| def predict(self): | |
| """Make prediction based on current sequence""" | |
| if len(self.keypoint_sequence) < self.sequence_length: | |
| return None, 0.0 | |
| # Convert sequences to arrays and interpolate | |
| keypoints_array = self._linear_interpolate_sequence( | |
| list(self.keypoint_sequence), self.sequence_length | |
| ) | |
| flow_array = self._linear_interpolate_sequence( | |
| list(self.flow_sequence), self.sequence_length | |
| ) | |
| # Convert to tensors | |
| keypoints_tensor = torch.FloatTensor(keypoints_array).unsqueeze(0).to(self.device) | |
| flow_tensor = torch.FloatTensor(flow_array).unsqueeze(0).to(self.device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(keypoints_tensor, flow_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| max_prob, max_idx = torch.max(probabilities, 1) | |
| predicted_label = self.idx_to_label[max_idx.item()] | |
| confidence = max_prob.item() | |
| return predicted_label, confidence | |
| def get_top_predictions(self, top_k=3): | |
| """Get top-k predictions""" | |
| if len(self.keypoint_sequence) < self.sequence_length: | |
| return [] | |
| # Convert sequences to arrays and interpolate | |
| keypoints_array = self._linear_interpolate_sequence( | |
| list(self.keypoint_sequence), self.sequence_length | |
| ) | |
| flow_array = self._linear_interpolate_sequence( | |
| list(self.flow_sequence), self.sequence_length | |
| ) | |
| # Convert to tensors | |
| keypoints_tensor = torch.FloatTensor(keypoints_array).unsqueeze(0).to(self.device) | |
| flow_tensor = torch.FloatTensor(flow_array).unsqueeze(0).to(self.device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(keypoints_tensor, flow_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.idx_to_label))) | |
| predictions = [] | |
| for i in range(top_indices.size(1)): | |
| idx = top_indices[0, i].item() | |
| prob = top_probs[0, i].item() | |
| label = self.idx_to_label[idx] | |
| predictions.append((label, prob)) | |
| return predictions | |
| def draw_landmarks(self, frame, results): | |
| """Draw MediaPipe landmarks on frame""" | |
| if results.left_hand_landmarks: | |
| self.feature_extractor.mp_drawing.draw_landmarks( | |
| frame, results.left_hand_landmarks, | |
| self.feature_extractor.mp_holistic.HAND_CONNECTIONS, | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(), | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style() | |
| ) | |
| if results.right_hand_landmarks: | |
| self.feature_extractor.mp_drawing.draw_landmarks( | |
| frame, results.right_hand_landmarks, | |
| self.feature_extractor.mp_holistic.HAND_CONNECTIONS, | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(), | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style() | |
| ) | |
| if results.pose_landmarks: | |
| self.feature_extractor.mp_drawing.draw_landmarks( | |
| frame, results.pose_landmarks, | |
| self.feature_extractor.mp_holistic.POSE_CONNECTIONS, | |
| self.feature_extractor.mp_drawing_styles.get_default_pose_landmarks_style() | |
| ) | |
| return frame | |
| class SingleSignPredictor: | |
| def __init__(self, model_path, config_path, sequence_length=50, recording_duration=4.0, use_segmentation=True): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.sequence_length = sequence_length | |
| self.recording_duration = recording_duration # seconds to record each sign | |
| # Load configuration and label mapping | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| self.label_mapping = config['label_mapping'] | |
| self.idx_to_label = {int(k): v for k, v in self.label_mapping.items()} | |
| # Initialize model | |
| self.model = SignLanguageModel( | |
| input_dim=225, # keypoint dimension | |
| hidden_dim=256, | |
| num_layers=2, | |
| num_classes=len(self.label_mapping), | |
| dropout=0.5, | |
| flow_dim=10 | |
| ) | |
| # Load trained weights | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| if 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Initialize feature extractor with segmentation | |
| self.feature_extractor = FeatureExtractor(use_segmentation=use_segmentation) | |
| # Recording state | |
| self.is_recording = False | |
| self.recording_start_time = None | |
| self.recorded_keypoints = [] | |
| self.recorded_flow = [] | |
| self.prev_frame = None | |
| self.prev_mask = None | |
| # Results | |
| self.last_prediction = None | |
| self.last_confidence = 0.0 | |
| self.last_top_predictions = [] | |
| print(f"Model loaded successfully. Using device: {self.device}") | |
| print(f"Recording duration: {self.recording_duration} seconds") | |
| print(f"Recognized classes: {list(self.idx_to_label.values())}") | |
| def _linear_interpolate_sequence(self, data, target_length): | |
| """Linear interpolation to adjust sequence length""" | |
| if len(data) == target_length: | |
| return np.array(data) | |
| data = np.array(data) | |
| original_length = len(data) | |
| feature_dim = data.shape[1] | |
| interpolated_data = np.zeros((target_length, feature_dim)) | |
| for dim in range(feature_dim): | |
| original_indices = np.linspace(0, original_length - 1, original_length) | |
| target_indices = np.linspace(0, original_length - 1, target_length) | |
| interpolated_data[:, dim] = np.interp(target_indices, original_indices, data[:, dim]) | |
| return interpolated_data | |
| def start_recording(self): | |
| """Start recording a new sign""" | |
| self.is_recording = True | |
| self.recording_start_time = time.time() | |
| self.recorded_keypoints = [] | |
| self.recorded_flow = [] | |
| self.prev_frame = None | |
| self.prev_mask = None | |
| print("Recording started...") | |
| def stop_recording(self): | |
| """Stop recording and make prediction""" | |
| if not self.is_recording: | |
| return | |
| self.is_recording = False | |
| print(f"Recording stopped. Collected {len(self.recorded_keypoints)} frames") | |
| if len(self.recorded_keypoints) < 10: # Need minimum frames | |
| print("Not enough frames for prediction") | |
| self.last_prediction = "Not enough data" | |
| self.last_confidence = 0.0 | |
| self.last_top_predictions = [] | |
| return | |
| # Interpolate to target length | |
| keypoints_array = self._linear_interpolate_sequence( | |
| self.recorded_keypoints, self.sequence_length | |
| ) | |
| flow_array = self._linear_interpolate_sequence( | |
| self.recorded_flow, self.sequence_length | |
| ) | |
| # Convert to tensors | |
| keypoints_tensor = torch.FloatTensor(keypoints_array).unsqueeze(0).to(self.device) | |
| flow_tensor = torch.FloatTensor(flow_array).unsqueeze(0).to(self.device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = self.model(keypoints_tensor, flow_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # Get top-5 predictions | |
| top_probs, top_indices = torch.topk(probabilities, k=min(5, len(self.idx_to_label))) | |
| predictions = [] | |
| for i in range(top_indices.size(1)): | |
| idx = top_indices[0, i].item() | |
| prob = top_probs[0, i].item() | |
| label = self.idx_to_label[idx] | |
| predictions.append((label, prob)) | |
| # Store results | |
| self.last_prediction = predictions[0][0] | |
| self.last_confidence = predictions[0][1] | |
| self.last_top_predictions = predictions | |
| print(f"Prediction: {self.last_prediction} (confidence: {self.last_confidence:.3f})") | |
| def process_frame(self, frame): | |
| """Process a single frame with segmentation""" | |
| # Apply segmentation mask first | |
| segmented_frame, seg_mask = self.feature_extractor.apply_segmentation_mask(frame) | |
| # Convert to RGB for MediaPipe | |
| frame_rgb = cv2.cvtColor(segmented_frame, cv2.COLOR_BGR2RGB) | |
| frame_rgb.flags.writeable = False | |
| # Process with MediaPipe | |
| with self.feature_extractor.mp_holistic.Holistic( | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5, | |
| model_complexity=1) as holistic: | |
| results = holistic.process(frame_rgb) | |
| frame_rgb.flags.writeable = True | |
| # Extract keypoints | |
| keypoints = self.feature_extractor.extract_pose_keypoints(segmented_frame, results) | |
| # Create enhanced hand mask with segmentation | |
| hand_mask = self.feature_extractor.create_enhanced_hand_mask( | |
| segmented_frame, | |
| results.left_hand_landmarks, | |
| results.right_hand_landmarks, | |
| results.pose_landmarks, | |
| seg_mask | |
| ) | |
| # Calculate optical flow on segmented frame | |
| flow_features = np.zeros(10, dtype=np.float16) | |
| if self.prev_frame is not None and self.prev_mask is not None: | |
| flow_features = self.feature_extractor.compute_regional_optical_flow( | |
| self.prev_frame, segmented_frame, hand_mask, downscale=0.5 | |
| ) | |
| # If recording, store the features | |
| if self.is_recording: | |
| self.recorded_keypoints.append(keypoints) | |
| self.recorded_flow.append(flow_features) | |
| # Check if recording duration is reached | |
| if time.time() - self.recording_start_time >= self.recording_duration: | |
| self.stop_recording() | |
| # Update previous frame and mask | |
| self.prev_frame = segmented_frame.copy() | |
| self.prev_mask = hand_mask | |
| return results | |
| def draw_landmarks(self, frame, results): | |
| """Draw MediaPipe landmarks on frame""" | |
| if results.left_hand_landmarks: | |
| self.feature_extractor.mp_drawing.draw_landmarks( | |
| frame, results.left_hand_landmarks, | |
| self.feature_extractor.mp_holistic.HAND_CONNECTIONS, | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(), | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style() | |
| ) | |
| if results.right_hand_landmarks: | |
| self.feature_extractor.mp_drawing.draw_landmarks( | |
| frame, results.right_hand_landmarks, | |
| self.feature_extractor.mp_holistic.HAND_CONNECTIONS, | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_landmarks_style(), | |
| self.feature_extractor.mp_drawing_styles.get_default_hand_connections_style() | |
| ) | |
| if results.pose_landmarks: | |
| self.feature_extractor.mp_drawing.draw_landmarks( | |
| frame, results.pose_landmarks, | |
| self.feature_extractor.mp_holistic.POSE_CONNECTIONS, | |
| self.feature_extractor.mp_drawing_styles.get_default_pose_landmarks_style() | |
| ) | |
| return frame | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Sign Language Recognition - Choose Mode') | |
| parser.add_argument('--model', default='tsflow/models/best_model.pt', | |
| help='Path to trained model') | |
| parser.add_argument('--config', default='tsflow/results/test_results.json', | |
| help='Path to config file with label mappings') | |
| parser.add_argument('--camera', type=int, default=0, | |
| help='Camera index') | |
| parser.add_argument('--sequence_length', type=int, default=50, | |
| help='Sequence length for prediction') | |
| parser.add_argument('--confidence_threshold', type=float, default=0.5, | |
| help='Confidence threshold for predictions') | |
| parser.add_argument('--mode', choices=['realtime', 'single'], default='single', | |
| help='Recognition mode: realtime (continuous) or single (one-by-one)') | |
| parser.add_argument('--recording_duration', type=float, default=4.0, | |
| help='Duration to record each sign in single mode (seconds)') | |
| parser.add_argument('--use_segmentation', action='store_true', default=True, | |
| help='Enable human segmentation for background removal') | |
| parser.add_argument('--no_segmentation', action='store_true', default=False, | |
| help='Disable human segmentation') | |
| args = parser.parse_args() | |
| # Check if model and config files exist | |
| if not os.path.exists(args.model): | |
| print(f"Model file not found: {args.model}") | |
| return | |
| if not os.path.exists(args.config): | |
| print(f"Config file not found: {args.config}") | |
| return | |
| # Determine segmentation setting | |
| use_segmentation = args.use_segmentation and not args.no_segmentation | |
| if args.mode == 'single': | |
| # Single sign mode | |
| predictor = SingleSignPredictor( | |
| model_path=args.model, | |
| config_path=args.config, | |
| sequence_length=args.sequence_length, | |
| recording_duration=args.recording_duration, | |
| use_segmentation=use_segmentation | |
| ) | |
| # Initialize camera | |
| cap = cv2.VideoCapture(args.camera) | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| cap.set(cv2.CAP_PROP_FPS, 30) | |
| if not cap.isOpened(): | |
| print(f"Cannot open camera {args.camera}") | |
| return | |
| print("\n" + "="*60) | |
| print("Single Sign Language Recognition") | |
| print("="*60) | |
| print("Controls:") | |
| print(" SPACE: Start/Stop recording a sign") | |
| print(" 'c': Clear last prediction") | |
| print(" 'q': Quit") | |
| print("="*60) | |
| # FPS calculation | |
| fps_counter = 0 | |
| fps_start_time = time.time() | |
| current_fps = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| print("Failed to read frame from camera") | |
| break | |
| # Mirror frame horizontally | |
| frame = cv2.flip(frame, 1) | |
| # Process frame | |
| results = predictor.process_frame(frame) | |
| # Draw landmarks | |
| frame = predictor.draw_landmarks(frame, results) | |
| # Calculate FPS | |
| fps_counter += 1 | |
| if fps_counter % 30 == 0: | |
| fps_end_time = time.time() | |
| current_fps = 30 / (fps_end_time - fps_start_time) | |
| fps_start_time = fps_end_time | |
| # Draw UI | |
| h, w, _ = frame.shape | |
| # Main info panel | |
| cv2.rectangle(frame, (10, 10), (w-10, 200), (0, 0, 0), -1) | |
| cv2.rectangle(frame, (10, 10), (w-10, 200), (255, 255, 255), 2) | |
| # FPS | |
| cv2.putText(frame, f"FPS: {current_fps:.1f}", (20, 35), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| # Recording status | |
| if predictor.is_recording: | |
| elapsed = time.time() - predictor.recording_start_time | |
| remaining = max(0, args.recording_duration - elapsed) | |
| progress = elapsed / args.recording_duration | |
| # Recording indicator | |
| cv2.putText(frame, "RECORDING", (20, 65), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
| cv2.putText(frame, f"Time: {remaining:.1f}s", (20, 90), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) | |
| # Progress bar | |
| bar_width = w - 40 | |
| bar_height = 15 | |
| cv2.rectangle(frame, (20, 100), (20 + bar_width, 100 + bar_height), (100, 100, 100), -1) | |
| cv2.rectangle(frame, (20, 100), (20 + int(bar_width * progress), 100 + bar_height), (0, 0, 255), -1) | |
| # Recording circle (blinking effect) | |
| if int(elapsed * 4) % 2 == 0: # Blink every 0.25 seconds | |
| cv2.circle(frame, (w - 40, 40), 15, (0, 0, 255), -1) | |
| else: | |
| cv2.putText(frame, "READY - Press SPACE to record", (20, 65), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| # Show last prediction if available | |
| if predictor.last_prediction and predictor.last_prediction != "Not enough data": | |
| y_offset = 100 | |
| cv2.putText(frame, "Last Prediction:", (20, y_offset), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Top prediction | |
| cv2.putText(frame, f"1. {predictor.last_prediction}: {predictor.last_confidence:.3f}", | |
| (20, y_offset + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| # Top 5 predictions | |
| for i, (label, conf) in enumerate(predictor.last_top_predictions[1:4], 2): | |
| cv2.putText(frame, f"{i}. {label}: {conf:.3f}", | |
| (20, y_offset + 25 * i), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| # Instructions | |
| cv2.putText(frame, "SPACE: Record | C: Clear | Q: Quit", (20, h-20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| # Show frame | |
| cv2.imshow('Single Sign Language Recognition', frame) | |
| # Handle key presses | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord('q'): | |
| break | |
| elif key == ord(' '): # Space bar | |
| if not predictor.is_recording: | |
| predictor.start_recording() | |
| else: | |
| predictor.stop_recording() | |
| elif key == ord('c'): | |
| predictor.last_prediction = None | |
| predictor.last_confidence = 0.0 | |
| predictor.last_top_predictions = [] | |
| print("Prediction cleared") | |
| else: | |
| # Realtime mode | |
| predictor = RealtimeSignPredictor( | |
| model_path=args.model, | |
| config_path=args.config, | |
| sequence_length=args.sequence_length, | |
| confidence_threshold=args.confidence_threshold, | |
| use_segmentation=use_segmentation | |
| ) | |
| # Initialize camera | |
| cap = cv2.VideoCapture(args.camera) | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| cap.set(cv2.CAP_PROP_FPS, 30) | |
| if not cap.isOpened(): | |
| print(f"Cannot open camera {args.camera}") | |
| return | |
| print("Starting real-time sign language recognition...") | |
| print("Press 'q' to quit, 'r' to reset sequence") | |
| # FPS calculation | |
| fps_counter = 0 | |
| fps_start_time = time.time() | |
| current_fps = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| print("Failed to read frame from camera") | |
| break | |
| # Mirror frame horizontally | |
| frame = cv2.flip(frame, 1) | |
| # Process frame | |
| results, keypoints, flow_features = predictor.process_frame(frame) | |
| # Draw landmarks | |
| frame = predictor.draw_landmarks(frame, results) | |
| # Get predictions | |
| top_predictions = predictor.get_top_predictions(top_k=3) | |
| # Calculate FPS | |
| fps_counter += 1 | |
| if fps_counter % 30 == 0: | |
| fps_end_time = time.time() | |
| current_fps = 30 / (fps_end_time - fps_start_time) | |
| fps_start_time = fps_end_time | |
| # Draw information on frame | |
| h, w, _ = frame.shape | |
| # Background for text | |
| cv2.rectangle(frame, (10, 10), (w-10, 150), (0, 0, 0), -1) | |
| cv2.rectangle(frame, (10, 10), (w-10, 150), (255, 255, 255), 2) | |
| # FPS | |
| cv2.putText(frame, f"FPS: {current_fps:.1f}", (20, 35), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| # Sequence progress | |
| progress = len(predictor.keypoint_sequence) / args.sequence_length | |
| cv2.putText(frame, f"Sequence: {len(predictor.keypoint_sequence)}/{args.sequence_length}", | |
| (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Progress bar | |
| bar_width = w - 40 | |
| bar_height = 10 | |
| cv2.rectangle(frame, (20, 70), (20 + bar_width, 70 + bar_height), (100, 100, 100), -1) | |
| cv2.rectangle(frame, (20, 70), (20 + int(bar_width * progress), 70 + bar_height), (0, 255, 0), -1) | |
| # Predictions | |
| y_offset = 100 | |
| if top_predictions: | |
| for i, (label, confidence) in enumerate(top_predictions): | |
| color = (0, 255, 0) if confidence > args.confidence_threshold else (0, 255, 255) | |
| text = f"{i+1}. {label}: {confidence:.2f}" | |
| cv2.putText(frame, text, (20, y_offset + i * 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) | |
| else: | |
| cv2.putText(frame, "Collecting frames...", (20, y_offset), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Instructions | |
| cv2.putText(frame, "Press 'q' to quit, 'r' to reset", (20, h-20), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| # Show frame | |
| cv2.imshow('Real-time Sign Language Recognition', frame) | |
| # Handle key presses | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord('q'): | |
| break | |
| elif key == ord('r'): | |
| predictor.keypoint_sequence.clear() | |
| predictor.flow_sequence.clear() | |
| print("Sequence reset") | |
| # Cleanup | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| print("Recognition stopped") | |
| if __name__ == "__main__": | |
| main() |