| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| from copy import deepcopy |
| from typing import Dict, List, Union |
|
|
| from mmengine.config import ConfigDict |
|
|
|
|
| def safe_format(input_str: str, **kwargs) -> str: |
| """Safely formats a string with the given keyword arguments. If a keyword |
| is not found in the string, it will be ignored. |
| |
| Args: |
| input_str (str): The string to be formatted. |
| **kwargs: The keyword arguments to be used for formatting. |
| |
| Returns: |
| str: The formatted string. |
| """ |
| for k, v in kwargs.items(): |
| input_str = input_str.replace(f'{{{k}}}', str(v)) |
| return input_str |
|
|
|
|
| def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: |
| """Get the hash of the prompt configuration. |
| |
| Args: |
| dataset_cfg (ConfigDict or list[ConfigDict]): The dataset |
| configuration. |
| |
| Returns: |
| str: The hash of the prompt configuration. |
| """ |
| if isinstance(dataset_cfg, list): |
| if len(dataset_cfg) == 1: |
| dataset_cfg = dataset_cfg[0] |
| else: |
| hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg]) |
| hash_object = hashlib.sha256(hashes.encode()) |
| return hash_object.hexdigest() |
| if 'reader_cfg' in dataset_cfg.infer_cfg: |
| |
| reader_cfg = dict(type='DatasetReader', |
| input_columns=dataset_cfg.reader_cfg.input_columns, |
| output_column=dataset_cfg.reader_cfg.output_column) |
| dataset_cfg.infer_cfg.reader = reader_cfg |
| if 'train_split' in dataset_cfg.infer_cfg.reader_cfg: |
| dataset_cfg.infer_cfg.retriever[ |
| 'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][ |
| 'train_split'] |
| if 'test_split' in dataset_cfg.infer_cfg.reader_cfg: |
| dataset_cfg.infer_cfg.retriever[ |
| 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split |
| for k, v in dataset_cfg.infer_cfg.items(): |
| dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] |
| |
| if 'fix_id_list' in dataset_cfg.infer_cfg.retriever: |
| fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list') |
| dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list |
| d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) |
| hash_object = hashlib.sha256(d_json.encode()) |
| return hash_object.hexdigest() |
|
|
|
|
| class PromptList(list): |
| """An enhanced list, used for intermidate representation of a prompt.""" |
|
|
| def format(self, **kwargs) -> PromptList: |
| """Replaces all instances of 'src' in the PromptList with 'dst'. |
| |
| Args: |
| src (str): The string to be replaced. |
| dst (str or PromptList): The string or PromptList to replace with. |
| |
| Returns: |
| PromptList: A new PromptList with 'src' replaced by 'dst'. |
| |
| Raises: |
| TypeError: If 'dst' is a PromptList and 'src' is in a dictionary's |
| 'prompt' key. |
| """ |
| new_list = PromptList() |
| for item in self: |
| if isinstance(item, Dict): |
| new_item = deepcopy(item) |
| if 'prompt' in item: |
| new_item['prompt'] = safe_format(item['prompt'], **kwargs) |
| new_list.append(new_item) |
| else: |
| new_list.append(safe_format(item, **kwargs)) |
| return new_list |
|
|
| def replace(self, src: str, dst: Union[str, PromptList]) -> PromptList: |
| """Replaces all instances of 'src' in the PromptList with 'dst'. |
| |
| Args: |
| src (str): The string to be replaced. |
| dst (str or PromptList): The string or PromptList to replace with. |
| |
| Returns: |
| PromptList: A new PromptList with 'src' replaced by 'dst'. |
| |
| Raises: |
| TypeError: If 'dst' is a PromptList and 'src' is in a dictionary's |
| 'prompt' key. |
| """ |
| new_list = PromptList() |
| for item in self: |
| if isinstance(item, str): |
| if isinstance(dst, str): |
| new_list.append(item.replace(src, dst)) |
| elif isinstance(dst, PromptList): |
| split_str = item.split(src) |
| for i, split_item in enumerate(split_str): |
| if split_item: |
| new_list.append(split_item) |
| if i < len(split_str) - 1: |
| new_list += dst |
| elif isinstance(item, Dict): |
| new_item = deepcopy(item) |
| if 'prompt' in item: |
| if src in item['prompt']: |
| if isinstance(dst, PromptList): |
| raise TypeError( |
| f'Found keyword {src} in a dictionary\'s ' |
| 'prompt key. Cannot replace with a ' |
| 'PromptList.') |
| new_item['prompt'] = new_item['prompt'].replace( |
| src, dst) |
| new_list.append(new_item) |
| else: |
| new_list.append(item.replace(src, dst)) |
| return new_list |
|
|
| def __add__(self, other: Union[str, PromptList]) -> PromptList: |
| """Adds a string or another PromptList to this PromptList. |
| |
| Args: |
| other (str or PromptList): The string or PromptList to be added. |
| |
| Returns: |
| PromptList: A new PromptList that is the result of the addition. |
| """ |
| if not other: |
| return PromptList([*self]) |
| if isinstance(other, str): |
| return PromptList(self + [other]) |
| else: |
| return PromptList(super().__add__(other)) |
|
|
| def __radd__(self, other: Union[str, PromptList]) -> PromptList: |
| """Implements addition when the PromptList is on the right side of the |
| '+' operator. |
| |
| Args: |
| other (str or PromptList): The string or PromptList to be added. |
| |
| Returns: |
| PromptList: A new PromptList that is the result of the addition. |
| """ |
| if not other: |
| return PromptList([*self]) |
| if isinstance(other, str): |
| return PromptList([other, *self]) |
| else: |
| return PromptList(other + self) |
|
|
| def __iadd__(self, other: Union[str, PromptList]) -> PromptList: |
| """Implements in-place addition for the PromptList. |
| |
| Args: |
| other (str or PromptList): The string or PromptList to be added. |
| |
| Returns: |
| PromptList: The updated PromptList. |
| """ |
| if not other: |
| return self |
| if isinstance(other, str): |
| self.append(other) |
| else: |
| super().__iadd__(other) |
| return self |
|
|
| def __str__(self) -> str: |
| """Converts the PromptList into a string. |
| |
| Returns: |
| str: The string representation of the PromptList. |
| |
| Raises: |
| TypeError: If there's an item in the PromptList that is not a |
| string or dictionary. |
| """ |
| res = [] |
| for item in self: |
| if isinstance(item, str): |
| res.append(item) |
| elif isinstance(item, dict): |
| if 'prompt' in item: |
| res.append(item['prompt']) |
| else: |
| raise TypeError('Invalid type in prompt list when ' |
| 'converting to string') |
| return ''.join(res) |
|
|