Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from matplotlib import pyplot as plt, patches | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| import config | |
| from dataset import YOLODataset | |
| from model import YOLOv3 | |
| from utils import load_checkpoint, cells_to_bboxes, non_max_suppression | |
| def plot_image(image, boxes): | |
| cmap = plt.get_cmap("tab20b") | |
| class_labels = config.COCO_LABELS if config.DATASET == 'COCO' else config.PASCAL_CLASSES | |
| colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] | |
| im = np.array(image) | |
| height, width, _ = im.shape | |
| fig, ax = plt.subplots(1) | |
| ax.imshow(im) | |
| for box in boxes: | |
| assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" | |
| class_pred = box[0] | |
| box = box[2:] | |
| upper_left_x = box[0] - box[2] / 2 | |
| upper_left_y = box[1] - box[3] / 2 | |
| rect = patches.Rectangle( | |
| (upper_left_x * width, upper_left_y * height), | |
| box[2] * width, | |
| box[3] * height, | |
| linewidth=2, | |
| edgecolor=colors[int(class_pred)], | |
| facecolor="none", | |
| ) | |
| ax.add_patch(rect) | |
| plt.text( | |
| upper_left_x * width, | |
| upper_left_y * height, | |
| s=class_labels[int(class_pred)], | |
| color="white", | |
| verticalalignment="top", | |
| bbox={"color": colors[int(class_pred)], "pad": 0}, | |
| ) | |
| plt.savefig("upload/output.png") | |
| def plot_couple_examples(model, loader, thresh, iou_thresh, anchors): | |
| model.eval() | |
| x = next(iter(loader)) | |
| x = x.to(config.DEVICE) | |
| with torch.no_grad(): | |
| out = model(x) | |
| bboxes = [[] for _ in range(x.shape[0])] | |
| for i in range(3): | |
| batch_size, A, S, _, _ = out[i].shape | |
| anchor = anchors[i] | |
| boxes_scale_i = cells_to_bboxes( | |
| out[i], anchor, S=S, is_preds=True | |
| ) | |
| for idx, (box) in enumerate(boxes_scale_i): | |
| bboxes[idx] += box | |
| model.train() | |
| for i in range(batch_size): | |
| nms_boxes = non_max_suppression( | |
| bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", | |
| ) | |
| plot_image(x[i].permute(1, 2, 0).detach().cpu(), nms_boxes) | |
| def process(): | |
| model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) | |
| optimizer = optim.Adam( | |
| model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY | |
| ) | |
| load_checkpoint( | |
| config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE | |
| ) | |
| IMAGE_SIZE = config.IMAGE_SIZE | |
| train_dataset = YOLODataset( | |
| config.DATASET + "/train.csv", | |
| transform=config.test_transforms, | |
| S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], | |
| img_dir="upload", | |
| label_dir=config.LABEL_DIR, | |
| anchors=config.ANCHORS, | |
| test=True | |
| ) | |
| train_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=1, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=config.PIN_MEMORY, | |
| shuffle=True, | |
| drop_last=False, | |
| ) | |
| scaled_anchors = ( | |
| torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ).to(config.DEVICE) | |
| plot_couple_examples(model, train_loader, 0.6, 0.5, scaled_anchors) | |
| def main(): | |
| st.title("YOLOv3 Object Detection") | |
| output_directory = "upload" | |
| if not os.path.exists(output_directory): | |
| os.makedirs(output_directory) | |
| uploaded_file = st.file_uploader("Choose an image...", type="jpg") | |
| if uploaded_file is not None: | |
| for file_name in os.listdir(output_directory): | |
| file_path = os.path.join(output_directory, file_name) | |
| try: | |
| if os.path.isfile(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| st.error(f"Error deleting file: {e}") | |
| image_path = os.path.join(output_directory, "uploaded_image.jpg") | |
| with open(image_path, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| process() | |
| st.image(image_path, caption="Uploaded Image", use_column_width=True) | |
| st.image(Image.open("upload/output.png"), caption="Object Detected", use_column_width=True) | |
| if __name__ == "__main__": | |
| main() | |