| |
| from functools import partial |
|
|
| from mmengine.utils.misc import get_object_from_string |
|
|
|
|
| def template_map_fn(example, template): |
| conversation = example.get('conversation', []) |
| for i, single_turn_conversation in enumerate(conversation): |
| input = single_turn_conversation.get('input', '') |
| if input is None: |
| input = '' |
| input_text = template.INSTRUCTION.format(input=input, round=i + 1) |
| system = single_turn_conversation.get('system', '') |
| if system != '' and system is not None: |
| system = template.SYSTEM.format(system=system) |
| input_text = system + input_text |
| single_turn_conversation['input'] = input_text |
|
|
| if template.get('SUFFIX', None): |
| output_text = single_turn_conversation.get('output', '') |
| output_text += template.SUFFIX |
| single_turn_conversation['output'] = output_text |
|
|
| |
| single_turn_conversation['need_eos_token'] = \ |
| not template.get('SUFFIX_AS_EOS', False) |
| single_turn_conversation['sep'] = template.get('SEP', '') |
|
|
| return {'conversation': conversation} |
|
|
|
|
| def template_map_fn_factory(template): |
| if isinstance(template, str): |
| template = get_object_from_string(template) |
| return partial(template_map_fn, template=template) |
|
|