|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from transformers import PreTrainedModel, Dinov2Model, Dinov2Config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor: |
|
|
M = M / reg |
|
|
u, v = torch.zeros_like(log_a), torch.zeros_like(log_b) |
|
|
for _ in range(num_iters): |
|
|
u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze() |
|
|
v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze() |
|
|
return M + u.unsqueeze(2) + v.unsqueeze(1) |
|
|
|
|
|
def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0): |
|
|
batch_size, m, n = S.size() |
|
|
S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device) |
|
|
S_aug[:, :m, :n] = S |
|
|
S_aug[:, m, :] = dustbin_score |
|
|
|
|
|
norm = -torch.tensor(math.log(n + m), device=S.device) |
|
|
log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous() |
|
|
log_a[-1] = log_a[-1] + math.log(n - m) |
|
|
log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1) |
|
|
|
|
|
log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg) |
|
|
return log_P - norm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnyThermalConfig(Dinov2Config): |
|
|
model_type = "anythermal" |
|
|
|
|
|
class AnyThermalSegmentationModel(PreTrainedModel): |
|
|
config_class = AnyThermalConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.backbone = Dinov2Model(config) |
|
|
|
|
|
|
|
|
self.head = nn.Module() |
|
|
self.head.model = nn.Sequential( |
|
|
nn.Conv2d(config.hidden_size, 64, kernel_size=3, padding=1), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout2d(p=0.0), |
|
|
nn.Conv2d(64, config.num_labels, kernel_size=1) |
|
|
) |
|
|
|
|
|
|
|
|
self.register_buffer("norm_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)) |
|
|
self.register_buffer("norm_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def preprocess_input(self, x): |
|
|
""" |
|
|
Replicates preprocess_dinov2: |
|
|
1. Resize to nearest multiple of 14 |
|
|
2. Normalize with ViT stats |
|
|
""" |
|
|
B, C, H, W = x.shape |
|
|
patch_size = 14 |
|
|
|
|
|
|
|
|
new_H = (H // patch_size) * patch_size |
|
|
new_W = (W // patch_size) * patch_size |
|
|
|
|
|
if new_H != H or new_W != W: |
|
|
x = F.interpolate(x, size=(new_H, new_W), mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
|
|
|
if x.max() > 1.0: x = x / 255.0 |
|
|
|
|
|
x = (x - self.norm_mean) / self.norm_std |
|
|
return x |
|
|
|
|
|
def forward(self, pixel_values, labels=None, **kwargs): |
|
|
|
|
|
pixel_values = self.preprocess_input(pixel_values) |
|
|
|
|
|
|
|
|
outputs = self.backbone(pixel_values, **kwargs) |
|
|
features = outputs.last_hidden_state[:, 1:, :] |
|
|
B, L, C = features.shape |
|
|
H = W = int(L**0.5) |
|
|
features = features.permute(0, 2, 1).reshape(B, C, H, W) |
|
|
|
|
|
logits = self.head.model(features) |
|
|
|
|
|
|
|
|
logits = F.interpolate(logits, size=pixel_values.shape[-2:], mode='bilinear', align_corners=False) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnyThermalVPRConfig(Dinov2Config): |
|
|
model_type = "anythermal_vpr" |
|
|
def __init__(self, num_clusters=64, cluster_dim=128, token_dim=256, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.num_clusters = num_clusters |
|
|
self.cluster_dim = cluster_dim |
|
|
self.token_dim = token_dim |
|
|
|
|
|
class SALADHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.num_channels = config.hidden_size |
|
|
self.num_clusters = config.num_clusters |
|
|
self.cluster_dim = config.cluster_dim |
|
|
self.token_dim = config.token_dim |
|
|
|
|
|
self.token_features = nn.Sequential( |
|
|
nn.Linear(self.num_channels, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, self.token_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.cluster_features = nn.Sequential( |
|
|
nn.Conv2d(self.num_channels, 512, 1), |
|
|
nn.Dropout(0.0), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(512, self.cluster_dim, 1) |
|
|
) |
|
|
|
|
|
self.score = nn.Sequential( |
|
|
nn.Conv2d(self.num_channels, 512, 1), |
|
|
nn.Dropout(0.0), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(512, self.num_clusters, 1), |
|
|
) |
|
|
|
|
|
self.dust_bin = nn.Parameter(torch.tensor(1.)) |
|
|
|
|
|
def forward(self, x_tuple): |
|
|
x, t = x_tuple |
|
|
f = self.cluster_features(x).flatten(2) |
|
|
p = self.score(x).flatten(2) |
|
|
t = self.token_features(t) |
|
|
|
|
|
p = get_matching_probs(p, self.dust_bin, 3) |
|
|
p = torch.exp(p) |
|
|
p = p[:, :-1, :] |
|
|
|
|
|
p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1) |
|
|
f_rep = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) |
|
|
|
|
|
vlad = (f_rep * p).sum(dim=-1) |
|
|
vlad = F.normalize(vlad, p=2, dim=1).flatten(1) |
|
|
|
|
|
combined = torch.cat([F.normalize(t, p=2, dim=-1), vlad], dim=-1) |
|
|
return F.normalize(combined, p=2, dim=-1) |
|
|
|
|
|
class AnyThermalVPRModel(PreTrainedModel): |
|
|
config_class = AnyThermalVPRConfig |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.backbone = Dinov2Model(config) |
|
|
|
|
|
self.vpr_head = nn.Sequential(SALADHead(config)) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, pixel_values, **kwargs): |
|
|
outputs = self.backbone(pixel_values, **kwargs) |
|
|
patch_tokens = outputs.last_hidden_state[:, 1:, :].permute(0, 2, 1) |
|
|
B, C, L = patch_tokens.shape |
|
|
H = W = int(L**0.5) |
|
|
patch_tokens = patch_tokens.reshape(B, C, H, W) |
|
|
cls_token = outputs.last_hidden_state[:, 0, :] |
|
|
return self.vpr_head[0]((patch_tokens, cls_token)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AnyThermalDepthConfig(Dinov2Config): |
|
|
model_type = "anythermal_depth" |
|
|
def __init__(self, features=256, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.features = features |
|
|
|
|
|
class ResidualConvUnit(nn.Module): |
|
|
def __init__(self, features): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True) |
|
|
self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
def forward(self, x): |
|
|
out = self.relu(x) |
|
|
out = self.conv1(out) |
|
|
out = self.relu(out) |
|
|
out = self.conv2(out) |
|
|
return out + x |
|
|
|
|
|
class FeatureFusionBlock(nn.Module): |
|
|
def __init__(self, features): |
|
|
super().__init__() |
|
|
self.resConfUnit1 = ResidualConvUnit(features) |
|
|
self.resConfUnit2 = ResidualConvUnit(features) |
|
|
|
|
|
def forward(self, *xs): |
|
|
output = xs[0] |
|
|
if len(xs) == 2: |
|
|
if output.shape[-2:] != xs[1].shape[-2:]: |
|
|
output = F.interpolate(output, size=xs[1].shape[-2:], mode="bilinear", align_corners=True) |
|
|
output = output + self.resConfUnit1(xs[1]) |
|
|
output = self.resConfUnit2(output) |
|
|
output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True) |
|
|
return output |
|
|
|
|
|
class AnyThermalDepthModel(PreTrainedModel): |
|
|
config_class = AnyThermalDepthConfig |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.backbone = Dinov2Model(config) |
|
|
features = config.features |
|
|
|
|
|
self.scratch = nn.Module() |
|
|
self.pretrained = nn.Module() |
|
|
|
|
|
self.scratch.layer1_rn = nn.Conv2d(96, features, 3, 1, 1, bias=False) |
|
|
self.scratch.layer2_rn = nn.Conv2d(192, features, 3, 1, 1, bias=False) |
|
|
self.scratch.layer3_rn = nn.Conv2d(384, features, 3, 1, 1, bias=False) |
|
|
self.scratch.layer4_rn = nn.Conv2d(768, features, 3, 1, 1, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
self.pretrained.act_postprocess1 = nn.Sequential( |
|
|
nn.Identity(), nn.Identity(), nn.Identity(), |
|
|
nn.Conv2d(768, 96, 1), nn.ConvTranspose2d(96, 96, 4, 4) |
|
|
) |
|
|
self.pretrained.act_postprocess2 = nn.Sequential( |
|
|
nn.Identity(), nn.Identity(), nn.Identity(), |
|
|
nn.Conv2d(768, 192, 1), nn.ConvTranspose2d(192, 192, 2, 2) |
|
|
) |
|
|
self.pretrained.act_postprocess3 = nn.Sequential( |
|
|
nn.Identity(), nn.Identity(), nn.Identity(), |
|
|
nn.Conv2d(768, 384, 1) |
|
|
) |
|
|
self.pretrained.act_postprocess4 = nn.Sequential( |
|
|
nn.Identity(), nn.Identity(), nn.Identity(), |
|
|
nn.Conv2d(768, 768, 1), nn.Conv2d(768, 768, 3, 2, 1) |
|
|
) |
|
|
|
|
|
self.scratch.refinenet4 = FeatureFusionBlock(features) |
|
|
self.scratch.refinenet3 = FeatureFusionBlock(features) |
|
|
self.scratch.refinenet2 = FeatureFusionBlock(features) |
|
|
self.scratch.refinenet1 = FeatureFusionBlock(features) |
|
|
|
|
|
self.scratch.output_conv = nn.Sequential( |
|
|
nn.Conv2d(features, 128, 3, 1, 1), |
|
|
nn.Upsample(scale_factor=1.75, mode="bilinear"), |
|
|
nn.Conv2d(128, 32, 3, 1, 1), |
|
|
nn.ReLU(True), |
|
|
nn.Conv2d(32, 1, 1, 1, 0), |
|
|
nn.ReLU(True) |
|
|
) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, pixel_values): |
|
|
outputs = self.backbone(pixel_values, output_hidden_states=True) |
|
|
layers = [outputs.hidden_states[i] for i in [3, 6, 9, 12]] |
|
|
|
|
|
def process(l, h, w): |
|
|
l = l[:, 1:, :].transpose(1, 2) |
|
|
return l.reshape(l.shape[0], l.shape[1], h//14, w//14) |
|
|
|
|
|
b, _, h, w = pixel_values.shape |
|
|
l1, l2, l3, l4 = [process(layers[i], h, w) for i in range(4)] |
|
|
|
|
|
layer_1_rn = self.scratch.layer1_rn(self.pretrained.act_postprocess1(l1)) |
|
|
layer_2_rn = self.scratch.layer2_rn(self.pretrained.act_postprocess2(l2)) |
|
|
layer_3_rn = self.scratch.layer3_rn(self.pretrained.act_postprocess3(l3)) |
|
|
layer_4_rn = self.scratch.layer4_rn(self.pretrained.act_postprocess4(l4)) |
|
|
|
|
|
path_4 = self.scratch.refinenet4(layer_4_rn) |
|
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) |
|
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) |
|
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) |
|
|
|
|
|
return self.scratch.output_conv(path_1).squeeze(1) |
|
|
|
|
|
|
|
|
AnyThermalSegmentationModel.register_for_auto_class("AutoModel") |
|
|
AnyThermalVPRModel.register_for_auto_class("AutoModel") |
|
|
AnyThermalDepthModel.register_for_auto_class("AutoModel") |