BAT102 / app.py
haiquanua's picture
Update app.py
dd38055 verified
import io, os, sys
from typing import List, Tuple
from PIL import Image, ImageDraw, ImageFont
from transformers import pipeline
from huggingface_hub import snapshot_download
#import transformers
import pprint
from transformers.pipelines import PIPELINE_REGISTRY
from mmengine.config import Config
from pathlib import Path
from mmdet.registry import MODELS
#from mmengine.runner import load_state_dict
from safetensors.torch import load_file
import torch
#import platform
import gradio as gr
from mmdet.utils import register_all_modules
import supervision as sv
#import mmcv
from mmdet.apis import inference_detector
import numpy as np
from supervision import Detections
from typing import List, Dict, Union, Optional
from transformers import (
AutoConfig, AutoModelForObjectDetection, AutoImageProcessor, pipeline
)
CONFIDENCE_THRESHOLD = 0.5
NMS_IOU_THRESHOLD = 0.5
#detector = pipeline("object-detection", model="facebook/detr-resnet-50")
#detector = pipeline("object-detection", model="haiquanua/weed_detr")
repo_path="haiquanua/weed_swin"
model = AutoModelForObjectDetection.from_pretrained(
repo_path, trust_remote_code=True
)
#print("Model class:", type(model).__name__) # expect: MmdetBridge
ip = AutoImageProcessor.from_pretrained(
repo_path, trust_remote_code=True
)
#print("Processor class:", type(ip).__name__) # expect: MmdetImageProcessor
#detector = pipeline(task="mmdet-detection", model=repo_path, trust_remote_code=True)
detector = pipeline(task="object-detection", model=model, image_processor=ip, trust_remote_code=True)
num_head_params = sum(p.numel() for n,p in detector.model.named_parameters() if 'roi_head' in n or 'rpn_head' in n)
print("roi/rpn params after pipeline setup:", num_head_params)
def draw_boxes(im: Image.Image, preds, threshold: float = 0.25, class_map={"LABEL_0":"Weed", "LABEL_1":"lettuce","LABEL_2":"Spinach"}) -> Image.Image:
"""Draw bounding boxes + labels on a PIL image."""
im = im.convert("RGB")
draw = ImageDraw.Draw(im)
try:
# A small default bitmap font (portable in Spaces)
font = ImageFont.load_default()
except Exception:
font = None
for p in preds:
if p.get("score", 0) < threshold:
continue
box = p["box"] # {'xmin','ymin','xmax','ymax'}
class_label=class_map.get(p['label'], p['label'])
label = f"{class_label} {p['score']:.2f}"
xy = [(box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])]
if p['label']=='LABEL_0':
col=(255, 0, 0) #red
elif p['label']=='LABEL_1':
col=(0, 255, 0) #green
else:
col='yellow'
# rectangle + label background
draw.rectangle(xy, outline=(255, 0, 0), width=3)
tw, th = draw.textlength(label, font=font), 14 if font is None else font.size + 6
x0, y0 = box["xmin"], max(0, box["ymin"] - th - 2)
draw.rectangle([x0, y0, x0 + tw + 6, y0 + th + 2], fill=(0, 0, 0))
draw.text((x0 + 3, y0 + 2), label, fill=(255, 255, 255), font=font)
counts = {}
for p in preds:
if p.get("score", 0) >= threshold:
counts[p["label"]] = counts.get(p["label"], 0) + 1
caption = ", ".join(f"{k}: {v}" for k, v in sorted(counts.items())) or "No detections"
return im
def detect_multiple(images: List[Image.Image], threshold: float = 0.25) -> List[Tuple[Image.Image, str]]:
"""
Accepts a list of PIL images, returns a list of (image, caption) pairs
suitable for gr.Gallery. Each image is annotated with boxes.
"""
outputs = []
if detector is None:
gr.Error("detector is empty")
#else:
# gr.Info(f"dector is {type(detector).__name__}")
results = detector(images, threshold=threshold) # list of lists of predictions
#print(results)
#gr.Info("get results")
if not isinstance(images, list):
annotated = draw_boxes(images.copy(), results, threshold)
outputs.append(annotated)
else:
for img, preds in zip(images, results):
annotated = draw_boxes(img.copy(), preds, threshold)
outputs.append(annotated)
return outputs
for d in ["/tmp/huggingface", "/tmp/huggingface/datasets", "/tmp/huggingface/transformers"]:
os.makedirs(d, exist_ok=True)
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
print("finished environment variables")
with gr.Blocks(title="Multi-Image Object Detection") as demo:
gr.Markdown("# Multi-Image Object Detection\nUpload several images; I’ll draw boxes and labels for each.")
with gr.Row():
#img_in = gr.Image(type="pil", label="Upload images", tool="select", image_mode="RGB", source="upload", elem_id="img_in", interactive=True, multiple=True)
img_in = gr.Image(type="pil", label="Upload images") # tool="select", image_mode="RGB", source="upload", elem_id="img_in", interactive=True, multiple=True)
gallery = gr.Gallery(label="Detections", columns=3, show_label=True) #height=500,
thr = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Confidence threshold")
btn = gr.Button("Run Detection", variant="primary")
btn.click(fn=detect_multiple, inputs=[img_in, thr], outputs=gallery)
gr.Markdown("Tip: You can drag-select multiple files in the picker or paste from clipboard.")
gr.Info(detector.__dict__)
gr.Info("finished blocks setting")
#image=Image.open(Path(__file__).resolve().parent / "test.jpg")
#print(image.size)
#results = detector(image, padding=True, threshold=0.0)
#print("final results", results)
demo.queue(max_size=16).launch(server_name="0.0.0.0",server_port=7860, share=False, show_error=True)