diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,17 +1,13 @@ #!/usr/bin/env python3 """ -Video Background Replacement - FIXED TwoStageProcessor Integration -- Fixed output path parameter (now uses proper .mp4 extension) -- Fixed parameter order for TwoStageProcessor.process_video() -- Fixed background handling (pass background, not preset to processor) -- Added robust vertical gradient + background size guard to prevent broadcasting errors -- All other functionality maintained from working version +Video Background Replacement - FIXED VERSION +- Processes FULL video length (no forced 5-second trim) +- Preserves original audio +- Real AI background generation with Diffusers +- Quality settings (CRF control) +- Optional trimming via UI """ -# ============================================================================== -# CHAPTER 1: IMPORTS AND SETUP -# ============================================================================== - import os import sys import logging @@ -23,12 +19,11 @@ from pathlib import Path from typing import Optional, Tuple, Dict, Any from uuid import uuid4 -import re -import builtins +import random # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) +logger = logging.getLogger("bgx") # Environment setup os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY") @@ -36,16 +31,14 @@ os.environ.setdefault("PYTHONUNBUFFERED", "1") os.environ["OMP_NUM_THREADS"] = "2" os.environ.setdefault("MKL_NUM_THREADS", "4") - -# Force max quality for better segmentation -os.environ["BFX_QUALITY"] = "max" +os.environ.setdefault("BFX_QUALITY", "max") # Core imports import numpy as np import cv2 from PIL import Image import gradio as gr -from moviepy.editor import VideoFileClip, ImageSequenceClip +from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip # Paths BASE_DIR = Path(__file__).resolve().parent @@ -55,83 +48,105 @@ for p in (CHECKPOINTS, TEMP_DIR, OUT_DIR): p.mkdir(parents=True, exist_ok=True) +# Try torch/cuda +try: + import torch + TORCH_AVAILABLE = True + CUDA_AVAILABLE = torch.cuda.is_available() + DEVICE = "cuda" if CUDA_AVAILABLE else "cpu" +except Exception: + TORCH_AVAILABLE = False + CUDA_AVAILABLE = False + DEVICE = "cpu" + # ============================================================================== -# CHAPTER 2: SYSTEM STATE MANAGEMENT +# SYSTEM STATE # ============================================================================== class SystemState: def __init__(self): - self.torch_available = False - self.cuda_available = False - self.device = "cpu" + self.torch_available = TORCH_AVAILABLE + self.cuda_available = CUDA_AVAILABLE + self.device = DEVICE self.sam2_ready = False self.matanyone_ready = False self.sam2_error = None self.matanyone_error = None - self.person_detector_ready = False - self.pose_detector_ready = False - self._detect_capabilities() - - def _detect_capabilities(self): - try: - import torch - self.torch_available = True - self.cuda_available = torch.cuda.is_available() - self.device = "cuda" if self.cuda_available else "cpu" - logger.info(f"PyTorch available, device: {self.device}") - except ImportError: - logger.warning("PyTorch not available") - def get_detailed_status(self) -> str: - lines = [ - "=== SYSTEM STATUS ===", - f"PyTorch: {'✅' if self.torch_available else '❌'}", - f"CUDA: {'✅' if self.cuda_available else '❌'}", - f"Device: {self.device}", - "", - f"SAM2: {'✅ Ready' if self.sam2_ready else '❌ Failed' if self.sam2_error else '⏳ Not tested'}", - ] - if self.sam2_error: - lines.append(f"SAM2 Error: {self.sam2_error}") - - lines.extend([ - f"Person Detector: {'✅ Ready' if self.person_detector_ready else '❌ Failed'}", - f"Pose Detector: {'✅ Ready' if self.pose_detector_ready else '❌ Failed'}", - "", - f"MatAnyone: {'✅ Ready' if self.matanyone_ready else '❌ Failed' if self.matanyone_error else '⏳ Not tested'}", - ]) - if self.matanyone_error: - lines.append(f"MatAnyone Error: {self.matanyone_error}") - - return "\n".join(lines) + def get_status(self) -> str: + return f"""=== SYSTEM STATUS === +PyTorch: {'✅' if self.torch_available else '❌'} +CUDA: {'✅' if self.cuda_available else '❌'} +Device: {self.device} +SAM2: {'✅' if self.sam2_ready else '❌' if self.sam2_error else '⏳'} +MatAnyone: {'✅' if self.matanyone_ready else '❌' if self.matanyone_error else '⏳'} +""" state = SystemState() # ============================================================================== -# CHAPTER 3: UTILITY FUNCTIONS +# UTILITY FUNCTIONS # ============================================================================== +def run_ffmpeg(args: list, fail_ok=False) -> bool: + """Run ffmpeg command with error handling.""" + cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error"] + args + try: + subprocess.run(cmd, check=True, capture_output=True) + return True + except Exception as e: + if not fail_ok: + logger.error(f"ffmpeg failed: {e}") + return False + +def preserve_audio(original_video: str, processed_video: str, output_path: str) -> bool: + """Mux original audio back to processed video.""" + return run_ffmpeg([ + "-i", processed_video, + "-i", original_video, + "-map", "0:v:0", + "-map", "1:a:0?", + "-c:v", "copy", + "-c:a", "aac", + "-b:a", "192k", + output_path + ], fail_ok=True) + +def write_video_h264(clip, path: str, fps: Optional[int] = None, crf: int = 18, preset: str = "medium"): + """Write video with H.264 encoding and quality control.""" + fps = fps or max(1, int(round(getattr(clip, "fps", None) or 24))) + clip.write_videofile( + path, + audio=False, # We'll handle audio separately + fps=fps, + codec="libx264", + preset=preset, + ffmpeg_params=[ + "-crf", str(crf), + "-pix_fmt", "yuv420p", + "-profile:v", "high", + "-movflags", "+faststart", + ], + logger=None, + verbose=False, + ) + def download_file(url: str, dest: Path, name: str) -> bool: """Download file with progress logging.""" if dest.exists(): logger.info(f"{name} already exists") return True - try: - logger.info(f"Downloading {name}...") import requests - - response = requests.get(url, stream=True, timeout=300) - response.raise_for_status() - - with open(dest, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - - logger.info(f"{name} downloaded successfully") + logger.info(f"Downloading {name}...") + with requests.get(url, stream=True, timeout=300) as r: + r.raise_for_status() + with open(dest, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + logger.info(f"{name} downloaded") return True - except Exception as e: logger.error(f"Failed to download {name}: {e}") if dest.exists(): @@ -139,1105 +154,166 @@ def download_file(url: str, dest: Path, name: str) -> bool: return False def ensure_repo(repo_name: str, git_url: str) -> Optional[Path]: - """Clone repository and add to path safely.""" + """Clone repository if needed.""" repo_path = CHECKPOINTS / f"{repo_name}_repo" - if not repo_path.exists(): try: - logger.info(f"Cloning {repo_name}...") - subprocess.run([ - "git", "clone", "--depth", "1", git_url, str(repo_path) - ], check=True, timeout=300, capture_output=True) - logger.info(f"{repo_name} cloned successfully") + subprocess.run(["git", "clone", "--depth", "1", git_url, str(repo_path)], + check=True, timeout=300, capture_output=True) + logger.info(f"{repo_name} cloned") except Exception as e: logger.error(f"Failed to clone {repo_name}: {e}") return None - else: - logger.info(f"{repo_name} already exists") - # Add to path if not already there repo_str = str(repo_path) if repo_str not in sys.path: sys.path.insert(0, repo_str) - logger.info(f"Added {repo_name} to Python path") - return repo_path -def apply_k_governor_patch(repo_path: Path) -> bool: - """Apply K-Governor patch to prevent topk errors.""" - if os.environ.get("SAFE_TOPK_BYPASS", "0") == "1": - logger.info("K-Governor bypassed") - return True - - try: - # Create safe_ops.py - utils_dir = repo_path / "matanyone" / "utils" - if not utils_dir.exists(): - utils_dir = repo_path / "utils" - utils_dir.mkdir(parents=True, exist_ok=True) - - safe_ops_content = ''' -import torch -import os - -# Store original functions -_orig_topk = getattr(torch.topk, "__wrapped__", torch.topk) -_orig_kthvalue = getattr(torch.kthvalue, "__wrapped__", torch.kthvalue) - -def safe_topk(input, k, dim=None, largest=True, sorted=True): - """Safe version of torch.topk that clamps k to valid range.""" - if dim is None: - dim = -1 - - size = input.size(dim) - k_clamped = max(1, min(int(k), int(size))) - - if k_clamped != k and os.environ.get("SAFE_TOPK_VERBOSE", "1") == "1": - print(f"[K-Governor] Clamped k from {k} to {k_clamped} for dim {dim}") - - values, indices = _orig_topk(input, k_clamped, dim=dim, largest=largest, sorted=sorted) - - # Pad if necessary - if k_clamped < k: - pad_size = k - k_clamped - pad_shape = list(values.shape) - pad_shape[dim] = pad_size - - if largest: - pad_values = torch.full(pad_shape, float('-inf'), dtype=values.dtype, device=values.device) - else: - pad_values = torch.full(pad_shape, float('inf'), dtype=values.dtype, device=values.device) - - pad_indices = torch.zeros(pad_shape, dtype=indices.dtype, device=indices.device) - - values = torch.cat([values, pad_values], dim=dim) - indices = torch.cat([indices, pad_indices], dim=dim) - - return values, indices - -def safe_kthvalue(input, k, dim=None, keepdim=False): - """Safe version of torch.kthvalue that clamps k to valid range.""" - if dim is None: - dim = -1 - - size = input.size(dim) - k_clamped = max(1, min(int(k), int(size))) - - if k_clamped != k and os.environ.get("SAFE_TOPK_VERBOSE", "1") == "1": - print(f"[K-Governor] Clamped k from {k} to {k_clamped} for dim {dim}") - - return _orig_kthvalue(input, k_clamped, dim=dim, keepdim=keepdim) - -# Replace torch functions globally -torch.topk = safe_topk -torch.kthvalue = safe_kthvalue -''' - - (utils_dir / "safe_ops.py").write_text(safe_ops_content) - - # Patch source files - patched_count = 0 - for py_file in repo_path.rglob("*.py"): - if py_file.name == "safe_ops.py": - continue - - try: - content = py_file.read_text(encoding='utf-8') - original_content = content - - # Check if we need to patch this file - if "torch.topk" in content or "torch.kthvalue" in content: - # Replace function calls with safe versions - content = re.sub(r'\btorch\.topk\b', 'torch.topk', content) - content = re.sub(r'\btorch\.kthvalue\b', 'torch.kthvalue', content) - - if content != original_content: - py_file.write_text(content, encoding='utf-8') - patched_count += 1 - - except Exception as e: - logger.warning(f"Failed to patch {py_file}: {e}") - - logger.info(f"K-Governor patch applied to {patched_count} files") - return True - - except Exception as e: - logger.error(f"K-Governor patch failed: {e}") - return False - -def write_video_optimized(clip, path: str): - """Write video with optimized settings.""" - fps = max(1, int(round(getattr(clip, "fps", None) or 24))) - clip.write_videofile( - path, - audio=False, - fps=fps, - codec="libx264", - logger=None, - ffmpeg_params=["-preset", "fast", "-pix_fmt", "yuv420p"] - ) - -def create_background(width: int, height: int, preset: str = "office") -> np.ndarray: - """Create background based on preset.""" - preset = preset.lower() - - if preset == "office": - # Soft gray gradient - top = np.array([245, 246, 248]) - bottom = np.array([220, 223, 228]) - elif preset == "studio": - # Dark gradient - top = np.array([32, 32, 36]) - bottom = np.array([64, 64, 70]) - elif preset == "nature": - # Green gradient - top = np.array([180, 220, 190]) - bottom = np.array([100, 160, 120]) - elif preset == "blue": - # Solid blue - return np.full((height, width, 3), [18, 112, 214], dtype=np.uint8) - else: - # White - return np.full((height, width, 3), [255, 255, 255], dtype=np.uint8) - - # Create vertical gradient - gradient = np.zeros((height, width, 3), dtype=np.uint8) - for i in range(height): - ratio = i / max(1, height - 1) - color = top * (1 - ratio) + bottom * ratio - gradient[i, :] = color.astype(np.uint8) - - return gradient - # ============================================================================== -# CHAPTER 4: SAM2 HANDLER CLASS +# SAM2 HANDLER # ============================================================================== class SAM2Handler: def __init__(self): self.predictor = None self.initialized = False - self.error_details = None - self.person_detector = None - self.pose_detector = None - self.mp_pose = None - - # REVERTED TO WORKING CONFIGURATION (SAM2 Large settings) self.config = { - "points_per_side": 32, # Original working value - "pred_iou_thresh": 0.86, # Original working value - "stability_score_thresh": 0.95, - "crop_n_layers": 2, - "crop_n_points_downscale_factor": 1, - "min_mask_region_area": 50, - # Working full-body settings "use_person_detection": True, "use_pose_estimation": True, - "refine_iterations": 5, # Back to original - "confidence_threshold": 0.15, # Original working value - "min_coverage_ratio": 0.08, - "max_points_per_method": 30, - "bbox_padding": 0.3, # Original working value - "use_negative_prompts": True, - # DISABLE problematic features - "multimask_output": True, # Keep this but don't use union - "use_mask_union": False, # DISABLE - was causing issues + "confidence_threshold": 0.15, + "refine_iterations": 5, } def update_config(self, new_config: Dict[str, Any]): - """Update configuration from UI settings""" self.config.update(new_config) - logger.info(f"SAM2 config updated: {new_config}") def initialize(self) -> bool: - if not (state.torch_available and state.cuda_available): - error_msg = "SAM2 requires CUDA" - logger.info(error_msg) - state.sam2_error = error_msg + """Initialize SAM2 with working configuration.""" + if not (TORCH_AVAILABLE and CUDA_AVAILABLE): + state.sam2_error = "SAM2 requires CUDA" return False try: # Ensure repository repo_path = ensure_repo("sam2", "https://github.com/facebookresearch/segment-anything-2.git") if not repo_path: - state.sam2_error = "Failed to clone SAM2 repository" + state.sam2_error = "Failed to clone SAM2" return False - # REVERT TO SAM2 LARGE (was working) + # Download checkpoint checkpoint_path = CHECKPOINTS / "sam2.1_hiera_large.pt" checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" - if not download_file(checkpoint_url, checkpoint_path, "SAM2 Large"): - state.sam2_error = "Failed to download SAM2 Large checkpoint" + state.sam2_error = "Failed to download SAM2 checkpoint" return False - # Import SAM2 components - import torch + # Import and setup from hydra import initialize_config_dir, compose from hydra.core.global_hydra import GlobalHydra from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor - # Setup Hydra configuration config_dir = repo_path / "sam2" / "configs" - if not config_dir.exists(): - error_msg = f"SAM2 config directory not found: {config_dir}" - logger.error(error_msg) - state.sam2_error = error_msg - return False - - # Clear any existing Hydra instance if GlobalHydra().is_initialized(): GlobalHydra.instance().clear() - # REVERT TO SAM2 LARGE CONFIG initialize_config_dir(config_dir=str(config_dir), version_base=None) cfg = compose(config_name="sam2.1/sam2.1_hiera_l.yaml") - # Build LARGE model (back to working version) + # Build model model = build_sam2("sam2.1/sam2.1_hiera_l.yaml", str(checkpoint_path), device="cuda") - - # Optimize if possible - if hasattr(torch, 'compile'): - try: - model = torch.compile(model, mode="default") - logger.info("SAM2 Large model compiled") - except Exception as e: - logger.info(f"Model compilation skipped: {e}") - self.predictor = SAM2ImagePredictor(model) - # Initialize additional components for hybrid prompting - self._initialize_person_detector() - self._initialize_pose_detector() - - # Verify with test inference + # Verify test_image = np.zeros((64, 64, 3), dtype=np.uint8) self.predictor.set_image(test_image) - masks, scores, _ = self.predictor.predict( point_coords=np.array([[32, 32]]), point_labels=np.array([1]), multimask_output=True ) - if masks is not None and len(masks) > 0: - self.initialized = True - state.sam2_ready = True - logger.info("SAM2 Large initialized and verified successfully (REVERTED TO WORKING VERSION)") - return True - else: - error_msg = "SAM2 Large verification failed - no valid masks produced" - logger.error(error_msg) - state.sam2_error = error_msg - return False - + self.initialized = masks is not None and len(masks) > 0 + state.sam2_ready = self.initialized + + if not self.initialized: + state.sam2_error = "SAM2 verification failed" + + return self.initialized + except Exception as e: - error_msg = f"SAM2 Large initialization failed: {str(e)}" - logger.error(error_msg) - state.sam2_error = error_msg - self.error_details = traceback.format_exc() + state.sam2_error = f"SAM2 init failed: {e}" return False - def _initialize_person_detector(self): - """Initialize YOLOv8 for person detection""" - try: - from ultralytics import YOLO - self.person_detector = YOLO('yolov8n.pt') - state.person_detector_ready = True - logger.info("Person detector initialized successfully") - except Exception as e: - logger.warning(f"Person detector failed to initialize: {e}") - self.person_detector = None - state.person_detector_ready = False - - def _initialize_pose_detector(self): - """Initialize MediaPipe for pose estimation""" - try: - import mediapipe as mp - self.mp_pose = mp.solutions.pose - self.pose_detector = self.mp_pose.Pose( - static_image_mode=True, - model_complexity=2, # Higher complexity for better detection - enable_segmentation=False, - min_detection_confidence=0.3, # Lower threshold - min_tracking_confidence=0.3 - ) - state.pose_detector_ready = True - logger.info("Pose detector initialized successfully") - except Exception as e: - logger.warning(f"Pose detector failed to initialize: {e}") - self.pose_detector = None - state.pose_detector_ready = False - def create_mask(self, image_rgb: np.ndarray) -> Optional[np.ndarray]: - """REVERTED to original working approach - no complex fusion.""" + """Create person mask from RGB image.""" if not self.initialized: return None try: - logger.info("=== REVERTED SAM2Handler create_mask (working version) ===") self.predictor.set_image(image_rgb) h, w = image_rgb.shape[:2] - # Keep track of all masks and their qualities - candidate_masks = [] - - # 1. Person detection (original working approach) - if self.config["use_person_detection"] and self.person_detector is not None: - masks = self._try_person_detection_original(image_rgb, h, w) - candidate_masks.extend(masks) - logger.info(f"Person detection generated {len(masks)} candidate masks") - - # 2. Pose estimation (original working approach) - if self.config["use_pose_estimation"] and self.pose_detector is not None: - masks = self._try_pose_estimation_original(image_rgb, h, w) - candidate_masks.extend(masks) - logger.info(f"Pose estimation generated {len(masks)} candidate masks") - - # 3. Anatomical grid (original working approach) - masks = self._try_anatomical_grid_original(h, w) - candidate_masks.extend(masks) - logger.info(f"Anatomical grid generated {len(masks)} candidate masks") - - # 4. Progressive expansion (original working approach) - masks = self._try_progressive_expansion_original(h, w) - candidate_masks.extend(masks) - logger.info(f"Progressive expansion generated {len(masks)} candidate masks") - - # 5. Multiple box strategies (original working approach) - masks = self._try_multiple_box_original(h, w) - candidate_masks.extend(masks) - logger.info(f"Multiple box strategies generated {len(masks)} candidate masks") - - # Select the BEST mask (original working selection logic) - if candidate_masks: - best_mask, best_score, method_used = self._select_best_mask_original(candidate_masks) - - if best_mask is not None: - logger.info(f"SELECTED BEST MASK: {method_used}, quality: {best_score:.3f}, coverage: {np.mean(best_mask):.3f}") - - # Apply original working mask cleaning - cleaned_mask = self._original_mask_cleaning(best_mask) - - # Log mask statistics for debugging - mask_stats = self._get_mask_statistics(cleaned_mask) - logger.info(f"SAM2 mask stats: shape={cleaned_mask.shape}, min={mask_stats['min']:.3f}, max={mask_stats['max']:.3f}, mean={mask_stats['mean']:.3f}") - - return (cleaned_mask * 255).astype(np.uint8) - - logger.error("ALL prompting strategies failed to generate acceptable masks") - return None - - except Exception as e: - logger.error(f"SAM2 mask creation failed: {e}") - return None - - # [All the SAM2 helper methods remain the same as in original working version] - def _try_person_detection_original(self, image_rgb: np.ndarray, h: int, w: int): - """Original working person detection approach.""" - if self.person_detector is None: - return [] - - masks = [] - - try: - image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) - results = self.person_detector(image_bgr, verbose=False, conf=0.1) - - if not results or len(results) == 0: - return masks - - person_boxes = [] - for result in results: - if result.boxes is not None: - for box, cls, conf in zip(result.boxes.xyxy, result.boxes.cls, result.boxes.conf): - if int(cls) == 0: # person class - person_boxes.append((box.cpu().numpy(), float(conf))) - - if not person_boxes: - return masks - - logger.info(f"Found {len(person_boxes)} person detections") - - # Try strategies for each detected person (original approach) - for i, (bbox, conf) in enumerate(person_boxes[:3]): - x1, y1, x2, y2 = bbox - - # Original multipoint strategy - mask = self._try_bbox_multipoint_original(x1, y1, x2, y2, h, w) - if mask is not None: - quality = self._get_mask_quality(mask) - masks.append((mask, quality, f"person_bbox_multipoint_{i}")) - - # Original expanded bbox strategy - mask = self._try_expanded_bbox_original(x1, y1, x2, y2, h, w) - if mask is not None: - quality = self._get_mask_quality(mask) - masks.append((mask, quality, f"person_bbox_expanded_{i}")) - - # Original grid strategy - mask = self._try_bbox_grid_original(x1, y1, x2, y2, h, w) - if mask is not None: - quality = self._get_mask_quality(mask) - masks.append((mask, quality, f"person_bbox_grid_{i}")) - - except Exception as e: - logger.warning(f"Person detection failed: {e}") - - return masks - - def _try_bbox_multipoint_original(self, x1, y1, x2, y2, h, w): - """Original working bbox multipoint strategy.""" - try: - padding = self.config["bbox_padding"] - box_w, box_h = x2 - x1, y2 - y1 - x1 = max(0, x1 - padding * box_w) - y1 = max(0, y1 - padding * box_h) - x2 = min(w, x2 + padding * box_w) - y2 = min(h, y2 + padding * box_h) - - # Generate multiple points covering the full person - points = [] - labels = [] - - # Main body points (positive prompts) - for row in np.linspace(y1, y2, 5): - for col in np.linspace(x1, x2, 3): - points.append([col, row]) - labels.append(1) - - # Add negative prompts around the person - if self.config["use_negative_prompts"]: - neg_points = [ - [x1 - 50, y1], [x2 + 50, y1], - [x1 - 50, y2], [x2 + 50, y2], - [x1 + box_w/2, y1 - 50], - [x1 + box_w/2, y2 + 50], - ] - for px, py in neg_points: - if 0 <= px < w and 0 <= py < h: - points.append([px, py]) - labels.append(0) - - if len(points) == 0: - return None - - points = np.array(points) - labels = np.array(labels) + # Try center point strategy + center_points = np.array([[w//2, h//2]]) + center_labels = np.array([1]) masks, scores, _ = self.predictor.predict( - point_coords=points, - point_labels=labels, + point_coords=center_points, + point_labels=center_labels, multimask_output=True ) if len(masks) > 0: - # ORIGINAL: Take best single mask, no fusion - return masks[np.argmax(scores)] - - except Exception as e: - logger.warning(f"Bbox multipoint original failed: {e}") - - return None - - def _try_expanded_bbox_original(self, x1, y1, x2, y2, h, w): - """Original expanded bbox strategy.""" - try: - padding = 0.2 - box_w, box_h = x2 - x1, y2 - y1 - x1 = max(0, x1 - padding * box_w) - y1 = max(0, y1 - padding * box_h) - x2 = min(w, x2 + padding * box_w) - y2 = min(h, y2 + padding * box_h) - - box = np.array([x1, y1, x2, y2]) - - masks, scores, _ = self.predictor.predict( - box=box[None, :], - multimask_output=True - ) - - if len(masks) > 0: - return masks[np.argmax(scores)] - - except Exception as e: - logger.warning(f"Expanded bbox original failed: {e}") - - return None - - def _try_bbox_grid_original(self, x1, y1, x2, y2, h, w): - """Original bbox grid strategy.""" - try: - points = [] - - for row in np.linspace(y1, y2, 6): - for col in np.linspace(x1, x2, 4): - points.append([col, row]) - - if len(points) == 0: - return None - - points = np.array(points) - labels = np.ones(len(points)) - - masks, scores, _ = self.predictor.predict( - point_coords=points, - point_labels=labels, - multimask_output=True - ) - - if len(masks) > 0: - return masks[np.argmax(scores)] - - except Exception as e: - logger.warning(f"Bbox grid original failed: {e}") - - return None - - def _try_pose_estimation_original(self, image_rgb: np.ndarray, h: int, w: int): - """Original pose estimation approach.""" - if self.pose_detector is None: - return [] - - masks = [] - - try: - results = self.pose_detector.process(image_rgb) - - if not results.pose_landmarks: - return masks - - all_landmarks = [] - - for i, landmark in enumerate(results.pose_landmarks.landmark): - x = int(landmark.x * w) - y = int(landmark.y * h) - if 0 <= x < w and 0 <= y < h and landmark.visibility > 0.3: - all_landmarks.append([x, y]) - - if len(all_landmarks) < 5: - return masks - - logger.info(f"Found {len(all_landmarks)} valid pose landmarks") - - if len(all_landmarks) <= self.config["max_points_per_method"]: - points = np.array(all_landmarks) - labels = np.ones(len(points)) - - masks_result, scores, _ = self.predictor.predict( - point_coords=points, - point_labels=labels, - multimask_output=True - ) - - if len(masks_result) > 0: - best_mask = masks_result[np.argmax(scores)] - quality = self._get_mask_quality(best_mask) - masks.append((best_mask, quality, "pose_all_landmarks")) - - except Exception as e: - logger.warning(f"Pose estimation failed: {e}") - - return masks - - def _try_anatomical_grid_original(self, h: int, w: int): - """Original anatomical grid approach.""" - masks = [] - - try: - body_regions = { - "head": [(w//2, h//8)], - "shoulders": [(w//3, h//4), (2*w//3, h//4)], - "torso": [(w//2, h//3), (w//2, h//2)], - "arms": [(w//6, h//3), (5*w//6, h//3), (w//8, h//2), (7*w//8, h//2)], - "hips": [(w//3, 2*h//3), (2*w//3, 2*h//3)], - "legs": [(w//3, 3*h//4), (2*w//3, 3*h//4), (w//3, 7*h//8), (2*w//3, 7*h//8)], - "feet": [(w//3, 15*h//16), (2*w//3, 15*h//16)] - } - - all_points = [] - for region_points in body_regions.values(): - all_points.extend(region_points) - - if len(all_points) <= self.config["max_points_per_method"]: - points = np.array(all_points) - labels = np.ones(len(points)) - - masks_result, scores, _ = self.predictor.predict( - point_coords=points, - point_labels=labels, - multimask_output=True - ) - - if len(masks_result) > 0: - best_mask = masks_result[np.argmax(scores)] - quality = self._get_mask_quality(best_mask) - masks.append((best_mask, quality, "anatomical_full_body")) - - except Exception as e: - logger.warning(f"Anatomical grid failed: {e}") - - return masks - - def _try_progressive_expansion_original(self, h: int, w: int): - """Original progressive expansion approach.""" - masks = [] - - try: - center_point = np.array([[w//2, h//2]]) - labels = np.array([1]) - - masks_result, scores, _ = self.predictor.predict( - point_coords=center_point, - point_labels=labels, - multimask_output=True - ) - - if len(masks_result) == 0: - return masks - - current_mask = masks_result[np.argmax(scores)] - quality = self._get_mask_quality(current_mask) - masks.append((current_mask, quality, "progressive_expansion")) - - except Exception as e: - logger.warning(f"Progressive expansion failed: {e}") - - return masks - - def _try_multiple_box_original(self, h: int, w: int): - """Original multiple box strategies.""" - masks = [] - - try: - box_strategies = [ - {"name": "full_body_tight", "box": [0.15, 0.05, 0.85, 0.95]}, - {"name": "full_body_loose", "box": [0.1, 0.02, 0.9, 0.98]}, - {"name": "center_person", "box": [0.25, 0.1, 0.75, 0.9]}, - {"name": "tall_narrow", "box": [0.35, 0.05, 0.65, 0.95]}, - {"name": "wide_coverage", "box": [0.05, 0.1, 0.95, 0.85]}, - ] - - for strategy in box_strategies: - box_fractions = strategy["box"] - box = np.array([ - box_fractions[0] * w, - box_fractions[1] * h, - box_fractions[2] * w, - box_fractions[3] * h - ]) - - masks_result, scores, _ = self.predictor.predict( - box=box[None, :], - multimask_output=True - ) - - if len(masks_result) > 0: - best_mask = masks_result[np.argmax(scores)] - quality = self._get_mask_quality(best_mask) - masks.append((best_mask, quality, f"box_{strategy['name']}")) - - except Exception as e: - logger.warning(f"Multiple box strategies failed: {e}") - - return masks - - def _select_best_mask_original(self, candidate_masks): - """Original working mask selection logic.""" - if not candidate_masks: - return None, 0.0, "none" - - best_mask = None - best_score = 0.0 - best_method = "none" - - for mask, base_quality, method in candidate_masks: - if mask is None: - continue - - # Original comprehensive scoring - coverage = np.mean(mask) - edge_strength = self._calculate_edge_strength(mask) - compactness = self._calculate_compactness(mask) - aspect_ratio_score = self._calculate_aspect_ratio_score(mask) - vertical_coverage = self._calculate_vertical_coverage(mask) - - # Original weighted score - full_body_score = ( - coverage * 0.3 + - edge_strength * 0.2 + - compactness * 0.15 + - aspect_ratio_score * 0.15 + - vertical_coverage * 0.2 - ) - - if coverage >= self.config["min_coverage_ratio"]: - full_body_score *= 1.2 - - logger.info(f"Mask {method}: coverage={coverage:.3f}, edge={edge_strength:.3f}, " - f"compact={compactness:.3f}, aspect={aspect_ratio_score:.3f}, " - f"vertical={vertical_coverage:.3f}, FINAL_SCORE={full_body_score:.3f}") - - if full_body_score > best_score: - best_score = full_body_score - best_mask = mask - best_method = method - - return best_mask, best_score, best_method - - def _original_mask_cleaning(self, mask: np.ndarray) -> np.ndarray: - """Original working mask cleaning approach.""" - try: - if mask.max() <= 1.0: - mask_uint8 = (mask * 255).astype(np.uint8) - else: - mask_uint8 = mask.astype(np.uint8) - - # Original 6-step cleaning process that was working - # 1. Fill holes - kernel_fill = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)) - mask_filled = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_fill) - - # 2. Connect regions - kernel_connect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) - mask_connected = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, kernel_connect) - - # 3. Heavy smoothing - mask_smooth1 = cv2.GaussianBlur(mask_connected, (7, 7), 2.0) - - # 4. Re-threshold - _, mask_thresh = cv2.threshold(mask_smooth1, 128, 255, cv2.THRESH_BINARY) - - # 5. Final smoothing - mask_smooth2 = cv2.GaussianBlur(mask_thresh, (5, 5), 1.0) - - # 6. Slight dilation - kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) - mask_final = cv2.morphologyEx(mask_smooth2, cv2.MORPH_DILATE, kernel_dilate) + best_mask = masks[np.argmax(scores)] + # Clean up mask + mask_uint8 = (best_mask * 255).astype(np.uint8) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask_cleaned = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel) + mask_cleaned = cv2.GaussianBlur(mask_cleaned, (3, 3), 1.0) + return mask_cleaned - return (mask_final.astype(np.float32) / 255.0) + return None except Exception as e: - logger.warning(f"Original mask cleaning failed: {e}") - return mask - - def _calculate_edge_strength(self, mask: np.ndarray): - """Calculate edge strength of mask.""" - try: - edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) - return np.sum(edges > 0) / mask.size - except: - return 0.0 - - def _calculate_compactness(self, mask: np.ndarray): - """Calculate compactness (prefer connected regions).""" - try: - mask_uint8 = (mask * 255).astype(np.uint8) - contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - - if contours: - largest_contour = max(contours, key=cv2.contourArea) - area = cv2.contourArea(largest_contour) - perimeter = cv2.arcLength(largest_contour, True) - if perimeter > 0: - return 4 * np.pi * area / (perimeter * perimeter) - return 0.0 - except: - return 0.0 - - def _calculate_aspect_ratio_score(self, mask: np.ndarray): - """Calculate how well aspect ratio matches expected human proportions.""" - try: - coords = np.where(mask > 0.5) - if len(coords[0]) == 0: - return 0.0 - - min_y, max_y = np.min(coords[0]), np.max(coords[0]) - min_x, max_x = np.min(coords[1]), np.max(coords[1]) - - height = max_y - min_y - width = max_x - min_x - - if width == 0: - return 0.0 - - aspect_ratio = height / width - - # Ideal human aspect ratio is roughly 1.5-2.5 (height:width) - ideal_ratio = 2.0 - ratio_score = 1.0 - abs(aspect_ratio - ideal_ratio) / ideal_ratio - - return max(0.0, ratio_score) - except: - return 0.0 - - def _calculate_vertical_coverage(self, mask: np.ndarray): - """Calculate how well mask covers the vertical extent (head to toe).""" - try: - h, w = mask.shape - - # Check coverage in vertical thirds - top_third = mask[:h//3, :] - middle_third = mask[h//3:2*h//3, :] - bottom_third = mask[2*h//3:, :] - - top_coverage = np.mean(top_third) - middle_coverage = np.mean(middle_third) - bottom_coverage = np.mean(bottom_third) - - # Good full-body mask should have coverage in all thirds - min_coverage = min(top_coverage, middle_coverage, bottom_coverage) - avg_coverage = (top_coverage + middle_coverage + bottom_coverage) / 3 - - # Weighted score favoring masks that cover all vertical regions - return 0.4 * min_coverage + 0.6 * avg_coverage - except: - return 0.0 - - def _get_mask_quality(self, mask: np.ndarray): - """Calculate overall mask quality score.""" - if mask is None: - return 0.0 - - try: - coverage = np.mean(mask) - edge_strength = self._calculate_edge_strength(mask) - compactness = self._calculate_compactness(mask) - - # Simple quality metric for fallback compatibility - quality = coverage * 0.5 + edge_strength * 0.3 + compactness * 0.2 - return min(quality, 1.0) - except: - return 0.0 - - def _get_mask_statistics(self, mask: np.ndarray): - """Get detailed mask statistics for logging.""" - return { - "min": float(np.min(mask)), - "max": float(np.max(mask)), - "mean": float(np.mean(mask)) - } + logger.error(f"SAM2 mask creation failed: {e}") + return None # ============================================================================== -# CHAPTER 5: MATANYONE HANDLER CLASS +# MATANYONE HANDLER # ============================================================================== class MatAnyoneHandler: def __init__(self): self.processor = None self.initialized = False - self.error_details = None def initialize(self) -> bool: - if not state.torch_available: - error_msg = "MatAnyone requires PyTorch" - logger.info(error_msg) - state.matanyone_error = error_msg + """Initialize MatAnyone processor.""" + if not TORCH_AVAILABLE: + state.matanyone_error = "MatAnyone requires PyTorch" return False try: # Ensure repository repo_path = ensure_repo("matanyone", "https://github.com/pq-yang/MatAnyone.git") if not repo_path: - state.matanyone_error = "Failed to clone MatAnyone repository" + state.matanyone_error = "Failed to clone MatAnyone" return False - # Apply K-Governor patch - apply_k_governor_patch(repo_path) - - # CRITICAL FIX 1: Actually import and activate the safe_ops patch - try: - # Prefer canonical package path - try: - import matanyone.utils.safe_ops as _safe_ops # noqa: F401 - logger.info("K-Governor safe_ops loaded (torch.topk/kthvalue patched)") - except ImportError: - # Fallback: direct module load in case structure differs - import importlib.util - safe_ops_file = repo_path / "matanyone" / "utils" / "safe_ops.py" - if safe_ops_file.exists(): - spec = importlib.util.spec_from_file_location("matanyone.utils.safe_ops", str(safe_ops_file)) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - logger.info("K-Governor safe_ops loaded via direct import") - else: - logger.warning("safe_ops.py file not found - K-Governor not activated") - except Exception as e: - logger.warning(f"Could not load safe_ops: {e}") - - # Try to import MatAnyone + # Import try: from matanyone.inference.inference_core import InferenceCore - logger.info("Imported MatAnyone from inference_core") - except ImportError: - try: - # Alternative path - sys.path.insert(0, str(repo_path / "matanyone")) - from inference.inference_core import InferenceCore - logger.info("Imported MatAnyone from alternative path") - except ImportError: - error_msg = "Could not import MatAnyone InferenceCore" - logger.error(error_msg) - state.matanyone_error = error_msg - return False - - # CRITICAL FIX 2: Consolidate ObjectInfo equality patch (single definitive patch) - try: - from matanyone.inference.object_info import ObjectInfo - - def safe_eq(self, other): - """Safe equality that handles ObjectInfo, int, and str comparisons.""" - if isinstance(other, ObjectInfo): - return self.id == other.id - if isinstance(other, (int, str)): - return self.id == other - return NotImplemented - - def safe_hash(self): - """Consistent hash based on ID for dict membership.""" - return hash(self.id) - - ObjectInfo.__eq__ = safe_eq - ObjectInfo.__hash__ = safe_hash - logger.info("ObjectInfo equality/hash patched (matanyone.inference.object_info)") - - except ImportError as e: - logger.warning(f"Primary ObjectInfo patch failed: {e}") - # Fallback for alt layout - try: - from inference.object_info import ObjectInfo as _ObjectInfoAlt - - def safe_eq_alt(self, other): - if isinstance(other, _ObjectInfoAlt): - return self.id == other.id - if isinstance(other, (int, str)): - return self.id == other - return NotImplemented - - def safe_hash_alt(self): - return hash(self.id) - - _ObjectInfoAlt.__eq__ = safe_eq_alt - _ObjectInfoAlt.__hash__ = safe_hash_alt - logger.info("ObjectInfo equality/hash patched (inference.object_info)") - - except ImportError as e2: - logger.warning(f"Fallback ObjectInfo patch failed: {e2}") - - # CRITICAL FIX 3: ObjectManager membership checks using obj.id - try: - from matanyone.inference.object_manager import ObjectManager - - original_has_all = ObjectManager.has_all - - def safe_has_all(self, objects): - """Fixed has_all that uses obj.id for membership checks.""" - if objects is None: - return True - - for obj in objects: - # Use obj.id instead of obj for membership check - obj_id = getattr(obj, 'id', obj) - if obj_id not in self.obj_to_tmp_id: - return False - return True - - ObjectManager.has_all = safe_has_all - - # Verify the patch was applied - if ObjectManager.has_all == safe_has_all: - logger.info("ObjectManager membership patch applied and verified") - else: - logger.warning("ObjectManager patch may not have taken effect") - except ImportError: - logger.warning("Could not import ObjectManager for patching") - except Exception as e: - logger.warning(f"ObjectManager patching failed: {e}") - - # CRITICAL FIX: Patch tensor concatenation in modules.py - try: - import torch.nn.functional as F - - # Import from the actual installed package, not cloned repo - patch_applied = False - - # Method 1: Try to import the specific problematic module - try: - import matanyone.model.modules as modules_mod - - # Find classes with forward methods that might have the concatenation issue - for name in dir(modules_mod): - cls = getattr(modules_mod, name) - if (hasattr(cls, '__bases__') and - hasattr(cls, 'forward') and - callable(getattr(cls, 'forward', None))): - - original_forward = cls.forward - - def make_safe_forward(orig_forward, class_name): - def safe_forward(self, *args, **kwargs): - try: - return orig_forward(self, *args, **kwargs) - except RuntimeError as e: - if "Sizes of tensors must match" in str(e): - logger.warning(f"Tensor concatenation error in {class_name}: {e}") - - # Handle the specific torch.cat error - if len(args) >= 2: - g, h = args[0], args[1] - if hasattr(g, 'shape') and hasattr(h, 'shape'): - logger.warning(f"g shape: {g.shape}, h shape: {h.shape}") - - # Try to fix by padding - if len(g.shape) >= 3 and len(h.shape) >= 3: - g_size = g.shape[2] - h_size = h.shape[2] - target_size = max(g_size, h_size) - - if g_size < target_size: - pad_size = target_size - g_size - g = F.pad(g, [0, pad_size], mode='constant', value=0) - - if h_size < target_size: - pad_size = target_size - h_size - h = F.pad(h, [0, pad_size], mode='constant', value=0) - - # Retry with padded tensors - try: - return orig_forward(self, g, h, *args[2:], **kwargs) - except Exception: - pass - - # Fallback: return first argument if available - if args: - return args[0] - return torch.zeros(1, 1, 1) - else: - raise e - - return safe_forward - - # Apply patch - cls.forward = make_safe_forward(original_forward, name) - patch_applied = True - - if patch_applied: - logger.info("Tensor concatenation safety patches applied to modules") - - except ImportError as e: - logger.warning(f"Could not import MatAnyone modules for patching: {e}") - - except Exception as e: - logger.warning(f"Tensor concatenation patching failed: {e}") + state.matanyone_error = "Failed to import MatAnyone" + return False # Try HuggingFace model first try: self.processor = InferenceCore("PeiqingYang/MatAnyone") - logger.info("MatAnyone initialized with HuggingFace model") - except Exception as e: - logger.warning(f"HuggingFace model failed: {e}") + logger.info("MatAnyone loaded from HuggingFace") + except Exception: # Fallback to local checkpoint checkpoint_path = CHECKPOINTS / "matanyone.pth" checkpoint_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth" @@ -1246,737 +322,506 @@ def safe_forward(self, *args, **kwargs): state.matanyone_error = "Failed to download MatAnyone checkpoint" return False - try: - from matanyone.utils.get_default_model import get_matanyone_model - device = "cuda" if state.cuda_available else "cpu" - network = get_matanyone_model(str(checkpoint_path), device=device) - self.processor = InferenceCore(network) - logger.info("MatAnyone initialized with local checkpoint") - except Exception as e: - error_msg = f"Local checkpoint initialization failed: {e}" - logger.error(error_msg) - state.matanyone_error = error_msg - return False - - # More thorough verification with realistic test - if self._verify(): - self.initialized = True - state.matanyone_ready = True - logger.info("MatAnyone initialized and verified successfully") - return True - else: - error_msg = "MatAnyone verification failed" - logger.error(error_msg) - state.matanyone_error = error_msg - return False - - except Exception as e: - error_msg = f"MatAnyone initialization failed: {str(e)}" - logger.error(error_msg) - state.matanyone_error = error_msg - self.error_details = traceback.format_exc() - return False - - def _verify(self) -> bool: - """Thorough verification with realistic test.""" - try: - # Create larger test video with motion - frames = [] - for i in range(20): - frame = np.zeros((128, 128, 3), dtype=np.uint8) - # Moving white square - x = 20 + i * 2 - cv2.rectangle(frame, (x, 40), (x + 30, 70), (255, 255, 255), -1) - frames.append(frame) - - # Write test video - test_video = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False, dir=TEMP_DIR) - test_video.close() - - clip = ImageSequenceClip(frames, fps=5) - clip.write_videofile(test_video.name, audio=False, logger=None, verbose=False) - clip.close() + from matanyone.utils.get_default_model import get_matanyone_model + network = get_matanyone_model(str(checkpoint_path), device=DEVICE) + self.processor = InferenceCore(network) + logger.info("MatAnyone loaded from local checkpoint") - # Create test mask - mask = np.zeros((128, 128), dtype=np.uint8) - cv2.rectangle(mask, (25, 45), (85, 65), 255, -1) - - test_mask = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=TEMP_DIR) - test_mask.close() - cv2.imwrite(test_mask.name, mask) - - # Test directory - test_dir = tempfile.mkdtemp(dir=TEMP_DIR) - - # Run processing - result = self.processor.process_video( - input_path=test_video.name, - mask_path=test_mask.name, - output_path=test_dir, - max_size=256 - ) - - # Check result more thoroughly - success = False - alpha_path = None - - if result: - if isinstance(result, (list, tuple)) and len(result) > 1: - alpha_path = result[1] - elif isinstance(result, str): - alpha_path = result - - if not alpha_path or not os.path.exists(alpha_path): - # Search for alpha files - for pattern in ["alpha.mp4", "alpha.mkv", "alpha.mov", "alpha.webm"]: - candidate = os.path.join(test_dir, pattern) - if os.path.exists(candidate): - alpha_path = candidate - break - - if alpha_path and os.path.exists(alpha_path) and os.path.getsize(alpha_path) > 1000: - # Try to read the alpha video to ensure it's valid - try: - test_clip = VideoFileClip(alpha_path) - test_frame = test_clip.get_frame(0.1) - test_clip.close() - if test_frame is not None: - success = True - logger.info(f"MatAnyone verification successful: {alpha_path}") - except Exception as e: - logger.warning(f"Alpha video not readable: {e}") - - # Cleanup - try: - os.unlink(test_video.name) - os.unlink(test_mask.name) - shutil.rmtree(test_dir, ignore_errors=True) - except Exception: - pass - - return success + self.initialized = True + state.matanyone_ready = True + return True except Exception as e: - logger.error(f"MatAnyone verification error: {e}") + state.matanyone_error = f"MatAnyone init failed: {e}" return False # ============================================================================== -# CHAPTER 6: MAIN PROCESSING PIPELINE - FIXED +# AI BACKGROUND GENERATOR +# ============================================================================== + +def generate_ai_background( + width: int, + height: int, + prompt: str, + init_image_path: Optional[str] = None, + num_steps: int = 25, + guidance_scale: float = 7.5, + seed: Optional[int] = None, +) -> str: + """Generate AI background using Stable Diffusion.""" + if not TORCH_AVAILABLE: + raise RuntimeError("PyTorch required for AI background generation") + + try: + from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline + except ImportError: + raise RuntimeError("Please install diffusers: pip install diffusers transformers") + + # Setup generator + generator = torch.Generator(device=DEVICE) + if seed is None: + seed = random.randint(0, 2**31 - 1) + generator.manual_seed(seed) + + # Generate background + if init_image_path and os.path.exists(init_image_path): + # Image-to-image + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16 if CUDA_AVAILABLE else torch.float32, + safety_checker=None + ).to(DEVICE) + + init_image = Image.open(init_image_path).convert("RGB").resize((width, height)) + result = pipe( + prompt=prompt, + image=init_image, + strength=0.6, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + generator=generator + ).images[0] + else: + # Text-to-image + pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16 if CUDA_AVAILABLE else torch.float32, + safety_checker=None + ).to(DEVICE) + + result = pipe( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + generator=generator + ).images[0] + + # Save result + output_path = TEMP_DIR / f"ai_bg_{int(time.time())}_{uuid4().hex[:6]}.jpg" + result.save(output_path, quality=95) + return str(output_path) + +# ============================================================================== +# MAIN PROCESSING PIPELINE # ============================================================================== def process_video_main( video_path: str, - background_image_path: Optional[str] = None, - background_preset: str = "office", - trim_seconds: float = 5.0, - sam2_config: Optional[Dict[str, Any]] = None + background_path: Optional[str] = None, + trim_duration: Optional[float] = None, # CHANGED: Optional trimming + sam2_config: Optional[Dict[str, Any]] = None, + crf: int = 18, + preserve_audio: bool = True, ) -> Tuple[Optional[str], str]: """ - FIXED Main video processing pipeline - Uses TwoStageProcessor with proper parameters. - - KEY FIXES: - - Fixed output path parameter (now uses proper .mp4 extension) - - Fixed parameter order for TwoStageProcessor.process_video() - - Fixed background handling (pass actual background image/path) + Main video processing pipeline. + FIXED: Processes full video unless trim_duration is specified. + FIXED: Preserves original audio. """ - status_messages = [] temp_files = [] + messages = [] try: # Initialize processors sam2 = SAM2Handler() matanyone = MatAnyoneHandler() - # Update SAM2 config if provided if sam2_config: sam2.update_config(sam2_config) # Initialize SAM2 - sam2_ok = sam2.initialize() - if not sam2_ok: - error_details = f"SAM2 initialization failed: {state.sam2_error}" - if sam2.error_details: - error_details += f"\n\nDetailed error:\n{sam2.error_details}" - return None, error_details + if not sam2.initialize(): + return None, f"SAM2 initialization failed: {state.sam2_error}" # Initialize MatAnyone - matanyone_ok = matanyone.initialize() - if not matanyone_ok: - error_details = f"MatAnyone initialization failed: {state.matanyone_error}" - if matanyone.error_details: - error_details += f"\n\nDetailed error:\n{matanyone.error_details}" - return None, error_details + if not matanyone.initialize(): + return None, f"MatAnyone initialization failed: {state.matanyone_error}" - status_messages.append("Both SAM2 Large and MatAnyone initialized successfully (REVERTED TO WORKING VERSION)") - status_messages.append(f"Person Detection: {'✅' if state.person_detector_ready else '❌'}") - status_messages.append(f"Pose Detection: {'✅' if state.pose_detector_ready else '❌'}") + messages.append("✅ SAM2 and MatAnyone initialized successfully") - # Trim video to first N seconds - trimmed_video = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False, dir=TEMP_DIR) - trimmed_video.close() - temp_files.append(trimmed_video.name) + # FIXED: Only trim if explicitly requested + input_video = video_path + if trim_duration and trim_duration > 0: + trimmed_path = TEMP_DIR / f"trimmed_{uuid4().hex[:6]}.mp4" + temp_files.append(str(trimmed_path)) + + with VideoFileClip(video_path) as clip: + duration = min(trim_duration, float(clip.duration or trim_duration)) + trimmed_clip = clip.subclip(0, duration) + write_video_h264(trimmed_clip, str(trimmed_path), crf=crf) + trimmed_clip.close() + + input_video = str(trimmed_path) + messages.append(f"✂️ Video trimmed to {duration:.1f}s") + else: + with VideoFileClip(video_path) as clip: + messages.append(f"🎞️ Processing full video: {clip.duration:.1f}s") - clip = VideoFileClip(video_path) - duration = min(trim_seconds, float(clip.duration or trim_seconds)) - trimmed_clip = clip.subclip(0, duration) + # Create mask from first frame + cap = cv2.VideoCapture(input_video) + ret, first_frame = cap.read() + cap.release() - write_video_optimized(trimmed_clip, trimmed_video.name) + if not ret: + return None, "Could not read video" - clip.close() - trimmed_clip.close() - status_messages.append(f"Video trimmed to {duration:.1f}s") + height, width = first_frame.shape[:2] + rgb_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) - # FIXED: Prepare background properly for TwoStageProcessor - background_for_processor = None + mask = sam2.create_mask(rgb_frame) + if mask is None: + return None, "Failed to create mask with SAM2" + + # Save mask + mask_path = TEMP_DIR / f"mask_{uuid4().hex[:6]}.png" + temp_files.append(str(mask_path)) + cv2.imwrite(str(mask_path), mask) + messages.append("✅ Person mask created") + + # Process with MatAnyone + output_dir = TEMP_DIR / f"matanyone_output_{uuid4().hex[:6]}" + output_dir.mkdir(exist_ok=True) + temp_files.append(str(output_dir)) + + result = matanyone.processor.process_video( + input_path=input_video, + mask_path=str(mask_path), + output_path=str(output_dir) + ) - if background_image_path and os.path.exists(background_image_path): - # Use provided background image path - background_for_processor = background_image_path - status_messages.append("Using provided background image") + # Find alpha video + alpha_video = None + if isinstance(result, (list, tuple)) and len(result) > 1: + alpha_video = result[1] + elif isinstance(result, str): + alpha_video = result + + if not alpha_video or not os.path.exists(alpha_video): + # Search for alpha files + for pattern in ["alpha.mp4", "alpha.mkv", "alpha.mov"]: + candidate = output_dir / pattern + if candidate.exists(): + alpha_video = str(candidate) + break + + if not alpha_video or not os.path.exists(alpha_video): + return None, "MatAnyone failed to generate alpha video" + + messages.append("✅ Alpha video generated") + + # Composite with background + original_clip = VideoFileClip(input_video) + alpha_clip = VideoFileClip(alpha_video) + + # Load or create background + if background_path and os.path.exists(background_path): + bg_image = cv2.imread(background_path) + bg_image = cv2.resize(bg_image, (width, height)) + bg_rgb = cv2.cvtColor(bg_image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 else: - # Create preset background and save it - cap = cv2.VideoCapture(trimmed_video.name) - ret, frame = cap.read() - cap.release() - if ret: - h, w = frame.shape[:2] - background = create_background(w, h, background_preset) - background = cv2.cvtColor(background, cv2.COLOR_RGB2BGR) - - # Save background to temp file - background_temp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=TEMP_DIR) - background_temp.close() - temp_files.append(background_temp.name) - cv2.imwrite(background_temp.name, background) - background_for_processor = background_temp.name - status_messages.append(f"Created {background_preset} background") - else: - return None, "Could not read video" + # Default gradient background + bg_rgb = np.zeros((height, width, 3), dtype=np.float32) + for i in range(height): + ratio = i / (height - 1) + bg_rgb[i, :] = [0.2 + ratio * 0.3, 0.2 + ratio * 0.3, 0.3 + ratio * 0.2] - # FIXED: Try to import and use TwoStageProcessor with corrected parameters - try: - from processing.two_stage.two_stage_processor import TwoStageProcessor - logger.info("TwoStageProcessor imported successfully") - - # Create processor instance - processor = TwoStageProcessor(sam2, matanyone.processor) - status_messages.append("TwoStageProcessor created successfully") - - # CRITICAL FIX: Generate proper output path with .mp4 extension - output_path = OUT_DIR / f"two_stage_output_{int(time.time())}_{uuid4().hex[:6]}.mp4" - - # FIXED: Call process_video with correct parameter order and proper file paths - result = processor.process_video( - video_path=trimmed_video.name, # Input video - background_path=background_for_processor, # Background image path - output_path=str(output_path), # FIXED: Proper .mp4 output path - quality='medium', - trim_seconds=None, # Already trimmed above - callback=None # No callback for now - ) - - # CRITICAL FIX: Handle tuple return properly - if isinstance(result, tuple) and len(result) >= 2: - final_path = result[0] # Extract just the path - status_message = result[1] # Extract the status message - - logger.info(f"DEBUG: TwoStageProcessor returned final_path: {final_path}") - logger.info(f"DEBUG: TwoStageProcessor returned status: {status_message}") - - if final_path and os.path.exists(final_path): - logger.info(f"DEBUG: Final path exists, size: {os.path.getsize(final_path)} bytes") - status_messages.append("TwoStageProcessor completed successfully") - status_messages.append(status_message) - - # FIXED: Return the final path directly (TwoStageProcessor creates output in the right location) - return str(final_path), "\n".join(status_messages) - else: - logger.error(f"DEBUG: Final path does not exist or is None: {final_path}") - return None, f"TwoStageProcessor failed: {status_message} (final_path: {final_path})" - else: - logger.error(f"DEBUG: TwoStageProcessor returned unexpected format: {result}") - return None, f"TwoStageProcessor returned unexpected format: {result}" - - except ImportError as e: - logger.warning(f"Could not import TwoStageProcessor: {e}. Using fallback direct compositing.") - status_messages.append("Using fallback direct compositing (with SAM2 Large - WORKING VERSION)") - - # FALLBACK: Direct compositing with REVERTED SAM2 and fixed scale normalization - # Get video properties - cap = cv2.VideoCapture(trimmed_video.name) - ret, first_frame = cap.read() - cap.release() - - if not ret: - return None, "Could not read video" - - height, width = first_frame.shape[:2] - - # Create mask using REVERTED SAM2 with full-body detection - rgb_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) - logger.info("Using reverted SAM2Handler.create_mask() in fallback mode") - mask = sam2.create_mask(rgb_frame) + def composite_frame(get_frame, t): + """Composite original video with background using alpha.""" + original_frame = get_frame(t).astype(np.float32) / 255.0 - if mask is None: - return None, "SAM2 Large failed to create mask from first frame" + # Get alpha at time t + alpha_t = min(t, alpha_clip.duration - 0.01) if alpha_clip.duration > 0 else 0 + alpha_frame = alpha_clip.get_frame(alpha_t) - mask_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=TEMP_DIR) - mask_file.close() - temp_files.append(mask_file.name) - cv2.imwrite(mask_file.name, mask) - status_messages.append("SAM2 Large mask created successfully with original working strategies") + if alpha_frame.ndim == 3: + alpha_frame = alpha_frame[:, :, 0] # Use first channel - # Process with MatAnyone - output_dir = tempfile.mkdtemp(dir=TEMP_DIR) - matanyone_result = matanyone.processor.process_video(trimmed_video.name, mask_file.name, output_dir) + alpha_frame = alpha_frame.astype(np.float32) / 255.0 + alpha_frame = cv2.resize(alpha_frame, (width, height)) + alpha_frame = alpha_frame[:, :, np.newaxis] # Add channel dimension - # CRITICAL FIX: Handle MatAnyone returning tuple - alpha_video_path = None - if matanyone_result: - if isinstance(matanyone_result, (list, tuple)) and len(matanyone_result) > 1: - alpha_video_path = matanyone_result[1] # Second element is usually alpha video - elif isinstance(matanyone_result, str): - alpha_video_path = matanyone_result - else: - # Search for alpha files in output directory - for pattern in ["alpha.mp4", "alpha.mkv", "alpha.mov", "alpha.webm"]: - candidate = os.path.join(output_dir, pattern) - if os.path.exists(candidate): - alpha_video_path = candidate - break - - if not alpha_video_path or not os.path.exists(alpha_video_path): - error_details = "MatAnyone processing failed - no alpha video found" - if matanyone.error_details: - error_details += f"\n\nDetailed error:\n{matanyone.error_details}" - return None, error_details - - status_messages.append("MatAnyone processing successful") - - # Load background for compositing - if background_for_processor and os.path.exists(background_for_processor): - background = cv2.imread(background_for_processor) - if background is None: - logger.warning(f"Could not load background image: {background_for_processor}") - # Fallback to preset - background = create_background(width, height, background_preset) - background = cv2.cvtColor(background, cv2.COLOR_RGB2BGR) + # Composite + result = alpha_frame * original_frame + (1 - alpha_frame) * bg_rgb + return np.clip(result * 255, 0, 255).astype(np.uint8) + + # Apply compositing + final_clip = original_clip.fl(composite_frame) + + # Write final video (without audio first) + output_path = OUT_DIR / f"processed_{int(time.time())}_{uuid4().hex[:6]}.mp4" + temp_video_path = TEMP_DIR / f"temp_video_{uuid4().hex[:6]}.mp4" + temp_files.append(str(temp_video_path)) + + write_video_h264(final_clip, str(temp_video_path), crf=crf) + + original_clip.close() + alpha_clip.close() + final_clip.close() + + # FIXED: Preserve original audio + if preserve_audio: + if preserve_audio(video_path, str(temp_video_path), str(output_path)): + messages.append("🔊 Original audio preserved") else: - background = create_background(width, height, background_preset) - background = cv2.cvtColor(background, cv2.COLOR_RGB2BGR) - - # IMPROVED Composite with background (better color preservation) - original_clip = VideoFileClip(trimmed_video.name) - alpha_clip = VideoFileClip(alpha_video_path) # Now guaranteed to be a string - - def composite_frame_improved(get_frame, t): - """Improved composite frame with better color preservation.""" - try: - original_frame = get_frame(t) - - # Get alpha frame (handle duration mismatch) - alpha_duration = float(alpha_clip.duration or 0) - alpha_t = min(t, max(0, alpha_duration - 0.01)) if alpha_duration > 0 else 0 - alpha_frame = alpha_clip.get_frame(alpha_t) - - # Process alpha carefully - if alpha_frame.ndim == 3: - alpha_frame = alpha_frame[:, :, 0] - - # Normalize alpha to 0-1 range - if alpha_frame.max() > 1.1: - alpha_frame = alpha_frame / 255.0 - - alpha_frame = np.clip(alpha_frame, 0, 1) - - # Resize alpha if needed - if alpha_frame.shape != (height, width): - alpha_frame = cv2.resize(alpha_frame, (width, height)) - - # Feather the alpha for smoother edges - alpha_feathered = cv2.GaussianBlur(alpha_frame, (3, 3), 1.0) - alpha_feathered = alpha_feathered[:, :, np.newaxis] - - # Keep original frame in original color space (moviepy gives 0-1) - original_normalized = original_frame.astype(np.float32) - - # Process background more carefully - if background.dtype != np.float32: - bg_normalized = background.astype(np.float32) / 255.0 - else: - bg_normalized = background - - # Resize background to match frame - if bg_normalized.shape[:2] != (height, width): - bg_normalized = cv2.resize(bg_normalized, (width, height)) - # Convert BGR to RGB if needed - if bg_normalized.shape[2] == 3: - bg_normalized = cv2.cvtColor(bg_normalized, cv2.COLOR_BGR2RGB) - - # Composite with proper color preservation - # Use feathered alpha for smoother transitions - result = alpha_feathered * original_normalized + (1 - alpha_feathered) * bg_normalized - - # Ensure result is in correct range - result = np.clip(result, 0, 1).astype(np.float32) - - return result - - except Exception as e: - logger.error(f"Improved composite frame error: {e}") - # Return original frame on error - if 'original_frame' in locals(): - return original_frame - else: - return np.zeros((height, width, 3), dtype=np.float32) - - # Create final composite - final_clip = original_clip.fl(composite_frame_improved) - - # FIXED: Use proper output path for fallback too - fallback_output_path = OUT_DIR / f"fallback_output_{int(time.time())}_{uuid4().hex[:6]}.mp4" - write_video_optimized(final_clip, str(fallback_output_path)) - - original_clip.close() - alpha_clip.close() - final_clip.close() - - status_messages.append("Final compositing completed successfully (fallback mode with SAM2 Large - WORKING VERSION)") - return str(fallback_output_path), "\n".join(status_messages) + # Fallback: copy without audio + shutil.copy2(str(temp_video_path), str(output_path)) + messages.append("⚠️ Audio preservation failed, video saved without audio") + else: + shutil.copy2(str(temp_video_path), str(output_path)) + messages.append("🔇 Video saved without audio (as requested)") + + messages.append("✅ Processing completed successfully") + return str(output_path), "\n".join(messages) except Exception as e: - logger.error(f"Pipeline error: {e}") - error_details = f"Pipeline error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" - return None, error_details + error_msg = f"Processing failed: {str(e)}\n\n{traceback.format_exc()}" + return None, error_msg finally: # Cleanup temporary files for temp_file in temp_files: try: - if os.path.exists(temp_file): + if os.path.isdir(temp_file): + shutil.rmtree(temp_file, ignore_errors=True) + elif os.path.exists(temp_file): os.unlink(temp_file) except Exception: pass # ============================================================================== -# CHAPTER 7: GRADIO INTERFACE +# GRADIO INTERFACE # ============================================================================== -# ---- Preset catalog (put real files in assets/backgrounds/professional/) ---- -PROFESSIONAL_PRESETS = { - # Name # Relative path (you provide these files) - "Office – Warm": "assets/backgrounds/professional/office_warm.jpg", - "Office – Cool": "assets/backgrounds/professional/office_cool.jpg", - "Studio – Neutral": "assets/backgrounds/professional/studio_neutral.jpg", - "Studio – Dark": "assets/backgrounds/professional/studio_dark.jpg", -} - -# Simple two-stop vertical gradients (BGR tuples, OpenCV order!) -GRADIENT_PRESETS = { - "Blue Fade": ((128, 64, 0), (255, 128, 0)), # dark→light blue - "Sunset": ((255, 128, 0), (255, 0, 128)), # orange→magenta - "Green Field": ((64, 128, 64), (160, 255, 160)), # dark→light green - "Slate": ((40, 40, 48), (96, 96, 112)), # dark gray→slate -} - -def _ensure_file_exists(path_str: str) -> bool: - try: - p = BASE_DIR / path_str - return p.exists() and p.is_file() - except Exception: - return False - -def _ensure_size_bgr(img: np.ndarray, width: int, height: int) -> np.ndarray: - """Guarantee image is (height,width,3) BGR uint8.""" - if img is None or getattr(img, "size", 0) == 0: - raise ValueError("Background image is empty") - if img.ndim == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - if img.shape[2] != 3: - img = img[:, :, :3] - if (img.shape[1], img.shape[0]) != (width, height): - img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - return img.astype(np.uint8) - -def _make_vertical_gradient(width: int, height: int, c1, c2) -> np.ndarray: - """ - Make a vertical BGR gradient image (uint8). - Returns shape = (height, width, 3). Broadcast-safe via explicit repeat. - """ - top = np.array(c1, dtype=np.float32) # (3,) - bot = np.array(c2, dtype=np.float32) # (3,) - rows = np.linspace(top, bot, num=max(1, height), dtype=np.float32) # (height, 3) - grad = np.repeat(rows[:, None, :], repeats=max(1, width), axis=1) # (height, width, 3) - return np.clip(grad, 0, 255).astype(np.uint8) - -def resolve_background( - video_path: str, - source: str, - upload_img_path: Optional[str], - preset_name: Optional[str], - gradient_name: Optional[str], - ai_prompt: Optional[str], -) -> str: - """ - Returns a *file path* to a background image sized to the input video. - Always writes to TEMP_DIR and returns that path. - """ - # Read one frame to get dimensions - cap = cv2.VideoCapture(video_path) - ret, frame = cap.read() - cap.release() - if not ret: - raise RuntimeError("Could not read video to determine frame size.") - h, w = frame.shape[:2] - - out_path = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False, dir=TEMP_DIR).name - - if source == "Upload Image": - if not upload_img_path or not os.path.exists(upload_img_path): - raise RuntimeError("No uploaded background image provided.") - img = cv2.imread(upload_img_path) - if img is None: - raise RuntimeError("Failed to read uploaded background image.") - img = _ensure_size_bgr(img, w, h) - cv2.imwrite(out_path, img) - return out_path - - if source == "Professional Presets": - if not preset_name or preset_name not in PROFESSIONAL_PRESETS: - raise RuntimeError("Please select a professional preset.") - rel = PROFESSIONAL_PRESETS[preset_name] - path = str((BASE_DIR / rel).resolve()) - if not _ensure_file_exists(rel): - # Fallback gradient if the asset isn't present - logger.warning("Preset file not found: %s — using fallback gradient.", rel) - grad = _make_vertical_gradient(w, h, (96, 96, 112), (24, 24, 28)) - cv2.imwrite(out_path, grad) - return out_path - img = cv2.imread(path) - if img is None: - raise RuntimeError(f"Failed to load preset image: {preset_name}") - img = _ensure_size_bgr(img, w, h) - cv2.imwrite(out_path, img) - return out_path - - if source == "Gradients": - if not gradient_name or gradient_name not in GRADIENT_PRESETS: - raise RuntimeError("Please select a gradient preset.") - c1, c2 = GRADIENT_PRESETS[gradient_name] - grad = _make_vertical_gradient(w, h, c1, c2) - grad = _ensure_size_bgr(grad, w, h) - cv2.imwrite(out_path, grad) - return out_path - - if source == "AI Background": - # Placeholder: generate a pleasant neutral gradient for now. - # If you wire up a diffusion pipeline later, render to (w,h) here and save to out_path. - logger.info("AI Background requested: %s (using placeholder gradient for now)", ai_prompt or "(no prompt)") - grad = _make_vertical_gradient(w, h, (80, 72, 96), (180, 180, 210)) - grad = _ensure_size_bgr(grad, w, h) - cv2.imwrite(out_path, grad) - return out_path - - # Fallback (shouldn't happen) - raise RuntimeError(f"Unknown background source: {source}") - def create_interface(): - """Create Gradio interface with richer background options and original SAM2 controls.""" + """Create the Gradio interface.""" + + def run_diagnostics(): + return state.get_status() + + def generate_background(video_file, prompt, init_image, steps, guidance, seed_val): + """Generate AI background.""" + if not video_file: + return None, "Please upload a video first" + + try: + # Get video dimensions + video_path = video_file.name if hasattr(video_file, 'name') else str(video_file) + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + cap.release() + + if not ret: + return None, "Could not read video" + + height, width = frame.shape[:2] + init_path = init_image if isinstance(init_image, str) else ( + init_image.name if init_image and hasattr(init_image, 'name') else None + ) + + bg_path = generate_ai_background( + width=width, + height=height, + prompt=prompt or "professional office background", + init_image_path=init_path, + num_steps=int(steps), + guidance_scale=float(guidance), + seed=int(seed_val) if seed_val else None + ) + + return bg_path, "✅ AI background generated successfully" + + except Exception as e: + return None, f"❌ Background generation failed: {e}" + def process_video( video_file, background_source, - upload_img, - preset_choice, - gradient_choice, - ai_prompt, + uploaded_bg, + ai_generated_bg, + trim_enabled, + trim_seconds, + crf_value, + audio_enabled, + # SAM2 settings use_person_detection, use_pose_estimation, confidence_threshold, refine_iterations, ): - if video_file is None: - return None, None, "Please upload a video file." - - start_time = time.time() - video_path = video_file.name if hasattr(video_file, "name") else str(video_file) - upload_path = upload_img if isinstance(upload_img, str) else (upload_img.name if upload_img and hasattr(upload_img, "name") else None) - - # Build SAM2 config from UI + """Process the video.""" + if not video_file: + return None, None, "Please upload a video file" + + video_path = video_file.name if hasattr(video_file, 'name') else str(video_file) + + # Determine background + bg_path = None + if background_source == "Upload Image" and uploaded_bg: + bg_path = uploaded_bg if isinstance(uploaded_bg, str) else uploaded_bg.name + elif background_source == "AI Generated" and ai_generated_bg: + bg_path = ai_generated_bg + + # SAM2 configuration sam2_config = { - "use_person_detection": bool(use_person_detection), - "use_pose_estimation": bool(use_pose_estimation), - "confidence_threshold": float(confidence_threshold), - "refine_iterations": int(refine_iterations), + "use_person_detection": use_person_detection, + "use_pose_estimation": use_pose_estimation, + "confidence_threshold": confidence_threshold, + "refine_iterations": refine_iterations, } - - # Resolve background to a real file path sized to the video - try: - bg_path = resolve_background( - video_path=video_path, - source=background_source, - upload_img_path=upload_path, - preset_name=preset_choice, - gradient_name=gradient_choice, - ai_prompt=ai_prompt, - ) - except Exception as e: - logger.error("Background resolution error: %s", e) - return None, None, f"❌ Background error: {e}" - - # Call your main pipeline (unchanged) + + # Process video result_path, status = process_video_main( video_path=video_path, - background_image_path=bg_path, # pass the actual background image path - background_preset="office", # not used when background_image_path is provided - trim_seconds=5.0, # keep your current behavior + background_path=bg_path, + trim_duration=float(trim_seconds) if (trim_enabled and trim_seconds > 0) else None, sam2_config=sam2_config, + crf=int(crf_value), + preserve_audio=audio_enabled, ) - - elapsed = time.time() - start_time + if result_path and os.path.exists(result_path): - final_status = ( - f"✅ Processing completed successfully!\n" - f"Time: {elapsed:.1f}s\n" - f"Output: {result_path}\n\n" - f"SAM2 Large Configuration (working profile):\n" - f"- Person Detection: {'✅' if use_person_detection else '❌'}\n" - f"- Pose Estimation: {'✅' if use_pose_estimation else '❌'}\n" - f"- Confidence Threshold: {confidence_threshold:.2f}\n" - f"- Refinement Iterations: {refine_iterations}\n\n" - f"{status}" - ) - return result_path, result_path, final_status + return result_path, result_path, f"✅ Success!\n\n{status}" else: - return None, None, f"❌ Processing failed\n\n{status}" - - def run_diagnostics(): - return state.get_detailed_status() - + return None, None, f"❌ Failed!\n\n{status}" + + # Create interface with gr.Blocks(title="Video Background Replacement", theme=gr.themes.Soft()) as interface: + gr.Markdown("# 🎬 Video Background Replacement") - gr.Markdown("**SAM2 Large – Working config with richer background options**") - - # System status panel - status_html = f""" -
-

