File size: 3,947 Bytes
49a8512 38b2eeb 49a8512 6f27afb 05dba08 38b2eeb 5693df1 38b2eeb 6f27afb 49a8512 38b2eeb 49a8512 5693df1 0444951 5693df1 0444951 5693df1 0444951 5693df1 38b2eeb 5693df1 38b2eeb 49a8512 5693df1 38b2eeb 49a8512 5693df1 49a8512 5693df1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | # Cesar Garcia
# Imports
from transformers import pipeline
import streamlit as st
from datasets import load_dataset, Image
import torch
# from torch.utils.tensorboard.summary import draw_boxes
from transformers import (AutoConfig, AutoModelForObjectDetection, AutoImageProcessor, pipeline)
from PIL import Image, ImageDraw, ImageFont
import pkgutil
# Load dataset from Hugging Face
dataset1 = load_dataset("gcesar/spinach", download_mode="force_redownload")
dataset2 = load_dataset("gcesar/spinach_augment", download_mode="force_redownload")
# im = []
# for i in range(20):
#
# dict_row = dataset2['train'][i]
# img = dict_row.get('image')
# im.append(img)
# Call image using datasets[vision]
# dataset["train"][0]["image"]
# Check for mps
# torch.backends.mps.is_built()
# Assign GPU
# device = torch.device("mps")
# Use GPU
# .to(device)
# pipeline(device=device)
# Create pipeline model
pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", trust_remote_code=True, force_download=True)
# Create pipeline model with mps
#pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", device=device, trust_remote_code=True, force_download=True)
# repo_path="haiquanua/weed_swin"
#
# model = AutoModelForObjectDetection.from_pretrained(repo_path, trust_remote_code=True)
#
# ip = AutoImageProcessor.from_pretrained(repo_path, trust_remote_code=True)
#
# pipe = pipeline(task="object-detection", model=model, image_processor=ip, force_download=True, trust_remote_code=True)
# Professor Haiquan Li function draw_boxes from haiquanua/BAT102
def draw_boxes(im: Image.Image, preds, threshold: float = 0.25,
class_map={"LABEL_0": "Weed", "LABEL_1": "lettuce", "LABEL_2": "Spinach"}) -> Image.Image:
"""Draw bounding boxes + labels on a PIL image."""
im = im.convert("RGB")
draw = ImageDraw.Draw(im)
try:
# A small default bitmap font (portable in Spaces)
font = ImageFont.load_default()
except Exception:
font = None
for p in preds:
if p.get("score", 0) < threshold:
continue
box = p["box"] # {'xmin','ymin','xmax','ymax'}
class_label = class_map.get(p['label'], p['label'])
label = f"{class_label} {p['score']:.2f}"
xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])]
if p['label'] == 'LABEL_0':
col = (255, 0, 0) # red
elif p['label'] == 'LABEL_1':
col = (0, 255, 0) # green
else:
col = 'yellow'
# rectangle + label background
draw.rectangle(xy, outline=(255, 0, 0), width=3)
tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6
x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2)
draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0))
draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font)
counts = {}
for p in preds:
if p.get("score", 0) >= threshold:
counts[p["label"]] = counts.get(p["label"], 0) + 1
caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections"
return im
# Set tittle
st.title("Weed Detector")
# Set columns
col1, col2 = st.columns(2)
# Iterate images
with col1:
# Set subtitle
st.subheader("Initial Dataset")
for i in range(0, 20):
im = dataset1["train"][i]["image"]
# Predict pipe
preds = pipe(im)
# Draw boxes
img = draw_boxes(im, preds)
# Display images with streamlit
st.write(img)
with col2:
# Set subtitle
st.subheader("Augmented Dataset")
for i in range(0, 20):
im = dataset2["train"][i]["image"]
# Predict pipe
preds = pipe(im)
# Draw boxes
img = draw_boxes(im, preds)
# Display images with streamlit
st.write(img)
# img = draw_boxes(im, preds)
# st.write(img)
# st.image(img)
|