File size: 5,274 Bytes
ea3734f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli
# Barebones generation class for standalone inference.
import torch
from stripedhyena.sample import sample
from stripedhyena.tokenizer import CharLevelTokenizer
from stripedhyena.utils import print_rank_0
class Generator:
def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
self.model = model
self.tokenizer = tokenizer
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.untils = ["\n\n"]
def generate(
self,
device,
input_string=None,
input_ids=None,
num_tokens=32,
cached_generation=False,
print_generation=True,
verbose=False,
skip_special_tokens=False,
stop_at_eos=True,
max_seqlen=None,
):
if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
if input_ids is None:
input = self.tokenizer.tokenize(input_string)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)
else:
input = input_ids
x = input
if max_seqlen is not None:
x = x[:, -max_seqlen:]
prompt_len = x.shape[-1]
num_tokens = int(num_tokens)
tot_length = prompt_len + num_tokens
batch_size = x.shape[0]
generation = torch.empty(
x.shape[0],
num_tokens,
dtype=torch.long,
device=x.device,
)
scores = torch.empty(
x.shape[0],
num_tokens,
self.tokenizer.vocab_size,
dtype=torch.float,
device=x.device,
)
if cached_generation:
inference_params_dict_out = self.model.initialize_inference_params()
inference_params_dict_out["mha"].max_batch_size = batch_size
inference_params_dict_out["hyena"].max_batch_size = batch_size
else:
inference_params_dict_out = None
if verbose:
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
print_rank_0("Starting generation...")
if input_string is not None:
print_rank_0("Prompt: " + input_string)
else:
print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")
for i in range(int(num_tokens)):
post_prefill = cached_generation and i > 0
# prefill then process only the last token
if post_prefill:
x = x[:, -1:]
seqlen_offset = inference_params_dict_out["mha"].seqlen_offset
if seqlen_offset == 0:
seqlen_offset = input.shape[-1]
inference_params_dict_out["hyena"].seqlen_offset = seqlen_offset
inference_params_dict_out["mha"].seqlen_offset = seqlen_offset
else:
inference_params_dict_out["mha"].seqlen_offset += 1
inference_params_dict_out["hyena"].seqlen_offset += 1
# do forward pass with no gradient
with torch.no_grad():
logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)
last_logits = logits[:, -1]
new_idx = sample(
last_logits,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
)
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
print_rank_0("Stopping generation at EOS")
if print_generation and verbose and batch_size == 1:
print_rank_0(
f"{self.tokenizer.detokenize([new_idx.item()])}",
end=" ",
)
scores[:, i] = last_logits
generation[:, i] = new_idx
if post_prefill:
x = new_idx[:, None]
else:
x = torch.cat([x, new_idx[:, None]], dim=-1)
if verbose:
kwargs = {}
if not isinstance(self.tokenizer, CharLevelTokenizer):
kwargs["skip_special_tokens"] = skip_special_tokens
y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs)
for until in self.untils:
if until in y:
y = y.split(until)[0]
break
print_rank_0(f"\nInput: {input_string}, Output: {y}")
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
print_rank_0(f"Memory after generation: {mem_end} GB")
return generation[:, : i + 1], scores[:, : i + 1]
|