Spaces:
Runtime error
Runtime error
honeytian commited on
Commit ·
071150e
0
Parent(s):
first commit
Browse files- .dockerignore +0 -0
- .gitattributes +35 -0
- .gitignore +11 -0
- Dockerfile +23 -0
- README.md +12 -0
- examples/__init__.py +3 -0
- examples/keypoint_match/__init__.py +3 -0
- examples/keypoint_match/sift_multi_image_match.py +118 -0
- examples/keypoint_match/superpoint_multi_image_match.py +118 -0
- examples/keypoint_match/video_track_and_collect_templates.py +278 -0
- examples/keypoints/kornia/test.py +69 -0
- examples/keypoints/superpoint/test.py +111 -0
- main.py +0 -0
- project_settings.py +27 -0
- requirements.txt +7 -0
- toolbox/__init__.py +0 -0
- toolbox/json/__init__.py +6 -0
- toolbox/json/misc.py +63 -0
- toolbox/keypoint_match/__init__.py +64 -0
- toolbox/keypoint_match/base.py +141 -0
- toolbox/keypoint_match/detector.py +52 -0
- toolbox/keypoint_match/keypoint_detector/__init__.py +25 -0
- toolbox/keypoint_match/keypoint_detector/multi_image_detector.py +239 -0
- toolbox/keypoint_match/keypoint_detector/single_image_detector.py +206 -0
- toolbox/keypoint_match/keypoint_extracter/__init__.py +13 -0
- toolbox/keypoint_match/keypoint_extracter/sift.py +184 -0
- toolbox/keypoint_match/keypoint_extracter/superpoint.py +191 -0
- toolbox/keypoint_match/keypoint_match/__init__.py +10 -0
- toolbox/keypoint_match/keypoint_match/single_image_match.py +280 -0
- toolbox/keypoint_match/types.py +112 -0
- toolbox/os/__init__.py +6 -0
- toolbox/os/command.py +59 -0
- toolbox/os/environment.py +114 -0
- toolbox/os/other.py +9 -0
.dockerignore
ADDED
|
File without changes
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
.idea/
|
| 3 |
+
|
| 4 |
+
__pycache__/
|
| 5 |
+
|
| 6 |
+
data/
|
| 7 |
+
logs/
|
| 8 |
+
temp/
|
| 9 |
+
trainede_models/
|
| 10 |
+
|
| 11 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
COPY . /code
|
| 6 |
+
|
| 7 |
+
RUN apt-get update
|
| 8 |
+
RUN apt-get install -y ffmpeg build-essential
|
| 9 |
+
RUN apt-get install -y libnss3
|
| 10 |
+
|
| 11 |
+
RUN pip install --upgrade pip
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 13 |
+
|
| 14 |
+
RUN useradd -m -u 1000 user
|
| 15 |
+
USER user
|
| 16 |
+
ENV HOME=/home/user \
|
| 17 |
+
PATH=/home/user/.local/bin:$PATH
|
| 18 |
+
|
| 19 |
+
WORKDIR $HOME/app
|
| 20 |
+
|
| 21 |
+
COPY --chown=user . $HOME/app
|
| 22 |
+
|
| 23 |
+
CMD ["python3", "main.py"]
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: KeyPoints
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
short_description: KeyPoints
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
examples/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
examples/keypoint_match/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
examples/keypoint_match/sift_multi_image_match.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
示例:使用 SIFT 提取关键点 + 多模板匹配(不做仿射/单应性)。
|
| 6 |
+
|
| 7 |
+
运行:
|
| 8 |
+
python -m examples.keypoint_match.sift_multi_image_match
|
| 9 |
+
|
| 10 |
+
或指定多个模板:
|
| 11 |
+
python -m examples.keypoint_match.sift_multi_image_match --template_paths a.png b.png --search_path big.png
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import List, Tuple
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from project_settings import project_path
|
| 24 |
+
from toolbox.keypoint_match import (
|
| 25 |
+
MultiImageDetector,
|
| 26 |
+
MultiImageDetectorConfig,
|
| 27 |
+
SiftExtractConfig,
|
| 28 |
+
SiftKeyPointExtract,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _crop_template(image_bgr: np.ndarray, xyxy: Tuple[int, int, int, int]) -> np.ndarray:
|
| 33 |
+
x1, y1, x2, y2 = [int(v) for v in xyxy]
|
| 34 |
+
h, w = image_bgr.shape[:2]
|
| 35 |
+
x1 = max(0, min(w - 1, x1))
|
| 36 |
+
x2 = max(0, min(w, x2))
|
| 37 |
+
y1 = max(0, min(h - 1, y1))
|
| 38 |
+
y2 = max(0, min(h, y2))
|
| 39 |
+
if x2 <= x1 or y2 <= y1:
|
| 40 |
+
raise ValueError(f"无效裁剪框: {xyxy}, image_size={(h, w)}")
|
| 41 |
+
return image_bgr[y1:y2, x1:x2].copy()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_args():
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--search_path",
|
| 48 |
+
type=str,
|
| 49 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
|
| 50 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard1.jpg").as_posix(),
|
| 51 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard2.jpg").as_posix(),
|
| 52 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard3.jpg").as_posix(),
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--template_paths",
|
| 56 |
+
type=str,
|
| 57 |
+
nargs="*",
|
| 58 |
+
default=[
|
| 59 |
+
(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller/roller1.png").as_posix(),
|
| 60 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller1.png").as_posix(),
|
| 61 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller2.png").as_posix(),
|
| 62 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller3.png").as_posix(),
|
| 63 |
+
],
|
| 64 |
+
help="提供多个小图模板路径;为空则使用默认模板 + 从大图裁剪一个 patch 当作第二模板",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument("--max_keypoints", type=int, default=2000)
|
| 67 |
+
parser.add_argument("--ratio", type=float, default=0.95)
|
| 68 |
+
parser.add_argument("--max_matches", type=int, default=120)
|
| 69 |
+
return parser.parse_args()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
args = get_args()
|
| 74 |
+
|
| 75 |
+
search = cv2.imread(args.search_path)
|
| 76 |
+
if search is None:
|
| 77 |
+
raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
|
| 78 |
+
|
| 79 |
+
templates: List[np.ndarray] = []
|
| 80 |
+
template_ids: List[str] = []
|
| 81 |
+
|
| 82 |
+
for p in args.template_paths:
|
| 83 |
+
img = cv2.imread(p)
|
| 84 |
+
if img is None:
|
| 85 |
+
raise FileNotFoundError(f"无法读取模板图: {p}")
|
| 86 |
+
templates.append(img)
|
| 87 |
+
template_ids.append(Path(p).stem)
|
| 88 |
+
|
| 89 |
+
extractor = SiftKeyPointExtract(
|
| 90 |
+
SiftExtractConfig(
|
| 91 |
+
max_keypoints=int(args.max_keypoints),
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
detector = MultiImageDetector(
|
| 96 |
+
extractor=extractor,
|
| 97 |
+
config=MultiImageDetectorConfig(
|
| 98 |
+
ratio=float(args.ratio),
|
| 99 |
+
max_matches=int(args.max_matches),
|
| 100 |
+
max_keypoints=int(args.max_keypoints),
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
result = detector.detect(templates, search, template_ids=template_ids)
|
| 105 |
+
|
| 106 |
+
for it in result.items:
|
| 107 |
+
print(f"template={it.template_id} kp_t={it.template_kp.n} kp_s={it.search_kp.n} matches={it.matches.m}")
|
| 108 |
+
|
| 109 |
+
title = "sift_multi_image_match | " + " , ".join([f"{it.template_id}:{it.matches.m}" for it in result.items])
|
| 110 |
+
cv2.imshow(title, result.vis_search)
|
| 111 |
+
cv2.waitKey(0)
|
| 112 |
+
cv2.destroyAllWindows()
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
| 118 |
+
|
examples/keypoint_match/superpoint_multi_image_match.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
示例:使用 SuperPoint 提取关键点 + 多模板匹配(不做仿射/单应性)。
|
| 6 |
+
|
| 7 |
+
运行:
|
| 8 |
+
python -m examples.keypoint_match.superpoint_multi_image_match
|
| 9 |
+
|
| 10 |
+
或指定多个模板:
|
| 11 |
+
python -m examples.keypoint_match.superpoint_multi_image_match --template_paths a.png b.png --search_path big.png
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import List, Tuple
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from project_settings import project_path
|
| 24 |
+
from toolbox.keypoint_match import (
|
| 25 |
+
MultiImageDetector,
|
| 26 |
+
MultiImageDetectorConfig,
|
| 27 |
+
SuperPointExtractConfig,
|
| 28 |
+
SuperPointKeyPointExtract,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _crop_template(image_bgr: np.ndarray, xyxy: Tuple[int, int, int, int]) -> np.ndarray:
|
| 33 |
+
x1, y1, x2, y2 = [int(v) for v in xyxy]
|
| 34 |
+
h, w = image_bgr.shape[:2]
|
| 35 |
+
x1 = max(0, min(w - 1, x1))
|
| 36 |
+
x2 = max(0, min(w, x2))
|
| 37 |
+
y1 = max(0, min(h - 1, y1))
|
| 38 |
+
y2 = max(0, min(h, y2))
|
| 39 |
+
if x2 <= x1 or y2 <= y1:
|
| 40 |
+
raise ValueError(f"无效裁剪框: {xyxy}, image_size={(h, w)}")
|
| 41 |
+
return image_bgr[y1:y2, x1:x2].copy()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_args():
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--search_path",
|
| 48 |
+
type=str,
|
| 49 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--template_paths",
|
| 53 |
+
type=str,
|
| 54 |
+
nargs="*",
|
| 55 |
+
default=[
|
| 56 |
+
(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller/roller1.png").as_posix(),
|
| 57 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller1.png").as_posix(),
|
| 58 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller2.png").as_posix(),
|
| 59 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller3.png").as_posix(),
|
| 60 |
+
],
|
| 61 |
+
help="提供多个小图模板路径;为空则使用默认模板 + 从大图裁剪一个 patch 当作第二模板",
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument("--device", type=str, default="cpu")
|
| 64 |
+
parser.add_argument("--max_keypoints", type=int, default=2000)
|
| 65 |
+
parser.add_argument("--ratio", type=float, default=0.99)
|
| 66 |
+
parser.add_argument("--max_matches", type=int, default=120)
|
| 67 |
+
return parser.parse_args()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main():
|
| 71 |
+
args = get_args()
|
| 72 |
+
|
| 73 |
+
search = cv2.imread(args.search_path)
|
| 74 |
+
if search is None:
|
| 75 |
+
raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
|
| 76 |
+
|
| 77 |
+
templates: List[np.ndarray] = []
|
| 78 |
+
template_ids: List[str] = []
|
| 79 |
+
|
| 80 |
+
for p in args.template_paths:
|
| 81 |
+
img = cv2.imread(p)
|
| 82 |
+
if img is None:
|
| 83 |
+
raise FileNotFoundError(f"无法读取模板图: {p}")
|
| 84 |
+
templates.append(img)
|
| 85 |
+
template_ids.append(Path(p).stem)
|
| 86 |
+
|
| 87 |
+
extractor = SuperPointKeyPointExtract(
|
| 88 |
+
SuperPointExtractConfig(
|
| 89 |
+
device=str(args.device),
|
| 90 |
+
max_keypoints=int(args.max_keypoints),
|
| 91 |
+
)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
detector = MultiImageDetector(
|
| 95 |
+
extractor=extractor,
|
| 96 |
+
config=MultiImageDetectorConfig(
|
| 97 |
+
ratio=float(args.ratio),
|
| 98 |
+
max_matches=int(args.max_matches),
|
| 99 |
+
max_keypoints=int(args.max_keypoints),
|
| 100 |
+
),
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
result = detector.detect(templates, search, template_ids=template_ids)
|
| 104 |
+
|
| 105 |
+
# 打印每个模板的匹配数量(越多通常说明越像)
|
| 106 |
+
for it in result.items:
|
| 107 |
+
print(f"template={it.template_id} kp_t={it.template_kp.n} kp_s={it.search_kp.n} matches={it.matches.m}")
|
| 108 |
+
|
| 109 |
+
title = "superpoint_multi_image_match | " + " , ".join([f"{it.template_id}:{it.matches.m}" for it in result.items])
|
| 110 |
+
cv2.imshow(title, result.vis_search)
|
| 111 |
+
cv2.waitKey(0)
|
| 112 |
+
cv2.destroyAllWindows()
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
| 118 |
+
|
examples/keypoint_match/video_track_and_collect_templates.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
通过视频录制目标多角度,用跟踪器持续估计目标在画面中的位置与大致范围(边界框),
|
| 6 |
+
便于按帧裁剪保存为后续关键点匹配的模板小图。
|
| 7 |
+
|
| 8 |
+
用法:
|
| 9 |
+
# 摄像头(Windows 上设备号多为 0)
|
| 10 |
+
python -m examples.keypoint_match.video_track_and_collect_templates --source 0
|
| 11 |
+
|
| 12 |
+
# 视频文件
|
| 13 |
+
python -m examples.keypoint_match.video_track_and_collect_templates --source path/to/video.mp4
|
| 14 |
+
|
| 15 |
+
操作:
|
| 16 |
+
- 第一帧(或按 r 重选):按住鼠标左键拖拽画矩形框选目标,松开鼠标后自动开始跟踪
|
| 17 |
+
- s:将当前帧中跟踪框区域(可带边距)裁剪保存到 --output_dir
|
| 18 |
+
- r:进入重选模式,在当前帧重新拖拽框选(跟踪漂移或丢目标时使用)
|
| 19 |
+
- 空格:暂停/继续
|
| 20 |
+
- q 或 ESC:退出
|
| 21 |
+
|
| 22 |
+
跟踪器:
|
| 23 |
+
默认 mil(OpenCV 自带)。若本机有 Nano 权重可试 nano(需自行准备模型路径,见 --help)。
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import sys
|
| 30 |
+
import time
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Optional, Tuple
|
| 33 |
+
|
| 34 |
+
import cv2
|
| 35 |
+
import numpy as np
|
| 36 |
+
|
| 37 |
+
from project_settings import temp_directory
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Rect = Tuple[int, int, int, int] # x, y, w, h
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_tracker(kind: str, nano_model: Optional[str] = None, nano_backend: Optional[str] = None):
|
| 44 |
+
"""创建 OpenCV Tracker,按环境能力回退。"""
|
| 45 |
+
kind = (kind or "mil").lower().strip()
|
| 46 |
+
|
| 47 |
+
if kind == "nano":
|
| 48 |
+
if not hasattr(cv2, "TrackerNano_create"):
|
| 49 |
+
raise RuntimeError("当前 OpenCV 未提供 TrackerNano_create,请改用 --tracker mil")
|
| 50 |
+
params = cv2.TrackerNano_Params()
|
| 51 |
+
if nano_model:
|
| 52 |
+
params.backbone = nano_model
|
| 53 |
+
if nano_backend:
|
| 54 |
+
params.NNBackend = nano_backend
|
| 55 |
+
return cv2.TrackerNano_create(params)
|
| 56 |
+
|
| 57 |
+
if kind == "goturn":
|
| 58 |
+
if not hasattr(cv2, "TrackerGOTURN_create"):
|
| 59 |
+
raise RuntimeError("当前 OpenCV 未提供 TrackerGOTURN_create")
|
| 60 |
+
return cv2.TrackerGOTURN_create()
|
| 61 |
+
|
| 62 |
+
if kind == "dasiamrpn":
|
| 63 |
+
if not hasattr(cv2, "TrackerDaSiamRPN_create"):
|
| 64 |
+
raise RuntimeError("当前 OpenCV 未提供 TrackerDaSiamRPN_create")
|
| 65 |
+
return cv2.TrackerDaSiamRPN_create()
|
| 66 |
+
|
| 67 |
+
if kind == "mil":
|
| 68 |
+
return cv2.TrackerMIL_create()
|
| 69 |
+
|
| 70 |
+
raise ValueError(f"未知 tracker 类型: {kind},可选: mil, nano, goturn, dasiamrpn")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def clamp_rect(x: int, y: int, w: int, h: int, frame_w: int, frame_h: int) -> Rect:
|
| 74 |
+
x = max(0, min(x, frame_w - 1))
|
| 75 |
+
y = max(0, min(y, frame_h - 1))
|
| 76 |
+
w = max(1, min(w, frame_w - x))
|
| 77 |
+
h = max(1, min(h, frame_h - y))
|
| 78 |
+
return x, y, w, h
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ROISelector:
|
| 82 |
+
"""鼠标拖拽选框。"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, window_name: str):
|
| 85 |
+
self.window_name = window_name
|
| 86 |
+
self.drawing = False
|
| 87 |
+
self.x0 = self.y0 = self.x1 = self.y1 = 0
|
| 88 |
+
|
| 89 |
+
def on_mouse(self, event, mx, my, flags, param):
|
| 90 |
+
if event == cv2.EVENT_LBUTTONDOWN:
|
| 91 |
+
self.drawing = True
|
| 92 |
+
self.x0 = self.x1 = mx
|
| 93 |
+
self.y0 = self.y1 = my
|
| 94 |
+
elif event == cv2.EVENT_MOUSEMOVE and self.drawing:
|
| 95 |
+
self.x1, self.y1 = mx, my
|
| 96 |
+
elif event == cv2.EVENT_LBUTTONUP:
|
| 97 |
+
self.drawing = False
|
| 98 |
+
self.x1, self.y1 = mx, my
|
| 99 |
+
|
| 100 |
+
def attach(self):
|
| 101 |
+
cv2.setMouseCallback(self.window_name, self.on_mouse)
|
| 102 |
+
|
| 103 |
+
def current_rect(self, frame_w: int, frame_h: int) -> Optional[Rect]:
|
| 104 |
+
x1, y1, x2, y2 = self.x0, self.y0, self.x1, self.y1
|
| 105 |
+
if abs(x2 - x1) < 3 or abs(y2 - y1) < 3:
|
| 106 |
+
return None
|
| 107 |
+
xa, xb = sorted((x1, x2))
|
| 108 |
+
ya, yb = sorted((y1, y2))
|
| 109 |
+
return clamp_rect(xa, ya, xb - xa, yb - ya, frame_w, frame_h)
|
| 110 |
+
|
| 111 |
+
def draw_preview(self, frame: np.ndarray) -> np.ndarray:
|
| 112 |
+
vis = frame.copy()
|
| 113 |
+
if self.drawing or abs(self.x1 - self.x0) > 1:
|
| 114 |
+
xa, xb = sorted((self.x0, self.x1))
|
| 115 |
+
ya, yb = sorted((self.y0, self.y1))
|
| 116 |
+
cv2.rectangle(vis, (xa, ya), (xb, yb), (0, 255, 255), 2)
|
| 117 |
+
return vis
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def open_capture(source: str) -> cv2.VideoCapture:
|
| 121 |
+
if source.isdigit():
|
| 122 |
+
cap = cv2.VideoCapture(int(source))
|
| 123 |
+
else:
|
| 124 |
+
cap = cv2.VideoCapture(source)
|
| 125 |
+
if not cap.isOpened():
|
| 126 |
+
raise RuntimeError(f"无法打开视频源: {source}")
|
| 127 |
+
return cap
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_args():
|
| 131 |
+
p = argparse.ArgumentParser(description="视频目标跟踪 + 裁剪保存模板小图")
|
| 132 |
+
p.add_argument(
|
| 133 |
+
"--source",
|
| 134 |
+
type=str,
|
| 135 |
+
default="0",
|
| 136 |
+
help="摄像头设备号(如 0)或视频文件路径",
|
| 137 |
+
)
|
| 138 |
+
p.add_argument(
|
| 139 |
+
"--output_dir",
|
| 140 |
+
type=str,
|
| 141 |
+
default=(temp_directory / "template_crops").as_posix(),
|
| 142 |
+
help="按 s 保存裁剪图到此目录",
|
| 143 |
+
)
|
| 144 |
+
p.add_argument(
|
| 145 |
+
"--tracker",
|
| 146 |
+
type=str,
|
| 147 |
+
default="mil",
|
| 148 |
+
choices=["mil", "nano", "goturn", "dasiamrpn"],
|
| 149 |
+
help="OpenCV 跟踪器��型(默认可用 mil)",
|
| 150 |
+
)
|
| 151 |
+
p.add_argument("--nano_model", type=str, default="", help="TrackerNano backbone 模型路径(可选)")
|
| 152 |
+
p.add_argument("--nano_backend", type=str, default="", help="TrackerNano NNBackend(可选,依 OpenCV 文档)")
|
| 153 |
+
p.add_argument(
|
| 154 |
+
"--crop_pad",
|
| 155 |
+
type=float,
|
| 156 |
+
default=0.08,
|
| 157 |
+
help="保存裁剪时在框四周按比例扩展(相对框宽高),便于包含上下文",
|
| 158 |
+
)
|
| 159 |
+
return p.parse_args()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def expand_rect(rect: Rect, pad_ratio: float, fw: int, fh: int) -> Rect:
|
| 163 |
+
x, y, w, h = rect
|
| 164 |
+
px = int(w * pad_ratio)
|
| 165 |
+
py = int(h * pad_ratio)
|
| 166 |
+
return clamp_rect(x - px, y - py, w + 2 * px, h + 2 * py, fw, fh)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def main():
|
| 170 |
+
args = get_args()
|
| 171 |
+
out_dir = Path(args.output_dir)
|
| 172 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 173 |
+
|
| 174 |
+
cap = open_capture(args.source)
|
| 175 |
+
win = "track_and_collect | drag ROI then release | s save | r reselect | space pause | q quit"
|
| 176 |
+
cv2.namedWindow(win, cv2.WINDOW_NORMAL)
|
| 177 |
+
selector = ROISelector(win)
|
| 178 |
+
selector.attach()
|
| 179 |
+
|
| 180 |
+
tracker = None
|
| 181 |
+
paused = False
|
| 182 |
+
need_init = True
|
| 183 |
+
ok_track = False
|
| 184 |
+
bbox: Optional[Rect] = None
|
| 185 |
+
save_idx = 0
|
| 186 |
+
nano_model = args.nano_model or None
|
| 187 |
+
nano_backend = args.nano_backend or None
|
| 188 |
+
prev_drawing = False
|
| 189 |
+
|
| 190 |
+
print(__doc__)
|
| 191 |
+
print(f"输出目录: {out_dir.resolve()}")
|
| 192 |
+
|
| 193 |
+
while True:
|
| 194 |
+
if not paused:
|
| 195 |
+
ret, frame = cap.read()
|
| 196 |
+
if not ret or frame is None:
|
| 197 |
+
print("视频结束或读取失败")
|
| 198 |
+
break
|
| 199 |
+
fh, fw = frame.shape[:2]
|
| 200 |
+
else:
|
| 201 |
+
fh, fw = frame.shape[:2]
|
| 202 |
+
|
| 203 |
+
key = cv2.waitKey(1) & 0xFF
|
| 204 |
+
if key in (ord("q"), 27):
|
| 205 |
+
break
|
| 206 |
+
if key == ord(" "):
|
| 207 |
+
paused = not paused
|
| 208 |
+
continue
|
| 209 |
+
if key == ord("r"):
|
| 210 |
+
need_init = True
|
| 211 |
+
tracker = None
|
| 212 |
+
ok_track = False
|
| 213 |
+
paused = True
|
| 214 |
+
print("重选模式:在本帧拖拽框选目标,松开鼠标后开始跟踪;可按空格继续/暂停")
|
| 215 |
+
|
| 216 |
+
display = frame.copy()
|
| 217 |
+
|
| 218 |
+
if need_init:
|
| 219 |
+
display = selector.draw_preview(display)
|
| 220 |
+
preview_rect = selector.current_rect(fw, fh)
|
| 221 |
+
if preview_rect is not None:
|
| 222 |
+
cv2.rectangle(
|
| 223 |
+
display,
|
| 224 |
+
(preview_rect[0], preview_rect[1]),
|
| 225 |
+
(preview_rect[0] + preview_rect[2], preview_rect[1] + preview_rect[3]),
|
| 226 |
+
(0, 255, 0),
|
| 227 |
+
2,
|
| 228 |
+
)
|
| 229 |
+
# 松开鼠标:从拖拽 -> 非拖拽 的瞬间,若框有效则自动初始化跟踪器
|
| 230 |
+
if prev_drawing and not selector.drawing:
|
| 231 |
+
rect = selector.current_rect(fw, fh)
|
| 232 |
+
if rect is not None:
|
| 233 |
+
tracker = create_tracker(args.tracker, nano_model=nano_model, nano_backend=nano_backend)
|
| 234 |
+
tracker.init(frame, rect)
|
| 235 |
+
bbox = rect
|
| 236 |
+
need_init = False
|
| 237 |
+
ok_track = True
|
| 238 |
+
paused = False
|
| 239 |
+
print(f"已初始化跟踪,bbox= {rect}")
|
| 240 |
+
else:
|
| 241 |
+
if tracker is not None and not paused:
|
| 242 |
+
ok_track, bbox = tracker.update(frame)
|
| 243 |
+
if bbox is not None:
|
| 244 |
+
x, y, w, h = [int(v) for v in bbox]
|
| 245 |
+
x, y, w, h = clamp_rect(x, y, w, h, fw, fh)
|
| 246 |
+
color = (0, 255, 0) if ok_track else (0, 0, 255)
|
| 247 |
+
cv2.rectangle(display, (x, y), (x + w, y + h), color, 2)
|
| 248 |
+
label = "tracking OK" if ok_track else "tracking LOST?"
|
| 249 |
+
cv2.putText(display, label, (x, max(0, y - 8)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)
|
| 250 |
+
|
| 251 |
+
if key == ord("s") and bbox is not None and ok_track:
|
| 252 |
+
x, y, w, h = [int(v) for v in bbox]
|
| 253 |
+
x, y, w, h = clamp_rect(x, y, w, h, fw, fh)
|
| 254 |
+
ex = expand_rect((x, y, w, h), float(args.crop_pad), fw, fh)
|
| 255 |
+
crop = frame[ex[1] : ex[1] + ex[3], ex[0] : ex[0] + ex[2]]
|
| 256 |
+
ts = int(time.time() * 1000)
|
| 257 |
+
path = out_dir / f"template_{save_idx:04d}_{ts}.png"
|
| 258 |
+
cv2.imwrite(str(path), crop)
|
| 259 |
+
save_idx += 1
|
| 260 |
+
print(f"已保存: {path}")
|
| 261 |
+
|
| 262 |
+
hint = (
|
| 263 |
+
"Drag ROI (release=init) | s save | r reselect | space pause | q quit"
|
| 264 |
+
if need_init
|
| 265 |
+
else "s save | r reselect | space pause | q quit"
|
| 266 |
+
)
|
| 267 |
+
cv2.putText(display, hint, (10, fh - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (200, 200, 200), 1, cv2.LINE_AA)
|
| 268 |
+
cv2.imshow(win, display)
|
| 269 |
+
|
| 270 |
+
prev_drawing = selector.drawing
|
| 271 |
+
|
| 272 |
+
cap.release()
|
| 273 |
+
cv2.destroyAllWindows()
|
| 274 |
+
return 0
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
sys.exit(main())
|
examples/keypoints/kornia/test.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/kornia/kornia
|
| 3 |
+
"""
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import kornia as K
|
| 7 |
+
import kornia.feature as KF
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from kornia_moons.viz import draw_LAF_matches
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def match_keyboards(img_path1, img_path2):
|
| 13 |
+
# 1. 加载图片并转换为 Tensor (Kornia 处理的是 Tensor)
|
| 14 |
+
# 键盘细节丰富,建议保持较高分辨率,不要缩放得太厉害
|
| 15 |
+
img1 = K.image_to_tensor(cv2.imread(img_path1), keepdim=False).float() / 255.0
|
| 16 |
+
img2 = K.image_to_tensor(cv2.imread(img_path2), keepdim=False).float() / 255.0
|
| 17 |
+
|
| 18 |
+
# 转换为灰度图用于特征提取
|
| 19 |
+
img1_gray = K.color.rgb_to_grayscale(K.color.bgr_to_rgb(img1))
|
| 20 |
+
img2_gray = K.color.rgb_to_grayscale(K.color.bgr_to_rgb(img2))
|
| 21 |
+
|
| 22 |
+
# 2. 定义特征提取器 (KeyNet + AffNet + HardNet)
|
| 23 |
+
# 这是一套非常强大的组合,对键盘这种重复纹理有很好的识别力
|
| 24 |
+
num_features = 2000 # 键盘字符多,建议点数设多一点
|
| 25 |
+
matcher = KF.LocalFeatureMatcher(
|
| 26 |
+
KF.KeyNetAffNetHardNet(num_features, pretrained=True),
|
| 27 |
+
KF.DescriptorMatcher('adalam', torch.device('cpu')) # 'adalam' 是目前非常快的离群点过滤匹配算法
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# 3. 执行匹配
|
| 31 |
+
input_dict = {"image0": img1_gray, "image1": img2_gray}
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
correspondences = matcher(input_dict)
|
| 34 |
+
|
| 35 |
+
# 提取匹配后的坐标点
|
| 36 |
+
mkpts0 = correspondences['keypoints0'].cpu().numpy()
|
| 37 |
+
mkpts1 = correspondences['keypoints1'].cpu().numpy()
|
| 38 |
+
|
| 39 |
+
print(f"找到匹配点数量: {len(mkpts0)}")
|
| 40 |
+
|
| 41 |
+
# 4. 可视化结果
|
| 42 |
+
# 我们用 OpenCV 把两张图拼在一起看连线
|
| 43 |
+
img1_cv = cv2.cvtColor(cv2.imread(img_path1), cv2.COLOR_BGR2RGB)
|
| 44 |
+
img2_cv = cv2.cvtColor(cv2.imread(img_path2), cv2.COLOR_BGR2RGB)
|
| 45 |
+
|
| 46 |
+
# 简单的可视化:画出前 50 个最强的匹配
|
| 47 |
+
fig, ax = plt.subplots(figsize=(15, 8))
|
| 48 |
+
draw_LAF_matches(
|
| 49 |
+
KF.laf_from_center_scale_ori(correspondences['keypoints0'].view(1, -1, 2)),
|
| 50 |
+
KF.laf_from_center_scale_ori(correspondences['keypoints1'].view(1, -1, 2)),
|
| 51 |
+
torch.arange(len(mkpts0)).view(1, -1, 1),
|
| 52 |
+
img1_cv,
|
| 53 |
+
img2_cv,
|
| 54 |
+
draw_dict={'inliers_fallback': True}
|
| 55 |
+
)
|
| 56 |
+
plt.show()
|
| 57 |
+
|
| 58 |
+
# 5. 判断逻辑
|
| 59 |
+
# 如果匹配点数量很多(比如 > 100)且连线平行度高,则极有可能是同一个键盘
|
| 60 |
+
if len(mkpts0) > 50:
|
| 61 |
+
print("结论:两张图片匹配度极高,很可能是同一个键盘或极其相似。")
|
| 62 |
+
else:
|
| 63 |
+
print("结论:匹配点过少,可能是不同键盘或拍摄角度偏差过大。")
|
| 64 |
+
|
| 65 |
+
# 使用示例
|
| 66 |
+
match_keyboards(
|
| 67 |
+
'keyboard_left.jpg',
|
| 68 |
+
'keyboard_right.jpg'
|
| 69 |
+
)
|
examples/keypoints/superpoint/test.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/rpautrat/SuperPointPretrainedNetwork
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
import requests
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from transformers import AutoImageProcessor
|
| 16 |
+
from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection, SuperPointKeypointDescriptionOutput
|
| 17 |
+
from project_settings import project_path, temp_directory
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_args():
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--model_name",
|
| 24 |
+
default="magic-leap-community/superpoint",
|
| 25 |
+
type=str
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--model_cache_dir",
|
| 29 |
+
default=(project_path / "../../hf_hub_models").as_posix(),
|
| 30 |
+
type=str
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--image_path",
|
| 34 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
|
| 35 |
+
type=str
|
| 36 |
+
)
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
return args
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def show_image(image):
|
| 42 |
+
cv2.imshow("image", image)
|
| 43 |
+
cv2.waitKey(0)
|
| 44 |
+
cv2.destroyAllWindows()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
args = get_args()
|
| 49 |
+
|
| 50 |
+
processor = AutoImageProcessor.from_pretrained(
|
| 51 |
+
pretrained_model_name_or_path=args.model_name,
|
| 52 |
+
cache_dir=args.model_cache_dir,
|
| 53 |
+
)
|
| 54 |
+
model = SuperPointForKeypointDetection.from_pretrained(
|
| 55 |
+
pretrained_model_name_or_path=args.model_name,
|
| 56 |
+
cache_dir=args.model_cache_dir,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
image = Image.open(args.image_path).convert("RGB")
|
| 60 |
+
|
| 61 |
+
inputs = processor(image, return_tensors="pt")
|
| 62 |
+
output: SuperPointKeypointDescriptionOutput = model(**inputs)
|
| 63 |
+
|
| 64 |
+
# 使用 processor 的后处理,将相对坐标转换为像素坐标
|
| 65 |
+
image_size = (image.height, image.width)
|
| 66 |
+
processed = processor.post_process_keypoint_detection(
|
| 67 |
+
output,
|
| 68 |
+
[image_size],
|
| 69 |
+
)
|
| 70 |
+
# processed 是长度为 batch_size 的 list,这里只有一张图
|
| 71 |
+
keypoints = processed[0]["keypoints"] # [N, 2],(x, y) 为像素坐标
|
| 72 |
+
scores = processed[0]["scores"] # [N]
|
| 73 |
+
descriptors = processed[0]["descriptors"] # [N, D]
|
| 74 |
+
scores = scores.detach().cpu().numpy()
|
| 75 |
+
|
| 76 |
+
print(f"检测到关键点数量: {keypoints.shape[0]}")
|
| 77 |
+
print(f"描述符维度: {descriptors.shape}")
|
| 78 |
+
|
| 79 |
+
# 5. 使用 OpenCV 的 drawKeypoints 在图像中画出关键点并展示
|
| 80 |
+
# PIL 图像 -> numpy -> BGR
|
| 81 |
+
image_np = np.array(image) # RGB
|
| 82 |
+
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 83 |
+
|
| 84 |
+
# 将 SuperPoint 的关键点转换为 OpenCV 的 KeyPoint 列表
|
| 85 |
+
cv2_keypoints = []
|
| 86 |
+
for (x, y), score in zip(keypoints, scores):
|
| 87 |
+
# x, y 是像素坐标;score 作为响应值
|
| 88 |
+
# OpenCV 只有在 angle != -1 时,DRAW_RICH_KEYPOINTS 才会画出“半径线”
|
| 89 |
+
kp = cv2.KeyPoint(
|
| 90 |
+
x=float(x),
|
| 91 |
+
y=float(y),
|
| 92 |
+
size=7,
|
| 93 |
+
response=float(score),
|
| 94 |
+
)
|
| 95 |
+
cv2_keypoints.append(kp)
|
| 96 |
+
|
| 97 |
+
# 使用 drawKeypoints 画关键点
|
| 98 |
+
image_with_kp = cv2.drawKeypoints(
|
| 99 |
+
image_bgr,
|
| 100 |
+
cv2_keypoints,
|
| 101 |
+
None,
|
| 102 |
+
color=(0, 0, 255),
|
| 103 |
+
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
show_image(image_with_kp)
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
main.py
ADDED
|
File without changes
|
project_settings.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from toolbox.os.environment import EnvironmentManager
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
project_path = os.path.abspath(os.path.dirname(__file__))
|
| 10 |
+
project_path = Path(project_path)
|
| 11 |
+
|
| 12 |
+
time_zone_info = "Asia/Shanghai"
|
| 13 |
+
|
| 14 |
+
log_directory = project_path / "logs"
|
| 15 |
+
log_directory.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
temp_directory = project_path / "temp"
|
| 18 |
+
temp_directory.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
environment = EnvironmentManager(
|
| 21 |
+
path=os.path.join(project_path, "dotenv"),
|
| 22 |
+
env=os.environ.get("environment", "dev"),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
pass
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python-dotenv
|
| 2 |
+
transformers
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
opencv-python
|
| 6 |
+
Pillow
|
| 7 |
+
requests
|
toolbox/__init__.py
ADDED
|
File without changes
|
toolbox/json/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == '__main__':
|
| 6 |
+
pass
|
toolbox/json/misc.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def traverse(js, callback: Callable, *args, **kwargs):
|
| 7 |
+
if isinstance(js, list):
|
| 8 |
+
result = list()
|
| 9 |
+
for l in js:
|
| 10 |
+
l = traverse(l, callback, *args, **kwargs)
|
| 11 |
+
result.append(l)
|
| 12 |
+
return result
|
| 13 |
+
elif isinstance(js, tuple):
|
| 14 |
+
result = list()
|
| 15 |
+
for l in js:
|
| 16 |
+
l = traverse(l, callback, *args, **kwargs)
|
| 17 |
+
result.append(l)
|
| 18 |
+
return tuple(result)
|
| 19 |
+
elif isinstance(js, dict):
|
| 20 |
+
result = dict()
|
| 21 |
+
for k, v in js.items():
|
| 22 |
+
k = traverse(k, callback, *args, **kwargs)
|
| 23 |
+
v = traverse(v, callback, *args, **kwargs)
|
| 24 |
+
result[k] = v
|
| 25 |
+
return result
|
| 26 |
+
elif isinstance(js, int):
|
| 27 |
+
return callback(js, *args, **kwargs)
|
| 28 |
+
elif isinstance(js, str):
|
| 29 |
+
return callback(js, *args, **kwargs)
|
| 30 |
+
else:
|
| 31 |
+
return js
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def demo1():
|
| 35 |
+
d = {
|
| 36 |
+
"env": "ppe",
|
| 37 |
+
"mysql_connect": {
|
| 38 |
+
"host": "$mysql_connect_host",
|
| 39 |
+
"port": 3306,
|
| 40 |
+
"user": "callbot",
|
| 41 |
+
"password": "NxcloudAI2021!",
|
| 42 |
+
"database": "callbot_ppe",
|
| 43 |
+
"charset": "utf8"
|
| 44 |
+
},
|
| 45 |
+
"es_connect": {
|
| 46 |
+
"hosts": ["10.20.251.8"],
|
| 47 |
+
"http_auth": ["elastic", "ElasticAI2021!"],
|
| 48 |
+
"port": 9200
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def callback(s):
|
| 53 |
+
if isinstance(s, str) and s.startswith('$'):
|
| 54 |
+
return s[1:]
|
| 55 |
+
return s
|
| 56 |
+
|
| 57 |
+
result = traverse(d, callback=callback)
|
| 58 |
+
print(result)
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
demo1()
|
toolbox/keypoint_match/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from toolbox.keypoint_match.types import (
|
| 5 |
+
Detection,
|
| 6 |
+
DetectionResult,
|
| 7 |
+
ImageArray,
|
| 8 |
+
KeyPointSet,
|
| 9 |
+
MatchSet,
|
| 10 |
+
)
|
| 11 |
+
from toolbox.keypoint_match.base import (
|
| 12 |
+
KeyPointExtract,
|
| 13 |
+
KeyPointExtractConfig,
|
| 14 |
+
KeyPointMatch,
|
| 15 |
+
KeyPointMatchConfig,
|
| 16 |
+
KeyPointTemplateDetector,
|
| 17 |
+
RegionScorer,
|
| 18 |
+
RegionScorerConfig,
|
| 19 |
+
)
|
| 20 |
+
from toolbox.keypoint_match.detector import SimpleKeyPointTemplateDetector
|
| 21 |
+
from toolbox.keypoint_match.keypoint_extracter.sift import SiftExtractConfig, SiftKeyPointExtract
|
| 22 |
+
from toolbox.keypoint_match.keypoint_extracter.superpoint import SuperPointExtractConfig, SuperPointKeyPointExtract
|
| 23 |
+
from toolbox.keypoint_match.keypoint_match.single_image_match import SingleImageMatcher, SingleImageMatcherConfig
|
| 24 |
+
from toolbox.keypoint_match.keypoint_detector.single_image_detector import (
|
| 25 |
+
SingleImageDetector,
|
| 26 |
+
SingleImageDetectorConfig,
|
| 27 |
+
SingleImageDetectorResult,
|
| 28 |
+
)
|
| 29 |
+
from toolbox.keypoint_match.keypoint_detector.multi_image_detector import (
|
| 30 |
+
MultiImageDetector,
|
| 31 |
+
MultiImageDetectorConfig,
|
| 32 |
+
MultiImageDetectorItem,
|
| 33 |
+
MultiImageDetectorResult,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
__all__ = [
|
| 37 |
+
"Detection",
|
| 38 |
+
"DetectionResult",
|
| 39 |
+
"ImageArray",
|
| 40 |
+
"KeyPointSet",
|
| 41 |
+
"MatchSet",
|
| 42 |
+
"KeyPointExtract",
|
| 43 |
+
"KeyPointExtractConfig",
|
| 44 |
+
"KeyPointMatch",
|
| 45 |
+
"KeyPointMatchConfig",
|
| 46 |
+
"KeyPointTemplateDetector",
|
| 47 |
+
"RegionScorer",
|
| 48 |
+
"RegionScorerConfig",
|
| 49 |
+
"SimpleKeyPointTemplateDetector",
|
| 50 |
+
"SiftExtractConfig",
|
| 51 |
+
"SiftKeyPointExtract",
|
| 52 |
+
"SuperPointExtractConfig",
|
| 53 |
+
"SuperPointKeyPointExtract",
|
| 54 |
+
"SingleImageMatcher",
|
| 55 |
+
"SingleImageMatcherConfig",
|
| 56 |
+
"SingleImageDetector",
|
| 57 |
+
"SingleImageDetectorConfig",
|
| 58 |
+
"SingleImageDetectorResult",
|
| 59 |
+
"MultiImageDetector",
|
| 60 |
+
"MultiImageDetectorConfig",
|
| 61 |
+
"MultiImageDetectorItem",
|
| 62 |
+
"MultiImageDetectorResult",
|
| 63 |
+
]
|
| 64 |
+
|
toolbox/keypoint_match/base.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Mapping, Optional, Sequence, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from toolbox.keypoint_match.types import (
|
| 13 |
+
DetectionResult,
|
| 14 |
+
ImageArray,
|
| 15 |
+
KeyPointSet,
|
| 16 |
+
MatchSet,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(frozen=True)
|
| 21 |
+
class KeyPointExtractConfig:
|
| 22 |
+
"""关键点/描述子提取器的通用配置。"""
|
| 23 |
+
|
| 24 |
+
max_keypoints: Optional[int] = None
|
| 25 |
+
# 允许实现方在 meta 中存放更多参数,例如:nms_radius、score_threshold、scale_pyramid 等
|
| 26 |
+
extra: Optional[Mapping[str, Any]] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KeyPointExtract(ABC):
|
| 30 |
+
"""
|
| 31 |
+
抽象:从一张图片提取关键点与描述子。
|
| 32 |
+
|
| 33 |
+
适配来源:
|
| 34 |
+
- OpenCV: ORB/SIFT/AKAZE 等
|
| 35 |
+
- Kornia: KeyNetAffNetHardNet、DISK 等
|
| 36 |
+
- SuperPoint/LightGlue 等深度模型
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: Optional[KeyPointExtractConfig] = None):
|
| 40 |
+
self.config = config or KeyPointExtractConfig()
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def extract(self, image: ImageArray) -> KeyPointSet:
|
| 44 |
+
"""输入一张图,输出关键点集合(包含可选描述子)。"""
|
| 45 |
+
|
| 46 |
+
def batch_extract(self, images: Sequence[ImageArray]) -> Sequence[KeyPointSet]:
|
| 47 |
+
return [self.extract(im) for im in images]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass(frozen=True)
|
| 51 |
+
class KeyPointMatchConfig:
|
| 52 |
+
"""匹配器的通用配置。"""
|
| 53 |
+
|
| 54 |
+
# 常见:ratio test 的阈值(若实现方使用 KNN)
|
| 55 |
+
ratio: Optional[float] = None
|
| 56 |
+
# 允许实现方控制返回的匹配数量上限
|
| 57 |
+
max_matches: Optional[int] = None
|
| 58 |
+
extra: Optional[Mapping[str, Any]] = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class KeyPointMatch(ABC):
|
| 62 |
+
"""
|
| 63 |
+
抽象:对两张图(或两组描述子)进行匹配,输出匹配对。
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, config: Optional[KeyPointMatchConfig] = None):
|
| 67 |
+
self.config = config or KeyPointMatchConfig()
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def match(self, query: KeyPointSet, train: KeyPointSet) -> MatchSet:
|
| 71 |
+
"""
|
| 72 |
+
query: 通常来自“小图/模板”
|
| 73 |
+
train: 通常来自“大图/搜索图”
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass(frozen=True)
|
| 78 |
+
class RegionScorerConfig:
|
| 79 |
+
"""
|
| 80 |
+
匹配点聚集成“区域”的通用配置。
|
| 81 |
+
|
| 82 |
+
说明:你的核心算法描述是“某个区域匹配点特别多 => 目标被找到”,
|
| 83 |
+
因此我们把“如何聚类/评分/生成 bbox”抽象成 RegionScorer。
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# 大图上聚类半径(像素)。例如:DBSCAN eps 或网格统计的 cell_size
|
| 87 |
+
radius_px: float = 24.0
|
| 88 |
+
# 认为“找到”的最低匹配点数阈值
|
| 89 |
+
min_match_count: int = 12
|
| 90 |
+
# 最多输出多少个候选区域
|
| 91 |
+
topk: int = 10
|
| 92 |
+
extra: Optional[Mapping[str, Any]] = None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class RegionScorer(ABC):
|
| 96 |
+
"""
|
| 97 |
+
抽象:把匹配关系映射成候选区域(bbox + score)。
|
| 98 |
+
|
| 99 |
+
输入包含关键点坐标与匹配对,因此实现可以:
|
| 100 |
+
- 做简单的网格投票 / 密度聚类
|
| 101 |
+
- 用单应性 / RANSAC 过滤外点后再成簇
|
| 102 |
+
- 用匹配点的局部一致性做评分
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, config: Optional[RegionScorerConfig] = None):
|
| 106 |
+
self.config = config or RegionScorerConfig()
|
| 107 |
+
|
| 108 |
+
@abstractmethod
|
| 109 |
+
def score(
|
| 110 |
+
self,
|
| 111 |
+
query: KeyPointSet,
|
| 112 |
+
train: KeyPointSet,
|
| 113 |
+
matches: MatchSet,
|
| 114 |
+
*,
|
| 115 |
+
template_id: Optional[str] = None,
|
| 116 |
+
template_size: Optional[Tuple[int, int]] = None, # (h, w)
|
| 117 |
+
) -> DetectionResult:
|
| 118 |
+
"""
|
| 119 |
+
template_size: 小图大小,便于从匹配点推断 bbox 尺寸(可选)
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class KeyPointTemplateDetector(ABC):
|
| 124 |
+
"""
|
| 125 |
+
抽象:把 (提取器 + 匹配器 + 区域评分器) 组合成“模板检测”。
|
| 126 |
+
|
| 127 |
+
你可以实现一个具体 Detector,将其用于:
|
| 128 |
+
- 多模板检索(小图集合在大图中找出现位置)
|
| 129 |
+
- 单模板定位(某个小图在大图的哪里)
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
@abstractmethod
|
| 133 |
+
def detect(
|
| 134 |
+
self,
|
| 135 |
+
template_image: ImageArray,
|
| 136 |
+
search_image: ImageArray,
|
| 137 |
+
*,
|
| 138 |
+
template_id: Optional[str] = None,
|
| 139 |
+
) -> DetectionResult:
|
| 140 |
+
...
|
| 141 |
+
|
toolbox/keypoint_match/detector.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
from toolbox.keypoint_match.base import KeyPointExtract, KeyPointMatch, KeyPointTemplateDetector, RegionScorer
|
| 10 |
+
from toolbox.keypoint_match.types import DetectionResult, ImageArray, KeyPointSet, MatchSet
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class SimpleKeyPointTemplateDetector(KeyPointTemplateDetector):
|
| 15 |
+
"""
|
| 16 |
+
一个“组合式”的默认 Detector 实现。
|
| 17 |
+
|
| 18 |
+
说明:
|
| 19 |
+
- 这是一个可直接工作的拼装类:extract(template) + extract(search) + match + score
|
| 20 |
+
- 具体的“找区域”逻辑由 RegionScorer 决定(因此仍保持算法可替换/可扩展)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
extractor: KeyPointExtract
|
| 24 |
+
matcher: KeyPointMatch
|
| 25 |
+
region_scorer: RegionScorer
|
| 26 |
+
|
| 27 |
+
def detect(
|
| 28 |
+
self,
|
| 29 |
+
template_image: ImageArray,
|
| 30 |
+
search_image: ImageArray,
|
| 31 |
+
*,
|
| 32 |
+
template_id: Optional[str] = None,
|
| 33 |
+
) -> DetectionResult:
|
| 34 |
+
template_kp: KeyPointSet = self.extractor.extract(template_image)
|
| 35 |
+
search_kp: KeyPointSet = self.extractor.extract(search_image)
|
| 36 |
+
matches: MatchSet = self.matcher.match(template_kp, search_kp)
|
| 37 |
+
|
| 38 |
+
template_size: Optional[Tuple[int, int]] = None
|
| 39 |
+
try:
|
| 40 |
+
h, w = int(template_image.shape[0]), int(template_image.shape[1])
|
| 41 |
+
template_size = (h, w)
|
| 42 |
+
except Exception:
|
| 43 |
+
template_size = None
|
| 44 |
+
|
| 45 |
+
return self.region_scorer.score(
|
| 46 |
+
template_kp,
|
| 47 |
+
search_kp,
|
| 48 |
+
matches,
|
| 49 |
+
template_id=template_id,
|
| 50 |
+
template_size=template_size,
|
| 51 |
+
)
|
| 52 |
+
|
toolbox/keypoint_match/keypoint_detector/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from toolbox.keypoint_match.keypoint_detector.single_image_detector import (
|
| 5 |
+
SingleImageDetector,
|
| 6 |
+
SingleImageDetectorConfig,
|
| 7 |
+
SingleImageDetectorResult,
|
| 8 |
+
)
|
| 9 |
+
from toolbox.keypoint_match.keypoint_detector.multi_image_detector import (
|
| 10 |
+
MultiImageDetector,
|
| 11 |
+
MultiImageDetectorConfig,
|
| 12 |
+
MultiImageDetectorItem,
|
| 13 |
+
MultiImageDetectorResult,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"SingleImageDetector",
|
| 18 |
+
"SingleImageDetectorConfig",
|
| 19 |
+
"SingleImageDetectorResult",
|
| 20 |
+
"MultiImageDetector",
|
| 21 |
+
"MultiImageDetectorConfig",
|
| 22 |
+
"MultiImageDetectorItem",
|
| 23 |
+
"MultiImageDetectorResult",
|
| 24 |
+
]
|
| 25 |
+
|
toolbox/keypoint_match/keypoint_detector/multi_image_detector.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from project_settings import project_path
|
| 15 |
+
from toolbox.keypoint_match.base import KeyPointExtract
|
| 16 |
+
from toolbox.keypoint_match.keypoint_match.single_image_match import (
|
| 17 |
+
SingleImageMatcher,
|
| 18 |
+
SingleImageMatcherConfig,
|
| 19 |
+
)
|
| 20 |
+
from toolbox.keypoint_match.types import ImageArray, KeyPointSet, MatchSet
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(frozen=True)
|
| 24 |
+
class MultiImageDetectorConfig:
|
| 25 |
+
"""
|
| 26 |
+
多模板(多个小图)在同一张大图中的关键点匹配检测配置。
|
| 27 |
+
|
| 28 |
+
- 不做仿射/单应性估计
|
| 29 |
+
- 只画出“大图中匹配得很好的关键点”
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
ratio: float = 0.75
|
| 33 |
+
max_matches: int = 120
|
| 34 |
+
max_keypoints: int = 2000
|
| 35 |
+
point_radius: int = 4
|
| 36 |
+
point_thickness: int = 2
|
| 37 |
+
# 每个模板对应一个颜色(BGR),不足时循环使用
|
| 38 |
+
colors_bgr: Sequence[Tuple[int, int, int]] = (
|
| 39 |
+
(0, 0, 255), # red
|
| 40 |
+
(0, 255, 0), # green
|
| 41 |
+
(255, 0, 0), # blue
|
| 42 |
+
(0, 255, 255), # yellow
|
| 43 |
+
(255, 0, 255), # magenta
|
| 44 |
+
(255, 255, 0), # cyan
|
| 45 |
+
)
|
| 46 |
+
extra: Optional[Dict[str, Any]] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass(frozen=True)
|
| 50 |
+
class MultiImageDetectorItem:
|
| 51 |
+
template_id: str
|
| 52 |
+
template_kp: KeyPointSet
|
| 53 |
+
search_kp: KeyPointSet
|
| 54 |
+
matches: MatchSet
|
| 55 |
+
color_bgr: Tuple[int, int, int]
|
| 56 |
+
meta: Optional[Dict[str, Any]] = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True)
|
| 60 |
+
class MultiImageDetectorResult:
|
| 61 |
+
items: Sequence[MultiImageDetectorItem]
|
| 62 |
+
vis_search: np.ndarray
|
| 63 |
+
meta: Optional[Dict[str, Any]] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class MultiImageDetector:
|
| 67 |
+
"""
|
| 68 |
+
Multi image detector:给定多个模板小图,在同一张大图中找“匹配得很好的关键点”,并可视化。
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, extractor: KeyPointExtract, config: Optional[MultiImageDetectorConfig] = None):
|
| 72 |
+
self.extractor = extractor
|
| 73 |
+
self.config = config or MultiImageDetectorConfig()
|
| 74 |
+
|
| 75 |
+
def _color_for_index(self, i: int) -> Tuple[int, int, int]:
|
| 76 |
+
palette = list(self.config.colors_bgr) if self.config.colors_bgr else [(0, 0, 255)]
|
| 77 |
+
return tuple(int(c) for c in palette[i % len(palette)])
|
| 78 |
+
|
| 79 |
+
def _draw_points_on_search(
|
| 80 |
+
self,
|
| 81 |
+
base_image: ImageArray,
|
| 82 |
+
search_kp: KeyPointSet,
|
| 83 |
+
matches: MatchSet,
|
| 84 |
+
color_bgr: Tuple[int, int, int],
|
| 85 |
+
) -> np.ndarray:
|
| 86 |
+
img = np.asarray(base_image).copy()
|
| 87 |
+
if img.ndim == 2:
|
| 88 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 89 |
+
|
| 90 |
+
if matches.m == 0 or search_kp.n == 0:
|
| 91 |
+
return img
|
| 92 |
+
|
| 93 |
+
idx = matches.train_idx.astype(np.int64)
|
| 94 |
+
idx = idx[(idx >= 0) & (idx < search_kp.n)]
|
| 95 |
+
if idx.size == 0:
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
pts = search_kp.xy[idx]
|
| 99 |
+
for x, y in pts:
|
| 100 |
+
cv2.circle(
|
| 101 |
+
img,
|
| 102 |
+
center=(int(round(float(x))), int(round(float(y)))),
|
| 103 |
+
radius=int(self.config.point_radius),
|
| 104 |
+
color=color_bgr,
|
| 105 |
+
thickness=int(self.config.point_thickness),
|
| 106 |
+
lineType=cv2.LINE_AA,
|
| 107 |
+
)
|
| 108 |
+
return img
|
| 109 |
+
|
| 110 |
+
def detect(
|
| 111 |
+
self,
|
| 112 |
+
template_images: Sequence[ImageArray],
|
| 113 |
+
search_image: ImageArray,
|
| 114 |
+
*,
|
| 115 |
+
template_ids: Optional[Sequence[str]] = None,
|
| 116 |
+
) -> MultiImageDetectorResult:
|
| 117 |
+
if template_ids is None:
|
| 118 |
+
template_ids = [f"template_{i}" for i in range(len(template_images))]
|
| 119 |
+
if len(template_ids) != len(template_images):
|
| 120 |
+
raise ValueError("template_ids 的长度必须与 template_images 一致")
|
| 121 |
+
|
| 122 |
+
matcher = SingleImageMatcher(
|
| 123 |
+
extractor=self.extractor,
|
| 124 |
+
config=SingleImageMatcherConfig(
|
| 125 |
+
ratio=float(self.config.ratio),
|
| 126 |
+
max_matches=int(self.config.max_matches),
|
| 127 |
+
),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 为了避免重复提取大图特征点:这里直接复用 matcher 的实现会重复提取,
|
| 131 |
+
# 但为了接口简单先保持这样。后续需要性能时,可扩展 matcher 支持传入已提取的 search_kp。
|
| 132 |
+
items: List[MultiImageDetectorItem] = []
|
| 133 |
+
vis = np.asarray(search_image).copy()
|
| 134 |
+
if vis.ndim == 2:
|
| 135 |
+
vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
|
| 136 |
+
|
| 137 |
+
for i, (tid, tmpl) in enumerate(zip(template_ids, template_images)):
|
| 138 |
+
template_kp, search_kp, matches = matcher.match(tmpl, search_image, template_id=str(tid))
|
| 139 |
+
color = self._color_for_index(i)
|
| 140 |
+
|
| 141 |
+
# 叠加绘制
|
| 142 |
+
vis = self._draw_points_on_search(vis, search_kp, matches, color_bgr=color)
|
| 143 |
+
|
| 144 |
+
items.append(
|
| 145 |
+
MultiImageDetectorItem(
|
| 146 |
+
template_id=str(tid),
|
| 147 |
+
template_kp=template_kp,
|
| 148 |
+
search_kp=search_kp,
|
| 149 |
+
matches=matches,
|
| 150 |
+
color_bgr=color,
|
| 151 |
+
meta={
|
| 152 |
+
"kp_template": int(template_kp.n),
|
| 153 |
+
"kp_search": int(search_kp.n),
|
| 154 |
+
"match_count": int(matches.m),
|
| 155 |
+
},
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
meta: Dict[str, Any] = {
|
| 160 |
+
"template_count": int(len(template_images)),
|
| 161 |
+
"ratio": float(self.config.ratio),
|
| 162 |
+
"max_matches": int(self.config.max_matches),
|
| 163 |
+
}
|
| 164 |
+
if self.config.extra:
|
| 165 |
+
meta["extra"] = dict(self.config.extra)
|
| 166 |
+
|
| 167 |
+
return MultiImageDetectorResult(items=items, vis_search=vis, meta=meta)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_args():
|
| 171 |
+
parser = argparse.ArgumentParser()
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--template_paths",
|
| 174 |
+
type=str,
|
| 175 |
+
nargs="+",
|
| 176 |
+
default=[
|
| 177 |
+
(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller/roller1.png").as_posix(),
|
| 178 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller1.png").as_posix(),
|
| 179 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller2.png").as_posix(),
|
| 180 |
+
(project_path / "data/images/keyboard/g98-v2-pink/local/roller/roller3.png").as_posix(),
|
| 181 |
+
],
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--search_path",
|
| 185 |
+
type=str,
|
| 186 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
|
| 187 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard2.jpg").as_posix(),
|
| 188 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard3.jpg").as_posix(),
|
| 189 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard4.jpg").as_posix(),
|
| 190 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard6.jpg").as_posix(),
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument("--ratio", type=float, default=0.90)
|
| 193 |
+
parser.add_argument("--max_matches", type=int, default=120)
|
| 194 |
+
parser.add_argument("--max_keypoints", type=int, default=2000)
|
| 195 |
+
return parser.parse_args()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def main():
|
| 199 |
+
from toolbox.keypoint_match.keypoint_extracter import SiftExtractConfig, SiftKeyPointExtract
|
| 200 |
+
|
| 201 |
+
args = get_args()
|
| 202 |
+
|
| 203 |
+
search = cv2.imread(args.search_path)
|
| 204 |
+
if search is None:
|
| 205 |
+
raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
|
| 206 |
+
|
| 207 |
+
templates: List[np.ndarray] = []
|
| 208 |
+
template_ids: List[str] = []
|
| 209 |
+
for p in args.template_paths:
|
| 210 |
+
img = cv2.imread(p)
|
| 211 |
+
if img is None:
|
| 212 |
+
raise FileNotFoundError(f"无法读取模板图: {p}")
|
| 213 |
+
templates.append(img)
|
| 214 |
+
template_ids.append(Path(p).stem)
|
| 215 |
+
|
| 216 |
+
extractor = SiftKeyPointExtract(SiftExtractConfig(max_keypoints=int(args.max_keypoints)))
|
| 217 |
+
detector = MultiImageDetector(
|
| 218 |
+
extractor=extractor,
|
| 219 |
+
config=MultiImageDetectorConfig(
|
| 220 |
+
ratio=float(args.ratio),
|
| 221 |
+
max_matches=int(args.max_matches),
|
| 222 |
+
max_keypoints=int(args.max_keypoints),
|
| 223 |
+
),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
result = detector.detect(templates, search, template_ids=template_ids)
|
| 227 |
+
|
| 228 |
+
# 在窗口标题里打印每个模板的匹配数量
|
| 229 |
+
stat = ", ".join([f"{it.template_id}:{it.matches.m}" for it in result.items])
|
| 230 |
+
title = f"multi_image_detector | {stat}"
|
| 231 |
+
cv2.imshow(title, result.vis_search)
|
| 232 |
+
cv2.waitKey(0)
|
| 233 |
+
cv2.destroyAllWindows()
|
| 234 |
+
return
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
main()
|
| 239 |
+
|
toolbox/keypoint_match/keypoint_detector/single_image_detector.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from project_settings import project_path
|
| 14 |
+
from toolbox.keypoint_match.base import KeyPointExtract
|
| 15 |
+
from toolbox.keypoint_match.keypoint_match.single_image_match import (
|
| 16 |
+
SingleImageMatcher,
|
| 17 |
+
SingleImageMatcherConfig,
|
| 18 |
+
)
|
| 19 |
+
from toolbox.keypoint_match.types import ImageArray, KeyPointSet, MatchSet
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True)
|
| 23 |
+
class SingleImageDetectorConfig:
|
| 24 |
+
"""
|
| 25 |
+
单目标小图在大图中的“关键点匹配检测”配置。
|
| 26 |
+
|
| 27 |
+
注意:这里不做仿射/单应性估计,只保留“匹配得很好的点”,并把这些点在大图上画出来。
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
ratio: float = 0.75
|
| 31 |
+
max_matches: int = 120
|
| 32 |
+
max_keypoints: int = 2000
|
| 33 |
+
# 画点参数
|
| 34 |
+
point_radius: int = 4
|
| 35 |
+
point_thickness: int = 2
|
| 36 |
+
color_bgr: Tuple[int, int, int] = (0, 0, 255)
|
| 37 |
+
# 是否同时在模板图上也画出参与匹配的点(便于对照)
|
| 38 |
+
draw_on_template: bool = False
|
| 39 |
+
extra: Optional[Dict[str, Any]] = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(frozen=True)
|
| 43 |
+
class SingleImageDetectorResult:
|
| 44 |
+
template_kp: KeyPointSet
|
| 45 |
+
search_kp: KeyPointSet
|
| 46 |
+
matches: MatchSet
|
| 47 |
+
vis_search: np.ndarray
|
| 48 |
+
vis_template: Optional[np.ndarray] = None
|
| 49 |
+
meta: Optional[Dict[str, Any]] = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SingleImageDetector:
|
| 53 |
+
"""
|
| 54 |
+
Single image detector:输入小图与大图,输出“匹配到的大图关键点可视化”。
|
| 55 |
+
|
| 56 |
+
- 初始化时传入关键点提取器(SIFT/SuperPoint 等)
|
| 57 |
+
- 内部复用 `SingleImageMatcher` 做匹配
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, extractor: KeyPointExtract, config: Optional[SingleImageDetectorConfig] = None):
|
| 61 |
+
self.extractor = extractor
|
| 62 |
+
self.config = config or SingleImageDetectorConfig()
|
| 63 |
+
|
| 64 |
+
def detect(self, template_image: ImageArray, search_image: ImageArray, *, template_id: Optional[str] = None) -> SingleImageDetectorResult:
|
| 65 |
+
matcher = SingleImageMatcher(
|
| 66 |
+
extractor=self.extractor,
|
| 67 |
+
config=SingleImageMatcherConfig(
|
| 68 |
+
ratio=float(self.config.ratio),
|
| 69 |
+
max_matches=int(self.config.max_matches),
|
| 70 |
+
),
|
| 71 |
+
)
|
| 72 |
+
template_kp, search_kp, matches = matcher.match(template_image, search_image, template_id=template_id)
|
| 73 |
+
|
| 74 |
+
vis_search = self.draw_matched_keypoints_on_search(search_image, search_kp, matches)
|
| 75 |
+
vis_template = None
|
| 76 |
+
if self.config.draw_on_template:
|
| 77 |
+
vis_template = self.draw_matched_keypoints_on_template(template_image, template_kp, matches)
|
| 78 |
+
|
| 79 |
+
meta = {
|
| 80 |
+
"template_id": template_id,
|
| 81 |
+
"kp_template": int(template_kp.n),
|
| 82 |
+
"kp_search": int(search_kp.n),
|
| 83 |
+
"match_count": int(matches.m),
|
| 84 |
+
"ratio": float(self.config.ratio),
|
| 85 |
+
}
|
| 86 |
+
if self.config.extra:
|
| 87 |
+
meta["extra"] = dict(self.config.extra)
|
| 88 |
+
|
| 89 |
+
return SingleImageDetectorResult(
|
| 90 |
+
template_kp=template_kp,
|
| 91 |
+
search_kp=search_kp,
|
| 92 |
+
matches=matches,
|
| 93 |
+
vis_search=vis_search,
|
| 94 |
+
vis_template=vis_template,
|
| 95 |
+
meta=meta,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def draw_matched_keypoints_on_search(self, search_image: ImageArray, search_kp: KeyPointSet, matches: MatchSet) -> np.ndarray:
|
| 99 |
+
img = np.asarray(search_image).copy()
|
| 100 |
+
if img.ndim == 2:
|
| 101 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 102 |
+
|
| 103 |
+
if matches.m == 0 or search_kp.n == 0:
|
| 104 |
+
return img
|
| 105 |
+
|
| 106 |
+
# 在大图上画出被匹配到的 train_idx 对应关键点
|
| 107 |
+
idx = matches.train_idx.astype(np.int64)
|
| 108 |
+
idx = idx[(idx >= 0) & (idx < search_kp.n)]
|
| 109 |
+
if idx.size == 0:
|
| 110 |
+
return img
|
| 111 |
+
|
| 112 |
+
pts = search_kp.xy[idx]
|
| 113 |
+
for x, y in pts:
|
| 114 |
+
cv2.circle(
|
| 115 |
+
img,
|
| 116 |
+
center=(int(round(float(x))), int(round(float(y)))),
|
| 117 |
+
radius=int(self.config.point_radius),
|
| 118 |
+
color=tuple(int(c) for c in self.config.color_bgr),
|
| 119 |
+
thickness=int(self.config.point_thickness),
|
| 120 |
+
lineType=cv2.LINE_AA,
|
| 121 |
+
)
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
def draw_matched_keypoints_on_template(self, template_image: ImageArray, template_kp: KeyPointSet, matches: MatchSet) -> np.ndarray:
|
| 125 |
+
img = np.asarray(template_image).copy()
|
| 126 |
+
if img.ndim == 2:
|
| 127 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 128 |
+
|
| 129 |
+
if matches.m == 0 or template_kp.n == 0:
|
| 130 |
+
return img
|
| 131 |
+
|
| 132 |
+
idx = matches.query_idx.astype(np.int64)
|
| 133 |
+
idx = idx[(idx >= 0) & (idx < template_kp.n)]
|
| 134 |
+
if idx.size == 0:
|
| 135 |
+
return img
|
| 136 |
+
|
| 137 |
+
pts = template_kp.xy[idx]
|
| 138 |
+
for x, y in pts:
|
| 139 |
+
cv2.circle(
|
| 140 |
+
img,
|
| 141 |
+
center=(int(round(float(x))), int(round(float(y)))),
|
| 142 |
+
radius=max(2, int(self.config.point_radius) - 1),
|
| 143 |
+
color=(0, 255, 0),
|
| 144 |
+
thickness=int(self.config.point_thickness),
|
| 145 |
+
lineType=cv2.LINE_AA,
|
| 146 |
+
)
|
| 147 |
+
return img
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_args():
|
| 151 |
+
parser = argparse.ArgumentParser()
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--template_path",
|
| 154 |
+
type=str,
|
| 155 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller.png"),
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--search_path",
|
| 159 |
+
type=str,
|
| 160 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png"),
|
| 161 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/keyboard1.jpg"),
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument("--ratio", type=float, default=0.75)
|
| 164 |
+
parser.add_argument("--max_matches", type=int, default=120)
|
| 165 |
+
parser.add_argument("--max_keypoints", type=int, default=2000)
|
| 166 |
+
parser.add_argument("--draw_on_template", action="store_true")
|
| 167 |
+
return parser.parse_args()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def main():
|
| 171 |
+
from toolbox.keypoint_match.keypoint_extracter import SiftExtractConfig, SiftKeyPointExtract
|
| 172 |
+
|
| 173 |
+
args = get_args()
|
| 174 |
+
|
| 175 |
+
template = cv2.imread(args.template_path)
|
| 176 |
+
search = cv2.imread(args.search_path)
|
| 177 |
+
if template is None:
|
| 178 |
+
raise FileNotFoundError(f"无法读取模板图: {args.template_path}")
|
| 179 |
+
if search is None:
|
| 180 |
+
raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
|
| 181 |
+
|
| 182 |
+
extractor = SiftKeyPointExtract(SiftExtractConfig(max_keypoints=int(args.max_keypoints)))
|
| 183 |
+
detector = SingleImageDetector(
|
| 184 |
+
extractor=extractor,
|
| 185 |
+
config=SingleImageDetectorConfig(
|
| 186 |
+
ratio=float(args.ratio),
|
| 187 |
+
max_matches=int(args.max_matches),
|
| 188 |
+
max_keypoints=int(args.max_keypoints),
|
| 189 |
+
draw_on_template=bool(args.draw_on_template),
|
| 190 |
+
),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
result = detector.detect(template, search, template_id="roller")
|
| 194 |
+
|
| 195 |
+
title = f"single_image_detector | matches={result.matches.m}"
|
| 196 |
+
cv2.imshow(title, result.vis_search)
|
| 197 |
+
if result.vis_template is not None:
|
| 198 |
+
cv2.imshow("template_matched_keypoints", result.vis_template)
|
| 199 |
+
cv2.waitKey(0)
|
| 200 |
+
cv2.destroyAllWindows()
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
main()
|
| 206 |
+
|
toolbox/keypoint_match/keypoint_extracter/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from toolbox.keypoint_match.keypoint_extracter.sift import SiftExtractConfig, SiftKeyPointExtract
|
| 5 |
+
from toolbox.keypoint_match.keypoint_extracter.superpoint import SuperPointExtractConfig, SuperPointKeyPointExtract
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"SiftExtractConfig",
|
| 9 |
+
"SiftKeyPointExtract",
|
| 10 |
+
"SuperPointExtractConfig",
|
| 11 |
+
"SuperPointKeyPointExtract",
|
| 12 |
+
]
|
| 13 |
+
|
toolbox/keypoint_match/keypoint_extracter/sift.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
from project_settings import project_path
|
| 13 |
+
from toolbox.keypoint_match.base import KeyPointExtract, KeyPointExtractConfig
|
| 14 |
+
from toolbox.keypoint_match.types import ImageArray, KeyPointSet
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class SiftExtractConfig(KeyPointExtractConfig):
|
| 19 |
+
"""
|
| 20 |
+
OpenCV SIFT 的常用参数封装。
|
| 21 |
+
|
| 22 |
+
说明:OpenCV 的 SIFT_create 参数较多,这里保留最常见的几项;其余可放到 extra。
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
n_features: int = 0
|
| 26 |
+
n_octave_layers: int = 3
|
| 27 |
+
contrast_threshold: float = 0.04
|
| 28 |
+
edge_threshold: float = 10.0
|
| 29 |
+
sigma: float = 1.6
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SiftKeyPointExtract(KeyPointExtract):
|
| 33 |
+
"""基于 OpenCV SIFT 的特征点/描述子提取器。"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, config: Optional[SiftExtractConfig] = None):
|
| 36 |
+
super().__init__(config=config or SiftExtractConfig())
|
| 37 |
+
self.config: SiftExtractConfig
|
| 38 |
+
self._sift = self._create_sift()
|
| 39 |
+
|
| 40 |
+
def _create_sift(self):
|
| 41 |
+
try:
|
| 42 |
+
import cv2 # 延迟导入,避免无 OpenCV 时影响其它模块
|
| 43 |
+
except Exception as e: # pragma: no cover
|
| 44 |
+
raise ImportError("使用 SiftKeyPointExtract 需要先安装 opencv-python 或 opencv-contrib-python") from e
|
| 45 |
+
|
| 46 |
+
if not hasattr(cv2, "SIFT_create"):
|
| 47 |
+
raise RuntimeError(
|
| 48 |
+
"当前 OpenCV 不包含 SIFT(可能缺少 contrib 模块)。"
|
| 49 |
+
"请安装/替换为 `opencv-contrib-python`。"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
cfg = self.config
|
| 53 |
+
return cv2.SIFT_create(
|
| 54 |
+
nfeatures=int(cfg.n_features),
|
| 55 |
+
nOctaveLayers=int(cfg.n_octave_layers),
|
| 56 |
+
contrastThreshold=float(cfg.contrast_threshold),
|
| 57 |
+
edgeThreshold=float(cfg.edge_threshold),
|
| 58 |
+
sigma=float(cfg.sigma),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def _to_gray_u8(image: ImageArray) -> np.ndarray:
|
| 63 |
+
"""
|
| 64 |
+
SIFT 在 OpenCV 中通常使用 8-bit 灰度图。
|
| 65 |
+
- 输入 uint8: 直接处理(若为彩色则转灰)
|
| 66 |
+
- 输入 float: 若在 [0,1],缩放到 [0,255];否则 clip 到 [0,255]
|
| 67 |
+
"""
|
| 68 |
+
import cv2
|
| 69 |
+
|
| 70 |
+
img = np.asarray(image)
|
| 71 |
+
if img.ndim == 3 and img.shape[2] >= 3:
|
| 72 |
+
# OpenCV 读图一般是 BGR;这里不强制颜色空间,仅做灰度化
|
| 73 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 74 |
+
elif img.ndim != 2:
|
| 75 |
+
raise ValueError(f"image 需要是 HxW 或 HxWxC,但得到 {img.shape=}")
|
| 76 |
+
|
| 77 |
+
if img.dtype == np.uint8:
|
| 78 |
+
return img
|
| 79 |
+
|
| 80 |
+
img_f = img.astype(np.float32, copy=False)
|
| 81 |
+
if img_f.size > 0 and img_f.max() <= 1.0:
|
| 82 |
+
img_f = img_f * 255.0
|
| 83 |
+
img_u8 = np.clip(img_f, 0.0, 255.0).astype(np.uint8)
|
| 84 |
+
return img_u8
|
| 85 |
+
|
| 86 |
+
def extract(self, image: ImageArray) -> KeyPointSet:
|
| 87 |
+
import cv2
|
| 88 |
+
|
| 89 |
+
gray = self._to_gray_u8(image)
|
| 90 |
+
keypoints, descriptors = self._sift.detectAndCompute(gray, mask=None)
|
| 91 |
+
|
| 92 |
+
if keypoints is None or len(keypoints) == 0:
|
| 93 |
+
empty_xy = np.zeros((0, 2), dtype=np.float32)
|
| 94 |
+
return KeyPointSet(xy=empty_xy, descriptors=None, scores=None, meta={"backend": "opencv_sift"})
|
| 95 |
+
|
| 96 |
+
xy = np.array([kp.pt for kp in keypoints], dtype=np.float32) # (x,y)
|
| 97 |
+
scores = np.array([kp.response for kp in keypoints], dtype=np.float32)
|
| 98 |
+
|
| 99 |
+
if descriptors is not None:
|
| 100 |
+
descriptors = np.asarray(descriptors)
|
| 101 |
+
|
| 102 |
+
# 若设置了 max_keypoints,则按 response 排序截断(SIFT 本身也可能受 nfeatures 影响)
|
| 103 |
+
max_kp = self.config.max_keypoints
|
| 104 |
+
if max_kp is not None and xy.shape[0] > int(max_kp):
|
| 105 |
+
idx = np.argsort(-scores)[: int(max_kp)]
|
| 106 |
+
xy = xy[idx]
|
| 107 |
+
scores = scores[idx]
|
| 108 |
+
if descriptors is not None:
|
| 109 |
+
descriptors = descriptors[idx]
|
| 110 |
+
|
| 111 |
+
meta: Dict[str, Any] = {
|
| 112 |
+
"backend": "opencv_sift",
|
| 113 |
+
"n_features": int(self.config.n_features),
|
| 114 |
+
"n_octave_layers": int(self.config.n_octave_layers),
|
| 115 |
+
"contrast_threshold": float(self.config.contrast_threshold),
|
| 116 |
+
"edge_threshold": float(self.config.edge_threshold),
|
| 117 |
+
"sigma": float(self.config.sigma),
|
| 118 |
+
}
|
| 119 |
+
if self.config.extra:
|
| 120 |
+
meta["extra"] = dict(self.config.extra)
|
| 121 |
+
|
| 122 |
+
return KeyPointSet(
|
| 123 |
+
xy=xy,
|
| 124 |
+
descriptors=descriptors,
|
| 125 |
+
scores=scores,
|
| 126 |
+
meta=meta,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_args():
|
| 132 |
+
parser = argparse.ArgumentParser()
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--image_path",
|
| 135 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
|
| 136 |
+
# default=(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller.png").as_posix(),
|
| 137 |
+
type=str
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument("--max_keypoints", type=int, default=8000)
|
| 140 |
+
|
| 141 |
+
args = parser.parse_args()
|
| 142 |
+
return args
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main():
|
| 146 |
+
args = get_args()
|
| 147 |
+
|
| 148 |
+
image = cv2.imread(args.image_path)
|
| 149 |
+
if image is None:
|
| 150 |
+
raise FileNotFoundError(f"无法读取图片: {args.image_path}")
|
| 151 |
+
|
| 152 |
+
extractor = SiftKeyPointExtract(
|
| 153 |
+
SiftExtractConfig(
|
| 154 |
+
max_keypoints=int(args.max_keypoints),
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
kp_set = extractor.extract(image)
|
| 158 |
+
|
| 159 |
+
cv2_keypoints = [
|
| 160 |
+
cv2.KeyPoint(
|
| 161 |
+
x=float(x),
|
| 162 |
+
y=float(y),
|
| 163 |
+
size=7,
|
| 164 |
+
response=float(score) if kp_set.scores is not None else 0.0,
|
| 165 |
+
)
|
| 166 |
+
for (x, y), score in zip(kp_set.xy, kp_set.scores if kp_set.scores is not None else np.zeros((kp_set.n,)))
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
image_with_kp = cv2.drawKeypoints(
|
| 170 |
+
image,
|
| 171 |
+
cv2_keypoints,
|
| 172 |
+
None,
|
| 173 |
+
color=(0, 0, 255),
|
| 174 |
+
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
cv2.imshow("sift_keypoints", image_with_kp)
|
| 178 |
+
cv2.waitKey(0)
|
| 179 |
+
cv2.destroyAllWindows()
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
toolbox/keypoint_match/keypoint_extracter/superpoint.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from project_settings import project_path
|
| 14 |
+
from toolbox.keypoint_match.base import KeyPointExtract, KeyPointExtractConfig
|
| 15 |
+
from toolbox.keypoint_match.types import ImageArray, KeyPointSet
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class SuperPointExtractConfig(KeyPointExtractConfig):
|
| 20 |
+
model_name: str = "magic-leap-community/superpoint"
|
| 21 |
+
model_cache_dir: str = (project_path / "../../hf_hub_models").as_posix()
|
| 22 |
+
device: str = "cpu" # "cpu" / "cuda"
|
| 23 |
+
# 设置 HuggingFace 镜像(与示例保持一致)
|
| 24 |
+
hf_endpoint: Optional[str] = "https://hf-mirror.com"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SuperPointKeyPointExtract(KeyPointExtract):
|
| 28 |
+
"""
|
| 29 |
+
基于 transformers 的 SuperPoint 特征点/描述子提取器。
|
| 30 |
+
|
| 31 |
+
参考:examples/keypoints/superpoint/test.py
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: Optional[SuperPointExtractConfig] = None):
|
| 35 |
+
super().__init__(config=config or SuperPointExtractConfig())
|
| 36 |
+
self.config: SuperPointExtractConfig
|
| 37 |
+
|
| 38 |
+
if self.config.hf_endpoint:
|
| 39 |
+
os.environ.setdefault("HF_ENDPOINT", str(self.config.hf_endpoint))
|
| 40 |
+
|
| 41 |
+
self._processor = None
|
| 42 |
+
self._model = None
|
| 43 |
+
|
| 44 |
+
def _lazy_init(self):
|
| 45 |
+
if self._processor is not None and self._model is not None:
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
from transformers import AutoImageProcessor
|
| 50 |
+
from transformers.models.superpoint.modeling_superpoint import SuperPointForKeypointDetection
|
| 51 |
+
|
| 52 |
+
self._processor = AutoImageProcessor.from_pretrained(
|
| 53 |
+
pretrained_model_name_or_path=self.config.model_name,
|
| 54 |
+
cache_dir=self.config.model_cache_dir,
|
| 55 |
+
)
|
| 56 |
+
self._model = SuperPointForKeypointDetection.from_pretrained(
|
| 57 |
+
pretrained_model_name_or_path=self.config.model_name,
|
| 58 |
+
cache_dir=self.config.model_cache_dir,
|
| 59 |
+
)
|
| 60 |
+
self._model.eval()
|
| 61 |
+
|
| 62 |
+
device = torch.device(self.config.device)
|
| 63 |
+
self._model.to(device)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def _to_pil_rgb(image: ImageArray):
|
| 67 |
+
from PIL import Image
|
| 68 |
+
import cv2
|
| 69 |
+
|
| 70 |
+
img = np.asarray(image)
|
| 71 |
+
if img.ndim == 2:
|
| 72 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 73 |
+
elif img.ndim == 3 and img.shape[2] >= 3:
|
| 74 |
+
# 默认按 OpenCV 的 BGR 输入处理
|
| 75 |
+
img = cv2.cvtColor(img[:, :, :3], cv2.COLOR_BGR2RGB)
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"image 需要是 HxW 或 HxWxC,但得到 {img.shape=}")
|
| 78 |
+
|
| 79 |
+
if img.dtype != np.uint8:
|
| 80 |
+
img_f = img.astype(np.float32, copy=False)
|
| 81 |
+
if img_f.size > 0 and img_f.max() <= 1.0:
|
| 82 |
+
img_f = img_f * 255.0
|
| 83 |
+
img = np.clip(img_f, 0.0, 255.0).astype(np.uint8)
|
| 84 |
+
|
| 85 |
+
return Image.fromarray(img).convert("RGB")
|
| 86 |
+
|
| 87 |
+
def extract(self, image: ImageArray) -> KeyPointSet:
|
| 88 |
+
import torch
|
| 89 |
+
|
| 90 |
+
self._lazy_init()
|
| 91 |
+
assert self._processor is not None
|
| 92 |
+
assert self._model is not None
|
| 93 |
+
|
| 94 |
+
pil = self._to_pil_rgb(image)
|
| 95 |
+
inputs = self._processor(pil, return_tensors="pt")
|
| 96 |
+
|
| 97 |
+
device = next(self._model.parameters()).device
|
| 98 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
output = self._model(**inputs)
|
| 102 |
+
|
| 103 |
+
image_size = (pil.height, pil.width)
|
| 104 |
+
processed = self._processor.post_process_keypoint_detection(
|
| 105 |
+
output,
|
| 106 |
+
[image_size],
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
keypoints = processed[0]["keypoints"] # [N,2] (x,y)
|
| 110 |
+
scores = processed[0]["scores"] # [N]
|
| 111 |
+
descriptors = processed[0]["descriptors"] # [N,D]
|
| 112 |
+
|
| 113 |
+
keypoints_np = keypoints.detach().cpu().numpy().astype(np.float32)
|
| 114 |
+
scores_np = scores.detach().cpu().numpy().astype(np.float32)
|
| 115 |
+
desc_np = descriptors.detach().cpu().numpy().astype(np.float32)
|
| 116 |
+
|
| 117 |
+
# 统一按 scores 截断到 max_keypoints
|
| 118 |
+
max_kp = self.config.max_keypoints
|
| 119 |
+
if max_kp is not None and keypoints_np.shape[0] > int(max_kp):
|
| 120 |
+
idx = np.argsort(-scores_np)[: int(max_kp)]
|
| 121 |
+
keypoints_np = keypoints_np[idx]
|
| 122 |
+
scores_np = scores_np[idx]
|
| 123 |
+
desc_np = desc_np[idx]
|
| 124 |
+
|
| 125 |
+
meta: Dict[str, Any] = {
|
| 126 |
+
"backend": "transformers_superpoint",
|
| 127 |
+
"model_name": str(self.config.model_name),
|
| 128 |
+
"device": str(self.config.device),
|
| 129 |
+
}
|
| 130 |
+
if self.config.extra:
|
| 131 |
+
meta["extra"] = dict(self.config.extra)
|
| 132 |
+
|
| 133 |
+
return KeyPointSet(
|
| 134 |
+
xy=keypoints_np,
|
| 135 |
+
descriptors=desc_np,
|
| 136 |
+
scores=scores_np,
|
| 137 |
+
meta=meta,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main():
|
| 142 |
+
import cv2
|
| 143 |
+
|
| 144 |
+
parser = argparse.ArgumentParser()
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--image_path",
|
| 147 |
+
type=str,
|
| 148 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png").as_posix(),
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument("--device", type=str, default="cpu")
|
| 151 |
+
parser.add_argument("--max_keypoints", type=int, default=2000)
|
| 152 |
+
args = parser.parse_args()
|
| 153 |
+
|
| 154 |
+
image = cv2.imread(args.image_path)
|
| 155 |
+
if image is None:
|
| 156 |
+
raise FileNotFoundError(f"无法读取图片: {args.image_path}")
|
| 157 |
+
|
| 158 |
+
extractor = SuperPointKeyPointExtract(
|
| 159 |
+
SuperPointExtractConfig(
|
| 160 |
+
device=str(args.device),
|
| 161 |
+
max_keypoints=int(args.max_keypoints),
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
kp_set = extractor.extract(image)
|
| 165 |
+
|
| 166 |
+
cv2_keypoints = [
|
| 167 |
+
cv2.KeyPoint(
|
| 168 |
+
x=float(x),
|
| 169 |
+
y=float(y),
|
| 170 |
+
size=7,
|
| 171 |
+
response=float(score),
|
| 172 |
+
)
|
| 173 |
+
for (x, y), score in zip(kp_set.xy, kp_set.scores if kp_set.scores is not None else np.zeros((kp_set.n,), dtype=np.float32))
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
image_with_kp = cv2.drawKeypoints(
|
| 177 |
+
image,
|
| 178 |
+
cv2_keypoints,
|
| 179 |
+
None,
|
| 180 |
+
color=(0, 0, 255),
|
| 181 |
+
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS,
|
| 182 |
+
)
|
| 183 |
+
cv2.imshow(f"superpoint_keypoints | n={kp_set.n}", image_with_kp)
|
| 184 |
+
cv2.waitKey(0)
|
| 185 |
+
cv2.destroyAllWindows()
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
main()
|
| 191 |
+
|
toolbox/keypoint_match/keypoint_match/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from toolbox.keypoint_match.keypoint_match.single_image_match import SingleImageMatcher, SingleImageMatcherConfig
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"SingleImageMatcher",
|
| 8 |
+
"SingleImageMatcherConfig",
|
| 9 |
+
]
|
| 10 |
+
|
toolbox/keypoint_match/keypoint_match/single_image_match.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from project_settings import project_path
|
| 13 |
+
from toolbox.keypoint_match.base import KeyPointExtract
|
| 14 |
+
from toolbox.keypoint_match.types import ImageArray, KeyPointSet, MatchSet
|
| 15 |
+
from toolbox.keypoint_match.keypoint_extracter import SiftExtractConfig, SiftKeyPointExtract
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class SingleImageMatcherConfig:
|
| 20 |
+
"""
|
| 21 |
+
单模板图 vs 单大图 的匹配配置。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# KNN ratio test 阈值(Lowe's ratio test),越小越严格
|
| 25 |
+
ratio: float = 0.75
|
| 26 |
+
# 最多保留多少条匹配(按 distance 从小到大截断)
|
| 27 |
+
max_matches: Optional[int] = 500
|
| 28 |
+
# 是否做 mutual check(A->B 与 B->A 互相最近邻一致才保留)
|
| 29 |
+
mutual_check: bool = False
|
| 30 |
+
extra: Optional[Dict[str, Any]] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SingleImageMatcher:
|
| 34 |
+
"""
|
| 35 |
+
匹配类:输入一张目标小图 + 一张包含目标的大图,输出匹配对。
|
| 36 |
+
|
| 37 |
+
- 初始化时传入关键点提取器(例如 `SiftKeyPointExtract` / SuperPoint 提取器等)
|
| 38 |
+
- `match(...)` 会提取两张图的 KeyPointSet,然后用 OpenCV BFMatcher 做匹配并输出 MatchSet
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, extractor: KeyPointExtract, config: Optional[SingleImageMatcherConfig] = None):
|
| 42 |
+
self.extractor = extractor
|
| 43 |
+
self.config = config or SingleImageMatcherConfig()
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def _infer_norm_type(desc: np.ndarray) -> int:
|
| 47 |
+
import cv2
|
| 48 |
+
|
| 49 |
+
# ORB/BRIEF 等通常是 uint8 的二进制描述子;SIFT/SuperPoint 等通常是 float32
|
| 50 |
+
if desc.dtype == np.uint8:
|
| 51 |
+
return cv2.NORM_HAMMING
|
| 52 |
+
return cv2.NORM_L2
|
| 53 |
+
|
| 54 |
+
def _bfmatcher(self, query_desc: np.ndarray):
|
| 55 |
+
import cv2
|
| 56 |
+
|
| 57 |
+
norm = self._infer_norm_type(query_desc)
|
| 58 |
+
# mutual_check=True 时,OpenCV 的 crossCheck 只能用于 match()(不能用于 knnMatch)
|
| 59 |
+
return cv2.BFMatcher(normType=norm, crossCheck=False)
|
| 60 |
+
|
| 61 |
+
def _knn_ratio_match(self, bf, query_desc: np.ndarray, train_desc: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 62 |
+
"""
|
| 63 |
+
返回 (query_idx, train_idx, distance) 的一维数组
|
| 64 |
+
"""
|
| 65 |
+
knn = bf.knnMatch(query_desc, train_desc, k=2)
|
| 66 |
+
q_idx = []
|
| 67 |
+
t_idx = []
|
| 68 |
+
dist = []
|
| 69 |
+
ratio = float(self.config.ratio)
|
| 70 |
+
for pair in knn:
|
| 71 |
+
if len(pair) < 2:
|
| 72 |
+
continue
|
| 73 |
+
m, n = pair[0], pair[1]
|
| 74 |
+
if m.distance < ratio * n.distance:
|
| 75 |
+
q_idx.append(int(m.queryIdx))
|
| 76 |
+
t_idx.append(int(m.trainIdx))
|
| 77 |
+
dist.append(float(m.distance))
|
| 78 |
+
|
| 79 |
+
if len(q_idx) == 0:
|
| 80 |
+
return (
|
| 81 |
+
np.zeros((0,), dtype=np.int64),
|
| 82 |
+
np.zeros((0,), dtype=np.int64),
|
| 83 |
+
np.zeros((0,), dtype=np.float32),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
q_idx_arr = np.asarray(q_idx, dtype=np.int64)
|
| 87 |
+
t_idx_arr = np.asarray(t_idx, dtype=np.int64)
|
| 88 |
+
dist_arr = np.asarray(dist, dtype=np.float32)
|
| 89 |
+
|
| 90 |
+
# 按 distance 从小到大排序,并截断
|
| 91 |
+
order = np.argsort(dist_arr)
|
| 92 |
+
if self.config.max_matches is not None:
|
| 93 |
+
order = order[: int(self.config.max_matches)]
|
| 94 |
+
return q_idx_arr[order], t_idx_arr[order], dist_arr[order]
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def _mutual_filter(
|
| 98 |
+
q_to_t: Tuple[np.ndarray, np.ndarray, np.ndarray],
|
| 99 |
+
t_to_q: Tuple[np.ndarray, np.ndarray, np.ndarray],
|
| 100 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 101 |
+
"""
|
| 102 |
+
q_to_t: (q_idx, t_idx, dist) from query->train
|
| 103 |
+
t_to_q: (t_idx, q_idx, dist) from train->query (注意顺序)
|
| 104 |
+
保留互为最近邻/通过 ratio 的匹配对。
|
| 105 |
+
"""
|
| 106 |
+
q_idx, t_idx, dist = q_to_t
|
| 107 |
+
t_idx2, q_idx2, _ = t_to_q
|
| 108 |
+
if q_idx.size == 0 or t_idx2.size == 0:
|
| 109 |
+
return q_idx, t_idx, dist
|
| 110 |
+
|
| 111 |
+
# 用集合做互相包含过滤
|
| 112 |
+
pairs_qt = {(int(q), int(t)) for q, t in zip(q_idx.tolist(), t_idx.tolist())}
|
| 113 |
+
pairs_tq = {(int(q), int(t)) for t, q in zip(t_idx2.tolist(), q_idx2.tolist())} # 统一成 (q,t)
|
| 114 |
+
keep_pairs = pairs_qt.intersection(pairs_tq)
|
| 115 |
+
if not keep_pairs:
|
| 116 |
+
return (
|
| 117 |
+
np.zeros((0,), dtype=np.int64),
|
| 118 |
+
np.zeros((0,), dtype=np.int64),
|
| 119 |
+
np.zeros((0,), dtype=np.float32),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
keep_mask = np.array([(int(q), int(t)) in keep_pairs for q, t in zip(q_idx, t_idx)], dtype=bool)
|
| 123 |
+
return q_idx[keep_mask], t_idx[keep_mask], dist[keep_mask]
|
| 124 |
+
|
| 125 |
+
def match(
|
| 126 |
+
self,
|
| 127 |
+
template_image: ImageArray,
|
| 128 |
+
search_image: ImageArray,
|
| 129 |
+
*,
|
| 130 |
+
template_id: Optional[str] = None,
|
| 131 |
+
) -> Tuple[KeyPointSet, KeyPointSet, MatchSet]:
|
| 132 |
+
"""
|
| 133 |
+
返回:(template_kp, search_kp, matches)
|
| 134 |
+
"""
|
| 135 |
+
template_kp = self.extractor.extract(template_image)
|
| 136 |
+
search_kp = self.extractor.extract(search_image)
|
| 137 |
+
|
| 138 |
+
q_desc = np.asarray(template_kp.descriptors) if template_kp.descriptors is not None else np.zeros((0, 0), dtype=np.float32)
|
| 139 |
+
t_desc = np.asarray(search_kp.descriptors) if search_kp.descriptors is not None else np.zeros((0, 0), dtype=np.float32)
|
| 140 |
+
if q_desc.size == 0 or t_desc.size == 0:
|
| 141 |
+
matches = MatchSet(
|
| 142 |
+
query_idx=np.zeros((0,), dtype=np.int64),
|
| 143 |
+
train_idx=np.zeros((0,), dtype=np.int64),
|
| 144 |
+
distance=np.zeros((0,), dtype=np.float32),
|
| 145 |
+
meta={"backend": "opencv_bf_knn_ratio", "template_id": template_id},
|
| 146 |
+
)
|
| 147 |
+
return template_kp, search_kp, matches
|
| 148 |
+
|
| 149 |
+
bf = self._bfmatcher(q_desc)
|
| 150 |
+
q_to_t = self._knn_ratio_match(bf, q_desc, t_desc)
|
| 151 |
+
|
| 152 |
+
if self.config.mutual_check:
|
| 153 |
+
bf2 = self._bfmatcher(t_desc)
|
| 154 |
+
t_to_q = self._knn_ratio_match(bf2, t_desc, q_desc)
|
| 155 |
+
q_idx, t_idx, dist = self._mutual_filter(q_to_t, t_to_q)
|
| 156 |
+
else:
|
| 157 |
+
q_idx, t_idx, dist = q_to_t
|
| 158 |
+
|
| 159 |
+
meta: Dict[str, Any] = {
|
| 160 |
+
"backend": "opencv_bf_knn_ratio",
|
| 161 |
+
"ratio": float(self.config.ratio),
|
| 162 |
+
"mutual_check": bool(self.config.mutual_check),
|
| 163 |
+
"template_id": template_id,
|
| 164 |
+
}
|
| 165 |
+
if self.config.max_matches is not None:
|
| 166 |
+
meta["max_matches"] = int(self.config.max_matches)
|
| 167 |
+
if self.config.extra:
|
| 168 |
+
meta["extra"] = dict(self.config.extra)
|
| 169 |
+
|
| 170 |
+
matches = MatchSet(
|
| 171 |
+
query_idx=q_idx,
|
| 172 |
+
train_idx=t_idx,
|
| 173 |
+
distance=dist,
|
| 174 |
+
meta=meta,
|
| 175 |
+
)
|
| 176 |
+
return template_kp, search_kp, matches
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _to_cv2_keypoints(kp_set: KeyPointSet):
|
| 180 |
+
import cv2
|
| 181 |
+
|
| 182 |
+
scores = (
|
| 183 |
+
kp_set.scores
|
| 184 |
+
if kp_set.scores is not None
|
| 185 |
+
else np.zeros((kp_set.n,), dtype=np.float32)
|
| 186 |
+
)
|
| 187 |
+
return [
|
| 188 |
+
cv2.KeyPoint(
|
| 189 |
+
x=float(x),
|
| 190 |
+
y=float(y),
|
| 191 |
+
size=7,
|
| 192 |
+
response=float(score),
|
| 193 |
+
)
|
| 194 |
+
for (x, y), score in zip(kp_set.xy, scores)
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _to_cv2_matches(match_set: MatchSet):
|
| 199 |
+
import cv2
|
| 200 |
+
|
| 201 |
+
dist = (
|
| 202 |
+
match_set.distance
|
| 203 |
+
if match_set.distance is not None
|
| 204 |
+
else np.zeros((match_set.m,), dtype=np.float32)
|
| 205 |
+
)
|
| 206 |
+
return [
|
| 207 |
+
cv2.DMatch(
|
| 208 |
+
_queryIdx=int(q),
|
| 209 |
+
_trainIdx=int(t),
|
| 210 |
+
_imgIdx=0,
|
| 211 |
+
_distance=float(d),
|
| 212 |
+
)
|
| 213 |
+
for q, t, d in zip(match_set.query_idx, match_set.train_idx, dist)
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def get_args():
|
| 218 |
+
parser = argparse.ArgumentParser()
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"--template_path",
|
| 221 |
+
type=str,
|
| 222 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/local/roller.png"),
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--search_path",
|
| 226 |
+
type=str,
|
| 227 |
+
default=(project_path / "data/images/keyboard/g98-v2-pink/model/keyboard1.png"),
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument("--ratio", type=float, default=0.75)
|
| 230 |
+
parser.add_argument("--max_matches", type=int, default=80)
|
| 231 |
+
parser.add_argument("--max_keypoints", type=int, default=2000)
|
| 232 |
+
args = parser.parse_args()
|
| 233 |
+
|
| 234 |
+
return args
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
args = get_args()
|
| 239 |
+
|
| 240 |
+
template = cv2.imread(args.template_path)
|
| 241 |
+
search = cv2.imread(args.search_path)
|
| 242 |
+
if template is None:
|
| 243 |
+
raise FileNotFoundError(f"无法读取模板图: {args.template_path}")
|
| 244 |
+
if search is None:
|
| 245 |
+
raise FileNotFoundError(f"无法读取搜索图: {args.search_path}")
|
| 246 |
+
|
| 247 |
+
extractor = SiftKeyPointExtract(SiftExtractConfig(max_keypoints=int(args.max_keypoints)))
|
| 248 |
+
matcher = SingleImageMatcher(
|
| 249 |
+
extractor=extractor,
|
| 250 |
+
config=SingleImageMatcherConfig(
|
| 251 |
+
ratio=float(args.ratio),
|
| 252 |
+
max_matches=int(args.max_matches),
|
| 253 |
+
),
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
template_kp, search_kp, matches = matcher.match(template, search)
|
| 257 |
+
|
| 258 |
+
template_cv2_kp = _to_cv2_keypoints(template_kp)
|
| 259 |
+
search_cv2_kp = _to_cv2_keypoints(search_kp)
|
| 260 |
+
cv2_matches = _to_cv2_matches(matches)
|
| 261 |
+
|
| 262 |
+
vis = cv2.drawMatches(
|
| 263 |
+
template,
|
| 264 |
+
template_cv2_kp,
|
| 265 |
+
search,
|
| 266 |
+
search_cv2_kp,
|
| 267 |
+
cv2_matches,
|
| 268 |
+
None,
|
| 269 |
+
flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
title = f"single_match | kp_t={template_kp.n} kp_s={search_kp.n} matches={matches.m}"
|
| 273 |
+
cv2.imshow(title, vis)
|
| 274 |
+
cv2.waitKey(0)
|
| 275 |
+
cv2.destroyAllWindows()
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
main()
|
toolbox/keypoint_match/types.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Mapping, Optional, Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
ArrayF32 = np.ndarray
|
| 13 |
+
ArrayU8 = np.ndarray
|
| 14 |
+
ImageArray = np.ndarray # HxWxC or HxW, dtype不限(uint8/float32均可)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class KeyPointSet:
|
| 19 |
+
"""
|
| 20 |
+
统一表示一张图片上的关键点与(可选)描述子。
|
| 21 |
+
|
| 22 |
+
约定:
|
| 23 |
+
- xy: shape [N, 2],列为 (x, y),float32/float64
|
| 24 |
+
- descriptors: shape [N, D](可选)。例如 ORB: uint8;SIFT/SuperPoint: float32
|
| 25 |
+
- scores: shape [N](可选),越大表示点越“强”
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
xy: ArrayF32
|
| 29 |
+
descriptors: Optional[np.ndarray] = None
|
| 30 |
+
scores: Optional[ArrayF32] = None
|
| 31 |
+
meta: Optional[Mapping[str, Any]] = None
|
| 32 |
+
|
| 33 |
+
def __post_init__(self) -> None:
|
| 34 |
+
xy = np.asarray(self.xy)
|
| 35 |
+
if xy.ndim != 2 or xy.shape[1] != 2:
|
| 36 |
+
raise ValueError(f"KeyPointSet.xy 必须是 [N,2],但得到 {xy.shape=}")
|
| 37 |
+
if self.descriptors is not None:
|
| 38 |
+
desc = np.asarray(self.descriptors)
|
| 39 |
+
if desc.ndim != 2 or desc.shape[0] != xy.shape[0]:
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"KeyPointSet.descriptors 必须是 [N,D] 且与 xy 的 N 一致,"
|
| 42 |
+
f"但得到 {desc.shape=} vs {xy.shape=}"
|
| 43 |
+
)
|
| 44 |
+
if self.scores is not None:
|
| 45 |
+
sc = np.asarray(self.scores)
|
| 46 |
+
if sc.ndim != 1 or sc.shape[0] != xy.shape[0]:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"KeyPointSet.scores 必须是 [N] 且与 xy 的 N 一致,"
|
| 49 |
+
f"但得到 {sc.shape=} vs {xy.shape=}"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def n(self) -> int:
|
| 54 |
+
return int(np.asarray(self.xy).shape[0])
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def has_descriptors(self) -> bool:
|
| 58 |
+
return self.descriptors is not None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass(frozen=True)
|
| 62 |
+
class MatchSet:
|
| 63 |
+
"""
|
| 64 |
+
统一表示两组关键点之间的匹配关系。
|
| 65 |
+
|
| 66 |
+
- query_idx: 对应“模板/小图”关键点的索引,shape [M]
|
| 67 |
+
- train_idx: 对应“大图/搜索图”关键点的索引,shape [M]
|
| 68 |
+
- distance: 该匹配的距离/相似度度量,shape [M](越小越相似是最常见约定)
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
query_idx: np.ndarray
|
| 72 |
+
train_idx: np.ndarray
|
| 73 |
+
distance: Optional[np.ndarray] = None
|
| 74 |
+
meta: Optional[Mapping[str, Any]] = None
|
| 75 |
+
|
| 76 |
+
def __post_init__(self) -> None:
|
| 77 |
+
qi = np.asarray(self.query_idx)
|
| 78 |
+
ti = np.asarray(self.train_idx)
|
| 79 |
+
if qi.ndim != 1 or ti.ndim != 1 or qi.shape[0] != ti.shape[0]:
|
| 80 |
+
raise ValueError(f"MatchSet 索引必须是同长度的一维数组,但得到 {qi.shape=} {ti.shape=}")
|
| 81 |
+
if self.distance is not None:
|
| 82 |
+
dist = np.asarray(self.distance)
|
| 83 |
+
if dist.ndim != 1 or dist.shape[0] != qi.shape[0]:
|
| 84 |
+
raise ValueError(f"MatchSet.distance 必须是 [M],但得到 {dist.shape=} vs {qi.shape=}")
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def m(self) -> int:
|
| 88 |
+
return int(np.asarray(self.query_idx).shape[0])
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass(frozen=True)
|
| 92 |
+
class Detection:
|
| 93 |
+
"""
|
| 94 |
+
一次“模板在大图中被找到”的候选结果。
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
bbox_xyxy: Tuple[float, float, float, float] # (x1,y1,x2,y2)
|
| 98 |
+
score: float
|
| 99 |
+
match_count: int
|
| 100 |
+
template_id: Optional[str] = None
|
| 101 |
+
meta: Optional[Mapping[str, Any]] = None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass(frozen=True)
|
| 105 |
+
class DetectionResult:
|
| 106 |
+
"""
|
| 107 |
+
一次检测(一个模板对一张大图)输出的结果集合。
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
detections: Sequence[Detection]
|
| 111 |
+
meta: Optional[Mapping[str, Any]] = None
|
| 112 |
+
|
toolbox/os/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == '__main__':
|
| 6 |
+
pass
|
toolbox/os/command.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Command(object):
|
| 7 |
+
custom_command = [
|
| 8 |
+
"cd"
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def _get_cmd(command):
|
| 13 |
+
command = str(command).strip()
|
| 14 |
+
if command == "":
|
| 15 |
+
return None
|
| 16 |
+
cmd_and_args = command.split(sep=" ")
|
| 17 |
+
cmd = cmd_and_args[0]
|
| 18 |
+
args = " ".join(cmd_and_args[1:])
|
| 19 |
+
return cmd, args
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def popen(cls, command):
|
| 23 |
+
cmd, args = cls._get_cmd(command)
|
| 24 |
+
if cmd in cls.custom_command:
|
| 25 |
+
method = getattr(cls, cmd)
|
| 26 |
+
return method(args)
|
| 27 |
+
else:
|
| 28 |
+
resp = os.popen(command)
|
| 29 |
+
result = resp.read()
|
| 30 |
+
resp.close()
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def cd(cls, args):
|
| 35 |
+
if args.startswith("/"):
|
| 36 |
+
os.chdir(args)
|
| 37 |
+
else:
|
| 38 |
+
pwd = os.getcwd()
|
| 39 |
+
path = os.path.join(pwd, args)
|
| 40 |
+
os.chdir(path)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def system(cls, command):
|
| 44 |
+
return os.system(command)
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def ps_ef_grep(keyword: str):
|
| 51 |
+
cmd = "ps -ef | grep {}".format(keyword)
|
| 52 |
+
rows = Command.popen(cmd)
|
| 53 |
+
rows = str(rows).split("\n")
|
| 54 |
+
rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
|
| 55 |
+
return rows
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
pass
|
toolbox/os/environment.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from dotenv.main import DotEnv
|
| 8 |
+
|
| 9 |
+
from toolbox.json.misc import traverse
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EnvironmentManager(object):
|
| 13 |
+
def __init__(self, path, env, override=False):
|
| 14 |
+
filename = os.path.join(path, '{}.env'.format(env))
|
| 15 |
+
self.filename = filename
|
| 16 |
+
|
| 17 |
+
load_dotenv(
|
| 18 |
+
dotenv_path=filename,
|
| 19 |
+
override=override
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self._environ = dict()
|
| 23 |
+
|
| 24 |
+
def open_dotenv(self, filename: str = None):
|
| 25 |
+
filename = filename or self.filename
|
| 26 |
+
dotenv = DotEnv(
|
| 27 |
+
dotenv_path=filename,
|
| 28 |
+
stream=None,
|
| 29 |
+
verbose=False,
|
| 30 |
+
interpolate=False,
|
| 31 |
+
override=False,
|
| 32 |
+
encoding="utf-8",
|
| 33 |
+
)
|
| 34 |
+
result = dotenv.dict()
|
| 35 |
+
return result
|
| 36 |
+
|
| 37 |
+
def get(self, key, default=None, dtype=str):
|
| 38 |
+
result = os.environ.get(key)
|
| 39 |
+
if result is None:
|
| 40 |
+
if default is None:
|
| 41 |
+
result = None
|
| 42 |
+
else:
|
| 43 |
+
result = default
|
| 44 |
+
else:
|
| 45 |
+
result = dtype(result)
|
| 46 |
+
self._environ[key] = result
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
_DEFAULT_DTYPE_MAP = {
|
| 51 |
+
'int': int,
|
| 52 |
+
'float': float,
|
| 53 |
+
'str': str,
|
| 54 |
+
'json.loads': json.loads
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class JsonConfig(object):
|
| 59 |
+
"""
|
| 60 |
+
将 json 中, 形如 `$float:threshold` 的值, 处理为:
|
| 61 |
+
从环境变量中查到 threshold, 再将其转换为 float 类型.
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
|
| 64 |
+
self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
|
| 65 |
+
self.environment = environment or os.environ
|
| 66 |
+
|
| 67 |
+
def sanitize_by_filename(self, filename: str):
|
| 68 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
| 69 |
+
js = json.load(f)
|
| 70 |
+
|
| 71 |
+
return self.sanitize_by_json(js)
|
| 72 |
+
|
| 73 |
+
def sanitize_by_json(self, js):
|
| 74 |
+
js = traverse(
|
| 75 |
+
js,
|
| 76 |
+
callback=self.sanitize,
|
| 77 |
+
environment=self.environment
|
| 78 |
+
)
|
| 79 |
+
return js
|
| 80 |
+
|
| 81 |
+
def sanitize(self, string, environment):
|
| 82 |
+
"""支持 $ 符开始的, 环境变量配置"""
|
| 83 |
+
if isinstance(string, str) and string.startswith('$'):
|
| 84 |
+
dtype, key = string[1:].split(':')
|
| 85 |
+
dtype = self.dtype_map[dtype]
|
| 86 |
+
|
| 87 |
+
value = environment.get(key)
|
| 88 |
+
if value is None:
|
| 89 |
+
raise AssertionError('environment not exist. key: {}'.format(key))
|
| 90 |
+
|
| 91 |
+
value = dtype(value)
|
| 92 |
+
result = value
|
| 93 |
+
else:
|
| 94 |
+
result = string
|
| 95 |
+
return result
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def demo1():
|
| 99 |
+
import json
|
| 100 |
+
|
| 101 |
+
from project_settings import project_path
|
| 102 |
+
|
| 103 |
+
environment = EnvironmentManager(
|
| 104 |
+
path=os.path.join(project_path, 'server/callbot_server/dotenv'),
|
| 105 |
+
env='dev',
|
| 106 |
+
)
|
| 107 |
+
init_scenes = environment.get(key='init_scenes', dtype=json.loads)
|
| 108 |
+
print(init_scenes)
|
| 109 |
+
print(environment._environ)
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == '__main__':
|
| 114 |
+
demo1()
|
toolbox/os/other.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pwd():
|
| 6 |
+
"""你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
|
| 7 |
+
frame = inspect.stack()[1]
|
| 8 |
+
module = inspect.getmodule(frame[0])
|
| 9 |
+
return os.path.dirname(os.path.abspath(module.__file__))
|