File size: 2,307 Bytes
930428f
 
 
 
 
 
 
6e22f37
930428f
 
 
 
 
 
 
 
 
6e22f37
930428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import kornia
import torch

from .configuration_disk import DiskConfig
from transformers.models.superpoint.modeling_superpoint import (
    SuperPointKeypointDescriptionOutput,
)
from transformers import PreTrainedModel


class DiskForKeypointDetection(PreTrainedModel):
    config_class = DiskConfig

    def __init__(self, config: DiskConfig):
        super().__init__(config)

        self.config = config
        self.model = kornia.feature.DISK(self.config.descriptor_decoder_dim)

    def forward(
        self, pixel_values: torch.Tensor
    ) -> SuperPointKeypointDescriptionOutput:
        detections = self.model(
            pixel_values,
            n=self.config.max_num_keypoints,
            window_size=self.config.nms_window_size,
            score_threshold=self.config.detection_threshold,
            pad_if_not_divisible=self.config.pad_if_not_divisible,
        )
        max_num_keypoints = max(
            detection.keypoints.shape[0] for detection in detections
        )
        keypoints = torch.zeros(
            len(detections), max_num_keypoints, 2, device=pixel_values.device
        )
        descriptors = torch.zeros(
            len(detections),
            max_num_keypoints,
            self.config.descriptor_decoder_dim,
            device=pixel_values.device,
        )
        scores = torch.zeros(
            len(detections), max_num_keypoints, device=pixel_values.device
        )
        mask = torch.zeros(
            len(detections), max_num_keypoints, device=pixel_values.device
        )
        for i, detection in enumerate(detections):
            keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints
            descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors
            scores[i, : detection.detection_scores.shape[0]] = (
                detection.detection_scores
            )
            mask[i, : detection.detection_scores.shape[0]] = 1
        width, height = pixel_values.shape[-1], pixel_values.shape[-2]
        keypoints[:, :, 0] = keypoints[:, :, 0] / width
        keypoints[:, :, 1] = keypoints[:, :, 1] / height

        return SuperPointKeypointDescriptionOutput(
            keypoints=keypoints,
            scores=scores,
            descriptors=descriptors,
            mask=mask,
        )