| import numpy as np |
| from PIL import Image,ImageColor,ImageDraw,ImageFont |
| import torch |
| from torch import nn |
|
|
| import torchvision |
| from torchvision import datasets, models, transforms |
|
|
| import streamlit as st |
|
|
| |
| def plot_detection(image,prediction,idx2names,min_score = 0.8): |
| image_result = image.copy() |
| boxes,labels,scores = prediction['boxes'],prediction['labels'],prediction['scores'] |
| draw = ImageDraw.Draw(image_result) |
| for idx in range(boxes.shape[0]): |
| if scores[idx] >= min_score: |
| x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3] |
| name = idx2names.get(str(labels[idx].item())) |
| score = scores[idx] |
| draw.rectangle((x1,y1,x2,y2), fill=None, outline ='lawngreen',width = 2) |
| draw.text((x1,y1),name+":\n"+str(round(score.item(),2)),fill="red") |
| return image_result |
|
|
|
|
| |
| @st.cache() |
| def load_model(): |
| num_classes = 91 |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True,num_classes = num_classes) |
| if torch.cuda.is_available(): |
| model.to("cuda:0") |
| model.eval() |
| model.idx2names = {'0': 'background', '1': 'person', '2': 'bicycle', '3': 'car', |
| '4': 'motorcycle', '5': 'airplane', '6': 'bus', '7': 'train', '8': 'truck', '9': 'boat', |
| '10': 'traffic light', '11': 'fire hydrant', '13': 'stop sign', |
| '14': 'parking meter', '15': 'bench', '16': 'bird', '17': 'cat', |
| '18': 'dog', '19': 'horse', '20': 'sheep', '21': 'cow', '22': 'elephant', |
| '23': 'bear', '24': 'zebra', '25': 'giraffe', '27': 'backpack', |
| '28': 'umbrella', '31': 'handbag', '32': 'tie', '33': 'suitcase', |
| '34': 'frisbee', '35': 'skis', '36': 'snowboard', '37': 'sports ball', |
| '38': 'kite','39': 'baseball bat', '40': 'baseball glove', '41': 'skateboard', |
| '42': 'surfboard', '43': 'tennis racket', '44': 'bottle', '46': 'wine glass', |
| '47': 'cup', '48': 'fork', '49': 'knife', '50': 'spoon', '51': 'bowl', |
| '52': 'banana', '53': 'apple', '54': 'sandwich', '55': 'orange', |
| '56': 'broccoli', '57': 'carrot', '58': 'hot dog', '59': 'pizza', |
| '60': 'donut', '61': 'cake', '62': 'chair', '63': 'couch', |
| '64': 'potted plant', '65': 'bed', '67': 'dining table', |
| '70': 'toilet', '72': 'tv', '73': 'laptop', '74': 'mouse', |
| '75': 'remote', '76': 'keyboard', '77': 'cell phone', |
| '78': 'microwave', '79': 'oven', '80': 'toaster', |
| '81': 'sink', '82': 'refrigerator', '84': 'book', |
| '85': 'clock', '86': 'vase', '87': 'scissors', |
| '88': 'teddybear', '89': 'hair drier', '90': 'toothbrush'} |
| return model |
|
|
| def predict_detection(model,image_path,min_score=0.8): |
| |
| inputs = [] |
| img = Image.open(image_path).convert("RGB") |
| img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float() |
| if torch.cuda.is_available(): |
| img_tensor = img_tensor.cuda() |
| inputs.append(img_tensor) |
|
|
| |
| with torch.no_grad(): |
| predictions = model(inputs) |
|
|
| |
| img_result = plot_detection(img,predictions[0], |
| model.idx2names,min_score = min_score) |
| return img_result |
| |
| st.title("FasterRCNN功能演示") |
|
|
| st.header("FasterRCNN Input:") |
| image_file = st.file_uploader("upload a image file(jpg/png) to predict:") |
| if image_file is not None: |
| try: |
| st.image(image_file) |
| except Exception as err: |
| st.write(err) |
| else: |
| image_file = "horseman.png" |
| st.image(image_file) |
| |
| min_score = st.slider(label="choose the min_score parameter:",min_value=0.1,max_value=0.98,value=0.8) |
|
|
| st.header("FasterRCNN Prediction:") |
| with st.spinner('waitting for prediction...'): |
| model = load_model() |
| img_result = predict_detection(model,image_file,min_score=min_score) |
| st.image(img_result) |