| """ |
| infrastructure/processing/numba_vgtlnet_preprocessor.py |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Numba/NumPy implementation of VGTLNetSignalPreprocessor. |
| """ |
| from __future__ import annotations |
| import torch |
|
|
| import numpy as np |
|
|
| from src.domain.exceptions.pipeline_exceptions import PreprocessingError |
| from src.domain.interfaces.services.vgtlnet_preprocessor import VGTLNetSignalPreprocessor |
| from src.shared.constants import VGTLNET_WINDOW_SIZE |
| from src.shared.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| |
| try: |
| from numba import njit |
|
|
| @njit(cache=True) |
| def _vg_numba_compiled(y_arr: np.ndarray) -> np.ndarray: |
| N = len(y_arr) |
| adj_out = np.zeros((N, N), dtype=np.uint8) |
| for i in range(N - 1): |
| for j in range(i + 1, N): |
| vis = True |
| for k in range(i + 1, j): |
| lv = y_arr[i] + (y_arr[j] - y_arr[i]) * (k - i) / (j - i) |
| if y_arr[k] >= lv: |
| vis = False |
| break |
| if vis: |
| adj_out[i, j] = 255 |
| adj_out[j, i] = 255 |
| return adj_out |
|
|
| _has_numba = True |
| except ImportError: |
| _has_numba = False |
| logger.debug("Numba is not installed. Falling back to NumPy for visibility graphs.") |
|
|
|
|
| class NumbaVGTLNetPreprocessor(VGTLNetSignalPreprocessor): |
| """ |
| VGTL-Net preprocessing pipeline. Creates 224x224x3 Adjacency matrices |
| representing visibility graphs, converts to PyTorch tensors, and normalizes them. |
| """ |
|
|
| def preprocess_signals( |
| self, |
| ppg_segments: np.ndarray, |
| ecg_segments: np.ndarray, |
| ) -> torch.Tensor: |
| """ |
| Ingest batch of segmented PPG & ECG windows, compute visibility graphs (R=PPG, G=ECG, B=dPPG), |
| convert to PyTorch float32 batch tensor of shape (N, 3, 224, 224) normalized with mean=0.5, std=0.5. |
| """ |
| import torch |
|
|
| try: |
| n_wins = len(ppg_segments) |
| if n_wins == 0: |
| raise PreprocessingError("preprocess_signals", "PPG/ECG segment batch is empty") |
|
|
| tensors = [] |
| for i in range(n_wins): |
| ppg_win = ppg_segments[i] |
| ecg_win = ecg_segments[i] |
|
|
| if len(ppg_win) != VGTLNET_WINDOW_SIZE or len(ecg_win) != VGTLNET_WINDOW_SIZE: |
| raise PreprocessingError( |
| "preprocess_signals", |
| f"Window size mismatch: ppg={len(ppg_win)}, ecg={len(ecg_win)}, expected={VGTLNET_WINDOW_SIZE}" |
| ) |
|
|
| |
| rgb_img = self._build_rgb_adjacency_image(ppg_win, ecg_win) |
|
|
| |
| tensor = self._vg_image_to_tensor(rgb_img) |
| tensors.append(tensor) |
|
|
| return torch.stack(tensors) |
|
|
| except Exception as e: |
| if isinstance(e, PreprocessingError): |
| raise e |
| raise PreprocessingError("preprocess_signals", f"Unexpected error: {e}") from e |
|
|
| |
|
|
| def _build_rgb_adjacency_image(self, ppg: np.ndarray, ecg: np.ndarray) -> np.ndarray: |
| dppg = np.gradient(ppg.astype(np.float64)) |
|
|
| if _has_numba: |
| adj_ppg = _vg_numba_compiled(ppg.astype(np.float64)) |
| adj_ecg = _vg_numba_compiled(ecg.astype(np.float64)) |
| adj_dppg = _vg_numba_compiled(dppg) |
| else: |
| adj_ppg = self._vg_numpy_fallback(ppg) |
| adj_ecg = self._vg_numpy_fallback(ecg) |
| adj_dppg = self._vg_numpy_fallback(dppg) |
|
|
| result = np.stack([adj_ppg, adj_ecg, adj_dppg], axis=-1) |
| return result |
|
|
| @staticmethod |
| def _vg_numpy_fallback(y: np.ndarray) -> np.ndarray: |
| N = len(y) |
| adj = np.zeros((N, N), dtype=np.uint8) |
| y = y.astype(np.float64) |
|
|
| for i in range(N - 1): |
| for j in range(i + 1, N): |
| visible = True |
| for k in range(i + 1, j): |
| line_val = y[i] + (y[j] - y[i]) * (k - i) / (j - i) |
| if y[k] >= line_val: |
| visible = False |
| break |
| if visible: |
| adj[i, j] = 255 |
| adj[j, i] = 255 |
| return adj |
|
|
| @staticmethod |
| def _vg_image_to_tensor(rgb_image: np.ndarray) -> "torch.Tensor": |
| import torch |
| from PIL import Image |
| import torchvision.transforms as T |
|
|
| _VG_TRANSFORM = T.Compose([ |
| T.ToTensor(), |
| T.Normalize(mean=[0.5, 0.5, 0.5], |
| std=[0.5, 0.5, 0.5]), |
| ]) |
| pil_img = Image.fromarray(rgb_image) |
| return _VG_TRANSFORM(pil_img) |
|
|