|
|
""" |
|
|
Prompt Refactor Plus — вкладка для A1111 (Stable Diffusion WebUI). |
|
|
|
|
|
Почему раньше таб не показывался: |
|
|
- Нельзя вызывать gradio.launch() внутри скрипта расширения. |
|
|
- Скрипт должен регистрировать вкладку через script_callbacks.on_ui_tabs. |
|
|
- В A1111 падает dataclasses со строковыми аннотациями, поэтому НЕ используем |
|
|
`from __future__ import annotations`. |
|
|
|
|
|
Что сделано: |
|
|
- Перенос в on_ui_tabs() + регистрация callback. |
|
|
- Безопасная загрузка внешнего парсера из PROMPT_PARSER_PATH. Ошибка не ломает UI. |
|
|
- Базовые параметры в явном виде. Остальные спрятаны в «Advanced». |
|
|
- Анализ промпта, QuickFix, Wrap Wizard, Timeline Wizard, Token Counter. |
|
|
- Дополнения: Curve Editor, Prob-Alternate Wizard, Test Generator — в Advanced. |
|
|
- Кнопка Reload parser для горячей подгрузки парсера после правки файла. |
|
|
|
|
|
Совместимость: |
|
|
- Python 3.10.x (как в A1111). |
|
|
- Gradio 3.x, поставляемый с WebUI. |
|
|
- Внешние пакеты (transformers/open_clip) опциональны: есть «approx» fallback. |
|
|
|
|
|
Файл предназначен для пути: extensions/<your-ext>/scripts/prompt-refactor.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import csv |
|
|
import json |
|
|
import math |
|
|
import importlib |
|
|
import importlib.util |
|
|
import importlib.machinery |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Tuple, Optional, Any, Dict |
|
|
|
|
|
|
|
|
try: |
|
|
from modules import script_callbacks |
|
|
except Exception as _e: |
|
|
script_callbacks = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", "/content/A1111/modules/prompt_parser.py") |
|
|
PP_MOD_NAME = "pp_prompt_parser" |
|
|
_pp = None |
|
|
_pp_error = "" |
|
|
|
|
|
def import_from_path(mod_name: str, path: str): |
|
|
loader = importlib.machinery.SourceFileLoader(mod_name, path) |
|
|
spec = importlib.util.spec_from_loader(loader.name, loader) |
|
|
mod = importlib.util.module_from_spec(spec) |
|
|
loader.exec_module(mod) |
|
|
return mod |
|
|
|
|
|
def load_parser_safely() -> Tuple[Optional[Any], str]: |
|
|
"""Возвращает (модуль_парсера|None, сообщение_ошибки|'' ).""" |
|
|
if not os.path.isfile(PROMPT_PARSER_PATH): |
|
|
return None, f"Parser file not found: {PROMPT_PARSER_PATH}" |
|
|
try: |
|
|
|
|
|
import sys as _sys |
|
|
if PP_MOD_NAME in _sys.modules: |
|
|
del _sys.modules[PP_MOD_NAME] |
|
|
mod = import_from_path(PP_MOD_NAME, PROMPT_PARSER_PATH) |
|
|
|
|
|
_ = getattr(mod, "schedule_parser") |
|
|
_ = getattr(mod, "resolve_tree") |
|
|
return mod, "" |
|
|
except Exception as e: |
|
|
return None, f"{type(e).__name__}: {e}" |
|
|
|
|
|
|
|
|
_pp, _pp_error = load_parser_safely() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnvState: |
|
|
def __init__(self): |
|
|
self.allow_empty_alt = os.environ.get("ALLOW_EMPTY_ALTERNATE", "0") |
|
|
self.expand_alt_per_step = os.environ.get("EXPAND_ALTERNATE_PER_STEP", "1") |
|
|
self.group_combo_limit = os.environ.get("GROUP_COMBO_LIMIT", "100") |
|
|
self.suppress_standalone_colon = os.environ.get("SUPPRESS_STANDALONE_COLON", "1") |
|
|
|
|
|
def apply(self, allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool) -> bool: |
|
|
changed = False |
|
|
new_allow = "1" if allow_empty_alt else "0" |
|
|
new_expand = "1" if expand_alt else "0" |
|
|
new_limit = str(int(combo_limit)) |
|
|
new_suppress = "1" if suppress_colon else "0" |
|
|
|
|
|
if new_allow != self.allow_empty_alt: |
|
|
os.environ["ALLOW_EMPTY_ALTERNATE"] = new_allow |
|
|
self.allow_empty_alt = new_allow |
|
|
changed = True |
|
|
if new_expand != self.expand_alt_per_step: |
|
|
os.environ["EXPAND_ALTERNATE_PER_STEP"] = new_expand |
|
|
self.expand_alt_per_step = new_expand |
|
|
changed = True |
|
|
if new_limit != self.group_combo_limit: |
|
|
os.environ["GROUP_COMBO_LIMIT"] = new_limit |
|
|
self.group_combo_limit = new_limit |
|
|
changed = True |
|
|
if new_suppress != self.suppress_standalone_colon: |
|
|
os.environ["SUPPRESS_STANDALONE_COLON"] = new_suppress |
|
|
self.suppress_standalone_colon = new_suppress |
|
|
changed = True |
|
|
return changed |
|
|
|
|
|
ENV_STATE = EnvState() |
|
|
|
|
|
def reload_parser_after_env(allow_empty_alt: bool, expand_alt: bool, combo_limit: int, suppress_colon: bool): |
|
|
"""Применяет ENV и перезагружает парсер при изменении.""" |
|
|
global _pp, _pp_error |
|
|
if ENV_STATE.apply(allow_empty_alt, expand_alt, combo_limit, suppress_colon): |
|
|
_pp, _pp_error = load_parser_safely() |
|
|
|
|
|
def reload_parser_manual() -> str: |
|
|
"""Ручная перезагрузка парсера кнопкой.""" |
|
|
global _pp, _pp_error |
|
|
_pp, _pp_error = load_parser_safely() |
|
|
return _pp_error or "OK" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _scan_balanced_parens(s: str, start_idx: int) -> Optional[Tuple[int, str]]: |
|
|
if start_idx >= len(s) or s[start_idx] != '(': |
|
|
return None |
|
|
depth = 0 |
|
|
i = start_idx |
|
|
buf: List[str] = [] |
|
|
while i < len(s): |
|
|
ch = s[i] |
|
|
if ch == '(': |
|
|
depth += 1 |
|
|
if depth > 1: buf.append(ch) |
|
|
elif ch == ')': |
|
|
depth -= 1 |
|
|
if depth == 0: return i + 1, "".join(buf) |
|
|
else: buf.append(ch) |
|
|
else: |
|
|
buf.append(ch) |
|
|
i += 1 |
|
|
return None |
|
|
|
|
|
def _pp_repeat_macro(src: str) -> Tuple[str, List[str]]: |
|
|
changes: List[str] = [] |
|
|
i = 0 |
|
|
out: List[str] = [] |
|
|
rx = re.compile(r"\brepeat\s+(\d+)\s*\(", re.IGNORECASE) |
|
|
while i < len(src): |
|
|
m = rx.search(src, i) |
|
|
if not m: |
|
|
out.append(src[i:]); break |
|
|
out.append(src[i:m.start()]) |
|
|
n = max(1, int(m.group(1))) |
|
|
open_idx = m.end() - 1 |
|
|
scanned = _scan_balanced_parens(src, open_idx) |
|
|
if not scanned: |
|
|
out.append(src[m.start():m.end()]); i = m.end(); continue |
|
|
end_idx, inner = scanned |
|
|
body = inner.strip() |
|
|
out.append(", ".join([body] * n)) |
|
|
changes.append(f"repeat {n} (...) → {n} повторов") |
|
|
i = end_idx |
|
|
return "".join(out), changes |
|
|
|
|
|
def _split_top_level_bar(block: str) -> List[str]: |
|
|
parts: List[str] = []; buf: List[str] = [] |
|
|
d_round = d_square = d_curly = 0; i = 0 |
|
|
while i < len(block): |
|
|
ch = block[i] |
|
|
if ch == '\\': |
|
|
if i + 1 < len(block): buf.append(block[i:i+2]); i += 2; continue |
|
|
if ch == '(': d_round += 1 |
|
|
elif ch == ')': d_round = max(0, d_round - 1) |
|
|
elif ch == '[': d_square += 1 |
|
|
elif ch == ']': d_square = max(0, d_square - 1) |
|
|
elif ch == '{': d_curly += 1 |
|
|
elif ch == '}': d_curly = max(0, d_curly - 1) |
|
|
elif ch == '|' and d_round == d_square == d_curly == 0: |
|
|
parts.append("".join(buf)); buf = []; i += 1; continue |
|
|
buf.append(ch); i += 1 |
|
|
parts.append("".join(buf)) |
|
|
return [p.strip() for p in parts] |
|
|
|
|
|
def _pp_prob_alt(src: str, total_steps: int) -> Tuple[str, List[str]]: |
|
|
changes: List[str] = [] |
|
|
token_rx = re.compile(r"((?:[^\n\[\]]+?\{\s*\d*\.?\d+\s*\}\s*\|\s*)+[^\n\[\]]+?\{\s*\d*\.?\d+\s*\})") |
|
|
def convert_block(block: str) -> str: |
|
|
parts = _split_top_level_bar(block) |
|
|
items: List[Tuple[str, float]] = []; total = 0.0 |
|
|
for p in parts: |
|
|
m = re.search(r"\{\s*([+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?)\s*\}\s*$", p) |
|
|
if not m: return block |
|
|
w = float(m.group(1)) |
|
|
text = re.sub(r"\{\s*[+\-]?\d*\.?\d+(?:[eE][+\-]?\d+)?\s*\}\s*$", "", p).strip() |
|
|
items.append((text, w)); total += w |
|
|
if total <= 0: return block |
|
|
boundaries: List[str] = []; acc = 0.0 |
|
|
for i in range(len(items) - 1): |
|
|
acc += items[i][1] / total |
|
|
b = int(round(acc * total_steps)) |
|
|
if 0 < b < total_steps: boundaries.append(str(b)) |
|
|
core = " : ".join([t for t, _ in items]) |
|
|
extra = f" : {', '.join(boundaries)}" if boundaries else "" |
|
|
return f"[ {core} ]{extra}" |
|
|
def repl(m): |
|
|
orig = m.group(1) |
|
|
out = convert_block(orig) |
|
|
if out != orig: changes.append("prob a{p}|b{q} → scheduled") |
|
|
return out |
|
|
return token_rx.sub(repl, src), changes |
|
|
|
|
|
def preprocess(text: str, total_steps: int = 20) -> Tuple[str, List[str]]: |
|
|
out1, ch1 = _pp_repeat_macro(text) |
|
|
out2, ch2 = _pp_prob_alt(out1, total_steps) |
|
|
return out2, ch1 + ch2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AnalysisResult: |
|
|
ok: bool |
|
|
error: Optional[str] |
|
|
parse_tree: Optional[Any] |
|
|
resolved_text: Optional[str] |
|
|
schedule: Optional[List[Tuple[int, str]]] |
|
|
timeline: Optional[List[Tuple[int, str]]] |
|
|
changes: List[str] |
|
|
|
|
|
def analyze_prompt( |
|
|
text: str, |
|
|
steps: int = 20, |
|
|
use_preprocessor: bool = True, |
|
|
visitor_mode: bool = True, |
|
|
seed: Optional[int] = None, |
|
|
allow_empty_alt: bool = False, |
|
|
expand_alt_per_step: bool = True, |
|
|
group_combo_limit: int = 100, |
|
|
suppress_standalone_colon: bool = True, |
|
|
) -> AnalysisResult: |
|
|
if _pp is None: |
|
|
return AnalysisResult(False, f"Parser not loaded: {_pp_error}", None, None, None, None, []) |
|
|
reload_parser_after_env(allow_empty_alt, expand_alt_per_step, group_combo_limit, suppress_standalone_colon) |
|
|
|
|
|
changes: List[str] = [] |
|
|
src = text |
|
|
if use_preprocessor: |
|
|
src, ch = preprocess(src, steps) |
|
|
changes.extend(ch) |
|
|
|
|
|
try: |
|
|
tree = _pp.schedule_parser.parse(src) |
|
|
except Exception as e: |
|
|
return AnalysisResult(False, f"{e}", None, None, None, None, changes) |
|
|
|
|
|
try: |
|
|
resolved = _pp.resolve_tree(tree, keep_spacing=True) |
|
|
except Exception as e: |
|
|
return AnalysisResult(False, f"Resolve error: {e}", tree, None, None, None, changes) |
|
|
|
|
|
try: |
|
|
if visitor_mode: |
|
|
collector = _pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed=seed) |
|
|
schedules = collector.visit(tree) or [] |
|
|
else: |
|
|
schedules = [] |
|
|
for step in range(1, steps + 1): |
|
|
transformer = _pp.ScheduleTransformer(total_steps=steps, current_step=step, seed=seed) |
|
|
txt = transformer.transform(tree) |
|
|
schedules.append((step, txt)) |
|
|
schedules = [[b, t] for (b, t) in schedules] |
|
|
except Exception as e: |
|
|
return AnalysisResult(False, f"Schedule error: {e}", tree, resolved, None, None, changes) |
|
|
|
|
|
try: |
|
|
norm: List[Tuple[int, str]] = [] |
|
|
for item in schedules: |
|
|
if isinstance(item, (list, tuple)) and len(item) == 2: |
|
|
b, t = int(item[0]), str(item[1]); norm.append((b, t)) |
|
|
norm.sort(key=lambda x: x[0]) |
|
|
timeline: List[Tuple[int, str]] = [] |
|
|
idx = 0; cur = norm[0][1] if norm else "" |
|
|
for step in range(1, steps + 1): |
|
|
while idx < len(norm) and step > norm[idx][0]: |
|
|
cur = norm[idx][1]; idx += 1 |
|
|
timeline.append((step, cur)) |
|
|
except Exception as e: |
|
|
return AnalysisResult(False, f"Timeline error: {e}", tree, resolved, [(int(b), str(t)) for b, t in schedules], None, changes) |
|
|
|
|
|
return AnalysisResult(True, None, tree, resolved, [(int(b), str(t)) for b, t in schedules], timeline, changes) |
|
|
|
|
|
def schedule_to_csv(schedule: List[Tuple[int, str]]) -> str: |
|
|
from io import StringIO |
|
|
buf = StringIO(); w = csv.writer(buf) |
|
|
w.writerow(["boundary", "text"]) |
|
|
for b, t in schedule: w.writerow([b, t]) |
|
|
return buf.getvalue() |
|
|
|
|
|
def timeline_to_csv(timeline: List[Tuple[int, str]]) -> str: |
|
|
from io import StringIO |
|
|
buf = StringIO(); w = csv.writer(buf) |
|
|
w.writerow(["step", "text"]) |
|
|
for s, t in timeline: w.writerow([s, t]) |
|
|
return buf.getvalue() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _close_unbalanced_brackets(src: str) -> str: |
|
|
opens = {'(': 0, '[': 0, '{': 0} |
|
|
for ch in src: |
|
|
if ch in opens: opens[ch] += 1 |
|
|
elif ch == ')': opens['('] = max(0, opens['('] - 1) |
|
|
elif ch == ']': opens['['] = max(0, opens['['] - 1) |
|
|
elif ch == '}': opens['{'] = max(0, opens['{'] - 1) |
|
|
return src + (')' * opens['(']) + (']' * opens['[']) + ('}' * opens['{']) |
|
|
|
|
|
def _escape_bars_inside_parens(src: str) -> str: |
|
|
out: List[str] = []; depth = 0; i = 0 |
|
|
while i < len(src): |
|
|
ch = src[i] |
|
|
if ch == '\\' and i + 1 < len(src): |
|
|
out.append(src[i:i+2]); i += 2; continue |
|
|
if ch == '(': |
|
|
depth += 1; out.append(ch) |
|
|
elif ch == ')': |
|
|
depth = max(0, depth - 1); out.append(ch) |
|
|
elif ch == '|' and depth > 0: |
|
|
out.append(r'\|') |
|
|
else: |
|
|
out.append(ch) |
|
|
i += 1 |
|
|
return "".join(out) |
|
|
|
|
|
def _replace_single_colon_labels_outside_schedules(src: str) -> str: |
|
|
segments: List[Tuple[bool, str]] = []; inside = 0; buf: List[str] = [] |
|
|
for ch in src: |
|
|
if ch == '[': |
|
|
if buf: segments.append((False, "".join(buf))); buf = [] |
|
|
inside += 1; segments.append((True, ch)) |
|
|
elif ch == ']': |
|
|
segments.append((True, ch)); inside = max(0, inside - 1) |
|
|
else: |
|
|
if inside: segments.append((True, ch)) |
|
|
else: buf.append(ch) |
|
|
if buf: segments.append((False, "".join(buf))) |
|
|
|
|
|
def repl(text: str) -> str: |
|
|
pattern = re.compile(r'(?<!:)\b([A-Za-z0-9_]+)\s*:\s*(?!:)\b([A-Za-z0-9_]+)\b') |
|
|
return pattern.sub(r'\1::\2', text) |
|
|
|
|
|
out: List[str] = [] |
|
|
for inside_sq, chunk in segments: |
|
|
out.append(chunk if inside_sq else repl(chunk)) |
|
|
return "".join(out) |
|
|
|
|
|
def quickfix_suggestions(src: str) -> List[Dict[str, str]]: |
|
|
sgs: List[Dict[str, str]] = [] |
|
|
if any(src.count(o) > src.count(c) for o, c in [('(', ')'), ('[', ']'), ('{', '}')]): |
|
|
sgs.append({"id": "close_brackets", "title": "Закрыть незакрытые скобки"}) |
|
|
if '|' in src: |
|
|
sgs.append({"id": "escape_bars_in_parens", "title": "Экранировать | внутри (...)"}) |
|
|
if os.environ.get("SUPPRESS_STANDALONE_COLON", "1") == "1": |
|
|
sgs.append({"id": "labels_double_colon", "title": "word:word → word::word вне [..]"}) |
|
|
sgs.append({"id": "normalize_ws", "title": "Нормализовать пробелы"}) |
|
|
return sgs |
|
|
|
|
|
def apply_quickfix(src: str, fix_id: str) -> str: |
|
|
if fix_id == "close_brackets": return _close_unbalanced_brackets(src) |
|
|
if fix_id == "escape_bars_in_parens": return _escape_bars_inside_parens(src) |
|
|
if fix_id == "labels_double_colon": return _replace_single_colon_labels_outside_schedules(src) |
|
|
if fix_id == "normalize_ws": |
|
|
s = src.replace('\r\n', '\n').replace('\r', '\n') |
|
|
s = re.sub(r'[ \t]+', ' ', s); s = re.sub(r' *\n *', ' ', s); s = re.sub(r' {2,}', ' ', s) |
|
|
return s.strip() |
|
|
return src |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PALETTE_TEMPLATES: Dict[str, str] = { |
|
|
"Emphasis (1.2)": "(text:1.2)", |
|
|
"Emphasis range": "(text:start:end)", |
|
|
"Square schedule": "[ a : b ] : 20", |
|
|
"Triple schedule": "[ a : b : c ] : 20, 40", |
|
|
"Alternate": "a | b", |
|
|
"Prob alternate": "a{0.5} | b{0.5}", |
|
|
"Group label": "owner::label::description", |
|
|
"AND": "AND( term1, term2 )", |
|
|
"BREAK": "BREAK", |
|
|
} |
|
|
|
|
|
def palette_insert(cur_prompt: str, key: str) -> str: |
|
|
tpl = PALETTE_TEMPLATES.get(key, "") |
|
|
sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", " |
|
|
return (cur_prompt or "") + (sep + tpl) |
|
|
|
|
|
def _split_top_level_comma(block: str) -> List[str]: |
|
|
parts: List[str] = []; buf: List[str] = [] |
|
|
d_round = d_square = d_curly = 0; i = 0 |
|
|
while i < len(block): |
|
|
ch = block[i] |
|
|
if ch == '\\' and i + 1 < len(block): |
|
|
buf.append(block[i:i+2]); i += 2; continue |
|
|
if ch == '(': d_round += 1 |
|
|
elif ch == ')': d_round = max(0, d_round - 1) |
|
|
elif ch == '[': d_square += 1 |
|
|
elif ch == ']': d_square = max(0, d_square - 1) |
|
|
elif ch == '{': d_curly += 1 |
|
|
elif ch == '}': d_curly = max(0, d_curly - 1) |
|
|
elif ch == ',' and d_round == d_square == d_curly == 0: |
|
|
parts.append("".join(buf)); buf = []; i += 1; continue |
|
|
buf.append(ch); i += 1 |
|
|
parts.append("".join(buf)) |
|
|
return [p.strip() for p in parts if p.strip()] |
|
|
|
|
|
def segments_from_text(src: str, mode: str = "auto") -> List[str]: |
|
|
s = (src or "").strip() |
|
|
if not s: return [] |
|
|
if mode == "lines" or ('\n' in s and mode == "auto"): |
|
|
return [ln.strip() for ln in s.splitlines() if ln.strip()] |
|
|
if mode == "comma": |
|
|
return _split_top_level_comma(s) |
|
|
return _split_top_level_comma(s) |
|
|
|
|
|
def wrap_sequence(src: str, seg_mode: str = "auto") -> str: |
|
|
return ", ".join(segments_from_text(src, seg_mode)) |
|
|
|
|
|
def wrap_top_level2(src: str, steps: int = 20, boundary: Optional[int] = None, seg_mode: str = "auto") -> str: |
|
|
segs = segments_from_text(src, seg_mode) |
|
|
if len(segs) < 2: segs = segs + segs |
|
|
core = " : ".join(segs[:2]) |
|
|
b = boundary if boundary is not None else max(1, min(steps - 1, round(steps / 2))) |
|
|
return f"[ {core} ] : {b}" |
|
|
|
|
|
def wrap_top_level3(src: str, steps: int = 30, boundaries: Optional[Tuple[int, int]] = None, seg_mode: str = "auto") -> str: |
|
|
segs = segments_from_text(src, seg_mode) |
|
|
if len(segs) < 3: |
|
|
segs = segs + ([""] * (3 - len(segs))) |
|
|
segs = [s or f"seg{i+1}" for i, s in enumerate(segs)] |
|
|
core = " : ".join(segs[:3]) |
|
|
if boundaries is None: |
|
|
b1 = max(1, round(steps / 3)); b2 = max(b1 + 1, round(2 * steps / 3)) |
|
|
else: |
|
|
b1, b2 = int(boundaries[0]), int(boundaries[1]) |
|
|
return f"[ {core} ] : {b1}, {b2}" |
|
|
|
|
|
def wrap_alternate(src: str, seg_mode: str = "auto") -> str: |
|
|
return " | ".join(segments_from_text(src, seg_mode)) |
|
|
|
|
|
def wrap_prob_equal(src: str, weight: float = 1.0, seg_mode: str = "auto") -> str: |
|
|
segs = segments_from_text(src, seg_mode) |
|
|
if not segs: return "" |
|
|
w = f"{float(weight):g}" |
|
|
return " | ".join([f"{s}{{{w}}}" for s in segs]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ease_in(t: float, p: float) -> float: return t ** max(1.0, p) |
|
|
def _ease_out(t: float, p: float) -> float: return 1.0 - (1.0 - t) ** max(1.0, p) |
|
|
def _ease_in_out(t: float) -> float: return 0.5 * (1 - math.cos(math.pi * t)) |
|
|
def _cosine(t: float) -> float: return (1 - math.cos(math.pi * t)) / 2.0 |
|
|
def _cubic_bezier_y(t: float, x1: float, y1: float, x2: float, y2: float) -> float: |
|
|
mt = 1 - t |
|
|
return 3 * mt * mt * t * y1 + 3 * mt * t * t * y2 + t**3 |
|
|
|
|
|
def _curve_f(t: float, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float: |
|
|
t = max(0.0, min(1.0, float(t))) |
|
|
if kind == "linear": return t |
|
|
if kind == "ease-in": return _ease_in(t, p) |
|
|
if kind == "ease-out": return _ease_out(t, p) |
|
|
if kind == "ease-in-out": return _ease_in_out(t) |
|
|
if kind == "cosine": return _cosine(t) |
|
|
if kind == "bezier" and bezier: |
|
|
x1,y1,x2,y2 = bezier |
|
|
return _cubic_bezier_y(t, x1, y1, x2, y2) |
|
|
return t |
|
|
|
|
|
def _inv_curve(u: float, kind: str, steps: int = 60, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> float: |
|
|
lo, hi = 0.0, 1.0 |
|
|
for _ in range(steps): |
|
|
mid = 0.5 * (lo + hi) |
|
|
v = _curve_f(mid, kind, p=p, bezier=bezier) |
|
|
if v < u: lo = mid |
|
|
else: hi = mid |
|
|
return 0.5 * (lo + hi) |
|
|
|
|
|
def curve_boundaries(num_segments: int, total_steps: int, kind: str, p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> List[int]: |
|
|
num_segments = max(2, int(num_segments)) |
|
|
total_steps = max(2, int(total_steps)) |
|
|
targets = [i / num_segments for i in range(1, num_segments)] |
|
|
bounds: List[int] = [] |
|
|
for u in targets: |
|
|
t = _inv_curve(u, kind=kind, p=p, bezier=bezier) |
|
|
b = int(round(t * total_steps)) |
|
|
b = max(1, min(total_steps - 1, b)) |
|
|
if bounds and b <= bounds[-1]: b = min(total_steps - 1, bounds[-1] + 1) |
|
|
bounds.append(b) |
|
|
return bounds |
|
|
|
|
|
def build_curve_schedule(segments_text: str, num_segments: int, total_steps: int, kind: str, |
|
|
p: float = 2.0, bezier: Optional[Tuple[float,float,float,float]] = None) -> Tuple[str, str]: |
|
|
segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()] |
|
|
if not segs: |
|
|
segs = [f"seg{i+1}" for i in range(num_segments)] |
|
|
else: |
|
|
if len(segs) < num_segments: |
|
|
segs.extend([f"seg{i+1}" for i in range(len(segs), num_segments)]) |
|
|
elif len(segs) > num_segments: |
|
|
segs = segs[:num_segments] |
|
|
bounds = curve_boundaries(num_segments, total_steps, kind=kind, p=p, bezier=bezier) |
|
|
core = " : ".join(segs) |
|
|
csv_bounds = ", ".join(str(b) for b in bounds) |
|
|
block = f"[ {core} ]" + (f" : {csv_bounds}" if csv_bounds else "") |
|
|
return block, csv_bounds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenCountResult: |
|
|
def __init__(self, backend: str, total: int, payload: int, limit: int, overflow: int, trimmed_text: str): |
|
|
self.backend = backend |
|
|
self.total = total |
|
|
self.payload = payload |
|
|
self.limit = limit |
|
|
self.overflow = overflow |
|
|
self.trimmed_text = trimmed_text |
|
|
|
|
|
def _tokenize_transformers(text: str): |
|
|
from transformers import AutoTokenizer |
|
|
tok = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_fast=True) |
|
|
enc = tok(text, add_special_tokens=True, return_offsets_mapping=True) |
|
|
ids = enc["input_ids"] |
|
|
offsets = [(int(a), int(b)) for (a, b) in enc["offset_mapping"]] |
|
|
return ids, offsets, "transformers" |
|
|
|
|
|
def _tokenize_open_clip(text: str): |
|
|
import open_clip |
|
|
tokenizer = open_clip.get_tokenizer("ViT-L-14") |
|
|
ids = tokenizer(text) |
|
|
offs = [(m.start(), m.end()) for m in re.finditer(r"\S+", text)] |
|
|
if hasattr(ids, "squeeze"): |
|
|
ids = ids.squeeze(0).tolist() |
|
|
return list(map(int, ids)), offs, "open_clip" |
|
|
|
|
|
def _tokenize_approx(text: str): |
|
|
toks = []; offs = [] |
|
|
for m in re.finditer(r"\w+|[^\s\w]", text, flags=re.UNICODE): |
|
|
toks.append(m.group(0)); offs.append((m.start(), m.end())) |
|
|
return toks, offs, "approx" |
|
|
|
|
|
def count_clip_tokens(text: str, limit_payload: int = 77, backend: str = "auto") -> TokenCountResult: |
|
|
try: |
|
|
if backend in ("auto", "transformers"): |
|
|
ids, offsets, used = _tokenize_transformers(text) |
|
|
total = len(ids); payload = max(0, total - 2) if total >= 2 else total |
|
|
keep_payload = min(payload, limit_payload) |
|
|
keep_ids = keep_payload + (2 if total >= 2 else 0) |
|
|
cut_char = None |
|
|
if keep_payload < payload and offsets: |
|
|
cut_index = keep_ids - 1 |
|
|
if 0 <= cut_index < len(offsets): cut_char = offsets[cut_index][1] |
|
|
trimmed = text[:cut_char] if cut_char is not None else text |
|
|
overflow = max(0, payload - limit_payload) |
|
|
return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
if backend in ("auto", "open_clip"): |
|
|
ids, offsets, used = _tokenize_open_clip(text) |
|
|
total = len(ids); payload = max(0, total - 2) if total >= 2 else total |
|
|
keep_payload = min(payload, limit_payload) |
|
|
keep_ids = keep_payload + (2 if total >= 2 else 0) |
|
|
cut_index = min(keep_ids - 1, len(offsets) - 1) |
|
|
cut_char = offsets[cut_index][1] if keep_payload < payload and cut_index >= 0 else None |
|
|
trimmed = text[:cut_char] if cut_char is not None else text |
|
|
overflow = max(0, payload - limit_payload) |
|
|
return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) |
|
|
except Exception: |
|
|
pass |
|
|
toks, offsets, used = _tokenize_approx(text) |
|
|
total = len(toks); payload = total |
|
|
keep_payload = min(payload, limit_payload) |
|
|
cut_index = keep_payload - 1 |
|
|
cut_char = offsets[cut_index][1] if keep_payload < payload and 0 <= cut_index < len(offsets) else None |
|
|
trimmed = text[:cut_char] if cut_char is not None else text |
|
|
overflow = max(0, payload - limit_payload) |
|
|
return TokenCountResult(used, total, payload, limit_payload, overflow, trimmed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_ui_tabs(): |
|
|
import gradio as gr |
|
|
|
|
|
with gr.Blocks(analytics_enabled=False) as demo: |
|
|
gr.Markdown("## Prompt Refactor Plus") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
parser_status = gr.Textbox(label="Parser status", value=_pp_error or "OK", interactive=False) |
|
|
reload_btn = gr.Button("Reload parser") |
|
|
def _reload(): |
|
|
return reload_parser_manual() |
|
|
reload_btn.click(_reload, outputs=[parser_status]) |
|
|
|
|
|
|
|
|
with gr.Tab("Editor"): |
|
|
prompt_tb = gr.Textbox(label="Prompt", lines=8, placeholder="Введите промпт...") |
|
|
with gr.Row(): |
|
|
palette_dd = gr.Dropdown(choices=list(PALETTE_TEMPLATES.keys()), value="Emphasis (1.2)", label="Шаблон") |
|
|
insert_btn = gr.Button("Вставить шаблон") |
|
|
insert_btn.click(palette_insert, inputs=[prompt_tb, palette_dd], outputs=[prompt_tb]) |
|
|
|
|
|
|
|
|
with gr.Tab("Analysis"): |
|
|
with gr.Row(): |
|
|
steps = gr.Slider(1, 200, value=20, step=1, label="Steps") |
|
|
use_pre = gr.Checkbox(value=True, label="Use Preprocessor") |
|
|
suppress_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON") |
|
|
analyze_btn = gr.Button("Analyze") |
|
|
with gr.Accordion("Advanced", open=False): |
|
|
with gr.Row(): |
|
|
visitor = gr.Checkbox(value=True, label="Visitor mode") |
|
|
seed = gr.Number(value=None, label="Seed", precision=0) |
|
|
with gr.Row(): |
|
|
allow_empty_alt = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE") |
|
|
expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP") |
|
|
group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Resolved"): |
|
|
parse_out = gr.Textbox(label="", lines=4) |
|
|
changes_out = gr.Textbox(label="Preprocessor Changes", lines=3) |
|
|
with gr.Tab("Schedule CSV"): |
|
|
schedule_out = gr.Textbox(label="", lines=8) |
|
|
with gr.Tab("Timeline CSV"): |
|
|
timeline_out = gr.Textbox(label="", lines=8) |
|
|
|
|
|
|
|
|
gr.Markdown("### QuickFix") |
|
|
fixes_state = gr.State(value=[]) |
|
|
with gr.Row(): |
|
|
fixes_dd = gr.Dropdown(choices=[], label="Выберите исправление") |
|
|
apply_fix_btn = gr.Button("Применить") |
|
|
auto_fix_btn = gr.Button("Auto QuickFix") |
|
|
|
|
|
def _analyze(prompt, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon): |
|
|
res = analyze_prompt( |
|
|
text=prompt or "", |
|
|
steps=int(steps), |
|
|
use_preprocessor=bool(use_pre), |
|
|
visitor_mode=bool(visitor), |
|
|
seed=None if seed in (None, "") else int(seed), |
|
|
allow_empty_alt=bool(allow_empty_alt), |
|
|
expand_alt_per_step=bool(expand_alt), |
|
|
group_combo_limit=int(group_limit), |
|
|
suppress_standalone_colon=bool(suppress_colon), |
|
|
) |
|
|
sgs = quickfix_suggestions(prompt or "") |
|
|
return ( |
|
|
res.resolved_text or "", |
|
|
"\n".join(res.changes), |
|
|
schedule_to_csv(res.schedule) if res.schedule else "", |
|
|
timeline_to_csv(res.timeline) if res.timeline else "", |
|
|
sgs, |
|
|
gr.update(choices=[s["title"] for s in sgs]), |
|
|
_pp_error or "OK" |
|
|
) |
|
|
|
|
|
analyze_btn.click( |
|
|
_analyze, |
|
|
inputs=[prompt_tb, steps, use_pre, visitor, seed, allow_empty_alt, expand_alt, group_limit, suppress_colon], |
|
|
outputs=[parse_out, changes_out, schedule_out, timeline_out, fixes_state, fixes_dd, parser_status] |
|
|
) |
|
|
|
|
|
def _apply_fix_cur(prompt: str, fixes: List[Dict[str, str]], title: str): |
|
|
m = {f["title"]: f["id"] for f in fixes} |
|
|
fid = m.get(title) |
|
|
return apply_quickfix(prompt or "", fid) if fid else prompt |
|
|
apply_fix_btn.click(_apply_fix_cur, inputs=[prompt_tb, fixes_state, fixes_dd], outputs=[prompt_tb]) |
|
|
|
|
|
def _auto_fix(prompt: str): |
|
|
sgs = quickfix_suggestions(prompt or "") |
|
|
if not sgs: return prompt |
|
|
return apply_quickfix(prompt or "", sgs[0]["id"]) |
|
|
auto_fix_btn.click(_auto_fix, inputs=[prompt_tb], outputs=[prompt_tb]) |
|
|
|
|
|
|
|
|
with gr.Tab("Wrap Wizard"): |
|
|
src_tb = gr.Textbox(label="Исходный текст", lines=6, placeholder="Один сегмент на строку или через верхнеуровневые запятые") |
|
|
seg_mode_dd = gr.Radio(choices=["auto", "lines", "comma"], value="auto", label="Разделение") |
|
|
with gr.Row(): |
|
|
ww_steps = gr.Slider(2, 200, value=20, step=1, label="Steps (для top-level)") |
|
|
ww_boundary = gr.Number(value=None, label="Boundary (Top-Level 2)", precision=0) |
|
|
ww_b1 = gr.Number(value=None, label="b1 (Top-Level 3)", precision=0) |
|
|
ww_b2 = gr.Number(value=None, label="b2 (Top-Level 3)", precision=0) |
|
|
with gr.Row(): |
|
|
seq_btn = gr.Button("Sequence") |
|
|
tl2_btn = gr.Button("Top-Level (2)") |
|
|
tl3_btn = gr.Button("Top-Level (3)") |
|
|
with gr.Row(): |
|
|
alt_btn = gr.Button("Alternate") |
|
|
prob_btn = gr.Button("Prob-Equal") |
|
|
wrap_preview = gr.Textbox(label="Предпросмотр", lines=3) |
|
|
insert_wrap_btn = gr.Button("Вставить в Prompt") |
|
|
|
|
|
def on_wrap_seq(src, seg_mode): return wrap_sequence(src or "", seg_mode) |
|
|
def on_wrap_tl2(src, steps, boundary, seg_mode): |
|
|
b = None |
|
|
try: |
|
|
if boundary not in (None, ""): b = int(boundary) |
|
|
except Exception: b = None |
|
|
return wrap_top_level2(src or "", steps=int(steps), boundary=b, seg_mode=seg_mode) |
|
|
def on_wrap_tl3(src, steps, b1, b2, seg_mode): |
|
|
bb = None |
|
|
try: |
|
|
if b1 not in (None, "") and b2 not in (None, ""): bb = (int(b1), int(b2)) |
|
|
except Exception: bb = None |
|
|
return wrap_top_level3(src or "", steps=int(steps), boundaries=bb, seg_mode=seg_mode) |
|
|
def on_wrap_alt(src, seg_mode): return wrap_alternate(src or "", seg_mode) |
|
|
def on_wrap_prob(src, seg_mode): return wrap_prob_equal(src or "", 1.0, seg_mode) |
|
|
|
|
|
seq_btn.click(on_wrap_seq, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) |
|
|
tl2_btn.click(on_wrap_tl2, inputs=[src_tb, ww_steps, ww_boundary, seg_mode_dd], outputs=[wrap_preview]) |
|
|
tl3_btn.click(on_wrap_tl3, inputs=[src_tb, ww_steps, ww_b1, ww_b2, seg_mode_dd], outputs=[wrap_preview]) |
|
|
alt_btn.click(on_wrap_alt, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) |
|
|
prob_btn.click(on_wrap_prob, inputs=[src_tb, seg_mode_dd], outputs=[wrap_preview]) |
|
|
|
|
|
def _insert_block(cur_prompt: str, block: str) -> str: |
|
|
if not block or not block.strip(): return cur_prompt |
|
|
sep = "" if (not cur_prompt or cur_prompt.endswith((" ", "\n", ","))) else ", " |
|
|
return (cur_prompt or "") + sep + block |
|
|
insert_wrap_btn.click(_insert_block, inputs=[prompt_tb, wrap_preview], outputs=[prompt_tb]) |
|
|
|
|
|
|
|
|
with gr.Tab("Timeline Wizard"): |
|
|
seg_tb = gr.Textbox(label="Сегменты (по одному на строку)", lines=4, value="a\nb") |
|
|
boundaries_tb = gr.Textbox(label="Границы (CSV)", value="20") |
|
|
build_tl_btn = gr.Button("Построить") |
|
|
tl_preview = gr.Textbox(label="Предпросмотр", lines=2) |
|
|
insert_tl_btn = gr.Button("Вставить в Prompt") |
|
|
|
|
|
def build_schedule_block(segments_text: str, boundaries_csv: str) -> str: |
|
|
segs = [s.strip() for s in (segments_text or "").splitlines() if s.strip()] |
|
|
if not segs: return "" |
|
|
core = " : ".join(segs) |
|
|
bounds = [b.strip() for b in (boundaries_csv or "").split(",") if b.strip()] |
|
|
extra = f" : {', '.join(bounds)}" if bounds else "" |
|
|
return f"[ {core} ]{extra}" |
|
|
|
|
|
build_tl_btn.click(build_schedule_block, inputs=[seg_tb, boundaries_tb], outputs=[tl_preview]) |
|
|
insert_tl_btn.click(_insert_block, inputs=[prompt_tb, tl_preview], outputs=[prompt_tb]) |
|
|
|
|
|
|
|
|
with gr.Tab("Token Counter"): |
|
|
payload_limit = gr.Slider(10, 200, value=77, step=1, label="Payload limit") |
|
|
backend_dd = gr.Dropdown(choices=["auto", "transformers", "open_clip", "approx"], value="auto", label="Backend (Advanced)") |
|
|
count_btn = gr.Button("Count") |
|
|
with gr.Row(): |
|
|
backend_used = gr.Textbox(label="Used backend", interactive=False) |
|
|
total_out = gr.Number(label="Total tokens", interactive=False) |
|
|
payload_out = gr.Number(label="Payload tokens", interactive=False) |
|
|
overflow_out = gr.Number(label="Overflow", interactive=False) |
|
|
trimmed_tb = gr.Textbox(label="Trimmed text", lines=6) |
|
|
|
|
|
def on_count(prompt, limit_payload, backend): |
|
|
res = count_clip_tokens(prompt or "", limit_payload=int(limit_payload), backend=backend) |
|
|
return res.backend, res.total, res.payload, res.overflow, res.trimmed_text |
|
|
count_btn.click(on_count, inputs=[prompt_tb, payload_limit, backend_dd], |
|
|
outputs=[backend_used, total_out, payload_out, overflow_out, trimmed_tb]) |
|
|
|
|
|
|
|
|
with gr.Tab("Advanced"): |
|
|
with gr.Accordion("Curve Editor", open=False): |
|
|
curve_segments_tb = gr.Textbox(label="Сегменты (пусто = seg1..N)", lines=6, value="") |
|
|
with gr.Row(): |
|
|
curve_num = gr.Slider(2, 8, value=3, step=1, label="N") |
|
|
curve_steps = gr.Slider(2, 300, value=60, step=1, label="Steps") |
|
|
curve_kind = gr.Dropdown(choices=["linear", "ease-in", "ease-out", "ease-in-out", "cosine", "bezier"], value="ease-in-out", label="Кривая") |
|
|
curve_power = gr.Slider(1, 6, value=2, step=0.5, label="Показатель (ease-in/out)") |
|
|
with gr.Row(): |
|
|
bez_x1 = gr.Slider(0, 1, value=0.3, step=0.05, label="Bezier x1") |
|
|
bez_y1 = gr.Slider(0, 1, value=0.0, step=0.05, label="Bezier y1") |
|
|
bez_x2 = gr.Slider(0, 1, value=0.7, step=0.05, label="Bezier x2") |
|
|
bez_y2 = gr.Slider(0, 1, value=1.0, step=0.05, label="Bezier y2") |
|
|
build_curve_btn = gr.Button("Build") |
|
|
curve_bounds_out = gr.Textbox(label="Границы", lines=1) |
|
|
curve_preview = gr.Textbox(label="Предпросмотр", lines=2) |
|
|
insert_curve_btn = gr.Button("Вставить в Prompt") |
|
|
|
|
|
def on_build_curve(segs_text, n, steps, kind, power, x1, y1, x2, y2): |
|
|
bez = (float(x1), float(y1), float(x2), float(y2)) if kind == "bezier" else None |
|
|
block, csvb = build_curve_schedule(segs_text or "", int(n), int(steps), str(kind), float(power), bez) |
|
|
return csvb, block |
|
|
build_curve_btn.click(on_build_curve, |
|
|
inputs=[curve_segments_tb, curve_num, curve_steps, curve_kind, curve_power, bez_x1, bez_y1, bez_x2, bez_y2], |
|
|
outputs=[curve_bounds_out, curve_preview]) |
|
|
insert_curve_btn.click(_insert_block, inputs=[prompt_tb, curve_preview], outputs=[prompt_tb]) |
|
|
|
|
|
with gr.Accordion("Prob-Alternate Wizard", open=False): |
|
|
steps_pa = gr.Slider(1, 200, value=20, step=1, label="Steps") |
|
|
table = gr.Dataframe(headers=["Текст", "Вес"], datatype=["str", "number"], |
|
|
row_count=3, col_count=(2, "fixed"), |
|
|
value=[["a", 0.5], ["b", 0.5], ["", ""]], label="Варианты") |
|
|
build_pa_btn = gr.Button("Build") |
|
|
pa_preview = gr.Textbox(label="Предпросмотр", lines=2) |
|
|
def build_schedule_from_prob_table(rows, total_steps: int) -> str: |
|
|
items: List[Tuple[str, float]] = []; total = 0.0 |
|
|
for r in rows or []: |
|
|
if not r or len(r) < 2: continue |
|
|
text = (r[0] or "").strip() |
|
|
if not text: continue |
|
|
try: w = float(str(r[1]).strip()) |
|
|
except Exception: w = 0.0 |
|
|
items.append((text, w)); total += w |
|
|
if not items or total <= 0: return "" |
|
|
boundaries: List[str] = []; acc = 0.0 |
|
|
for i in range(len(items) - 1): |
|
|
acc += items[i][1] / total |
|
|
b = int(round(acc * int(total_steps))) |
|
|
if 0 < b < int(total_steps): boundaries.append(str(b)) |
|
|
core = " : ".join([t for t, _ in items]) |
|
|
extra = f" : {', '.join(boundaries)}" if boundaries else "" |
|
|
return f"[ {core} ]{extra}" |
|
|
build_pa_btn.click(build_schedule_from_prob_table, inputs=[table, steps_pa], outputs=[pa_preview]) |
|
|
insert_pa_btn = gr.Button("Вставить в Prompt") |
|
|
insert_pa_btn.click(_insert_block, inputs=[prompt_tb, pa_preview], outputs=[prompt_tb]) |
|
|
|
|
|
with gr.Accordion("Test Generator (unit tests)", open=False): |
|
|
tg_steps = gr.Slider(1, 200, value=20, step=1, label="Steps") |
|
|
tg_seed = gr.Number(value=None, label="Seed", precision=0) |
|
|
with gr.Row(): |
|
|
tg_allow_empty = gr.Checkbox(value=False, label="ALLOW_EMPTY_ALTERNATE") |
|
|
tg_expand_alt = gr.Checkbox(value=True, label="EXPAND_ALTERNATE_PER_STEP") |
|
|
tg_group_limit = gr.Number(value=100, label="GROUP_COMBO_LIMIT", precision=0) |
|
|
tg_suppr_colon = gr.Checkbox(value=True, label="SUPPRESS_STANDALONE_COLON") |
|
|
gen_btn = gr.Button("Generate") |
|
|
test_code_tb = gr.Textbox(label="test_prompt_parser_generated.py", lines=18) |
|
|
test_file = gr.File(label="Скачать", interactive=False) |
|
|
|
|
|
def generate_unittests_code(prompt_src: str, steps: int, env: Dict[str, str], seed: Optional[int], analysis: AnalysisResult) -> str: |
|
|
resolved = analysis.resolved_text or "" |
|
|
schedule = analysis.schedule or [] |
|
|
def slice_pairs(pairs: List[Tuple[int,str]], n_head: int = 3, n_tail: int = 3): |
|
|
if len(pairs) <= n_head + n_tail: return pairs |
|
|
return pairs[:n_head] + pairs[-n_tail:] |
|
|
sched_sample = slice_pairs(schedule) |
|
|
ENV_LINES = "\n".join([f" os.environ[{repr(k)}] = {repr(v)}" for k, v in env.items()]) |
|
|
code = f'''\ |
|
|
import os, sys, unittest, importlib.machinery, importlib.util |
|
|
PROMPT_PARSER_PATH = os.environ.get("PROMPT_PARSER_PATH", {repr(PROMPT_PARSER_PATH)}) |
|
|
def import_from_path(mod_name, path): |
|
|
if mod_name in sys.modules: del sys.modules[mod_name] |
|
|
loader = importlib.machinery.SourceFileLoader(mod_name, path) |
|
|
spec = importlib.util.spec_from_loader(loader.name, loader) |
|
|
mod = importlib.util.module_from_spec(spec); loader.exec_module(mod); return mod |
|
|
class TestPromptParserGenerated(unittest.TestCase): |
|
|
@classmethod |
|
|
def setUpClass(cls): |
|
|
{ENV_LINES and ENV_LINES or " pass"} |
|
|
cls.pp = import_from_path("pp", PROMPT_PARSER_PATH) |
|
|
def test_parse_and_resolve(self): |
|
|
pp = self.pp; src = {repr(prompt_src)} |
|
|
tree = pp.schedule_parser.parse(src) |
|
|
resolved = pp.resolve_tree(tree, keep_spacing=True) |
|
|
self.assertEqual(resolved, {repr(resolved)}) |
|
|
def test_schedule_sample(self): |
|
|
pp = self.pp; src = {repr(prompt_src)}; steps = {int(steps)} |
|
|
collector = pp.CollectSteps(steps, prefix="", suffix="", depth=0, use_scheduling=True, seed={repr(seed)}) |
|
|
schedules = collector.visit(tree := pp.schedule_parser.parse(src)) or [] |
|
|
got = [(int(b), str(t)) for b, t in schedules] |
|
|
sample_got = (got[:3] + got[-3:]) if len(got) > 6 else got |
|
|
self.assertEqual(sample_got, {repr(sched_sample)}) |
|
|
if __name__ == "__main__": unittest.main() |
|
|
''' |
|
|
return code |
|
|
|
|
|
def on_gen(prompt, steps, seed, ae, ex, gl, sc): |
|
|
reload_parser_after_env(bool(ae), bool(ex), int(gl), bool(sc)) |
|
|
res = analyze_prompt( |
|
|
text=prompt or "", steps=int(steps), use_preprocessor=True, visitor_mode=True, |
|
|
seed=None if seed in (None, "") else int(seed), |
|
|
allow_empty_alt=bool(ae), expand_alt_per_step=bool(ex), |
|
|
group_combo_limit=int(gl), suppress_standalone_colon=bool(sc) |
|
|
) |
|
|
env = { |
|
|
"ALLOW_EMPTY_ALTERNATE": "1" if ae else "0", |
|
|
"EXPAND_ALTERNATE_PER_STEP": "1" if ex else "0", |
|
|
"GROUP_COMBO_LIMIT": str(int(gl)), |
|
|
"SUPPRESS_STANDALONE_COLON": "1" if sc else "0", |
|
|
} |
|
|
code = generate_unittests_code(prompt or "", int(steps), env, None if seed in (None, "") else int(seed), res) |
|
|
path = "/mnt/data/test_prompt_parser_generated.py" |
|
|
try: |
|
|
with open(path, "w", encoding="utf-8") as f: f.write(code) |
|
|
return code, path |
|
|
except Exception: |
|
|
return code, None |
|
|
|
|
|
gen_btn.click(on_gen, |
|
|
inputs=[prompt_tb, tg_steps, tg_seed, tg_allow_empty, tg_expand_alt, tg_group_limit, tg_suppr_colon], |
|
|
outputs=[test_code_tb, test_file]) |
|
|
|
|
|
|
|
|
return [(demo, "Prompt Refactor", "prompt_refactor_tab")] |
|
|
|
|
|
|
|
|
if script_callbacks is not None: |
|
|
script_callbacks.on_ui_tabs(on_ui_tabs) |
|
|
|