BiliSakura commited on
Commit
6ee6ac8
·
verified ·
1 Parent(s): 001049a

Update all files for SegEarth-OV

Browse files
Files changed (1) hide show
  1. OV/upsamplers.py +251 -0
OV/upsamplers.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SimFeatUp upsamplers for dense feature restoration.
3
+ From SegEarth-OV/OV-2 simfeatup_dev. Used by CLIP-based variants (OV, OV-2).
4
+ """
5
+ import math
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ try:
13
+ from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv
14
+ except Exception:
15
+ AdaptiveConv = None
16
+
17
+
18
+ def adaptive_conv_py_simple(input, filters):
19
+ """Pure PyTorch fallback when featup CUDA is unavailable."""
20
+ b, c, h1, w1 = input.shape
21
+ b, h2, w2, f1, f2 = filters.shape
22
+ assert f1 == f2
23
+ t_filters = filters.reshape(b, h2, w2, f1 * f2)
24
+ patches = torch.nn.Unfold(f1)(input).view((b, c, f1 * f2, h2, w2))
25
+ return torch.einsum("bhwf,bcfhw->bchw", t_filters, patches)
26
+
27
+
28
+ def _meshgrid(device, diameter):
29
+ dist_range = torch.linspace(-1, 1, diameter, device=device)
30
+ x, y = torch.meshgrid(dist_range, dist_range, indexing="ij")
31
+ return torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
32
+
33
+
34
+ class Bilinear(torch.nn.Module):
35
+ def forward(self, source, guidance):
36
+ _, _, h, w = guidance.shape
37
+ return F.interpolate(source, (h, w), mode="bilinear")
38
+
39
+
40
+ class LayeredResizeConv(torch.nn.Module):
41
+ def __init__(self, dim, kernel_size=1, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ self.conv1 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
44
+ self.conv2 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
45
+ self.conv3 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
46
+ self.conv4 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
47
+
48
+ def apply_conv(self, source, guidance, conv, activation):
49
+ big_source = F.interpolate(source, scale_factor=2, mode="bilinear")
50
+ _, _, h, w = big_source.shape
51
+ small_guidance = F.interpolate(guidance, (h, w), mode="bilinear")
52
+ output = activation(conv(torch.cat([big_source, small_guidance], dim=1)))
53
+ return big_source + output
54
+
55
+ def forward(self, source, guidance):
56
+ source_2 = self.apply_conv(source, guidance, self.conv1, F.relu)
57
+ source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu)
58
+ source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu)
59
+ source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x)
60
+ return source_16
61
+
62
+
63
+ class SimpleImplicitFeaturizer(torch.nn.Module):
64
+ def __init__(self, n_freqs=20):
65
+ super().__init__()
66
+ self.n_freqs = n_freqs
67
+ self.dim_multiplier = 2
68
+
69
+ def forward(self, x):
70
+ b, c, h, w = x.shape
71
+ dtype = x.dtype
72
+ grid_h = torch.linspace(-1, 1, h, device=x.device, dtype=dtype)
73
+ grid_w = torch.linspace(-1, 1, w, device=x.device, dtype=dtype)
74
+ feats = torch.stack(torch.meshgrid(grid_h, grid_w, indexing="ij")).unsqueeze(0)
75
+ feats = feats.broadcast_to((b, feats.shape[1], h, w))
76
+ freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=x.device)).to(dtype).reshape(
77
+ 1, self.n_freqs, 1, 1, 1
78
+ )
79
+ feats = (feats.unsqueeze(1) * freqs).reshape(b, self.n_freqs * self.dim_multiplier, h, w)
80
+ return torch.cat([torch.sin(feats), torch.cos(feats), x], dim=1)
81
+
82
+
83
+ class IFA(torch.nn.Module):
84
+ def __init__(self, feat_dim, num_scales=20):
85
+ super().__init__()
86
+ self.feat_dim = feat_dim
87
+ self.sin_feats = SimpleImplicitFeaturizer()
88
+ self.mlp = nn.Sequential(
89
+ nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1),
90
+ nn.BatchNorm2d(feat_dim),
91
+ nn.LeakyReLU(),
92
+ nn.Conv2d(feat_dim, feat_dim, 1),
93
+ )
94
+
95
+ def _upsample_2x(self, source):
96
+ b, c, h, w = source.shape
97
+ dtype = source.dtype
98
+ up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest")
99
+ lr_cord = torch.linspace(0, h, steps=h, device=source.device, dtype=dtype)
100
+ hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device, dtype=dtype)
101
+ lr_coords = torch.stack(torch.meshgrid(lr_cord, lr_cord, indexing="ij")).unsqueeze(0)
102
+ hr_coords = torch.stack(torch.meshgrid(hr_cord, hr_cord, indexing="ij")).unsqueeze(0)
103
+ up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest")
104
+ coord_diff = up_lr_coords - hr_coords
105
+ coord_diff_feats = self.sin_feats(coord_diff).to(dtype)
106
+ bcast_coord_feats = coord_diff_feats.broadcast_to((b, coord_diff_feats.shape[1], h * 2, w * 2))
107
+ return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1))
108
+
109
+ def forward(self, source, guidance):
110
+ _, _, gh, gw = guidance.shape
111
+ x = source
112
+ while x.shape[2] < gh or x.shape[3] < gw:
113
+ x = self._upsample_2x(x)
114
+ if x.shape[2] != gh or x.shape[3] != gw:
115
+ x = F.interpolate(x, (gh, gw), mode="bilinear")
116
+ return x
117
+
118
+
119
+ class JBULearnedRange(torch.nn.Module):
120
+ def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3):
121
+ super().__init__()
122
+ self.scale = scale
123
+ self.radius = radius
124
+ self.diameter = self.radius * 2 + 1
125
+ self.guidance_dim = guidance_dim
126
+ self.key_dim = key_dim
127
+ self.feat_dim = feat_dim
128
+ self.range_temp = nn.Parameter(torch.tensor(0.0))
129
+ self.range_proj = nn.Sequential(
130
+ nn.Conv2d(guidance_dim, key_dim, 1, 1),
131
+ nn.GELU(),
132
+ nn.Dropout2d(0.1),
133
+ nn.Conv2d(key_dim, key_dim, 1, 1),
134
+ )
135
+ self.fixup_proj = nn.Sequential(
136
+ nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1),
137
+ nn.GELU(),
138
+ nn.Dropout2d(0.1),
139
+ nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1),
140
+ )
141
+ self.sigma_spatial = nn.Parameter(torch.tensor(1.0))
142
+
143
+ def get_range_kernel(self, x):
144
+ GB, GC, GH, GW = x.shape
145
+ proj_x = self.range_proj(x)
146
+ proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode="reflect")
147
+ queries = (
148
+ torch.nn.Unfold(self.diameter)(proj_x_padded)
149
+ .reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW))
150
+ .permute(0, 1, 3, 4, 2)
151
+ )
152
+ pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4)
153
+ return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1)
154
+
155
+ def get_spatial_kernel(self, device):
156
+ patch = _meshgrid(device, self.diameter)
157
+ return torch.exp(-patch.square().sum(0) / (2 * self.sigma_spatial ** 2)).reshape(
158
+ 1, self.diameter * self.diameter, 1, 1
159
+ )
160
+
161
+ def forward(self, source, guidance):
162
+ GB, GC, GH, GW = guidance.shape
163
+ SB, SC, SH, SQ = source.shape
164
+ assert SB == GB
165
+ dtype = source.dtype
166
+ guidance = guidance.to(dtype)
167
+ spatial_kernel = self.get_spatial_kernel(source.device).to(dtype)
168
+ range_kernel = self.get_range_kernel(guidance).to(dtype)
169
+ combined_kernel = (range_kernel * spatial_kernel).to(dtype)
170
+ combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7)
171
+ combined_kernel += 0.1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1))
172
+ combined_kernel = combined_kernel.permute(0, 2, 3, 1).reshape(
173
+ GB, GH, GW, self.diameter, self.diameter
174
+ )
175
+ hr_source = F.interpolate(source, size=(GH, GW), mode="bicubic", align_corners=False)
176
+ hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode="reflect")
177
+ combined_kernel = combined_kernel.to(hr_source_padded.dtype)
178
+ if AdaptiveConv is not None:
179
+ result = AdaptiveConv.apply(hr_source_padded, combined_kernel)
180
+ else:
181
+ result = adaptive_conv_py_simple(hr_source_padded, combined_kernel)
182
+ return result
183
+
184
+
185
+ class JBUStack(torch.nn.Module):
186
+ def __init__(self, feat_dim, *args, **kwargs):
187
+ super().__init__(*args, **kwargs)
188
+ self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3)
189
+ self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3)
190
+ self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3)
191
+ self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3)
192
+ self.fixup_proj = nn.Sequential(
193
+ nn.Dropout2d(0.2),
194
+ nn.Conv2d(feat_dim, feat_dim, kernel_size=1),
195
+ )
196
+
197
+ def upsample(self, source, guidance, up):
198
+ _, _, h, w = source.shape
199
+ small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2))
200
+ return up(source, small_guidance)
201
+
202
+ def forward(self, source, guidance):
203
+ source_2 = self.upsample(source, guidance, self.up1)
204
+ source_4 = self.upsample(source_2, guidance, self.up2)
205
+ source_8 = self.upsample(source_4, guidance, self.up3)
206
+ source_16 = self.upsample(source_8, guidance, self.up4)
207
+ return self.fixup_proj(source_16) * 0.1 + source_16
208
+
209
+
210
+ class JBUOne(torch.nn.Module):
211
+ def __init__(self, feat_dim, *args, **kwargs):
212
+ super().__init__(*args, **kwargs)
213
+ self.up = JBULearnedRange(3, feat_dim, 32, radius=5)
214
+ self.fixup_proj = nn.Sequential(
215
+ nn.Dropout2d(0.2),
216
+ nn.Conv2d(feat_dim, feat_dim, kernel_size=1),
217
+ )
218
+
219
+ def upsample(self, source, guidance, up):
220
+ _, _, h, w = source.shape
221
+ small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2))
222
+ return up(source, small_guidance)
223
+
224
+ def forward(self, source, guidance):
225
+ source_2 = self.upsample(source, guidance, self.up)
226
+ source_4 = self.upsample(source_2, guidance, self.up)
227
+ source_8 = self.upsample(source_4, guidance, self.up)
228
+ source_16 = self.upsample(source_8, guidance, self.up)
229
+ return self.fixup_proj(source_16) * 0.1 + source_16
230
+
231
+
232
+ FEATUP_CHECKPOINTS = {
233
+ "jbu_one": "simfeatup/xclip_jbu_one_million_aid.ckpt",
234
+ "jbu_stack": "simfeatup/clip_jbu_stack_cocostuff.ckpt",
235
+ "jbu_stack_maskclip": "simfeatup/maskclip_jbu_stack_cocostuff.ckpt",
236
+ }
237
+
238
+
239
+ def get_upsampler(name: str, feat_dim: int):
240
+ if name == "bilinear":
241
+ return Bilinear()
242
+ elif name == "jbu_one":
243
+ return JBUOne(feat_dim)
244
+ elif name == "jbu_stack":
245
+ return JBUStack(feat_dim)
246
+ elif name == "resize_conv":
247
+ return LayeredResizeConv(feat_dim, 1)
248
+ elif name == "ifa":
249
+ return IFA(feat_dim)
250
+ else:
251
+ raise ValueError(f"Unknown upsampler: {name}. Use: bilinear, jbu_one, jbu_stack, resize_conv, ifa")