Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import numpy as np | |
| from math import exp | |
| class MultiScaleDerivativeLoss(nn.Module): | |
| def __init__(self, operator='scharr', p=1, reduction='mean', normalize_input=False, num_scales=4): | |
| """ | |
| operator: 'scharr' (一阶) or 'laplace' (二阶) | |
| p: 1 for L1, 2 for L2 | |
| reduction: 'mean' or 'sum' | |
| normalize_input: whether to normalize input vectors (for normals) | |
| num_scales: number of scales in the pyramid (e.g., 4 = 原图, 1/2, 1/4, 1/8) | |
| """ | |
| super().__init__() | |
| assert operator in ['scharr', 'laplace'] | |
| assert p in [1, 2] | |
| assert reduction in ['mean', 'sum'] | |
| assert num_scales >= 1 | |
| self.operator = operator | |
| self.p = p | |
| self.reduction = reduction | |
| self.normalize_input = normalize_input | |
| self.num_scales = num_scales | |
| def forward(self, pred, gt): | |
| """ | |
| pred, gt: [B, C, H, W] tensors | |
| """ | |
| pred_pyramid = self._build_pyramid(pred) | |
| gt_pyramid = self._build_pyramid(gt) | |
| total_loss = 0.0 | |
| for pred_i, gt_i in zip(pred_pyramid, gt_pyramid): | |
| if self.normalize_input: | |
| pred_i = F.normalize(pred_i, dim=1) | |
| gt_i = F.normalize(gt_i, dim=1) | |
| grad_pred = self._compute_gradient(pred_i) | |
| grad_gt = self._compute_gradient(gt_i) | |
| diff = grad_pred - grad_gt | |
| if self.p == 1: | |
| diff = torch.abs(diff) | |
| else: | |
| diff = diff ** 2 | |
| if self.reduction == 'mean': | |
| total_loss += diff.mean() | |
| else: | |
| total_loss += diff.sum() | |
| return total_loss / self.num_scales | |
| def _build_pyramid(self, img): | |
| """Construct a multi-scale pyramid from input image""" | |
| pyramid = [img] | |
| for i in range(1, self.num_scales): | |
| scale = 0.5 ** i | |
| img = F.interpolate(img, scale_factor=scale, mode='bicubic', align_corners=False, recompute_scale_factor=True,antialias=True) | |
| pyramid.append(img) | |
| return pyramid | |
| def _compute_gradient(self, img): | |
| B, C, H, W = img.shape | |
| device = img.device | |
| if self.operator == 'scharr': | |
| kernel_x = torch.tensor([[[-3., 0., 3.], | |
| [-10., 0., 10.], | |
| [-3., 0., 3.]]], device=device) / 16.0 | |
| kernel_y = torch.tensor([[[-3., -10., -3.], | |
| [0., 0., 0.], | |
| [3., 10., 3.]]], device=device) / 16.0 | |
| kernel_x = kernel_x.unsqueeze(0).expand(C, 1, 3, 3) | |
| kernel_y = kernel_y.unsqueeze(0).expand(C, 1, 3, 3) | |
| grad_x = F.conv2d(img, kernel_x, padding=1, groups=C) | |
| grad_y = F.conv2d(img, kernel_y, padding=1, groups=C) | |
| return torch.cat([grad_x, grad_y], dim=1) # [B, 2C, H, W] | |
| elif self.operator == 'laplace': | |
| kernel = torch.tensor([[[0., 1., 0.], | |
| [1., -4., 1.], | |
| [0., 1., 0.]]], device=device) | |
| kernel = kernel.unsqueeze(0).expand(C, 1, 3, 3) | |
| return F.conv2d(img, kernel, padding=1, groups=C) # [B, C, H, W] | |
| class CosineLoss(torch.nn.Module): | |
| def __init__(self): | |
| super(CosineLoss, self).__init__() | |
| def forward(self, N, N_hat): | |
| """ | |
| N: 真实法向量, 形状 (B, C, H, W) | |
| N_hat: 预测法向量, 形状应与 N 相同 | |
| """ | |
| # 创建非零 mask(按像素维度求L2范数) | |
| _,_,H,W = N.shape | |
| mask = (N.norm(p=2, dim=1, keepdim=True) > 0) # shape: (B, 1, H, W),True表示N非零 | |
| mse = F.mse_loss(N, N_hat, reduction='mean') * H * W /2048 | |
| dot_product = torch.sum(N * N_hat, dim=1, keepdim=True) # shape: (B, 1, H, W) | |
| # 仅在非零区域计算 loss | |
| loss = 1 - dot_product | |
| loss = loss[mask] # 只取非零像素位置 | |
| return loss.mean(), mse | |
| def gaussian(window_size, sigma): | |
| gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | |
| return gauss/gauss.sum() | |
| def create_window(window_size, channel): | |
| _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
| _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
| window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
| return window | |
| def _ssim(img1, img2, window, window_size, channel, size_average = True, stride=None): | |
| mu1 = F.conv2d(img1, window, padding = (window_size-1)//2, groups = channel, stride=stride) | |
| mu2 = F.conv2d(img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) | |
| mu1_sq = mu1.pow(2) | |
| mu2_sq = mu2.pow(2) | |
| mu1_mu2 = mu1*mu2 | |
| sigma1_sq = F.conv2d(img1*img1, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_sq | |
| sigma2_sq = F.conv2d(img2*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu2_sq | |
| sigma12 = F.conv2d(img1*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_mu2 | |
| C1 = 0.01**2 | |
| C2 = 0.03**2 | |
| ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
| if size_average: | |
| return ssim_map.mean() | |
| else: | |
| return ssim_map.mean(1).mean(1).mean(1) | |
| class SSIM(torch.nn.Module): | |
| def __init__(self, window_size = 3, size_average = True, stride=3): | |
| super(SSIM, self).__init__() | |
| self.window_size = window_size | |
| self.size_average = size_average | |
| self.channel = 1 | |
| self.stride = stride | |
| self.window = create_window(window_size, self.channel) | |
| def forward(self, img1, img2): | |
| """ | |
| img1, img2: torch.Tensor([b,c,h,w]) | |
| """ | |
| (_, channel, _, _) = img1.size() | |
| if channel == self.channel and self.window.data.type() == img1.data.type(): | |
| window = self.window | |
| else: | |
| window = create_window(self.window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| self.window = window | |
| self.channel = channel | |
| return _ssim(img1, img2, window, self.window_size, channel, self.size_average, stride=self.stride) | |
| def ssim(img1, img2, window_size = 11, size_average = True): | |
| (_, channel, _, _) = img1.size() | |
| window = create_window(window_size, channel) | |
| if img1.is_cuda: | |
| window = window.cuda(img1.get_device()) | |
| window = window.type_as(img1) | |
| return _ssim(img1, img2, window, window_size, channel, size_average) | |
| class S3IM(torch.nn.Module): | |
| def __init__(self, kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=32): | |
| super(S3IM, self).__init__() | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.repeat_time = repeat_time | |
| self.patch_height = patch_height | |
| self.patch_width = patch_width | |
| self.ssim_loss = SSIM(window_size=self.kernel_size, stride=self.stride) | |
| def forward(self, src_vec, tar_vec): | |
| """ | |
| Args: | |
| src_vec: [B, N, C] e.g., [batch, pixels, channels] | |
| tar_vec: [B, N, C] | |
| Returns: | |
| loss: scalar tensor | |
| """ | |
| B, N, C = src_vec.shape | |
| device = src_vec.device | |
| patch_list_src, patch_list_tar = [], [] | |
| for b in range(B): | |
| index_list = [] | |
| for i in range(self.repeat_time): | |
| if i == 0: | |
| tmp_index = torch.arange(N, device=device) | |
| else: | |
| tmp_index = torch.randperm(N, device=device) | |
| index_list.append(tmp_index) | |
| res_index = torch.cat(index_list) # [M * N] | |
| tar_all = tar_vec[b][res_index] # [M*N, C] | |
| src_all = src_vec[b][res_index] # [M*N, C] | |
| # reshape into [1, C, H, W] | |
| tar_patch = tar_all.permute(1, 0).reshape(1, C, self.patch_height, self.patch_width * self.repeat_time) | |
| src_patch = src_all.permute(1, 0).reshape(1, C, self.patch_height, self.patch_width * self.repeat_time) | |
| patch_list_tar.append(tar_patch) | |
| patch_list_src.append(src_patch) | |
| # Stack all batches: [B, C, H, W] | |
| tar_tensor = torch.cat(patch_list_tar, dim=0) | |
| src_tensor = torch.cat(patch_list_src, dim=0) | |
| # 计算 batch-wise SSIM,输出为 [B] | |
| ssim_scores = self.ssim_loss(src_tensor, tar_tensor) | |
| # 损失为 1 - mean SSIM | |
| loss = 1.0 - ssim_scores | |
| return loss | |
| torch.manual_seed(0) | |
| # 假设每张图片提取出 64 x 64 个像素,每个像素 3 通道 | |
| # H, W, C = 64, 32, 3 | |
| # N = H * W | |
| # B = 4 | |
| # # 随机生成两个图像特征向量:[N, C] | |
| # src_vec = torch.rand(B, N, C) # 模拟重建图像 | |
| # tar_vec = torch.rand(B, N, C) # 模拟 ground truth 图像 | |
| # # 初始化 S3IM 模块 | |
| # s3im_loss_fn = S3IM(kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=32) | |
| # # 计算损失 | |
| # loss = s3im_loss_fn(src_vec, tar_vec) | |
| def weighted_huber_loss( | |
| input: torch.Tensor, | |
| target: torch.Tensor, | |
| weight: torch.Tensor, # 新增的置信度权重张量 | |
| reduction: str = 'mean', | |
| delta: float = 1.0, | |
| ) -> torch.Tensor: | |
| # 广播对齐所有张量 | |
| expanded_input, expanded_target = torch.broadcast_tensors(input, target) | |
| expanded_weight, _ = torch.broadcast_tensors(weight, input) # 确保权重可广播 | |
| # 计算逐元素误差 | |
| diff = expanded_input - expanded_target | |
| abs_diff = torch.abs(diff) | |
| # Huber损失分段计算 | |
| loss = torch.where( | |
| abs_diff <= delta, | |
| 0.5 * (diff ** 2), | |
| delta * (abs_diff - 0.5 * delta) | |
| ) | |
| # 应用权重 | |
| weighted_loss = expanded_weight * loss | |
| # 汇总方式 | |
| if reduction == 'mean': | |
| return torch.mean(weighted_loss) | |
| elif reduction == 'sum': | |
| return torch.sum(weighted_loss) | |
| elif reduction == 'none': | |
| return weighted_loss | |
| else: | |
| raise ValueError(f"Unsupported reduction: {reduction}") |