| | """ |
| | Fallback strategies for BackgroundFX Pro. |
| | Implements robust fallback mechanisms when primary processing fails. |
| | """ |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from typing import Dict, List, Optional, Tuple, Any |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | import logging |
| | import traceback |
| |
|
| | from ..utils.logger import setup_logger |
| | from ..utils.device import DeviceManager |
| | from ..utils.config import ConfigManager |
| | from ..core.quality import QualityAnalyzer |
| |
|
| | logger = setup_logger(__name__) |
| |
|
| |
|
| | class FallbackLevel(Enum): |
| | """Fallback hierarchy levels.""" |
| | NONE = 0 |
| | QUALITY_REDUCTION = 1 |
| | METHOD_SWITCH = 2 |
| | BASIC_PROCESSING = 3 |
| | MINIMAL_PROCESSING = 4 |
| | PASSTHROUGH = 5 |
| |
|
| |
|
| | @dataclass |
| | class FallbackConfig: |
| | """Configuration for fallback strategies.""" |
| | max_retries: int = 3 |
| | quality_reduction_factor: float = 0.75 |
| | min_quality: float = 0.3 |
| | enable_caching: bool = True |
| | cache_size: int = 10 |
| | timeout_seconds: float = 30.0 |
| | gpu_fallback_to_cpu: bool = True |
| | progressive_downscale: bool = True |
| | min_resolution: Tuple[int, int] = (320, 240) |
| |
|
| |
|
| | class FallbackStrategy: |
| | """Intelligent fallback strategy manager.""" |
| | |
| | def __init__(self, config: Optional[FallbackConfig] = None): |
| | self.config = config or FallbackConfig() |
| | self.device_manager = DeviceManager() |
| | self.quality_analyzer = QualityAnalyzer() |
| | self.cache = {} |
| | self.fallback_history = [] |
| | self.current_level = FallbackLevel.NONE |
| | |
| | def execute_with_fallback(self, func, *args, **kwargs) -> Dict[str, Any]: |
| | """ |
| | Execute function with automatic fallback on failure. |
| | |
| | Args: |
| | func: Function to execute |
| | *args: Function arguments |
| | **kwargs: Function keyword arguments |
| | |
| | Returns: |
| | Result dictionary with status and output |
| | """ |
| | attempt = 0 |
| | last_error = None |
| | original_args = args |
| | original_kwargs = kwargs.copy() |
| | |
| | while attempt < self.config.max_retries: |
| | try: |
| | |
| | logger.info(f"Attempt {attempt + 1}/{self.config.max_retries} for {func.__name__}") |
| | |
| | |
| | result = func(*args, **kwargs) |
| | |
| | |
| | self.current_level = FallbackLevel.NONE |
| | |
| | return { |
| | 'success': True, |
| | 'result': result, |
| | 'attempts': attempt + 1, |
| | 'fallback_level': self.current_level |
| | } |
| | |
| | except Exception as e: |
| | last_error = e |
| | logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
| | |
| | |
| | fallback_result = self._apply_fallback( |
| | func, e, attempt, |
| | original_args, original_kwargs |
| | ) |
| | |
| | if fallback_result['handled']: |
| | args = fallback_result.get('new_args', args) |
| | kwargs = fallback_result.get('new_kwargs', kwargs) |
| | else: |
| | break |
| | |
| | attempt += 1 |
| | |
| | |
| | logger.error(f"All attempts failed for {func.__name__}") |
| | return self._final_fallback(func, last_error, original_args) |
| | |
| | def _apply_fallback(self, func, error: Exception, |
| | attempt: int, original_args: tuple, |
| | original_kwargs: dict) -> Dict[str, Any]: |
| | """Apply appropriate fallback strategy based on error type.""" |
| | |
| | error_type = type(error).__name__ |
| | self.fallback_history.append({ |
| | 'function': func.__name__, |
| | 'error': error_type, |
| | 'attempt': attempt |
| | }) |
| | |
| | |
| | if 'CUDA' in str(error) or 'GPU' in str(error): |
| | return self._handle_gpu_error(original_kwargs) |
| | |
| | |
| | elif 'memory' in str(error).lower(): |
| | return self._handle_memory_error(original_args, original_kwargs) |
| | |
| | |
| | elif 'timeout' in str(error).lower(): |
| | return self._handle_timeout_error(original_kwargs) |
| | |
| | |
| | elif 'model' in str(error).lower(): |
| | return self._handle_model_error(original_kwargs) |
| | |
| | |
| | else: |
| | return self._handle_generic_error(attempt, original_kwargs) |
| | |
| | def _handle_gpu_error(self, kwargs: dict) -> Dict[str, Any]: |
| | """Handle GPU-related errors.""" |
| | logger.info("GPU error detected, falling back to CPU") |
| | |
| | if self.config.gpu_fallback_to_cpu: |
| | |
| | self.device_manager.device = torch.device('cpu') |
| | kwargs['device'] = 'cpu' |
| | |
| | |
| | if 'batch_size' in kwargs: |
| | kwargs['batch_size'] = max(1, kwargs['batch_size'] // 2) |
| | |
| | self.current_level = FallbackLevel.METHOD_SWITCH |
| | |
| | return { |
| | 'handled': True, |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | return {'handled': False} |
| | |
| | def _handle_memory_error(self, args: tuple, |
| | kwargs: dict) -> Dict[str, Any]: |
| | """Handle memory-related errors.""" |
| | logger.info("Memory error detected, reducing quality") |
| | |
| | |
| | image = None |
| | image_idx = -1 |
| | |
| | for i, arg in enumerate(args): |
| | if isinstance(arg, np.ndarray) and len(arg.shape) == 3: |
| | image = arg |
| | image_idx = i |
| | break |
| | |
| | if image is not None and self.config.progressive_downscale: |
| | |
| | h, w = image.shape[:2] |
| | new_h = int(h * self.config.quality_reduction_factor) |
| | new_w = int(w * self.config.quality_reduction_factor) |
| | |
| | |
| | new_h = max(new_h, self.config.min_resolution[1]) |
| | new_w = max(new_w, self.config.min_resolution[0]) |
| | |
| | if new_h < h or new_w < w: |
| | resized = cv2.resize(image, (new_w, new_h)) |
| | args = list(args) |
| | args[image_idx] = resized |
| | |
| | self.current_level = FallbackLevel.QUALITY_REDUCTION |
| | |
| | return { |
| | 'handled': True, |
| | 'new_args': tuple(args), |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | |
| | if 'quality' in kwargs: |
| | kwargs['quality'] = max( |
| | self.config.min_quality, |
| | kwargs['quality'] * self.config.quality_reduction_factor |
| | ) |
| | |
| | return { |
| | 'handled': True, |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | def _handle_timeout_error(self, kwargs: dict) -> Dict[str, Any]: |
| | """Handle timeout errors by simplifying processing.""" |
| | logger.info("Timeout detected, simplifying processing") |
| | |
| | |
| | simplifications = { |
| | 'use_refinement': False, |
| | 'use_temporal': False, |
| | 'use_guided_filter': False, |
| | 'iterations': 1, |
| | 'num_samples': 1 |
| | } |
| | |
| | for key, value in simplifications.items(): |
| | if key in kwargs: |
| | kwargs[key] = value |
| | |
| | self.current_level = FallbackLevel.BASIC_PROCESSING |
| | |
| | return { |
| | 'handled': True, |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | def _handle_model_error(self, kwargs: dict) -> Dict[str, Any]: |
| | """Handle model loading errors.""" |
| | logger.info("Model error detected, using simpler model") |
| | |
| | |
| | if 'model_type' in kwargs: |
| | model_hierarchy = ['large', 'base', 'small', 'tiny'] |
| | current = kwargs.get('model_type', 'base') |
| | |
| | if current in model_hierarchy: |
| | idx = model_hierarchy.index(current) |
| | if idx < len(model_hierarchy) - 1: |
| | kwargs['model_type'] = model_hierarchy[idx + 1] |
| | self.current_level = FallbackLevel.METHOD_SWITCH |
| | |
| | return { |
| | 'handled': True, |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | |
| | kwargs['use_model'] = False |
| | self.current_level = FallbackLevel.BASIC_PROCESSING |
| | |
| | return { |
| | 'handled': True, |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | def _handle_generic_error(self, attempt: int, |
| | kwargs: dict) -> Dict[str, Any]: |
| | """Handle generic errors with progressive degradation.""" |
| | logger.info(f"Generic error, applying degradation level {attempt + 1}") |
| | |
| | |
| | if attempt == 0: |
| | |
| | self.current_level = FallbackLevel.QUALITY_REDUCTION |
| | if 'quality' in kwargs: |
| | kwargs['quality'] *= 0.8 |
| | |
| | elif attempt == 1: |
| | |
| | self.current_level = FallbackLevel.METHOD_SWITCH |
| | kwargs['method'] = 'basic' |
| | |
| | else: |
| | |
| | self.current_level = FallbackLevel.MINIMAL_PROCESSING |
| | kwargs['skip_refinement'] = True |
| | kwargs['fast_mode'] = True |
| | |
| | return { |
| | 'handled': True, |
| | 'new_kwargs': kwargs |
| | } |
| | |
| | def _final_fallback(self, func, error: Exception, |
| | original_args: tuple) -> Dict[str, Any]: |
| | """Apply final fallback when all attempts fail.""" |
| | logger.error(f"Final fallback for {func.__name__}: {str(error)}") |
| | self.current_level = FallbackLevel.PASSTHROUGH |
| | |
| | |
| | for arg in original_args: |
| | if isinstance(arg, np.ndarray): |
| | |
| | return { |
| | 'success': False, |
| | 'result': arg, |
| | 'fallback_level': self.current_level, |
| | 'error': str(error) |
| | } |
| | |
| | |
| | return { |
| | 'success': False, |
| | 'result': None, |
| | 'fallback_level': self.current_level, |
| | 'error': str(error) |
| | } |
| |
|
| |
|
| | class ProcessingFallback: |
| | """Specific fallback implementations for processing operations.""" |
| | |
| | def __init__(self): |
| | self.logger = setup_logger(f"{__name__}.ProcessingFallback") |
| | self.quality_analyzer = QualityAnalyzer() |
| | |
| | def basic_segmentation(self, image: np.ndarray) -> np.ndarray: |
| | """ |
| | Basic segmentation using traditional CV methods. |
| | Used as fallback when ML models fail. |
| | |
| | Args: |
| | image: Input image |
| | |
| | Returns: |
| | Binary mask |
| | """ |
| | try: |
| | |
| | if len(image.shape) == 3: |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| | else: |
| | gray = image |
| | |
| | |
| | mask = np.zeros(gray.shape[:2], np.uint8) |
| | bgd_model = np.zeros((1, 65), np.float64) |
| | fgd_model = np.zeros((1, 65), np.float64) |
| | |
| | |
| | h, w = gray.shape[:2] |
| | rect = (int(w * 0.1), int(h * 0.1), |
| | int(w * 0.8), int(h * 0.8)) |
| | |
| | |
| | cv2.grabCut(image, mask, rect, bgd_model, fgd_model, |
| | 5, cv2.GC_INIT_WITH_RECT) |
| | |
| | |
| | mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8') |
| | |
| | return mask2 |
| | |
| | except Exception as e: |
| | self.logger.error(f"Basic segmentation failed: {e}") |
| | |
| | return self._center_blob_mask(image.shape[:2]) |
| | |
| | def _center_blob_mask(self, shape: Tuple[int, int]) -> np.ndarray: |
| | """Create a center ellipse mask as ultimate fallback.""" |
| | h, w = shape |
| | mask = np.zeros((h, w), dtype=np.uint8) |
| | |
| | |
| | center = (w // 2, h // 2) |
| | axes = (w // 3, h // 3) |
| | cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1) |
| | |
| | |
| | mask = cv2.GaussianBlur(mask, (21, 21), 10) |
| | _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) |
| | |
| | return mask |
| | |
| | def basic_matting(self, image: np.ndarray, |
| | mask: np.ndarray) -> np.ndarray: |
| | """ |
| | Basic matting using morphological operations. |
| | |
| | Args: |
| | image: Input image |
| | mask: Binary mask |
| | |
| | Returns: |
| | Alpha matte |
| | """ |
| | try: |
| | |
| | if mask.dtype != np.uint8: |
| | mask = (mask * 255).astype(np.uint8) |
| | |
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
| | |
| | |
| | mask = cv2.GaussianBlur(mask, (5, 5), 2) |
| | |
| | |
| | alpha = mask.astype(np.float32) / 255.0 |
| | |
| | return alpha |
| | |
| | except Exception as e: |
| | self.logger.error(f"Basic matting failed: {e}") |
| | return mask.astype(np.float32) / 255.0 |
| | |
| | def color_difference_keying(self, image: np.ndarray, |
| | key_color: Optional[np.ndarray] = None, |
| | threshold: float = 30) -> np.ndarray: |
| | """ |
| | Simple color difference keying for solid backgrounds. |
| | |
| | Args: |
| | image: Input image |
| | key_color: Background color to remove |
| | threshold: Color difference threshold |
| | |
| | Returns: |
| | Alpha matte |
| | """ |
| | try: |
| | if key_color is None: |
| | |
| | h, w = image.shape[:2] |
| | corners = [ |
| | image[0:10, 0:10], |
| | image[0:10, w-10:w], |
| | image[h-10:h, 0:10], |
| | image[h-10:h, w-10:w] |
| | ] |
| | key_color = np.mean([np.mean(c, axis=(0, 1)) for c in corners], axis=0) |
| | |
| | |
| | diff = np.sqrt(np.sum((image - key_color) ** 2, axis=2)) |
| | |
| | |
| | mask = (diff > threshold).astype(np.float32) |
| | |
| | |
| | mask = cv2.GaussianBlur(mask, (5, 5), 2) |
| | |
| | return mask |
| | |
| | except Exception as e: |
| | self.logger.error(f"Color keying failed: {e}") |
| | return np.ones(image.shape[:2], dtype=np.float32) |
| | |
| | def edge_based_segmentation(self, image: np.ndarray) -> np.ndarray: |
| | """ |
| | Edge-based segmentation as fallback. |
| | |
| | Args: |
| | image: Input image |
| | |
| | Returns: |
| | Binary mask |
| | """ |
| | try: |
| | |
| | if len(image.shape) == 3: |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| | else: |
| | gray = image |
| | |
| | |
| | edges = cv2.Canny(gray, 50, 150) |
| | |
| | |
| | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) |
| | closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=2) |
| | |
| | |
| | contours, _ = cv2.findContours( |
| | closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE |
| | ) |
| | |
| | |
| | mask = np.zeros(gray.shape, dtype=np.uint8) |
| | if contours: |
| | largest = max(contours, key=cv2.contourArea) |
| | cv2.drawContours(mask, [largest], -1, 255, -1) |
| | |
| | return mask |
| | |
| | except Exception as e: |
| | self.logger.error(f"Edge segmentation failed: {e}") |
| | return self._center_blob_mask(image.shape[:2]) |
| | |
| | def cached_result(self, cache_key: str, |
| | fallback_func, *args, **kwargs) -> Any: |
| | """ |
| | Try to retrieve cached result or compute with fallback. |
| | |
| | Args: |
| | cache_key: Cache identifier |
| | fallback_func: Function to call if not cached |
| | *args, **kwargs: Function arguments |
| | |
| | Returns: |
| | Cached or computed result |
| | """ |
| | |
| | if not hasattr(self, '_cache'): |
| | self._cache = {} |
| | |
| | if cache_key in self._cache: |
| | self.logger.info(f"Using cached result for {cache_key}") |
| | return self._cache[cache_key] |
| | |
| | try: |
| | result = fallback_func(*args, **kwargs) |
| | self._cache[cache_key] = result |
| | |
| | |
| | if len(self._cache) > 100: |
| | |
| | keys = list(self._cache.keys()) |
| | for key in keys[:20]: |
| | del self._cache[key] |
| | |
| | return result |
| | |
| | except Exception as e: |
| | self.logger.error(f"Cached computation failed: {e}") |
| | return None |