File size: 3,246 Bytes
071150e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
https://github.com/rpautrat/SuperPointPretrainedNetwork

"""
import argparse
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import torch
import cv2
import numpy as np
import requests
from PIL import Image
from transformers import AutoImageProcessor
from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection, SuperPointKeypointDescriptionOutput
from project_settings import project_path, temp_directory


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        default="magic-leap-community/superpoint",
        type=str
    )
    parser.add_argument(
        "--model_cache_dir",
        default=(project_path / "../../hf_hub_models").as_posix(),
        type=str
    )
    parser.add_argument(
        "--image_path",
        default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


def show_image(image):
    cv2.imshow("image", image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def main():
    args = get_args()

    processor = AutoImageProcessor.from_pretrained(
        pretrained_model_name_or_path=args.model_name,
        cache_dir=args.model_cache_dir,
    )
    model = SuperPointForKeypointDetection.from_pretrained(
        pretrained_model_name_or_path=args.model_name,
        cache_dir=args.model_cache_dir,
    )

    image = Image.open(args.image_path).convert("RGB")

    inputs = processor(image, return_tensors="pt")
    output: SuperPointKeypointDescriptionOutput = model(**inputs)

    # 使用 processor 的后处理,将相对坐标转换为像素坐标
    image_size = (image.height, image.width)
    processed = processor.post_process_keypoint_detection(
        output,
        [image_size],
    )
    # processed 是长度为 batch_size 的 list,这里只有一张图
    keypoints = processed[0]["keypoints"]  # [N, 2],(x, y) 为像素坐标
    scores = processed[0]["scores"]        # [N]
    descriptors = processed[0]["descriptors"]  # [N, D]
    scores = scores.detach().cpu().numpy()

    print(f"检测到关键点数量: {keypoints.shape[0]}")
    print(f"描述符维度: {descriptors.shape}")

    # 5. 使用 OpenCV 的 drawKeypoints 在图像中画出关键点并展示
    # PIL 图像 -> numpy -> BGR
    image_np = np.array(image)  # RGB
    image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

    # 将 SuperPoint 的关键点转换为 OpenCV 的 KeyPoint 列表
    cv2_keypoints = []
    for (x, y), score in zip(keypoints, scores):
        # x, y 是像素坐标;score 作为响应值
        # OpenCV 只有在 angle != -1 时,DRAW_RICH_KEYPOINTS 才会画出“半径线”
        kp = cv2.KeyPoint(
            x=float(x),
            y=float(y),
            size=7,
            response=float(score),
        )
        cv2_keypoints.append(kp)

    # 使用 drawKeypoints 画关键点
    image_with_kp = cv2.drawKeypoints(
        image_bgr,
        cv2_keypoints,
        None,
        color=(0, 0, 255),
        flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
    )

    show_image(image_with_kp)
    return


if __name__ == "__main__":
    main()