| | |
| |
|
| | import os |
| | import shlex |
| | import subprocess |
| |
|
| | if os.getenv('SYSTEM') == 'spaces': |
| | git_repo = "https://github.com/WildChlamydia/MiVOLO.git" |
| | subprocess.call(shlex.split(f'pip install git+{git_repo}')) |
| |
|
| | import pathlib |
| | import os |
| | import gradio as gr |
| | import huggingface_hub |
| | import numpy as np |
| | import functools |
| | from dataclasses import dataclass |
| |
|
| | from mivolo.predictor import Predictor |
| |
|
| |
|
| | @dataclass |
| | class Cfg: |
| | detector_weights: str |
| | checkpoint: str |
| | device: str = "cpu" |
| | with_persons: bool = True |
| | disable_faces: bool = False |
| | draw: bool = True |
| |
|
| |
|
| | DESCRIPTION = """ |
| | # MiVOLO: Multi-input Transformer for Age and Gender Estimation |
| | |
| | This is an official demo for https://github.com/WildChlamydia/MiVOLO.\n |
| | Telegram channel: https://t.me/+K0i2fLGpVKBjNzUy (Russian language) |
| | """ |
| |
|
| | HF_TOKEN = os.getenv('HF_TOKEN') |
| |
|
| |
|
| | def load_models(): |
| | detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector', |
| | 'yolov8x_person_face.pt', |
| | use_auth_token=HF_TOKEN) |
| |
|
| | age_gender_path = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross', |
| | 'checkpoint-377.pth.tar', |
| | use_auth_token=HF_TOKEN) |
| |
|
| | predictor_cfg = Cfg(detector_path, age_gender_path) |
| | predictor = Predictor(predictor_cfg) |
| |
|
| | return predictor |
| |
|
| |
|
| | def detect( |
| | image: np.ndarray, |
| | score_threshold: float, |
| | iou_threshold: float, |
| | mode: str, |
| | predictor: Predictor |
| | ) -> np.ndarray: |
| | |
| |
|
| | predictor.detector.detector_kwargs['conf'] = score_threshold |
| | predictor.detector.detector_kwargs['iou'] = iou_threshold |
| |
|
| | if mode == "Use persons and faces": |
| | use_persons = True |
| | disable_faces = False |
| | elif mode == "Use persons only": |
| | use_persons = True |
| | disable_faces = True |
| | elif mode == "Use faces only": |
| | use_persons = False |
| | disable_faces = False |
| |
|
| | predictor.age_gender_model.meta.use_persons = use_persons |
| | predictor.age_gender_model.meta.disable_faces = disable_faces |
| |
|
| | image = image[:, :, ::-1] |
| | detected_objects, out_im = predictor.recognize(image) |
| | print(detected_objects) |
| | return out_im[:, :, ::-1] |
| |
|
| |
|
| | def clear(): |
| | return None, 0.4, 0.7, "Use persons and faces", None |
| |
|
| |
|
| | predictor = load_models() |
| |
|
| | image_dir = pathlib.Path('images') |
| | examples = [[path.as_posix(), 0.4, 0.7, "Use persons and faces"] for path in sorted(image_dir.glob('*.jpg'))] |
| |
|
| | func = functools.partial(detect, predictor=predictor) |
| |
|
| | with gr.Blocks( |
| | theme=gr.themes.Default(), |
| | css="style.css" |
| | ) as demo: |
| | gr.Markdown(DESCRIPTION) |
| | with gr.Row(): |
| | with gr.Column(): |
| | image = gr.Image(label='Input', type='numpy') |
| | score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold') |
| | iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold') |
| | mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"], |
| | value="Use persons and faces", |
| | label="Inference mode", |
| | info="What to use for gender and age recognition") |
| |
|
| | with gr.Row(): |
| | clear_button = gr.Button("Clear") |
| | with gr.Column(): |
| | run_button = gr.Button("Submit", variant="primary") |
| | with gr.Column(): |
| | result = gr.Image(label='Output', type='numpy') |
| |
|
| | inputs = [image, score_threshold, iou_threshold, mode] |
| | gr.Examples(examples=examples, |
| | inputs=inputs, |
| | outputs=result, |
| | fn=func, |
| | cache_examples=False) |
| | run_button.click(fn=func, inputs=inputs, outputs=result, api_name='predict') |
| | clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result]) |
| |
|
| | demo.queue(max_size=15).launch() |
| |
|