raheebhassan's picture
Add code + LFS attributes
398659b
"""
Factory for creating segmentation models.
"""
from typing import Dict, Type, Optional, List
from .base import BaseSegmenter
from .segformer import SegFormerSegmenter
from .yolo import YOLOSegmenter
from .mask2former import Mask2FormerSegmenter
from .maskrcnn import MaskRCNNSegmenter
from .sam3 import SAM3Segmenter
# Registry of available segmentation methods
SEGMENTER_REGISTRY: Dict[str, Type[BaseSegmenter]] = {
'segformer': SegFormerSegmenter,
'yolo': YOLOSegmenter,
'mask2former': Mask2FormerSegmenter,
'maskrcnn': MaskRCNNSegmenter,
'sam3': SAM3Segmenter,
}
def register_segmenter(name: str, segmenter_class: Type[BaseSegmenter]):
"""
Register a new segmentation method.
Args:
name: Method name (e.g., 'sam', 'drone_detector')
segmenter_class: Segmenter class that extends BaseSegmenter
"""
if not issubclass(segmenter_class, BaseSegmenter):
raise ValueError(f"{segmenter_class} must extend BaseSegmenter")
SEGMENTER_REGISTRY[name.lower()] = segmenter_class
def create_segmenter(
method: str,
device: str = 'cuda',
**kwargs
) -> BaseSegmenter:
"""
Factory function to create a segmentation model.
Args:
method: Segmentation method name ('segformer', 'yolo', etc.)
device: Device to run on ('cuda' or 'cpu')
**kwargs: Method-specific parameters
Returns:
Instance of the requested segmenter
Raises:
ValueError: If method is not recognized
Example:
>>> segmenter = create_segmenter('yolo', device='cuda', conf_threshold=0.3)
>>> mask = segmenter(image, target_classes=['car', 'person'])
"""
method_lower = method.lower()
if method_lower not in SEGMENTER_REGISTRY:
available = ', '.join(SEGMENTER_REGISTRY.keys())
raise ValueError(
f"Unknown segmentation method: '{method}'. "
f"Available methods: {available}"
)
segmenter_class = SEGMENTER_REGISTRY[method_lower]
return segmenter_class(device=device, **kwargs)
def get_available_methods() -> List[str]:
"""
Get list of available segmentation methods.
Returns:
List of method names
"""
return list(SEGMENTER_REGISTRY.keys())