h1r41's picture
initial commit
3be698e
import io
import math
from decimal import ROUND_HALF_UP, Decimal
import av
import requests
import streamlit as st
import torch
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from streamlit_webrtc import webrtc_streamer
# モデルの読み込み
yolov5n = torch.hub.load("ultralytics/yolov5", "yolov5n", pretrained=True)
yolov5s = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
truetype_url = "https://github.com/JotJunior/PHP-Boleto-ZF2/blob/master/public/assets/fonts/arial.ttf?raw=true"
r = requests.get(truetype_url, allow_redirects=True)
classes = list(yolov5n.model.names.values())
def object_detection(image: Image) -> Image:
# 物体検出の実行
pred = model(image)
# 色の一覧を作成
cmap = plt.get_cmap("hsv", len(model.model.names))
# フォントサイズ設定
sqrt = math.sqrt(image.size[0] * image.size[1] / 10000)
size = int(sqrt * 5)
font = ImageFont.truetype(io.BytesIO(r.content), size=size)
# BBoxの線の太さ設定
rec_width = int(sqrt / 2)
# 検出結果の描画
for detections in pred.xyxy:
for detection in detections:
class_id = int(detection[5])
class_name = str(model.model.names[class_id])
bbox = [int(x) for x in detection[:4].tolist()]
conf = float(detection[4])
# 閾値以上のconfidenceの場合のみ描画
if conf >= threshold:
color = cmap(class_id, bytes=True)
draw = ImageDraw.Draw(image)
draw.rectangle(bbox, outline=color, width=rec_width)
conf_str = Decimal(str(conf * 100)).quantize(
Decimal("0.01"), rounding=ROUND_HALF_UP
)
draw.text(
[bbox[0] + 5, bbox[1] + 10],
f"{class_name} {conf_str}%",
fill=color,
font=font,
)
return image
def callback(frame: av.VideoFrame) -> av.VideoFrame:
img = frame.to_image()
img = object_detection(image=img)
return av.VideoFrame.from_image(img)
# Streamlitの画面設定
st.set_page_config(page_title="Real-time object detection", page_icon=":shark:")
# サイドバー表示
classes_str = "\n".join(f"- {item}" for item in classes)
st.sidebar.markdown(f"データセットに含まれるクラス一覧:\n{classes_str}")
# メイン画面表示
st.title("Real-time object detection")
model_name = st.selectbox("Model", ["yolov5n", "yolov5s"])
model = yolov5n if model_name == "yolov5n" else yolov5s
threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.25, 0.01)
webrtc_streamer(
key="object_detection",
video_frame_callback=callback,
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
media_stream_constraints={"video": True, "audio": False},
)