| """ |
| Prompt Refactor Plus — вкладка для A1111 (Stable Diffusion WebUI). |
| |
| Почему раньше таб не показывался: |
| - Нельзя вызывать gradio.launch() внутри скрипта расширения. |
| - Скрипт должен регистрировать вкладку через script_callbacks.on_ui_tabs. |
| - В A1111 падает dataclasses со строковыми аннотациями, поэтому НЕ используем |
| `from __future__ import annotations`. |
| |
| Что сделано: |
| - Перенос в on_ui_tabs() + регистрация callback. |
| - Безопасная загрузка внешнего парсера из PROMPT_PARSER_PATH. Ошибка не ломает UI. |
| - Базовые параметры в явном виде. Остальные спрятаны в «Advanced». |
| - Анализ промпта, QuickFix, Wrap Wizard, Timeline Wizard, Token Counter. |
| - Дополнения: Curve Editor, Prob-Alternate Wizard, Test Generator — в Advanced. |
| - Кнопка Reload parser для горячей подгрузки парсера после правки файла. |
| |
| Совместимость: |
| - Python 3.10.x (как в A1111). |
| - Gradio 3.x, поставляемый с WebUI. |
| - Внешние пакеты (transformers/open_clip) опциональны: есть «approx» fallback. |
| |
| Файл предназначен для пути: extensions/<your-ext>/scripts/prompt-refactor.py |
| """ |
|
|
| import os |
| import re |
| import csv |
| import json |
| import math |
| import importlib |
| import importlib.util |
| import importlib.machinery |
| from dataclasses import dataclass |
| from typing import List, Tuple, Optional, Any, Dict |
|
|
| |
| try: |
| from modules import script_callbacks |
| except Exception as _e: |
| script_callbacks = None |
|
|
| |
| |
| |
| PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", "/content/A1111/modules/prompt_parser.py") |
| PP_MOD_NAME = "pp_prompt_parser" |
| _pp = None |
| _pp_error = "" |
|
|
| def import_from_path(mod_name: str, path: str): |
| loader = importlib.machinery.SourceFileLoader(mod_name, path) |
| spec = importlib.util.spec_from_loader(loader.name, loader) |
| mod = importlib.util.module_from_spec(spec) |
| loader.exec_module(mod) |
| return mod |
|
|
| def load_parser_safely() -> Tuple[Optional[Any], str]: |
| """Возвращает (модуль_парсера|None, сообщение_ошибки|'' ).""" |
| if not os.path.isfile(PROMPT_PARSER_PATH): |
| return None, f"Parser file not found: {PROMPT_PARSER_PATH}" |
| try: |
| |
| import sys as _sys |
| if PP_MOD_NAME in _sys.modules: |
| del _sys.modules[PP_MOD_NAME] |
| mod = import_from_path(PP_MOD_NAME, PROMPT_PARSER_PATH) |
| |
| _ = getattr(mod, "schedule_parser") |
| _ = getattr(mod, "resolve_tree") |
| return mod, "" |
| except Exception as e: |
| return None, f"{type(e).__name__}: {e}" |
|
|
| |
| _pp, _pp_error = load_parser_safely() |
|
|
| |
| |
| |
| class EnvState: |
| def __init__(self): |
| self.allow_empty_alt = os.environ.get("ALLOW_EMPTY_ALTERNATE", "0") |
| self.expand_alt_per_step = os.environ.get("EXPAND_ALTERNATE_PER_STEP", "1") |
| self.group_combo_limit = os.environ.get("GROUP_COMBO_LIMIT", "100") |
| self.suppress_standalone_colon = os.environ.get("SUPPRESS_STANDALONE_COLON", "1") |
|
|
| def apply(self, allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool) -> bool: |
| changed = False |
| new_allow = "1" if allow_empty_alt else "0" |
| new_expand = "1" if expand_alt else "0" |
| new_limit = str(int(combo_limit)) |
| new_suppress = "1" if suppress_colon else "0" |
|
|
| if new_allow != self.allow_empty_alt: |
| os.environ["ALLOW_EMPTY_ALTERNATE"] = new_allow |
| self.allow_empty_alt = new_allow |
| changed = True |
| if new_expand != self.expand_alt_per_step: |
| os.environ["EXPAND_ALTERNATE_PER_STEP"] = new_expand |
| self.expand_alt_per_step = new_expand |
| changed = True |
| if new_limit != self.group_combo_limit: |
| os.environ["GROUP_COMBO_LIMIT"] = new_limit |
| self.group_combo_limit = new_limit |
| changed = True |
| if new_suppress != self.suppress_standalone_colon: |
| os.environ["SUPPRESS_STANDALONE_COLON"] = new_suppress |
| self.suppress_standalone_colon = new_suppress |
| changed = True |
| return changed |
|
|
| ENV_STATE = EnvState() |
|
|
| def reload_parser_after_env(allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool): |
| """Применяет ENV и перезагружает парсер при изменении.""" |
| global _pp, _pp_error |
| if ENV_STATE.apply(allow_empty_alt, expand_alt, combo_limit, suppress_colon): |
| _pp, _pp_error = load_parser_safely() |
|
|
| def reload_parser_manual() -> str: |
| """Ручная перезагрузка парсера кнопкой.""" |
| global _pp, _pp_error |
| _pp, _pp_error = load_parser_safely() |
| return _pp_error or "OK" |
|
|
| |
| |
| |
| def _scan_balanced_parens(s: str, start_idx: int) -> Optional[Tuple[int, str]]: |
| if start_idx >= len(s) or s[start_idx] != '(': |
| return None |
| depth = 0 |
| i = start_idx |
| buf: List[str] = [] |
| while i < len(s): |
| ch = s[i] |
| if ch == '(': |
| depth += 1 |
| if depth > 1: buf.append(ch) |
| elif ch == ')': |
| depth -= 1 |
| if depth == 0: return i + 1, "".join(buf) |
| else: buf.append(ch) |
| else: |
| buf.append(ch) |
| i += 1 |
| return None |
|
|
| def _pp_repeat_macro(src: str) -> Tuple[str, List[str]]: |
| changes: List[str] = [] |
| i = 0 |
| out: List[str] = [] |
| rx = re.compile(r"\brepeat\s+(\d+)\s*\(", re.IGNORECASE) |
| while i < len(src): |
| m = rx.search(src, i) |
| if not m: |
| out.append(src[i:]); break |
| out.append(src[i:m.start()]) |
| n = max(1, int(m.group(1))) |
| open_idx = m.end() - 1 |
| scanned = _scan_balanced_parens(src, open_idx) |
| if not scanned: |
| out.append(src[m.start():m.end()]); i = m.end(); continue |
| end_idx, inner = scanned |
| body = inner.strip() |
| out.append(", ".join([body] * n)) |
| changes.append(f"repeat {n} (...) → {n} повторов") |
| i = end_idx |
| return "".join(out), changes |
|
|
| def _split_top_level_bar(block: str) -> List[str]: |
| parts: List[str] = []; buf: List[str] = [] |
| d_round = d_square = d_curly = 0; i = 0 |
| while i < len(block): |
| ch = block[i] |
| if ch == '\\': |
| if i + 1 < len(block): buf.append(block[i:i+2]); i += 2; continue |
| if ch == '(': d_round += 1 |
| elif ch == ')': d_round = max(0, d_round - 1) |
| elif ch == '[': d_square += 1 |
| elif ch == ']': d_square = max(0, d_square - 1) |
| elif ch == '{': d_curly += 1 |
| elif ch == '}': d_curly = max(0, d_curly - 1) |
| elif ch == '|' and d_round == d_square == d_curly == 0: |
| parts.append("".join(buf)); buf = []; i += 1; continue |
| buf.append(ch); i += 1 |
| parts.append("".join(buf)) |
| return [p.strip() for p in parts] |
|
|
| def _pp_prob_alt(src: str, total_steps: int) -> Tuple[str, List[str]]: |
| changes: List[str] = [] |
| token_rx = re.compile(r"((?:[^\n\[\]]+?\{\s*\d*\.?\d+\s*\}\s*\|\s*)+[^\n\[\]]+?\{\s*\d*\.?\d+\s*\})") |
| def convert_block(block: str) -> str: |
| parts = _split_top_level_bar(block) |
| items: List[Tuple[str, float]] = []; total = 0.0 |
| for p in parts: |
| m = re.search(r"\{\s*([+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?)\s*\}\s*$", p) |
| if not m: return block |
| w = float(m.group(1)) |
| text = re.sub(r"\{\s*[+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?\s*\}\s*$", "", p).strip() |
| items.append((text, w)); total += w |
| if total <= 0: return block |
| boundaries: List[str] = []; acc = 0.0 |
| for i in range(len(items) - 1): |
| acc += items[i][1] / total |
| b = int(round(acc * total_steps)) |
| if 0 < b < total_steps: boundaries.append(str(b)) |
| core = " : ".join([t for t, _ in items]) |
| extra = f" : {', '.join(boundaries)}" if boundaries else "" |
| return f"[ {core} ]{extra}" |
| def repl(m): |
| orig = m.group(1) |
| out = convert_block(orig) |
| if out != orig: changes.append("prob a{p}|b{q} → scheduled") |
| return out |
| return token_rx.sub(repl, src), changes |
|
|
| def preprocess(text: str, total_steps: int = 20) -> Tuple[str, List[str]]: |
| out1, ch1 = _pp_repeat_macro(text) |
| out2, ch2 = _pp_prob_alt(out1, total_steps) |
| return out2, ch1 + ch2 |
|
|
| |
| |
| |
| @dataclass |
| class AnalysisResult: |
| ok: bool |
| error: Optional[str] |
| parse_tree: Optional[Any] |
| resolved_text: Optional[str] |
| schedule: Optional[List[Tuple[int, str]]] |
| timeline: Optional[List[Tuple[int, str]]] |
| changes: List[str] |
|
|
| def analyze_prompt( |
| text: str, |
| steps: int = 20, |
| use_preprocessor: bool = True, |
| visitor_mode: bool = True, |
| seed: Optional[int] = None, |
| allow_empty_alt: bool = False, |
| expand_alt_per_step: bool = True, |
| group_combo_limit: int = 100, |
| suppress_standalone_colon: bool = True, |
| ) -> AnalysisResult: |
| if _pp is None: |
| return AnalysisResult(False, f"Parser not loaded: {_pp_error}", None, None, None, None, []) |
| reload_parser_after_env(allow_empty_alt, expand_alt_per_step, group_combo_limit, suppress_standalone_colon) |
|
|
| changes: List[str] = [] |
| src = text |
| if use_preprocessor: |
| src, ch = preprocess(src, steps) |
| changes.extend(ch) |
|
|
| try: |
| tree = _pp.schedule_parser.parse(src) |
| except Exception as e: |
| return AnalysisResult(False, f"{e}", None, None, None, None, changes) |
|
|
| try: |
| resolved = _pp.resolve_tree(tree, keep_spacing=True) |
| except Exception as e: |
| return AnalysisResult(False, f"Resolve error: {e}", tree, None, None, None, changes) |
|
|
| try: |
| if visitor_mode: |
| collector = _pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed=seed) |
| schedules = collector.visit(tree) or [] |
| else: |
| schedules = [] |
| for step in range(1, steps + 1): |
| transformer = _pp.ScheduleTransformer(total_steps=steps, current_step=step, seed=seed) |
| txt = transformer.transform(tree) |
| schedules.append((step, txt)) |
| schedules = [[b, t] for (b, t) in schedules] |
| except Exception as e: |
| return AnalysisResult(False, f"Schedule error: {e}", tree, resolved, None, None, changes) |
|
|
| try: |
| norm: List[Tuple[int, str]] = [] |
| for item in schedules: |
| if isinstance(item, (list, tuple)) and len(item) == 2: |
| b, t = int(item[0]), str(item[1]); norm.append((b, t)) |
| norm.sort(key=lambda x: x[0]) |
| timeline: List[Tuple[int, str]] = [] |
| idx = 0; cur = norm[0][1] if norm else "" |
| for step in range(1, steps + 1): |
| while idx < len(norm) and step > norm[idx][0]: |
| cur = norm[idx][1]; idx += 1 |
| timeline.append((step, cur)) |
| except Exception as e: |
| return AnalysisResult(False, f"Timeline error: {e}", tree, resolved, [(int(b), str(t)) for b, t in schedules], None, changes) |
|
|
| return AnalysisResult(True, None, tree, resolved, [(int(b), str(t)) for b, t in schedules], timeline, changes) |
|
|
| def schedule_to_csv(schedule: List[Tuple[int, str]]) -> str: |
| from io import StringIO |
| buf = StringIO(); w = csv.writer(buf) |
| w.writerow(["boundary", "text"]) |
| for b, t in schedule: w.writerow([b, t]) |
| return buf.getvalue() |
|
|
| def timeline_to_csv(timeline: List[Tuple[int, str]]) -> str: |
| from io import StringIO |
| buf = StringIO(); w = csv.writer(buf) |
| w.writerow(["step", "text"]) |
| for s, t in timeline: w.writerow([s, t]) |
| return buf.getvalue() |
|
|
| |
| |
| |
| def _close_unbalanced_brackets(src: str) -> str: |
| opens = {'(': 0, '[': 0, '{': 0} |
| for ch in src: |
| if ch in opens: opens[ch] += 1 |
| elif ch == ')': opens['('] = max(0, opens['('] - 1) |
| elif ch == ']': opens['['] = max(0, opens['['] - 1) |
| elif ch == '}': opens['{'] = max(0, opens['{'] - 1) |
| return src + (')' * opens['(']) + (']' * opens['[']) + ('}' * opens['{']) |
|
|
| def _escape_bars_inside_parens(src: str) -> str: |
| out: List[str] = []; depth = 0; i = 0 |
| while i < len(src): |
| ch = src[i] |
| if ch == '\\' and i + 1 < len(src): |
| out.append(src[i:i+2]); i += 2; continue |
| if ch == '(': |
| depth += 1; out.append(ch) |
| elif ch == ')': |
| depth = max(0, depth - 1); out.append(ch) |
| elif ch == '|' and depth > 0: |
| out.append(r'\|') |
| else: |
| out.append(ch) |
| i += 1 |
| return "".join(out) |
|
|
| def _replace_single_colon_labels_outside_schedules(src: str) -> str: |
| segments: List[Tuple[bool, str]] = []; inside = 0; buf: List[str] = [] |
| for ch in src: |
| if ch == '[': |
| if buf: segments.append((False, "".join(buf))); buf = [] |
| inside += 1; segments.append((True, ch)) |
| elif ch == ']': |
| segments.append((True, ch)); inside = max(0, inside - 1) |
| else: |
| if inside: segments.append((True, ch)) |
| else: buf.append(ch) |
| if buf: segments.append((False, "".join(buf))) |
|
|
| def repl(text: str) -> str: |
| pattern = re.compile(r'(?<!:)\b([A-Za-z0-9_]+)\s*:\s*(?!:)\b([A-Za-z0-9_]+)\b') |
| return pattern.sub(r'\1::\2', text) |
|
|
| out: List[str] = [] |
| for inside_sq, chunk in segments: |
| out.append(chunk if inside_sq else repl(chunk)) |
| return "".join(out) |
|
|
| def quickfix_suggestions(src: str) -> List[Dict[str, str]]: |
| sgs: List[Dict[str, str]] = [] |
| if any(src.count(o) > src.count(c) for o, c in [('(', ')'), ('[', ']'), ('{', '}')]): |
| sgs.append({"id": "close_brackets", "title": "Закрыть незакрытые скобки"}) |
| if '|' in src: |
| sgs.append({"id": "escape_bars_in_parens", "title": "Экранировать | внутри (...)"}) |
| if os.environ.get("SUPPRESS_STANDALONE_COLON", "1") == "1": |
| sgs.append({"id": "labels_double_colon", "title": "word:word → word::word вне [..]"}) |
| sgs.append({"id": "normalize_ws", "title": "Нормализовать пробелы"}) |
| return sgs |
|
|
| def apply_quickfix(src: str, fix_id: str) -> str: |
| if fix_id == "close_brackets": return _close_unbalanced_brackets(src) |
| if fix_id == "escape_bars_in_parens": return _escape_bars_inside_parens(src) |
| if fix_id == "labels_double_colon": return _replace_single_colon_labels_outside_schedules(src) |
| if fix_id == "normalize_ws": |
| s = src.replace('\r\n', '\n').replace('\r', '\n') |
| s = re.sub(r'[ \t]+', ' ', s); s = re.sub(r' *\n *', ' ', s); s = re.sub(r' {2,}', ' ', s) |
| return s.strip() |
| return src |
|
|
| |
| |
| |
| PALETTE_TEMPLATES: Dict[str, str] = { |
| "Emphasis (1.2)": "(text:1.2)", |
| "Emphasis range": "(text:start:end)", |
| "Square schedule": "[ a : b ] : 20", |
| "Triple schedule": "[ a : b : c ] : 20, 40", |
| "Alternate": "a | b", |
| "Prob alternate": "a{0.5} | b{0.5}", |
| "Group label": "owner::label::description", |
| "AND": "AND( term1, term2 )", |
| "BREAK": "BREAK", |
| } |
|
|
| def palette_insert(cur_prompt: str, key: str) -> str: |
| tpl = PALETTE_TEMPLATES.get(key, "") |
| sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", " |
| return (cur_prompt or "") + (sep + tpl) |
|
|
| def _split_top_level_comma(block: str) -> List[str]: |
| parts: List[str] = []; buf: List[str] = [] |
| d_round = d_square = d_curly = 0; i = 0 |
| while i < len(block): |
| ch = block[i] |
| if ch == '\\' and i + 1 < len(block): |
| buf.append(block[i:i+2]); i += 2; continue |
| if ch == '(': d_round += 1 |
| elif ch == ')': d_round = max(0, d_round - 1) |
| elif ch == '[': d_square += 1 |
| elif ch == ']': d_square = max(0, d_square - 1) |
| elif ch == '{': d_curly += 1 |
| elif ch == '}': d_curly = max(0, d_curly - 1) |
| elif ch == ',' and d_round == d_square == d_curly == 0: |
| parts.append("".join(buf)); buf = []; i += 1; continue |
| buf.append(ch); i += 1 |
| parts.append("".join(buf)) |
| return [p.strip() for p in parts if p.strip()] |
|
|
| def segments_from_text(src: str, mode: str = "auto") -> List[str]: |
| s = (src or "").strip() |
| if not s: return [] |
| if mode == "lines" or ('\n' in s and mode == "auto"): |
| return [ln.strip() for ln in s.splitlines() if ln.strip()] |
| if mode == "comma": |
| return _split_top_level_comma(s) |
| return _split_top_level_comma(s) |
|
|
| def wrap_sequence(src: str, seg_mode: str = "auto") -> str: |
| return ", ".join(segments_from_text(src, seg_mode)) |
|
|
| def wrap_top_level2(src: str, steps: int = 20, boundary: Optional[int] = None, seg_mode: str = "auto") -> str: |
| segs = segments_from_text(src, seg_mode) |
| if len(segs) < 2: segs = segs + segs |
| core = " : ".join(segs[:2]) |
| b = boundary if boundary is not None else max(1, min(steps - 1, round(steps / 2))) |
| return f"[ {core} ] : {b}" |
|
|
| def wrap_top_level3(src: str, steps: int = 30, boundaries: Optional[Tuple[int, int]] = None, seg_mode: str = "auto") -> str: |
| segs = segments_from_text(src, seg_mode) |
| if len(segs) < 3: |
| segs = segs + ([""] * (3 - len(segs))) |
| segs = [s or f"seg{i+1}" for i, s in enumerate(segs)] |
| core = " : ".join(segs[:3]) |
| if boundaries is None: |
| b1 = max(1, round(steps / 3)); b2 = max(b1 + 1, round(2 * steps / 3)) |
| else: |
| b1, b2 = int(boundaries[0]), int(boundaries[1]) |
| return f"[ {core} ] : {b1}, {b2}" |
|
|
| def wrap_alternate(src: str, seg_mode: str = "auto") -> str: |
| return " | ".join(segments_from_text(src, seg_mode)) |
|
|
| def wrap_prob_equal(src: str, weight: float = 1.0, seg_mode: str = "auto") -> str: |
| segs = segments_from_text(src, seg_mode) |
| if not segs: return "" |
| w = f"{float(weight):g}" |
| return " | ".join([f"{s}{{{w}}}" for s in segs]) |
|
|
| |
| |
| |
| def _ease_in(t: float, p: float) -> float: return t ** max(1.0, p) |
| def _ease_out(t: float, p: float) -> float: return 1.0 - (1.0 - t) ** max(1.0, p) |
| def _ease_in_out(t: float) -> float: return 0.5 * (1 - math.cos(math.pi * t)) |
| def _cosine(t: float) -> float: return (1 - math.cos(math.pi * t)) / 2.0 |
| def _cubic_bezier_y(t: float, x1: float, y1: float, x2: float, y2: float) -> float: |
| mt = 1 - t |
| return 3 * mt * mt * t * y1 + 3 * mt * t * t * y2 + t**3 |
|
|
| def _curve_f(t: float, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float: |
| t = max(0.0, min(1.0, float(t))) |
| if kind == "linear": return t |
| if kind == "ease-in": return _ease_in(t, p) |
| if kind == "ease-out": return _ease_out(t, p) |
| if kind == "ease-in-out": return _ease_in_out(t) |
| if kind == "cosine": return _cosine(t) |
| if kind == "bezier" and bezier: |
| x1,y1,x2,y2 = bezier |
| return _cubic_bezier_y(t, x1, y1, x2, y2) |
| return t |
|
|
| def _inv_curve(u: float, kind: str, steps: int = 60, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float: |
| lo, hi = 0.0, 1.0 |
| for _ in range(steps): |
| mid = 0.5 * (lo + hi) |
| v = _curve_f(mid, kind, p=p, bezier=bezier) |
| if v < u: lo = mid |
| else: hi = mid |
| return 0.5 * (lo + hi) |
|
|
| def curve_boundaries(num_segments: int, total_steps: int, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> List[int]: |
| num_segments = max(2, int(num_segments)) |
| total_steps = max(2, int(total_steps)) |
| targets = [i / num_segments for i in range(1, num_segments)] |
| bounds: List[int] = [] |
| for u in targets: |
| t = _inv_curve(u, kind=kind, p=p, bezier=bezier) |
| b = int(round(t * total_steps)) |
| b = max(1, min(total_steps - 1, b)) |
| if bounds and b <= bounds[-1]: b = min(total_steps - 1, bounds[-1] + 1) |
| bounds.append(b) |
| return bounds |
|
|
| def build_curve_schedule(segments_text: str, num_segments: int, total_steps: int, kind: str, |
| p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> Tuple[str, str]: |
| segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()] |
| if not segs: |
| segs = [f"seg{i+1}" for i in range(num_segments)] |
| else: |
| if len(segs) < num_segments: |
| segs.extend([f"seg{i+1}" for i in range(len(segs), num_segments)]) |
| elif len(segs) > num_segments: |
| segs = segs[:num_segments] |
| bounds = curve_boundaries(num_segments, total_steps, kind=kind, p=p, bezier=bezier) |
| core = " : ".join(segs) |
| csv_bounds = ", ".join(str(b) for b in bounds) |
| block = f"[ {core} ]" + (f" : {csv_bounds}" if csv_bounds else "") |
| return block, csv_bounds |
|
|
| |
| |
| |
| class TokenCountResult: |
| def __init__(self, backend: str, total: int, payload: int, limit: int, overflow: int, trimmed_text: str): |
| self.backend = backend |
| self.total = total |
| self.payload = payload |
| self.limit = limit |
| self.overflow = overflow |
| self.trimmed_text = trimmed_text |
|
|
| def _tokenize_transformers(text: str): |
| from transformers import AutoTokenizer |
| tok = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_fast=True) |
| enc = tok(text, add_special_tokens=True, return_offsets_mapping=True) |
| ids = enc["input_ids"] |
| offsets = [(int(a), int(b)) for (a, b) in enc["offset_mapping"]] |
| return ids, offsets, "transformers" |
|
|
| def _tokenize_open_clip(text: str): |
| import open_clip |
| tokenizer = open_clip.get_tokenizer("ViT-L-14") |
| ids = tokenizer(text) |
| offs = [(m.start(), m.end()) for m in re.finditer(r"\S+", text)] |
| if hasattr(ids, "squeeze"): |
| ids = ids.squeeze(0).tolist() |
| return list(map(int, ids)), offs, "open_clip" |
|
|
| def _tokenize_approx(text: str): |
| toks = []; offs = [] |
| for m in re.finditer(r"\w+|[^\s\w]", text, flags=re.UNICODE): |
| toks.append(m.group(0)); offs.append((m.start(), m.end())) |
| return toks, offs, "approx" |
|
|
| def count_clip_tokens(text: str, limit_payload: int = 77, backend: str = "auto") -> TokenCountResult: |
| try: |
| if backend in ("auto", "transformers"): |
| ids, offsets, used = _tokenize_transformers(text) |
| total = len(ids); payload = max(0, total - 2) if total >= 2 else total |
| keep_payload = min(payload, limit_payload) |
| keep_ids = keep_payload + (2 if total >= 2 else 0) |
| cut_char = None |
| if keep_payload < payload and offsets: |
| cut_index = keep_ids - 1 |
| if 0 <= cut_index < len(offsets): cut_char = offsets[cut_index][1] |
| trimmed = text[:cut_char] if cut_char is not None else text |
| overflow = max(0, payload - limit_payload) |
| return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) |
| except Exception: |
| pass |
| try: |
| if backend in ("auto", "open_clip"): |
| ids, offsets, used = _tokenize_open_clip(text) |
| total = len(ids); payload = max(0, total - 2) if total >= 2 else total |
| keep_payload = min(payload, limit_payload) |
| keep_ids = keep_payload + (2 if total >= 2 else 0) |
| cut_index = min(keep_ids - 1, len(offsets) - 1) |
| cut_char = offsets[cut_index][1] if keep_payload < payload and cut_index >= 0 else None |
| trimmed = text[:cut_char] if cut_char is not None else text |
| overflow = max(0, payload - limit_payload) |
| return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) |
| except Exception: |
| pass |
| toks, offsets, used = _tokenize_approx(text) |
| total = len(toks); payload = total |
| keep_payload = min(payload, limit_payload) |
| cut_index = keep_payload - 1 |
| cut_char = offsets[cut_index][1] if keep_payload < payload and 0 <= cut_index < len(offsets) else None |
| trimmed = text[:cut_char] if cut_char is not None else text |
| overflow = max(0, payload - limit_payload) |
| return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) |
|
|
| |
| |
| |
| def on_ui_tabs(): |
| import gradio as gr |
|
|
| with gr.Blocks(analytics_enabled=False) as demo: |
| gr.Markdown("## Prompt Refactor Plus") |
|
|
| |
| with gr.Row(): |
| parser_status = gr.Textbox(label="Parser status", value=_pp_error or "OK", interactive=False) |
| reload_btn = gr.Button("Reload parser") |
| def _reload(): |
| return reload_parser_manual() |
| reload_btn.click(_reload, outputs=[parser_status]) |
|
|
| |
| with gr.Tab("Editor"): |
| prompt_tb = gr.Textbox(label="Prompt", lines=8, placeholder="Введите промпт...") |
| with gr.Row(): |
| palette_dd = gr.Dropdown(choices=list(PALETTE_TEMPLATES.keys()), value="Emphasis (1.2)", label="Шаблон") |
| insert_btn = gr.Button("Вставить шаблон") |
| insert_btn.click(palette_insert, inputs=[prompt_tb, palette_dd], outputs=[prompt_tb]) |
|
|
| |
| with gr.Tab("Analysis"): |
| with gr.Row(): |
| steps = gr.Slider(1, 200, value=20, step=1, label="Steps") |
| use_pre = gr.Checkbox(value=True, label="Use Preprocessor") |
| suppress_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON") |
| analyze_btn = gr.Button("Analyze") |
| with gr.Accordion("Advanced", open=False): |
| with gr.Row(): |
| visitor = gr.Checkbox(value=True, label="Visitor mode") |
| seed = gr.Number(value=None, label="Seed", precision=0) |
| with gr.Row(): |
| allow_empty_alt = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE") |
| expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP") |
| group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Resolved"): |
| parse_out = gr.Textbox(label="", lines=4) |
| changes_out = gr.Textbox(label="Preprocessor Changes", lines=3) |
| with gr.Tab("Schedule CSV"): |
| schedule_out = gr.Textbox(label="", lines=8) |
| with gr.Tab("Timeline CSV"): |
| timeline_out = gr.Textbox(label="", lines=8) |
|
|
| |
| gr.Markdown("### QuickFix") |
| fixes_state = gr.State(value=[]) |
| with gr.Row(): |
| fixes_dd = gr.Dropdown(choices=[], label="Выберите исправление") |
| apply_fix_btn = gr.Button("Применить") |
| auto_fix_btn = gr.Button("Auto QuickFix") |
|
|
| def _analyze(prompt, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon): |
| res = analyze_prompt( |
| text=prompt or "", |
| steps=int(steps), |
| use_preprocessor=bool(use_pre), |
| visitor_mode=bool(visitor), |
| seed=None if seed in (None, "") else int(seed), |
| allow_empty_alt=bool(allow_empty_alt), |
| expand_alt_per_step=bool(expand_alt), |
| group_combo_limit=int(group_limit), |
| suppress_standalone_colon=bool(suppress_colon), |
| ) |
| sgs = quickfix_suggestions(prompt or "") |
| return ( |
| res.resolved_text or "", |
| "\n".join(res.changes), |
| schedule_to_csv(res.schedule) if res.schedule else "", |
| timeline_to_csv(res.timeline) if res.timeline else "", |
| sgs, |
| gr.update(choices=[s["title"] for s in sgs]), |
| _pp_error or "OK" |
| ) |
|
|
| analyze_btn.click( |
| _analyze, |
| inputs=[prompt_tb, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon], |
| outputs=[parse_out, changes_out, schedule_out, timeline_out, fixes_state, fixes_dd, parser_status] |
| ) |
|
|
| def _apply_fix_cur(prompt: str, fixes: List[Dict[str, str]], title: str): |
| m = {f["title"]: f["id"] for f in fixes} |
| fid = m.get(title) |
| return apply_quickfix(prompt or "", fid) if fid else prompt |
| apply_fix_btn.click(_apply_fix_cur, inputs=[prompt_tb, fixes_state, fixes_dd], outputs=[prompt_tb]) |
|
|
| def _auto_fix(prompt: str): |
| sgs = quickfix_suggestions(prompt or "") |
| if not sgs: return prompt |
| return apply_quickfix(prompt or "", sgs[0]["id"]) |
| auto_fix_btn.click(_auto_fix, inputs=[prompt_tb], outputs=[prompt_tb]) |
|
|
| |
| with gr.Tab("Wrap Wizard"): |
| src_tb = gr.Textbox(label="Исходный текст", lines=6, placeholder="Один сегмент на строку или через верхнеуровневые запятые") |
| seg_mode_dd = gr.Radio(choices=["auto", "lines", "comma"], value="auto", label="Разделение") |
| with gr.Row(): |
| ww_steps = gr.Slider(2, 200, value=20, step=1, label="Steps (для top-level)") |
| ww_boundary = gr.Number(value=None, label="Boundary (Top-Level 2)", precision=0) |
| ww_b1 = gr.Number(value=None, label="b1 (Top-Level 3)", precision=0) |
| ww_b2 = gr.Number(value=None, label="b2 (Top-Level 3)", precision=0) |
| with gr.Row(): |
| seq_btn = gr.Button("Sequence") |
| tl2_btn = gr.Button("Top-Level (2)") |
| tl3_btn = gr.Button("Top-Level (3)") |
| with gr.Row(): |
| alt_btn = gr.Button("Alternate") |
| prob_btn = gr.Button("Prob-Equal") |
| wrap_preview = gr.Textbox(label="Предпросмотр", lines=3) |
| insert_wrap_btn = gr.Button("Вставить в Prompt") |
|
|
| def on_wrap_seq(src, seg_mode): return wrap_sequence(src or "", seg_mode) |
| def on_wrap_tl2(src, steps, boundary, seg_mode): |
| b = None |
| try: |
| if boundary not in (None, ""): b = int(boundary) |
| except Exception: b = None |
| return wrap_top_level2(src or "", steps=int(steps), boundary=b, seg_mode=seg_mode) |
| def on_wrap_tl3(src, steps, b1, b2, seg_mode): |
| bb = None |
| try: |
| if b1 not in (None, "") and b2 not in (None, ""): bb = (int(b1), int(b2)) |
| except Exception: bb = None |
| return wrap_top_level3(src or "", steps=int(steps), boundaries=bb, seg_mode=seg_mode) |
| def on_wrap_alt(src, seg_mode): return wrap_alternate(src or "", seg_mode) |
| def on_wrap_prob(src, seg_mode): return wrap_prob_equal(src or "", 1.0, seg_mode) |
|
|
| seq_btn.click(on_wrap_seq, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) |
| tl2_btn.click(on_wrap_tl2, inputs=[src_tb, ww_steps, ww_boundary, seg_mode_dd], outputs=[wrap_preview]) |
| tl3_btn.click(on_wrap_tl3, inputs=[src_tb, ww_steps, ww_b1, ww_b2, seg_mode_dd], outputs=[wrap_preview]) |
| alt_btn.click(on_wrap_alt, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) |
| prob_btn.click(on_wrap_prob, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) |
|
|
| def _insert_block(cur_prompt: str, block: str) -> str: |
| if not block or not block.strip(): return cur_prompt |
| sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", " |
| return (cur_prompt or "") + sep + block |
| insert_wrap_btn.click(_insert_block, inputs=[prompt_tb, wrap_preview], outputs=[prompt_tb]) |
|
|
| |
| with gr.Tab("Timeline Wizard"): |
| seg_tb = gr.Textbox(label="Сегменты (по одному на строку)", lines=4, value="a\nb") |
| boundaries_tb = gr.Textbox(label="Границы (CSV)", value="20") |
| build_tl_btn = gr.Button("Построить") |
| tl_preview = gr.Textbox(label="Предпросмотр", lines=2) |
| insert_tl_btn = gr.Button("Вставить в Prompt") |
|
|
| def build_schedule_block(segments_text: str, boundaries_csv: str) -> str: |
| segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()] |
| if not segs: return "" |
| core = " : ".join(segs) |
| bounds = [b.strip() for b in (boundaries_csv or "").split(",") if b.strip()] |
| extra = f" : {', '.join(bounds)}" if bounds else "" |
| return f"[ {core} ]{extra}" |
|
|
| build_tl_btn.click(build_schedule_block, inputs=[seg_tb, boundaries_tb], outputs=[tl_preview]) |
| insert_tl_btn.click(_insert_block, inputs=[prompt_tb, tl_preview], outputs=[prompt_tb]) |
|
|
| |
| with gr.Tab("Token Counter"): |
| payload_limit = gr.Slider(10, 200, value=77, step=1, label="Payload limit") |
| backend_dd = gr.Dropdown(choices=["auto", "transformers", "open_clip", "approx"], value="auto", label="Backend (Advanced)") |
| count_btn = gr.Button("Count") |
| with gr.Row(): |
| backend_used = gr.Textbox(label="Used backend", interactive=False) |
| total_out = gr.Number(label="Total tokens", interactive=False) |
| payload_out = gr.Number(label="Payload tokens", interactive=False) |
| overflow_out = gr.Number(label="Overflow", interactive=False) |
| trimmed_tb = gr.Textbox(label="Trimmed text", lines=6) |
|
|
| def on_count(prompt, limit_payload, backend): |
| res = count_clip_tokens(prompt or "", limit_payload=int(limit_payload), backend=backend) |
| return res.backend, res.total, res.payload, res.overflow, res.trimmed_text |
| count_btn.click(on_count, inputs=[prompt_tb, payload_limit, backend_dd], |
| outputs=[backend_used, total_out, payload_out, overflow_out, trimmed_tb]) |
|
|
| |
| with gr.Tab("Advanced"): |
| with gr.Accordion("Curve Editor", open=False): |
| curve_segments_tb = gr.Textbox(label="Сегменты (пусто = seg1..N)", lines=6, value="") |
| with gr.Row(): |
| curve_num = gr.Slider(2, 8, value=3, step=1, label="N") |
| curve_steps = gr.Slider(2, 300, value=60, step=1, label="Steps") |
| curve_kind = gr.Dropdown(choices=["linear", "ease-in", "ease-out", "ease-in-out", "cosine", "bezier"], value="ease-in-out", label="Кривая") |
| curve_power = gr.Slider(1, 6, value=2, step=0.5, label="Показатель (ease-in/out)") |
| with gr.Row(): |
| bez_x1 = gr.Slider(0, 1, value=0.3, step=0.05, label="Bezier x1") |
| bez_y1 = gr.Slider(0, 1, value=0.0, step=0.05, label="Bezier y1") |
| bez_x2 = gr.Slider(0, 1, value=0.7, step=0.05, label="Bezier x2") |
| bez_y2 = gr.Slider(0, 1, value=1.0, step=0.05, label="Bezier y2") |
| build_curve_btn = gr.Button("Build") |
| curve_bounds_out = gr.Textbox(label="Границы", lines=1) |
| curve_preview = gr.Textbox(label="Предпросмотр", lines=2) |
| insert_curve_btn = gr.Button("Вставить в Prompt") |
|
|
| def on_build_curve(segs_text, n, steps, kind, power, x1, y1, x2, y2): |
| bez = (float(x1), float(y1), float(x2), float(y2)) if kind == "bezier" else None |
| block, csvb = build_curve_schedule(segs_text or "", int(n), int(steps), str(kind), float(power), bez) |
| return csvb, block |
| build_curve_btn.click(on_build_curve, |
| inputs=[curve_segments_tb, curve_num, curve_steps, curve_kind, curve_power, bez_x1, bez_y1, bez_x2, bez_y2], |
| outputs=[curve_bounds_out, curve_preview]) |
| insert_curve_btn.click(_insert_block, inputs=[prompt_tb, curve_preview], outputs=[prompt_tb]) |
|
|
| with gr.Accordion("Prob-Alternate Wizard", open=False): |
| steps_pa = gr.Slider(1, 200, value=20, step=1, label="Steps") |
| table = gr.Dataframe(headers=["Текст", "Вес"], datatype=["str", "number"], |
| row_count=3, col_count=(2, "fixed"), |
| value=[["a", 0.5], ["b", 0.5], ["", ""]], label="Варианты") |
| build_pa_btn = gr.Button("Build") |
| pa_preview = gr.Textbox(label="Предпросмотр", lines=2) |
| def build_schedule_from_prob_table(rows, total_steps: int) -> str: |
| items: List[Tuple[str, float]] = []; total = 0.0 |
| for r in rows or []: |
| if not r or len(r) < 2: continue |
| text = (r[0] or "").strip() |
| if not text: continue |
| try: w = float(str(r[1]).strip()) |
| except Exception: w = 0.0 |
| items.append((text, w)); total += w |
| if not items or total <= 0: return "" |
| boundaries: List[str] = []; acc = 0.0 |
| for i in range(len(items) - 1): |
| acc += items[i][1] / total |
| b = int(round(acc * int(total_steps))) |
| if 0 < b < int(total_steps): boundaries.append(str(b)) |
| core = " : ".join([t for t, _ in items]) |
| extra = f" : {', '.join(boundaries)}" if boundaries else "" |
| return f"[ {core} ]{extra}" |
| build_pa_btn.click(build_schedule_from_prob_table, inputs=[table, steps_pa], outputs=[pa_preview]) |
| insert_pa_btn = gr.Button("Вставить в Prompt") |
| insert_pa_btn.click(_insert_block, inputs=[prompt_tb, pa_preview], outputs=[prompt_tb]) |
|
|
| with gr.Accordion("Test Generator (unit tests)", open=False): |
| tg_steps = gr.Slider(1, 200, value=20, step=1, label="Steps") |
| tg_seed = gr.Number(value=None, label="Seed", precision=0) |
| with gr.Row(): |
| tg_allow_empty = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE") |
| tg_expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP") |
| tg_group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0) |
| tg_suppr_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON") |
| gen_btn = gr.Button("Generate") |
| test_code_tb = gr.Textbox(label="test_prompt_parser_generated.py", lines=18) |
| test_file = gr.File(label="Скачать", interactive=False) |
|
|
| def generate_unittests_code(prompt_src: str, steps: int, env: Dict[str, str], seed: Optional[int], analysis: AnalysisResult) -> str: |
| resolved = analysis.resolved_text or "" |
| schedule = analysis.schedule or [] |
| def slice_pairs(pairs: List[Tuple[int,str]], n_head: int = 3, n_tail: int = 3): |
| if len(pairs) <= n_head + n_tail: return pairs |
| return pairs[:n_head] + pairs[-n_tail:] |
| sched_sample = slice_pairs(schedule) |
| ENV_LINES = "\n".join([f" os.environ[{repr(k)}] = {repr(v)}" for k, v in env.items()]) |
| code = f'''\ |
| import os, sys, unittest, importlib.machinery, importlib.util |
| PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", {repr(PROMPT_PARSER_PATH)}) |
| def import_from_path(mod_name, path): |
| if mod_name in sys.modules: del sys.modules[mod_name] |
| loader = importlib.machinery.SourceFileLoader(mod_name, path) |
| spec = importlib.util.spec_from_loader(loader.name, loader) |
| mod = importlib.util.module_from_spec(spec); loader.exec_module(mod); return mod |
| class TestPromptParserGenerated(unittest.TestCase): |
| @classmethod |
| def setUpClass(cls): |
| {ENV_LINES and ENV_LINES or " pass"} |
| cls.pp = import_from_path("pp", PROMPT_PARSER_PATH) |
| def test_parse_and_resolve(self): |
| pp = self.pp; src = {repr(prompt_src)} |
| tree = pp.schedule_parser.parse(src) |
| resolved = pp.resolve_tree(tree, keep_spacing=True) |
| self.assertEqual(resolved, {repr(resolved)}) |
| def test_schedule_sample(self): |
| pp = self.pp; src = {repr(prompt_src)}; steps = {int(steps)} |
| collector = pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed={repr(seed)}) |
| schedules = collector.visit(tree := pp.schedule_parser.parse(src)) or [] |
| got = [(int(b), str(t)) for b, t in schedules] |
| sample_got = (got[:3] + got[-3:]) if len(got) > 6 else got |
| self.assertEqual(sample_got, {repr(sched_sample)}) |
| if __name__ == "__main__": unittest.main() |
| ''' |
| return code |
|
|
| def on_gen(prompt, steps, seed, ae, ex, gl, sc): |
| reload_parser_after_env(bool(ae), bool(ex), int(gl), bool(sc)) |
| res = analyze_prompt( |
| text=prompt or "", steps=int(steps), use_preprocessor=True, visitor_mode=True, |
| seed=None if seed in (None, "") else int(seed), |
| allow_empty_alt=bool(ae), expand_alt_per_step=bool(ex), |
| group_combo_limit=int(gl), suppress_standalone_colon=bool(sc) |
| ) |
| env = { |
| "ALLOW_EMPTY_ALTERNATE": "1" if ae else "0", |
| "EXPAND_ALTERNATE_PER_STEP": "1" if ex else "0", |
| "GROUP_COMBO_LIMIT": str(int(gl)), |
| "SUPPRESS_STANDALONE_COLON": "1" if sc else "0", |
| } |
| code = generate_unittests_code(prompt or "", int(steps), env, None if seed in (None, "") else int(seed), res) |
| path = "/mnt/data/test_prompt_parser_generated.py" |
| try: |
| with open(path, "w", encoding="utf-8") as f: f.write(code) |
| return code, path |
| except Exception: |
| return code, None |
|
|
| gen_btn.click(on_gen, |
| inputs=[prompt_tb, tg_steps, tg_seed, tg_allow_empty, tg_expand_alt, tg_group_limit, tg_suppr_colon], |
| outputs=[test_code_tb, test_file]) |
|
|
| |
| return [(demo, "Prompt Refactor", "prompt_refactor_tab")] |
|
|
| |
| if script_callbacks is not None: |
| script_callbacks.on_ui_tabs(on_ui_tabs) |
|
|