| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import Callable, Type |
|
|
| import torch |
|
|
|
|
| PromptFormatFnReturnType = dict[str, list[torch.Tensor]] |
| PromptFormatSignature = Callable[[object, object], PromptFormatFnReturnType] |
| PROMPT_FORMAT_FNS: dict[tuple[Type, Type] | Type, PromptFormatSignature] = {} |
|
|
|
|
| def registered_prompt_format_fn(example_type: Type, formatter_type: Type = None): |
| """ |
| Decorator for registering text prompt functions. |
| It allows to select the right prompt formatting function based on the types of the |
| example and the prompt formatter, allowing different strategies for formatting different |
| types of data with different prompt formats. |
| |
| When formatter_type is set None, registers a default prompt format function for a given data type. |
| |
| Example:: |
| |
| >>> @registered_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) |
| ... def my_src_tgt_text_prompt(example, formatter): |
| ... pass |
| ... |
| ... @registered_prompt_format_fn(Cut, Llama2PromptFormatter) |
| ... def my_audio_prompt(example, formatter): |
| ... pass |
| ... |
| ... prompt_fn = get_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) |
| """ |
|
|
| def _decorator(prompt_fn: Callable[[object, object], dict[str, list[torch.Tensor]]]): |
| global PROMPT_FORMAT_FNS |
| if formatter_type is None: |
| PROMPT_FORMAT_FNS[example_type] = prompt_fn |
| else: |
| PROMPT_FORMAT_FNS[(example_type, formatter_type)] = prompt_fn |
| return prompt_fn |
|
|
| return _decorator |
|
|
|
|
| def get_prompt_format_fn(example: Type | object, prompt: Type | object = None) -> PromptFormatSignature: |
| """See the documentation of ``text_prompt_formatter`` above.""" |
|
|
| |
| if not isinstance(example, type): |
| example = type(example) |
| if not isinstance(prompt, type): |
| prompt = type(prompt) |
|
|
| |
| for example_subtype in example.mro(): |
|
|
| |
| |
| for prompt_subtype in prompt.mro(): |
| if (example_subtype, prompt_subtype) in PROMPT_FORMAT_FNS: |
| return PROMPT_FORMAT_FNS[(example_subtype, prompt_subtype)] |
|
|
| |
| if example_subtype in PROMPT_FORMAT_FNS: |
| return PROMPT_FORMAT_FNS[example_subtype] |
|
|
| raise ValueError( |
| f"Unknown prompt format function for ({example}, {prompt}). " |
| f"Available choices are: {list(PROMPT_FORMAT_FNS.keys())}" |
| ) |
|
|
|
|
| def apply_prompt_format_fn(example: object | Type, prompt: object | Type) -> PromptFormatFnReturnType: |
| """ |
| Utility for resolving the prompt format function and applying it to an example in one go. |
| See the documentation of ``text_prompt_formatter`` above. |
| """ |
| fn = get_prompt_format_fn(example, prompt) |
| return fn(example, prompt) |
|
|