add stream params for chat and char_crop
#24
by
weege007 - opened
- modeling_GOT.py +5 -5
modeling_GOT.py
CHANGED
|
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 486 |
|
| 487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
| 488 |
|
| 489 |
self.disable_torch_init()
|
| 490 |
|
|
@@ -563,8 +563,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 564 |
keywords = [stop_str]
|
| 565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 566 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
-
|
| 568 |
if stream_flag:
|
| 569 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 570 |
output_ids = self.generate(
|
|
@@ -728,7 +728,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 728 |
return processed_images
|
| 729 |
|
| 730 |
|
| 731 |
-
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
| 732 |
# Model
|
| 733 |
self.disable_torch_init()
|
| 734 |
multi_page=False
|
|
@@ -817,7 +817,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 818 |
keywords = [stop_str]
|
| 819 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 820 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
if stream_flag:
|
| 823 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
|
|
| 484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 486 |
|
| 487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False, streamer=None):
|
| 488 |
|
| 489 |
self.disable_torch_init()
|
| 490 |
|
|
|
|
| 563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 564 |
keywords = [stop_str]
|
| 565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 566 |
+
streamer = streamer if streamer else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
+
|
| 568 |
if stream_flag:
|
| 569 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 570 |
output_ids = self.generate(
|
|
|
|
| 728 |
return processed_images
|
| 729 |
|
| 730 |
|
| 731 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False, streamer=None):
|
| 732 |
# Model
|
| 733 |
self.disable_torch_init()
|
| 734 |
multi_page=False
|
|
|
|
| 817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 818 |
keywords = [stop_str]
|
| 819 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 820 |
+
streamer = streamer if streamer else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
if stream_flag:
|
| 823 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|