farrell236 commited on
Commit
c67f469
·
verified ·
1 Parent(s): c320b82

Upload heatmap_utils.py

Browse files
Files changed (1) hide show
  1. heatmap_utils.py +329 -0
heatmap_utils.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def _gaussian_blur_heatmaps(heatmaps: torch.Tensor, kernel: int = 11) -> torch.Tensor:
7
+ if kernel % 2 == 0:
8
+ raise ValueError("kernel must be odd")
9
+
10
+ sigma = kernel / 6.0
11
+ radius = kernel // 2
12
+
13
+ x = torch.arange(kernel, device=heatmaps.device, dtype=heatmaps.dtype) - radius
14
+ g = torch.exp(-(x ** 2) / (2 * sigma * sigma))
15
+ g = g / g.sum()
16
+
17
+ g_x = g.view(1, 1, 1, kernel)
18
+ g_y = g.view(1, 1, kernel, 1)
19
+
20
+ B, N, H, W = heatmaps.shape
21
+
22
+ # 🔥 FIX HERE
23
+ x_in = heatmaps.reshape(B * N, 1, H, W)
24
+
25
+ x_in = F.pad(x_in, (radius, radius, 0, 0), mode="reflect")
26
+ x_in = F.conv2d(x_in, g_x)
27
+
28
+ x_in = F.pad(x_in, (0, 0, radius, radius), mode="reflect")
29
+ x_in = F.conv2d(x_in, g_y)
30
+
31
+ return x_in.reshape(B, N, H, W)
32
+
33
+
34
+ def heatmaps_to_coords_dark(
35
+ heatmaps: torch.Tensor,
36
+ blur_kernel: int = 11,
37
+ eps: float = 1e-10,
38
+ ) -> torch.Tensor:
39
+ """
40
+ DARK-style decoding with second-order local refinement.
41
+
42
+ Args:
43
+ heatmaps: [B, N, H, W] or [N, H, W]
44
+ blur_kernel: Gaussian blur kernel before log
45
+ eps: numerical stability for log
46
+
47
+ Returns:
48
+ coords: [B, N, 2] or [N, 2] in heatmap coordinates
49
+ """
50
+ squeeze_batch = False
51
+ if heatmaps.ndim == 3:
52
+ heatmaps = heatmaps.unsqueeze(0)
53
+ squeeze_batch = True
54
+
55
+ if heatmaps.ndim != 4:
56
+ raise ValueError(f"Expected [B, N, H, W] or [N, H, W], got {heatmaps.shape}")
57
+
58
+ B, N, H, W = heatmaps.shape
59
+
60
+ # Blur then log, as in DARK-style refinement
61
+ hm = _gaussian_blur_heatmaps(heatmaps, kernel=blur_kernel)
62
+ hm = torch.clamp(hm, min=eps).log()
63
+
64
+ # Coarse argmax
65
+ flat = hm.view(B, N, -1)
66
+ idx = flat.argmax(dim=-1)
67
+
68
+ py = (idx // W).long()
69
+ px = (idx % W).long()
70
+
71
+ coords = torch.stack([px.float(), py.float()], dim=-1)
72
+
73
+ # Refine using local derivatives of log-heatmap
74
+ for b in range(B):
75
+ for n in range(N):
76
+ x = px[b, n].item()
77
+ y = py[b, n].item()
78
+
79
+ # Need 1-pixel neighborhood for derivatives
80
+ if x < 1 or x > W - 2 or y < 1 or y > H - 2:
81
+ continue
82
+
83
+ patch = hm[b, n]
84
+
85
+ dx = 0.5 * (patch[y, x + 1] - patch[y, x - 1])
86
+ dy = 0.5 * (patch[y + 1, x] - patch[y - 1, x])
87
+
88
+ dxx = patch[y, x + 1] - 2 * patch[y, x] + patch[y, x - 1]
89
+ dyy = patch[y + 1, x] - 2 * patch[y, x] + patch[y - 1, x]
90
+ dxy = 0.25 * (
91
+ patch[y + 1, x + 1]
92
+ - patch[y + 1, x - 1]
93
+ - patch[y - 1, x + 1]
94
+ + patch[y - 1, x - 1]
95
+ )
96
+
97
+ grad = torch.stack([dx, dy]) # [2]
98
+ hessian = torch.stack(
99
+ [
100
+ torch.stack([dxx, dxy]),
101
+ torch.stack([dxy, dyy]),
102
+ ]
103
+ ) # [2, 2]
104
+
105
+ # Solve offset = -H^{-1} g
106
+ det = hessian[0, 0] * hessian[1, 1] - hessian[0, 1] * hessian[1, 0]
107
+ if torch.abs(det) < 1e-6:
108
+ continue
109
+
110
+ try:
111
+ offset = -torch.linalg.solve(hessian, grad)
112
+ except RuntimeError:
113
+ continue
114
+
115
+ # Keep refinement bounded; if huge, it's unstable
116
+ if torch.all(torch.abs(offset) <= 1.5):
117
+ coords[b, n, 0] += offset[0]
118
+ coords[b, n, 1] += offset[1]
119
+
120
+ if squeeze_batch:
121
+ coords = coords[0]
122
+
123
+ return coords
124
+
125
+
126
+ def heatmap_coords_to_image_coords(
127
+ coords: torch.Tensor,
128
+ image_size: tuple,
129
+ heatmap_size: tuple,
130
+ ) -> torch.Tensor:
131
+ """
132
+ Map coordinates from heatmap space back to image space.
133
+
134
+ Args:
135
+ coords: [B, N, 2] or [N, 2]
136
+ image_size: (H_img, W_img)
137
+ heatmap_size: (H_hm, W_hm)
138
+ """
139
+ H_img, W_img = image_size
140
+ H_hm, W_hm = heatmap_size
141
+
142
+ out = coords.clone()
143
+ out[..., 0] *= (W_img / W_hm)
144
+ out[..., 1] *= (H_img / H_hm)
145
+ return out
146
+
147
+
148
+ def gaussian2d(size: int, sigma: float, device=None) -> torch.Tensor:
149
+ """
150
+ Create a 2D Gaussian kernel of shape [size, size].
151
+ """
152
+ coords = torch.arange(size, device=device, dtype=torch.float32)
153
+ center = (size - 1) / 2.0
154
+ x = coords - center
155
+ y = coords - center
156
+ yy, xx = torch.meshgrid(y, x, indexing="ij")
157
+ g = torch.exp(-(xx**2 + yy**2) / (2 * sigma * sigma))
158
+ return g
159
+
160
+
161
+ def draw_gaussian(
162
+ heatmap: torch.Tensor,
163
+ center_x: float,
164
+ center_y: float,
165
+ sigma: float,
166
+ ) -> torch.Tensor:
167
+ """
168
+ Draw a Gaussian on a single heatmap in-place.
169
+
170
+ Args:
171
+ heatmap: [H, W]
172
+ center_x, center_y: landmark coordinates in heatmap space
173
+ sigma: Gaussian sigma in heatmap pixels
174
+ """
175
+ H, W = heatmap.shape
176
+ radius = int(3 * sigma)
177
+ size = 2 * radius + 1
178
+
179
+ mu_x = int(round(center_x.item()))
180
+ mu_y = int(round(center_y.item()))
181
+
182
+ left = min(mu_x, radius)
183
+ right = min(W - mu_x - 1, radius)
184
+ top = min(mu_y, radius)
185
+ bottom = min(H - mu_y - 1, radius)
186
+
187
+ if left < 0 or right < 0 or top < 0 or bottom < 0:
188
+ return heatmap
189
+
190
+ g = gaussian2d(size=size, sigma=sigma, device=heatmap.device)
191
+
192
+ g_x0 = radius - left
193
+ g_x1 = radius + right + 1
194
+ g_y0 = radius - top
195
+ g_y1 = radius + bottom + 1
196
+
197
+ h_x0 = mu_x - left
198
+ h_x1 = mu_x + right + 1
199
+ h_y0 = mu_y - top
200
+ h_y1 = mu_y + bottom + 1
201
+
202
+ heatmap[h_y0:h_y1, h_x0:h_x1] = torch.maximum(
203
+ heatmap[h_y0:h_y1, h_x0:h_x1],
204
+ g[g_y0:g_y1, g_x0:g_x1],
205
+ )
206
+ return heatmap
207
+
208
+
209
+ def generate_heatmaps(
210
+ landmarks: torch.Tensor,
211
+ image_size: tuple,
212
+ heatmap_size: tuple,
213
+ sigma: float = 2.0,
214
+ ) -> torch.Tensor:
215
+ """
216
+ Generate Gaussian heatmaps for landmark detection.
217
+
218
+ Args:
219
+ landmarks: [N, 2] tensor of (x, y) in original image coordinates
220
+ image_size: (H_img, W_img)
221
+ heatmap_size: (H_hm, W_hm)
222
+ sigma: Gaussian sigma in heatmap pixels
223
+
224
+ Returns:
225
+ heatmaps: [N, H_hm, W_hm]
226
+ """
227
+ if landmarks.ndim != 2 or landmarks.shape[1] != 2:
228
+ raise ValueError(f"Expected landmarks shape [N, 2], got {landmarks.shape}")
229
+
230
+ H_img, W_img = image_size
231
+ H_hm, W_hm = heatmap_size
232
+
233
+ scale_x = W_hm / W_img
234
+ scale_y = H_hm / H_img
235
+
236
+ device = landmarks.device
237
+ num_landmarks = landmarks.shape[0]
238
+ heatmaps = torch.zeros((num_landmarks, H_hm, W_hm), dtype=torch.float32, device=device)
239
+
240
+ for i in range(num_landmarks):
241
+ x, y = landmarks[i]
242
+ x_hm = x * scale_x
243
+ y_hm = y * scale_y
244
+
245
+ if 0 <= x_hm < W_hm and 0 <= y_hm < H_hm:
246
+ draw_gaussian(heatmaps[i], x_hm, y_hm, sigma=sigma)
247
+
248
+ return heatmaps
249
+
250
+
251
+ def generate_batch_heatmaps(
252
+ landmarks_batch: torch.Tensor,
253
+ image_size: tuple,
254
+ heatmap_size: tuple,
255
+ sigma: float = 2.0,
256
+ ) -> torch.Tensor:
257
+ """
258
+ Batch version.
259
+
260
+ Args:
261
+ landmarks_batch: [B, N, 2]
262
+ image_size: (H_img, W_img)
263
+ heatmap_size: (H_hm, W_hm)
264
+
265
+ Returns:
266
+ heatmaps: [B, N, H_hm, W_hm]
267
+ """
268
+ if landmarks_batch.ndim != 3 or landmarks_batch.shape[-1] != 2:
269
+ raise ValueError(f"Expected [B, N, 2], got {landmarks_batch.shape}")
270
+
271
+ out = []
272
+ for b in range(landmarks_batch.shape[0]):
273
+ hm = generate_heatmaps(
274
+ landmarks=landmarks_batch[b],
275
+ image_size=image_size,
276
+ heatmap_size=heatmap_size,
277
+ sigma=sigma,
278
+ )
279
+ out.append(hm)
280
+ return torch.stack(out, dim=0)
281
+
282
+
283
+ def heatmaps_to_coords_argmax(heatmaps: torch.Tensor) -> torch.Tensor:
284
+ """
285
+ Decode coordinates from heatmaps using argmax.
286
+
287
+ Args:
288
+ heatmaps: [B, N, H, W] or [N, H, W]
289
+
290
+ Returns:
291
+ coords: [B, N, 2] or [N, 2] in heatmap coordinates
292
+ """
293
+ squeeze_batch = False
294
+ if heatmaps.ndim == 3:
295
+ heatmaps = heatmaps.unsqueeze(0)
296
+ squeeze_batch = True
297
+
298
+ B, N, H, W = heatmaps.shape
299
+ flat = heatmaps.view(B, N, -1)
300
+ idx = flat.argmax(dim=-1)
301
+
302
+ y = idx // W
303
+ x = idx % W
304
+
305
+ coords = torch.stack([x.float(), y.float()], dim=-1)
306
+
307
+ if squeeze_batch:
308
+ coords = coords[0]
309
+ return coords
310
+
311
+
312
+ def heatmap_coords_to_image_coords(
313
+ coords: torch.Tensor,
314
+ image_size: tuple,
315
+ heatmap_size: tuple,
316
+ ) -> torch.Tensor:
317
+ """
318
+ Map coordinates from heatmap space back to image space.
319
+ """
320
+ H_img, W_img = image_size
321
+ H_hm, W_hm = heatmap_size
322
+
323
+ scale_x = W_img / W_hm
324
+ scale_y = H_img / H_hm
325
+
326
+ out = coords.clone()
327
+ out[..., 0] = out[..., 0] * scale_x
328
+ out[..., 1] = out[..., 1] * scale_y
329
+ return out