Spaces:
Runtime error
Runtime error
| import torch | |
| import lzma | |
| import tops | |
| from pathlib import Path | |
| from dp2.detection.base import BaseDetector | |
| from face_detection import build_detector as build_face_detector | |
| from .structures import FaceDetection | |
| from tops import logger | |
| def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): | |
| assert len(box1.shape) == 2 | |
| assert len(box2.shape) == 2 | |
| box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) | |
| # This can be batched | |
| for i, box in enumerate(box1): | |
| is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) | |
| is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) | |
| is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) | |
| box1_inside[i] = is_outside.logical_not().any() | |
| return box1_inside | |
| class FaceDetector(BaseDetector): | |
| def __init__( | |
| self, | |
| face_detector_cfg: dict, | |
| score_threshold: float, | |
| face_post_process_cfg: dict, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold) | |
| self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) | |
| self.face_post_process_cfg = face_post_process_cfg | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |
| def _detect_faces(self, im: torch.Tensor): | |
| H, W = im.shape[1:] | |
| im = im.float() - self.face_mean | |
| im = self.face_detector.resize(im[None], 1.0) | |
| boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score | |
| boxes_XYXY[:, [0, 2]] *= W | |
| boxes_XYXY[:, [1, 3]] *= H | |
| return boxes_XYXY.round().long().cpu() | |
| def forward(self, im: torch.Tensor): | |
| face_boxes = self._detect_faces(im) | |
| face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) | |
| return [face_boxes] | |
| def load_from_cache(self, cache_path: Path): | |
| logger.log(f"Loading detection from cache path: {cache_path}") | |
| with lzma.open(cache_path, "rb") as fp: | |
| state_dict = torch.load(fp) | |
| return [ | |
| state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict | |
| ] | |