File size: 3,230 Bytes
ad4e05b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# This code is released without a license.

import yaml

class TagSwap:

    """
    Accepts a comma-delimited prompt and YAML input. Operates in two modes:

    In replace mode, replaces tags if they are present or injects them if they are absent.
    In select mode, adds a tag if the tag is present. Injects if it is absent.

    The YAML format is:

    select|replace:
     - tag_name:
         p: present action
         a: absent action
    """

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "tags": ("STRING", {"default": '', "multiline": True, "forceInput": True}),
                "rules": ("STRING", {"default": '', "multiline": True })
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("replacements",)    
    FUNCTION = "apply"
    CATEGORY = "utils"

    def __init__(self):
        pass

    def replace(self, needles, haystack):
        for tag in haystack:
            if tag in needles:
                if 'p' in needles[tag]:
                    yield needles[tag]['p']
            else:
                yield tag
        
        for needle, actions in needles.items():
            if 'a' in actions and needle not in haystack:
                yield actions['a']

    def select(self, needles, haystack):
        for tag in needles:
            if tag in haystack:
                if 'p' in needles[tag]:
                    yield needles[tag]['p']
            else:
                if 'a' in needles[tag]:
                    yield needles[tag]['a']

    def apply(self, tags, rules):
        haystack = [ tag.strip() for tag in tags.split(',') ]
        input = yaml.safe_load(rules)
        if 'replace' in input:
            needles = input['replace']
            return ( ', '.join(list(self.replace(needles, haystack))), )
        if 'select' in input:
            needles = input['select']
            return ( ', '.join(list(self.select(needles, haystack))), )
        raise Exception("Must use either 'replace' or 'select'")

from collections import OrderedDict

class PromptMerge:

    """
    Takes a list of prompts. Merges identical, adjacent prompts. Outputs a 
    format compatible with BatchPromptScheudle.
    """

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "prompts": ("STRING",{"default": [], "forceInput": True}),
            }
        }
    
    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("prompt",)
    INPUT_IS_LIST = True
    FUNCTION = "apply"
    CATEGORY = "utils"

    def apply(self, prompts):
        print("Called with %s"%prompts)
        merged = OrderedDict()
        last = None
        for i, prompt in enumerate(prompts):
            if prompt != last:
                merged[i] = prompt
                last = prompt
        travel = [ """ "%d": "%s" """ % (index, prompt) for (index, prompt) in merged.items() ]
        print(travel)
        return (', '.join(travel),)

NODE_CLASS_MAPPINGS = {
    "TagSwap": TagSwap,
    "PromptMerge": PromptMerge,
}

# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
    "TagSwap": "Tag Swap",
    "PromptMerge": "Prompt Merge",
}