Raheeb Hassan
Add code + LFS attributes
398659b
"""
Abstract base class for segmentation models.
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict, Any
import numpy as np
from PIL import Image
class BaseSegmenter(ABC):
"""
Abstract base class for all segmentation models.
This class defines the common interface that all segmentation models
must implement. Subclasses can handle different types of inputs:
- Class-based segmentation (YOLO, SegFormer): List of class names
- Natural language segmentation (SAM, CLIP-based): Text prompts
- Point/box-based segmentation (SAM): Coordinates
"""
def __init__(self, device: str = 'cuda', **kwargs):
"""
Initialize the segmenter.
Args:
device: Device to run inference on ('cuda' or 'cpu')
**kwargs: Model-specific parameters
"""
self.device = device
self.model = None
self._is_loaded = False
@abstractmethod
def load_model(self):
"""
Load the segmentation model.
This method should load the model weights and prepare the model
for inference. Called automatically before first use.
"""
pass
@abstractmethod
def segment(
self,
image: Image.Image,
target_classes: Optional[List[str]] = None,
**kwargs
) -> np.ndarray:
"""
Create binary segmentation mask for ROI.
Args:
image: PIL Image to segment
target_classes: List of target classes or text prompts
**kwargs: Model-specific parameters (e.g., confidence threshold)
Returns:
Binary mask as numpy array (H, W) with values 0 or 1
- 1: Region of Interest (ROI)
- 0: Background
"""
pass
@abstractmethod
def get_available_classes(self) -> Union[List[str], Dict[str, int]]:
"""
Get list or mapping of classes this model can segment.
Returns:
List of class names or dict mapping class names to IDs
"""
pass
def validate_classes(self, target_classes: Optional[List[str]]) -> List[str]:
"""
Validate and filter target classes against available classes.
Args:
target_classes: List of requested class names
Returns:
List of valid class names
"""
if target_classes is None:
return self.get_default_classes()
available_classes = self.get_available_classes()
if isinstance(available_classes, dict):
available_classes = list(available_classes.keys())
valid_classes = []
for cls in target_classes:
cls_lower = cls.lower()
if cls_lower in [c.lower() for c in available_classes]:
valid_classes.append(cls)
else:
print(f"Warning: '{cls}' not in {self.__class__.__name__} classes.")
if not valid_classes:
print(f"Warning: No valid classes found. Using defaults.")
valid_classes = self.get_default_classes()
return valid_classes
def get_default_classes(self) -> List[str]:
"""
Get default classes to segment if none specified.
Returns:
List of default class names
"""
return ['car'] # Default fallback
def ensure_loaded(self):
"""Ensure model is loaded before use."""
if not self._is_loaded:
self.load_model()
self._is_loaded = True
def __call__(
self,
image: Image.Image,
target_classes: Optional[List[str]] = None,
**kwargs
) -> np.ndarray:
"""
Convenience method to call segment().
Args:
image: PIL Image to segment
target_classes: List of target classes or text prompts
**kwargs: Model-specific parameters
Returns:
Binary mask as numpy array (H, W)
"""
self.ensure_loaded()
return self.segment(image, target_classes, **kwargs)