pvit / Home.py
PVIT's picture
Duplicate from carboncoo/PVIT
3576efa
import os
import re
import copy
import json
import yaml
import random
import streamlit as st
from PIL import Image, ImageDraw
import requests
import base64
from io import BytesIO
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
import datetime
import json
import os
import time
import gradio as gr
import requests
import hashlib
import time
import streamlit as st
import streamlit.components.v1 as components
from streamlit_chat import message as st_message
from streamlit_drawable_canvas import st_canvas
st.set_page_config(page_title="Model Chat", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed")
col_img, col_chat = st.columns([1, 1])
with col_chat:
with st.container():
input_area = st.container()
chatbox = st.container()
# ==================== Conversation =================== #
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
import re
# Hack for displaying Region in Chatbot
def convert_region_tags(text):
pattern = r'<Region>(.*?)<\/Region>'
replaced_text = re.sub(pattern, lambda m: '&lt;Region&gt;' + m.group(1).replace('<', '&lt;').replace('>', '&gt;') + '&lt;/Region&gt;', text)
return replaced_text
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode == "Crop":
pass
elif image_process_mode == "Resize":
image = image.resize((224, 224))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
msg = convert_region_tags(msg)
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
# image = image.resize((224, 224))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = msg.replace('<image>', img_str)
else:
msg = convert_region_tags(msg)
ret.append([msg, None])
else:
if isinstance(msg, str):
msg = convert_region_tags(msg)
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
default_conversation = conv_vicuna_v1_1
# ==================== Chat =================== #
def convert_bbox_to_region(bbox_xywh, image_width, image_height):
bbox_x, bbox_y, bbox_w, bbox_h = bbox_xywh
x1 = bbox_x
y1 = bbox_y
x2 = bbox_x + bbox_w
y2 = bbox_y + bbox_h
x1_normalized = x1 / image_width
y1_normalized = y1 / image_height
x2_normalized = x2 / image_width
y2_normalized = y2 / image_height
x1_norm = int(x1_normalized * 1000)
y1_norm = int(y1_normalized * 1000)
x2_norm = int(x2_normalized * 1000)
y2_norm = int(y2_normalized * 1000)
region_format = "<Region><L{}><L{}><L{}><L{}></Region>".format(x1_norm, y1_norm, x2_norm, y2_norm)
return region_format
def load_config(config_fn, field='chat'):
config = yaml.load(open(config_fn), Loader=yaml.Loader)
return config[field]
chat_config = load_config('configs/chat.yaml')
def get_model_list():
return ['PVIT_v1.0']
def change_model(model_name):
if model_name != st.session_state.get('model_name', ''):
st.session_state['model_name'] = 'PVIT_v1.0'
st.session_state['model_addr'] = chat_config['model_addr']
st.session_state['messages'] = []
def init_chat(image=None):
st.session_state['image'] = image
if 'input_message' not in st.session_state:
st.session_state['input_message'] = ''
if 'messages' not in st.session_state:
st.session_state['messages'] = []
def clear_messages():
st.session_state['messages'] = []
st.session_state['input_message'] = ''
def encode_img(img):
if isinstance(img, str):
img = Image.open(img).convert('RGB')
im_file = BytesIO()
img.save(im_file, format="JPEG")
elif isinstance(img, Image.Image):
im_file = BytesIO()
img.save(im_file, format="JPEG")
else:
im_file = img
im_bytes = im_file.getvalue() # im_bytes: image in binary format.
im_b64 = base64.b64encode(im_bytes).decode()
return im_b64
def send_one_message(message, max_new_tokens=32, temperature=0.7):
conv = default_conversation.copy()
# for role, msg in st.session_state['messages']:
# with chatbox:
# st_message(msg.lstrip('<image>\n'), is_user=(role==conv.roles[0]))
# # show message
# with chatbox:
# st_message(message, is_user=True)
if 'messages' not in st.session_state:
st.session_state['messages'] = []
if len(st.session_state['messages']) == 0:
if '<image>' not in message:
message = '<image>\n' + message
st.session_state['messages'].append([conv.roles[0], message])
conv.messages = copy.deepcopy(st.session_state['messages'])
# conv.append_message(conv.roles[0], message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if 'canvas_result' in st.session_state:
objects = st.session_state['canvas_result'].get('objects', [])
for i, obj in enumerate(objects):
prompt = prompt.replace(f'[REGION-{i}]', obj['bbox_label'])
headers = {"User-Agent": "LLaVA Client"}
pload = {
"prompt": prompt,
"images": [st.session_state['image']],
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"stop": conv.sep2,
}
print(prompt)
response = requests.post(st.session_state['model_addr'] + "/worker_generate_stream", headers=headers,
json=pload, stream=True)
result = ""
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data_t = json.loads(chunk.decode("utf-8"))
output = data_t["text"].split(conv.roles[1]+':')[-1]
result = output
# # show response
# with chatbox:
# st_message(result)
st.session_state['messages'].append([conv.roles[1], result])
# Customize Streamlit UI using CSS # background-color: #eb5424;
st.markdown("""
<style>
div.stButton > button:first-child {
background-color: #eb5424;
color: white;
font-size: 20px;
font-weight: bold;
border-radius: 0.5rem;
padding: 0.5rem 1rem;
border: none;
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
width: 300 px;
height: 42px;
transition: all 0.2s ease-in-out;
}
div.stButton > button:first-child:hover {
transform: translateY(-3px);
box-shadow: 0 1rem 2rem rgba(0,0,0,0.15);
}
div.stButton > button:first-child:active {
transform: translateY(-1px);
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
}
div.stButton > button:focus:not(:focus-visible) {
color: #FFFFFF;
}
@media only screen and (min-width: 768px) {
/* For desktop: */
div.stButton > button:first-child {
background-color: #eb5424;
color: white;
font-size: 20px;
font-weight: bold;
border-radius: 0.5rem;
padding: 0.5rem 1rem;
border: none;
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
width: 300 px;
height: 42px;
transition: all 0.2s ease-in-out;
position: relative;
bottom: -32px;
right: 0px;
}
div.stButton > button:first-child:hover {
transform: translateY(-3px);
box-shadow: 0 1rem 2rem rgba(0,0,0,0.15);
}
div.stButton > button:first-child:active {
transform: translateY(-1px);
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
}
div.stButton > button:focus:not(:focus-visible) {
color: #FFFFFF;
}
input {
border-radius: 0.5rem;
padding: 0.5rem 1rem;
border: none;
box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
transition: all 0.2s ease-in-out;
height: 40px;
}
}
</style>
""", unsafe_allow_html=True)
# ==================== Draw Bounding Boxes =================== #
COLORS = sns.color_palette("tab10", n_colors=10).as_hex()
random.Random(32).shuffle(COLORS)
def update_annotation_states(canvas_result, ratio, img_size):
for obj in canvas_result['objects']:
top = obj["top"] * ratio
left = obj["left"] * ratio
width = obj["width"] * ratio
height = obj["height"] * ratio
obj['bbox_label'] = convert_bbox_to_region([left, top, width, height], img_size[0], img_size[1])
st.session_state['canvas_result'] = canvas_result
st.session_state['label_color'] = COLORS[len(st.session_state['canvas_result']['objects'])+1]
def init_canvas():
if 'canvas_result' not in st.session_state:
st.session_state['canvas_result'] = None
if 'label_color' not in st.session_state:
st.session_state['label_color'] = COLORS[0]
def input_message(msg):
st.session_state['input_message'] = msg
def get_objects():
canvas_result = st.session_state.get('canvas_result', {})
if canvas_result is not None:
objects = canvas_result.get('objects', [])
else:
objects = []
return objects
def format_object_str(input_str):
if 'canvas_result' in st.session_state:
objects = st.session_state['canvas_result'].get('objects', [])
for i, obj in enumerate(objects):
input_str = input_str.replace(f'[REGION-{i}]', obj['bbox_label'])
return input_str
# select model
model_list = get_model_list()
with col_img:
model_name = st.selectbox(
'Choose a model to chat with',
model_list
)
change_model(model_name)
css = ''
# upload image
with col_img:
image = st.file_uploader("Chat with Image", type=["png", "jpg", "jpeg"], on_change=clear_messages)
img_fn = image.name if image is not None else None
if image:
init_chat(encode_img(image))
init_canvas()
img = Image.open(image).convert('RGB')
width = 700
height = round(width * img.size[1] * 1.0 / img.size[0])
ratio = img.size[0] / width
with st.sidebar:
max_new_tokens = st.number_input('max_new_tokens', min_value=1, max_value=1024, value=128)
temperature = st.number_input('temperature', min_value=0.0, max_value=1.0, value=0.0)
drawing_mode = st.selectbox(
"Drawing tool:", ("rect", "point", "line", "circle"),
)
drawing_mode = "transform" if st.checkbox("Move ROIs", False) else drawing_mode
stroke_width = st.slider("Stroke width: ", 1, 25, 3)
# bg_color = st.color_picker("Background color: ", "#eee", key="bg_color")
# save_file = st.text_input("Save File", value="saved.jsonl")
# save_button = st.button(label='Save')
# if save_button:
# if img_fn is None:
# st.warning("Please upload an image first!")
# else:
# conversations_to_save = [{'from': role, 'value': format_object_str(conv)} for (role, conv) in st.session_state['messages']]
# model_name = st.session_state['model_name']
# save_dict = {
# 'image': img_fn,
# 'conversations': conversations_to_save,
# 'info': {
# 'model_name': model_name
# }
# }
# save_image_path = os.path.join(chat_config['save_path'], 'images')
# os.makedirs(save_image_path, exist_ok=True)
# img.save(os.path.join(save_image_path, img_fn))
# chat_save_path = os.path.join(chat_config['save_path'], save_file)
# with open(chat_save_path, 'a+') as fout:
# fout.write(json.dumps(save_dict) + '\n')
# st.success('Save successfully!')
with col_img:
canvas_result = st_canvas(
fill_color=st.session_state['label_color'] + "77", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=st.session_state['label_color'] + "77",
background_color="#eee",
background_image=Image.open(image) if image else None,
update_streamlit=True,
width=width,
height=height,
drawing_mode=drawing_mode,
point_display_radius=3 if drawing_mode == 'point' else 0,
key="canvas"
)
if canvas_result.json_data is not None:
update_annotation_states(canvas_result.json_data, ratio, img.size)
if st.session_state.get('submit_btn', False):
send_one_message(st.session_state['input_message'], max_new_tokens=max_new_tokens, temperature=temperature)
st.session_state['input_message'] = ""
with input_area:
col3, col4, col5 = st.columns([5, 1, 1])
with col3:
message = st.text_input('User', key="input_message")
with col4:
submit_btn = st.button(label='submit', key='submit_btn')
components.html(
"""
<script>
const doc = window.parent.document;
buttons = Array.from(doc.querySelectorAll('button[kind=secondary]'));
const submit = buttons.find(el => el.innerText === 'submit');
doc.addEventListener('keydown', function(e) {
switch (e.keyCode) {
case 13: // (37 = enter)
submit.click();
}
});
</script>
""",
height=0,
width=0,
)
with col5:
clear_btn = st.button(label='clear', on_click=clear_messages)
objects = get_objects()
if len(objects):
bbox_cols = st.columns([1 for _ in range(len(objects))])
def on_bbox_button_click(str):
def f():
st.session_state['input_message'] += str
return f
for i, (obj, bbox_col) in enumerate(zip(objects, bbox_cols)):
with bbox_col:
st.button(label=f'Region-{i}', on_click=on_bbox_button_click(f'[REGION-{i}]'))
# css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.e1g8pov65 > div.block-container.css-z5fcl4.e1g8pov64 > div:nth-child(1) > div > div.css-ocqkz7.esravye3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(2) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n'
css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.ea3mdgi5 > div.block-container.css-awvpbp.ea3mdgi4 > div:nth-child(1) > div > div.css-ocqkz7.e1f1d6gn3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(3) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n'
# css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.ea3mdgi5 > div.block-container.css-awvpbp.ea3mdgi4 > div:nth-child(1) > div > div.css-ocqkz7.e1f1d6gn3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(2) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n'
for i, (role, msg) in enumerate(st.session_state['messages']):
with chatbox:
st_message(msg.lstrip('<image>\n'), is_user=(role==default_conversation.roles[0]), key=f'{i}-{msg}')
st.markdown("<style>\n" + css + "</style>", unsafe_allow_html=True)
st.markdown(
"""
--------------------
### User Manual
- **Step 1.** Upload an image here
""")
st.image("figures/upload_image.png")
st.markdown(
"""
- **Step 2.** (Optional) You can draw bounding boxes on the image. Each box you draw creates a corresponding button of the same color.
""")
st.image("figures/bbox.png", width=512)
st.markdown(
"""
- **Step 3.** Ask questions. Insert region tokens in the question by clicking on the `Region-i` button. For example:
> What color is the dog in [REGION-0]?
> What is the relationship between the dog in [REGION-0] and the dog in [REGION-1]?
**Note**: This demo is in its experimental stage, and we are actively working on improvements.
""")