File size: 20,129 Bytes
bddd311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import math
import os
import random
import shutil
import time

import torch
from torchvision import transforms as T
from torch.nn import functional as F
import numpy as np
import cv2
import scipy.stats as st
# import config
from tqdm import tqdm
import matplotlib.pyplot as plt


def get_gaussian_kernel(kernlen=21, nsig=5):
    """
    生成二维高斯核并返回一个不可训练的核权重,用于创建高斯热图
    
    参数:
        kernlen (int): 核的大小(边长),必须是奇数
        nsig (float): 高斯分布的标准差(控制分布的宽度)
    
    返回:
        torch.Tensor: 形状为(1, 1, kernlen, kernlen)的高斯核张量
    """
    # 1. 计算采样间隔
    # 在[-nsig, nsig]范围内均匀采样kernlen个点
    interval = (2 * nsig + 1.) / kernlen
    # 2. 创建一维坐标点数组
    # 从-nsig-interval/2到nsig+interval/2,共kernlen+1个点
    x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
    # 3. 计算一维高斯核
    # 使用标准正态分布的累积分布函数(CDF)的差分
    kern1d = np.diff(st.norm.cdf(x))
    # 4. 创建二维高斯核
    # 通过两个一维高斯核的外积(outer product)
    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
    # 5. 归一化核
    # 使所有元素之和为1,确保核的总"能量"为1
    kernel = kernel_raw / kernel_raw.sum()
    # 6. 转换为PyTorch张量并调整维度
    # 添加批次和通道维度 (1, 1, H, W)
    kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
    # 7. 创建不可训练的高斯核权重参数
    weight = torch.nn.Parameter(data=kernel, requires_grad=False)
    # 8. 将值归一化到[0, 1]范围
    # 使最小值为0,最大值为1
    weight = (weight - weight.min()) / (weight.max() - weight.min())

    return weight


# 对用于监督的点标签进行高斯模糊
def label_gaussian_blur(label_point_positions, kernel, stride=1):
    '''
    对用于监督的点标签进行高斯模糊
    该函数通过应用高斯卷积核来模糊标签点,以生成更平滑的热图
    参数:
        label_point_positions: 标签点的位置,形状为(B, C, H, W)的张量
        kernel: 高斯卷积核,形状为(1, 1, Hk, Wk)的张量
        stride: 卷积步长,默认为1
    返回值:
        blurred_label: 模糊后的标签热图,形状与输入相同
    '''
    # 应用高斯卷积进行模糊
    blurred_label = F.conv2d(label_point_positions, kernel, stride=stride, padding=(kernel.shape[-1] - 1) // 2)
    # 裁剪值范围:限制热图最大值为1
    blurred_label[blurred_label > 1] = 1

    return blurred_label


def affine_images(images, used_for='detector'):
    '''
    根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强
    参数:
        images: 输入图像张量,形状为 (B, C, H, W)
        used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器
    返回值:
        CPU上的仿射变换后的图像 output.detach().clone()
        仿射变换网格 grid
        逆仿射变换网格 grid_inverse
    '''

    """
    Perform affine transformation on images
    :param images: (B, C, H, W)
    :param keypoint_labels: corresponding labels
    :param value_map: value maps, used to record history learned geo_points
    :return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse
    """

    h, w = images.shape[2:] # 获取图像的尺寸
    theta = None
    thetaI = None

    # 对一个batch中的每个图像进行随机仿射变换参数的生成
    for i in range(len(images)):
        if used_for == 'detector':
            affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15],                # 旋转角度范围:±15度
                                                        translate=[0.2, 0.2],               # 平移范围:±20%
                                                        scale_ranges=[0.9, 1.35],           # 缩放范围:0.9-1.35倍
                                                        shears=None,                        # 不使用剪切变换
                                                        img_size=[h, w])
        else:
            affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3],                  # 旋转角度范围:±3度
                                                            translate=[0.1, 0.1],           # 平移范围:宽高的10%
                                                            scale_ranges=[0.9, 1.1],        # 缩放范围:0.9到1.1倍
                                                            shears=None,                    # 不使用剪切变换
                                                            img_size=[h, w])
        # 根据仿射变换参数计算变换矩阵和逆变换矩阵
        angle = -affine_params[0] * math.pi / 180
        theta_ = torch.tensor([
            [1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]],
            [math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]],
            [0, 0, 1]
        ], dtype=torch.float).to(images)
        thetaI_ = theta_.inverse()
        theta_ = theta_[:2]
        thetaI_ = thetaI_[:2]

        # 将变换矩阵和逆变换矩阵叠成一个batch
        theta_ = theta_.unsqueeze(0)
        thetaI_ = thetaI_.unsqueeze(0)
        theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了
        thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_))

    # 根据变换矩阵生成网格
    # 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。
    grid = F.affine_grid(theta, images.size(), align_corners=True)
    grid = grid.to(images)  # 将参数移动到GPU上
    grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True)
    grid_inverse = grid_inverse.to(images)  # 将参数移动到GPU上
    output = F.grid_sample(images, grid, align_corners=True)     # 对图像进行采样得到仿射变换后的图像

    # 对于用于描述符的情况进行一些额外的随机增强操作
    if used_for == 'descriptor':
        if random.random() >= 0.4:
            # 颜色抖动(其实是灰度变化)
            output = output.repeat(1, 3, 1, 1)   # 将单通道图像复制为三通道(颜色变换需要RGB)
            output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output)
            output = T.Grayscale()(output)      # 灰度化

    return output.detach().clone(), grid, grid_inverse


