Spaces:
Sleeping
Sleeping
File size: 3,712 Bytes
eb8e8ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import torch
from torch import Tensor
from scipy.ndimage import gaussian_filter
from typing import Optional, List, Tuple
def get_id(x: str) -> int:
# shanghaitech
if "_" in x:
return int(x.split(".")[0].split("_")[1])
# qnrf
elif "img" in x:
return int(x.split(".")[0].split("_")[1])
# oto
elif "image-" in x:
return int(x.split(".")[0].split("-")[-1])
# nwpu
else:
return int(x.split(".")[0])
def generate_density_map(label: Tensor, height: int, width: int, sigma: Optional[float] = None, num_classes: int = 1) -> Tensor:
"""
generate the density map based on the dot annotations provided by the label.
supports multi-class labels if label has 3 columns (x, y, class_id).
"""
density_map = torch.zeros((num_classes, height, width), dtype=torch.float32)
if len(label) > 0:
label_ = label.long()
# Check if label has class information (N, 3)
if label_.shape[1] == 3 and num_classes > 1:
# Clamp coordinates
label_[:, 0] = label_[:, 0].clamp(min=0, max=width - 1)
label_[:, 1] = label_[:, 1].clamp(min=0, max=height - 1)
# Clamp class ids to be safe
label_[:, 2] = label_[:, 2].clamp(min=0, max=num_classes - 1)
# Assign to specific class channels
# We iterate to handle potential duplicate points at same location properly if needed,
# but for simple binary map setting:
for c in range(num_classes):
mask = (label_[:, 2] == c)
if mask.any():
pts = label_[mask]
density_map[c, pts[:, 1], pts[:, 0]] = 1.0
else:
# Default behavior (single class or ignore class info if num_classes=1)
assert label_.shape[1] >= 2, f"label should have at least 2 columns, got {label.shape}."
label_[:, 0] = label_[:, 0].clamp(min=0, max=width - 1)
label_[:, 1] = label_[:, 1].clamp(min=0, max=height - 1)
density_map[0, label_[:, 1], label_[:, 0]] = 1.0
if sigma is not None:
assert sigma > 0, f"sigma should be positive if not None, got {sigma}."
# Apply gaussian filter to each channel
for c in range(num_classes):
density_map[c] = torch.from_numpy(gaussian_filter(density_map[c], sigma=sigma))
return density_map
def collate_fn(batch: List[Tensor]) -> Tuple[Tensor, List[Tensor], Tensor]:
batch = list(zip(*batch))
images = batch[0]
assert len(images[0].shape) == 4, f"images should be a 4D tensor, got {images[0].shape}."
if len(batch) == 4: # image, label, density_map, image_name
images = torch.cat(images, 0)
points = batch[1] # list of lists of tensors, flatten it
points = [p for points_ in points for p in points_]
densities = torch.cat(batch[2], 0)
image_names = batch[3] # list of lists of strings, flatten it
image_names = [name for names_ in image_names for name in names_]
return images, points, densities, image_names
elif len(batch) == 3: # image, label, density_map
images = torch.cat(images, 0)
points = batch[1]
points = [p for points_ in points for p in points_]
densities = torch.cat(batch[2], 0)
return images, points, densities
elif len(batch) == 2: # image, image_name. NWPU test dataset
images = torch.cat(images, 0)
image_names = batch[1]
image_names = [name for names_ in image_names for name in names_]
return images, image_names
else:
images = torch.cat(images, 0)
return images
|