shunxing1234 commited on
Commit
7fe8bf7
·
verified ·
1 Parent(s): c54a749

Delete generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +0 -162
generation_utils.py DELETED
@@ -1,162 +0,0 @@
1
- from typing import Optional
2
- from collections import deque
3
- from queue import Queue
4
- import copy
5
-
6
-
7
- class History:
8
-
9
- def __init__(self, tokenizer, history):
10
- '''
11
- init from a list of dict
12
- '''
13
- # use deque to meet some special situation
14
- self.input_history = deque()
15
- self.tokenizer = tokenizer
16
- if history:
17
- self._transfer_from_list(history)
18
-
19
- def _transfer_from_list(self, history):
20
- for message in history:
21
- content = message.get("content")
22
- # the token result may not be equal to the result model gen
23
- message.update(self.tokenizer(content))
24
- self.input_history.append(message)
25
-
26
- def append(self, message):
27
- content = message.get("content")
28
- if "input_ids" not in message or "attention_mask" not in message:
29
- message.update(self.tokenizer(content))
30
- self.input_history.append(message)
31
-
32
- def append_left(self, message):
33
- content = message.get("content")
34
- if "input_ids" not in message or "attention_mask" not in message:
35
- message.update(self.tokenizer(content))
36
- self.input_history.appendleft(message)
37
-
38
- def pop(self):
39
- x = self.input_history.pop()
40
- return x
41
-
42
- def pop_left(self):
43
- x = self.input_history.pop_left()
44
- return x
45
-
46
- def update(self, message):
47
- self.input_history.pop()
48
- self.append(message)
49
-
50
- def __len__(self):
51
- return self.input_history.__len__()
52
-
53
- def __str__(self):
54
- return self.input_history.__str__()
55
-
56
- def __copy__(self):
57
- new_instance = type(self)(self.tokenizer, [])
58
- new_instance.input_history = copy.copy(self.input_history)
59
- return new_instance
60
-
61
- def __deepcopy__(self, memodict={}):
62
- new_instance = type(self)(self.tokenizer, [])
63
- new_instance.input_history = copy.deepcopy(self.input_history)
64
- return new_instance
65
-
66
-
67
- class TelechatIterTextStreamer:
68
- """
69
- With reference to the TextIterStreamers in transformers, we have rewritten this class
70
- """
71
-
72
- def __init__(
73
- self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None,
74
- **decode_kwargs
75
- ):
76
-
77
- self.tokenizer = tokenizer
78
- self.history = history
79
- self.skip_prompt = skip_prompt
80
- self.timeout = timeout
81
- self.decode_kwargs = decode_kwargs
82
-
83
- self.text_queue = Queue()
84
- self.cache_time = 0
85
- self.text_until = ""
86
- self.token_until = []
87
- self.stop_signal = None
88
- self.next_tokens_are_prompt = True
89
-
90
- self.history.append({"role": "bot", "content": self.text_until})
91
-
92
- def put(self, value):
93
- """
94
- put printable text into queue
95
- """
96
- if len(value.shape) > 1 and value.shape[0] > 1:
97
- raise ValueError("TextStreamer only supports batch size 1")
98
- elif len(value.shape) > 1:
99
- value = value[0]
100
-
101
- if self.skip_prompt and self.next_tokens_are_prompt:
102
- self.next_tokens_are_prompt = False
103
- return
104
-
105
- if value[-1] == self.tokenizer.eos_token_id:
106
- return
107
-
108
- # there may be some smart way to decode.
109
- self.token_until.extend(value.tolist())
110
- text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
111
-
112
-
113
- if self._is_printable(text) or self.cache_time >= 6:
114
- output_text = text[len(self.text_until):]
115
- self.text_until = text
116
-
117
- else:
118
- self.cache_time+=1
119
- return
120
-
121
- self.on_finalized_text(output_text)
122
-
123
- def end(self):
124
- """Flushes any remaining cache and prints a newline to stdout."""
125
- # Flush the cache, if it exists
126
- text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
127
- output_text = text[len(self.text_until):]
128
- self.text_until = text
129
- self.on_finalized_text(output_text, stream_end=True)
130
- self.clear_cache()
131
-
132
- def clear_cache(self):
133
- self.cache_time = 0
134
- self.token_until = []
135
- self.text_until = ""
136
- self.history = None
137
- self.next_tokens_are_prompt = True
138
-
139
- def on_finalized_text(self, text: str, stream_end: bool = False):
140
- """Put the text tuple in the queue."""
141
- self.history.update({"role": "bot", "content": self.text_until, "input_ids": self.token_until,
142
- "attention_mask": [1] * len(self.token_until)})
143
- self.text_queue.put((text, self.history), timeout=self.timeout)
144
- if stream_end:
145
- self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout)
146
-
147
- @staticmethod
148
- def _is_printable(cp):
149
- """Checks whether tokens can be decoded or not"""
150
- if "�" in cp:
151
- return False
152
- return True
153
-
154
- def __iter__(self):
155
- return self
156
-
157
- def __next__(self):
158
- value_now, history_until = self.text_queue.get(timeout=self.timeout)
159
- if value_now == self.stop_signal:
160
- raise StopIteration()
161
- else:
162
- return value_now, history_until