liamsch commited on
Commit
e34628e
·
1 Parent(s): 8b1eedf

move ffmpeg command out of GPU decorator

Browse files
Files changed (1) hide show
  1. gradio_demo.py +59 -47
gradio_demo.py CHANGED
@@ -144,60 +144,72 @@ def process_image(image: np.ndarray) -> Image.Image:
144
 
145
 
146
  @spaces.GPU
147
- def process_video(video_path: str, progress=gr.Progress()) -> str:
148
  """
149
- Process a video and return path to the rendered output video using background threads.
 
150
  """
151
  # Initialize models on first use (lazy loading for @spaces.GPU)
152
  initialize_models()
153
 
154
- temp_dir = Path(tempfile.mkdtemp())
155
  render_size = 512
156
- try:
157
- # Prepare dataset and dataloader
158
- dataset = VideoFrameDataset(video_path, fa_model)
159
- dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
160
- fps = dataset.fps
161
- num_frames = len(dataset)
162
- # Prepare rendering thread and queue
163
- render_queue = Queue(maxsize=32)
164
- num_render_workers = 1
165
- rendering_threads = []
166
- for _ in range(num_render_workers):
167
- thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size)
168
- thread.start()
169
- rendering_threads.append(thread)
170
- progress(0, desc="Processing video frames...")
171
- frame_idx = 0
172
- with torch.no_grad():
173
- for batch in dataloader:
174
- images = batch["image"].to(device)
175
- cropped_frames = batch["cropped_frame"]
176
- # Run inference
177
- predictions = sheap_model(images)
178
- verts = flame(
179
- shape=predictions["shape_from_facenet"],
180
- expression=predictions["expr"],
181
- pose=pose_components_to_rotmats(predictions),
182
- eyelids=predictions["eyelids"],
183
- translation=predictions["cam_trans"],
 
 
 
 
 
 
 
184
  )
185
- verts = verts.cpu()
186
- for i in range(images.shape[0]):
187
- cropped_frame = _tensor_to_numpy_image(cropped_frames[i])
188
- render_queue.put((frame_idx, cropped_frame, verts[i]))
189
- frame_idx += 1
190
- progress(
191
- frame_idx / num_frames, desc=f"Processing frame {frame_idx}/{num_frames}"
192
- )
193
- # Stop rendering threads
194
- for _ in range(num_render_workers):
195
- render_queue.put(None)
196
- for thread in rendering_threads:
197
- thread.join()
198
- if frame_idx == 0:
199
- raise ValueError("No frames were successfully processed!")
200
- # Create output video using ffmpeg
 
 
 
 
 
201
  progress(0.95, desc="Encoding video...")
202
  output_path = temp_dir / "output.mp4"
203
  ffmpeg_cmd = [
 
144
 
145
 
146
  @spaces.GPU
147
+ def process_video_frames(video_path: str, temp_dir: Path, progress=gr.Progress()):
148
  """
149
+ Process video frames with GPU (inference and rendering).
150
+ Returns fps and number of frames processed.
151
  """
152
  # Initialize models on first use (lazy loading for @spaces.GPU)
153
  initialize_models()
154
 
 
155
  render_size = 512
156
+ # Prepare dataset and dataloader
157
+ dataset = VideoFrameDataset(video_path, fa_model)
158
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
159
+ fps = dataset.fps
160
+ num_frames = len(dataset)
161
+ # Prepare rendering thread and queue
162
+ render_queue = Queue(maxsize=32)
163
+ num_render_workers = 1
164
+ rendering_threads = []
165
+ for _ in range(num_render_workers):
166
+ thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size)
167
+ thread.start()
168
+ rendering_threads.append(thread)
169
+ progress(0, desc="Processing video frames...")
170
+ frame_idx = 0
171
+ with torch.no_grad():
172
+ for batch in dataloader:
173
+ images = batch["image"].to(device)
174
+ cropped_frames = batch["cropped_frame"]
175
+ # Run inference
176
+ predictions = sheap_model(images)
177
+ verts = flame(
178
+ shape=predictions["shape_from_facenet"],
179
+ expression=predictions["expr"],
180
+ pose=pose_components_to_rotmats(predictions),
181
+ eyelids=predictions["eyelids"],
182
+ translation=predictions["cam_trans"],
183
+ )
184
+ verts = verts.cpu()
185
+ for i in range(images.shape[0]):
186
+ cropped_frame = _tensor_to_numpy_image(cropped_frames[i])
187
+ render_queue.put((frame_idx, cropped_frame, verts[i]))
188
+ frame_idx += 1
189
+ progress(
190
+ frame_idx / num_frames, desc=f"Processing frame {frame_idx}/{num_frames}"
191
  )
192
+ # Stop rendering threads
193
+ for _ in range(num_render_workers):
194
+ render_queue.put(None)
195
+ for thread in rendering_threads:
196
+ thread.join()
197
+ if frame_idx == 0:
198
+ raise ValueError("No frames were successfully processed!")
199
+
200
+ return fps, frame_idx
201
+
202
+
203
+ def process_video(video_path: str, progress=gr.Progress()) -> str:
204
+ """
205
+ Process a video and return path to the rendered output video.
206
+ """
207
+ temp_dir = Path(tempfile.mkdtemp())
208
+ try:
209
+ # Process frames with GPU
210
+ fps, num_frames = process_video_frames(video_path, temp_dir, progress)
211
+
212
+ # Create output video using ffmpeg (CPU-only, outside GPU context)
213
  progress(0.95, desc="Encoding video...")
214
  output_path = temp_dir / "output.mp4"
215
  ffmpeg_cmd = [