DatSplit commited on
Commit
e0d3805
·
verified ·
1 Parent(s): 2010180

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import spaces
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import onnxruntime
8
+ import torch
9
+ from PIL import Image, ImageColor
10
+ from torchvision.utils import draw_bounding_boxes
11
+ import rfdetr.datasets.transforms as T
12
+
13
+
14
+ def process_categories() -> tuple:
15
+ with open("categories.json") as fp:
16
+ categories = json.load(fp)
17
+
18
+ category_id_to_name = {d["id"]: d["name"] for d in categories}
19
+
20
+ random.seed(42)
21
+ color_names = list(ImageColor.colormap.keys())
22
+ sampled_colors = random.sample(color_names, len(categories))
23
+ rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]
24
+ category_id_to_color = {category["id"]: color for category, color in zip(categories, rgb_colors)}
25
+
26
+ return category_id_to_name, category_id_to_color
27
+
28
+
29
+
30
+ def draw_predictions(boxes, labels, scores, img, score_threshold=0.5):
31
+ imgs_list = []
32
+ label_id_to_name, label_id_to_color = process_categories()
33
+
34
+ mask = scores > score_threshold
35
+ boxes_filtered = boxes[mask]
36
+ labels_filtered = labels[mask]
37
+ scores_filtered = scores[mask]
38
+
39
+ label_names = [label_id_to_name[int(i)] for i in labels_filtered]
40
+ colors = [label_id_to_color[int(i)] for i in labels_filtered]
41
+
42
+ img_bbox = draw_bounding_boxes(
43
+ img,
44
+ boxes=torch.from_numpy(boxes_filtered),
45
+ labels=[f"{name}: {score:.2f}" for name, score in zip(label_names, scores_filtered)],
46
+ colors=colors,
47
+ width=4
48
+ )
49
+ imgs_list.append(img_bbox.permute(1, 2, 0).numpy()) # convert to HWC for Gradio
50
+
51
+ return imgs_list
52
+
53
+
54
+ @spaces.CPU(duration=20)
55
+ def inference(image_path, model_name, bbox_threshold):
56
+ transforms = T.Compose([
57
+ T.SquareResize([1120]),
58
+ T.ToTensor(),
59
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
60
+ ])
61
+
62
+ image = Image.open(image_path).convert("RGB")
63
+ tensor_img, _ = transforms(image, None)
64
+ tensor_img = tensor_img.unsqueeze(0)
65
+
66
+ ort_inputs = {
67
+ 'input': tensor_img.cpu().numpy()
68
+ }
69
+
70
+ model_path = "/home/datsplit/FashionVeil/models/rfdetr/onnx-models/rfdetrl_finetuned_fashionveil.onnx"
71
+
72
+
73
+ sess_options = onnxruntime.SessionOptions()
74
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
75
+ ort_session = onnxruntime.InferenceSession(
76
+ model_path,
77
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
78
+ sess_options=sess_options
79
+ )
80
+
81
+ ort_inputs = {ort_session.get_inputs()[0].name: img_transformed}
82
+ ort_outs = ort_session.run(None, ort_inputs)
83
+
84
+ boxes, labels, scores = ort_outs
85
+ return draw_predictions(boxes, labels, scores, torch.from_numpy(np.array(img)), score_threshold=bbox_threshold)
86
+
87
+
88
+
89
+ title = "FashionUnveil - Demo"
90
+ description = r"""This is the demo of the research project <a href="https://github.com/DatSplit/FashionVeil">FashionUnveil</a>. Upload your image for inference."""
91
+
92
+ demo = gr.Interface(
93
+ fn=inference,
94
+ inputs=[
95
+ gr.Image(type="filepath", label="Input Image"),
96
+ gr.Dropdown(["RF-DETR-L"], value="RF-DETR-L", label="Model"),
97
+ gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold"),
98
+ ],
99
+ outputs=gr.Gallery(label="Output", preview=True, height=500),
100
+ title=title,
101
+ description=description,
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch()