Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import numpy as np | |
| import datasets | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| sample_dataset1 = datasets.load_dataset("asgaardlab/SampleDataset", split="validation") | |
| sample_dataset2 = datasets.load_dataset("asgaardlab/SampleDataset2", split="validation") | |
| def overlay_with_transparency(background, overlay, alpha_mask): | |
| """ | |
| Overlay a semi-transparent image on top of another image. | |
| Args: | |
| - background: The image on which the overlay will be added. | |
| - overlay: The image to overlay. | |
| - alpha_mask: The mask specifying transparency levels. | |
| """ | |
| return cv2.addWeighted(background, 1, overlay, alpha_mask, 0) | |
| def generate_overlay_image(buggy_image, objects, segmentation_image_rgb, font_scale=0.5, font_color=(0, 255, 255)): | |
| """ | |
| Generate an overlaid image using the provided annotations. | |
| Args: | |
| - buggy_image: The image to be overlaid. | |
| - objects: The JSON object details. | |
| - segmentation_image_rgb: The segmentation image. | |
| - font_scale: Scale factor for the font size. | |
| - font_color: Color for the font in BGR format. | |
| Returns: | |
| - The overlaid image. | |
| """ | |
| overlaid_img = buggy_image.copy() | |
| for obj in objects: | |
| # Get the mask for this object | |
| color = tuple(obj["color"])[:-1] | |
| mask = np.all(segmentation_image_rgb[:, :, :3] == np.array(color), axis=-1).astype(np.float32) | |
| # Create a colored version of the mask using the object's color | |
| colored_mask = np.zeros_like(overlaid_img) | |
| colored_mask[mask == 1] = color | |
| # Overlay the colored mask onto the original image with 0.3 transparency | |
| overlaid_img = overlay_with_transparency(overlaid_img, colored_mask, 0.3) | |
| # Find the center of the mask to place the label | |
| mask_coords = np.argwhere(mask) | |
| y_center, x_center = np.mean(mask_coords, axis=0).astype(int) | |
| # Draw the object's name at the center with specified font size and color | |
| cv2.putText(overlaid_img, obj["labelName"], (x_center, y_center), | |
| cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_color, 1, cv2.LINE_AA) | |
| return overlaid_img | |
| def generate_annotations(selected_dataset, image_index): | |
| bugs_ds = sample_dataset1 if selected_dataset == 'Western Scene' else sample_dataset2 | |
| image_index = int(image_index) | |
| objects_json = bugs_ds[image_index]["Objects JSON (Correct)"] | |
| objects = json.loads(objects_json) | |
| segmentation_image_rgb = bugs_ds[image_index]["Segmentation Image (Correct)"] | |
| segmentation_image_rgb = np.array(segmentation_image_rgb) | |
| annotations = [] | |
| for obj in objects: | |
| color = tuple(obj["color"])[:-1] | |
| mask = np.all(segmentation_image_rgb[:, :, :3] == np.array(color), axis=-1).astype(np.float32) | |
| annotations.append((mask, obj["labelName"])) | |
| object_count = 0 # bugs_ds[image_index]["Object Count"] | |
| victim_name = bugs_ds[image_index]["Victim Name"] | |
| bug_type = bugs_ds[image_index]["Tag"] | |
| bug_image = bugs_ds[image_index]["Buggy Image"] | |
| correct_image = bugs_ds[image_index]["Correct Image"] | |
| # # Load a single image sample from the first dataset for demonstration | |
| # image_sample = sample_dataset1[0] | |
| # # Extract annotations for this image sample | |
| # objects_json = image_sample["Objects JSON (Correct)"] | |
| # objects = json.loads(objects_json) | |
| # segmentation_image_rgb = np.array(image_sample["Segmentation Image (Correct)"]) | |
| # # Generate the overlaid image with custom font size and color | |
| # overlaid_image = generate_overlay_image(np.array(image_sample["Buggy Image"]), objects, segmentation_image_rgb, font_scale=0.7, font_color=(255, 0, 0)) | |
| # # Display the overlaid image | |
| # plt.imshow(overlaid_image) | |
| # plt.axis('off') | |
| # plt.show() | |
| overlaid_image = generate_overlay_image(np.array(bugs_ds[image_index]["Buggy Image"]), objects, segmentation_image_rgb) | |
| return ( | |
| bug_image, | |
| correct_image, | |
| (bugs_ds[image_index]["Correct Image"], annotations), | |
| overlaid_image, | |
| objects, | |
| object_count, | |
| victim_name, | |
| bug_type, | |
| ) | |
| def update_slider(selected_dataset): | |
| dataset = sample_dataset1 if selected_dataset == 'Western Scene' else sample_dataset2 | |
| return gr.update(minimum=0, maximum=len(dataset) - 1, step=1) | |
| # Setting up the Gradio interface using blocks API | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "Enter the image index and click **Submit** to view the segmentation annotations." | |
| ) | |
| with gr.Row(): | |
| selected_dataset = gr.Dropdown(['Western Scene', 'Viking Village'], label="Dataset") | |
| input_slider = gr.Slider( | |
| minimum=0, maximum=1, step=1, label="Image Index" | |
| ) | |
| btn = gr.Button("Visualize") | |
| with gr.Row(): | |
| bug_image = gr.Image() | |
| correct_image = gr.Image() | |
| with gr.Row(): | |
| seg_img = gr.AnnotatedImage() | |
| overlaid_img = gr.Image() | |
| with gr.Row(): | |
| object_count = gr.Number(label="Object Count") | |
| victim_name = gr.Textbox(label="Victim Name") | |
| bug_type = gr.Textbox(label="Bug Type") | |
| with gr.Row(): | |
| json_data = gr.JSON() | |
| btn.click( | |
| fn=generate_annotations, | |
| inputs=[selected_dataset, input_slider], | |
| outputs=[bug_image, correct_image, seg_img, overlaid_img, json_data, object_count, victim_name, bug_type], | |
| ) | |
| selected_dataset.change( | |
| fn=update_slider, | |
| inputs=[selected_dataset], | |
| outputs=[input_slider] | |
| ) | |
| demo.launch() | |