Spaces:
Runtime error
Runtime error
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import Any, Mapping, Optional, Sequence, Tuple | |
| import numpy as np | |
| from toolbox.keypoint_match.types import ( | |
| DetectionResult, | |
| ImageArray, | |
| KeyPointSet, | |
| MatchSet, | |
| ) | |
| class KeyPointExtractConfig: | |
| """关键点/描述子提取器的通用配置。""" | |
| max_keypoints: Optional[int] = None | |
| # 允许实现方在 meta 中存放更多参数,例如:nms_radius、score_threshold、scale_pyramid 等 | |
| extra: Optional[Mapping[str, Any]] = None | |
| class KeyPointExtract(ABC): | |
| """ | |
| 抽象:从一张图片提取关键点与描述子。 | |
| 适配来源: | |
| - OpenCV: ORB/SIFT/AKAZE 等 | |
| - Kornia: KeyNetAffNetHardNet、DISK 等 | |
| - SuperPoint/LightGlue 等深度模型 | |
| """ | |
| def __init__(self, config: Optional[KeyPointExtractConfig] = None): | |
| self.config = config or KeyPointExtractConfig() | |
| def extract(self, image: ImageArray) -> KeyPointSet: | |
| """输入一张图,输出关键点集合(包含可选描述子)。""" | |
| def batch_extract(self, images: Sequence[ImageArray]) -> Sequence[KeyPointSet]: | |
| return [self.extract(im) for im in images] | |
| class KeyPointMatchConfig: | |
| """匹配器的通用配置。""" | |
| # 常见:ratio test 的阈值(若实现方使用 KNN) | |
| ratio: Optional[float] = None | |
| # 允许实现方控制返回的匹配数量上限 | |
| max_matches: Optional[int] = None | |
| extra: Optional[Mapping[str, Any]] = None | |
| class KeyPointMatch(ABC): | |
| """ | |
| 抽象:对两张图(或两组描述子)进行匹配,输出匹配对。 | |
| """ | |
| def __init__(self, config: Optional[KeyPointMatchConfig] = None): | |
| self.config = config or KeyPointMatchConfig() | |
| def match(self, query: KeyPointSet, train: KeyPointSet) -> MatchSet: | |
| """ | |
| query: 通常来自“小图/模板” | |
| train: 通常来自“大图/搜索图” | |
| """ | |
| class RegionScorerConfig: | |
| """ | |
| 匹配点聚集成“区域”的通用配置。 | |
| 说明:你的核心算法描述是“某个区域匹配点特别多 => 目标被找到”, | |
| 因此我们把“如何聚类/评分/生成 bbox”抽象成 RegionScorer。 | |
| """ | |
| # 大图上聚类半径(像素)。例如:DBSCAN eps 或网格统计的 cell_size | |
| radius_px: float = 24.0 | |
| # 认为“找到”的最低匹配点数阈值 | |
| min_match_count: int = 12 | |
| # 最多输出多少个候选区域 | |
| topk: int = 10 | |
| extra: Optional[Mapping[str, Any]] = None | |
| class RegionScorer(ABC): | |
| """ | |
| 抽象:把匹配关系映射成候选区域(bbox + score)。 | |
| 输入包含关键点坐标与匹配对,因此实现可以: | |
| - 做简单的网格投票 / 密度聚类 | |
| - 用单应性 / RANSAC 过滤外点后再成簇 | |
| - 用匹配点的局部一致性做评分 | |
| """ | |
| def __init__(self, config: Optional[RegionScorerConfig] = None): | |
| self.config = config or RegionScorerConfig() | |
| def score( | |
| self, | |
| query: KeyPointSet, | |
| train: KeyPointSet, | |
| matches: MatchSet, | |
| *, | |
| template_id: Optional[str] = None, | |
| template_size: Optional[Tuple[int, int]] = None, # (h, w) | |
| ) -> DetectionResult: | |
| """ | |
| template_size: 小图大小,便于从匹配点推断 bbox 尺寸(可选) | |
| """ | |
| class KeyPointTemplateDetector(ABC): | |
| """ | |
| 抽象:把 (提取器 + 匹配器 + 区域评分器) 组合成“模板检测”。 | |
| 你可以实现一个具体 Detector,将其用于: | |
| - 多模板检索(小图集合在大图中找出现位置) | |
| - 单模板定位(某个小图在大图的哪里) | |
| """ | |
| def detect( | |
| self, | |
| template_image: ImageArray, | |
| search_image: ImageArray, | |
| *, | |
| template_id: Optional[str] = None, | |
| ) -> DetectionResult: | |
| ... | |