Darius Morawiec
Add initial space
f0f166f
raw
history blame
4.76 kB
import gradio as gr
import PIL.Image
import torch
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cpu"
class Detector:
def __init__(self, model_id: str):
self.device = DEVICE
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
self.device
)
def detect(
self,
image: PIL.Image.Image,
text_labels: list[str],
threshold: float = 0.4,
):
inputs = self.processor(
images=image, text=[text_labels], return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs, threshold=threshold, target_sizes=[(image.height, image.width)]
)
detections = []
result = results[0]
for box, score, labels in zip(
result["boxes"], result["scores"], result["labels"]
):
box = [round(x, 2) for x in box.tolist()]
detections.append(
dict(
label=labels,
confidence=round(score.item(), 3),
box=box,
)
)
return detections
models = dict(
tiny=Detector("iSEE-Laboratory/llmdet_tiny"),
base=Detector("iSEE-Laboratory/llmdet_base"),
large=Detector("iSEE-Laboratory/llmdet_large"),
)
def _postprocess(detections):
annotations = []
for detection in detections:
box = detection["box"]
mask = (int(box[0]), int(box[1]), int(box[2]), int(box[3]))
label = f"{detection['label']} ({detection['confidence']:.2f})"
annotations.append((mask, label))
return annotations
def detect_objects(image, labels, confidence_threshold):
labels = [label.strip() for label in labels.split(",")]
return (
(
image,
_postprocess(
models["tiny"].detect(
image,
labels,
threshold=confidence_threshold,
)
),
),
(
image,
_postprocess(
models["base"].detect(
image,
labels,
threshold=confidence_threshold,
)
),
),
(
image,
_postprocess(
models["large"].detect(
image,
labels,
threshold=confidence_threshold,
)
),
),
)
with gr.Blocks() as demo:
gr.Markdown("# LLMDet Open Vocabulary Object Detection")
confidence_slider = gr.Slider(
0,
1,
value=0.4,
step=0.01,
interactive=True,
label="Confidence threshold",
)
labels = [
"backpack",
"bag",
"belt",
"blouse",
"boot",
"bracelet",
"cap",
"cardigan",
"coat",
"dress",
"earring",
"flipflop",
"glasses",
"glove",
"handbag",
"hat",
"heels",
"jacket",
"jeans",
"loafer",
"necklace",
"pullover",
"raincoat",
"ring",
"sandal",
"scarf",
"shirt",
"shoe",
"shorts",
"skirt",
"slippers",
"sneaker",
"socks",
"suitcase",
"sunglasses",
"sweater",
"tshirt",
"tie",
"top",
"trouser",
"umbrella",
"vest",
"watch",
]
# Requested labels
text_input = gr.Textbox(
label="Object labels (comma separated)!",
placeholder="shirt, jeans, shoe",
lines=1,
value=",".join(labels),
)
with gr.Row():
image_input = gr.Image(type="pil", image_mode="RGB")
with gr.Row():
output_annotated_image_tiny = gr.AnnotatedImage(label="TINY")
output_annotated_image_base = gr.AnnotatedImage(label="BASE")
output_annotated_image_large = gr.AnnotatedImage(label="LARGE")
detect_button = gr.Button("Detect")
# Connect the button to the detection function
detect_button.click(
fn=detect_objects,
inputs=[image_input, text_input, confidence_slider],
outputs=[
output_annotated_image_tiny,
output_annotated_image_base,
output_annotated_image_large,
],
)
if __name__ == "__main__":
demo.launch()