File size: 7,737 Bytes
69f9775
48bf46d
 
 
69f9775
 
 
 
48bf46d
 
69f9775
ad4360e
69f9775
48bf46d
69f9775
 
48bf46d
69f9775
 
13b84eb
 
69f9775
 
 
 
 
 
 
 
 
 
13b84eb
69f9775
 
48bf46d
69f9775
 
 
00afd62
48bf46d
69f9775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48bf46d
 
69f9775
 
48bf46d
69f9775
 
 
13b84eb
 
69f9775
 
 
48bf46d
 
69f9775
 
 
 
 
 
48bf46d
93d3db6
69f9775
 
 
 
 
 
 
 
 
 
48bf46d
69f9775
48bf46d
69f9775
 
 
 
 
13b84eb
69f9775
 
 
 
13b84eb
 
69f9775
 
 
 
 
 
13b84eb
69f9775
 
 
1097247
76a8928
c62ea80
69f9775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13b84eb
 
69f9775
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Import necessary libraries
from collections import Counter
import time
import traceback
import gradio as gr
from transformers import AutoImageProcessor
from transformers import SiglipForImageClassification
from transformers.image_utils import load_image
from PIL import Image
import torch
import cv2 # Import cv2 for video frame processing

# Load model and processor for Alphabet Sign Language Detection
model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)

# Define the maximum number of consecutive repetitions allowed for predictions
MAX_CONSECUTIVE_REPETITIONS = 3


def sign_language_classification(video):
    """
    Predicts sign language alphabet category for each frame in a video,
    yields predictions in real-time with repetition handling, and returns a list of unique predicted letters.
    """
    print("sign_language_classification function called.") # Debug print to indicate function call
    if video is None:
        print("No video provided.") # Debug print if no video input
        yield "No video provided.", "" # Yield empty string for the second output if no video
        return

    print(f"Video input type: {type(video)}") # Debug print to show video input type
    print(f"Video value: {video}") # Debug print to show video input value

    predicted_letters = [] # List to store all predicted letters from each frame
    last_predicted_label = None # Initialize variable to store the last predicted label to handle repetitions
    consecutive_repetitions = 0 # Initialize counter for consecutive repetitions of the same prediction

    try:
        print("Starting frame processing loop.") # Debug print to indicate start of frame processing
        frames = []
        if isinstance(video, str):
            # If video is a filepath (e.g., uploaded file), load the video frames using OpenCV
            cap = cv2.VideoCapture(video)
            if not cap.isOpened():
                yield "Error: Could not open video file.", "" # Yield error if video file cannot be opened
                return
            while True:
                ret, frame = cap.read()
                if not ret: # Break the loop if no more frames are returned
                    break
                frames.append(frame) # Append the read frame to the frames list
            cap.release() # Release the video capture object
        elif isinstance(video, list):
             # If video is already a list of frames (e.g., from webcam in some Gradio versions)
             frames = video
        else:
             yield "Error: Unsupported video input type.", "" # Yield error for unsupported video input types
             return


        for i, frame in enumerate(frames):
            # print(f"Processing frame {i}") # Debug print - Removed for cleaner output

            # Convert the numpy frame (BGR format from OpenCV) to a PIL Image in RGB format for the model
            image = Image.fromarray(frame).convert("RGB")
            # print(f"Frame {i} converted to PIL Image.") # Debug print - Removed for cleaner output


            # Process the image frame using the pre-trained processor and model
            inputs = processor(images=image, return_tensors="pt") # Prepare image for model input
            # print(f"Frame {i} processed by processor.") # Debug print - Removed for cleaner output


            # Perform inference with the model
            with torch.no_grad(): # Disable gradient calculation for inference
                outputs = model(**inputs)
                logits = outputs.logits # Get the raw output scores (logits)
                probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # Apply softmax to get probabilities and convert to list
                # print(f"Frame {i} processed by model. Logits shape: {logits.shape}") # Debug print - Removed for cleaner output


            # Define the labels mapping model output indices to ASL alphabet letters
            labels = {
                "0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
                "10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
                "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
            }
            # Get the index of the highest probability and find the corresponding predicted label
            predicted_label_index = probs.index(max(probs))
            predicted_label = labels[str(predicted_label_index)]
            # print(f"Frame {i} prediction: {predicted_label}") # Debug print - Removed for cleaner output

            predicted_letters.append(predicted_label) # Append predicted letter to the list of all predictions

            # Check for consecutive repetitions and yield only if the rule is met
            if predicted_label == last_predicted_label:
                consecutive_repetitions += 1
            else:
                consecutive_repetitions = 1 # Reset consecutive count if prediction changes

            # Yield the prediction if it's not a consecutive repetition beyond the limit or if it's the first prediction
            if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None:
                 yield predicted_label, "" # Yield real-time prediction and empty string for the second output
                 last_predicted_label = predicted_label # Update the last predicted label


        print("Finished frame processing loop.") # Debug print to indicate end of frame processing
        # Get unique predicted letters while maintaining order of appearance
        unique_predicted_letters = list(dict.fromkeys(predicted_letters))
        final_output = ", ".join(unique_predicted_letters) # Join unique letters into a comma-separated string
        # Yield the last predicted label (or empty string if none) and the final list of unique letters
        yield last_predicted_label if last_predicted_label is not None else "", final_output

    except Exception as e:
        print(f"Error caught: {e}") # Debug print if an error occurs
        # Yield error message and traceback information in case of an exception
        # Flatten traceback to avoid SyntaxError in the generated string
        flat_traceback = traceback.format_exc().replace(chr(10), ' | ').replace(chr(13), '')
        yield f"Error processing video: {e}", f"Error processing video: {e}\n{{flat_traceback}}"

# Custom CSS for styling (commented out)
# custom_css = """
# body {
#   background-color: #add8e6;
# }
# """

# Create Gradio interface with video input and multiple outputs
iface = gr.Interface(
    fn=sign_language_classification, # The function to run when the user interacts with the interface
    inputs=gr.Video(sources=["upload", "webcam"]), # Input component: Video, allowing upload or webcam
    outputs=[
        gr.Label(label="Real-time Prediction"), # Output component: Label to display the real-time prediction
        gr.Textbox(label="Unique Predicted Letters") # Output component: Textbox to display the final list of unique predicted letters
    ],
    title="ASL Translator", # Title of the Gradio interface
    description="Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions in real-time and a summary list. ASL Words Translator coming soon!", # Description displayed below the title
    # css=custom_css # Apply custom CSS (commented out)
)

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)