Spaces:
Running
Running
| import os | |
| import cv2 | |
| import imghdr | |
| import shutil | |
| import warnings | |
| import numpy as np | |
| import gradio as gr | |
| from dataclasses import dataclass | |
| from mivolo.predictor import Predictor | |
| from utils import is_url, download_file, get_jpg_files, _L, MODEL_DIR, TMP_DIR | |
| class Cfg: | |
| detector_weights: str | |
| checkpoint: str | |
| device: str = "cpu" | |
| with_persons: bool = True | |
| disable_faces: bool = False | |
| draw: bool = True | |
| class ValidImgDetector: | |
| predictor = None | |
| def __init__(self): | |
| detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" | |
| age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" | |
| predictor_cfg = Cfg(detector_path, age_gender_path) | |
| self.predictor = Predictor(predictor_cfg) | |
| def _detect( | |
| self, | |
| image: np.ndarray, | |
| score_threshold: float, | |
| iou_threshold: float, | |
| mode: str, | |
| predictor: Predictor, | |
| ) -> np.ndarray: | |
| predictor.detector.detector_kwargs["conf"] = score_threshold | |
| predictor.detector.detector_kwargs["iou"] = iou_threshold | |
| if mode == "Use persons and faces": | |
| use_persons = True | |
| disable_faces = False | |
| elif mode == "Use persons only": | |
| use_persons = True | |
| disable_faces = True | |
| elif mode == "Use faces only": | |
| use_persons = False | |
| disable_faces = False | |
| predictor.age_gender_model.meta.use_persons = use_persons | |
| predictor.age_gender_model.meta.disable_faces = disable_faces | |
| detected_objects, out_im = predictor.recognize(image) | |
| has_child, has_female, has_male = False, False, False | |
| if len(detected_objects.ages) > 0: | |
| has_child = _L("是") if min(detected_objects.ages) < 18 else _L("否") | |
| has_female = _L("是") if "female" in detected_objects.genders else _L("否") | |
| has_male = _L("是") if "male" in detected_objects.genders else _L("否") | |
| return out_im[:, :, ::-1], has_child, has_female, has_male | |
| def valid_img(self, img_path): | |
| image = cv2.imread(img_path) | |
| return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor) | |
| def infer(photo: str): | |
| status = "Success" | |
| result = child = female = male = None | |
| try: | |
| if is_url(photo): | |
| if os.path.exists(TMP_DIR): | |
| shutil.rmtree(TMP_DIR) | |
| photo = download_file(photo, f"{TMP_DIR}/download.jpg") | |
| detector = ValidImgDetector() | |
| if not photo or not os.path.exists(photo) or imghdr.what(photo) == None: | |
| raise ValueError("请正确输入图片") | |
| result, child, female, male = detector.valid_img(photo) | |
| except Exception as e: | |
| status = f"{e}" | |
| return status, result, child, female, male | |
| if __name__ == "__main__": | |
| warnings.filterwarnings("ignore") | |
| with gr.Blocks() as iface: | |
| gr.Markdown(_L("# 性别年龄检测器")) | |
| with gr.Tab(_L("上传模式")): | |
| gr.Interface( | |
| fn=infer, | |
| inputs=gr.Image(label=_L("上传照片"), type="filepath"), | |
| outputs=[ | |
| gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
| gr.Image( | |
| label=_L("检测结果"), | |
| type="numpy", | |
| show_share_button=False, | |
| ), | |
| gr.Textbox(label=_L("存在儿童")), | |
| gr.Textbox(label=_L("存在女性")), | |
| gr.Textbox(label=_L("存在男性")), | |
| ], | |
| examples=get_jpg_files(f"{MODEL_DIR}/examples"), | |
| flagging_mode="never", | |
| cache_examples=False, | |
| ) | |
| with gr.Tab(_L("在线模式")): | |
| gr.Interface( | |
| fn=infer, | |
| inputs=gr.Textbox( | |
| label=_L("网络图片链接"), | |
| show_copy_button=True, | |
| ), | |
| outputs=[ | |
| gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
| gr.Image( | |
| label=_L("检测结果"), | |
| type="numpy", | |
| show_share_button=False, | |
| ), | |
| gr.Textbox(label=_L("存在儿童")), | |
| gr.Textbox(label=_L("存在女性")), | |
| gr.Textbox(label=_L("存在男性")), | |
| ], | |
| flagging_mode="never", | |
| ) | |
| iface.launch() | |