| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from collections.abc import Generator |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | class RequestValidationError(Exception): |
| | pass |
| |
|
| |
|
| | def _validate_that(condition: bool, msg: str): |
| | if not condition: |
| | raise RequestValidationError(msg) |
| |
|
| |
|
| | def _validate_non_empty(data, msg: str): |
| | if isinstance(data, torch.Tensor): |
| | _validate_that(data is not None and data.numel() > 0, msg) |
| | else: |
| | _validate_that(data is not None and data.size > 0, msg) |
| |
|
| |
|
| | def _validate_single_gt_0(data, msg: str): |
| | _validate_non_empty(data, msg) |
| | _validate_that(data.flatten()[0] > 0, msg) |
| |
|
| |
|
| | def _single_value(data: Optional[np.ndarray]): |
| | if data is None: |
| | return None |
| | return data.flatten()[0] |
| |
|
| |
|
| | @dataclass |
| | class Request: |
| | text_input: np.ndarray = np.array([]) |
| | decoder_text_input: np.ndarray = None |
| | image_input: Optional[np.ndarray] = None |
| | max_tokens: Optional[np.ndarray] = None |
| | bad_words: Optional[np.ndarray] = None |
| | stop_words: Optional[np.ndarray] = None |
| | end_id: Optional[np.ndarray] = None |
| | pad_id: Optional[np.ndarray] = None |
| | top_k: Optional[np.ndarray] = None |
| | top_p: Optional[np.ndarray] = None |
| | temperature: Optional[np.ndarray] = None |
| | length_penalty: Optional[np.ndarray] = None |
| | repetition_penalty: Optional[np.ndarray] = None |
| | min_length: Optional[np.ndarray] = None |
| | return_log_probs: Optional[np.ndarray] = None |
| | prompt_embedding_table: Optional[np.ndarray] = None |
| | prompt_vocab_size: Optional[np.ndarray] = None |
| | embedding_bias_words: Optional[np.ndarray] = None |
| | embedding_bias_weights: Optional[np.ndarray] = None |
| | num_draft_tokens: Optional[np.ndarray] = None |
| | use_draft_logits: Optional[np.ndarray] = None |
| | stream: Optional[np.ndarray] = None |
| | beam_width: Optional[np.ndarray] = None |
| | return_context_logits: Optional[np.ndarray] = None |
| | return_generation_logits: Optional[np.ndarray] = None |
| | random_seed: Optional[np.ndarray] = None |
| | presence_penalty: Optional[np.ndarray] = None |
| | frequency_penalty: Optional[np.ndarray] = None |
| |
|
| | def validate(self): |
| | _validate_non_empty(self.text_input, "text_input is required") |
| | _validate_single_gt_0(self.max_tokens, |
| | "max_tokens must be a single value > 0") |
| |
|
| | num_draft_tokens = _single_value(self.num_draft_tokens) |
| | _single_value(self.return_generation_logits) |
| | context_logits = _single_value(self.return_context_logits) |
| |
|
| | if num_draft_tokens: |
| | _validate_that( |
| | not self.stream.any(), |
| | "streaming is not supported with speculative decoding") |
| | _validate_that( |
| | not context_logits, |
| | "context logits are not supported with speculative decoding") |
| |
|
| |
|
| | @dataclass |
| | class DraftRequest: |
| | draft_input_ids: Optional[np.ndarray] = None |
| | draft_logits: Optional[np.ndarray] = None |
| |
|
| |
|
| | @dataclass |
| | class PreprocResponse: |
| | input_ids: np.ndarray = np.array([]) |
| | decoder_input_ids: np.ndarray = None |
| | input_lengths: np.ndarray = np.array([]) |
| | decoder_input_lengths: np.ndarray = None |
| | bad_words_list: Optional[np.ndarray] = None |
| | stop_words_list: Optional[np.ndarray] = None |
| | embedding_bias: Optional[np.ndarray] = None |
| | end_id: Optional[np.ndarray] = None |
| | pad_id: Optional[np.ndarray] = None |
| |
|
| | @classmethod |
| | def with_new_inputs(cls, |
| | other, |
| | input_ids: Optional[np.ndarray] = None, |
| | input_lengths: Optional[np.ndarray] = None): |
| | return cls(input_ids=(input_ids |
| | if input_ids is not None else other.input_ids), |
| | input_lengths=(input_lengths if input_lengths is not None |
| | else other.input_lengths), |
| | decoder_input_ids=other.decoder_input_ids, |
| | decoder_input_lengths=other.decoder_input_lengths, |
| | bad_words_list=other.bad_words_list, |
| | stop_words_list=other.stop_words_list, |
| | end_id=other.end_id, |
| | pad_id=other.pad_id) |
| |
|
| |
|
| | @dataclass |
| | class MultimodalEncResponse: |
| | prompt_embedding_table: Optional[torch.Tensor] = None |
| | prompt_vocab_size: Optional[np.ndarray] = None |
| |
|
| |
|
| | @dataclass |
| | class GenerationResponse: |
| | output_ids: np.ndarray = np.array([]) |
| | sequence_length: np.ndarray = np.array([]) |
| | cum_log_probs: Optional[np.ndarray] = None |
| | output_log_probs: Optional[np.ndarray] = None |
| | context_logits: Optional[np.ndarray] = None |
| | generation_logits: Optional[np.ndarray] = None |
| | batch_index: Optional[np.ndarray] = None |
| |
|
| |
|
| | @dataclass |
| | class Response: |
| | text_output: np.ndarray = np.array([]) |
| | cum_log_probs: Optional[np.ndarray] = None |
| | output_log_probs: Optional[np.ndarray] = None |
| | context_logits: Optional[np.ndarray] = None |
| | generation_logits: Optional[np.ndarray] = None |
| | batch_index: Optional[np.ndarray] = None |
| |
|
| | def __eq__(self, o) -> bool: |
| | """Just for testing""" |
| | if not isinstance(o, Response): |
| | return False |
| | return (np.array_equal(self.text_output, o.text_output) |
| | and np.array_equal(self.cum_log_probs, o.cum_log_probs) |
| | and np.array_equal(self.output_log_probs, o.output_log_probs) |
| | and np.array_equal(self.context_logits, o.context_logits) |
| | and np.array_equal(self.generation_logits, o.generation_logits) |
| | and np.array_equal(self.batch_index, o.batch_index)) |
| |
|
| |
|
| | class Decoder: |
| |
|
| | def __init__(self, streaming=False, accumulate=False): |
| | self._streaming = streaming |
| | self._accumulate = accumulate |
| |
|
| | self._accumulated_tokens = [] |
| |
|
| | def decode(self, |
| | request: Request, |
| | speculative_decoding=False, |
| | is_multimodal=False) -> Generator[Response, None, None]: |
| |
|
| | batch_size = request.text_input.shape[0] |
| | self._accumulated_tokens = [None] * batch_size |
| | preproc_response = self.preprocess(request) |
| |
|
| | multimodal_enc_response = None |
| | if is_multimodal: |
| | multimodal_enc_response = self._multimodal_enc_generate(request) |
| |
|
| | if speculative_decoding: |
| | if batch_size > 1: |
| | raise Exception( |
| | "speculative decoding is not supported with batch size > 1" |
| | ) |
| | for gen_response in self._spec_generate(preproc_response, request): |
| | yield self.postprocess(gen_response, batch_size) |
| | else: |
| | if not self._streaming and batch_size == 1: |
| | gen_response = self._generate_non_streaming( |
| | preproc_response, |
| | request, |
| | multimodal_enc_response=multimodal_enc_response) |
| | yield self.postprocess(gen_response, batch_size) |
| | else: |
| | for gen_response in self._generate( |
| | preproc_response, |
| | request, |
| | multimodal_enc_response=multimodal_enc_response): |
| | yield self.postprocess(gen_response, batch_size) |
| |
|
| | def encountered_stop_words(self, input_ids, stop_words_ids): |
| | for stop_word_ids in stop_words_ids: |
| | if np.array_equal(input_ids[-len(stop_word_ids):], stop_word_ids): |
| | return True |
| | return False |
| |
|
| | def _spec_generate( |
| | self, preproc: PreprocResponse, |
| | request: Request) -> Generator[GenerationResponse, None, None]: |
| |
|
| | if preproc.input_ids.shape[0] > 1: |
| | raise Exception( |
| | "Speculative decoding does not support batch size > 1.") |
| |
|
| | prompt_input_ids: np.ndarray = preproc.input_ids[0] |
| | input_ids: np.ndarray = prompt_input_ids |
| | output_len: int = request.max_tokens[0][0] |
| | last_input_ids: np.ndarray = None |
| | draft_output_ids: np.ndarray = None |
| | draft_logits: np.ndarray = None |
| |
|
| | target_response: GenerationResponse = None |
| |
|
| | cur_preproc = preproc |
| |
|
| | counter = 0 |
| | while True: |
| | counter += 1 |
| | num_draft_tokens = min( |
| | request.num_draft_tokens[0][0], |
| | len(prompt_input_ids) + output_len - len(input_ids) - 1) |
| |
|
| | draft_request = None |
| | if num_draft_tokens > 0: |
| | draft_response: GenerationResponse = self._draft_generate_non_streaming( |
| | cur_preproc, request, num_draft_tokens) |
| | seq_len: int = draft_response.sequence_length[0][0] |
| | |
| | draft_output_ids = draft_response.output_ids[0][0] |
| | |
| | if request.use_draft_logits is not None and request.use_draft_logits[ |
| | 0]: |
| | if draft_response.generation_logits is not None: |
| | draft_logits = draft_response.generation_logits[0][0] |
| |
|
| | input_draft_tokens = draft_output_ids[len(input_ids):seq_len] |
| | draft_request = DraftRequest( |
| | draft_input_ids=np.expand_dims(input_draft_tokens, 0)) |
| | if request.use_draft_logits is not None and request.use_draft_logits[ |
| | 0]: |
| | draft_request.draft_logits = np.expand_dims( |
| | draft_logits[-len(input_draft_tokens):], 0) |
| | else: |
| | draft_request = DraftRequest() |
| | target_response = self._generate_non_streaming( |
| | cur_preproc, request, draft_request) |
| | last_input_ids = input_ids |
| | input_ids = target_response.output_ids[0][0] |
| | cur_preproc = PreprocResponse.with_new_inputs( |
| | cur_preproc, np.expand_dims(input_ids, 0), |
| | np.array([[len(input_ids)]], dtype=np.int32)) |
| |
|
| | |
| | |
| | length_stop = (len(input_ids) >= |
| | len(prompt_input_ids) + output_len) |
| | if length_stop: |
| | break |
| | |
| | |
| | target_draft_equal = draft_output_ids is not None and np.array_equal( |
| | draft_output_ids, input_ids) |
| | if target_draft_equal: |
| | break |
| | |
| | last_current_equal = np.array_equal(last_input_ids, input_ids) |
| | if last_current_equal: |
| | break |
| | |
| | hit_stop_words = self.encountered_stop_words( |
| | input_ids, preproc.stop_words_list[0]) |
| | if hit_stop_words: |
| | break |
| |
|
| | yield target_response |
| |
|
| | def _draft_generate_non_streaming( |
| | self, preproc: PreprocResponse, request: Request, |
| | num_draft_tokens: int) -> GenerationResponse: |
| | raise NotImplementedError() |
| |
|
| | def _multimodal_enc_generate( |
| | self, |
| | request: Request, |
| | ) -> MultimodalEncResponse: |
| | raise NotImplementedError() |
| |
|
| | def _generate( |
| | self, |
| | preproc: PreprocResponse, |
| | request: Request, |
| | draft_request: Optional[DraftRequest] = None, |
| | multimodal_enc_response: Optional[MultimodalEncResponse] = None, |
| | ) -> Generator[GenerationResponse, None, None]: |
| | raise NotImplementedError() |
| |
|
| | def _generate_non_streaming( |
| | self, |
| | preproc: PreprocResponse, |
| | request: Request, |
| | draft_request: Optional[DraftRequest] = None, |
| | multimodal_enc_response: Optional[MultimodalEncResponse] = None, |
| | ) -> GenerationResponse: |
| | raise NotImplementedError() |
| |
|
| | def postprocess(self, gen_response: GenerationResponse, |
| | batch_size) -> Response: |
| | if self._accumulate and self._streaming: |
| | new_tokens: np.ndarray = gen_response.output_ids |
| | if new_tokens.ndim != 3: |
| | raise Exception("Expected output_ids tensor to have 3 dims.") |
| | if new_tokens.shape[0] != 1: |
| | raise Exception("Expected batch size of 1") |
| | if new_tokens.shape[1] != 1: |
| | raise Exception( |
| | "Accumulation of tokens is only implemented for beam width = 1" |
| | ) |
| |
|
| | batch_index = gen_response.batch_index |
| | if batch_index.ndim != 2: |
| | raise Exception("Expected batch_index tensor to have 2 dims.") |
| | if batch_index.shape[0] != 1: |
| | raise Exception("Expected batch size of 1") |
| | if batch_index.shape[1] != 1: |
| | raise Exception("Expected only one batch_index") |
| |
|
| | batch_index = batch_index[0][0] |
| |
|
| | self._accumulated_tokens[batch_index] = new_tokens if ( |
| | self._accumulated_tokens[batch_index] is None |
| | ) else np.concatenate( |
| | (self._accumulated_tokens[batch_index], new_tokens), axis=2) |
| | sequence_lengths = np.array( |
| | [[self._accumulated_tokens[batch_index].shape[2]]], |
| | dtype=np.int32) |
| | return self._postprocess(self._accumulated_tokens[batch_index], |
| | sequence_lengths, gen_response) |
| | else: |
| | return self._postprocess(gen_response.output_ids, None, |
| | gen_response) |
| |
|
| | def _postprocess(self, tokens: np.ndarray, |
| | sequence_lengths: Optional[np.ndarray], |
| | gen_response: GenerationResponse) -> Response: |
| | raise NotImplementedError() |
| |
|
| | def preprocess(self, request: Request) -> PreprocResponse: |
| | raise NotImplementedError() |
| |
|
| | def reset_decoder(self): |
| | self._accumulated_tokens = [] |
| |
|