MogensR commited on
Commit
d1fd07a
·
1 Parent(s): 63f9af1

Update processing/two_stage/two_stage_processor.py

Browse files
processing/two_stage/two_stage_processor.py CHANGED
@@ -14,13 +14,14 @@
14
 
15
  from __future__ import annotations
16
 
17
- import cv2, numpy as np, os, io, gc, pickle, logging, tempfile, traceback, math, threading
18
  from pathlib import Path
19
  from typing import Optional, Dict, Any, Callable, Tuple, List
20
 
21
  from utils.cv_processing import segment_person_hq, refine_mask_hq
22
 
23
- try: # project logger if available
 
24
  from utils.logger import get_logger
25
  logger = get_logger(__name__)
26
  except Exception:
@@ -28,7 +29,7 @@
28
 
29
 
30
  # ---------------------------------------------------------------------------
31
- # ――― Local video-writer helper (unchanged from your previous file) ―――
32
  # ---------------------------------------------------------------------------
33
  def create_video_writer(output_path: str, fps: float, width: int, height: int, prefer_mp4: bool = True):
34
  try:
@@ -57,7 +58,7 @@ def create_video_writer(output_path: str, fps: float, width: int, height: int, p
57
 
58
 
59
  # ---------------------------------------------------------------------------
60
- # ――― NEW: key-colour helpers (fast, no external deps) ―――
61
  # ---------------------------------------------------------------------------
62
  def _bgr_to_hsv_hue_deg(bgr: np.ndarray) -> np.ndarray:
63
  hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
@@ -95,7 +96,7 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
95
  best_name, best_score = None, -1.0
96
  for name, info in _key_candidates_bgr().items():
97
  cand_hue = info["hue"]
98
- score = min(_hue_distance(cand_hue, th) for th in top_hues)
99
  if score > best_score:
100
  best_name, best_score = name, score
101
  return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
@@ -104,7 +105,7 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
104
 
105
 
106
  # ---------------------------------------------------------------------------
107
- # ――― Chroma presets (same keys, but tolerance now gets overwritten) ―――
108
  # ---------------------------------------------------------------------------
109
  CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
110
  'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
@@ -114,13 +115,14 @@ def _choose_best_key_color(frame_bgr: np.ndarray, mask_uint8: np.ndarray) -> dic
114
 
115
 
116
  # ---------------------------------------------------------------------------
117
- # ――― Two-Stage Processor ―――
118
  # ---------------------------------------------------------------------------
119
  class TwoStageProcessor:
120
  def __init__(self, sam2_predictor=None, matanyone_model=None):
121
  self.sam2 = self._unwrap_sam2(sam2_predictor)
122
  self.matanyone = matanyone_model
123
- self.mask_cache_dir = Path("/tmp/mask_cache"); self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
 
124
  logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
125
 
126
  # ---------------------------------------------------------------------
@@ -132,59 +134,70 @@ def stage1_extract_to_greenscreen(
132
  output_path: str,
133
  *,
134
  key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
135
- progress_callback: Optional[Callable[[float,str],None]] = None,
136
  stop_event: Optional["threading.Event"] = None,
137
  ) -> Tuple[Optional[dict], str]:
138
- def _prog(p,d):
139
- if progress_callback:
140
- try: progress_callback(float(p), str(d)); except Exception: pass
 
 
 
 
141
 
142
  try:
143
  _prog(0.0, "Stage 1: opening video…")
144
  cap = cv2.VideoCapture(video_path)
145
- if not cap.isOpened(): return None, "Could not open input video"
 
146
 
147
  fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
148
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
149
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
150
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
151
 
152
- writer,out_path = create_video_writer(output_path, fps, w, h)
153
- if writer is None:
154
- cap.release(); return None, "Could not create output writer"
 
155
 
156
- key_info : dict | None = None
157
- chosen_bgr = np.array([0,255,0], np.uint8) # default
158
- probe_done = False
159
- masks : List[np.ndarray] = []
160
  frame_idx = 0
161
 
162
- green_bg_template = np.zeros((h,w,3), np.uint8) # we’ll overwrite per-frame
163
 
164
  while True:
165
  if stop_event and stop_event.is_set():
166
- _prog(1.0, "Stage 1: cancelled"); break
 
167
 
168
- ok,frame = cap.read()
169
- if not ok: break
 
170
 
171
  mask = self._get_mask(frame)
172
 
173
- # -------- decide key colour once --------
174
  if not probe_done:
175
  if key_color_mode.lower() == "auto":
176
- key_info = _choose_best_key_color(frame, mask)
177
- chosen_bgr= key_info["bgr"]
178
  else:
179
  cand = _key_candidates_bgr().get(key_color_mode.lower())
180
- chosen_bgr = cand["bgr"] if cand is not None else chosen_bgr
 
181
  probe_done = True
182
  logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
183
 
184
  # optional refine
185
  if self.matanyone and frame_idx % 3 == 0:
186
- try: mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
187
- except Exception as e: logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
 
 
188
 
189
  # composite
190
  green_bg_template[:] = chosen_bgr
@@ -193,18 +206,21 @@ def _prog(p,d):
193
  masks.append(self._to_binary_mask(mask))
194
 
195
  frame_idx += 1
196
- pct = 0.05 + 0.9 * (frame_idx/total) if total else min(0.95, 0.05+frame_idx*0.002)
197
  _prog(pct, f"Stage 1: {frame_idx}/{total or '?'}")
198
 
199
- cap.release(); writer.release()
 
200
 
201
  # save mask cache
202
  try:
203
  cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
204
- with open(cache_file,"wb") as f: pickle.dump(masks,f)
205
- except Exception as e: logger.warning(f"mask cache save fail: {e}")
 
 
206
 
207
- _prog(1.0,"Stage 1: complete")
208
  return (
209
  {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
210
  f"Green-screen video created ({frame_idx} frames)"
@@ -223,186 +239,20 @@ def stage2_greenscreen_to_final(
223
  background: np.ndarray | str,
224
  output_path: str,
225
  *,
226
- chroma_settings: Optional[Dict[str,Any]] = None,
227
- progress_callback: Optional[Callable[[float,str],None]] = None,
228
  stop_event: Optional["threading.Event"] = None,
229
  ) -> Tuple[Optional[str], str]:
230
- def _prog(p,d):
231
- if progress_callback:
232
- try: progress_callback(float(p),str(d)); except Exception: pass
233
-
234
- try:
235
- _prog(0.0,"Stage 2: opening keyed video…")
236
- cap = cv2.VideoCapture(gs_path)
237
- if not cap.isOpened(): return None,"Could not open keyed video"
238
-
239
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
240
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
241
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
242
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
243
-
244
- writer,out_path = create_video_writer(output_path, fps, w, h)
245
- if writer is None: cap.release(); return None,"Could not create output writer"
246
-
247
- # background
248
- if isinstance(background,str):
249
- bg = cv2.imread(background, cv2.IMREAD_COLOR)
250
- if bg is None: cap.release(); writer.release(); return None,"Could not load background"
251
- else: bg = background
252
- bg = cv2.resize(bg,(w,h),interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
253
 
254
- # settings
255
- settings = dict(CHROMA_PRESETS['standard'])
256
- if chroma_settings: settings.update(chroma_settings)
 
 
 
257
 
258
- # load cached masks if any
259
- cache_file = self.mask_cache_dir / (Path(gs_path).stem + "_masks.pkl")
260
- cached_masks = None
261
- if cache_file.exists():
262
- try: cached_masks = pickle.load(open(cache_file,'rb'))
263
- except Exception as e: logger.warning(f"mask cache load fail: {e}")
264
-
265
- frame_idx=0
266
- while True:
267
- if stop_event and stop_event.is_set(): _prog(1.0,"Stage 2: cancelled"); break
268
- ok,frame = cap.read()
269
- if not ok: break
270
-
271
- seg_mask = None
272
- if cached_masks and frame_idx < len(cached_masks):
273
- seg_mask = cached_masks[frame_idx]
274
- else:
275
- seg_mask = self._segmentation_mask_on_stage2(frame)
276
-
277
- composite = self._chroma_key_advanced(frame, bg, settings, seg_mask)
278
-
279
- writer.write(composite)
280
- frame_idx += 1
281
- pct = 0.05 + 0.9*(frame_idx/total) if total else min(0.95,0.05+frame_idx*0.002)
282
- _prog(pct,f"Stage 2: {frame_idx}/{total or '?'}")
283
-
284
- cap.release(); writer.release()
285
- _prog(1.0,"Stage 2: complete")
286
- return out_path, f"Final video created ({frame_idx} frames)"
287
- except Exception as e:
288
- logger.error(f"Stage 2 error: {e}\n{traceback.format_exc()}")
289
- return None, f"Stage 2 failed: {e}"
290
-
291
- # ---------------------------------------------------------------------
292
- # Full pipeline – now passes chosen key into Stage 2
293
- # ---------------------------------------------------------------------
294
- def process_full_pipeline(
295
- self,
296
- video_path: str,
297
- background: np.ndarray | str,
298
- final_output: str,
299
- *,
300
- key_color_mode: str = "auto",
301
- chroma_settings: Optional[Dict[str,Any]] = None,
302
- progress_callback: Optional[Callable[[float,str],None]] = None,
303
- stop_event: Optional["threading.Event"] = None,
304
- ) -> Tuple[Optional[str], str]:
305
- gs_tmp = tempfile.mktemp(suffix="_gs.mp4")
306
  try:
307
- gs_info,msg1 = self.stage1_extract_to_greenscreen(
308
- video_path, gs_tmp,
309
- key_color_mode=key_color_mode,
310
- progress_callback=progress_callback, stop_event=stop_event
311
- )
312
- if gs_info is None: return None,msg1
313
-
314
- # inject key colour into chroma settings for Stage 2
315
- chosen_key = gs_info.get("key_bgr",[0,255,0])
316
- cs = dict(chroma_settings or CHROMA_PRESETS['standard'])
317
- cs['key_color'] = chosen_key
318
-
319
- result,msg2 = self.stage2_greenscreen_to_final(
320
- gs_info["path"], background, final_output,
321
- chroma_settings=cs, progress_callback=progress_callback, stop_event=stop_event
322
- )
323
- return result,msg2
324
- finally:
325
- try: os.remove(gs_tmp)
326
- except Exception: pass
327
- gc.collect()
328
-
329
- # ---------------------------------------------------------------------
330
- # Internal helpers (mostly unchanged + new hybrid / seg)
331
- # ---------------------------------------------------------------------
332
- def _unwrap_sam2(self,obj):
333
- try:
334
- if obj is None: return None
335
- if all(hasattr(obj,attr) for attr in ("set_image","predict")): return obj
336
- for attr in ("model","predictor"):
337
- inner=getattr(obj,attr,None)
338
- if inner and all(hasattr(inner,a) for a in ("set_image","predict")): return inner
339
- except Exception as e: logger.warning(f"SAM2 unwrap fail: {e}")
340
- return None
341
-
342
- def _get_mask(self,frame:np.ndarray)->np.ndarray:
343
- try: return segment_person_hq(frame,self.sam2,fallback_enabled=True)
344
- except Exception as e:
345
- logger.warning(f"Segmentation fallback: {e}")
346
- h,w=frame.shape[:2]; m=np.zeros((h,w),np.uint8); m[h//6:5*h//6,w//4:3*w//4]=255; return m
347
-
348
- # ---------- stage-1 composite (same as before) ----------
349
- def _apply_greenscreen_hard(self,frame,mask,green_bg):
350
- mask_u8=self._to_binary_mask(mask)
351
- mk=cv2.cvtColor(mask_u8,cv2.COLOR_GRAY2BGR).astype(np.float32)/255.0
352
- out=frame.astype(np.float32)*mk+green_bg.astype(np.float32)*(1.0-mk)
353
- return np.clip(out,0,255).astype(np.uint8)
354
-
355
- @staticmethod
356
- def _to_binary_mask(mask:np.ndarray)->np.ndarray:
357
- if mask.ndim==3: mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
358
- if mask.dtype!=np.uint8:
359
- mask=(np.clip(mask,0,1)*255).astype(np.uint8) if mask.max()<=1.0 else np.clip(mask,0,255).astype(np.uint8)
360
- _,binm=cv2.threshold(mask,127,255,cv2.THRESH_BINARY); return binm
361
-
362
- # ---------- segmentation rescue for stage-2 ----------
363
- def _segmentation_mask_on_stage2(self,frame_bgr:np.ndarray)->Optional[np.ndarray]:
364
- try:
365
- if self.sam2 is None: return None
366
- return self._get_mask(frame_bgr)
367
- except Exception: return None
368
-
369
- # ---------- hybrid chroma key ----------
370
- def _chroma_key_advanced(
371
- self,
372
- frame_bgr: np.ndarray,
373
- bg_bgr: np.ndarray,
374
- settings: Dict[str,Any],
375
- seg_mask: Optional[np.ndarray] = None,
376
- )->np.ndarray:
377
- try:
378
- key = np.array(settings.get("key_color",[0,255,0]),dtype=np.float32)
379
- tol = float(settings.get("tolerance",40))
380
- soft = int (settings.get("edge_softness",2))
381
- spill= float(settings.get("spill_suppression",0.3))
382
-
383
- f = frame_bgr.astype(np.float32)
384
- b = bg_bgr.astype(np.float32)
385
-
386
- diff = np.linalg.norm(f-key,axis=2)
387
- alpha = np.clip((diff - tol*0.6) / max(1e-6,tol*0.4), 0.0, 1.0)
388
- if soft>0:
389
- k=soft*2+1; alpha=cv2.GaussianBlur(alpha,(k,k),soft)
390
-
391
- # ---------- segmentation rescue ----------
392
- if seg_mask is not None:
393
- if seg_mask.ndim==3: seg_mask=cv2.cvtColor(seg_mask,cv2.COLOR_BGR2GRAY)
394
- seg = seg_mask.astype(np.float32)/255.0
395
- seg = cv2.GaussianBlur(seg,(5,5),1.0)
396
- alpha=np.clip(np.maximum(alpha,seg*0.85),0.0,1.0)
397
-
398
- # ---------- spill suppression ----------
399
- if spill>0:
400
- zone = 1.0-alpha
401
- g=f[:,:,1]; f[:,:,1]=np.clip(g - g*zone*spill,0,255)
402
-
403
- mask3=np.stack([alpha]*3,axis=2)
404
- out = f*mask3 + b*(1.0-mask3)
405
- return np.clip(out,0,255).astype(np.uint8)
406
- except Exception as e:
407
- logger.error(f"Chroma key error: {e}")
408
- return frame_bgr
 
14
 
15
  from __future__ import annotations
16
 
17
+ import cv2, numpy as np, os, gc, pickle, logging, tempfile, traceback, threading
18
  from pathlib import Path
19
  from typing import Optional, Dict, Any, Callable, Tuple, List
20
 
21
  from utils.cv_processing import segment_person_hq, refine_mask_hq
22
 
23
+ # Project logger if available
24
+ try:
25
  from utils.logger import get_logger
26
  logger = get_logger(__name__)
27
  except Exception:
 
29
 
30
 
31
  # ---------------------------------------------------------------------------
32
+ # Local video-writer helper
33
  # ---------------------------------------------------------------------------
34
  def create_video_writer(output_path: str, fps: float, width: int, height: int, prefer_mp4: bool = True):
35
  try:
 
58
 
59
 
60
  # ---------------------------------------------------------------------------
61
+ # Key-colour helpers (fast, no external deps)
62
  # ---------------------------------------------------------------------------
63
  def _bgr_to_hsv_hue_deg(bgr: np.ndarray) -> np.ndarray:
64
  hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
 
96
  best_name, best_score = None, -1.0
97
  for name, info in _key_candidates_bgr().items():
98
  cand_hue = info["hue"]
99
+ score = min(abs((cand_hue - th + 180) % 360 - 180) for th in top_hues)
100
  if score > best_score:
101
  best_name, best_score = name, score
102
  return _key_candidates_bgr().get(best_name, _key_candidates_bgr()["green"])
 
105
 
106
 
107
  # ---------------------------------------------------------------------------
108
+ # Chroma presets
109
  # ---------------------------------------------------------------------------
110
  CHROMA_PRESETS: Dict[str, Dict[str, Any]] = {
111
  'standard': {'key_color': [0,255,0], 'tolerance': 38, 'edge_softness': 2, 'spill_suppression': 0.35},
 
115
 
116
 
117
  # ---------------------------------------------------------------------------
118
+ # Two-Stage Processor
119
  # ---------------------------------------------------------------------------
120
  class TwoStageProcessor:
121
  def __init__(self, sam2_predictor=None, matanyone_model=None):
122
  self.sam2 = self._unwrap_sam2(sam2_predictor)
123
  self.matanyone = matanyone_model
124
+ self.mask_cache_dir = Path("/tmp/mask_cache")
125
+ self.mask_cache_dir.mkdir(parents=True, exist_ok=True)
126
  logger.info(f"TwoStageProcessor init – SAM2: {self.sam2 is not None} | MatAnyOne: {self.matanyone is not None}")
127
 
128
  # ---------------------------------------------------------------------
 
134
  output_path: str,
135
  *,
136
  key_color_mode: str = "auto", # "auto" | "green" | "blue" | "cyan" | "magenta"
137
+ progress_callback: Optional[Callable[[float, str], None]] = None,
138
  stop_event: Optional["threading.Event"] = None,
139
  ) -> Tuple[Optional[dict], str]:
140
+
141
+ def _prog(p, d):
142
+ if progress_callback:
143
+ try:
144
+ progress_callback(float(p), str(d))
145
+ except Exception:
146
+ pass
147
 
148
  try:
149
  _prog(0.0, "Stage 1: opening video…")
150
  cap = cv2.VideoCapture(video_path)
151
+ if not cap.isOpened():
152
+ return None, "Could not open input video"
153
 
154
  fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
155
  total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
156
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
157
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
158
 
159
+ writer, out_path = create_video_writer(output_path, fps, w, h)
160
+ if writer is None:
161
+ cap.release()
162
+ return None, "Could not create output writer"
163
 
164
+ key_info: dict | None = None
165
+ chosen_bgr = np.array([0, 255, 0], np.uint8) # default
166
+ probe_done = False
167
+ masks: List[np.ndarray] = []
168
  frame_idx = 0
169
 
170
+ green_bg_template = np.zeros((h, w, 3), np.uint8) # overwritten per-frame
171
 
172
  while True:
173
  if stop_event and stop_event.is_set():
174
+ _prog(1.0, "Stage 1: cancelled")
175
+ break
176
 
177
+ ok, frame = cap.read()
178
+ if not ok:
179
+ break
180
 
181
  mask = self._get_mask(frame)
182
 
183
+ # decide key colour once
184
  if not probe_done:
185
  if key_color_mode.lower() == "auto":
186
+ key_info = _choose_best_key_color(frame, mask)
187
+ chosen_bgr = key_info["bgr"]
188
  else:
189
  cand = _key_candidates_bgr().get(key_color_mode.lower())
190
+ if cand is not None:
191
+ chosen_bgr = cand["bgr"]
192
  probe_done = True
193
  logger.info(f"[TwoStage] Using key colour: {key_color_mode} → {chosen_bgr.tolist()}")
194
 
195
  # optional refine
196
  if self.matanyone and frame_idx % 3 == 0:
197
+ try:
198
+ mask = refine_mask_hq(frame, mask, self.matanyone, fallback_enabled=True)
199
+ except Exception as e:
200
+ logger.warning(f"MatAnyOne refine fail f={frame_idx}: {e}")
201
 
202
  # composite
203
  green_bg_template[:] = chosen_bgr
 
206
  masks.append(self._to_binary_mask(mask))
207
 
208
  frame_idx += 1
209
+ pct = 0.05 + 0.9 * (frame_idx / total) if total else min(0.95, 0.05 + frame_idx * 0.002)
210
  _prog(pct, f"Stage 1: {frame_idx}/{total or '?'}")
211
 
212
+ cap.release()
213
+ writer.release()
214
 
215
  # save mask cache
216
  try:
217
  cache_file = self.mask_cache_dir / (Path(out_path).stem + "_masks.pkl")
218
+ with open(cache_file, "wb") as f:
219
+ pickle.dump(masks, f)
220
+ except Exception as e:
221
+ logger.warning(f"mask cache save fail: {e}")
222
 
223
+ _prog(1.0, "Stage 1: complete")
224
  return (
225
  {"path": out_path, "frames": frame_idx, "key_bgr": chosen_bgr.tolist()},
226
  f"Green-screen video created ({frame_idx} frames)"
 
239
  background: np.ndarray | str,
240
  output_path: str,
241
  *,
242
+ chroma_settings: Optional[Dict[str, Any]] = None,
243
+ progress_callback: Optional[Callable[[float, str], None]] = None,
244
  stop_event: Optional["threading.Event"] = None,
245
  ) -> Tuple[Optional[str], str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ def _prog(p, d):
248
+ if progress_callback:
249
+ try:
250
+ progress_callback(float(p), str(d))
251
+ except Exception:
252
+ pass
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  try:
255
+ _prog(0.0, "Stage 2: opening keyed video…")
256
+ cap = cv2.VideoCapture(gs_path)
257
+ if not cap.isOpened():
258
+ return None, "