File size: 5,150 Bytes
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
application/use_cases/process_and_predict.py
─────────────────────────────────────────────
ProcessAndPredictUseCase β€” ETL #2 orchestrator.

Responsibility:
    Triggered by a queue message, this use case:
      1. Fetches the full PPG signal from the database.
      2. Preprocesses it (filter β†’ normalize β†’ segment).
      3. Runs AI inference (GAN β†’ VGTL-Net).
      4. Persists the BPPrediction result.

Dependencies (constructor-injected):
    β€’ PPGRepository        β€” to fetch the stored signal
    β€’ PredictionRepository β€” to store the result
    β€’ SignalProcessor      β€” to preprocess the signal
    β€’ ModelService         β€” to run inference
"""
from __future__ import annotations

from src.domain.entities.prediction import BPPrediction
from src.domain.exceptions.domain_exceptions import EntityNotFoundError
from src.domain.interfaces.repositories.ppg_repository import PPGRepository
from src.domain.interfaces.repositories.prediction_repository import PredictionRepository
from src.domain.interfaces.services.model_service import ModelService
from src.domain.interfaces.services.signal_processor import SignalProcessor
from src.shared.logger import get_logger

logger = get_logger(__name__)


class ProcessAndPredictUseCase:
    """
    ETL #2: Extract β†’ Transform β†’ Load for AI-based BP prediction.

    Steps:
        1. Extract   β€” fetch PPGSignal entity from DB using ``signal_id`` from message.
        2. Transform β€” run SignalProcessor.process() (filter β†’ normalize β†’ segment).
        3. Load      β€” run ModelService.predict() β†’ validate β†’ persist BPPrediction.

    Usage::

        use_case = ProcessAndPredictUseCase(
            ppg_repo=...,
            prediction_repo=...,
            signal_processor=...,
            model_service=...,
        )
        prediction = await use_case.execute(message_payload)
    """

    def __init__(
        self,
        ppg_repo: PPGRepository,
        prediction_repo: PredictionRepository,
        signal_processor: SignalProcessor,
        model_service: ModelService,
    ) -> None:
        self._ppg_repo = ppg_repo
        self._prediction_repo = prediction_repo
        self._signal_processor = signal_processor
        self._model_service = model_service

    async def execute(self, message: dict) -> BPPrediction:
        """
        Process one queue message end-to-end.

        Args:
            message: Dict payload from the queue. Must contain ``"id"`` key
                     matching a stored PPGSignal UUID.

        Returns:
            Persisted BPPrediction entity.

        Raises:
            EntityNotFoundError: If no PPGSignal matches the message's signal ID.
            InvalidSignalError:  If the signal fails re-validation.
            RuntimeError:        If the model has not been loaded.
        """
        signal_id: str = message.get("id", "")
        logger.info("ProcessAndPredictUseCase.execute β€” signal_id=%s", signal_id)

        # ── Step 1: EXTRACT β€” Fetch PPG signal from database ─────────────────
        signal = await self._ppg_repo.get_by_id(signal_id)
        if signal is None:
            raise EntityNotFoundError("PPGSignal", signal_id)

        logger.debug("Fetched signal: %s", signal)

        # ── Step 2: TRANSFORM β€” Preprocess the signal ─────────────────────────
        logger.info(
            "Running signal processing pipeline (filter β†’ normalize β†’ segment) "
            "on signal_id=%s",
            signal_id,
        )
        segments = self._signal_processor.process(
            raw_signal=signal.ppg_values,
            sampling_rate=signal.sampling_rate,
        )
        logger.info(
            "Preprocessing complete: %d segments of shape %s",
            segments.shape[0],
            segments.shape,
        )

        # ── Step 3a: LOAD β€” Run model inference ───────────────────────────────
        if not self._model_service.is_loaded():
            logger.warning("Model not loaded β€” calling load_model() now")
            await self._model_service.load_model()

        prediction = await self._model_service.predict(
            ppg_signal_id=signal.id,
            segments=segments,
        )
        logger.info(
            "Inference complete: SBP=%.1f DBP=%.1f (model=%s, time=%.1f ms)",
            prediction.predicted_sbp,
            prediction.predicted_dbp,
            prediction.model_version,
            prediction.inference_time_ms,
        )

        # Validate domain rules (physiological bounds)
        prediction.validate()

        # ── Step 3b: LOAD β€” Persist prediction ───────────────────────────────
        persisted_prediction = await self._prediction_repo.add(prediction)
        logger.info("Prediction persisted with id=%s", persisted_prediction.id)

        return persisted_prediction