custom_node / e7ai-animate.py
kazuhasasd's picture
Upload e7ai-animate.py
ad4e05b verified
# This code is released without a license.
import yaml
class TagSwap:
"""
Accepts a comma-delimited prompt and YAML input. Operates in two modes:
In replace mode, replaces tags if they are present or injects them if they are absent.
In select mode, adds a tag if the tag is present. Injects if it is absent.
The YAML format is:
select|replace:
- tag_name:
p: present action
a: absent action
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"tags": ("STRING", {"default": '', "multiline": True, "forceInput": True}),
"rules": ("STRING", {"default": '', "multiline": True })
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("replacements",)
FUNCTION = "apply"
CATEGORY = "utils"
def __init__(self):
pass
def replace(self, needles, haystack):
for tag in haystack:
if tag in needles:
if 'p' in needles[tag]:
yield needles[tag]['p']
else:
yield tag
for needle, actions in needles.items():
if 'a' in actions and needle not in haystack:
yield actions['a']
def select(self, needles, haystack):
for tag in needles:
if tag in haystack:
if 'p' in needles[tag]:
yield needles[tag]['p']
else:
if 'a' in needles[tag]:
yield needles[tag]['a']
def apply(self, tags, rules):
haystack = [ tag.strip() for tag in tags.split(',') ]
input = yaml.safe_load(rules)
if 'replace' in input:
needles = input['replace']
return ( ', '.join(list(self.replace(needles, haystack))), )
if 'select' in input:
needles = input['select']
return ( ', '.join(list(self.select(needles, haystack))), )
raise Exception("Must use either 'replace' or 'select'")
from collections import OrderedDict
class PromptMerge:
"""
Takes a list of prompts. Merges identical, adjacent prompts. Outputs a
format compatible with BatchPromptScheudle.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompts": ("STRING",{"default": [], "forceInput": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prompt",)
INPUT_IS_LIST = True
FUNCTION = "apply"
CATEGORY = "utils"
def apply(self, prompts):
print("Called with %s"%prompts)
merged = OrderedDict()
last = None
for i, prompt in enumerate(prompts):
if prompt != last:
merged[i] = prompt
last = prompt
travel = [ """ "%d": "%s" """ % (index, prompt) for (index, prompt) in merged.items() ]
print(travel)
return (', '.join(travel),)
NODE_CLASS_MAPPINGS = {
"TagSwap": TagSwap,
"PromptMerge": PromptMerge,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"TagSwap": "Tag Swap",
"PromptMerge": "Prompt Merge",
}