phanerozoic commited on
Commit
d2251ba
·
verified ·
1 Parent(s): 0e8110e

Add novel heads: Optimal Transport (det), Info Bottleneck (seg), Harmonic Depth (depth)

Browse files
heads/__init__.py CHANGED
@@ -8,6 +8,7 @@ from .wavelet.head import Wavelet
8
  from .patch_attention.head import PatchAttention
9
  from .graph_crf.head import GraphCRF
10
  from .hypercolumn_linear.head import HypercolumnLinear
 
11
 
12
  REGISTRY = {
13
  "linear_probe": LinearProbe,
@@ -18,6 +19,7 @@ REGISTRY = {
18
  "patch_attention": PatchAttention,
19
  "graph_crf": GraphCRF,
20
  "hypercolumn_linear": HypercolumnLinear,
 
21
  }
22
 
23
  ALL_NAMES = list(REGISTRY.keys())
 
8
  from .patch_attention.head import PatchAttention
9
  from .graph_crf.head import GraphCRF
10
  from .hypercolumn_linear.head import HypercolumnLinear
11
+ from .info_bottleneck.head import InfoBottleneck
12
 
13
  REGISTRY = {
14
  "linear_probe": LinearProbe,
 
19
  "patch_attention": PatchAttention,
20
  "graph_crf": GraphCRF,
21
  "hypercolumn_linear": HypercolumnLinear,
22
+ "info_bottleneck": InfoBottleneck,
23
  }
24
 
25
  ALL_NAMES = list(REGISTRY.keys())
heads/info_bottleneck/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import InfoBottleneck
heads/info_bottleneck/head.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Information Bottleneck: project to d dimensions, classify from the compressed representation.
2
+
3
+ The bottleneck dimension d is the minimum sufficient statistic for segmentation.
4
+ If d=8 works for 150-class ADE20K, the frozen features have at most 8 independent
5
+ directions relevant to semantic segmentation. That is a statement about the
6
+ backbone's feature geometry.
7
+ """
8
+
9
+ import torch.nn as nn
10
+
11
+
12
+ class InfoBottleneck(nn.Module):
13
+ name = "info_bottleneck"
14
+ needs_intermediates = False
15
+
16
+ def __init__(self, feat_dim=768, num_classes=150, bottleneck_dim=8):
17
+ super().__init__()
18
+ self.bottleneck_dim = bottleneck_dim
19
+ self.compress = nn.Conv2d(feat_dim, bottleneck_dim, 1, bias=False)
20
+ self.classify = nn.Conv2d(bottleneck_dim, num_classes, 1)
21
+
22
+ def forward(self, spatial, inter=None):
23
+ z = self.compress(spatial)
24
+ return self.classify(z)