|
|
|
|
|
|
|
|
|
|
|
from transformers import pipeline |
|
|
import streamlit as st |
|
|
from datasets import load_dataset, Image |
|
|
import torch |
|
|
|
|
|
from transformers import (AutoConfig, AutoModelForObjectDetection, AutoImageProcessor, pipeline) |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import pkgutil |
|
|
|
|
|
|
|
|
dataset1 = load_dataset("gcesar/spinach", download_mode="force_redownload") |
|
|
dataset2 = load_dataset("gcesar/spinach_augment", download_mode="force_redownload") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", trust_remote_code=True, force_download=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_boxes(im: Image.Image, preds, threshold: float = 0.25, |
|
|
class_map={"LABEL_0": "Weed", "LABEL_1": "lettuce", "LABEL_2": "Spinach"}) -> Image.Image: |
|
|
"""Draw bounding boxes + labels on a PIL image.""" |
|
|
im = im.convert("RGB") |
|
|
draw = ImageDraw.Draw(im) |
|
|
try: |
|
|
|
|
|
font = ImageFont.load_default() |
|
|
except Exception: |
|
|
font = None |
|
|
|
|
|
for p in preds: |
|
|
if p.get("score", 0) < threshold: |
|
|
continue |
|
|
box = p["box"] |
|
|
class_label = class_map.get(p['label'], p['label']) |
|
|
label = f"{class_label} {p['score']:.2f}" |
|
|
xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])] |
|
|
|
|
|
if p['label'] == 'LABEL_0': |
|
|
col = (255, 0, 0) |
|
|
elif p['label'] == 'LABEL_1': |
|
|
col = (0, 255, 0) |
|
|
else: |
|
|
col = 'yellow' |
|
|
|
|
|
|
|
|
draw.rectangle(xy, outline=(255, 0, 0), width=3) |
|
|
tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6 |
|
|
x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2) |
|
|
draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0)) |
|
|
draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font) |
|
|
|
|
|
counts = {} |
|
|
for p in preds: |
|
|
if p.get("score", 0) >= threshold: |
|
|
counts[p["label"]] = counts.get(p["label"], 0) + 1 |
|
|
caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections" |
|
|
return im |
|
|
|
|
|
|
|
|
|
|
|
st.title("Weed Detector") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
with col1: |
|
|
|
|
|
st.subheader("Initial Dataset") |
|
|
for i in range(0, 20): |
|
|
im = dataset1["train"][i]["image"] |
|
|
|
|
|
preds = pipe(im) |
|
|
|
|
|
img = draw_boxes(im, preds) |
|
|
|
|
|
st.write(img) |
|
|
|
|
|
with col2: |
|
|
|
|
|
st.subheader("Augmented Dataset") |
|
|
for i in range(0, 20): |
|
|
im = dataset2["train"][i]["image"] |
|
|
|
|
|
preds = pipe(im) |
|
|
|
|
|
img = draw_boxes(im, preds) |
|
|
|
|
|
st.write(img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|