mdokl commited on
Commit
ddca8f7
·
verified ·
1 Parent(s): a7081d9

Upload 4 files

Browse files

优化推理时缓存管理与重复惩罚,添加历史记录输出!

Files changed (4) hide show
  1. app.py +299 -294
  2. install_ac.py +11 -0
  3. tokenizer.py +91 -104
  4. train_and_use.py +25 -13
app.py CHANGED
@@ -1,294 +1,299 @@
1
- # 公开库
2
- import time
3
- import html
4
- import uuid
5
- import torch
6
- import threading
7
- import numpy as np
8
- import gradio as gr
9
- from queue import Queue
10
- # 私有库
11
- from make_model import make_model
12
- from LazyCache import ExpiringDict
13
- from train_and_use import El_text_continue_stream
14
- from tokenizer import tokenizer,vocab_size,token2str
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
- # 加载模型
18
- model = make_model(
19
- #token是从1开始的,0填充,剩下的用来覆盖全部字节
20
- vocab_size = vocab_size+1+255,
21
- embedding_dim = 768,
22
- key_dim = 128,
23
- head_number = 12,
24
- position_information_type = "mask",
25
- enable_affine = True,
26
- enable_talking_head = True,
27
- use_diff = False,
28
- self_attention_block_size = 0,
29
- feed_forward_dim = 1536,
30
- enable_layer_norm = True,
31
- deep = 12,
32
- dropout_rate = 0.1,
33
- enable_el_cache = True
34
- ).to(device)
35
- model.load_state_dict(torch.load('large_model_instruct_09291732.weight',map_location=device,weights_only=True))
36
- model = model.eval()
37
-
38
- # token包装函数 - 使用HTML span标签确保每个token在独立矩形中
39
- def token_wapper(token):
40
- # 对特殊字符进行HTML转义处理
41
- escaped_token = html.escape(token)
42
- return f'<span class="token-box">{escaped_token}</span>'
43
-
44
- # 多token包装函数 - 使用HTML span标签确保每个token在独立矩形中
45
- def token_split_wapper(token):
46
- # 对特殊字符进行HTML转义处理
47
- escaped_token = html.escape(token)
48
- return f'<span class="multi-token-box">({escaped_token})[多token]</span>'
49
-
50
- # 处理用户输入的token,返回安全的显示格式
51
- def process_user_tokens(user_message):
52
- # 通过分词器转化为token
53
- user_tokens = tokenizer(user_message, 5.0)
54
-
55
- # 将token还原并进行安全包装
56
- words = [] # token列表
57
- temp = [] # token是特殊字节,要合并
58
- for token in user_tokens:
59
- if token > 0:
60
- # 将合并成功的加入列表
61
- if len(temp):
62
- words.append(token_split_wapper(token2str(temp)))
63
- temp = []
64
- # 将新的token加入列表
65
- words.append(token_wapper(token2str([token])))
66
- else:
67
- # 将字节送去合并
68
- temp.append(token)
69
- # 结束的时候要进行收尾
70
- if len(temp):
71
- words.append(token_split_wapper(token2str(temp)))
72
- # 返回包装好的token列表
73
- return ''.join(words)
74
-
75
- # 全局字典,存 per-session 的不可 deepcopy 对象 / 状态
76
- user_queues = ExpiringDict(ttl=550) # session_id -> Queue(list([string,string])),用于流式输出
77
- user_queues.start_auto_cleanup()
78
- user_stop_flags = ExpiringDict(ttl=550) # session_id -> bool (True 表示停止)
79
- user_stop_flags.start_auto_cleanup()
80
- user_history_sessions_show = ExpiringDict(ttl=550) # session_id -> 用于显示的历史记录,list([string,string])
81
- user_history_sessions_show.start_auto_cleanup()
82
- user_history_sessions_text = ExpiringDict(ttl=550) # session_id -> 纯文本历史记录,string
83
- user_history_sessions_text.start_auto_cleanup()
84
-
85
- # 后台生成函数(只访问全局字典,通过 session_id 定位)
86
- def generate_text(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay):
87
- out = ""
88
- q = user_queues.get(session_id)
89
- # 立即刷出用户问题
90
- q.put(out, block=False)
91
- # 构建完整的对话历史输入
92
- if len(sess) == 1:
93
- user_history_sessions_text[session_id] = f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
94
- else:
95
- user_history_sessions_text[session_id] += f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
96
- # 转换为模型输入格式
97
- tokens_batch = [tokenizer(user_history_sessions_text[session_id], 5.0)]
98
- tokens_batch = np.array(tokens_batch, dtype=np.int64) + 255
99
- inputs = torch.from_numpy(tokens_batch).to(device).data
100
- last_len = -1
101
- # 模型输出
102
- with torch.no_grad():
103
- for o in El_text_continue_stream(
104
- model, inputs, out_length=max_length,
105
- repeat_penalty_value=repeat_penalty,
106
- temperature=temperature,decay=decay,session_id=session_id):
107
- # 如果当前位置可以完整解码
108
- if o[0,-1] > 255:
109
- # 将未解码的部分一起解码
110
- temp = token2str(o[0][last_len:].cpu().numpy()-255)
111
- out += temp
112
- user_history_sessions_text[session_id] += temp
113
- # 重置为解码光标
114
- last_len = -1
115
- q.put(out, block=False)
116
- else:
117
- # 无法解码,光标固定
118
- last_len -= 1
119
- # 如果用户主动断开连接,停止生成,去除潜在标记
120
- if user_stop_flags.get(session_id, True):
121
- if '<' + out.split('<')[-1] in '<|im_end|>':
122
- # 显示的部分去除标记
123
- out = '<'+'<'.join(out.split('<')[:-1])
124
- # 历史的部分保留标记
125
- user_history_sessions_text[session_id] = '<'+'<'.join(user_history_sessions_text[session_id].split('<')[:-1])+'<|im_end|>'
126
- break
127
-
128
- # 如果是输出终止标记
129
- if '<|im_end|>' in out:
130
- # 显示的部分,去除标记
131
- out = out.split('<|im_end|>')[0]
132
- q.put(out, block=False)
133
- break
134
- # 如果用户中断
135
- if user_stop_flags[session_id] == True:
136
- break
137
- # 更新标记为暂停
138
- user_stop_flags[session_id] = True
139
-
140
- # 按钮处理逻辑:发送消息 / 止生成 / 清空会话
141
- def send_message(sess, btn_label, user_message, session_id, temperature, repeat_penalty, max_length, decay):
142
- # 发送消息按钮 - 启动生成线程
143
- if btn_label == "发送消息" and user_message:
144
- # 设置当前用户正在生成的标志
145
- user_stop_flags[session_id] = False
146
- # 立即在UI中显示用户消息
147
- user_tokens_display = process_user_tokens(user_message)
148
- # 添加用户消息到当前会话
149
- user_history_sessions_show[session_id] = sess
150
- user_history_sessions_show[session_id] += [[user_tokens_display, ""]]
151
- if session_id not in user_history_sessions_text:
152
- return "", "会话过期!"
153
- # 在这里开始流式输出
154
- thread = threading.Thread(target=generate_text, args=(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay))
155
- thread.daemon = True #主进程退出时退出
156
- thread.start() #启动
157
- user_stop_flags[session_id] = False
158
- # 清空入框,更新按钮文本
159
- return "", "停止生成"
160
- else:
161
- # 停止生成按钮 - 设置标志位
162
- user_stop_flags[session_id] = True
163
- # 更新返回给前端的 state/stop_flag
164
- return user_message, "发送消息"
165
-
166
- # 清空会话
167
- def clear_session():
168
- return []
169
-
170
- # 流式输出,无限循环刷新页面
171
- def stream_output(sess):
172
- global user_queues, user_stop_flags, user_history_sessions_show, user_history_sessions_text
173
- # 页面加载时初始化 session
174
- session_id = str(uuid.uuid4())
175
- user_queues[session_id] = Queue()
176
- user_stop_flags[session_id] = True
177
- user_history_sessions_show[session_id] = [] # 初始化历史会话记录,用于显示
178
- user_history_sessions_text[session_id] = "" # 初始化历史会话记录,用于文本存储
179
- # 返回初始状态
180
- yield [], "发送消息", session_id
181
- # 不断刷新
182
- while True:
183
- time.sleep(0.01) # 防止 busy-wait 占满 CPU
184
- # 等待队列有数据
185
- q = user_queues.get(session_id)
186
- if q is None:
187
- continue
188
- # 处理队列中的消息
189
- if not q.empty():
190
- # 取到最后一个加入的数据
191
- while q.qsize() > 1:
192
- q.get()
193
- out = q.get()
194
- sess = user_history_sessions_show[session_id]
195
- sess[-1][1] = out
196
- # 更新UI状态
197
- current_stopped = user_stop_flags.get(session_id, True)
198
- button_label = "停止生成" if not current_stopped else "发送消息"
199
- yield sess, button_label, session_id
200
-
201
- # UI美化
202
- css = """
203
- /* 大标题居中 */
204
- .title {
205
- text-align: center;
206
- }
207
- /* 高级选项字体居中 */
208
- #adv-param button {
209
- justify-content: center;
210
- }
211
- /* 高级选项字体放大 */
212
- #adv-param > button > span {
213
- font-size: 16px !important;
214
- font-weight: 600 !important;
215
- }
216
- /* 自定义token样式 */
217
- .token-box {
218
- display: inline-block;
219
- background-color: #f0f0f0;
220
- border: 1px solid #ddd;
221
- border-radius: 4px;
222
- padding: 2px 4px;
223
- margin: 2px;
224
- font-family: monospace;
225
- }
226
- .multi-token-box {
227
- display: inline-block;
228
- background-color: #e6f7ff;
229
- border: 1px solid #91d5ff;
230
- border-radius: 4px;
231
- padding: 2px 4px;
232
- margin: 2px;
233
- font-family: monospace;
234
- }
235
- """
236
- # ========== Gradio UI ==========
237
- with gr.Blocks(css=css) as demo:
238
- with gr.Column(elem_classes="container"):
239
- gr.Markdown("# 0.18B中文大语言模型在线体验", elem_classes="title")
240
- # 聊天界面
241
- chatbot = gr.Chatbot(
242
- label="对话",
243
- autoscroll=False,
244
- show_copy_button=True,
245
- elem_classes="chatbox",
246
- type="tuples",
247
- height=400
248
- )
249
- # 输入区域
250
- with gr.Column(elem_classes="input-area"):
251
- msg = gr.Textbox(
252
- placeholder="请输入你的问题...",
253
- label="",
254
- lines=3,
255
- show_label=False
256
- )
257
- # 按钮区域
258
- with gr.Row(elem_classes="button-row"):
259
- send_btn = gr.Button("发送消息", elem_classes="send-btn")
260
- clear_btn = gr.Button("清空会话", elem_classes="clear-btn")
261
- # 参数设置区域(可折叠)
262
- with gr.Accordion("高级参数设置", open=False, elem_classes="parameter-row", elem_id="adv-param"):
263
- with gr.Row():
264
- temperature = gr.Slider(0.0001, 3.0001, value=0.0001, step=0.1, label="Temperature")
265
- repeat_penalty = gr.Slider(0.0, 5.0, value=0.5, step=0.1, label="Repeat Penalty")
266
- with gr.Row():
267
- max_length = gr.Slider(64, 8192, value=512, step=64, label="Max Length")
268
- decay = gr.Slider(0.90, 1.0, value=0.99, step=0.005, label="Repeat Penalty Decay Rate")
269
- # gr.State 用来在前端保存可 deepcopied session 值
270
- session_id = gr.State()
271
- # 发送按钮处理
272
- send_btn.click(
273
- send_message,
274
- inputs=[chatbot, send_btn, msg, session_id, temperature, repeat_penalty, max_length, decay],
275
- outputs=[msg, send_btn],
276
- )
277
- # 会话清空按钮处理
278
- clear_btn.click(
279
- clear_session,
280
- inputs=[],
281
- outputs=[chatbot],
282
- )
283
- # 无限循环,一直更新聊天界面
284
- demo.load(
285
- stream_output,
286
- inputs=[chatbot],
287
- outputs=[chatbot, send_btn, session_id],
288
- )
289
- if __name__ == "__main__":
290
- """主函数:启动Gradio界面"""
291
- # 设置队列参数以提高并发处理能力
292
- demo.queue(max_size=128, default_concurrency_limit=128)
293
- # 启动Gradio应用,不公开分享,并应用CSS样式
294
- demo.launch(share=False)
 
 
 
 
 
 
1
+ # 公开库
2
+ import time
3
+ import html
4
+ import uuid
5
+ import torch
6
+ import threading
7
+ import numpy as np
8
+ import gradio as gr
9
+ from queue import Queue
10
+ # 私有库
11
+ from make_model import make_model
12
+ from LazyCache import ExpiringDict
13
+ from train_and_use import El_text_continue_stream
14
+ from tokenizer import tokenizer,vocab_size,token2str
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # 加载模型
18
+ model = make_model(
19
+ #token是从1开始的,0填充,剩下的用来覆盖全部字节
20
+ vocab_size = vocab_size+1+255,
21
+ embedding_dim = 768,
22
+ key_dim = 128,
23
+ head_number = 12,
24
+ position_information_type = "mask",
25
+ enable_affine = True,
26
+ enable_talking_head = True,
27
+ use_diff = False,
28
+ self_attention_block_size = 0,
29
+ feed_forward_dim = 1536,
30
+ enable_layer_norm = True,
31
+ deep = 12,
32
+ dropout_rate = 0.1,
33
+ enable_el_cache = True
34
+ ).to(device)
35
+ model.load_state_dict(torch.load('large_model_instruct_09291732.weight',map_location=device,weights_only=True))
36
+ model = model.eval()
37
+
38
+ # token包装函数 - 使用HTML span标签确保每个token在独立矩形中
39
+ def token_wapper(token):
40
+ # 对特殊字符进行HTML转义处理
41
+ escaped_token = html.escape(token)
42
+ return f'<span class="token-box">{escaped_token}</span>'
43
+
44
+ # 多token包装函数 - 使用HTML span标签确保每个token在独立矩形中
45
+ def token_split_wapper(token):
46
+ # 对特殊字符进行HTML转义处理
47
+ escaped_token = html.escape(token)
48
+ return f'<span class="multi-token-box">({escaped_token})[多token]</span>'
49
+
50
+ # 处理用户输入的token,返回安全的显示格式
51
+ def process_user_tokens(user_message):
52
+ # 通过分词器转化为token
53
+ user_tokens = tokenizer(user_message, 5.0)
54
+
55
+ # 将token还原并进行安全包装
56
+ words = [] # token列表
57
+ temp = [] # token是特殊字节,要合并
58
+ for token in user_tokens:
59
+ if token > 0:
60
+ # 将合并成功的加入列表
61
+ if len(temp):
62
+ words.append(token_split_wapper(token2str(temp)))
63
+ temp = []
64
+ # 将新的token加入列表
65
+ words.append(token_wapper(token2str([token])))
66
+ else:
67
+ # 将字节送去合并
68
+ temp.append(token)
69
+ # 结束的时候要进行收尾
70
+ if len(temp):
71
+ words.append(token_split_wapper(token2str(temp)))
72
+ # 返回包装好的token列表
73
+ return ''.join(words)
74
+
75
+ # 全局字典,存 per-session 的不可 deepcopy 对象 / 状态
76
+ user_queues = ExpiringDict(ttl=550) # session_id -> Queue(list([string,string])),用于流式输出
77
+ user_queues.start_auto_cleanup()
78
+ user_stop_flags = ExpiringDict(ttl=550) # session_id -> bool (True 表示停止)
79
+ user_stop_flags.start_auto_cleanup()
80
+ user_history_sessions_show = ExpiringDict(ttl=550) # session_id -> 用于显示的历史记录,list([string,string])
81
+ user_history_sessions_show.start_auto_cleanup()
82
+ user_history_sessions_text = ExpiringDict(ttl=550) # session_id -> 纯文本历史记录,string,有KV_Cache,用不到历史,但留着可作为日志输出[狗头]
83
+ user_history_sessions_text.start_auto_cleanup()
84
+
85
+ # 后台生成函数(只访问全局字典,通过 session_id 定位)
86
+ def generate_text(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay):
87
+ out = ""
88
+ q = user_queues.get(session_id)
89
+ # 立即刷出用户问题
90
+ q.put(out, block=False)
91
+ # 记录历史长度
92
+ history_len = len(sess) - 1
93
+ # 构建完整的对话历史输入
94
+ if history_len:
95
+ user_history_sessions_text[session_id] += f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
96
+ else:
97
+ user_history_sessions_text[session_id] = f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
98
+ # 转换为模型输入格式
99
+ tokens_batch = [tokenizer(f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant ", 5.0)]
100
+ tokens_batch = np.array(tokens_batch, dtype=np.int64) + 255
101
+ inputs = torch.from_numpy(tokens_batch).to(device).data
102
+ last_len = -1
103
+ # 模型输出
104
+ with torch.no_grad():
105
+ for o in El_text_continue_stream(
106
+ model, inputs, out_length=max_length,
107
+ repeat_penalty_value=repeat_penalty,
108
+ temperature=temperature,decay=decay,
109
+ session_id=session_id,history_len=history_len):
110
+ # 如果当前位置可以完整解码
111
+ if o[0,-1] > 255:
112
+ # 将未解码的部分一起解码
113
+ temp = token2str(o[0][last_len:].cpu().numpy()-255)
114
+ out += temp
115
+ user_history_sessions_text[session_id] += temp
116
+ # 重置为解码光标
117
+ last_len = -1
118
+ q.put(out, block=False)
119
+ else:
120
+ # 无法解码,光标固定
121
+ last_len -= 1
122
+ # 如果用户主动断开连接,停止生成,去除潜在标记
123
+ if user_stop_flags.get(session_id, True):
124
+ if '<' + out.split('<')[-1] in '<|im_end|>':
125
+ # 显示的部分去除标记
126
+ out = '<'+'<'.join(out.split('<')[:-1])
127
+ # 历史的部分保留标记
128
+ user_history_sessions_text[session_id] = '<'+'<'.join(user_history_sessions_text[session_id].split('<')[:-1])+'<|im_end|>'
129
+ break
130
+
131
+ # 如果是输出终止标记
132
+ if '<|im_end|>' in out:
133
+ # 显示的部分,去除标记
134
+ out = out.split('<|im_end|>')[0]
135
+ q.put(out, block=False)
136
+ break
137
+ # 如果用户中断
138
+ if user_stop_flags[session_id] == True:
139
+ break
140
+ # 更新标记为暂
141
+ user_stop_flags[session_id] = True
142
+ # 打印日志用于查看用户提问,用于后续优化
143
+ print(user_history_sessions_text[session_id])
144
+
145
+ # 按钮处理逻辑:发送消息 / 停止生成 / 清空会话
146
+ def send_message(sess, btn_label, user_message, session_id, temperature, repeat_penalty, max_length, decay):
147
+ # 发送消息按钮 - 启动生成线程
148
+ if btn_label == "发送消息" and user_message:
149
+ # 设置当前用户正在生成的标志
150
+ user_stop_flags[session_id] = False
151
+ # 立即在UI中显示用户消息
152
+ user_tokens_display = process_user_tokens(user_message)
153
+ # 添加用户消息到当前会话
154
+ user_history_sessions_show[session_id] = sess
155
+ user_history_sessions_show[session_id] += [[user_tokens_display, ""]]
156
+ if session_id not in user_history_sessions_text:
157
+ return "", "会话过期!"
158
+ # 在这里开始流式
159
+ thread = threading.Thread(target=generate_text, args=(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay))
160
+ thread.daemon = True #主进程退出时退出
161
+ thread.start() #启动
162
+ user_stop_flags[session_id] = False
163
+ # 清空输入框,更新按钮文本
164
+ return "", "停止生成"
165
+ else:
166
+ # 停止生成按钮 - 设置标志位
167
+ user_stop_flags[session_id] = True
168
+ # 更新返回给前端的 state/stop_flag
169
+ return user_message, "发送消息"
170
+
171
+ # 清空会话
172
+ def clear_session():
173
+ return []
174
+
175
+ # 流式输出,无限循环刷新页面
176
+ def stream_output(sess):
177
+ global user_queues, user_stop_flags, user_history_sessions_show, user_history_sessions_text
178
+ # 页面加载时初始化 session
179
+ session_id = str(uuid.uuid4())
180
+ user_queues[session_id] = Queue()
181
+ user_stop_flags[session_id] = True
182
+ user_history_sessions_show[session_id] = [] # 初始化历史会话记录,用于显示
183
+ user_history_sessions_text[session_id] = "" # 初始化历史会话记录,用于文本存储
184
+ # 返回初始状态
185
+ yield [], "发送消息", session_id
186
+ # 不断刷新
187
+ while True:
188
+ time.sleep(0.01) # 防止 busy-wait 占满 CPU
189
+ # 等待队列有数据
190
+ q = user_queues.get(session_id)
191
+ if q is None:
192
+ continue
193
+ # 处理队列中的消息
194
+ if not q.empty():
195
+ # 取到最后一个加入的数据
196
+ while q.qsize() > 1:
197
+ q.get()
198
+ out = q.get()
199
+ sess = user_history_sessions_show[session_id]
200
+ sess[-1][1] = out
201
+ # 更新UI状态
202
+ current_stopped = user_stop_flags.get(session_id, True)
203
+ button_label = "停止生成" if not current_stopped else "发送消息"
204
+ yield sess, button_label, session_id
205
+
206
+ # UI美化
207
+ css = """
208
+ /* 大标题居中 */
209
+ .title {
210
+ text-align: center;
211
+ }
212
+ /* 高级选项字体居中 */
213
+ #adv-param button {
214
+ justify-content: center;
215
+ }
216
+ /* 高级选项字体放大 */
217
+ #adv-param > button > span {
218
+ font-size: 16px !important;
219
+ font-weight: 600 !important;
220
+ }
221
+ /* 自定义token样式 */
222
+ .token-box {
223
+ display: inline-block;
224
+ background-color: #f0f0f0;
225
+ border: 1px solid #ddd;
226
+ border-radius: 4px;
227
+ padding: 2px 4px;
228
+ margin: 2px;
229
+ font-family: monospace;
230
+ }
231
+ .multi-token-box {
232
+ display: inline-block;
233
+ background-color: #e6f7ff;
234
+ border: 1px solid #91d5ff;
235
+ border-radius: 4px;
236
+ padding: 2px 4px;
237
+ margin: 2px;
238
+ font-family: monospace;
239
+ }
240
+ """
241
+ # ========== Gradio UI ==========
242
+ with gr.Blocks(css=css) as demo:
243
+ with gr.Column(elem_classes="container"):
244
+ gr.Markdown("# 0.18B中文大语言模型", elem_classes="title")
245
+ # 聊天界面
246
+ chatbot = gr.Chatbot(
247
+ label="对话",
248
+ autoscroll=False,
249
+ show_copy_button=True,
250
+ elem_classes="chatbox",
251
+ type="tuples",
252
+ height=400
253
+ )
254
+ # 输入区域
255
+ with gr.Column(elem_classes="input-area"):
256
+ msg = gr.Textbox(
257
+ placeholder="请输入你的问题...",
258
+ label="",
259
+ lines=3,
260
+ show_label=False
261
+ )
262
+ # 按钮区域
263
+ with gr.Row(elem_classes="button-row"):
264
+ send_btn = gr.Button("发送消息", elem_classes="send-btn")
265
+ clear_btn = gr.Button("清空会话", elem_classes="clear-btn")
266
+ # 参数设置区域(可折叠)
267
+ with gr.Accordion("高级参数设置", open=False, elem_classes="parameter-row", elem_id="adv-param"):
268
+ with gr.Row():
269
+ temperature = gr.Slider(0.0001, 3.0001, value=0.0001, step=0.1, label="Temperature")
270
+ repeat_penalty = gr.Slider(0.0, 5.0, value=0.5, step=0.1, label="Repeat Penalty")
271
+ with gr.Row():
272
+ max_length = gr.Slider(64, 8192, value=512, step=64, label="Max Length")
273
+ decay = gr.Slider(0.90, 1.0, value=0.99, step=0.005, label="Repeat Penalty Decay Rate")
274
+ # gr.State 用来在前端保存可 deepcopied session
275
+ session_id = gr.State()
276
+ # 发送按钮处理
277
+ send_btn.click(
278
+ send_message,
279
+ inputs=[chatbot, send_btn, msg, session_id, temperature, repeat_penalty, max_length, decay],
280
+ outputs=[msg, send_btn],
281
+ )
282
+ # 会话清空按钮处理
283
+ clear_btn.click(
284
+ clear_session,
285
+ inputs=[],
286
+ outputs=[chatbot],
287
+ )
288
+ # 无限循环,一直更新聊天界面
289
+ demo.load(
290
+ stream_output,
291
+ inputs=[chatbot],
292
+ outputs=[chatbot, send_btn, session_id],
293
+ )
294
+ if __name__ == "__main__":
295
+ """主函数:启动Gradio界面"""
296
+ # 设置队列参数以提高并发处理能力
297
+ demo.queue(max_size=128, default_concurrency_limit=128)
298
+ # 启动Gradio应用,不公开分享,并应用CSS样式
299
+ demo.launch(share=False)
install_ac.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ os.environ['AHOCORASICK_BYTES'] = '1'
5
+ # 使用当前解释器对应的pip进行安装
6
+ subprocess.check_call([
7
+ sys.executable, "-m", "pip", "install",
8
+ "--force-reinstall", # 强制重新安装
9
+ "--no-cache-dir", # 不使用缓存,确保获取最新版本
10
+ "git+https://github.com/WojciechMula/pyahocorasick.git"
11
+ ])
tokenizer.py CHANGED
@@ -1,104 +1,91 @@
1
- import numpy as np
2
- import subprocess
3
- import sys
4
- import os
5
-
6
- # 设置环境变量
7
- os.environ['AHOCORASICK_BYTES'] = '1'
8
- # 使用当前解释器对应的pip进行安装
9
- subprocess.check_call([
10
- sys.executable, "-m", "pip", "install",
11
- "--force-reinstall", # 强制重新安装
12
- "--no-cache-dir", # 不使用缓存确保获取最新版本
13
- "git+https://github.com/WojciechMula/pyahocorasick.git"
14
- ])
15
-
16
- #加载词表
17
- with open('vocab_b_65544.txt','r',encoding='utf-8') as f:
18
- # with open('vocab_tiny_random.txt','r',encoding='utf-8') as f:
19
- words_count = dict()
20
- for word in f:
21
- if word[0] != '\t':
22
- k,v = word.split('\t')
23
- words_count[k] = int(v[:-1])
24
-
25
- #补充缺失词,但尽量不要改变词频
26
- if '.' in words_count:
27
- words_count[','] = words_count['.']
28
- words_count['\r'] = 1
29
- words_count['\n'] = 1
30
- words_count['\t'] = 1
31
-
32
- #计算每个片段长度的单词
33
- N = 7
34
- count_sum = [0 for _ in range(N)]
35
- for k,v in words_count.items():
36
- count_sum[len(k)-1] += v
37
-
38
- #创建AC自动机
39
- import ahocorasick as ah
40
- aca= ah.Automaton()
41
- for k,v in words_count.items():
42
- aca.add_word(k.encode(),(len(k.encode()),np.log(v/count_sum[len(k)-1])))
43
- aca.make_automaton()
44
-
45
- #单词与整数互转字典
46
- words = [k for k in words_count]
47
- words.sort()
48
- word2idx = {k.encode():i+1 for i,k in enumerate(words)}
49
- idx2word = {i:k for k,i in word2idx.items()}
50
- vocab_size = len(word2idx)
51
-
52
- #分词器函数
53
- def tokenizer(text,alpha=1.0):
54
- encode_text = text.encode()
55
- #路径,记录起始位置和
56
- LOT = len(encode_text)
57
- BOW = 0 #表示最佳词的起始位置
58
- VOW = 1 #表示最佳路径的累积值
59
- VOID = 5 #表示没有记录
60
- routes = [(i,VOID) for i in range(LOT)] + [(-1,0.0)]
61
- tokens = [] #保存分词结果
62
- #遍历所有匹配成功的词
63
- # low:len_of_word
64
- # vow:value_of_word
65
- for eow, (low,vow) in aca.iter(encode_text):
66
- #匹配词起点序号 = 匹配词终点序号 -(匹配词长度-1)
67
- bow = eow - low + 1
68
- #得分是负数,但负的程度越小约好,
69
- #数值为起始位置的得分 + 当前词的分数
70
- #起始位置无记录就往前找
71
- i = 0
72
- while routes[bow - 1 - i][VOW] == VOID:
73
- i += 1
74
- v = routes[bow - 1 -i][VOW] + vow
75
- # 超过5.0直接使用确定算法
76
- if alpha >= 5.0:
77
- #更短的路径或第一个到达,更新
78
- if v > routes[eow][VOW] and i == 0 or routes[eow][VOW] == VOID:
79
- routes[eow] = bow,v #记录起始位置以及累积值
80
- else:
81
- # 随机算法
82
- if routes[eow][VOW] == VOID:
83
- base = v
84
- temp = 1.0
85
- denominator = 1.0
86
- routes[eow] = bow,v
87
- else:
88
- temp = np.exp(alpha * (v - base))
89
- denominator += temp
90
- if np.random.rand() < temp/denominator and i == 0:
91
- routes[eow] = bow,v #记录起始位置以及累积值
92
- #从后往前查找分割点
93
- eow = LOT - 1
94
- while encode_text:
95
- bow = routes[eow][BOW] #找到最佳词的起始位置
96
- tokens.append(encode_text[bow:eow+1]) #记录该词语
97
- encode_text,eow = encode_text[:bow],bow - 1 #继续分上一个词
98
- #从后往前找,需要反序得到正序的分词结果
99
- return [word2idx[w] if w in word2idx else -ord(w) for w in tokens[::-1]]
100
-
101
- def token2str(tokens,split=''):
102
- return b''.join([(int(-token)).to_bytes(1,'big') if token < 0 else idx2word[token] + split.encode() for token in tokens]).decode(errors="ignore")
103
-
104
-
 
1
+ import numpy as np
2
+ import install_ac
3
+ #加载词表
4
+ with open('vocab_b_65544.txt','r',encoding='utf-8') as f:
5
+ # with open('vocab_tiny_random.txt','r',encoding='utf-8') as f:
6
+ words_count = dict()
7
+ for word in f:
8
+ if word[0] != '\t':
9
+ k,v = word.split('\t')
10
+ words_count[k] = int(v[:-1])
11
+
12
+ #补充缺失词但尽量不要改变词频
13
+ if '.' in words_count:
14
+ words_count[','] = words_count['.']
15
+ words_count['\r'] = 1
16
+ words_count['\n'] = 1
17
+ words_count['\t'] = 1
18
+
19
+ #计算每个片段长度的单词总数
20
+ N = 7
21
+ count_sum = [0 for _ in range(N)]
22
+ for k,v in words_count.items():
23
+ count_sum[len(k)-1] += v
24
+
25
+ #创建AC自动机
26
+ import ahocorasick as ah
27
+ aca= ah.Automaton()
28
+ for k,v in words_count.items():
29
+ aca.add_word(k.encode(),(len(k.encode()),np.log(v/count_sum[len(k)-1])))
30
+ aca.make_automaton()
31
+
32
+ #单词与整互转字典
33
+ words = [k for k in words_count]
34
+ words.sort()
35
+ word2idx = {k.encode():i+1 for i,k in enumerate(words)}
36
+ idx2word = {i:k for k,i in word2idx.items()}
37
+ vocab_size = len(word2idx)
38
+
39
+ #分词器函数
40
+ def tokenizer(text,alpha=1.0):
41
+ encode_text = text.encode()
42
+ #路径,记录起始位置和分值
43
+ LOT = len(encode_text)
44
+ BOW = 0 #表示最佳词的起始���置
45
+ VOW = 1 #表示最佳路径的累积值
46
+ VOID = 5 #表示没有记录
47
+ routes = [(i,VOID) for i in range(LOT)] + [(-1,0.0)]
48
+ tokens = [] #保存分词结果
49
+ #遍历所有匹配成功的词
50
+ # low:len_of_word
51
+ # vow:value_of_word
52
+ for eow, (low,vow) in aca.iter(encode_text):
53
+ #匹配词起点序号 = 匹配词终点序号 -(匹配词长度-1
54
+ bow = eow - low + 1
55
+ #是负数,但负的程度越小约好,
56
+ #数值为起始位置的得分 + 当前词的分数
57
+ #起始位置无记录就往前找
58
+ i = 0
59
+ while routes[bow - 1 - i][VOW] == VOID:
60
+ i += 1
61
+ v = routes[bow - 1 -i][VOW] + vow
62
+ # 超过5.0直接使用确定算法
63
+ if alpha >= 5.0:
64
+ #更短的路径或第一个到达,更新
65
+ if v > routes[eow][VOW] and i == 0 or routes[eow][VOW] == VOID:
66
+ routes[eow] = bow,v #记录起始位置以及累积值
67
+ else:
68
+ # 随机算法
69
+ if routes[eow][VOW] == VOID:
70
+ base = v
71
+ temp = 1.0
72
+ denominator = 1.0
73
+ routes[eow] = bow,v
74
+ else:
75
+ temp = np.exp(alpha * (v - base))
76
+ denominator += temp
77
+ if np.random.rand() < temp/denominator and i == 0:
78
+ routes[eow] = bow,v #记录起始位置以及累积值
79
+ #从后往前查找分割点
80
+ eow = LOT - 1
81
+ while encode_text:
82
+ bow = routes[eow][BOW] #找到最佳词的起始位置
83
+ tokens.append(encode_text[bow:eow+1]) #记录该词语
84
+ encode_text,eow = encode_text[:bow],bow - 1 #继续分上一个词
85
+ #从后往前找,需要反序得到正序的分词结果
86
+ return [word2idx[w] if w in word2idx else -ord(w) for w in tokens[::-1]]
87
+
88
+ def token2str(tokens,split=''):
89
+ return b''.join([(int(-token)).to_bytes(1,'big') if token < 0 else idx2word[token] + split.encode() for token in tokens]).decode(errors="ignore")
90
+
91
+
 
 
 
 
 
 
 
 
 
 
 
 
 
train_and_use.py CHANGED
@@ -340,25 +340,37 @@ def El_text_continue(model,inputs,out_length,repeat_penalty_value,temperature,de
340
  repeat_penalty[i][next_token[i]] -= repeat_penalty_value
341
  return inputs
342
 
343
- def El_text_continue_stream(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98,session_id='0'):
344
  if model.model_type == "generator":
345
- assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
346
- query = model.embedding(inputs)
347
- prob_dist = model.projector(model.encoder(query,inputs==inputs,session_id)[:,-1,:])
348
- repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
349
- for index in range(inputs.size(1)):
350
- for line in range(inputs.size(0)):
351
- repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
352
- repeat_penalty *= decay
353
- prob_dist += repeat_penalty
354
- next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
355
- inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:]
 
 
 
 
 
 
 
 
 
 
 
 
356
  yield inputs
357
  for i in range(next_token.size(0)):
358
  repeat_penalty[i][next_token[i]] -= repeat_penalty_value
359
  for _ in range(0,out_length-1,1):
360
  query = model.embedding(inputs[:,[-1]])
361
- prob_dist = model.projector(model.encoder(query,(inputs==inputs)[:,[-1]],session_id)[:,-1,:])
362
  repeat_penalty *= decay
363
  prob_dist += repeat_penalty
364
  next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
 
340
  repeat_penalty[i][next_token[i]] -= repeat_penalty_value
341
  return inputs
342
 
343
+ def El_text_continue_stream(model,inputs,out_length,repeat_penalty_value,temperature,decay=0.98,session_id='0',history_len=0):
344
  if model.model_type == "generator":
345
+ if history_len == 0:
346
+ # 没有历史记录就处理整个输入,提高并行性
347
+ assert len(inputs[0]) > 1, "初始序列长度必须大于1,与增量续写进行区分"
348
+ query = model.embedding(inputs)
349
+ prob_dist = model.projector(model.encoder(query,inputs==inputs,session_id)[:,-1,:])
350
+ repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
351
+ # 不计算用户输入的重复惩罚
352
+ # for index in range(inputs.size(1)):
353
+ # for line in range(inputs.size(0)):
354
+ # repeat_penalty[line][inputs[line][index]] -= repeat_penalty_value
355
+ # repeat_penalty *= decay
356
+ # prob_dist += repeat_penalty
357
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
358
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:]
359
+ else:
360
+ # 否则将用户的问题逐个token处理,减少运算量
361
+ query = model.embedding(inputs)
362
+ for i in range(query.size(1)-1):
363
+ model.encoder(query[:,i:i+1],(inputs[:,[-1]]==inputs[:,[-1]]),session_id)[:,-1,:]
364
+ prob_dist = model.projector(model.encoder(query[:,[-1]],(inputs[:,[-1]]==inputs[:,[-1]]),session_id)[:,-1,:])
365
+ repeat_penalty = torch.zeros_like(prob_dist, device=inputs.device)
366
+ next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)
367
+ inputs = torch.cat([inputs,next_token.to(inputs.device)], dim=-1)[:,-4:]
368
  yield inputs
369
  for i in range(next_token.size(0)):
370
  repeat_penalty[i][next_token[i]] -= repeat_penalty_value
371
  for _ in range(0,out_length-1,1):
372
  query = model.embedding(inputs[:,[-1]])
373
+ prob_dist = model.projector(model.encoder(query,(inputs[:,[-1]]==inputs[:,[-1]]),session_id)[:,-1,:])
374
  repeat_penalty *= decay
375
  prob_dist += repeat_penalty
376
  next_token = torch.multinomial(F.softmax(prob_dist/temperature, dim = -1), num_samples = 1)