Spaces:
Sleeping
Sleeping
File size: 12,095 Bytes
bddd311 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | from torch.nn import functional as F
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
import numpy as np
from PIL import Image
import cv2
import torch.nn as nn
import scipy.stats as st
# 从提供的关键点和对应分数中移除那些过于靠近图像边缘的关键点
# 这个函数是在推理的时候用的,训练的时候没有使用
def remove_borders(keypoints, scores, border: int, height: int, width: int):
""" Removes keypoints too close to the border """
'''
keypoints: 关键点坐标的二维数组,形状为 (N, 2),其中每行表示一个关键点的 (y, x) 坐标。
scores: 每个关键点对应的分数,形状为 (N,)。
border: 表示需要移除的边界宽度。 推理时预设的是4像素
height: 图像的高度。
width: 图像的宽度。
'''
# 创建高度方向掩码: 关键点必须在 [border, height-border) 范围内
mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
# 创建宽度方向掩码: 关键点必须在 [border, width-border) 范围内
mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
# 组合掩码 (必须同时满足高度和宽度条件)
mask = mask_h & mask_w # 所以这个mask是判定条件
# 返回过滤后的关键点和分数
return keypoints[mask], scores[mask]
def simple_nms(scores, nms_radius: int):
"""
快速非极大值抑制 (NMS) 算法,用于移除相邻关键点
参数:
scores: 关键点分数图 (B, H, W) 或 (H, W)
nms_radius: NMS邻域半径
返回:
NMS处理后的分数图
"""
assert (nms_radius >= 0)
# 计算NMS窗口大小 (2*半径+1)
size = nms_radius * 2 + 1
avg_size = 2
# 定义最大池化函数 (使用固定步长1和适当填充)
def max_pool(x):
return torch.nn.functional.max_pool2d(x, kernel_size=size, stride=1, padding=nms_radius)
# 创建与输入相同形状的零张量
zeros = torch.zeros_like(scores)
# max_map = max_pool(scores)
# 步骤1: 识别局部最大值点
# 比较每个点与其邻域内的最大值
max_mask = scores == max_pool(scores) # max_pool(scores):每个像素点被替换为其局部窗口内的最大值。
# 步骤2: 添加微小随机扰动 (避免多个相同最大值)
# 生成 [0, 0.1) 范围内的随机数
max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10
# 生成与 max_mask 相同形状的随机数(范围在 [0, 0.1)),作为微小扰动。
# 非局部最大值点置零
max_mask_[~max_mask] = 0
# 步骤3: 对扰动后的图再次应用NMS
# 识别扰动后仍然是局部最大值的点
mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) # mask:布尔掩码,仅保留扰动后仍然是局部最大值的点。
# 步骤4: 保留局部最大值点,其他点置零
return torch.where(mask, scores, zeros) # 如果 mask 为 True,保留原始分数。否则,将得分设置为零。
def pre_processing(data):
""" Enhance retinal images """
train_imgs = datasets_normalized(data)
train_imgs = clahe_equalized(train_imgs)
train_imgs = adjust_gamma(train_imgs, 1.2)
train_imgs = train_imgs / 255.
return train_imgs.astype(np.float32)
def rgb2gray(rgb):
""" Convert RGB image to gray image """
r, g, b = rgb.split()
return g
# 对输入图像的 CLAHE(Contrast Limited Adaptive Histogram Equalization)增强处理
# 用于提高图像的对比度,特别是在光照不均或细节难以分辨的情况下。
def clahe_equalized(images):
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
# clipLimit=2.0: 限制对比度的参数,值越低,增强的对比度越小,避免高对比度导致的过曝光区域。
# tileGridSize=(8, 8): 将图像分成大小为 8x8 的网格,每个网格单独进行直方图均衡化,减少全局对比度增强引入的伪影。
images_equalized = np.empty(images.shape)
images_equalized[:, :] = clahe.apply(np.array(images[:, :],
dtype=np.uint8))
return images_equalized
def datasets_normalized(images):
# 归一化之后还需要把值映射到0到255
# images_normalized = np.empty(images.shape)
images_std = np.std(images)
images_mean = np.mean(images)
images_normalized = (images - images_mean) / (images_std + 1e-6)
minv = np.min(images_normalized)
images_normalized = ((images_normalized - minv) /
(np.max(images_normalized) - minv)) * 255
return images_normalized
def adjust_gamma(images, gamma=1.0):
invGamma = 1.0 / gamma # invGamma: 伽马值的倒数,用于生成查找表。
table = np.array([((i / 255.0) ** invGamma) * 255
for i in np.arange(0, 256)]).astype("uint8")
# 预计算伽马校正的转换值,用于快速查找。
# 每个输入像素值(0-255)都映射到一个经过伽马变换的输出值。
# 生成过程:
# i / 255.0: 将像素值归一化到 [0, 1] 范围。
# (i / 255.0) ** invGamma: 应用伽马校正公式。
# * 255: 将归一化后的值还原到 [0, 255] 范围。
# astype("uint8"): 转换为 uint8 数据类型,适配图像格式。
new_images = np.empty(images.shape)
new_images[:, :] = cv2.LUT(np.array(images[:, :],
dtype=np.uint8), table)
# cv2.LUT: OpenCV 的快速像素值映射函数。
# 输入图像的每个像素值通过查找表 table 进行伽马校正。
# 大幅提高效率,避免逐像素计算。
return new_images
def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False):
"""
在检测器预测上应用非极大值抑制 (NMS)
参数:
detector_pred: 检测器预测 (B, 1, H, W)
nms_thresh: NMS阈值
nms_size: NMS邻域大小
detector_label: 检测器标签 (可选)
mask: 是否使用标签掩码 (当前未实现)
返回:
关键点位置列表 (每个元素是 (N, 2) 的数组)
"""
# 创建预测副本 (避免修改原始数据)
detector_pred = detector_pred.clone().detach()
# 获取批次大小和图像尺寸
B, _, h, w = detector_pred.shape
# if mask:
# assert detector_label is not None
# detector_pred[detector_pred < nms_thresh] = 0
# label_mask = detector_label
#
# # more area
#
# detector_label = detector_label.long().cpu().numpy()
# detector_label = detector_label.astype(np.uint8)
# kernel = np.ones((3, 3), np.uint8)
# label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1)
# for s in range(len(detector_label))])
# label_mask = torch.from_numpy(label_mask).unsqueeze(1)
# detector_pred[label_mask > 1e-6] = 0
# 应用快速NMS算法
scores = simple_nms(detector_pred, nms_size)
# 重塑分数图形状 (B, H, W)
scores = scores.reshape(B, h, w)
# 找出分数高于阈值的点
points = [
torch.nonzero(s > nms_thresh)
for s in scores]
# 提取这些点的分数值
scores = [s[tuple(k.t())] for s, k in zip(scores, points)]
# 移除靠近边界的点
points, scores = list(zip(*[
remove_borders(k, s, 8, h, w)
for k, s in zip(points, scores)]))
# 翻转坐标顺序: [y, x] -> [x, y]
points = [torch.flip(k, [1]).long() for k in points]
return points
# 实际上模型生成的描述符是1/8尺寸*1/8尺寸的描述符特征,这里是上采样到原尺寸的
# 这个是在PKE算损失的时候用的
def sample_keypoint_desc(keypoints, descriptors, s: int = 8):
"""
在关键点位置采样描述符
参数:
keypoints: 关键点坐标 (B, N, 2) 格式 [x, y]
descriptors: 描述符图 (B, C, H, W)
s: 描述符图相对于原始图像的下采样比例
返回:
采样后的描述符 (B, C, N)
"""
# 获取描述符张量的形状信息,用于后续处理
b, c, h, w = descriptors.shape # 原始输入 descriptors: (b, c, h, w)
# 克隆关键点并将其转换为浮点类型,以便进行坐标计算
keypoints = keypoints.clone().float()
# 将关键点坐标归一化到范围 (0, 1)
keypoints /= torch.tensor([(w * s - 1), (h * s - 1)]).to(keypoints)[None]
# 将关键点坐标缩放到范围 (-1, 1),以适应 grid_sample 函数的要求
keypoints = keypoints * 2 - 1
# 根据 PyTorch 版本准备 grid_sample 函数的参数,确保兼容性
args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
# 使用 grid_sample 函数在关键点位置插值描述符
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) # 经过 grid_sample: (b, c, 1, n) n是关键点数量
# 对描述符进行 L2 归一化,使其长度为 1(1个像素),以便后续处理
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1) # reshape 后: (b, c, n) channel=256
# 返回处理后的描述符
return descriptors
# 这个是在模型算损失的时候用的,可以同时处理关键点和被映射后的关键点损失
def sample_descriptors(detector_pred, descriptor_pred, affine_descriptor_pred, grid_inverse,
nms_size=10, nms_thresh=0.1, scale=8, affine_detector_pred=None):
"""
基于关键点采样描述符
参数:
detector_pred: 原始图像的检测器预测 (B, 1, H, W)
descriptor_pred: 原始图像的描述符预测 (B, C, H, W)
affine_descriptor_pred: 仿射图像的描述符预测 (B, C, H, W)
grid_inverse: 逆变换网格 (B, H, W, 2)
nms_size: NMS邻域大小
nms_thresh: NMS阈值
scale: 描述符图相对于原始图像的下采样比例
affine_detector_pred: 仿射图像的检测器预测 (可选)
返回:
descriptors: 原始图像关键点的描述符列表
affine_descriptors: 仿射图像对应关键点的描述符列表
keypoints: 原始图像的关键点位置列表
"""
# 获取批次大小和图像尺寸
B, _, h, w = detector_pred.shape
# 应用NMS获取关键点位置
keypoints = nms(detector_pred, nms_size=nms_size, nms_thresh=nms_thresh)
# 使用逆变换网格将关键点映射到仿射空间
affine_keypoints = [
grid_inverse[s, k[:, 1].long(), k[:, 0].long()] # 使用网格插值
for s, k in enumerate(keypoints)
]
# 初始化存储列表
kp = [] # 过滤后的原始关键点
affine_kp = [] # 过滤后的仿射关键点
# 处理每个样本
for s, k in enumerate(affine_keypoints):
# 过滤超出仿射图像边界的点
idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & (k[:, 1] > -1)
# 存储过滤后的原始关键点
kp.append(keypoints[s][idx])
# 获取过滤后的仿射关键点
ak = k[idx]
# 将归一化坐标转换回像素坐标
ak[:, 0] = (ak[:, 0] + 1) / 2 * (w - 1) # x坐标
ak[:, 1] = (ak[:, 1] + 1) / 2 * (h - 1) # y坐标
# 存储转换后的仿射关键点
affine_kp.append(ak)
# 在原始图像关键点位置采样描述符
descriptors = [
sample_keypoint_desc(k[None], d[None], s=scale)[0]
for k, d in zip(kp, descriptor_pred)
]
# 在仿射图像关键点位置采样描述符
affine_descriptors = [
sample_keypoint_desc(k[None], d[None], s=scale)[0]
for k, d in zip(affine_kp, affine_descriptor_pred)
]
return descriptors, affine_descriptors, keypoints
|