def affine_images_with_mask(images, used_for='detector'):
    '''
    仿射变换图像后,额外返回一个公共区域mask
    根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强
    参数:
        images: 输入图像张量,形状为 (B, C, H, W)
        used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器
    返回值:
        CPU上的仿射变换后的图像 output.detach().clone()
        仿射变换网格 grid
        逆仿射变换网格 grid_inverse
    '''

    """
    Perform affine transformation on images
    :param images: (B, C, H, W)
    :param keypoint_labels: corresponding labels
    :param value_map: value maps, used to record history learned geo_points
    :return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse
    """

    h, w = images.shape[2:] # 获取图像的尺寸
    theta = None
    thetaI = None

    # ---------------------------------------------------------------------
    # 新增: 创建原始图像坐标网格
    base_grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), (1, 1, h, w), align_corners=True)
    base_grid = base_grid.expand(images.size(0), *base_grid.shape[1:])
    # ---------------------------------------------------------------------

    # 对一个batch中的每个图像进行随机仿射变换参数的生成
    for i in range(len(images)):
        if used_for == 'detector':
            affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15],                # 旋转角度范围:±15度
                                                        translate=[0.2, 0.2],               # 平移范围:±20%
                                                        scale_ranges=[0.9, 1.35],           # 缩放范围:0.9-1.35倍
                                                        shears=None,                        # 不使用剪切变换
                                                        img_size=[h, w])
        else:
            affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3],                  # 旋转角度范围:±3度
                                                            translate=[0.1, 0.1],           # 平移范围:宽高的10%
                                                            scale_ranges=[0.9, 1.1],        # 缩放范围:0.9到1.1倍
                                                            shears=None,                    # 不使用剪切变换
                                                            img_size=[h, w])
        # 根据仿射变换参数计算变换矩阵和逆变换矩阵
        angle = -affine_params[0] * math.pi / 180
        theta_ = torch.tensor([
            [1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]],
            [math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]],
            [0, 0, 1]
        ], dtype=torch.float).to(images)
        thetaI_ = theta_.inverse()
        theta_ = theta_[:2]
        thetaI_ = thetaI_[:2]

        # 将变换矩阵和逆变换矩阵叠成一个batch
        theta_ = theta_.unsqueeze(0)
        thetaI_ = thetaI_.unsqueeze(0)
        theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了
        thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_))

    # 根据变换矩阵生成网格
    # 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。
    grid = F.affine_grid(theta, images.size(), align_corners=True)
    grid = grid.to(images)  # 将参数移动到GPU上
    grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True)
    grid_inverse = grid_inverse.to(images)  # 将参数移动到GPU上
    output = F.grid_sample(images, grid, align_corners=True)     # 对图像进行采样得到仿射变换后的图像

    # ---------------------------------------------------------------------
    # 新增: 计算有效区域mask
    # 1. 创建全1mask (表示所有像素初始有效)
    valid_mask = torch.ones_like(images[:, :1])  # (B, 1, H, W)
    # 2. 应用相同的变换到mask上
    transformed_mask = F.grid_sample(valid_mask, grid, align_corners=True)
    # 3. 反向映射回原始图像坐标
    valid_mask = F.grid_sample(transformed_mask, grid_inverse, align_corners=True)
    # 4. 二值化 (大于0.5视为有效)
    valid_mask = (valid_mask > 0.5).float()
    # ---------------------------------------------------------------------

    # 对于用于描述符的情况进行一些额外的随机增强操作
    if used_for == 'descriptor':
        if random.random() >= 0.4:
            # 颜色抖动(其实是灰度变化)
            output = output.repeat(1, 3, 1, 1)   # 将单通道图像复制为三通道(颜色变换需要RGB)
            output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output)
            output = T.Grayscale()(output)      # 灰度化

    return output.detach().clone(), grid, grid_inverse, valid_mask.detach().clone()


