kalpniks commited on
Commit
69f9775
·
verified ·
1 Parent(s): 39068b7

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. app.py +113 -175
  3. requirements.txt +3 -5
Dockerfile CHANGED
@@ -12,5 +12,5 @@ COPY . /app
12
 
13
  EXPOSE 7860
14
 
15
- # CMD for Streamlit application
16
- CMD ["streamlit", "run", "app.py", "--server.port", "7860", "--server.address", "0.0.0.0"]
 
12
 
13
  EXPOSE 7860
14
 
15
+ # CMD for Gradio application
16
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,206 +1,144 @@
1
- import streamlit as st
2
- import os
3
  from collections import Counter
4
  import time
5
  import traceback
6
- from transformers import AutoImageProcessor, SiglipForImageClassification
 
 
 
7
  from PIL import Image
8
  import torch
9
- import cv2
10
- import numpy as np
11
- import av
12
- from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
13
- # ClientSettings import removed as it was causing issues, using dict directly for rtc_configuration
14
 
15
- os.environ["HF_HOME"] = "/tmp/huggingface"
16
- os.makedirs("/tmp/huggingface", exist_ok=True)
17
-
18
- # Load model and processor
19
  model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
20
- @st.cache_resource
21
- def load_model_and_processor():
22
- print(f"INFO: Loading model '{model_name}'...")
23
- model = SiglipForImageClassification.from_pretrained(model_name)
24
- processor = AutoImageProcessor.from_pretrained(model_name)
25
- print("INFO: Model and processor loaded successfully.")
26
- return model, processor
27
-
28
- # Call the cached resource loader once and store in global variables
29
- model, processor = load_model_and_processor()
30
-
31
- # Define the maximum number of consecutive repetitions allowed for predictions (global constant)
32
- MAX_CONSECUTIVE_REPETITIONS = 3
33
 
34
- # Define labels (global constant)
35
- labels = {
36
- "0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
37
- "10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
38
- "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
39
- }
40
-
41
- # Initialize session state for live predictions if not already present
42
- # These are the only session state variables that need to be dynamic
43
- if 'live_realtime_pred' not in st.session_state:
44
- st.session_state.live_realtime_pred = ""
45
- if 'live_unique_letters' not in st.session_state:
46
- st.session_state.live_unique_letters = ""
47
- if 'live_predicted_frames_buffer' not in st.session_state:
48
- st.session_state.live_predicted_frames_buffer = []
49
-
50
-
51
- class SignLanguageVideoProcessor(VideoProcessorBase):
52
- def __init__(self):
53
- # Directly use the global variables (which are cached resources or constants)
54
- self.model = model
55
- self.processor = processor
56
- self.labels = labels
57
- self.max_consecutive_repetitions = MAX_CONSECUTIVE_REPETITIONS
58
- self.last_predicted_label = None
59
- self.consecutive_repetitions = 0
60
-
61
- def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
62
- img_pil = frame.to_image().convert("RGB")
63
-
64
- inputs = self.processor(images=img_pil, return_tensors="pt")
65
- with torch.no_grad():
66
- outputs = self.model(**inputs)
67
- logits = outputs.logits
68
-
69
- predicted_label_index = torch.argmax(logits, dim=1).item()
70
- current_predicted_label = self.labels[str(predicted_label_index)]
71
-
72
- # Update the buffer of all predicted letters
73
- st.session_state.live_predicted_frames_buffer.append(current_predicted_label)
74
-
75
- # Apply repetition logic for real-time display
76
- if current_predicted_label == self.last_predicted_label:
77
- self.consecutive_repetitions += 1
78
- else:
79
- self.consecutive_repetitions = 1
80
-
81
- if self.consecutive_repetitions > self.max_consecutive_repetitions or self.last_predicted_label is None:
82
- st.session_state.live_realtime_pred = current_predicted_label
83
- self.last_predicted_label = current_predicted_label
84
 
85
- # Update unique letters from the buffer
86
- unique_preds = list(dict.fromkeys(st.session_state.live_predicted_frames_buffer))
87
- st.session_state.live_unique_letters = ", ".join(unique_preds)
88
 
89
- return frame # Return original frame (or modified frame if drawing text)
 
 
 
 
 
 
 
 
 
90
 