System Status

- Device: {state.device}
- PyTorch: {'✅' if state.torch_available else '❌'}
- CUDA: {'✅' if state.cuda_available else '❌'}
- Models initialize at first processing run. + gr.Markdown("**FIXED VERSION**: Processes full video length + preserves audio + real AI backgrounds") + + # System status + gr.HTML(f""" +
+ Device: {DEVICE}    + PyTorch: {'✅' if TORCH_AVAILABLE else '❌'}    + CUDA: {'✅' if CUDA_AVAILABLE else '❌'}
- """ - gr.HTML(status_html) - + """) + with gr.Row(): - with gr.Column(): - # Inputs - video_input = gr.Video(label="Input Video (first 5s processed)") + with gr.Column(scale=1): + # Input + video_input = gr.Video(label="Input Video") + + # Background options gr.Markdown("### 🖼️ Background") background_source = gr.Dropdown( - label="Background Source", - choices=["Professional Presets", "Upload Image", "Gradients", "AI Background"], - value="Professional Presets", - ) - upload_img = gr.Image(label="Upload Background Image", type="filepath", visible=False) - preset_choice = gr.Dropdown( - label="Professional Preset", - choices=list(PROFESSIONAL_PRESETS.keys()), - value=(list(PROFESSIONAL_PRESETS.keys())[0] if PROFESSIONAL_PRESETS else None), - visible=True, + choices=["Default Gradient", "Upload Image", "AI Generated"], + value="Default Gradient", + label="Background Source" ) - gradient_choice = gr.Dropdown( - label="Gradient Preset", - choices=list(GRADIENT_PRESETS.keys()), - value="Blue Fade", - visible=False, + + uploaded_bg = gr.Image( + label="Upload Background Image", + type="filepath", + visible=False ) + + # AI Background Generation + gr.Markdown("### 🤖 AI Background Generator") ai_prompt = gr.Textbox( label="AI Background Prompt", - placeholder="e.g., modern glass office with warm sunlight and soft bokeh", - visible=False, + placeholder="e.g., modern office with plants and natural lighting", + value="professional office background with soft lighting" ) - - # SAM2 options (your working config) - gr.Markdown("### 🎯 SAM2 Segmentation Settings (Working Profile)") + ai_init_image = gr.Image(label="Initial Image (optional)", type="filepath") + with gr.Row(): - use_person_detection = gr.Checkbox(label="Use Person Detection (YOLOv8)", value=True) - use_pose_estimation = gr.Checkbox(label="Use Pose Estimation (MediaPipe)", value=True) + ai_steps = gr.Slider(10, 50, value=25, step=1, label="Steps") + ai_guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="Guidance") + + ai_seed = gr.Number(label="Seed (optional)", precision=0) + + generate_bg_btn = gr.Button("Generate AI Background", variant="primary") + ai_generated_bg = gr.Image(label="Generated Background") + ai_bg_status = gr.Textbox(label="Generation Status", lines=2) + + # Processing options + gr.Markdown("### ⚙️ Processing Options") + with gr.Row(): - confidence_threshold = gr.Slider( - minimum=0.1, maximum=1.0, value=0.15, step=0.05, - label="Quality Threshold", - ) - refine_iterations = gr.Slider( - minimum=1, maximum=5, value=5, step=1, - label="Refinement Iterations", - ) - - process_button = gr.Button("Process Video", variant="primary") - gr.Markdown("### 🔧 System Diagnostics") - diagnostics_button = gr.Button("Run Diagnostics") - diagnostics_output = gr.Textbox(label="System Diagnostics", lines=15) - - with gr.Column(): + trim_enabled = gr.Checkbox(label="Trim Video", value=False) + trim_seconds = gr.Number(label="Trim to (seconds)", value=5, precision=1) + + with gr.Row(): + crf_value = gr.Slider(0, 30, value=18, step=1, label="Quality (CRF - lower = better)") + audio_enabled = gr.Checkbox(label="Preserve Audio", value=True) + + # SAM2 Settings + gr.Markdown("### 🎯 SAM2 Settings") + with gr.Row(): + use_person_detection = gr.Checkbox(label="Person Detection", value=True) + use_pose_estimation = gr.Checkbox(label="Pose Estimation", value=True) + + with gr.Row(): + confidence_threshold = gr.Slider(0.1, 1.0, value=0.15, step=0.05, label="Confidence") + refine_iterations = gr.Slider(1, 10, value=5, step=1, label="Refinement") + + # Buttons + process_btn = gr.Button("🚀 Process Video", variant="primary", size="lg") + + gr.Markdown("### 🔧 Diagnostics") + diagnostics_btn = gr.Button("Run System Diagnostics") + diagnostics_output = gr.Textbox(label="System Status", lines=10) + + with gr.Column(scale=1): + # Output output_video = gr.Video(label="Processed Video") - download_file = gr.File(label="Download") - status_output = gr.Textbox(label="Processing Status", lines=15) - - # Show/hide background inputs based on source - def _toggle_bg_fields(src): - return ( - gr.update(visible=(src == "Upload Image")), # upload_img - gr.update(visible=(src == "Professional Presets")), # preset_choice - gr.update(visible=(src == "Gradients")), # gradient_choice - gr.update(visible=(src == "AI Background")), # ai_prompt - ) - + download_file = gr.File(label="Download Processed Video") + status_output = gr.Textbox(label="Processing Status", lines=20) + + # Event handlers + def toggle_background_inputs(source): + return gr.update(visible=(source == "Upload Image")) + background_source.change( - _toggle_bg_fields, + toggle_background_inputs, inputs=[background_source], + outputs=[uploaded_bg] ) - - background_source.change( - _toggle_bg_fields, - inputs=[background_source], - outputs=[upload_img, preset_choice, gradient_choice, ai_prompt], + + generate_bg_btn.click( + generate_background, + inputs=[video_input, ai_prompt, ai_init_image, ai_steps, ai_guidance, ai_seed], + outputs=[ai_generated_bg, ai_bg_status] ) - - # Hook the processing - process_button.click( - fn=process_video, + + process_btn.click( + process_video, inputs=[ - video_input, - background_source, upload_img, preset_choice, gradient_choice, ai_prompt, - use_person_detection, use_pose_estimation, confidence_threshold, refine_iterations, + video_input, background_source, uploaded_bg, ai_generated_bg, + trim_enabled, trim_seconds, crf_value, audio_enabled, + use_person_detection, use_pose_estimation, confidence_threshold, refine_iterations ], - outputs=[output_video, download_file, status_output], - concurrency_limit=1, + outputs=[output_video, download_file, status_output] ) - - diagnostics_button.click( - fn=run_diagnostics, - outputs=[diagnostics_output], - concurrency_limit=1, + + diagnostics_btn.click( + run_diagnostics, + outputs=[diagnostics_output] ) - + gr.Markdown(""" - **Background Notes** - - *Professional Presets* expect files under `assets/backgrounds/professional/`. - - *Gradients* render on-the-fly (fastest) and are auto-sized to your video to avoid broadcasting issues. - - *AI Background* is currently a placeholder gradient — wire a diffusion - pipeline here later to render images from prompts, then save and return the path. + ### 📝 Notes + - **Full Length Processing**: By default, processes the entire video (no 5-second limit) + - **Audio Preservation**: Original audio is automatically preserved unless disabled + - **AI Backgrounds**: Generate custom backgrounds using Stable Diffusion + - **Quality Control**: Adjust CRF for quality vs file size (18 = high quality) + - **Optional Trimming**: Enable trimming only if you want to process part of the video """) - + return interface # ============================================================================== -# CHAPTER 8: MAIN ENTRY POINT +# MAIN # ============================================================================== def main(): - logger.info("Starting Video Background Replacement - SAM2 LARGE (REVERTED TO WORKING VERSION)") + logger.info("Starting Fixed Video Background Replacement") interface = create_interface() interface.launch( server_name="0.0.0.0", @@ -1987,4 +832,4 @@ def main(): ) if __name__ == "__main__": - main() + main() \ No newline at end of file