FashionVeil / app.py
DatSplit's picture
Create app.py
e0d3805 verified
raw
history blame
3.34 kB
import json
import random
import spaces
import gradio as gr
import numpy as np
import onnxruntime
import torch
from PIL import Image, ImageColor
from torchvision.utils import draw_bounding_boxes
import rfdetr.datasets.transforms as T
def process_categories() -> tuple:
with open("categories.json") as fp:
categories = json.load(fp)
category_id_to_name = {d["id"]: d["name"] for d in categories}
random.seed(42)
color_names = list(ImageColor.colormap.keys())
sampled_colors = random.sample(color_names, len(categories))
rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]
category_id_to_color = {category["id"]: color for category, color in zip(categories, rgb_colors)}
return category_id_to_name, category_id_to_color
def draw_predictions(boxes, labels, scores, img, score_threshold=0.5):
imgs_list = []
label_id_to_name, label_id_to_color = process_categories()
mask = scores > score_threshold
boxes_filtered = boxes[mask]
labels_filtered = labels[mask]
scores_filtered = scores[mask]
label_names = [label_id_to_name[int(i)] for i in labels_filtered]
colors = [label_id_to_color[int(i)] for i in labels_filtered]
img_bbox = draw_bounding_boxes(
img,
boxes=torch.from_numpy(boxes_filtered),
labels=[f"{name}: {score:.2f}" for name, score in zip(label_names, scores_filtered)],
colors=colors,
width=4
)
imgs_list.append(img_bbox.permute(1, 2, 0).numpy()) # convert to HWC for Gradio
return imgs_list
@spaces.CPU(duration=20)
def inference(image_path, model_name, bbox_threshold):
transforms = T.Compose([
T.SquareResize([1120]),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
tensor_img, _ = transforms(image, None)
tensor_img = tensor_img.unsqueeze(0)
ort_inputs = {
'input': tensor_img.cpu().numpy()
}
model_path = "/home/datsplit/FashionVeil/models/rfdetr/onnx-models/rfdetrl_finetuned_fashionveil.onnx"
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_session = onnxruntime.InferenceSession(
model_path,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
sess_options=sess_options
)
ort_inputs = {ort_session.get_inputs()[0].name: img_transformed}
ort_outs = ort_session.run(None, ort_inputs)
boxes, labels, scores = ort_outs
return draw_predictions(boxes, labels, scores, torch.from_numpy(np.array(img)), score_threshold=bbox_threshold)
title = "FashionUnveil - Demo"
description = r"""This is the demo of the research project <a href="https://github.com/DatSplit/FashionVeil">FashionUnveil</a>. Upload your image for inference."""
demo = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(["RF-DETR-L"], value="RF-DETR-L", label="Model"),
gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold"),
],
outputs=gr.Gallery(label="Output", preview=True, height=500),
title=title,
description=description,
)
if __name__ == "__main__":
demo.launch()