def pre_processing(data: np.ndarray) -> np.ndarray:
    """
    视网膜图像预处理流水线
    包含标准化、CLAHE对比度增强和伽马校正
    
    参数:
        image: 输入视网膜图像 (单通道,numpy数组)
    
    返回:
        预处理后的图像 (float32, 范围[0,1])
    """
    train_imgs = datasets_normalized(data)
    train_imgs = clahe_equalized(train_imgs)
    train_imgs = adjust_gamma(train_imgs, 1.2)
    train_imgs = train_imgs / 255.  # 最终归一化到[0,1]范围
    return train_imgs.astype(np.float32)


def datasets_normalized(images: np.ndarray) -> np.ndarray:
    # 归一化之后还需要把值映射到0到255
    # images_normalized = np.empty(images.shape)
    # 计算全局统计量
    images_std = np.std(images)
    images_mean = np.mean(images)
    # 应用标准化: (x - mean) / std
    # 添加微小值避免除以零
    images_normalized = (images - images_mean) / (images_std + 1e-6)
    # 线性映射到[0,255]范围
    minv = np.min(images_normalized)
    images_normalized = ((images_normalized - minv) / (np.max(images_normalized) - minv)) * 255

    return images_normalized


def clahe_equalized(images):
    # 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
    # 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    # clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
    # tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
    images_equalized = np.empty(images.shape)
    images_equalized[:, :] = clahe.apply(np.array(images[:, :], dtype=np.uint8))

    return images_equalized


def adjust_gamma(images, gamma=1.0):
    # γ>1: 压缩高光区域,增强暗部细节 (适合视网膜图像)
    # 创建伽马校正查找表
    invGamma = 1.0 / gamma  # invGamma: 伽马值的倒数,用于生成查找表。
    table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")  
    # 预计算伽马校正的转换值,用于快速查找。
    # 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
    # 生成过程:
    #     i / 255.0: 将像素值归一化到 [0, 1] 范围。
    #    (i / 255.0) ** invGamma: 应用伽马校正公式。
    #    * 255: 将归一化后的值还原到 [0, 255] 范围。
    #   astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
    new_images = np.empty(images.shape)
    # 应用伽马校正
    new_images[:, :] = cv2.LUT(np.array(images[:, :], dtype=np.uint8), table)
    # cv2.LUT: OpenCV 的快速像素值映射函数。
    # 输入图像的每个像素值通过查找表 table 进行伽马校正。
    # 大幅提高效率,避免逐像素计算。

    return new_images


