Spaces:
Runtime error
Runtime error
| import numpy as np | |
| def init(cfg): | |
| chat_template = cfg['chat_template'] | |
| model = cfg['model'] | |
| s_info = cfg['s_info'] | |
| lock = cfg['session_lock'] | |
| # ========== 预处理 key、desc ========== | |
| def str_tokenize(s): | |
| s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False) | |
| if s[0] in chat_template.onenl: | |
| return s[1:] | |
| else: | |
| return s | |
| text_format = cfg['text_format'] | |
| for x in cfg['btn_status_bar_list']: | |
| x['key'] = text_format(x['key'], | |
| char=cfg['role_char'].value, | |
| user=cfg['role_usr'].value) | |
| x['key_t'] = str_tokenize(x['key']) | |
| x['desc'] = text_format(x['desc'], | |
| char=cfg['role_char'].value, | |
| user=cfg['role_usr'].value) | |
| if x['desc']: | |
| x['desc_t'] = str_tokenize(x['desc']) | |
| # ========== 预处理 构造函数 mask ========== | |
| def btn_status_bar_fn_mask(): | |
| _shape1d = model.scores.shape[-1] | |
| mask = np.full((_shape1d,), -np.inf, dtype=np.single) | |
| return mask | |
| # ========== 预处理 构造函数 数字 ========== | |
| def btn_status_bar_fn_int(unit: str): | |
| t_int = str_tokenize('0123456789') | |
| assert len(t_int) == 10 | |
| fn_int_mask = btn_status_bar_fn_mask() | |
| fn_int_mask[chat_template.eos] = 0 | |
| fn_int_mask[t_int] = 0 | |
| if unit: | |
| unit_t = str_tokenize(unit) | |
| fn_int_mask[unit_t[0]] = 0 | |
| def logits_processor(_input_ids, logits): | |
| return logits + fn_int_mask | |
| def inner(eval_t, sample_t): | |
| retn = [] | |
| while True: | |
| token = sample_t(logits_processor) | |
| # ========== 不是数字就结束 ========== | |
| if token in chat_template.eos: | |
| break | |
| if unit and token == unit_t[0]: | |
| break | |
| # ========== 是数字就继续 ========== | |
| retn.append(token) | |
| eval_t([token]) | |
| if unit: | |
| eval_t(unit_t) # 添加单位 | |
| retn.extend(unit_t) | |
| return model.str_detokenize(retn) | |
| return inner | |
| # ========== 预处理 构造函数 集合 ========== | |
| def btn_status_bar_fn_set(value): | |
| value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)} | |
| fn_set_mask = btn_status_bar_fn_mask() | |
| fn_set_mask[list(value_t.keys())] = 0 | |
| def logits_processor(_input_ids, logits): | |
| return logits + fn_set_mask | |
| def inner(eval_t, sample_t): | |
| token = sample_t(logits_processor) | |
| eval_t(value_t[token][0]) | |
| return value_t[token][1] | |
| return inner | |
| # ========== 预处理 构造函数 字符串 ========== | |
| def btn_status_bar_fn_str(): | |
| def inner(eval_t, sample_t): | |
| retn = [] | |
| tmp = '' | |
| while True: | |
| token = sample_t(None) | |
| if token in chat_template.eos: | |
| break | |
| retn.append(token) | |
| tmp = model.str_detokenize(retn) | |
| if tmp.endswith('\n') or tmp.endswith('\r'): | |
| break | |
| # ========== 继续 ========== | |
| eval_t([token]) | |
| return tmp.strip() | |
| return inner | |
| # ========== 预处理 value ========== | |
| for x in cfg['btn_status_bar_list']: | |
| for y in x['combine']: | |
| if y['prefix']: | |
| y['prefix_t'] = str_tokenize(y['prefix']) | |
| if y['type'] == 'int': | |
| y['fn'] = btn_status_bar_fn_int(y['unit']) | |
| elif y['type'] == 'set': | |
| y['fn'] = btn_status_bar_fn_set(y['value']) | |
| elif y['type'] == 'str': | |
| y['fn'] = btn_status_bar_fn_str() | |
| else: | |
| pass | |
| # ========== 添加分隔标记 ========== | |
| for i, x in enumerate(cfg['btn_status_bar_list']): | |
| if i == 0: # 跳过第一个 | |
| continue | |
| x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t'] | |
| del x # 避免干扰 | |
| del y | |
| # print(cfg['btn_status_bar_list']) | |
| # ========== 输出状态栏 ========== | |
| def btn_status_bar(_n_keep, _n_discard, | |
| _temperature, _repeat_penalty, _frequency_penalty, | |
| _presence_penalty, _repeat_last_n, _top_k, | |
| _top_p, _min_p, _typical_p, | |
| _tfs_z, _mirostat_mode, _mirostat_eta, | |
| _mirostat_tau, _usr, _char, | |
| _rag, _max_tokens): | |
| with lock: | |
| if not cfg['session_active']: | |
| raise RuntimeError | |
| if cfg['btn_stop_status']: | |
| yield [], model.venv_info | |
| return | |
| # ========== 临时的eval和sample ========== | |
| def eval_t(tokens): | |
| return model.eval_t( | |
| tokens=tokens, | |
| n_keep=_n_keep, | |
| n_discard=_n_discard, | |
| im_start=chat_template.im_start_token | |
| ) | |
| def sample_t(logits_processor): | |
| return model.sample_t( | |
| top_k=_top_k, | |
| top_p=_top_p, | |
| min_p=_min_p, | |
| typical_p=_typical_p, | |
| temp=_temperature, | |
| repeat_penalty=_repeat_penalty, | |
| repeat_last_n=_repeat_last_n, | |
| frequency_penalty=_frequency_penalty, | |
| presence_penalty=_presence_penalty, | |
| tfs_z=_tfs_z, | |
| mirostat_mode=_mirostat_mode, | |
| mirostat_tau=_mirostat_tau, | |
| mirostat_eta=_mirostat_eta, | |
| logits_processor=logits_processor | |
| ) | |
| # ========== 初始化输出模版 ========== | |
| model.venv_create('status') # 创建隔离环境 | |
| eval_t(chat_template('状态')) # 开始标记 | |
| # ========== 流式输出 ========== | |
| df = [] # 清空 | |
| for _x in cfg['btn_status_bar_list']: | |
| # ========== 属性 ========== | |
| df.append([_x['key'], '']) | |
| eval_t(_x['key_t']) | |
| if _x['desc']: | |
| eval_t(_x['desc_t']) | |
| yield df, model.venv_info | |
| # ========== 值 ========== | |
| for _y in _x['combine']: | |
| if _y['prefix']: | |
| if df[-1][-1]: | |
| df[-1][-1] += _y['prefix'] | |
| else: | |
| df[-1][-1] += _y['prefix'].lstrip(':') | |
| eval_t(_y['prefix_t']) | |
| df[-1][-1] += _y['fn'](eval_t, sample_t) | |
| yield df, model.venv_info | |
| eval_t(chat_template.im_end_nl) # 结束标记 | |
| # ========== 清理上一次生成的状态栏 ========== | |
| model.venv_remove('status', keep_last=1) | |
| yield df, model.venv_info | |
| cfg['btn_status_bar_fn'] = { | |
| 'fn': btn_status_bar, | |
| 'inputs': cfg['setting'], | |
| 'outputs': [cfg['status_bar'], s_info] | |
| } | |
| cfg['btn_status_bar_fn'].update(cfg['btn_concurrency']) | |
| cfg['btn_status_bar'].click( | |
| **cfg['btn_start'] | |
| ).success( | |
| **cfg['btn_status_bar_fn'] | |
| ).success( | |
| **cfg['btn_finish'] | |
| ) | |