|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
import gradio as gr |
|
|
from modules import script_callbacks, scripts, shared |
|
|
from modules.shared import opts |
|
|
from collections import namedtuple |
|
|
import lark |
|
|
import random |
|
|
from functools import lru_cache |
|
|
import hashlib |
|
|
from itertools import product |
|
|
import os |
|
|
import logging |
|
|
import traceback |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _env_bool(name: str, default: str = "0") -> bool: |
|
|
v = str(os.getenv(name, default)).strip().lower() |
|
|
return v not in ("0", "", "false", "no", "off") |
|
|
|
|
|
ALLOW_EMPTY_ALTERNATE = _env_bool("ALLOW_EMPTY_ALTERNATE", "0") |
|
|
EXPAND_ALTERNATE_PER_STEP = _env_bool("EXPAND_ALTERNATE_PER_STEP", "1") |
|
|
GROUP_COMBO_LIMIT = int(os.getenv("GROUP_COMBO_LIMIT", "100")) |
|
|
CACHE_SIZE = int(os.getenv('PROMPT_PARSER_CACHE_SIZE', 4096)) |
|
|
|
|
|
|
|
|
_alt_rule = r' "[" prompt ("|" prompt)* "]" ' if not ALLOW_EMPTY_ALTERNATE else r' "[" prompt ("|" [prompt])+ "]" ' |
|
|
|
|
|
_grammar = r""" |
|
|
!start: (prompt | /[][():|]/+)* |
|
|
|
|
|
prompt: (scheduled | emphasized | grouped |
|
|
| alternate | alternate_distinct |
|
|
| alternate1 | alternate2 |
|
|
| top_level_sequence | sequence |
|
|
| compound | numbered | and_rule |
|
|
| plain | WHITESPACE)* |
|
|
|
|
|
!emphasized: "(" prompt ")" |
|
|
| "(" prompt ":" prompt ")" |
|
|
| "(" prompt ":" NUMBER ")" |
|
|
| "[" prompt "]" |
|
|
|
|
|
scheduled: "[" [prompt (":" prompt)+] "]" ":" NUMBER (step_range_list | reverse_flag | step_range_list reverse_flag)? |
|
|
reverse_flag: "reverse" | "r" |
|
|
step_range_list: step_range ("," step_range)* |
|
|
step_range: NUMBER "-" NUMBER | NUMBER "%" "-" NUMBER "%" |
|
|
|
|
|
alternate: """ + _alt_rule + r""" |
|
|
!alternate_distinct: "[" prompt ("|" prompt)* "]!" |
|
|
alternate1: (prompt) "|" (prompt)+ |
|
|
alternate2: (plain | compound) ("|" (plain | compound))+ |
|
|
|
|
|
grouped: "{" ((NUMBER_Q | prompt | sequence | grouped) ("," | "|")?)+ "}" |
|
|
|
|
|
top_level_sequence: prompt ("::" sequence)+ "!!" ("," plain)? |
|
|
sequence: prompt "::" prompt ("," | WHITESPACE)* nested_sequence* ("!" | ";") |
|
|
nested_sequence: "::" prompt ("," | WHITESPACE)* ("!" | ";" | "~") |
|
|
|
|
|
compound: /[a-zA-Z0-9]+(_[a-zA-Z0-9]+)+/ |
|
|
numbered: NUMBER_Q ("!" | "_")? (grouped | sequence | compound | and_rule | plain | alternate | alternate_distinct | alternate1 | alternate2) |
|
|
and_rule: (plain | compound) ("&" (plain | compound))+ |
|
|
WHITESPACE: /\s+/ |
|
|
plain: /([^\\\[\]()&]|\\.)+/ |
|
|
|
|
|
%import common.SIGNED_NUMBER -> NUMBER |
|
|
%import common.INT -> NUMBER_Q |
|
|
""" |
|
|
|
|
|
schedule_parser = lark.Lark(_grammar) |
|
|
|
|
|
@lru_cache(maxsize=CACHE_SIZE) |
|
|
def hash_tree(tree: lark.Tree | lark.Token) -> str: |
|
|
if isinstance(tree, lark.Tree): |
|
|
return hashlib.md5((tree.data + ''.join(hash_tree(c) for c in tree.children)).encode()).hexdigest() |
|
|
return hashlib.md5(str(tree).encode()).hexdigest() |
|
|
|
|
|
def resolve_tree(tree: lark.Tree | lark.Token, keep_spacing: bool = True) -> str: |
|
|
if isinstance(tree, lark.Tree): |
|
|
children = [] |
|
|
for child in tree.children: |
|
|
if isinstance(child, lark.Token) and child.type == "WHITESPACE": |
|
|
if keep_spacing: |
|
|
children.append(" ") |
|
|
continue |
|
|
children.append(resolve_tree(child, keep_spacing)) |
|
|
result = "".join(str(c) for c in children if c) |
|
|
return re.sub(r"[\s\u2028\u2029]+", " ", result).strip() if keep_spacing else result.strip() |
|
|
return str(tree).strip() |
|
|
|
|
|
class ScheduleTransformer(lark.Transformer): |
|
|
def __init__(self, total_steps: int, current_step: int = 1, seed: int | None = 42): |
|
|
super().__init__() |
|
|
self.total_steps = total_steps |
|
|
self.current_step = current_step |
|
|
self.seed = seed |
|
|
self.rng = random.Random(seed) if seed is not None else random |
|
|
|
|
|
def start(self, args): |
|
|
return "".join(str(arg) for arg in args if arg) |
|
|
|
|
|
def prompt(self, args): |
|
|
return "".join(str(arg) for arg in args if arg) |
|
|
|
|
|
def plain(self, args): |
|
|
return args[0].value |
|
|
|
|
|
def compound(self, args): |
|
|
return "_".join(str(arg) for arg in args) |
|
|
|
|
|
def and_rule(self, args): |
|
|
return " and ".join(resolve_tree(arg, keep_spacing=True) for arg in args if resolve_tree(arg)) |
|
|
|
|
|
def grouped(self, args): |
|
|
return ", ".join(resolve_tree(arg, keep_spacing=True) for arg in args if resolve_tree(arg).strip(" ,|")) |
|
|
|
|
|
def alternate(self, args): |
|
|
vals = [] |
|
|
for arg in args: |
|
|
s = resolve_tree(arg, keep_spacing=True) |
|
|
if s or s == "": |
|
|
vals.append(s) |
|
|
return vals[(self.current_step - 1) % len(vals)] if vals else "empty_prompt" |
|
|
|
|
|
def alternate_distinct(self, args): |
|
|
options = [resolve_tree(arg, keep_spacing=True) for arg in args if resolve_tree(arg)] |
|
|
return self.rng.choice(options) if options else "empty_prompt" |
|
|
|
|
|
def alternate1(self, args): |
|
|
options = [resolve_tree(arg, keep_spacing=True) for arg in args if resolve_tree(arg)] |
|
|
return self.rng.choice(options) if options else "empty_prompt" |
|
|
|
|
|
def alternate2(self, args): |
|
|
options = [resolve_tree(arg, keep_spacing=True) for arg in args if resolve_tree(arg)] |
|
|
return self.rng.choice(options) if options else "empty_prompt" |
|
|
|
|
|
def numbered(self, args): |
|
|
quantity = int(args[0]) |
|
|
|
|
|
distinct = False |
|
|
if len(args) > 1: |
|
|
mark = str(args[1]) |
|
|
distinct = mark in ("!", "_") |
|
|
|
|
|
target = args[-1] |
|
|
|
|
|
options = [] |
|
|
if isinstance(target, lark.Tree) and target.data in ("alternate", "alternate1", "alternate2"): |
|
|
for child in target.children: |
|
|
val = self.visit(child) |
|
|
if val: |
|
|
options.append(val) |
|
|
elif isinstance(target, lark.Token): |
|
|
options = [resolve_tree(target, keep_spacing=True)] |
|
|
else: |
|
|
for child in target.children: |
|
|
val = self.visit(child) |
|
|
if val: |
|
|
options.append(val) |
|
|
|
|
|
if not options: |
|
|
return "empty_prompt" |
|
|
|
|
|
if distinct: |
|
|
if quantity >= len(options): |
|
|
unique = self.rng.sample(options, len(options)) if options else [] |
|
|
pad = self.rng.choices(options, k=quantity - len(unique)) if quantity > len(unique) else [] |
|
|
selected = unique + pad |
|
|
else: |
|
|
selected = self.rng.sample(options, quantity) |
|
|
else: |
|
|
selected = self.rng.choices(options, k=quantity) |
|
|
|
|
|
return ", ".join(selected) |
|
|
|
|
|
def sequence(self, args, parent=None): |
|
|
owner = resolve_tree(args[0], keep_spacing=True) if parent is None else parent |
|
|
descriptors = [resolve_tree(arg, keep_spacing=True).strip(" ,~!;") for arg in args[1:] if resolve_tree(arg).strip(" ,~!;")] |
|
|
return f"{owner}: {', '.join(descriptors)}" |
|
|
|
|
|
def top_level_sequence(self, args): |
|
|
owner = resolve_tree(args[0], keep_spacing=True).strip() |
|
|
sequences = [] |
|
|
trailing_text = [] |
|
|
for child in args[1:]: |
|
|
if isinstance(child, lark.Tree) and child.data == "sequence": |
|
|
sequences.append(self.sequence(child.children, owner)) |
|
|
elif isinstance(child, str) and child.strip() == "!!": |
|
|
continue |
|
|
else: |
|
|
t = resolve_tree(child, keep_spacing=True).strip(" ,") |
|
|
if t: |
|
|
trailing_text.append(t) |
|
|
text = f"{owner} -> {', '.join(sequences)}" |
|
|
if trailing_text: |
|
|
text += f", {', '.join(trailing_text)}" |
|
|
return text |
|
|
|
|
|
def nested_sequence(self, args): |
|
|
elements = [resolve_tree(arg, keep_spacing=True).strip(" ,~!;") for arg in args[:-1] if resolve_tree(arg).strip(" ,~!;")] |
|
|
terminator = args[-1] if args and isinstance(args[-1], str) else None |
|
|
if terminator == "~": |
|
|
return self.rng.choice(elements) if elements else "empty_prompt" |
|
|
return f"[{' | '.join(elements)}]" |
|
|
|
|
|
def emphasized(self, args): |
|
|
prompt = args[0] |
|
|
if len(args) > 1: |
|
|
if isinstance(args[1], lark.Token) and args[1].type == "NUMBER": |
|
|
weight = float(args[1].value) |
|
|
return f"({prompt}:{weight})" |
|
|
else: |
|
|
second = args[1] |
|
|
return f"({prompt}:{second})" |
|
|
else: |
|
|
return f"({prompt}:1.1)" |
|
|
|
|
|
def scheduled(self, args): |
|
|
prompts = [arg for arg in args[:-1] if not isinstance(arg, lark.Token) or arg.type != "NUMBER"] |
|
|
number_node = args[-1] |
|
|
if isinstance(number_node, lark.Tree): |
|
|
number_node = resolve_tree(number_node, keep_spacing=True) |
|
|
try: |
|
|
weight = float(number_node) |
|
|
except ValueError: |
|
|
weight = 1.0 |
|
|
|
|
|
boundary = int(weight * self.total_steps) if weight <= 1.0 else int(weight) |
|
|
boundary = max(1, min(boundary, self.total_steps)) |
|
|
|
|
|
if not prompts: |
|
|
return "empty_prompt" |
|
|
if len(prompts) == 1: |
|
|
return f"({resolve_tree(prompts[0], keep_spacing=True)}:{weight})" if self.current_step >= boundary else "" |
|
|
step_increment = boundary / max(1, len(prompts)) |
|
|
for i, prompt in enumerate(prompts): |
|
|
step = min(self.total_steps, int(i * step_increment)) if i < len(prompts) - 1 else self.total_steps |
|
|
if self.current_step <= step: |
|
|
return f"({resolve_tree(prompt, keep_spacing=True)}:{weight})" |
|
|
return f"({resolve_tree(prompts[-1], keep_spacing=True)}:{weight})" |
|
|
|
|
|
class CollectSteps(lark.Visitor): |
|
|
def __init__(self, steps, prefix="", suffix="", depth=0, use_scheduling=True, seed=None): |
|
|
super().__init__() |
|
|
self.steps = steps |
|
|
self.prefix = prefix |
|
|
self.suffix = suffix |
|
|
self.depth = depth |
|
|
self.use_scheduling = use_scheduling |
|
|
self.seed = seed |
|
|
self.rng = random.Random(seed) if seed is not None else random |
|
|
self.schedules = [] |
|
|
|
|
|
def visit(self, tree): |
|
|
if isinstance(tree, lark.Tree): |
|
|
method_name = f"visit_{tree.data}" |
|
|
method = getattr(self, method_name, self._default_visit) |
|
|
return method(tree) |
|
|
elif isinstance(tree, lark.Token): |
|
|
return self._visit_token(tree) |
|
|
return [] |
|
|
|
|
|
def _default_visit(self, tree): |
|
|
schedules = [] |
|
|
for i, child in enumerate(tree.children): |
|
|
if isinstance(child, lark.Token) and child.type == "WHITESPACE": |
|
|
continue |
|
|
pre = "".join(resolve_tree(c, keep_spacing=True) for j, c in enumerate(tree.children) if j < i and not (isinstance(c, lark.Token) and c.type == "WHITESPACE")) |
|
|
post = "".join(resolve_tree(c, keep_spacing=True) for j, c in enumerate(tree.children) if j > i and not (isinstance(c, lark.Token) and c.type == "WHITESPACE")) |
|
|
collector = CollectSteps(self.steps, prefix=self.prefix + pre, suffix=post + self.suffix, depth=self.depth + 1, use_scheduling=self.use_scheduling, seed=self.seed) |
|
|
child_schedules = collector.visit(child) |
|
|
schedules.extend(child_schedules) |
|
|
return schedules |
|
|
|
|
|
def _visit_token(self, token): |
|
|
if token.type == "WHITESPACE": |
|
|
return [] |
|
|
return [[self.steps, self.prefix + str(token) + self.suffix]] |
|
|
|
|
|
def visit_plain(self, tree): |
|
|
text = resolve_tree(tree, keep_spacing=True) |
|
|
return [[self.steps, self.prefix + text + self.suffix]] |
|
|
|
|
|
def visit_top_level_sequence(self, tree): |
|
|
transformer = ScheduleTransformer(self.steps, 1, self.seed) |
|
|
text = transformer.transform(tree) |
|
|
return [[self.steps, self.prefix + text + self.suffix]] |
|
|
|
|
|
def visit_scheduled(self, tree): |
|
|
if not tree.children: |
|
|
return [[self.steps, self.prefix + "empty_prompt" + self.suffix]] |
|
|
|
|
|
prompts = [ |
|
|
p for p in tree.children |
|
|
if not (isinstance(p, lark.Token) and p.type == "NUMBER") |
|
|
and not (isinstance(p, lark.Tree) and p.data in ("step_range_list", "reverse_flag")) |
|
|
] |
|
|
number_node = next((p for p in tree.children if isinstance(p, lark.Token) and p.type == "NUMBER"), None) |
|
|
step_range_list = next((p for p in tree.children if isinstance(p, lark.Tree) and p.data == "step_range_list"), None) |
|
|
is_reverse = any(isinstance(p, lark.Tree) and p.data == "reverse_flag" for p in tree.children) |
|
|
|
|
|
weight = float(number_node.value) if number_node else 1.0 |
|
|
|
|
|
def _clamp_step(x: int) -> int: |
|
|
return max(1, min(x, self.steps)) |
|
|
|
|
|
step_intervals = [] |
|
|
explicit_ranges = False |
|
|
if step_range_list: |
|
|
explicit_ranges = True |
|
|
for sr in step_range_list.children: |
|
|
if not (isinstance(sr, lark.Tree) and sr.data == "step_range"): |
|
|
continue |
|
|
if len(sr.children) != 2: |
|
|
continue |
|
|
start_txt = resolve_tree(sr.children[0], keep_spacing=False) |
|
|
end_txt = resolve_tree(sr.children[1], keep_spacing=False) |
|
|
|
|
|
def _to_steps(txt: str) -> int: |
|
|
s = txt.strip() |
|
|
if s.endswith("%"): |
|
|
try: |
|
|
return int(round(float(s[:-1]) / 100.0 * self.steps)) |
|
|
except ValueError: |
|
|
return 1 |
|
|
try: |
|
|
return int(round(float(s))) |
|
|
except ValueError: |
|
|
return 1 |
|
|
|
|
|
start_step = _clamp_step(_to_steps(start_txt)) |
|
|
end_step = _clamp_step(_to_steps(end_txt)) |
|
|
if start_step < end_step: |
|
|
step_intervals.append((start_step, end_step)) |
|
|
else: |
|
|
num_prompts = len(prompts) |
|
|
boundary = _clamp_step(int(round(weight * self.steps)) if weight <= 1.0 else int(round(weight))) |
|
|
if num_prompts == 1: |
|
|
schedules = [] |
|
|
schedules.append([boundary - 1, self.prefix + self.suffix]) |
|
|
last_text = resolve_tree(prompts[0], keep_spacing=True) |
|
|
schedules.append([self.steps, self.prefix + last_text + self.suffix]) |
|
|
return schedules |
|
|
if boundary < num_prompts: |
|
|
boundary = num_prompts |
|
|
step_size = boundary / num_prompts |
|
|
for i in range(num_prompts): |
|
|
start = _clamp_step(int(round(i * step_size)) + 1) |
|
|
end = _clamp_step(int(round((i + 1) * step_size))) |
|
|
if start < end: |
|
|
step_intervals.append((start, end)) |
|
|
|
|
|
if is_reverse: |
|
|
prompts = prompts[::-1] |
|
|
step_intervals = step_intervals[::-1] |
|
|
|
|
|
schedules = [] |
|
|
|
|
|
if step_intervals and step_intervals[0][0] > 1: |
|
|
schedules.append([step_intervals[0][0] - 1, self.prefix + self.suffix]) |
|
|
|
|
|
for i, (start, end) in enumerate(step_intervals): |
|
|
end = min(end, self.steps) |
|
|
if start < end: |
|
|
p = prompts[i] |
|
|
if isinstance(p, lark.Tree): |
|
|
child_schedules = self.visit(p) |
|
|
else: |
|
|
text = resolve_tree(p, keep_spacing=True) |
|
|
child_schedules = [[self.steps, text]] |
|
|
|
|
|
for sched in child_schedules: |
|
|
schedules.append([end, self.prefix + sched[1] + self.suffix]) |
|
|
|
|
|
if step_intervals and step_intervals[-1][1] < self.steps: |
|
|
tail_text = resolve_tree(prompts[-1], keep_spacing=True) |
|
|
schedules.append([self.steps, self.prefix + tail_text + self.suffix]) |
|
|
|
|
|
if not schedules: |
|
|
return [[self.steps, self.prefix + resolve_tree(tree, keep_spacing=True) + self.suffix]] |
|
|
|
|
|
return schedules |
|
|
|
|
|
def visit_alternate(self, tree): |
|
|
options = [] |
|
|
for child in tree.children: |
|
|
if isinstance(child, lark.Token) and child.type == "WHITESPACE": |
|
|
continue |
|
|
child_schedules = self.visit(child) |
|
|
child_options = [ |
|
|
sched[1].strip(" ,|") |
|
|
for sched in child_schedules |
|
|
if sched[1].strip(" ,|") |
|
|
] |
|
|
options.append( |
|
|
child_options |
|
|
or [resolve_tree(child, keep_spacing=True).strip(" ,|")] |
|
|
) |
|
|
|
|
|
if not options: |
|
|
return [[self.steps, self.prefix + "empty_prompt" + self.suffix]] |
|
|
|
|
|
if EXPAND_ALTERNATE_PER_STEP: |
|
|
schedules = [] |
|
|
for step in range(1, self.steps + 1): |
|
|
option = options[(step - 1) % len(options)] |
|
|
for sched in option: |
|
|
schedules.append([step, self.prefix + sched + self.suffix]) |
|
|
return schedules |
|
|
else: |
|
|
group = options[self.rng.randrange(len(options))] |
|
|
choice = self.rng.choice(group) if group else "empty_prompt" |
|
|
return [[self.steps, self.prefix + choice + self.suffix]] |
|
|
|
|
|
def visit_alternate_distinct(self, tree): |
|
|
options = [] |
|
|
for child in tree.children: |
|
|
if isinstance(child, lark.Token) and child.type == "WHITESPACE": |
|
|
continue |
|
|
child_schedules = self.visit(child) |
|
|
child_options = [ |
|
|
sched[1].strip(" ,|") |
|
|
for sched in child_schedules |
|
|
if sched[1].strip(" ,|") |
|
|
] |
|
|
options.append( |
|
|
child_options |
|
|
or [resolve_tree(child, keep_spacing=True).strip(" ,|")] |
|
|
) |
|
|
|
|
|
flat = [opt for group in options for opt in group] |
|
|
if not flat: |
|
|
return [[self.steps, self.prefix + "empty_prompt" + self.suffix]] |
|
|
|
|
|
selected = self.rng.choice(flat) |
|
|
return [[self.steps, self.prefix + selected + self.suffix]] |
|
|
|
|
|
def visit_alternate1(self, tree): |
|
|
return self.visit_alternate_distinct(tree) |
|
|
|
|
|
def visit_alternate2(self, tree): |
|
|
options = [resolve_tree(c).strip() for c in tree.children] |
|
|
combined_options = [] |
|
|
for option in options: |
|
|
if "_" in option: |
|
|
combined_options.append(option) |
|
|
else: |
|
|
suffix = options[0].split("_")[-1] if "_" in options[0] else "" |
|
|
combined_options.append(f"{option}_{suffix}" if suffix else option) |
|
|
return [[self.steps, self.prefix + "|".join(combined_options) + self.suffix]] |
|
|
|
|
|
def visit_grouped(self, tree): |
|
|
all_options = [] |
|
|
for child in tree.children: |
|
|
if isinstance(child, lark.Token) and child.type == "WHITESPACE": |
|
|
continue |
|
|
child_schedules = self.visit(child) |
|
|
child_options = [sched[1].strip(" ,|") for sched in child_schedules if sched[1].strip(" ,|")] |
|
|
all_options.append(child_options or [resolve_tree(child, keep_spacing=True).strip(" ,|")]) |
|
|
out = [] |
|
|
for i, combo in enumerate(product(*all_options)): |
|
|
if i >= GROUP_COMBO_LIMIT: |
|
|
break |
|
|
text = ", ".join(combo).strip() |
|
|
if text: |
|
|
out.append([self.steps, self.prefix + text + self.suffix]) |
|
|
|
|
|
return out or [[self.steps, self.prefix + "empty_prompt" + self.suffix]] |
|
|
|
|
|
def visit_sequence(self, tree): |
|
|
transformer = ScheduleTransformer(self.steps, 1, self.seed) |
|
|
text = transformer.transform(tree) |
|
|
return [[self.steps, self.prefix + text + self.suffix]] |
|
|
|
|
|
def visit_nested_sequence(self, tree): |
|
|
elements = [resolve_tree(child, keep_spacing=True).strip(" ,~!;") for child in tree.children[:-1] if resolve_tree(child).strip(" ,~!;")] |
|
|
terminator = tree.children[-1].value if tree.children and isinstance(tree.children[-1], lark.Token) else None |
|
|
if terminator == "~": |
|
|
text = self.rng.choice(elements) if elements else "empty_prompt" |
|
|
else: |
|
|
text = f"[{' | '.join(elements)}]" |
|
|
return [[self.steps, self.prefix + text + self.suffix]] |
|
|
|
|
|
def visit_numbered(self, tree): |
|
|
quantity = int(tree.children[0]) |
|
|
|
|
|
distinct = False |
|
|
if len(tree.children) > 1: |
|
|
mark = str(tree.children[1]) |
|
|
distinct = mark in ("!", "_") |
|
|
|
|
|
target = tree.children[-1] |
|
|
|
|
|
child_schedules = self.visit(target) |
|
|
options = [ |
|
|
sched[1].strip(" ,|") |
|
|
for sched in child_schedules |
|
|
if sched[1].strip(" ,|") |
|
|
] |
|
|
if not options: |
|
|
options = [resolve_tree(target, keep_spacing=True).strip(" ,|")] |
|
|
|
|
|
if not options: |
|
|
return [[self.steps, self.prefix + "empty_prompt" + self.suffix]] |
|
|
|
|
|
if distinct: |
|
|
if quantity >= len(options): |
|
|
unique = self.rng.sample(options, len(options)) if options else [] |
|
|
pad = self.rng.choices(options, k=quantity - len(unique)) if quantity > len(unique) else [] |
|
|
selected = unique + pad |
|
|
else: |
|
|
selected = self.rng.sample(options, quantity) |
|
|
else: |
|
|
selected = self.rng.choices(options, k=quantity) |
|
|
|
|
|
return [[self.steps, self.prefix + ", ".join(selected) + self.suffix]] |
|
|
|
|
|
def visit_and_rule(self, tree): |
|
|
text = " and ".join(resolve_tree(c, keep_spacing=True) for c in tree.children if resolve_tree(c, keep_spacing=True)) |
|
|
return [[self.steps, self.prefix + text + self.suffix]] |
|
|
|
|
|
def visit_emphasized(self, tree): |
|
|
prompt = resolve_tree(tree.children[0], keep_spacing=True) |
|
|
if len(tree.children) > 1: |
|
|
if isinstance(tree.children[1], lark.Token) and tree.children[1].type == "NUMBER": |
|
|
weight = float(tree.children[1].value) |
|
|
text = f"({prompt}:{weight})" |
|
|
else: |
|
|
second = resolve_tree(tree.children[1], keep_spacing=True) |
|
|
text = f"({prompt}:{second})" |
|
|
else: |
|
|
text = f"({prompt}:1.1)" |
|
|
return [[self.steps, self.prefix + text + self.suffix]] |
|
|
|
|
|
def __call__(self, tree): |
|
|
self.schedules = self.visit(tree) |
|
|
return self.schedules or [[self.steps, self.prefix + resolve_tree(tree, keep_spacing=True) + self.suffix]] |
|
|
|
|
|
|
|
|
def at_step_from_schedule(step, schedule): |
|
|
if not schedule: |
|
|
return "" |
|
|
for end_step, text in schedule: |
|
|
if step <= int(end_step): |
|
|
return text |
|
|
return schedule[-1][1] |
|
|
|
|
|
def at_step(step: int, prompt_or_schedule, *, steps: int | None = None, |
|
|
seed: int | None = 42, use_visitor: bool = True) -> str: |
|
|
if isinstance(prompt_or_schedule, list) and prompt_or_schedule and isinstance(prompt_or_schedule[0], list): |
|
|
return at_step_from_schedule(step, prompt_or_schedule) |
|
|
|
|
|
prompt = str(prompt_or_schedule) |
|
|
if steps is None: |
|
|
raise ValueError("steps is required when passing a prompt string to at_step(...)") |
|
|
sched = get_schedule(prompt, steps, True, seed, use_visitor) |
|
|
return at_step_from_schedule(step, sched) |
|
|
|
|
|
@lru_cache(maxsize=CACHE_SIZE) |
|
|
def get_schedule(prompt: str, steps: int, use_scheduling: bool, seed: int | None, use_visitor: bool = True): |
|
|
try: |
|
|
tree = schedule_parser.parse(prompt) |
|
|
except lark.exceptions.LarkError as e: |
|
|
logger.warning("Prompt parse failed: '%s' — %s", prompt, e) |
|
|
return [[steps, prompt]] |
|
|
|
|
|
collector = CollectSteps(steps, use_scheduling=use_scheduling, seed=seed) |
|
|
schedules = collector(tree) |
|
|
|
|
|
if not use_visitor: |
|
|
rebuilt = [] |
|
|
for end, _ in schedules: |
|
|
transformer = ScheduleTransformer(steps, end, seed) |
|
|
text = transformer.transform(tree) |
|
|
rebuilt.append([end, text]) |
|
|
return rebuilt |
|
|
|
|
|
return schedules |
|
|
|
|
|
def analyze_prompt(input_prompt, steps, seed, use_visitor, output_mode, specific_step): |
|
|
tree_str = "" |
|
|
schedule_str = "" |
|
|
result = "" |
|
|
error_msg = "" |
|
|
try: |
|
|
tree = schedule_parser.parse(input_prompt) |
|
|
tree_str = tree.pretty() |
|
|
except lark.exceptions.LarkError as e: |
|
|
error_msg = f"Parse Error: {str(e)}\n\nPossible causes:\n- Unbalanced brackets or invalid syntax.\n- Missing colons, commas, or incorrect operators.\n- Suggestion: Validate grammar rules for scheduled, alternate, grouped, sequences, etc.\n- Ensure no invalid characters in plain text.\nTraceback: {traceback.format_exc()}" |
|
|
return result, tree_str, schedule_str, error_msg |
|
|
|
|
|
try: |
|
|
schedule = get_schedule(input_prompt, steps, True, seed, use_visitor) |
|
|
schedule_str = "\n".join([f"Up to step {end}: {text}" for end, text in schedule]) |
|
|
except Exception as e: |
|
|
error_msg = f"Schedule Error: {str(e)}\nTraceback: {traceback.format_exc()}" |
|
|
return result, tree_str, schedule_str, error_msg |
|
|
|
|
|
if output_mode == "Full Schedule": |
|
|
result = schedule_str |
|
|
elif output_mode == "Parse Tree": |
|
|
result = tree_str |
|
|
elif output_mode == "At Specific Step": |
|
|
result = at_step_from_schedule(specific_step, schedule) |
|
|
return result, tree_str, schedule_str, error_msg |
|
|
|
|
|
|
|
|
def convert_to_grouped(input_prompt): |
|
|
return f"{{{input_prompt}}}" |
|
|
|
|
|
def convert_to_sequence(input_prompt): |
|
|
parts = [p.strip() for p in input_prompt.split(',')] |
|
|
if len(parts) < 2: |
|
|
return input_prompt + "::descriptor!" |
|
|
owner = parts[0] |
|
|
descriptors = ', '.join(parts[1:]) |
|
|
return f"{owner}::{descriptors}!" |
|
|
|
|
|
def convert_to_numbered(input_prompt, num, distinct): |
|
|
mark = "!" if distinct else "" |
|
|
return f"{num}{mark} {{{input_prompt}}}" |
|
|
|
|
|
class PromptParserAnalyzerScript(scripts.Script): |
|
|
def title(self): |
|
|
return "Prompt Parser Analyzer" |
|
|
|
|
|
def show(self, is_txt2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_txt2img): |
|
|
with gr.Group(): |
|
|
with gr.Accordion("Prompt Parser Analyzer", open=False): |
|
|
with gr.Tab("Analyze Prompt"): |
|
|
input_prompt = gr.Textbox(label="Input Prompt", lines=5) |
|
|
steps = gr.Number(label="Steps", value=100, precision=0) |
|
|
seed = gr.Number(label="Seed (for random elements)", value=42, precision=0) |
|
|
use_visitor = gr.Checkbox(label="Use Visitor Mode (faster for simple prompts)", value=True) |
|
|
output_mode = gr.Dropdown(choices=["Full Schedule", "Parse Tree", "At Specific Step"], label="Output Mode", value="Full Schedule") |
|
|
specific_step = gr.Number(label="Specific Step (for At Specific Step mode)", value=50, precision=0) |
|
|
analyze_btn = gr.Button("Analyze Prompt") |
|
|
output_result = gr.Textbox(label="Analysis Result", lines=10) |
|
|
output_tree = gr.Textbox(label="Parse Tree (Debug)", lines=5) |
|
|
output_schedule = gr.Textbox(label="Full Schedule (Debug)", lines=5) |
|
|
error_output = gr.Textbox(label="Errors & Suggestions", lines=5) |
|
|
|
|
|
analyze_btn.click( |
|
|
fn=analyze_prompt, |
|
|
inputs=[input_prompt, steps, seed, use_visitor, output_mode, specific_step], |
|
|
outputs=[output_result, output_tree, output_schedule, error_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("Format Converters"): |
|
|
gr.HTML(value="<p>These helpers convert plain prompts to advanced formats. Read the explanations for usage.</p>") |
|
|
|
|
|
with gr.Accordion("Grouped {} - For combinations of attributes"): |
|
|
gr.HTML(value="<p>Wraps items in {}, generates combinations if | (alternates) inside. <br>Example: {red|blue, car|bike} resolves to 'red, car', 'red, bike', 'blue, car', 'blue, bike'. <br>Use for multiple independent attributes like colors, objects. Limit combos with GROUP_COMBO_LIMIT env var.</p>") |
|
|
conv_group_input = gr.Textbox(label="Input for Grouped (e.g., red|blue, car|bike)", lines=2) |
|
|
conv_group_output = gr.Textbox(label="Converted Grouped Prompt", lines=2) |
|
|
conv_group_btn = gr.Button("Convert to Grouped") |
|
|
conv_group_btn.click(convert_to_grouped, inputs=conv_group_input, outputs=conv_group_output) |
|
|
|
|
|
with gr.Accordion("Sequence :: - For structured descriptions"): |
|
|
gr.HTML(value="<p>Structures as owner::descriptors!. Nested with :: and ~ for random or ! for close. <br>Example: character::hair:ponytail, eyes:blue! -> 'character: hair ponytail, eyes blue'. <br>Use for hierarchical entities like characters, outfits. Close with ! or ;. Top-level with !!! for groups.</p>") |
|
|
conv_seq_input = gr.Textbox(label="Input for Sequence (e.g., character, hair:ponytail, eyes:blue)", lines=2) |
|
|
conv_seq_output = gr.Textbox(label="Converted Sequence Prompt", lines=2) |
|
|
conv_seq_btn = gr.Button("Convert to Sequence") |
|
|
conv_seq_btn.click(convert_to_sequence, inputs=conv_seq_input, outputs=conv_seq_output) |
|
|
|
|
|
with gr.Accordion("Numbered N! or N_ - For repeats"): |
|
|
gr.HTML(value="<p>N{group} repeats N times, ! or _ for distinct (no repeats). <br>Example: 3! {a|b|c} -> 'a, b, c' (unique sample). <br>Use for multiples like '3 cats'. If N > options, pads with repeats unless distinct.</p>") |
|
|
conv_num_input = gr.Textbox(label="Input for Numbered (e.g., a|b|c)", lines=2) |
|
|
conv_num_num = gr.Number(label="Number (N)", value=3, precision=0) |
|
|
conv_num_distinct = gr.Checkbox(label="Distinct (! or _)", value=True) |
|
|
conv_num_output = gr.Textbox(label="Converted Numbered Prompt", lines=2) |
|
|
conv_num_btn = gr.Button("Convert to Numbered") |
|
|
conv_num_btn.click(convert_to_numbered, inputs=[conv_num_input, conv_num_num, conv_num_distinct], outputs=conv_num_output) |
|
|
|
|
|
return [input_prompt, steps, seed, use_visitor, output_mode, specific_step, analyze_btn, output_result] |
|
|
|
|
|
def on_ui_settings(): |
|
|
section = ('prompt-parser-analyzer', "Prompt Parser Analyzer") |
|
|
|
|
|
script_callbacks.on_ui_settings(on_ui_settings) |