Spaces:
Running on Zero
Running on Zero
| import torch | |
| from torch import nn | |
| class LocalFusion(nn.Module): | |
| def __init__(self, att_in_dim=3, num_categories=6, max_pool_ksize1=4, max_pool_ksize2=2, encoder_dims=[8, 16]): | |
| super().__init__() | |
| self.num_categories = num_categories | |
| self.att_in_dim = att_in_dim | |
| self.attention_fusion = nn.ModuleList([Self_Attn(in_dim=att_in_dim, max_pool_ksize1=max_pool_ksize1, max_pool_ksize2=max_pool_ksize2, encoder_dims=encoder_dims) for _ in range(num_categories)]) | |
| def forward(self, x, color_naming_probs=None, q=None): | |
| # Using the average to compute the blending | |
| if color_naming_probs is None: | |
| # Using the same input tensor for query, key, and value | |
| if q is None: | |
| return torch.mean(torch.stack([att(x_color, q=x) for att, x_color in zip(self.attention_fusion, x)], dim=0)) | |
| else: | |
| return torch.mean(torch.stack([att(x_color, q=q) for att, x_color in zip(self.attention_fusion, x)], dim=0)) | |
| # Using the color naming probabilities to compute the blending. Weighted average with color naming probs as | |
| # weights. | |
| else: | |
| color_naming_probs = (color_naming_probs > 0.20).float() | |
| color_naming_avg = torch.sum(color_naming_probs, dim=0).unsqueeze(1).repeat(1, 3, 1, 1) | |
| color_naming_probs = color_naming_probs.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
| # Using the same input tensor for query, key, and value | |
| if q is None: | |
| out = torch.stack([att(x_color, q=x) for att, x_color in zip(self.attention_fusion, x)], dim=0) | |
| else: | |
| out = torch.stack([att(x_color, q=q) for att, x_color in zip(self.attention_fusion, x)], dim=0) | |
| out = torch.sum(out * color_naming_probs, dim=0) / color_naming_avg | |
| return torch.clip(out, 0, 1) | |
| class Self_Attn(nn.Module): | |
| def __init__(self, in_dim, max_pool_ksize1=4, max_pool_ksize2=2, encoder_dims=[8, 16]): | |
| super(Self_Attn, self).__init__() | |
| self.chanel_in = in_dim | |
| self.max_pool_ksize1 = max_pool_ksize1 | |
| self.max_pool_ksize2 = max_pool_ksize2 | |
| self.down_ratio = max_pool_ksize1 * max_pool_ksize2 | |
| self.query_conv = nn.Sequential( | |
| nn.Conv2d(in_channels=in_dim, out_channels=encoder_dims[0], kernel_size=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(4, 4), | |
| nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[1], kernel_size=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2)) | |
| self.key_conv = nn.Sequential( | |
| nn.Conv2d(in_channels=in_dim, out_channels=encoder_dims[0], kernel_size=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(4, 4), | |
| nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[1], kernel_size=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2)) | |
| self.value_conv = nn.Sequential( | |
| nn.Conv2d(in_channels=in_dim, out_channels=encoder_dims[0], kernel_size=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(4, 4), | |
| nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[1], kernel_size=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2)) | |
| self.upsample = nn.Sequential( | |
| nn.Conv2d(in_channels=encoder_dims[1], out_channels=encoder_dims[0], kernel_size=1), | |
| nn.ReLU(), | |
| nn.UpsamplingNearest2d(scale_factor=4), | |
| nn.Conv2d(in_channels=encoder_dims[0], out_channels=encoder_dims[0], kernel_size=1), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels=encoder_dims[0], out_channels=3, kernel_size=1), | |
| nn.ReLU()) | |
| self.last_conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| self.max_pool = nn.MaxPool2d(2, 2) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x, q=None): | |
| if q is None: | |
| q = x | |
| m_batch_size, C, width, height = x.size() | |
| proj_query = self.query_conv(q).view(m_batch_size, -1, int((width//self.down_ratio)*(height//self.down_ratio))).permute(0, 2, 1) | |
| proj_key = self.key_conv(x).view(m_batch_size, -1, int((width//self.down_ratio)*(height//self.down_ratio))) | |
| energy = torch.bmm(proj_query, proj_key) | |
| attention = self.softmax(energy) | |
| proj_value = self.value_conv(x).view(m_batch_size, -1, int((width//self.down_ratio)*(height//self.down_ratio))) | |
| out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
| out = out.view(m_batch_size, 16, int(width//self.down_ratio), int(height//self.down_ratio)) | |
| out = self.upsample(out) | |
| upsampled_layer = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=False) | |
| out = upsampled_layer(out) | |
| out = self.last_conv(out) | |
| out = out + x | |
| return out |