def simple_nms(scores, nms_radius: int):
    """
    快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
    参数:
        scores: 关键点分数图 (B, H, W) 或 (H, W)
        nms_radius: NMS邻域半径
    返回:
        NMS处理后的分数图
    """
    assert (nms_radius >= 0)
    # 计算NMS窗口大小 (2*半径+1)
    size = nms_radius * 2 + 1
    avg_size = 2
    # 定义最大池化函数 (使用固定步长1和适当填充)
    def max_pool(x):
        return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
    
    # 创建与输入相同形状的零张量
    zeros = torch.zeros_like(scores)
    # max_map = max_pool(scores)

    # 步骤1: 识别局部最大值点
    # 比较每个点与其邻域内的最大值
    max_mask = scores == max_pool(scores)   # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
    # 步骤2: 添加微小随机扰动 (避免多个相同最大值)
    # 生成 [0, 0.1) 范围内的随机数
    max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10 
    # 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
    # 非局部最大值点置零
    max_mask_[~max_mask] = 0
    # 步骤3: 对扰动后的图再次应用NMS
    # 识别扰动后仍然是局部最大值的点
    mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0))   # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。
    # 步骤4: 保留局部最大值点,其他点置零
    return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。


def remove_borders(keypoints, scores, border: int, height: int, width: int):
    '''
    从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
    这个函数是在推理的时候用的,训练的时候没有使用
    keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
    scores: 每个关键点对应的分数,形状为 (N,)。
    border: 表示需要移除的边界宽度。    推理时预设的是4像素
    height: 图像的高度。
    width: 图像的宽度。
    '''
    mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))    # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
    mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))     # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
    mask = mask_h & mask_w  # 组合掩码 (必须同时满足高度和宽度条件)
    return keypoints[mask], scores[mask]    # 返回过滤后的关键点和分数


def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
    """
    在检测器预测上应用非极大值抑制 (NMS)
    
    参数:
        detector_pred: 检测器预测 (B, 1, H, W)
        nms_thresh: NMS阈值
        nms_size: NMS邻域大小
        detector_label: 检测器标签 (可选)
        mask: 是否使用标签掩码 (当前未实现)
    
    返回:
        关键点位置列表 (每个元素是 (N, 2) 的数组)
    """
    detector_pred = detector_pred.clone().detach()  # 创建预测副本 (避免修改原始数据)
    B, _, h, w = detector_pred.shape    # 获取批次大小和图像尺寸
    # if mask:
    #     assert detector_label is not None
    #     detector_pred[detector_pred < nms_thresh] = 0
    #     label_mask = detector_label
    #
    #     # more area
    #
    #     detector_label = detector_label.long().cpu().numpy()
    #     detector_label = detector_label.astype(np.uint8)
    #     kernel = np.ones((3, 3), np.uint8)
    #     label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
    #                            for s in range(len(detector_label))])
    #     label_mask = torch.from_numpy(label_mask).unsqueeze(1)
    #     detector_pred[label_mask > 1e-6] = 0
    scores = simple_nms(detector_pred, nms_size)    # 应用快速NMS算法
    scores = scores.reshape(B, h, w)    # 重塑分数图形状 (B, H, W)
    # print(f"scores before thresh {nms_thresh}", scores)
    points = [torch.nonzero(s > nms_thresh) for s in scores]    # 找出分数高于阈值的点
    # print("points after thresh", points)
    scores = [s[tuple(k.t())] for s, k in zip(scores, points)]  # 提取这些点的分数值
    points, scores = list(zip(*[ remove_borders(k, s, 8, h, w) for k, s in zip(points, scores)]))   # 移除靠近边界的点
    points = [torch.flip(k, [1]).long() for k in points]    # 翻转坐标顺序: [y, x] -> [x, y]

    return points