Fix streaming_chat
Browse files- modeling_internlm.py +53 -21
modeling_internlm.py
CHANGED
|
@@ -20,6 +20,7 @@
|
|
| 20 |
""" PyTorch InternLM model."""
|
| 21 |
import math
|
| 22 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 23 |
|
| 24 |
import torch
|
| 25 |
import torch.utils.checkpoint
|
|
@@ -784,7 +785,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
| 784 |
do_sample: bool = True,
|
| 785 |
temperature: float = 0.8,
|
| 786 |
top_p: float = 0.8,
|
| 787 |
-
eos_token_id = (2, 103028),
|
| 788 |
**kwargs):
|
| 789 |
inputs = self.build_inputs(tokenizer, query, history)
|
| 790 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
|
@@ -794,7 +794,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
| 794 |
do_sample=do_sample,
|
| 795 |
temperature=temperature,
|
| 796 |
top_p=top_p,
|
| 797 |
-
eos_token_id=list(eos_token_id),
|
| 798 |
**kwargs)
|
| 799 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
|
| 800 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
|
@@ -811,38 +810,71 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
| 811 |
do_sample: bool = True,
|
| 812 |
temperature: float = 0.8,
|
| 813 |
top_p: float = 0.8,
|
| 814 |
-
eos_token_id = (2, 103028),
|
| 815 |
**kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
class ChatStreamer(BaseStreamer):
|
| 817 |
def __init__(self, tokenizer) -> None:
|
| 818 |
super().__init__()
|
| 819 |
self.tokenizer = tokenizer
|
| 820 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
def put(self, value):
|
| 822 |
if len(value.shape) > 1 and value.shape[0] > 1:
|
| 823 |
raise ValueError("ChatStreamer only supports batch size 1")
|
| 824 |
elif len(value.shape) > 1:
|
| 825 |
value = value[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
| 827 |
if token.strip() != "<eoa>":
|
| 828 |
-
|
| 829 |
-
|
|
|
|
|
|
|
| 830 |
def end(self):
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 846 |
|
| 847 |
@add_start_docstrings(
|
| 848 |
"""
|
|
|
|
| 20 |
""" PyTorch InternLM model."""
|
| 21 |
import math
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
+
import threading, queue
|
| 24 |
|
| 25 |
import torch
|
| 26 |
import torch.utils.checkpoint
|
|
|
|
| 785 |
do_sample: bool = True,
|
| 786 |
temperature: float = 0.8,
|
| 787 |
top_p: float = 0.8,
|
|
|
|
| 788 |
**kwargs):
|
| 789 |
inputs = self.build_inputs(tokenizer, query, history)
|
| 790 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
|
|
|
| 794 |
do_sample=do_sample,
|
| 795 |
temperature=temperature,
|
| 796 |
top_p=top_p,
|
|
|
|
| 797 |
**kwargs)
|
| 798 |
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
|
| 799 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
|
|
|
| 810 |
do_sample: bool = True,
|
| 811 |
temperature: float = 0.8,
|
| 812 |
top_p: float = 0.8,
|
|
|
|
| 813 |
**kwargs):
|
| 814 |
+
"""
|
| 815 |
+
Return a generator in format: (response, history)
|
| 816 |
+
Eg.
|
| 817 |
+
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
| 818 |
+
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
| 819 |
+
"""
|
| 820 |
+
|
| 821 |
+
response_queue = queue.Queue(maxsize=20)
|
| 822 |
+
|
| 823 |
class ChatStreamer(BaseStreamer):
|
| 824 |
def __init__(self, tokenizer) -> None:
|
| 825 |
super().__init__()
|
| 826 |
self.tokenizer = tokenizer
|
| 827 |
+
self.queue = response_queue
|
| 828 |
+
self.query = query
|
| 829 |
+
self.history = history
|
| 830 |
+
self.response = ""
|
| 831 |
+
self.received_inputs = False
|
| 832 |
+
self.queue.put((self.response, history + [(self.query, self.response)]))
|
| 833 |
+
|
| 834 |
def put(self, value):
|
| 835 |
if len(value.shape) > 1 and value.shape[0] > 1:
|
| 836 |
raise ValueError("ChatStreamer only supports batch size 1")
|
| 837 |
elif len(value.shape) > 1:
|
| 838 |
value = value[0]
|
| 839 |
+
|
| 840 |
+
if not self.received_inputs:
|
| 841 |
+
# The first received value is input_ids, ignore here
|
| 842 |
+
self.received_inputs = True
|
| 843 |
+
return
|
| 844 |
+
|
| 845 |
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
|
| 846 |
if token.strip() != "<eoa>":
|
| 847 |
+
self.response = self.response + token
|
| 848 |
+
history = self.history + [(self.query, self.response)]
|
| 849 |
+
self.queue.put((self.response, history))
|
| 850 |
+
|
| 851 |
def end(self):
|
| 852 |
+
self.queue.put(None)
|
| 853 |
+
|
| 854 |
+
def stream_producer():
|
| 855 |
+
return self.chat(
|
| 856 |
+
tokenizer=tokenizer,
|
| 857 |
+
query=query,
|
| 858 |
+
streamer=ChatStreamer(tokenizer=tokenizer),
|
| 859 |
+
history=history,
|
| 860 |
+
max_new_tokens=max_new_tokens,
|
| 861 |
+
do_sample=do_sample,
|
| 862 |
+
temperature=temperature,
|
| 863 |
+
top_p=top_p,
|
| 864 |
+
**kwargs
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
def consumer():
|
| 868 |
+
producer = threading.Thread(target=stream_producer)
|
| 869 |
+
producer.start()
|
| 870 |
+
while True:
|
| 871 |
+
res = response_queue.get()
|
| 872 |
+
if res is None:
|
| 873 |
+
return
|
| 874 |
+
yield res
|
| 875 |
+
|
| 876 |
+
return consumer()
|
| 877 |
+
|
| 878 |
|
| 879 |
@add_start_docstrings(
|
| 880 |
"""
|