MogensR commited on
Commit
dd52427
·
verified ·
1 Parent(s): ec9ba45

Delete pipeline/integrated_pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline/integrated_pipeline.py +0 -586
pipeline/integrated_pipeline.py DELETED
@@ -1,586 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- integrated_pipeline.py - Two-stage pipeline with proper GPU management and error handling
4
- """
5
- import os
6
- import sys
7
- import gc
8
- import json
9
- import subprocess
10
- import tempfile
11
- from pathlib import Path
12
- from typing import Dict, Any, Optional, Tuple
13
- import numpy as np
14
- import cv2
15
- import torch
16
- import logging
17
- import shutil
18
- import traceback
19
- from concurrent.futures import ThreadPoolExecutor, TimeoutError
20
-
21
- # --- Project Setup ---
22
- current_dir = Path(__file__).parent
23
- parent_dir = current_dir.parent
24
- sys.path.append(str(parent_dir))
25
-
26
- # --- Logging Configuration ---
27
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
- logger = logging.getLogger(__name__)
29
-
30
- # --- TwoStageProcessor Class ---
31
- class TwoStageProcessor:
32
- def __init__(self, temp_dir: Optional[str] = None):
33
- self.temp_dir = Path(temp_dir) if temp_dir else Path(tempfile.mkdtemp())
34
- self.temp_dir.mkdir(exist_ok=True)
35
- logger.info(f"Initialized temp_dir: {self.temp_dir}")
36
-
37
- # Stage outputs
38
- self.masks_path = self.temp_dir / "masks.mkv"
39
- self.metadata_path = self.temp_dir / "meta.json"
40
-
41
- # Ensure GPU is available and set as default
42
- if torch.cuda.is_available():
43
- torch.cuda.set_device(0)
44
- logger.info(f"GPU set as default device: {torch.cuda.get_device_name(0)}")
45
- else:
46
- logger.warning("CUDA not available, using CPU")
47
-
48
- def process_video(self, input_video: str, background_video: str,
49
- click_points: list, output_path: str,
50
- use_matanyone: bool = True, progress_callback=None) -> bool:
51
- """Main entry point - two-stage processing"""
52
- try:
53
- logger.info("="*60)
54
- logger.info("STARTING TWO-STAGE PROCESSING")
55
- logger.info("="*60)
56
-
57
- # Log disk usage
58
- free_gb = shutil.disk_usage(self.temp_dir).free/1e9
59
- logger.info(f"Disk free: {free_gb:.2f}GB")
60
-
61
- if free_gb < 5.0:
62
- logger.error("Insufficient disk space (need at least 5GB)")
63
- return False
64
-
65
- # Stage 1: Generate masks
66
- logger.info("="*60)
67
- logger.info("STAGE 1: MASK GENERATION")
68
- logger.info("="*60)
69
-
70
- if progress_callback:
71
- progress_callback(0.05, "Stage 1: Starting SAM2...")
72
-
73
- stage1_success = self._stage1_generate_masks(input_video, click_points, progress_callback)
74
-
75
- if not stage1_success:
76
- logger.error("STAGE 1 FAILED - Aborting")
77
- return False
78
-
79
- logger.info("STAGE 1 COMPLETED SUCCESSFULLY")
80
-
81
- # Verify stage 1 outputs exist
82
- if not self.masks_path.exists():
83
- logger.error(f"Masks file not found: {self.masks_path}")
84
- return False
85
-
86
- if not self.metadata_path.exists():
87
- logger.error(f"Metadata file not found: {self.metadata_path}")
88
- return False
89
-
90
- masks_size = self.masks_path.stat().st_size / 1e6
91
- logger.info(f"Masks file size: {masks_size:.2f}MB")
92
-
93
- # Force GPU cleanup before Stage 2
94
- if torch.cuda.is_available():
95
- torch.cuda.empty_cache()
96
- torch.cuda.synchronize()
97
- logger.info(f"GPU memory after Stage 1: {torch.cuda.memory_allocated()/1e9:.2f}GB")
98
-
99
- gc.collect()
100
-
101
- # Stage 2: Process and composite
102
- logger.info("="*60)
103
- logger.info("STAGE 2: MATTING & COMPOSITING")
104
- logger.info("="*60)
105
-
106
- if progress_callback:
107
- progress_callback(0.5, "Stage 2: Starting MatAnyone...")
108
-
109
- stage2_success = self._stage2_composite(
110
- input_video, background_video,
111
- output_path, use_matanyone, progress_callback
112
- )
113
-
114
- if not stage2_success:
115
- logger.error("STAGE 2 FAILED")
116
- return False
117
-
118
- logger.info("STAGE 2 COMPLETED SUCCESSFULLY")
119
- logger.info("="*60)
120
- logger.info("TWO-STAGE PROCESSING COMPLETE")
121
- logger.info("="*60)
122
-
123
- return True
124
-
125
- except Exception as e:
126
- logger.error(f"Two-stage processing exception: {str(e)}")
127
- logger.error(traceback.format_exc())
128
- return False
129
-
130
- def _stage1_generate_masks(self, input_video: str, click_points: list,
131
- progress_callback=None) -> bool:
132
- """Stage 1: SAM2 mask generation"""
133
- predictor = None
134
- inference_state = None
135
- ffmpeg_process = None
136
-
137
- try:
138
- logger.info("Loading SAM2...")
139
-
140
- # Use the SAM2Predictor wrapper
141
- from models.sam2_loader import SAM2Predictor
142
-
143
- # Force GPU device
144
- if torch.cuda.is_available():
145
- device = torch.device("cuda:0")
146
- torch.cuda.set_device(0)
147
- else:
148
- device = torch.device("cpu")
149
-
150
- predictor = SAM2Predictor(device=device, model_size="large")
151
-
152
- if torch.cuda.is_available():
153
- logger.info(f"SAM2 loaded on GPU: {torch.cuda.get_device_name(0)}")
154
- logger.info(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
155
- else:
156
- logger.info("SAM2 loaded on CPU")
157
-
158
- # Get video info
159
- cap = cv2.VideoCapture(input_video)
160
- if not cap.isOpened():
161
- raise RuntimeError(f"Cannot open video: {input_video}")
162
-
163
- fps = cap.get(cv2.CAP_PROP_FPS)
164
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
165
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
166
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
167
- cap.release()
168
-
169
- logger.info(f"Video: {width}x{height}, {frame_count} frames @ {fps:.2f}fps")
170
-
171
- # Save metadata
172
- metadata = {
173
- "fps": fps,
174
- "frame_count": frame_count,
175
- "width": width,
176
- "height": height,
177
- "click_points": click_points
178
- }
179
- with open(self.metadata_path, 'w') as f:
180
- json.dump(metadata, f, indent=2)
181
- logger.info(f"Metadata saved")
182
-
183
- # Initialize inference
184
- logger.info("Initializing SAM2 inference state...")
185
- inference_state = predictor.init_state(video_path=input_video)
186
- logger.info("Inference state initialized")
187
-
188
- # Add prompts
189
- logger.info(f"Adding {len(click_points)} prompts...")
190
- for i, point in enumerate(click_points):
191
- x, y = point
192
- points = np.array([[x * width, y * height]], dtype=np.float32)
193
- labels = np.array([1], dtype=np.int32)
194
- predictor.add_new_points_or_box(
195
- inference_state=inference_state,
196
- frame_idx=0,
197
- obj_id=i,
198
- points=points,
199
- labels=labels,
200
- )
201
- logger.info("Prompts added")
202
-
203
- # Setup FFmpeg for lossless encoding
204
- ffmpeg_cmd = [
205
- 'ffmpeg', '-y', '-f', 'rawvideo',
206
- '-pix_fmt', 'gray', '-s', f'{width}x{height}',
207
- '-r', str(fps), '-i', '-',
208
- '-c:v', 'ffv1', '-level', '3', '-pix_fmt', 'gray',
209
- str(self.masks_path)
210
- ]
211
-
212
- logger.info("Starting FFmpeg...")
213
- ffmpeg_process = subprocess.Popen(
214
- ffmpeg_cmd,
215
- stdin=subprocess.PIPE,
216
- stderr=subprocess.PIPE,
217
- stdout=subprocess.PIPE
218
- )
219
- logger.info("FFmpeg started")
220
-
221
- # Generate and stream masks
222
- logger.info(f"Propagating masks through {frame_count} frames...")
223
- frame_idx = 0
224
-
225
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
226
- # Progress update
227
- if progress_callback and (out_frame_idx % 30 == 0 or out_frame_idx == frame_count - 1):
228
- progress = 0.05 + (out_frame_idx + 1) / frame_count * 0.4 # 5% to 45%
229
- progress_callback(progress, f"SAM2: Frame {out_frame_idx + 1}/{frame_count}")
230
-
231
- # Combine masks
232
- combined_mask = np.zeros((height, width), dtype=np.uint8)
233
-
234
- if isinstance(out_obj_ids, torch.Tensor):
235
- obj_ids = out_obj_ids.cpu().numpy()
236
- else:
237
- obj_ids = out_obj_ids
238
-
239
- for i, obj_id in enumerate(obj_ids):
240
- if i < len(out_mask_logits):
241
- mask = (out_mask_logits[i] > 0.0)
242
- if isinstance(mask, torch.Tensor):
243
- mask = mask.cpu().numpy()
244
- mask = mask.squeeze().astype(np.uint8) * 255
245
- combined_mask = np.maximum(combined_mask, mask)
246
-
247
- # Write to FFmpeg
248
- try:
249
- ffmpeg_process.stdin.write(combined_mask.tobytes())
250
- except BrokenPipeError:
251
- logger.error("FFmpeg pipe broken")
252
- return False
253
-
254
- frame_idx = out_frame_idx
255
-
256
- # Memory management every 50 frames
257
- if (out_frame_idx + 1) % 50 == 0:
258
- if torch.cuda.is_available():
259
- torch.cuda.empty_cache()
260
- gc.collect()
261
-
262
- logger.info(f"Processed {frame_idx + 1} frames")
263
-
264
- # Close FFmpeg
265
- logger.info("Finalizing FFmpeg...")
266
- ffmpeg_process.stdin.close()
267
-
268
- # Wait for FFmpeg to finish (increased timeout)
269
- try:
270
- ffmpeg_process.wait(timeout=300) # 5 minutes timeout
271
- except subprocess.TimeoutExpired:
272
- logger.error("FFmpeg timeout after 5 minutes")
273
- ffmpeg_process.kill()
274
- return False
275
-
276
- if ffmpeg_process.returncode != 0:
277
- error = ffmpeg_process.stderr.read().decode()
278
- logger.error(f"FFmpeg failed: {error}")
279
- return False
280
-
281
- logger.info("FFmpeg completed successfully")
282
-
283
- # Verify output
284
- if not self.masks_path.exists():
285
- logger.error("Masks file was not created")
286
- return False
287
-
288
- return True
289
-
290
- except Exception as e:
291
- logger.error(f"Stage 1 exception: {str(e)}")
292
- logger.error(traceback.format_exc())
293
- return False
294
-
295
- finally:
296
- # CRITICAL: Complete cleanup
297
- logger.info("Cleaning up Stage 1...")
298
-
299
- if ffmpeg_process is not None:
300
- try:
301
- ffmpeg_process.kill()
302
- except:
303
- pass
304
-
305
- if predictor is not None:
306
- del predictor
307
- if inference_state is not None:
308
- del inference_state
309
-
310
- if torch.cuda.is_available():
311
- torch.cuda.empty_cache()
312
- torch.cuda.synchronize()
313
- logger.info(f"GPU memory after cleanup: {torch.cuda.memory_allocated()/1e9:.2f}GB")
314
-
315
- gc.collect()
316
- logger.info("Stage 1 cleanup complete")
317
-
318
- def _stage2_composite(self, input_video: str, background_video: str,
319
- output_path: str, use_matanyone: bool, progress_callback=None) -> bool:
320
- """Stage 2: Read masks, refine, and composite"""
321
- try:
322
- # Load metadata
323
- with open(self.metadata_path, 'r') as f:
324
- metadata = json.load(f)
325
- logger.info(f"Metadata loaded")
326
-
327
- frame_count = metadata["frame_count"]
328
-
329
- # Read masks
330
- if progress_callback:
331
- progress_callback(0.5, "Reading masks...")
332
-
333
- logger.info("Reading mask stream...")
334
- masks = self._read_mask_stream()
335
-
336
- if masks is None or len(masks) == 0:
337
- logger.error("Failed to read masks")
338
- return False
339
-
340
- logger.info(f"Read {len(masks)} masks")
341
-
342
- # MatAnyone refinement
343
- if use_matanyone:
344
- if progress_callback:
345
- progress_callback(0.6, "Refining with MatAnyone...")
346
-
347
- logger.info("Starting MatAnyone refinement...")
348
- refined_masks = self._refine_with_matanyone(input_video, masks, progress_callback)
349
-
350
- if refined_masks is not None and len(refined_masks) > 0:
351
- masks = refined_masks
352
- logger.info(f"Using {len(refined_masks)} refined masks")
353
- else:
354
- logger.warning("MatAnyone failed, using SAM2 masks")
355
-
356
- # Final composition
357
- if progress_callback:
358
- progress_callback(0.8, "Compositing final video...")
359
-
360
- logger.info("Starting final composition...")
361
- return self._composite_final_video(
362
- input_video, background_video,
363
- masks, output_path, metadata, progress_callback
364
- )
365
-
366
- except Exception as e:
367
- logger.error(f"Stage 2 exception: {str(e)}")
368
- logger.error(traceback.format_exc())
369
- return False
370
-
371
- def _read_mask_stream(self) -> Optional[list]:
372
- """Read masks from FFV1 stream"""
373
- try:
374
- with open(self.metadata_path, 'r') as f:
375
- metadata = json.load(f)
376
-
377
- width = metadata["width"]
378
- height = metadata["height"]
379
- frame_count = metadata["frame_count"]
380
-
381
- logger.info(f"Reading {frame_count} masks ({width}x{height})...")
382
-
383
- # FFmpeg decode
384
- ffmpeg_cmd = [
385
- 'ffmpeg', '-i', str(self.masks_path),
386
- '-f', 'rawvideo', '-pix_fmt', 'gray', '-'
387
- ]
388
-
389
- process = subprocess.Popen(
390
- ffmpeg_cmd,
391
- stdout=subprocess.PIPE,
392
- stderr=subprocess.PIPE
393
- )
394
-
395
- masks = []
396
- frame_size = width * height
397
-
398
- for frame_idx in range(frame_count):
399
- frame_data = process.stdout.read(frame_size)
400
-
401
- if len(frame_data) != frame_size:
402
- logger.error(f"Unexpected frame size at {frame_idx}: {len(frame_data)} vs {frame_size}")
403
- break
404
-
405
- mask = np.frombuffer(frame_data, dtype=np.uint8).reshape((height, width))
406
- masks.append(mask)
407
-
408
- process.stdout.close()
409
- process.wait(timeout=60)
410
-
411
- if process.returncode != 0:
412
- error = process.stderr.read().decode()
413
- logger.error(f"FFmpeg decode error: {error}")
414
- return None
415
-
416
- logger.info(f"Successfully read {len(masks)} masks")
417
- return masks
418
-
419
- except Exception as e:
420
- logger.error(f"Mask reading exception: {str(e)}")
421
- logger.error(traceback.format_exc())
422
- return None
423
-
424
- def _refine_with_matanyone(self, input_video: str, masks: list, progress_callback=None) -> Optional[list]:
425
- """Apply MatAnyone refinement"""
426
- try:
427
- from models.matanyone_loader import MatAnyoneSession
428
- logger.info("Loading MatAnyone...")
429
-
430
- # Create temp directory
431
- matanyone_temp = self.temp_dir / "matanyone"
432
- matanyone_temp.mkdir(exist_ok=True)
433
-
434
- # Save first mask
435
- first_mask_path = matanyone_temp / "first_mask.png"
436
- cv2.imwrite(str(first_mask_path), masks[0])
437
-
438
- # Initialize on GPU
439
- if torch.cuda.is_available():
440
- device = "cuda"
441
- torch.cuda.set_device(0)
442
- else:
443
- device = "cpu"
444
-
445
- session = MatAnyoneSession(device=device)
446
-
447
- if torch.cuda.is_available():
448
- logger.info(f"MatAnyone on GPU, Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
449
-
450
- # Process
451
- alpha_path, fg_path = session.process_stream(
452
- video_path=Path(input_video),
453
- seed_mask_path=first_mask_path,
454
- out_dir=matanyone_temp,
455
- progress_cb=progress_callback
456
- )
457
-
458
- if not alpha_path or not alpha_path.exists():
459
- logger.warning("MatAnyone produced no output")
460
- return None
461
-
462
- # Read refined masks
463
- refined_masks = []
464
- cap = cv2.VideoCapture(str(alpha_path))
465
-
466
- while True:
467
- ret, frame = cap.read()
468
- if not ret:
469
- break
470
- alpha_mask = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
471
- refined_masks.append(alpha_mask)
472
-
473
- cap.release()
474
-
475
- if torch.cuda.is_available():
476
- torch.cuda.empty_cache()
477
- gc.collect()
478
-
479
- logger.info(f"MatAnyone produced {len(refined_masks)} refined masks")
480
- return refined_masks if len(refined_masks) > 0 else None
481
-
482
- except Exception as e:
483
- logger.warning(f"MatAnyone exception: {str(e)}")
484
- logger.warning(traceback.format_exc())
485
- return None
486
-
487
- def _composite_final_video(self, input_video: str, background_video: str,
488
- masks: list, output_path: str, metadata: Dict[str, Any],
489
- progress_callback=None) -> bool:
490
- """Create final composite"""
491
- try:
492
- fg_cap = cv2.VideoCapture(input_video)
493
-
494
- fps = metadata["fps"]
495
- width = metadata["width"]
496
- height = metadata["height"]
497
-
498
- # Handle background
499
- if background_video and os.path.exists(background_video):
500
- if background_video.lower().endswith(('.png', '.jpg', '.jpeg')):
501
- bg_image = cv2.imread(background_video)
502
- bg_image = cv2.resize(bg_image, (width, height))
503
- bg_cap = None
504
- else:
505
- bg_cap = cv2.VideoCapture(background_video)
506
- else:
507
- bg_image = np.full((height, width, 3), (0, 255, 0), dtype=np.uint8)
508
- bg_cap = None
509
-
510
- # Output writer
511
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
512
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
513
-
514
- frame_idx = 0
515
- total_frames = len(masks)
516
-
517
- logger.info(f"Compositing {total_frames} frames...")
518
-
519
- while frame_idx < total_frames:
520
- ret_fg, fg_frame = fg_cap.read()
521
- if not ret_fg:
522
- logger.error(f"Failed to read foreground frame {frame_idx}")
523
- break
524
-
525
- # Background
526
- if bg_cap is not None:
527
- ret_bg, bg_frame = bg_cap.read()
528
- if not ret_bg:
529
- bg_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
530
- ret_bg, bg_frame = bg_cap.read()
531
- if ret_bg:
532
- bg_frame = cv2.resize(bg_frame, (width, height))
533
- else:
534
- bg_frame = bg_image
535
- else:
536
- bg_frame = bg_image
537
-
538
- # Composite
539
- mask = masks[frame_idx]
540
- mask_norm = mask.astype(np.float32) / 255.0
541
- mask_3ch = np.stack([mask_norm, mask_norm, mask_norm], axis=-1)
542
-
543
- composite = (fg_frame * mask_3ch + bg_frame * (1 - mask_3ch)).astype(np.uint8)
544
- out.write(composite)
545
-
546
- frame_idx += 1
547
-
548
- if progress_callback and frame_idx % 30 == 0:
549
- progress = 0.8 + (frame_idx / total_frames) * 0.2
550
- progress_callback(progress, f"Compositing: {frame_idx}/{total_frames}")
551
-
552
- fg_cap.release()
553
- if bg_cap is not None:
554
- bg_cap.release()
555
- out.release()
556
-
557
- logger.info(f"Composite complete: {output_path}")
558
- return True
559
-
560
- except Exception as e:
561
- logger.error(f"Compositing exception: {str(e)}")
562
- logger.error(traceback.format_exc())
563
- return False
564
-
565
- def cleanup(self):
566
- """Clean up temp files"""
567
- try:
568
- if self.temp_dir.exists():
569
- shutil.rmtree(self.temp_dir)
570
- logger.info("Temp cleaned")
571
- except Exception as e:
572
- logger.error(f"Cleanup error: {str(e)}")
573
-
574
- # --- Compatibility Wrapper ---
575
- def process_video_two_stage(input_video: str, background_video: str,
576
- click_points: list, output_path: str,
577
- use_matanyone: bool = True, progress_callback=None) -> bool:
578
- """Drop-in replacement"""
579
- processor = TwoStageProcessor()
580
- try:
581
- return processor.process_video(
582
- input_video, background_video, click_points,
583
- output_path, use_matanyone, progress_callback
584
- )
585
- finally:
586
- processor.cleanup()