File size: 739 Bytes
b30e7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c36daa
 
b30e7a3
 
 
5c36daa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from typing import NamedTuple, Optional, Sequence

import numpy as np


class DetectionResult(NamedTuple):
    boxes: np.ndarray
    scores: Sequence[float]
    labels: Sequence[int]
    label_names: Optional[Sequence[str]] = None


class ObjectDetector:
    """Detector interface to keep inference agnostic to model details."""

    name: str
    supports_batch: bool = False
    max_batch_size: int = 1

    def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
        raise NotImplementedError

    def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
        """Default: sequential fallback"""
        return [self.predict(f, queries) for f in frames]