| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class Histogram_Matching(nn.Module): |
| | def __init__(self, differentiable=False): |
| | super(Histogram_Matching, self).__init__() |
| | self.differentiable = differentiable |
| |
|
| | def forward(self, dst, ref): |
| | |
| | B, C, H, W = dst.size() |
| | |
| | assert dst.device == ref.device |
| | |
| | hist_dst = self.cal_hist(dst) |
| | hist_ref = self.cal_hist(ref) |
| | |
| | tables = self.cal_trans_batch(hist_dst, hist_ref) |
| | |
| | rst = dst.clone() |
| | for b in range(B): |
| | for c in range(C): |
| | rst[b,c] = tables[b*c, (dst[b,c] * 255).long()] |
| | |
| | rst /= 255. |
| | return rst |
| |
|
| | def cal_hist(self, img): |
| | B, C, H, W = img.size() |
| | |
| | if self.differentiable: |
| | hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=3*25) |
| | else: |
| | hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255) for b in range(B) for c in range(C)]) |
| | hists = hists.float() |
| | hists = F.normalize(hists, p=1) |
| | |
| | bc, n = hists.size() |
| | |
| | triu = torch.ones(bc, n, n, device=hists.device).triu() |
| | |
| | hists = torch.bmm(hists[:,None,:], triu)[:,0,:] |
| | return hists |
| |
|
| | def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=3*25): |
| | |
| | B, C, H, W = x.size() |
| | |
| | x = x.view(B*C, -1) |
| | |
| | delta = float(max - min) / float(bins) |
| | |
| | centers = float(min) + delta * (torch.arange(bins, device=x.device, dtype=torch.bfloat16) + 0.5) |
| | |
| | x = torch.unsqueeze(x, 1) |
| | |
| | centers = centers[None,:,None] |
| | |
| | x = x - centers |
| | |
| | x = x.type(torch.bfloat16) |
| | |
| | x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2)) |
| | |
| | x = x.sum(dim=2) |
| | |
| | x = x.type(torch.float32) |
| | |
| | |
| | return x |
| |
|
| | def cal_trans_batch(self, hist_dst, hist_ref): |
| | |
| | hist_dst = hist_dst[:,None,:].repeat(1,256,1) |
| | |
| | hist_ref = hist_ref[:,:,None].repeat(1,1,256) |
| | |
| | table = hist_dst - hist_ref |
| | |
| | table = torch.where(table>=0, 1., 0.) |
| | |
| | table = torch.sum(table, dim=1) - 1 |
| | |
| | table = torch.clamp(table, min=0, max=255) |
| | return table |
| |
|