|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import sys |
|
|
import numpy as np |
|
|
|
|
|
from .sample import sample |
|
|
from .tokenizer import CharLevelTokenizer |
|
|
from .utils import print_rank_0 |
|
|
|
|
|
|
|
|
class Generator: |
|
|
def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.top_k = top_k |
|
|
self.top_p = top_p |
|
|
self.temperature = temperature |
|
|
self.untils = ["\n\n"] |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
device: str, |
|
|
input_string: str = None, |
|
|
input_ids: torch.Tensor = None, |
|
|
num_tokens: int = 32, |
|
|
cached_generation: bool = True, |
|
|
force_prompt_threshold: int = None, |
|
|
max_seqlen: int = None, |
|
|
print_generation: bool = True, |
|
|
verbose: bool = False, |
|
|
skip_special_tokens: bool = False, |
|
|
stop_at_eos: bool = True, |
|
|
inference_params_dict: dict = None, |
|
|
token_callback=lambda i: None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Generates using the model with optional cached sampling replay. |
|
|
|
|
|
This method enables passing in and returning the `inference_params_dict` for |
|
|
replaying cached sampling from a given state, for example for beam search. |
|
|
|
|
|
Args: |
|
|
device: The device to run the model on. |
|
|
input_string: The input prompt to generate from. |
|
|
input_ids: The input prompt token ids to generate from. |
|
|
num_tokens: The number of tokens to generate. |
|
|
cached_generation: Whether to use cached generation. Defaults to False. |
|
|
force_prompt_threshold: Number of tokens to prefill in parallel before |
|
|
switching to prompt forcing. Used to reduce peak memory usage and |
|
|
support longer prompts. Defaults to None. |
|
|
max_seqlen: Maximum sequence length to generate. Determines the max size |
|
|
of the cache if larger. Otherwise automatically determined using |
|
|
prompt length + max_tokens. Defaults to None. |
|
|
print_generation: Whether to print generated tokens. Defaults to False. |
|
|
verbose: Whether to print verbose output. Defaults to False. |
|
|
skip_special_tokens: Whether to skip special tokens. Defaults to True. |
|
|
stop_at_eos: Whether to stop generation at EOS token. Defaults to True. |
|
|
inference_params_dict: Dictionary of inference parameters to use for |
|
|
replaying cached sampling. Defaults to None. |
|
|
token_callback: Optional callback function called after each token is |
|
|
generated. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
dict: The inference parameters dictionary used for generation, which can |
|
|
be used to replay the exact same sampling sequence. |
|
|
""" |
|
|
if isinstance(self.tokenizer.eos, int): |
|
|
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device) |
|
|
else: |
|
|
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device) |
|
|
|
|
|
if input_ids is None: |
|
|
input = self.tokenizer.tokenize(input_string) |
|
|
if isinstance(input, list): |
|
|
input = torch.LongTensor(input).unsqueeze(0).to(device) |
|
|
else: |
|
|
input = input.unsqueeze(0).to(device) |
|
|
else: |
|
|
input = input_ids |
|
|
x = input |
|
|
|
|
|
if max_seqlen is not None: |
|
|
x = x[:, -max_seqlen:] |
|
|
|
|
|
num_tokens = int(num_tokens) |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
prompt_length = x.shape[1] |
|
|
prompt_forcing = inference_params_dict is None and force_prompt_threshold is not None and prompt_length > force_prompt_threshold |
|
|
if prompt_forcing: |
|
|
forced_prompt_length = prompt_length - force_prompt_threshold |
|
|
x_force = x[:, force_prompt_threshold:] |
|
|
x = x[:, :force_prompt_threshold] |
|
|
else: |
|
|
forced_prompt_length = 0 |
|
|
tot_length = prompt_length + num_tokens |
|
|
if max_seqlen is not None: |
|
|
if max_seqlen > tot_length: |
|
|
tot_length = max_seqlen |
|
|
|
|
|
generation = torch.empty( |
|
|
x.shape[0], |
|
|
num_tokens, |
|
|
dtype=torch.long, |
|
|
device=x.device, |
|
|
) |
|
|
|
|
|
scores = torch.empty( |
|
|
x.shape[0], |
|
|
num_tokens, |
|
|
self.tokenizer.vocab_size, |
|
|
dtype=torch.float, |
|
|
device=x.device, |
|
|
) |
|
|
|
|
|
if inference_params_dict is not None: |
|
|
cached_generation = True |
|
|
prefilled = True |
|
|
|
|
|
if any(data.device != x.device for data in inference_params_dict["hcl"].fir_state_dict.values()): |
|
|
for key, data in inference_params_dict["mha"].key_value_memory_dict.items(): |
|
|
inference_params_dict["mha"].key_value_memory_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcl"].fir_state_dict.items(): |
|
|
inference_params_dict["hcl"].fir_state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcl"].state_dict.items(): |
|
|
inference_params_dict["hcl"].state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcm"].fir_inner_state_dict.items(): |
|
|
inference_params_dict["hcm"].fir_inner_state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcm"].fir_state_dict.items(): |
|
|
inference_params_dict["hcm"].fir_state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcm"].state_dict.items(): |
|
|
inference_params_dict["hcm"].state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcs"].fir_state_dict.items(): |
|
|
inference_params_dict["hcs"].fir_state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcs"].fir_inner_state_dict.items(): |
|
|
inference_params_dict["hcs"].fir_inner_state_dict[key] = data.to(x.device) |
|
|
for key, data in inference_params_dict["hcs"].state_dict.items(): |
|
|
inference_params_dict["hcs"].state_dict[key] = data.to(x.device) |
|
|
inference_params_dict["mha"].max_batch_size = batch_size |
|
|
elif cached_generation: |
|
|
inference_params_dict = self.model.initialize_inference_params(max_seqlen=tot_length) |
|
|
inference_params_dict["mha"].max_batch_size = batch_size |
|
|
prefilled = False |
|
|
else: |
|
|
inference_params_dict = None |
|
|
prefilled = False |
|
|
|
|
|
if verbose: |
|
|
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9 |
|
|
print_rank_0(f"Memory after tokenization: {mem_after_tok} GB") |
|
|
print_rank_0("Starting generation...") |
|
|
if input_string is not None: |
|
|
print_rank_0("Prompt: " + input_string) |
|
|
else: |
|
|
print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}") |
|
|
|
|
|
i = 0 |
|
|
for i in range(forced_prompt_length + num_tokens): |
|
|
post_prefill = prefilled or (cached_generation and i > 0) |
|
|
|
|
|
|
|
|
if post_prefill: |
|
|
x = x[:, -1:] |
|
|
seqlen_offset = inference_params_dict["mha"].seqlen_offset |
|
|
|
|
|
if seqlen_offset == 0: |
|
|
if prompt_forcing: |
|
|
seqlen_offset = force_prompt_threshold |
|
|
else: |
|
|
seqlen_offset = input.shape[-1] |
|
|
inference_params_dict["mha"].seqlen_offset = seqlen_offset |
|
|
inference_params_dict["hcl"].seqlen_offset = seqlen_offset |
|
|
inference_params_dict["hcm"].seqlen_offset = seqlen_offset |
|
|
inference_params_dict["hcs"].seqlen_offset = seqlen_offset |
|
|
else: |
|
|
inference_params_dict["mha"].seqlen_offset += 1 |
|
|
inference_params_dict["hcl"].seqlen_offset += 1 |
|
|
inference_params_dict["hcm"].seqlen_offset += 1 |
|
|
inference_params_dict["hcs"].seqlen_offset += 1 |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
logits, inference_params_dict = self.model( |
|
|
x, |
|
|
inference_params_dict=inference_params_dict, |
|
|
) |
|
|
|
|
|
token_callback(i) |
|
|
|
|
|
last_logits = logits[:, -1] |
|
|
|
|
|
if prompt_forcing and i < forced_prompt_length: |
|
|
new_idx = x_force[:, i] |
|
|
else: |
|
|
new_idx = sample( |
|
|
last_logits, |
|
|
top_k=self.top_k, |
|
|
top_p=self.top_p, |
|
|
temperature=self.temperature, |
|
|
) |
|
|
|
|
|
if stop_at_eos and (generation[0, -1:] == eos_token_ids).all(): |
|
|
print("Stopping generation at EOS") |
|
|
|
|
|
if print_generation and verbose and batch_size == 1: |
|
|
print( |
|
|
f"{self.tokenizer.detokenize([new_idx.item()])}", |
|
|
end=" ", |
|
|
flush=True, |
|
|
) |
|
|
|
|
|
if prompt_forcing: |
|
|
if i >= forced_prompt_length: |
|
|
scores[:, i - forced_prompt_length] = last_logits |
|
|
generation[:, i - forced_prompt_length] = new_idx |
|
|
else: |
|
|
scores[:, i] = last_logits |
|
|
generation[:, i] = new_idx |
|
|
|
|
|
if post_prefill: |
|
|
x = new_idx[:, None] |
|
|
else: |
|
|
x = torch.cat([x, new_idx[:, None]], dim=-1) |
|
|
|
|
|
if verbose: |
|
|
y = self.tokenizer.detokenize_batch(generation[:, : i + 1]) |
|
|
|
|
|
for until in self.untils: |
|
|
if until in y: |
|
|
y = y.split(until)[0] |
|
|
break |
|
|
|
|
|
print(f"\nInput: {input_string}, Output: {y}") |
|
|
|
|
|
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9 |
|
|
print(f"Memory after generation: {mem_end} GB") |
|
|
|
|
|
return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict |
|
|
|
|
|
|
|
|
def logits_to_logprobs(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert logits to log probabilities.""" |
|
|
probs = torch.log_softmax(logits, dim=-1) |
|
|
return torch.gather(probs, -1, tokens.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
|
|
|
def prepare_batch( |
|
|
seqs: list[str], tokenizer: CharLevelTokenizer, prepend_bos: bool = False, device: str = "cuda:0" |
|
|
) -> tuple[torch.Tensor, list[int]]: |
|
|
"""Prepare a batch of sequences for the model.""" |
|
|
if prepend_bos: |
|
|
seqs = [tokenizer.bos + seq for seq in seqs] |
|
|
|
|
|
tokens = [tokenizer.tokenize(seq) for seq in seqs] |
|
|
if isinstance(tokens[0], list): |
|
|
tokens = [torch.tensor(t, dtype=torch.long) for t in tokens] |
|
|
|
|
|
max_len = max(len(t) for t in tokens) |
|
|
batch = torch.zeros((len(tokens), max_len), dtype=torch.long) |
|
|
|
|
|
for i, t in enumerate(tokens): |
|
|
batch[i, : len(t)] = t |
|
|
|
|
|
return batch.to(device), [len(t) for t in tokens] |
|
|
|
|
|
|
|
|
@dataclass(kw_only=True) |
|
|
class GenerationOutput: |
|
|
sequences: list[str] |
|
|
logits: list[torch.Tensor] |
|
|
logprobs_mean: list[float] |
|
|
|
|
|
|
|
|
def generate( |
|
|
*, |
|
|
prompt_seqs: list[str], |
|
|
model, |
|
|
tokenizer: CharLevelTokenizer, |
|
|
n_tokens: int = 100, |
|
|
temperature: float = 0.0, |
|
|
top_k: int = 1, |
|
|
top_p: float = 1.0, |
|
|
batched: bool = True, |
|
|
prepend_bos: bool = False, |
|
|
force_prompt_threshold: int = 1000, |
|
|
cached_generation: bool = True, |
|
|
verbose: int = 1, |
|
|
device: str = "cuda:0", |
|
|
**kwargs, |
|
|
) -> GenerationOutput: |
|
|
""" |
|
|
Performs generation from a list of prompts. |
|
|
If all prompts are the same length, this can do batched generation. |
|
|
Also supports cached generation for efficient sampling. |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
g = Generator( |
|
|
model, |
|
|
tokenizer, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
temperature=temperature, |
|
|
) |
|
|
|
|
|
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs) |
|
|
|
|
|
if batched and uniform_lengths: |
|
|
input_ids_list = [ |
|
|
prepare_batch( |
|
|
prompt_seqs, |
|
|
tokenizer, |
|
|
prepend_bos=prepend_bos, |
|
|
device=device, |
|
|
)[0] |
|
|
] |
|
|
else: |
|
|
sys.stderr.write("WARNING: Batched generation is turned off.\n") |
|
|
input_ids_list = [ |
|
|
prepare_batch( |
|
|
[prompt_seq], |
|
|
tokenizer, |
|
|
prepend_bos=prepend_bos, |
|
|
device=device, |
|
|
)[0] |
|
|
for prompt_seq in prompt_seqs |
|
|
] |
|
|
|
|
|
generated_seqs, generated_scores, logitss = [], [], [] |
|
|
for input_ids in input_ids_list: |
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
output_ids, logits, _ = g.generate( |
|
|
input_ids=input_ids, |
|
|
num_tokens=n_tokens, |
|
|
device=device, |
|
|
print_generation=(verbose > 1), |
|
|
verbose=(verbose > 1), |
|
|
stop_at_eos=False, |
|
|
force_prompt_threshold=force_prompt_threshold, |
|
|
cached_generation=cached_generation, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if verbose > 1: |
|
|
print("input_ids.shape", input_ids.shape) |
|
|
print("output_ids.shape", output_ids.shape) |
|
|
print("logits.shape", logits.shape) |
|
|
|
|
|
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids)) |
|
|
assert len(generated_seqs_batch) == batch_size |
|
|
generated_seqs += generated_seqs_batch |
|
|
logitss.append(logits) |
|
|
|
|
|
logprobs = logits_to_logprobs(logits, output_ids) |
|
|
logprobs = logprobs.float().cpu().numpy() |
|
|
|
|
|
generated_scores += [np.mean(logprobs[idx]) for idx in range(batch_size)] |
|
|
|
|
|
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs) |
|
|
if verbose: |
|
|
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs): |
|
|
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}') |
|
|
|
|
|
return GenerationOutput( |
|
|
sequences=generated_seqs, |
|
|
logits=logitss, |
|
|
logprobs_mean=generated_scores, |
|
|
) |
|
|
|