h1r41 commited on
Commit
3be698e
·
1 Parent(s): ea2437a

initial commit

Browse files
Files changed (5) hide show
  1. app.py +84 -0
  2. packages.txt +1 -0
  3. requirements.txt +0 -0
  4. yolov5n.pt +3 -0
  5. yolov5s.pt +3 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ from decimal import ROUND_HALF_UP, Decimal
4
+
5
+ import av
6
+ import requests
7
+ import streamlit as st
8
+ import torch
9
+ from matplotlib import pyplot as plt
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from streamlit_webrtc import webrtc_streamer
12
+
13
+ # モデルの読み込み
14
+ yolov5n = torch.hub.load("ultralytics/yolov5", "yolov5n", pretrained=True)
15
+ yolov5s = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
16
+ truetype_url = "https://github.com/JotJunior/PHP-Boleto-ZF2/blob/master/public/assets/fonts/arial.ttf?raw=true"
17
+ r = requests.get(truetype_url, allow_redirects=True)
18
+ classes = list(yolov5n.model.names.values())
19
+
20
+
21
+ def object_detection(image: Image) -> Image:
22
+ # 物体検出の実行
23
+ pred = model(image)
24
+ # 色の一覧を作成
25
+ cmap = plt.get_cmap("hsv", len(model.model.names))
26
+ # フォントサイズ設定
27
+ sqrt = math.sqrt(image.size[0] * image.size[1] / 10000)
28
+ size = int(sqrt * 5)
29
+ font = ImageFont.truetype(io.BytesIO(r.content), size=size)
30
+ # BBoxの線の太さ設定
31
+ rec_width = int(sqrt / 2)
32
+
33
+ # 検出結果の描画
34
+ for detections in pred.xyxy:
35
+ for detection in detections:
36
+ class_id = int(detection[5])
37
+ class_name = str(model.model.names[class_id])
38
+ bbox = [int(x) for x in detection[:4].tolist()]
39
+ conf = float(detection[4])
40
+ # 閾値以上のconfidenceの場合のみ描画
41
+ if conf >= threshold:
42
+ color = cmap(class_id, bytes=True)
43
+ draw = ImageDraw.Draw(image)
44
+ draw.rectangle(bbox, outline=color, width=rec_width)
45
+ conf_str = Decimal(str(conf * 100)).quantize(
46
+ Decimal("0.01"), rounding=ROUND_HALF_UP
47
+ )
48
+ draw.text(
49
+ [bbox[0] + 5, bbox[1] + 10],
50
+ f"{class_name} {conf_str}%",
51
+ fill=color,
52
+ font=font,
53
+ )
54
+
55
+ return image
56
+
57
+
58
+ def callback(frame: av.VideoFrame) -> av.VideoFrame:
59
+ img = frame.to_image()
60
+
61
+ img = object_detection(image=img)
62
+
63
+ return av.VideoFrame.from_image(img)
64
+
65
+
66
+ # Streamlitの画面設定
67
+ st.set_page_config(page_title="Real-time object detection", page_icon=":shark:")
68
+
69
+ # サイドバー表示
70
+ classes_str = "\n".join(f"- {item}" for item in classes)
71
+ st.sidebar.markdown(f"データセットに含まれるクラス一覧:\n{classes_str}")
72
+
73
+ # メイン画面表示
74
+ st.title("Real-time object detection")
75
+ model_name = st.selectbox("Model", ["yolov5n", "yolov5s"])
76
+ model = yolov5n if model_name == "yolov5n" else yolov5s
77
+ threshold = st.slider("Confidence threshold", 0.0, 1.0, 0.25, 0.01)
78
+
79
+ webrtc_streamer(
80
+ key="object_detection",
81
+ video_frame_callback=callback,
82
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
83
+ media_stream_constraints={"video": True, "audio": False},
84
+ )
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libgl1
requirements.txt ADDED
Binary file (5.03 kB). View file
 
yolov5n.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f180cf23ba0717ada0badd6c685026d73d48f184d00fc159c2641284b2ac0a3
3
+ size 4062133
yolov5s.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b3b748c1e592ddd8868022e8732fde20025197328490623cc16c6f24d0782ee
3
+ size 14808437