ASL_Translator / app.py
kalpniks's picture
Upload folder using huggingface_hub
c62ea80 verified
raw
history blame
7.74 kB
# Import necessary libraries
from collections import Counter
import time
import traceback
import gradio as gr
from transformers import AutoImageProcessor
from transformers import SiglipForImageClassification
from transformers.image_utils import load_image
from PIL import Image
import torch
import cv2 # Import cv2 for video frame processing
# Load model and processor for Alphabet Sign Language Detection
model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# Define the maximum number of consecutive repetitions allowed for predictions
MAX_CONSECUTIVE_REPETITIONS = 3
def sign_language_classification(video):
"""
Predicts sign language alphabet category for each frame in a video,
yields predictions in real-time with repetition handling, and returns a list of unique predicted letters.
"""
print("sign_language_classification function called.") # Debug print to indicate function call
if video is None:
print("No video provided.") # Debug print if no video input
yield "No video provided.", "" # Yield empty string for the second output if no video
return
print(f"Video input type: {type(video)}") # Debug print to show video input type
print(f"Video value: {video}") # Debug print to show video input value
predicted_letters = [] # List to store all predicted letters from each frame
last_predicted_label = None # Initialize variable to store the last predicted label to handle repetitions
consecutive_repetitions = 0 # Initialize counter for consecutive repetitions of the same prediction
try:
print("Starting frame processing loop.") # Debug print to indicate start of frame processing
frames = []
if isinstance(video, str):
# If video is a filepath (e.g., uploaded file), load the video frames using OpenCV
cap = cv2.VideoCapture(video)
if not cap.isOpened():
yield "Error: Could not open video file.", "" # Yield error if video file cannot be opened
return
while True:
ret, frame = cap.read()
if not ret: # Break the loop if no more frames are returned
break
frames.append(frame) # Append the read frame to the frames list
cap.release() # Release the video capture object
elif isinstance(video, list):
# If video is already a list of frames (e.g., from webcam in some Gradio versions)
frames = video
else:
yield "Error: Unsupported video input type.", "" # Yield error for unsupported video input types
return
for i, frame in enumerate(frames):
# print(f"Processing frame {i}") # Debug print - Removed for cleaner output
# Convert the numpy frame (BGR format from OpenCV) to a PIL Image in RGB format for the model
image = Image.fromarray(frame).convert("RGB")
# print(f"Frame {i} converted to PIL Image.") # Debug print - Removed for cleaner output
# Process the image frame using the pre-trained processor and model
inputs = processor(images=image, return_tensors="pt") # Prepare image for model input
# print(f"Frame {i} processed by processor.") # Debug print - Removed for cleaner output
# Perform inference with the model
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(**inputs)
logits = outputs.logits # Get the raw output scores (logits)
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # Apply softmax to get probabilities and convert to list
# print(f"Frame {i} processed by model. Logits shape: {logits.shape}") # Debug print - Removed for cleaner output
# Define the labels mapping model output indices to ASL alphabet letters
labels = {
"0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
"10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
"20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
}
# Get the index of the highest probability and find the corresponding predicted label
predicted_label_index = probs.index(max(probs))
predicted_label = labels[str(predicted_label_index)]
# print(f"Frame {i} prediction: {predicted_label}") # Debug print - Removed for cleaner output
predicted_letters.append(predicted_label) # Append predicted letter to the list of all predictions
# Check for consecutive repetitions and yield only if the rule is met
if predicted_label == last_predicted_label:
consecutive_repetitions += 1
else:
consecutive_repetitions = 1 # Reset consecutive count if prediction changes
# Yield the prediction if it's not a consecutive repetition beyond the limit or if it's the first prediction
if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None:
yield predicted_label, "" # Yield real-time prediction and empty string for the second output
last_predicted_label = predicted_label # Update the last predicted label
print("Finished frame processing loop.") # Debug print to indicate end of frame processing
# Get unique predicted letters while maintaining order of appearance
unique_predicted_letters = list(dict.fromkeys(predicted_letters))
final_output = ", ".join(unique_predicted_letters) # Join unique letters into a comma-separated string
# Yield the last predicted label (or empty string if none) and the final list of unique letters
yield last_predicted_label if last_predicted_label is not None else "", final_output
except Exception as e:
print(f"Error caught: {e}") # Debug print if an error occurs
# Yield error message and traceback information in case of an exception
# Flatten traceback to avoid SyntaxError in the generated string
flat_traceback = traceback.format_exc().replace(chr(10), ' | ').replace(chr(13), '')
yield f"Error processing video: {e}", f"Error processing video: {e}\n{{flat_traceback}}"
# Custom CSS for styling (commented out)
# custom_css = """
# body {
# background-color: #add8e6;
# }
# """
# Create Gradio interface with video input and multiple outputs
iface = gr.Interface(
fn=sign_language_classification, # The function to run when the user interacts with the interface
inputs=gr.Video(sources=["upload", "webcam"]), # Input component: Video, allowing upload or webcam
outputs=[
gr.Label(label="Real-time Prediction"), # Output component: Label to display the real-time prediction
gr.Textbox(label="Unique Predicted Letters") # Output component: Textbox to display the final list of unique predicted letters
],
title="ASL Translator", # Title of the Gradio interface
description="Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions in real-time and a summary list. ASL Words Translator coming soon!", # Description displayed below the title
# css=custom_css # Apply custom CSS (commented out)
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)