Thiago Hersan
add model descriptions
71cdb5c
import gradio as gr
import numpy as np
from io import BytesIO
from PIL import Image as PImage
from torch import cuda
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances
from transformers import AutoModel, AutoProcessor, pipeline
from utils import draw_results, embed_image, embed_word, idxs_along_axes, idxs_by_dist, make_image
DEVICE = "cuda" if cuda.is_available() else "cpu"
# CLIP_MODEL = "google/siglip2-large-patch16-256"
CLIP_MODEL = "openai/clip-vit-large-patch14"
DETR_MODEL = "facebook/detr-resnet-50"
OWL_MODEL = "google/owlv2-base-patch16"
detr = pipeline(task="object-detection",
model=DETR_MODEL,
device=DEVICE)
owl = pipeline(task="zero-shot-object-detection",
model=OWL_MODEL,
device=DEVICE)
clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL)
clip = AutoModel.from_pretrained(CLIP_MODEL, device_map="auto").to(DEVICE)
def run_detr(img):
predictions = detr(img)
return draw_results(img, predictions)
def run_owl(img, classes_str):
classes = [c.strip() for c in classes_str.split(",")]
predictions = owl(img, candidate_labels=classes)
return draw_results(img, predictions)
def run_clip(files, word0, word1=""):
w0e = embed_word(word0, clip_processor, clip, DEVICE)
w1e = embed_word(word1, clip_processor, clip, DEVICE)
ies = []
imgs = []
for f in files:
img = PImage.open(f.name).convert("RGB")
img = img.resize((int(256 * img.width/img.height), 256))
imgs.append(img)
ies.append(embed_image(img, clip_processor, clip, DEVICE))
if word1 == "":
ordered_idxs = idxs_by_dist(ies, w0e)
return make_image(imgs, ordered_idxs)
else:
ordered_idxs = idxs_along_axes(ies, (w0e, w1e))
return make_image(imgs, ordered_idxs)
examples = [
("painted portrait young person", "painted portrait old person"),
("painted portrait happy person", "painted portrait worried person"),
]
with gr.Blocks() as demo:
gr.Interface(
title="Object Detection",
description="[DETR](https://huggingface.co/facebook/detr-resnet-50) model from facebook (2020), trained on [COCO 2017](https://github.com/amikelive/coco-labels/blob/master/coco-labels-2014_2017.txt) dataset and labels.",
api_name="object",
fn=run_detr,
inputs=gr.Image(type="pil"),
outputs=gr.Image(format="jpeg"),
flagging_mode="never",
)
gr.Interface(
title="Zero-Shot Object Detection",
description="[OWLv2](https://huggingface.co/google/owlv2-large-patch14-ensemble) model from google (2023).",
api_name="zero",
fn=run_owl,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Object", show_label=True)],
outputs=gr.Image(format="jpeg"),
flagging_mode="never",
)
gr.Interface(
title="Contrastive Embedding",
description="[CLIP](https://huggingface.co/openai/clip-vit-large-patch14) model from openai (2021).",
api_name="clip",
fn=run_clip,
inputs=[gr.File(file_count="multiple"),
gr.Textbox(label="1st Descriptor", show_label=True),
gr.Textbox(label="2nd Descriptor", show_label=True)],
outputs=gr.Image(format="jpeg"),
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch()