dikdimon's picture
Upload stable-diffusion-webui-prompt-parser8 using SD-Hub
6d53c19 verified
raw
history blame
45.1 kB
"""
Prompt Refactor Plus — вкладка для A1111 (Stable Diffusion WebUI).
Почему раньше таб не показывался:
- Нельзя вызывать gradio.launch() внутри скрипта расширения.
- Скрипт должен регистрировать вкладку через script_callbacks.on_ui_tabs.
- В A1111 падает dataclasses со строковыми аннотациями, поэтому НЕ используем
`from __future__ import annotations`.
Что сделано:
- Перенос в on_ui_tabs() + регистрация callback.
- Безопасная загрузка внешнего парсера из PROMPT_PARSER_PATH. Ошибка не ломает UI.
- Базовые параметры в явном виде. Остальные спрятаны в «Advanced».
- Анализ промпта, QuickFix, Wrap Wizard, Timeline Wizard, Token Counter.
- Дополнения: Curve Editor, Prob-Alternate Wizard, Test Generator — в Advanced.
- Кнопка Reload parser для горячей подгрузки парсера после правки файла.
Совместимость:
- Python 3.10.x (как в A1111).
- Gradio 3.x, поставляемый с WebUI.
- Внешние пакеты (transformers/open_clip) опциональны: есть «approx» fallback.
Файл предназначен для пути: extensions/<your-ext>/scripts/prompt-refactor.py
"""
import os
import re
import csv
import json
import math
import importlib
import importlib.util
import importlib.machinery
from dataclasses import dataclass
from typing import List, Tuple, Optional, Any, Dict
# A1111 hooks
try:
from modules import script_callbacks
except Exception as _e:
script_callbacks = None # позволит локально запускать как обычный модуль
# ---------------------------------------------------------
# Конфиг: путь до внешнего парсера
# ---------------------------------------------------------
PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", "/content/A1111/modules/prompt_parser.py")
PP_MOD_NAME = "pp_prompt_parser"
_pp = None
_pp_error = ""
def import_from_path(mod_name: str, path: str):
loader = importlib.machinery.SourceFileLoader(mod_name, path)
spec = importlib.util.spec_from_loader(loader.name, loader)
mod = importlib.util.module_from_spec(spec)
loader.exec_module(mod)
return mod
def load_parser_safely() -> Tuple[Optional[Any], str]:
"""Возвращает (модуль_парсера|None, сообщение_ошибки|'' )."""
if not os.path.isfile(PROMPT_PARSER_PATH):
return None, f"Parser file not found: {PROMPT_PARSER_PATH}"
try:
# удалить из sys.modules, чтобы форсировать перезагрузку
import sys as _sys
if PP_MOD_NAME in _sys.modules:
del _sys.modules[PP_MOD_NAME]
mod = import_from_path(PP_MOD_NAME, PROMPT_PARSER_PATH)
# быстрая проверка API
_ = getattr(mod, "schedule_parser")
_ = getattr(mod, "resolve_tree")
return mod, ""
except Exception as e:
return None, f"{type(e).__name__}: {e}"
# загрузим один раз при импорте, но не падаем при ошибке
_pp, _pp_error = load_parser_safely()
# ---------------------------------------------------------
# ENV и горячая перезагрузка
# ---------------------------------------------------------
class EnvState:
def __init__(self):
self.allow_empty_alt = os.environ.get("ALLOW_EMPTY_ALTERNATE", "0")
self.expand_alt_per_step = os.environ.get("EXPAND_ALTERNATE_PER_STEP", "1")
self.group_combo_limit = os.environ.get("GROUP_COMBO_LIMIT", "100")
self.suppress_standalone_colon = os.environ.get("SUPPRESS_STANDALONE_COLON", "1")
def apply(self, allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool) -> bool:
changed = False
new_allow = "1" if allow_empty_alt else "0"
new_expand = "1" if expand_alt else "0"
new_limit = str(int(combo_limit))
new_suppress = "1" if suppress_colon else "0"
if new_allow != self.allow_empty_alt:
os.environ["ALLOW_EMPTY_ALTERNATE"] = new_allow
self.allow_empty_alt = new_allow
changed = True
if new_expand != self.expand_alt_per_step:
os.environ["EXPAND_ALTERNATE_PER_STEP"] = new_expand
self.expand_alt_per_step = new_expand
changed = True
if new_limit != self.group_combo_limit:
os.environ["GROUP_COMBO_LIMIT"] = new_limit
self.group_combo_limit = new_limit
changed = True
if new_suppress != self.suppress_standalone_colon:
os.environ["SUPPRESS_STANDALONE_COLON"] = new_suppress
self.suppress_standalone_colon = new_suppress
changed = True
return changed
ENV_STATE = EnvState()
def reload_parser_after_env(allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool):
"""Применяет ENV и перезагружает парсер при изменении."""
global _pp, _pp_error
if ENV_STATE.apply(allow_empty_alt, expand_alt, combo_limit, suppress_colon):
_pp, _pp_error = load_parser_safely()
def reload_parser_manual() -> str:
"""Ручная перезагрузка парсера кнопкой."""
global _pp, _pp_error
_pp, _pp_error = load_parser_safely()
return _pp_error or "OK"
# ---------------------------------------------------------
# Препроцессор: repeat и prob-alternates
# ---------------------------------------------------------
def _scan_balanced_parens(s: str, start_idx: int) -> Optional[Tuple[int, str]]:
if start_idx >= len(s) or s[start_idx] != '(':
return None
depth = 0
i = start_idx
buf: List[str] = []
while i < len(s):
ch = s[i]
if ch == '(':
depth += 1
if depth > 1: buf.append(ch)
elif ch == ')':
depth -= 1
if depth == 0: return i + 1, "".join(buf)
else: buf.append(ch)
else:
buf.append(ch)
i += 1
return None
def _pp_repeat_macro(src: str) -> Tuple[str, List[str]]:
changes: List[str] = []
i = 0
out: List[str] = []
rx = re.compile(r"\brepeat\s+(\d+)\s*\(", re.IGNORECASE)
while i < len(src):
m = rx.search(src, i)
if not m:
out.append(src[i:]); break
out.append(src[i:m.start()])
n = max(1, int(m.group(1)))
open_idx = m.end() - 1
scanned = _scan_balanced_parens(src, open_idx)
if not scanned:
out.append(src[m.start():m.end()]); i = m.end(); continue
end_idx, inner = scanned
body = inner.strip()
out.append(", ".join([body] * n))
changes.append(f"repeat {n} (...) → {n} повторов")
i = end_idx
return "".join(out), changes
def _split_top_level_bar(block: str) -> List[str]:
parts: List[str] = []; buf: List[str] = []
d_round = d_square = d_curly = 0; i = 0
while i < len(block):
ch = block[i]
if ch == '\\':
if i + 1 < len(block): buf.append(block[i:i+2]); i += 2; continue
if ch == '(': d_round += 1
elif ch == ')': d_round = max(0, d_round - 1)
elif ch == '[': d_square += 1
elif ch == ']': d_square = max(0, d_square - 1)
elif ch == '{': d_curly += 1
elif ch == '}': d_curly = max(0, d_curly - 1)
elif ch == '|' and d_round == d_square == d_curly == 0:
parts.append("".join(buf)); buf = []; i += 1; continue
buf.append(ch); i += 1
parts.append("".join(buf))
return [p.strip() for p in parts]
def _pp_prob_alt(src: str, total_steps: int) -> Tuple[str, List[str]]:
changes: List[str] = []
token_rx = re.compile(r"((?:[^\n\[\]]+?\{\s*\d*\.?\d+\s*\}\s*\|\s*)+[^\n\[\]]+?\{\s*\d*\.?\d+\s*\})")
def convert_block(block: str) -> str:
parts = _split_top_level_bar(block)
items: List[Tuple[str, float]] = []; total = 0.0
for p in parts:
m = re.search(r"\{\s*([+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?)\s*\}\s*$", p)
if not m: return block
w = float(m.group(1))
text = re.sub(r"\{\s*[+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?\s*\}\s*$", "", p).strip()
items.append((text, w)); total += w
if total <= 0: return block
boundaries: List[str] = []; acc = 0.0
for i in range(len(items) - 1):
acc += items[i][1] / total
b = int(round(acc * total_steps))
if 0 < b < total_steps: boundaries.append(str(b))
core = " : ".join([t for t, _ in items])
extra = f" : {', '.join(boundaries)}" if boundaries else ""
return f"[ {core} ]{extra}"
def repl(m):
orig = m.group(1)
out = convert_block(orig)
if out != orig: changes.append("prob a{p}|b{q} → scheduled")
return out
return token_rx.sub(repl, src), changes
def preprocess(text: str, total_steps: int = 20) -> Tuple[str, List[str]]:
out1, ch1 = _pp_repeat_macro(text)
out2, ch2 = _pp_prob_alt(out1, total_steps)
return out2, ch1 + ch2
# ---------------------------------------------------------
# Аналитика
# ---------------------------------------------------------
@dataclass
class AnalysisResult:
ok: bool
error: Optional[str]
parse_tree: Optional[Any]
resolved_text: Optional[str]
schedule: Optional[List[Tuple[int, str]]]
timeline: Optional[List[Tuple[int, str]]]
changes: List[str]
def analyze_prompt(
text: str,
steps: int = 20,
use_preprocessor: bool = True,
visitor_mode: bool = True,
seed: Optional[int] = None,
allow_empty_alt: bool = False,
expand_alt_per_step: bool = True,
group_combo_limit: int = 100,
suppress_standalone_colon: bool = True,
) -> AnalysisResult:
if _pp is None:
return AnalysisResult(False, f"Parser not loaded: {_pp_error}", None, None, None, None, [])
reload_parser_after_env(allow_empty_alt, expand_alt_per_step, group_combo_limit, suppress_standalone_colon)
changes: List[str] = []
src = text
if use_preprocessor:
src, ch = preprocess(src, steps)
changes.extend(ch)
try:
tree = _pp.schedule_parser.parse(src)
except Exception as e:
return AnalysisResult(False, f"{e}", None, None, None, None, changes)
try:
resolved = _pp.resolve_tree(tree, keep_spacing=True)
except Exception as e:
return AnalysisResult(False, f"Resolve error: {e}", tree, None, None, None, changes)
try:
if visitor_mode:
collector = _pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed=seed)
schedules = collector.visit(tree) or []
else:
schedules = []
for step in range(1, steps + 1):
transformer = _pp.ScheduleTransformer(total_steps=steps, current_step=step, seed=seed)
txt = transformer.transform(tree)
schedules.append((step, txt))
schedules = [[b, t] for (b, t) in schedules]
except Exception as e:
return AnalysisResult(False, f"Schedule error: {e}", tree, resolved, None, None, changes)
try:
norm: List[Tuple[int, str]] = []
for item in schedules:
if isinstance(item, (list, tuple)) and len(item) == 2:
b, t = int(item[0]), str(item[1]); norm.append((b, t))
norm.sort(key=lambda x: x[0])
timeline: List[Tuple[int, str]] = []
idx = 0; cur = norm[0][1] if norm else ""
for step in range(1, steps + 1):
while idx < len(norm) and step > norm[idx][0]:
cur = norm[idx][1]; idx += 1
timeline.append((step, cur))
except Exception as e:
return AnalysisResult(False, f"Timeline error: {e}", tree, resolved, [(int(b), str(t)) for b, t in schedules], None, changes)
return AnalysisResult(True, None, tree, resolved, [(int(b), str(t)) for b, t in schedules], timeline, changes)
def schedule_to_csv(schedule: List[Tuple[int, str]]) -> str:
from io import StringIO
buf = StringIO(); w = csv.writer(buf)
w.writerow(["boundary", "text"])
for b, t in schedule: w.writerow([b, t])
return buf.getvalue()
def timeline_to_csv(timeline: List[Tuple[int, str]]) -> str:
from io import StringIO
buf = StringIO(); w = csv.writer(buf)
w.writerow(["step", "text"])
for s, t in timeline: w.writerow([s, t])
return buf.getvalue()
# ---------------------------------------------------------
# QuickFix
# ---------------------------------------------------------
def _close_unbalanced_brackets(src: str) -> str:
opens = {'(': 0, '[': 0, '{': 0}
for ch in src:
if ch in opens: opens[ch] += 1
elif ch == ')': opens['('] = max(0, opens['('] - 1)
elif ch == ']': opens['['] = max(0, opens['['] - 1)
elif ch == '}': opens['{'] = max(0, opens['{'] - 1)
return src + (')' * opens['(']) + (']' * opens['[']) + ('}' * opens['{'])
def _escape_bars_inside_parens(src: str) -> str:
out: List[str] = []; depth = 0; i = 0
while i < len(src):
ch = src[i]
if ch == '\\' and i + 1 < len(src):
out.append(src[i:i+2]); i += 2; continue
if ch == '(':
depth += 1; out.append(ch)
elif ch == ')':
depth = max(0, depth - 1); out.append(ch)
elif ch == '|' and depth > 0:
out.append(r'\|')
else:
out.append(ch)
i += 1
return "".join(out)
def _replace_single_colon_labels_outside_schedules(src: str) -> str:
segments: List[Tuple[bool, str]] = []; inside = 0; buf: List[str] = []
for ch in src:
if ch == '[':
if buf: segments.append((False, "".join(buf))); buf = []
inside += 1; segments.append((True, ch))
elif ch == ']':
segments.append((True, ch)); inside = max(0, inside - 1)
else:
if inside: segments.append((True, ch))
else: buf.append(ch)
if buf: segments.append((False, "".join(buf)))
def repl(text: str) -> str:
pattern = re.compile(r'(?<!:)\b([A-Za-z0-9_]+)\s*:\s*(?!:)\b([A-Za-z0-9_]+)\b')
return pattern.sub(r'\1::\2', text)
out: List[str] = []
for inside_sq, chunk in segments:
out.append(chunk if inside_sq else repl(chunk))
return "".join(out)
def quickfix_suggestions(src: str) -> List[Dict[str, str]]:
sgs: List[Dict[str, str]] = []
if any(src.count(o) > src.count(c) for o, c in [('(', ')'), ('[', ']'), ('{', '}')]):
sgs.append({"id": "close_brackets", "title": "Закрыть незакрытые скобки"})
if '|' in src:
sgs.append({"id": "escape_bars_in_parens", "title": "Экранировать | внутри (...)"})
if os.environ.get("SUPPRESS_STANDALONE_COLON", "1") == "1":
sgs.append({"id": "labels_double_colon", "title": "word:word → word::word вне [..]"})
sgs.append({"id": "normalize_ws", "title": "Нормализовать пробелы"})
return sgs
def apply_quickfix(src: str, fix_id: str) -> str:
if fix_id == "close_brackets": return _close_unbalanced_brackets(src)
if fix_id == "escape_bars_in_parens": return _escape_bars_inside_parens(src)
if fix_id == "labels_double_colon": return _replace_single_colon_labels_outside_schedules(src)
if fix_id == "normalize_ws":
s = src.replace('\r\n', '\n').replace('\r', '\n')
s = re.sub(r'[ \t]+', ' ', s); s = re.sub(r' *\n *', ' ', s); s = re.sub(r' {2,}', ' ', s)
return s.strip()
return src
# ---------------------------------------------------------
# Wraps / Builders
# ---------------------------------------------------------
PALETTE_TEMPLATES: Dict[str, str] = {
"Emphasis (1.2)": "(text:1.2)",
"Emphasis range": "(text:start:end)",
"Square schedule": "[ a : b ] : 20",
"Triple schedule": "[ a : b : c ] : 20, 40",
"Alternate": "a | b",
"Prob alternate": "a{0.5} | b{0.5}",
"Group label": "owner::label::description",
"AND": "AND( term1, term2 )",
"BREAK": "BREAK",
}
def palette_insert(cur_prompt: str, key: str) -> str:
tpl = PALETTE_TEMPLATES.get(key, "")
sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", "
return (cur_prompt or "") + (sep + tpl)
def _split_top_level_comma(block: str) -> List[str]:
parts: List[str] = []; buf: List[str] = []
d_round = d_square = d_curly = 0; i = 0
while i < len(block):
ch = block[i]
if ch == '\\' and i + 1 < len(block):
buf.append(block[i:i+2]); i += 2; continue
if ch == '(': d_round += 1
elif ch == ')': d_round = max(0, d_round - 1)
elif ch == '[': d_square += 1
elif ch == ']': d_square = max(0, d_square - 1)
elif ch == '{': d_curly += 1
elif ch == '}': d_curly = max(0, d_curly - 1)
elif ch == ',' and d_round == d_square == d_curly == 0:
parts.append("".join(buf)); buf = []; i += 1; continue
buf.append(ch); i += 1
parts.append("".join(buf))
return [p.strip() for p in parts if p.strip()]
def segments_from_text(src: str, mode: str = "auto") -> List[str]:
s = (src or "").strip()
if not s: return []
if mode == "lines" or ('\n' in s and mode == "auto"):
return [ln.strip() for ln in s.splitlines() if ln.strip()]
if mode == "comma":
return _split_top_level_comma(s)
return _split_top_level_comma(s)
def wrap_sequence(src: str, seg_mode: str = "auto") -> str:
return ", ".join(segments_from_text(src, seg_mode))
def wrap_top_level2(src: str, steps: int = 20, boundary: Optional[int] = None, seg_mode: str = "auto") -> str:
segs = segments_from_text(src, seg_mode)
if len(segs) < 2: segs = segs + segs
core = " : ".join(segs[:2])
b = boundary if boundary is not None else max(1, min(steps - 1, round(steps / 2)))
return f"[ {core} ] : {b}"
def wrap_top_level3(src: str, steps: int = 30, boundaries: Optional[Tuple[int, int]] = None, seg_mode: str = "auto") -> str:
segs = segments_from_text(src, seg_mode)
if len(segs) < 3:
segs = segs + ([""] * (3 - len(segs)))
segs = [s or f"seg{i+1}" for i, s in enumerate(segs)]
core = " : ".join(segs[:3])
if boundaries is None:
b1 = max(1, round(steps / 3)); b2 = max(b1 + 1, round(2 * steps / 3))
else:
b1, b2 = int(boundaries[0]), int(boundaries[1])
return f"[ {core} ] : {b1}, {b2}"
def wrap_alternate(src: str, seg_mode: str = "auto") -> str:
return " | ".join(segments_from_text(src, seg_mode))
def wrap_prob_equal(src: str, weight: float = 1.0, seg_mode: str = "auto") -> str:
segs = segments_from_text(src, seg_mode)
if not segs: return ""
w = f"{float(weight):g}"
return " | ".join([f"{s}{{{w}}}" for s in segs])
# ---------------------------------------------------------
# Curve Editor / Boundaries
# ---------------------------------------------------------
def _ease_in(t: float, p: float) -> float: return t ** max(1.0, p)
def _ease_out(t: float, p: float) -> float: return 1.0 - (1.0 - t) ** max(1.0, p)
def _ease_in_out(t: float) -> float: return 0.5 * (1 - math.cos(math.pi * t))
def _cosine(t: float) -> float: return (1 - math.cos(math.pi * t)) / 2.0
def _cubic_bezier_y(t: float, x1: float, y1: float, x2: float, y2: float) -> float:
mt = 1 - t
return 3 * mt * mt * t * y1 + 3 * mt * t * t * y2 + t**3
def _curve_f(t: float, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float:
t = max(0.0, min(1.0, float(t)))
if kind == "linear": return t
if kind == "ease-in": return _ease_in(t, p)
if kind == "ease-out": return _ease_out(t, p)
if kind == "ease-in-out": return _ease_in_out(t)
if kind == "cosine": return _cosine(t)
if kind == "bezier" and bezier:
x1,y1,x2,y2 = bezier
return _cubic_bezier_y(t, x1, y1, x2, y2)
return t
def _inv_curve(u: float, kind: str, steps: int = 60, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float:
lo, hi = 0.0, 1.0
for _ in range(steps):
mid = 0.5 * (lo + hi)
v = _curve_f(mid, kind, p=p, bezier=bezier)
if v < u: lo = mid
else: hi = mid
return 0.5 * (lo + hi)
def curve_boundaries(num_segments: int, total_steps: int, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> List[int]:
num_segments = max(2, int(num_segments))
total_steps = max(2, int(total_steps))
targets = [i / num_segments for i in range(1, num_segments)]
bounds: List[int] = []
for u in targets:
t = _inv_curve(u, kind=kind, p=p, bezier=bezier)
b = int(round(t * total_steps))
b = max(1, min(total_steps - 1, b))
if bounds and b <= bounds[-1]: b = min(total_steps - 1, bounds[-1] + 1)
bounds.append(b)
return bounds
def build_curve_schedule(segments_text: str, num_segments: int, total_steps: int, kind: str,
p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> Tuple[str, str]:
segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()]
if not segs:
segs = [f"seg{i+1}" for i in range(num_segments)]
else:
if len(segs) < num_segments:
segs.extend([f"seg{i+1}" for i in range(len(segs), num_segments)])
elif len(segs) > num_segments:
segs = segs[:num_segments]
bounds = curve_boundaries(num_segments, total_steps, kind=kind, p=p, bezier=bezier)
core = " : ".join(segs)
csv_bounds = ", ".join(str(b) for b in bounds)
block = f"[ {core} ]" + (f" : {csv_bounds}" if csv_bounds else "")
return block, csv_bounds
# ---------------------------------------------------------
# Token Counter (CLIP) — с fallback'ом
# ---------------------------------------------------------
class TokenCountResult:
def __init__(self, backend: str, total: int, payload: int, limit: int, overflow: int, trimmed_text: str):
self.backend = backend
self.total = total
self.payload = payload
self.limit = limit
self.overflow = overflow
self.trimmed_text = trimmed_text
def _tokenize_transformers(text: str):
from transformers import AutoTokenizer # type: ignore
tok = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_fast=True)
enc = tok(text, add_special_tokens=True, return_offsets_mapping=True)
ids = enc["input_ids"]
offsets = [(int(a), int(b)) for (a, b) in enc["offset_mapping"]]
return ids, offsets, "transformers"
def _tokenize_open_clip(text: str):
import open_clip # type: ignore
tokenizer = open_clip.get_tokenizer("ViT-L-14")
ids = tokenizer(text)
offs = [(m.start(), m.end()) for m in re.finditer(r"\S+", text)]
if hasattr(ids, "squeeze"):
ids = ids.squeeze(0).tolist()
return list(map(int, ids)), offs, "open_clip"
def _tokenize_approx(text: str):
toks = []; offs = []
for m in re.finditer(r"\w+|[^\s\w]", text, flags=re.UNICODE):
toks.append(m.group(0)); offs.append((m.start(), m.end()))
return toks, offs, "approx"
def count_clip_tokens(text: str, limit_payload: int = 77, backend: str = "auto") -> TokenCountResult:
try:
if backend in ("auto", "transformers"):
ids, offsets, used = _tokenize_transformers(text)
total = len(ids); payload = max(0, total - 2) if total >= 2 else total
keep_payload = min(payload, limit_payload)
keep_ids = keep_payload + (2 if total >= 2 else 0)
cut_char = None
if keep_payload < payload and offsets:
cut_index = keep_ids - 1
if 0 <= cut_index < len(offsets): cut_char = offsets[cut_index][1]
trimmed = text[:cut_char] if cut_char is not None else text
overflow = max(0, payload - limit_payload)
return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed)
except Exception:
pass
try:
if backend in ("auto", "open_clip"):
ids, offsets, used = _tokenize_open_clip(text)
total = len(ids); payload = max(0, total - 2) if total >= 2 else total
keep_payload = min(payload, limit_payload)
keep_ids = keep_payload + (2 if total >= 2 else 0)
cut_index = min(keep_ids - 1, len(offsets) - 1)
cut_char = offsets[cut_index][1] if keep_payload < payload and cut_index >= 0 else None
trimmed = text[:cut_char] if cut_char is not None else text
overflow = max(0, payload - limit_payload)
return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed)
except Exception:
pass
toks, offsets, used = _tokenize_approx(text)
total = len(toks); payload = total
keep_payload = min(payload, limit_payload)
cut_index = keep_payload - 1
cut_char = offsets[cut_index][1] if keep_payload < payload and 0 <= cut_index < len(offsets) else None
trimmed = text[:cut_char] if cut_char is not None else text
overflow = max(0, payload - limit_payload)
return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed)
# ---------------------------------------------------------
# Gradio UI для A1111
# ---------------------------------------------------------
def on_ui_tabs():
import gradio as gr
with gr.Blocks(analytics_enabled=False) as demo:
gr.Markdown("## Prompt Refactor Plus")
# --------- PARSER STATUS / CONTROL ----------
with gr.Row():
parser_status = gr.Textbox(label="Parser status", value=_pp_error or "OK", interactive=False)
reload_btn = gr.Button("Reload parser")
def _reload():
return reload_parser_manual()
reload_btn.click(_reload, outputs=[parser_status])
# --------- EDITOR ----------
with gr.Tab("Editor"):
prompt_tb = gr.Textbox(label="Prompt", lines=8, placeholder="Введите промпт...")
with gr.Row():
palette_dd = gr.Dropdown(choices=list(PALETTE_TEMPLATES.keys()), value="Emphasis (1.2)", label="Шаблон")
insert_btn = gr.Button("Вставить шаблон")
insert_btn.click(palette_insert, inputs=[prompt_tb, palette_dd], outputs=[prompt_tb])
# --------- ANALYSIS ----------
with gr.Tab("Analysis"):
with gr.Row():
steps = gr.Slider(1, 200, value=20, step=1, label="Steps")
use_pre = gr.Checkbox(value=True, label="Use Preprocessor")
suppress_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON")
analyze_btn = gr.Button("Analyze")
with gr.Accordion("Advanced", open=False):
with gr.Row():
visitor = gr.Checkbox(value=True, label="Visitor mode")
seed = gr.Number(value=None, label="Seed", precision=0)
with gr.Row():
allow_empty_alt = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE")
expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP")
group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0)
with gr.Tabs():
with gr.Tab("Resolved"):
parse_out = gr.Textbox(label="", lines=4)
changes_out = gr.Textbox(label="Preprocessor Changes", lines=3)
with gr.Tab("Schedule CSV"):
schedule_out = gr.Textbox(label="", lines=8)
with gr.Tab("Timeline CSV"):
timeline_out = gr.Textbox(label="", lines=8)
# QuickFix
gr.Markdown("### QuickFix")
fixes_state = gr.State(value=[])
with gr.Row():
fixes_dd = gr.Dropdown(choices=[], label="Выберите исправление")
apply_fix_btn = gr.Button("Применить")
auto_fix_btn = gr.Button("Auto QuickFix")
def _analyze(prompt, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon):
res = analyze_prompt(
text=prompt or "",
steps=int(steps),
use_preprocessor=bool(use_pre),
visitor_mode=bool(visitor),
seed=None if seed in (None, "") else int(seed),
allow_empty_alt=bool(allow_empty_alt),
expand_alt_per_step=bool(expand_alt),
group_combo_limit=int(group_limit),
suppress_standalone_colon=bool(suppress_colon),
)
sgs = quickfix_suggestions(prompt or "")
return (
res.resolved_text or "",
"\n".join(res.changes),
schedule_to_csv(res.schedule) if res.schedule else "",
timeline_to_csv(res.timeline) if res.timeline else "",
sgs,
gr.update(choices=[s["title"] for s in sgs]),
_pp_error or "OK"
)
analyze_btn.click(
_analyze,
inputs=[prompt_tb, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon],
outputs=[parse_out, changes_out, schedule_out, timeline_out, fixes_state, fixes_dd, parser_status]
)
def _apply_fix_cur(prompt: str, fixes: List[Dict[str, str]], title: str):
m = {f["title"]: f["id"] for f in fixes}
fid = m.get(title)
return apply_quickfix(prompt or "", fid) if fid else prompt
apply_fix_btn.click(_apply_fix_cur, inputs=[prompt_tb, fixes_state, fixes_dd], outputs=[prompt_tb])
def _auto_fix(prompt: str):
sgs = quickfix_suggestions(prompt or "")
if not sgs: return prompt
return apply_quickfix(prompt or "", sgs[0]["id"])
auto_fix_btn.click(_auto_fix, inputs=[prompt_tb], outputs=[prompt_tb])
# --------- WRAP ----------
with gr.Tab("Wrap Wizard"):
src_tb = gr.Textbox(label="Исходный текст", lines=6, placeholder="Один сегмент на строку или через верхнеуровневые запятые")
seg_mode_dd = gr.Radio(choices=["auto", "lines", "comma"], value="auto", label="Разделение")
with gr.Row():
ww_steps = gr.Slider(2, 200, value=20, step=1, label="Steps (для top-level)")
ww_boundary = gr.Number(value=None, label="Boundary (Top-Level 2)", precision=0)
ww_b1 = gr.Number(value=None, label="b1 (Top-Level 3)", precision=0)
ww_b2 = gr.Number(value=None, label="b2 (Top-Level 3)", precision=0)
with gr.Row():
seq_btn = gr.Button("Sequence")
tl2_btn = gr.Button("Top-Level (2)")
tl3_btn = gr.Button("Top-Level (3)")
with gr.Row():
alt_btn = gr.Button("Alternate")
prob_btn = gr.Button("Prob-Equal")
wrap_preview = gr.Textbox(label="Предпросмотр", lines=3)
insert_wrap_btn = gr.Button("Вставить в Prompt")
def on_wrap_seq(src, seg_mode): return wrap_sequence(src or "", seg_mode)
def on_wrap_tl2(src, steps, boundary, seg_mode):
b = None
try:
if boundary not in (None, ""): b = int(boundary)
except Exception: b = None
return wrap_top_level2(src or "", steps=int(steps), boundary=b, seg_mode=seg_mode)
def on_wrap_tl3(src, steps, b1, b2, seg_mode):
bb = None
try:
if b1 not in (None, "") and b2 not in (None, ""): bb = (int(b1), int(b2))
except Exception: bb = None
return wrap_top_level3(src or "", steps=int(steps), boundaries=bb, seg_mode=seg_mode)
def on_wrap_alt(src, seg_mode): return wrap_alternate(src or "", seg_mode)
def on_wrap_prob(src, seg_mode): return wrap_prob_equal(src or "", 1.0, seg_mode)
seq_btn.click(on_wrap_seq, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview])
tl2_btn.click(on_wrap_tl2, inputs=[src_tb, ww_steps, ww_boundary, seg_mode_dd], outputs=[wrap_preview])
tl3_btn.click(on_wrap_tl3, inputs=[src_tb, ww_steps, ww_b1, ww_b2, seg_mode_dd], outputs=[wrap_preview])
alt_btn.click(on_wrap_alt, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview])
prob_btn.click(on_wrap_prob, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview])
def _insert_block(cur_prompt: str, block: str) -> str:
if not block or not block.strip(): return cur_prompt
sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", "
return (cur_prompt or "") + sep + block
insert_wrap_btn.click(_insert_block, inputs=[prompt_tb, wrap_preview], outputs=[prompt_tb])
# --------- TIMELINE ----------
with gr.Tab("Timeline Wizard"):
seg_tb = gr.Textbox(label="Сегменты (по одному на строку)", lines=4, value="a\nb")
boundaries_tb = gr.Textbox(label="Границы (CSV)", value="20")
build_tl_btn = gr.Button("Построить")
tl_preview = gr.Textbox(label="Предпросмотр", lines=2)
insert_tl_btn = gr.Button("Вставить в Prompt")
def build_schedule_block(segments_text: str, boundaries_csv: str) -> str:
segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()]
if not segs: return ""
core = " : ".join(segs)
bounds = [b.strip() for b in (boundaries_csv or "").split(",") if b.strip()]
extra = f" : {', '.join(bounds)}" if bounds else ""
return f"[ {core} ]{extra}"
build_tl_btn.click(build_schedule_block, inputs=[seg_tb, boundaries_tb], outputs=[tl_preview])
insert_tl_btn.click(_insert_block, inputs=[prompt_tb, tl_preview], outputs=[prompt_tb])
# --------- TOKEN COUNTER ----------
with gr.Tab("Token Counter"):
payload_limit = gr.Slider(10, 200, value=77, step=1, label="Payload limit")
backend_dd = gr.Dropdown(choices=["auto", "transformers", "open_clip", "approx"], value="auto", label="Backend (Advanced)")
count_btn = gr.Button("Count")
with gr.Row():
backend_used = gr.Textbox(label="Used backend", interactive=False)
total_out = gr.Number(label="Total tokens", interactive=False)
payload_out = gr.Number(label="Payload tokens", interactive=False)
overflow_out = gr.Number(label="Overflow", interactive=False)
trimmed_tb = gr.Textbox(label="Trimmed text", lines=6)
def on_count(prompt, limit_payload, backend):
res = count_clip_tokens(prompt or "", limit_payload=int(limit_payload), backend=backend)
return res.backend, res.total, res.payload, res.overflow, res.trimmed_text
count_btn.click(on_count, inputs=[prompt_tb, payload_limit, backend_dd],
outputs=[backend_used, total_out, payload_out, overflow_out, trimmed_tb])
# --------- ADVANCED (Curve/Prob/Test) ----------
with gr.Tab("Advanced"):
with gr.Accordion("Curve Editor", open=False):
curve_segments_tb = gr.Textbox(label="Сегменты (пусто = seg1..N)", lines=6, value="")
with gr.Row():
curve_num = gr.Slider(2, 8, value=3, step=1, label="N")
curve_steps = gr.Slider(2, 300, value=60, step=1, label="Steps")
curve_kind = gr.Dropdown(choices=["linear", "ease-in", "ease-out", "ease-in-out", "cosine", "bezier"], value="ease-in-out", label="Кривая")
curve_power = gr.Slider(1, 6, value=2, step=0.5, label="Показатель (ease-in/out)")
with gr.Row():
bez_x1 = gr.Slider(0, 1, value=0.3, step=0.05, label="Bezier x1")
bez_y1 = gr.Slider(0, 1, value=0.0, step=0.05, label="Bezier y1")
bez_x2 = gr.Slider(0, 1, value=0.7, step=0.05, label="Bezier x2")
bez_y2 = gr.Slider(0, 1, value=1.0, step=0.05, label="Bezier y2")
build_curve_btn = gr.Button("Build")
curve_bounds_out = gr.Textbox(label="Границы", lines=1)
curve_preview = gr.Textbox(label="Предпросмотр", lines=2)
insert_curve_btn = gr.Button("Вставить в Prompt")
def on_build_curve(segs_text, n, steps, kind, power, x1, y1, x2, y2):
bez = (float(x1), float(y1), float(x2), float(y2)) if kind == "bezier" else None
block, csvb = build_curve_schedule(segs_text or "", int(n), int(steps), str(kind), float(power), bez)
return csvb, block
build_curve_btn.click(on_build_curve,
inputs=[curve_segments_tb, curve_num, curve_steps, curve_kind, curve_power, bez_x1, bez_y1, bez_x2, bez_y2],
outputs=[curve_bounds_out, curve_preview])
insert_curve_btn.click(_insert_block, inputs=[prompt_tb, curve_preview], outputs=[prompt_tb])
with gr.Accordion("Prob-Alternate Wizard", open=False):
steps_pa = gr.Slider(1, 200, value=20, step=1, label="Steps")
table = gr.Dataframe(headers=["Текст", "Вес"], datatype=["str", "number"],
row_count=3, col_count=(2, "fixed"),
value=[["a", 0.5], ["b", 0.5], ["", ""]], label="Варианты")
build_pa_btn = gr.Button("Build")
pa_preview = gr.Textbox(label="Предпросмотр", lines=2)
def build_schedule_from_prob_table(rows, total_steps: int) -> str:
items: List[Tuple[str, float]] = []; total = 0.0
for r in rows or []:
if not r or len(r) < 2: continue
text = (r[0] or "").strip()
if not text: continue
try: w = float(str(r[1]).strip())
except Exception: w = 0.0
items.append((text, w)); total += w
if not items or total <= 0: return ""
boundaries: List[str] = []; acc = 0.0
for i in range(len(items) - 1):
acc += items[i][1] / total
b = int(round(acc * int(total_steps)))
if 0 < b < int(total_steps): boundaries.append(str(b))
core = " : ".join([t for t, _ in items])
extra = f" : {', '.join(boundaries)}" if boundaries else ""
return f"[ {core} ]{extra}"
build_pa_btn.click(build_schedule_from_prob_table, inputs=[table, steps_pa], outputs=[pa_preview])
insert_pa_btn = gr.Button("Вставить в Prompt")
insert_pa_btn.click(_insert_block, inputs=[prompt_tb, pa_preview], outputs=[prompt_tb])
with gr.Accordion("Test Generator (unit tests)", open=False):
tg_steps = gr.Slider(1, 200, value=20, step=1, label="Steps")
tg_seed = gr.Number(value=None, label="Seed", precision=0)
with gr.Row():
tg_allow_empty = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE")
tg_expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP")
tg_group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0)
tg_suppr_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON")
gen_btn = gr.Button("Generate")
test_code_tb = gr.Textbox(label="test_prompt_parser_generated.py", lines=18)
test_file = gr.File(label="Скачать", interactive=False)
def generate_unittests_code(prompt_src: str, steps: int, env: Dict[str, str], seed: Optional[int], analysis: AnalysisResult) -> str:
resolved = analysis.resolved_text or ""
schedule = analysis.schedule or []
def slice_pairs(pairs: List[Tuple[int,str]], n_head: int = 3, n_tail: int = 3):
if len(pairs) <= n_head + n_tail: return pairs
return pairs[:n_head] + pairs[-n_tail:]
sched_sample = slice_pairs(schedule)
ENV_LINES = "\n".join([f" os.environ[{repr(k)}] = {repr(v)}" for k, v in env.items()])
code = f'''\
import os, sys, unittest, importlib.machinery, importlib.util
PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", {repr(PROMPT_PARSER_PATH)})
def import_from_path(mod_name, path):
if mod_name in sys.modules: del sys.modules[mod_name]
loader = importlib.machinery.SourceFileLoader(mod_name, path)
spec = importlib.util.spec_from_loader(loader.name, loader)
mod = importlib.util.module_from_spec(spec); loader.exec_module(mod); return mod
class TestPromptParserGenerated(unittest.TestCase):
@classmethod
def setUpClass(cls):
{ENV_LINES and ENV_LINES or " pass"}
cls.pp = import_from_path("pp", PROMPT_PARSER_PATH)
def test_parse_and_resolve(self):
pp = self.pp; src = {repr(prompt_src)}
tree = pp.schedule_parser.parse(src)
resolved = pp.resolve_tree(tree, keep_spacing=True)
self.assertEqual(resolved, {repr(resolved)})
def test_schedule_sample(self):
pp = self.pp; src = {repr(prompt_src)}; steps = {int(steps)}
collector = pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed={repr(seed)})
schedules = collector.visit(tree := pp.schedule_parser.parse(src)) or []
got = [(int(b), str(t)) for b, t in schedules]
sample_got = (got[:3] + got[-3:]) if len(got) > 6 else got
self.assertEqual(sample_got, {repr(sched_sample)})
if __name__ == "__main__": unittest.main()
'''
return code
def on_gen(prompt, steps, seed, ae, ex, gl, sc):
reload_parser_after_env(bool(ae), bool(ex), int(gl), bool(sc))
res = analyze_prompt(
text=prompt or "", steps=int(steps), use_preprocessor=True, visitor_mode=True,
seed=None if seed in (None, "") else int(seed),
allow_empty_alt=bool(ae), expand_alt_per_step=bool(ex),
group_combo_limit=int(gl), suppress_standalone_colon=bool(sc)
)
env = {
"ALLOW_EMPTY_ALTERNATE": "1" if ae else "0",
"EXPAND_ALTERNATE_PER_STEP": "1" if ex else "0",
"GROUP_COMBO_LIMIT": str(int(gl)),
"SUPPRESS_STANDALONE_COLON": "1" if sc else "0",
}
code = generate_unittests_code(prompt or "", int(steps), env, None if seed in (None, "") else int(seed), res)
path = "/mnt/data/test_prompt_parser_generated.py"
try:
with open(path, "w", encoding="utf-8") as f: f.write(code)
return code, path
except Exception:
return code, None
gen_btn.click(on_gen,
inputs=[prompt_tb, tg_steps, tg_seed, tg_allow_empty, tg_expand_alt, tg_group_limit, tg_suppr_colon],
outputs=[test_code_tb, test_file])
# Вернуть как вкладку WebUI
return [(demo, "Prompt Refactor", "prompt_refactor_tab")]
# Зарегистрировать вкладку при работе внутри A1111
if script_callbacks is not None:
script_callbacks.on_ui_tabs(on_ui_tabs)