File size: 2,910 Bytes
3be698e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import io
import math
from decimal import ROUND_HALF_UP, Decimal

import av
import requests
import streamlit as st
import torch
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from streamlit_webrtc import webrtc_streamer

# モデルの読み込み
yolov5n = torch.hub.load("ultralytics/yolov5", "yolov5n", pretrained=True)
yolov5s = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
truetype_url = "https://github.com/JotJunior/PHP-Boleto-ZF2/blob/master/public/assets/fonts/arial.ttf?raw=true"
r = requests.get(truetype_url, allow_redirects=True)
classes = list(yolov5n.model.names.values())


def object_detection(image: Image) -> Image:
    # 物体検出の実行
    pred = model(image)
    # 色の一覧を作成
    cmap = plt.get_cmap("hsv", len(model.model.names))
    # フォントサイズ設定
    sqrt = math.sqrt(image.size[0] * image.size[1] / 10000)
    size = int(sqrt * 5)
    font = ImageFont.truetype(io.BytesIO(r.content), size=size)
    # BBoxの線の太さ設定
    rec_width = int(sqrt / 2)

    # 検出結果の描画
    for detections in pred.xyxy:
        for detection in detections:
            class_id = int(detection[5])
            class_name = str(model.model.names[class_id])
            bbox = [int(x) for x in detection[:4].tolist()]
            conf = float(detection[4])
            # 閾値以上のconfidenceの場合のみ描画
            if conf >= threshold:
                color = cmap(class_id, bytes=True)
                draw = ImageDraw.Draw(image)
                draw.rectangle(bbox, outline=color, width=rec_width)
                conf_str = Decimal(str(conf * 100)).quantize(
                    Decimal("0.01"), rounding=ROUND_HALF_UP
                )
                draw.text(
                    [bbox[0] + 5, bbox[1] + 10],
                    f"{class_name} {conf_str}%",
                    fill=color,
                    font=font,
                )

    return image


def callback(frame: av.VideoFrame) -> av.VideoFrame:
    img = frame.to_image()

    img = object_detection(image=img)

    return av.VideoFrame.from_image(img)


# Streamlitの画面設定
st.set_page_config(page_title="Real-time object detection", page_icon=":shark:")

# サイドバー表示
classes_str = "\n".join(f"- {item}" for item in classes)
st.sidebar.markdown(f"データセットに含まれるクラス一覧:\n{classes_str}")

# メイン画面表示
st.title("Real-time object detection")
model_name = st.selectbox("Model", ["yolov5n", "yolov5s"])
model = yolov5n if model_name == "yolov5n" else yolov5s
threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.25, 0.01)

webrtc_streamer(
    key="object_detection",
    video_frame_callback=callback,
    rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
    media_stream_constraints={"video": True, "audio": False},
)