Spaces:
Sleeping
Sleeping
| import math | |
| import os | |
| import random | |
| import shutil | |
| import time | |
| import torch | |
| from torchvision import transforms as T | |
| from torch.nn import functional as F | |
| import numpy as np | |
| import cv2 | |
| import scipy.stats as st | |
| # import config | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| def get_gaussian_kernel(kernlen=21, nsig=5): | |
| """ | |
| 生成二维高斯核并返回一个不可训练的核权重,用于创建高斯热图 | |
| 参数: | |
| kernlen (int): 核的大小(边长),必须是奇数 | |
| nsig (float): 高斯分布的标准差(控制分布的宽度) | |
| 返回: | |
| torch.Tensor: 形状为(1, 1, kernlen, kernlen)的高斯核张量 | |
| """ | |
| # 1. 计算采样间隔 | |
| # 在[-nsig, nsig]范围内均匀采样kernlen个点 | |
| interval = (2 * nsig + 1.) / kernlen | |
| # 2. 创建一维坐标点数组 | |
| # 从-nsig-interval/2到nsig+interval/2,共kernlen+1个点 | |
| x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1) | |
| # 3. 计算一维高斯核 | |
| # 使用标准正态分布的累积分布函数(CDF)的差分 | |
| kern1d = np.diff(st.norm.cdf(x)) | |
| # 4. 创建二维高斯核 | |
| # 通过两个一维高斯核的外积(outer product) | |
| kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) | |
| # 5. 归一化核 | |
| # 使所有元素之和为1,确保核的总"能量"为1 | |
| kernel = kernel_raw / kernel_raw.sum() | |
| # 6. 转换为PyTorch张量并调整维度 | |
| # 添加批次和通道维度 (1, 1, H, W) | |
| kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0) | |
| # 7. 创建不可训练的高斯核权重参数 | |
| weight = torch.nn.Parameter(data=kernel, requires_grad=False) | |
| # 8. 将值归一化到[0, 1]范围 | |
| # 使最小值为0,最大值为1 | |
| weight = (weight - weight.min()) / (weight.max() - weight.min()) | |
| return weight | |
| # 对用于监督的点标签进行高斯模糊 | |
| def label_gaussian_blur(label_point_positions, kernel, stride=1): | |
| ''' | |
| 对用于监督的点标签进行高斯模糊 | |
| 该函数通过应用高斯卷积核来模糊标签点,以生成更平滑的热图 | |
| 参数: | |
| label_point_positions: 标签点的位置,形状为(B, C, H, W)的张量 | |
| kernel: 高斯卷积核,形状为(1, 1, Hk, Wk)的张量 | |
| stride: 卷积步长,默认为1 | |
| 返回值: | |
| blurred_label: 模糊后的标签热图,形状与输入相同 | |
| ''' | |
| # 应用高斯卷积进行模糊 | |
| blurred_label = F.conv2d(label_point_positions, kernel, stride=stride, padding=(kernel.shape[-1] - 1) // 2) | |
| # 裁剪值范围:限制热图最大值为1 | |
| blurred_label[blurred_label > 1] = 1 | |
| return blurred_label | |
| def affine_images(images, used_for='detector'): | |
| ''' | |
| 根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强 | |
| 参数: | |
| images: 输入图像张量,形状为 (B, C, H, W) | |
| used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器 | |
| 返回值: | |
| CPU上的仿射变换后的图像 output.detach().clone() | |
| 仿射变换网格 grid | |
| 逆仿射变换网格 grid_inverse | |
| ''' | |
| """ | |
| Perform affine transformation on images | |
| :param images: (B, C, H, W) | |
| :param keypoint_labels: corresponding labels | |
| :param value_map: value maps, used to record history learned geo_points | |
| :return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse | |
| """ | |
| h, w = images.shape[2:] # 获取图像的尺寸 | |
| theta = None | |
| thetaI = None | |
| # 对一个batch中的每个图像进行随机仿射变换参数的生成 | |
| for i in range(len(images)): | |
| if used_for == 'detector': | |
| affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15], # 旋转角度范围:±15度 | |
| translate=[0.2, 0.2], # 平移范围:±20% | |
| scale_ranges=[0.9, 1.35], # 缩放范围:0.9-1.35倍 | |
| shears=None, # 不使用剪切变换 | |
| img_size=[h, w]) | |
| else: | |
| affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3], # 旋转角度范围:±3度 | |
| translate=[0.1, 0.1], # 平移范围:宽高的10% | |
| scale_ranges=[0.9, 1.1], # 缩放范围:0.9到1.1倍 | |
| shears=None, # 不使用剪切变换 | |
| img_size=[h, w]) | |
| # 根据仿射变换参数计算变换矩阵和逆变换矩阵 | |
| angle = -affine_params[0] * math.pi / 180 | |
| theta_ = torch.tensor([ | |
| [1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]], | |
| [math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]], | |
| [0, 0, 1] | |
| ], dtype=torch.float).to(images) | |
| thetaI_ = theta_.inverse() | |
| theta_ = theta_[:2] | |
| thetaI_ = thetaI_[:2] | |
| # 将变换矩阵和逆变换矩阵叠成一个batch | |
| theta_ = theta_.unsqueeze(0) | |
| thetaI_ = thetaI_.unsqueeze(0) | |
| theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了 | |
| thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_)) | |
| # 根据变换矩阵生成网格 | |
| # 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。 | |
| grid = F.affine_grid(theta, images.size(), align_corners=True) | |
| grid = grid.to(images) # 将参数移动到GPU上 | |
| grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True) | |
| grid_inverse = grid_inverse.to(images) # 将参数移动到GPU上 | |
| output = F.grid_sample(images, grid, align_corners=True) # 对图像进行采样得到仿射变换后的图像 | |
| # 对于用于描述符的情况进行一些额外的随机增强操作 | |
| if used_for == 'descriptor': | |
| if random.random() >= 0.4: | |
| # 颜色抖动(其实是灰度变化) | |
| output = output.repeat(1, 3, 1, 1) # 将单通道图像复制为三通道(颜色变换需要RGB) | |
| output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output) | |
| output = T.Grayscale()(output) # 灰度化 | |
| return output.detach().clone(), grid, grid_inverse | |
| def affine_images_with_mask(images, used_for='detector'): | |
| ''' | |
| 仿射变换图像后,额外返回一个公共区域mask | |
| 根据detector和descriptor两种计算loss和训练的情况分别生成两种不同的仿射变及其逆变换,支持数据增强 | |
| 参数: | |
| images: 输入图像张量,形状为 (B, C, H, W) | |
| used_for: 指定变换用途 - 'detector'用于检测器,'descriptor'用于描述器 | |
| 返回值: | |
| CPU上的仿射变换后的图像 output.detach().clone() | |
| 仿射变换网格 grid | |
| 逆仿射变换网格 grid_inverse | |
| ''' | |
| """ | |
| Perform affine transformation on images | |
| :param images: (B, C, H, W) | |
| :param keypoint_labels: corresponding labels | |
| :param value_map: value maps, used to record history learned geo_points | |
| :return: results of affine images, affine labels, affine value maps, affine transformed grid_inverse, inverse transformed grid_inverse | |
| """ | |
| h, w = images.shape[2:] # 获取图像的尺寸 | |
| theta = None | |
| thetaI = None | |
| # --------------------------------------------------------------------- | |
| # 新增: 创建原始图像坐标网格 | |
| base_grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), (1, 1, h, w), align_corners=True) | |
| base_grid = base_grid.expand(images.size(0), *base_grid.shape[1:]) | |
| # --------------------------------------------------------------------- | |
| # 对一个batch中的每个图像进行随机仿射变换参数的生成 | |
| for i in range(len(images)): | |
| if used_for == 'detector': | |
| affine_params = T.RandomAffine(20).get_params(degrees=[-15, 15], # 旋转角度范围:±15度 | |
| translate=[0.2, 0.2], # 平移范围:±20% | |
| scale_ranges=[0.9, 1.35], # 缩放范围:0.9-1.35倍 | |
| shears=None, # 不使用剪切变换 | |
| img_size=[h, w]) | |
| else: | |
| affine_params = T.RandomAffine(20).get_params(degrees=[-3, 3], # 旋转角度范围:±3度 | |
| translate=[0.1, 0.1], # 平移范围:宽高的10% | |
| scale_ranges=[0.9, 1.1], # 缩放范围:0.9到1.1倍 | |
| shears=None, # 不使用剪切变换 | |
| img_size=[h, w]) | |
| # 根据仿射变换参数计算变换矩阵和逆变换矩阵 | |
| angle = -affine_params[0] * math.pi / 180 | |
| theta_ = torch.tensor([ | |
| [1 / affine_params[2] * math.cos(angle), math.sin(-angle), -affine_params[1][0] / images.shape[2]], | |
| [math.sin(angle), 1 / affine_params[2] * math.cos(angle), -affine_params[1][1] / images.shape[3]], | |
| [0, 0, 1] | |
| ], dtype=torch.float).to(images) | |
| thetaI_ = theta_.inverse() | |
| theta_ = theta_[:2] | |
| thetaI_ = thetaI_[:2] | |
| # 将变换矩阵和逆变换矩阵叠成一个batch | |
| theta_ = theta_.unsqueeze(0) | |
| thetaI_ = thetaI_.unsqueeze(0) | |
| theta = theta_ if theta is None else torch.cat((theta, theta_)) # 如果theta为None,意味着这是第一个,否则则说明前面已经有参数值了 | |
| thetaI = thetaI_ if thetaI is None else torch.cat((thetaI, thetaI_)) | |
| # 根据变换矩阵生成网格 | |
| # 变换网格(Transformation Grid)是在进行图像的仿射变换时生成的一个规则的坐标网格。它定义了原始图像中每个像素在变换后的目标位置。 | |
| grid = F.affine_grid(theta, images.size(), align_corners=True) | |
| grid = grid.to(images) # 将参数移动到GPU上 | |
| grid_inverse = F.affine_grid(thetaI, images.size(), align_corners=True) | |
| grid_inverse = grid_inverse.to(images) # 将参数移动到GPU上 | |
| output = F.grid_sample(images, grid, align_corners=True) # 对图像进行采样得到仿射变换后的图像 | |
| # --------------------------------------------------------------------- | |
| # 新增: 计算有效区域mask | |
| # 1. 创建全1mask (表示所有像素初始有效) | |
| valid_mask = torch.ones_like(images[:, :1]) # (B, 1, H, W) | |
| # 2. 应用相同的变换到mask上 | |
| transformed_mask = F.grid_sample(valid_mask, grid, align_corners=True) | |
| # 3. 反向映射回原始图像坐标 | |
| valid_mask = F.grid_sample(transformed_mask, grid_inverse, align_corners=True) | |
| # 4. 二值化 (大于0.5视为有效) | |
| valid_mask = (valid_mask > 0.5).float() | |
| # --------------------------------------------------------------------- | |
| # 对于用于描述符的情况进行一些额外的随机增强操作 | |
| if used_for == 'descriptor': | |
| if random.random() >= 0.4: | |
| # 颜色抖动(其实是灰度变化) | |
| output = output.repeat(1, 3, 1, 1) # 将单通道图像复制为三通道(颜色变换需要RGB) | |
| output = T.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.3, hue=0.2)(output) | |
| output = T.Grayscale()(output) # 灰度化 | |
| return output.detach().clone(), grid, grid_inverse, valid_mask.detach().clone() | |
| def pre_processing(data: np.ndarray) -> np.ndarray: | |
| """ | |
| 视网膜图像预处理流水线 | |
| 包含标准化、CLAHE对比度增强和伽马校正 | |
| 参数: | |
| image: 输入视网膜图像 (单通道,numpy数组) | |
| 返回: | |
| 预处理后的图像 (float32, 范围[0,1]) | |
| """ | |
| train_imgs = datasets_normalized(data) | |
| train_imgs = clahe_equalized(train_imgs) | |
| train_imgs = adjust_gamma(train_imgs, 1.2) | |
| train_imgs = train_imgs / 255. # 最终归一化到[0,1]范围 | |
| return train_imgs.astype(np.float32) | |
| def datasets_normalized(images: np.ndarray) -> np.ndarray: | |
| # 归一化之后还需要把值映射到0到255 | |
| # images_normalized = np.empty(images.shape) | |
| # 计算全局统计量 | |
| images_std = np.std(images) | |
| images_mean = np.mean(images) | |
| # 应用标准化: (x - mean) / std | |
| # 添加微小值避免除以零 | |
| images_normalized = (images - images_mean) / (images_std + 1e-6) | |
| # 线性映射到[0,255]范围 | |
| minv = np.min(images_normalized) | |
| images_normalized = ((images_normalized - minv) / (np.max(images_normalized) - minv)) * 255 | |
| return images_normalized | |
| def clahe_equalized(images): | |
| # 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理 | |
| # 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。 | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| # clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。 | |
| # tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。 | |
| images_equalized = np.empty(images.shape) | |
| images_equalized[:, :] = clahe.apply(np.array(images[:, :], dtype=np.uint8)) | |
| return images_equalized | |
| def adjust_gamma(images, gamma=1.0): | |
| # γ>1: 压缩高光区域,增强暗部细节 (适合视网膜图像) | |
| # 创建伽马校正查找表 | |
| invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。 | |
| table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8") | |
| # 预计算伽马校正的转换值,用于快速查找。 | |
| # 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。 | |
| # 生成过程: | |
| # i / 255.0: 将像素值归一化到 [0, 1] 范围。 | |
| # (i / 255.0) ** invGamma: 应用伽马校正公式。 | |
| # * 255: 将归一化后的值还原到 [0, 255] 范围。 | |
| # astype("uint8"): 转换为 uint8 数据类型,适配图像格式。 | |
| new_images = np.empty(images.shape) | |
| # 应用伽马校正 | |
| new_images[:, :] = cv2.LUT(np.array(images[:, :], dtype=np.uint8), table) | |
| # cv2.LUT: OpenCV 的快速像素值映射函数。 | |
| # 输入图像的每个像素值通过查找表 table 进行伽马校正。 | |
| # 大幅提高效率,避免逐像素计算。 | |
| return new_images | |
| def simple_nms(scores, nms_radius: int): | |
| """ | |
| 快速非极大值抑制 (NMS) 算法,用于移除相邻关键点 | |
| 参数: | |
| scores: 关键点分数图 (B, H, W) 或 (H, W) | |
| nms_radius: NMS邻域半径 | |
| 返回: | |
| NMS处理后的分数图 | |
| """ | |
| assert (nms_radius >= 0) | |
| # 计算NMS窗口大小 (2*半径+1) | |
| size = nms_radius * 2 + 1 | |
| avg_size = 2 | |
| # 定义最大池化函数 (使用固定步长1和适当填充) | |
| def max_pool(x): | |
| return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius) | |
| # 创建与输入相同形状的零张量 | |
| zeros = torch.zeros_like(scores) | |
| # max_map = max_pool(scores) | |
| # 步骤1: 识别局部最大值点 | |
| # 比较每个点与其邻域内的最大值 | |
| max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。 | |
| # 步骤2: 添加微小随机扰动 (避免多个相同最大值) | |
| # 生成 [0, 0.1) 范围内的随机数 | |
| max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10 | |
| # 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。 | |
| # 非局部最大值点置零 | |
| max_mask_[~max_mask] = 0 | |
| # 步骤3: 对扰动后的图再次应用NMS | |
| # 识别扰动后仍然是局部最大值的点 | |
| mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。 | |
| # 步骤4: 保留局部最大值点,其他点置零 | |
| return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。 | |
| def remove_borders(keypoints, scores, border: int, height: int, width: int): | |
| ''' | |
| 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点 | |
| 这个函数是在推理的时候用的,训练的时候没有使用 | |
| keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。 | |
| scores: 每个关键点对应的分数,形状为 (N,)。 | |
| border: 表示需要移除的边界宽度。 推理时预设的是4像素 | |
| height: 图像的高度。 | |
| width: 图像的宽度。 | |
| ''' | |
| mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) # 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内 | |
| mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) # 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内 | |
| mask = mask_h & mask_w # 组合掩码 (必须同时满足高度和宽度条件) | |
| return keypoints[mask], scores[mask] # 返回过滤后的关键点和分数 | |
| def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False): | |
| """ | |
| 在检测器预测上应用非极大值抑制 (NMS) | |
| 参数: | |
| detector_pred: 检测器预测 (B, 1, H, W) | |
| nms_thresh: NMS阈值 | |
| nms_size: NMS邻域大小 | |
| detector_label: 检测器标签 (可选) | |
| mask: 是否使用标签掩码 (当前未实现) | |
| 返回: | |
| 关键点位置列表 (每个元素是 (N, 2) 的数组) | |
| """ | |
| detector_pred = detector_pred.clone().detach() # 创建预测副本 (避免修改原始数据) | |
| B, _, h, w = detector_pred.shape # 获取批次大小和图像尺寸 | |
| # if mask: | |
| # assert detector_label is not None | |
| # detector_pred[detector_pred < nms_thresh] = 0 | |
| # label_mask = detector_label | |
| # | |
| # # more area | |
| # | |
| # detector_label = detector_label.long().cpu().numpy() | |
| # detector_label = detector_label.astype(np.uint8) | |
| # kernel = np.ones((3, 3), np.uint8) | |
| # label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1) | |
| # for s in range(len(detector_label))]) | |
| # label_mask = torch.from_numpy(label_mask).unsqueeze(1) | |
| # detector_pred[label_mask > 1e-6] = 0 | |
| scores = simple_nms(detector_pred, nms_size) # 应用快速NMS算法 | |
| scores = scores.reshape(B, h, w) # 重塑分数图形状 (B, H, W) | |
| # print(f"scores before thresh {nms_thresh}", scores) | |
| points = [torch.nonzero(s > nms_thresh) for s in scores] # 找出分数高于阈值的点 | |
| # print("points after thresh", points) | |
| scores = [s[tuple(k.t())] for s, k in zip(scores, points)] # 提取这些点的分数值 | |
| points, scores = list(zip(*[ remove_borders(k, s, 8, h, w) for k, s in zip(points, scores)])) # 移除靠近边界的点 | |
| points = [torch.flip(k, [1]).long() for k in points] # 翻转坐标顺序: [y, x] -> [x, y] | |
| return points | |