Upload 5 files
Browse files- __init__.py +1 -0
- hf_model.py +102 -0
- libfastllm_tools.so +0 -0
- llm.py +166 -0
- torch2flm.py +89 -0
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__all__ = ["llm"]
|
hf_model.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastllm_pytools import llm;
|
| 2 |
+
import torch;
|
| 3 |
+
import ctypes;
|
| 4 |
+
import numpy as np;
|
| 5 |
+
|
| 6 |
+
fastllm_data_type_dict = {
|
| 7 |
+
"int4": 8,
|
| 8 |
+
"int8": 3,
|
| 9 |
+
"float16": 7
|
| 10 |
+
}
|
| 11 |
+
fastllm_weight_type_dict = {
|
| 12 |
+
"linear": 1,
|
| 13 |
+
"embedding": 2
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def create(model,
|
| 17 |
+
tokenizer = None,
|
| 18 |
+
pre_prompt = None,
|
| 19 |
+
user_role = None,
|
| 20 |
+
bot_role = None,
|
| 21 |
+
history_sep = None,
|
| 22 |
+
dtype = "float16"):
|
| 23 |
+
if (dtype not in fastllm_data_type_dict):
|
| 24 |
+
print("dtype should in ", list(fastllm_data_type_dict.keys()));
|
| 25 |
+
exit(0);
|
| 26 |
+
|
| 27 |
+
# 0.1 model info
|
| 28 |
+
modelInfo = model.config.__dict__
|
| 29 |
+
if (pre_prompt):
|
| 30 |
+
modelInfo["pre_prompt"] = pre_prompt;
|
| 31 |
+
if (user_role):
|
| 32 |
+
modelInfo["user_role"] = user_role;
|
| 33 |
+
if (bot_role):
|
| 34 |
+
modelInfo["bot_role"] = bot_role;
|
| 35 |
+
if (history_sep):
|
| 36 |
+
modelInfo["history_sep"] = history_sep;
|
| 37 |
+
if (modelInfo["model_type"] == "baichuan" and hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")):
|
| 38 |
+
# Baichuan 2代
|
| 39 |
+
modelInfo["use_alibi"] = "1";
|
| 40 |
+
modelInfo["pre_prompt"] = "";
|
| 41 |
+
modelInfo["user_role"] = tokenizer.decode([model.generation_config.user_token_id]);
|
| 42 |
+
modelInfo["bot_role"] = tokenizer.decode([model.generation_config.assistant_token_id]);
|
| 43 |
+
modelInfo["history_sep"] = "";
|
| 44 |
+
|
| 45 |
+
weight_type_dict = {};
|
| 46 |
+
module_dict = {};
|
| 47 |
+
for key, m in model.named_modules():
|
| 48 |
+
if (isinstance(m, torch.nn.Linear)):
|
| 49 |
+
weight_type_dict[key + ".weight"] = "linear";
|
| 50 |
+
module_dict[key + ".weight"] = m;
|
| 51 |
+
if (isinstance(m, torch.nn.Embedding)):
|
| 52 |
+
weight_type_dict[key] = "embedding";
|
| 53 |
+
|
| 54 |
+
model = model.cpu();
|
| 55 |
+
dict = model.state_dict();
|
| 56 |
+
model_type = model.config.__dict__["model_type"];
|
| 57 |
+
model = llm.fastllm_lib.create_empty_llm_model(model_type.encode());
|
| 58 |
+
for it in modelInfo.keys():
|
| 59 |
+
llm.fastllm_lib.add_dict_llm_model(model, str(it).encode(), str(modelInfo[it]).encode());
|
| 60 |
+
|
| 61 |
+
# 1. vocab
|
| 62 |
+
if (tokenizer):
|
| 63 |
+
if (hasattr(tokenizer, "sp_model")):
|
| 64 |
+
piece_size = tokenizer.sp_model.piece_size();
|
| 65 |
+
for i in range(piece_size):
|
| 66 |
+
llm.fastllm_lib.add_tokenizer_word_llm_model(model, tokenizer.sp_model.id_to_piece(i).encode(), i);
|
| 67 |
+
else:
|
| 68 |
+
vocab = tokenizer.get_vocab();
|
| 69 |
+
for v in vocab.keys():
|
| 70 |
+
llm.fastllm_lib.add_tokenizer_word_llm_model(model, v.encode(), vocab[v]);
|
| 71 |
+
tot = 0;
|
| 72 |
+
for key in dict:
|
| 73 |
+
ori_data_type = 0;
|
| 74 |
+
ori_np_data_type = np.float32;
|
| 75 |
+
cur_weight_type = 0;
|
| 76 |
+
if (key in weight_type_dict and weight_type_dict[key] in fastllm_weight_type_dict):
|
| 77 |
+
cur_weight_type = fastllm_weight_type_dict[weight_type_dict[key]];
|
| 78 |
+
to_data_type = 0;
|
| 79 |
+
|
| 80 |
+
if (cur_weight_type == 1):
|
| 81 |
+
to_data_type = fastllm_data_type_dict[dtype];
|
| 82 |
+
if (to_data_type == 7):
|
| 83 |
+
ori_data_type = 7;
|
| 84 |
+
ori_np_data_type = np.float16;
|
| 85 |
+
elif (cur_weight_type == 2):
|
| 86 |
+
# TODO bfloat
|
| 87 |
+
to_data_type = 0;
|
| 88 |
+
|
| 89 |
+
llm.fastllm_lib.add_weight_llm_model(model, key.encode(),
|
| 90 |
+
len(dict[key].shape),
|
| 91 |
+
(ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
|
| 92 |
+
to_data_type, cur_weight_type, ori_data_type,
|
| 93 |
+
dict[key].numpy().astype(ori_np_data_type).ctypes.data_as(ctypes.c_void_p));
|
| 94 |
+
tot += 1;
|
| 95 |
+
print("convert (", tot, "/", len(dict), end = " )\r");
|
| 96 |
+
|
| 97 |
+
print("");
|
| 98 |
+
llm.fastllm_lib.init_params_llm_model(model);
|
| 99 |
+
llm.fastllm_lib.warmup_llm_model(model);
|
| 100 |
+
ret = llm.model("", id = model);
|
| 101 |
+
return ret;
|
| 102 |
+
|
libfastllm_tools.so
ADDED
|
Binary file (746 kB). View file
|
|
|
llm.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes;
|
| 2 |
+
import os;
|
| 3 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any;
|
| 4 |
+
|
| 5 |
+
import platform
|
| 6 |
+
if platform.system() == 'Windows':
|
| 7 |
+
fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "fastllm_tools.dll"))
|
| 8 |
+
else:
|
| 9 |
+
fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.so"))
|
| 10 |
+
|
| 11 |
+
fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
|
| 12 |
+
fastllm_lib.create_llm_model.restype = ctypes.c_int
|
| 13 |
+
|
| 14 |
+
fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
|
| 15 |
+
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
|
| 16 |
+
ctypes.c_float, ctypes.c_float]
|
| 17 |
+
fastllm_lib.launch_response_llm_model.restype = ctypes.c_int
|
| 18 |
+
|
| 19 |
+
fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
|
| 20 |
+
fastllm_lib.fetch_response_llm_model.restype = ctypes.c_int
|
| 21 |
+
|
| 22 |
+
fastllm_lib.response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_char_p,
|
| 23 |
+
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
|
| 24 |
+
ctypes.c_float, ctypes.c_float]
|
| 25 |
+
fastllm_lib.response_str_llm_model.restype = ctypes.c_char_p
|
| 26 |
+
|
| 27 |
+
fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
|
| 28 |
+
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
|
| 29 |
+
ctypes.c_float, ctypes.c_float]
|
| 30 |
+
fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int
|
| 31 |
+
|
| 32 |
+
fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
|
| 33 |
+
fastllm_lib.fetch_response_str_llm_model.restype = ctypes.c_char_p
|
| 34 |
+
|
| 35 |
+
fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
|
| 36 |
+
fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p
|
| 37 |
+
|
| 38 |
+
fastllm_lib.make_input_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p]
|
| 39 |
+
fastllm_lib.make_input_llm_model.restype = ctypes.c_char_p
|
| 40 |
+
|
| 41 |
+
def set_cpu_threads(threads: int):
|
| 42 |
+
fastllm_lib.set_cpu_threads(threads);
|
| 43 |
+
|
| 44 |
+
def get_cpu_threads() -> int:
|
| 45 |
+
return fastllm_lib.get_cpu_threads();
|
| 46 |
+
|
| 47 |
+
def print_ins_info():
|
| 48 |
+
fastllm_lib.print_cpu_ins();
|
| 49 |
+
|
| 50 |
+
def set_cpu_kvcache(cpu_kvcache):
|
| 51 |
+
fastllm_lib.set_kvcache_in_cpu(ctypes.c_bool(cpu_kvcache));
|
| 52 |
+
|
| 53 |
+
def get_cpu_kvcache():
|
| 54 |
+
return fastllm_lib.get_kvcache_in_cpu();
|
| 55 |
+
|
| 56 |
+
def set_cpu_low_mem(low_mem):
|
| 57 |
+
fastllm_lib.set_cpu_low_mem(ctypes.c_bool(low_mem));
|
| 58 |
+
|
| 59 |
+
def get_cpu_low_mem():
|
| 60 |
+
return fastllm_lib.get_cpu_low_mem();
|
| 61 |
+
|
| 62 |
+
def from_hf(model,
|
| 63 |
+
tokenizer = None,
|
| 64 |
+
dtype = "float16"):
|
| 65 |
+
from fastllm_pytools import hf_model;
|
| 66 |
+
return hf_model.create(model, tokenizer, dtype = dtype);
|
| 67 |
+
|
| 68 |
+
class model:
|
| 69 |
+
def __init__ (self, path : str,
|
| 70 |
+
id : int = -99999):
|
| 71 |
+
if (id != -99999):
|
| 72 |
+
self.model = id;
|
| 73 |
+
else:
|
| 74 |
+
self.model = fastllm_lib.create_llm_model(path.encode());
|
| 75 |
+
self.direct_query = False;
|
| 76 |
+
|
| 77 |
+
def get_prompt(self,
|
| 78 |
+
query: str,
|
| 79 |
+
history: List[Tuple[str, str]] = None) -> str:
|
| 80 |
+
if (not(history)):
|
| 81 |
+
history = [];
|
| 82 |
+
prompt = "";
|
| 83 |
+
for i, (old_query, response) in enumerate(history):
|
| 84 |
+
prompt = fastllm_lib.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode()).decode();
|
| 85 |
+
prompt = fastllm_lib.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode()).decode();
|
| 86 |
+
return prompt;
|
| 87 |
+
|
| 88 |
+
def save(self, path : str):
|
| 89 |
+
fastllm_lib.save_llm_model(self.model, path.encode());
|
| 90 |
+
|
| 91 |
+
def response(self,
|
| 92 |
+
query: str,
|
| 93 |
+
history: List[Tuple[str, str]] = None,
|
| 94 |
+
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0) -> str:
|
| 95 |
+
prompt = query if self.direct_query else self.get_prompt(query, history);
|
| 96 |
+
ret = fastllm_lib.response_str_llm_model(self.model, prompt.encode(),
|
| 97 |
+
max_length, do_sample, top_p, top_k, temperature, repeat_penalty).decode();
|
| 98 |
+
return ret;
|
| 99 |
+
|
| 100 |
+
def stream_response(self,
|
| 101 |
+
query: str,
|
| 102 |
+
history: List[Tuple[str, str]] = None,
|
| 103 |
+
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
|
| 104 |
+
one_by_one = True):
|
| 105 |
+
prompt = query if self.direct_query else self.get_prompt(query, history);
|
| 106 |
+
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
|
| 107 |
+
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
|
| 108 |
+
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty));
|
| 109 |
+
res = "";
|
| 110 |
+
ret = b'';
|
| 111 |
+
while True:
|
| 112 |
+
ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
|
| 113 |
+
cur = "";
|
| 114 |
+
try:
|
| 115 |
+
cur = ret.decode();
|
| 116 |
+
ret = b'';
|
| 117 |
+
except:
|
| 118 |
+
pass;
|
| 119 |
+
if (cur == "<flmeos>"):
|
| 120 |
+
break;
|
| 121 |
+
if one_by_one:
|
| 122 |
+
yield cur;
|
| 123 |
+
else:
|
| 124 |
+
res += cur;
|
| 125 |
+
yield res;
|
| 126 |
+
|
| 127 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
|
| 128 |
+
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
|
| 129 |
+
if (not(history)):
|
| 130 |
+
history = [];
|
| 131 |
+
prompt = query if self.direct_query else self.get_prompt(query, history);
|
| 132 |
+
input = tokenizer.encode(prompt);
|
| 133 |
+
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
|
| 134 |
+
max_length, do_sample, top_p, top_k, temperature, repeat_penalty);
|
| 135 |
+
|
| 136 |
+
result = [];
|
| 137 |
+
while True:
|
| 138 |
+
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
|
| 139 |
+
if (cur == -1):
|
| 140 |
+
break;
|
| 141 |
+
result.append(cur);
|
| 142 |
+
response = tokenizer.decode(result);
|
| 143 |
+
history = history + [(query, response)];
|
| 144 |
+
return response, history;
|
| 145 |
+
|
| 146 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
|
| 147 |
+
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
|
| 148 |
+
return_past_key_values = False, **kwargs) -> str:
|
| 149 |
+
if (not(history)):
|
| 150 |
+
history = [];
|
| 151 |
+
prompt = query if self.direct_query else self.get_prompt(query, history);
|
| 152 |
+
input = tokenizer.encode(prompt);
|
| 153 |
+
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
|
| 154 |
+
max_length, do_sample, top_p, top_k, temperature, repeat_penalty);
|
| 155 |
+
tokens = [];
|
| 156 |
+
while True:
|
| 157 |
+
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
|
| 158 |
+
if (cur == -1):
|
| 159 |
+
break;
|
| 160 |
+
tokens.append(cur);
|
| 161 |
+
response = tokenizer.decode(tokens);
|
| 162 |
+
new_history = history + [(query, response)];
|
| 163 |
+
if return_past_key_values:
|
| 164 |
+
yield response, new_history, None;
|
| 165 |
+
else:
|
| 166 |
+
yield response, new_history;
|
torch2flm.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import struct
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def writeString(fo, s):
|
| 5 |
+
fo.write(struct.pack('i', len(s)));
|
| 6 |
+
fo.write(s.encode());
|
| 7 |
+
|
| 8 |
+
def writeKeyValue(fo, key, value):
|
| 9 |
+
writeString(fo, key);
|
| 10 |
+
writeString(fo, value);
|
| 11 |
+
|
| 12 |
+
def tofile(exportPath,
|
| 13 |
+
model,
|
| 14 |
+
tokenizer = None,
|
| 15 |
+
pre_prompt = None,
|
| 16 |
+
user_role = None,
|
| 17 |
+
bot_role = None,
|
| 18 |
+
history_sep = None):
|
| 19 |
+
dict = model.state_dict();
|
| 20 |
+
fo = open(exportPath, "wb");
|
| 21 |
+
|
| 22 |
+
# 0. version id
|
| 23 |
+
fo.write(struct.pack('i', 2));
|
| 24 |
+
|
| 25 |
+
# 0.1 model info
|
| 26 |
+
modelInfo = model.config.__dict__
|
| 27 |
+
if ("model_type" not in modelInfo):
|
| 28 |
+
print("unknown model_type.");
|
| 29 |
+
exit(0);
|
| 30 |
+
|
| 31 |
+
if (pre_prompt):
|
| 32 |
+
modelInfo["pre_prompt"] = pre_prompt;
|
| 33 |
+
if (user_role):
|
| 34 |
+
modelInfo["user_role"] = user_role;
|
| 35 |
+
if (bot_role):
|
| 36 |
+
modelInfo["bot_role"] = bot_role;
|
| 37 |
+
if (history_sep):
|
| 38 |
+
modelInfo["history_sep"] = history_sep;
|
| 39 |
+
if (modelInfo["model_type"] == "baichuan" and hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")):
|
| 40 |
+
# Baichuan 2代
|
| 41 |
+
modelInfo["use_alibi"] = "1";
|
| 42 |
+
modelInfo["pre_prompt"] = "";
|
| 43 |
+
modelInfo["user_role"] = tokenizer.decode([model.generation_config.user_token_id]);
|
| 44 |
+
modelInfo["bot_role"] = tokenizer.decode([model.generation_config.assistant_token_id]);
|
| 45 |
+
modelInfo["history_sep"] = "";
|
| 46 |
+
|
| 47 |
+
fo.write(struct.pack('i', len(modelInfo)));
|
| 48 |
+
for it in modelInfo.keys():
|
| 49 |
+
writeKeyValue(fo, str(it), str(modelInfo[it]));
|
| 50 |
+
|
| 51 |
+
# 1. vocab
|
| 52 |
+
if (tokenizer):
|
| 53 |
+
if (hasattr(tokenizer, "sp_model")):
|
| 54 |
+
piece_size = tokenizer.sp_model.piece_size();
|
| 55 |
+
fo.write(struct.pack('i', piece_size));
|
| 56 |
+
for i in range(piece_size):
|
| 57 |
+
s = tokenizer.sp_model.id_to_piece(i).encode();
|
| 58 |
+
fo.write(struct.pack('i', len(s)));
|
| 59 |
+
for c in s:
|
| 60 |
+
fo.write(struct.pack('i', c));
|
| 61 |
+
fo.write(struct.pack('i', i));
|
| 62 |
+
else:
|
| 63 |
+
vocab = tokenizer.get_vocab();
|
| 64 |
+
fo.write(struct.pack('i', len(vocab)));
|
| 65 |
+
for v in vocab.keys():
|
| 66 |
+
s = v.encode();
|
| 67 |
+
fo.write(struct.pack('i', len(s)));
|
| 68 |
+
for c in s:
|
| 69 |
+
fo.write(struct.pack('i', c));
|
| 70 |
+
fo.write(struct.pack('i', vocab[v]));
|
| 71 |
+
else:
|
| 72 |
+
fo.write(struct.pack('i', 0));
|
| 73 |
+
|
| 74 |
+
# 2. weight
|
| 75 |
+
fo.write(struct.pack('i', len(dict)));
|
| 76 |
+
tot = 0;
|
| 77 |
+
for key in dict:
|
| 78 |
+
cur = dict[key].numpy().astype(np.float32);
|
| 79 |
+
fo.write(struct.pack('i', len(key)));
|
| 80 |
+
fo.write(key.encode());
|
| 81 |
+
fo.write(struct.pack('i', len(cur.shape)));
|
| 82 |
+
for i in cur.shape:
|
| 83 |
+
fo.write(struct.pack('i', i));
|
| 84 |
+
fo.write(struct.pack('i', 0));
|
| 85 |
+
fo.write(cur.data);
|
| 86 |
+
tot += 1;
|
| 87 |
+
print("output (", tot, "/", len(dict), end = " )\r");
|
| 88 |
+
print("\nfinish.");
|
| 89 |
+
fo.close();
|