honeytian
first commit
071150e
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Mapping, Optional, Sequence, Tuple
import numpy as np
from toolbox.keypoint_match.types import (
DetectionResult,
ImageArray,
KeyPointSet,
MatchSet,
)
@dataclass(frozen=True)
class KeyPointExtractConfig:
"""关键点/描述子提取器的通用配置。"""
max_keypoints: Optional[int] = None
# 允许实现方在 meta 中存放更多参数,例如:nms_radius、score_threshold、scale_pyramid 等
extra: Optional[Mapping[str, Any]] = None
class KeyPointExtract(ABC):
"""
抽象:从一张图片提取关键点与描述子。
适配来源:
- OpenCV: ORB/SIFT/AKAZE 等
- Kornia: KeyNetAffNetHardNet、DISK 等
- SuperPoint/LightGlue 等深度模型
"""
def __init__(self, config: Optional[KeyPointExtractConfig] = None):
self.config = config or KeyPointExtractConfig()
@abstractmethod
def extract(self, image: ImageArray) -> KeyPointSet:
"""输入一张图,输出关键点集合(包含可选描述子)。"""
def batch_extract(self, images: Sequence[ImageArray]) -> Sequence[KeyPointSet]:
return [self.extract(im) for im in images]
@dataclass(frozen=True)
class KeyPointMatchConfig:
"""匹配器的通用配置。"""
# 常见:ratio test 的阈值(若实现方使用 KNN)
ratio: Optional[float] = None
# 允许实现方控制返回的匹配数量上限
max_matches: Optional[int] = None
extra: Optional[Mapping[str, Any]] = None
class KeyPointMatch(ABC):
"""
抽象:对两张图(或两组描述子)进行匹配,输出匹配对。
"""
def __init__(self, config: Optional[KeyPointMatchConfig] = None):
self.config = config or KeyPointMatchConfig()
@abstractmethod
def match(self, query: KeyPointSet, train: KeyPointSet) -> MatchSet:
"""
query: 通常来自“小图/模板”
train: 通常来自“大图/搜索图”
"""
@dataclass(frozen=True)
class RegionScorerConfig:
"""
匹配点聚集成“区域”的通用配置。
说明:你的核心算法描述是“某个区域匹配点特别多 => 目标被找到”,
因此我们把“如何聚类/评分/生成 bbox”抽象成 RegionScorer。
"""
# 大图上聚类半径(像素)。例如:DBSCAN eps 或网格统计的 cell_size
radius_px: float = 24.0
# 认为“找到”的最低匹配点数阈值
min_match_count: int = 12
# 最多输出多少个候选区域
topk: int = 10
extra: Optional[Mapping[str, Any]] = None
class RegionScorer(ABC):
"""
抽象:把匹配关系映射成候选区域(bbox + score)。
输入包含关键点坐标与匹配对,因此实现可以:
- 做简单的网格投票 / 密度聚类
- 用单应性 / RANSAC 过滤外点后再成簇
- 用匹配点的局部一致性做评分
"""
def __init__(self, config: Optional[RegionScorerConfig] = None):
self.config = config or RegionScorerConfig()
@abstractmethod
def score(
self,
query: KeyPointSet,
train: KeyPointSet,
matches: MatchSet,
*,
template_id: Optional[str] = None,
template_size: Optional[Tuple[int, int]] = None, # (h, w)
) -> DetectionResult:
"""
template_size: 小图大小,便于从匹配点推断 bbox 尺寸(可选)
"""
class KeyPointTemplateDetector(ABC):
"""
抽象:把 (提取器 + 匹配器 + 区域评分器) 组合成“模板检测”。
你可以实现一个具体 Detector,将其用于:
- 多模板检索(小图集合在大图中找出现位置)
- 单模板定位(某个小图在大图的哪里)
"""
@abstractmethod
def detect(
self,
template_image: ImageArray,
search_image: ImageArray,
*,
template_id: Optional[str] = None,
) -> DetectionResult:
...