| | import collections |
| | import fnmatch |
| | import gc |
| | import itertools |
| | import time |
| | from functools import wraps |
| | from typing import ( |
| | TYPE_CHECKING, |
| | Any, |
| | Callable, |
| | Dict, |
| | Iterable, |
| | Iterator, |
| | List, |
| | Literal, |
| | Optional, |
| | Tuple, |
| | Type, |
| | Union, |
| | ) |
| |
|
| | import torch |
| | import transformers |
| |
|
| | from lm_eval.utils import eval_logger |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from transformers import PreTrainedTokenizerBase |
| | from transformers.configuration_utils import PretrainedConfig |
| |
|
| |
|
| | def chunks(iter, n: int = 0, fn=None): |
| | """ |
| | Divides an iterable into chunks of specified size or based on a given function. |
| | Useful for batching |
| | |
| | Parameters: |
| | - iter: The input iterable to be divided into chunks. |
| | - n: An integer representing the size of each chunk. Default is 0. |
| | - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. |
| | |
| | Returns: |
| | An iterator that yields chunks of the input iterable. |
| | |
| | Example usage: |
| | ``` |
| | data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
| | for chunk in chunks(data, 3): |
| | print(chunk) |
| | ``` |
| | Output: |
| | ``` |
| | [1, 2, 3] |
| | [4, 5, 6] |
| | [7, 8, 9] |
| | [10] |
| | ``` |
| | """ |
| | arr = [] |
| | for i, x in enumerate(iter): |
| | arr.append(x) |
| | if len(arr) == (fn(i, iter) if fn else n): |
| | yield arr |
| | arr = [] |
| |
|
| | if arr: |
| | yield arr |
| |
|
| |
|
| | class MultiChoice: |
| | def __init__(self, choices) -> None: |
| | self.choices = choices |
| |
|
| | |
| | def __contains__(self, values) -> bool: |
| | for value in values.split(","): |
| | if len(fnmatch.filter(self.choices, value)) == 0: |
| | eval_logger.info("Available tasks to choose:") |
| | for choice in self.choices: |
| | eval_logger.info(f" - {choice}") |
| | raise ValueError("'{}' is not in task list".format(value)) |
| | return True |
| |
|
| | def __iter__(self) -> Iterator: |
| | for choice in self.choices: |
| | yield choice |
| |
|
| |
|
| | class Grouper: |
| | """ |
| | takes an array `arr` and function `fn` and returns a dictionary |
| | with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all |
| | objects in `arr` satisfying `key == fn(ob)`. |
| | """ |
| |
|
| | def __init__(self, arr, fn) -> None: |
| | |
| | self.size = len(arr) |
| | arr = list(enumerate(arr)) |
| |
|
| | def group_return_dict(arr, fn): |
| | res = collections.defaultdict(list) |
| |
|
| | for ob in arr: |
| | res[fn(ob)].append(ob) |
| | return res |
| |
|
| | arr = group_return_dict(arr, lambda x: fn(x[1])) |
| |
|
| | |
| | self.arr = arr |
| | self._grouped = None |
| |
|
| | def get_grouped(self): |
| | |
| | if self._grouped: |
| | return self._grouped |
| | grouped = {} |
| | for key in self.arr.keys(): |
| | |
| | grouped[key] = [y[1] for y in self.arr[key]] |
| | self._grouped = grouped |
| | return grouped |
| |
|
| | def get_original(self, grouped_dict): |
| | |
| | |
| | |
| | res = [None] * self.size |
| | cov = [False] * self.size |
| | |
| |
|
| | assert grouped_dict.keys() == self.arr.keys() |
| |
|
| | for key in grouped_dict.keys(): |
| | for (ind, _), v in zip(self.arr[key], grouped_dict[key]): |
| | res[ind] = v |
| | cov[ind] = True |
| | |
| |
|
| | assert all(cov) |
| | |
| |
|
| | return res |
| |
|
| |
|
| | def pad_and_concat( |
| | max_length: int, |
| | tensors: List[torch.Tensor], |
| | padding_side: Literal["right", "left"] = "right", |
| | ): |
| | """ |
| | Method for padding a list of tensors given the maximum tensor |
| | length in the batch. Used for batching inputs and continuations in |
| | seq2seq models. |
| | """ |
| | assert padding_side == "left" or padding_side == "right", ( |
| | f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" |
| | ) |
| |
|
| | for i, tensor in enumerate(tensors): |
| | if len(tensor.shape) == 2: |
| | tensor = tensor.squeeze(0) |
| | tensor_len = tensor.shape[0] |
| | if tensor_len < max_length: |
| | if padding_side == "right": |
| | |
| | tensors[i] = torch.cat( |
| | [ |
| | tensor, |
| | torch.zeros( |
| | max_length - tensor_len, |
| | dtype=torch.long, |
| | device=tensor.device, |
| | ), |
| | ], |
| | dim=0, |
| | ).unsqueeze(0) |
| | else: |
| | |
| | tensors[i] = torch.cat( |
| | [ |
| | torch.zeros( |
| | max_length - tensor_len, |
| | dtype=torch.long, |
| | device=tensor.device, |
| | ), |
| | tensor, |
| | ], |
| | dim=0, |
| | ).unsqueeze(0) |
| | else: |
| | tensors[i] = tensor.unsqueeze(0) |
| |
|
| | return torch.cat(tensors, dim=0) |
| |
|
| |
|
| | def clear_torch_cache() -> None: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: |
| | """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" |
| | if isinstance(dtype, str) and dtype != "auto": |
| | |
| | _torch_dtype = getattr(torch, dtype) |
| | else: |
| | _torch_dtype = dtype |
| | return _torch_dtype |
| |
|
| |
|
| | class MultiTokenEOSCriteria(transformers.StoppingCriteria): |
| | """Criteria to stop on the specified multi-token sequence.""" |
| |
|
| | def __init__( |
| | self, |
| | sequence: str, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | initial_decoder_input_length: int, |
| | batch_size: int, |
| | ) -> None: |
| | self.initial_decoder_input_length = initial_decoder_input_length |
| | self.done_tracker = [False] * batch_size |
| | self.sequence = sequence |
| | self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | self.sequence_id_len = len(self.sequence_ids) + 2 |
| | self.tokenizer = tokenizer |
| |
|
| | def __call__(self, input_ids, scores, **kwargs) -> bool: |
| | |
| | lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] |
| |
|
| | lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] |
| |
|
| | lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) |
| |
|
| | for i, done in enumerate(self.done_tracker): |
| | if not done: |
| | self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] |
| | return False not in self.done_tracker |
| |
|
| |
|
| | def stop_sequences_criteria( |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | stop_sequences: List[str], |
| | initial_decoder_input_length: int, |
| | batch_size: int, |
| | ) -> transformers.StoppingCriteriaList: |
| | return transformers.StoppingCriteriaList( |
| | [ |
| | *[ |
| | MultiTokenEOSCriteria( |
| | sequence, tokenizer, initial_decoder_input_length, batch_size |
| | ) |
| | for sequence in stop_sequences |
| | ], |
| | ] |
| | ) |
| |
|
| |
|
| | def undistribute(iterable): |
| | """ |
| | Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute . |
| | |
| | Re-interleaves results that have been split using more_itertools.distribute: |
| | >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) |
| | >>> list(group_1) |
| | [1, 3, 5] |
| | >>> list(group_2) |
| | [2, 4, 6] |
| | >>> undistribute([group_1, group_2]) |
| | [1, 2, 3, 4, 5, 6] |
| | |
| | Handles non-uniform component lengths: |
| | |
| | >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) |
| | >>> [list(c) for c in children] |
| | [[1, 4, 7], [2, 5], [3, 6]] |
| | >>> undistribute(children) |
| | [1, 2, 3, 4, 5, 6, 7] |
| | |
| | Also handles when some iterables are empty: |
| | |
| | >>> children = distribute(5, [1, 2, 3]) |
| | >>> [list(c) for c in children] |
| | [[1], [2], [3], [], []] |
| | >>> undistribute(children) |
| | [1, 2, 3] |
| | |
| | """ |
| |
|
| | return [ |
| | x |
| | for x in itertools.chain.from_iterable( |
| | itertools.zip_longest(*[list(x) for x in iterable]) |
| | ) |
| | if x is not None |
| | ] |
| |
|
| |
|
| | def retry_on_specific_exceptions( |
| | on_exceptions: List[Type[Exception]], |
| | max_retries: Optional[int] = None, |
| | backoff_time: float = 3.0, |
| | backoff_multiplier: float = 1.5, |
| | on_exception_callback: Optional[Callable[[Exception, float], Any]] = None, |
| | ): |
| | """Retry on an LLM Provider's rate limit error with exponential backoff |
| | For example, to use for OpenAI, do the following: |
| | ``` |
| | from openai import RateLimitError |
| | |
| | # Recommend specifying max_retries to avoid infinite loops! |
| | @retry_on_specific_exceptions([RateLimitError], max_retries=3) |
| | def completion(...): |
| | # Wrap OpenAI completion function here |
| | ... |
| | ``` |
| | """ |
| |
|
| | def decorator(func: Callable): |
| | @wraps(func) |
| | def wrapper(*args, **kwargs): |
| | sleep_time = backoff_time |
| | attempt = 0 |
| | while max_retries is None or attempt < max_retries: |
| | try: |
| | return func(*args, **kwargs) |
| | except tuple(on_exceptions) as e: |
| | if on_exception_callback is not None: |
| | on_exception_callback(e, sleep_time) |
| | time.sleep(sleep_time) |
| | sleep_time *= backoff_multiplier |
| | attempt += 1 |
| |
|
| | return wrapper |
| |
|
| | return decorator |
| |
|
| |
|
| | class Collator: |
| | """ |
| | A class for reordering and batching elements of an array. |
| | |
| | This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. |
| | |
| | Objects of this class have the group_by attribute which determines the method for grouping |
| | the data while batching it. Three options include "gen_kwargs", "contexts", or None: |
| | If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs |
| | If group_by == "contexts" then requests will be grouped by context + cont[:-1] |
| | If None then requests will just be reordered by length descending. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | arr: List, |
| | sort_fn: Callable = lambda x: x, |
| | group_fn: Callable = lambda x: x[1], |
| | group_by: Union[Literal["gen_kwargs", "contexts"], None] = None, |
| | ) -> None: |
| | self._group_by = group_by |
| | |
| | self._sort_fn = lambda x: sort_fn(x[1]) |
| | self._group_fn = lambda x: group_fn(x[1]) |
| | self._reorder_indices: List = [] |
| | self._size = len(arr) |
| | self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple( |
| | enumerate(arr) |
| | ) |
| | if self._group_by == "contexts": |
| | self._group_by_context() |
| | elif self._group_by == "gen_kwargs": |
| | self._group_by_index() |
| |
|
| | def _group_by_index(self) -> None: |
| | """Group the elements of a list based on their indices.""" |
| | self._arr_with_indices = self.group( |
| | self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs" |
| | ) |
| |
|
| | def _group_by_context(self) -> None: |
| | """Group the array with indices by context.""" |
| | self._arr_with_indices = self.group( |
| | self._arr_with_indices, fn=self._group_fn, group_by="contexts" |
| | ) |
| |
|
| | def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: |
| | """ |
| | Generates and yields batches from the reordered array. The method of grouping and batching |
| | depends on the parameter `group_by`. |
| | If `group_by` is set to "gen_kwargs", it will batch the |
| | re-ordered values with same gen_kwargs for each batch. |
| | If `group_by` is "contexts", it caches the requests by context before batching. |
| | If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array |
| | |
| | Parameters: |
| | - n (int): The size of each batch. Defaults to 1. |
| | - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of |
| | each batch. Optional, defaults to None. |
| | |
| | Returns: |
| | Iterator: An iterator over batches of reordered elements grouped as per the `group_by` |
| | attribute. |
| | |
| | Yields: |
| | List of batched elements according to the `group_by` attribute. |
| | """ |
| | if self._group_by == "gen_kwargs": |
| | for ( |
| | key, |
| | values, |
| | ) in self._arr_with_indices.items(): |
| | values = self._reorder(values) |
| | batch = self.get_chunks(values, n=n, fn=batch_fn) |
| | yield from batch |
| | elif self._group_by == "contexts": |
| | |
| | values = self._reorder( |
| | [value[0] for value in self._arr_with_indices.values()] |
| | ) |
| | batch = self.get_chunks(values, n=n, fn=batch_fn) |
| | yield from batch |
| | else: |
| | values = self._reorder(self._arr_with_indices) |
| | batch = self.get_chunks(values, n=n, fn=batch_fn) |
| | yield from batch |
| |
|
| | def get_cache( |
| | self, |
| | req_str: Tuple[str, str] = None, |
| | cxt_toks: List[int] = None, |
| | cont_toks: List[int] = None, |
| | logits: torch.Tensor = None, |
| | ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]: |
| | """ |
| | Retrieves cached single-token continuations and their associated arguments, updating indices as necessary. |
| | |
| | The behavior of this function varies depending on how the `group_by` attribute is set: |
| | |
| | - When `group_by` is "contexts": |
| | The function identifies single-token continuations by checking for keys that equate to |
| | [context+continuation][-1] and logs the indices for re-ordering. |
| | In this mode, this function can work in two scenarios: |
| | |
| | 1. Cache Hit - Single Match: |
| | If a single matching context-continuation pair is found in the cache, |
| | the function yields the original arguments. |
| | |
| | 2. Cache Hit - Multiple Matches: |
| | If multiple matching context-continuation pairs are found in the cache, |
| | the function expands the logits batch dimension to match the number of cache hits. |
| | It updates the original requests and continuation tokens. |
| | |
| | - When `group_by` is not set to "contexts": |
| | This method yields the original arguments, logits and continuation tokens, |
| | without checking for one-token continuations. |
| | |
| | Parameters: |
| | - req_str (tuple[str, str]): Original strings used for CachingLM. |
| | - cxt_toks (list[int]): Full context tokens used for lookup. |
| | - cont_toks (list[int]): Continuation tokens for which logits were generated. |
| | - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys. |
| | |
| | Yields: |
| | - Iterator: |
| | - req_str (tuple[str, str]): strings used for CachingLM. |
| | - cont_toks (list[int]) : continuation tokens. |
| | - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times) |
| | """ |
| | if self._group_by == "contexts": |
| | cache_hit: List[ |
| | Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]] |
| | ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1])) |
| | if (cache_size := len(cache_hit)) == 1: |
| | self._reorder_indices.extend(x[0] for x in cache_hit) |
| | yield req_str, cont_toks, logits |
| | else: |
| | |
| | |
| | multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size) |
| | indices, req_str, cont_toks = zip( |
| | *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit] |
| | ) |
| | self._reorder_indices.extend(indices) |
| | for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits): |
| | yield c_key, cont_tok, logit |
| | else: |
| | yield req_str, cont_toks, logits |
| |
|
| | def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator: |
| | """ |
| | Reorders the elements in the array based on the sorting function. |
| | |
| | Parameters: |
| | - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered. |
| | |
| | Yields: |
| | Iterator |
| | """ |
| | arr = sorted(arr, key=self._sort_fn) |
| | if not self._group_by == "contexts": |
| | |
| | self._reorder_indices.extend([x[0] for x in arr]) |
| | yield from [x[1] for x in arr] |
| |
|
| | def get_original(self, newarr: List) -> List: |
| | """ |
| | Restores the original order of elements from the reordered list. |
| | |
| | Parameters: |
| | - newarr (list): The reordered array. |
| | |
| | Returns: |
| | list: The array with elements restored to their original order. |
| | """ |
| | res = [None] * self._size |
| | cov = [False] * self._size |
| |
|
| | for ind, v in zip(self._reorder_indices, newarr): |
| | res[ind] = v |
| | cov[ind] = True |
| |
|
| | assert all(cov) |
| |
|
| | return res |
| |
|
| | def __len__(self): |
| | return self._size |
| |
|
| | @staticmethod |
| | def group( |
| | arr: Iterable, |
| | fn: Callable, |
| | group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs", |
| | ) -> dict: |
| | """ |
| | Groups elements of an iterable based on a provided function. |
| | |
| | |
| | The `group_by` parameter determines the method of grouping. |
| | If `group_by` is "contexts", the elements are grouped by [context + cont][:-1]. |
| | If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict. |
| | |
| | Parameters: |
| | - arr (Iterable): The iterable to be grouped. |
| | - fn (Callable): The function to determine the grouping. |
| | - values (bool): If True, returns the values of the group. Defaults to False. |
| | |
| | Returns: |
| | Iterator: An iterable of grouped elements. |
| | """ |
| | res = collections.defaultdict(list) |
| | for ob in arr: |
| | |
| | if group_by == "contexts": |
| | res[tuple(fn(ob))].append(ob) |
| | else: |
| | try: |
| | hashable_dict = tuple( |
| | ( |
| | key, |
| | tuple(value) |
| | if isinstance(value, collections.abc.Iterable) |
| | else value, |
| | ) |
| | for key, value in sorted(fn(ob).items()) |
| | ) |
| | res[hashable_dict].append(ob) |
| | except (TypeError, AttributeError): |
| | res[tuple(fn(ob))].append(ob) |
| | return res |
| |
|
| | @staticmethod |
| | def get_chunks(_iter, n: int = 0, fn=None): |
| | """ |
| | Divides an iterable into chunks of specified size or based on a given function. |
| | Useful for batching |
| | |
| | Parameters: |
| | - iter: The input iterable to be divided into chunks. |
| | - n: An integer representing the size of each chunk. Default is 0. |
| | - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. |
| | |
| | Returns: |
| | An iterator that yields chunks of the input iterable. |
| | |
| | Example usage: |
| | ``` |
| | data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
| | for chunk in chunks(data, 3): |
| | print(chunk) |
| | ``` |
| | Output: |
| | ``` |
| | [1, 2, 3] |
| | [4, 5, 6] |
| | [7, 8, 9] |
| | [10] |
| | ``` |
| | """ |
| | arr = [] |
| | _iter = tuple(_iter) |
| | for i, x in enumerate(_iter): |
| | arr.append(x) |
| | if len(arr) == (fn(i, _iter) if fn else n): |
| | yield arr |
| | arr = [] |
| |
|
| | if arr: |
| | yield arr |
| |
|
| |
|
| | def configure_pad_token( |
| | tokenizer: "PreTrainedTokenizerBase", |
| | model_config: Optional["PretrainedConfig"] = None, |
| | ) -> "PreTrainedTokenizerBase": |
| | """ |
| | This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present. |
| | Some tokenizers require special handling. |
| | |
| | Args: |
| | tokenizer: The tokenizer for which the padding token is to be handled. |
| | model_config: The configuration of the model. Default is None. |
| | |
| | Returns: |
| | The tokenizer after the padding token has been handled. |
| | |
| | Raises: |
| | AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0. |
| | """ |
| | if tokenizer.pad_token: |
| | pass |
| | elif tokenizer.unk_token: |
| | tokenizer.pad_token_id = tokenizer.unk_token_id |
| | elif tokenizer.eos_token: |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | else: |
| | |
| | if model_config and getattr(model_config, "model_type", None) == "qwen": |
| | |
| | tokenizer.pad_token = "<|endoftext|>" |
| | elif ( |
| | tokenizer.__class__.__name__ == "RWKVWorldTokenizer" |
| | or tokenizer.__class__.__name__ == "Rwkv5Tokenizer" |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | assert tokenizer.pad_token_id == 0 |
| | else: |
| | tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) |
| |
|
| | return tokenizer |
| |
|
| |
|
| | def replace_placeholders( |
| | string: str, default_placeholder: str, image_token: str, max_images: int |
| | ): |
| | """ |
| | A utility function used for local multimodal models. It locates all `placeholder` string |
| | occurrences in the given input `string_` and replaces the first `max_count` instances with |
| | `replacement`, and all subsequent occurrences with the empty string. |
| | |
| | This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|> |
| | and to allow for only the first `max_count` images to be passed to a model if desired. |
| | |
| | :param string: The original string containing placeholders. |
| | :param default_placeholder: The placeholder text to be replaced. |
| | :param image_token: The token to replace the placeholder with. |
| | :param max_images: The maximum number of replacements to make. |
| | :return: The string with placeholders replaced. |
| | """ |
| | count = 0 |
| | result = [] |
| |
|
| | parts = string.split(default_placeholder) |
| | for part in parts[:-1]: |
| | result.append(part) |
| | if count < max_images: |
| | result.append(image_token) |
| | count += 1 |
| | elif default_placeholder != image_token: |
| | result.append(default_placeholder) |
| |
|
| | |
| | result.append(parts[-1]) |
| | return "".join(result) |
| |
|
| |
|
| | def flatten_image_list(images: List[List]): |
| | """ |
| | Takes in a list of lists of images, and returns a single list of all images in order. |
| | Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor. |
| | |
| | :param images: A list of lists of PIL images. |
| | :return: a list of PIL images, via concatenating all the sub-lists in order. |
| | """ |
| | return [image for image_list in images for image in image_list] |
| |
|
| |
|
| | def handle_stop_sequences( |
| | until: Union[str, List[str], None], eos: Optional[str] |
| | ) -> List[str]: |
| | """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token.""" |
| | if isinstance(until, str): |
| | until = [until] |
| | elif until is None: |
| | until = [] |
| | elif not isinstance(until, list): |
| | raise ValueError( |
| | f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" |
| | ) |
| |
|
| | if eos is not None and eos not in until: |
| | until.append(eos) |
| | return until |
| |
|