disk / modeling_disk.py
stevenbucaille's picture
Upload DiskForKeypointDetection
6e22f37 verified
import kornia
import torch
from .configuration_disk import DiskConfig
from transformers.models.superpoint.modeling_superpoint import (
SuperPointKeypointDescriptionOutput,
)
from transformers import PreTrainedModel
class DiskForKeypointDetection(PreTrainedModel):
config_class = DiskConfig
def __init__(self, config: DiskConfig):
super().__init__(config)
self.config = config
self.model = kornia.feature.DISK(self.config.descriptor_decoder_dim)
def forward(
self, pixel_values: torch.Tensor
) -> SuperPointKeypointDescriptionOutput:
detections = self.model(
pixel_values,
n=self.config.max_num_keypoints,
window_size=self.config.nms_window_size,
score_threshold=self.config.detection_threshold,
pad_if_not_divisible=self.config.pad_if_not_divisible,
)
max_num_keypoints = max(
detection.keypoints.shape[0] for detection in detections
)
keypoints = torch.zeros(
len(detections), max_num_keypoints, 2, device=pixel_values.device
)
descriptors = torch.zeros(
len(detections),
max_num_keypoints,
self.config.descriptor_decoder_dim,
device=pixel_values.device,
)
scores = torch.zeros(
len(detections), max_num_keypoints, device=pixel_values.device
)
mask = torch.zeros(
len(detections), max_num_keypoints, device=pixel_values.device
)
for i, detection in enumerate(detections):
keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints
descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors
scores[i, : detection.detection_scores.shape[0]] = (
detection.detection_scores
)
mask[i, : detection.detection_scores.shape[0]] = 1
width, height = pixel_values.shape[-1], pixel_values.shape[-2]
keypoints[:, :, 0] = keypoints[:, :, 0] / width
keypoints[:, :, 1] = keypoints[:, :, 1] / height
return SuperPointKeypointDescriptionOutput(
keypoints=keypoints,
scores=scores,
descriptors=descriptors,
mask=mask,
)