kalpniks commited on
Commit
13b84eb
·
verified ·
1 Parent(s): 4b02806

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +102 -4
  2. requirements.txt +3 -0
app.py CHANGED
@@ -7,6 +7,9 @@ from transformers import AutoImageProcessor, SiglipForImageClassification
7
  from PIL import Image
8
  import torch
9
  import cv2
 
 
 
10
 
11
  os.environ["HF_HOME"] = "/tmp/huggingface"
12
  os.makedirs("/tmp/huggingface", exist_ok=True)
@@ -33,6 +36,64 @@ labels = {
33
  "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
34
  }
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def sign_language_classification_streamlit(video_path):
37
  print("sign_language_classification_streamlit function called.")
38
  predicted_letters = []
@@ -84,7 +145,11 @@ def sign_language_classification_streamlit(video_path):
84
 
85
  st.set_page_config(page_title="ASL Translator", layout="centered")
86
  st.title("ASL Translator")
87
- st.markdown("Upload a video to translate ASL into one of the 26 sign language alphabet categories and see predictions. ASL Words Translator coming soon!")
 
 
 
 
88
 
89
  uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov", "webm"])
90
 
@@ -95,17 +160,50 @@ if uploaded_file is not None:
95
  f.write(uploaded_file.getbuffer())
96
  st.video(video_path)
97
 
98
- if st.button("Translate ASL"):
99
  with st.spinner("Translating video... This might take a while depending on video length."):
100
  realtime_pred, unique_letters = sign_language_classification_streamlit(video_path)
101
  st.success("Translation Complete!")
102
 
103
- st.subheader("Last Predicted Sign (Real-time Equivalent)")
104
  st.write(realtime_pred)
105
 
106
- st.subheader("Unique Predicted Letters")
107
  st.write(unique_letters)
108
 
109
  os.remove(video_path) # Clean up temporary file
110
  else:
111
  st.info("Please upload a video file to start the translation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from PIL import Image
8
  import torch
9
  import cv2
10
+ import numpy as np # Required for opencv and streamlit-webrtc frame processing
11
+ import av # Required for streamlit-webrtc
12
+ from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, ClientSettings
13
 
14
  os.environ["HF_HOME"] = "/tmp/huggingface"
15
  os.makedirs("/tmp/huggingface", exist_ok=True)
 
36
  "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
37
  }
38
 
39
+ # Store model and processor in session state for access by VideoProcessor
40
+ if 'model_obj' not in st.session_state:
41
+ st.session_state.model_obj = model
42
+ if 'processor_obj' not in st.session_state:
43
+ st.session_state.processor_obj = processor
44
+ if 'labels_dict' not in st.session_state:
45
+ st.session_state.labels_dict = labels
46
+ if 'max_consecutive_repetitions_val' not in st.session_state:
47
+ st.session_state.max_consecutive_repetitions_val = MAX_CONSECUTIVE_REPETITIONS
48
+
49
+ # Initialize session state for live predictions if not already present
50
+ if 'live_realtime_pred' not in st.session_state:
51
+ st.session_state.live_realtime_pred = ""
52
+ if 'live_unique_letters' not in st.session_state:
53
+ st.session_state.live_unique_letters = ""
54
+ if 'live_predicted_frames_buffer' not in st.session_state:
55
+ st.session_state.live_predicted_frames_buffer = []
56
+
57
+
58
+ class SignLanguageVideoProcessor(VideoProcessorBase):
59
+ def __init__(self):
60
+ self.model = st.session_state.model_obj
61
+ self.processor = st.session_state.processor_obj
62
+ self.labels = st.session_state.labels_dict
63
+ self.max_consecutive_repetitions = st.session_state.max_consecutive_repetitions_val
64
+ self.last_predicted_label = None
65
+ self.consecutive_repetitions = 0
66
+
67
+ def recv(self, frame: av.VideoFrame) -> av.VideoFrame:
68
+ img_pil = frame.to_image().convert("RGB")
69
+
70
+ inputs = self.processor(images=img_pil, return_tensors="pt")
71
+ with torch.no_grad():
72
+ outputs = self.model(**inputs)
73
+ logits = outputs.logits
74
+
75
+ predicted_label_index = torch.argmax(logits, dim=1).item()
76
+ current_predicted_label = self.labels[str(predicted_label_index)]
77
+
78
+ # Update the buffer of all predicted letters
79
+ st.session_state.live_predicted_frames_buffer.append(current_predicted_label)
80
+
81
+ # Apply repetition logic for real-time display
82
+ if current_predicted_label == self.last_predicted_label:
83
+ self.consecutive_repetitions += 1
84
+ else:
85
+ self.consecutive_repetitions = 1
86
+
87
+ if self.consecutive_repetitions > self.max_consecutive_repetitions or self.last_predicted_label is None:
88
+ st.session_state.live_realtime_pred = current_predicted_label
89
+ self.last_predicted_label = current_predicted_label
90
+
91
+ # Update unique letters from the buffer
92
+ unique_preds = list(dict.fromkeys(st.session_state.live_predicted_frames_buffer))
93
+ st.session_state.live_unique_letters = ", ".join(unique_preds)
94
+
95
+ return frame # Return original frame (or modified frame if drawing text)
96
+
97
  def sign_language_classification_streamlit(video_path):
