File size: 3,334 Bytes
df4a21a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """
Inference service for running model predictions.
"""
from typing import Any, Dict, Optional
from PIL import Image
from app.core.errors import InferenceError, ModelNotFoundError
from app.core.logging import get_logger
from app.services.model_registry import get_model_registry
from app.utils.timing import Timer
logger = get_logger(__name__)
class InferenceService:
"""
Service for running inference on individual models.
"""
def __init__(self):
self._registry = get_model_registry()
def predict_single(
self,
model_key: str,
image: Optional[Image.Image] = None,
image_bytes: Optional[bytes] = None,
**kwargs
) -> Dict[str, Any]:
"""
Run prediction on a single submodel.
Args:
model_key: Submodel name or repo_id
image: PIL Image object
image_bytes: Raw image bytes (alternative to image)
**kwargs: Additional arguments for the model
Returns:
Standardized prediction dictionary
Raises:
ModelNotFoundError: If model not found
InferenceError: If prediction fails
"""
try:
submodel = self._registry.get_submodel(model_key)
return submodel.predict(image=image, image_bytes=image_bytes, **kwargs)
except ModelNotFoundError:
raise
except Exception as e:
logger.error(f"Inference failed for {model_key}: {e}")
raise InferenceError(
message=f"Inference failed for model {model_key}",
details={"model": model_key, "error": str(e)}
)
def predict_all_submodels(
self,
image: Optional[Image.Image] = None,
image_bytes: Optional[bytes] = None,
**kwargs
) -> Dict[str, Dict[str, Any]]:
"""
Run prediction on all loaded submodels.
Args:
image: PIL Image object
image_bytes: Raw image bytes (alternative to image)
**kwargs: Additional arguments for the models
Returns:
Dictionary mapping submodel name to prediction result
Raises:
InferenceError: If any prediction fails
"""
submodels = self._registry.get_all_submodels()
results = {}
for name, submodel in submodels.items():
try:
result = submodel.predict(image=image, image_bytes=image_bytes, **kwargs)
results[name] = result
except Exception as e:
logger.error(f"Inference failed for submodel {name}: {e}")
raise InferenceError(
message=f"Inference failed for submodel {name}",
details={"model": name, "error": str(e)}
)
return results
# Global singleton instance
_inference_service: Optional[InferenceService] = None
def get_inference_service() -> InferenceService:
"""
Get the global inference service instance.
Returns:
InferenceService instance
"""
global _inference_service
if _inference_service is None:
_inference_service = InferenceService()
return _inference_service
|