xuqinyang commited on
Commit
27a2322
·
1 Parent(s): 89859ea

Upload 5 files

Browse files
Files changed (5) hide show
  1. __init__.py +1 -0
  2. hf_model.py +102 -0
  3. libfastllm_tools.so +0 -0
  4. llm.py +166 -0
  5. torch2flm.py +89 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = ["llm"]
hf_model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastllm_pytools import llm;
2
+ import torch;
3
+ import ctypes;
4
+ import numpy as np;
5
+
6
+ fastllm_data_type_dict = {
7
+ "int4": 8,
8
+ "int8": 3,
9
+ "float16": 7
10
+ }
11
+ fastllm_weight_type_dict = {
12
+ "linear": 1,
13
+ "embedding": 2
14
+ }
15
+
16
+ def create(model,
17
+ tokenizer = None,
18
+ pre_prompt = None,
19
+ user_role = None,
20
+ bot_role = None,
21
+ history_sep = None,
22
+ dtype = "float16"):
23
+ if (dtype not in fastllm_data_type_dict):
24
+ print("dtype should in ", list(fastllm_data_type_dict.keys()));
25
+ exit(0);
26
+
27
+ # 0.1 model info
28
+ modelInfo = model.config.__dict__
29
+ if (pre_prompt):
30
+ modelInfo["pre_prompt"] = pre_prompt;
31
+ if (user_role):
32
+ modelInfo["user_role"] = user_role;
33
+ if (bot_role):
34
+ modelInfo["bot_role"] = bot_role;
35
+ if (history_sep):
36
+ modelInfo["history_sep"] = history_sep;
37
+ if (modelInfo["model_type"] == "baichuan" and hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")):
38
+ # Baichuan 2代
39
+ modelInfo["use_alibi"] = "1";
40
+ modelInfo["pre_prompt"] = "";
41
+ modelInfo["user_role"] = tokenizer.decode([model.generation_config.user_token_id]);
42
+ modelInfo["bot_role"] = tokenizer.decode([model.generation_config.assistant_token_id]);
43
+ modelInfo["history_sep"] = "";
44
+
45
+ weight_type_dict = {};
46
+ module_dict = {};
47
+ for key, m in model.named_modules():
48
+ if (isinstance(m, torch.nn.Linear)):
49
+ weight_type_dict[key + ".weight"] = "linear";
50
+ module_dict[key + ".weight"] = m;
51
+ if (isinstance(m, torch.nn.Embedding)):
52
+ weight_type_dict[key] = "embedding";
53
+
54
+ model = model.cpu();
55
+ dict = model.state_dict();
56
+ model_type = model.config.__dict__["model_type"];
57
+ model = llm.fastllm_lib.create_empty_llm_model(model_type.encode());
58
+ for it in modelInfo.keys():
59
+ llm.fastllm_lib.add_dict_llm_model(model, str(it).encode(), str(modelInfo[it]).encode());
60
+
61
+ # 1. vocab
62
+ if (tokenizer):
63
+ if (hasattr(tokenizer, "sp_model")):
64
+ piece_size = tokenizer.sp_model.piece_size();
65
+ for i in range(piece_size):
66
+ llm.fastllm_lib.add_tokenizer_word_llm_model(model, tokenizer.sp_model.id_to_piece(i).encode(), i);
67
+ else:
68
+ vocab = tokenizer.get_vocab();
69
+ for v in vocab.keys():
70
+ llm.fastllm_lib.add_tokenizer_word_llm_model(model, v.encode(), vocab[v]);
71
+ tot = 0;
72
+ for key in dict:
73
+ ori_data_type = 0;
74
+ ori_np_data_type = np.float32;
75
+ cur_weight_type = 0;
76
+ if (key in weight_type_dict and weight_type_dict[key] in fastllm_weight_type_dict):
77
+ cur_weight_type = fastllm_weight_type_dict[weight_type_dict[key]];
78
+ to_data_type = 0;
79
+
80
+ if (cur_weight_type == 1):
81
+ to_data_type = fastllm_data_type_dict[dtype];
82
+ if (to_data_type == 7):
83
+ ori_data_type = 7;
84
+ ori_np_data_type = np.float16;
85
+ elif (cur_weight_type == 2):
86
+ # TODO bfloat
87
+ to_data_type = 0;
88
+
89
+ llm.fastllm_lib.add_weight_llm_model(model, key.encode(),
90
+ len(dict[key].shape),
91
+ (ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
92
+ to_data_type, cur_weight_type, ori_data_type,
93
+ dict[key].numpy().astype(ori_np_data_type).ctypes.data_as(ctypes.c_void_p));
94
+ tot += 1;
95
+ print("convert (", tot, "/", len(dict), end = " )\r");
96
+
97
+ print("");
98
+ llm.fastllm_lib.init_params_llm_model(model);
99
+ llm.fastllm_lib.warmup_llm_model(model);
100
+ ret = llm.model("", id = model);
101
+ return ret;
102
+
libfastllm_tools.so ADDED
Binary file (746 kB). View file
 
llm.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes;
2
+ import os;
3
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any;
4
+
5
+ import platform
6
+ if platform.system() == 'Windows':
7
+ fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "fastllm_tools.dll"))
8
+ else:
9
+ fastllm_lib = ctypes.cdll.LoadLibrary(os.path.join(os.path.split(os.path.realpath(__file__))[0], "libfastllm_tools.so"))
10
+
11
+ fastllm_lib.create_llm_model.argtypes = [ctypes.c_char_p]
12
+ fastllm_lib.create_llm_model.restype = ctypes.c_int
13
+
14
+ fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
15
+ ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
16
+ ctypes.c_float, ctypes.c_float]
17
+ fastllm_lib.launch_response_llm_model.restype = ctypes.c_int
18
+
19
+ fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
20
+ fastllm_lib.fetch_response_llm_model.restype = ctypes.c_int
21
+
22
+ fastllm_lib.response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_char_p,
23
+ ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
24
+ ctypes.c_float, ctypes.c_float]
25
+ fastllm_lib.response_str_llm_model.restype = ctypes.c_char_p
26
+
27
+ fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
28
+ ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
29
+ ctypes.c_float, ctypes.c_float]
30
+ fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int
31
+
32
+ fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
33
+ fastllm_lib.fetch_response_str_llm_model.restype = ctypes.c_char_p
34
+
35
+ fastllm_lib.make_history_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
36
+ fastllm_lib.make_history_llm_model.restype = ctypes.c_char_p
37
+
38
+ fastllm_lib.make_input_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p]
39
+ fastllm_lib.make_input_llm_model.restype = ctypes.c_char_p
40
+
41
+ def set_cpu_threads(threads: int):
42
+ fastllm_lib.set_cpu_threads(threads);
43
+
44
+ def get_cpu_threads() -> int:
45
+ return fastllm_lib.get_cpu_threads();
46
+
47
+ def print_ins_info():
48
+ fastllm_lib.print_cpu_ins();
49
+
50
+ def set_cpu_kvcache(cpu_kvcache):
51
+ fastllm_lib.set_kvcache_in_cpu(ctypes.c_bool(cpu_kvcache));
52
+
53
+ def get_cpu_kvcache():
54
+ return fastllm_lib.get_kvcache_in_cpu();
55
+
56
+ def set_cpu_low_mem(low_mem):
57
+ fastllm_lib.set_cpu_low_mem(ctypes.c_bool(low_mem));
58
+
59
+ def get_cpu_low_mem():
60
+ return fastllm_lib.get_cpu_low_mem();
61
+
62
+ def from_hf(model,
63
+ tokenizer = None,
64
+ dtype = "float16"):
65
+ from fastllm_pytools import hf_model;
66
+ return hf_model.create(model, tokenizer, dtype = dtype);
67
+
68
+ class model:
69
+ def __init__ (self, path : str,
70
+ id : int = -99999):
71
+ if (id != -99999):
72
+ self.model = id;
73
+ else:
74
+ self.model = fastllm_lib.create_llm_model(path.encode());
75
+ self.direct_query = False;
76
+
77
+ def get_prompt(self,
78
+ query: str,
79
+ history: List[Tuple[str, str]] = None) -> str:
80
+ if (not(history)):
81
+ history = [];
82
+ prompt = "";
83
+ for i, (old_query, response) in enumerate(history):
84
+ prompt = fastllm_lib.make_history_llm_model(self.model, prompt.encode(), i, old_query.encode(), response.encode()).decode();
85
+ prompt = fastllm_lib.make_input_llm_model(self.model, prompt.encode(), len(history), query.encode()).decode();
86
+ return prompt;
87
+
88
+ def save(self, path : str):
89
+ fastllm_lib.save_llm_model(self.model, path.encode());
90
+
91
+ def response(self,
92
+ query: str,
93
+ history: List[Tuple[str, str]] = None,
94
+ max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0) -> str:
95
+ prompt = query if self.direct_query else self.get_prompt(query, history);
96
+ ret = fastllm_lib.response_str_llm_model(self.model, prompt.encode(),
97
+ max_length, do_sample, top_p, top_k, temperature, repeat_penalty).decode();
98
+ return ret;
99
+
100
+ def stream_response(self,
101
+ query: str,
102
+ history: List[Tuple[str, str]] = None,
103
+ max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
104
+ one_by_one = True):
105
+ prompt = query if self.direct_query else self.get_prompt(query, history);
106
+ handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
107
+ ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
108
+ ctypes.c_float(temperature), ctypes.c_float(repeat_penalty));
109
+ res = "";
110
+ ret = b'';
111
+ while True:
112
+ ret += fastllm_lib.fetch_response_str_llm_model(self.model, handle);
113
+ cur = "";
114
+ try:
115
+ cur = ret.decode();
116
+ ret = b'';
117
+ except:
118
+ pass;
119
+ if (cur == "<flmeos>"):
120
+ break;
121
+ if one_by_one:
122
+ yield cur;
123
+ else:
124
+ res += cur;
125
+ yield res;
126
+
127
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
128
+ do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
129
+ if (not(history)):
130
+ history = [];
131
+ prompt = query if self.direct_query else self.get_prompt(query, history);
132
+ input = tokenizer.encode(prompt);
133
+ handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
134
+ max_length, do_sample, top_p, top_k, temperature, repeat_penalty);
135
+
136
+ result = [];
137
+ while True:
138
+ cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
139
+ if (cur == -1):
140
+ break;
141
+ result.append(cur);
142
+ response = tokenizer.decode(result);
143
+ history = history + [(query, response)];
144
+ return response, history;
145
+
146
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
147
+ max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
148
+ return_past_key_values = False, **kwargs) -> str:
149
+ if (not(history)):
150
+ history = [];
151
+ prompt = query if self.direct_query else self.get_prompt(query, history);
152
+ input = tokenizer.encode(prompt);
153
+ handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
154
+ max_length, do_sample, top_p, top_k, temperature, repeat_penalty);
155
+ tokens = [];
156
+ while True:
157
+ cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
158
+ if (cur == -1):
159
+ break;
160
+ tokens.append(cur);
161
+ response = tokenizer.decode(tokens);
162
+ new_history = history + [(query, response)];
163
+ if return_past_key_values:
164
+ yield response, new_history, None;
165
+ else:
166
+ yield response, new_history;
torch2flm.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import struct
2
+ import numpy as np
3
+
4
+ def writeString(fo, s):
5
+ fo.write(struct.pack('i', len(s)));
6
+ fo.write(s.encode());
7
+
8
+ def writeKeyValue(fo, key, value):
9
+ writeString(fo, key);
10
+ writeString(fo, value);
11
+
12
+ def tofile(exportPath,
13
+ model,
14
+ tokenizer = None,
15
+ pre_prompt = None,
16
+ user_role = None,
17
+ bot_role = None,
18
+ history_sep = None):
19
+ dict = model.state_dict();
20
+ fo = open(exportPath, "wb");
21
+
22
+ # 0. version id
23
+ fo.write(struct.pack('i', 2));
24
+
25
+ # 0.1 model info
26
+ modelInfo = model.config.__dict__
27
+ if ("model_type" not in modelInfo):
28
+ print("unknown model_type.");
29
+ exit(0);
30
+
31
+ if (pre_prompt):
32
+ modelInfo["pre_prompt"] = pre_prompt;
33
+ if (user_role):
34
+ modelInfo["user_role"] = user_role;
35
+ if (bot_role):
36
+ modelInfo["bot_role"] = bot_role;
37
+ if (history_sep):
38
+ modelInfo["history_sep"] = history_sep;
39
+ if (modelInfo["model_type"] == "baichuan" and hasattr(model, "model") and hasattr(model.model, "get_alibi_mask")):
40
+ # Baichuan 2代
41
+ modelInfo["use_alibi"] = "1";
42
+ modelInfo["pre_prompt"] = "";
43
+ modelInfo["user_role"] = tokenizer.decode([model.generation_config.user_token_id]);
44
+ modelInfo["bot_role"] = tokenizer.decode([model.generation_config.assistant_token_id]);
45
+ modelInfo["history_sep"] = "";
46
+
47
+ fo.write(struct.pack('i', len(modelInfo)));
48
+ for it in modelInfo.keys():
49
+ writeKeyValue(fo, str(it), str(modelInfo[it]));
50
+
51
+ # 1. vocab
52
+ if (tokenizer):
53
+ if (hasattr(tokenizer, "sp_model")):
54
+ piece_size = tokenizer.sp_model.piece_size();
55
+ fo.write(struct.pack('i', piece_size));
56
+ for i in range(piece_size):
57
+ s = tokenizer.sp_model.id_to_piece(i).encode();
58
+ fo.write(struct.pack('i', len(s)));
59
+ for c in s:
60
+ fo.write(struct.pack('i', c));
61
+ fo.write(struct.pack('i', i));
62
+ else:
63
+ vocab = tokenizer.get_vocab();
64
+ fo.write(struct.pack('i', len(vocab)));
65
+ for v in vocab.keys():
66
+ s = v.encode();
67
+ fo.write(struct.pack('i', len(s)));
68
+ for c in s:
69
+ fo.write(struct.pack('i', c));
70
+ fo.write(struct.pack('i', vocab[v]));
71
+ else:
72
+ fo.write(struct.pack('i', 0));
73
+
74
+ # 2. weight
75
+ fo.write(struct.pack('i', len(dict)));
76
+ tot = 0;
77
+ for key in dict:
78
+ cur = dict[key].numpy().astype(np.float32);
79
+ fo.write(struct.pack('i', len(key)));
80
+ fo.write(key.encode());
81
+ fo.write(struct.pack('i', len(cur.shape)));
82
+ for i in cur.shape:
83
+ fo.write(struct.pack('i', i));
84
+ fo.write(struct.pack('i', 0));
85
+ fo.write(cur.data);
86
+ tot += 1;
87
+ print("output (", tot, "/", len(dict), end = " )\r");
88
+ print("\nfinish.");
89
+ fo.close();