xjh19972 commited on
Commit
a9dfbf3
·
verified ·
1 Parent(s): 5610575

Add STHN one-stage and two-stage models with demo and examples

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/gt.png filter=lfs diff=lfs merge=lfs -text
STHN_demo.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ STHN Demo: Satellite-Thermal Homography Network
3
+ Supports uploading to / loading from HuggingFace Hub.
4
+ Input: 1 RGB satellite image + 1 thermal image
5
+ Output: 4-point displacement + visualization
6
+ """
7
+ import sys
8
+ import os
9
+ import argparse
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import kornia.geometry.transform as tgm
16
+ import kornia.geometry.bbox as bbox_utils
17
+ from huggingface_hub import PyTorchModelHubMixin
18
+ from PIL import Image
19
+ import torchvision.transforms as transforms
20
+ import matplotlib.pyplot as plt
21
+ import cv2
22
+
23
+ # Import model building blocks from local_pipeline
24
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'local_pipeline'))
25
+ from extractor import BasicEncoderQuarter
26
+ from corr import CorrBlock
27
+ from update import CNN_64
28
+ from utils import coords_grid
29
+
30
+
31
+ # ==============================================================================
32
+ # Model Components (redefined without args dependency for HuggingFace)
33
+ # ==============================================================================
34
+
35
+ class GMA(nn.Module):
36
+ """Update block that predicts delta 4-point displacement from correlation and flow.
37
+ Redefined from local_pipeline/update.py to remove args dependency.
38
+ """
39
+ def __init__(self, corr_level, sz):
40
+ super().__init__()
41
+ if sz == 64:
42
+ if corr_level == 2:
43
+ init_dim = 164 # 2 * 81 + 2
44
+ elif corr_level == 4:
45
+ init_dim = 326 # 4 * 81 + 2
46
+ elif corr_level == 6:
47
+ init_dim = 488 # 6 * 81 + 2
48
+ else:
49
+ raise NotImplementedError(f"corr_level={corr_level} not supported")
50
+ self.cnn = CNN_64(128, init_dim=init_dim)
51
+ else:
52
+ raise NotImplementedError(f"GMA with sz={sz} not supported in this demo")
53
+
54
+ def forward(self, corr, flow):
55
+ return self.cnn(torch.cat((corr, flow), dim=1))
56
+
57
+
58
+ class IHN(nn.Module):
59
+ """Iterative Homography Network.
60
+ Redefined from local_pipeline/model/network.py to remove args dependency.
61
+ State dict keys are compatible with original checkpoints (after stripping 'module.').
62
+ """
63
+ def __init__(self, resize_width, corr_level):
64
+ super().__init__()
65
+ self.resize_width = resize_width
66
+ self.fnet1 = BasicEncoderQuarter(output_dim=256, norm_fn='instance')
67
+ sz = resize_width // 4
68
+ self.update_block_4 = GMA(corr_level, sz)
69
+ self.imagenet_mean = None
70
+ self.imagenet_std = None
71
+
72
+ def get_flow_now_4(self, four_point):
73
+ four_point = four_point / 4
74
+ four_point_org = torch.zeros((2, 2, 2)).to(four_point.device)
75
+ four_point_org[:, 0, 0] = torch.Tensor([0, 0])
76
+ four_point_org[:, 0, 1] = torch.Tensor([self.sz[3] - 1, 0])
77
+ four_point_org[:, 1, 0] = torch.Tensor([0, self.sz[2] - 1])
78
+ four_point_org[:, 1, 1] = torch.Tensor([self.sz[3] - 1, self.sz[2] - 1])
79
+
80
+ four_point_org = four_point_org.unsqueeze(0).repeat(self.sz[0], 1, 1, 1)
81
+ four_point_new = four_point_org + four_point
82
+ four_point_org = four_point_org.flatten(2).permute(0, 2, 1).contiguous()
83
+ four_point_new = four_point_new.flatten(2).permute(0, 2, 1).contiguous()
84
+ H = tgm.get_perspective_transform(four_point_org, four_point_new)
85
+
86
+ gridy, gridx = torch.meshgrid(
87
+ torch.linspace(0, self.resize_width // 4 - 1, steps=self.resize_width // 4),
88
+ torch.linspace(0, self.resize_width // 4 - 1, steps=self.resize_width // 4))
89
+ points = torch.cat(
90
+ (gridx.flatten().unsqueeze(0), gridy.flatten().unsqueeze(0),
91
+ torch.ones((1, self.resize_width // 4 * self.resize_width // 4))),
92
+ dim=0).unsqueeze(0).repeat(H.shape[0], 1, 1).to(four_point.device)
93
+ points_new = H.bmm(points)
94
+ if torch.isnan(points_new).any():
95
+ raise KeyError("Some of transformed coords are NaN!")
96
+ points_new = points_new / points_new[:, 2, :].unsqueeze(1)
97
+ points_new = points_new[:, 0:2, :]
98
+ flow = torch.cat(
99
+ (points_new[:, 0, :].reshape(self.sz[0], self.sz[3], self.sz[2]).unsqueeze(1),
100
+ points_new[:, 1, :].reshape(self.sz[0], self.sz[3], self.sz[2]).unsqueeze(1)),
101
+ dim=1)
102
+ return flow
103
+
104
+ def forward(self, image1, image2, iters_lev0=6, corr_level=2, corr_radius=4):
105
+ if self.imagenet_mean is None:
106
+ self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(image1.device)
107
+ self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(image1.device)
108
+ image1 = (image1.contiguous() - self.imagenet_mean) / self.imagenet_std
109
+ image2 = (image2.contiguous() - self.imagenet_mean) / self.imagenet_std
110
+
111
+ fmap1 = self.fnet1(image1).float()
112
+ fmap2 = self.fnet1(image2).float()
113
+
114
+ corr_fn = CorrBlock(fmap1, fmap2, num_levels=corr_level, radius=corr_radius)
115
+
116
+ N, C, H, W = image1.shape
117
+ coords0 = coords_grid(N, H // 4, W // 4).to(image1.device)
118
+ coords1 = coords_grid(N, H // 4, W // 4).to(image1.device)
119
+
120
+ sz = fmap1.shape
121
+ self.sz = sz
122
+ four_point_disp = torch.zeros((sz[0], 2, 2, 2)).to(fmap1.device)
123
+ four_point_predictions = []
124
+
125
+ for itr in range(iters_lev0):
126
+ corr = corr_fn(coords1)
127
+ flow = coords1 - coords0
128
+ delta_four_point = self.update_block_4(corr, flow)
129
+ try:
130
+ last_four_point_disp = four_point_disp
131
+ four_point_disp = four_point_disp + delta_four_point
132
+ coords1 = self.get_flow_now_4(four_point_disp)
133
+ four_point_predictions.append(four_point_disp)
134
+ except Exception:
135
+ four_point_disp = last_four_point_disp
136
+ coords1 = self.get_flow_now_4(four_point_disp)
137
+ four_point_predictions.append(four_point_disp)
138
+
139
+ return four_point_predictions, four_point_disp
140
+
141
+
142
+ # ==============================================================================
143
+ # STHN HuggingFace Model
144
+ # ==============================================================================
145
+
146
+ class STHN(nn.Module, PyTorchModelHubMixin):
147
+ """
148
+ Satellite-Thermal Homography Network with HuggingFace Hub support.
149
+
150
+ Estimates 4-point homography displacement between a satellite RGB image
151
+ and a thermal image for UAV geo-localization.
152
+ """
153
+ def __init__(self, model_config):
154
+ super().__init__()
155
+ self.model_config = model_config
156
+
157
+ self.resize_width = model_config.get('resize_width', 256)
158
+ self.database_size = model_config.get('database_size', 1536)
159
+ self.corr_level = model_config.get('corr_level', 4)
160
+ self.two_stages = model_config.get('two_stages', False)
161
+ self.iters_lev0 = model_config.get('iters_lev0', 6)
162
+ self.iters_lev1 = model_config.get('iters_lev1', 6)
163
+ self.fine_padding = model_config.get('fine_padding', 0)
164
+
165
+ self.netG = IHN(self.resize_width, self.corr_level)
166
+ if self.two_stages:
167
+ self.netG_fine = IHN(self.resize_width, 2)
168
+
169
+ def forward(self, satellite_image, thermal_image):
170
+ """
171
+ Args:
172
+ satellite_image: [B, 3, database_size, database_size] RGB satellite (values in [0, 1])
173
+ thermal_image: [B, 3, resize_width, resize_width] 3-channel thermal (values in [0, 1])
174
+ Returns:
175
+ four_pred: [B, 2, 2, 2] predicted 4-point displacement at resize_width scale
176
+ Shape meaning: [batch, x/y, top/bottom, left/right]
177
+ """
178
+ image_1 = F.interpolate(satellite_image, size=self.resize_width,
179
+ mode='bilinear', align_corners=True, antialias=True)
180
+ image_2 = thermal_image
181
+
182
+ _, four_pred = self.netG(
183
+ image1=image_1, image2=image_2,
184
+ iters_lev0=self.iters_lev0, corr_level=self.corr_level)
185
+
186
+ if self.two_stages:
187
+ image_1_crop, delta, flow_bbox = self._crop_for_refinement(
188
+ satellite_image, four_pred)
189
+ _, four_pred_fine = self.netG_fine(
190
+ image1=image_1_crop, image2=image_2,
191
+ iters_lev0=self.iters_lev1)
192
+ four_pred = self._combine_coarse_fine(four_pred_fine, delta, flow_bbox)
193
+
194
+ return four_pred
195
+
196
+ def _get_four_point_org(self, size, device):
197
+ fp = torch.zeros((1, 2, 2, 2), device=device)
198
+ fp[0, :, 0, 0] = torch.tensor([0.0, 0.0])
199
+ fp[0, :, 0, 1] = torch.tensor([size - 1.0, 0.0])
200
+ fp[0, :, 1, 0] = torch.tensor([0.0, size - 1.0])
201
+ fp[0, :, 1, 1] = torch.tensor([size - 1.0, size - 1.0])
202
+ return fp
203
+
204
+ def _crop_for_refinement(self, image_1_ori, four_pred):
205
+ device = four_pred.device
206
+ rw = self.resize_width
207
+ ds = self.database_size
208
+ alpha = ds / rw
209
+
210
+ four_point_org = self._get_four_point_org(rw, device)
211
+ four_point = four_pred + four_point_org
212
+
213
+ x = four_point[:, 0].clone()
214
+ y = four_point[:, 1].clone()
215
+
216
+ x[:, :, 0] = x[:, :, 0] * alpha
217
+ x[:, :, 1] = (x[:, :, 1] + 1) * alpha
218
+ y[:, 0, :] = y[:, 0, :] * alpha
219
+ y[:, 1, :] = (y[:, 1, :] + 1) * alpha
220
+
221
+ left = torch.min(x.view(x.shape[0], -1), dim=1)[0]
222
+ right = torch.max(x.view(x.shape[0], -1), dim=1)[0]
223
+ top = torch.min(y.view(y.shape[0], -1), dim=1)[0]
224
+ bottom = torch.max(y.view(y.shape[0], -1), dim=1)[0]
225
+
226
+ w = torch.max(torch.stack([right - left, bottom - top], dim=1), dim=1)[0]
227
+ c = torch.stack([(left + right) / 2, (bottom + top) / 2], dim=1)
228
+
229
+ w_padded = w + 2 * self.fine_padding
230
+ crop_top_left = c + torch.stack([-w_padded / 2, -w_padded / 2], dim=1)
231
+ x_start = crop_top_left[:, 0]
232
+ y_start = crop_top_left[:, 1]
233
+
234
+ bbox_s = bbox_utils.bbox_generator(x_start, y_start, w_padded, w_padded)
235
+ delta = (w_padded / rw).unsqueeze(1).unsqueeze(1).unsqueeze(1)
236
+ image_1_crop = tgm.crop_and_resize(image_1_ori, bbox_s, (rw, rw))
237
+
238
+ bbox_s_swap = torch.stack(
239
+ [bbox_s[:, 0], bbox_s[:, 1], bbox_s[:, 3], bbox_s[:, 2]], dim=1)
240
+ four_cor_bbox = bbox_s_swap.permute(0, 2, 1).view(-1, 2, 2, 2)
241
+ four_point_org_large = self._get_four_point_org(ds, device)
242
+ flow_bbox = four_cor_bbox - four_point_org_large
243
+
244
+ return image_1_crop.detach(), delta.detach(), flow_bbox.detach()
245
+
246
+ def _combine_coarse_fine(self, four_pred_fine, delta, flow_bbox):
247
+ alpha = self.database_size / self.resize_width
248
+ kappa = delta / alpha
249
+ return four_pred_fine * kappa + flow_bbox / alpha
250
+
251
+ @classmethod
252
+ def from_local_checkpoint(cls, checkpoint_path, model_config):
253
+ """Load model from a local training checkpoint (.pth file)."""
254
+ model = cls(model_config)
255
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
256
+
257
+ netG_state = {k.replace('module.', ''): v for k, v in ckpt['netG'].items()}
258
+ model.netG.load_state_dict(netG_state, strict=True)
259
+
260
+ if model_config.get('two_stages', False) and ckpt.get('netG_fine') is not None:
261
+ netG_fine_state = {k.replace('module.', ''): v
262
+ for k, v in ckpt['netG_fine'].items()}
263
+ model.netG_fine.load_state_dict(netG_fine_state, strict=True)
264
+
265
+ return model
266
+
267
+
268
+ # ==============================================================================
269
+ # Preprocessing & Visualization
270
+ # ==============================================================================
271
+
272
+ def load_and_preprocess_satellite(image_path, database_size):
273
+ image = Image.open(image_path).convert('RGB')
274
+ transform = transforms.Compose([
275
+ transforms.Resize([database_size, database_size]),
276
+ transforms.ToTensor(),
277
+ ])
278
+ return transform(image).unsqueeze(0)
279
+
280
+
281
+ def load_and_preprocess_thermal(image_path, resize_width):
282
+ image = Image.open(image_path).convert('L')
283
+ transform = transforms.Compose([
284
+ transforms.Grayscale(num_output_channels=3),
285
+ transforms.Resize([resize_width, resize_width]),
286
+ transforms.ToTensor(),
287
+ ])
288
+ return transform(image).unsqueeze(0)
289
+
290
+
291
+ def visualize_result(satellite_image, thermal_image, four_pred, resize_width,
292
+ database_size, save_path='examples/STHN_result.png',
293
+ gt_image_path=None):
294
+ alpha = database_size / resize_width
295
+
296
+ four_point_org = torch.zeros((1, 2, 2, 2))
297
+ four_point_org[:, :, 0, 0] = torch.tensor([0, 0])
298
+ four_point_org[:, :, 0, 1] = torch.tensor([resize_width - 1, 0])
299
+ four_point_org[:, :, 1, 0] = torch.tensor([0, resize_width - 1])
300
+ four_point_org[:, :, 1, 1] = torch.tensor([resize_width - 1, resize_width - 1])
301
+
302
+ four_point_pred = four_pred.cpu() + four_point_org
303
+
304
+ sat_display = F.interpolate(satellite_image, size=resize_width,
305
+ mode='bilinear', align_corners=True, antialias=True)
306
+ sat_np = (sat_display[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
307
+ thermal_np = (thermal_image[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
308
+
309
+ pred_pts = four_point_pred[0].numpy()
310
+ pts = np.array([
311
+ [pred_pts[0, 0, 0], pred_pts[1, 0, 0]], # TL
312
+ [pred_pts[0, 0, 1], pred_pts[1, 0, 1]], # TR
313
+ [pred_pts[0, 1, 1], pred_pts[1, 1, 1]], # BR
314
+ [pred_pts[0, 1, 0], pred_pts[1, 1, 0]], # BL
315
+ ], dtype=np.int32).reshape((-1, 1, 2))
316
+
317
+ sat_with_bbox = sat_np.copy()
318
+ cv2.polylines(sat_with_bbox, [pts], True, (0, 255, 0), 2)
319
+
320
+ four_point_org_flat = four_point_org.flatten(2).permute(0, 2, 1).contiguous()
321
+ four_point_pred_flat = four_point_pred.flatten(2).permute(0, 2, 1).contiguous()
322
+ H = tgm.get_perspective_transform(four_point_org_flat, four_point_pred_flat)
323
+ warped_thermal = tgm.warp_perspective(thermal_image.cpu(), H,
324
+ (resize_width, resize_width))
325
+ warped_np = (warped_thermal[0].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
326
+
327
+ # Determine layout based on whether ground truth is available
328
+ has_gt = gt_image_path is not None and os.path.exists(gt_image_path)
329
+ ncols = 5 if has_gt else 4
330
+ fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, 5))
331
+
332
+ axes[0].imshow(sat_np)
333
+ axes[0].set_title('Satellite Image')
334
+ axes[0].axis('off')
335
+
336
+ axes[1].imshow(thermal_np, cmap='gray')
337
+ axes[1].set_title('Thermal Image')
338
+ axes[1].axis('off')
339
+
340
+ axes[2].imshow(sat_with_bbox)
341
+ axes[2].set_title('Predicted Alignment (green bbox)')
342
+ axes[2].axis('off')
343
+
344
+ axes[3].imshow(sat_np)
345
+ axes[3].imshow(warped_np, alpha=0.5)
346
+ axes[3].set_title('Overlay')
347
+ axes[3].axis('off')
348
+
349
+ if has_gt:
350
+ gt_img = np.array(Image.open(gt_image_path).convert('RGB'))
351
+ axes[4].imshow(gt_img)
352
+ axes[4].set_title('Ground Truth')
353
+ axes[4].axis('off')
354
+
355
+ plt.tight_layout()
356
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
357
+ print(f"\nVisualization saved to {save_path}")
358
+
359
+ disp = four_pred[0].cpu()
360
+ disp_scaled = disp * alpha
361
+ print(f"\n4-Point Displacement (pixels at {resize_width}x{resize_width} scale):")
362
+ print(f" Top-Left: dx={disp[0, 0, 0]:.2f}, dy={disp[1, 0, 0]:.2f}")
363
+ print(f" Top-Right: dx={disp[0, 0, 1]:.2f}, dy={disp[1, 0, 1]:.2f}")
364
+ print(f" Bottom-Left: dx={disp[0, 1, 0]:.2f}, dy={disp[1, 1, 0]:.2f}")
365
+ print(f" Bottom-Right: dx={disp[0, 1, 1]:.2f}, dy={disp[1, 1, 1]:.2f}")
366
+ print(f"\n4-Point Displacement (scaled to {database_size}x{database_size}):")
367
+ print(f" Top-Left: dx={disp_scaled[0, 0, 0]:.2f}, dy={disp_scaled[1, 0, 0]:.2f}")
368
+ print(f" Top-Right: dx={disp_scaled[0, 0, 1]:.2f}, dy={disp_scaled[1, 0, 1]:.2f}")
369
+ print(f" Bottom-Left: dx={disp_scaled[0, 1, 0]:.2f}, dy={disp_scaled[1, 1, 0]:.2f}")
370
+ print(f" Bottom-Right: dx={disp_scaled[0, 1, 1]:.2f}, dy={disp_scaled[1, 1, 1]:.2f}")
371
+
372
+
373
+ # ==============================================================================
374
+ # Main
375
+ # ==============================================================================
376
+
377
+ ONE_STAGE_CONFIG = {
378
+ 'resize_width': 256,
379
+ 'database_size': 1536,
380
+ 'corr_level': 4,
381
+ 'iters_lev0': 6,
382
+ 'iters_lev1': 6,
383
+ 'two_stages': False,
384
+ 'fine_padding': 0,
385
+ }
386
+
387
+ TWO_STAGE_CONFIG = {
388
+ 'resize_width': 256,
389
+ 'database_size': 1536,
390
+ 'corr_level': 4,
391
+ 'iters_lev0': 6,
392
+ 'iters_lev1': 6,
393
+ 'two_stages': True,
394
+ 'fine_padding': 0,
395
+ }
396
+
397
+ if __name__ == "__main__":
398
+ parser = argparse.ArgumentParser(description='STHN Demo: Satellite-Thermal Homography Estimation')
399
+ parser.add_argument('--satellite_image', type=str, default='examples/img1.png',
400
+ help='Path to satellite RGB image')
401
+ parser.add_argument('--thermal_image', type=str, default='examples/img2.png',
402
+ help='Path to thermal image')
403
+ parser.add_argument('--gt_image', type=str, default='examples/gt.png',
404
+ help='Path to ground truth overlay image')
405
+ parser.add_argument('--two_stages', action='store_true',
406
+ help='Use two-stage model for higher accuracy')
407
+ parser.add_argument('--save_path', type=str, default=None,
408
+ help='Output visualization path')
409
+ parser.add_argument('--hf_model', type=str, default=None,
410
+ help='HuggingFace model name (e.g., arplaboratory/STHN_one_stage)')
411
+ parser.add_argument('--local_checkpoint', type=str, default=None,
412
+ help='Path to local checkpoint (.pth)')
413
+ parser.add_argument('--push_to_hub', type=str, default=None,
414
+ help='Upload model to HuggingFace Hub (e.g., arplaboratory/STHN_one_stage)')
415
+ args = parser.parse_args()
416
+
417
+ device = "cuda" if torch.cuda.is_available() else "cpu"
418
+ config = TWO_STAGE_CONFIG if args.two_stages else ONE_STAGE_CONFIG
419
+ if args.save_path is None:
420
+ args.save_path = 'examples/STHN_result_two_stage.png' if args.two_stages else 'examples/STHN_result_one_stage.png'
421
+
422
+ # ---- Load Model ----
423
+ if args.local_checkpoint:
424
+ print(f"Loading model from local checkpoint: {args.local_checkpoint}")
425
+ model = STHN.from_local_checkpoint(args.local_checkpoint, config)
426
+ elif args.hf_model:
427
+ print(f"Loading model from HuggingFace Hub: {args.hf_model}")
428
+ model = STHN.from_pretrained(args.hf_model)
429
+ else:
430
+ default_ckpt = '1536_one_stage/STHN.pth' if not args.two_stages else '1536_two_stages/STHN.pth'
431
+ if os.path.exists(default_ckpt):
432
+ print(f"Loading model from default checkpoint: {default_ckpt}")
433
+ model = STHN.from_local_checkpoint(default_ckpt, config)
434
+ else:
435
+ print("No checkpoint found. Please specify --hf_model or --local_checkpoint")
436
+ sys.exit(1)
437
+
438
+ model = model.to(device)
439
+ model.eval()
440
+
441
+ # ---- Push to HuggingFace Hub ----
442
+ if args.push_to_hub:
443
+ print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
444
+ model.push_to_hub(args.push_to_hub)
445
+ print("Done!")
446
+
447
+ # ---- Run Inference ----
448
+ print(f"Running inference on {device}...")
449
+ satellite = load_and_preprocess_satellite(
450
+ args.satellite_image, config['database_size']).to(device)
451
+ thermal = load_and_preprocess_thermal(
452
+ args.thermal_image, config['resize_width']).to(device)
453
+
454
+ with torch.no_grad():
455
+ four_pred = model(satellite, thermal)
456
+
457
+ visualize_result(satellite, thermal, four_pred,
458
+ config['resize_width'], config['database_size'],
459
+ args.save_path, gt_image_path=args.gt_image)
examples/gt.png ADDED

Git LFS Details

  • SHA256: fba93ce0c34caa7b90981d0723b784c650c1b3d9872e8db905f0de3101fd8e7f
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
examples/img1.png ADDED
examples/img2.png ADDED
one_stage/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
one_stage/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "corr_level": 4,
4
+ "database_size": 1536,
5
+ "fine_padding": 0,
6
+ "iters_lev0": 6,
7
+ "iters_lev1": 6,
8
+ "resize_width": 256,
9
+ "two_stages": false
10
+ }
11
+ }
one_stage/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:359345e2f46f6c65ee2ca23ffa1ff6ab73a8af5ae348862799d1843edbd068e2
3
+ size 6508704
two_stages/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
two_stages/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "corr_level": 4,
4
+ "database_size": 1536,
5
+ "fine_padding": 0,
6
+ "iters_lev0": 6,
7
+ "iters_lev1": 6,
8
+ "resize_width": 256,
9
+ "two_stages": true
10
+ }
11
+ }
two_stages/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baefb3ae87be966349a3dab2f46b60aa7333489b2e2ce9edebfd06ac21711a47
3
+ size 12271248