""" https://github.com/rpautrat/SuperPointPretrainedNetwork """ import argparse import os os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" import torch import cv2 import numpy as np import requests from PIL import Image from transformers import AutoImageProcessor from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection, SuperPointKeypointDescriptionOutput from project_settings import project_path, temp_directory def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_name", default="magic-leap-community/superpoint", type=str ) parser.add_argument( "--model_cache_dir", default=(project_path / "../../hf_hub_models").as_posix(), type=str ) parser.add_argument( "--image_path", default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(), type=str ) args = parser.parse_args() return args def show_image(image): cv2.imshow("image", image) cv2.waitKey(0) cv2.destroyAllWindows() def main(): args = get_args() processor = AutoImageProcessor.from_pretrained( pretrained_model_name_or_path=args.model_name, cache_dir=args.model_cache_dir, ) model = SuperPointForKeypointDetection.from_pretrained( pretrained_model_name_or_path=args.model_name, cache_dir=args.model_cache_dir, ) image = Image.open(args.image_path).convert("RGB") inputs = processor(image, return_tensors="pt") output: SuperPointKeypointDescriptionOutput = model(**inputs) # 使用 processor 的后处理,将相对坐标转换为像素坐标 image_size = (image.height, image.width) processed = processor.post_process_keypoint_detection( output, [image_size], ) # processed 是长度为 batch_size 的 list,这里只有一张图 keypoints = processed[0]["keypoints"] # [N, 2],(x, y) 为像素坐标 scores = processed[0]["scores"] # [N] descriptors = processed[0]["descriptors"] # [N, D] scores = scores.detach().cpu().numpy() print(f"检测到关键点数量: {keypoints.shape[0]}") print(f"描述符维度: {descriptors.shape}") # 5. 使用 OpenCV 的 drawKeypoints 在图像中画出关键点并展示 # PIL 图像 -> numpy -> BGR image_np = np.array(image) # RGB image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # 将 SuperPoint 的关键点转换为 OpenCV 的 KeyPoint 列表 cv2_keypoints = [] for (x, y), score in zip(keypoints, scores): # x, y 是像素坐标;score 作为响应值 # OpenCV 只有在 angle != -1 时,DRAW_RICH_KEYPOINTS 才会画出“半径线” kp = cv2.KeyPoint( x=float(x), y=float(y), size=7, response=float(score), ) cv2_keypoints.append(kp) # 使用 drawKeypoints 画关键点 image_with_kp = cv2.drawKeypoints( image_bgr, cv2_keypoints, None, color=(0, 0, 255), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS, ) show_image(image_with_kp) return if __name__ == "__main__": main()