qgyd2021's picture
first commit
071150e
"""
https://github.com/rpautrat/SuperPointPretrainedNetwork
"""
import argparse
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import torch
import cv2
import numpy as np
import requests
from PIL import Image
from transformers import AutoImageProcessor
from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection, SuperPointKeypointDescriptionOutput
from project_settings import project_path, temp_directory
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="magic-leap-community/superpoint",
type=str
)
parser.add_argument(
"--model_cache_dir",
default=(project_path / "../../hf_hub_models").as_posix(),
type=str
)
parser.add_argument(
"--image_path",
default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
type=str
)
args = parser.parse_args()
return args
def show_image(image):
cv2.imshow("image", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
def main():
args = get_args()
processor = AutoImageProcessor.from_pretrained(
pretrained_model_name_or_path=args.model_name,
cache_dir=args.model_cache_dir,
)
model = SuperPointForKeypointDetection.from_pretrained(
pretrained_model_name_or_path=args.model_name,
cache_dir=args.model_cache_dir,
)
image = Image.open(args.image_path).convert("RGB")
inputs = processor(image, return_tensors="pt")
output: SuperPointKeypointDescriptionOutput = model(**inputs)
# 使用 processor 的后处理,将相对坐标转换为像素坐标
image_size = (image.height, image.width)
processed = processor.post_process_keypoint_detection(
output,
[image_size],
)
# processed 是长度为 batch_size 的 list,这里只有一张图
keypoints = processed[0]["keypoints"] # [N, 2],(x, y) 为像素坐标
scores = processed[0]["scores"] # [N]
descriptors = processed[0]["descriptors"] # [N, D]
scores = scores.detach().cpu().numpy()
print(f"检测到关键点数量: {keypoints.shape[0]}")
print(f"描述符维度: {descriptors.shape}")
# 5. 使用 OpenCV 的 drawKeypoints 在图像中画出关键点并展示
# PIL 图像 -> numpy -> BGR
image_np = np.array(image) # RGB
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 将 SuperPoint 的关键点转换为 OpenCV 的 KeyPoint 列表
cv2_keypoints = []
for (x, y), score in zip(keypoints, scores):
# x, y 是像素坐标;score 作为响应值
# OpenCV 只有在 angle != -1 时,DRAW_RICH_KEYPOINTS 才会画出“半径线”
kp = cv2.KeyPoint(
x=float(x),
y=float(y),
size=7,
response=float(score),
)
cv2_keypoints.append(kp)
# 使用 drawKeypoints 画关键点
image_with_kp = cv2.drawKeypoints(
image_bgr,
cv2_keypoints,
None,
color=(0, 0, 255),
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
)
show_image(image_with_kp)
return
if __name__ == "__main__":
main()