Spaces:
Sleeping
Sleeping
| # Import necessary libraries | |
| from collections import Counter | |
| import time | |
| import traceback | |
| import gradio as gr | |
| from transformers import AutoImageProcessor | |
| from transformers import SiglipForImageClassification | |
| from transformers.image_utils import load_image | |
| from PIL import Image | |
| import torch | |
| import cv2 # Import cv2 for video frame processing | |
| # Load model and processor for Alphabet Sign Language Detection | |
| model_name = "prithivMLmods/Alphabet-Sign-Language-Detection" | |
| model = SiglipForImageClassification.from_pretrained(model_name) | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| # Define the maximum number of consecutive repetitions allowed for predictions | |
| MAX_CONSECUTIVE_REPETITIONS = 3 | |
| def sign_language_classification(video): | |
| """ | |
| Predicts sign language alphabet category for each frame in a video, | |
| yields predictions in real-time with repetition handling, and returns a list of unique predicted letters. | |
| """ | |
| print("sign_language_classification function called.") # Debug print to indicate function call | |
| if video is None: | |
| print("No video provided.") # Debug print if no video input | |
| yield "No video provided.", "" # Yield empty string for the second output if no video | |
| return | |
| print(f"Video input type: {type(video)}") # Debug print to show video input type | |
| print(f"Video value: {video}") # Debug print to show video input value | |
| predicted_letters = [] # List to store all predicted letters from each frame | |
| last_predicted_label = None # Initialize variable to store the last predicted label to handle repetitions | |
| consecutive_repetitions = 0 # Initialize counter for consecutive repetitions of the same prediction | |
| try: | |
| print("Starting frame processing loop.") # Debug print to indicate start of frame processing | |
| frames = [] | |
| if isinstance(video, str): | |
| # If video is a filepath (e.g., uploaded file), load the video frames using OpenCV | |
| cap = cv2.VideoCapture(video) | |
| if not cap.isOpened(): | |
| yield "Error: Could not open video file.", "" # Yield error if video file cannot be opened | |
| return | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: # Break the loop if no more frames are returned | |
| break | |
| frames.append(frame) # Append the read frame to the frames list | |
| cap.release() # Release the video capture object | |
| elif isinstance(video, list): | |
| # If video is already a list of frames (e.g., from webcam in some Gradio versions) | |
| frames = video | |
| else: | |
| yield "Error: Unsupported video input type.", "" # Yield error for unsupported video input types | |
| return | |
| for i, frame in enumerate(frames): | |
| # print(f"Processing frame {i}") # Debug print - Removed for cleaner output | |
| # Convert the numpy frame (BGR format from OpenCV) to a PIL Image in RGB format for the model | |
| image = Image.fromarray(frame).convert("RGB") | |
| # print(f"Frame {i} converted to PIL Image.") # Debug print - Removed for cleaner output | |
| # Process the image frame using the pre-trained processor and model | |
| inputs = processor(images=image, return_tensors="pt") # Prepare image for model input | |
| # print(f"Frame {i} processed by processor.") # Debug print - Removed for cleaner output | |
| # Perform inference with the model | |
| with torch.no_grad(): # Disable gradient calculation for inference | |
| outputs = model(**inputs) | |
| logits = outputs.logits # Get the raw output scores (logits) | |
| probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # Apply softmax to get probabilities and convert to list | |
| # print(f"Frame {i} processed by model. Logits shape: {logits.shape}") # Debug print - Removed for cleaner output | |
| # Define the labels mapping model output indices to ASL alphabet letters | |
| labels = { | |
| "0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J", | |
| "10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T", | |
| "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z" | |
| } | |
| # Get the index of the highest probability and find the corresponding predicted label | |
| predicted_label_index = probs.index(max(probs)) | |
| predicted_label = labels[str(predicted_label_index)] | |
| # print(f"Frame {i} prediction: {predicted_label}") # Debug print - Removed for cleaner output | |
| predicted_letters.append(predicted_label) # Append predicted letter to the list of all predictions | |
| # Check for consecutive repetitions and yield only if the rule is met | |
| if predicted_label == last_predicted_label: | |
| consecutive_repetitions += 1 | |
| else: | |
| consecutive_repetitions = 1 # Reset consecutive count if prediction changes | |
| # Yield the prediction if it's not a consecutive repetition beyond the limit or if it's the first prediction | |
| if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None: | |
| yield predicted_label, "" # Yield real-time prediction and empty string for the second output | |
| last_predicted_label = predicted_label # Update the last predicted label | |
| print("Finished frame processing loop.") # Debug print to indicate end of frame processing | |
| # Get unique predicted letters while maintaining order of appearance | |
| unique_predicted_letters = list(dict.fromkeys(predicted_letters)) | |
| final_output = ", ".join(unique_predicted_letters) # Join unique letters into a comma-separated string | |
| # Yield the last predicted label (or empty string if none) and the final list of unique letters | |
| yield last_predicted_label if last_predicted_label is not None else "", final_output | |
| except Exception as e: | |
| print(f"Error caught: {e}") # Debug print if an error occurs | |
| # Yield error message and traceback information in case of an exception | |
| # Flatten traceback to avoid SyntaxError in the generated string | |
| flat_traceback = traceback.format_exc().replace(chr(10), ' | ').replace(chr(13), '') | |
| yield f"Error processing video: {e}", f"Error processing video: {e}\n{{flat_traceback}}" | |
| # Custom CSS for styling (commented out) | |
| # custom_css = """ | |
| # body { | |
| # background-color: #add8e6; | |
| # } | |
| # """ | |
| # Create Gradio interface with video input and multiple outputs | |
| iface = gr.Interface( | |
| fn=sign_language_classification, # The function to run when the user interacts with the interface | |
| inputs=gr.Video(sources=["upload", "webcam"]), # Input component: Video, allowing upload or webcam | |
| outputs=[ | |
| gr.Label(label="Real-time Prediction"), # Output component: Label to display the real-time prediction | |
| gr.Textbox(label="Unique Predicted Letters") # Output component: Textbox to display the final list of unique predicted letters | |
| ], | |
| title="ASL Translator", # Title of the Gradio interface | |
| description="Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions in real-time and a summary list. ASL Words Translator coming soon!", # Description displayed below the title | |
| # css=custom_css # Apply custom CSS (commented out) | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |