VitalyVorobyev's picture
SuperPoint adapter imports tolerant of missing transformer
a304ea5
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
from .base import AdapterContext, DLAdapter, _first_input
try: # Optional runtime dependency
import torch
except ImportError: # pragma: no cover - handled dynamically
torch = None # type: ignore
try: # Optional runtime dependency
from PIL import Image
except ImportError: # pragma: no cover - handled dynamically
Image = None # type: ignore
try: # Optional runtime dependency
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
except ImportError: # pragma: no cover - handled dynamically
AutoImageProcessor = None # type: ignore
SuperPointForKeypointDetection = None # type: ignore
class _SuperPointBaseAdapter(DLAdapter):
"""Common utilities shared by SuperPoint adapters."""
_PIXEL_VALUES_KEY = "pixel_values"
def _make_hw_div8(self, H: int, W: int) -> Tuple[int, int]:
H8 = max(8, (H // 8) * 8)
W8 = max(8, (W // 8) * 8)
return H8, W8
def _prepare_input(
self,
rgb: np.ndarray,
input_name: str,
target_hw: Tuple[int, int],
) -> Tuple[Dict[str, np.ndarray], AdapterContext]:
H, W = rgb.shape[:2]
th, tw = target_hw
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
gray_r = cv2.resize(gray, (tw, th), interpolation=cv2.INTER_AREA).astype(np.float32) / 255.0
x = gray_r[None, None, ...]
ctx = AdapterContext(
input_name=input_name,
in_size=(th, tw),
orig_size=(H, W),
resize_size=(th, tw),
extra={},
)
return {input_name: x}, ctx
def _pick_outputs(self, outputs: List[np.ndarray]) -> Tuple[np.ndarray, Optional[np.ndarray]]:
semi = None
desc = None
for o in outputs:
if o.ndim == 4 and o.shape[1] in (64, 65):
semi = o
elif o.ndim == 4 and o.shape[1] == 256:
desc = o
if semi is None:
semi = outputs[0]
return semi, desc
def _softmax_channel(self, x: np.ndarray, axis: int = 1) -> np.ndarray:
x = x - np.max(x, axis=axis, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=axis, keepdims=True)
def _semi_to_heat(self, semi: np.ndarray) -> np.ndarray:
if semi.ndim != 4:
semi = semi.reshape(1, semi.shape[0], semi.shape[1], semi.shape[2])
semi = self._softmax_channel(semi, axis=1)
semi = semi[0]
if semi.shape[0] == 65:
semi = semi[:-1, ...]
Hc, Wc = semi.shape[1], semi.shape[2]
semi = semi.transpose(1, 2, 0)
semi = semi.reshape(Hc, Wc, 8, 8)
semi = semi.transpose(0, 2, 1, 3)
heat = semi.reshape(Hc * 8, Wc * 8)
return heat
def _nms_points(
self, heat: np.ndarray, thresh: float = 0.015, nms_size: int = 3, max_kp: int = 1000
) -> Tuple[np.ndarray, np.ndarray]:
H, W = heat.shape
dil = cv2.dilate(heat, np.ones((nms_size, nms_size), np.float32))
maxima = (heat == dil) & (heat > thresh)
ys, xs = np.where(maxima)
if len(xs) > max_kp:
vals = heat[ys, xs]
idx = np.argsort(vals)[-max_kp:]
ys, xs = ys[idx], xs[idx]
return ys, xs
def postprocess(
self,
outputs: List[np.ndarray],
rgb: np.ndarray,
ctx: AdapterContext,
detector: str,
) -> Tuple[np.ndarray, Dict[str, Any]]:
semi, desc = self._pick_outputs(outputs)
heat_r = self._semi_to_heat(semi)
H0, W0 = ctx.orig_size
heat = cv2.resize(heat_r, (W0, H0), interpolation=cv2.INTER_CUBIC)
ys, xs = self._nms_points(heat)
overlay = rgb.copy()
for y, x in zip(ys.tolist(), xs.tolist()):
cv2.circle(overlay, (int(x), int(y)), 2, (255, 255, 0), -1)
meta: Dict[str, Any] = {
"num_corners": int(len(xs)),
"heat_mean": float(np.mean(heat)),
}
if desc is not None:
meta["descriptors_shape"] = list(desc.shape)
return overlay, meta
class SuperPointAdapter(_SuperPointBaseAdapter):
"""SuperPoint-style keypoint detector."""
def preprocess(self, rgb: np.ndarray, sess) -> Tuple[Dict[str, np.ndarray], AdapterContext]:
input_name, in_wh = _first_input(sess)
H, W = rgb.shape[:2]
th, tw = in_wh
if th is None or tw is None:
th, tw = self._make_hw_div8(H, W)
else:
th, tw = self._make_hw_div8(th, tw)
feed, ctx = self._prepare_input(rgb, input_name, (th, tw))
ctx.extra["adapter"] = "superpoint_onnx"
ctx.extra["backend"] = "onnxruntime"
return feed, ctx
def postprocess(
self,
outputs: List[np.ndarray],
rgb: np.ndarray,
ctx: AdapterContext,
detector: str,
) -> Tuple[np.ndarray, Dict[str, Any]]:
overlay, meta = super().postprocess(outputs, rgb, ctx, detector)
meta.setdefault("adapter", "superpoint_onnx")
meta.setdefault("backend", "onnxruntime")
return overlay, meta
class SuperPointTransformersAdapter(_SuperPointBaseAdapter):
"""SuperPoint adapter backed by the HuggingFace transformers implementation."""
def __init__(
self,
model_name: str = "magic-leap-community/superpoint",
device: Optional[str] = None,
model: Optional[Any] = None,
processor: Optional[Any] = None,
) -> None:
if SuperPointForKeypointDetection is None or AutoImageProcessor is None:
raise ImportError(
"transformers is required for SuperPointTransformersAdapter. "
"Install with `pip install transformers`."
)
if torch is None:
raise ImportError(
"PyTorch is required for SuperPointTransformersAdapter. "
"Install with `pip install torch`."
)
if Image is None:
raise ImportError(
"Pillow is required for SuperPointTransformersAdapter. Install with `pip install Pillow`."
)
self._model_name = model_name
self._device = torch.device(device) if device is not None else torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self._model: Optional[Any] = model
self._processor: Optional[Any] = processor
def _ensure_model(self) -> SuperPointForKeypointDetection:
if self._model is None:
self._model = SuperPointForKeypointDetection.from_pretrained(self._model_name)
self._model.to(self._device)
self._model.eval()
return self._model
def _ensure_processor(self):
if self._processor is None:
self._processor = AutoImageProcessor.from_pretrained(self._model_name)
return self._processor
def preprocess(self, rgb: np.ndarray, sess=None) -> Tuple[Dict[str, torch.Tensor], AdapterContext]:
model = self._ensure_model()
processor = self._ensure_processor()
H, W = rgb.shape[:2]
image = Image.fromarray(rgb)
processed = processor(images=image, return_tensors="pt")
pixel_values = processed["pixel_values"].to(self._device)
ctx = AdapterContext(
input_name=self._PIXEL_VALUES_KEY,
in_size=(pixel_values.shape[-2], pixel_values.shape[-1]),
orig_size=(H, W),
resize_size=(H, W),
extra={
"adapter": "superpoint_transformers",
"backend": "transformers",
"model_name": self._model_name,
"device": str(self._device),
},
)
# Ensure the model is initialized (no-op but keeps lazy loading logic consistent).
model # noqa: B018
return {self._PIXEL_VALUES_KEY: pixel_values}, ctx
def _forward(self, pixel_values: torch.Tensor):
model = self._ensure_model()
with torch.no_grad():
outputs = model(pixel_values=pixel_values, return_dict=True)
return outputs
def postprocess(
self,
outputs,
rgb: np.ndarray,
ctx: AdapterContext,
detector: str,
) -> Tuple[np.ndarray, Dict[str, Any]]:
overlay = rgb.copy()
meta: Dict[str, Any] = {
"adapter": "superpoint_transformers",
"backend": "transformers",
"model_name": self._model_name,
"device": str(self._device),
}
keypoints = outputs.keypoints[0]
scores = outputs.scores[0]
mask = outputs.mask[0] if outputs.mask is not None else torch.ones_like(scores, dtype=torch.bool)
mask = mask.bool()
keypoints = keypoints[mask]
scores = scores[mask]
descriptors = outputs.descriptors[0][mask] if outputs.descriptors is not None else None
num_keypoints = int(keypoints.shape[0])
meta["num_keypoints"] = num_keypoints
if num_keypoints > 0:
H, W = ctx.orig_size
keypoints_np = keypoints.detach().cpu().numpy()
scores_np = scores.detach().cpu().numpy()
abs_pts = np.zeros_like(keypoints_np)
abs_pts[:, 0] = np.clip(keypoints_np[:, 0] * W, 0, W - 1)
abs_pts[:, 1] = np.clip(keypoints_np[:, 1] * H, 0, H - 1)
for (x_f, y_f), score in zip(abs_pts, scores_np):
x_i = int(round(float(x_f)))
y_i = int(round(float(y_f)))
cv2.circle(overlay, (x_i, y_i), 2, (255, 255, 0), -1)
meta["scores_mean"] = float(np.mean(scores_np))
meta["scores_std"] = float(np.std(scores_np))
meta["scores_max"] = float(np.max(scores_np))
meta["scores_min"] = float(np.min(scores_np))
else:
meta["scores_mean"] = 0.0
meta["scores_std"] = 0.0
meta["scores_max"] = 0.0
meta["scores_min"] = 0.0
if descriptors is not None:
descriptors_np = descriptors.detach().cpu().numpy()
meta["descriptors_shape"] = list(descriptors_np.shape)
else:
meta["descriptors_shape"] = None
return overlay, meta
def infer(self, rgb: np.ndarray, detector: str) -> Tuple[np.ndarray, Dict[str, Any]]:
feed, ctx = self.preprocess(rgb, None)
outputs = self._forward(feed[self._PIXEL_VALUES_KEY])
overlay, meta = self.postprocess(outputs, rgb, ctx, detector)
return overlay, meta
__all__ = ["SuperPointAdapter", "SuperPointTransformersAdapter"]