File size: 12,095 Bytes
bddd311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from torch.nn import functional as F
import torch
import matplotlib.pyplot as plt

from torchvision import transforms
import numpy as np
from PIL import Image
import cv2
import torch.nn as nn
import scipy.stats as st


# 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
# 这个函数是在推理的时候用的,训练的时候没有使用
def remove_borders(keypoints, scores, border: int, height: int, width: int):
    """ Removes keypoints too close to the border """
    '''
    keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
    scores: 每个关键点对应的分数,形状为 (N,)。
    border: 表示需要移除的边界宽度。    推理时预设的是4像素
    height: 图像的高度。
    width: 图像的宽度。
    '''
    # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
    mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
    # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
    mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
    # 组合掩码 (必须同时满足高度和宽度条件)
    mask = mask_h & mask_w  # 所以这个mask是判定条件
    # 返回过滤后的关键点和分数
    return keypoints[mask], scores[mask]


def simple_nms(scores, nms_radius: int):
    """
    快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
    
    参数:
        scores: 关键点分数图 (B, H, W) 或 (H, W)
        nms_radius: NMS邻域半径
    
    返回:
        NMS处理后的分数图
    """
    assert (nms_radius >= 0)
    # 计算NMS窗口大小 (2*半径+1)
    size = nms_radius * 2 + 1
    avg_size = 2
    # 定义最大池化函数 (使用固定步长1和适当填充)
    def max_pool(x):
        return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
    
    # 创建与输入相同形状的零张量
    zeros = torch.zeros_like(scores)
    # max_map = max_pool(scores)

    # 步骤1: 识别局部最大值点
    # 比较每个点与其邻域内的最大值
    max_mask = scores == max_pool(scores)   # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
    # 步骤2: 添加微小随机扰动 (避免多个相同最大值)
    # 生成 [0, 0.1) 范围内的随机数
    max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10 
    # 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
    # 非局部最大值点置零
    max_mask_[~max_mask] = 0

    # 步骤3: 对扰动后的图再次应用NMS
    # 识别扰动后仍然是局部最大值的点
    mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0))   # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。

    # 步骤4: 保留局部最大值点,其他点置零
    return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。


def pre_processing(data):
    """ Enhance retinal images """
    train_imgs = datasets_normalized(data)
    train_imgs = clahe_equalized(train_imgs)
    train_imgs = adjust_gamma(train_imgs, 1.2)

    train_imgs = train_imgs / 255.

    return train_imgs.astype(np.float32)


def rgb2gray(rgb):
    """ Convert RGB image to gray image """
    r, g, b = rgb.split()
    return g


# 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
# 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
def clahe_equalized(images):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    # clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
    # tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
    images_equalized = np.empty(images.shape)
    images_equalized[:, :] = clahe.apply(np.array(images[:, :],
                                                  dtype=np.uint8))

    return images_equalized


def datasets_normalized(images):
    # 归一化之后还需要把值映射到0到255
    # images_normalized = np.empty(images.shape)
    images_std = np.std(images)
    images_mean = np.mean(images)
    images_normalized = (images - images_mean) / (images_std + 1e-6)
    minv = np.min(images_normalized)
    images_normalized = ((images_normalized - minv) /
                         (np.max(images_normalized) - minv)) * 255

    return images_normalized


def adjust_gamma(images, gamma=1.0):
    invGamma = 1.0 / gamma  # invGamma: 伽马值的倒数,用于生成查找表。
    table = np.array([((i / 255.0) ** invGamma) * 255
                      for i in np.arange(0, 256)]).astype("uint8")  
    # 预计算伽马校正的转换值,用于快速查找。
    # 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
    # 生成过程:
    #     i / 255.0: 将像素值归一化到 [0, 1] 范围。
    #    (i / 255.0) ** invGamma: 应用伽马校正公式。
    #    * 255: 将归一化后的值还原到 [0, 255] 范围。
    #   astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
    new_images = np.empty(images.shape)
    new_images[:, :] = cv2.LUT(np.array(images[:, :],
                                        dtype=np.uint8), table)
    # cv2.LUT: OpenCV 的快速像素值映射函数。
    # 输入图像的每个像素值通过查找表 table 进行伽马校正。
    # 大幅提高效率,避免逐像素计算。

    return new_images


def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
    """
    在检测器预测上应用非极大值抑制 (NMS)
    
    参数:
        detector_pred: 检测器预测 (B, 1, H, W)
        nms_thresh: NMS阈值
        nms_size: NMS邻域大小
        detector_label: 检测器标签 (可选)
        mask: 是否使用标签掩码 (当前未实现)
    
    返回:
        关键点位置列表 (每个元素是 (N, 2) 的数组)
    """
    # 创建预测副本 (避免修改原始数据)
    detector_pred = detector_pred.clone().detach()
    # 获取批次大小和图像尺寸
    B, _, h, w = detector_pred.shape

    # if mask:
    #     assert detector_label is not None
    #     detector_pred[detector_pred < nms_thresh] = 0
    #     label_mask = detector_label
    #
    #     # more area
    #
    #     detector_label = detector_label.long().cpu().numpy()
    #     detector_label = detector_label.astype(np.uint8)
    #     kernel = np.ones((3, 3), np.uint8)
    #     label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
    #                            for s in range(len(detector_label))])
    #     label_mask = torch.from_numpy(label_mask).unsqueeze(1)
    #     detector_pred[label_mask > 1e-6] = 0
    
    # 应用快速NMS算法
    scores = simple_nms(detector_pred, nms_size)
    # 重塑分数图形状 (B, H, W)
    scores = scores.reshape(B, h, w)
    # 找出分数高于阈值的点
    points = [
        torch.nonzero(s > nms_thresh)
        for s in scores]
    # 提取这些点的分数值
    scores = [s[tuple(k.t())] for s, k in zip(scores, points)]
    # 移除靠近边界的点
    points, scores = list(zip(*[
        remove_borders(k, s, 8, h, w)
        for k, s in zip(points, scores)]))
    # 翻转坐标顺序: [y, x] -> [x, y]
    points = [torch.flip(k, [1]).long() for k in points]

    return points


