Update modeling_vit.py
Browse files- modeling_vit.py +3 -2
modeling_vit.py
CHANGED
|
@@ -924,10 +924,11 @@ class ViTForSemanticSegmentation(ViTPreTrainedModel):
|
|
| 924 |
out = self.decoder_norm(out)
|
| 925 |
B, _, C = out.shape
|
| 926 |
# out = out.reshape(B, self.hw_shape, self.hw_shape, C).permute(0, 3, 1, 2).contiguous()
|
| 927 |
-
out = out.reshape(
|
| 928 |
out = self.decoder_mlp(out)
|
|
|
|
| 929 |
out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=False)
|
| 930 |
-
|
| 931 |
logits = self.decoder_classifier(out)
|
| 932 |
logits = logits.permute(0, 3, 1, 2).contiguous()
|
| 933 |
|
|
|
|
| 924 |
out = self.decoder_norm(out)
|
| 925 |
B, _, C = out.shape
|
| 926 |
# out = out.reshape(B, self.hw_shape, self.hw_shape, C).permute(0, 3, 1, 2).contiguous()
|
| 927 |
+
out = out.reshape(B, self.hw_shape, self.hw_shape, C)
|
| 928 |
out = self.decoder_mlp(out)
|
| 929 |
+
out = out.permute(0, 3, 1, 2).contiguous()
|
| 930 |
out = nn.functional.interpolate(out, scale_factor=4, mode='bilinear', align_corners=False)
|
| 931 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
| 932 |
logits = self.decoder_classifier(out)
|
| 933 |
logits = logits.permute(0, 3, 1, 2).contiguous()
|
| 934 |
|