MogensR commited on
Commit
b81a3d7
·
verified ·
1 Parent(s): d4c672c

Update pipeline/video_pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline/video_pipeline.py +2 -2
pipeline/video_pipeline.py CHANGED
@@ -142,12 +142,12 @@ def generate_first_frame_mask(video_path, predictor):
142
  frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
143
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
144
  predictor.set_image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
145
- masks, _, _ = predictor.predict(
146
  point_coords=np.array([[w//2, h//2]]),
147
  point_labels=np.array([1]),
148
  multimask_output=True
149
  )
150
- return (masks[np.argmax(predictor.get_mask_scores())] * 255).astype(np.uint8)
151
  # --- Temporal Smoothing ---
152
  def smooth_alpha_video(alpha_path, output_path, window_size=5):
153
  """Apply temporal smoothing to alpha masks"""
 
142
  frame = cv2.resize(frame, (int(w * scale), int(h * scale)))
143
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
144
  predictor.set_image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
145
+ masks, scores, _ = predictor.predict(
146
  point_coords=np.array([[w//2, h//2]]),
147
  point_labels=np.array([1]),
148
  multimask_output=True
149
  )
150
+ return (masks[np.argmax(scores)] * 255).astype(np.uint8)
151
  # --- Temporal Smoothing ---
152
  def smooth_alpha_video(alpha_path, output_path, window_size=5):
153
  """Apply temporal smoothing to alpha masks"""