File size: 5,981 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import unittest
import pathlib
import sys
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
from lib_neutral_prompt import neutral_prompt_parser
class TestPromptParser(unittest.TestCase):
def setUp(self):
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
def test_simple_prompt_child_count(self):
self.assertEqual(len(self.simple_prompt.children), 1)
def test_simple_prompt_child_weight(self):
self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
def test_simple_prompt_child_prompt(self):
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
def test_square_weight_prompt(self):
prompt = "a [b c d e : f g h :1.5]"
parsed = neutral_prompt_parser.parse_root(prompt)
self.assertEqual(parsed.children[0].prompt, prompt)
composed_prompt = f"{prompt} AND_PERP other prompt"
parsed = neutral_prompt_parser.parse_root(composed_prompt)
self.assertEqual(parsed.children[0].prompt, prompt)
def test_and_prompt_child_count(self):
self.assertEqual(len(self.and_prompt.children), 2)
def test_and_prompt_child_weights_and_prompts(self):
self.assertEqual(self.and_prompt.children[0].weight, 1.0)
self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
self.assertEqual(self.and_prompt.children[1].weight, 2.0)
self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
def test_and_perp_prompt_child_count(self):
self.assertEqual(len(self.and_perp_prompt.children), 2)
def test_and_perp_prompt_child_types(self):
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
def test_and_perp_prompt_nested_child(self):
nested_child = self.and_perp_prompt.children[1]
self.assertEqual(nested_child.weight, 2.0)
self.assertEqual(nested_child.prompt.strip(), "goodbye")
def test_nested_and_perp_prompt_child_count(self):
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
def test_nested_and_perp_prompt_child_types(self):
self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
def test_nested_and_perp_prompt_nested_child_types(self):
nested_child = self.nested_and_perp_prompt.children[1].children[0]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
nested_child = self.nested_and_perp_prompt.children[1].children[1]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
def test_nested_and_perp_prompt_nested_child(self):
nested_child = self.nested_and_perp_prompt.children[1].children[1]
self.assertEqual(nested_child.weight, 3.0)
self.assertEqual(nested_child.prompt.strip(), "welcome")
def test_invalid_weight_child_count(self):
self.assertEqual(len(self.invalid_weight.children), 1)
def test_invalid_weight_child_weight(self):
self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
def test_invalid_weight_child_prompt(self):
self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
def test_and_salt_prompt_child_count(self):
self.assertEqual(len(self.and_salt_prompt.children), 2)
def test_and_salt_prompt_child_types(self):
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
def test_and_salt_prompt_nested_child(self):
nested_child = self.and_salt_prompt.children[1]
self.assertEqual(nested_child.weight, 2.0)
self.assertEqual(nested_child.prompt.strip(), "goodbye")
def test_nested_and_salt_prompt_child_count(self):
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
def test_nested_and_salt_prompt_child_types(self):
self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
def test_nested_and_salt_prompt_nested_child_types(self):
nested_child = self.nested_and_salt_prompt.children[1].children[0]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
nested_child = self.nested_and_salt_prompt.children[1].children[1]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
def test_nested_and_salt_prompt_nested_child(self):
nested_child = self.nested_and_salt_prompt.children[1].children[1]
self.assertEqual(nested_child.weight, 3.0)
self.assertEqual(nested_child.prompt.strip(), "welcome")
def test_start_with_prompt_editing(self):
prompt = "[(long shot:1.2):0.1] detail.."
res = neutral_prompt_parser.parse_root(prompt)
self.assertEqual(res.children[0].weight, 1.0)
self.assertEqual(res.children[0].prompt, prompt)
if __name__ == '__main__':
unittest.main()
|