""" Abstract base class for segmentation models. """ from abc import ABC, abstractmethod from typing import List, Optional, Union, Dict, Any import numpy as np from PIL import Image class BaseSegmenter(ABC): """ Abstract base class for all segmentation models. This class defines the common interface that all segmentation models must implement. Subclasses can handle different types of inputs: - Class-based segmentation (YOLO, SegFormer): List of class names - Natural language segmentation (SAM, CLIP-based): Text prompts - Point/box-based segmentation (SAM): Coordinates """ def __init__(self, device: str = 'cuda', **kwargs): """ Initialize the segmenter. Args: device: Device to run inference on ('cuda' or 'cpu') **kwargs: Model-specific parameters """ self.device = device self.model = None self._is_loaded = False @abstractmethod def load_model(self): """ Load the segmentation model. This method should load the model weights and prepare the model for inference. Called automatically before first use. """ pass @abstractmethod def segment( self, image: Image.Image, target_classes: Optional[List[str]] = None, **kwargs ) -> np.ndarray: """ Create binary segmentation mask for ROI. Args: image: PIL Image to segment target_classes: List of target classes or text prompts **kwargs: Model-specific parameters (e.g., confidence threshold) Returns: Binary mask as numpy array (H, W) with values 0 or 1 - 1: Region of Interest (ROI) - 0: Background """ pass @abstractmethod def get_available_classes(self) -> Union[List[str], Dict[str, int]]: """ Get list or mapping of classes this model can segment. Returns: List of class names or dict mapping class names to IDs """ pass def validate_classes(self, target_classes: Optional[List[str]]) -> List[str]: """ Validate and filter target classes against available classes. Args: target_classes: List of requested class names Returns: List of valid class names """ if target_classes is None: return self.get_default_classes() available_classes = self.get_available_classes() if isinstance(available_classes, dict): available_classes = list(available_classes.keys()) valid_classes = [] for cls in target_classes: cls_lower = cls.lower() if cls_lower in [c.lower() for c in available_classes]: valid_classes.append(cls) else: print(f"Warning: '{cls}' not in {self.__class__.__name__} classes.") if not valid_classes: print(f"Warning: No valid classes found. Using defaults.") valid_classes = self.get_default_classes() return valid_classes def get_default_classes(self) -> List[str]: """ Get default classes to segment if none specified. Returns: List of default class names """ return ['car'] # Default fallback def ensure_loaded(self): """Ensure model is loaded before use.""" if not self._is_loaded: self.load_model() self._is_loaded = True def __call__( self, image: Image.Image, target_classes: Optional[List[str]] = None, **kwargs ) -> np.ndarray: """ Convenience method to call segment(). Args: image: PIL Image to segment target_classes: List of target classes or text prompts **kwargs: Model-specific parameters Returns: Binary mask as numpy array (H, W) """ self.ensure_loaded() return self.segment(image, target_classes, **kwargs)