Add novel heads: Optimal Transport (det), Info Bottleneck (seg), Harmonic Depth (depth)
Browse files- heads/__init__.py +2 -0
- heads/info_bottleneck/__init__.py +1 -0
- heads/info_bottleneck/head.py +24 -0
heads/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from .wavelet.head import Wavelet
|
|
| 8 |
from .patch_attention.head import PatchAttention
|
| 9 |
from .graph_crf.head import GraphCRF
|
| 10 |
from .hypercolumn_linear.head import HypercolumnLinear
|
|
|
|
| 11 |
|
| 12 |
REGISTRY = {
|
| 13 |
"linear_probe": LinearProbe,
|
|
@@ -18,6 +19,7 @@ REGISTRY = {
|
|
| 18 |
"patch_attention": PatchAttention,
|
| 19 |
"graph_crf": GraphCRF,
|
| 20 |
"hypercolumn_linear": HypercolumnLinear,
|
|
|
|
| 21 |
}
|
| 22 |
|
| 23 |
ALL_NAMES = list(REGISTRY.keys())
|
|
|
|
| 8 |
from .patch_attention.head import PatchAttention
|
| 9 |
from .graph_crf.head import GraphCRF
|
| 10 |
from .hypercolumn_linear.head import HypercolumnLinear
|
| 11 |
+
from .info_bottleneck.head import InfoBottleneck
|
| 12 |
|
| 13 |
REGISTRY = {
|
| 14 |
"linear_probe": LinearProbe,
|
|
|
|
| 19 |
"patch_attention": PatchAttention,
|
| 20 |
"graph_crf": GraphCRF,
|
| 21 |
"hypercolumn_linear": HypercolumnLinear,
|
| 22 |
+
"info_bottleneck": InfoBottleneck,
|
| 23 |
}
|
| 24 |
|
| 25 |
ALL_NAMES = list(REGISTRY.keys())
|
heads/info_bottleneck/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import InfoBottleneck
|
heads/info_bottleneck/head.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Information Bottleneck: project to d dimensions, classify from the compressed representation.
|
| 2 |
+
|
| 3 |
+
The bottleneck dimension d is the minimum sufficient statistic for segmentation.
|
| 4 |
+
If d=8 works for 150-class ADE20K, the frozen features have at most 8 independent
|
| 5 |
+
directions relevant to semantic segmentation. That is a statement about the
|
| 6 |
+
backbone's feature geometry.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class InfoBottleneck(nn.Module):
|
| 13 |
+
name = "info_bottleneck"
|
| 14 |
+
needs_intermediates = False
|
| 15 |
+
|
| 16 |
+
def __init__(self, feat_dim=768, num_classes=150, bottleneck_dim=8):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.bottleneck_dim = bottleneck_dim
|
| 19 |
+
self.compress = nn.Conv2d(feat_dim, bottleneck_dim, 1, bias=False)
|
| 20 |
+
self.classify = nn.Conv2d(bottleneck_dim, num_classes, 1)
|
| 21 |
+
|
| 22 |
+
def forward(self, spatial, inter=None):
|
| 23 |
+
z = self.compress(spatial)
|
| 24 |
+
return self.classify(z)
|