Add STHN one-stage and two-stage models with demo and examples
Browse files- .gitattributes +1 -0
- STHN_demo.py +459 -0
- examples/gt.png +3 -0
- examples/img1.png +0 -0
- examples/img2.png +0 -0
- one_stage/README.md +10 -0
- one_stage/config.json +11 -0
- one_stage/model.safetensors +3 -0
- two_stages/README.md +10 -0
- two_stages/config.json +11 -0
- two_stages/model.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/gt.png filter=lfs diff=lfs merge=lfs -text
|
STHN_demo.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
STHN Demo: Satellite-Thermal Homography Network
|
| 3 |
+
Supports uploading to / loading from HuggingFace Hub.
|
| 4 |
+
Input: 1 RGB satellite image + 1 thermal image
|
| 5 |
+
Output: 4-point displacement + visualization
|
| 6 |
+
"""
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import kornia.geometry.transform as tgm
|
| 16 |
+
import kornia.geometry.bbox as bbox_utils
|
| 17 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import torchvision.transforms as transforms
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
# Import model building blocks from local_pipeline
|
| 24 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'local_pipeline'))
|
| 25 |
+
from extractor import BasicEncoderQuarter
|
| 26 |
+
from corr import CorrBlock
|
| 27 |
+
from update import CNN_64
|
| 28 |
+
from utils import coords_grid
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ==============================================================================
|
| 32 |
+
# Model Components (redefined without args dependency for HuggingFace)
|
| 33 |
+
# ==============================================================================
|
| 34 |
+
|
| 35 |
+
class GMA(nn.Module):
|
| 36 |
+
"""Update block that predicts delta 4-point displacement from correlation and flow.
|
| 37 |
+
Redefined from local_pipeline/update.py to remove args dependency.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, corr_level, sz):
|
| 40 |
+
super().__init__()
|
| 41 |
+
if sz == 64:
|
| 42 |
+
if corr_level == 2:
|
| 43 |
+
init_dim = 164 # 2 * 81 + 2
|
| 44 |
+
elif corr_level == 4:
|
| 45 |
+
init_dim = 326 # 4 * 81 + 2
|
| 46 |
+
elif corr_level == 6:
|
| 47 |
+
init_dim = 488 # 6 * 81 + 2
|
| 48 |
+
else:
|
| 49 |
+
raise NotImplementedError(f"corr_level={corr_level} not supported")
|
| 50 |
+
self.cnn = CNN_64(128, init_dim=init_dim)
|
| 51 |
+
else:
|
| 52 |
+
raise NotImplementedError(f"GMA with sz={sz} not supported in this demo")
|
| 53 |
+
|
| 54 |
+
def forward(self, corr, flow):
|
| 55 |
+
return self.cnn(torch.cat((corr, flow), dim=1))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class IHN(nn.Module):
|
| 59 |
+
"""Iterative Homography Network.
|
| 60 |
+
Redefined from local_pipeline/model/network.py to remove args dependency.
|
| 61 |
+
State dict keys are compatible with original checkpoints (after stripping 'module.').
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self, resize_width, corr_level):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.resize_width = resize_width
|
| 66 |
+
self.fnet1 = BasicEncoderQuarter(output_dim=256, norm_fn='instance')
|
| 67 |
+
sz = resize_width // 4
|
| 68 |
+
self.update_block_4 = GMA(corr_level, sz)
|
| 69 |
+
self.imagenet_mean = None
|
| 70 |
+
self.imagenet_std = None
|
| 71 |
+
|
| 72 |
+
def get_flow_now_4(self, four_point):
|
| 73 |
+
four_point = four_point / 4
|
| 74 |
+
four_point_org = torch.zeros((2, 2, 2)).to(four_point.device)
|
| 75 |
+
four_point_org[:, 0, 0] = torch.Tensor([0, 0])
|
| 76 |
+
four_point_org[:, 0, 1] = torch.Tensor([self.sz[3] - 1, 0])
|
| 77 |
+
four_point_org[:, 1, 0] = torch.Tensor([0, self.sz[2] - 1])
|
| 78 |
+
four_point_org[:, 1, 1] = torch.Tensor([self.sz[3] - 1, self.sz[2] - 1])
|
| 79 |
+
|
| 80 |
+
four_point_org = four_point_org.unsqueeze(0).repeat(self.sz[0], 1, 1, 1)
|
| 81 |
+
four_point_new = four_point_org + four_point
|
| 82 |
+
four_point_org = four_point_org.flatten(2).permute(0, 2, 1).contiguous()
|
| 83 |
+
four_point_new = four_point_new.flatten(2).permute(0, 2, 1).contiguous()
|
| 84 |
+
H = tgm.get_perspective_transform(four_point_org, four_point_new)
|
| 85 |
+
|
| 86 |
+
gridy, gridx = torch.meshgrid(
|
| 87 |
+
torch.linspace(0, self.resize_width // 4 - 1, steps=self.resize_width // 4),
|
| 88 |
+
torch.linspace(0, self.resize_width // 4 - 1, steps=self.resize_width // 4))
|
| 89 |
+
points = torch.cat(
|
| 90 |
+
(gridx.flatten().unsqueeze(0), gridy.flatten().unsqueeze(0),
|
| 91 |
+
torch.ones((1, self.resize_width // 4 * self.resize_width // 4))),
|
| 92 |
+
dim=0).unsqueeze(0).repeat(H.shape[0], 1, 1).to(four_point.device)
|
| 93 |
+
points_new = H.bmm(points)
|
| 94 |
+
if torch.isnan(points_new).any():
|
| 95 |
+
raise KeyError("Some of transformed coords are NaN!")
|
| 96 |
+
points_new = points_new / points_new[:, 2, :].unsqueeze(1)
|
| 97 |
+
points_new = points_new[:, 0:2, :]
|
| 98 |
+
flow = torch.cat(
|
| 99 |
+
(points_new[:, 0, :].reshape(self.sz[0], self.sz[3], self.sz[2]).unsqueeze(1),
|
| 100 |
+
points_new[:, 1, :].reshape(self.sz[0], self.sz[3], self.sz[2]).unsqueeze(1)),
|
| 101 |
+
dim=1)
|
| 102 |
+
return flow
|
| 103 |
+
|
| 104 |
+
def forward(self, image1, image2, iters_lev0=6, corr_level=2, corr_radius=4):
|
| 105 |
+
if self.imagenet_mean is None:
|
| 106 |
+
self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(image1.device)
|
| 107 |
+
self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(image1.device)
|
| 108 |
+
image1 = (image1.contiguous() - self.imagenet_mean) / self.imagenet_std
|
| 109 |
+
image2 = (image2.contiguous() - self.imagenet_mean) / self.imagenet_std
|
| 110 |
+
|
| 111 |
+
fmap1 = self.fnet1(image1).float()
|
| 112 |
+
fmap2 = self.fnet1(image2).float()
|
| 113 |
+
|
| 114 |
+
corr_fn = CorrBlock(fmap1, fmap2, num_levels=corr_level, radius=corr_radius)
|
| 115 |
+
|
| 116 |
+
N, C, H, W = image1.shape
|
| 117 |
+
coords0 = coords_grid(N, H // 4, W // 4).to(image1.device)
|
| 118 |
+
coords1 = coords_grid(N, H // 4, W // 4).to(image1.device)
|
| 119 |
+
|
| 120 |
+
sz = fmap1.shape
|
| 121 |
+
self.sz = sz
|
| 122 |
+
four_point_disp = torch.zeros((sz[0], 2, 2, 2)).to(fmap1.device)
|
| 123 |
+
four_point_predictions = []
|
| 124 |
+
|
| 125 |
+
for itr in range(iters_lev0):
|
| 126 |
+
corr = corr_fn(coords1)
|
| 127 |
+
flow = coords1 - coords0
|
| 128 |
+
delta_four_point = self.update_block_4(corr, flow)
|
| 129 |
+
try:
|
| 130 |
+
last_four_point_disp = four_point_disp
|
| 131 |
+
four_point_disp = four_point_disp + delta_four_point
|
| 132 |
+
coords1 = self.get_flow_now_4(four_point_disp)
|
| 133 |
+
four_point_predictions.append(four_point_disp)
|
| 134 |
+
except Exception:
|
| 135 |
+
four_point_disp = last_four_point_disp
|
| 136 |
+
coords1 = self.get_flow_now_4(four_point_disp)
|
| 137 |
+
four_point_predictions.append(four_point_disp)
|
| 138 |
+
|
| 139 |
+
return four_point_predictions, four_point_disp
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ==============================================================================
|
| 143 |
+
# STHN HuggingFace Model
|
| 144 |
+
# ==============================================================================
|
| 145 |
+
|
| 146 |
+
class STHN(nn.Module, PyTorchModelHubMixin):
|
| 147 |
+
"""
|
| 148 |
+
Satellite-Thermal Homography Network with HuggingFace Hub support.
|
| 149 |
+
|
| 150 |
+
Estimates 4-point homography displacement between a satellite RGB image
|
| 151 |
+
and a thermal image for UAV geo-localization.
|
| 152 |
+
"""
|
| 153 |
+
def __init__(self, model_config):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.model_config = model_config
|
| 156 |
+
|
| 157 |
+
self.resize_width = model_config.get('resize_width', 256)
|
| 158 |
+
self.database_size = model_config.get('database_size', 1536)
|
| 159 |
+
self.corr_level = model_config.get('corr_level', 4)
|
| 160 |
+
self.two_stages = model_config.get('two_stages', False)
|
| 161 |
+
self.iters_lev0 = model_config.get('iters_lev0', 6)
|
| 162 |
+
self.iters_lev1 = model_config.get('iters_lev1', 6)
|
| 163 |
+
self.fine_padding = model_config.get('fine_padding', 0)
|
| 164 |
+
|
| 165 |
+
self.netG = IHN(self.resize_width, self.corr_level)
|
| 166 |
+
if self.two_stages:
|
| 167 |
+
self.netG_fine = IHN(self.resize_width, 2)
|
| 168 |
+
|
| 169 |
+
def forward(self, satellite_image, thermal_image):
|
| 170 |
+
"""
|
| 171 |
+
Args:
|
| 172 |
+
satellite_image: [B, 3, database_size, database_size] RGB satellite (values in [0, 1])
|
| 173 |
+
thermal_image: [B, 3, resize_width, resize_width] 3-channel thermal (values in [0, 1])
|
| 174 |
+
Returns:
|
| 175 |
+
four_pred: [B, 2, 2, 2] predicted 4-point displacement at resize_width scale
|
| 176 |
+
Shape meaning: [batch, x/y, top/bottom, left/right]
|
| 177 |
+
"""
|
| 178 |
+
image_1 = F.interpolate(satellite_image, size=self.resize_width,
|
| 179 |
+
mode='bilinear', align_corners=True, antialias=True)
|
| 180 |
+
image_2 = thermal_image
|
| 181 |
+
|
| 182 |
+
_, four_pred = self.netG(
|
| 183 |
+
image1=image_1, image2=image_2,
|
| 184 |
+
iters_lev0=self.iters_lev0, corr_level=self.corr_level)
|
| 185 |
+
|
| 186 |
+
if self.two_stages:
|
| 187 |
+
image_1_crop, delta, flow_bbox = self._crop_for_refinement(
|
| 188 |
+
satellite_image, four_pred)
|
| 189 |
+
_, four_pred_fine = self.netG_fine(
|
| 190 |
+
image1=image_1_crop, image2=image_2,
|
| 191 |
+
iters_lev0=self.iters_lev1)
|
| 192 |
+
four_pred = self._combine_coarse_fine(four_pred_fine, delta, flow_bbox)
|
| 193 |
+
|
| 194 |
+
return four_pred
|
| 195 |
+
|
| 196 |
+
def _get_four_point_org(self, size, device):
|
| 197 |
+
fp = torch.zeros((1, 2, 2, 2), device=device)
|
| 198 |
+
fp[0, :, 0, 0] = torch.tensor([0.0, 0.0])
|
| 199 |
+
fp[0, :, 0, 1] = torch.tensor([size - 1.0, 0.0])
|
| 200 |
+
fp[0, :, 1, 0] = torch.tensor([0.0, size - 1.0])
|
| 201 |
+
fp[0, :, 1, 1] = torch.tensor([size - 1.0, size - 1.0])
|
| 202 |
+
return fp
|
| 203 |
+
|
| 204 |
+
def _crop_for_refinement(self, image_1_ori, four_pred):
|
| 205 |
+
device = four_pred.device
|
| 206 |
+
rw = self.resize_width
|
| 207 |
+
ds = self.database_size
|
| 208 |
+
alpha = ds / rw
|
| 209 |
+
|
| 210 |
+
four_point_org = self._get_four_point_org(rw, device)
|
| 211 |
+
four_point = four_pred + four_point_org
|
| 212 |
+
|
| 213 |
+
x = four_point[:, 0].clone()
|
| 214 |
+
y = four_point[:, 1].clone()
|
| 215 |
+
|
| 216 |
+
x[:, :, 0] = x[:, :, 0] * alpha
|
| 217 |
+
x[:, :, 1] = (x[:, :, 1] + 1) * alpha
|
| 218 |
+
y[:, 0, :] = y[:, 0, :] * alpha
|
| 219 |
+
y[:, 1, :] = (y[:, 1, :] + 1) * alpha
|
| 220 |
+
|
| 221 |
+
left = torch.min(x.view(x.shape[0], -1), dim=1)[0]
|
| 222 |
+
right = torch.max(x.view(x.shape[0], -1), dim=1)[0]
|
| 223 |
+
top = torch.min(y.view(y.shape[0], -1), dim=1)[0]
|
| 224 |
+
bottom = torch.max(y.view(y.shape[0], -1), dim=1)[0]
|
| 225 |
+
|
| 226 |
+
w = torch.max(torch.stack([right - left, bottom - top], dim=1), dim=1)[0]
|
| 227 |
+
c = torch.stack([(left + right) / 2, (bottom + top) / 2], dim=1)
|
| 228 |
+
|
| 229 |
+
w_padded = w + 2 * self.fine_padding
|
| 230 |
+
crop_top_left = c + torch.stack([-w_padded / 2, -w_padded / 2], dim=1)
|
| 231 |
+
x_start = crop_top_left[:, 0]
|
| 232 |
+
y_start = crop_top_left[:, 1]
|
| 233 |
+
|
| 234 |
+
bbox_s = bbox_utils.bbox_generator(x_start, y_start, w_padded, w_padded)
|
| 235 |
+
delta = (w_padded / rw).unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
| 236 |
+
image_1_crop = tgm.crop_and_resize(image_1_ori, bbox_s, (rw, rw))
|
| 237 |
+
|
| 238 |
+
bbox_s_swap = torch.stack(
|
| 239 |
+
[bbox_s[:, 0], bbox_s[:, 1], bbox_s[:, 3], bbox_s[:, 2]], dim=1)
|
| 240 |
+
four_cor_bbox = bbox_s_swap.permute(0, 2, 1).view(-1, 2, 2, 2)
|
| 241 |
+
four_point_org_large = self._get_four_point_org(ds, device)
|
| 242 |
+
flow_bbox = four_cor_bbox - four_point_org_large
|
| 243 |
+
|
| 244 |
+
return image_1_crop.detach(), delta.detach(), flow_bbox.detach()
|
| 245 |
+
|
| 246 |
+
def _combine_coarse_fine(self, four_pred_fine, delta, flow_bbox):
|
| 247 |
+
alpha = self.database_size / self.resize_width
|
| 248 |
+
kappa = delta / alpha
|
| 249 |
+
return four_pred_fine * kappa + flow_bbox / alpha
|
| 250 |
+
|
| 251 |
+
@classmethod
|
| 252 |
+
def from_local_checkpoint(cls, checkpoint_path, model_config):
|
| 253 |
+
"""Load model from a local training checkpoint (.pth file)."""
|
| 254 |
+
model = cls(model_config)
|
| 255 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 256 |
+
|
| 257 |
+
netG_state = {k.replace('module.', ''): v for k, v in ckpt['netG'].items()}
|
| 258 |
+
model.netG.load_state_dict(netG_state, strict=True)
|
| 259 |
+
|
| 260 |
+
if model_config.get('two_stages', False) and ckpt.get('netG_fine') is not None:
|
| 261 |
+
netG_fine_state = {k.replace('module.', ''): v
|
| 262 |
+
for k, v in ckpt['netG_fine'].items()}
|
| 263 |
+
model.netG_fine.load_state_dict(netG_fine_state, strict=True)
|
| 264 |
+
|
| 265 |
+
return model
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ==============================================================================
|
| 269 |
+
# Preprocessing & Visualization
|
| 270 |
+
# ==============================================================================
|
| 271 |
+
|
| 272 |
+
def load_and_preprocess_satellite(image_path, database_size):
|
| 273 |
+
image = Image.open(image_path).convert('RGB')
|
| 274 |
+
transform = transforms.Compose([
|
| 275 |
+
transforms.Resize([database_size, database_size]),
|
| 276 |
+
transforms.ToTensor(),
|
| 277 |
+
])
|
| 278 |
+
return transform(image).unsqueeze(0)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def load_and_preprocess_thermal(image_path, resize_width):
|
| 282 |
+
image = Image.open(image_path).convert('L')
|
| 283 |
+
transform = transforms.Compose([
|
| 284 |
+
transforms.Grayscale(num_output_channels=3),
|
| 285 |
+
transforms.Resize([resize_width, resize_width]),
|
| 286 |
+
transforms.ToTensor(),
|
| 287 |
+
])
|
| 288 |
+
return transform(image).unsqueeze(0)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def visualize_result(satellite_image, thermal_image, four_pred, resize_width,
|
| 292 |
+
database_size, save_path='examples/STHN_result.png',
|
| 293 |
+
gt_image_path=None):
|
| 294 |
+
alpha = database_size / resize_width
|
| 295 |
+
|
| 296 |
+
four_point_org = torch.zeros((1, 2, 2, 2))
|
| 297 |
+
four_point_org[:, :, 0, 0] = torch.tensor([0, 0])
|
| 298 |
+
four_point_org[:, :, 0, 1] = torch.tensor([resize_width - 1, 0])
|
| 299 |
+
four_point_org[:, :, 1, 0] = torch.tensor([0, resize_width - 1])
|
| 300 |
+
four_point_org[:, :, 1, 1] = torch.tensor([resize_width - 1, resize_width - 1])
|
| 301 |
+
|
| 302 |
+
four_point_pred = four_pred.cpu() + four_point_org
|
| 303 |
+
|
| 304 |
+
sat_display = F.interpolate(satellite_image, size=resize_width,
|
| 305 |
+
mode='bilinear', align_corners=True, antialias=True)
|
| 306 |
+
sat_np = (sat_display[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
| 307 |
+
thermal_np = (thermal_image[0].permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
| 308 |
+
|
| 309 |
+
pred_pts = four_point_pred[0].numpy()
|
| 310 |
+
pts = np.array([
|
| 311 |
+
[pred_pts[0, 0, 0], pred_pts[1, 0, 0]], # TL
|
| 312 |
+
[pred_pts[0, 0, 1], pred_pts[1, 0, 1]], # TR
|
| 313 |
+
[pred_pts[0, 1, 1], pred_pts[1, 1, 1]], # BR
|
| 314 |
+
[pred_pts[0, 1, 0], pred_pts[1, 1, 0]], # BL
|
| 315 |
+
], dtype=np.int32).reshape((-1, 1, 2))
|
| 316 |
+
|
| 317 |
+
sat_with_bbox = sat_np.copy()
|
| 318 |
+
cv2.polylines(sat_with_bbox, [pts], True, (0, 255, 0), 2)
|
| 319 |
+
|
| 320 |
+
four_point_org_flat = four_point_org.flatten(2).permute(0, 2, 1).contiguous()
|
| 321 |
+
four_point_pred_flat = four_point_pred.flatten(2).permute(0, 2, 1).contiguous()
|
| 322 |
+
H = tgm.get_perspective_transform(four_point_org_flat, four_point_pred_flat)
|
| 323 |
+
warped_thermal = tgm.warp_perspective(thermal_image.cpu(), H,
|
| 324 |
+
(resize_width, resize_width))
|
| 325 |
+
warped_np = (warped_thermal[0].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
|
| 326 |
+
|
| 327 |
+
# Determine layout based on whether ground truth is available
|
| 328 |
+
has_gt = gt_image_path is not None and os.path.exists(gt_image_path)
|
| 329 |
+
ncols = 5 if has_gt else 4
|
| 330 |
+
fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, 5))
|
| 331 |
+
|
| 332 |
+
axes[0].imshow(sat_np)
|
| 333 |
+
axes[0].set_title('Satellite Image')
|
| 334 |
+
axes[0].axis('off')
|
| 335 |
+
|
| 336 |
+
axes[1].imshow(thermal_np, cmap='gray')
|
| 337 |
+
axes[1].set_title('Thermal Image')
|
| 338 |
+
axes[1].axis('off')
|
| 339 |
+
|
| 340 |
+
axes[2].imshow(sat_with_bbox)
|
| 341 |
+
axes[2].set_title('Predicted Alignment (green bbox)')
|
| 342 |
+
axes[2].axis('off')
|
| 343 |
+
|
| 344 |
+
axes[3].imshow(sat_np)
|
| 345 |
+
axes[3].imshow(warped_np, alpha=0.5)
|
| 346 |
+
axes[3].set_title('Overlay')
|
| 347 |
+
axes[3].axis('off')
|
| 348 |
+
|
| 349 |
+
if has_gt:
|
| 350 |
+
gt_img = np.array(Image.open(gt_image_path).convert('RGB'))
|
| 351 |
+
axes[4].imshow(gt_img)
|
| 352 |
+
axes[4].set_title('Ground Truth')
|
| 353 |
+
axes[4].axis('off')
|
| 354 |
+
|
| 355 |
+
plt.tight_layout()
|
| 356 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 357 |
+
print(f"\nVisualization saved to {save_path}")
|
| 358 |
+
|
| 359 |
+
disp = four_pred[0].cpu()
|
| 360 |
+
disp_scaled = disp * alpha
|
| 361 |
+
print(f"\n4-Point Displacement (pixels at {resize_width}x{resize_width} scale):")
|
| 362 |
+
print(f" Top-Left: dx={disp[0, 0, 0]:.2f}, dy={disp[1, 0, 0]:.2f}")
|
| 363 |
+
print(f" Top-Right: dx={disp[0, 0, 1]:.2f}, dy={disp[1, 0, 1]:.2f}")
|
| 364 |
+
print(f" Bottom-Left: dx={disp[0, 1, 0]:.2f}, dy={disp[1, 1, 0]:.2f}")
|
| 365 |
+
print(f" Bottom-Right: dx={disp[0, 1, 1]:.2f}, dy={disp[1, 1, 1]:.2f}")
|
| 366 |
+
print(f"\n4-Point Displacement (scaled to {database_size}x{database_size}):")
|
| 367 |
+
print(f" Top-Left: dx={disp_scaled[0, 0, 0]:.2f}, dy={disp_scaled[1, 0, 0]:.2f}")
|
| 368 |
+
print(f" Top-Right: dx={disp_scaled[0, 0, 1]:.2f}, dy={disp_scaled[1, 0, 1]:.2f}")
|
| 369 |
+
print(f" Bottom-Left: dx={disp_scaled[0, 1, 0]:.2f}, dy={disp_scaled[1, 1, 0]:.2f}")
|
| 370 |
+
print(f" Bottom-Right: dx={disp_scaled[0, 1, 1]:.2f}, dy={disp_scaled[1, 1, 1]:.2f}")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# ==============================================================================
|
| 374 |
+
# Main
|
| 375 |
+
# ==============================================================================
|
| 376 |
+
|
| 377 |
+
ONE_STAGE_CONFIG = {
|
| 378 |
+
'resize_width': 256,
|
| 379 |
+
'database_size': 1536,
|
| 380 |
+
'corr_level': 4,
|
| 381 |
+
'iters_lev0': 6,
|
| 382 |
+
'iters_lev1': 6,
|
| 383 |
+
'two_stages': False,
|
| 384 |
+
'fine_padding': 0,
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
TWO_STAGE_CONFIG = {
|
| 388 |
+
'resize_width': 256,
|
| 389 |
+
'database_size': 1536,
|
| 390 |
+
'corr_level': 4,
|
| 391 |
+
'iters_lev0': 6,
|
| 392 |
+
'iters_lev1': 6,
|
| 393 |
+
'two_stages': True,
|
| 394 |
+
'fine_padding': 0,
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
parser = argparse.ArgumentParser(description='STHN Demo: Satellite-Thermal Homography Estimation')
|
| 399 |
+
parser.add_argument('--satellite_image', type=str, default='examples/img1.png',
|
| 400 |
+
help='Path to satellite RGB image')
|
| 401 |
+
parser.add_argument('--thermal_image', type=str, default='examples/img2.png',
|
| 402 |
+
help='Path to thermal image')
|
| 403 |
+
parser.add_argument('--gt_image', type=str, default='examples/gt.png',
|
| 404 |
+
help='Path to ground truth overlay image')
|
| 405 |
+
parser.add_argument('--two_stages', action='store_true',
|
| 406 |
+
help='Use two-stage model for higher accuracy')
|
| 407 |
+
parser.add_argument('--save_path', type=str, default=None,
|
| 408 |
+
help='Output visualization path')
|
| 409 |
+
parser.add_argument('--hf_model', type=str, default=None,
|
| 410 |
+
help='HuggingFace model name (e.g., arplaboratory/STHN_one_stage)')
|
| 411 |
+
parser.add_argument('--local_checkpoint', type=str, default=None,
|
| 412 |
+
help='Path to local checkpoint (.pth)')
|
| 413 |
+
parser.add_argument('--push_to_hub', type=str, default=None,
|
| 414 |
+
help='Upload model to HuggingFace Hub (e.g., arplaboratory/STHN_one_stage)')
|
| 415 |
+
args = parser.parse_args()
|
| 416 |
+
|
| 417 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 418 |
+
config = TWO_STAGE_CONFIG if args.two_stages else ONE_STAGE_CONFIG
|
| 419 |
+
if args.save_path is None:
|
| 420 |
+
args.save_path = 'examples/STHN_result_two_stage.png' if args.two_stages else 'examples/STHN_result_one_stage.png'
|
| 421 |
+
|
| 422 |
+
# ---- Load Model ----
|
| 423 |
+
if args.local_checkpoint:
|
| 424 |
+
print(f"Loading model from local checkpoint: {args.local_checkpoint}")
|
| 425 |
+
model = STHN.from_local_checkpoint(args.local_checkpoint, config)
|
| 426 |
+
elif args.hf_model:
|
| 427 |
+
print(f"Loading model from HuggingFace Hub: {args.hf_model}")
|
| 428 |
+
model = STHN.from_pretrained(args.hf_model)
|
| 429 |
+
else:
|
| 430 |
+
default_ckpt = '1536_one_stage/STHN.pth' if not args.two_stages else '1536_two_stages/STHN.pth'
|
| 431 |
+
if os.path.exists(default_ckpt):
|
| 432 |
+
print(f"Loading model from default checkpoint: {default_ckpt}")
|
| 433 |
+
model = STHN.from_local_checkpoint(default_ckpt, config)
|
| 434 |
+
else:
|
| 435 |
+
print("No checkpoint found. Please specify --hf_model or --local_checkpoint")
|
| 436 |
+
sys.exit(1)
|
| 437 |
+
|
| 438 |
+
model = model.to(device)
|
| 439 |
+
model.eval()
|
| 440 |
+
|
| 441 |
+
# ---- Push to HuggingFace Hub ----
|
| 442 |
+
if args.push_to_hub:
|
| 443 |
+
print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
|
| 444 |
+
model.push_to_hub(args.push_to_hub)
|
| 445 |
+
print("Done!")
|
| 446 |
+
|
| 447 |
+
# ---- Run Inference ----
|
| 448 |
+
print(f"Running inference on {device}...")
|
| 449 |
+
satellite = load_and_preprocess_satellite(
|
| 450 |
+
args.satellite_image, config['database_size']).to(device)
|
| 451 |
+
thermal = load_and_preprocess_thermal(
|
| 452 |
+
args.thermal_image, config['resize_width']).to(device)
|
| 453 |
+
|
| 454 |
+
with torch.no_grad():
|
| 455 |
+
four_pred = model(satellite, thermal)
|
| 456 |
+
|
| 457 |
+
visualize_result(satellite, thermal, four_pred,
|
| 458 |
+
config['resize_width'], config['database_size'],
|
| 459 |
+
args.save_path, gt_image_path=args.gt_image)
|
examples/gt.png
ADDED
|
Git LFS Details
|
examples/img1.png
ADDED
|
examples/img2.png
ADDED
|
one_stage/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- model_hub_mixin
|
| 4 |
+
- pytorch_model_hub_mixin
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
| 8 |
+
- Code: [More Information Needed]
|
| 9 |
+
- Paper: [More Information Needed]
|
| 10 |
+
- Docs: [More Information Needed]
|
one_stage/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"corr_level": 4,
|
| 4 |
+
"database_size": 1536,
|
| 5 |
+
"fine_padding": 0,
|
| 6 |
+
"iters_lev0": 6,
|
| 7 |
+
"iters_lev1": 6,
|
| 8 |
+
"resize_width": 256,
|
| 9 |
+
"two_stages": false
|
| 10 |
+
}
|
| 11 |
+
}
|
one_stage/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:359345e2f46f6c65ee2ca23ffa1ff6ab73a8af5ae348862799d1843edbd068e2
|
| 3 |
+
size 6508704
|
two_stages/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- model_hub_mixin
|
| 4 |
+
- pytorch_model_hub_mixin
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
|
| 8 |
+
- Code: [More Information Needed]
|
| 9 |
+
- Paper: [More Information Needed]
|
| 10 |
+
- Docs: [More Information Needed]
|
two_stages/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"corr_level": 4,
|
| 4 |
+
"database_size": 1536,
|
| 5 |
+
"fine_padding": 0,
|
| 6 |
+
"iters_lev0": 6,
|
| 7 |
+
"iters_lev1": 6,
|
| 8 |
+
"resize_width": 256,
|
| 9 |
+
"two_stages": true
|
| 10 |
+
}
|
| 11 |
+
}
|
two_stages/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:baefb3ae87be966349a3dab2f46b60aa7333489b2e2ce9edebfd06ac21711a47
|
| 3 |
+
size 12271248
|