91
- def sign_language_classification_streamlit(video_path):
92
- print("sign_language_classification_streamlit function called.")
93
- predicted_letters = []
94
- last_predicted_label = None
95
- consecutive_repetitions = 0
96
 
97
- # Access model, processor, labels, and MAX_CONSECUTIVE_REPETITIONS from global scope
98
- local_model = model
99
- local_processor = processor
100
- local_labels = labels
101
- local_max_consecutive_repetitions = MAX_CONSECUTIVE_REPETITIONS
102
 
103
  try:
104
- cap = cv2.VideoCapture(video_path)
105
- if not cap.isOpened():
106
- return "Error: Could not open video file.", ""
107
-
108
- while True:
109
- ret, frame = cap.read()
110
- if not ret:
111
- break
112
- image = Image.fromarray(frame).convert("RGB")
113
- inputs = local_processor(images=image, return_tensors="pt")
114
- with torch.no_grad():
115
- outputs = local_model(**inputs)
116
- logits = outputs.logits
117
- predicted_label_index = torch.argmax(logits, dim=1).item()
118
- current_predicted_label = local_labels[str(predicted_label_index)]
119
-
120
- # Apply repetition logic
121
- if current_predicted_label == last_predicted_label:
122
- consecutive_repetitions += 1
123
- else:
124
- consecutive_repetitions = 1
125
-
126
- if consecutive_repetitions > local_max_consecutive_repetitions or last_predicted_label is None:
127
- predicted_letters.append(current_predicted_label)
128
- last_predicted_label = current_predicted_label
129
-
130
- cap.release()
131
- unique_predicted_letters = list(dict.fromkeys(predicted_letters))
132
- final_output_str = ", ".join(unique_predicted_letters)
133
- realtime_equivalent_prediction = unique_predicted_letters[-1] if unique_predicted_letters else ""
134
 
135
- return realtime_equivalent_prediction, final_output_str
136
 
137
- except Exception as e:
138
- print(f"Error caught: {e}")
139
- error_msg = f"Error processing video: {e}"
140
- full_traceback_flat = traceback.format_exc().replace('\n', ' | ').replace('\r', '')
141
- return error_msg, f"{{error_msg}} (Details: {{full_traceback_flat}})"
142
 
143
-
144
- st.set_page_config(page_title="ASL Translator", layout="centered")
145
- st.title("ASL Translator")
146
- st.markdown("Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions. ASL Words Translator coming soon!")
147
 
148
 
149
- # --- Section for Uploaded Video ---
150
- st.subheader("Translate from Uploaded Video")
 
151
 
152
- uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov", "webm"])
153
 
154
- if uploaded_file is not None:
155
- # Save the uploaded file temporarily
156
- video_path = os.path.join("/tmp", uploaded_file.name)
157
- with open(video_path, "wb") as f:
158
- f.write(uploaded_file.getbuffer())
159
- st.video(video_path)
160
 
161
- if st.button("Translate ASL (from file)"):
162
- with st.spinner("Translating video... This might take a while depending on video length."):
163
- realtime_pred, unique_letters = sign_language_classification_streamlit(video_path)
164
- st.success("Translation Complete!")
165
 
166
- st.subheader("Last Predicted Sign (from file)")
167
- st.write(realtime_pred)
 
 
 
 
 
 
 
 
168
 
169
- st.subheader("Unique Predicted Letters (from file)")
170
- st.write(unique_letters)
171
 
172
- os.remove(video_path) # Clean up temporary file
173
- else:
174
- st.info("Please upload a video file to start the translation.")
 
 
175
 
176
- # The line st.markdown("--- # ---") was causing a SyntaxError, temporarily removed for testing.
 
 
 
177
 
178
- # --- Section for Live Webcam ---
179
- st.subheader("Live ASL Translation from Webcam")
180
 
181
- # Placeholders for live updates
182
- live_realtime_placeholder = st.empty()
183
- live_unique_letters_placeholder = st.empty()
 
 
 
184
 
185
- webrtc_ctx = webrtc_streamer(
186
- key="webrtc_asl",
187
- mode=WebRtcMode.SENDRECV,
188
- rtc_configuration={
189
- "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
190
- }, # Directly providing the dict
191
- video_processor_factory=SignLanguageVideoProcessor,
192
- media_stream_constraints={"video": True, "audio": False},
193
- async_processing=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  )
195
 