98
  print("sign_language_classification_streamlit function called.")
99
  predicted_letters = []
 
145
 
146
  st.set_page_config(page_title="ASL Translator", layout="centered")
147
  st.title("ASL Translator")
148
+ st.markdown("Upload a video or use your webcam to translate ASL into one of the 26 sign language alphabet categories and see predictions. ASL Words Translator coming soon!")
149
+
150
+
151
+ # --- Section for Uploaded Video ---
152
+ st.subheader("Translate from Uploaded Video")
153
 
154
  uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov", "webm"])
155
 
 
160
  f.write(uploaded_file.getbuffer())
161
  st.video(video_path)
162
 
163
+ if st.button("Translate ASL (from file)"):
164
  with st.spinner("Translating video... This might take a while depending on video length."):
165
  realtime_pred, unique_letters = sign_language_classification_streamlit(video_path)
166
  st.success("Translation Complete!")
167
 
168
+ st.subheader("Last Predicted Sign (from file)")
169
  st.write(realtime_pred)
170
 
171
+ st.subheader("Unique Predicted Letters (from file)")
172
  st.write(unique_letters)
173
 
174
  os.remove(video_path) # Clean up temporary file
175
  else:
176
  st.info("Please upload a video file to start the translation.")
177
+
178
+ st.markdown("--- # ---
179
+ ")
180
+
181
+ # --- Section for Live Webcam ---
182
+ st.subheader("Live ASL Translation from Webcam")
183
+
184
+ # Placeholders for live updates
185
+ live_realtime_placeholder = st.empty()
186
+ live_unique_letters_placeholder = st.empty()
187
+
188
+ webrtc_ctx = webrtc_streamer(
189
+ key="webrtc_asl",
190
+ mode="sendrecv",
191
+ # rtc_configuration=ClientSettings( # Removed for broader compatibility, relies on default STUN/TURN
192
+ # rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
193
+ # ).rtc_configuration,
194
+ video_processor_factory=SignLanguageVideoProcessor,
195
+ media_stream_constraints={"video": True, "audio": False},
196
+ async_processing=True,
197
+ )
198
+
199
+ if webrtc_ctx.state.playing:
200
+ # Update placeholders based on session state. These will update on each rerun triggered by session_state changes.
201
+ live_realtime_placeholder.markdown(f"**Real-time Prediction:** {st.session_state.live_realtime_pred}")
202
+ live_unique_letters_placeholder.markdown(f"**Unique Predicted Letters:** {st.session_state.live_unique_letters}")
203
+ else:
204
+ # Reset session state when webcam is not playing
205
+ if st.session_state.live_realtime_pred != "" or st.session_state.live_unique_letters != "":
206
+ st.session_state.live_realtime_pred = ""
207
+ st.session_state.live_unique_letters = ""
208
+ st.session_state.live_predicted_frames_buffer = []
209
+ st.info("Click 'Start' to begin live ASL translation from your webcam.")
requirements.txt CHANGED
@@ -4,3 +4,6 @@ opencv-python-headless
4
  transformers
5
  torch
6
  Pillow
 
 
 
 
4
  transformers
5
  torch
6
  Pillow
7
+ streamlit_webrtc
8
+ pyav
9
+ numpy