Plana-Archive commited on
Commit
d05ab7b
·
verified ·
1 Parent(s): 3ad997e

Upload anime_object_detection/detection/halfbody.py with huggingface_hub

Browse files
anime_object_detection/detection/halfbody.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple
3
+
4
+ from imgutils.data import ImageTyping
5
+ from imgutils.detect.halfbody import detect_halfbody
6
+
7
+ from .base import DeepGHSObjectDetection
8
+
9
+
10
+ def _parse_model_name(model_name: str):
11
+ matching = re.fullmatch(r'^halfbody_detect_(?P<version>[\s\S]+?)_(?P<level>[\s\S]+?)$', model_name)
12
+ return matching.group('version'), matching.group('level')
13
+
14
+
15
+ class HalfBodyDetection(DeepGHSObjectDetection):
16
+ def __init__(self):
17
+ DeepGHSObjectDetection.__init__(self, repo_id='deepghs/anime_halfbody_detection')
18
+
19
+ def _get_default_model(self) -> str:
20
+ return 'halfbody_detect_v1.0_s'
21
+
22
+ def _get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]:
23
+ return 0.7, 0.5
24
+
25
+ def _get_labels(self, model_name: str) -> List[str]:
26
+ return ['halfbody']
27
+
28
+ def detect(self, image: ImageTyping, model_name: str,
29
+ iou_threshold: float = 0.7, score_threshold: float = 0.25) \
30
+ -> List[Tuple[Tuple[float, float, float, float], str, float]]:
31
+ version, level = _parse_model_name(model_name)
32
+ return detect_halfbody(image, level=level, version=version,
33
+ conf_threshold=score_threshold, iou_threshold=iou_threshold)