""" Prompt Refactor Plus — вкладка для A1111 (Stable Diffusion WebUI). Почему раньше таб не показывался: - Нельзя вызывать gradio.launch() внутри скрипта расширения. - Скрипт должен регистрировать вкладку через script_callbacks.on_ui_tabs. - В A1111 падает dataclasses со строковыми аннотациями, поэтому НЕ используем `from __future__ import annotations`. Что сделано: - Перенос в on_ui_tabs() + регистрация callback. - Безопасная загрузка внешнего парсера из PROMPT_PARSER_PATH. Ошибка не ломает UI. - Базовые параметры в явном виде. Остальные спрятаны в «Advanced». - Анализ промпта, QuickFix, Wrap Wizard, Timeline Wizard, Token Counter. - Дополнения: Curve Editor, Prob-Alternate Wizard, Test Generator — в Advanced. - Кнопка Reload parser для горячей подгрузки парсера после правки файла. Совместимость: - Python 3.10.x (как в A1111). - Gradio 3.x, поставляемый с WebUI. - Внешние пакеты (transformers/open_clip) опциональны: есть «approx» fallback. Файл предназначен для пути: extensions//scripts/prompt-refactor.py """ import os import re import csv import json import math import importlib import importlib.util import importlib.machinery from dataclasses import dataclass from typing import List, Tuple, Optional, Any, Dict # A1111 hooks try: from modules import script_callbacks except Exception as _e: script_callbacks = None # позволит локально запускать как обычный модуль # --------------------------------------------------------- # Конфиг: путь до внешнего парсера # --------------------------------------------------------- PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", "/content/A1111/modules/prompt_parser.py") PP_MOD_NAME = "pp_prompt_parser" _pp = None _pp_error = "" def import_from_path(mod_name: str, path: str): loader = importlib.machinery.SourceFileLoader(mod_name, path) spec = importlib.util.spec_from_loader(loader.name, loader) mod = importlib.util.module_from_spec(spec) loader.exec_module(mod) return mod def load_parser_safely() -> Tuple[Optional[Any], str]: """Возвращает (модуль_парсера|None, сообщение_ошибки|'' ).""" if not os.path.isfile(PROMPT_PARSER_PATH): return None, f"Parser file not found: {PROMPT_PARSER_PATH}" try: # удалить из sys.modules, чтобы форсировать перезагрузку import sys as _sys if PP_MOD_NAME in _sys.modules: del _sys.modules[PP_MOD_NAME] mod = import_from_path(PP_MOD_NAME, PROMPT_PARSER_PATH) # быстрая проверка API _ = getattr(mod, "schedule_parser") _ = getattr(mod, "resolve_tree") return mod, "" except Exception as e: return None, f"{type(e).__name__}: {e}" # загрузим один раз при импорте, но не падаем при ошибке _pp, _pp_error = load_parser_safely() # --------------------------------------------------------- # ENV и горячая перезагрузка # --------------------------------------------------------- class EnvState: def __init__(self): self.allow_empty_alt = os.environ.get("ALLOW_EMPTY_ALTERNATE", "0") self.expand_alt_per_step = os.environ.get("EXPAND_ALTERNATE_PER_STEP", "1") self.group_combo_limit = os.environ.get("GROUP_COMBO_LIMIT", "100") self.suppress_standalone_colon = os.environ.get("SUPPRESS_STANDALONE_COLON", "1") def apply(self, allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool) -> bool: changed = False new_allow = "1" if allow_empty_alt else "0" new_expand = "1" if expand_alt else "0" new_limit = str(int(combo_limit)) new_suppress = "1" if suppress_colon else "0" if new_allow != self.allow_empty_alt: os.environ["ALLOW_EMPTY_ALTERNATE"] = new_allow self.allow_empty_alt = new_allow changed = True if new_expand != self.expand_alt_per_step: os.environ["EXPAND_ALTERNATE_PER_STEP"] = new_expand self.expand_alt_per_step = new_expand changed = True if new_limit != self.group_combo_limit: os.environ["GROUP_COMBO_LIMIT"] = new_limit self.group_combo_limit = new_limit changed = True if new_suppress != self.suppress_standalone_colon: os.environ["SUPPRESS_STANDALONE_COLON"] = new_suppress self.suppress_standalone_colon = new_suppress changed = True return changed ENV_STATE = EnvState() def reload_parser_after_env(allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool): """Применяет ENV и перезагружает парсер при изменении.""" global _pp, _pp_error if ENV_STATE.apply(allow_empty_alt, expand_alt, combo_limit, suppress_colon): _pp, _pp_error = load_parser_safely() def reload_parser_manual() -> str: """Ручная перезагрузка парсера кнопкой.""" global _pp, _pp_error _pp, _pp_error = load_parser_safely() return _pp_error or "OK" # --------------------------------------------------------- # Препроцессор: repeat и prob-alternates # --------------------------------------------------------- def _scan_balanced_parens(s: str, start_idx: int) -> Optional[Tuple[int, str]]: if start_idx >= len(s) or s[start_idx] != '(': return None depth = 0 i = start_idx buf: List[str] = [] while i < len(s): ch = s[i] if ch == '(': depth += 1 if depth > 1: buf.append(ch) elif ch == ')': depth -= 1 if depth == 0: return i + 1, "".join(buf) else: buf.append(ch) else: buf.append(ch) i += 1 return None def _pp_repeat_macro(src: str) -> Tuple[str, List[str]]: changes: List[str] = [] i = 0 out: List[str] = [] rx = re.compile(r"\brepeat\s+(\d+)\s*\(", re.IGNORECASE) while i < len(src): m = rx.search(src, i) if not m: out.append(src[i:]); break out.append(src[i:m.start()]) n = max(1, int(m.group(1))) open_idx = m.end() - 1 scanned = _scan_balanced_parens(src, open_idx) if not scanned: out.append(src[m.start():m.end()]); i = m.end(); continue end_idx, inner = scanned body = inner.strip() out.append(", ".join([body] * n)) changes.append(f"repeat {n} (...) → {n} повторов") i = end_idx return "".join(out), changes def _split_top_level_bar(block: str) -> List[str]: parts: List[str] = []; buf: List[str] = [] d_round = d_square = d_curly = 0; i = 0 while i < len(block): ch = block[i] if ch == '\\': if i + 1 < len(block): buf.append(block[i:i+2]); i += 2; continue if ch == '(': d_round += 1 elif ch == ')': d_round = max(0, d_round - 1) elif ch == '[': d_square += 1 elif ch == ']': d_square = max(0, d_square - 1) elif ch == '{': d_curly += 1 elif ch == '}': d_curly = max(0, d_curly - 1) elif ch == '|' and d_round == d_square == d_curly == 0: parts.append("".join(buf)); buf = []; i += 1; continue buf.append(ch); i += 1 parts.append("".join(buf)) return [p.strip() for p in parts] def _pp_prob_alt(src: str, total_steps: int) -> Tuple[str, List[str]]: changes: List[str] = [] token_rx = re.compile(r"((?:[^\n\[\]]+?\{\s*\d*\.?\d+\s*\}\s*\|\s*)+[^\n\[\]]+?\{\s*\d*\.?\d+\s*\})") def convert_block(block: str) -> str: parts = _split_top_level_bar(block) items: List[Tuple[str, float]] = []; total = 0.0 for p in parts: m = re.search(r"\{\s*([+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?)\s*\}\s*$", p) if not m: return block w = float(m.group(1)) text = re.sub(r"\{\s*[+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?\s*\}\s*$", "", p).strip() items.append((text, w)); total += w if total <= 0: return block boundaries: List[str] = []; acc = 0.0 for i in range(len(items) - 1): acc += items[i][1] / total b = int(round(acc * total_steps)) if 0 < b < total_steps: boundaries.append(str(b)) core = " : ".join([t for t, _ in items]) extra = f" : {', '.join(boundaries)}" if boundaries else "" return f"[ {core} ]{extra}" def repl(m): orig = m.group(1) out = convert_block(orig) if out != orig: changes.append("prob a{p}|b{q} → scheduled") return out return token_rx.sub(repl, src), changes def preprocess(text: str, total_steps: int = 20) -> Tuple[str, List[str]]: out1, ch1 = _pp_repeat_macro(text) out2, ch2 = _pp_prob_alt(out1, total_steps) return out2, ch1 + ch2 # --------------------------------------------------------- # Аналитика # --------------------------------------------------------- @dataclass class AnalysisResult: ok: bool error: Optional[str] parse_tree: Optional[Any] resolved_text: Optional[str] schedule: Optional[List[Tuple[int, str]]] timeline: Optional[List[Tuple[int, str]]] changes: List[str] def analyze_prompt( text: str, steps: int = 20, use_preprocessor: bool = True, visitor_mode: bool = True, seed: Optional[int] = None, allow_empty_alt: bool = False, expand_alt_per_step: bool = True, group_combo_limit: int = 100, suppress_standalone_colon: bool = True, ) -> AnalysisResult: if _pp is None: return AnalysisResult(False, f"Parser not loaded: {_pp_error}", None, None, None, None, []) reload_parser_after_env(allow_empty_alt, expand_alt_per_step, group_combo_limit, suppress_standalone_colon) changes: List[str] = [] src = text if use_preprocessor: src, ch = preprocess(src, steps) changes.extend(ch) try: tree = _pp.schedule_parser.parse(src) except Exception as e: return AnalysisResult(False, f"{e}", None, None, None, None, changes) try: resolved = _pp.resolve_tree(tree, keep_spacing=True) except Exception as e: return AnalysisResult(False, f"Resolve error: {e}", tree, None, None, None, changes) try: if visitor_mode: collector = _pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed=seed) schedules = collector.visit(tree) or [] else: schedules = [] for step in range(1, steps + 1): transformer = _pp.ScheduleTransformer(total_steps=steps, current_step=step, seed=seed) txt = transformer.transform(tree) schedules.append((step, txt)) schedules = [[b, t] for (b, t) in schedules] except Exception as e: return AnalysisResult(False, f"Schedule error: {e}", tree, resolved, None, None, changes) try: norm: List[Tuple[int, str]] = [] for item in schedules: if isinstance(item, (list, tuple)) and len(item) == 2: b, t = int(item[0]), str(item[1]); norm.append((b, t)) norm.sort(key=lambda x: x[0]) timeline: List[Tuple[int, str]] = [] idx = 0; cur = norm[0][1] if norm else "" for step in range(1, steps + 1): while idx < len(norm) and step > norm[idx][0]: cur = norm[idx][1]; idx += 1 timeline.append((step, cur)) except Exception as e: return AnalysisResult(False, f"Timeline error: {e}", tree, resolved, [(int(b), str(t)) for b, t in schedules], None, changes) return AnalysisResult(True, None, tree, resolved, [(int(b), str(t)) for b, t in schedules], timeline, changes) def schedule_to_csv(schedule: List[Tuple[int, str]]) -> str: from io import StringIO buf = StringIO(); w = csv.writer(buf) w.writerow(["boundary", "text"]) for b, t in schedule: w.writerow([b, t]) return buf.getvalue() def timeline_to_csv(timeline: List[Tuple[int, str]]) -> str: from io import StringIO buf = StringIO(); w = csv.writer(buf) w.writerow(["step", "text"]) for s, t in timeline: w.writerow([s, t]) return buf.getvalue() # --------------------------------------------------------- # QuickFix # --------------------------------------------------------- def _close_unbalanced_brackets(src: str) -> str: opens = {'(': 0, '[': 0, '{': 0} for ch in src: if ch in opens: opens[ch] += 1 elif ch == ')': opens['('] = max(0, opens['('] - 1) elif ch == ']': opens['['] = max(0, opens['['] - 1) elif ch == '}': opens['{'] = max(0, opens['{'] - 1) return src + (')' * opens['(']) + (']' * opens['[']) + ('}' * opens['{']) def _escape_bars_inside_parens(src: str) -> str: out: List[str] = []; depth = 0; i = 0 while i < len(src): ch = src[i] if ch == '\\' and i + 1 < len(src): out.append(src[i:i+2]); i += 2; continue if ch == '(': depth += 1; out.append(ch) elif ch == ')': depth = max(0, depth - 1); out.append(ch) elif ch == '|' and depth > 0: out.append(r'\|') else: out.append(ch) i += 1 return "".join(out) def _replace_single_colon_labels_outside_schedules(src: str) -> str: segments: List[Tuple[bool, str]] = []; inside = 0; buf: List[str] = [] for ch in src: if ch == '[': if buf: segments.append((False, "".join(buf))); buf = [] inside += 1; segments.append((True, ch)) elif ch == ']': segments.append((True, ch)); inside = max(0, inside - 1) else: if inside: segments.append((True, ch)) else: buf.append(ch) if buf: segments.append((False, "".join(buf))) def repl(text: str) -> str: pattern = re.compile(r'(? List[Dict[str, str]]: sgs: List[Dict[str, str]] = [] if any(src.count(o) > src.count(c) for o, c in [('(', ')'), ('[', ']'), ('{', '}')]): sgs.append({"id": "close_brackets", "title": "Закрыть незакрытые скобки"}) if '|' in src: sgs.append({"id": "escape_bars_in_parens", "title": "Экранировать | внутри (...)"}) if os.environ.get("SUPPRESS_STANDALONE_COLON", "1") == "1": sgs.append({"id": "labels_double_colon", "title": "word:word → word::word вне [..]"}) sgs.append({"id": "normalize_ws", "title": "Нормализовать пробелы"}) return sgs def apply_quickfix(src: str, fix_id: str) -> str: if fix_id == "close_brackets": return _close_unbalanced_brackets(src) if fix_id == "escape_bars_in_parens": return _escape_bars_inside_parens(src) if fix_id == "labels_double_colon": return _replace_single_colon_labels_outside_schedules(src) if fix_id == "normalize_ws": s = src.replace('\r\n', '\n').replace('\r', '\n') s = re.sub(r'[ \t]+', ' ', s); s = re.sub(r' *\n *', ' ', s); s = re.sub(r' {2,}', ' ', s) return s.strip() return src # --------------------------------------------------------- # Wraps / Builders # --------------------------------------------------------- PALETTE_TEMPLATES: Dict[str, str] = { "Emphasis (1.2)": "(text:1.2)", "Emphasis range": "(text:start:end)", "Square schedule": "[ a : b ] : 20", "Triple schedule": "[ a : b : c ] : 20, 40", "Alternate": "a | b", "Prob alternate": "a{0.5} | b{0.5}", "Group label": "owner::label::description", "AND": "AND( term1, term2 )", "BREAK": "BREAK", } def palette_insert(cur_prompt: str, key: str) -> str: tpl = PALETTE_TEMPLATES.get(key, "") sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", " return (cur_prompt or "") + (sep + tpl) def _split_top_level_comma(block: str) -> List[str]: parts: List[str] = []; buf: List[str] = [] d_round = d_square = d_curly = 0; i = 0 while i < len(block): ch = block[i] if ch == '\\' and i + 1 < len(block): buf.append(block[i:i+2]); i += 2; continue if ch == '(': d_round += 1 elif ch == ')': d_round = max(0, d_round - 1) elif ch == '[': d_square += 1 elif ch == ']': d_square = max(0, d_square - 1) elif ch == '{': d_curly += 1 elif ch == '}': d_curly = max(0, d_curly - 1) elif ch == ',' and d_round == d_square == d_curly == 0: parts.append("".join(buf)); buf = []; i += 1; continue buf.append(ch); i += 1 parts.append("".join(buf)) return [p.strip() for p in parts if p.strip()] def segments_from_text(src: str, mode: str = "auto") -> List[str]: s = (src or "").strip() if not s: return [] if mode == "lines" or ('\n' in s and mode == "auto"): return [ln.strip() for ln in s.splitlines() if ln.strip()] if mode == "comma": return _split_top_level_comma(s) return _split_top_level_comma(s) def wrap_sequence(src: str, seg_mode: str = "auto") -> str: return ", ".join(segments_from_text(src, seg_mode)) def wrap_top_level2(src: str, steps: int = 20, boundary: Optional[int] = None, seg_mode: str = "auto") -> str: segs = segments_from_text(src, seg_mode) if len(segs) < 2: segs = segs + segs core = " : ".join(segs[:2]) b = boundary if boundary is not None else max(1, min(steps - 1, round(steps / 2))) return f"[ {core} ] : {b}" def wrap_top_level3(src: str, steps: int = 30, boundaries: Optional[Tuple[int, int]] = None, seg_mode: str = "auto") -> str: segs = segments_from_text(src, seg_mode) if len(segs) < 3: segs = segs + ([""] * (3 - len(segs))) segs = [s or f"seg{i+1}" for i, s in enumerate(segs)] core = " : ".join(segs[:3]) if boundaries is None: b1 = max(1, round(steps / 3)); b2 = max(b1 + 1, round(2 * steps / 3)) else: b1, b2 = int(boundaries[0]), int(boundaries[1]) return f"[ {core} ] : {b1}, {b2}" def wrap_alternate(src: str, seg_mode: str = "auto") -> str: return " | ".join(segments_from_text(src, seg_mode)) def wrap_prob_equal(src: str, weight: float = 1.0, seg_mode: str = "auto") -> str: segs = segments_from_text(src, seg_mode) if not segs: return "" w = f"{float(weight):g}" return " | ".join([f"{s}{{{w}}}" for s in segs]) # --------------------------------------------------------- # Curve Editor / Boundaries # --------------------------------------------------------- def _ease_in(t: float, p: float) -> float: return t ** max(1.0, p) def _ease_out(t: float, p: float) -> float: return 1.0 - (1.0 - t) ** max(1.0, p) def _ease_in_out(t: float) -> float: return 0.5 * (1 - math.cos(math.pi * t)) def _cosine(t: float) -> float: return (1 - math.cos(math.pi * t)) / 2.0 def _cubic_bezier_y(t: float, x1: float, y1: float, x2: float, y2: float) -> float: mt = 1 - t return 3 * mt * mt * t * y1 + 3 * mt * t * t * y2 + t**3 def _curve_f(t: float, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float: t = max(0.0, min(1.0, float(t))) if kind == "linear": return t if kind == "ease-in": return _ease_in(t, p) if kind == "ease-out": return _ease_out(t, p) if kind == "ease-in-out": return _ease_in_out(t) if kind == "cosine": return _cosine(t) if kind == "bezier" and bezier: x1,y1,x2,y2 = bezier return _cubic_bezier_y(t, x1, y1, x2, y2) return t def _inv_curve(u: float, kind: str, steps: int = 60, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float: lo, hi = 0.0, 1.0 for _ in range(steps): mid = 0.5 * (lo + hi) v = _curve_f(mid, kind, p=p, bezier=bezier) if v < u: lo = mid else: hi = mid return 0.5 * (lo + hi) def curve_boundaries(num_segments: int, total_steps: int, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> List[int]: num_segments = max(2, int(num_segments)) total_steps = max(2, int(total_steps)) targets = [i / num_segments for i in range(1, num_segments)] bounds: List[int] = [] for u in targets: t = _inv_curve(u, kind=kind, p=p, bezier=bezier) b = int(round(t * total_steps)) b = max(1, min(total_steps - 1, b)) if bounds and b <= bounds[-1]: b = min(total_steps - 1, bounds[-1] + 1) bounds.append(b) return bounds def build_curve_schedule(segments_text: str, num_segments: int, total_steps: int, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> Tuple[str, str]: segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()] if not segs: segs = [f"seg{i+1}" for i in range(num_segments)] else: if len(segs) < num_segments: segs.extend([f"seg{i+1}" for i in range(len(segs), num_segments)]) elif len(segs) > num_segments: segs = segs[:num_segments] bounds = curve_boundaries(num_segments, total_steps, kind=kind, p=p, bezier=bezier) core = " : ".join(segs) csv_bounds = ", ".join(str(b) for b in bounds) block = f"[ {core} ]" + (f" : {csv_bounds}" if csv_bounds else "") return block, csv_bounds # --------------------------------------------------------- # Token Counter (CLIP) — с fallback'ом # --------------------------------------------------------- class TokenCountResult: def __init__(self, backend: str, total: int, payload: int, limit: int, overflow: int, trimmed_text: str): self.backend = backend self.total = total self.payload = payload self.limit = limit self.overflow = overflow self.trimmed_text = trimmed_text def _tokenize_transformers(text: str): from transformers import AutoTokenizer # type: ignore tok = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_fast=True) enc = tok(text, add_special_tokens=True, return_offsets_mapping=True) ids = enc["input_ids"] offsets = [(int(a), int(b)) for (a, b) in enc["offset_mapping"]] return ids, offsets, "transformers" def _tokenize_open_clip(text: str): import open_clip # type: ignore tokenizer = open_clip.get_tokenizer("ViT-L-14") ids = tokenizer(text) offs = [(m.start(), m.end()) for m in re.finditer(r"\S+", text)] if hasattr(ids, "squeeze"): ids = ids.squeeze(0).tolist() return list(map(int, ids)), offs, "open_clip" def _tokenize_approx(text: str): toks = []; offs = [] for m in re.finditer(r"\w+|[^\s\w]", text, flags=re.UNICODE): toks.append(m.group(0)); offs.append((m.start(), m.end())) return toks, offs, "approx" def count_clip_tokens(text: str, limit_payload: int = 77, backend: str = "auto") -> TokenCountResult: try: if backend in ("auto", "transformers"): ids, offsets, used = _tokenize_transformers(text) total = len(ids); payload = max(0, total - 2) if total >= 2 else total keep_payload = min(payload, limit_payload) keep_ids = keep_payload + (2 if total >= 2 else 0) cut_char = None if keep_payload < payload and offsets: cut_index = keep_ids - 1 if 0 <= cut_index < len(offsets): cut_char = offsets[cut_index][1] trimmed = text[:cut_char] if cut_char is not None else text overflow = max(0, payload - limit_payload) return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) except Exception: pass try: if backend in ("auto", "open_clip"): ids, offsets, used = _tokenize_open_clip(text) total = len(ids); payload = max(0, total - 2) if total >= 2 else total keep_payload = min(payload, limit_payload) keep_ids = keep_payload + (2 if total >= 2 else 0) cut_index = min(keep_ids - 1, len(offsets) - 1) cut_char = offsets[cut_index][1] if keep_payload < payload and cut_index >= 0 else None trimmed = text[:cut_char] if cut_char is not None else text overflow = max(0, payload - limit_payload) return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) except Exception: pass toks, offsets, used = _tokenize_approx(text) total = len(toks); payload = total keep_payload = min(payload, limit_payload) cut_index = keep_payload - 1 cut_char = offsets[cut_index][1] if keep_payload < payload and 0 <= cut_index < len(offsets) else None trimmed = text[:cut_char] if cut_char is not None else text overflow = max(0, payload - limit_payload) return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) # --------------------------------------------------------- # Gradio UI для A1111 # --------------------------------------------------------- def on_ui_tabs(): import gradio as gr with gr.Blocks(analytics_enabled=False) as demo: gr.Markdown("## Prompt Refactor Plus") # --------- PARSER STATUS / CONTROL ---------- with gr.Row(): parser_status = gr.Textbox(label="Parser status", value=_pp_error or "OK", interactive=False) reload_btn = gr.Button("Reload parser") def _reload(): return reload_parser_manual() reload_btn.click(_reload, outputs=[parser_status]) # --------- EDITOR ---------- with gr.Tab("Editor"): prompt_tb = gr.Textbox(label="Prompt", lines=8, placeholder="Введите промпт...") with gr.Row(): palette_dd = gr.Dropdown(choices=list(PALETTE_TEMPLATES.keys()), value="Emphasis (1.2)", label="Шаблон") insert_btn = gr.Button("Вставить шаблон") insert_btn.click(palette_insert, inputs=[prompt_tb, palette_dd], outputs=[prompt_tb]) # --------- ANALYSIS ---------- with gr.Tab("Analysis"): with gr.Row(): steps = gr.Slider(1, 200, value=20, step=1, label="Steps") use_pre = gr.Checkbox(value=True, label="Use Preprocessor") suppress_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON") analyze_btn = gr.Button("Analyze") with gr.Accordion("Advanced", open=False): with gr.Row(): visitor = gr.Checkbox(value=True, label="Visitor mode") seed = gr.Number(value=None, label="Seed", precision=0) with gr.Row(): allow_empty_alt = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE") expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP") group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0) with gr.Tabs(): with gr.Tab("Resolved"): parse_out = gr.Textbox(label="", lines=4) changes_out = gr.Textbox(label="Preprocessor Changes", lines=3) with gr.Tab("Schedule CSV"): schedule_out = gr.Textbox(label="", lines=8) with gr.Tab("Timeline CSV"): timeline_out = gr.Textbox(label="", lines=8) # QuickFix gr.Markdown("### QuickFix") fixes_state = gr.State(value=[]) with gr.Row(): fixes_dd = gr.Dropdown(choices=[], label="Выберите исправление") apply_fix_btn = gr.Button("Применить") auto_fix_btn = gr.Button("Auto QuickFix") def _analyze(prompt, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon): res = analyze_prompt( text=prompt or "", steps=int(steps), use_preprocessor=bool(use_pre), visitor_mode=bool(visitor), seed=None if seed in (None, "") else int(seed), allow_empty_alt=bool(allow_empty_alt), expand_alt_per_step=bool(expand_alt), group_combo_limit=int(group_limit), suppress_standalone_colon=bool(suppress_colon), ) sgs = quickfix_suggestions(prompt or "") return ( res.resolved_text or "", "\n".join(res.changes), schedule_to_csv(res.schedule) if res.schedule else "", timeline_to_csv(res.timeline) if res.timeline else "", sgs, gr.update(choices=[s["title"] for s in sgs]), _pp_error or "OK" ) analyze_btn.click( _analyze, inputs=[prompt_tb, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon], outputs=[parse_out, changes_out, schedule_out, timeline_out, fixes_state, fixes_dd, parser_status] ) def _apply_fix_cur(prompt: str, fixes: List[Dict[str, str]], title: str): m = {f["title"]: f["id"] for f in fixes} fid = m.get(title) return apply_quickfix(prompt or "", fid) if fid else prompt apply_fix_btn.click(_apply_fix_cur, inputs=[prompt_tb, fixes_state, fixes_dd], outputs=[prompt_tb]) def _auto_fix(prompt: str): sgs = quickfix_suggestions(prompt or "") if not sgs: return prompt return apply_quickfix(prompt or "", sgs[0]["id"]) auto_fix_btn.click(_auto_fix, inputs=[prompt_tb], outputs=[prompt_tb]) # --------- WRAP ---------- with gr.Tab("Wrap Wizard"): src_tb = gr.Textbox(label="Исходный текст", lines=6, placeholder="Один сегмент на строку или через верхнеуровневые запятые") seg_mode_dd = gr.Radio(choices=["auto", "lines", "comma"], value="auto", label="Разделение") with gr.Row(): ww_steps = gr.Slider(2, 200, value=20, step=1, label="Steps (для top-level)") ww_boundary = gr.Number(value=None, label="Boundary (Top-Level 2)", precision=0) ww_b1 = gr.Number(value=None, label="b1 (Top-Level 3)", precision=0) ww_b2 = gr.Number(value=None, label="b2 (Top-Level 3)", precision=0) with gr.Row(): seq_btn = gr.Button("Sequence") tl2_btn = gr.Button("Top-Level (2)") tl3_btn = gr.Button("Top-Level (3)") with gr.Row(): alt_btn = gr.Button("Alternate") prob_btn = gr.Button("Prob-Equal") wrap_preview = gr.Textbox(label="Предпросмотр", lines=3) insert_wrap_btn = gr.Button("Вставить в Prompt") def on_wrap_seq(src, seg_mode): return wrap_sequence(src or "", seg_mode) def on_wrap_tl2(src, steps, boundary, seg_mode): b = None try: if boundary not in (None, ""): b = int(boundary) except Exception: b = None return wrap_top_level2(src or "", steps=int(steps), boundary=b, seg_mode=seg_mode) def on_wrap_tl3(src, steps, b1, b2, seg_mode): bb = None try: if b1 not in (None, "") and b2 not in (None, ""): bb = (int(b1), int(b2)) except Exception: bb = None return wrap_top_level3(src or "", steps=int(steps), boundaries=bb, seg_mode=seg_mode) def on_wrap_alt(src, seg_mode): return wrap_alternate(src or "", seg_mode) def on_wrap_prob(src, seg_mode): return wrap_prob_equal(src or "", 1.0, seg_mode) seq_btn.click(on_wrap_seq, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) tl2_btn.click(on_wrap_tl2, inputs=[src_tb, ww_steps, ww_boundary, seg_mode_dd], outputs=[wrap_preview]) tl3_btn.click(on_wrap_tl3, inputs=[src_tb, ww_steps, ww_b1, ww_b2, seg_mode_dd], outputs=[wrap_preview]) alt_btn.click(on_wrap_alt, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) prob_btn.click(on_wrap_prob, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) def _insert_block(cur_prompt: str, block: str) -> str: if not block or not block.strip(): return cur_prompt sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", " return (cur_prompt or "") + sep + block insert_wrap_btn.click(_insert_block, inputs=[prompt_tb, wrap_preview], outputs=[prompt_tb]) # --------- TIMELINE ---------- with gr.Tab("Timeline Wizard"): seg_tb = gr.Textbox(label="Сегменты (по одному на строку)", lines=4, value="a\nb") boundaries_tb = gr.Textbox(label="Границы (CSV)", value="20") build_tl_btn = gr.Button("Построить") tl_preview = gr.Textbox(label="Предпросмотр", lines=2) insert_tl_btn = gr.Button("Вставить в Prompt") def build_schedule_block(segments_text: str, boundaries_csv: str) -> str: segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()] if not segs: return "" core = " : ".join(segs) bounds = [b.strip() for b in (boundaries_csv or "").split(",") if b.strip()] extra = f" : {', '.join(bounds)}" if bounds else "" return f"[ {core} ]{extra}" build_tl_btn.click(build_schedule_block, inputs=[seg_tb, boundaries_tb], outputs=[tl_preview]) insert_tl_btn.click(_insert_block, inputs=[prompt_tb, tl_preview], outputs=[prompt_tb]) # --------- TOKEN COUNTER ---------- with gr.Tab("Token Counter"): payload_limit = gr.Slider(10, 200, value=77, step=1, label="Payload limit") backend_dd = gr.Dropdown(choices=["auto", "transformers", "open_clip", "approx"], value="auto", label="Backend (Advanced)") count_btn = gr.Button("Count") with gr.Row(): backend_used = gr.Textbox(label="Used backend", interactive=False) total_out = gr.Number(label="Total tokens", interactive=False) payload_out = gr.Number(label="Payload tokens", interactive=False) overflow_out = gr.Number(label="Overflow", interactive=False) trimmed_tb = gr.Textbox(label="Trimmed text", lines=6) def on_count(prompt, limit_payload, backend): res = count_clip_tokens(prompt or "", limit_payload=int(limit_payload), backend=backend) return res.backend, res.total, res.payload, res.overflow, res.trimmed_text count_btn.click(on_count, inputs=[prompt_tb, payload_limit, backend_dd], outputs=[backend_used, total_out, payload_out, overflow_out, trimmed_tb]) # --------- ADVANCED (Curve/Prob/Test) ---------- with gr.Tab("Advanced"): with gr.Accordion("Curve Editor", open=False): curve_segments_tb = gr.Textbox(label="Сегменты (пусто = seg1..N)", lines=6, value="") with gr.Row(): curve_num = gr.Slider(2, 8, value=3, step=1, label="N") curve_steps = gr.Slider(2, 300, value=60, step=1, label="Steps") curve_kind = gr.Dropdown(choices=["linear", "ease-in", "ease-out", "ease-in-out", "cosine", "bezier"], value="ease-in-out", label="Кривая") curve_power = gr.Slider(1, 6, value=2, step=0.5, label="Показатель (ease-in/out)") with gr.Row(): bez_x1 = gr.Slider(0, 1, value=0.3, step=0.05, label="Bezier x1") bez_y1 = gr.Slider(0, 1, value=0.0, step=0.05, label="Bezier y1") bez_x2 = gr.Slider(0, 1, value=0.7, step=0.05, label="Bezier x2") bez_y2 = gr.Slider(0, 1, value=1.0, step=0.05, label="Bezier y2") build_curve_btn = gr.Button("Build") curve_bounds_out = gr.Textbox(label="Границы", lines=1) curve_preview = gr.Textbox(label="Предпросмотр", lines=2) insert_curve_btn = gr.Button("Вставить в Prompt") def on_build_curve(segs_text, n, steps, kind, power, x1, y1, x2, y2): bez = (float(x1), float(y1), float(x2), float(y2)) if kind == "bezier" else None block, csvb = build_curve_schedule(segs_text or "", int(n), int(steps), str(kind), float(power), bez) return csvb, block build_curve_btn.click(on_build_curve, inputs=[curve_segments_tb, curve_num, curve_steps, curve_kind, curve_power, bez_x1, bez_y1, bez_x2, bez_y2], outputs=[curve_bounds_out, curve_preview]) insert_curve_btn.click(_insert_block, inputs=[prompt_tb, curve_preview], outputs=[prompt_tb]) with gr.Accordion("Prob-Alternate Wizard", open=False): steps_pa = gr.Slider(1, 200, value=20, step=1, label="Steps") table = gr.Dataframe(headers=["Текст", "Вес"], datatype=["str", "number"], row_count=3, col_count=(2, "fixed"), value=[["a", 0.5], ["b", 0.5], ["", ""]], label="Варианты") build_pa_btn = gr.Button("Build") pa_preview = gr.Textbox(label="Предпросмотр", lines=2) def build_schedule_from_prob_table(rows, total_steps: int) -> str: items: List[Tuple[str, float]] = []; total = 0.0 for r in rows or []: if not r or len(r) < 2: continue text = (r[0] or "").strip() if not text: continue try: w = float(str(r[1]).strip()) except Exception: w = 0.0 items.append((text, w)); total += w if not items or total <= 0: return "" boundaries: List[str] = []; acc = 0.0 for i in range(len(items) - 1): acc += items[i][1] / total b = int(round(acc * int(total_steps))) if 0 < b < int(total_steps): boundaries.append(str(b)) core = " : ".join([t for t, _ in items]) extra = f" : {', '.join(boundaries)}" if boundaries else "" return f"[ {core} ]{extra}" build_pa_btn.click(build_schedule_from_prob_table, inputs=[table, steps_pa], outputs=[pa_preview]) insert_pa_btn = gr.Button("Вставить в Prompt") insert_pa_btn.click(_insert_block, inputs=[prompt_tb, pa_preview], outputs=[prompt_tb]) with gr.Accordion("Test Generator (unit tests)", open=False): tg_steps = gr.Slider(1, 200, value=20, step=1, label="Steps") tg_seed = gr.Number(value=None, label="Seed", precision=0) with gr.Row(): tg_allow_empty = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE") tg_expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP") tg_group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0) tg_suppr_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON") gen_btn = gr.Button("Generate") test_code_tb = gr.Textbox(label="test_prompt_parser_generated.py", lines=18) test_file = gr.File(label="Скачать", interactive=False) def generate_unittests_code(prompt_src: str, steps: int, env: Dict[str, str], seed: Optional[int], analysis: AnalysisResult) -> str: resolved = analysis.resolved_text or "" schedule = analysis.schedule or [] def slice_pairs(pairs: List[Tuple[int,str]], n_head: int = 3, n_tail: int = 3): if len(pairs) <= n_head + n_tail: return pairs return pairs[:n_head] + pairs[-n_tail:] sched_sample = slice_pairs(schedule) ENV_LINES = "\n".join([f" os.environ[{repr(k)}] = {repr(v)}" for k, v in env.items()]) code = f'''\ import os, sys, unittest, importlib.machinery, importlib.util PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", {repr(PROMPT_PARSER_PATH)}) def import_from_path(mod_name, path): if mod_name in sys.modules: del sys.modules[mod_name] loader = importlib.machinery.SourceFileLoader(mod_name, path) spec = importlib.util.spec_from_loader(loader.name, loader) mod = importlib.util.module_from_spec(spec); loader.exec_module(mod); return mod class TestPromptParserGenerated(unittest.TestCase): @classmethod def setUpClass(cls): {ENV_LINES and ENV_LINES or " pass"} cls.pp = import_from_path("pp", PROMPT_PARSER_PATH) def test_parse_and_resolve(self): pp = self.pp; src = {repr(prompt_src)} tree = pp.schedule_parser.parse(src) resolved = pp.resolve_tree(tree, keep_spacing=True) self.assertEqual(resolved, {repr(resolved)}) def test_schedule_sample(self): pp = self.pp; src = {repr(prompt_src)}; steps = {int(steps)} collector = pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed={repr(seed)}) schedules = collector.visit(tree := pp.schedule_parser.parse(src)) or [] got = [(int(b), str(t)) for b, t in schedules] sample_got = (got[:3] + got[-3:]) if len(got) > 6 else got self.assertEqual(sample_got, {repr(sched_sample)}) if __name__ == "__main__": unittest.main() ''' return code def on_gen(prompt, steps, seed, ae, ex, gl, sc): reload_parser_after_env(bool(ae), bool(ex), int(gl), bool(sc)) res = analyze_prompt( text=prompt or "", steps=int(steps), use_preprocessor=True, visitor_mode=True, seed=None if seed in (None, "") else int(seed), allow_empty_alt=bool(ae), expand_alt_per_step=bool(ex), group_combo_limit=int(gl), suppress_standalone_colon=bool(sc) ) env = { "ALLOW_EMPTY_ALTERNATE": "1" if ae else "0", "EXPAND_ALTERNATE_PER_STEP": "1" if ex else "0", "GROUP_COMBO_LIMIT": str(int(gl)), "SUPPRESS_STANDALONE_COLON": "1" if sc else "0", } code = generate_unittests_code(prompt or "", int(steps), env, None if seed in (None, "") else int(seed), res) path = "/mnt/data/test_prompt_parser_generated.py" try: with open(path, "w", encoding="utf-8") as f: f.write(code) return code, path except Exception: return code, None gen_btn.click(on_gen, inputs=[prompt_tb, tg_steps, tg_seed, tg_allow_empty, tg_expand_alt, tg_group_limit, tg_suppr_colon], outputs=[test_code_tb, test_file]) # Вернуть как вкладку WebUI return [(demo, "Prompt Refactor", "prompt_refactor_tab")] # Зарегистрировать вкладку при работе внутри A1111 if script_callbacks is not None: script_callbacks.on_ui_tabs(on_ui_tabs)