Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from pprint import pprint | |
| from omegaconf import OmegaConf | |
| from PIL import Image, ImageDraw | |
| import streamlit as st | |
| os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__)) | |
| # print("ROOT", os.environ['ROOT']) | |
| class ImageRetriever: | |
| def __init__(self, root_path, anno_path): | |
| self.root_path = Path(root_path) | |
| self.anno = json.load(open(anno_path)) | |
| def key2img_path(self, key): | |
| file_paths = [ | |
| self.root_path / f"images/{key}.jpg", | |
| self.root_path / f"images/{key}.png", | |
| self.root_path / f"var_images/{key}.jpg", | |
| self.root_path / f"var_images/{key}.png", | |
| self.root_path / f"images/{key}.jpg", | |
| self.root_path / f"img/train/{key.split('_')[0]}/{key}.png", | |
| self.root_path / f"img/val/{key.split('_')[0]}/{key}.png", | |
| self.root_path / f"img/test/{key.split('_')[0]}/{key}.png", | |
| ] | |
| for file_path in file_paths: | |
| if file_path.exists(): | |
| return file_path | |
| def key2img(self, key, draw_bbox=True): | |
| file_path = self.key2img_path(key) | |
| image = Image.open(file_path) | |
| if draw_bbox: | |
| meta = self.anno[key]['details'][-1] | |
| bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] | |
| image = self.hide_region(image, bboxes) | |
| return image | |
| def hide_region(self, image, bboxes): | |
| self.hide_true_bbox = 2 | |
| image = image.convert('RGBA') | |
| if self.hide_true_bbox == 1: # hide mode | |
| draw = ImageDraw.Draw(image, 'RGBA') | |
| if self.hide_true_bbox in [2, 5, 7, 8, 9]: #highlight mode | |
| overlay = Image.new('RGBA', image.size, '#00000000') | |
| draw = ImageDraw.Draw(overlay, 'RGBA') | |
| if self.hide_true_bbox == 3 or self.hide_true_bbox == 6: #blackout mode or position only mode | |
| overlay = Image.new('RGBA', image.size, '#7B7575ff') | |
| draw = ImageDraw.Draw(overlay, 'RGBA') | |
| color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] # Green, Blue, Yellow? | |
| for idx, bbox in enumerate(bboxes): | |
| if bbox == None: | |
| continue | |
| color_fill = color_fill_list[idx] | |
| x, y = bbox['left'], bbox['top'] | |
| if self.hide_true_bbox == 1: # hide mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575') | |
| elif self.hide_true_bbox in [2, 5, 7, 8, 9]: # highlight mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff', | |
| width=3) # Fill with Pink 60% ##00F1E8 | |
| elif self.hide_true_bbox == 3: # blackout mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000') | |
| elif self.hide_true_bbox == 6: # position only mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill) | |
| if self.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]: | |
| image = Image.alpha_composite(image, overlay) | |
| return image | |
| def main(mode='spec', split='val', which='best'): | |
| result_dict = json.load(open(os.path.join(os.environ['ROOT'], f"{mode}_{split}_demo.json"))) | |
| main_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"anno_{split}_combined.json"))) | |
| which_keys = {'best': [], 'worst': []} | |
| message_dict = {} | |
| for file_key in result_dict.keys(): | |
| if file_key == 'settings': | |
| continue | |
| annotated_hazard = result_dict[file_key]['annotated_hazard'] | |
| annotated_rank = result_dict[file_key]['annotated_rank'] | |
| res = result_dict[file_key] | |
| message = {'annotated_hazard': f"{int(annotated_rank)} - {annotated_hazard}"} | |
| message['hazard'] = [f"rank: {int(tup[0])}, score: {tup[1]}, sent: {tup[2]}" for tup in zip(res['rank'], res['chatgpt_score'], res['hazard'])] | |
| message_dict[file_key] = message | |
| if 1 <= annotated_rank <= 20: | |
| which_keys['best'].append(file_key) | |
| if annotated_rank > 100: | |
| which_keys['worst'].append(file_key) | |
| img_retriever = ImageRetriever( | |
| root_path=os.path.join(os.environ['ROOT'], ''), | |
| anno_path=os.path.join(os.environ['ROOT'], f'anno_{split}_combined.json'), | |
| ) | |
| # st.write(f"Total cases: {len(which_keys[which])}") | |
| st.write("---") | |
| for file_key in which_keys[which]: | |
| # st.write(f"File key: {file_key}") | |
| img = img_retriever.key2img(file_key) | |
| # img = img.resize((img.width // 3, img.height // 3)) | |
| st.image(img) | |
| st.json(message_dict[file_key]) | |
| st.write('---') | |
| if __name__ == '__main__': | |
| main(mode='spec', split='val', which='best') |