VitalyVorobyev's picture
refactor adapters
b4d1348
from __future__ import annotations
from typing import Any, Dict, List, Tuple
import cv2
import numpy as np
from .base import AdapterContext, DLAdapter, _ensure_3ch, _first_input
class EdgesAdapter(DLAdapter):
"""Generic single-channel edge detector (DexiNed/HED-style)."""
def preprocess(self, rgb: np.ndarray, sess) -> Tuple[Dict[str, np.ndarray], AdapterContext]:
input_name, in_wh = _first_input(sess)
H, W = rgb.shape[:2]
th, tw = in_wh
if th is None or tw is None:
th, tw = 512, 512
resized = cv2.resize(rgb, (tw, th), interpolation=cv2.INTER_AREA)
x = _ensure_3ch(resized.astype(np.float32) / 255.0)
x = np.transpose(x, (2, 0, 1))[None, ...]
ctx = AdapterContext(
input_name=input_name,
in_size=(th, tw),
orig_size=(H, W),
resize_size=(th, tw),
extra={},
)
return {input_name: x}, ctx
def _extract_edge_prob(self, outputs: List[np.ndarray]) -> np.ndarray:
pred = outputs[0]
if pred.ndim == 4:
cdim = pred.shape[1]
prob = pred[0, 0] if cdim >= 1 else pred[0, 0]
elif pred.ndim == 3:
if pred.shape[0] in (1, 3):
prob = pred[0]
else:
prob = pred[0]
elif pred.ndim == 2:
prob = pred
else:
prob = pred.reshape(-1)
prob = prob - prob.min()
prob = prob / (prob.max() + 1e-8)
prob = prob.reshape(int(np.sqrt(prob.size)), -1)
pmin, pmax = float(np.min(prob)), float(np.max(prob))
if pmax > pmin:
prob = (prob - pmin) / (pmax - pmin)
else:
prob = np.zeros_like(prob)
return prob.astype(np.float32)
def postprocess(
self, outputs: List[np.ndarray], rgb: np.ndarray, ctx: AdapterContext, detector: str
) -> Tuple[np.ndarray, Dict[str, Any]]:
H, W = ctx.orig_size
prob = self._extract_edge_prob(outputs)
mask = (prob > 0.5).astype(np.uint8) * 255
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST)
bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
bgr[mask > 0] = (0, 255, 0)
overlay = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
meta: Dict[str, Any] = {
"edge_prob_mean": float(np.mean(prob)),
"resize": {"h": ctx.in_size[0], "w": ctx.in_size[1]},
}
return overlay, meta
__all__ = ["EdgesAdapter"]