Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def freeze_weights(module): | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| def l1_regularize(module): | |
| reg_loss = 0. | |
| for key, param in module.reg_params.items(): | |
| if "weight" in key and param.requires_grad: | |
| reg_loss += torch.sum(torch.abs(param)) | |
| return reg_loss | |
| class SeparableConv2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): | |
| super(SeparableConv2d, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, | |
| groups=in_channels, bias=bias) | |
| self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.pointwise(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, in_channels, out_channels, reps, strides=1, | |
| start_with_relu=True, grow_first=True, with_bn=True): | |
| super(Block, self).__init__() | |
| self.with_bn = with_bn | |
| if out_channels != in_channels or strides != 1: | |
| self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) | |
| if with_bn: | |
| self.skipbn = nn.BatchNorm2d(out_channels) | |
| else: | |
| self.skip = None | |
| rep = [] | |
| for i in range(reps): | |
| if grow_first: | |
| inc = in_channels if i == 0 else out_channels | |
| outc = out_channels | |
| else: | |
| inc = in_channels | |
| outc = in_channels if i < (reps - 1) else out_channels | |
| rep.append(nn.ReLU(inplace=True)) | |
| rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) | |
| if with_bn: | |
| rep.append(nn.BatchNorm2d(outc)) | |
| if not start_with_relu: | |
| rep = rep[1:] | |
| else: | |
| rep[0] = nn.ReLU(inplace=False) | |
| if strides != 1: | |
| rep.append(nn.MaxPool2d(3, strides, 1)) | |
| self.rep = nn.Sequential(*rep) | |
| def forward(self, inp): | |
| x = self.rep(inp) | |
| if self.skip is not None: | |
| skip = self.skip(inp) | |
| if self.with_bn: | |
| skip = self.skipbn(skip) | |
| else: | |
| skip = inp | |
| x += skip | |
| return x | |
| class GraphReasoning(nn.Module): | |
| """ Graph Reasoning Module for information aggregation. """ | |
| def __init__(self, va_in, va_out, vb_in, vb_out, vc_in, vc_out, spatial_ratio, drop_rate): | |
| super(GraphReasoning, self).__init__() | |
| self.ratio = spatial_ratio | |
| self.va_embedding = nn.Sequential( | |
| nn.Conv2d(va_in, va_out, 1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(va_out, va_out, 1, bias=False), | |
| ) | |
| self.va_gated_b = nn.Sequential( | |
| nn.Conv2d(va_in, va_out, 1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| self.va_gated_c = nn.Sequential( | |
| nn.Conv2d(va_in, va_out, 1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| self.vb_embedding = nn.Sequential( | |
| nn.Linear(vb_in, vb_out, bias=False), | |
| nn.ReLU(True), | |
| nn.Linear(vb_out, vb_out, bias=False), | |
| ) | |
| self.vc_embedding = nn.Sequential( | |
| nn.Linear(vc_in, vc_out, bias=False), | |
| nn.ReLU(True), | |
| nn.Linear(vc_out, vc_out, bias=False), | |
| ) | |
| self.unfold_b = nn.Unfold(kernel_size=spatial_ratio[0], stride=spatial_ratio[0]) | |
| self.unfold_c = nn.Unfold(kernel_size=spatial_ratio[1], stride=spatial_ratio[1]) | |
| self.reweight_ab = nn.Sequential( | |
| nn.Linear(va_out + vb_out, 1, bias=False), | |
| nn.ReLU(True), | |
| nn.Softmax(dim=1) | |
| ) | |
| self.reweight_ac = nn.Sequential( | |
| nn.Linear(va_out + vc_out, 1, bias=False), | |
| nn.ReLU(True), | |
| nn.Softmax(dim=1) | |
| ) | |
| self.reproject = nn.Sequential( | |
| nn.Conv2d(va_out + vb_out + vc_out, va_in, kernel_size=1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(va_in, va_in, kernel_size=1, bias=False), | |
| nn.Dropout(drop_rate) if drop_rate is not None else nn.Identity(), | |
| ) | |
| def forward(self, vert_a, vert_b, vert_c): | |
| emb_vert_a = self.va_embedding(vert_a) | |
| emb_vert_a = emb_vert_a.reshape([emb_vert_a.shape[0], emb_vert_a.shape[1], -1]) | |
| gate_vert_b = 1 - self.va_gated_b(vert_a) | |
| gate_vert_b = gate_vert_b.reshape(*emb_vert_a.shape) | |
| gate_vert_c = 1 - self.va_gated_c(vert_a) | |
| gate_vert_c = gate_vert_c.reshape(*emb_vert_a.shape) | |
| vert_b = self.unfold_b(vert_b).reshape( | |
| [vert_b.shape[0], vert_b.shape[1], self.ratio[0] * self.ratio[0], -1]) | |
| vert_b = vert_b.permute([0, 2, 3, 1]) | |
| emb_vert_b = self.vb_embedding(vert_b) | |
| vert_c = self.unfold_c(vert_c).reshape( | |
| [vert_c.shape[0], vert_c.shape[1], self.ratio[1] * self.ratio[1], -1]) | |
| vert_c = vert_c.permute([0, 2, 3, 1]) | |
| emb_vert_c = self.vc_embedding(vert_c) | |
| agg_vb = list() | |
| agg_vc = list() | |
| for j in range(emb_vert_a.shape[-1]): | |
| # ab propagating | |
| emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[0] ** 2), dim=1) | |
| emb_v_b = emb_vert_b[:, :, j, :] | |
| emb_v_ab = torch.cat([emb_v_a, emb_v_b], dim=-1) | |
| w = self.reweight_ab(emb_v_ab) | |
| agg_vb.append(torch.bmm(emb_v_b.transpose(1, 2), w).squeeze() * gate_vert_b[:, :, j]) | |
| # ac propagating | |
| emb_v_a = torch.stack([emb_vert_a[:, :, j]] * (self.ratio[1] ** 2), dim=1) | |
| emb_v_c = emb_vert_c[:, :, j, :] | |
| emb_v_ac = torch.cat([emb_v_a, emb_v_c], dim=-1) | |
| w = self.reweight_ac(emb_v_ac) | |
| agg_vc.append(torch.bmm(emb_v_c.transpose(1, 2), w).squeeze() * gate_vert_c[:, :, j]) | |
| agg_vert_b = torch.stack(agg_vb, dim=-1) | |
| agg_vert_c = torch.stack(agg_vc, dim=-1) | |
| agg_vert_bc = torch.cat([agg_vert_b, agg_vert_c], dim=1) | |
| agg_vert_abc = torch.cat([agg_vert_bc, emb_vert_a], dim=1) | |
| agg_vert_abc = torch.sigmoid(agg_vert_abc) | |
| agg_vert_abc = agg_vert_abc.reshape(vert_a.shape[0], -1, vert_a.shape[2], vert_a.shape[3]) | |
| return self.reproject(agg_vert_abc) | |
| class GuidedAttention(nn.Module): | |
| """ Reconstruction Guided Attention. """ | |
| def __init__(self, depth=728, drop_rate=0.2): | |
| super(GuidedAttention, self).__init__() | |
| self.depth = depth | |
| self.gated = nn.Sequential( | |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), | |
| nn.ReLU(True), | |
| nn.Conv2d(3, 1, 1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| self.h = nn.Sequential( | |
| nn.Conv2d(depth, depth, 1, 1, bias=False), | |
| nn.BatchNorm2d(depth), | |
| nn.ReLU(True), | |
| ) | |
| self.dropout = nn.Dropout(drop_rate) | |
| def forward(self, x, pred_x, embedding): | |
| residual_full = torch.abs(x - pred_x) | |
| residual_x = F.interpolate(residual_full, size=embedding.shape[-2:], | |
| mode='bilinear', align_corners=True) | |
| res_map = self.gated(residual_x) | |
| return res_map * self.h(embedding) + self.dropout(embedding) | |