Spaces:
Runtime error
Runtime error
update chat
Browse files- app.py +25 -9
- multimodal/open_flamingo/chat/conversation.py +46 -10
app.py
CHANGED
|
@@ -237,30 +237,36 @@ def upload_img(gr_img, text_input, chat_state,chatbot):
|
|
| 237 |
value="Start Chatting", interactive=False), chat_state, img_list,chatbot
|
| 238 |
|
| 239 |
|
| 240 |
-
def gradio_ask(user_message, chatbot, chat_state):
|
| 241 |
if len(user_message) == 0:
|
| 242 |
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
| 243 |
|
| 244 |
|
| 245 |
-
chat.ask(user_message, chat_state)
|
| 246 |
chatbot = chatbot + [[user_message, None]]
|
| 247 |
return '', chatbot, chat_state
|
| 248 |
|
| 249 |
|
| 250 |
-
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
| 251 |
llm_message,image = \
|
| 252 |
chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
|
| 253 |
-
max_length=2000)
|
| 254 |
|
| 255 |
chatbot[-1][1] = llm_message
|
| 256 |
if image==None:
|
| 257 |
return chatbot, chat_state, img_list
|
| 258 |
else:
|
| 259 |
path = build_image(image)
|
| 260 |
-
chatbot = chatbot + [[(path,)
|
| 261 |
return chatbot, chat_state, img_list
|
| 262 |
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
with gr.Blocks() as demo:
|
| 266 |
gr.Markdown(title)
|
|
@@ -273,6 +279,9 @@ with gr.Blocks() as demo:
|
|
| 273 |
image = gr.Image(type="pil")
|
| 274 |
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
| 275 |
clear = gr.Button("Restart")
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
num_beams = gr.Slider(
|
| 278 |
minimum=1,
|
|
@@ -296,13 +305,20 @@ with gr.Blocks() as demo:
|
|
| 296 |
chat_state = gr.State()
|
| 297 |
img_list = gr.State()
|
| 298 |
chatbot = gr.Chatbot(label='Compositional-VLM')
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
upload_button.click(upload_img, [image, text_input, chat_state,chatbot],
|
| 302 |
[image, text_input, upload_button, chat_state, img_list,chatbot])
|
| 303 |
|
| 304 |
-
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
| 305 |
-
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
| 306 |
)
|
| 307 |
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
| 308 |
queue=False)
|
|
|
|
| 237 |
value="Start Chatting", interactive=False), chat_state, img_list,chatbot
|
| 238 |
|
| 239 |
|
| 240 |
+
def gradio_ask(user_message, chatbot, chat_state,radio):
|
| 241 |
if len(user_message) == 0:
|
| 242 |
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
| 243 |
|
| 244 |
|
| 245 |
+
chat.ask(user_message, chat_state,radio)
|
| 246 |
chatbot = chatbot + [[user_message, None]]
|
| 247 |
return '', chatbot, chat_state
|
| 248 |
|
| 249 |
|
| 250 |
+
def gradio_answer(chatbot, chat_state, img_list, radio, text,num_beams, temperature,radio):
|
| 251 |
llm_message,image = \
|
| 252 |
chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
|
| 253 |
+
max_length=2000,radio = radio,text_input = text)
|
| 254 |
|
| 255 |
chatbot[-1][1] = llm_message
|
| 256 |
if image==None:
|
| 257 |
return chatbot, chat_state, img_list
|
| 258 |
else:
|
| 259 |
path = build_image(image)
|
| 260 |
+
chatbot = chatbot + [[None,(path,)]]
|
| 261 |
return chatbot, chat_state, img_list
|
| 262 |
|
| 263 |
+
task_template = {
|
| 264 |
+
"Cap": "Summarize the content of the photo <image>.",
|
| 265 |
+
"VQA": "For this image <image>, I want a simple and direct answer to my question: <question>",
|
| 266 |
+
"REC": "Can you point out <expr> in the image <image> and provide the coordinates of its location?",
|
| 267 |
+
"GC": "Can you give me a description of the region <boxes> in image <image>?",
|
| 268 |
+
"Advanced": "<question>",
|
| 269 |
+
}
|
| 270 |
|
| 271 |
with gr.Blocks() as demo:
|
| 272 |
gr.Markdown(title)
|
|
|
|
| 279 |
image = gr.Image(type="pil")
|
| 280 |
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
| 281 |
clear = gr.Button("Restart")
|
| 282 |
+
radio = gr.Radio(
|
| 283 |
+
["Cap", "VQA", "REC", "Advanced"], label="Task Template", value='Cap',
|
| 284 |
+
)
|
| 285 |
|
| 286 |
num_beams = gr.Slider(
|
| 287 |
minimum=1,
|
|
|
|
| 305 |
chat_state = gr.State()
|
| 306 |
img_list = gr.State()
|
| 307 |
chatbot = gr.Chatbot(label='Compositional-VLM')
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False,
|
| 311 |
+
# value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.')
|
| 312 |
+
# text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image first, then input...", lines=3,
|
| 313 |
+
# value=None, visible=False, interactive=False)
|
| 314 |
+
|
| 315 |
+
text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...', interactive=False)
|
| 316 |
|
| 317 |
upload_button.click(upload_img, [image, text_input, chat_state,chatbot],
|
| 318 |
[image, text_input, upload_button, chat_state, img_list,chatbot])
|
| 319 |
|
| 320 |
+
text_input.submit(gradio_ask, [text_input, chatbot, chat_state,radio], [text_input, chatbot, chat_state]).then(
|
| 321 |
+
gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature, radio], [chatbot, chat_state, img_list]
|
| 322 |
)
|
| 323 |
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
|
| 324 |
queue=False)
|
multimodal/open_flamingo/chat/conversation.py
CHANGED
|
@@ -278,18 +278,34 @@ class Chat:
|
|
| 278 |
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
| 279 |
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 280 |
|
| 281 |
-
def ask(self, text, conv):
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
# if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
| 287 |
# and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
| 288 |
# conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
| 289 |
# else:
|
| 290 |
# conv.append_message(conv.roles[0], text)
|
| 291 |
|
| 292 |
-
def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
|
| 293 |
repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
|
| 294 |
# conv.append_message(conv.roles[1], None)
|
| 295 |
# embs = self.get_context_emb(conv, img_list)
|
|
@@ -315,7 +331,14 @@ class Chat:
|
|
| 315 |
# output_text = output_text.split('###')[0] # remove the stop sign '###'
|
| 316 |
# output_text = output_text.split('Assistant:')[-1].strip()
|
| 317 |
# conv.messages[-1][1] = output_text
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
| 320 |
box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
| 321 |
endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
|
@@ -336,10 +359,23 @@ class Chat:
|
|
| 336 |
|
| 337 |
# conversation = []
|
| 338 |
human_sentence = None
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
"from": "gpt",
|
| 341 |
-
"value":
|
| 342 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
# while True:
|
| 344 |
# human_sentence = input("### Human: ")
|
| 345 |
# if human_sentence == "#end#":
|
|
|
|
| 278 |
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
| 279 |
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 280 |
|
| 281 |
+
def ask(self, text, conv,radio):
|
| 282 |
+
if radio in ["Cap"]:
|
| 283 |
+
conv.append({
|
| 284 |
+
"from": "human",
|
| 285 |
+
"value": "",
|
| 286 |
+
})
|
| 287 |
+
elif radio in ["VQA"]:
|
| 288 |
+
conv.append({
|
| 289 |
+
"from": "human",
|
| 290 |
+
"value": f"Answer the question using a single word or phrase.{text}",
|
| 291 |
+
})
|
| 292 |
+
elif radio in ["REC"]:
|
| 293 |
+
conv.append({
|
| 294 |
+
"from": "human",
|
| 295 |
+
"value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
|
| 296 |
+
})
|
| 297 |
+
else:
|
| 298 |
+
conv.append({
|
| 299 |
+
"from": "human",
|
| 300 |
+
"value": text,
|
| 301 |
+
})
|
| 302 |
# if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
| 303 |
# and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
| 304 |
# conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
| 305 |
# else:
|
| 306 |
# conv.append_message(conv.roles[0], text)
|
| 307 |
|
| 308 |
+
def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
|
| 309 |
repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
|
| 310 |
# conv.append_message(conv.roles[1], None)
|
| 311 |
# embs = self.get_context_emb(conv, img_list)
|
|
|
|
| 331 |
# output_text = output_text.split('###')[0] # remove the stop sign '###'
|
| 332 |
# output_text = output_text.split('Assistant:')[-1].strip()
|
| 333 |
# conv.messages[-1][1] = output_text
|
| 334 |
+
visual_token = "<|#visual#|>"
|
| 335 |
+
previsual_token = "<|#previsual#|>"
|
| 336 |
+
box_token = "<|#box#|>"
|
| 337 |
+
prebox_token = "<|#prebox#|>"
|
| 338 |
+
end_token = "<|#endofobject#|>"
|
| 339 |
+
object_token = "<|#object#|>"
|
| 340 |
+
end_of_attr_token = "<|#endofattr#|>"
|
| 341 |
+
preend_of_attr_token = "<|#preendofattr#|>"
|
| 342 |
media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
|
| 343 |
box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
|
| 344 |
endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
|
|
|
|
| 359 |
|
| 360 |
# conversation = []
|
| 361 |
human_sentence = None
|
| 362 |
+
if radio in ["Cap","VQA"]:
|
| 363 |
+
conv.append({
|
| 364 |
+
"from": "gpt",
|
| 365 |
+
"value": "",
|
| 366 |
+
})
|
| 367 |
+
elif radio in ["REC"]:
|
| 368 |
+
conv.append(
|
| 369 |
+
{
|
| 370 |
"from": "gpt",
|
| 371 |
+
"value": object_token + text_input + end_token + visual_token
|
| 372 |
+
}
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
conv.append({
|
| 376 |
+
"from": "gpt",
|
| 377 |
+
"value": "",
|
| 378 |
+
})
|
| 379 |
# while True:
|
| 380 |
# human_sentence = input("### Human: ")
|
| 381 |
# if human_sentence == "#end#":
|