Spaces:
Running
Running
File size: 4,859 Bytes
0917e8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import numpy as np
from tqdm import tqdm
from gpim import gprutils
from gpim.gpreg import skgpr
class GPSTRUCT(torch.nn.Module):
"""
...
"""
def __init__(self):
super(GPSTRUCT, self).__init__()
@staticmethod
def GP_Structured(
imgdata: torch.Tensor,
imgdata_gt: torch.Tensor,
R2: torch.Tensor,
iter__: int = 50,
):
"""
Replicates the logic of the GP_Structured function.
Parameters
----------
:param imgdata:
Input image array.
:param imgdata_gt:
Ground truth image array for error computation.
:param R2:
Binary mask or other reference array used to zero out invalid regions.
:param iter__:
Number of iterations.
"""
# -> np.ndarray
imgdata = imgdata.detach().cpu().numpy()
imgdata_gt = imgdata_gt.detach().cpu().numpy()
R2 = R2.detach().cpu().numpy()
# HACK: assume BS=1
# [B, H, W] -> [H, W]
imgdata = imgdata.squeeze(0)
imgdata_gt = imgdata_gt.squeeze(0)
R2 = R2.squeeze(0)
# ---------------------------------------------
# 1) Normalize input image into [0, 1]
# ---------------------------------------------
orig_min = np.min(imgdata)
orig_ptp = np.ptp(imgdata) # max - min
R = (imgdata - orig_min) / (orig_ptp + 1e-8) # +1e-8 for safety
# Use the value at [1, 1] as a "missing data" placeholder
R[R == R[1, 1]] = np.nan
# ---------------------------------------------
# 2) Set up GP
# ---------------------------------------------
e1, e2 = R.shape
xx, yy = np.mgrid[:e1, :e2]
# Ensure float dtype
xx = xx.astype(float)
yy = yy.astype(float)
X_true = np.array([xx, yy])
# Build “sparse” (X, R_sparse) from the data and mask
X, R_sparse = gprutils.corrupt_data_xy(X_true, R)
lengthscale = [[1.0, 1.0], [4.0, 4.0]]
kernel = "RBF"
# ---------------------------------------------
# 3) Run GP for iter__ iterations
# We'll only keep the final iteration result.
# ---------------------------------------------
gp_data_norm = None # will store the final GP reconstruction in [0, 1]
with torch.enable_grad():
for ii in tqdm(range(iter__), desc="Training.."):
skreconstructor = skgpr.skreconstructor(
X,
R_sparse,
X_true,
kernel,
lengthscale=lengthscale,
input_dim=2,
grid_points_ratio=1.0,
learning_rate=0.1,
iterations=ii,
calculate_sd=True,
num_batches=1,
use_gpu=True,
verbose=False,
)
mean, sd, hyperparams = skreconstructor.run()
# Reshape the final GP output back to image shape (H, W)
gp_data = mean.reshape(e1, e2)
# In this code, gp_data is already on a [0, 1] scale
gp_data_norm = gp_data.copy()
# ---------------------------------------------
# 4) Un-normalize final GP reconstruction
# ---------------------------------------------
# Bring gp_data_norm back to the original image distribution
# shape: (H, W)
final_pred_unorm = gp_data_norm * (orig_ptp + 1e-8) + orig_min
# If you want to respect the zeroed-out region in R2, you could do:
# final_pred_unorm[R2_np == 0] = imgdata_np[R2_np == 0]
# Or some other strategy; depends on your exact goal.
# ---------------------------------------------
# 5) Expand dims back to (B, H, W) and return
# ---------------------------------------------
# Because we originally squeezed out batch=1, let's reintroduce it
final_pred_unorm = final_pred_unorm[None, ...] # shape: (1, H, W)
return final_pred_unorm
def forward(
self,
y: torch.Tensor,
y_sparse: torch.Tensor,
y_mask: torch.Tensor,
iter__: int = 20,
):
x: np.ndarray = GPSTRUCT.GP_Structured(
y_sparse,
y,
y_mask,
iter__,
)
return torch.tensor(x).cuda()
@staticmethod
def get(weights=None):
"""
Returns an instance of the GPSTRUCT class.
"""
return GPSTRUCT()
if __name__ == "__main__":
model = GPSTRUCT()
x = torch.rand((1, 128, 128))
y = torch.rand((1, 128, 128))
mask = torch.rand((1, 128, 128))
model(x, y, mask)
|