| | """ |
| | Fusion service for combining submodel predictions. |
| | """ |
| |
|
| | from typing import Any, Dict |
| |
|
| | from app.core.errors import FusionError |
| | from app.core.logging import get_logger |
| | from app.services.model_registry import get_model_registry |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class FusionService: |
| | """ |
| | Service for running fusion predictions. |
| | """ |
| | |
| | def __init__(self): |
| | self._registry = get_model_registry() |
| | |
| | def fuse( |
| | self, |
| | submodel_outputs: Dict[str, Dict[str, Any]], |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Run fusion on submodel outputs. |
| | |
| | Args: |
| | submodel_outputs: Dictionary mapping submodel name to its prediction output |
| | **kwargs: Additional arguments for the fusion model |
| | |
| | Returns: |
| | Standardized prediction dictionary |
| | |
| | Raises: |
| | FusionError: If fusion fails |
| | """ |
| | try: |
| | fusion = self._registry.get_fusion() |
| | return fusion.predict(submodel_outputs=submodel_outputs, **kwargs) |
| | except FusionError: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Fusion failed: {e}") |
| | raise FusionError( |
| | message="Fusion prediction failed", |
| | details={"error": str(e)} |
| | ) |
| |
|
| |
|
| | |
| | _fusion_service = None |
| |
|
| |
|
| | def get_fusion_service() -> FusionService: |
| | """ |
| | Get the global fusion service instance. |
| | |
| | Returns: |
| | FusionService instance |
| | """ |
| | global _fusion_service |
| | if _fusion_service is None: |
| | _fusion_service = FusionService() |
| | return _fusion_service |
| |
|