| import kornia | |
| import torch | |
| from .configuration_disk import DiskConfig | |
| from transformers.models.superpoint.modeling_superpoint import ( | |
| SuperPointKeypointDescriptionOutput, | |
| ) | |
| from transformers import PreTrainedModel | |
| class DiskForKeypointDetection(PreTrainedModel): | |
| config_class = DiskConfig | |
| def __init__(self, config: DiskConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.model = kornia.feature.DISK(self.config.descriptor_decoder_dim) | |
| def forward( | |
| self, pixel_values: torch.Tensor | |
| ) -> SuperPointKeypointDescriptionOutput: | |
| detections = self.model( | |
| pixel_values, | |
| n=self.config.max_num_keypoints, | |
| window_size=self.config.nms_window_size, | |
| score_threshold=self.config.detection_threshold, | |
| pad_if_not_divisible=self.config.pad_if_not_divisible, | |
| ) | |
| max_num_keypoints = max( | |
| detection.keypoints.shape[0] for detection in detections | |
| ) | |
| keypoints = torch.zeros( | |
| len(detections), max_num_keypoints, 2, device=pixel_values.device | |
| ) | |
| descriptors = torch.zeros( | |
| len(detections), | |
| max_num_keypoints, | |
| self.config.descriptor_decoder_dim, | |
| device=pixel_values.device, | |
| ) | |
| scores = torch.zeros( | |
| len(detections), max_num_keypoints, device=pixel_values.device | |
| ) | |
| mask = torch.zeros( | |
| len(detections), max_num_keypoints, device=pixel_values.device | |
| ) | |
| for i, detection in enumerate(detections): | |
| keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints | |
| descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors | |
| scores[i, : detection.detection_scores.shape[0]] = ( | |
| detection.detection_scores | |
| ) | |
| mask[i, : detection.detection_scores.shape[0]] = 1 | |
| width, height = pixel_values.shape[-1], pixel_values.shape[-2] | |
| keypoints[:, :, 0] = keypoints[:, :, 0] / width | |
| keypoints[:, :, 1] = keypoints[:, :, 1] / height | |
| return SuperPointKeypointDescriptionOutput( | |
| keypoints=keypoints, | |
| scores=scores, | |
| descriptors=descriptors, | |
| mask=mask, | |
| ) | |