PVIT carboncoo commited on
Commit
3576efa
·
0 Parent(s):

Duplicate from carboncoo/PVIT

Browse files

Co-authored-by: Chi Chen <carboncoo@users.noreply.huggingface.co>

Files changed (7) hide show
  1. .gitattributes +35 -0
  2. Home.py +616 -0
  3. README.md +14 -0
  4. configs/chat.yaml +2 -0
  5. figures/bbox.png +0 -0
  6. figures/upload_image.png +0 -0
  7. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Home.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import copy
4
+ import json
5
+ import yaml
6
+ import random
7
+ import streamlit as st
8
+ from PIL import Image, ImageDraw
9
+ import requests
10
+ import base64
11
+ from io import BytesIO
12
+ import seaborn as sns
13
+ import matplotlib.pyplot as plt
14
+ import pandas as pd
15
+
16
+ from collections import defaultdict
17
+ import datetime
18
+ import json
19
+ import os
20
+ import time
21
+
22
+ import gradio as gr
23
+ import requests
24
+
25
+ import hashlib
26
+ import time
27
+
28
+ import streamlit as st
29
+ import streamlit.components.v1 as components
30
+ from streamlit_chat import message as st_message
31
+ from streamlit_drawable_canvas import st_canvas
32
+
33
+ st.set_page_config(page_title="Model Chat", page_icon="🌍", layout="wide", initial_sidebar_state="collapsed")
34
+
35
+ col_img, col_chat = st.columns([1, 1])
36
+ with col_chat:
37
+ with st.container():
38
+ input_area = st.container()
39
+ chatbox = st.container()
40
+
41
+ # ==================== Conversation =================== #
42
+ import dataclasses
43
+ from enum import auto, Enum
44
+ from typing import List, Tuple
45
+
46
+
47
+ class SeparatorStyle(Enum):
48
+ """Different separator style."""
49
+ SINGLE = auto()
50
+ TWO = auto()
51
+
52
+ import re
53
+ # Hack for displaying Region in Chatbot
54
+ def convert_region_tags(text):
55
+ pattern = r'<Region>(.*?)<\/Region>'
56
+ replaced_text = re.sub(pattern, lambda m: '&lt;Region&gt;' + m.group(1).replace('<', '&lt;').replace('>', '&gt;') + '&lt;/Region&gt;', text)
57
+ return replaced_text
58
+
59
+ @dataclasses.dataclass
60
+ class Conversation:
61
+ """A class that keeps all conversation history."""
62
+ system: str
63
+ roles: List[str]
64
+ messages: List[List[str]]
65
+ offset: int
66
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
67
+ sep: str = "###"
68
+ sep2: str = None
69
+ version: str = "Unknown"
70
+
71
+ skip_next: bool = False
72
+
73
+ def get_prompt(self):
74
+ if self.sep_style == SeparatorStyle.SINGLE:
75
+ ret = self.system + self.sep
76
+ for role, message in self.messages:
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + ": " + message + self.sep
81
+ else:
82
+ ret += role + ":"
83
+ return ret
84
+ elif self.sep_style == SeparatorStyle.TWO:
85
+ seps = [self.sep, self.sep2]
86
+ ret = self.system + seps[0]
87
+ for i, (role, message) in enumerate(self.messages):
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, _, _ = message
91
+ ret += role + ": " + message + seps[i % 2]
92
+ else:
93
+ ret += role + ":"
94
+ return ret
95
+ else:
96
+ raise ValueError(f"Invalid style: {self.sep_style}")
97
+
98
+ def append_message(self, role, message):
99
+ self.messages.append([role, message])
100
+
101
+ def get_images(self, return_pil=False):
102
+ images = []
103
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
104
+ if i % 2 == 0:
105
+ if type(msg) is tuple:
106
+ import base64
107
+ from io import BytesIO
108
+ from PIL import Image
109
+ msg, image, image_process_mode = msg
110
+ if image_process_mode == "Pad":
111
+ def expand2square(pil_img, background_color=(122, 116, 104)):
112
+ width, height = pil_img.size
113
+ if width == height:
114
+ return pil_img
115
+ elif width > height:
116
+ result = Image.new(pil_img.mode, (width, width), background_color)
117
+ result.paste(pil_img, (0, (width - height) // 2))
118
+ return result
119
+ else:
120
+ result = Image.new(pil_img.mode, (height, height), background_color)
121
+ result.paste(pil_img, ((height - width) // 2, 0))
122
+ return result
123
+ image = expand2square(image)
124
+ elif image_process_mode == "Crop":
125
+ pass
126
+ elif image_process_mode == "Resize":
127
+ image = image.resize((224, 224))
128
+ else:
129
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
130
+ max_hw, min_hw = max(image.size), min(image.size)
131
+ aspect_ratio = max_hw / min_hw
132
+ max_len, min_len = 800, 400
133
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
134
+ longest_edge = int(shortest_edge * aspect_ratio)
135
+ W, H = image.size
136
+ if H > W:
137
+ H, W = longest_edge, shortest_edge
138
+ else:
139
+ H, W = shortest_edge, longest_edge
140
+ image = image.resize((W, H))
141
+ if return_pil:
142
+ images.append(image)
143
+ else:
144
+ buffered = BytesIO()
145
+ image.save(buffered, format="JPEG")
146
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
147
+ images.append(img_b64_str)
148
+ return images
149
+
150
+ def to_gradio_chatbot(self):
151
+ ret = []
152
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
153
+ if i % 2 == 0:
154
+ if type(msg) is tuple:
155
+ import base64
156
+ from io import BytesIO
157
+ msg, image, image_process_mode = msg
158
+ msg = convert_region_tags(msg)
159
+ max_hw, min_hw = max(image.size), min(image.size)
160
+ aspect_ratio = max_hw / min_hw
161
+ max_len, min_len = 800, 400
162
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
163
+ longest_edge = int(shortest_edge * aspect_ratio)
164
+ W, H = image.size
165
+ if H > W:
166
+ H, W = longest_edge, shortest_edge
167
+ else:
168
+ H, W = shortest_edge, longest_edge
169
+ image = image.resize((W, H))
170
+ # image = image.resize((224, 224))
171
+ buffered = BytesIO()
172
+ image.save(buffered, format="JPEG")
173
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
174
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
175
+ msg = msg.replace('<image>', img_str)
176
+ else:
177
+ msg = convert_region_tags(msg)
178
+ ret.append([msg, None])
179
+ else:
180
+ if isinstance(msg, str):
181
+ msg = convert_region_tags(msg)
182
+ ret[-1][-1] = msg
183
+ return ret
184
+
185
+ def copy(self):
186
+ return Conversation(
187
+ system=self.system,
188
+ roles=self.roles,
189
+ messages=[[x, y] for x, y in self.messages],
190
+ offset=self.offset,
191
+ sep_style=self.sep_style,
192
+ sep=self.sep,
193
+ sep2=self.sep2)
194
+
195
+ def dict(self):
196
+ if len(self.get_images()) > 0:
197
+ return {
198
+ "system": self.system,
199
+ "roles": self.roles,
200
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
201
+ "offset": self.offset,
202
+ "sep": self.sep,
203
+ "sep2": self.sep2,
204
+ }
205
+ return {
206
+ "system": self.system,
207
+ "roles": self.roles,
208
+ "messages": self.messages,
209
+ "offset": self.offset,
210
+ "sep": self.sep,
211
+ "sep2": self.sep2,
212
+ }
213
+
214
+ conv_vicuna_v1_1 = Conversation(
215
+ system="A chat between a curious user and an artificial intelligence assistant. "
216
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
217
+ roles=("USER", "ASSISTANT"),
218
+ version="v1",
219
+ messages=(),
220
+ offset=0,
221
+ sep_style=SeparatorStyle.TWO,
222
+ sep=" ",
223
+ sep2="</s>",
224
+ )
225
+
226
+ default_conversation = conv_vicuna_v1_1
227
+
228
+ # ==================== Chat =================== #
229
+
230
+
231
+ def convert_bbox_to_region(bbox_xywh, image_width, image_height):
232
+ bbox_x, bbox_y, bbox_w, bbox_h = bbox_xywh
233
+ x1 = bbox_x
234
+ y1 = bbox_y
235
+ x2 = bbox_x + bbox_w
236
+ y2 = bbox_y + bbox_h
237
+
238
+ x1_normalized = x1 / image_width
239
+ y1_normalized = y1 / image_height
240
+ x2_normalized = x2 / image_width
241
+ y2_normalized = y2 / image_height
242
+
243
+ x1_norm = int(x1_normalized * 1000)
244
+ y1_norm = int(y1_normalized * 1000)
245
+ x2_norm = int(x2_normalized * 1000)
246
+ y2_norm = int(y2_normalized * 1000)
247
+
248
+ region_format = "<Region><L{}><L{}><L{}><L{}></Region>".format(x1_norm, y1_norm, x2_norm, y2_norm)
249
+ return region_format
250
+
251
+ def load_config(config_fn, field='chat'):
252
+ config = yaml.load(open(config_fn), Loader=yaml.Loader)
253
+ return config[field]
254
+
255
+ chat_config = load_config('configs/chat.yaml')
256
+
257
+ def get_model_list():
258
+ return ['PVIT_v1.0']
259
+
260
+ def change_model(model_name):
261
+ if model_name != st.session_state.get('model_name', ''):
262
+ st.session_state['model_name'] = 'PVIT_v1.0'
263
+ st.session_state['model_addr'] = chat_config['model_addr']
264
+ st.session_state['messages'] = []
265
+
266
+
267
+ def init_chat(image=None):
268
+ st.session_state['image'] = image
269
+ if 'input_message' not in st.session_state:
270
+ st.session_state['input_message'] = ''
271
+ if 'messages' not in st.session_state:
272
+ st.session_state['messages'] = []
273
+
274
+ def clear_messages():
275
+ st.session_state['messages'] = []
276
+ st.session_state['input_message'] = ''
277
+
278
+ def encode_img(img):
279
+ if isinstance(img, str):
280
+ img = Image.open(img).convert('RGB')
281
+ im_file = BytesIO()
282
+ img.save(im_file, format="JPEG")
283
+ elif isinstance(img, Image.Image):
284
+ im_file = BytesIO()
285
+ img.save(im_file, format="JPEG")
286
+ else:
287
+ im_file = img
288
+ im_bytes = im_file.getvalue() # im_bytes: image in binary format.
289
+ im_b64 = base64.b64encode(im_bytes).decode()
290
+ return im_b64
291
+
292
+
293
+ def send_one_message(message, max_new_tokens=32, temperature=0.7):
294
+ conv = default_conversation.copy()
295
+ # for role, msg in st.session_state['messages']:
296
+ # with chatbox:
297
+ # st_message(msg.lstrip('<image>\n'), is_user=(role==conv.roles[0]))
298
+
299
+ # # show message
300
+ # with chatbox:
301
+ # st_message(message, is_user=True)
302
+ if 'messages' not in st.session_state:
303
+ st.session_state['messages'] = []
304
+ if len(st.session_state['messages']) == 0:
305
+ if '<image>' not in message:
306
+ message = '<image>\n' + message
307
+ st.session_state['messages'].append([conv.roles[0], message])
308
+ conv.messages = copy.deepcopy(st.session_state['messages'])
309
+ # conv.append_message(conv.roles[0], message)
310
+ conv.append_message(conv.roles[1], None)
311
+ prompt = conv.get_prompt()
312
+
313
+ if 'canvas_result' in st.session_state:
314
+ objects = st.session_state['canvas_result'].get('objects', [])
315
+ for i, obj in enumerate(objects):
316
+ prompt = prompt.replace(f'[REGION-{i}]', obj['bbox_label'])
317
+
318
+ headers = {"User-Agent": "LLaVA Client"}
319
+ pload = {
320
+ "prompt": prompt,
321
+ "images": [st.session_state['image']],
322
+ "max_new_tokens": max_new_tokens,
323
+ "temperature": temperature,
324
+ "stop": conv.sep2,
325
+ }
326
+ print(prompt)
327
+ response = requests.post(st.session_state['model_addr'] + "/worker_generate_stream", headers=headers,
328
+ json=pload, stream=True)
329
+ result = ""
330
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
331
+ if chunk:
332
+ data_t = json.loads(chunk.decode("utf-8"))
333
+ output = data_t["text"].split(conv.roles[1]+':')[-1]
334
+ result = output
335
+
336
+ # # show response
337
+ # with chatbox:
338
+ # st_message(result)
339
+ st.session_state['messages'].append([conv.roles[1], result])
340
+
341
+
342
+ # Customize Streamlit UI using CSS # background-color: #eb5424;
343
+ st.markdown("""
344
+ <style>
345
+ div.stButton > button:first-child {
346
+ background-color: #eb5424;
347
+ color: white;
348
+ font-size: 20px;
349
+ font-weight: bold;
350
+ border-radius: 0.5rem;
351
+ padding: 0.5rem 1rem;
352
+ border: none;
353
+ box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
354
+ width: 300 px;
355
+ height: 42px;
356
+ transition: all 0.2s ease-in-out;
357
+ }
358
+ div.stButton > button:first-child:hover {
359
+ transform: translateY(-3px);
360
+ box-shadow: 0 1rem 2rem rgba(0,0,0,0.15);
361
+ }
362
+ div.stButton > button:first-child:active {
363
+ transform: translateY(-1px);
364
+ box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
365
+ }
366
+ div.stButton > button:focus:not(:focus-visible) {
367
+ color: #FFFFFF;
368
+ }
369
+ @media only screen and (min-width: 768px) {
370
+ /* For desktop: */
371
+ div.stButton > button:first-child {
372
+ background-color: #eb5424;
373
+ color: white;
374
+ font-size: 20px;
375
+ font-weight: bold;
376
+ border-radius: 0.5rem;
377
+ padding: 0.5rem 1rem;
378
+ border: none;
379
+ box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
380
+ width: 300 px;
381
+ height: 42px;
382
+ transition: all 0.2s ease-in-out;
383
+ position: relative;
384
+ bottom: -32px;
385
+ right: 0px;
386
+ }
387
+ div.stButton > button:first-child:hover {
388
+ transform: translateY(-3px);
389
+ box-shadow: 0 1rem 2rem rgba(0,0,0,0.15);
390
+ }
391
+ div.stButton > button:first-child:active {
392
+ transform: translateY(-1px);
393
+ box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
394
+ }
395
+ div.stButton > button:focus:not(:focus-visible) {
396
+ color: #FFFFFF;
397
+ }
398
+ input {
399
+ border-radius: 0.5rem;
400
+ padding: 0.5rem 1rem;
401
+ border: none;
402
+ box-shadow: 0 0.5rem 1rem rgba(0,0,0,0.15);
403
+ transition: all 0.2s ease-in-out;
404
+ height: 40px;
405
+ }
406
+ }
407
+ </style>
408
+ """, unsafe_allow_html=True)
409
+
410
+ # ==================== Draw Bounding Boxes =================== #
411
+
412
+ COLORS = sns.color_palette("tab10", n_colors=10).as_hex()
413
+ random.Random(32).shuffle(COLORS)
414
+
415
+ def update_annotation_states(canvas_result, ratio, img_size):
416
+ for obj in canvas_result['objects']:
417
+ top = obj["top"] * ratio
418
+ left = obj["left"] * ratio
419
+ width = obj["width"] * ratio
420
+ height = obj["height"] * ratio
421
+ obj['bbox_label'] = convert_bbox_to_region([left, top, width, height], img_size[0], img_size[1])
422
+ st.session_state['canvas_result'] = canvas_result
423
+ st.session_state['label_color'] = COLORS[len(st.session_state['canvas_result']['objects'])+1]
424
+
425
+ def init_canvas():
426
+ if 'canvas_result' not in st.session_state:
427
+ st.session_state['canvas_result'] = None
428
+ if 'label_color' not in st.session_state:
429
+ st.session_state['label_color'] = COLORS[0]
430
+
431
+ def input_message(msg):
432
+ st.session_state['input_message'] = msg
433
+
434
+
435
+ def get_objects():
436
+ canvas_result = st.session_state.get('canvas_result', {})
437
+ if canvas_result is not None:
438
+ objects = canvas_result.get('objects', [])
439
+ else:
440
+ objects = []
441
+ return objects
442
+
443
+ def format_object_str(input_str):
444
+ if 'canvas_result' in st.session_state:
445
+ objects = st.session_state['canvas_result'].get('objects', [])
446
+ for i, obj in enumerate(objects):
447
+ input_str = input_str.replace(f'[REGION-{i}]', obj['bbox_label'])
448
+ return input_str
449
+
450
+ # select model
451
+ model_list = get_model_list()
452
+ with col_img:
453
+ model_name = st.selectbox(
454
+ 'Choose a model to chat with',
455
+ model_list
456
+ )
457
+ change_model(model_name)
458
+
459
+ css = ''
460
+ # upload image
461
+ with col_img:
462
+ image = st.file_uploader("Chat with Image", type=["png", "jpg", "jpeg"], on_change=clear_messages)
463
+ img_fn = image.name if image is not None else None
464
+ if image:
465
+ init_chat(encode_img(image))
466
+ init_canvas()
467
+
468
+ img = Image.open(image).convert('RGB')
469
+
470
+ width = 700
471
+ height = round(width * img.size[1] * 1.0 / img.size[0])
472
+ ratio = img.size[0] / width
473
+
474
+ with st.sidebar:
475
+ max_new_tokens = st.number_input('max_new_tokens', min_value=1, max_value=1024, value=128)
476
+ temperature = st.number_input('temperature', min_value=0.0, max_value=1.0, value=0.0)
477
+ drawing_mode = st.selectbox(
478
+ "Drawing tool:", ("rect", "point", "line", "circle"),
479
+ )
480
+ drawing_mode = "transform" if st.checkbox("Move ROIs", False) else drawing_mode
481
+ stroke_width = st.slider("Stroke width: ", 1, 25, 3)
482
+ # bg_color = st.color_picker("Background color: ", "#eee", key="bg_color")
483
+
484
+ # save_file = st.text_input("Save File", value="saved.jsonl")
485
+ # save_button = st.button(label='Save')
486
+
487
+ # if save_button:
488
+ # if img_fn is None:
489
+ # st.warning("Please upload an image first!")
490
+ # else:
491
+ # conversations_to_save = [{'from': role, 'value': format_object_str(conv)} for (role, conv) in st.session_state['messages']]
492
+ # model_name = st.session_state['model_name']
493
+ # save_dict = {
494
+ # 'image': img_fn,
495
+ # 'conversations': conversations_to_save,
496
+ # 'info': {
497
+ # 'model_name': model_name
498
+ # }
499
+ # }
500
+
501
+ # save_image_path = os.path.join(chat_config['save_path'], 'images')
502
+ # os.makedirs(save_image_path, exist_ok=True)
503
+
504
+ # img.save(os.path.join(save_image_path, img_fn))
505
+
506
+ # chat_save_path = os.path.join(chat_config['save_path'], save_file)
507
+ # with open(chat_save_path, 'a+') as fout:
508
+ # fout.write(json.dumps(save_dict) + '\n')
509
+
510
+ # st.success('Save successfully!')
511
+
512
+ with col_img:
513
+ canvas_result = st_canvas(
514
+ fill_color=st.session_state['label_color'] + "77", # Fixed fill color with some opacity
515
+ stroke_width=stroke_width,
516
+ stroke_color=st.session_state['label_color'] + "77",
517
+ background_color="#eee",
518
+ background_image=Image.open(image) if image else None,
519
+ update_streamlit=True,
520
+ width=width,
521
+ height=height,
522
+ drawing_mode=drawing_mode,
523
+ point_display_radius=3 if drawing_mode == 'point' else 0,
524
+ key="canvas"
525
+ )
526
+
527
+ if canvas_result.json_data is not None:
528
+ update_annotation_states(canvas_result.json_data, ratio, img.size)
529
+
530
+ if st.session_state.get('submit_btn', False):
531
+ send_one_message(st.session_state['input_message'], max_new_tokens=max_new_tokens, temperature=temperature)
532
+ st.session_state['input_message'] = ""
533
+
534
+ with input_area:
535
+ col3, col4, col5 = st.columns([5, 1, 1])
536
+
537
+ with col3:
538
+ message = st.text_input('User', key="input_message")
539
+
540
+ with col4:
541
+ submit_btn = st.button(label='submit', key='submit_btn')
542
+
543
+ components.html(
544
+ """
545
+ <script>
546
+ const doc = window.parent.document;
547
+ buttons = Array.from(doc.querySelectorAll('button[kind=secondary]'));
548
+ const submit = buttons.find(el => el.innerText === 'submit');
549
+
550
+ doc.addEventListener('keydown', function(e) {
551
+ switch (e.keyCode) {
552
+ case 13: // (37 = enter)
553
+ submit.click();
554
+ }
555
+ });
556
+ </script>
557
+ """,
558
+ height=0,
559
+ width=0,
560
+ )
561
+
562
+ with col5:
563
+ clear_btn = st.button(label='clear', on_click=clear_messages)
564
+
565
+
566
+ objects = get_objects()
567
+
568
+ if len(objects):
569
+ bbox_cols = st.columns([1 for _ in range(len(objects))])
570
+
571
+ def on_bbox_button_click(str):
572
+ def f():
573
+ st.session_state['input_message'] += str
574
+ return f
575
+
576
+ for i, (obj, bbox_col) in enumerate(zip(objects, bbox_cols)):
577
+ with bbox_col:
578
+ st.button(label=f'Region-{i}', on_click=on_bbox_button_click(f'[REGION-{i}]'))
579
+ # css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.e1g8pov65 > div.block-container.css-z5fcl4.e1g8pov64 > div:nth-child(1) > div > div.css-ocqkz7.esravye3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(2) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n'
580
+ css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.ea3mdgi5 > div.block-container.css-awvpbp.ea3mdgi4 > div:nth-child(1) > div > div.css-ocqkz7.e1f1d6gn3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(3) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n'
581
+ # css += f"#root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.css-uf99v8.ea3mdgi5 > div.block-container.css-awvpbp.ea3mdgi4 > div:nth-child(1) > div > div.css-ocqkz7.e1f1d6gn3 > div:nth-child(2) > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(1) > div > div:nth-child(2) > div:nth-child({i+1}) > div:nth-child(1) > div > div > div > button {{background-color:{obj['stroke'][:7]}; bottom: 0px}} \n" + '\n'
582
+
583
+ for i, (role, msg) in enumerate(st.session_state['messages']):
584
+ with chatbox:
585
+ st_message(msg.lstrip('<image>\n'), is_user=(role==default_conversation.roles[0]), key=f'{i}-{msg}')
586
+
587
+ st.markdown("<style>\n" + css + "</style>", unsafe_allow_html=True)
588
+
589
+ st.markdown(
590
+ """
591
+ --------------------
592
+ ### User Manual
593
+
594
+ - **Step 1.** Upload an image here
595
+ """)
596
+
597
+ st.image("figures/upload_image.png")
598
+
599
+ st.markdown(
600
+ """
601
+ - **Step 2.** (Optional) You can draw bounding boxes on the image. Each box you draw creates a corresponding button of the same color.
602
+ """)
603
+
604
+ st.image("figures/bbox.png", width=512)
605
+
606
+ st.markdown(
607
+ """
608
+ - **Step 3.** Ask questions. Insert region tokens in the question by clicking on the `Region-i` button. For example:
609
+
610
+ > What color is the dog in [REGION-0]?
611
+
612
+ > What is the relationship between the dog in [REGION-0] and the dog in [REGION-1]?
613
+
614
+ **Note**: This demo is in its experimental stage, and we are actively working on improvements.
615
+
616
+ """)
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PVIT
3
+ emoji: 🐢
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.26.0
8
+ app_file: Home.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: carboncoo/PVIT
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
configs/chat.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ chat:
2
+ model_addr: http://demo.shuzibeijing.cn:40021
figures/bbox.png ADDED
figures/upload_image.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit-drawable-canvas
2
+ streamlit_chat
3
+ seaborn
4
+ pandas