# 实际上模型生成的描述符是1/8尺寸*1/8尺寸的描述符特征,这里是上采样到原尺寸的
# 这个是在PKE算损失的时候用的
def sample_keypoint_desc(keypoints, descriptors, s: int = 8):
    """
    在关键点位置采样描述符
    
    参数:
        keypoints: 关键点坐标 (B, N, 2) 格式 [x, y]
        descriptors: 描述符图 (B, C, H, W)
        s: 描述符图相对于原始图像的下采样比例
    
    返回:
        采样后的描述符 (B, C, N)
    """
    # 获取描述符张量的形状信息,用于后续处理
    b, c, h, w = descriptors.shape  # 原始输入 descriptors: (b, c, h, w)
    
    # 克隆关键点并将其转换为浮点类型,以便进行坐标计算
    keypoints = keypoints.clone().float()

    # 将关键点坐标归一化到范围 (0, 1)
    keypoints /= torch.tensor([(w * s - 1), (h * s - 1)]).to(keypoints)[None]
    
    # 将关键点坐标缩放到范围 (-1, 1),以适应 grid_sample 函数的要求
    keypoints = keypoints * 2 - 1

    # 根据 PyTorch 版本准备 grid_sample 函数的参数,确保兼容性
    args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
    
    # 使用 grid_sample 函数在关键点位置插值描述符
    descriptors = torch.nn.functional.grid_sample(
        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)  # 经过 grid_sample:     (b, c, 1, n) n是关键点数量

    # 对描述符进行 L2 归一化,使其长度为 1(1个像素),以便后续处理
    descriptors = torch.nn.functional.normalize(
        descriptors.reshape(b, c, -1), p=2, dim=1)  # reshape 后: (b, c, n)  channel=256
    
    # 返回处理后的描述符
    return descriptors


# 这个是在模型算损失的时候用的,可以同时处理关键点和被映射后的关键点损失
def sample_descriptors(detector_pred, descriptor_pred, affine_descriptor_pred, grid_inverse,
                        nms_size=10, nms_thresh=0.1, scale=8, affine_detector_pred=None):
    """
    基于关键点采样描述符
    
    参数:
        detector_pred: 原始图像的检测器预测 (B, 1, H, W)
        descriptor_pred: 原始图像的描述符预测 (B, C, H, W)
        affine_descriptor_pred: 仿射图像的描述符预测 (B, C, H, W)
        grid_inverse: 逆变换网格 (B, H, W, 2)
        nms_size: NMS邻域大小
        nms_thresh: NMS阈值
        scale: 描述符图相对于原始图像的下采样比例
        affine_detector_pred: 仿射图像的检测器预测 (可选)
    
    返回:
        descriptors: 原始图像关键点的描述符列表
        affine_descriptors: 仿射图像对应关键点的描述符列表
        keypoints: 原始图像的关键点位置列表
    """
    # 获取批次大小和图像尺寸
    B, _, h, w = detector_pred.shape
    
    # 应用NMS获取关键点位置
    keypoints = nms(detector_pred, nms_size=nms_size, nms_thresh=nms_thresh)
    
    # 使用逆变换网格将关键点映射到仿射空间
    affine_keypoints = [
        grid_inverse[s, k[:, 1].long(), k[:, 0].long()]  # 使用网格插值
        for s, k in enumerate(keypoints)
    ]
    
    # 初始化存储列表
    kp = []        # 过滤后的原始关键点
    affine_kp = []  # 过滤后的仿射关键点
    
    # 处理每个样本
    for s, k in enumerate(affine_keypoints):
        # 过滤超出仿射图像边界的点
        idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & (k[:, 1] > -1)
        
        # 存储过滤后的原始关键点
        kp.append(keypoints[s][idx])
        
        # 获取过滤后的仿射关键点
        ak = k[idx]
        
        # 将归一化坐标转换回像素坐标
        ak[:, 0] = (ak[:, 0] + 1) / 2 * (w - 1)  # x坐标
        ak[:, 1] = (ak[:, 1] + 1) / 2 * (h - 1)  # y坐标
        
        # 存储转换后的仿射关键点
        affine_kp.append(ak)
    
    # 在原始图像关键点位置采样描述符
    descriptors = [
        sample_keypoint_desc(k[None], d[None], s=scale)[0]
        for k, d in zip(kp, descriptor_pred)
    ]
    
    # 在仿射图像关键点位置采样描述符
    affine_descriptors = [
        sample_keypoint_desc(k[None], d[None], s=scale)[0]
        for k, d in zip(affine_kp, affine_descriptor_pred)
    ]
    
    return descriptors, affine_descriptors, keypoints