| |
| """ |
| Enhanced utilities.py - Core computer vision functions with auto-best quality |
| VERSION: 2.0-auto-best |
| ROLLBACK: Set USE_ENHANCED_SEGMENTATION = False to revert to original behavior |
| """ |
|
|
| import os |
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw |
| import logging |
| import time |
| from typing import Optional, Dict, Any, Tuple, List |
| from pathlib import Path |
|
|
| |
| |
| |
|
|
| |
| USE_ENHANCED_SEGMENTATION = True |
| USE_AUTO_TEMPORAL_CONSISTENCY = True |
| USE_INTELLIGENT_PROMPTING = True |
| USE_ITERATIVE_REFINEMENT = True |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| PROFESSIONAL_BACKGROUNDS = { |
| "office_modern": { |
| "name": "Modern Office", |
| "type": "gradient", |
| "colors": ["#f8f9fa", "#e9ecef", "#dee2e6"], |
| "direction": "diagonal", |
| "description": "Clean, contemporary office environment", |
| "brightness": 0.95, |
| "contrast": 1.1 |
| }, |
| "studio_blue": { |
| "name": "Professional Blue", |
| "type": "gradient", |
| "colors": ["#1e3c72", "#2a5298", "#3498db"], |
| "direction": "radial", |
| "description": "Broadcast-quality blue studio", |
| "brightness": 0.9, |
| "contrast": 1.2 |
| }, |
| "studio_green": { |
| "name": "Broadcast Green", |
| "type": "color", |
| "colors": ["#00b894"], |
| "chroma_key": True, |
| "description": "Professional green screen replacement", |
| "brightness": 1.0, |
| "contrast": 1.0 |
| }, |
| "minimalist": { |
| "name": "Minimalist White", |
| "type": "gradient", |
| "colors": ["#ffffff", "#f1f2f6", "#ddd"], |
| "direction": "soft_radial", |
| "description": "Clean, minimal background", |
| "brightness": 0.98, |
| "contrast": 0.9 |
| }, |
| "warm_gradient": { |
| "name": "Warm Sunset", |
| "type": "gradient", |
| "colors": ["#ff7675", "#fd79a8", "#fdcb6e"], |
| "direction": "diagonal", |
| "description": "Warm, inviting atmosphere", |
| "brightness": 0.85, |
| "contrast": 1.15 |
| }, |
| "tech_dark": { |
| "name": "Tech Dark", |
| "type": "gradient", |
| "colors": ["#0c0c0c", "#2d3748", "#4a5568"], |
| "direction": "vertical", |
| "description": "Modern tech/gaming setup", |
| "brightness": 0.7, |
| "contrast": 1.3 |
| }, |
| "corporate_blue": { |
| "name": "Corporate Blue", |
| "type": "gradient", |
| "colors": ["#667eea", "#764ba2", "#f093fb"], |
| "direction": "diagonal", |
| "description": "Professional corporate background", |
| "brightness": 0.88, |
| "contrast": 1.1 |
| }, |
| "nature_blur": { |
| "name": "Soft Nature", |
| "type": "gradient", |
| "colors": ["#a8edea", "#fed6e3", "#d299c2"], |
| "direction": "radial", |
| "description": "Soft blurred nature effect", |
| "brightness": 0.92, |
| "contrast": 0.95 |
| } |
| } |
|
|
| |
| class SegmentationError(Exception): |
| """Custom exception for segmentation failures""" |
| pass |
|
|
| class MaskRefinementError(Exception): |
| """Custom exception for mask refinement failures""" |
| pass |
|
|
| class BackgroundReplacementError(Exception): |
| """Custom exception for background replacement failures""" |
| pass |
|
|
| |
| |
| |
|
|
| def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| """ |
| ENHANCED VERSION 2.0: High-quality person segmentation with intelligent automation |
| |
| ROLLBACK: Set USE_ENHANCED_SEGMENTATION = False to revert to original behavior |
| |
| Args: |
| image: Input image (H, W, 3) |
| predictor: SAM2 predictor instance |
| fallback_enabled: Whether to use fallback segmentation if AI fails |
| |
| Returns: |
| Binary mask (H, W) with values 0-255 |
| """ |
| if not USE_ENHANCED_SEGMENTATION: |
| return segment_person_hq_original(image, predictor, fallback_enabled) |
| |
| logger.debug("Using ENHANCED segmentation with intelligent automation") |
| |
| if image is None or image.size == 0: |
| raise SegmentationError("Invalid input image") |
| |
| try: |
| |
| if predictor is None: |
| if fallback_enabled: |
| logger.warning("SAM2 predictor not available, using fallback") |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError("SAM2 predictor not available") |
| |
| |
| try: |
| predictor.set_image(image) |
| except Exception as e: |
| logger.error(f"Failed to set image in predictor: {e}") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError(f"Predictor setup failed: {e}") |
| |
| |
| if USE_INTELLIGENT_PROMPTING: |
| mask = _segment_with_intelligent_prompts(image, predictor) |
| else: |
| mask = _segment_with_basic_prompts(image, predictor) |
| |
| |
| if USE_ITERATIVE_REFINEMENT and mask is not None: |
| mask = _auto_refine_mask_iteratively(image, mask, predictor) |
| |
| |
| if not _validate_mask_quality(mask, image.shape[:2]): |
| logger.warning("Mask quality validation failed") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError("Poor mask quality") |
| |
| logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}") |
| return mask |
| |
| except SegmentationError: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected segmentation error: {e}") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError(f"Unexpected error: {e}") |
|
|
| def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| """NEW: Intelligent automatic prompt generation""" |
| try: |
| h, w = image.shape[:2] |
| |
| |
| pos_points, neg_points = _generate_smart_prompts(image) |
| |
| if len(pos_points) == 0: |
| |
| pos_points = np.array([[w//2, h//2]], dtype=np.float32) |
| |
| |
| points = np.vstack([pos_points, neg_points]) |
| labels = np.hstack([ |
| np.ones(len(pos_points), dtype=np.int32), |
| np.zeros(len(neg_points), dtype=np.int32) |
| ]) |
| |
| logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points") |
| |
| |
| with torch.no_grad(): |
| masks, scores, _ = predictor.predict( |
| point_coords=points, |
| point_labels=labels, |
| multimask_output=True |
| ) |
| |
| if masks is None or len(masks) == 0: |
| raise SegmentationError("No masks generated") |
| |
| |
| if scores is not None and len(scores) > 0: |
| best_idx = np.argmax(scores) |
| best_mask = masks[best_idx] |
| logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}") |
| else: |
| best_mask = masks[0] |
| |
| return _process_mask(best_mask) |
| |
| except Exception as e: |
| logger.error(f"Intelligent prompting failed: {e}") |
| raise |
|
|
| def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| """NEW: Generate optimal positive/negative points automatically""" |
| try: |
| h, w = image.shape[:2] |
| |
| |
| try: |
| saliency = cv2.saliency.StaticSaliencySpectralResidual_create() |
| success, saliency_map = saliency.computeSaliency(image) |
| |
| if success: |
| |
| saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1] |
| contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8), |
| cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| positive_points = [] |
| if contours: |
| |
| for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: |
| M = cv2.moments(contour) |
| if M["m00"] != 0: |
| cx = int(M["m10"] / M["m00"]) |
| cy = int(M["m01"] / M["m00"]) |
| |
| if 0 < cx < w and 0 < cy < h: |
| positive_points.append([cx, cy]) |
| |
| if positive_points: |
| logger.debug(f"Generated {len(positive_points)} saliency-based points") |
| positive_points = np.array(positive_points, dtype=np.float32) |
| else: |
| raise Exception("No valid saliency points found") |
| |
| except Exception as e: |
| logger.debug(f"Saliency method failed: {e}, using fallback") |
| |
| positive_points = np.array([ |
| [w//2, h//3], |
| [w//2, h//2], |
| [w//2, 2*h//3], |
| ], dtype=np.float32) |
| |
| |
| negative_points = np.array([ |
| [10, 10], |
| [w-10, 10], |
| [10, h-10], |
| [w-10, h-10], |
| [w//2, 5], |
| [w//2, h-5], |
| ], dtype=np.float32) |
| |
| return positive_points, negative_points |
| |
| except Exception as e: |
| logger.warning(f"Smart prompt generation failed: {e}") |
| |
| h, w = image.shape[:2] |
| positive_points = np.array([[w//2, h//2]], dtype=np.float32) |
| negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32) |
| return positive_points, negative_points |
|
|
| def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray, |
| predictor: Any, max_iterations: int = 2) -> np.ndarray: |
| """NEW: Automatically refine mask based on quality assessment""" |
| try: |
| current_mask = initial_mask.copy() |
| h, w = image.shape[:2] |
| |
| for iteration in range(max_iterations): |
| |
| quality_score = _assess_mask_quality(current_mask, image) |
| logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}") |
| |
| if quality_score > 0.85: |
| logger.debug(f"Quality sufficient after {iteration} iterations") |
| break |
| |
| |
| problem_areas = _find_mask_errors(current_mask, image) |
| |
| if np.any(problem_areas): |
| |
| corrective_points, corrective_labels = _generate_corrective_prompts( |
| image, current_mask, problem_areas |
| ) |
| |
| if len(corrective_points) > 0: |
| |
| try: |
| with torch.no_grad(): |
| masks, scores, _ = predictor.predict( |
| point_coords=corrective_points, |
| point_labels=corrective_labels, |
| mask_input=current_mask[None, :, :], |
| multimask_output=False |
| ) |
| |
| if masks is not None and len(masks) > 0: |
| refined_mask = _process_mask(masks[0]) |
| |
| |
| if _assess_mask_quality(refined_mask, image) > quality_score: |
| current_mask = refined_mask |
| logger.debug(f"Improved mask in iteration {iteration}") |
| else: |
| logger.debug(f"Refinement didn't improve quality in iteration {iteration}") |
| break |
| |
| except Exception as e: |
| logger.debug(f"Refinement iteration {iteration} failed: {e}") |
| break |
| else: |
| logger.debug("No problem areas detected") |
| break |
| |
| return current_mask |
| |
| except Exception as e: |
| logger.warning(f"Iterative refinement failed: {e}") |
| return initial_mask |
|
|
| def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float: |
| """NEW: Assess mask quality automatically""" |
| try: |
| h, w = image.shape[:2] |
| |
| |
| scores = [] |
| |
| |
| mask_area = np.sum(mask > 127) |
| total_area = h * w |
| area_ratio = mask_area / total_area |
| |
| if 0.05 <= area_ratio <= 0.8: |
| area_score = 1.0 |
| elif area_ratio < 0.05: |
| area_score = area_ratio / 0.05 |
| else: |
| area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2) |
| scores.append(area_score) |
| |
| |
| mask_binary = mask > 127 |
| if np.any(mask_binary): |
| mask_center_y, mask_center_x = np.where(mask_binary) |
| center_y = np.mean(mask_center_y) / h |
| center_x = np.mean(mask_center_x) / w |
| |
| center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5)) |
| scores.append(center_score) |
| else: |
| scores.append(0.0) |
| |
| |
| edges = cv2.Canny(mask, 50, 150) |
| edge_density = np.sum(edges > 0) / total_area |
| smoothness_score = max(0, 1.0 - edge_density * 10) |
| scores.append(smoothness_score) |
| |
| |
| num_labels, _ = cv2.connectedComponents(mask) |
| connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2) |
| scores.append(connectivity_score) |
| |
| |
| weights = [0.3, 0.2, 0.3, 0.2] |
| overall_score = np.average(scores, weights=weights) |
| |
| return overall_score |
| |
| except Exception as e: |
| logger.warning(f"Quality assessment failed: {e}") |
| return 0.5 |
|
|
| def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray: |
| """NEW: Identify problematic areas in mask""" |
| try: |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| |
| |
| edges = cv2.Canny(gray, 50, 150) |
| |
| |
| mask_edges = cv2.Canny(mask, 50, 150) |
| |
| |
| edge_discrepancy = cv2.bitwise_xor(edges, mask_edges) |
| |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1) |
| |
| return error_regions > 0 |
| |
| except Exception as e: |
| logger.warning(f"Error detection failed: {e}") |
| return np.zeros_like(mask, dtype=bool) |
|
|
| def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray, |
| problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| """NEW: Generate corrective prompts based on problem areas""" |
| try: |
| |
| contours, _ = cv2.findContours(problem_areas.astype(np.uint8), |
| cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| corrective_points = [] |
| corrective_labels = [] |
| |
| for contour in contours: |
| if cv2.contourArea(contour) > 100: |
| M = cv2.moments(contour) |
| if M["m00"] != 0: |
| cx = int(M["m10"] / M["m00"]) |
| cy = int(M["m01"] / M["m00"]) |
| |
| |
| |
| current_mask_value = mask[cy, cx] |
| |
| |
| |
| if current_mask_value < 127: |
| |
| corrective_points.append([cx, cy]) |
| corrective_labels.append(1) |
| else: |
| |
| corrective_points.append([cx, cy]) |
| corrective_labels.append(0) |
| |
| return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2), |
| np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32)) |
| |
| except Exception as e: |
| logger.warning(f"Corrective prompt generation failed: {e}") |
| return np.array([]).reshape(0, 2), np.array([], dtype=np.int32) |
|
|
| def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: |
| """FALLBACK: Original basic prompting method""" |
| h, w = image.shape[:2] |
| |
| |
| positive_points = np.array([ |
| [w//2, h//3], |
| [w//2, h//2], |
| [w//2, 2*h//3], |
| ], dtype=np.float32) |
| |
| negative_points = np.array([ |
| [w//10, h//10], |
| [9*w//10, h//10], |
| [w//10, 9*h//10], |
| [9*w//10, 9*h//10], |
| ], dtype=np.float32) |
| |
| |
| points = np.vstack([positive_points, negative_points]) |
| labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32) |
| |
| |
| with torch.no_grad(): |
| masks, scores, _ = predictor.predict( |
| point_coords=points, |
| point_labels=labels, |
| multimask_output=True |
| ) |
| |
| if masks is None or len(masks) == 0: |
| raise SegmentationError("No masks generated") |
| |
| |
| best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0 |
| best_mask = masks[best_idx] |
| |
| return _process_mask(best_mask) |
|
|
| |
| |
| |
|
|
| def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: |
| """ |
| ORIGINAL VERSION: Preserved for rollback capability |
| """ |
| if image is None or image.size == 0: |
| raise SegmentationError("Invalid input image") |
| |
| try: |
| |
| if predictor is None: |
| if fallback_enabled: |
| logger.warning("SAM2 predictor not available, using fallback") |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError("SAM2 predictor not available") |
| |
| |
| try: |
| predictor.set_image(image) |
| except Exception as e: |
| logger.error(f"Failed to set image in predictor: {e}") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError(f"Predictor setup failed: {e}") |
| |
| h, w = image.shape[:2] |
| |
| |
| points = np.array([ |
| [w//2, h//4], |
| [w//2, h//2], |
| [w//2, 3*h//4], |
| [w//3, h//2], |
| [2*w//3, h//2], |
| [w//2, h//6], |
| [w//4, 2*h//3], |
| [3*w//4, 2*h//3], |
| ], dtype=np.float32) |
| |
| labels = np.ones(len(points), dtype=np.int32) |
| |
| |
| try: |
| with torch.no_grad(): |
| masks, scores, _ = predictor.predict( |
| point_coords=points, |
| point_labels=labels, |
| multimask_output=True |
| ) |
| except Exception as e: |
| logger.error(f"SAM2 prediction failed: {e}") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError(f"Prediction failed: {e}") |
| |
| |
| if masks is None or len(masks) == 0: |
| logger.warning("SAM2 returned no masks") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError("No masks generated") |
| |
| if scores is None or len(scores) == 0: |
| logger.warning("SAM2 returned no scores") |
| best_mask = masks[0] |
| else: |
| |
| best_idx = np.argmax(scores) |
| best_mask = masks[best_idx] |
| logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}") |
| |
| |
| mask = _process_mask(best_mask) |
| |
| |
| if not _validate_mask_quality(mask, image.shape[:2]): |
| logger.warning("Mask quality validation failed") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError("Poor mask quality") |
| |
| logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}") |
| return mask |
| |
| except SegmentationError: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected segmentation error: {e}") |
| if fallback_enabled: |
| return _fallback_segmentation(image) |
| else: |
| raise SegmentationError(f"Unexpected error: {e}") |
|
|
| |
| |
| |
|
|
| def _process_mask(mask: np.ndarray) -> np.ndarray: |
| """Process raw mask to ensure correct format and range""" |
| try: |
| |
| if len(mask.shape) > 2: |
| mask = mask.squeeze() |
| |
| if len(mask.shape) > 2: |
| mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2) |
| |
| |
| if mask.dtype == bool: |
| mask = mask.astype(np.uint8) * 255 |
| elif mask.dtype == np.float32 or mask.dtype == np.float64: |
| if mask.max() <= 1.0: |
| mask = (mask * 255).astype(np.uint8) |
| else: |
| mask = np.clip(mask, 0, 255).astype(np.uint8) |
| else: |
| mask = mask.astype(np.uint8) |
| |
| |
| kernel = np.ones((3, 3), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| |
| |
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| |
| return mask |
| |
| except Exception as e: |
| logger.error(f"Mask processing failed: {e}") |
| |
| h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256) |
| fallback = np.zeros((h, w), dtype=np.uint8) |
| fallback[h//4:3*h//4, w//4:3*w//4] = 255 |
| return fallback |
|
|
| def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool: |
| """Validate that the mask meets quality criteria""" |
| try: |
| h, w = image_shape |
| mask_area = np.sum(mask > 127) |
| total_area = h * w |
| |
| |
| area_ratio = mask_area / total_area |
| if area_ratio < 0.05 or area_ratio > 0.8: |
| logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}") |
| return False |
| |
| |
| mask_binary = mask > 127 |
| mask_center_y, mask_center_x = np.where(mask_binary) |
| |
| if len(mask_center_y) == 0: |
| logger.warning("Empty mask") |
| return False |
| |
| center_y = np.mean(mask_center_y) |
| center_x = np.mean(mask_center_x) |
| |
| |
| if center_y < h * 0.2 or center_y > h * 0.9: |
| logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}") |
| return False |
| |
| return True |
| |
| except Exception as e: |
| logger.warning(f"Mask validation error: {e}") |
| return True |
|
|
| def _fallback_segmentation(image: np.ndarray) -> np.ndarray: |
| """Fallback segmentation when AI models fail""" |
| try: |
| logger.info("Using fallback segmentation strategy") |
| h, w = image.shape[:2] |
| |
| |
| try: |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| |
| |
| edge_pixels = np.concatenate([ |
| gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1] |
| ]) |
| bg_color = np.median(edge_pixels) |
| |
| |
| diff = np.abs(gray.astype(float) - bg_color) |
| mask = (diff > 30).astype(np.uint8) * 255 |
| |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| |
| |
| if _validate_mask_quality(mask, image.shape[:2]): |
| logger.info("Background subtraction fallback successful") |
| return mask |
| |
| except Exception as e: |
| logger.warning(f"Background subtraction fallback failed: {e}") |
| |
| |
| mask = np.zeros((h, w), dtype=np.uint8) |
| |
| |
| center_x, center_y = w // 2, h // 2 |
| radius_x, radius_y = w // 3, h // 2.5 |
| |
| y, x = np.ogrid[:h, :w] |
| mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1 |
| mask[mask_ellipse] = 255 |
| |
| logger.info("Using geometric fallback mask") |
| return mask |
| |
| except Exception as e: |
| logger.error(f"All fallback strategies failed: {e}") |
| |
| h, w = image.shape[:2] |
| mask = np.zeros((h, w), dtype=np.uint8) |
| mask[h//6:5*h//6, w//4:3*w//4] = 255 |
| return mask |
|
|
| |
| |
| |
|
|
| def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any, |
| fallback_enabled: bool = True) -> np.ndarray: |
| """ |
| Enhanced mask refinement with MatAnyone and robust fallbacks |
| UNCHANGED for rollback safety |
| """ |
| if image is None or mask is None: |
| raise MaskRefinementError("Invalid input image or mask") |
| |
| try: |
| |
| mask = _process_mask(mask) |
| |
| |
| if matanyone_processor is not None: |
| try: |
| logger.debug("Attempting MatAnyone refinement") |
| refined_mask = _matanyone_refine(image, mask, matanyone_processor) |
| |
| if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]): |
| logger.debug("MatAnyone refinement successful") |
| return refined_mask |
| else: |
| logger.warning("MatAnyone produced poor quality mask") |
| |
| except Exception as e: |
| logger.warning(f"MatAnyone refinement failed: {e}") |
| |
| |
| if fallback_enabled: |
| logger.debug("Using enhanced OpenCV refinement") |
| return enhance_mask_opencv_advanced(image, mask) |
| else: |
| raise MaskRefinementError("MatAnyone failed and fallback disabled") |
| |
| except MaskRefinementError: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected mask refinement error: {e}") |
| if fallback_enabled: |
| return enhance_mask_opencv_advanced(image, mask) |
| else: |
| raise MaskRefinementError(f"Unexpected error: {e}") |
|
|
| def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]: |
| """Attempt MatAnyone mask refinement - Python 3.10 compatible""" |
| try: |
| |
| if hasattr(processor, 'infer'): |
| refined_mask = processor.infer(image, mask) |
| elif hasattr(processor, 'process'): |
| refined_mask = processor.process(image, mask) |
| elif callable(processor): |
| refined_mask = processor(image, mask) |
| else: |
| logger.warning("Unknown MatAnyone interface") |
| return None |
| |
| if refined_mask is None: |
| return None |
| |
| |
| refined_mask = _process_mask(refined_mask) |
| |
| logger.debug("MatAnyone refinement successful") |
| return refined_mask |
| |
| except Exception as e: |
| logger.warning(f"MatAnyone processing error: {e}") |
| return None |
|
|
| def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray: |
| """Advanced OpenCV-based mask enhancement with multiple techniques""" |
| try: |
| if len(mask.shape) == 3: |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| |
| |
| if mask.max() <= 1.0: |
| mask = (mask * 255).astype(np.uint8) |
| |
| |
| |
| |
| refined_mask = cv2.bilateralFilter(mask, 9, 75, 75) |
| |
| |
| refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2) |
| |
| |
| kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close) |
| |
| kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) |
| refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open) |
| |
| |
| refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8) |
| |
| |
| _, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY) |
| |
| return refined_mask |
| |
| except Exception as e: |
| logger.warning(f"Enhanced OpenCV refinement failed: {e}") |
| |
| return cv2.GaussianBlur(mask, (5, 5), 1.0) |
|
|
| def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray: |
| """Approximation of guided filter for edge-aware smoothing""" |
| try: |
| guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide |
| guide_gray = guide_gray.astype(np.float32) / 255.0 |
| mask_float = mask.astype(np.float32) / 255.0 |
| |
| |
| kernel_size = 2 * radius + 1 |
| |
| |
| mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size)) |
| mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size)) |
| corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size)) |
| |
| |
| cov_guide_mask = corr_guide_mask - mean_guide * mean_mask |
| mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size)) |
| var_guide = mean_guide_sq - mean_guide * mean_guide |
| |
| |
| a = cov_guide_mask / (var_guide + eps) |
| b = mean_mask - a * mean_guide |
| |
| |
| mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size)) |
| mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size)) |
| |
| output = mean_a * guide_gray + mean_b |
| output = np.clip(output * 255, 0, 255).astype(np.uint8) |
| |
| return output |
| |
| except Exception as e: |
| logger.warning(f"Guided filter approximation failed: {e}") |
| return mask |
|
|
| def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray, |
| fallback_enabled: bool = True) -> np.ndarray: |
| """Enhanced background replacement with comprehensive error handling and quality improvements""" |
| if frame is None or mask is None or background is None: |
| raise BackgroundReplacementError("Invalid input frame, mask, or background") |
| |
| try: |
| |
| background = cv2.resize(background, (frame.shape[1], frame.shape[0]), |
| interpolation=cv2.INTER_LANCZOS4) |
| |
| |
| if len(mask.shape) == 3: |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| |
| if mask.dtype != np.uint8: |
| mask = mask.astype(np.uint8) |
| |
| if mask.max() <= 1.0: |
| logger.debug("Converting normalized mask to 0-255 range") |
| mask = (mask * 255).astype(np.uint8) |
| |
| |
| try: |
| result = _advanced_compositing(frame, mask, background) |
| logger.debug("Advanced compositing successful") |
| return result |
| |
| except Exception as e: |
| logger.warning(f"Advanced compositing failed: {e}") |
| if fallback_enabled: |
| return _simple_compositing(frame, mask, background) |
| else: |
| raise BackgroundReplacementError(f"Advanced compositing failed: {e}") |
| |
| except BackgroundReplacementError: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected background replacement error: {e}") |
| if fallback_enabled: |
| return _simple_compositing(frame, mask, background) |
| else: |
| raise BackgroundReplacementError(f"Unexpected error: {e}") |
|
|
| def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray: |
| """Advanced compositing with edge feathering and color correction""" |
| try: |
| |
| threshold = 100 |
| _, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY) |
| |
| |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel) |
| mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel) |
| |
| |
| mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0) |
| mask_smooth = mask_smooth / 255.0 |
| |
| |
| mask_smooth = np.power(mask_smooth, 0.8) |
| |
| |
| mask_smooth = np.where(mask_smooth > 0.5, |
| np.minimum(mask_smooth * 1.1, 1.0), |
| mask_smooth * 0.9) |
| |
| |
| frame_adjusted = _color_match_edges(frame, background, mask_smooth) |
| |
| |
| alpha_3ch = np.stack([mask_smooth] * 3, axis=2) |
| |
| |
| frame_float = frame_adjusted.astype(np.float32) |
| background_float = background.astype(np.float32) |
| |
| |
| result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch) |
| result = np.clip(result, 0, 255).astype(np.uint8) |
| |
| return result |
| |
| except Exception as e: |
| logger.error(f"Advanced compositing error: {e}") |
| raise |
|
|
| def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray: |
| """Subtle color matching at edges to reduce halos""" |
| try: |
| |
| edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3) |
| edge_mask = np.abs(edge_mask) |
| edge_mask = (edge_mask > 0.1).astype(np.float32) |
| |
| |
| edge_areas = edge_mask > 0 |
| if not np.any(edge_areas): |
| return frame |
| |
| |
| frame_adjusted = frame.copy().astype(np.float32) |
| background_float = background.astype(np.float32) |
| |
| |
| adjustment_strength = 0.1 |
| for c in range(3): |
| frame_adjusted[:, :, c] = np.where( |
| edge_areas, |
| frame_adjusted[:, :, c] * (1 - adjustment_strength) + |
| background_float[:, :, c] * adjustment_strength, |
| frame_adjusted[:, :, c] |
| ) |
| |
| return np.clip(frame_adjusted, 0, 255).astype(np.uint8) |
| |
| except Exception as e: |
| logger.warning(f"Color matching failed: {e}") |
| return frame |
|
|
| def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray: |
| """Simple fallback compositing method""" |
| try: |
| logger.info("Using simple compositing fallback") |
| |
| |
| background = cv2.resize(background, (frame.shape[1], frame.shape[0])) |
| |
| |
| if len(mask.shape) == 3: |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| if mask.max() <= 1.0: |
| mask = (mask * 255).astype(np.uint8) |
| |
| |
| _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| |
| |
| mask_norm = mask_binary.astype(np.float32) / 255.0 |
| mask_3ch = np.stack([mask_norm] * 3, axis=2) |
| |
| |
| result = frame * mask_3ch + background * (1 - mask_3ch) |
| return result.astype(np.uint8) |
| |
| except Exception as e: |
| logger.error(f"Simple compositing failed: {e}") |
| |
| return frame |
|
|
| def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
| """Enhanced professional background creation with quality improvements""" |
| try: |
| if bg_config["type"] == "color": |
| background = _create_solid_background(bg_config, width, height) |
| elif bg_config["type"] == "gradient": |
| background = _create_gradient_background_enhanced(bg_config, width, height) |
| else: |
| |
| background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
| |
| |
| background = _apply_background_adjustments(background, bg_config) |
| |
| return background |
| |
| except Exception as e: |
| logger.error(f"Background creation error: {e}") |
| return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
|
|
| def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
| """Create solid color background""" |
| color_hex = bg_config["colors"][0].lstrip('#') |
| color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4)) |
| color_bgr = color_rgb[::-1] |
| return np.full((height, width, 3), color_bgr, dtype=np.uint8) |
|
|
| def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray: |
| """Create enhanced gradient background with better quality""" |
| try: |
| colors = bg_config["colors"] |
| direction = bg_config.get("direction", "vertical") |
| |
| |
| rgb_colors = [] |
| for color_hex in colors: |
| color_hex = color_hex.lstrip('#') |
| rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4)) |
| rgb_colors.append(rgb) |
| |
| if not rgb_colors: |
| rgb_colors = [(128, 128, 128)] |
| |
| |
| if direction == "vertical": |
| background = _create_vertical_gradient(rgb_colors, width, height) |
| elif direction == "horizontal": |
| background = _create_horizontal_gradient(rgb_colors, width, height) |
| elif direction == "diagonal": |
| background = _create_diagonal_gradient(rgb_colors, width, height) |
| elif direction in ["radial", "soft_radial"]: |
| background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial") |
| else: |
| background = _create_vertical_gradient(rgb_colors, width, height) |
| |
| return cv2.cvtColor(background, cv2.COLOR_RGB2BGR) |
| |
| except Exception as e: |
| logger.error(f"Gradient creation error: {e}") |
| return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8) |
|
|
| def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray: |
| """Create vertical gradient using NumPy for performance""" |
| gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| |
| for y in range(height): |
| progress = y / height if height > 0 else 0 |
| color = _interpolate_color(colors, progress) |
| gradient[y, :] = color |
| |
| return gradient |
|
|
| def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray: |
| """Create horizontal gradient using NumPy for performance""" |
| gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| |
| for x in range(width): |
| progress = x / width if width > 0 else 0 |
| color = _interpolate_color(colors, progress) |
| gradient[:, x] = color |
| |
| return gradient |
|
|
| def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray: |
| """Create diagonal gradient using vectorized operations""" |
| y_coords, x_coords = np.mgrid[0:height, 0:width] |
| max_distance = width + height |
| progress = (x_coords + y_coords) / max_distance |
| progress = np.clip(progress, 0, 1) |
| |
| |
| gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| for c in range(3): |
| gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c) |
| |
| return gradient |
|
|
| def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray: |
| """Create radial gradient using vectorized operations""" |
| center_x, center_y = width // 2, height // 2 |
| max_distance = np.sqrt(center_x**2 + center_y**2) |
| |
| y_coords, x_coords = np.mgrid[0:height, 0:width] |
| distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2) |
| progress = distances / max_distance |
| progress = np.clip(progress, 0, 1) |
| |
| if soft: |
| progress = np.power(progress, 0.7) |
| |
| |
| gradient = np.zeros((height, width, 3), dtype=np.uint8) |
| for c in range(3): |
| gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c) |
| |
| return gradient |
|
|
| def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray: |
| """Vectorized color interpolation for performance""" |
| if len(colors) == 1: |
| return np.full_like(progress, colors[0][channel], dtype=np.uint8) |
| |
| num_segments = len(colors) - 1 |
| segment_progress = progress * num_segments |
| segment_indices = np.floor(segment_progress).astype(int) |
| segment_indices = np.clip(segment_indices, 0, num_segments - 1) |
| local_progress = segment_progress - segment_indices |
| |
| |
| start_colors = np.array([colors[i][channel] for i in range(len(colors))]) |
| end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))]) |
| |
| start_vals = start_colors[segment_indices] |
| end_vals = end_colors[segment_indices] |
| |
| result = start_vals + (end_vals - start_vals) * local_progress |
| return np.clip(result, 0, 255).astype(np.uint8) |
|
|
| def _interpolate_color(colors: list, progress: float) -> tuple: |
| """Interpolate between multiple colors""" |
| if len(colors) == 1: |
| return colors[0] |
| elif len(colors) == 2: |
| r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress) |
| g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress) |
| b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress) |
| return (r, g, b) |
| else: |
| segment = progress * (len(colors) - 1) |
| idx = int(segment) |
| local_progress = segment - idx |
| if idx >= len(colors) - 1: |
| return colors[-1] |
| c1, c2 = colors[idx], colors[idx + 1] |
| r = int(c1[0] + (c2[0] - c1[0]) * local_progress) |
| g = int(c1[1] + (c2[1] - c1[1]) * local_progress) |
| b = int(c1[2] + (c2[2] - c1[2]) * local_progress) |
| return (r, g, b) |
|
|
| def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray: |
| """Apply brightness and contrast adjustments to background""" |
| try: |
| brightness = bg_config.get("brightness", 1.0) |
| contrast = bg_config.get("contrast", 1.0) |
| |
| if brightness != 1.0 or contrast != 1.0: |
| background = background.astype(np.float32) |
| background = background * contrast * brightness |
| background = np.clip(background, 0, 255).astype(np.uint8) |
| |
| return background |
| |
| except Exception as e: |
| logger.warning(f"Background adjustment failed: {e}") |
| return background |
|
|
| def validate_video_file(video_path: str) -> Tuple[bool, str]: |
| """Enhanced video file validation with detailed checks""" |
| if not video_path or not os.path.exists(video_path): |
| return False, "Video file not found" |
| |
| try: |
| |
| file_size = os.path.getsize(video_path) |
| if file_size == 0: |
| return False, "Video file is empty" |
| |
| if file_size > 2 * 1024 * 1024 * 1024: |
| return False, "Video file too large (>2GB)" |
| |
| |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return False, "Cannot open video file" |
| |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| cap.release() |
| |
| |
| if frame_count == 0: |
| return False, "Video appears to be empty (0 frames)" |
| |
| if fps <= 0 or fps > 120: |
| return False, f"Invalid frame rate: {fps}" |
| |
| if width <= 0 or height <= 0: |
| return False, f"Invalid resolution: {width}x{height}" |
| |
| if width > 4096 or height > 4096: |
| return False, f"Resolution too high: {width}x{height} (max 4096x4096)" |
| |
| duration = frame_count / fps |
| if duration > 300: |
| return False, f"Video too long: {duration:.1f}s (max 300s)" |
| |
| return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s" |
| |
| except Exception as e: |
| return False, f"Error validating video: {str(e)}" |