airlabshare commited on
Commit
8fe9811
·
verified ·
1 Parent(s): 096760b

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +70 -0
model.py CHANGED
@@ -38,5 +38,75 @@ class AnyThermalSegmentationModel(PreTrainedModel):
38
  # Upscale to original resolution (14x) [cite: 131]
39
  return F.interpolate(logits, scale_factor=14, mode='bilinear', align_corners=False)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Register for AutoModel discovery
42
  AnyThermalSegmentationModel.register_for_auto_class("AutoModel")
 
38
  # Upscale to original resolution (14x) [cite: 131]
39
  return F.interpolate(logits, scale_factor=14, mode='bilinear', align_corners=False)
40
 
41
+
42
+ # 1. Custom Config to handle SALAD parameters
43
+ class AnyThermalVPRConfig(Dinov2Config):
44
+ model_type = "anythermal_vpr"
45
+ def __init__(self, num_clusters=64, cluster_dim=128, token_dim=256, **kwargs):
46
+ super().__init__(**kwargs)
47
+ self.num_clusters = num_clusters
48
+ self.cluster_dim = cluster_dim
49
+ self.token_dim = token_dim
50
+
51
+ # 2. SALAD Aggregator (Logic from salad.py)
52
+ class SALADHead(nn.Module):
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.num_channels = config.hidden_size
56
+ self.num_clusters = config.num_clusters
57
+ self.cluster_dim = config.cluster_dim
58
+ self.token_dim = config.token_dim
59
+
60
+ self.token_features = nn.Sequential(
61
+ nn.Linear(self.num_channels, 512),
62
+ nn.ReLU(),
63
+ nn.Linear(512, self.token_dim)
64
+ )
65
+ self.cluster_features = nn.Sequential(
66
+ nn.Conv2d(self.num_channels, 512, 1),
67
+ nn.ReLU(),
68
+ nn.Conv2d(512, self.cluster_dim, 1)
69
+ )
70
+ self.score = nn.Sequential(
71
+ nn.Conv2d(self.num_channels, 512, 1),
72
+ nn.ReLU(),
73
+ nn.Conv2d(512, self.num_clusters, 1),
74
+ )
75
+
76
+ def forward(self, x_tuple):
77
+ x, t = x_tuple # patch features [B, C, H/14, W/14], cls token [B, C]
78
+ f = self.cluster_features(x).flatten(2)
79
+ p = F.softmax(self.score(x).flatten(2), dim=1) # Simplified Sinkhorn for inference
80
+ t = self.token_features(t)
81
+
82
+ vlad = (f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1) * p.unsqueeze(1)).sum(dim=-1)
83
+ vlad = F.normalize(vlad, p=2, dim=1).flatten(1)
84
+
85
+ combined = torch.cat([F.normalize(t, p=2, dim=-1), vlad], dim=-1)
86
+ return F.normalize(combined, p=2, dim=-1)
87
+
88
+ # 3. Final VPR Model
89
+ class AnyThermalVPRModel(PreTrainedModel):
90
+ config_class = AnyThermalVPRConfig
91
+
92
+ def __init__(self, config):
93
+ super().__init__(config)
94
+ self.backbone = Dinov2Model(config)
95
+ self.vpr_head = SALADHead(config)
96
+ self.post_init()
97
+
98
+ def forward(self, pixel_values, **kwargs):
99
+ outputs = self.backbone(pixel_values, **kwargs)
100
+ # Prepare inputs for SALAD
101
+ patch_tokens = outputs.last_hidden_state[:, 1:, :].permute(0, 2, 1)
102
+ B, C, L = patch_tokens.shape
103
+ H = W = int(L**0.5)
104
+ patch_tokens = patch_tokens.reshape(B, C, H, W)
105
+ cls_token = outputs.last_hidden_state[:, 0, :]
106
+
107
+ # Global descriptor
108
+ return self.vpr_head((patch_tokens, cls_token))
109
+
110
+ AnyThermalVPRModel.register_for_auto_class("AutoModel")
111
  # Register for AutoModel discovery
112
  AnyThermalSegmentationModel.register_for_auto_class("AutoModel")