evo2-7b / generation.py
ishanjmukherjee's picture
Copy Python verbatim from vortex
43539ed
# Copied verbatim from vortex
# Copyright (c) 2024, Michael Poli.
from dataclasses import dataclass
import torch
import sys
import numpy as np
from .sample import sample
from .tokenizer import CharLevelTokenizer
from .utils import print_rank_0
class Generator:
def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
self.model = model
self.tokenizer = tokenizer
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.untils = ["\n\n"]
def generate(
self,
device: str,
input_string: str = None,
input_ids: torch.Tensor = None,
num_tokens: int = 32,
cached_generation: bool = True,
force_prompt_threshold: int = None,
max_seqlen: int = None,
print_generation: bool = True,
verbose: bool = False,
skip_special_tokens: bool = False,
stop_at_eos: bool = True,
inference_params_dict: dict = None,
token_callback=lambda i: None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generates using the model with optional cached sampling replay.
This method enables passing in and returning the `inference_params_dict` for
replaying cached sampling from a given state, for example for beam search.
Args:
device: The device to run the model on.
input_string: The input prompt to generate from.
input_ids: The input prompt token ids to generate from.
num_tokens: The number of tokens to generate.
cached_generation: Whether to use cached generation. Defaults to False.
force_prompt_threshold: Number of tokens to prefill in parallel before
switching to prompt forcing. Used to reduce peak memory usage and
support longer prompts. Defaults to None.
max_seqlen: Maximum sequence length to generate. Determines the max size
of the cache if larger. Otherwise automatically determined using
prompt length + max_tokens. Defaults to None.
print_generation: Whether to print generated tokens. Defaults to False.
verbose: Whether to print verbose output. Defaults to False.
skip_special_tokens: Whether to skip special tokens. Defaults to True.
stop_at_eos: Whether to stop generation at EOS token. Defaults to True.
inference_params_dict: Dictionary of inference parameters to use for
replaying cached sampling. Defaults to None.
token_callback: Optional callback function called after each token is
generated. Defaults to None.
Returns:
dict: The inference parameters dictionary used for generation, which can
be used to replay the exact same sampling sequence.
"""
if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
if input_ids is None:
input = self.tokenizer.tokenize(input_string)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
else:
input = input.unsqueeze(0).to(device)
else:
input = input_ids
x = input
if max_seqlen is not None:
x = x[:, -max_seqlen:]
num_tokens = int(num_tokens)
batch_size = x.shape[0]
prompt_length = x.shape[1]
prompt_forcing = inference_params_dict is None and force_prompt_threshold is not None and prompt_length > force_prompt_threshold
if prompt_forcing:
forced_prompt_length = prompt_length - force_prompt_threshold
x_force = x[:, force_prompt_threshold:]
x = x[:, :force_prompt_threshold]
else:
forced_prompt_length = 0
tot_length = prompt_length + num_tokens
if max_seqlen is not None:
if max_seqlen > tot_length:
tot_length = max_seqlen
generation = torch.empty(
x.shape[0],
num_tokens,
dtype=torch.long,
device=x.device,
)
scores = torch.empty(
x.shape[0],
num_tokens,
self.tokenizer.vocab_size,
dtype=torch.float,
device=x.device,
)
if inference_params_dict is not None:
cached_generation = True
prefilled = True
# Ensure that the cached data is loaded on the correct device.
if any(data.device != x.device for data in inference_params_dict["hcl"].fir_state_dict.values()):
for key, data in inference_params_dict["mha"].key_value_memory_dict.items():
inference_params_dict["mha"].key_value_memory_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcl"].fir_state_dict.items():
inference_params_dict["hcl"].fir_state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcl"].state_dict.items():
inference_params_dict["hcl"].state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcm"].fir_inner_state_dict.items():
inference_params_dict["hcm"].fir_inner_state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcm"].fir_state_dict.items():
inference_params_dict["hcm"].fir_state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcm"].state_dict.items():
inference_params_dict["hcm"].state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcs"].fir_state_dict.items():
inference_params_dict["hcs"].fir_state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcs"].fir_inner_state_dict.items():
inference_params_dict["hcs"].fir_inner_state_dict[key] = data.to(x.device)
for key, data in inference_params_dict["hcs"].state_dict.items():
inference_params_dict["hcs"].state_dict[key] = data.to(x.device)
inference_params_dict["mha"].max_batch_size = batch_size
elif cached_generation:
inference_params_dict = self.model.initialize_inference_params(max_seqlen=tot_length)
inference_params_dict["mha"].max_batch_size = batch_size
prefilled = False
else:
inference_params_dict = None
prefilled = False
if verbose:
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
print_rank_0("Starting generation...")
if input_string is not None:
print_rank_0("Prompt: " + input_string)
else:
print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")
i = 0
for i in range(forced_prompt_length + num_tokens):
post_prefill = prefilled or (cached_generation and i > 0)
# prefill then process only the last token
if post_prefill:
x = x[:, -1:]
seqlen_offset = inference_params_dict["mha"].seqlen_offset
if seqlen_offset == 0:
if prompt_forcing:
seqlen_offset = force_prompt_threshold
else:
seqlen_offset = input.shape[-1]
inference_params_dict["mha"].seqlen_offset = seqlen_offset
inference_params_dict["hcl"].seqlen_offset = seqlen_offset
inference_params_dict["hcm"].seqlen_offset = seqlen_offset
inference_params_dict["hcs"].seqlen_offset = seqlen_offset
else:
inference_params_dict["mha"].seqlen_offset += 1
inference_params_dict["hcl"].seqlen_offset += 1
inference_params_dict["hcm"].seqlen_offset += 1
inference_params_dict["hcs"].seqlen_offset += 1
# do forward pass with no gradient
with torch.inference_mode():
logits, inference_params_dict = self.model(
x,
inference_params_dict=inference_params_dict,
)
token_callback(i)
last_logits = logits[:, -1]
if prompt_forcing and i < forced_prompt_length:
new_idx = x_force[:, i]
else:
new_idx = sample(
last_logits,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
)
if stop_at_eos and (generation[0, -1:] == eos_token_ids).all():
print("Stopping generation at EOS")
if print_generation and verbose and batch_size == 1:
print(
f"{self.tokenizer.detokenize([new_idx.item()])}",
end=" ",
flush=True,
)
if prompt_forcing:
if i >= forced_prompt_length:
scores[:, i - forced_prompt_length] = last_logits
generation[:, i - forced_prompt_length] = new_idx
else:
scores[:, i] = last_logits
generation[:, i] = new_idx
if post_prefill:
x = new_idx[:, None]
else:
x = torch.cat([x, new_idx[:, None]], dim=-1)
if verbose:
y = self.tokenizer.detokenize_batch(generation[:, : i + 1])
for until in self.untils:
if until in y:
y = y.split(until)[0]
break
print(f"\nInput: {input_string}, Output: {y}")
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
print(f"Memory after generation: {mem_end} GB")
return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict
def logits_to_logprobs(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
"""Convert logits to log probabilities."""
probs = torch.log_softmax(logits, dim=-1)
return torch.gather(probs, -1, tokens.unsqueeze(-1)).squeeze(-1)
def prepare_batch(
seqs: list[str], tokenizer: CharLevelTokenizer, prepend_bos: bool = False, device: str = "cuda:0"
) -> tuple[torch.Tensor, list[int]]:
"""Prepare a batch of sequences for the model."""
if prepend_bos:
seqs = [tokenizer.bos + seq for seq in seqs]
tokens = [tokenizer.tokenize(seq) for seq in seqs]
if isinstance(tokens[0], list):
tokens = [torch.tensor(t, dtype=torch.long) for t in tokens]
max_len = max(len(t) for t in tokens)
batch = torch.zeros((len(tokens), max_len), dtype=torch.long)
for i, t in enumerate(tokens):
batch[i, : len(t)] = t
return batch.to(device), [len(t) for t in tokens]
@dataclass(kw_only=True)
class GenerationOutput:
sequences: list[str]
logits: list[torch.Tensor]
logprobs_mean: list[float]
def generate(
*,
prompt_seqs: list[str],
model,
tokenizer: CharLevelTokenizer,
n_tokens: int = 100,
temperature: float = 0.0,
top_k: int = 1,
top_p: float = 1.0,
batched: bool = True,
prepend_bos: bool = False,
force_prompt_threshold: int = 1000,
cached_generation: bool = True,
verbose: int = 1,
device: str = "cuda:0",
**kwargs,
) -> GenerationOutput:
"""
Performs generation from a list of prompts.
If all prompts are the same length, this can do batched generation.
Also supports cached generation for efficient sampling.
"""
model.eval()
g = Generator(
model,
tokenizer,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
if batched and uniform_lengths:
input_ids_list = [
prepare_batch(
prompt_seqs,
tokenizer,
prepend_bos=prepend_bos,
device=device,
)[0]
]
else:
sys.stderr.write("WARNING: Batched generation is turned off.\n")
input_ids_list = [
prepare_batch(
[prompt_seq],
tokenizer,
prepend_bos=prepend_bos,
device=device,
)[0]
for prompt_seq in prompt_seqs
]
generated_seqs, generated_scores, logitss = [], [], []
for input_ids in input_ids_list:
batch_size = input_ids.shape[0]
output_ids, logits, _ = g.generate(
input_ids=input_ids,
num_tokens=n_tokens,
device=device,
print_generation=(verbose > 1),
verbose=(verbose > 1),
stop_at_eos=False,
force_prompt_threshold=force_prompt_threshold,
cached_generation=cached_generation,
**kwargs,
)
if verbose > 1:
print("input_ids.shape", input_ids.shape)
print("output_ids.shape", output_ids.shape)
print("logits.shape", logits.shape)
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
assert len(generated_seqs_batch) == batch_size
generated_seqs += generated_seqs_batch
logitss.append(logits)
logprobs = logits_to_logprobs(logits, output_ids)
logprobs = logprobs.float().cpu().numpy()
generated_scores += [np.mean(logprobs[idx]) for idx in range(batch_size)]
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
if verbose:
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
return GenerationOutput(
sequences=generated_seqs,
logits=logitss,
logprobs_mean=generated_scores,
)