Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -4,13 +4,13 @@ from collections import Counter
|
|
| 4 |
import time
|
| 5 |
import traceback
|
| 6 |
from transformers import SiglipForImageClassification
|
| 7 |
-
from transformers.image_processing_utils import AutoImageProcessor
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
import cv2
|
| 11 |
-
import numpy as np
|
| 12 |
-
import av
|
| 13 |
-
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
|
| 14 |
|
| 15 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 16 |
os.makedirs("/tmp/huggingface", exist_ok=True)
|
|
@@ -25,8 +25,6 @@ def load_model_and_processor():
|
|
| 25 |
print("INFO: Model and processor loaded successfully.")
|
| 26 |
return model, processor
|
| 27 |
|
| 28 |
-
model, processor = load_model_and_processor()
|
| 29 |
-
|
| 30 |
# Define the maximum number of consecutive repetitions allowed for predictions
|
| 31 |
MAX_CONSECUTIVE_REPETITIONS = 3
|
| 32 |
|
|
@@ -37,11 +35,13 @@ labels = {
|
|
| 37 |
"20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
|
| 38 |
}
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
st.session_state
|
| 42 |
-
st.session_state.processor_obj =
|
| 43 |
-
st.session_state
|
| 44 |
-
st.session_state.
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Initialize session state for live predictions if not already present
|
| 47 |
if 'live_realtime_pred' not in st.session_state:
|
|
@@ -66,7 +66,7 @@ class SignLanguageVideoProcessor(VideoProcessorBase):
|
|
| 66 |
|
| 67 |
inputs = self.processor(images=img_pil, return_tensors="pt")
|
| 68 |
with torch.no_grad():
|
| 69 |
-
outputs = model(**inputs)
|
| 70 |
logits = outputs.logits
|
| 71 |
|
| 72 |
predicted_label_index = torch.argmax(logits, dim=1).item()
|
|
@@ -97,6 +97,12 @@ def sign_language_classification_streamlit(video_path):
|
|
| 97 |
last_predicted_label = None
|
| 98 |
consecutive_repetitions = 0
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
cap = cv2.VideoCapture(video_path)
|
| 102 |
if not cap.isOpened():
|
|
@@ -107,12 +113,12 @@ def sign_language_classification_streamlit(video_path):
|
|
| 107 |
if not ret:
|
| 108 |
break
|
| 109 |
image = Image.fromarray(frame).convert("RGB")
|
| 110 |
-
inputs =
|
| 111 |
with torch.no_grad():
|
| 112 |
-
outputs =
|
| 113 |
logits = outputs.logits
|
| 114 |
predicted_label_index = torch.argmax(logits, dim=1).item()
|
| 115 |
-
current_predicted_label =
|
| 116 |
|
| 117 |
# Apply repetition logic
|
| 118 |
if current_predicted_label == last_predicted_label:
|
|
@@ -120,21 +126,19 @@ def sign_language_classification_streamlit(video_path):
|
|
| 120 |
else:
|
| 121 |
consecutive_repetitions = 1
|
| 122 |
|
| 123 |
-
if consecutive_repetitions >
|
| 124 |
predicted_letters.append(current_predicted_label)
|
| 125 |
last_predicted_label = current_predicted_label
|
| 126 |
|
| 127 |
cap.release()
|
| 128 |
unique_predicted_letters = list(dict.fromkeys(predicted_letters))
|
| 129 |
final_output_str = ", ".join(unique_predicted_letters)
|
| 130 |
-
# For 'Real-time Prediction' equivalent, let's use the last valid unique prediction or the most frequent
|
| 131 |
realtime_equivalent_prediction = unique_predicted_letters[-1] if unique_predicted_letters else ""
|
| 132 |
|
| 133 |
return realtime_equivalent_prediction, final_output_str
|
| 134 |
|
| 135 |
except Exception as e:
|
| 136 |
print(f"Error caught: {e}")
|
| 137 |
-
# Modify the return to ensure the traceback is flattened into a single line
|
| 138 |
error_msg = f"Error processing video: {e}"
|
| 139 |
full_traceback_flat = traceback.format_exc().replace('\n', ' | ').replace('\r', '')
|
| 140 |
return error_msg, f"{{error_msg}} (Details: {{full_traceback_flat}})"
|
|
|
|
| 4 |
import time
|
| 5 |
import traceback
|
| 6 |
from transformers import SiglipForImageClassification
|
| 7 |
+
from transformers.image_processing_utils import AutoImageProcessor
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import av
|
| 13 |
+
from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
|
| 14 |
|
| 15 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 16 |
os.makedirs("/tmp/huggingface", exist_ok=True)
|
|
|
|
| 25 |
print("INFO: Model and processor loaded successfully.")
|
| 26 |
return model, processor
|
| 27 |
|
|
|
|
|
|
|
| 28 |
# Define the maximum number of consecutive repetitions allowed for predictions
|
| 29 |
MAX_CONSECUTIVE_REPETITIONS = 3
|
| 30 |
|
|
|
|
| 35 |
"20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
|
| 36 |
}
|
| 37 |
|
| 38 |
+
# Initialize all necessary session state variables using conditional checks
|
| 39 |
+
if 'model_obj' not in st.session_state:
|
| 40 |
+
st.session_state.model_obj, st.session_state.processor_obj = load_model_and_processor()
|
| 41 |
+
if 'labels_dict' not in st.session_state:
|
| 42 |
+
st.session_state.labels_dict = labels
|
| 43 |
+
if 'max_consecutive_repetitions_val' not in st.session_state:
|
| 44 |
+
st.session_state.max_consecutive_repetitions_val = MAX_CONSECUTIVE_REPETITIONS
|
| 45 |
|
| 46 |
# Initialize session state for live predictions if not already present
|
| 47 |
if 'live_realtime_pred' not in st.session_state:
|
|
|
|
| 66 |
|
| 67 |
inputs = self.processor(images=img_pil, return_tensors="pt")
|
| 68 |
with torch.no_grad():
|
| 69 |
+
outputs = self.model(**inputs)
|
| 70 |
logits = outputs.logits
|
| 71 |
|
| 72 |
predicted_label_index = torch.argmax(logits, dim=1).item()
|
|
|
|
| 97 |
last_predicted_label = None
|
| 98 |
consecutive_repetitions = 0
|
| 99 |
|
| 100 |
+
# Access model, processor, labels, and MAX_CONSECUTIVE_REPETITIONS from session state
|
| 101 |
+
local_model = st.session_state.model_obj
|
| 102 |
+
local_processor = st.session_state.processor_obj
|
| 103 |
+
local_labels = st.session_state.labels_dict
|
| 104 |
+
local_max_consecutive_repetitions = st.session_state.max_consecutive_repetitions_val
|
| 105 |
+
|
| 106 |
try:
|
| 107 |
cap = cv2.VideoCapture(video_path)
|
| 108 |
if not cap.isOpened():
|
|
|
|
| 113 |
if not ret:
|
| 114 |
break
|
| 115 |
image = Image.fromarray(frame).convert("RGB")
|
| 116 |
+
inputs = local_processor(images=image, return_tensors="pt")
|
| 117 |
with torch.no_grad():
|
| 118 |
+
outputs = local_model(**inputs)
|
| 119 |
logits = outputs.logits
|
| 120 |
predicted_label_index = torch.argmax(logits, dim=1).item()
|
| 121 |
+
current_predicted_label = local_labels[str(predicted_label_index)]
|
| 122 |
|
| 123 |
# Apply repetition logic
|
| 124 |
if current_predicted_label == last_predicted_label:
|
|
|
|
| 126 |
else:
|
| 127 |
consecutive_repetitions = 1
|
| 128 |
|
| 129 |
+
if consecutive_repetitions > local_max_consecutive_repetitions or last_predicted_label is None:
|
| 130 |
predicted_letters.append(current_predicted_label)
|
| 131 |
last_predicted_label = current_predicted_label
|
| 132 |
|
| 133 |
cap.release()
|
| 134 |
unique_predicted_letters = list(dict.fromkeys(predicted_letters))
|
| 135 |
final_output_str = ", ".join(unique_predicted_letters)
|
|
|
|
| 136 |
realtime_equivalent_prediction = unique_predicted_letters[-1] if unique_predicted_letters else ""
|
| 137 |
|
| 138 |
return realtime_equivalent_prediction, final_output_str
|
| 139 |
|
| 140 |
except Exception as e:
|
| 141 |
print(f"Error caught: {e}")
|
|
|
|
| 142 |
error_msg = f"Error processing video: {e}"
|
| 143 |
full_traceback_flat = traceback.format_exc().replace('\n', ' | ').replace('\r', '')
|
| 144 |
return error_msg, f"{{error_msg}} (Details: {{full_traceback_flat}})"
|