| | import io, os, sys |
| | from typing import List, Tuple |
| | from PIL import Image, ImageDraw, ImageFont |
| | from transformers import pipeline |
| | from huggingface_hub import snapshot_download |
| | |
| | import pprint |
| | from transformers.pipelines import PIPELINE_REGISTRY |
| | from mmengine.config import Config |
| | from pathlib import Path |
| | from mmdet.registry import MODELS |
| | |
| | from safetensors.torch import load_file |
| | import torch |
| | |
| | import gradio as gr |
| | from mmdet.utils import register_all_modules |
| | import supervision as sv |
| | |
| | from mmdet.apis import inference_detector |
| | import numpy as np |
| | from supervision import Detections |
| | from typing import List, Dict, Union, Optional |
| | from transformers import ( |
| | AutoConfig, AutoModelForObjectDetection, AutoImageProcessor, pipeline |
| | ) |
| |
|
| | CONFIDENCE_THRESHOLD = 0.5 |
| | NMS_IOU_THRESHOLD = 0.5 |
| |
|
| |
|
| | |
| | |
| |
|
| | repo_path="haiquanua/weed_swin" |
| |
|
| | model = AutoModelForObjectDetection.from_pretrained( |
| | repo_path, trust_remote_code=True |
| | ) |
| | |
| |
|
| | ip = AutoImageProcessor.from_pretrained( |
| | repo_path, trust_remote_code=True |
| | ) |
| | |
| |
|
| | |
| | detector = pipeline(task="object-detection", model=model, image_processor=ip, trust_remote_code=True) |
| |
|
| | num_head_params = sum(p.numel() for n,p in detector.model.named_parameters() if 'roi_head' in n or 'rpn_head' in n) |
| | print("roi/rpn params after pipeline setup:", num_head_params) |
| |
|
| | |
| | def draw_boxes(im: Image.Image, preds, threshold: float = 0.25, class_map={"LABEL_0":"Weed", "LABEL_1":"lettuce","LABEL_2":"Spinach"}) -> Image.Image: |
| | """Draw bounding boxes + labels on a PIL image.""" |
| | im = im.convert("RGB") |
| | draw = ImageDraw.Draw(im) |
| | try: |
| | |
| | font = ImageFont.load_default() |
| | except Exception: |
| | font = None |
| |
|
| | for p in preds: |
| | if p.get("score", 0) < threshold: |
| | continue |
| | box = p["box"] |
| | class_label=class_map.get(p['label'], p['label']) |
| | label = f"{class_label} {p['score']:.2f}" |
| | xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])] |
| | |
| | if p['label']=='LABEL_0': |
| | col=(255, 0, 0) |
| | elif p['label']=='LABEL_1': |
| | col=(0, 255, 0) |
| | else: |
| | col='yellow' |
| |
|
| | |
| | draw.rectangle(xy, outline=(255, 0, 0), width=3) |
| | tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6 |
| | x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2) |
| | draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0)) |
| | draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font) |
| | |
| | counts = {} |
| | for p in preds: |
| | if p.get("score", 0) >= threshold: |
| | counts[p["label"]] = counts.get(p["label"], 0) + 1 |
| | caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections" |
| | return im |
| |
|
| | def detect_multiple(images: List[Image.Image], threshold: float = 0.25) -> List[Tuple[Image.Image, str]]: |
| | """ |
| | Accepts a list of PIL images, returns a list of (image, caption) pairs |
| | suitable for gr.Gallery. Each image is annotated with boxes. |
| | """ |
| | outputs = [] |
| | if detector is None: |
| | gr.Error("detector is empty") |
| | |
| | |
| | results = detector(images, threshold=threshold) |
| | |
| | |
| | if not isinstance(images, list): |
| | annotated = draw_boxes(images.copy(), results, threshold) |
| | outputs.append(annotated) |
| | else: |
| | for img, preds in zip(images, results): |
| | annotated = draw_boxes(img.copy(), preds, threshold) |
| | outputs.append(annotated) |
| | return outputs |
| |
|
| |
|
| | for d in ["/tmp/huggingface", "/tmp/huggingface/datasets", "/tmp/huggingface/transformers"]: |
| | os.makedirs(d, exist_ok=True) |
| |
|
| | os.environ["HF_HOME"] = "/tmp/huggingface" |
| | os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets" |
| | os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" |
| | print("finished environment variables") |
| |
|
| | with gr.Blocks(title="Multi-Image Object Detection") as demo: |
| | gr.Markdown("# Multi-Image Object Detection\nUpload several images; I’ll draw boxes and labels for each.") |
| |
|
| | with gr.Row(): |
| | |
| | img_in = gr.Image(type="pil", label="Upload images") |
| | gallery = gr.Gallery(label="Detections", columns=3, show_label=True) |
| |
|
| | thr = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Confidence threshold") |
| | btn = gr.Button("Run Detection", variant="primary") |
| | btn.click(fn=detect_multiple, inputs=[img_in, thr], outputs=gallery) |
| |
|
| | gr.Markdown("Tip: You can drag-select multiple files in the picker or paste from clipboard.") |
| |
|
| | gr.Info(detector.__dict__) |
| | gr.Info("finished blocks setting") |
| |
|
| | |
| | |
| | |
| | |
| | demo.queue(max_size=16).launch(server_name="0.0.0.0",server_port=7860, share=False, show_error=True) |
| |
|
| |
|