K00B404 commited on
Commit
780f5a3
·
1 Parent(s): 93ebb51

generation.py ADDED

Browse files
Files changed (1) hide show
  1. generation.py +310 -0
generation.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import json
5
+ import os
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+ from typing import List, Literal, Optional, Tuple, TypedDict
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairscale.nn.model_parallel.initialize import (
14
+ get_model_parallel_rank,
15
+ initialize_model_parallel,
16
+ model_parallel_is_initialized,
17
+ )
18
+
19
+ from llama.model import ModelArgs, Transformer
20
+ from llama.tokenizer import Tokenizer
21
+
22
+ Role = Literal["system", "user", "assistant"]
23
+
24
+
25
+ class Message(TypedDict):
26
+ role: Role
27
+ content: str
28
+
29
+
30
+ class CompletionPrediction(TypedDict, total=False):
31
+ generation: str
32
+ tokens: List[str] # not required
33
+ logprobs: List[float] # not required
34
+
35
+
36
+ class ChatPrediction(TypedDict, total=False):
37
+ generation: Message
38
+ tokens: List[str] # not required
39
+ logprobs: List[float] # not required
40
+
41
+
42
+ Dialog = List[Message]
43
+
44
+ B_INST, E_INST = "[INST]", "[/INST]"
45
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
46
+
47
+ SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
48
+ UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
49
+
50
+
51
+ class Llama:
52
+ @staticmethod
53
+ def build(
54
+ ckpt_dir: str,
55
+ tokenizer_path: str,
56
+ max_seq_len: int,
57
+ max_batch_size: int,
58
+ model_parallel_size: Optional[int] = None,
59
+ ) -> "Llama":
60
+ if not torch.distributed.is_initialized():
61
+ torch.distributed.init_process_group("nccl")
62
+ if not model_parallel_is_initialized():
63
+ if model_parallel_size is None:
64
+ model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
65
+ initialize_model_parallel(model_parallel_size)
66
+
67
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
68
+ torch.cuda.set_device(local_rank)
69
+
70
+ # seed must be the same in all processes
71
+ torch.manual_seed(1)
72
+
73
+ if local_rank > 0:
74
+ sys.stdout = open(os.devnull, "w")
75
+
76
+ start_time = time.time()
77
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
78
+ assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
79
+ assert model_parallel_size == len(
80
+ checkpoints
81
+ ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
82
+ ckpt_path = checkpoints[get_model_parallel_rank()]
83
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
84
+ with open(Path(ckpt_dir) / "params.json", "r") as f:
85
+ params = json.loads(f.read())
86
+
87
+ model_args: ModelArgs = ModelArgs(
88
+ max_seq_len=max_seq_len,
89
+ max_batch_size=max_batch_size,
90
+ **params,
91
+ )
92
+ tokenizer = Tokenizer(model_path=tokenizer_path)
93
+ model_args.vocab_size = tokenizer.n_words
94
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
95
+ model = Transformer(model_args)
96
+ model.load_state_dict(checkpoint, strict=False)
97
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
98
+
99
+ return Llama(model, tokenizer)
100
+
101
+ def __init__(self, model: Transformer, tokenizer: Tokenizer):
102
+ self.model = model
103
+ self.tokenizer = tokenizer
104
+
105
+ @torch.inference_mode()
106
+ def generate(
107
+ self,
108
+ prompt_tokens: List[List[int]],
109
+ max_gen_len: int,
110
+ temperature: float = 0.6,
111
+ top_p: float = 0.9,
112
+ logprobs: bool = False,
113
+ echo: bool = False,
114
+ ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
115
+ params = self.model.params
116
+ bsz = len(prompt_tokens)
117
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
118
+
119
+ min_prompt_len = min(len(t) for t in prompt_tokens)
120
+ max_prompt_len = max(len(t) for t in prompt_tokens)
121
+ assert max_prompt_len <= params.max_seq_len
122
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
123
+
124
+ pad_id = self.tokenizer.pad_id
125
+ tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
126
+ for k, t in enumerate(prompt_tokens):
127
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
128
+ if logprobs:
129
+ token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
130
+
131
+ prev_pos = 0
132
+ eos_reached = torch.tensor([False] * bsz, device="cuda")
133
+ input_text_mask = tokens != pad_id
134
+ for cur_pos in range(min_prompt_len, total_len):
135
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
136
+ if logprobs:
137
+ token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
138
+ input=logits.transpose(1, 2),
139
+ target=tokens[:, prev_pos + 1 : cur_pos + 1],
140
+ reduction="none",
141
+ ignore_index=pad_id,
142
+ )
143
+ if temperature > 0:
144
+ probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
145
+ next_token = sample_top_p(probs, top_p)
146
+ else:
147
+ next_token = torch.argmax(logits[:, -1], dim=-1)
148
+
149
+ next_token = next_token.reshape(-1)
150
+ # only replace token if prompt has already been generated
151
+ next_token = torch.where(
152
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
153
+ )
154
+ tokens[:, cur_pos] = next_token
155
+ eos_reached |= (~input_text_mask[:, cur_pos]) & (
156
+ next_token == self.tokenizer.eos_id
157
+ )
158
+ prev_pos = cur_pos
159
+ if all(eos_reached):
160
+ break
161
+
162
+ if logprobs:
163
+ token_logprobs = token_logprobs.tolist()
164
+ out_tokens, out_logprobs = [], []
165
+ for i, toks in enumerate(tokens.tolist()):
166
+ # cut to max gen len
167
+ start = 0 if echo else len(prompt_tokens[i])
168
+ toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
169
+ probs = None
170
+ if logprobs:
171
+ probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
172
+ # cut to eos tok if any
173
+ if self.tokenizer.eos_id in toks:
174
+ eos_idx = toks.index(self.tokenizer.eos_id)
175
+ toks = toks[:eos_idx]
176
+ probs = probs[:eos_idx] if logprobs else None
177
+ out_tokens.append(toks)
178
+ out_logprobs.append(probs)
179
+ return (out_tokens, out_logprobs if logprobs else None)
180
+
181
+ def text_completion(
182
+ self,
183
+ prompts: List[str],
184
+ temperature: float = 0.6,
185
+ top_p: float = 0.9,
186
+ max_gen_len: Optional[int] = None,
187
+ logprobs: bool = False,
188
+ echo: bool = False,
189
+ ) -> List[CompletionPrediction]:
190
+ if max_gen_len is None:
191
+ max_gen_len = self.model.params.max_seq_len - 1
192
+ prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
193
+ generation_tokens, generation_logprobs = self.generate(
194
+ prompt_tokens=prompt_tokens,
195
+ max_gen_len=max_gen_len,
196
+ temperature=temperature,
197
+ top_p=top_p,
198
+ logprobs=logprobs,
199
+ echo=echo,
200
+ )
201
+ if logprobs:
202
+ return [
203
+ {
204
+ "generation": self.tokenizer.decode(t),
205
+ "tokens": [self.tokenizer.decode(x) for x in t],
206
+ "logprobs": logprobs_i,
207
+ }
208
+ for t, logprobs_i in zip(generation_tokens, generation_logprobs)
209
+ ]
210
+ return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
211
+
212
+ def chat_completion(
213
+ self,
214
+ dialogs: List[Dialog],
215
+ temperature: float = 0.6,
216
+ top_p: float = 0.9,
217
+ max_gen_len: Optional[int] = None,
218
+ logprobs: bool = False,
219
+ ) -> List[ChatPrediction]:
220
+ if max_gen_len is None:
221
+ max_gen_len = self.model.params.max_seq_len - 1
222
+ prompt_tokens = []
223
+ unsafe_requests = []
224
+ for dialog in dialogs:
225
+ unsafe_requests.append(
226
+ any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
227
+ )
228
+ if dialog[0]["role"] == "system":
229
+ dialog = [
230
+ {
231
+ "role": dialog[1]["role"],
232
+ "content": B_SYS
233
+ + dialog[0]["content"]
234
+ + E_SYS
235
+ + dialog[1]["content"],
236
+ }
237
+ ] + dialog[2:]
238
+ assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
239
+ [msg["role"] == "assistant" for msg in dialog[1::2]]
240
+ ), (
241
+ "model only supports 'system', 'user' and 'assistant' roles, "
242
+ "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
243
+ )
244
+ dialog_tokens: List[int] = sum(
245
+ [
246
+ self.tokenizer.encode(
247
+ f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
248
+ bos=True,
249
+ eos=True,
250
+ )
251
+ for prompt, answer in zip(
252
+ dialog[::2],
253
+ dialog[1::2],
254
+ )
255
+ ],
256
+ [],
257
+ )
258
+ assert (
259
+ dialog[-1]["role"] == "user"
260
+ ), f"Last message must be from user, got {dialog[-1]['role']}"
261
+ dialog_tokens += self.tokenizer.encode(
262
+ f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
263
+ bos=True,
264
+ eos=False,
265
+ )
266
+ prompt_tokens.append(dialog_tokens)
267
+
268
+ generation_tokens, generation_logprobs = self.generate(
269
+ prompt_tokens=prompt_tokens,
270
+ max_gen_len=max_gen_len,
271
+ temperature=temperature,
272
+ top_p=top_p,
273
+ logprobs=logprobs,
274
+ )
275
+ if logprobs:
276
+ return [
277
+ {
278
+ "generation": {
279
+ "role": "assistant",
280
+ "content": self.tokenizer.decode(t)
281
+ if not unsafe
282
+ else UNSAFE_ERROR,
283
+ },
284
+ "tokens": [self.tokenizer.decode(x) for x in t],
285
+ "logprobs": logprobs_i,
286
+ }
287
+ for t, logprobs_i, unsafe in zip(
288
+ generation_tokens, generation_logprobs, unsafe_requests
289
+ )
290
+ ]
291
+ return [
292
+ {
293
+ "generation": {
294
+ "role": "assistant",
295
+ "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
296
+ }
297
+ }
298
+ for t, unsafe in zip(generation_tokens, unsafe_requests)
299
+ ]
300
+
301
+
302
+ def sample_top_p(probs, p):
303
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
304
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
305
+ mask = probs_sum - probs_sort > p
306
+ probs_sort[mask] = 0.0
307
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
308
+ next_token = torch.multinomial(probs_sort, num_samples=1)
309
+ next_token = torch.gather(probs_idx, -1, next_token)
310
+ return next_token