Plana-Archive commited on
Commit
e97aeb0
·
verified ·
1 Parent(s): 067bc7c

Upload anime_object_detection/detection/booru_yolo.py with huggingface_hub

Browse files
anime_object_detection/detection/booru_yolo.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ from imgutils.data import ImageTyping
4
+ from imgutils.detect.booru_yolo import detect_with_booru_yolo, _DEFAULT_MODEL
5
+ from imgutils.generic.yolo import _open_models_for_repo_id
6
+
7
+ from .base import DeepGHSObjectDetection
8
+
9
+
10
+ class BooruYOLODetection(DeepGHSObjectDetection):
11
+ def __init__(self):
12
+ DeepGHSObjectDetection.__init__(self, repo_id='deepghs/booru_yolo')
13
+
14
+ def _get_default_model(self) -> str:
15
+ return _DEFAULT_MODEL
16
+
17
+ def _get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]:
18
+ return 0.7, 0.25
19
+
20
+ def _get_labels(self, model_name: str) -> List[str]:
21
+ _, _, labels = _open_models_for_repo_id(self._repo_id)._open_model(model_name)
22
+ return labels
23
+
24
+ def detect(self, image: ImageTyping, model_name: str,
25
+ iou_threshold: float = 0.7, score_threshold: float = 0.25) -> \
26
+ List[Tuple[Tuple[float, float, float, float], str, float]]:
27
+ return detect_with_booru_yolo(image=image, model_name=model_name,
28
+ iou_threshold=iou_threshold, conf_threshold=score_threshold)