| |
| import os |
| import json |
| import numpy as np |
| from pathlib import Path |
| from pprint import pprint |
| from omegaconf import OmegaConf |
| from PIL import Image, ImageDraw |
| import streamlit as st |
| import random |
| |
| os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__)) |
| |
| |
|
|
| def get_list_folder(PATH): |
| return [name for name in os.listdir(PATH) if os.path.isdir(os.path.join(PATH, name))] |
|
|
| def get_file_only(PATH): |
| return [f for f in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, f))] |
|
|
| |
| class ImageRetriever: |
|
|
| def __init__(self, root_path, anno_path): |
| self.root_path = Path(root_path) |
| self.anno_path = Path(anno_path) |
|
|
| def key2img_path(self, key): |
| file_paths = [ |
| self.root_path / f"var_images/{key}.jpg", |
| self.root_path / f"var_images/{key}.png", |
| self.root_path / f"images/{key}.jpg", |
| self.root_path / f"img/train/{key.split('_')[0]}/{key}.png", |
| self.root_path / f"img/val/{key.split('_')[0]}/{key}.png", |
| self.root_path / f"img/test/{key.split('_')[0]}/{key}.png", |
| self.root_path / f"img/{key}.png", |
| self.root_path / f"img/{key}.jpg", |
| self.root_path / f"{key}.png", |
| self.root_path / f"{key}.jpg", |
| ] |
| for file_path in file_paths: |
| if file_path.exists(): |
| return file_path |
|
|
|
|
| def key2img(self, key, temp_data, draw_bbox=True): |
| file_path = self.key2img_path(key) |
|
|
| image = Image.open(file_path) |
|
|
| if draw_bbox: |
| bboxes = [temp_data['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] |
| image = self.hide_region(image, bboxes) |
| return image |
|
|
| def hide_region(self, image, bboxes): |
| self.hide_true_bbox = 2 |
|
|
| image = image.convert('RGBA') |
|
|
| if self.hide_true_bbox == 1: |
| draw = ImageDraw.Draw(image, 'RGBA') |
|
|
| if self.hide_true_bbox in [2, 5, 7, 8, 9]: |
| overlay = Image.new('RGBA', image.size, '#00000000') |
| draw = ImageDraw.Draw(overlay, 'RGBA') |
|
|
| if self.hide_true_bbox == 3 or self.hide_true_bbox == 6: |
| overlay = Image.new('RGBA', image.size, '#7B7575ff') |
| draw = ImageDraw.Draw(overlay, 'RGBA') |
|
|
| color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] |
|
|
| for idx, bbox in enumerate(bboxes): |
| if bbox == None: |
| continue |
|
|
| color_fill = color_fill_list[idx] |
| x, y = bbox['left'], bbox['top'] |
|
|
| if self.hide_true_bbox == 1: |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575') |
| elif self.hide_true_bbox in [2, 5, 7, 8, 9]: |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff', |
| width=3) |
| elif self.hide_true_bbox == 3: |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000') |
| elif self.hide_true_bbox == 6: |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill) |
|
|
| if self.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]: |
| image = Image.alpha_composite(image, overlay) |
| return image |
| |
| def retrive_data(temp_data, img_key, mode='direct'): |
|
|
| |
|
|
| message_dict = {} |
|
|
| message_dict['img'] = temp_data['img'] |
| message_dict['plausible_speed'] = temp_data['plausible_speed'] |
| message_dict['bounding_box'] = temp_data['bounding_box'] |
| try: |
| message_dict['hazard'] = temp_data['hazard'] |
| except: |
| message_dict['hazard'] = temp_data['rationale'] |
| message_dict['Entity #1'] = temp_data['Entity #1'] |
| message_dict['Entity #2'] = temp_data['Entity #2'] |
| message_dict['Entity #3'] = temp_data['Entity #3'] |
|
|
| img_retriever = ImageRetriever( |
| root_path=os.path.join(os.environ['ROOT'], ''), |
| anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'), |
| ) |
| img = img_retriever.key2img(img_key, temp_data) |
| img = img.resize((img.width // 2, img.height // 2)) |
|
|
| return img, message_dict |
|
|
| |
|
|
|
|
| |
| if __name__ == '__main__': |
| st.title("DHPR: Driving Hazard Prediction and Reasoning") |
|
|
| img_path = os.path.join(os.environ['ROOT'], 'img/') |
| img_path_list = get_file_only(img_path) |
|
|
| split = 'val' |
| rand_index = 0 |
| main_direct_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'direct'}.json"))) |
| main_indirect_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'indirect'}.json"))) |
|
|
| if st.button('Random Data Sample'): |
| rand_index = random.randint(0, len(get_file_only(img_path))) |
| else: |
| pass |
|
|
| st.subheader("Data Visualization") |
|
|
| img_key = img_path_list[rand_index].split('.')[0] |
|
|
| if img_key in main_direct_dataset.keys(): |
| temp_data = main_direct_dataset[img_key]['details'][-1] |
| elif img_key in main_indirect_dataset.keys(): |
| temp_data = main_indirect_dataset[img_key]['details'][-1] |
| else: |
| pass |
|
|
| img, message_dict = retrive_data(temp_data, img_key) |
|
|
| st.write("---") |
|
|
| st.image(img) |
| st.subheader("Annotation Details") |
| st.json(message_dict) |
| st.write('---') |
| |
|
|
|
|
| |
|
|
|
|