Spaces:
Runtime error
Runtime error
| """ | |
| https://github.com/rpautrat/SuperPointPretrainedNetwork | |
| """ | |
| import argparse | |
| import os | |
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| from transformers import AutoImageProcessor | |
| from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection, SuperPointKeypointDescriptionOutput | |
| from project_settings import project_path, temp_directory | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_name", | |
| default="magic-leap-community/superpoint", | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--model_cache_dir", | |
| default=(project_path / "../../hf_hub_models").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--image_path", | |
| default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(), | |
| type=str | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def show_image(image): | |
| cv2.imshow("image", image) | |
| cv2.waitKey(0) | |
| cv2.destroyAllWindows() | |
| def main(): | |
| args = get_args() | |
| processor = AutoImageProcessor.from_pretrained( | |
| pretrained_model_name_or_path=args.model_name, | |
| cache_dir=args.model_cache_dir, | |
| ) | |
| model = SuperPointForKeypointDetection.from_pretrained( | |
| pretrained_model_name_or_path=args.model_name, | |
| cache_dir=args.model_cache_dir, | |
| ) | |
| image = Image.open(args.image_path).convert("RGB") | |
| inputs = processor(image, return_tensors="pt") | |
| output: SuperPointKeypointDescriptionOutput = model(**inputs) | |
| # 使用 processor 的后处理,将相对坐标转换为像素坐标 | |
| image_size = (image.height, image.width) | |
| processed = processor.post_process_keypoint_detection( | |
| output, | |
| [image_size], | |
| ) | |
| # processed 是长度为 batch_size 的 list,这里只有一张图 | |
| keypoints = processed[0]["keypoints"] # [N, 2],(x, y) 为像素坐标 | |
| scores = processed[0]["scores"] # [N] | |
| descriptors = processed[0]["descriptors"] # [N, D] | |
| scores = scores.detach().cpu().numpy() | |
| print(f"检测到关键点数量: {keypoints.shape[0]}") | |
| print(f"描述符维度: {descriptors.shape}") | |
| # 5. 使用 OpenCV 的 drawKeypoints 在图像中画出关键点并展示 | |
| # PIL 图像 -> numpy -> BGR | |
| image_np = np.array(image) # RGB | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| # 将 SuperPoint 的关键点转换为 OpenCV 的 KeyPoint 列表 | |
| cv2_keypoints = [] | |
| for (x, y), score in zip(keypoints, scores): | |
| # x, y 是像素坐标;score 作为响应值 | |
| # OpenCV 只有在 angle != -1 时,DRAW_RICH_KEYPOINTS 才会画出“半径线” | |
| kp = cv2.KeyPoint( | |
| x=float(x), | |
| y=float(y), | |
| size=7, | |
| response=float(score), | |
| ) | |
| cv2_keypoints.append(kp) | |
| # 使用 drawKeypoints 画关键点 | |
| image_with_kp = cv2.drawKeypoints( | |
| image_bgr, | |
| cv2_keypoints, | |
| None, | |
| color=(0, 0, 255), | |
| flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS, | |
| ) | |
| show_image(image_with_kp) | |
| return | |
| if __name__ == "__main__": | |
| main() | |