File size: 3,947 Bytes
49a8512
 
 
38b2eeb
49a8512
 
6f27afb
05dba08
38b2eeb
5693df1
38b2eeb
6f27afb
49a8512
38b2eeb
 
 
 
 
 
 
 
 
 
49a8512
 
5693df1
 
 
 
 
 
0444951
5693df1
 
 
 
 
 
0444951
5693df1
0444951
5693df1
38b2eeb
 
 
 
 
 
 
5693df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38b2eeb
 
49a8512
5693df1
38b2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49a8512
5693df1
49a8512
5693df1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Cesar Garcia

# Imports
from transformers import pipeline
import streamlit as st
from datasets import load_dataset, Image
import torch
# from torch.utils.tensorboard.summary import draw_boxes
from transformers import (AutoConfig, AutoModelForObjectDetection, AutoImageProcessor, pipeline)
from PIL import Image, ImageDraw, ImageFont
import pkgutil

# Load dataset from Hugging Face
dataset1 = load_dataset("gcesar/spinach", download_mode="force_redownload")
dataset2 = load_dataset("gcesar/spinach_augment", download_mode="force_redownload")

# im = []
# for i in range(20):
#
#     dict_row = dataset2['train'][i]
#     img = dict_row.get('image')
#     im.append(img)


# Call image using datasets[vision]
# dataset["train"][0]["image"]

# Check for mps
# torch.backends.mps.is_built()

# Assign GPU
# device = torch.device("mps")

# Use GPU
# .to(device)
# pipeline(device=device)

# Create pipeline model
pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", trust_remote_code=True, force_download=True)
# Create pipeline model with mps
#pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", device=device, trust_remote_code=True, force_download=True)

# repo_path="haiquanua/weed_swin"
#
# model = AutoModelForObjectDetection.from_pretrained(repo_path, trust_remote_code=True)
#
# ip = AutoImageProcessor.from_pretrained(repo_path, trust_remote_code=True)
#
# pipe = pipeline(task="object-detection", model=model, image_processor=ip, force_download=True, trust_remote_code=True)


# Professor Haiquan Li function draw_boxes from haiquanua/BAT102
def draw_boxes(im: Image.Image, preds, threshold: float = 0.25,
               class_map={"LABEL_0": "Weed", "LABEL_1": "lettuce", "LABEL_2": "Spinach"}) -> Image.Image:
    """Draw bounding boxes + labels on a PIL image."""
    im = im.convert("RGB")
    draw = ImageDraw.Draw(im)
    try:
        # A small default bitmap font (portable in Spaces)
        font = ImageFont.load_default()
    except Exception:
        font = None

    for p in preds:
        if p.get("score", 0) < threshold:
            continue
        box = p["box"]  # {'xmin','ymin','xmax','ymax'}
        class_label = class_map.get(p['label'], p['label'])
        label = f"{class_label} {p['score']:.2f}"
        xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])]

        if p['label'] == 'LABEL_0':
            col = (255, 0, 0)  # red
        elif p['label'] == 'LABEL_1':
            col = (0, 255, 0)  # green
        else:
            col = 'yellow'

        # rectangle + label background
        draw.rectangle(xy, outline=(255, 0, 0), width=3)
        tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6
        x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2)
        draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0))
        draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font)

    counts = {}
    for p in preds:
        if p.get("score", 0) >= threshold:
            counts[p["label"]] = counts.get(p["label"], 0) + 1
    caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections"
    return im


# Set tittle
st.title("Weed Detector")
# Set columns
col1, col2 = st.columns(2)

# Iterate images
with col1:
    # Set subtitle
    st.subheader("Initial Dataset")
    for i in range(0, 20):
        im = dataset1["train"][i]["image"]
        # Predict pipe
        preds = pipe(im)
        # Draw boxes
        img = draw_boxes(im, preds)
        # Display images with streamlit
        st.write(img)

with col2:
    # Set subtitle
    st.subheader("Augmented Dataset")
    for i in range(0, 20):
        im = dataset2["train"][i]["image"]
        # Predict pipe
        preds = pipe(im)
        # Draw boxes
        img = draw_boxes(im, preds)
        # Display images with streamlit
        st.write(img)

# img = draw_boxes(im, preds)

# st.write(img)
# st.image(img)