VitalyVorobyev's picture
refactor adapters
b4d1348
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Tuple
import numpy as np
@dataclass
class AdapterContext:
input_name: str
in_size: Tuple[int, int]
orig_size: Tuple[int, int]
resize_size: Tuple[int, int]
extra: Dict[str, Any]
class DLAdapter:
def preprocess(self, rgb: np.ndarray, sess): # pragma: no cover - runtime dependent
raise NotImplementedError
def postprocess(self, outputs, rgb: np.ndarray, ctx: AdapterContext, detector: str): # pragma: no cover
raise NotImplementedError
def _first_input(sess) -> Tuple[str, Tuple[int, int]]:
inp = sess.get_inputs()[0]
name = inp.name
shape = inp.shape
if len(shape) == 4:
h = shape[2] if isinstance(shape[2], int) and shape[2] > 0 else None
w = shape[3] if isinstance(shape[3], int) and shape[3] > 0 else None
if h is None or w is None:
return name, (None, None) # type: ignore
return name, (int(h), int(w))
return name, (None, None) # type: ignore
def _ensure_3ch(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
x = np.expand_dims(x, -1)
if x.shape[2] == 1:
x = np.repeat(x, 3, axis=2)
return x
__all__ = ["AdapterContext", "DLAdapter", "_first_input", "_ensure_3ch"]