liamsch commited on
Commit
8b1eedf
·
1 Parent(s): 8455092

speed up video processing by putting frame loading on bg thread

Browse files
Files changed (2) hide show
  1. gradio_demo.py +2 -11
  2. video_demo.py +87 -36
gradio_demo.py CHANGED
@@ -26,6 +26,7 @@ import torch.hub
26
  import torchvision.transforms.functional as TF
27
  from PIL import Image
28
  from torch.utils.data import DataLoader
 
29
 
30
  try:
31
  import spaces
@@ -41,13 +42,7 @@ except ImportError:
41
  from demo import create_rendering_image
42
  from sheap import load_sheap_model
43
  from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats
44
-
45
- try:
46
- import face_alignment
47
- except ImportError:
48
- raise ImportError(
49
- "The 'face_alignment' package is required. Please install it via 'pip install face-alignment'."
50
- )
51
  from sheap.fa_landmark_utils import detect_face_and_crop
52
 
53
  # Global variables for models (load once)
@@ -148,10 +143,6 @@ def process_image(image: np.ndarray) -> Image.Image:
148
  return combined
149
 
150
 
151
- # --- Import video utilities from video_demo.py ---
152
- from video_demo import RenderingThread, VideoFrameDataset, _tensor_to_numpy_image
153
-
154
-
155
  @spaces.GPU
156
  def process_video(video_path: str, progress=gr.Progress()) -> str:
157
  """
 
26
  import torchvision.transforms.functional as TF
27
  from PIL import Image
28
  from torch.utils.data import DataLoader
29
+ import face_alignment
30
 
31
  try:
32
  import spaces
 
42
  from demo import create_rendering_image
43
  from sheap import load_sheap_model
44
  from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats
45
+ from video_demo import RenderingThread, VideoFrameDataset, _tensor_to_numpy_image
 
 
 
 
 
 
46
  from sheap.fa_landmark_utils import detect_face_and_crop
47
 
48
  # Global variables for models (load once)
 
143
  return combined
144
 
145
 
 
 
 
 
146
  @spaces.GPU
147
  def process_video(video_path: str, progress=gr.Progress()) -> str:
148
  """
video_demo.py CHANGED
@@ -106,13 +106,17 @@ class RenderingThread(threading.Thread):
106
 
107
 
108
  class VideoFrameDataset(IterableDataset):
109
- """Iterable dataset for streaming video frames with face detection and cropping."""
 
 
 
110
 
111
  def __init__(
112
  self,
113
  video_path: str,
114
  fa_model: face_alignment.FaceAlignment,
115
  smoothing_alpha: float = 0.3,
 
116
  ):
117
  """
118
  Initialize video frame dataset.
@@ -122,11 +126,13 @@ class VideoFrameDataset(IterableDataset):
122
  fa_model: FaceAlignment model instance for face detection
123
  smoothing_alpha: Smoothing factor for bounding box (0=no smoothing, 1=no change).
124
  Lower values = more smoothing
 
125
  """
126
  super().__init__()
127
  self.video_path = video_path
128
  self.fa_model = fa_model
129
  self.smoothing_alpha = smoothing_alpha
 
130
  self.prev_bbox: Optional[Tuple[int, int, int, int]] = None
131
 
132
  # Get video metadata (don't keep capture open)
@@ -144,9 +150,43 @@ class VideoFrameDataset(IterableDataset):
144
  f"Video info: {self.num_frames} frames, {self.fps:.2f} fps, {self.width}x{self.height}"
145
  )
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def __iter__(self):
148
  """
149
  Iterate through video frames sequentially.
 
 
 
150
 
151
  Yields:
152
  Dictionary containing frame_idx, processed image, and bounding box
@@ -154,48 +194,59 @@ class VideoFrameDataset(IterableDataset):
154
  # Reset smoothing state for new iteration
155
  self.prev_bbox = None
156
 
157
- # Open video capture for this iteration
158
- cap = cv2.VideoCapture(self.video_path)
159
- if not cap.isOpened():
160
- raise RuntimeError(f"Could not open video file: {self.video_path}")
161
-
162
- frame_idx = 0
163
- while True:
164
- # Read frame
165
- ret, frame_bgr = cap.read()
166
- if not ret:
167
- break
168
-
169
- # Convert BGR to RGB
170
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
171
-
172
- # Convert to torch tensor (C, H, W) with values in [0, 1]
173
- image = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
 
 
 
 
 
 
 
174
 
175
- # Detect face and crop
176
- bbox = detect_face_and_crop(image, self.fa_model, margin=0.9, shift_up=0.5)
177
 
178
- # Apply smoothing using exponential moving average
179
- bbox = self._smooth_bbox(bbox)
180
- x0, y0, x1, y1 = bbox
181
 
182
- cropped = image[:, y0:y1, x0:x1]
 
 
183
 
184
- # Resize to 224x224 for SHEAP model
185
- cropped_resized = TF.resize(cropped, [224, 224], antialias=True)
186
- cropped_for_render = TF.resize(cropped, [512, 512], antialias=True)
187
 
188
- yield {
189
- "frame_idx": frame_idx,
190
- "image": cropped_resized,
191
- "bbox": bbox,
192
- "original_frame": frame_rgb, # Keep original for reference (as numpy array)
193
- "cropped_frame": cropped_for_render, # Cropped region resized to 512x512
194
- }
195
 
196
- frame_idx += 1
 
 
 
 
 
 
197
 
198
- cap.release()
 
 
 
199
 
200
  def _smooth_bbox(self, bbox: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
201
  """Apply exponential moving average smoothing to bounding box."""
 
106
 
107
 
108
  class VideoFrameDataset(IterableDataset):
109
+ """Iterable dataset for streaming video frames with face detection and cropping.
110
+
111
+ Uses a background thread for video frame loading while face detection runs in the main thread.
112
+ """
113
 
114
  def __init__(
115
  self,
116
  video_path: str,
117
  fa_model: face_alignment.FaceAlignment,
118
  smoothing_alpha: float = 0.3,
119
+ frame_buffer_size: int = 32,
120
  ):
121
  """