196
- if webrtc_ctx.state.playing:
197
- # Update placeholders based on session state. These will update on each rerun triggered by session_state changes.
198
- live_realtime_placeholder.markdown(f"**Real-time Prediction:** {st.session_state.live_realtime_pred}")
199
- live_unique_letters_placeholder.markdown(f"**Unique Predicted Letters:** {st.session_state.live_unique_letters}")
200
- else:
201
- # Reset session state when webcam is not playing
202
- if st.session_state.live_realtime_pred != "" or st.session_state.live_unique_letters != "":
203
- st.session_state.live_realtime_pred = ""
204
- st.session_state.live_unique_letters = ""
205
- st.session_state.live_predicted_frames_buffer = []
206
- st.info("Click 'Start' to begin live ASL translation from your webcam.")
 
1
+ # Import necessary libraries
 
2
  from collections import Counter
3
  import time
4
  import traceback
5
+ import gradio as gr
6
+ from transformers import AutoImageProcessor
7
+ from transformers import SiglipForImageClassification
8
+ from transformers.image_utils import load_image
9
  from PIL import Image
10
  import torch
11
+ import cv2 # Import cv2 for video frame processing
 
 
 
 
12
 
13
+ # Load model and processor for Alphabet Sign Language Detection
 
 
 
14
  model_name = "prithivMLmods/Alphabet-Sign-Language-Detection"
15
+ model = SiglipForImageClassification.from_pretrained(model_name)
16
+ processor = AutoImageProcessor.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Define the maximum number of consecutive repetitions allowed for predictions
19
+ MAX_CONSECUTIVE_REPETITIONS = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
21
 
22
+ def sign_language_classification(video):
23
+ """
24
+ Predicts sign language alphabet category for each frame in a video,
25
+ yields predictions in real-time with repetition handling, and returns a list of unique predicted letters.
26
+ """
27
+ print("sign_language_classification function called.") # Debug print to indicate function call
28
+ if video is None:
29
+ print("No video provided.") # Debug print if no video input
30
+ yield "No video provided.", "" # Yield empty string for the second output if no video
31
+ return
32
 
33
+ print(f"Video input type: {type(video)}") # Debug print to show video input type
34
+ print(f"Video value: {video}") # Debug print to show video input value
 
 
 
35
 
36
+ predicted_letters = [] # List to store all predicted letters from each frame
37
+ last_predicted_label = None # Initialize variable to store the last predicted label to handle repetitions
38
+ consecutive_repetitions = 0 # Initialize counter for consecutive repetitions of the same prediction
 
 
39
 
40
  try:
41
+ print("Starting frame processing loop.") # Debug print to indicate start of frame processing
42
+ frames = []
43
+ if isinstance(video, str):
44
+ # If video is a filepath (e.g., uploaded file), load the video frames using OpenCV
45
+ cap = cv2.VideoCapture(video)
46
+ if not cap.isOpened():
47
+ yield "Error: Could not open video file.", "" # Yield error if video file cannot be opened
48
+ return
49
+ while True:
50
+ ret, frame = cap.read()
51
+ if not ret: # Break the loop if no more frames are returned
52
+ break
53
+ frames.append(frame) # Append the read frame to the frames list
54
+ cap.release() # Release the video capture object
55
+ elif isinstance(video, list):
56
+ # If video is already a list of frames (e.g., from webcam in some Gradio versions)
57
+ frames = video
58
+ else:
59
+ yield "Error: Unsupported video input type.", "" # Yield error for unsupported video input types
60
+ return
 
 
 
 
 
 
 
 
 
 
61
 
 
62
 
63
+ for i, frame in enumerate(frames):
64
+ # print(f"Processing frame {i}") # Debug print - Removed for cleaner output
 
 
 
65
 
66
+ # Convert the numpy frame (BGR format from OpenCV) to a PIL Image in RGB format for the model
67
+ image = Image.fromarray(frame).convert("RGB")
68
+ # print(f"Frame {i} converted to PIL Image.") # Debug print - Removed for cleaner output
 
69
 
70
 
71
+ # Process the image frame using the pre-trained processor and model
72
+ inputs = processor(images=image, return_tensors="pt") # Prepare image for model input
73
+ # print(f"Frame {i} processed by processor.") # Debug print - Removed for cleaner output
74
 
 
75
 
