Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +2 -2
- app.py +113 -175
- requirements.txt +3 -5
Dockerfile
CHANGED
|
@@ -12,5 +12,5 @@ COPY . /app
|
|
| 12 |
|
| 13 |
EXPOSE 7860
|
| 14 |
|
| 15 |
-
# CMD for
|
| 16 |
-
CMD ["
|
|
|
|
| 12 |
|
| 13 |
EXPOSE 7860
|
| 14 |
|
| 15 |
+
# CMD for Gradio application
|
| 16 |
+
CMD ["python", "app.py"]
|
app.py
CHANGED
|
@@ -1,206 +1,144 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
from collections import Counter
|
| 4 |
import time
|
| 5 |
import traceback
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
import torch
|
| 9 |
-
import cv2
|
| 10 |
-
import numpy as np
|
| 11 |
-
import av
|
| 12 |
-
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
|
| 13 |
-
# ClientSettings import removed as it was causing issues, using dict directly for rtc_configuration
|
| 14 |
|
| 15 |
-
|
| 16 |
-
os.makedirs("/tmp/huggingface", exist_ok=True)
|
| 17 |
-
|
| 18 |
-
# Load model and processor
|
| 19 |
model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
print(f"INFO: Loading model '{model_name}'...")
|
| 23 |
-
model = SiglipForImageClassification.from_pretrained(model_name)
|
| 24 |
-
processor = AutoImageProcessor.from_pretrained(model_name)
|
| 25 |
-
print("INFO: Model and processor loaded successfully.")
|
| 26 |
-
return model, processor
|
| 27 |
-
|
| 28 |
-
# Call the cached resource loader once and store in global variables
|
| 29 |
-
model, processor = load_model_and_processor()
|
| 30 |
-
|
| 31 |
-
# Define the maximum number of consecutive repetitions allowed for predictions (global constant)
|
| 32 |
-
MAX_CONSECUTIVE_REPETITIONS = 3
|
| 33 |
|
| 34 |
-
# Define
|
| 35 |
-
|
| 36 |
-
"0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
|
| 37 |
-
"10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
|
| 38 |
-
"20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
# Initialize session state for live predictions if not already present
|
| 42 |
-
# These are the only session state variables that need to be dynamic
|
| 43 |
-
if 'live_realtime_pred' not in st.session_state:
|
| 44 |
-
st.session_state.live_realtime_pred = ""
|
| 45 |
-
if 'live_unique_letters' not in st.session_state:
|
| 46 |
-
st.session_state.live_unique_letters = ""
|
| 47 |
-
if 'live_predicted_frames_buffer' not in st.session_state:
|
| 48 |
-
st.session_state.live_predicted_frames_buffer = []
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class SignLanguageVideoProcessor(VideoProcessorBase):
|
| 52 |
-
def __init__(self):
|
| 53 |
-
# Directly use the global variables (which are cached resources or constants)
|
| 54 |
-
self.model = model
|
| 55 |
-
self.processor = processor
|
| 56 |
-
self.labels = labels
|
| 57 |
-
self.max_consecutive_repetitions = MAX_CONSECUTIVE_REPETITIONS
|
| 58 |
-
self.last_predicted_label = None
|
| 59 |
-
self.consecutive_repetitions = 0
|
| 60 |
-
|
| 61 |
-
def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
|
| 62 |
-
img_pil = frame.to_image().convert("RGB")
|
| 63 |
-
|
| 64 |
-
inputs = self.processor(images=img_pil, return_tensors="pt")
|
| 65 |
-
with torch.no_grad():
|
| 66 |
-
outputs = self.model(**inputs)
|
| 67 |
-
logits = outputs.logits
|
| 68 |
-
|
| 69 |
-
predicted_label_index = torch.argmax(logits, dim=1).item()
|
| 70 |
-
current_predicted_label = self.labels[str(predicted_label_index)]
|
| 71 |
-
|
| 72 |
-
# Update the buffer of all predicted letters
|
| 73 |
-
st.session_state.live_predicted_frames_buffer.append(current_predicted_label)
|
| 74 |
-
|
| 75 |
-
# Apply repetition logic for real-time display
|
| 76 |
-
if current_predicted_label == self.last_predicted_label:
|
| 77 |
-
self.consecutive_repetitions += 1
|
| 78 |
-
else:
|
| 79 |
-
self.consecutive_repetitions = 1
|
| 80 |
-
|
| 81 |
-
if self.consecutive_repetitions > self.max_consecutive_repetitions or self.last_predicted_label is None:
|
| 82 |
-
st.session_state.live_realtime_pred = current_predicted_label
|
| 83 |
-
self.last_predicted_label = current_predicted_label
|
| 84 |
|
| 85 |
-
# Update unique letters from the buffer
|
| 86 |
-
unique_preds = list(dict.fromkeys(st.session_state.live_predicted_frames_buffer))
|
| 87 |
-
st.session_state.live_unique_letters = ", ".join(unique_preds)
|
| 88 |
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
print("
|
| 93 |
-
predicted_letters = []
|
| 94 |
-
last_predicted_label = None
|
| 95 |
-
consecutive_repetitions = 0
|
| 96 |
|
| 97 |
-
#
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
local_labels = labels
|
| 101 |
-
local_max_consecutive_repetitions = MAX_CONSECUTIVE_REPETITIONS
|
| 102 |
|
| 103 |
try:
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
consecutive_repetitions = 1
|
| 125 |
-
|
| 126 |
-
if consecutive_repetitions > local_max_consecutive_repetitions or last_predicted_label is None:
|
| 127 |
-
predicted_letters.append(current_predicted_label)
|
| 128 |
-
last_predicted_label = current_predicted_label
|
| 129 |
-
|
| 130 |
-
cap.release()
|
| 131 |
-
unique_predicted_letters = list(dict.fromkeys(predicted_letters))
|
| 132 |
-
final_output_str = ", ".join(unique_predicted_letters)
|
| 133 |
-
realtime_equivalent_prediction = unique_predicted_letters[-1] if unique_predicted_letters else ""
|
| 134 |
|
| 135 |
-
return realtime_equivalent_prediction, final_output_str
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
error_msg = f"Error processing video: {e}"
|
| 140 |
-
full_traceback_flat = traceback.format_exc().replace('\n', ' | ').replace('\r', '')
|
| 141 |
-
return error_msg, f"{{error_msg}} (Details: {{full_traceback_flat}})"
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
st.markdown("Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions. ASL Words Translator coming soon!")
|
| 147 |
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
-
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov", "webm"])
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
|
| 161 |
-
if st.button("Translate ASL (from file)"):
|
| 162 |
-
with st.spinner("Translating video... This might take a while depending on video length."):
|
| 163 |
-
realtime_pred, unique_letters = sign_language_classification_streamlit(video_path)
|
| 164 |
-
st.success("Translation Complete!")
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
|
| 170 |
-
st.write(unique_letters)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
#
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
# --- Section for Live Webcam ---
|
| 179 |
-
st.subheader("Live ASL Translation from Webcam")
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
)
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
live_unique_letters_placeholder.markdown(f"**Unique Predicted Letters:** {st.session_state.live_unique_letters}")
|
| 200 |
-
else:
|
| 201 |
-
# Reset session state when webcam is not playing
|
| 202 |
-
if st.session_state.live_realtime_pred != "" or st.session_state.live_unique_letters != "":
|
| 203 |
-
st.session_state.live_realtime_pred = ""
|
| 204 |
-
st.session_state.live_unique_letters = ""
|
| 205 |
-
st.session_state.live_predicted_frames_buffer = []
|
| 206 |
-
st.info("Click 'Start' to begin live ASL translation from your webcam.")
|
|
|
|
| 1 |
+
# Import necessary libraries
|
|
|
|
| 2 |
from collections import Counter
|
| 3 |
import time
|
| 4 |
import traceback
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from transformers import AutoImageProcessor
|
| 7 |
+
from transformers import SiglipForImageClassification
|
| 8 |
+
from transformers.image_utils import load_image
|
| 9 |
from PIL import Image
|
| 10 |
import torch
|
| 11 |
+
import cv2 # Import cv2 for video frame processing
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Load model and processor for Alphabet Sign Language Detection
|
|
|
|
|
|
|
|
|
|
| 14 |
model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
|
| 15 |
+
model = SiglipForImageClassification.from_pretrained(model_name)
|
| 16 |
+
processor = AutoImageProcessor.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# Define the maximum number of consecutive repetitions allowed for predictions
|
| 19 |
+
MAX_CONSECUTIVE_REPETITIONS = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
def sign_language_classification(video):
|
| 23 |
+
"""
|
| 24 |
+
Predicts sign language alphabet category for each frame in a video,
|
| 25 |
+
yields predictions in real-time with repetition handling, and returns a list of unique predicted letters.
|
| 26 |
+
"""
|
| 27 |
+
print("sign_language_classification function called.") # Debug print to indicate function call
|
| 28 |
+
if video is None:
|
| 29 |
+
print("No video provided.") # Debug print if no video input
|
| 30 |
+
yield "No video provided.", "" # Yield empty string for the second output if no video
|
| 31 |
+
return
|
| 32 |
|
| 33 |
+
print(f"Video input type: {type(video)}") # Debug print to show video input type
|
| 34 |
+
print(f"Video value: {video}") # Debug print to show video input value
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
predicted_letters = [] # List to store all predicted letters from each frame
|
| 37 |
+
last_predicted_label = None # Initialize variable to store the last predicted label to handle repetitions
|
| 38 |
+
consecutive_repetitions = 0 # Initialize counter for consecutive repetitions of the same prediction
|
|
|
|
|
|
|
| 39 |
|
| 40 |
try:
|
| 41 |
+
print("Starting frame processing loop.") # Debug print to indicate start of frame processing
|
| 42 |
+
frames = []
|
| 43 |
+
if isinstance(video, str):
|
| 44 |
+
# If video is a filepath (e.g., uploaded file), load the video frames using OpenCV
|
| 45 |
+
cap = cv2.VideoCapture(video)
|
| 46 |
+
if not cap.isOpened():
|
| 47 |
+
yield "Error: Could not open video file.", "" # Yield error if video file cannot be opened
|
| 48 |
+
return
|
| 49 |
+
while True:
|
| 50 |
+
ret, frame = cap.read()
|
| 51 |
+
if not ret: # Break the loop if no more frames are returned
|
| 52 |
+
break
|
| 53 |
+
frames.append(frame) # Append the read frame to the frames list
|
| 54 |
+
cap.release() # Release the video capture object
|
| 55 |
+
elif isinstance(video, list):
|
| 56 |
+
# If video is already a list of frames (e.g., from webcam in some Gradio versions)
|
| 57 |
+
frames = video
|
| 58 |
+
else:
|
| 59 |
+
yield "Error: Unsupported video input type.", "" # Yield error for unsupported video input types
|
| 60 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
|
|
|
| 62 |
|
| 63 |
+
for i, frame in enumerate(frames):
|
| 64 |
+
# print(f"Processing frame {i}") # Debug print - Removed for cleaner output
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
# Convert the numpy frame (BGR format from OpenCV) to a PIL Image in RGB format for the model
|
| 67 |
+
image = Image.fromarray(frame).convert("RGB")
|
| 68 |
+
# print(f"Frame {i} converted to PIL Image.") # Debug print - Removed for cleaner output
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
+
# Process the image frame using the pre-trained processor and model
|
| 72 |
+
inputs = processor(images=image, return_tensors="pt") # Prepare image for model input
|
| 73 |
+
# print(f"Frame {i} processed by processor.") # Debug print - Removed for cleaner output
|
| 74 |
|
|
|
|
| 75 |
|
| 76 |
+
# Perform inference with the model
|
| 77 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
| 78 |
+
outputs = model(**inputs)
|
| 79 |
+
logits = outputs.logits # Get the raw output scores (logits)
|
| 80 |
+
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # Apply softmax to get probabilities and convert to list
|
| 81 |
+
# print(f"Frame {i} processed by model. Logits shape: {logits.shape}") # Debug print - Removed for cleaner output
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
# Define the labels mapping model output indices to ASL alphabet letters
|
| 85 |
+
labels = {
|
| 86 |
+
"0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
|
| 87 |
+
"10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
|
| 88 |
+
"20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
|
| 89 |
+
}
|
| 90 |
+
# Get the index of the highest probability and find the corresponding predicted label
|
| 91 |
+
predicted_label_index = probs.index(max(probs))
|
| 92 |
+
predicted_label = labels[str(predicted_label_index)]
|
| 93 |
+
# print(f"Frame {i} prediction: {predicted_label}") # Debug print - Removed for cleaner output
|
| 94 |
|
| 95 |
+
predicted_letters.append(predicted_label) # Append predicted letter to the list of all predictions
|
|
|
|
| 96 |
|
| 97 |
+
# Check for consecutive repetitions and yield only if the rule is met
|
| 98 |
+
if predicted_label == last_predicted_label:
|
| 99 |
+
consecutive_repetitions += 1
|
| 100 |
+
else:
|
| 101 |
+
consecutive_repetitions = 1 # Reset consecutive count if prediction changes
|
| 102 |
|
| 103 |
+
# Yield the prediction if it's not a consecutive repetition beyond the limit or if it's the first prediction
|
| 104 |
+
if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None:
|
| 105 |
+
yield predicted_label, "" # Yield real-time prediction and empty string for the second output
|
| 106 |
+
last_predicted_label = predicted_label # Update the last predicted label
|
| 107 |
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
print("Finished frame processing loop.") # Debug print to indicate end of frame processing
|
| 110 |
+
# Get unique predicted letters while maintaining order of appearance
|
| 111 |
+
unique_predicted_letters = list(dict.fromkeys(predicted_letters))
|
| 112 |
+
final_output = ", ".join(unique_predicted_letters) # Join unique letters into a comma-separated string
|
| 113 |
+
# Yield the last predicted label (or empty string if none) and the final list of unique letters
|
| 114 |
+
yield last_predicted_label if last_predicted_label is not None else "", final_output
|
| 115 |
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"Error caught: {e}") # Debug print if an error occurs
|
| 118 |
+
# Yield error message and traceback information in case of an exception
|
| 119 |
+
yield f"Error processing video: {e}", f"Error processing video: {e}
|
| 120 |
+
{traceback.format_exc()}"
|
| 121 |
+
|
| 122 |
+
# Custom CSS for styling (commented out)
|
| 123 |
+
# custom_css = """
|
| 124 |
+
# body {
|
| 125 |
+
# background-color: #add8e6;
|
| 126 |
+
# }
|
| 127 |
+
# """
|
| 128 |
+
|
| 129 |
+
# Create Gradio interface with video input and multiple outputs
|
| 130 |
+
iface = gr.Interface(
|
| 131 |
+
fn=sign_language_classification, # The function to run when the user interacts with the interface
|
| 132 |
+
inputs=gr.Video(sources=["upload", "webcam"]), # Input component: Video, allowing upload or webcam
|
| 133 |
+
outputs=[
|
| 134 |
+
gr.Label(label="Real-time Prediction"), # Output component: Label to display the real-time prediction
|
| 135 |
+
gr.Textbox(label="Unique Predicted Letters") # Output component: Textbox to display the final list of unique predicted letters
|
| 136 |
+
],
|
| 137 |
+
title="ASL Translator", # Title of the Gradio interface
|
| 138 |
+
description="Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions in real-time and a summary list. ASL Words Translator coming soon!", # Description displayed below the title
|
| 139 |
+
# css=custom_css # Apply custom CSS (commented out)
|
| 140 |
)
|
| 141 |
|
| 142 |
+
# Launch the Gradio app
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
|
| 2 |
-
|
| 3 |
opencv-python-headless
|
| 4 |
-
transformers
|
| 5 |
torch
|
| 6 |
Pillow
|
| 7 |
-
|
| 8 |
-
av==12.0.0
|
| 9 |
numpy
|
| 10 |
-
# huggingface_hub==0.20.0 # Removed pinning to let transformers choose
|
|
|
|
| 1 |
|
| 2 |
+
gradio
|
| 3 |
opencv-python-headless
|
| 4 |
+
transformers
|
| 5 |
torch
|
| 6 |
Pillow
|
| 7 |
+
av
|
|
|
|
| 8 |
numpy
|
|
|