File size: 4,859 Bytes
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import numpy as np

from tqdm import tqdm
from gpim import gprutils
from gpim.gpreg import skgpr


class GPSTRUCT(torch.nn.Module):
    """
    ...
    """

    def __init__(self):
        super(GPSTRUCT, self).__init__()

    @staticmethod
    def GP_Structured(
        imgdata: torch.Tensor,
        imgdata_gt: torch.Tensor,
        R2: torch.Tensor,
        iter__: int = 50,
    ):
        """
        Replicates the logic of the GP_Structured function.

        Parameters
        ----------
        :param imgdata:
            Input image array.
        :param imgdata_gt:
            Ground truth image array for error computation.
        :param R2:
            Binary mask or other reference array used to zero out invalid regions.
        :param iter__:
            Number of iterations.
        """

        # -> np.ndarray
        imgdata = imgdata.detach().cpu().numpy()
        imgdata_gt = imgdata_gt.detach().cpu().numpy()
        R2 = R2.detach().cpu().numpy()

        # HACK: assume BS=1
        # [B, H, W] -> [H, W]
        imgdata = imgdata.squeeze(0)
        imgdata_gt = imgdata_gt.squeeze(0)
        R2 = R2.squeeze(0)

        # ---------------------------------------------
        # 1) Normalize input image into [0, 1]
        # ---------------------------------------------
        orig_min = np.min(imgdata)
        orig_ptp = np.ptp(imgdata)  # max - min
        R = (imgdata - orig_min) / (orig_ptp + 1e-8)  # +1e-8 for safety

        # Use the value at [1, 1] as a "missing data" placeholder
        R[R == R[1, 1]] = np.nan

        # ---------------------------------------------
        # 2) Set up GP
        # ---------------------------------------------
        e1, e2 = R.shape
        xx, yy = np.mgrid[:e1, :e2]

        # Ensure float dtype
        xx = xx.astype(float)
        yy = yy.astype(float)

        X_true = np.array([xx, yy])

        # Build “sparse” (X, R_sparse) from the data and mask
        X, R_sparse = gprutils.corrupt_data_xy(X_true, R)

        lengthscale = [[1.0, 1.0], [4.0, 4.0]]
        kernel = "RBF"

        # ---------------------------------------------
        # 3) Run GP for iter__ iterations
        #    We'll only keep the final iteration result.
        # ---------------------------------------------

        gp_data_norm = None  # will store the final GP reconstruction in [0, 1]
        
        with torch.enable_grad():
            
            for ii in tqdm(range(iter__), desc="Training.."):
                skreconstructor = skgpr.skreconstructor(
                    X,
                    R_sparse,
                    X_true,
                    kernel,
                    lengthscale=lengthscale,
                    input_dim=2,
                    grid_points_ratio=1.0,
                    learning_rate=0.1,
                    iterations=ii,
                    calculate_sd=True,
                    num_batches=1,
                    use_gpu=True,
                    verbose=False,
                )

                mean, sd, hyperparams = skreconstructor.run()

                # Reshape the final GP output back to image shape (H, W)
                gp_data = mean.reshape(e1, e2)

                # In this code, gp_data is already on a [0, 1] scale
                gp_data_norm = gp_data.copy()

        # ---------------------------------------------
        # 4) Un-normalize final GP reconstruction
        # ---------------------------------------------
        # Bring gp_data_norm back to the original image distribution
        # shape: (H, W)
        final_pred_unorm = gp_data_norm * (orig_ptp + 1e-8) + orig_min

        # If you want to respect the zeroed-out region in R2, you could do:
        # final_pred_unorm[R2_np == 0] = imgdata_np[R2_np == 0]
        # Or some other strategy; depends on your exact goal.

        # ---------------------------------------------
        # 5) Expand dims back to (B, H, W) and return
        # ---------------------------------------------
        # Because we originally squeezed out batch=1, let's reintroduce it
        final_pred_unorm = final_pred_unorm[None, ...]  # shape: (1, H, W)

        return final_pred_unorm

    def forward(
        self,
        y: torch.Tensor,
        y_sparse: torch.Tensor,
        y_mask: torch.Tensor,
        iter__: int = 20,
    ):
        x: np.ndarray = GPSTRUCT.GP_Structured(
            y_sparse,
            y,
            y_mask,
            iter__,
        )
        return torch.tensor(x).cuda()

    @staticmethod
    def get(weights=None):
        """
        Returns an instance of the GPSTRUCT class.
        """
        return GPSTRUCT()


if __name__ == "__main__":
    model = GPSTRUCT()
    x = torch.rand((1, 128, 128))
    y = torch.rand((1, 128, 128))
    mask = torch.rand((1, 128, 128))
    model(x, y, mask)