Spaces:
Runtime error
Runtime error
| import io | |
| import math | |
| from decimal import ROUND_HALF_UP, Decimal | |
| import av | |
| import requests | |
| import streamlit as st | |
| import torch | |
| from matplotlib import pyplot as plt | |
| from PIL import Image, ImageDraw, ImageFont | |
| from streamlit_webrtc import webrtc_streamer | |
| # モデルの読み込み | |
| yolov5n = torch.hub.load("ultralytics/yolov5", "yolov5n", pretrained=True) | |
| yolov5s = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True) | |
| truetype_url = "https://github.com/JotJunior/PHP-Boleto-ZF2/blob/master/public/assets/fonts/arial.ttf?raw=true" | |
| r = requests.get(truetype_url, allow_redirects=True) | |
| classes = list(yolov5n.model.names.values()) | |
| def object_detection(image: Image) -> Image: | |
| # 物体検出の実行 | |
| pred = model(image) | |
| # 色の一覧を作成 | |
| cmap = plt.get_cmap("hsv", len(model.model.names)) | |
| # フォントサイズ設定 | |
| sqrt = math.sqrt(image.size[0] * image.size[1] / 10000) | |
| size = int(sqrt * 5) | |
| font = ImageFont.truetype(io.BytesIO(r.content), size=size) | |
| # BBoxの線の太さ設定 | |
| rec_width = int(sqrt / 2) | |
| # 検出結果の描画 | |
| for detections in pred.xyxy: | |
| for detection in detections: | |
| class_id = int(detection[5]) | |
| class_name = str(model.model.names[class_id]) | |
| bbox = [int(x) for x in detection[:4].tolist()] | |
| conf = float(detection[4]) | |
| # 閾値以上のconfidenceの場合のみ描画 | |
| if conf >= threshold: | |
| color = cmap(class_id, bytes=True) | |
| draw = ImageDraw.Draw(image) | |
| draw.rectangle(bbox, outline=color, width=rec_width) | |
| conf_str = Decimal(str(conf * 100)).quantize( | |
| Decimal("0.01"), rounding=ROUND_HALF_UP | |
| ) | |
| draw.text( | |
| [bbox[0] + 5, bbox[1] + 10], | |
| f"{class_name} {conf_str}%", | |
| fill=color, | |
| font=font, | |
| ) | |
| return image | |
| def callback(frame: av.VideoFrame) -> av.VideoFrame: | |
| img = frame.to_image() | |
| img = object_detection(image=img) | |
| return av.VideoFrame.from_image(img) | |
| # Streamlitの画面設定 | |
| st.set_page_config(page_title="Real-time object detection", page_icon=":shark:") | |
| # サイドバー表示 | |
| classes_str = "\n".join(f"- {item}" for item in classes) | |
| st.sidebar.markdown(f"データセットに含まれるクラス一覧:\n{classes_str}") | |
| # メイン画面表示 | |
| st.title("Real-time object detection") | |
| model_name = st.selectbox("Model", ["yolov5n", "yolov5s"]) | |
| model = yolov5n if model_name == "yolov5n" else yolov5s | |
| threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.25, 0.01) | |
| webrtc_streamer( | |
| key="object_detection", | |
| video_frame_callback=callback, | |
| rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, | |
| media_stream_constraints={"video": True, "audio": False}, | |
| ) | |