Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| import spaces | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import onnxruntime | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image, ImageColor | |
| from torchvision.io import read_image | |
| from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights | |
| from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks | |
| # Load pre-trained model transformations. | |
| weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT | |
| transforms = weights.transforms() | |
| def fix_category_id(cat_ids: list): | |
| # Define the excluded category ids and the remaining ones | |
| excluded_indices = {2, 12, 16, 19, 20} | |
| remaining_categories = list(set(range(27)) - excluded_indices) | |
| # Create a dictionary that maps new IDs to old(original) IDs | |
| new_id_to_org_id = dict(zip(range(len(remaining_categories)), remaining_categories)) | |
| return [new_id_to_org_id[i-1]+1 for i in cat_ids] | |
| def process_categories() -> tuple: | |
| """ | |
| Load and process category information from a JSON file. | |
| Returns a tuple containing two dictionaries: `category_id_to_name` maps category IDs to their names, and | |
| `category_id_to_color` maps category IDs to a randomly sampled RGB color. | |
| Returns: | |
| tuple: A tuple containing two dictionaries: | |
| - `category_id_to_name`: a dictionary mapping category IDs to their names. | |
| - `category_id_to_color`: a dictionary mapping category IDs to a randomly sampled RGB color. | |
| """ | |
| # Load raw categories from JSON file | |
| with open("categories.json") as fp: | |
| categories = json.load(fp) | |
| # Map category IDs to names | |
| category_id_to_name = {d["id"]: d["name"] for d in categories} | |
| # Set the seed for the random sampling operation | |
| random.seed(42) | |
| # Get a list of all the color names in the PIL colormap | |
| color_names = list(ImageColor.colormap.keys()) | |
| # Sample 46 unique colors from the list of color names | |
| sampled_colors = random.sample(color_names, 46) | |
| # Convert the color names to RGB values | |
| rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors] | |
| # Map category IDs to colors | |
| category_id_to_color = { | |
| category["id"]: color for category, color in zip(categories, rgb_colors) | |
| } | |
| return category_id_to_name, category_id_to_color | |
| def draw_predictions( | |
| boxes, labels, scores, masks, img, model_name, score_threshold, proba_threshold | |
| ): | |
| """ | |
| Draw predictions on the input image based on the provided boxes, labels, scores, and masks. Only predictions | |
| with scores above the `score_threshold` will be included, and masks with probabilities exceeding the | |
| `proba_threshold` will be displayed. | |
| Args: | |
| - boxes: numpy.ndarray - an array of bounding box coordinates. | |
| - labels: numpy.ndarray - an array of integers representing the predicted class for each bounding box. | |
| - scores: numpy.ndarray - an array of confidence scores for each bounding box. | |
| - masks: numpy.ndarray - an array of binary masks for each bounding box. | |
| - img: PIL.Image.Image - the input image. | |
| - model_name: str - name of the model given by the dropdown menu, either "facere" or "facere+". | |
| - score_threshold: float - a confidence score threshold for filtering out low-scoring bbox predictions. | |
| - proba_threshold: float - a threshold for filtering out low-probability (pixel-wise) mask predictions. | |
| Returns: | |
| - A list of strings, each representing the path to an image file containing the input image with a different | |
| set of predictions drawn (masks, bounding boxes, masks with bounding box labels and scores). | |
| """ | |
| imgs_list = [] | |
| # Map label IDs to names and colors | |
| label_id_to_name, label_id_to_color = process_categories() | |
| # Filter out predictions using thresholds | |
| labels_id = labels[scores > score_threshold].tolist() | |
| if model_name == "facere+": | |
| labels_id = fix_category_id(labels_id) | |
| # models output is in range: [1,class_id+1], hence re-map to: [0,class_id] | |
| labels = [label_id_to_name[int(i) - 1] for i in labels_id] | |
| masks = (masks[scores > score_threshold] > proba_threshold).astype(np.uint8) | |
| boxes = boxes[scores > score_threshold] | |
| # Draw masks to input image and save | |
| img_masks = draw_segmentation_masks( | |
| image=img, | |
| masks=torch.from_numpy(masks.squeeze(1).astype(bool)), | |
| alpha=0.9, | |
| colors=[label_id_to_color[int(i) - 1] for i in labels_id], | |
| ) | |
| img_masks = F.to_pil_image(img_masks) | |
| img_masks.save("img_masks.png") | |
| imgs_list.append("img_masks.png") | |
| # Draw bboxes to input image and save | |
| img_bbox = draw_bounding_boxes(img, boxes=torch.from_numpy(boxes), width=4) | |
| img_bbox = F.to_pil_image(img_bbox) | |
| img_bbox.save("img_bbox.png") | |
| imgs_list.append("img_bbox.png") | |
| # Save masks with their bbox labels & bbox scores | |
| for col, (mask, label, score) in enumerate(zip(masks, labels, scores)): | |
| mask = Image.fromarray(mask.squeeze()) | |
| plt.imshow(mask) | |
| plt.axis("off") | |
| plt.title(f"{label}: {score:.2f}", fontsize=9) | |
| plt.savefig(f"mask-{col}.png") | |
| plt.close() | |
| imgs_list.append(f"mask-{col}.png") | |
| return imgs_list | |
| def inference(image, model_name, mask_threshold, bbox_threshold): | |
| """ | |
| Load the ONNX model and run inference with the provided input `image`. Visualize the predictions and save them in a | |
| figure, which will be shown in the Gradio app. | |
| """ | |
| # Load image. | |
| img = read_image(image) | |
| # Apply original transformation to the image. | |
| img_transformed = transforms(img) | |
| # Download model | |
| path_onnx = hf_hub_download( | |
| repo_id="rizavelioglu/fashionfail", | |
| filename="facere_plus.onnx" if model_name == "facere+" else "facere_base.onnx" | |
| ) | |
| # Session options (see https://github.com/microsoft/onnxruntime/issues/14694#issuecomment-1598429295) | |
| sess_options = onnxruntime.SessionOptions() | |
| sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL | |
| # Create an inference session. | |
| ort_session = onnxruntime.InferenceSession( | |
| path_onnx, | |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"], | |
| sess_options=sess_options, | |
| ) | |
| # compute ONNX Runtime output prediction | |
| ort_inputs = { | |
| ort_session.get_inputs()[0].name: img_transformed.unsqueeze(dim=0).numpy() | |
| } | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| boxes, labels, scores, masks = ort_outs | |
| imgs_list = draw_predictions(boxes, labels, scores, masks, img, model_name, | |
| score_threshold=bbox_threshold, proba_threshold=mask_threshold | |
| ) | |
| return imgs_list | |
| title = "Facere - Demo" | |
| description = r"""This is the demo of the paper <a href="https://arxiv.org/abs/2404.08582">FashionFail: Addressing | |
| Failure Cases in Fashion Object Detection and Segmentation</a>. <br>Upload your image and choose the model for inference | |
| from the dropdown menu—either `Facere` or `Facere+` <br> Check out the <a | |
| href="https://rizavelioglu.github.io/fashionfail/">project page</a> for more information.""" | |
| article = r""" | |
| Example images are sampled from the `Fashionpedia-test` and `FashionFail-test` set, which the models did not see during training. | |
| <br>**Citation** <br>If you find our work useful in your research, please consider giving a star ⭐ and | |
| a citation: | |
| ``` | |
| @inproceedings{velioglu2024fashionfail, | |
| author = {Velioglu, Riza and Chan, Robin and Hammer, Barbara}, | |
| title = {FashionFail: Addressing Failure Cases in Fashion Object Detection and Segmentation}, | |
| journal = {IJCNN}, | |
| eprint = {2404.08582}, | |
| year = {2024}, | |
| } | |
| ``` | |
| """ | |
| demo = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Image(type="filepath", label="input"), | |
| gr.Dropdown(["facere", "facere+"], value="facere", label="Models"), | |
| gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="Mask threshold", info="a threshold for " | |
| "filtering out " | |
| "low-probability (" | |
| "pixel-wise) mask " | |
| "predictions"), | |
| gr.Slider(value=0.7, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold", info="a threshold for " | |
| "filtering out " | |
| "low-scoring bbox " | |
| "predictions") | |
| ], | |
| outputs=gr.Gallery(label="output", preview=True, height=500), | |
| title=title, | |
| description=description, | |
| article=article, | |
| cache_examples=True, | |
| examples_per_page=6 | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |