File size: 5,981 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import unittest
import pathlib
import sys
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
from lib_neutral_prompt import neutral_prompt_parser


class TestPromptParser(unittest.TestCase):
    def setUp(self):
        self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
        self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
        self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
        self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
        self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
        self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
        self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")

    def test_simple_prompt_child_count(self):
        self.assertEqual(len(self.simple_prompt.children), 1)

    def test_simple_prompt_child_weight(self):
        self.assertEqual(self.simple_prompt.children[0].weight, 1.0)

    def test_simple_prompt_child_prompt(self):
        self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")

    def test_square_weight_prompt(self):
        prompt = "a [b c d e : f g h :1.5]"
        parsed = neutral_prompt_parser.parse_root(prompt)
        self.assertEqual(parsed.children[0].prompt, prompt)

        composed_prompt = f"{prompt} AND_PERP other prompt"
        parsed = neutral_prompt_parser.parse_root(composed_prompt)
        self.assertEqual(parsed.children[0].prompt, prompt)

    def test_and_prompt_child_count(self):
        self.assertEqual(len(self.and_prompt.children), 2)

    def test_and_prompt_child_weights_and_prompts(self):
        self.assertEqual(self.and_prompt.children[0].weight, 1.0)
        self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
        self.assertEqual(self.and_prompt.children[1].weight, 2.0)
        self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")

    def test_and_perp_prompt_child_count(self):
        self.assertEqual(len(self.and_perp_prompt.children), 2)

    def test_and_perp_prompt_child_types(self):
        self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
        self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)

    def test_and_perp_prompt_nested_child(self):
        nested_child = self.and_perp_prompt.children[1]
        self.assertEqual(nested_child.weight, 2.0)
        self.assertEqual(nested_child.prompt.strip(), "goodbye")

    def test_nested_and_perp_prompt_child_count(self):
        self.assertEqual(len(self.nested_and_perp_prompt.children), 2)

    def test_nested_and_perp_prompt_child_types(self):
        self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
        self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)

    def test_nested_and_perp_prompt_nested_child_types(self):
        nested_child = self.nested_and_perp_prompt.children[1].children[0]
        self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
        nested_child = self.nested_and_perp_prompt.children[1].children[1]
        self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)

    def test_nested_and_perp_prompt_nested_child(self):
        nested_child = self.nested_and_perp_prompt.children[1].children[1]
        self.assertEqual(nested_child.weight, 3.0)
        self.assertEqual(nested_child.prompt.strip(), "welcome")

    def test_invalid_weight_child_count(self):
        self.assertEqual(len(self.invalid_weight.children), 1)

    def test_invalid_weight_child_weight(self):
        self.assertEqual(self.invalid_weight.children[0].weight, 1.0)

    def test_invalid_weight_child_prompt(self):
        self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")

    def test_and_salt_prompt_child_count(self):
        self.assertEqual(len(self.and_salt_prompt.children), 2)

    def test_and_salt_prompt_child_types(self):
        self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
        self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)

    def test_and_salt_prompt_nested_child(self):
        nested_child = self.and_salt_prompt.children[1]
        self.assertEqual(nested_child.weight, 2.0)
        self.assertEqual(nested_child.prompt.strip(), "goodbye")

    def test_nested_and_salt_prompt_child_count(self):
        self.assertEqual(len(self.nested_and_salt_prompt.children), 2)

    def test_nested_and_salt_prompt_child_types(self):
        self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
        self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)

    def test_nested_and_salt_prompt_nested_child_types(self):
        nested_child = self.nested_and_salt_prompt.children[1].children[0]
        self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
        nested_child = self.nested_and_salt_prompt.children[1].children[1]
        self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)

    def test_nested_and_salt_prompt_nested_child(self):
        nested_child = self.nested_and_salt_prompt.children[1].children[1]
        self.assertEqual(nested_child.weight, 3.0)
        self.assertEqual(nested_child.prompt.strip(), "welcome")

    def test_start_with_prompt_editing(self):
        prompt = "[(long shot:1.2):0.1] detail.."
        res = neutral_prompt_parser.parse_root(prompt)
        self.assertEqual(res.children[0].weight, 1.0)
        self.assertEqual(res.children[0].prompt, prompt)


if __name__ == '__main__':
    unittest.main()