| import unittest | |
| from opencompass.models.base_api import APITemplateParser | |
| from opencompass.utils.prompt import PromptList | |
| class TestAPITemplateParser(unittest.TestCase): | |
| def setUp(self): | |
| self.parser = APITemplateParser() | |
| self.prompt = PromptList([ | |
| { | |
| 'section': 'begin', | |
| 'pos': 'begin' | |
| }, | |
| 'begin', | |
| { | |
| 'role': 'SYSTEM', | |
| 'fallback_role': 'HUMAN', | |
| 'prompt': 'system msg' | |
| }, | |
| { | |
| 'section': 'ice', | |
| 'pos': 'begin' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U0' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B0' | |
| }, | |
| { | |
| 'section': 'ice', | |
| 'pos': 'end' | |
| }, | |
| { | |
| 'section': 'begin', | |
| 'pos': 'end' | |
| }, | |
| { | |
| 'section': 'round', | |
| 'pos': 'begin' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U1' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B1' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U2' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B2' | |
| }, | |
| { | |
| 'section': 'round', | |
| 'pos': 'end' | |
| }, | |
| { | |
| 'section': 'end', | |
| 'pos': 'begin' | |
| }, | |
| 'end', | |
| { | |
| 'section': 'end', | |
| 'pos': 'end' | |
| }, | |
| ]) | |
| def test_parse_template_str_input(self): | |
| prompt = self.parser.parse_template('Hello, world!', mode='gen') | |
| self.assertEqual(prompt, 'Hello, world!') | |
| prompt = self.parser.parse_template('Hello, world!', mode='ppl') | |
| self.assertEqual(prompt, 'Hello, world!') | |
| def test_parse_template_list_input(self): | |
| prompt = self.parser.parse_template(['Hello', 'world'], mode='gen') | |
| self.assertEqual(prompt, ['Hello', 'world']) | |
| prompt = self.parser.parse_template(['Hello', 'world'], mode='ppl') | |
| self.assertEqual(prompt, ['Hello', 'world']) | |
| def test_parse_template_PromptList_input_no_meta_template(self): | |
| prompt = self.parser.parse_template(self.prompt, mode='gen') | |
| self.assertEqual(prompt, | |
| 'begin\nsystem msg\nU0\nB0\nU1\nB1\nU2\nB2\nend') | |
| prompt = self.parser.parse_template(self.prompt, mode='ppl') | |
| self.assertEqual(prompt, | |
| 'begin\nsystem msg\nU0\nB0\nU1\nB1\nU2\nB2\nend') | |
| def test_parse_template_PromptList_input_with_meta_template(self): | |
| parser = APITemplateParser(meta_template=dict(round=[ | |
| dict(role='HUMAN', api_role='HUMAN'), | |
| dict(role='BOT', api_role='BOT', generate=True) | |
| ], )) | |
| with self.assertWarns(Warning): | |
| prompt = parser.parse_template(self.prompt, mode='gen') | |
| self.assertEqual( | |
| prompt, | |
| PromptList([ | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'system msg\nU0' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B0' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U1' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B1' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U2' | |
| }, | |
| ])) | |
| with self.assertWarns(Warning): | |
| prompt = parser.parse_template(self.prompt, mode='ppl') | |
| self.assertEqual( | |
| prompt, | |
| PromptList([ | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'system msg\nU0' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B0' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U1' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B1' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U2' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B2' | |
| }, | |
| ])) | |
| parser = APITemplateParser(meta_template=dict( | |
| round=[ | |
| dict(role='HUMAN', api_role='HUMAN'), | |
| dict(role='BOT', api_role='BOT', generate=True) | |
| ], | |
| reserved_roles=[ | |
| dict(role='SYSTEM', api_role='SYSTEM'), | |
| ], | |
| )) | |
| with self.assertWarns(Warning): | |
| prompt = parser.parse_template(self.prompt, mode='gen') | |
| self.assertEqual( | |
| prompt, | |
| PromptList([ | |
| { | |
| 'role': 'SYSTEM', | |
| 'prompt': 'system msg' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U0' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B0' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U1' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B1' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U2' | |
| }, | |
| ])) | |
| with self.assertWarns(Warning): | |
| prompt = parser.parse_template(self.prompt, mode='ppl') | |
| self.assertEqual( | |
| prompt, | |
| PromptList([ | |
| { | |
| 'role': 'SYSTEM', | |
| 'prompt': 'system msg' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U0' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B0' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U1' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B1' | |
| }, | |
| { | |
| 'role': 'HUMAN', | |
| 'prompt': 'U2' | |
| }, | |
| { | |
| 'role': 'BOT', | |
| 'prompt': 'B2' | |
| }, | |
| ])) | |