122
  Initialize video frame dataset.
 
126
  fa_model: FaceAlignment model instance for face detection
127
  smoothing_alpha: Smoothing factor for bounding box (0=no smoothing, 1=no change).
128
  Lower values = more smoothing
129
+ frame_buffer_size: Size of the frame buffer queue for the background thread
130
  """
131
  super().__init__()
132
  self.video_path = video_path
133
  self.fa_model = fa_model
134
  self.smoothing_alpha = smoothing_alpha
135
+ self.frame_buffer_size = frame_buffer_size
136
  self.prev_bbox: Optional[Tuple[int, int, int, int]] = None
137
 
138
  # Get video metadata (don't keep capture open)
 
150
  f"Video info: {self.num_frames} frames, {self.fps:.2f} fps, {self.width}x{self.height}"
151
  )
152
 
153
+ def _video_reader_thread(self, frame_queue: Queue, stop_event: threading.Event):
154
+ """Background thread that reads video frames and puts them in a queue.
155
+
156
+ Args:
157
+ frame_queue: Queue to put (frame_idx, frame_rgb) tuples
158
+ stop_event: Event to signal thread to stop
159
+ """
160
+ cap = cv2.VideoCapture(self.video_path)
161
+ if not cap.isOpened():
162
+ frame_queue.put(("error", f"Could not open video file: {self.video_path}"))
163
+ return
164
+
165
+ frame_idx = 0
166
+ try:
167
+ while not stop_event.is_set():
168
+ ret, frame_bgr = cap.read()
169
+ if not ret:
170
+ break
171
+
172
+ # Convert BGR to RGB
173
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
174
+
175
+ # Put frame in queue (blocks if queue is full)
176
+ frame_queue.put((frame_idx, frame_rgb))
177
+ frame_idx += 1
178
+
179
+ finally:
180
+ cap.release()
181
+ # Signal end of video
182
+ frame_queue.put(None)
183
+
184
  def __iter__(self):
185
  """
186
  Iterate through video frames sequentially.
187
+
188
+ Video frame loading happens in a background thread, while face detection
189
+ and processing happen in the main thread.
190
 
191
  Yields:
192
  Dictionary containing frame_idx, processed image, and bounding box
 
194
  # Reset smoothing state for new iteration
195
  self.prev_bbox = None
196
 
197
+ # Create queue and start background thread for video reading
198
+ frame_queue = Queue(maxsize=self.frame_buffer_size)
199
+ stop_event = threading.Event()
200
+ reader_thread = threading.Thread(
201
+ target=self._video_reader_thread,
202
+ args=(frame_queue, stop_event),
203
+ daemon=True
204
+ )
205
+ reader_thread.start()
206
+
207
+ try:
208
+ while True:
209
+ # Get frame from background thread
210
+ item = frame_queue.get()
211
+
212
+ # Check for end of video
213
+ if item is None:
214
+ break
215
+
216
+ # Check for error
217
+ if isinstance(item, tuple) and len(item) == 2 and item[0] == "error":
218
+ raise RuntimeError(item[1])
219
+
220
+ frame_idx, frame_rgb = item
221
 
222
+ # Convert to torch tensor (C, H, W) with values in [0, 1]
223
+ image = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
224
 
225
+ # Detect face and crop (runs in main thread, can use GPU)
226
+ bbox = detect_face_and_crop(image, self.fa_model, margin=0.9, shift_up=0.5)
 
227
 
228
+ # Apply smoothing using exponential moving average
229
+ bbox = self._smooth_bbox(bbox)
230
+ x0, y0, x1, y1 = bbox
231
 
232
+ cropped = image[:, y0:y1, x0:x1]
 
 
233
 
234
+ # Resize to 224x224 for SHEAP model
235
+ cropped_resized = TF.resize(cropped, [224, 224], antialias=True)
236
+ cropped_for_render = TF.resize(cropped, [512, 512], antialias=True)
 
 
 
 
237
 
238
+ yield {
239
+ "frame_idx": frame_idx,
240
+ "image": cropped_resized,
241
+ "bbox": bbox,
242
+ "original_frame": frame_rgb, # Keep original for reference (as numpy array)
243
+ "cropped_frame": cropped_for_render, # Cropped region resized to 512x512
244
+ }
245
 
246
+ finally:
247
+ # Clean up background thread
248
+ stop_event.set()
249
+ reader_thread.join(timeout=1.0)
250
 
251
  def _smooth_bbox(self, bbox: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
252
  """Apply exponential moving average smoothing to bounding box."""