ake178178 commited on
Commit
1d51954
·
verified ·
1 Parent(s): d5ef1ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -38
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
- import cv2
3
- import numpy as np
4
  import tensorflow as tf
 
5
  from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
6
 
7
  st.title("物体识别应用")
@@ -10,43 +10,24 @@ st.write("通过摄像头识别物体,从左到右显示主要物体的名称"
10
  # 加载 MobileNet 预训练模型
11
  model = tf.keras.applications.MobileNet(weights="imagenet")
12
 
13
- # 打开摄像头函数
14
- def open_camera():
15
- video_capture = cv2.VideoCapture(0)
16
- return video_capture
17
 
18
- # 物体识别函数
19
- def detect_objects(frame):
20
- image_resized = cv2.resize(frame, (224, 224))
21
- image_array = np.expand_dims(image_resized, axis=0)
22
- processed_image = preprocess_input(image_array)
23
-
24
- preds = model.predict(processed_image)
25
- decoded_preds = decode_predictions(preds, top=3)[0]
26
- objects = [f"{label}: {round(score * 100, 2)}%" for (_, label, score) in decoded_preds]
27
- return objects
28
 
29
- # 按钮控制摄像头启动
30
- if st.button("开始摄像头"):
31
- video_capture = open_camera()
32
- stframe = st.empty()
33
-
34
- while video_capture.isOpened():
35
- ret, frame = video_capture.read()
36
- if not ret:
37
- st.write("无法读取摄像头数据。")
38
- break
39
-
40
- objects = detect_objects(frame)
41
  detected_text = " | ".join(objects)
42
-
43
- # 显示检测结果在帧上
44
- cv2.putText(frame, detected_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
45
-
46
- # 转换并显示在 Streamlit
47
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
48
- stframe.image(frame_rgb, caption="检测到的物体", channels="RGB")
49
 
50
- video_capture.release()
51
- else:
52
- st.write("点击上方按钮以启动摄像头")
 
 
 
 
1
  import streamlit as st
2
+ from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
 
3
  import tensorflow as tf
4
+ import numpy as np
5
  from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
6
 
7
  st.title("物体识别应用")
 
10
  # 加载 MobileNet 预训练模型
11
  model = tf.keras.applications.MobileNet(weights="imagenet")
12
 
13
+ class ObjectDetectionTransformer(VideoTransformerBase):
14
+ def transform(self, frame):
15
+ img = frame.to_ndarray(format="bgr24")
 
16
 
17
+ # 将图像调整大小并进行预处理
18
+ image_resized = cv2.resize(img, (224, 224))
19
+ image_array = np.expand_dims(image_resized, axis=0)
20
+ processed_image = preprocess_input(image_array)
 
 
 
 
 
 
21
 
22
+ # 进行物体识别
23
+ preds = model.predict(processed_image)
24
+ decoded_preds = decode_predictions(preds, top=3)[0]
25
+ objects = [f"{label}: {round(score * 100, 2)}%" for (_, label, score) in decoded_preds]
 
 
 
 
 
 
 
 
26
  detected_text = " | ".join(objects)
 
 
 
 
 
 
 
27
 
28
+ # 将检测结果写在图像上
29
+ cv2.putText(img, detected_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
30
+ return img
31
+
32
+ # 使用 streamlit-webrtc 启动摄像头流
33
+ webrtc_streamer(key="object-detection", video_transformer_factory=ObjectDetectionTransformer)