| import abc |
| import hashlib |
| import json |
| import logging |
| import os |
| from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union |
|
|
| import transformers |
| from sqlitedict import SqliteDict |
| from tqdm import tqdm |
|
|
| from dllm_eval import utils |
|
|
|
|
| eval_logger = logging.getLogger(__name__) |
|
|
| T = TypeVar("T", bound="LM") |
|
|
|
|
| class LM(abc.ABC): |
| def __init__(self) -> None: |
| """Defines the interface that should be implemented by all LM subclasses. |
| LMs are assumed to take text (strings) as input and yield strings as output |
| (inputs/outputs should be tokenization-agnostic.) |
| |
| """ |
| |
| self._rank = 0 |
| self._world_size = 1 |
| self.cache_hook = CacheHook(None) |
|
|
| @abc.abstractmethod |
| def loglikelihood(self, requests) -> List[Tuple[float, bool]]: |
| """Compute log-likelihood of generating a continuation from a context. |
| Downstream tasks should attempt to use loglikelihood instead of other |
| LM calls whenever possible. |
| |
| :param requests: list[Instance] |
| A list of Instance objects, with property `args` which returns a tuple (context, continuation). |
| `context: str` |
| Context string. Implementations of LM must be able to handle an |
| empty context string. |
| `continuation: str` |
| The continuation over which log likelihood will be calculated. If |
| there is a word boundary, the space should be in the continuation. |
| For example, context="hello" continuation=" world" is correct. |
| |
| :return: list[tuple[float, bool]] |
| A list of pairs (logprob, isgreedy) |
| `logprob: float` |
| The log probability of `continuation`. |
| `isgreedy`: |
| Whether `continuation` would be generated by greedy sampling from `context`. |
| """ |
| pass |
|
|
| @abc.abstractmethod |
| def loglikelihood_rolling(self, requests) -> List[float]: |
| """Compute full log-likelihood of a string, with no truncation, for perplexity computation |
| - We will use the full max context length of the model. |
| - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to |
| the max context length. |
| - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations |
| which may simply concatenate multiple documents together. |
| - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into |
| multiple chunks, the last input will still a full-sized context. |
| Example: |
| Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] |
| Prefix: BOS/EOS |
| Max context length: 4 |
| Resulting input/prediction pairs: |
| |
| INPUT: BOS 0 1 2 |
| PRED: 0 1 2 3 |
| |
| INPUT: 3 4 5 6 |
| PRED: 4 5 6 7 |
| |
| INPUT: 5 6 7 8 |
| PRED: 8 9 |
| |
| Observe that: |
| 1. Each token is predicted exactly once |
| 2. For the last pair, we provide the full context, but only score the last two tokens |
| |
| :param requests: list[Instance] |
| A list of Instance objects with property `args` which returns a tuple (context,). |
| string: str |
| String for which we are computing overall loglikelihood |
| :return: list[tuple[float]] |
| A list of tuples (logprob,) |
| logprob: float |
| The log probability of `context` conditioned on the BOS/EOS token. |
| Can also be overridden for custom cases by `prefix_token_id`. |
| """ |
| pass |
|
|
| |
| @abc.abstractmethod |
| def generate_until(self, requests) -> List[str]: |
| """Generate greedily until a stopping sequence |
| |
| :param requests: list[Instance] |
| A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). |
| context: str |
| Context string |
| gen_kwargs: dict |
| A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. |
| :return: list[str] |
| A list of model generated continuations. |
| continuation: str |
| The generated continuation. |
| """ |
| pass |
|
|
| def apply_chat_template( |
| self, chat_history: List[Dict[str, str]], add_generation_prompt=True |
| ) -> str: |
| """ |
| Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. |
| |
| :param chat_history: list[dict[str, str]] |
| A list of dictionaries with keys 'role' and 'content'. |
| Values are strings representing the role name and the content of the message, respectively. |
| :param add_generation_prompt: bool |
| Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message. |
| :return: str |
| A string representing the chat history in a format that can be used as input to the LM. |
| """ |
| raise NotImplementedError( |
| "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type." |
| ) |
|
|
| @classmethod |
| def create_from_arg_string( |
| cls: Type[T], arg_string: str, additional_config: Optional[dict] = None |
| ) -> T: |
| """ |
| Creates an instance of the LM class using the given argument string and additional config. |
| |
| Parameters: |
| - arg_string: A string containing arguments in the format key1=value1,key2=value2. |
| - additional_config: Optional dictionary containing additional configuration parameters. |
| |
| Returns: |
| - Instance of the LM class. |
| """ |
| additional_config = {} if additional_config is None else additional_config |
| args = utils.simple_parse_args_string(arg_string) |
| args2 = {k: v for k, v in additional_config.items() if v is not None} |
| return cls(**args, **args2) |
|
|
| @classmethod |
| def create_from_arg_obj( |
| cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None |
| ) -> T: |
| """ |
| Creates an instance of the LM class using the given arg_obj |
| |
| Parameters: |
| - arg_obj: A dict containing arguments in the format key1=value1,key2=value2. |
| - additional_config: Optional dictionary containing additional configuration parameters. |
| |
| Returns: |
| - Instance of the LM class. |
| """ |
|
|
| additional_config = {} if additional_config is None else additional_config |
| additional_config = { |
| k: v for k, v in additional_config.items() if v is not None |
| } |
|
|
| return cls(**arg_dict, **additional_config) |
|
|
| @property |
| def rank(self): |
| |
| |
| |
| return self._rank |
|
|
| @property |
| def world_size(self): |
| |
| |
| |
| return self._world_size |
|
|
| @property |
| def tokenizer_name(self) -> str: |
| """Must be defined for LM subclasses which implement Chat Templating. |
| Should return the name of the tokenizer or chat template used. |
| Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. |
| """ |
| raise NotImplementedError( |
| "To use this model with chat templates, please implement the 'tokenizer_name' property." |
| ) |
|
|
| def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
| """Returns the chat template structure for user/assistant messages if a template is provided. |
| This method is intended to be overridden in a subclass to define a specific chat template format. |
| For models that do not support chat templates, this method returns None by default. |
| """ |
|
|
| return "" |
|
|
| def set_cache_hook(self, cache_hook) -> None: |
| self.cache_hook = cache_hook |
|
|
|
|
| |
| def hash_args(attr, args): |
| dat = json.dumps([attr] + list(args)) |
| return hashlib.sha256(dat.encode("utf-8")).hexdigest() |
|
|
|
|
| class CacheHook: |
| def __init__(self, cachinglm) -> None: |
| if cachinglm is None: |
| self.dbdict = None |
| return |
|
|
| self.dbdict = cachinglm.dbdict |
|
|
| def add_partial(self, attr, req, res) -> None: |
| if self.dbdict is None: |
| return |
| hsh = hash_args(attr, req) |
| self.dbdict[hsh] = res |
|
|
|
|
| class CachingLM: |
| def __init__(self, lm, cache_db) -> None: |
| """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. |
| |
| :param lm: LM |
| Underlying LM |
| :param cache_db: str |
| Path to cache db |
| """ |
| self.lm = lm |
| self.cache_db = cache_db |
| if os.path.dirname(cache_db): |
| os.makedirs(os.path.dirname(cache_db), exist_ok=True) |
| self.dbdict = SqliteDict(cache_db, autocommit=True) |
|
|
| |
| lm.set_cache_hook(self.get_cache_hook()) |
|
|
| def __getattr__(self, attr: str): |
| lm_attr = getattr(self.lm, attr) |
| if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]: |
| eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") |
| return lm_attr |
|
|
| def fn(requests): |
| res = [] |
| remaining_reqs = [] |
| warned = False |
| |
| eval_logger.info( |
| f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." |
| ) |
| for req in tqdm(requests, desc="Checking cached requests"): |
| hsh = hash_args(attr, req.args) |
| if attr == "generate_until" and req.args[1].get("do_sample", False): |
| |
| |
| if not warned: |
| eval_logger.warning( |
| f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." |
| ) |
| warned = True |
| res.append(None) |
| remaining_reqs.append(req) |
| elif hsh in self.dbdict: |
| ob = self.dbdict[hsh] |
|
|
| assert ob is not None |
|
|
| res.append(ob) |
| else: |
| res.append(None) |
| remaining_reqs.append(req) |
| eval_logger.info( |
| f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" |
| ) |
| if remaining_reqs: |
| |
| rem_res = getattr(self.lm, attr)(remaining_reqs) |
| else: |
| rem_res = [] |
|
|
| |
| resptr = 0 |
| for req, r in zip(remaining_reqs, rem_res): |
| while res[resptr] is not None: |
| resptr += 1 |
|
|
| res[resptr] = r |
|
|
| |
| hsh = hash_args(attr, req.args) |
| self.dbdict[hsh] = r |
| self.dbdict.commit() |
|
|
| return res |
|
|
| return fn |
|
|
| def get_cache_hook(self): |
| return CacheHook(self) |
|
|
|
|
| class TemplateLM(LM): |
| """ |
| A class acting as intermediary between the LM base class |
| and boilerplate often included in other LM subclasses. |
| """ |
|
|
| tokenizer = None |
|
|
| @property |
| @abc.abstractmethod |
| def eot_token_id(self): |
| pass |
|
|
| @property |
| def prefix_token_id(self): |
| |
| return self.eot_token_id |
|
|
| @abc.abstractmethod |
| def tok_encode(self, string: str, **kwargs) -> List[int]: |
| """ |
| Tokenize a string using the model's tokenizer and return a list of token IDs. |
| """ |
| pass |
|
|
| @abc.abstractmethod |
| def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: |
| pass |
|
|
| def _encode_pair( |
| self, context: str, continuation: str |
| ) -> Tuple[List[int], List[int]]: |
| n_spaces = len(context) - len(context.rstrip()) |
| if n_spaces > 0: |
| continuation = context[-n_spaces:] + continuation |
| context = context[:-n_spaces] |
|
|
| model_class = getattr(self, "AUTO_MODEL_CLASS", None) |
|
|
| if model_class == transformers.AutoModelForSeq2SeqLM: |
| context_enc = self.tok_encode(context) |
| continuation_enc = self.tok_encode(continuation, add_special_tokens=False) |
| else: |
| whole_enc = self.tok_encode(context + continuation) |
| context_enc = self.tok_encode(context) |
|
|
| context_enc_len = len(context_enc) |
| continuation_enc = whole_enc[context_enc_len:] |
|
|
| return context_enc, continuation_enc |
|
|
| def loglikelihood( |
| self, requests, disable_tqdm: bool = False |
| ) -> List[Tuple[float, bool]]: |
| new_reqs = [] |
| for context, continuation in [req.args for req in requests]: |
| if context == "": |
| |
| context_enc, continuation_enc = ( |
| [self.prefix_token_id], |
| self.tok_encode(continuation), |
| ) |
| else: |
| context_enc, continuation_enc = self._encode_pair(context, continuation) |
|
|
| new_reqs.append(((context, continuation), context_enc, continuation_enc)) |
|
|
| return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) |
|
|
| @abc.abstractmethod |
| def loglikelihood_rolling( |
| self, requests, disable_tqdm: bool = False |
| ) -> List[float]: |
| pass |
|
|
| @abc.abstractmethod |
| def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: |
| pass |
|
|
| def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
| """ |
| Set and get the appropriate chat template for the model. |
| This method sets the tokenizer's chat_template and returns the template string for reproducibility. |
| |
| The template selection logic is adapted from the Transformers library's `apply_chat_template` |
| method in the Tokenizer class. The original implementation can be found at: |
| https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 |
| |
| This method ensures that the right template is chosen based on the following: |
| 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. |
| 1. If the model's tokenizer has multiple templates: |
| a. Use the specified template if it exists in the dictionary. |
| b. Use the default template from the list if no specific template is provided. |
| c. Raise an error if no default template exists and no specific template is provided. |
| 2. If the model's tokenizer has a single template or no template: |
| a. Use the tokenizer's chat template if available. |
| b. Fall back to the default chat template if no tokenizer chat template exists. |
| |
| Args: |
| chat_template (Union[bool, str]): Specifies the chat template to use. |
| - If False or None, no template is applied. |
| - If True, the default or only available template is used. |
| - If a string, the template with the matching name is used. |
| |
| Returns: |
| Optional[str]: The selected chat template, or None if no template is applied. |
| """ |
| if self.tokenizer is None: |
| return "" |
|
|
| if chat_template is False or chat_template is None: |
| eval_logger.warning( |
| "model.chat_template was called with the chat_template set to False or None. " |
| "Therefore no chat template will be applied. Make sure this is an intended behavior." |
| ) |
| return None |
|
|
| |
| if isinstance(chat_template, bool): |
| chat_template = None |
| using_default_template = False |
|
|
| |
| try: |
| template = ( |
| self.tokenizer.chat_template or self.tokenizer.default_chat_template |
| ) |
| except AttributeError: |
| return None |
|
|
| if isinstance(template, dict): |
| using_default_dict = self.tokenizer.chat_template is None |
|
|
| if chat_template is not None: |
| if chat_template in template: |
| selected_template = template[chat_template] |
| if using_default_dict: |
| using_default_template = True |
| else: |
| raise ValueError( |
| f"The specified chat template '{chat_template}' is not available. " |
| f"Available template names are {sorted(template.keys())}." |
| ) |
| else: |
| |
| if "default" in template: |
| selected_template = template["default"] |
| using_default_template = True |
| else: |
| raise ValueError( |
| "This model has multiple chat templates with no default specified! Please either pass a chat " |
| "template or the name of the template you wish to use to the `chat_template` argument. Available " |
| f"template names are {sorted(template.keys())}." |
| ) |
|
|
| |
| else: |
| |
| if isinstance(chat_template, str): |
| eval_logger.warning( |
| "Chat template name provided, but the tokenizer's chat template is not a dictionary. " |
| "Using the tokenizer's chat template or the default template instead." |
| ) |
| if self.tokenizer.chat_template is not None: |
| selected_template = self.tokenizer.chat_template |
| else: |
| selected_template = self.tokenizer.default_chat_template |
| using_default_template = True |
|
|
| if using_default_template: |
| eval_logger.warning( |
| "No chat template is set for this tokenizer, falling back to a default class-level template. This is " |
| "very error-prone, because models are often trained with templates different from the class default! " |
| "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " |
| "point any code depending on them will stop working. We recommend setting a valid chat template before " |
| "then to ensure that this model continues working without issues." |
| ) |
|
|
| return selected_template |
|
|