| | import os |
| | import re |
| | import json |
| | import streamlit as st |
| | from PIL import Image, ImageDraw |
| | import requests |
| | from io import BytesIO |
| | import seaborn as sns |
| | import matplotlib.pyplot as plt |
| | from streamlit_chat import message as st_message |
| |
|
| | import yaml |
| |
|
| | st.set_page_config(page_title="Data Exploration", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed") |
| | COLORS = sns.color_palette("Paired", n_colors=100).as_hex() |
| |
|
| | def load_config(config_fn, field='data_explore') -> dict: |
| | config = yaml.load(open(config_fn), Loader=yaml.Loader) |
| | return config[field] |
| |
|
| | def convert_from_prompt_tokens(s_with_region_tokens): |
| | """Convert from strings with prompt tokens for prompt encoders |
| | |
| | e.g.: |
| | |
| | Input: "<Region><L12><24><L101><L777></Region>" |
| | |
| | Output: [0.012, 0.024, 0.101, 0.777] |
| | """ |
| | REGION_PATTERN = r'<Region>(\s*<L(\d{1,4})>\s*<L(\d{1,4})>\s*<L(\d{1,4})>\s*<L(\d{1,4})>\s*)</Region>' |
| | boxes = [] |
| | boxes_str = re.findall(REGION_PATTERN, s_with_region_tokens) |
| | for boxes_str_i in boxes_str: |
| | matched_str_i, boxes_str_i = boxes_str_i[0], boxes_str_i[1:] |
| | boxes.append(tuple([int(s)/1000 for s in boxes_str_i])) |
| | return boxes |
| |
|
| | def parse_regions(s): |
| | pattern = r"\[([\d.,\s]+)\]" |
| | matches = re.findall(pattern, s) |
| | bboxes = [] |
| | points = [] |
| | for res in matches: |
| | res = eval(res) |
| | if len(res) == 4: |
| | |
| | x1, y1, x2, y2 = res |
| | bboxes.append((x1, y1, x2, y2)) |
| | else: |
| | x1, y1 = res |
| | points.append((x1, y1)) |
| | |
| | bboxes.extend(convert_from_prompt_tokens(s)) |
| | return list(set(bboxes)) |
| |
|
| | def get_image(image_path, bboxes=None): |
| |
|
| | if os.path.exists(image_path): |
| | image = Image.open(image_path).convert('RGB') |
| | else: |
| | |
| | response = requests.get(image_path) |
| | image = Image.open(BytesIO(response.content)).convert('RGB') |
| | |
| | draw = ImageDraw.Draw(image, 'RGB') |
| | color_mapping = None |
| | if bboxes is not None: |
| | width, height = image.size |
| | color_mapping = [] |
| | for i, bbox_coords in enumerate(bboxes): |
| | color = COLORS[i] |
| | |
| | x1, y1, x2, y2 = bbox_coords |
| | x1 *= width |
| | y1 *= height |
| | x2 *= width |
| | y2 *= height |
| | draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
| | |
| | color_mapping.append([bbox_coords, color]) |
| | |
| | color_mapping = dict(color_mapping) |
| | return image, color_mapping |
| |
|
| | def insert_color(s, color_mapping): |
| | for coords, color in color_mapping.items(): |
| | coords_str = ', '.join([str(x) for x in coords]) |
| | s = s.replace('[' + coords_str + ']', f'<span style="color: {color}; font-weight: bold;">■</span>' + ' [' + coords_str + ']') |
| | |
| | return s |
| |
|
| | modal_indicator = ['<image>', '<audio>', '<video>'] |
| | def show_one_msg(msg, modal_inputs): |
| | splits = re.split('(' + '|'.join(modal_indicator) + ')', msg) |
| | for s in splits: |
| | if s == '<image>': |
| | st.image(modal_inputs['image'].pop(0)) |
| | elif s == '<audio>': |
| | st.audio(modal_inputs['audio'].pop(0)) |
| | elif s == '<video>': |
| | st.video(modal_inputs['video'].pop(0)) |
| | else: |
| | st.write(s) |
| |
|
| | def show_multimodal_example(example, col1, col2): |
| | with col1: |
| | info = example.get('info', {}) |
| | info['modal_inputs'] = example['modal_inputs'] |
| | st.json(info) |
| | |
| | with col2: |
| | conversations = example['conversations'] |
| | modal_inputs = example['modal_inputs'] |
| | for i in range(len(conversations) // 2): |
| | with st.chat_message("user"): |
| | show_one_msg(conversations[2*i]['value'], modal_inputs) |
| | with st.chat_message("assistant"): |
| | show_one_msg(conversations[2*i+1]['value'], modal_inputs) |
| |
|
| |
|
| | def show_example(example, col1, col2, enable_scores=True): |
| | if 'conversations' in example: |
| | regions = parse_regions(str(example['conversations'])) |
| | else: |
| | regions = parse_regions(str(example)) |
| | |
| | image_fn = example['image'] |
| | image, color_mapping = get_image(image_fn, regions) |
| | |
| | with col1: |
| | st.image(image) |
| | info = example.get('info', {}) |
| | info['id'] = example.get('id', 'N/A') |
| | info['image'] = image_fn |
| | if 'dataset' in example: |
| | info['source'] = example['dataset'] |
| | st.json(info) |
| | |
| | if len(color_mapping): |
| | table_md = "| 颜色 | 坐标 |\n| --- | --- |\n" |
| | for coords, color in color_mapping.items(): |
| | color_cell = f'<span style="color: {color}; font-weight: bold;">■</span>' |
| | table_md += f"| {color_cell} | {coords} |\n" |
| |
|
| | |
| | st.markdown(table_md, unsafe_allow_html=True) |
| | |
| | score_dict = None |
| | with col2: |
| | if 'conversations' in example: |
| | if enable_scores: |
| | score_dict = {'image': image_fn, 'conversations': example['conversations']} |
| | with st.expander("Give a score based on the result above", expanded=True): |
| | quality_score = st.radio("问题质量分数",('Bad', 'Mediocre', 'Good'),key="quality", horizontal = True) |
| | format_score = st.radio("格式分数",('Bad', 'Mediocre', 'Good'),key="format", horizontal = True) |
| | score_dict['scores'] = { |
| | 'quality': quality_score, 'format': format_score |
| | } |
| | st.subheader("Chat") |
| | conversations = example['conversations'] |
| | for i in range(len(conversations) // 2): |
| | st_message(conversations[2*i]['value'], is_user=True, key=image_fn + str(2*i)) |
| | st_message(conversations[2*i+1]['value'], is_user=False, key=image_fn + str(2*i+1)) |
| |
|
| | if 'ground_truth' in example: |
| | |
| | gt = insert_color(json.dumps(example['ground_truth']), color_mapping) |
| | st.markdown(f"**Ground Truth:**\n\n{gt}", unsafe_allow_html=True) |
| | else: |
| | |
| | instruction = insert_color(example['instruction'], color_mapping) |
| | st.markdown(f"**Instruction:**\n\n{instruction}", unsafe_allow_html=True) |
| |
|
| | |
| | if 'input' in example: |
| | input = insert_color(example['input'], color_mapping) |
| | st.markdown(f"**Input:**\n\n{input}", unsafe_allow_html=True) |
| |
|
| | |
| | output = insert_color(example['output'], color_mapping) |
| | st.markdown(f"**Output:**\n\n{output}", unsafe_allow_html=True) |
| |
|
| | if 'query' in example: |
| | |
| | query = insert_color(json.dumps(example['query']), color_mapping) |
| | st.markdown(f"**Query:**\n\n{query}", unsafe_allow_html=True) |
| | return score_dict |
| |
|
| | def reset_state(): |
| | print('RESET') |
| | st.session_state['data_explore'] = {'idx': 0} |
| | st.session_state.scores = {} |
| |
|
| | def load_dir_data(dir, dataset_configs): |
| | mapping_file = os.path.join(dir, 'mapping.yaml') |
| | assert os.path.exists(mapping_file) |
| | |
| | config = yaml.load(open(mapping_file), Loader=yaml.Loader) |
| | |
| | image_paths = config['image_paths'] |
| | image_paths['default'] = image_paths.get('default', '.') |
| | |
| | res = [] |
| | for k, v in config['mapping'].items(): |
| | if os.path.exists(os.path.join(dir, k + '.json')): |
| | data = json.load(open(os.path.join(dir, k + '.json'))) |
| | elif os.path.exists(os.path.join(dir, k + '.jsonl')): |
| | data = [json.loads(line) for line in open(os.path.join(dir, k + '.jsonl'))] |
| | elif os.path.exists(os.path.join(dir, k + '.txt')): |
| | data = [json.loads(line) for line in open(os.path.join(dir, k + '.txt'))] |
| | |
| | image_path = image_paths.get(v, image_paths['default']) |
| | for example in data: |
| | example['image'] = os.path.join(image_path, example['image']) |
| | example['dataset'] = k |
| | res.extend(data) |
| | |
| | return res |
| |
|
| | @st.cache_data |
| | def load_data(fn, dataset_configs): |
| | if os.path.isdir(fn): |
| | res = load_dir_data(fn, dataset_configs) |
| | return res |
| | |
| | if fn.endswith(('.txt', '.jsonl')): |
| | res = [] |
| | for line in open(fn): |
| | example = json.loads(line) |
| | res.append(example) |
| | else: |
| | res = json.load(open(fn)) |
| | |
| | for example in res: |
| | dataset_path = dataset_configs[example.get('dataset', 'default')] |
| | |
| | if 'image' in example: |
| | example['image'] = os.path.join(dataset_path, example['image']) |
| | elif 'img_info' in example: |
| | if isinstance(example['img_info'], str): |
| | example['image'] = os.path.join(dataset_path, example['img_info']) |
| | else: |
| | if 'coco_url' in example['img_info']: |
| | example['image'] = example['img_info']['coco_url'] |
| | else: |
| | assert 'modal_inputs' in example |
| | |
| | return res |
| |
|
| | dataset_configs = load_config('config.yaml') |
| | print(dataset_configs) |
| | data_paths = dataset_configs.get('data_paths', ['instruction_data']) |
| |
|
| | files = [] |
| | def add_file(path): |
| | if os.path.exists(os.path.join(path, 'mapping.yaml')): |
| | files.append(path) |
| | else: |
| | for f in sorted(os.listdir(path)): |
| | file = os.path.join(path, f) |
| | if os.path.isfile(file) and file.endswith(('.txt', '.json')): |
| | files.append(file) |
| | else: |
| | add_file(file) |
| | |
| | for data_path in data_paths: |
| | add_file(data_path) |
| |
|
| |
|
| | st.session_state['data_explore'] = {'idx': 0} |
| | enable_score = st.sidebar.checkbox('Score it!', value=False) |
| | if enable_score and 'scores' not in st.session_state: |
| | st.session_state.scores = {} |
| |
|
| | status_placeholder = st.empty() |
| | control_col1, control_col2 = st.columns(2) |
| |
|
| | with control_col1: |
| | selected_file = st.selectbox('Select a file', files, on_change=reset_state) |
| |
|
| | col1, col2 = st.columns(2) |
| |
|
| | if selected_file: |
| | data = load_data(selected_file, dataset_configs) |
| |
|
| | with control_col2: |
| | idx = st.number_input(f'Input an idx (Total: {len(data)})', min_value=0, max_value=len(data), value=st.session_state.get('data_explore', {}).get('idx', 0)) |
| | st.session_state['data_explore']['idx'] = idx |
| | |
| | if 'image' in data[idx]: |
| | show_example(data[idx], col1, col2, enable_scores=enable_score) |
| | else: |
| | show_multimodal_example(data[idx], col1, col2) |
| |
|
| | if enable_score: |
| | name = st.sidebar.text_input("Username", placeholder = "Enter your name", value="cc") |
| | if st.sidebar.button(label ="Submit scores", key = "submit"): |
| | if name: |
| | score_path = f"score_results/{os.path.basename(selected_file)}_{name}.json" |
| | with open(score_path, "w") as score_file: |
| | json.dump(st.session_state.scores, score_file, indent = 4) |
| | status_placeholder.success("Successfully saved!") |
| | else: |
| | status_placeholder.error("Please enter your name on the sidebar!") |