| import unittest | |
| import pathlib | |
| import sys | |
| sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) | |
| from lib_neutral_prompt import neutral_prompt_parser | |
| class TestPromptParser(unittest.TestCase): | |
| def setUp(self): | |
| self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0") | |
| self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0") | |
| self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0") | |
| self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0") | |
| self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]") | |
| self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]") | |
| self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float") | |
| def test_simple_prompt_child_count(self): | |
| self.assertEqual(len(self.simple_prompt.children), 1) | |
| def test_simple_prompt_child_weight(self): | |
| self.assertEqual(self.simple_prompt.children[0].weight, 1.0) | |
| def test_simple_prompt_child_prompt(self): | |
| self.assertEqual(self.simple_prompt.children[0].prompt, "hello ") | |
| def test_square_weight_prompt(self): | |
| prompt = "a [b c d e : f g h :1.5]" | |
| parsed = neutral_prompt_parser.parse_root(prompt) | |
| self.assertEqual(parsed.children[0].prompt, prompt) | |
| composed_prompt = f"{prompt} AND_PERP other prompt" | |
| parsed = neutral_prompt_parser.parse_root(composed_prompt) | |
| self.assertEqual(parsed.children[0].prompt, prompt) | |
| def test_and_prompt_child_count(self): | |
| self.assertEqual(len(self.and_prompt.children), 2) | |
| def test_and_prompt_child_weights_and_prompts(self): | |
| self.assertEqual(self.and_prompt.children[0].weight, 1.0) | |
| self.assertEqual(self.and_prompt.children[0].prompt, "hello ") | |
| self.assertEqual(self.and_prompt.children[1].weight, 2.0) | |
| self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ") | |
| def test_and_perp_prompt_child_count(self): | |
| self.assertEqual(len(self.and_perp_prompt.children), 2) | |
| def test_and_perp_prompt_child_types(self): | |
| self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt) | |
| self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt) | |
| def test_and_perp_prompt_nested_child(self): | |
| nested_child = self.and_perp_prompt.children[1] | |
| self.assertEqual(nested_child.weight, 2.0) | |
| self.assertEqual(nested_child.prompt.strip(), "goodbye") | |
| def test_nested_and_perp_prompt_child_count(self): | |
| self.assertEqual(len(self.nested_and_perp_prompt.children), 2) | |
| def test_nested_and_perp_prompt_child_types(self): | |
| self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt) | |
| self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt) | |
| def test_nested_and_perp_prompt_nested_child_types(self): | |
| nested_child = self.nested_and_perp_prompt.children[1].children[0] | |
| self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) | |
| nested_child = self.nested_and_perp_prompt.children[1].children[1] | |
| self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) | |
| def test_nested_and_perp_prompt_nested_child(self): | |
| nested_child = self.nested_and_perp_prompt.children[1].children[1] | |
| self.assertEqual(nested_child.weight, 3.0) | |
| self.assertEqual(nested_child.prompt.strip(), "welcome") | |
| def test_invalid_weight_child_count(self): | |
| self.assertEqual(len(self.invalid_weight.children), 1) | |
| def test_invalid_weight_child_weight(self): | |
| self.assertEqual(self.invalid_weight.children[0].weight, 1.0) | |
| def test_invalid_weight_child_prompt(self): | |
| self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float") | |
| def test_and_salt_prompt_child_count(self): | |
| self.assertEqual(len(self.and_salt_prompt.children), 2) | |
| def test_and_salt_prompt_child_types(self): | |
| self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt) | |
| self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt) | |
| def test_and_salt_prompt_nested_child(self): | |
| nested_child = self.and_salt_prompt.children[1] | |
| self.assertEqual(nested_child.weight, 2.0) | |
| self.assertEqual(nested_child.prompt.strip(), "goodbye") | |
| def test_nested_and_salt_prompt_child_count(self): | |
| self.assertEqual(len(self.nested_and_salt_prompt.children), 2) | |
| def test_nested_and_salt_prompt_child_types(self): | |
| self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt) | |
| self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt) | |
| def test_nested_and_salt_prompt_nested_child_types(self): | |
| nested_child = self.nested_and_salt_prompt.children[1].children[0] | |
| self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) | |
| nested_child = self.nested_and_salt_prompt.children[1].children[1] | |
| self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) | |
| def test_nested_and_salt_prompt_nested_child(self): | |
| nested_child = self.nested_and_salt_prompt.children[1].children[1] | |
| self.assertEqual(nested_child.weight, 3.0) | |
| self.assertEqual(nested_child.prompt.strip(), "welcome") | |
| def test_start_with_prompt_editing(self): | |
| prompt = "[(long shot:1.2):0.1] detail.." | |
| res = neutral_prompt_parser.parse_root(prompt) | |
| self.assertEqual(res.children[0].weight, 1.0) | |
| self.assertEqual(res.children[0].prompt, prompt) | |
| if __name__ == '__main__': | |
| unittest.main() | |