Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import PIL | |
| import torch | |
| from utils import plot_img_no_mask, get_models | |
| from model import predict, prepare_prediction, predict_class | |
| DET_CKPT = 'efficientDet_icevision.ckpt' | |
| CLASS_CKPT = 'class_ViT_taco_7_class.pth' | |
| st.subheader('Upload Custom Image') | |
| image_file = st.file_uploader("Upload Images", type=["png","jpg","jpeg"]) | |
| st.subheader('Example Images') | |
| example_imgs = [ | |
| 'example_imgs/basura_4_2.jpg', | |
| 'example_imgs/basura_1.jpg', | |
| 'example_imgs/basura_3.jpg' | |
| ] | |
| with st.container() as cont: | |
| st.image(example_imgs[0], width=150, caption='1') | |
| if st.button('Select Image', key='Image_1'): | |
| image_file = example_imgs[0] | |
| with st.container() as cont: | |
| st.image(example_imgs[1], width=150, caption='2') | |
| if st.button('Select Image', key='Image_2'): | |
| image_file = example_imgs[1] | |
| with st.container() as cont: | |
| st.image(example_imgs[2], width=150, caption='2') | |
| if st.button('Select Image', key='Image_3'): | |
| image_file = example_imgs[2] | |
| st.subheader('Detection parameters') | |
| detection_threshold = st.slider('Detection threshold', | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.5, | |
| step=0.1) | |
| nms_threshold = st.slider('NMS threshold', | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.3, | |
| step=0.1) | |
| st.subheader('Prediction') | |
| if image_file is not None: | |
| det_model, classifier = get_models(DET_CKPT, CLASS_CKPT) | |
| print('Getting predictions') | |
| pred_dict = predict(det_model, image_file, detection_threshold) | |
| print('Fixing the preds') | |
| boxes, image = prepare_prediction(pred_dict, nms_threshold) | |
| print('Predicting classes') | |
| labels = predict_class(classifier, image, boxes) | |
| print('Plotting') | |
| plot_img_no_mask(image, boxes, labels) | |
| img = PIL.Image.open('img.png') | |
| st.image(img,width=750) |