76
+ # Perform inference with the model
77
+ with torch.no_grad(): # Disable gradient calculation for inference
78
+ outputs = model(**inputs)
79
+ logits = outputs.logits # Get the raw output scores (logits)
80
+ probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist() # Apply softmax to get probabilities and convert to list
81
+ # print(f"Frame {i} processed by model. Logits shape: {logits.shape}") # Debug print - Removed for cleaner output
82
 
 
 
 
 
83
 
84
+ # Define the labels mapping model output indices to ASL alphabet letters
85
+ labels = {
86
+ "0": "A", "1": "B", "2": "C", "3": "D", "4": "E", "5": "F", "6": "G", "7": "H", "8": "I", "9": "J",
87
+ "10": "K", "11": "L", "12": "M", "13": "N", "14": "O", "15": "P", "16": "Q", "17": "R", "18": "S", "19": "T",
88
+ "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
89
+ }
90
+ # Get the index of the highest probability and find the corresponding predicted label
91
+ predicted_label_index = probs.index(max(probs))
92
+ predicted_label = labels[str(predicted_label_index)]
93
+ # print(f"Frame {i} prediction: {predicted_label}") # Debug print - Removed for cleaner output
94
 
95
+ predicted_letters.append(predicted_label) # Append predicted letter to the list of all predictions
 
96
 
97
+ # Check for consecutive repetitions and yield only if the rule is met
98
+ if predicted_label == last_predicted_label:
99
+ consecutive_repetitions += 1
100
+ else:
101
+ consecutive_repetitions = 1 # Reset consecutive count if prediction changes
102
 
103
+ # Yield the prediction if it's not a consecutive repetition beyond the limit or if it's the first prediction
104
+ if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None:
105
+ yield predicted_label, "" # Yield real-time prediction and empty string for the second output
106
+ last_predicted_label = predicted_label # Update the last predicted label
107
 
 
 
108
 
109
+ print("Finished frame processing loop.") # Debug print to indicate end of frame processing
110
+ # Get unique predicted letters while maintaining order of appearance
111
+ unique_predicted_letters = list(dict.fromkeys(predicted_letters))
112
+ final_output = ", ".join(unique_predicted_letters) # Join unique letters into a comma-separated string
113
+ # Yield the last predicted label (or empty string if none) and the final list of unique letters
114
+ yield last_predicted_label if last_predicted_label is not None else "", final_output
115
 
116
+ except Exception as e:
117
+ print(f"Error caught: {e}") # Debug print if an error occurs
118
+ # Yield error message and traceback information in case of an exception
119
+ yield f"Error processing video: {e}", f"Error processing video: {e}
120
+ {traceback.format_exc()}"
121
+
122
+ # Custom CSS for styling (commented out)
123
+ # custom_css = """
124
+ # body {
125
+ # background-color: #add8e6;
126
+ # }
127
+ # """
128
+
129
+ # Create Gradio interface with video input and multiple outputs
130
+ iface = gr.Interface(
131
+ fn=sign_language_classification, # The function to run when the user interacts with the interface
132
+ inputs=gr.Video(sources=["upload", "webcam"]), # Input component: Video, allowing upload or webcam
133
+ outputs=[
134
+ gr.Label(label="Real-time Prediction"), # Output component: Label to display the real-time prediction
135
+ gr.Textbox(label="Unique Predicted Letters") # Output component: Textbox to display the final list of unique predicted letters
136
+ ],
137
+ title="ASL Translator", # Title of the Gradio interface
138
+ description="Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions in real-time and a summary list. ASL Words Translator coming soon!", # Description displayed below the title
139
+ # css=custom_css # Apply custom CSS (commented out)
140
  )
141
 
142
+ # Launch the Gradio app
143
+ if __name__ == "__main__":
144
+ iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
 
2
- streamlit
3
  opencv-python-headless
4
- transformers==4.40.1
5
  torch
6
  Pillow
7
- streamlit_webrtc
8
- av==12.0.0
9
  numpy
10
- # huggingface_hub==0.20.0 # Removed pinning to let transformers choose
 
1
 
2
+ gradio
3
  opencv-python-headless
4
+ transformers
5
  torch
6
  Pillow
7
+ av
 
8
  numpy