ake178178 commited on
Commit
04a66e5
·
verified ·
1 Parent(s): bb7ec90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -43
app.py CHANGED
@@ -1,47 +1,35 @@
1
- # app.py
2
-
3
  import streamlit as st
 
 
4
  import torch
 
5
  from PIL import Image
6
- import camera # 引入拍照功能
7
- import numpy as np
8
- import cv2
9
-
10
- # 物体识别函数
11
- def detect_objects(image_path):
12
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # 使用YOLOv5模型
13
- img = Image.open(image_path)
14
- results = model(img)
15
- return results
16
-
17
- def main():
18
- st.title("摄像头拍照并进行物体识别")
19
-
20
- # 拍照
21
- if st.button('拍照'):
22
- camera.take_picture()
23
- st.write("照片已拍摄并保存")
24
-
25
- # 物体识别
26
- if st.button('物体识别'):
27
- st.write("正在进行物体识别...")
28
- image_path = "captured_image.jpg"
29
-
30
- results = detect_objects(image_path)
31
- # 显示原始图片
32
- st.image(image_path, caption='原始图片', use_column_width=True)
33
-
34
- # 显示识别的结果
35
- results.render() # 在图片上绘制检测到的物体
36
- detected_img = Image.fromarray(results.imgs[0])
37
-
38
- st.image(detected_img, caption='物体识别结果', use_column_width=True)
39
-
40
- # 显示从左到右的物体列表
41
- st.write("识别到的物体:")
42
- objects = results.pandas().xyxy[0]['name'].tolist()
43
- objects_sorted_by_x = sorted(objects)
44
- st.write(objects_sorted_by_x)
45
 
46
- if __name__ == '__main__':
47
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
  import torch
5
+ from torchvision import transforms
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # 这里你需要替换成你训练好的模型的路径和类名
9
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # 加载 YOLOv5 模型
10
+ classes = model.names
11
+
12
+ def object_detection(img_path):
13
+ """
14
+ 对图像进行物体识别。
15
+ """
16
+ results = model(img_path)
17
+ return results.pandas().xyxy[0]
18
+
19
+ def show_results(results):
20
+ """
21
+ 在 Streamlit 中显示物体识别结果。
22
+ """
23
+ for index, row in results.iterrows():
24
+ left, top, right, bottom = row['xmin'], row['ymin'], row['xmax'], row['ymax']
25
+ label = classes[int(row['class'])]
26
+ st.write(f"物体:{label}, 位置:({left}, {top}) - ({right}, {bottom})")
27
+
28
+ if __name__ == "__main__":
29
+ if st.button("选择图片"):
30
+ uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
31
+ if uploaded_file is not None:
32
+ img = Image.open(uploaded_file)
33
+ img_array = np.array(img)
34
+ results = object_detection(img_array)
35
+ show_results(results)