Spaces:
Sleeping
Sleeping
| from torch.nn import functional as F | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from torchvision import transforms | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import torch.nn as nn | |
| import scipy.stats as st | |
| # 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点 | |
| # 这个函数是在推理的时候用的,训练的时候没有使用 | |
| def remove_borders(keypoints, scores, border: int, height: int, width: int): | |
| """ Removes keypoints too close to the border """ | |
| ''' | |
| keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。 | |
| scores: 每个关键点对应的分数,形状为 (N,)。 | |
| border: 表示需要移除的边界宽度。 推理时预设的是4像素 | |
| height: 图像的高度。 | |
| width: 图像的宽度。 | |
| ''' | |
| # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内 | |
| mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) | |
| # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内 | |
| mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) | |
| # 组合掩码 (必须同时满足高度和宽度条件) | |
| mask = mask_h & mask_w # 所以这个mask是判定条件 | |
| # 返回过滤后的关键点和分数 | |
| return keypoints[mask], scores[mask] | |
| def simple_nms(scores, nms_radius: int): | |
| """ | |
| 快速非极大值抑制 (NMS) 算法,用于移除相邻关键点 | |
| 参数: | |
| scores: 关键点分数图 (B, H, W) 或 (H, W) | |
| nms_radius: NMS邻域半径 | |
| 返回: | |
| NMS处理后的分数图 | |
| """ | |
| assert (nms_radius >= 0) | |
| # 计算NMS窗口大小 (2*半径+1) | |
| size = nms_radius * 2 + 1 | |
| avg_size = 2 | |
| # 定义最大池化函数 (使用固定步长1和适当填充) | |
| def max_pool(x): | |
| return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius) | |
| # 创建与输入相同形状的零张量 | |
| zeros = torch.zeros_like(scores) | |
| # max_map = max_pool(scores) | |
| # 步骤1: 识别局部最大值点 | |
| # 比较每个点与其邻域内的最大值 | |
| max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。 | |
| # 步骤2: 添加微小随机扰动 (避免多个相同最大值) | |
| # 生成 [0, 0.1) 范围内的随机数 | |
| max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10 | |
| # 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。 | |
| # 非局部最大值点置零 | |
| max_mask_[~max_mask] = 0 | |
| # 步骤3: 对扰动后的图再次应用NMS | |
| # 识别扰动后仍然是局部最大值的点 | |
| mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。 | |
| # 步骤4: 保留局部最大值点,其他点置零 | |
| return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。 | |
| def pre_processing(data): | |
| """ Enhance retinal images """ | |
| train_imgs = datasets_normalized(data) | |
| train_imgs = clahe_equalized(train_imgs) | |
| train_imgs = adjust_gamma(train_imgs, 1.2) | |
| train_imgs = train_imgs / 255. | |
| return train_imgs.astype(np.float32) | |
| def rgb2gray(rgb): | |
| """ Convert RGB image to gray image """ | |
| r, g, b = rgb.split() | |
| return g | |
| # 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理 | |
| # 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。 | |
| def clahe_equalized(images): | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| # clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。 | |
| # tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。 | |
| images_equalized = np.empty(images.shape) | |
| images_equalized[:, :] = clahe.apply(np.array(images[:, :], | |
| dtype=np.uint8)) | |
| return images_equalized | |
| def datasets_normalized(images): | |
| # 归一化之后还需要把值映射到0到255 | |
| # images_normalized = np.empty(images.shape) | |
| images_std = np.std(images) | |
| images_mean = np.mean(images) | |
| images_normalized = (images - images_mean) / (images_std + 1e-6) | |
| minv = np.min(images_normalized) | |
| images_normalized = ((images_normalized - minv) / | |
| (np.max(images_normalized) - minv)) * 255 | |
| return images_normalized | |
| def adjust_gamma(images, gamma=1.0): | |
| invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。 | |
| table = np.array([((i / 255.0) ** invGamma) * 255 | |
| for i in np.arange(0, 256)]).astype("uint8") | |
| # 预计算伽马校正的转换值,用于快速查找。 | |
| # 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。 | |
| # 生成过程: | |
| # i / 255.0: 将像素值归一化到 [0, 1] 范围。 | |
| # (i / 255.0) ** invGamma: 应用伽马校正公式。 | |
| # * 255: 将归一化后的值还原到 [0, 255] 范围。 | |
| # astype("uint8"): 转换为 uint8 数据类型,适配图像格式。 | |
| new_images = np.empty(images.shape) | |
| new_images[:, :] = cv2.LUT(np.array(images[:, :], | |
| dtype=np.uint8), table) | |
| # cv2.LUT: OpenCV 的快速像素值映射函数。 | |
| # 输入图像的每个像素值通过查找表 table 进行伽马校正。 | |
| # 大幅提高效率,避免逐像素计算。 | |
| return new_images | |
| def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False): | |
| """ | |
| 在检测器预测上应用非极大值抑制 (NMS) | |
| 参数: | |
| detector_pred: 检测器预测 (B, 1, H, W) | |
| nms_thresh: NMS阈值 | |
| nms_size: NMS邻域大小 | |
| detector_label: 检测器标签 (可选) | |
| mask: 是否使用标签掩码 (当前未实现) | |
| 返回: | |
| 关键点位置列表 (每个元素是 (N, 2) 的数组) | |
| """ | |
| # 创建预测副本 (避免修改原始数据) | |
| detector_pred = detector_pred.clone().detach() | |
| # 获取批次大小和图像尺寸 | |
| B, _, h, w = detector_pred.shape | |
| # if mask: | |
| # assert detector_label is not None | |
| # detector_pred[detector_pred < nms_thresh] = 0 | |
| # label_mask = detector_label | |
| # | |
| # # more area | |
| # | |
| # detector_label = detector_label.long().cpu().numpy() | |
| # detector_label = detector_label.astype(np.uint8) | |
| # kernel = np.ones((3, 3), np.uint8) | |
| # label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1) | |
| # for s in range(len(detector_label))]) | |
| # label_mask = torch.from_numpy(label_mask).unsqueeze(1) | |
| # detector_pred[label_mask > 1e-6] = 0 | |
| # 应用快速NMS算法 | |
| scores = simple_nms(detector_pred, nms_size) | |
| # 重塑分数图形状 (B, H, W) | |
| scores = scores.reshape(B, h, w) | |
| # 找出分数高于阈值的点 | |
| points = [ | |
| torch.nonzero(s > nms_thresh) | |
| for s in scores] | |
| # 提取这些点的分数值 | |
| scores = [s[tuple(k.t())] for s, k in zip(scores, points)] | |
| # 移除靠近边界的点 | |
| points, scores = list(zip(*[ | |
| remove_borders(k, s, 8, h, w) | |
| for k, s in zip(points, scores)])) | |
| # 翻转坐标顺序: [y, x] -> [x, y] | |
| points = [torch.flip(k, [1]).long() for k in points] | |
| return points | |
| # 实际上模型生成的描述符是1/8尺寸*1/8尺寸的描述符特征,这里是上采样到原尺寸的 | |
| # 这个是在PKE算损失的时候用的 | |
| def sample_keypoint_desc(keypoints, descriptors, s: int = 8): | |
| """ | |
| 在关键点位置采样描述符 | |
| 参数: | |
| keypoints: 关键点坐标 (B, N, 2) 格式 [x, y] | |
| descriptors: 描述符图 (B, C, H, W) | |
| s: 描述符图相对于原始图像的下采样比例 | |
| 返回: | |
| 采样后的描述符 (B, C, N) | |
| """ | |
| # 获取描述符张量的形状信息,用于后续处理 | |
| b, c, h, w = descriptors.shape # 原始输入 descriptors: (b, c, h, w) | |
| # 克隆关键点并将其转换为浮点类型,以便进行坐标计算 | |
| keypoints = keypoints.clone().float() | |
| # 将关键点坐标归一化到范围 (0, 1) | |
| keypoints /= torch.tensor([(w * s - 1), (h * s - 1)]).to(keypoints)[None] | |
| # 将关键点坐标缩放到范围 (-1, 1),以适应 grid_sample 函数的要求 | |
| keypoints = keypoints * 2 - 1 | |
| # 根据 PyTorch 版本准备 grid_sample 函数的参数,确保兼容性 | |
| args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} | |
| # 使用 grid_sample 函数在关键点位置插值描述符 | |
| descriptors = torch.nn.functional.grid_sample( | |
| descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) # 经过 grid_sample: (b, c, 1, n) n是关键点数量 | |
| # 对描述符进行 L2 归一化,使其长度为 1(1个像素),以便后续处理 | |
| descriptors = torch.nn.functional.normalize( | |
| descriptors.reshape(b, c, -1), p=2, dim=1) # reshape 后: (b, c, n) channel=256 | |
| # 返回处理后的描述符 | |
| return descriptors | |
| # 这个是在模型算损失的时候用的,可以同时处理关键点和被映射后的关键点损失 | |
| def sample_descriptors(detector_pred, descriptor_pred, affine_descriptor_pred, grid_inverse, | |
| nms_size=10, nms_thresh=0.1, scale=8, affine_detector_pred=None): | |
| """ | |
| 基于关键点采样描述符 | |
| 参数: | |
| detector_pred: 原始图像的检测器预测 (B, 1, H, W) | |
| descriptor_pred: 原始图像的描述符预测 (B, C, H, W) | |
| affine_descriptor_pred: 仿射图像的描述符预测 (B, C, H, W) | |
| grid_inverse: 逆变换网格 (B, H, W, 2) | |
| nms_size: NMS邻域大小 | |
| nms_thresh: NMS阈值 | |
| scale: 描述符图相对于原始图像的下采样比例 | |
| affine_detector_pred: 仿射图像的检测器预测 (可选) | |
| 返回: | |
| descriptors: 原始图像关键点的描述符列表 | |
| affine_descriptors: 仿射图像对应关键点的描述符列表 | |
| keypoints: 原始图像的关键点位置列表 | |
| """ | |
| # 获取批次大小和图像尺寸 | |
| B, _, h, w = detector_pred.shape | |
| # 应用NMS获取关键点位置 | |
| keypoints = nms(detector_pred, nms_size=nms_size, nms_thresh=nms_thresh) | |
| # 使用逆变换网格将关键点映射到仿射空间 | |
| affine_keypoints = [ | |
| grid_inverse[s, k[:, 1].long(), k[:, 0].long()] # 使用网格插值 | |
| for s, k in enumerate(keypoints) | |
| ] | |
| # 初始化存储列表 | |
| kp = [] # 过滤后的原始关键点 | |
| affine_kp = [] # 过滤后的仿射关键点 | |
| # 处理每个样本 | |
| for s, k in enumerate(affine_keypoints): | |
| # 过滤超出仿射图像边界的点 | |
| idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & (k[:, 1] > -1) | |
| # 存储过滤后的原始关键点 | |
| kp.append(keypoints[s][idx]) | |
| # 获取过滤后的仿射关键点 | |
| ak = k[idx] | |
| # 将归一化坐标转换回像素坐标 | |
| ak[:, 0] = (ak[:, 0] + 1) / 2 * (w - 1) # x坐标 | |
| ak[:, 1] = (ak[:, 1] + 1) / 2 * (h - 1) # y坐标 | |
| # 存储转换后的仿射关键点 | |
| affine_kp.append(ak) | |
| # 在原始图像关键点位置采样描述符 | |
| descriptors = [ | |
| sample_keypoint_desc(k[None], d[None], s=scale)[0] | |
| for k, d in zip(kp, descriptor_pred) | |
| ] | |
| # 在仿射图像关键点位置采样描述符 | |
| affine_descriptors = [ | |
| sample_keypoint_desc(k[None], d[None], s=scale)[0] | |
| for k, d in zip(affine_kp, affine_descriptor_pred) | |
| ] | |
| return descriptors, affine_descriptors, keypoints | |