kalpniks commited on
Commit
00afd62
·
verified ·
1 Parent(s): 3cc0821

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -4,13 +4,13 @@ from collections import Counter
4
  import time
5
  import traceback
6
  from transformers import SiglipForImageClassification
7
- from transformers.image_processing_utils import AutoImageProcessor # Changed import path
8
  from PIL import Image
9
  import torch
10
  import cv2
11
- import numpy as np # Required for opencv and streamlit-webrtc frame processing
12
- import av # Required for streamlit-webrtc
13
- from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode # Import WebRtcMode
14
 
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
16
  os.makedirs("/tmp/huggingface", exist_ok=True)
@@ -25,8 +25,6 @@ def load_model_and_processor():
25
  print("INFO: Model and processor loaded successfully.")
26
  return model, processor
27
 
28
- model, processor = load_model_and_processor()
29
-
30
  # Define the maximum number of consecutive repetitions allowed for predictions
31
  MAX_CONSECUTIVE_REPETITIONS = 3
32
 
@@ -37,11 +35,13 @@ labels = {
37
  "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
38
  }
39
 
40
- # Store model and processor in session state for access by VideoProcessor
41
- st.session_state.model_obj = model
42
- st.session_state.processor_obj = processor
43
- st.session_state.labels_dict = labels
44
- st.session_state.max_consecutive_repetitions_val = MAX_CONSECUTIVE_REPETITIONS
 
 
45
 
46
  # Initialize session state for live predictions if not already present
47
  if 'live_realtime_pred' not in st.session_state:
@@ -66,7 +66,7 @@ class SignLanguageVideoProcessor(VideoProcessorBase):
66
 
67
  inputs = self.processor(images=img_pil, return_tensors="pt")
68
  with torch.no_grad():
69
- outputs = model(**inputs)
70
  logits = outputs.logits
71
 
72
  predicted_label_index = torch.argmax(logits, dim=1).item()
@@ -97,6 +97,12 @@ def sign_language_classification_streamlit(video_path):
97
  last_predicted_label = None
98
  consecutive_repetitions = 0
99
 
 
 
 
 
 
 
100
  try:
101
  cap = cv2.VideoCapture(video_path)
102
  if not cap.isOpened():
@@ -107,12 +113,12 @@ def sign_language_classification_streamlit(video_path):
107
  if not ret:
108
  break
109
  image = Image.fromarray(frame).convert("RGB")
110
- inputs = processor(images=image, return_tensors="pt")
111
  with torch.no_grad():
112
- outputs = model(**inputs)
113
  logits = outputs.logits
114
  predicted_label_index = torch.argmax(logits, dim=1).item()
115
- current_predicted_label = labels[str(predicted_label_index)]
116
 
117
  # Apply repetition logic
118
  if current_predicted_label == last_predicted_label:
@@ -120,21 +126,19 @@ def sign_language_classification_streamlit(video_path):
120
  else:
121
  consecutive_repetitions = 1
122
 
123
- if consecutive_repetitions > MAX_CONSECUTIVE_REPETITIONS or last_predicted_label is None:
124
  predicted_letters.append(current_predicted_label)
125
  last_predicted_label = current_predicted_label
126
 
127
  cap.release()
128
  unique_predicted_letters = list(dict.fromkeys(predicted_letters))
129
  final_output_str = ", ".join(unique_predicted_letters)
130
- # For 'Real-time Prediction' equivalent, let's use the last valid unique prediction or the most frequent
131
  realtime_equivalent_prediction = unique_predicted_letters[-1] if unique_predicted_letters else ""
132
 
133
  return realtime_equivalent_prediction, final_output_str
134
 
135
  except Exception as e:
136
  print(f"Error caught: {e}")
137
- # Modify the return to ensure the traceback is flattened into a single line
138
  error_msg = f"Error processing video: {e}"
139
  full_traceback_flat = traceback.format_exc().replace('\n', ' | ').replace('\r', '')
140
  return error_msg, f"{{error_msg}} (Details: {{full_traceback_flat}})"
 
4
  import time
5
  import traceback
6
  from transformers import SiglipForImageClassification
7
+ from transformers.image_processing_utils import AutoImageProcessor
8
  from PIL import Image
9
  import torch
10
  import cv2
11
+ import numpy as np
12
+ import av
13
+ from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
14
 
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
16
  os.makedirs("/tmp/huggingface", exist_ok=True)
 
25
  print("INFO: Model and processor loaded successfully.")
26
  return model, processor
27
 
 
 
28
  # Define the maximum number of consecutive repetitions allowed for predictions
29
  MAX_CONSECUTIVE_REPETITIONS = 3
30
 
 
35
  "20": "U", "21": "V", "22": "W", "23": "X", "24": "Y", "25": "Z"
36
  }
37
 
38
+ # Initialize all necessary session state variables using conditional checks
39
+ if 'model_obj' not in st.session_state:
40
+ st.session_state.model_obj, st.session_state.processor_obj = load_model_and_processor()
41
+ if 'labels_dict' not in st.session_state:
42
+ st.session_state.labels_dict = labels
43
+ if 'max_consecutive_repetitions_val' not in st.session_state:
44
+ st.session_state.max_consecutive_repetitions_val = MAX_CONSECUTIVE_REPETITIONS
45
 
46
  # Initialize session state for live predictions if not already present
47
  if 'live_realtime_pred' not in st.session_state:
 
66
 
67
  inputs = self.processor(images=img_pil, return_tensors="pt")
68
  with torch.no_grad():
69
+ outputs = self.model(**inputs)
70
  logits = outputs.logits
71
 
72
  predicted_label_index = torch.argmax(logits, dim=1).item()
 
97
  last_predicted_label = None
98
  consecutive_repetitions = 0
99
 
100
+ # Access model, processor, labels, and MAX_CONSECUTIVE_REPETITIONS from session state
101
+ local_model = st.session_state.model_obj
102
+ local_processor = st.session_state.processor_obj
103
+ local_labels = st.session_state.labels_dict
104
+ local_max_consecutive_repetitions = st.session_state.max_consecutive_repetitions_val
105
+
106
  try:
107
  cap = cv2.VideoCapture(video_path)
108
  if not cap.isOpened():
 
113
  if not ret:
114
  break
115
  image = Image.fromarray(frame).convert("RGB")
116
+ inputs = local_processor(images=image, return_tensors="pt")
117
  with torch.no_grad():
118
+ outputs = local_model(**inputs)
119
  logits = outputs.logits
120
  predicted_label_index = torch.argmax(logits, dim=1).item()
121
+ current_predicted_label = local_labels[str(predicted_label_index)]
122
 
123
  # Apply repetition logic
124
  if current_predicted_label == last_predicted_label:
 
126
  else:
127
  consecutive_repetitions = 1
128
 
129
+ if consecutive_repetitions > local_max_consecutive_repetitions or last_predicted_label is None:
130
  predicted_letters.append(current_predicted_label)
131
  last_predicted_label = current_predicted_label
132
 
133
  cap.release()
134
  unique_predicted_letters = list(dict.fromkeys(predicted_letters))
135
  final_output_str = ", ".join(unique_predicted_letters)
 
136
  realtime_equivalent_prediction = unique_predicted_letters[-1] if unique_predicted_letters else ""
137
 
138
  return realtime_equivalent_prediction, final_output_str
139
 
140
  except Exception as e:
141
  print(f"Error caught: {e}")
 
142
  error_msg = f"Error processing video: {e}"
143
  full_traceback_flat = traceback.format_exc().replace('\n', ' | ').replace('\r', '')
144
  return error_msg, f"{{error_msg}} (Details: {{full_traceback_flat}})"