upload
Browse files
alignedthreeattn_model.py
CHANGED
|
@@ -54,7 +54,12 @@ class ThreeAttnNodes(nn.Module):
|
|
| 54 |
x = F.interpolate(x, size=(588, 588), mode="bilinear")
|
| 55 |
feat2 = self.backbone2(x)
|
| 56 |
feats = torch.cat([feat1, feat2, feat3], dim=1)
|
| 57 |
-
out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
|
| 59 |
return out
|
| 60 |
|
|
|
|
| 54 |
x = F.interpolate(x, size=(588, 588), mode="bilinear")
|
| 55 |
feat2 = self.backbone2(x)
|
| 56 |
feats = torch.cat([feat1, feat2, feat3], dim=1)
|
| 57 |
+
# out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
|
| 58 |
+
outs = []
|
| 59 |
+
for i_layer in range(36):
|
| 60 |
+
out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
|
| 61 |
+
outs.append(out)
|
| 62 |
+
out = torch.stack(outs, dim=1)
|
| 63 |
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
|
| 64 |
return out
|
| 65 |
|