yolov3 / app.py
reputation's picture
Create app.py
6a89a50 verified
import os
import shutil
import numpy as np
import streamlit as st
import torch
from PIL import Image
from matplotlib import pyplot as plt, patches
from torch import optim
from torch.utils.data import DataLoader
import config
from dataset import YOLODataset
from model import YOLOv3
from utils import load_checkpoint, cells_to_bboxes, non_max_suppression
def plot_image(image, boxes):
cmap = plt.get_cmap("tab20b")
class_labels = config.COCO_LABELS if config.DATASET == 'COCO' else config.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
fig, ax = plt.subplots(1)
ax.imshow(im)
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[int(class_pred)],
facecolor="none",
)
ax.add_patch(rect)
plt.text(
upper_left_x * width,
upper_left_y * height,
s=class_labels[int(class_pred)],
color="white",
verticalalignment="top",
bbox={"color": colors[int(class_pred)], "pad": 0},
)
plt.savefig("upload/output.png")
def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
model.eval()
x = next(iter(loader))
x = x.to(config.DEVICE)
with torch.no_grad():
out = model(x)
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
model.train()
for i in range(batch_size):
nms_boxes = non_max_suppression(
bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
plot_image(x[i].permute(1, 2, 0).detach().cpu(), nms_boxes)
def process():
model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
load_checkpoint(
config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
)
IMAGE_SIZE = config.IMAGE_SIZE
train_dataset = YOLODataset(
config.DATASET + "/train.csv",
transform=config.test_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir="upload",
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS,
test=True
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=1,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=True,
drop_last=False,
)
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
plot_couple_examples(model, train_loader, 0.6, 0.5, scaled_anchors)
def main():
st.title("YOLOv3 Object Detection")
output_directory = "upload"
if not os.path.exists(output_directory):
os.makedirs(output_directory)
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
for file_name in os.listdir(output_directory):
file_path = os.path.join(output_directory, file_name)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
st.error(f"Error deleting file: {e}")
image_path = os.path.join(output_directory, "uploaded_image.jpg")
with open(image_path, "wb") as f:
f.write(uploaded_file.getvalue())
process()
st.image(image_path, caption="Uploaded Image", use_column_width=True)
st.image(Image.open("upload/output.png"), caption="Object Detected", use_column_width=True)
if __name__ == "__main__":
main()