File size: 6,892 Bytes
e0d3805
 
 
 
 
 
 
 
 
 
 
cbda6dd
e0d3805
1aa7016
 
cbda6dd
 
e0d3805
cbda6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0d3805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97d2e19
 
e0d3805
 
 
 
 
 
 
 
 
 
 
 
 
c589de5
3283a6b
e0d3805
875a271
3b89f33
1aa7016
e0d3805
1aa7016
e0d3805
 
 
 
97d2e19
e0d3805
 
 
 
 
 
 
 
 
 
 
11defca
075f8de
53e08b8
 
075f8de
53e08b8
e0d3805
 
 
 
 
 
6b42faf
e0d3805
 
11defca
 
cbda6dd
3f46474
cbda6dd
 
 
 
 
c589de5
cbda6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c589de5
3f46474
cbda6dd
e0d3805
 
 
 
 
 
 
 
 
 
 
075f8de
e0d3805
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import json
import random
import spaces

import gradio as gr
import numpy as np
import onnxruntime
import torch
from PIL import Image, ImageColor
from torchvision.utils import draw_bounding_boxes
import rfdetr.datasets.transforms as T
from torchvision.ops import box_convert

# Adapted from https://huggingface.co/spaces/rizavelioglu/fashionfail

def _box_yxyx_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
    """Convert bounding boxes from (y1, x1, y2, x2) format to (x1, y1, x2, y2) format.

    Args:
        boxes (torch.Tensor): A tensor of bounding boxes in the (y1, x1, y2, x2) format.

    Returns:
        torch.Tensor: A tensor of bounding boxes in the (x1, y1, x2, y2) format.
    """
    y1, x1, y2, x2 = boxes.unbind(-1)
    boxes = torch.stack((x1, y1, x2, y2), dim=-1)
    return boxes


def _box_xyxy_to_yxyx(boxes: torch.Tensor) -> torch.Tensor:
    """Convert bounding boxes from (x1, y1, x2, y2) format to (y1, x1, y2, x2) format.

    Args:
        boxes (torch.Tensor): A tensor of bounding boxes in the (x1, y1, x2, y2) format.

    Returns:
        torch.Tensor: A tensor of bounding boxes in the (y1, x1, y2, x2) format.
    """
    x1, y1, x2, y2 = boxes.unbind(-1)
    boxes = torch.stack((y1, x1, y2, x2), dim=-1)
    return boxes

    
# Adapted from: https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py#L168
def extended_box_convert(
    boxes: torch.Tensor, in_fmt: str, out_fmt: str
) -> torch.Tensor:
    """
    Converts boxes from given in_fmt to out_fmt.

    Supported in_fmt and out_fmt are:
        - 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is the format that torchvision utilities expect.
        - 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
        - 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height.
        - 'yxyx': boxes are represented via corners, y1, x1 being top left and y2, x2 being bottom right. This is the format that `amrcnn` model outputs.

    Args:
        boxes (Tensor[N, 4]): boxes which will be converted.
        in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx'].
        out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx'].

    Returns:
        Tensor[N, 4]: Boxes into converted format.
    """

    if in_fmt == "yxyx":
        # Convert to xyxy and assign in_fmt accordingly
        boxes = _box_yxyx_to_xyxy(boxes)
        in_fmt = "xyxy"

    if out_fmt == "yxyx":
        # Convert to xyxy if not already in that format
        if in_fmt != "xyxy":
            boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt="xyxy")
        # Convert to yxyx
        boxes = _box_xyxy_to_yxyx(boxes)
    else:
        # Use torchvision's box_convert for other conversions
        boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt)

    return boxes

    
def process_categories() -> tuple:
    with open("categories.json") as fp:
        categories = json.load(fp)

    category_id_to_name = {d["id"]: d["name"] for d in categories}

    random.seed(42)
    color_names = list(ImageColor.colormap.keys())
    sampled_colors = random.sample(color_names, len(categories))
    rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]
    category_id_to_color = {category["id"]: color for category, color in zip(categories, rgb_colors)}

    return category_id_to_name, category_id_to_color




def draw_predictions(boxes, labels, scores, img, score_threshold=0.5, font_size=20):
    imgs_list = []
    label_id_to_name, label_id_to_color = process_categories()

    mask = scores > score_threshold
    boxes_filtered = boxes[mask]
    labels_filtered = labels[mask]
    scores_filtered = scores[mask]

    label_names = [label_id_to_name[int(i)] for i in labels_filtered]
    colors = [label_id_to_color[int(i)] for i in labels_filtered]

    img_bbox = draw_bounding_boxes(
        img,
        boxes=boxes_filtered,
        labels=[f"{name} {score:.2f}" for name, score in zip(label_names, scores_filtered)],
        colors=colors,
        width=5,
        font_size=20,
        font="arial.ttf",
    )
    imgs_list.append(img_bbox.permute(1, 2, 0).numpy())

    return imgs_list



def inference(image_path, model_name, bbox_threshold):
    transforms = T.Compose([
    T.SquareResize([1120]),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

    image = Image.open(image_path).convert("RGB")
    tensor_img, _ = transforms(image, None)
    tensor_img = tensor_img.unsqueeze(0)

    
    print(model_name)
    if model_name == "RF-DETR-B":
        model_path = "rfdetr.onnx"
    if model_name == "RF-DETR-L":
        model_path = "rfdetrl.onnx"

    
    sess_options = onnxruntime.SessionOptions()
    sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
    ort_session = onnxruntime.InferenceSession(
        model_path,
        providers=["CPUExecutionProvider"],
        sess_options=sess_options
    )
    
    ort_inputs = {ort_session.get_inputs()[0].name: tensor_img.cpu().numpy()}
    pred_boxes, logits = ort_session.run(['dets', 'labels'], ort_inputs)
    print(pred_boxes)
    scores = torch.sigmoid(torch.from_numpy(logits))
    max_scores, pred_labels = scores.max(-1)
    mask = max_scores > bbox_threshold
    
    pred_boxes = torch.from_numpy(pred_boxes[0])
    image_w, image_h = image.size
    
    pred_boxes_abs = pred_boxes.clone()
    pred_boxes_abs[:, 0] *= image_w
    pred_boxes_abs[:, 1] *= image_h
    pred_boxes_abs[:, 2] *= image_w
    pred_boxes_abs[:, 3] *= image_h
    
    mask = mask.squeeze(0)
    
    filtered_boxes = extended_box_convert(
        pred_boxes_abs[mask], in_fmt="cxcywh", out_fmt="xyxy"
    )
    filtered_scores = max_scores.squeeze(0)[mask]
    filtered_labels = pred_labels.squeeze(0)[mask]
    
    img_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1)
    print("drawing")
    return draw_predictions(filtered_boxes, filtered_labels, filtered_scores, img_tensor, score_threshold=bbox_threshold)




title = "FashionUnveil - Demo"
description = r"""This is the demo of the research project <a href="https://github.com/DatSplit/FashionVeil">FashionUnveil</a>. Upload your image for inference."""

demo = gr.Interface(
    fn=inference,
    inputs=[
        gr.Image(type="filepath", label="Input Image"),
        gr.Dropdown(["RF-DETR-L", "RF-DETR-B"], value="RF-DETR-B", label="Model"),
        gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold"),
    ],
    outputs=gr.Gallery(label="Output", preview=True, height=500),
    title=title,
    description=description,
)

if __name__ == "__main__":
    demo.launch()