| | |
| | """ |
| | Enhanced utilities.py - Core computer vision functions with auto-best quality |
| | VERSION: 2.0-auto-best |
| | ROLLBACK: Set USE_ENHANCED_SEGMENTATION = False to revert to original behavior |
| | """ |
| |
|
| | import os |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from PIL import Image, ImageDraw |
| | import logging |
| | import time |
| | from typing import Optional, Dict, Any, Tuple, List |
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | USE_ENHANCED_SEGMENTATION = True |
| | USE_AUTO_TEMPORAL_CONSISTENCY = True |
| | USE_INTELLIGENT_PROMPTING = True |
| | USE_ITERATIVE_REFINEMENT = True |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | PROFESSIONAL_BACKGROUNDS = { |
| | "office_modern": { |
| | "name": "Modern Office", |
| | "type": "gradient", |
| | "colors": ["#f8f9fa", "#e9ecef", "#dee2e6"], |
| | "direction": "diagonal", |
| | "description": "Clean, contemporary office environment", |
| | "brightness": 0.95, |
| | "contrast": 1.1 |
| | }, |
| | "studio_blue": { |
| | "name": "Professional Blue", |
| | "type": "gradient", |
| | "colors": ["#1e3c72", "#2a5298", "#3498db"], |
| | "direction": "radial", |
| | "description": "Broadcast-quality blue studio", |
| | "brightness": 0.9, |
| | "contrast": 1.2 |
| | }, |
| | "studio_green": { |
| | "name": "Broadcast Green", |
| | "type": "color", |
| | "colors": ["#00b894"], |
| | "chroma_key": True, |
| | "description": "Professional green screen replacement", |
| | "brightness": 1.0, |
| | "contrast": 1.0 |
| | }, |
| | "minimalist": { |
| | "name": "Minimalist White", |
| | "type": "gradient", |
| | "colors": ["#ffffff", "#f1f2f6", "#ddd"], |
| | "direction": "soft_radial", |
| | "description": "Clean, minimal background", |
| | "brightness": 0.98, |
| | "contrast": 0.9 |
| | }, |
| | "warm_gradient": { |
| | "name": "Warm Sunset", |
| | "type": "gradient", |
| | "colors": ["#ff7675", "#fd79a8", "#fdcb6e"], |
| | "direction": "diagonal", |
| | "description": "Warm, inviting atmosphere", |
| | "brightness": 0.85, |
| | "contrast": 1.15 |
| | }, |
| | "tech_dark": { |
| | "name": "Tech Dark", |
| | "type": "gradient", |
| | "colors": ["#0c0c0c", "#2d3748", "#4a5568"], |
| | "direction": "vertical", |
| | "description": "Modern tech/gaming setup", |
| | "brightness": 0.7, |
| | "contrast": 1.3 |
| | }, |
| | "corporate_blue": { |
| | "name": "Corporate Blue", |
| | "type": "gradient", |
| | "colors": ["#667eea", "#764ba2", "#f093fb"], |
| | "direction": "diagonal", |
| | "description": "Professional corporate background", |
| | "brightness": 0.88, |
| | "contrast": 1.1 |
| | }, |
| | "nature_blur": { |
| | "name": "Soft Nature", |
| | "type": "gradient", |
| | "colors": ["#a8edea", "#fed6e3", "#d299c2"], |
| | "direction": "radial", |
| | "description": "Soft blurred nature effect", |
| | "brightness": 0.92, |
| | "contrast": 0.95 |
| | } |
| | } |
| |
|
| | |
| | class SegmentationError(Exception): |
| | """Custom exception for segmentation failures""" |
| | pass |
| |
|
| | class MaskRefinementError(Exception): |
| | """Custom exception for mask refinement failures""" |
| | pass |
| |
|
| | class BackgroundReplacementError(Exception): |
| | """Custom exception for background replacement failures""" |
| | pass |
| |
|
| | |
| | |
| | |
| |
|
| | def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| | """ |
| | ENHANCED VERSION 2.0: High-quality person segmentation with intelligent automation |
| | |
| | ROLLBACK: Set USE_ENHANCED_SEGMENTATION = False to revert to original behavior |
| | |
| | Args: |
| | image: Input image (H, W, 3) |
| | predictor: SAM2 predictor instance |
| | fallback_enabled: Whether to use fallback segmentation if AI fails |
| | |
| | Returns: |
| | Binary mask (H, W) with values 0-255 |
| | """ |
| | if not USE_ENHANCED_SEGMENTATION: |
| | return segment_person_hq_original(image, predictor, fallback_enabled) |
| | |
| | logger.debug("Using ENHANCED segmentation with intelligent automation") |
| | |
| | if image is None or image.size == 0: |
| | raise SegmentationError("Invalid input image") |
| | |
| | try: |
| | |
| | if predictor is None: |
| | if fallback_enabled: |
| | logger.warning("SAM2 predictor not available, using fallback") |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError("SAM2 predictor not available") |
| | |
| | |
| | try: |
| | predictor.set_image(image) |
| | except Exception as e: |
| | logger.error(f"Failed to set image in predictor: {e}") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError(f"Predictor setup failed: {e}") |
| | |
| | |
| | if USE_INTELLIGENT_PROMPTING: |
| | mask = _segment_with_intelligent_prompts(image, predictor) |
| | else: |
| | mask = _segment_with_basic_prompts(image, predictor) |
| | |
| | |
| | if USE_ITERATIVE_REFINEMENT and mask is not None: |
| | mask = _auto_refine_mask_iteratively(image, mask, predictor) |
| | |
| | |
| | if not _validate_mask_quality(mask, image.shape[:2]): |
| | logger.warning("Mask quality validation failed") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError("Poor mask quality") |
| | |
| | logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}") |
| | return mask |
| | |
| | except SegmentationError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Unexpected segmentation error: {e}") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError(f"Unexpected error: {e}") |
| |
|
| | def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| | """NEW: Intelligent automatic prompt generation""" |
| | try: |
| | h, w = image.shape[:2] |
| | |
| | |
| | pos_points, neg_points = _generate_smart_prompts(image) |
| | |
| | if len(pos_points) == 0: |
| | |
| | pos_points = np.array([[w//2, h//2]], dtype=np.float32) |
| | |
| | |
| | points = np.vstack([pos_points, neg_points]) |
| | labels = np.hstack([ |
| | np.ones(len(pos_points), dtype=np.int32), |
| | np.zeros(len(neg_points), dtype=np.int32) |
| | ]) |
| | |
| | logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points") |
| | |
| | |
| | with torch.no_grad(): |
| | masks, scores, _ = predictor.predict( |
| | point_coords=points, |
| | point_labels=labels, |
| | multimask_output=True |
| | ) |
| | |
| | if masks is None or len(masks) == 0: |
| | raise SegmentationError("No masks generated") |
| | |
| | |
| | if scores is not None and len(scores) > 0: |
| | best_idx = np.argmax(scores) |
| | best_mask = masks[best_idx] |
| | logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}") |
| | else: |
| | best_mask = masks[0] |
| | |
| | return _process_mask(best_mask) |
| | |
| | except Exception as e: |
| | logger.error(f"Intelligent prompting failed: {e}") |
| | raise |
| |
|
| | def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | """NEW: Generate optimal positive/negative points automatically""" |
| | try: |
| | h, w = image.shape[:2] |
| | |
| | |
| | try: |
| | saliency = cv2.saliency.StaticSaliencySpectralResidual_create() |
| | success, saliency_map = saliency.computeSaliency(image) |
| | |
| | if success: |
| | |
| | saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1] |
| | contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8), |
| | cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | |
| | positive_points = [] |
| | if contours: |
| | |
| | for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: |
| | M = cv2.moments(contour) |
| | if M["m00"] != 0: |
| | cx = int(M["m10"] / M["m00"]) |
| | cy = int(M["m01"] / M["m00"]) |
| | |
| | if 0 < cx < w and 0 < cy < h: |
| | positive_points.append([cx, cy]) |
| | |
| | if positive_points: |
| | logger.debug(f"Generated {len(positive_points)} saliency-based points") |
| | positive_points = np.array(positive_points, dtype=np.float32) |
| | else: |
| | raise Exception("No valid saliency points found") |
| | |
| | except Exception as e: |
| | logger.debug(f"Saliency method failed: {e}, using fallback") |
| | |
| | positive_points = np.array([ |
| | [w//2, h//3], |
| | [w//2, h//2], |
| | [w//2, 2*h//3], |
| | ], dtype=np.float32) |
| | |
| | |
| | negative_points = np.array([ |
| | [10, 10], |
| | [w-10, 10], |
| | [10, h-10], |
| | [w-10, h-10], |
| | [w//2, 5], |
| | [w//2, h-5], |
| | ], dtype=np.float32) |
| | |
| | return positive_points, negative_points |
| | |
| | except Exception as e: |
| | logger.warning(f"Smart prompt generation failed: {e}") |
| | |
| | h, w = image.shape[:2] |
| | positive_points = np.array([[w//2, h//2]], dtype=np.float32) |
| | negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32) |
| | return positive_points, negative_points |
| |
|
| | def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray, |
| | predictor: Any, max_iterations: int = 2) -> np.ndarray: |
| | """NEW: Automatically refine mask based on quality assessment""" |
| | try: |
| | current_mask = initial_mask.copy() |
| | h, w = image.shape[:2] |
| | |
| | for iteration in range(max_iterations): |
| | |
| | quality_score = _assess_mask_quality(current_mask, image) |
| | logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}") |
| | |
| | if quality_score > 0.85: |
| | logger.debug(f"Quality sufficient after {iteration} iterations") |
| | break |
| | |
| | |
| | problem_areas = _find_mask_errors(current_mask, image) |
| | |
| | if np.any(problem_areas): |
| | |
| | corrective_points, corrective_labels = _generate_corrective_prompts( |
| | image, current_mask, problem_areas |
| | ) |
| | |
| | if len(corrective_points) > 0: |
| | |
| | try: |
| | with torch.no_grad(): |
| | masks, scores, _ = predictor.predict( |
| | point_coords=corrective_points, |
| | point_labels=corrective_labels, |
| | mask_input=current_mask[None, :, :], |
| | multimask_output=False |
| | ) |
| | |
| | if masks is not None and len(masks) > 0: |
| | refined_mask = _process_mask(masks[0]) |
| | |
| | |
| | if _assess_mask_quality(refined_mask, image) > quality_score: |
| | current_mask = refined_mask |
| | logger.debug(f"Improved mask in iteration {iteration}") |
| | else: |
| | logger.debug(f"Refinement didn't improve quality in iteration {iteration}") |
| | break |
| | |
| | except Exception as e: |
| | logger.debug(f"Refinement iteration {iteration} failed: {e}") |
| | break |
| | else: |
| | logger.debug("No problem areas detected") |
| | break |
| | |
| | return current_mask |
| | |
| | except Exception as e: |
| | logger.warning(f"Iterative refinement failed: {e}") |
| | return initial_mask |
| |
|
| | def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float: |
| | """NEW: Assess mask quality automatically""" |
| | try: |
| | h, w = image.shape[:2] |
| | |
| | |
| | scores = [] |
| | |
| | |
| | mask_area = np.sum(mask > 127) |
| | total_area = h * w |
| | area_ratio = mask_area / total_area |
| | |
| | if 0.05 <= area_ratio <= 0.8: |
| | area_score = 1.0 |
| | elif area_ratio < 0.05: |
| | area_score = area_ratio / 0.05 |
| | else: |
| | area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2) |
| | scores.append(area_score) |
| | |
| | |
| | mask_binary = mask > 127 |
| | if np.any(mask_binary): |
| | mask_center_y, mask_center_x = np.where(mask_binary) |
| | center_y = np.mean(mask_center_y) / h |
| | center_x = np.mean(mask_center_x) / w |
| | |
| | center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5)) |
| | scores.append(center_score) |
| | else: |
| | scores.append(0.0) |
| | |
| | |
| | edges = cv2.Canny(mask, 50, 150) |
| | edge_density = np.sum(edges > 0) / total_area |
| | smoothness_score = max(0, 1.0 - edge_density * 10) |
| | scores.append(smoothness_score) |
| | |
| | |
| | num_labels, _ = cv2.connectedComponents(mask) |
| | connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2) |
| | scores.append(connectivity_score) |
| | |
| | |
| | weights = [0.3, 0.2, 0.3, 0.2] |
| | overall_score = np.average(scores, weights=weights) |
| | |
| | return overall_score |
| | |
| | except Exception as e: |
| | logger.warning(f"Quality assessment failed: {e}") |
| | return 0.5 |
| |
|
| | def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
| | """NEW: Identify problematic areas in mask""" |
| | try: |
| | |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| | |
| | |
| | edges = cv2.Canny(gray, 50, 150) |
| | |
| | |
| | mask_edges = cv2.Canny(mask, 50, 150) |
| | |
| | |
| | edge_discrepancy = cv2.bitwise_xor(edges, mask_edges) |
| | |
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| | error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1) |
| | |
| | return error_regions > 0 |
| | |
| | except Exception as e: |
| | logger.warning(f"Error detection failed: {e}") |
| | return np.zeros_like(mask, dtype=bool) |
| |
|
| | def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray, |
| | problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | """NEW: Generate corrective prompts based on problem areas""" |
| | try: |
| | |
| | contours, _ = cv2.findContours(problem_areas.astype(np.uint8), |
| | cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | |
| | corrective_points = [] |
| | corrective_labels = [] |
| | |
| | for contour in contours: |
| | if cv2.contourArea(contour) > 100: |
| | M = cv2.moments(contour) |
| | if M["m00"] != 0: |
| | cx = int(M["m10"] / M["m00"]) |
| | cy = int(M["m01"] / M["m00"]) |
| | |
| | |
| | |
| | current_mask_value = mask[cy, cx] |
| | |
| | |
| | |
| | if current_mask_value < 127: |
| | |
| | corrective_points.append([cx, cy]) |
| | corrective_labels.append(1) |
| | else: |
| | |
| | corrective_points.append([cx, cy]) |
| | corrective_labels.append(0) |
| | |
| | return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2), |
| | np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32)) |
| | |
| | except Exception as e: |
| | logger.warning(f"Corrective prompt generation failed: {e}") |
| | return np.array([]).reshape(0, 2), np.array([], dtype=np.int32) |
| |
|
| | def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| | """FALLBACK: Original basic prompting method""" |
| | h, w = image.shape[:2] |
| | |
| | |
| | positive_points = np.array([ |
| | [w//2, h//3], |
| | [w//2, h//2], |
| | [w//2, 2*h//3], |
| | ], dtype=np.float32) |
| | |
| | negative_points = np.array([ |
| | [w//10, h//10], |
| | [9*w//10, h//10], |
| | [w//10, 9*h//10], |
| | [9*w//10, 9*h//10], |
| | ], dtype=np.float32) |
| | |
| | |
| | points = np.vstack([positive_points, negative_points]) |
| | labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32) |
| | |
| | |
| | with torch.no_grad(): |
| | masks, scores, _ = predictor.predict( |
| | point_coords=points, |
| | point_labels=labels, |
| | multimask_output=True |
| | ) |
| | |
| | if masks is None or len(masks) == 0: |
| | raise SegmentationError("No masks generated") |
| | |
| | |
| | best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0 |
| | best_mask = masks[best_idx] |
| | |
| | return _process_mask(best_mask) |
| |
|
| | |
| | |
| | |
| |
|
| | def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| | """ |
| | ORIGINAL VERSION: Preserved for rollback capability |
| | """ |
| | if image is None or image.size == 0: |
| | raise SegmentationError("Invalid input image") |
| | |
| | try: |
| | |
| | if predictor is None: |
| | if fallback_enabled: |
| | logger.warning("SAM2 predictor not available, using fallback") |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError("SAM2 predictor not available") |
| | |
| | |
| | try: |
| | predictor.set_image(image) |
| | except Exception as e: |
| | logger.error(f"Failed to set image in predictor: {e}") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError(f"Predictor setup failed: {e}") |
| | |
| | h, w = image.shape[:2] |
| | |
| | |
| | points = np.array([ |
| | [w//2, h//4], |
| | [w//2, h//2], |
| | [w//2, 3*h//4], |
| | [w//3, h//2], |
| | [2*w//3, h//2], |
| | [w//2, h//6], |
| | [w//4, 2*h//3], |
| | [3*w//4, 2*h//3], |
| | ], dtype=np.float32) |
| | |
| | labels = np.ones(len(points), dtype=np.int32) |
| | |
| | |
| | try: |
| | with torch.no_grad(): |
| | masks, scores, _ = predictor.predict( |
| | point_coords=points, |
| | point_labels=labels, |
| | multimask_output=True |
| | ) |
| | except Exception as e: |
| | logger.error(f"SAM2 prediction failed: {e}") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError(f"Prediction failed: {e}") |
| | |
| | |
| | if masks is None or len(masks) == 0: |
| | logger.warning("SAM2 returned no masks") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError("No masks generated") |
| | |
| | if scores is None or len(scores) == 0: |
| | logger.warning("SAM2 returned no scores") |
| | best_mask = masks[0] |
| | else: |
| | |
| | best_idx = np.argmax(scores) |
| | best_mask = masks[best_idx] |
| | logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}") |
| | |
| | |
| | mask = _process_mask(best_mask) |
| | |
| | |
| | if not _validate_mask_quality(mask, image.shape[:2]): |
| | logger.warning("Mask quality validation failed") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError("Poor mask quality") |
| | |
| | logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}") |
| | return mask |
| | |
| | except SegmentationError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Unexpected segmentation error: {e}") |
| | if fallback_enabled: |
| | return _fallback_segmentation(image) |
| | else: |
| | raise SegmentationError(f"Unexpected error: {e}") |
| |
|
| | |
| | |
| | |
| |
|
| | def _process_mask(mask: np.ndarray) -> np.ndarray: |
| | """Process raw mask to ensure correct format and range""" |
| | try: |
| | |
| | if len(mask.shape) > 2: |
| | mask = mask.squeeze() |
| | |
| | if len(mask.shape) > 2: |
| | mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2) |
| | |
| | |
| | if mask.dtype == bool: |
| | mask = mask.astype(np.uint8) * 255 |
| | elif mask.dtype == np.float32 or mask.dtype == np.float64: |
| | if mask.max() <= 1.0: |
| | mask = (mask * 255).astype(np.uint8) |
| | else: |
| | mask = np.clip(mask, 0, 255).astype(np.uint8) |
| | else: |
| | mask = mask.astype(np.uint8) |
| | |
| | |
| | kernel = np.ones((3, 3), np.uint8) |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| | |
| | |
| | _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| | |
| | return mask |
| | |
| | except Exception as e: |
| | logger.error(f"Mask processing failed: {e}") |
| | |
| | h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256) |
| | fallback = np.zeros((h, w), dtype=np.uint8) |
| | fallback[h//4:3*h//4, w//4:3*w//4] = 255 |
| | return fallback |
| |
|
| | def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool: |
| | """Validate that the mask meets quality criteria""" |
| | try: |
| | h, w = image_shape |
| | mask_area = np.sum(mask > 127) |
| | total_area = h * w |
| | |
| | |
| | area_ratio = mask_area / total_area |
| | if area_ratio < 0.05 or area_ratio > 0.8: |
| | logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}") |
| | return False |
| | |
| | |
| | mask_binary = mask > 127 |
| | mask_center_y, mask_center_x = np.where(mask_binary) |
| | |
| | if len(mask_center_y) == 0: |
| | logger.warning("Empty mask") |
| | return False |
| | |
| | center_y = np.mean(mask_center_y) |
| | center_x = np.mean(mask_center_x) |
| | |
| | |
| | if center_y < h * 0.2 or center_y > h * 0.9: |
| | logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}") |
| | return False |
| | |
| | return True |
| | |
| | except Exception as e: |
| | logger.warning(f"Mask validation error: {e}") |
| | return True |
| |
|
| | def _fallback_segmentation(image: np.ndarray) -> np.ndarray: |
| | """Fallback segmentation when AI models fail""" |
| | try: |
| | logger.info("Using fallback segmentation strategy") |
| | h, w = image.shape[:2] |
| | |
| | |
| | try: |
| | |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| | |
| | |
| | edge_pixels = np.concatenate([ |
| | gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1] |
| | ]) |
| | bg_color = np.median(edge_pixels) |
| | |
| | |
| | diff = np.abs(gray.astype(float) - bg_color) |
| | mask = (diff > 30).astype(np.uint8) * 255 |
| | |
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| | |
| | |
| | if _validate_mask_quality(mask, image.shape[:2]): |
| | logger.info("Background subtraction fallback successful") |
| | return mask |
| | |
| | except Exception as e: |
| | logger.warning(f"Background subtraction fallback failed: {e}") |
| | |
| | |
| | mask = np.zeros((h, w), dtype=np.uint8) |
| | |
| | |
| | center_x, center_y = w // 2, h // 2 |
| | radius_x, radius_y = w // 3, h // 2.5 |
| | |
| | y, x = np.ogrid[:h, :w] |
| | mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1 |
| | mask[mask_ellipse] = 255 |
| | |
| | logger.info("Using geometric fallback mask") |
| | return mask |
| | |
| | except Exception as e: |
| | logger.error(f"All fallback strategies failed: {e}") |
| | |
| | h, w = image.shape[:2] |
| | mask = np.zeros((h, w), dtype=np.uint8) |
| | mask[h//6:5*h//6, w//4:3*w//4] = 255 |
| | return mask |
| |
|
| | |
| | |
| | |
| |
|
| | def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any, |
| | fallback_enabled: bool = True) -> np.ndarray: |
| | """ |
| | Enhanced mask refinement with MatAnyone and robust fallbacks |
| | UNCHANGED for rollback safety |
| | """ |
| | if image is None or mask is None: |
| | raise MaskRefinementError("Invalid input image or mask") |
| | |
| | try: |
| | |
| | mask = _process_mask(mask) |
| | |
| | |
| | if matanyone_processor is not None: |
| | try: |
| | logger.debug("Attempting MatAnyone refinement") |
| | refined_mask = _matanyone_refine(image, mask, matanyone_processor) |
| | |
| | if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]): |
| | logger.debug("MatAnyone refinement successful") |
| | return refined_mask |
| | else: |
| | logger.warning("MatAnyone produced poor quality mask") |
| | |
| | except Exception as e: |
| | logger.warning(f"MatAnyone refinement failed: {e}") |
| | |
| | |
| | if fallback_enabled: |
| | logger.debug("Using enhanced OpenCV refinement") |
| | return enhance_mask_opencv_advanced(image, mask) |
| | else: |
| | raise MaskRefinementError("MatAnyone failed and fallback disabled") |
| | |
| | except MaskRefinementError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Unexpected mask refinement error: {e}") |
| | if fallback_enabled: |
| | return enhance_mask_opencv_advanced(image, mask) |
| | else: |
| | raise MaskRefinementError(f"Unexpected error: {e}") |
| |
|
| | def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]: |
| | """Attempt MatAnyone mask refinement - Python 3.10 compatible""" |
| | try: |
| | |
| | if hasattr(processor, 'infer'): |
| | refined_mask = processor.infer(image, mask) |
| | elif hasattr(processor, 'process'): |
| | refined_mask = processor.process(image, mask) |
| | elif callable(processor): |
| | refined_mask = processor(image, mask) |
| | else: |
| | logger.warning("Unknown MatAnyone interface") |
| | return None |
| | |
| | if refined_mask is None: |
| | return None |
| | |
| | |
| | refined_mask = _process_mask(refined_mask) |
| | |
| | logger.debug("MatAnyone refinement successful") |
| | return refined_mask |
| | |
| | except Exception as e: |
| | logger.warning(f"MatAnyone processing error: {e}") |
| | return None |
| |
|
| | def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| | """Advanced OpenCV-based mask enhancement with multiple techniques""" |
| | try: |
| | if len(mask.shape) == 3: |
| | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| | |
| | |
| | if mask.max() <= 1.0: |
| | mask = (mask * 255).astype(np.uint8) |
| | |
| | |
| | |
| | |
| | refined_mask = cv2.bilateralFilter(mask, 9, 75, 75) |
| | |
| | |
| | refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2) |
| | |
| | |
| | kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| | refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close) |
| | |
| | kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) |
| | refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open) |
| | |
| | |
| | refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8) |
| | |
| | |
| | _, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY) |
| | |
| | return refined_mask |
| | |
| | except Exception as e: |
| | logger.warning(f"Enhanced OpenCV refinement failed: {e}") |
| | |
| | return cv2.GaussianBlur(mask, (5, 5), 1.0) |
| |
|
| | def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray: |
| | """Approximation of guided filter for edge-aware smoothing""" |
| | try: |
| | guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide |
| | guide_gray = guide_gray.astype(np.float32) / 255.0 |
| | mask_float = mask.astype(np.float32) / 255.0 |
| | |
| | |
| | kernel_size = 2 * radius + 1 |
| | |
| | |
| | mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size)) |
| | mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size)) |
| | corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size)) |
| | |
| | |
| | cov_guide_mask = corr_guide_mask - mean_guide * mean_mask |
| | mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size)) |
| | var_guide = mean_guide_sq - mean_guide * mean_guide |
| | |
| | |
| | a = cov_guide_mask / (var_guide + eps) |
| | b = mean_mask - a * mean_guide |
| | |
| | |
| | mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size)) |
| | mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size)) |
| | |
| | output = mean_a * guide_gray + mean_b |
| | output = np.clip(output * 255, 0, 255).astype(np.uint8) |
| | |
| | return output |
| | |
| | except Exception as e: |
| | logger.warning(f"Guided filter approximation failed: {e}") |
| | return mask |
| |
|
| | def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray, |
| | fallback_enabled: bool = True) -> np.ndarray: |
| | """Enhanced background replacement with comprehensive error handling and quality improvements""" |
| | if frame is None or mask is None or background is None: |
| | raise BackgroundReplacementError("Invalid input frame, mask, or background") |
| | |
| | try: |
| | |
| | background = cv2.resize(background, (frame.shape[1], frame.shape[0]), |
| | interpolation=cv2.INTER_LANCZOS4) |
| | |
| | |
| | if len(mask.shape) == 3: |
| | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| | |
| | if mask.dtype != np.uint8: |
| | mask = mask.astype(np.uint8) |
| | |
| | if mask.max() <= 1.0: |
| | logger.debug("Converting normalized mask to 0-255 range") |
| | mask = (mask * 255).astype(np.uint8) |
| | |
| | |
| | try: |
| | result = _advanced_compositing(frame, mask, background) |
| | logger.debug("Advanced compositing successful") |
| | return result |
| | |
| | except Exception as e: |
| | logger.warning(f"Advanced compositing failed: {e}") |
| | if fallback_enabled: |
| | return _simple_compositing(frame, mask, background) |
| | else: |
| | raise BackgroundReplacementError(f"Advanced compositing failed: {e}") |
| | |
| | except BackgroundReplacementError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Unexpected background replacement error: {e}") |
| | if fallback_enabled: |
| | return _simple_compositing(frame, mask, background) |
| | else: |
| | raise BackgroundReplacementError(f"Unexpected error: {e}") |
| |
|
| | def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray: |
| | """Advanced compositing with edge feathering and color correction""" |
| | try: |
| | |
| | threshold = 100 |
| | _, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) |
| | |
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| | mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel) |
| | mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel) |
| | |
| | |
| | mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0) |
| | mask_smooth = mask_smooth / 255.0 |
| | |
| | |
| | mask_smooth = np.power(mask_smooth, 0.8) |
| | |
| | |
| | mask_smooth = np.where(mask_smooth > 0.5, |
| | np.minimum(mask_smooth * 1.1, 1.0), |
| | mask_smooth * 0.9) |
| | |
| | |
| | frame_adjusted = _color_match_edges(frame, background, mask_smooth) |
| | |
| | |
| | alpha_3ch = np.stack([mask_smooth] * 3, axis=2) |
| | |
| | |
| | frame_float = frame_adjusted.astype(np.float32) |
| | background_float = background.astype(np.float32) |
| | |
| | |
| | result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch) |
| | result = np.clip(result, 0, 255).astype(np.uint8) |
| | |
| | return result |
| | |
| | except Exception as e: |
| | logger.error(f"Advanced compositing error: {e}") |
| | raise |
| |
|
| | def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
| | """Subtle color matching at edges to reduce halos""" |
| | try: |
| | |
| | edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3) |
| | edge_mask = np.abs(edge_mask) |
| | edge_mask = (edge_mask > 0.1).astype(np.float32) |
| | |
| | |
| | edge_areas = edge_mask > 0 |
| | if not np.any(edge_areas): |
| | return frame |
| | |
| | |
| | frame_adjusted = frame.copy().astype(np.float32) |
| | background_float = background.astype(np.float32) |
| | |
| | |
| | adjustment_strength = 0.1 |
| | for c in range(3): |
| | frame_adjusted[:, :, c] = np.where( |
| | edge_areas, |
| | frame_adjusted[:, :, c] * (1 - adjustment_strength) + |
| | background_float[:, :, c] * adjustment_strength, |
| | frame_adjusted[:, :, c] |
| | ) |
| | |
| | return np.clip(frame_adjusted, 0, 255).astype(np.uint8) |
| | |
| | except Exception as e: |
| | logger.warning(f"Color matching failed: {e}") |
| | return frame |
| |
|
| | def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray: |
| | """Simple fallback compositing method""" |
| | try: |
| | logger.info("Using simple compositing fallback") |
| | |
| | |
| | background = cv2.resize(background, (frame.shape[1], frame.shape[0])) |
| | |
| | |
| | if len(mask.shape) == 3: |
| | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| | if mask.max() <= 1.0: |
| | mask = (mask * 255).astype(np.uint8) |
| | |
| | |
| | _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| | |
| | |
| | mask_norm = mask_binary.astype(np.float32) / 255.0 |
| | mask_3ch = np.stack([mask_norm] * 3, axis=2) |
| | |
| | |
| | result = frame * mask_3ch + background * (1 - mask_3ch) |
| | return result.astype(np.uint8) |
| | |
| | except Exception as e: |
| | logger.error(f"Simple compositing failed: {e}") |
| | |
| | return frame |
| |
|
| | def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
| | """Enhanced professional background creation with quality improvements""" |
| | try: |
| | if bg_config["type"] == "color": |
| | background = _create_solid_background(bg_config, width, height) |
| | elif bg_config["type"] == "gradient": |
| | background = _create_gradient_background_enhanced(bg_config, width, height) |
| | else: |
| | |
| | background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
| | |
| | |
| | background = _apply_background_adjustments(background, bg_config) |
| | |
| | return background |
| | |
| | except Exception as e: |
| | logger.error(f"Background creation error: {e}") |
| | return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
| |
|
| | def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
| | """Create solid color background""" |
| | color_hex = bg_config["colors"][0].lstrip('#') |
| | color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4)) |
| | color_bgr = color_rgb[::-1] |
| | return np.full((height, width, 3), color_bgr, dtype=np.uint8) |
| |
|
| | def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
| | """Create enhanced gradient background with better quality""" |
| | try: |
| | colors = bg_config["colors"] |
| | direction = bg_config.get("direction", "vertical") |
| | |
| | |
| | rgb_colors = [] |
| | for color_hex in colors: |
| | color_hex = color_hex.lstrip('#') |
| | rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4)) |
| | rgb_colors.append(rgb) |
| | |
| | if not rgb_colors: |
| | rgb_colors = [(128, 128, 128)] |
| | |
| | |
| | if direction == "vertical": |
| | background = _create_vertical_gradient(rgb_colors, width, height) |
| | elif direction == "horizontal": |
| | background = _create_horizontal_gradient(rgb_colors, width, height) |
| | elif direction == "diagonal": |
| | background = _create_diagonal_gradient(rgb_colors, width, height) |
| | elif direction in ["radial", "soft_radial"]: |
| | background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial") |
| | else: |
| | background = _create_vertical_gradient(rgb_colors, width, height) |
| | |
| | return cv2.cvtColor(background, cv2.COLOR_RGB2BGR) |
| | |
| | except Exception as e: |
| | logger.error(f"Gradient creation error: {e}") |
| | return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
| |
|
| | def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray: |
| | """Create vertical gradient using NumPy for performance""" |
| | gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| | |
| | for y in range(height): |
| | progress = y / height if height > 0 else 0 |
| | color = _interpolate_color(colors, progress) |
| | gradient[y, :] = color |
| | |
| | return gradient |
| |
|
| | def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray: |
| | """Create horizontal gradient using NumPy for performance""" |
| | gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| | |
| | for x in range(width): |
| | progress = x / width if width > 0 else 0 |
| | color = _interpolate_color(colors, progress) |
| | gradient[:, x] = color |
| | |
| | return gradient |
| |
|
| | def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray: |
| | """Create diagonal gradient using vectorized operations""" |
| | y_coords, x_coords = np.mgrid[0:height, 0:width] |
| | max_distance = width + height |
| | progress = (x_coords + y_coords) / max_distance |
| | progress = np.clip(progress, 0, 1) |
| | |
| | |
| | gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| | for c in range(3): |
| | gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c) |
| | |
| | return gradient |
| |
|
| | def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray: |
| | """Create radial gradient using vectorized operations""" |
| | center_x, center_y = width // 2, height // 2 |
| | max_distance = np.sqrt(center_x**2 + center_y**2) |
| | |
| | y_coords, x_coords = np.mgrid[0:height, 0:width] |
| | distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2) |
| | progress = distances / max_distance |
| | progress = np.clip(progress, 0, 1) |
| | |
| | if soft: |
| | progress = np.power(progress, 0.7) |
| | |
| | |
| | gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| | for c in range(3): |
| | gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c) |
| | |
| | return gradient |
| |
|
| | def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray: |
| | """Vectorized color interpolation for performance""" |
| | if len(colors) == 1: |
| | return np.full_like(progress, colors[0][channel], dtype=np.uint8) |
| | |
| | num_segments = len(colors) - 1 |
| | segment_progress = progress * num_segments |
| | segment_indices = np.floor(segment_progress).astype(int) |
| | segment_indices = np.clip(segment_indices, 0, num_segments - 1) |
| | local_progress = segment_progress - segment_indices |
| | |
| | |
| | start_colors = np.array([colors[i][channel] for i in range(len(colors))]) |
| | end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))]) |
| | |
| | start_vals = start_colors[segment_indices] |
| | end_vals = end_colors[segment_indices] |
| | |
| | result = start_vals + (end_vals - start_vals) * local_progress |
| | return np.clip(result, 0, 255).astype(np.uint8) |
| |
|
| | def _interpolate_color(colors: list, progress: float) -> tuple: |
| | """Interpolate between multiple colors""" |
| | if len(colors) == 1: |
| | return colors[0] |
| | elif len(colors) == 2: |
| | r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress) |
| | g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress) |
| | b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress) |
| | return (r, g, b) |
| | else: |
| | segment = progress * (len(colors) - 1) |
| | idx = int(segment) |
| | local_progress = segment - idx |
| | if idx >= len(colors) - 1: |
| | return colors[-1] |
| | c1, c2 = colors[idx], colors[idx + 1] |
| | r = int(c1[0] + (c2[0] - c1[0]) * local_progress) |
| | g = int(c1[1] + (c2[1] - c1[1]) * local_progress) |
| | b = int(c1[2] + (c2[2] - c1[2]) * local_progress) |
| | return (r, g, b) |
| |
|
| | def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray: |
| | """Apply brightness and contrast adjustments to background""" |
| | try: |
| | brightness = bg_config.get("brightness", 1.0) |
| | contrast = bg_config.get("contrast", 1.0) |
| | |
| | if brightness != 1.0 or contrast != 1.0: |
| | background = background.astype(np.float32) |
| | background = background * contrast * brightness |
| | background = np.clip(background, 0, 255).astype(np.uint8) |
| | |
| | return background |
| | |
| | except Exception as e: |
| | logger.warning(f"Background adjustment failed: {e}") |
| | return background |
| |
|
| | def validate_video_file(video_path: str) -> Tuple[bool, str]: |
| | """Enhanced video file validation with detailed checks""" |
| | if not video_path or not os.path.exists(video_path): |
| | return False, "Video file not found" |
| | |
| | try: |
| | |
| | file_size = os.path.getsize(video_path) |
| | if file_size == 0: |
| | return False, "Video file is empty" |
| | |
| | if file_size > 2 * 1024 * 1024 * 1024: |
| | return False, "Video file too large (>2GB)" |
| | |
| | |
| | cap = cv2.VideoCapture(video_path) |
| | if not cap.isOpened(): |
| | return False, "Cannot open video file" |
| | |
| | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | |
| | cap.release() |
| | |
| | |
| | if frame_count == 0: |
| | return False, "Video appears to be empty (0 frames)" |
| | |
| | if fps <= 0 or fps > 120: |
| | return False, f"Invalid frame rate: {fps}" |
| | |
| | if width <= 0 or height <= 0: |
| | return False, f"Invalid resolution: {width}x{height}" |
| | |
| | if width > 4096 or height > 4096: |
| | return False, f"Resolution too high: {width}x{height} (max 4096x4096)" |
| | |
| | duration = frame_count / fps |
| | if duration > 300: |
| | return False, f"Video too long: {duration:.1f}s (max 300s)" |
| | |
| | return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s" |
| | |
| | except Exception as e: |
| | return False, f"Error validating video: {str(e)}" |