| """ |
| Factory for creating segmentation models. |
| """ |
|
|
| from typing import Dict, Type, Optional, List |
| from .base import BaseSegmenter |
| from .segformer import SegFormerSegmenter |
| from .yolo import YOLOSegmenter |
| from .mask2former import Mask2FormerSegmenter |
| from .maskrcnn import MaskRCNNSegmenter |
| from .sam3 import SAM3Segmenter |
|
|
|
|
| |
| SEGMENTER_REGISTRY: Dict[str, Type[BaseSegmenter]] = { |
| 'segformer': SegFormerSegmenter, |
| 'yolo': YOLOSegmenter, |
| 'mask2former': Mask2FormerSegmenter, |
| 'maskrcnn': MaskRCNNSegmenter, |
| 'sam3': SAM3Segmenter, |
| } |
|
|
|
|
| def register_segmenter(name: str, segmenter_class: Type[BaseSegmenter]): |
| """ |
| Register a new segmentation method. |
| |
| Args: |
| name: Method name (e.g., 'sam', 'drone_detector') |
| segmenter_class: Segmenter class that extends BaseSegmenter |
| """ |
| if not issubclass(segmenter_class, BaseSegmenter): |
| raise ValueError(f"{segmenter_class} must extend BaseSegmenter") |
| SEGMENTER_REGISTRY[name.lower()] = segmenter_class |
|
|
|
|
| def create_segmenter( |
| method: str, |
| device: str = 'cuda', |
| **kwargs |
| ) -> BaseSegmenter: |
| """ |
| Factory function to create a segmentation model. |
| |
| Args: |
| method: Segmentation method name ('segformer', 'yolo', etc.) |
| device: Device to run on ('cuda' or 'cpu') |
| **kwargs: Method-specific parameters |
| |
| Returns: |
| Instance of the requested segmenter |
| |
| Raises: |
| ValueError: If method is not recognized |
| |
| Example: |
| >>> segmenter = create_segmenter('yolo', device='cuda', conf_threshold=0.3) |
| >>> mask = segmenter(image, target_classes=['car', 'person']) |
| """ |
| method_lower = method.lower() |
| |
| if method_lower not in SEGMENTER_REGISTRY: |
| available = ', '.join(SEGMENTER_REGISTRY.keys()) |
| raise ValueError( |
| f"Unknown segmentation method: '{method}'. " |
| f"Available methods: {available}" |
| ) |
| |
| segmenter_class = SEGMENTER_REGISTRY[method_lower] |
| return segmenter_class(device=device, **kwargs) |
|
|
|
|
| def get_available_methods() -> List[str]: |
| """ |
| Get list of available segmentation methods. |
| |
| Returns: |
| List of method names |
| """ |
| return list(SEGMENTER_REGISTRY.keys()) |
|
|