| from typing import Union |
|
|
| import torch |
| from torch import nn, Tensor |
| from torch.nn.functional import normalize |
|
|
| from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING |
| from classes.fc4.squeezenet.SqueezeNetLoader import SqueezeNetLoader |
|
|
| """ |
| FC4: Fully Convolutional Color Constancy with Confidence-weighted Pooling |
| * Original code: https://github.com/yuanming-hu/fc4 |
| * Paper: https://www.microsoft.com/en-us/research/publication/fully-convolutional-color-constancy-confidence-weighted-pooling/ |
| """ |
|
|
|
|
| class FC4(torch.nn.Module): |
|
|
| def __init__(self, squeezenet_version: float = 1.1): |
| super().__init__() |
|
|
| |
| squeezenet = SqueezeNetLoader(squeezenet_version).load(pretrained=True) |
| self.backbone = nn.Sequential(*list(squeezenet.children())[0][:12]) |
|
|
| |
| self.final_convs = nn.Sequential( |
| nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True), |
| nn.Conv2d(512, 64, kernel_size=6, stride=1, padding=3), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=0.5), |
| nn.Conv2d(64, 4 if USE_CONFIDENCE_WEIGHTED_POOLING else 3, kernel_size=1, stride=1), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self, x: Tensor) -> Union[tuple, Tensor]: |
| """ |
| Estimate an RGB colour for the illuminant of the input image |
| @param x: the image for which the colour of the illuminant has to be estimated |
| @return: the colour estimate as a Tensor. If confidence-weighted pooling is used, the per-path colour estimates |
| and the confidence weights are returned as well (used for visualizations) |
| """ |
|
|
| x = self.backbone(x) |
| out = self.final_convs(x) |
|
|
| |
| if USE_CONFIDENCE_WEIGHTED_POOLING: |
| |
| rgb = normalize(out[:, :3, :, :], dim=1) |
|
|
| |
| confidence = out[:, 3:4, :, :] |
|
|
| |
| pred = normalize(torch.sum(torch.sum(rgb * confidence, 2), 2), dim=1) |
|
|
| return pred, rgb, confidence |
|
|
| |
| pred = normalize(torch.sum(torch.sum(out, 2), 2), dim=1) |
|
|
| return pred |
|
|