|
|
import os |
|
|
import re |
|
|
import copy |
|
|
import json |
|
|
import yaml |
|
|
import random |
|
|
import streamlit as st |
|
|
from PIL import Image, ImageDraw |
|
|
import requests |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
import pandas as pd |
|
|
|
|
|
from collections import defaultdict |
|
|
import datetime |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
|
|
|
import gradio as gr |
|
|
import requests |
|
|
|
|
|
import hashlib |
|
|
import time |
|
|
|
|
|
import streamlit as st |
|
|
import streamlit.components.v1 as components |
|
|
from streamlit_chat import message as st_message |
|
|
from streamlit_drawable_canvas import st_canvas |
|
|
|
|
|
st.set_page_config(page_title="Model Chat", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed") |
|
|
|
|
|
col_img, col_chat = st.columns([1, 1]) |
|
|
with col_chat: |
|
|
with st.container(): |
|
|
input_area = st.container() |
|
|
chatbox = st.container() |
|
|
|
|
|
|
|
|
import dataclasses |
|
|
from enum import auto, Enum |
|
|
from typing import List, Tuple |
|
|
|
|
|
|
|
|
class SeparatorStyle(Enum): |
|
|
"""Different separator style.""" |
|
|
SINGLE = auto() |
|
|
TWO = auto() |
|
|
|
|
|
import re |
|
|
|
|
|
def convert_region_tags(text): |
|
|
pattern = r'<Region>(.*?)<\/Region>' |
|
|
replaced_text = re.sub(pattern, lambda m: '<Region>' + m.group(1).replace('<', '<').replace('>', '>') + '</Region>', text) |
|
|
return replaced_text |
|
|
|
|
|
@dataclasses.dataclass |
|
|
class Conversation: |
|
|
"""A class that keeps all conversation history.""" |
|
|
system: str |
|
|
roles: List[str] |
|
|
messages: List[List[str]] |
|
|
offset: int |
|
|
sep_style: SeparatorStyle = SeparatorStyle.SINGLE |
|
|
sep: str = "###" |
|
|
sep2: str = None |
|
|
version: str = "Unknown" |
|
|
|
|
|
skip_next: bool = False |
|
|
|
|
|
def get_prompt(self): |
|
|
if self.sep_style == SeparatorStyle.SINGLE: |
|
|
ret = self.system + self.sep |
|
|
for role, message in self.messages: |
|
|
if message: |
|
|
if type(message) is tuple: |
|
|
message, _, _ = message |
|
|
ret += role + ": " + message + self.sep |
|
|
else: |
|
|
ret += role + ":" |
|
|
return ret |
|
|
elif self.sep_style == SeparatorStyle.TWO: |
|
|
seps = [self.sep, self.sep2] |
|
|
ret = self.system + seps[0] |
|
|
for i, (role, message) in enumerate(self.messages): |
|
|
if message: |
|
|
if type(message) is tuple: |
|
|
message, _, _ = message |
|
|
ret += role + ": " + message + seps[i % 2] |
|
|
else: |
|
|
ret += role + ":" |
|
|
return ret |
|
|
else: |
|
|
raise ValueError(f"Invalid style: {self.sep_style}") |
|
|
|
|
|
def append_message(self, role, message): |
|
|
self.messages.append([role, message]) |
|
|
|
|
|
def get_images(self, return_pil=False): |
|
|
images = [] |
|
|
for i, (role, msg) in enumerate(self.messages[self.offset:]): |
|
|
if i % 2 == 0: |
|
|
if type(msg) is tuple: |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
msg, image, image_process_mode = msg |
|
|
if image_process_mode == "Pad": |
|
|
def expand2square(pil_img, background_color=(122, 116, 104)): |
|
|
width, height = pil_img.size |
|
|
if width == height: |
|
|
return pil_img |
|
|
elif width > height: |
|
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
|
return result |
|
|
else: |
|
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
|
return result |
|
|
image = expand2square(image) |
|
|
elif image_process_mode == "Crop": |
|
|
pass |
|
|
elif image_process_mode == "Resize": |
|
|
image = image.resize((224, 224)) |
|
|
else: |
|
|
raise ValueError(f"Invalid image_process_mode: {image_process_mode}") |
|
|
max_hw, min_hw = max(image.size), min(image.size) |
|
|
aspect_ratio = max_hw / min_hw |
|
|
max_len, min_len = 800, 400 |
|
|
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) |
|
|
longest_edge = int(shortest_edge * aspect_ratio) |
|
|
W, H = image.size |
|
|
if H > W: |
|
|
H, W = longest_edge, shortest_edge |
|
|
else: |
|
|
H, W = shortest_edge, longest_edge |
|
|
image = image.resize((W, H)) |
|
|
if return_pil: |
|
|
images.append(image) |
|
|
else: |
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="JPEG") |
|
|
img_b64_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
images.append(img_b64_str) |
|
|
return images |
|
|
|
|
|
def to_gradio_chatbot(self): |
|
|
ret = [] |
|
|
for i, (role, msg) in enumerate(self.messages[self.offset:]): |
|
|
if i % 2 == 0: |
|
|
if type(msg) is tuple: |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
msg, image, image_process_mode = msg |
|
|
msg = convert_region_tags(msg) |
|
|
max_hw, min_hw = max(image.size), min(image.size) |
|
|
aspect_ratio = max_hw / min_hw |
|
|
max_len, min_len = 800, 400 |
|
|
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) |
|
|
longest_edge = int(shortest_edge * aspect_ratio) |
|
|
W, H = image.size |
|
|
if H > W: |
|
|
H, W = longest_edge, shortest_edge |
|
|
else: |
|
|
H, W = shortest_edge, longest_edge |
|
|
image = image.resize((W, H)) |
|
|
|
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="JPEG") |
|
|
img_b64_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />' |
|
|
msg = msg.replace('<image>', img_str) |
|
|
else: |
|
|
msg = convert_region_tags(msg) |
|
|
ret.append([msg, None]) |
|
|
else: |
|
|
if isinstance(msg, str): |
|
|
msg = convert_region_tags(msg) |
|
|
ret[-1][-1] = msg |
|
|
return ret |
|
|
|
|
|
def copy(self): |
|
|
return Conversation( |
|
|
system=self.system, |
|
|
roles=self.roles, |
|
|
messages=[[x, y] for x, y in self.messages], |
|
|
offset=self.offset, |
|
|
sep_style=self.sep_style, |
|
|
sep=self.sep, |
|
|
sep2=self.sep2) |
|
|
|
|
|
def dict(self): |
|
|
if len(self.get_images()) > 0: |
|
|
return { |
|
|
"system": self.system, |
|
|
"roles": self.roles, |
|
|
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], |
|
|
"offset": self.offset, |
|
|
"sep": self.sep, |
|
|
"sep2": self.sep2, |
|
|
} |
|
|
return { |
|
|
"system": self.system, |
|
|
"roles": self.roles, |
|
|
"messages": self.messages, |
|
|
"offset": self.offset, |
|
|
"sep": self.sep, |
|
|
"sep2": self.sep2, |
|
|
} |
|
|
|
|
|
conv_vicuna_v1_1 = Conversation( |
|
|
system="A chat between a curious user and an artificial intelligence assistant. " |
|
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.", |
|
|
roles=("USER", "ASSISTANT"), |
|
|
version="v1", |
|
|
messages=(), |
|
|
offset=0, |
|
|
sep_style=SeparatorStyle.TWO, |
|
|
sep=" ", |
|
|
sep2="</s>", |
|
|
) |
|
|
|
|
|
default_conversation = conv_vicuna_v1_1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_bbox_to_region(bbox_xywh, image_width, image_height): |
|
|
bbox_x, bbox_y, bbox_w, bbox_h = bbox_xywh |
|
|
x1 = bbox_x |
|
|
y1 = bbox_y |
|
|
x2 = bbox_x + bbox_w |
|
|
y2 = bbox_y + bbox_h |
|
|
|
|
|
x1_normalized = x1 / image_width |
|
|
y1_normalized = y1 / image_height |
|
|
x2_normalized = x2 / image_width |
|
|
y2_normalized = y2 / image_height |
|
|
|
|
|
x1_norm = int(x1_normalized * 1000) |
|
|
y1_norm = int(y1_normalized * 1000) |
|
|
x2_norm = int(x2_normalized * 1000) |
|
|
y2_norm = int(y2_normalized * 1000) |
|
|
|
|
|
region_format = "<Region><L{}><L{}><L{}><L{}></Region>".format(x1_norm, y1_norm, x2_norm, y2_norm) |
|
|
return region_format |
|
|
|
|
|
def load_config(config_fn, field='chat'): |
|
|
config = yaml.load(open(config_fn), Loader=yaml.Loader) |
|
|
return config[field] |
|
|
|
|
|
chat_config = load_config('configs/chat.yaml') |
|
|
|
|
|
def get_model_list(): |
|
|
return ['PVIT_v1.0'] |
|
|
|
|
|
def change_model(model_name): |
|
|
if model_name != st.session_state.get('model_name', ''): |
|
|
st.session_state['model_name'] = 'PVIT_v1.0' |
|
|
st.session_state['model_addr'] = chat_config['model_addr'] |
|
|
st.session_state['messages'] = [] |
|
|
|
|
|
|
|
|
def init_chat(image=None): |
|
|
st.session_state['image'] = image |
|
|
if 'input_message' not in st.session_state: |
|
|
st.session_state['input_message'] = '' |
|
|
if 'messages' not in st.session_state: |
|
|
st.session_state['messages'] = [] |
|
|
|
|
|
def clear_messages(): |
|
|
st.session_state['messages'] = [] |
|
|
st.session_state['input_message'] = '' |
|
|
|
|
|
def encode_img(img): |
|
|
if isinstance(img, str): |
|
|
img = Image.open(img).convert('RGB') |
|
|
im_file = BytesIO() |
|
|
img.save(im_file, format="JPEG") |
|
|
elif isinstance(img, Image.Image): |
|
|
im_file = BytesIO() |
|
|
img.save(im_file, format="JPEG") |
|
|
else: |
|
|
im_file = img |
|
|
im_bytes = im_file.getvalue() |
|
|
im_b64 = base64.b64encode(im_bytes).decode() |
|
|
return im_b64 |
|
|
|
|
|
|
|
|
def send_one_message(message, max_new_tokens=32, temperature=0.7): |
|
|
conv = default_conversation.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'messages' not in st.session_state: |
|
|
st.session_state['messages'] = [] |
|
|
if len(st.session_state['messages']) == 0: |
|
|
if '<image>' not in message: |
|
|
message = '<image>\n' + message |
|
|
st.session_state['messages'].append([conv.roles[0], message]) |
|
|
conv.messages = copy.deepcopy(st.session_state['messages']) |
|
|
|
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt = conv.get_prompt() |
|
|
|
|
|
if 'canvas_result' in st.session_state: |
|
|
objects = st.session_state['canvas_result'].get('objects', []) |
|
|
for i, obj in enumerate(objects): |
|
|
prompt = prompt.replace(f'[REGION-{i}]', obj['bbox_label']) |
|
|
|
|
|
headers = {"User-Agent": "LLaVA Client"} |
|
|
pload = { |
|
|
"prompt": prompt, |
|
|
"images": [st.session_state['image']], |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature, |
|
|
"stop": conv.sep2, |
|
|
} |
|
|
print(prompt) |
|
|
response = requests.post(st.session_state['model_addr'] + "/worker_generate_stream", headers=headers, |
|
|
json=pload, stream=True) |
|
|
result = "" |
|
|
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): |
|
|
if chunk: |
|
|
data_t = json.loads(chunk.decode("utf-8")) |
|
|
output = data_t["text"].split(conv.roles[1]+':')[-1] |
|
|
result = output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state['messages'].append([conv.roles[1], result]) |
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
div.stButton > button:first-child { |
|
|
background-color: #eb5424; |
|
|
color: white; |
|
|
font-size: 20px; |
|
|
font-weight: bold; |
|
|
border-radius: 0.5rem; |
|
|
padding: 0.5rem 1rem; |
|
|
border: none; |
|
|
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15); |
|
|
width: 300 px; |
|
|
height: 42px; |
|
|
transition: all 0.2s ease-in-out; |
|
|
} |
|
|
div.stButton > button:first-child:hover { |
|
|
transform: translateY(-3px); |
|
|
box-shadow: 0 1rem 2rem rgba(0,0,0,0.15); |
|
|
} |
|
|
div.stButton > button:first-child:active { |
|
|
transform: translateY(-1px); |
|
|
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15); |
|
|
} |
|
|
div.stButton > button:focus:not(:focus-visible) { |
|
|
color: #FFFFFF; |
|
|
} |
|
|
@media only screen and (min-width: 768px) { |
|
|
/* For desktop: */ |
|
|
div.stButton > button:first-child { |
|
|
background-color: #eb5424; |
|
|
color: white; |
|
|
font-size: 20px; |
|
|
font-weight: bold; |
|
|
border-radius: 0.5rem; |
|
|
padding: 0.5rem 1rem; |
|
|
border: none; |
|
|
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15); |
|
|
width: 300 px; |
|
|
height: 42px; |
|
|
transition: all 0.2s ease-in-out; |
|
|
position: relative; |
|
|
bottom: -32px; |
|
|
right: 0px; |
|
|
} |
|
|
div.stButton > button:first-child:hover { |
|
|
transform: translateY(-3px); |
|
|
box-shadow: 0 1rem 2rem rgba(0,0,0,0.15); |
|
|
} |
|
|
div.stButton > button:first-child:active { |
|
|
transform: translateY(-1px); |
|
|
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15); |
|
|
} |
|
|
div.stButton > button:focus:not(:focus-visible) { |
|
|
color: #FFFFFF; |
|
|
} |
|
|
input { |
|
|
border-radius: 0.5rem; |
|
|
padding: 0.5rem 1rem; |
|
|
border: none; |
|
|
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15); |
|
|
transition: all 0.2s ease-in-out; |
|
|
height: 40px; |
|
|
} |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
COLORS = sns.color_palette("tab10", n_colors=10).as_hex() |
|
|
random.Random(32).shuffle(COLORS) |
|
|
|
|
|
def update_annotation_states(canvas_result, ratio, img_size): |
|
|
for obj in canvas_result['objects']: |
|
|
top = obj["top"] * ratio |
|
|
left = obj["left"] * ratio |
|
|
width = obj["width"] * ratio |
|
|
height = obj["height"] * ratio |
|
|
obj['bbox_label'] = convert_bbox_to_region([left, top, width, height], img_size[0], img_size[1]) |
|
|
st.session_state['canvas_result'] = canvas_result |
|
|
st.session_state['label_color'] = COLORS[len(st.session_state['canvas_result']['objects'])+1] |
|
|
|
|
|
def init_canvas(): |
|
|
if 'canvas_result' not in st.session_state: |
|
|
st.session_state['canvas_result'] = None |
|
|
if 'label_color' not in st.session_state: |
|
|
st.session_state['label_color'] = COLORS[0] |
|
|
|
|
|
def input_message(msg): |
|
|
st.session_state['input_message'] = msg |
|
|
|
|
|
|
|
|
def get_objects(): |
|
|
canvas_result = st.session_state.get('canvas_result', {}) |
|
|
if canvas_result is not None: |
|
|
objects = canvas_result.get('objects', []) |
|
|
else: |
|
|
objects = [] |
|
|
return objects |
|
|
|
|
|
def format_object_str(input_str): |
|
|
if 'canvas_result' in st.session_state: |
|
|
objects = st.session_state['canvas_result'].get('objects', []) |
|
|
for i, obj in enumerate(objects): |
|
|
input_str = input_str.replace(f'[REGION-{i}]', obj['bbox_label']) |
|
|
return input_str |
|
|
|
|
|
|
|
|
model_list = get_model_list() |
|
|
with col_img: |
|
|
model_name = st.selectbox( |
|
|
'Choose a model to chat with', |
|
|
model_list |
|
|
) |
|
|
change_model(model_name) |
|
|
|
|
|
css = '' |
|
|
|
|
|
with col_img: |
|
|
image = st.file_uploader("Chat with Image", type=["png", "jpg", "jpeg"], on_change=clear_messages) |
|
|
img_fn = image.name if image is not None else None |
|
|
if image: |
|
|
init_chat(encode_img(image)) |
|
|
init_canvas() |
|
|
|
|
|
img = Image.open(image).convert('RGB') |
|
|
|
|
|
width = 700 |
|
|
height = round(width * img.size[1] * 1.0 / img.size[0]) |
|
|
ratio = img.size[0] / width |
|
|
|
|
|
with st.sidebar: |
|
|
max_new_tokens = st.number_input('max_new_tokens', min_value=1, max_value=1024, value=128) |
|
|
temperature = st.number_input('temperature', min_value=0.0, max_value=1.0, value=0.0) |
|
|
drawing_mode = st.selectbox( |
|
|
"Drawing tool:", ("rect", "point", "line", "circle"), |
|
|
) |
|
|
drawing_mode = "transform" if st.checkbox("Move ROIs", False) else drawing_mode |
|
|
stroke_width = st.slider("Stroke width: ", 1, 25, 3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with col_img: |
|
|
canvas_result = st_canvas( |
|
|
fill_color=st.session_state['label_color'] + "77", |
|
|
stroke_width=stroke_width, |
|
|
stroke_color=st.session_state['label_color'] + "77", |
|
|
background_color="#eee", |
|
|
background_image=Image.open(image) if image else None, |
|
|
update_streamlit=True, |
|
|
width=width, |
|
|
height=height, |
|
|
drawing_mode=drawing_mode, |
|
|
point_display_radius=3 if drawing_mode == 'point' else 0, |
|
|
key="canvas" |
|
|
) |
|
|
|
|
|
if canvas_result.json_data is not None: |
|
|
update_annotation_states(canvas_result.json_data, ratio, img.size) |
|
|
|
|
|
if st.session_state.get('submit_btn', False): |
|
|
send_one_message(st.session_state['input_message'], max_new_tokens=max_new_tokens, temperature=temperature) |
|
|
st.session_state['input_message'] = "" |
|
|
|
|
|
with input_area: |
|
|
col3, col4, col5 = st.columns([5, 1, 1]) |
|
|
|
|
|
with col3: |
|
|
message = st.text_input('User', key="input_message") |
|
|
|
|
|
with col4: |
|
|
submit_btn = st.button(label='submit', key='submit_btn') |
|
|
|
|
|
components.html( |
|
|
""" |
|
|
<script> |
|
|
const doc = window.parent.document; |
|
|
buttons = Array.from(doc.querySelectorAll('button[kind=secondary]')); |
|
|
const submit = buttons.find(el => el.innerText === 'submit'); |
|
|
|
|
|
doc.addEventListener('keydown', function(e) { |
|
|
switch (e.keyCode) { |
|
|
case 13: // (37 = enter) |
|
|
submit.click(); |
|
|
} |
|
|
}); |
|
|
</script> |
|
|
""", |
|
|
height=0, |
|
|
width=0, |
|
|
) |
|
|
|
|
|
with col5: |
|
|
clear_btn = st.button(label='clear', on_click=clear_messages) |
|
|
|
|
|
|
|
|
objects = get_objects() |
|
|
|
|
|
if len(objects): |
|
|
bbox_cols = st.columns([1 for _ in range(len(objects))]) |
|
|
|
|
|
def on_bbox_button_click(str): |
|
|
def f(): |
|
|
st.session_state['input_message'] += str |
|
|
return f |
|
|
|
|
|
for i, (obj, bbox_col) in enumerate(zip(objects, bbox_cols)): |
|
|
with bbox_col: |
|
|
st.button(label=f'Region-{i}', on_click=on_bbox_button_click(f'[REGION-{i}]')) |
|
|
|
|
|
css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.ea3mdgi5 > div.block-container.css-awvpbp.ea3mdgi4 > div:nth-child(1) > div > div.css-ocqkz7.e1f1d6gn3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(3) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n' |
|
|
|
|
|
|
|
|
for i, (role, msg) in enumerate(st.session_state['messages']): |
|
|
with chatbox: |
|
|
st_message(msg.lstrip('<image>\n'), is_user=(role==default_conversation.roles[0]), key=f'{i}-{msg}') |
|
|
|
|
|
st.markdown("<style>\n" + css + "</style>", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
-------------------- |
|
|
### User Manual |
|
|
|
|
|
- **Step 1.** Upload an image here |
|
|
""") |
|
|
|
|
|
st.image("figures/upload_image.png") |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
- **Step 2.** (Optional) You can draw bounding boxes on the image. Each box you draw creates a corresponding button of the same color. |
|
|
""") |
|
|
|
|
|
st.image("figures/bbox.png", width=512) |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
- **Step 3.** Ask questions. Insert region tokens in the question by clicking on the `Region-i` button. For example: |
|
|
|
|
|
> What color is the dog in [REGION-0]? |
|
|
|
|
|
> What is the relationship between the dog in [REGION-0] and the dog in [REGION-1]? |
|
|
|
|
|
**Note**: This demo is in its experimental stage, and we are actively working on improvements. |
|
|
|
|
|
""") |