ake178178's picture
Update app.py
079dc08 verified
raw
history blame
1.16 kB
import streamlit as st
import cv2
import numpy as np
import numpy
import torch
from torchvision import transforms
from PIL import Image
# 这里你需要替换成你训练好的模型的路径和类名
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # 加载 YOLOv5 模型
classes = model.names
def object_detection(img_path):
"""
对图像进行物体识别。
"""
results = model(img_path)
return results.pandas().xyxy[0]
def show_results(results):
"""
在 Streamlit 中显示物体识别结果。
"""
for index, row in results.iterrows():
left, top, right, bottom = row['xmin'], row['ymin'], row['xmax'], row['ymax']
label = classes[int(row['class'])]
st.write(f"物体:{label}, 位置:({left}, {top}) - ({right}, {bottom})")
if __name__ == "__main__":
if st.button("选择图片"):
uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
img = Image.open(uploaded_file)
img_array = np.array(img)
results = object_detection(img_array)
show_results(results)