File size: 5,303 Bytes
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
infrastructure/processing/numba_vgtlnet_preprocessor.py
────────────────────────────────────────────────────────
Numba/NumPy implementation of VGTLNetSignalPreprocessor.
"""
from __future__ import annotations
import torch

import numpy as np

from src.domain.exceptions.pipeline_exceptions import PreprocessingError
from src.domain.interfaces.services.vgtlnet_preprocessor import VGTLNetSignalPreprocessor
from src.shared.constants import VGTLNET_WINDOW_SIZE
from src.shared.logger import get_logger

logger = get_logger(__name__)


# Try defining numba visibility graph function globally for fast compilation and reuse
try:
    from numba import njit

    @njit(cache=True)
    def _vg_numba_compiled(y_arr: np.ndarray) -> np.ndarray:
        N = len(y_arr)
        adj_out = np.zeros((N, N), dtype=np.uint8)
        for i in range(N - 1):
            for j in range(i + 1, N):
                vis = True
                for k in range(i + 1, j):
                    lv = y_arr[i] + (y_arr[j] - y_arr[i]) * (k - i) / (j - i)
                    if y_arr[k] >= lv:
                        vis = False
                        break
                if vis:
                    adj_out[i, j] = 255
                    adj_out[j, i] = 255
        return adj_out

    _has_numba = True
except ImportError:
    _has_numba = False
    logger.debug("Numba is not installed. Falling back to NumPy for visibility graphs.")


class NumbaVGTLNetPreprocessor(VGTLNetSignalPreprocessor):
    """
    VGTL-Net preprocessing pipeline. Creates 224x224x3 Adjacency matrices
    representing visibility graphs, converts to PyTorch tensors, and normalizes them.
    """

    def preprocess_signals(
        self,
        ppg_segments: np.ndarray,
        ecg_segments: np.ndarray,
    ) -> torch.Tensor:
        """
        Ingest batch of segmented PPG & ECG windows, compute visibility graphs (R=PPG, G=ECG, B=dPPG),
        convert to PyTorch float32 batch tensor of shape (N, 3, 224, 224) normalized with mean=0.5, std=0.5.
        """
        import torch

        try:
            n_wins = len(ppg_segments)
            if n_wins == 0:
                raise PreprocessingError("preprocess_signals", "PPG/ECG segment batch is empty")

            tensors = []
            for i in range(n_wins):
                ppg_win = ppg_segments[i]
                ecg_win = ecg_segments[i]

                if len(ppg_win) != VGTLNET_WINDOW_SIZE or len(ecg_win) != VGTLNET_WINDOW_SIZE:
                    raise PreprocessingError(
                        "preprocess_signals",
                        f"Window size mismatch: ppg={len(ppg_win)}, ecg={len(ecg_win)}, expected={VGTLNET_WINDOW_SIZE}"
                    )

                # Compute adjacency RGB image (224x224x3)
                rgb_img = self._build_rgb_adjacency_image(ppg_win, ecg_win)

                # Convert to Tensor & Normalize (mean=0.5, std=0.5)
                tensor = self._vg_image_to_tensor(rgb_img)
                tensors.append(tensor)

            return torch.stack(tensors)

        except Exception as e:
            if isinstance(e, PreprocessingError):
                raise e
            raise PreprocessingError("preprocess_signals", f"Unexpected error: {e}") from e

    # ── Helper Processing Methods ───────────────────────────────────────────────

    def _build_rgb_adjacency_image(self, ppg: np.ndarray, ecg: np.ndarray) -> np.ndarray:
        dppg = np.gradient(ppg.astype(np.float64))

        if _has_numba:
            adj_ppg = _vg_numba_compiled(ppg.astype(np.float64))
            adj_ecg = _vg_numba_compiled(ecg.astype(np.float64))
            adj_dppg = _vg_numba_compiled(dppg)
        else:
            adj_ppg = self._vg_numpy_fallback(ppg)
            adj_ecg = self._vg_numpy_fallback(ecg)
            adj_dppg = self._vg_numpy_fallback(dppg)

        result = np.stack([adj_ppg, adj_ecg, adj_dppg], axis=-1)  # (224, 224, 3)
        return result

    @staticmethod
    def _vg_numpy_fallback(y: np.ndarray) -> np.ndarray:
        N = len(y)
        adj = np.zeros((N, N), dtype=np.uint8)
        y = y.astype(np.float64)

        for i in range(N - 1):
            for j in range(i + 1, N):
                visible = True
                for k in range(i + 1, j):
                    line_val = y[i] + (y[j] - y[i]) * (k - i) / (j - i)
                    if y[k] >= line_val:
                        visible = False
                        break
                if visible:
                    adj[i, j] = 255
                    adj[j, i] = 255
        return adj

    @staticmethod
    def _vg_image_to_tensor(rgb_image: np.ndarray) -> "torch.Tensor":
        import torch
        from PIL import Image
        import torchvision.transforms as T

        _VG_TRANSFORM = T.Compose([
            T.ToTensor(),                                    # uint8 [0,255] -> float [0,1]
            T.Normalize(mean=[0.5, 0.5, 0.5],
                        std=[0.5, 0.5, 0.5]),               # [0,1] -> [-1, 1]
        ])
        pil_img = Image.fromarray(rgb_image)
        return _VG_TRANSFORM(pil_img)