Update all files for SegEarth-OV
Browse files- OV/upsamplers.py +251 -0
OV/upsamplers.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SimFeatUp upsamplers for dense feature restoration.
|
| 3 |
+
From SegEarth-OV/OV-2 simfeatup_dev. Used by CLIP-based variants (OV, OV-2).
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv
|
| 14 |
+
except Exception:
|
| 15 |
+
AdaptiveConv = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def adaptive_conv_py_simple(input, filters):
|
| 19 |
+
"""Pure PyTorch fallback when featup CUDA is unavailable."""
|
| 20 |
+
b, c, h1, w1 = input.shape
|
| 21 |
+
b, h2, w2, f1, f2 = filters.shape
|
| 22 |
+
assert f1 == f2
|
| 23 |
+
t_filters = filters.reshape(b, h2, w2, f1 * f2)
|
| 24 |
+
patches = torch.nn.Unfold(f1)(input).view((b, c, f1 * f2, h2, w2))
|
| 25 |
+
return torch.einsum("bhwf,bcfhw->bchw", t_filters, patches)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _meshgrid(device, diameter):
|
| 29 |
+
dist_range = torch.linspace(-1, 1, diameter, device=device)
|
| 30 |
+
x, y = torch.meshgrid(dist_range, dist_range, indexing="ij")
|
| 31 |
+
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Bilinear(torch.nn.Module):
|
| 35 |
+
def forward(self, source, guidance):
|
| 36 |
+
_, _, h, w = guidance.shape
|
| 37 |
+
return F.interpolate(source, (h, w), mode="bilinear")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class LayeredResizeConv(torch.nn.Module):
|
| 41 |
+
def __init__(self, dim, kernel_size=1, *args, **kwargs):
|
| 42 |
+
super().__init__(*args, **kwargs)
|
| 43 |
+
self.conv1 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
|
| 44 |
+
self.conv2 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
|
| 45 |
+
self.conv3 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
|
| 46 |
+
self.conv4 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
|
| 47 |
+
|
| 48 |
+
def apply_conv(self, source, guidance, conv, activation):
|
| 49 |
+
big_source = F.interpolate(source, scale_factor=2, mode="bilinear")
|
| 50 |
+
_, _, h, w = big_source.shape
|
| 51 |
+
small_guidance = F.interpolate(guidance, (h, w), mode="bilinear")
|
| 52 |
+
output = activation(conv(torch.cat([big_source, small_guidance], dim=1)))
|
| 53 |
+
return big_source + output
|
| 54 |
+
|
| 55 |
+
def forward(self, source, guidance):
|
| 56 |
+
source_2 = self.apply_conv(source, guidance, self.conv1, F.relu)
|
| 57 |
+
source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu)
|
| 58 |
+
source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu)
|
| 59 |
+
source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x)
|
| 60 |
+
return source_16
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SimpleImplicitFeaturizer(torch.nn.Module):
|
| 64 |
+
def __init__(self, n_freqs=20):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.n_freqs = n_freqs
|
| 67 |
+
self.dim_multiplier = 2
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
b, c, h, w = x.shape
|
| 71 |
+
dtype = x.dtype
|
| 72 |
+
grid_h = torch.linspace(-1, 1, h, device=x.device, dtype=dtype)
|
| 73 |
+
grid_w = torch.linspace(-1, 1, w, device=x.device, dtype=dtype)
|
| 74 |
+
feats = torch.stack(torch.meshgrid(grid_h, grid_w, indexing="ij")).unsqueeze(0)
|
| 75 |
+
feats = feats.broadcast_to((b, feats.shape[1], h, w))
|
| 76 |
+
freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=x.device)).to(dtype).reshape(
|
| 77 |
+
1, self.n_freqs, 1, 1, 1
|
| 78 |
+
)
|
| 79 |
+
feats = (feats.unsqueeze(1) * freqs).reshape(b, self.n_freqs * self.dim_multiplier, h, w)
|
| 80 |
+
return torch.cat([torch.sin(feats), torch.cos(feats), x], dim=1)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class IFA(torch.nn.Module):
|
| 84 |
+
def __init__(self, feat_dim, num_scales=20):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.feat_dim = feat_dim
|
| 87 |
+
self.sin_feats = SimpleImplicitFeaturizer()
|
| 88 |
+
self.mlp = nn.Sequential(
|
| 89 |
+
nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1),
|
| 90 |
+
nn.BatchNorm2d(feat_dim),
|
| 91 |
+
nn.LeakyReLU(),
|
| 92 |
+
nn.Conv2d(feat_dim, feat_dim, 1),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def _upsample_2x(self, source):
|
| 96 |
+
b, c, h, w = source.shape
|
| 97 |
+
dtype = source.dtype
|
| 98 |
+
up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest")
|
| 99 |
+
lr_cord = torch.linspace(0, h, steps=h, device=source.device, dtype=dtype)
|
| 100 |
+
hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device, dtype=dtype)
|
| 101 |
+
lr_coords = torch.stack(torch.meshgrid(lr_cord, lr_cord, indexing="ij")).unsqueeze(0)
|
| 102 |
+
hr_coords = torch.stack(torch.meshgrid(hr_cord, hr_cord, indexing="ij")).unsqueeze(0)
|
| 103 |
+
up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest")
|
| 104 |
+
coord_diff = up_lr_coords - hr_coords
|
| 105 |
+
coord_diff_feats = self.sin_feats(coord_diff).to(dtype)
|
| 106 |
+
bcast_coord_feats = coord_diff_feats.broadcast_to((b, coord_diff_feats.shape[1], h * 2, w * 2))
|
| 107 |
+
return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1))
|
| 108 |
+
|
| 109 |
+
def forward(self, source, guidance):
|
| 110 |
+
_, _, gh, gw = guidance.shape
|
| 111 |
+
x = source
|
| 112 |
+
while x.shape[2] < gh or x.shape[3] < gw:
|
| 113 |
+
x = self._upsample_2x(x)
|
| 114 |
+
if x.shape[2] != gh or x.shape[3] != gw:
|
| 115 |
+
x = F.interpolate(x, (gh, gw), mode="bilinear")
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class JBULearnedRange(torch.nn.Module):
|
| 120 |
+
def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.scale = scale
|
| 123 |
+
self.radius = radius
|
| 124 |
+
self.diameter = self.radius * 2 + 1
|
| 125 |
+
self.guidance_dim = guidance_dim
|
| 126 |
+
self.key_dim = key_dim
|
| 127 |
+
self.feat_dim = feat_dim
|
| 128 |
+
self.range_temp = nn.Parameter(torch.tensor(0.0))
|
| 129 |
+
self.range_proj = nn.Sequential(
|
| 130 |
+
nn.Conv2d(guidance_dim, key_dim, 1, 1),
|
| 131 |
+
nn.GELU(),
|
| 132 |
+
nn.Dropout2d(0.1),
|
| 133 |
+
nn.Conv2d(key_dim, key_dim, 1, 1),
|
| 134 |
+
)
|
| 135 |
+
self.fixup_proj = nn.Sequential(
|
| 136 |
+
nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1),
|
| 137 |
+
nn.GELU(),
|
| 138 |
+
nn.Dropout2d(0.1),
|
| 139 |
+
nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1),
|
| 140 |
+
)
|
| 141 |
+
self.sigma_spatial = nn.Parameter(torch.tensor(1.0))
|
| 142 |
+
|
| 143 |
+
def get_range_kernel(self, x):
|
| 144 |
+
GB, GC, GH, GW = x.shape
|
| 145 |
+
proj_x = self.range_proj(x)
|
| 146 |
+
proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode="reflect")
|
| 147 |
+
queries = (
|
| 148 |
+
torch.nn.Unfold(self.diameter)(proj_x_padded)
|
| 149 |
+
.reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW))
|
| 150 |
+
.permute(0, 1, 3, 4, 2)
|
| 151 |
+
)
|
| 152 |
+
pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4)
|
| 153 |
+
return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1)
|
| 154 |
+
|
| 155 |
+
def get_spatial_kernel(self, device):
|
| 156 |
+
patch = _meshgrid(device, self.diameter)
|
| 157 |
+
return torch.exp(-patch.square().sum(0) / (2 * self.sigma_spatial ** 2)).reshape(
|
| 158 |
+
1, self.diameter * self.diameter, 1, 1
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, source, guidance):
|
| 162 |
+
GB, GC, GH, GW = guidance.shape
|
| 163 |
+
SB, SC, SH, SQ = source.shape
|
| 164 |
+
assert SB == GB
|
| 165 |
+
dtype = source.dtype
|
| 166 |
+
guidance = guidance.to(dtype)
|
| 167 |
+
spatial_kernel = self.get_spatial_kernel(source.device).to(dtype)
|
| 168 |
+
range_kernel = self.get_range_kernel(guidance).to(dtype)
|
| 169 |
+
combined_kernel = (range_kernel * spatial_kernel).to(dtype)
|
| 170 |
+
combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7)
|
| 171 |
+
combined_kernel += 0.1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1))
|
| 172 |
+
combined_kernel = combined_kernel.permute(0, 2, 3, 1).reshape(
|
| 173 |
+
GB, GH, GW, self.diameter, self.diameter
|
| 174 |
+
)
|
| 175 |
+
hr_source = F.interpolate(source, size=(GH, GW), mode="bicubic", align_corners=False)
|
| 176 |
+
hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode="reflect")
|
| 177 |
+
combined_kernel = combined_kernel.to(hr_source_padded.dtype)
|
| 178 |
+
if AdaptiveConv is not None:
|
| 179 |
+
result = AdaptiveConv.apply(hr_source_padded, combined_kernel)
|
| 180 |
+
else:
|
| 181 |
+
result = adaptive_conv_py_simple(hr_source_padded, combined_kernel)
|
| 182 |
+
return result
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class JBUStack(torch.nn.Module):
|
| 186 |
+
def __init__(self, feat_dim, *args, **kwargs):
|
| 187 |
+
super().__init__(*args, **kwargs)
|
| 188 |
+
self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3)
|
| 189 |
+
self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3)
|
| 190 |
+
self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3)
|
| 191 |
+
self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3)
|
| 192 |
+
self.fixup_proj = nn.Sequential(
|
| 193 |
+
nn.Dropout2d(0.2),
|
| 194 |
+
nn.Conv2d(feat_dim, feat_dim, kernel_size=1),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def upsample(self, source, guidance, up):
|
| 198 |
+
_, _, h, w = source.shape
|
| 199 |
+
small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2))
|
| 200 |
+
return up(source, small_guidance)
|
| 201 |
+
|
| 202 |
+
def forward(self, source, guidance):
|
| 203 |
+
source_2 = self.upsample(source, guidance, self.up1)
|
| 204 |
+
source_4 = self.upsample(source_2, guidance, self.up2)
|
| 205 |
+
source_8 = self.upsample(source_4, guidance, self.up3)
|
| 206 |
+
source_16 = self.upsample(source_8, guidance, self.up4)
|
| 207 |
+
return self.fixup_proj(source_16) * 0.1 + source_16
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class JBUOne(torch.nn.Module):
|
| 211 |
+
def __init__(self, feat_dim, *args, **kwargs):
|
| 212 |
+
super().__init__(*args, **kwargs)
|
| 213 |
+
self.up = JBULearnedRange(3, feat_dim, 32, radius=5)
|
| 214 |
+
self.fixup_proj = nn.Sequential(
|
| 215 |
+
nn.Dropout2d(0.2),
|
| 216 |
+
nn.Conv2d(feat_dim, feat_dim, kernel_size=1),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def upsample(self, source, guidance, up):
|
| 220 |
+
_, _, h, w = source.shape
|
| 221 |
+
small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2))
|
| 222 |
+
return up(source, small_guidance)
|
| 223 |
+
|
| 224 |
+
def forward(self, source, guidance):
|
| 225 |
+
source_2 = self.upsample(source, guidance, self.up)
|
| 226 |
+
source_4 = self.upsample(source_2, guidance, self.up)
|
| 227 |
+
source_8 = self.upsample(source_4, guidance, self.up)
|
| 228 |
+
source_16 = self.upsample(source_8, guidance, self.up)
|
| 229 |
+
return self.fixup_proj(source_16) * 0.1 + source_16
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
FEATUP_CHECKPOINTS = {
|
| 233 |
+
"jbu_one": "simfeatup/xclip_jbu_one_million_aid.ckpt",
|
| 234 |
+
"jbu_stack": "simfeatup/clip_jbu_stack_cocostuff.ckpt",
|
| 235 |
+
"jbu_stack_maskclip": "simfeatup/maskclip_jbu_stack_cocostuff.ckpt",
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def get_upsampler(name: str, feat_dim: int):
|
| 240 |
+
if name == "bilinear":
|
| 241 |
+
return Bilinear()
|
| 242 |
+
elif name == "jbu_one":
|
| 243 |
+
return JBUOne(feat_dim)
|
| 244 |
+
elif name == "jbu_stack":
|
| 245 |
+
return JBUStack(feat_dim)
|
| 246 |
+
elif name == "resize_conv":
|
| 247 |
+
return LayeredResizeConv(feat_dim, 1)
|
| 248 |
+
elif name == "ifa":
|
| 249 |
+
return IFA(feat_dim)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(f"Unknown upsampler: {name}. Use: bilinear, jbu_one, jbu_stack, resize_conv, ifa")
|