NeMo / nemo /collections /nlp /modules /common /text_generation_utils.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
from collections.abc import Iterable
import numpy as np
import torch
import torch.nn.functional as F
from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.collections.nlp.modules.common.text_generation_strategy import model_inference_strategy_dispatcher
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, OutputType, SamplingParam
from nemo.utils import AppState
try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
__all__ = [
"get_default_sampling_params",
"get_default_length_params",
"megatron_gpt_generate",
"get_computeprob_response",
"generate",
"sample_token_greedy",
"sample_token_topk",
]
def get_default_sampling_params():
# default do greedy sampling
sampling_params: SamplingParam = {
"use_greedy": True,
"temperature": 1.0,
"top_k": 0,
"top_p": 1.0,
"repetition_penalty": 1.0,
"add_BOS": True,
"all_probs": False,
"compute_logprob": False,
}
return sampling_params
def get_default_length_params():
# default do greedy sampling
length_params: LengthParam = {"min_length": 0, "max_length": 30}
return length_params
def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_params, **strategy_args):
# reproduce the old compute_prob method
# a very special case
if sampling_params['compute_logprob']:
# need to overwrite some configuration, make it immutable
sampling_params = sampling_params.copy()
length_params = length_params.copy()
length_params['max_length'] = 1
sampling_params['all_probs'] = True
sampling_params["add_BOS"] = False
sampling_params['greedy'] = True
response = generate(
model,
inputs=inputs,
tokens_to_generate=length_params['max_length'],
all_probs=sampling_params['all_probs'],
temperature=sampling_params['temperature'],
add_BOS=sampling_params['add_BOS'],
top_k=sampling_params['top_k'],
top_p=sampling_params['top_p'],
greedy=sampling_params['use_greedy'],
repetition_penalty=sampling_params['repetition_penalty'],
min_tokens_to_generate=length_params['min_length'],
**strategy_args,
)
compute_prob_response = get_computeprob_response(tokenizer, response, inputs)
return compute_prob_response
if isinstance(inputs, (list, tuple)):
if isinstance(inputs[0], (str, torch.Tensor)):
output = generate(
model,
inputs=inputs,
tokens_to_generate=length_params['max_length'],
all_probs=sampling_params['all_probs'],
temperature=sampling_params['temperature'],
add_BOS=sampling_params['add_BOS'],
top_k=sampling_params['top_k'],
top_p=sampling_params['top_p'],
greedy=sampling_params['use_greedy'],
repetition_penalty=sampling_params['repetition_penalty'],
min_tokens_to_generate=length_params['min_length'],
**strategy_args,
)
return output
elif isinstance(inputs[0], dict):
raise NotImplementedError("json object not implemented")
else:
raise NotImplementedError("unknown type is not implemented")
else:
raise NotImplementedError("unknown type is not implemented")
def get_computeprob_response(tokenizer, response, inputs):
compute_prob_response = {}
new_token_ids = []
new_tokens = []
new_texts = []
log_probs = []
full_logprobs = []
offsets = []
for batch_id in range(len(response['tokens'])):
if isinstance(inputs, (list, tuple)):
if isinstance(inputs[0], str):
new_token_id = tokenizer.text_to_ids(inputs[batch_id])
new_text = inputs[batch_id]
token_len = len(new_token_id)
elif isinstance(inputs[0], torch.Tensor):
token_len = int(inputs[1][batch_id].item())
new_token_id = inputs[0][batch_id][:token_len].tolist()
new_text = tokenizer.ids_to_text(new_token_id)
new_token_ids.append(new_token_id)
new_tokens.append(response['tokens'][batch_id][:token_len])
new_texts.append(new_text)
log_probs.append(response['logprob'][batch_id][:token_len])
full_logprobs.append(response['full_logprob'][batch_id][:token_len])
offsets.append(response['offsets'][batch_id][:-1])
compute_prob_response['sentences'] = new_texts
compute_prob_response['tokens'] = new_tokens
compute_prob_response['token_ids'] = new_token_ids
compute_prob_response['logprob'] = log_probs
compute_prob_response['full_logprob'] = full_logprobs
compute_prob_response['offsets'] = offsets
return compute_prob_response
def get_batch(model, tokenizer, context_tokens):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eos_id,
model.cfg.get('reset_position_ids', False),
model.cfg.get('reset_attention_mask', False),
model.cfg.get('eod_mask_loss', False),
)
return tokens, attention_mask, position_ids
def tab_logits(logits, min_id, max_id, filter_value=-float('Inf')):
logits[:, :min_id] = filter_value
logits[:, max_id:] = filter_value
return logits
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def repetition_penalty(logits, repetition_penalty, used_tokens):
""" Implement the repetition penalty, check paper
https://arxiv.org/pdf/1909.05858.pdf
"""
if used_tokens is not None and repetition_penalty != 1.0:
logits_update = torch.gather(logits, 1, used_tokens)
logits = torch.scatter(logits, 1, used_tokens, logits_update / repetition_penalty)
return logits
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the model parallel group."""
world_size = torch.distributed.get_world_size()
all_ranks = np.arange(world_size)
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
# [pipeline dim, data parallel, tensor dim]
all_ranks = all_ranks.reshape(pp_size, -1, tp_size)
dp_rank = parallel_state.get_data_parallel_rank()
return all_ranks[:, dp_rank, :].min()
def send_generate_info(
context_tokens_tensor,
context_length_tensor,
tokens_to_generate,
all_probs,
temperature,
top_k,
top_p,
greedy,
repetition_penalty,
min_tokens_to_generate,
):
"""
Needs to be synced up with receive_generate_info
"""
model_parallel_group = parallel_state.get_model_parallel_group()
src = get_model_parallel_src_rank()
# Send the sizes of the tensors
input_info = [
context_tokens_tensor.size(0), # batch_size
context_tokens_tensor.size(1), # seq_len
tokens_to_generate,
all_probs,
temperature,
top_k,
top_p,
greedy,
repetition_penalty,
min_tokens_to_generate,
]
input_info_tensor = torch.cuda.FloatTensor(input_info)
torch.distributed.broadcast(input_info_tensor, src, model_parallel_group)
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, src, model_parallel_group)
torch.distributed.broadcast(context_tokens_tensor, src, model_parallel_group)
def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
model_parallel_group = parallel_state.get_model_parallel_group()
src = get_model_parallel_src_rank()
input_info_tensor = torch.empty(10, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, src, model_parallel_group)
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item())
all_probs = bool(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item())
top_k = int(input_info_tensor[5].item())
top_p = float(input_info_tensor[6].item())
greedy = bool(input_info_tensor[7].item())
repetition_penalty = float(input_info_tensor[8].item())
min_tokens_to_generate = int(input_info_tensor[9].item())
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, src, model_parallel_group)
torch.distributed.broadcast(context_tokens_tensor, src, model_parallel_group)
return (
context_length_tensor,
context_tokens_tensor,
tokens_to_generate,
all_probs,
temperature,
top_k,
top_p,
greedy,
repetition_penalty,
min_tokens_to_generate,
)
def synced_generate(
model,
inference_strategy,
context_tokens_tensor,
context_length_tensor,
tokens_to_generate,
all_probs,
temperature,
top_k=0,
top_p=0.0,
greedy=False,
repetition_penalty=1.2,
min_tokens_to_generate=0,
):
context_length = context_length_tensor.min().item()
tokenizer = model.tokenizer
if isinstance(tokenizer, TabularTokenizer):
batch_token_iterator = tab_sample_sequence_batch(
model,
inference_strategy,
context_tokens_tensor,
context_length_tensor,
tokens_to_generate,
all_probs,
temperature=temperature,
)
else:
batch_token_iterator = sample_sequence_batch(
model,
inference_strategy,
context_tokens_tensor,
context_length_tensor,
tokens_to_generate,
all_probs,
temperature=temperature,
extra={
"top_p": top_p,
"top_k": top_k,
"greedy": greedy,
"repetition_penalty": repetition_penalty,
"min_tokens_to_generate": min_tokens_to_generate,
},
)
for tokens, lengths, output_logits, full_logits in batch_token_iterator:
context_length += 1
if parallel_state.is_pipeline_last_stage():
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
torch.distributed.broadcast(full_logits, src, group)
else:
if parallel_state.is_pipeline_first_stage():
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
output_logits = torch.empty(
tokens.size(0), context_length - 1, dtype=torch.float32, device=torch.device("cuda")
)
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
full_logits = torch.empty(
tokens.size(0),
context_length - 1,
model.padded_vocab_size,
dtype=torch.float32,
device=torch.device("cuda"),
)
torch.distributed.broadcast(full_logits, src, group)
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
def generate(
model,
inputs=None,
tokens_to_generate=0,
all_probs=False,
temperature=1.0,
add_BOS=False,
top_k=0,
top_p=0.0,
greedy=False,
repetition_penalty=1.0,
min_tokens_to_generate=0,
**strategy_args,
) -> OutputType:
"""
Args:
model (NLPModel): text generative model
inputs (Union[tuple, List[str]]): if it is a tuple, it is assumed to be (context_tokens_tensor, context_length_tensor). Otherwise it it a list of prompt text strings
tokens_to_generate (int): The maximum length of the tokens to be generated.
all_probs (bool): Return the log prob for all the tokens
temperature (float): sampling temperature
add_BOS (bool): add the bos token at the begining of the prompt
top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (float): If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
greedy (bool): Whether or not to use sampling ; use greedy decoding otherwise
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty
min_tokens_to_generate (int): The minimum length of the tokens to be generated
strategy_args, the extra arguments are treated as inference strategy arguments
Returns:
OutputType: It generates the output in a dictionary type. It has the following keys:
sentences: List[str], output sentences
tokens: List[List[str]], output sentences borken into tokens
logprob: List[Tensor], log prob of generated tokens
full_logprob: List[Tensor], log prob of all the tokens in the vocab
token_ids: List[Tensor], output sentence token ids
offsets: List[List[int]] # list of tokens start positions in text
"""
if 'strategy' in strategy_args:
inference_strategy = strategy_args['strategy']
else:
inference_strategy = model_inference_strategy_dispatcher(model, **strategy_args)
tokenizer = model.tokenizer
if torch.distributed.get_rank() == get_model_parallel_src_rank():
if isinstance(inputs, tuple):
context_tokens_tensor, context_length_tensor = inputs
else:
context_tokens_tensor, context_length_tensor = inference_strategy.tokenize_batch(
inputs, tokens_to_generate, add_BOS
)
send_generate_info(
context_tokens_tensor,
context_length_tensor,
tokens_to_generate,
all_probs,
temperature,
top_k,
top_p,
greedy,
repetition_penalty,
min_tokens_to_generate,
)
else:
(
context_length_tensor,
context_tokens_tensor,
tokens_to_generate,
all_probs,
temperature,
top_k,
top_p,
greedy,
repetition_penalty,
min_tokens_to_generate,
) = receive_generate_info()
output = synced_generate(
model,
inference_strategy,
context_tokens_tensor,
context_length_tensor,
tokens_to_generate,
all_probs,
temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
repetition_penalty=repetition_penalty,
min_tokens_to_generate=min_tokens_to_generate,
)
special_tokens = set()
if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is not None:
special_tokens.add(tokenizer.pad_token)
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token is not None:
special_tokens.add(tokenizer.eos_token)
if hasattr(tokenizer, 'bos_token') and tokenizer.bos_token is not None:
special_tokens.add(tokenizer.bos_token)
if hasattr(tokenizer, 'cls_token') and tokenizer.cls_token is not None:
special_tokens.add(tokenizer.cls_token)
if hasattr(tokenizer, 'unk_token') and tokenizer.unk_token is not None:
special_tokens.add(tokenizer.unk_token)
if hasattr(tokenizer, 'sep_token') and tokenizer.sep_token is not None:
special_tokens.add(tokenizer.sep_token)
if hasattr(tokenizer, 'mask_token') and tokenizer.mask_token is not None:
special_tokens.add(tokenizer.mask_token)
if output is not None:
decode_tokens, output_logits, full_logits = output
resp_sentences = []
resp_sentences_seg = []
decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
sentence = tokenizer.ids_to_text(decode_token)
resp_sentences.append(sentence)
if not isinstance(tokenizer, TabularTokenizer):
words = []
for token in decode_token:
if not isinstance(token, Iterable):
token = [token]
word = tokenizer.ids_to_tokens(token)
if isinstance(word, Iterable):
word = word[0]
if hasattr(tokenizer.tokenizer, 'byte_decoder'):
word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace'
)
words.append(word)
resp_sentences_seg.append(words)
else:
words = tokenizer.text_to_tokens(sentence)
resp_sentences_seg.append(words)
# offsets calculation
all_offsets = []
for item in resp_sentences_seg:
offsets = [0]
for index, token in enumerate(item):
if index != len(item) - 1:
if token in special_tokens:
offsets.append(offsets[-1])
else:
offsets.append(len(token) + offsets[-1])
all_offsets.append(offsets)
output = {}
output['sentences'] = resp_sentences
output['tokens'] = resp_sentences_seg
output['logprob'] = output_logits
output['full_logprob'] = full_logits
output['token_ids'] = decode_tokens
output['offsets'] = all_offsets
output = inference_strategy.post_generation_process(output)
return output
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def sample_sequence_batch(
model,
inference_strategy,
context_tokens,
context_lengths,
tokens_to_generate,
all_probs=False,
type_ids=None,
temperature=None,
extra={},
):
# Importing here to avoid circular import errors
app_state = AppState()
micro_batch_size = context_tokens.shape[0]
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=micro_batch_size,
micro_batch_size=micro_batch_size,
data_parallel_size=1,
)
assert (
model.cfg.get('sequence_parallel', False) == False
), 'sequence_parallel should be False during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint'
assert (
model.cfg.get('activations_checkpoint_granularity', None) is None
), 'activations_checkpoint_granularity should be None during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint'
assert (
model.cfg.get('activations_checkpoint_method', None) is None
), 'activations_checkpoint_method should be None during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint'
tokenizer = model.tokenizer
# initialize the batch
with torch.no_grad():
context_length = context_lengths.min().item()
inference_strategy.init_batch(context_tokens, context_length)
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
eod_id = tokenizer.eos_id
counter = 0
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
output_logits = None
all_generated_indices = None # used to track all generated indices
# Generate enough tokens for the longest sequence
maxlen = tokens_to_generate + context_lengths.max().item()
maxlen = inference_strategy.clip_max_len(maxlen)
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length < maxlen:
batch, tensor_shape = inference_strategy.prepare_batch_at_step(
tokens, maxlen, micro_batch_size, counter, context_length
)
output = inference_strategy.forward_step(batch, tensor_shape)
if parallel_state.is_pipeline_last_stage():
output = output[0]['logits'].float()
output = tensor_parallel.gather_from_tensor_model_parallel_region(output)
assert output is not None
output = output.float()
logits = output[:, -1].view(batch_size, -1).contiguous()
# make sure it will generate at least min_length
min_length = extra.get('min_tokens_to_generate', 0)
if min_length > 0:
within_min_length = (context_length - context_lengths) < min_length
logits[within_min_length, eod_id] = -float('Inf')
# make sure it won't sample outside the vocab_size range
logits[:, tokenizer.vocab_size :] = -float('Inf')
if extra.get('greedy', False):
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= temperature
# handle repetition penality
logits = repetition_penalty(logits, extra.get('repetition_penalty', 1.2), all_generated_indices)
logits = top_k_logits(logits, top_k=extra.get('top_k', 0), top_p=extra.get('top_p', 0.9))
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
# Clamp the predicted out of vocabulary tokens
prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
# Replace sampled tokens w/ done token if EOD has already been sampled
new_tokens = switch(new_tokens, eod_id, is_done)
# post process the inference tokens based on the strategy
inference_strategy.post_process(tokens, new_tokens, context_length)
# Insert either new predicted or next prompt token
tokens[:, context_length] = new_tokens
if output_logits is None:
output = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2)
output_logits = torch.gather(output, 2, indices).squeeze(2)
all_generated_indices = indices[:, :, 0]
if all_probs:
full_logits = output
else:
output = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
new_output_logits = torch.gather(output, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits], 1)
all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1)
if all_probs:
full_logits = torch.cat([full_logits, output], 1)
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eod_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if all_probs:
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
else:
if parallel_state.is_pipeline_first_stage():
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None, None, None
else:
yield None, None, None, None
done = torch.cuda.ByteTensor([0])
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
def tab_sample_sequence_batch(
model,
inference_strategy,
context_tokens,
context_lengths,
tokens_to_generate,
all_probs=True,
type_ids=None,
temperature=None,
):
app_state = AppState()
micro_batch_size = context_tokens.shape[0]
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=micro_batch_size,
micro_batch_size=micro_batch_size,
data_parallel_size=1,
)
tokenizer = model.tokenizer
sizes = tokenizer.code_column.sizes
tokens_per_row = sum(sizes) + 1
columns = tokenizer.code_column.columns
num_columns = len(columns)
tokenid_range = []
for i in range(num_columns):
tokenid_range.extend(tokenizer.code_column.get_range(i))
# initialize the batch
with torch.no_grad():
context_length = context_lengths.min().item()
inference_strategy.init_batch(context_tokens, context_length)
context = context_tokens[:, :context_length]
# the context may start in the middle of the row,
# calculate the offset according to the position of '\n' or '<|endoftext|>'
positions = torch.where(context == tokenizer.eor)[1]
if len(positions) == 0:
positions = torch.where(context == tokenizer.eod)[1]
if len(positions) != 0:
max_position = positions.max().item()
# TODO, need to make sure context of different batch have the same offset lengths")
# otherwise, need to calculate offset per batch_id
offset = (context_length - max_position - 1) % tokens_per_row
else:
offset = 0
eod_id = tokenizer.eos_id
counter = 0
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
output_logits = None
# Generate enough tokens for the longest sequence
maxlen = tokens_to_generate + context_lengths.max().item()
if maxlen > model.cfg.encoder_seq_length:
maxlen = model.cfg.encoder_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length < maxlen:
batch, tensor_shape = inference_strategy.prepare_batch_at_step(
tokens, maxlen, micro_batch_size, counter, context_length
)
output = inference_strategy.forward_step(batch, tensor_shape)
if parallel_state.is_pipeline_last_stage():
output = output[0]['logits'].float()
output = tensor_parallel.gather_from_tensor_model_parallel_region(output)
assert output is not None
output = output.float()
logits = output[:, -1].view(batch_size, -1).contiguous()
token_in_row = (counter + offset) % tokens_per_row
logits = logits.float()
logits /= temperature
if token_in_row == tokens_per_row - 1:
# line break
eor_id = tokenizer.eor
eod_id = tokenizer.eos_id
min_id = min(eor_id, eod_id)
max_id = max(eor_id, eod_id) + 1
logits = tab_logits(logits, min_id, max_id)
else:
# limit the range
min_id, max_id = tokenid_range[token_in_row]
logits = tab_logits(logits, min_id, max_id)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
# Clamp the out of vocabulary tokens.
prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)
new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
# post process the inference tokens based on the strategy
inference_strategy.post_process(tokens, new_tokens, context_length)
tokens[:, context_length] = new_tokens
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
if all_probs:
full_logits = output_context
else:
output_context = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits], 1)
if all_probs:
full_logits = torch.cat([full_logits, output_context], 1)
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eod_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if all_probs:
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
else:
if parallel_state.is_pipeline_first_stage():
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None, None, None
else:
yield None, None, None, None
done = torch.cuda.ByteTensor([0])
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
def sample_token_greedy(logits):
"""
Greedy sampling. Returns the token with the highest probability, and corresponding log_prob.
Args:
logits: [batch_size, vocab_size] - unnormalized log probabilities of the next token
Returns:
log_probs: [batch_size] - log probabilities of the sampled tokens
token_ids: [batch_size] - sampled token ids
"""
log_probs, token_ids = torch.max(torch.nn.functional.log_softmax(logits, dim=-1), dim=-1)
return log_probs, token_ids
def sample_token_topk(logits, top_k=0, top_p=0.0, temperature=1.0, filter_value=-float('Inf')):
"""
Greedy sampling. Returns the token with the highest probability, and corresponding log_prob.
Args:
logits: [batch_size, vocab_size] - unnormalized log probabilities of the next token
top_k: int - if > 0: only sample from top k tokens with highest probability
top_p: float - if > 0.0: only sample from a subset of candidates, where the cumulative probability
temperature: float - temperature for sampling
filter_value: float - value to set filtered tokens to
Returns:
log_probs: [batch_size] - log probabilities of the sampled tokens
token_ids: [batch_size] - sampled token ids
"""
logits = logits.float()
logits /= temperature
logits = top_k_logits(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
token_ids = torch.multinomial(log_probs.exp(), num_samples=1).view(-1)
log_probs = log_probs.gather(1, token_ids.unsqueeze(1)).squeeze(1)
return log_probs, token_ids