LIBRE / src /application /use_cases /process_and_predict.py
RyZ
feat: adding full working local ETL Pipeline
e391a84
Raw
History Blame Contribute Delete
5.15 kB
"""
application/use_cases/process_and_predict.py
─────────────────────────────────────────────
ProcessAndPredictUseCase β€” ETL #2 orchestrator.
Responsibility:
Triggered by a queue message, this use case:
1. Fetches the full PPG signal from the database.
2. Preprocesses it (filter β†’ normalize β†’ segment).
3. Runs AI inference (GAN β†’ VGTL-Net).
4. Persists the BPPrediction result.
Dependencies (constructor-injected):
β€’ PPGRepository β€” to fetch the stored signal
β€’ PredictionRepository β€” to store the result
β€’ SignalProcessor β€” to preprocess the signal
β€’ ModelService β€” to run inference
"""
from __future__ import annotations
from src.domain.entities.prediction import BPPrediction
from src.domain.exceptions.domain_exceptions import EntityNotFoundError
from src.domain.interfaces.repositories.ppg_repository import PPGRepository
from src.domain.interfaces.repositories.prediction_repository import PredictionRepository
from src.domain.interfaces.services.model_service import ModelService
from src.domain.interfaces.services.signal_processor import SignalProcessor
from src.shared.logger import get_logger
logger = get_logger(__name__)
class ProcessAndPredictUseCase:
"""
ETL #2: Extract β†’ Transform β†’ Load for AI-based BP prediction.
Steps:
1. Extract β€” fetch PPGSignal entity from DB using ``signal_id`` from message.
2. Transform β€” run SignalProcessor.process() (filter β†’ normalize β†’ segment).
3. Load β€” run ModelService.predict() β†’ validate β†’ persist BPPrediction.
Usage::
use_case = ProcessAndPredictUseCase(
ppg_repo=...,
prediction_repo=...,
signal_processor=...,
model_service=...,
)
prediction = await use_case.execute(message_payload)
"""
def __init__(
self,
ppg_repo: PPGRepository,
prediction_repo: PredictionRepository,
signal_processor: SignalProcessor,
model_service: ModelService,
) -> None:
self._ppg_repo = ppg_repo
self._prediction_repo = prediction_repo
self._signal_processor = signal_processor
self._model_service = model_service
async def execute(self, message: dict) -> BPPrediction:
"""
Process one queue message end-to-end.
Args:
message: Dict payload from the queue. Must contain ``"id"`` key
matching a stored PPGSignal UUID.
Returns:
Persisted BPPrediction entity.
Raises:
EntityNotFoundError: If no PPGSignal matches the message's signal ID.
InvalidSignalError: If the signal fails re-validation.
RuntimeError: If the model has not been loaded.
"""
signal_id: str = message.get("id", "")
logger.info("ProcessAndPredictUseCase.execute β€” signal_id=%s", signal_id)
# ── Step 1: EXTRACT β€” Fetch PPG signal from database ─────────────────
signal = await self._ppg_repo.get_by_id(signal_id)
if signal is None:
raise EntityNotFoundError("PPGSignal", signal_id)
logger.debug("Fetched signal: %s", signal)
# ── Step 2: TRANSFORM β€” Preprocess the signal ─────────────────────────
logger.info(
"Running signal processing pipeline (filter β†’ normalize β†’ segment) "
"on signal_id=%s",
signal_id,
)
segments = self._signal_processor.process(
raw_signal=signal.ppg_values,
sampling_rate=signal.sampling_rate,
)
logger.info(
"Preprocessing complete: %d segments of shape %s",
segments.shape[0],
segments.shape,
)
# ── Step 3a: LOAD β€” Run model inference ───────────────────────────────
if not self._model_service.is_loaded():
logger.warning("Model not loaded β€” calling load_model() now")
await self._model_service.load_model()
prediction = await self._model_service.predict(
ppg_signal_id=signal.id,
segments=segments,
)
logger.info(
"Inference complete: SBP=%.1f DBP=%.1f (model=%s, time=%.1f ms)",
prediction.predicted_sbp,
prediction.predicted_dbp,
prediction.model_version,
prediction.inference_time_ms,
)
# Validate domain rules (physiological bounds)
prediction.validate()
# ── Step 3b: LOAD β€” Persist prediction ───────────────────────────────
persisted_prediction = await self._prediction_repo.add(prediction)
logger.info("Prediction persisted with id=%s", persisted_prediction.id)
return persisted_prediction