File size: 5,150 Bytes
e391a84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
application/use_cases/process_and_predict.py
βββββββββββββββββββββββββββββββββββββββββββββ
ProcessAndPredictUseCase β ETL #2 orchestrator.
Responsibility:
Triggered by a queue message, this use case:
1. Fetches the full PPG signal from the database.
2. Preprocesses it (filter β normalize β segment).
3. Runs AI inference (GAN β VGTL-Net).
4. Persists the BPPrediction result.
Dependencies (constructor-injected):
β’ PPGRepository β to fetch the stored signal
β’ PredictionRepository β to store the result
β’ SignalProcessor β to preprocess the signal
β’ ModelService β to run inference
"""
from __future__ import annotations
from src.domain.entities.prediction import BPPrediction
from src.domain.exceptions.domain_exceptions import EntityNotFoundError
from src.domain.interfaces.repositories.ppg_repository import PPGRepository
from src.domain.interfaces.repositories.prediction_repository import PredictionRepository
from src.domain.interfaces.services.model_service import ModelService
from src.domain.interfaces.services.signal_processor import SignalProcessor
from src.shared.logger import get_logger
logger = get_logger(__name__)
class ProcessAndPredictUseCase:
"""
ETL #2: Extract β Transform β Load for AI-based BP prediction.
Steps:
1. Extract β fetch PPGSignal entity from DB using ``signal_id`` from message.
2. Transform β run SignalProcessor.process() (filter β normalize β segment).
3. Load β run ModelService.predict() β validate β persist BPPrediction.
Usage::
use_case = ProcessAndPredictUseCase(
ppg_repo=...,
prediction_repo=...,
signal_processor=...,
model_service=...,
)
prediction = await use_case.execute(message_payload)
"""
def __init__(
self,
ppg_repo: PPGRepository,
prediction_repo: PredictionRepository,
signal_processor: SignalProcessor,
model_service: ModelService,
) -> None:
self._ppg_repo = ppg_repo
self._prediction_repo = prediction_repo
self._signal_processor = signal_processor
self._model_service = model_service
async def execute(self, message: dict) -> BPPrediction:
"""
Process one queue message end-to-end.
Args:
message: Dict payload from the queue. Must contain ``"id"`` key
matching a stored PPGSignal UUID.
Returns:
Persisted BPPrediction entity.
Raises:
EntityNotFoundError: If no PPGSignal matches the message's signal ID.
InvalidSignalError: If the signal fails re-validation.
RuntimeError: If the model has not been loaded.
"""
signal_id: str = message.get("id", "")
logger.info("ProcessAndPredictUseCase.execute β signal_id=%s", signal_id)
# ββ Step 1: EXTRACT β Fetch PPG signal from database βββββββββββββββββ
signal = await self._ppg_repo.get_by_id(signal_id)
if signal is None:
raise EntityNotFoundError("PPGSignal", signal_id)
logger.debug("Fetched signal: %s", signal)
# ββ Step 2: TRANSFORM β Preprocess the signal βββββββββββββββββββββββββ
logger.info(
"Running signal processing pipeline (filter β normalize β segment) "
"on signal_id=%s",
signal_id,
)
segments = self._signal_processor.process(
raw_signal=signal.ppg_values,
sampling_rate=signal.sampling_rate,
)
logger.info(
"Preprocessing complete: %d segments of shape %s",
segments.shape[0],
segments.shape,
)
# ββ Step 3a: LOAD β Run model inference βββββββββββββββββββββββββββββββ
if not self._model_service.is_loaded():
logger.warning("Model not loaded β calling load_model() now")
await self._model_service.load_model()
prediction = await self._model_service.predict(
ppg_signal_id=signal.id,
segments=segments,
)
logger.info(
"Inference complete: SBP=%.1f DBP=%.1f (model=%s, time=%.1f ms)",
prediction.predicted_sbp,
prediction.predicted_dbp,
prediction.model_version,
prediction.inference_time_ms,
)
# Validate domain rules (physiological bounds)
prediction.validate()
# ββ Step 3b: LOAD β Persist prediction βββββββββββββββββββββββββββββββ
persisted_prediction = await self._prediction_repo.add(prediction)
logger.info("Prediction persisted with id=%s", persisted_prediction.id)
return persisted_prediction
|