File size: 3,576 Bytes
5b680d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import platform, gradio as gr

import io
from typing import List, Tuple
from PIL import Image, ImageDraw, ImageFont
from transformers import pipeline

# Load an object-detection pipeline (pick any model you like)
#detector = pipeline("object-detection", model="facebook/detr-resnet-50")
detector = pipeline("object-detection", model="haiquanua/weed_detr")

def draw_boxes(im: Image.Image, preds, threshold: float = 0.25, class_map={"LABEL_0":"Luttuce", "LABEL_1":"Weed"}) -> Image.Image:
    """Draw bounding boxes + labels on a PIL image."""
    im = im.convert("RGB")
    draw = ImageDraw.Draw(im)
    try:
        # A small default bitmap font (portable in Spaces)
        font = ImageFont.load_default()
    except Exception:
        font = None

    for p in preds:
        if p.get("score", 0) < threshold:
            continue
        box = p["box"]  # {'xmin','ymin','xmax','ymax'}
        class_label=class_map.get(p['label'], 'others')
        label = f"{class_label} {p['score']:.2f}"
        xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])]
       
        if p['label']=='LABEL_0':
            col=(0, 255, 0) #green
        elif p['label']=='LABEL_1':
            col=(255, 0, 0) #red
        else:
            col='yellow'

        # rectangle + label background
        draw.rectangle(xy, outline=(255, 0, 0), width=3)
        tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6
        x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2)
        draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0))
        draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font)
        
    counts = {}
    for p in preds:
         if p.get("score", 0) >= threshold:
            counts[p["label"]] = counts.get(p["label"], 0) + 1
    caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections"
    return im

def detect_multiple(images: List[Image.Image], threshold: float = 0.25) -> List[Tuple[Image.Image, str]]:
    """
    Accepts a list of PIL images, returns a list of (image, caption) pairs
    suitable for gr.Gallery. Each image is annotated with boxes.
    """
    outputs = []
    # Batch through the HF pipeline (it accepts lists)
    results = detector(images)  # list of lists of predictions
    if not isinstance(images, list):
        annotated = draw_boxes(images.copy(), results, threshold)
        outputs.append(annotated)
    else:
      for img, preds in zip(images, results):
        annotated = draw_boxes(img.copy(), preds, threshold)
        outputs.append(annotated)
    return outputs

with gr.Blocks(title="Multi-Image Object Detection") as demo:
    gr.Markdown("# Multi-Image Object Detection\nUpload several images; I’ll draw boxes and labels for each.")

    with gr.Row():
        #img_in = gr.Image(type="pil", label="Upload images", tool="select", image_mode="RGB", source="upload", elem_id="img_in", interactive=True, multiple=True)
        img_in = gr.Image(type="pil", label="Upload images") # tool="select", image_mode="RGB", source="upload", elem_id="img_in", interactive=True, multiple=True)
        gallery = gr.Gallery(label="Detections", columns=3,  show_label=True) #height=500,

    thr = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Confidence threshold")    
    btn = gr.Button("Run Detection", variant="primary")
    btn.click(fn=detect_multiple, inputs=[img_in, thr], outputs=gallery)

    gr.Markdown("Tip: You can drag-select multiple files in the picker or paste from clipboard.")

demo.queue(max_size=16).launch()