|
|
import abc |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import numpy.typing as npt |
|
|
|
|
|
|
|
|
class LlamaDraftModel(abc.ABC): |
|
|
@abc.abstractmethod |
|
|
def __call__( |
|
|
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any |
|
|
) -> npt.NDArray[np.intc]: |
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
class LlamaPromptLookupDecoding(LlamaDraftModel): |
|
|
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding""" |
|
|
|
|
|
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10): |
|
|
self.max_ngram_size = max_ngram_size |
|
|
self.num_pred_tokens = num_pred_tokens |
|
|
|
|
|
@staticmethod |
|
|
def find_candidate_pred_tokens( |
|
|
input_ids: npt.NDArray[np.intc], |
|
|
max_ngram_size: int, |
|
|
num_pred_tokens: int, |
|
|
): |
|
|
input_length = input_ids.shape[0] |
|
|
|
|
|
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1): |
|
|
|
|
|
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) |
|
|
|
|
|
|
|
|
ngram_array = input_ids[-ngram_size:] |
|
|
|
|
|
|
|
|
matches = np.all(windows == ngram_array, axis=1) |
|
|
|
|
|
|
|
|
match_indices = np.nonzero(matches)[0] |
|
|
|
|
|
|
|
|
for idx in match_indices: |
|
|
start_idx = idx + ngram_size |
|
|
end_idx = start_idx + num_pred_tokens |
|
|
end_idx = min(end_idx, input_length) |
|
|
|
|
|
if start_idx < end_idx: |
|
|
return input_ids[start_idx:end_idx] |
|
|
|
|
|
|
|
|
return np.array([], dtype=np.intc) |
|
|
|
|
|
def __call__( |
|
|
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any |
|
|
) -> npt.NDArray[np.intc]: |
|
|
return self.find_candidate_pred_tokens( |
|
|
input_ids=input_ids, |
|
|
max_ngram_size=self.max_ngram_size, |
|
|
num_pred_tokens=self.num_pred_tokens, |
|
|
) |
|
|
|