| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Forward step utilities.""" |
|
|
| from collections.abc import Iterable |
|
|
| import torch |
|
|
| from megatron import ( |
| get_args, |
| mpu) |
| from .communication import ( |
| send_to_next_pipeline_rank, |
| recv_from_prev_pipeline_rank_) |
|
|
|
|
|
|
| class InferenceParams: |
| """Inference parameters that are passed to the main model in order |
| to efficienly calculate and store the context during inference.""" |
|
|
| def __init__(self, max_batch_size, max_sequence_len): |
| """Note that offsets are set to zero and we always set the |
| flag to allocate memory. After the first call, make sure to |
| set this flag to False.""" |
| self.max_sequence_len = max_sequence_len |
| self.max_batch_size = max_batch_size |
| self.sequence_len_offset = 0 |
| self.batch_size_offset = 0 |
| self.key_value_memory_dict = {} |
|
|
| def swap_key_value_dict(self, batch_idx): |
| "swap between batches" |
| if len(self.key_value_memory_dict) == 0: |
| raise ValueError("should not swap when dict in empty") |
| |
| for layer_number in self.key_value_memory_dict.keys(): |
| inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] |
| assert len(batch_idx) == inference_key_memory.shape[1] |
| new_inference_key_memory = inference_key_memory[:, batch_idx] |
| new_inference_value_memory = inference_value_memory[:, batch_idx] |
| self.key_value_memory_dict[layer_number] = ( |
| new_inference_key_memory, new_inference_value_memory) |
|
|
| class ForwardStep: |
| """Forward step function with all the communications. |
| We use a class here to hide the inference parameters |
| from the outside caller.""" |
|
|
| def __init__(self, model, max_batch_size, max_sequence_len): |
| """Set values so we don't need to do it multiple times.""" |
| |
| assert not isinstance(model, Iterable), \ |
| 'interleaving schedule is not supported for inference' |
| model.eval() |
| self.model = model |
| |
| self.inference_params = InferenceParams(max_batch_size, |
| max_sequence_len) |
| |
| args = get_args() |
| self.pipeline_size_larger_than_one = ( |
| args.pipeline_model_parallel_size > 1) |
| |
| self.pipelining_batch_x_seqlen = \ |
| args.inference_batch_times_seqlen_threshold |
|
|
|
|
| def __call__(self, tokens, position_ids, attention_mask): |
| """Invocation of the forward methods. Note that self.inference_params |
| is being modified by the forward step.""" |
| |
| if self.pipeline_size_larger_than_one: |
| current_batch_x_seqlen = tokens.size(0) * tokens.size(1) |
| if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: |
| micro_batch_size = \ |
| max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) |
| return _with_pipelining_forward_step(self.model, |
| tokens, |
| position_ids, |
| attention_mask, |
| self.inference_params, |
| micro_batch_size) |
|
|
| return _no_pipelining_forward_step(self.model, |
| tokens, |
| position_ids, |
| attention_mask, |
| self.inference_params) |
|
|
|
|
|
|
| def _get_recv_buffer_dtype(args): |
| """Receive happens between the layers.""" |
| if args.fp32_residual_connection: |
| return torch.float |
| return args.params_dtype |
|
|
|
|
|
|
| def _allocate_recv_buffer(batch_size, sequence_length): |
| """Receive happens between the layers with size [s, b, h].""" |
| if mpu.is_pipeline_first_stage(): |
| return None |
| args = get_args() |
| recv_size = (sequence_length, batch_size, args.hidden_size) |
| return torch.empty(recv_size, |
| dtype=_get_recv_buffer_dtype(args), |
| device=torch.cuda.current_device()) |
|
|
|
|
|
|
| def _forward_step_helper(model, tokens, position_ids, attention_mask, |
| inference_params, recv_buffer=None): |
| """Single forward step. Update the allocate memory flag so |
| only the first time the memory is allocated.""" |
| batch_size = tokens.size(0) |
| sequence_length = tokens.size(1) |
| if recv_buffer is None: |
| recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) |
|
|
| |
| recv_from_prev_pipeline_rank_(recv_buffer) |
|
|
| |
| model.set_input_tensor(recv_buffer) |
| output_tensor = model(tokens, position_ids, attention_mask, |
| inference_params=inference_params) |
|
|
| |
| send_to_next_pipeline_rank(output_tensor) |
|
|
| return output_tensor |
|
|
|
|
|
|
| def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, |
| inference_params, recv_buffer=None): |
| """If recv_buffer is none, we will allocate one on the fly.""" |
| |
| output_tensor = _forward_step_helper(model, tokens, position_ids, |
| attention_mask, inference_params, |
| recv_buffer=recv_buffer) |
| |
| inference_params.sequence_len_offset += tokens.size(1) |
|
|
| logits = None |
| if mpu.is_pipeline_last_stage(): |
| logits = output_tensor |
|
|
| return logits |
|
|
|
|
|
|
| def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, |
| inference_params, micro_batch_size): |
| """No interleaving is supported.""" |
| sequence_length = tokens.size(1) |
| batch_size = tokens.size(0) |
|
|
| |
| num_micro_batches, last_chunk = divmod(batch_size, |
| micro_batch_size) |
| if last_chunk > 0: |
| num_micro_batches += 1 |
|
|
| |
| logits = None |
| if mpu.is_pipeline_last_stage(): |
| args = get_args() |
| logits = torch.empty( |
| (batch_size, sequence_length, args.padded_vocab_size), |
| dtype=torch.float32, device=torch.cuda.current_device()) |
|
|
| |
| recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) |
|
|
| for micro_batch_index in range(num_micro_batches): |
| |
| start = micro_batch_index * micro_batch_size |
| end = min(start + micro_batch_size, batch_size) |
| this_micro_batch_size = end - start |
| tokens2use = tokens[start:end, ...] |
| position_ids2use = position_ids[start:end, ...] |
|
|
| |
| if this_micro_batch_size != micro_batch_size: |
| recv_buffer = None |
| output = _forward_step_helper(model, tokens2use, position_ids2use, |
| attention_mask, inference_params, |
| recv_buffer=recv_buffer) |
|
|
| |
| inference_params.batch_size_offset += this_micro_batch_size |
|
|
| |
| if mpu.is_pipeline_last_stage(): |
| logits[start:end, ...] = output |
|
|
| |
| |
| inference_params.sequence_len_offset += sequence_length |
| |
| inference_params.batch_size_offset = 0 |
|
|
| return logits |
|
|