BAT102 / app.py
CG
Removed mps
0444951
# Cesar Garcia
# Imports
from transformers import pipeline
import streamlit as st
from datasets import load_dataset, Image
import torch
# from torch.utils.tensorboard.summary import draw_boxes
from transformers import (AutoConfig, AutoModelForObjectDetection, AutoImageProcessor, pipeline)
from PIL import Image, ImageDraw, ImageFont
import pkgutil
# Load dataset from Hugging Face
dataset1 = load_dataset("gcesar/spinach", download_mode="force_redownload")
dataset2 = load_dataset("gcesar/spinach_augment", download_mode="force_redownload")
# im = []
# for i in range(20):
#
# dict_row = dataset2['train'][i]
# img = dict_row.get('image')
# im.append(img)
# Call image using datasets[vision]
# dataset["train"][0]["image"]
# Check for mps
# torch.backends.mps.is_built()
# Assign GPU
# device = torch.device("mps")
# Use GPU
# .to(device)
# pipeline(device=device)
# Create pipeline model
pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", trust_remote_code=True, force_download=True)
# Create pipeline model with mps
#pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", device=device, trust_remote_code=True, force_download=True)
# repo_path="haiquanua/weed_swin"
#
# model = AutoModelForObjectDetection.from_pretrained(repo_path, trust_remote_code=True)
#
# ip = AutoImageProcessor.from_pretrained(repo_path, trust_remote_code=True)
#
# pipe = pipeline(task="object-detection", model=model, image_processor=ip, force_download=True, trust_remote_code=True)
# Professor Haiquan Li function draw_boxes from haiquanua/BAT102
def draw_boxes(im: Image.Image, preds, threshold: float = 0.25,
class_map={"LABEL_0": "Weed", "LABEL_1": "lettuce", "LABEL_2": "Spinach"}) -> Image.Image:
"""Draw bounding boxes + labels on a PIL image."""
im = im.convert("RGB")
draw = ImageDraw.Draw(im)
try:
# A small default bitmap font (portable in Spaces)
font = ImageFont.load_default()
except Exception:
font = None
for p in preds:
if p.get("score", 0) < threshold:
continue
box = p["box"] # {'xmin','ymin','xmax','ymax'}
class_label = class_map.get(p['label'], p['label'])
label = f"{class_label} {p['score']:.2f}"
xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])]
if p['label'] == 'LABEL_0':
col = (255, 0, 0) # red
elif p['label'] == 'LABEL_1':
col = (0, 255, 0) # green
else:
col = 'yellow'
# rectangle + label background
draw.rectangle(xy, outline=(255, 0, 0), width=3)
tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6
x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2)
draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0))
draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font)
counts = {}
for p in preds:
if p.get("score", 0) >= threshold:
counts[p["label"]] = counts.get(p["label"], 0) + 1
caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections"
return im
# Set tittle
st.title("Weed Detector")
# Set columns
col1, col2 = st.columns(2)
# Iterate images
with col1:
# Set subtitle
st.subheader("Initial Dataset")
for i in range(0, 20):
im = dataset1["train"][i]["image"]
# Predict pipe
preds = pipe(im)
# Draw boxes
img = draw_boxes(im, preds)
# Display images with streamlit
st.write(img)
with col2:
# Set subtitle
st.subheader("Augmented Dataset")
for i in range(0, 20):
im = dataset2["train"][i]["image"]
# Predict pipe
preds = pipe(im)
# Draw boxes
img = draw_boxes(im, preds)
# Display images with streamlit
st.write(img)
# img = draw_boxes(im, preds)
# st.write(img)
# st.image(img)