| | """ |
| | Abstract base class for segmentation models. |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import List, Optional, Union, Dict, Any |
| | import numpy as np |
| | from PIL import Image |
| |
|
| |
|
| | class BaseSegmenter(ABC): |
| | """ |
| | Abstract base class for all segmentation models. |
| | |
| | This class defines the common interface that all segmentation models |
| | must implement. Subclasses can handle different types of inputs: |
| | - Class-based segmentation (YOLO, SegFormer): List of class names |
| | - Natural language segmentation (SAM, CLIP-based): Text prompts |
| | - Point/box-based segmentation (SAM): Coordinates |
| | """ |
| | |
| | def __init__(self, device: str = 'cuda', **kwargs): |
| | """ |
| | Initialize the segmenter. |
| | |
| | Args: |
| | device: Device to run inference on ('cuda' or 'cpu') |
| | **kwargs: Model-specific parameters |
| | """ |
| | self.device = device |
| | self.model = None |
| | self._is_loaded = False |
| | |
| | @abstractmethod |
| | def load_model(self): |
| | """ |
| | Load the segmentation model. |
| | |
| | This method should load the model weights and prepare the model |
| | for inference. Called automatically before first use. |
| | """ |
| | pass |
| | |
| | @abstractmethod |
| | def segment( |
| | self, |
| | image: Image.Image, |
| | target_classes: Optional[List[str]] = None, |
| | **kwargs |
| | ) -> np.ndarray: |
| | """ |
| | Create binary segmentation mask for ROI. |
| | |
| | Args: |
| | image: PIL Image to segment |
| | target_classes: List of target classes or text prompts |
| | **kwargs: Model-specific parameters (e.g., confidence threshold) |
| | |
| | Returns: |
| | Binary mask as numpy array (H, W) with values 0 or 1 |
| | - 1: Region of Interest (ROI) |
| | - 0: Background |
| | """ |
| | pass |
| | |
| | @abstractmethod |
| | def get_available_classes(self) -> Union[List[str], Dict[str, int]]: |
| | """ |
| | Get list or mapping of classes this model can segment. |
| | |
| | Returns: |
| | List of class names or dict mapping class names to IDs |
| | """ |
| | pass |
| | |
| | def validate_classes(self, target_classes: Optional[List[str]]) -> List[str]: |
| | """ |
| | Validate and filter target classes against available classes. |
| | |
| | Args: |
| | target_classes: List of requested class names |
| | |
| | Returns: |
| | List of valid class names |
| | """ |
| | if target_classes is None: |
| | return self.get_default_classes() |
| | |
| | available_classes = self.get_available_classes() |
| | if isinstance(available_classes, dict): |
| | available_classes = list(available_classes.keys()) |
| | |
| | valid_classes = [] |
| | for cls in target_classes: |
| | cls_lower = cls.lower() |
| | if cls_lower in [c.lower() for c in available_classes]: |
| | valid_classes.append(cls) |
| | else: |
| | print(f"Warning: '{cls}' not in {self.__class__.__name__} classes.") |
| | |
| | if not valid_classes: |
| | print(f"Warning: No valid classes found. Using defaults.") |
| | valid_classes = self.get_default_classes() |
| | |
| | return valid_classes |
| | |
| | def get_default_classes(self) -> List[str]: |
| | """ |
| | Get default classes to segment if none specified. |
| | |
| | Returns: |
| | List of default class names |
| | """ |
| | return ['car'] |
| | |
| | def ensure_loaded(self): |
| | """Ensure model is loaded before use.""" |
| | if not self._is_loaded: |
| | self.load_model() |
| | self._is_loaded = True |
| | |
| | def __call__( |
| | self, |
| | image: Image.Image, |
| | target_classes: Optional[List[str]] = None, |
| | **kwargs |
| | ) -> np.ndarray: |
| | """ |
| | Convenience method to call segment(). |
| | |
| | Args: |
| | image: PIL Image to segment |
| | target_classes: List of target classes or text prompts |
| | **kwargs: Model-specific parameters |
| | |
| | Returns: |
| | Binary mask as numpy array (H, W) |
| | """ |
| | self.ensure_loaded() |
| | return self.segment(image, target_classes, **kwargs) |
| |
|