|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_context( |
|
|
tokenizer, |
|
|
query, |
|
|
history, |
|
|
system, |
|
|
max_input_length, |
|
|
max_window_size: int = 6144, |
|
|
chat_format: str = "chatml", |
|
|
): |
|
|
if history is None: |
|
|
history = [] |
|
|
|
|
|
if chat_format == "chatml": |
|
|
im_start, im_end = "<|im_start|>", "<|im_end|>" |
|
|
im_start_tokens = [tokenizer.im_start_id] |
|
|
im_end_tokens = [tokenizer.im_end_id] |
|
|
nl_tokens = tokenizer.encode("\n") |
|
|
|
|
|
def _tokenize_str(role, content): |
|
|
return (f"{role}\n{content}", |
|
|
tokenizer.encode( |
|
|
role, |
|
|
allowed_special=set(), |
|
|
) + nl_tokens + tokenizer.encode( |
|
|
content, |
|
|
allowed_special=set(), |
|
|
)) |
|
|
|
|
|
system_text, system_tokens_part = _tokenize_str("system", system) |
|
|
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens |
|
|
raw_text = "" |
|
|
context_tokens = [] |
|
|
|
|
|
for turn_query, turn_response in reversed(history): |
|
|
query_text, query_tokens_part = _tokenize_str("user", turn_query) |
|
|
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens |
|
|
|
|
|
response_text, response_tokens_part = _tokenize_str( |
|
|
"assistant", turn_response) |
|
|
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens |
|
|
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens |
|
|
prev_chat = ( |
|
|
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" |
|
|
) |
|
|
|
|
|
current_context_size = (len(system_tokens) + |
|
|
len(next_context_tokens) + |
|
|
len(context_tokens)) |
|
|
if current_context_size < max_window_size: |
|
|
context_tokens = next_context_tokens + context_tokens |
|
|
raw_text = prev_chat + raw_text |
|
|
else: |
|
|
break |
|
|
|
|
|
context_tokens = system_tokens + context_tokens |
|
|
raw_text = f"{im_start}{system_text}{im_end}" + raw_text |
|
|
context_tokens += (nl_tokens + im_start_tokens + |
|
|
_tokenize_str("user", query)[1] + im_end_tokens + |
|
|
nl_tokens + im_start_tokens + |
|
|
tokenizer.encode("assistant") + nl_tokens) |
|
|
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" |
|
|
|
|
|
elif chat_format == "raw": |
|
|
raw_text = query |
|
|
context_tokens = tokenizer.encode(raw_text) |
|
|
else: |
|
|
raise NotImplementedError(f"Unknown chat format {chat_format!r}") |
|
|
|
|
|
return raw_text, context_tokens[-max_input_length:] |
|
|
|
|
|
|
|
|
def get_qwen_key_list(qwen_type): |
|
|
qwen_key_list = [ |
|
|
"attn.c_attn", |
|
|
"attn.c_proj", |
|
|
"mlp.w1", |
|
|
"mlp.w2", |
|
|
"mlp.c_proj", |
|
|
"ln_1", |
|
|
"ln_2", |
|
|
"transformer.wte", |
|
|
"transformer.ln_f", |
|
|
] |
|
|
qwen2_key_list = [ |
|
|
"self_attn.", |
|
|
"self_attn.o_proj", |
|
|
"mlp.up_proj", |
|
|
"mlp.gate_proj", |
|
|
"mlp.down_proj", |
|
|
"input_layernorm", |
|
|
"post_attention_layernorm", |
|
|
"model.embed_tokens", |
|
|
"model.norm", |
|
|
] |
|
|
key_list = [] |
|
|
if qwen_type == 'qwen': |
|
|
key_list.extend(qwen_key_list) |
|
|
else: |
|
|
key_list.extend(qwen2_key_list) |
|
|
return key_list |
|
|
|