Spaces:
Runtime error
Runtime error
update chat
Browse files
multimodal/open_flamingo/chat/conversation.py
CHANGED
|
@@ -366,11 +366,25 @@ class Chat:
|
|
| 366 |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
| 367 |
image_start_index_list = [[x] for x in image_start_index_list]
|
| 368 |
image_nums = [1] * len(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
# and torch.cuda.amp.autocast(dtype=torch.float16)
|
| 370 |
with torch.no_grad():
|
| 371 |
-
outputs = model(
|
| 372 |
-
vision_x=
|
| 373 |
-
lang_x=
|
| 374 |
attention_mask=attention_mask,
|
| 375 |
image_nums=image_nums,
|
| 376 |
image_start_index_list=image_start_index_list,
|
|
@@ -411,7 +425,7 @@ class Chat:
|
|
| 411 |
# # conv.messages[-1][1] = output_text
|
| 412 |
# print(
|
| 413 |
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
| 414 |
-
output_text =
|
| 415 |
return output_text, out_image
|
| 416 |
|
| 417 |
def upload_img(self, image, conv, img_list):
|
|
|
|
| 366 |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
|
| 367 |
image_start_index_list = [[x] for x in image_start_index_list]
|
| 368 |
image_nums = [1] * len(input_ids)
|
| 369 |
+
added_bbox_list = []
|
| 370 |
+
with torch.inference_mode():
|
| 371 |
+
text_outputs = self.model.generate(
|
| 372 |
+
batch_images,
|
| 373 |
+
input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
max_new_tokens=20,
|
| 376 |
+
# min_new_tokens=8,
|
| 377 |
+
num_beams=1,
|
| 378 |
+
# length_penalty=0,
|
| 379 |
+
image_start_index_list=image_start_index_list,
|
| 380 |
+
image_nums=image_nums,
|
| 381 |
+
added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
|
| 382 |
+
)
|
| 383 |
# and torch.cuda.amp.autocast(dtype=torch.float16)
|
| 384 |
with torch.no_grad():
|
| 385 |
+
outputs = self.model(
|
| 386 |
+
vision_x=batch_images,
|
| 387 |
+
lang_x=input_ids,
|
| 388 |
attention_mask=attention_mask,
|
| 389 |
image_nums=image_nums,
|
| 390 |
image_start_index_list=image_start_index_list,
|
|
|
|
| 425 |
# # conv.messages[-1][1] = output_text
|
| 426 |
# print(
|
| 427 |
# f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
|
| 428 |
+
output_text = self.tokenizer.decode(text_outputs[0])
|
| 429 |
return output_text, out_image
|
| 430 |
|
| 431 |
def upload_img(self, image, conv, img_list):
|