| | import unittest |
| |
|
| | from opencompass.openicl.icl_prompt_template import PromptTemplate |
| | from opencompass.utils.prompt import PromptList |
| |
|
| |
|
| | class TestPromptTemplate(unittest.TestCase): |
| |
|
| | def setUp(self) -> None: |
| | self.qa_template = dict(begin=[ |
| | dict(role='SYSTEM', fallback_role='HUMAN', prompt='instruct'), |
| | '</E>', |
| | ], |
| | round=[ |
| | dict(role='HUMAN', prompt='{input}'), |
| | dict(role='BOT', prompt='Answer: {answer}') |
| | ]) |
| | self.multiround_qa_template = dict(round=[ |
| | dict(role='HUMAN', prompt='{input}'), |
| | dict(role='BOT', prompt='A1', end='\n'), |
| | dict(role='HUMAN', prompt='Q1'), |
| | dict(role='BOT', prompt='A2', end='\n\n'), |
| | dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| | dict(role='BOT', prompt='Answer: {answer}') |
| | ]) |
| | self.entry = {'input': 'Hello, how are you?', 'answer': 'Good.'} |
| |
|
| | def test_init(self): |
| | template = 'Translate the following English text to French: {input}.' |
| | pt = PromptTemplate(template) |
| |
|
| | self.assertEqual(pt.template, template) |
| |
|
| | def test_generate_ice_item(self): |
| | |
| | template = 'Translate the following English text to French: {input}.' |
| | pt = PromptTemplate(template) |
| | label = None |
| | ice = pt.generate_ice_item(self.entry, label) |
| |
|
| | self.assertEqual(ice, |
| | ('Translate the following English text to French: ' |
| | 'Hello, how are you?.')) |
| |
|
| | |
| | pt = PromptTemplate(self.qa_template, ice_token='</E>') |
| | label = None |
| | ice = pt.generate_ice_item(self.entry, label) |
| |
|
| | ice_target = PromptList([ |
| | { |
| | 'section': 'ice', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='Hello, how are you?'), |
| | dict(role='BOT', prompt='Answer: Good.'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| | self.assertEqual(ice, ice_target) |
| |
|
| | |
| | pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>') |
| | label = None |
| | ice = pt.generate_ice_item(self.entry, label) |
| |
|
| | ice_target = PromptList([ |
| | { |
| | 'section': 'ice', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='Hello, how are you?'), |
| | dict(role='BOT', prompt='A1', end='\n'), |
| | dict(role='HUMAN', prompt='Q1'), |
| | dict(role='BOT', prompt='A2', end='\n\n'), |
| | dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| | dict(role='BOT', prompt='Answer: Good.'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| | self.assertEqual(ice, ice_target) |
| |
|
| | def test_generate_label_prompt_item(self): |
| | |
| | template = ('</E> Translate the following English text to French: ' |
| | '{input}.') |
| | pt = PromptTemplate(template, ice_token='</E>') |
| | ice = 'ICE' |
| | label = None |
| | prompt = pt.generate_label_prompt_item(self.entry, ice, label) |
| |
|
| | self.assertEqual( |
| | prompt, ('ICE Translate the following English text to French: ' |
| | 'Hello, how are you?.')) |
| |
|
| | ice = PromptList([ |
| | { |
| | 'section': 'ice', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='h1'), |
| | dict(role='BOT', prompt='b1'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| |
|
| | |
| | pt = PromptTemplate(self.qa_template, ice_token='</E>') |
| | label = None |
| | prompt = pt.generate_label_prompt_item(self.entry, ice, label) |
| | target = PromptList([ |
| | { |
| | 'section': 'begin', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='SYSTEM', fallback_role='HUMAN', prompt='instruct'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='h1'), |
| | dict(role='BOT', prompt='b1'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'end' |
| | }, |
| | { |
| | 'section': 'begin', |
| | 'pos': 'end' |
| | }, |
| | { |
| | 'section': 'round', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='Hello, how are you?'), |
| | dict(role='BOT', prompt='Answer: Good.'), |
| | { |
| | 'section': 'round', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| | self.assertEqual(prompt, target) |
| |
|
| | |
| | pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>') |
| | label = None |
| | prompt = pt.generate_label_prompt_item(self.entry, ice, label) |
| | target = PromptList([ |
| | { |
| | 'section': 'round', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='Hello, how are you?'), |
| | dict(role='BOT', prompt='A1', end='\n'), |
| | dict(role='HUMAN', prompt='Q1'), |
| | dict(role='BOT', prompt='A2', end='\n\n'), |
| | dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| | dict(role='BOT', prompt='Answer: Good.'), |
| | { |
| | 'section': 'round', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| | self.assertEqual(prompt, target) |
| |
|
| | def test_generate_item(self): |
| | |
| | template = 'Translate the following English text to French: {input}.' |
| | pt = PromptTemplate(template) |
| | item = pt.generate_item(self.entry) |
| |
|
| | self.assertEqual(item, |
| | ('Translate the following English text to French: ' |
| | 'Hello, how are you?.')) |
| |
|
| | ice = PromptList([ |
| | { |
| | 'section': 'ice', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='h1'), |
| | dict(role='BOT', prompt='b1'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| |
|
| | |
| | pt = PromptTemplate(self.qa_template, ice_token='</E>') |
| | prompt = pt.generate_item(self.entry, ice_field_replace_token=ice) |
| | target = PromptList([ |
| | { |
| | 'section': 'begin', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='SYSTEM', fallback_role='HUMAN', prompt='instruct'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='h1'), |
| | dict(role='BOT', prompt='b1'), |
| | { |
| | 'section': 'ice', |
| | 'pos': 'end' |
| | }, |
| | { |
| | 'section': 'begin', |
| | 'pos': 'end' |
| | }, |
| | { |
| | 'section': 'round', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='Hello, how are you?'), |
| | dict(role='BOT', prompt='Answer: Good.'), |
| | { |
| | 'section': 'round', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| | self.assertEqual(prompt, target) |
| |
|
| | pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>') |
| | prompt = pt.generate_item(self.entry, ice_field_replace_token=ice) |
| | target = PromptList([ |
| | { |
| | 'section': 'round', |
| | 'pos': 'begin' |
| | }, |
| | dict(role='HUMAN', prompt='Hello, how are you?'), |
| | dict(role='BOT', prompt='A1', end='\n'), |
| | dict(role='HUMAN', prompt='Q1'), |
| | dict(role='BOT', prompt='A2', end='\n\n'), |
| | dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| | dict(role='BOT', prompt='Answer: Good.'), |
| | { |
| | 'section': 'round', |
| | 'pos': 'end' |
| | }, |
| | ]) |
| | self.assertEqual(prompt, target) |
| |
|