Spaces:
Sleeping
Sleeping
File size: 7,737 Bytes
69f9775 48bf46d 69f9775 48bf46d 69f9775 ad4360e 69f9775 48bf46d 69f9775 48bf46d 69f9775 13b84eb 69f9775 13b84eb 69f9775 48bf46d 69f9775 00afd62 48bf46d 69f9775 48bf46d 69f9775 48bf46d 69f9775 13b84eb 69f9775 48bf46d 69f9775 48bf46d 93d3db6 69f9775 48bf46d 69f9775 48bf46d 69f9775 13b84eb 69f9775 13b84eb 69f9775 13b84eb 69f9775 1097247 76a8928 c62ea80 69f9775 13b84eb 69f9775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# Import necessary libraries
from collections import Counter
import time
import traceback
import gradio as gr
from transformers import AutoImageProcessor
from transformers import SiglipForImageClassification
from transformers.image_utils import load_image
from PIL import Image
import torch
import cv2 # Import cv2 for video frame processing
# Load model and processor for Alphabet Sign Language Detection
model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# Define the maximum number of consecutive repetitions allowed for predictions
MAX_CONSECUTIVE_REPETITIONS = 3
def sign_language_classification(video):
"""
Predicts sign language alphabet category for each frame in a video,
yields predictions in real-time with repetition handling, and returns a list of unique predicted letters.
"""
print("sign_language_classification function called.") # Debug print to indicate function call
if video is None:
print("No video provided.") # Debug print if no video input
yield "No video provided.", "" # Yield empty string for the second output if no video
return
print(f"Video input type: {type(video)}") # Debug print to show video input type
print(f"Video value: {video}") # Debug print to show video input value
predicted_letters = [] # List to store all predicted letters from each frame
last_predicted_label = None # Initialize variable to store the last predicted label to handle repetitions
consecutive_repetitions = 0 # Initialize counter for consecutive repetitions of the same prediction
try:
print("Starting frame processing loop.") # Debug print to indicate start of frame processing
frames = []
if isinstance(video, str):
# If video is a filepath (e.g., uploaded file), load the video frames using OpenCV
cap = cv2.VideoCapture(video)
if not cap.isOpened():
yield "Error: Could not open video file.", "" # Yield error if video file cannot be opened
return
while True:
ret, frame = cap.read()
if not ret: # Break the loop if no more frames are returned
break
frames.append(frame) # Append the read frame to the frames list
cap.release() # Release the video capture object
elif isinstance(video, list):
# If video is already a list of frames (e.g., from webcam in some Gradio versions)
frames = video
else:
yield "Error: Unsupported video input type.", "" # Yield error for unsupported video input types
return
for i, frame in enumerate(frames):
# print(f"Processing frame {i}") # Debug print - Removed for cleaner output
# Convert the numpy frame (BGR format from OpenCV) to a PIL Image in RGB format for the model
image = Image.fromarray(frame).convert("RGB")
# print(f"Frame {i} converted to PIL Image.") # Debug print - Removed for cleaner output
# Process the image frame using the pre-trained processor and model
inputs = processor(images=image, return_tensors="pt") # Prepare image for model input
# print(f"Frame {i} processed by processor.") # Debug print - Removed for cleaner output
# Perform inference with the model
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(**inputs)
logits = outputs.logits # Get the raw output scores (logits)
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # Apply softmax to get probabilities and convert to list
# print(f"Frame {i} processed by model. Logits shape: {logits.shape}") # Debug print - Removed for cleaner output
# Define the labels mapping model output indices to ASL alphabet letters
labels = {
"0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
"10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
"20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
}
# Get the index of the highest probability and find the corresponding predicted label
predicted_label_index = probs.index(max(probs))
predicted_label = labels[str(predicted_label_index)]
# print(f"Frame {i} prediction: {predicted_label}") # Debug print - Removed for cleaner output
predicted_letters.append(predicted_label) # Append predicted letter to the list of all predictions
# Check for consecutive repetitions and yield only if the rule is met
if predicted_label == last_predicted_label:
consecutive_repetitions += 1
else:
consecutive_repetitions = 1 # Reset consecutive count if prediction changes
# Yield the prediction if it's not a consecutive repetition beyond the limit or if it's the first prediction
if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None:
yield predicted_label, "" # Yield real-time prediction and empty string for the second output
last_predicted_label = predicted_label # Update the last predicted label
print("Finished frame processing loop.") # Debug print to indicate end of frame processing
# Get unique predicted letters while maintaining order of appearance
unique_predicted_letters = list(dict.fromkeys(predicted_letters))
final_output = ", ".join(unique_predicted_letters) # Join unique letters into a comma-separated string
# Yield the last predicted label (or empty string if none) and the final list of unique letters
yield last_predicted_label if last_predicted_label is not None else "", final_output
except Exception as e:
print(f"Error caught: {e}") # Debug print if an error occurs
# Yield error message and traceback information in case of an exception
# Flatten traceback to avoid SyntaxError in the generated string
flat_traceback = traceback.format_exc().replace(chr(10), ' | ').replace(chr(13), '')
yield f"Error processing video: {e}", f"Error processing video: {e}\n{{flat_traceback}}"
# Custom CSS for styling (commented out)
# custom_css = """
# body {
# background-color: #add8e6;
# }
# """
# Create Gradio interface with video input and multiple outputs
iface = gr.Interface(
fn=sign_language_classification, # The function to run when the user interacts with the interface
inputs=gr.Video(sources=["upload", "webcam"]), # Input component: Video, allowing upload or webcam
outputs=[
gr.Label(label="Real-time Prediction"), # Output component: Label to display the real-time prediction
gr.Textbox(label="Unique Predicted Letters") # Output component: Textbox to display the final list of unique predicted letters
],
title="ASL Translator", # Title of the Gradio interface
description="Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions in real-time and a summary list. ASL Words Translator coming soon!", # Description displayed below the title
# css=custom_css # Apply custom CSS (commented out)
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
|