CG commited on
Commit
5693df1
·
1 Parent(s): 05b8fd1

Preliminary results

Browse files
Files changed (1) hide show
  1. app.py +79 -3
app.py CHANGED
@@ -4,16 +4,92 @@
4
 
5
  import streamlit as st
6
  from datasets import load_dataset, Image
7
- from PIL import Image
 
 
 
8
 
9
  # Load dataset from Hugging Face
10
-
11
  dataset = load_dataset("gcesar/spinach")
12
 
13
  # Call image using datasets[vision]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
15
 
16
- st.write(dataset["train"][0]["image"])
 
 
 
 
 
 
 
 
17
 
18
 
 
19
 
 
 
 
4
 
5
  import streamlit as st
6
  from datasets import load_dataset, Image
7
+ from torch.utils.tensorboard.summary import draw_boxes
8
+ from transformers import pipeline
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import torch
11
 
12
  # Load dataset from Hugging Face
 
13
  dataset = load_dataset("gcesar/spinach")
14
 
15
  # Call image using datasets[vision]
16
+ # dataset["train"][0]["image"]
17
+
18
+ # Check for mps
19
+ # torch.backends.mps.is_built()
20
+
21
+ # Assign GPU
22
+ #device = torch.device("mps")
23
+
24
+ # Use GPU
25
+ # .to(device)
26
+ # pipeline(device=device)
27
+
28
+ # Create pipeline model
29
+ pipe = pipeline(task="object-detection", model="haiquanua/weed_detr")
30
+ # Create pipeline model with mps
31
+ # pipe = pipeline(task="object-detection", model="haiquanua/weed_detr", device=device)
32
+
33
+
34
+
35
+ # Professor Haiquan Li function draw_boxes from haiquanua/BAT102
36
+ def draw_boxes(im: Image.Image, preds, threshold: float = 0.25,
37
+ class_map={"LABEL_0": "Weed", "LABEL_1": "lettuce", "LABEL_2": "Spinach"}) -> Image.Image:
38
+ """Draw bounding boxes + labels on a PIL image."""
39
+ im = im.convert("RGB")
40
+ draw = ImageDraw.Draw(im)
41
+ try:
42
+ # A small default bitmap font (portable in Spaces)
43
+ font = ImageFont.load_default()
44
+ except Exception:
45
+ font = None
46
+
47
+ for p in preds:
48
+ if p.get("score", 0) < threshold:
49
+ continue
50
+ box = p["box"] # {'xmin','ymin','xmax','ymax'}
51
+ class_label = class_map.get(p['label'], p['label'])
52
+ label = f"{class_label} {p['score']:.2f}"
53
+ xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])]
54
+
55
+ if p['label'] == 'LABEL_0':
56
+ col = (255, 0, 0) # red
57
+ elif p['label'] == 'LABEL_1':
58
+ col = (0, 255, 0) # green
59
+ else:
60
+ col = 'yellow'
61
+
62
+ # rectangle + label background
63
+ draw.rectangle(xy, outline=(255, 0, 0), width=3)
64
+ tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6
65
+ x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2)
66
+ draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0))
67
+ draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font)
68
+
69
+ counts = {}
70
+ for p in preds:
71
+ if p.get("score", 0) >= threshold:
72
+ counts[p["label"]] = counts.get(p["label"], 0) + 1
73
+ caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections"
74
+ return im
75
+
76
+
77
 
78
+ # Set tittle
79
+ st.title("Weed Detector")
80
 
81
+ # Iterate images
82
+ for i in range(0, 20):
83
+ im = dataset["train"][i]["image"]
84
+ # Predict pipe
85
+ preds = pipe(im)
86
+ # Draw boxes
87
+ img = draw_boxes(im, preds)
88
+ # Display images with streamlit
89
+ st.write(img)
90
 
91
 
92
+ # img = draw_boxes(im, preds)
93
 
94
+ # st.write(img)
95
+ # st.image(img)