Spaces:
Build error
Build error
| import csv | |
| import os.path | |
| import time | |
| import cv2 | |
| import gdown | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| def load_classes(csv_reader): | |
| """ | |
| Load classes from csv. | |
| :param csv_reader: csv | |
| :return: | |
| """ | |
| result = {} | |
| for line, row in enumerate(csv_reader): | |
| line += 1 | |
| try: | |
| class_name, class_id = row | |
| except ValueError: | |
| raise (ValueError('line {}: format should be \'class_name,class_id\''.format(line))) | |
| class_id = int(class_id) | |
| if class_name in result: | |
| raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) | |
| result[class_name] = class_id | |
| return result | |
| def draw_caption(image, box, caption): | |
| """ | |
| Draw caption and bbox on image. | |
| :param image: image | |
| :param box: bounding box | |
| :param caption: caption | |
| :return: | |
| """ | |
| b = np.array(box).astype(int) | |
| cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2) | |
| cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1) | |
| def load_labels(): | |
| """ | |
| Loads labels. | |
| :return: | |
| """ | |
| with open("data/labels.csv", 'r') as f: | |
| classes = load_classes(csv.reader(f, delimiter=',')) | |
| labels = {} | |
| for key, value in classes.items(): | |
| labels[value] = key | |
| return labels | |
| def download_models(ids): | |
| """ | |
| Download all models. | |
| :param ids: name and links of models | |
| :return: | |
| """ | |
| # Download model from drive if not stored locally | |
| with st.spinner('Downloading models, this may take a minute...'): | |
| for key in ids: | |
| if not os.path.isfile(f"model/{key}.pt"): | |
| url = f"https://drive.google.com/uc?id={ids[key]}" | |
| gdown.download(url=url, output=f"model/{key}.pt") | |
| def load_model(model_path, prefix: str = 'model/'): | |
| """ | |
| Load model. | |
| :param model_path: path to inference model | |
| :param prefix: model prefix if needed | |
| :return: | |
| """ | |
| # Load model | |
| if torch.cuda.is_available(): | |
| model = torch.load(f"{prefix}{model_path}.pt").to('cuda') | |
| else: | |
| model = torch.load(f"{prefix}{model_path}.pt", map_location=torch.device('cpu')) | |
| model = model.module.cpu() | |
| model.training = False | |
| model.eval() | |
| return model | |
| def process_img(model, image, labels, caption: bool = True, thickness=2): | |
| """ | |
| Process img given a model. | |
| :param caption: whether to use captions or not | |
| :param image: image to process | |
| :param model: inference model | |
| :param labels: given labels | |
| :param thickness: thickness of bboxes | |
| :return: | |
| """ | |
| image_orig = image.copy() | |
| rows, cols, cns = image.shape | |
| smallest_side = min(rows, cols) | |
| # Rescale the image | |
| min_side = 608 | |
| max_side = 1024 | |
| scale = min_side / smallest_side | |
| # Check if the largest side is now greater than max_side | |
| largest_side = max(rows, cols) | |
| if largest_side * scale > max_side: | |
| scale = max_side / largest_side | |
| # Resize the image with the computed scale | |
| image = cv2.resize(image, (int(round(cols * scale)), int(round((rows * scale))))) | |
| rows, cols, cns = image.shape | |
| pad_w = 32 - rows % 32 | |
| pad_h = 32 - cols % 32 | |
| new_image = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32) | |
| new_image[:rows, :cols, :] = image.astype(np.float32) | |
| image = new_image.astype(np.float32) | |
| image /= 255 | |
| image -= [0.485, 0.456, 0.406] | |
| image /= [0.229, 0.224, 0.225] | |
| image = np.expand_dims(image, 0) | |
| image = np.transpose(image, (0, 3, 1, 2)) | |
| with torch.no_grad(): | |
| image = torch.from_numpy(image) | |
| if torch.cuda.is_available(): | |
| image = image.cuda() | |
| st = time.time() | |
| scores, classification, transformed_anchors = model(image.float()) | |
| elapsed_time = time.time() - st | |
| idxs = np.where(scores.cpu() > 0.5) | |
| for j in range(idxs[0].shape[0]): | |
| bbox = transformed_anchors[idxs[0][j], :] | |
| x1 = int(bbox[0] / scale) | |
| y1 = int(bbox[1] / scale) | |
| x2 = int(bbox[2] / scale) | |
| y2 = int(bbox[3] / scale) | |
| label_name = labels[int(classification[idxs[0][j]])] | |
| colors = { | |
| 'with_mask': (0, 255, 0), | |
| 'without_mask': (255, 0, 0), | |
| 'mask_weared_incorrect': (190, 100, 20) | |
| } | |
| cap = '{}'.format(label_name) if caption else '' | |
| draw_caption(image_orig, (x1, y1, x2, y2), cap) | |
| cv2.rectangle(image_orig, (x1, y1), (x2, y2), color=colors[label_name], | |
| thickness=int(1 * (smallest_side / 100))) | |
| return image_orig | |
| # Page config | |
| st.set_page_config(layout="centered") | |
| st.title("Face Mask Detection") | |
| st.write('Face Mask Detection on images, videos and webcam feed with ResNet[18~152] models. ') | |
| st.markdown(f"__Labels:__ with_mask, without_mask, mask_weared_incorrect") | |
| # Models drive ids | |
| ids = { | |
| 'resnet50_20': st.secrets['resnet50'], | |
| 'resnet152_20': st.secrets['resnet152'], | |
| } | |
| # Download all models from drive | |
| download_models(ids) | |
| # Split page into columns | |
| left, right = st.columns([5, 3]) | |
| # Model selection | |
| labels = load_labels() | |
| model_path = right.selectbox('Choose a model', options=[k for k in ids], index=0) | |
| model = load_model(model_path=model_path) if model_path != '' else None | |
| # Display example selection | |
| index = left.number_input('', min_value=0, max_value=852, value=495, help='Choose an image. ') | |
| # Uploader | |
| uploaded = st.file_uploader("Try it out with your own image!", type=['.jpg', '.png', '.jfif']) | |
| if uploaded is not None: | |
| # Convert file to image | |
| image = Image.open(uploaded) | |
| image = np.array(image) | |
| else: | |
| # Get corresponding image and transform it | |
| image = cv2.imread(f'data/validation/image/maksssksksss{str(index)}.jpg') | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Process img | |
| with st.spinner('Please wait while the image is being processed... This may take a while. '): | |
| image = process_img(model, image, labels, caption=False) | |
| left.image(cv2.resize(image, (450, 300))) | |
| # Write labels dict and device on right | |
| right.write({ | |
| 'green': 'with_mask', | |
| 'orange': 'mask_weared_incorrect', | |
| 'red': 'without_mask' | |
| }) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| right.write(device) | |
| captions = [image for image in os.listdir('data/examples/')] | |
| images = [Image.open(f'data/examples/{image}') for image in os.listdir('data/examples/')] | |
| # Display examples | |
| st.image(images, width=350) | |