Spaces:
Runtime error
Runtime error
| def init(cfg): | |
| chat_template = cfg['chat_template'] | |
| model = cfg['model'] | |
| gr = cfg['gr'] | |
| lock = cfg['session_lock'] | |
| # ========== 流式输出函数 ========== | |
| def btn_com(_n_keep, _n_discard, | |
| _temperature, _repeat_penalty, _frequency_penalty, | |
| _presence_penalty, _repeat_last_n, _top_k, | |
| _top_p, _min_p, _typical_p, | |
| _tfs_z, _mirostat_mode, _mirostat_eta, | |
| _mirostat_tau, _role, _max_tokens): | |
| # ========== 初始化输出模版 ========== | |
| t_bot = chat_template(_role) | |
| completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字 | |
| history = '' | |
| # ========== 流式输出 ========== | |
| for token in model.generate_t( | |
| tokens=t_bot, | |
| n_keep=_n_keep, | |
| n_discard=_n_discard, | |
| im_start=chat_template.im_start_token, | |
| top_k=_top_k, | |
| top_p=_top_p, | |
| min_p=_min_p, | |
| typical_p=_typical_p, | |
| temp=_temperature, | |
| repeat_penalty=_repeat_penalty, | |
| repeat_last_n=_repeat_last_n, | |
| frequency_penalty=_frequency_penalty, | |
| presence_penalty=_presence_penalty, | |
| tfs_z=_tfs_z, | |
| mirostat_mode=_mirostat_mode, | |
| mirostat_tau=_mirostat_tau, | |
| mirostat_eta=_mirostat_eta, | |
| ): | |
| if token in chat_template.eos or token == chat_template.nlnl: | |
| t_bot.extend(completion_tokens) | |
| print('token in eos', token) | |
| break | |
| completion_tokens.append(token) | |
| all_text = model.str_detokenize(completion_tokens) | |
| if not all_text: | |
| continue | |
| t_bot.extend(completion_tokens) | |
| history += all_text | |
| yield history | |
| if token in chat_template.onenl: | |
| # ========== 移除末尾的换行符 ========== | |
| if t_bot[-2] in chat_template.onenl: | |
| model.venv_pop_token() | |
| break | |
| if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl: | |
| model.venv_pop_token() | |
| break | |
| if history[-2:] == '\n\n': # 各种 'x\n\n' 的token,比如'。\n\n' | |
| print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])), | |
| repr(model.str_detokenize(t_bot[-1:]))) | |
| break | |
| if len(t_bot) > _max_tokens: | |
| break | |
| completion_tokens = [] | |
| # ========== 查看末尾的换行符 ========== | |
| print('history', repr(history)) | |
| # ========== 给 kv_cache 加上输出结束符 ========== | |
| model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard) | |
| t_bot.extend(chat_template.im_end_nl) | |
| cfg['btn_com'] = btn_com | |
| def btn_start_or_finish(finish): | |
| tmp = gr.update(interactive=finish) | |
| def _inner(): | |
| with lock: | |
| if cfg['session_active'] != finish: | |
| raise RuntimeError | |
| cfg['session_active'] = not cfg['session_active'] | |
| return tmp, tmp, tmp | |
| return _inner | |
| btn_start_or_finish_outputs = [cfg['btn_submit'], cfg['btn_vo'], cfg['btn_suggest']] | |
| cfg['btn_start'] = { | |
| 'fn': btn_start_or_finish(False), | |
| 'outputs': btn_start_or_finish_outputs | |
| } | |
| cfg['btn_finish'] = { | |
| 'fn': btn_start_or_finish(True), | |
| 'outputs': btn_start_or_finish_outputs | |
| } | |
| cfg['setting'] = [cfg[x] for x in ('setting_n_keep', 'setting_n_discard', | |
| 'setting_temperature', 'setting_repeat_penalty', 'setting_frequency_penalty', | |
| 'setting_presence_penalty', 'setting_repeat_last_n', 'setting_top_k', | |
| 'setting_top_p', 'setting_min_p', 'setting_typical_p', | |
| 'setting_tfs_z', 'setting_mirostat_mode', 'setting_mirostat_eta', | |
| 'setting_mirostat_tau', 'role_usr', 'role_char', | |
| 'rag', 'setting_max_tokens')] | |