sdas / negative_rejection_steering /scripts /negative_rejection_steering_script.py
dikdimon's picture
Update negative_rejection_steering/scripts/negative_rejection_steering_script.py
d7cdd26 verified
import torch
import gradio as gr
from modules import scripts, script_callbacks, sd_samplers_cfg_denoiser, shared
# ==============================================================================
# ЧАСТЬ 1: Математическое ядро NRS
# ==============================================================================
def calc_nrs(x_orig, cond, uncond, sigma, skew, stretch, squash):
# 1. Проверка режима V-prediction
is_v_pred = False
if hasattr(shared.sd_model, 'parameterization'):
is_v_pred = shared.sd_model.parameterization == "v"
# 2. Подготовка тензора Sigma
if isinstance(sigma, torch.Tensor):
sig_tens = sigma[0]
else:
sig_tens = torch.tensor(sigma, device=cond.device, dtype=cond.dtype)
if sig_tens.dtype != cond.dtype:
sig_tens = sig_tens.to(dtype=cond.dtype)
sig_tens = sig_tens.view(1, 1, 1, 1)
sig_root = (sig_tens ** 2 + 1).sqrt()
# 3. Конвертация EPS -> V
if is_v_pred:
nrs_cond, nrs_uncond = cond, uncond
x_div = None
else:
x_div = x_orig / (sig_tens ** 2 + 1)
factor = sig_tens / sig_root
nrs_cond = x_orig - (x_div - cond * factor)
nrs_uncond = x_orig - (x_div - uncond * factor)
# 4. Векторные операции (Core NRS Math)
def _dot(a, b): return (a * b).sum(dim=1, keepdim=True)
def _nrm2(v): return _dot(v, v)
eps_safe = 1e-6
# Проекция Uncond на Cond
c_dot_c = _nrm2(nrs_cond) + eps_safe
u_dot_c = _dot(nrs_uncond, nrs_cond)
u_on_c = (u_dot_c / c_dot_c) * nrs_cond
# Stretch: Усиление
proj_diff = nrs_cond - u_on_c
stretched = nrs_cond + (stretch * proj_diff)
# Skew: Отклонение
u_rej_c = nrs_uncond - u_on_c
skewed = stretched - (skew * u_rej_c)
# Squash: Нормализация
cond_len = nrs_cond.norm(dim=1, keepdim=True)
nrs_len = skewed.norm(dim=1, keepdim=True) + eps_safe
squash_scale = (1 - squash) + (squash * (cond_len / nrs_len))
x_final = skewed * squash_scale
# 5. Возврат в исходное пространство
if is_v_pred:
return x_final
else:
return (x_div - (x_orig - x_final)) * (sig_root / sig_tens)
# ==============================================================================
# ЧАСТЬ 1.5: Утилиты для управления шагами
# ==============================================================================
def should_apply_at_step(current_step, total_steps, start_step, end_step, start_frac, end_frac, step_mode):
"""
Определяет, нужно ли применять эффект на текущем шаге
Args:
current_step: Текущий шаг сэмплера
total_steps: Общее количество шагов
start_step: Начальный абсолютный шаг
end_step: Конечный абсолютный шаг
start_frac: Начальная доля (0.0-1.0)
end_frac: Конечная доля (0.0-1.0)
step_mode: Режим ("Absolute Steps" или "Fraction of Steps")
Returns:
bool: True если эффект должен применяться
"""
if step_mode == "Absolute Steps":
# Абсолютные шаги
effective_start = max(0, start_step)
effective_end = min(total_steps, end_step) if end_step > 0 else total_steps
return effective_start <= current_step < effective_end
else:
# Доли от общего количества шагов
effective_start = int(total_steps * max(0.0, min(1.0, start_frac)))
effective_end = int(total_steps * max(0.0, min(1.0, end_frac)))
if effective_end == 0:
effective_end = total_steps
return effective_start <= current_step < effective_end
def get_param_value_at_step(base_value, current_step, total_steps, start_step, end_step,
start_frac, end_frac, step_mode, enabled):
"""
Возвращает значение параметра для текущего шага
Если шаг вне диапазона - возвращает 0.0
"""
if not enabled:
return base_value
if should_apply_at_step(current_step, total_steps, start_step, end_step,
start_frac, end_frac, step_mode):
return base_value
else:
return 0.0
# ==============================================================================
# ЧАСТЬ 2: Перехват управления (Hooking)
# ==============================================================================
# Хук 1: Сохраняем Sigma, Input Latent и текущий шаг
def hook_cfg_denoiser_params(params):
if hasattr(params.denoiser, 'p') and getattr(params.denoiser.p, '_nrs_enabled', False):
params.denoiser.p._nrs_current_sigma = params.sigma
params.denoiser.p._nrs_current_x_in = params.x
# Определяем текущий шаг
if hasattr(params, 'sampling_step'):
params.denoiser.p._nrs_current_step = params.sampling_step
elif hasattr(params.denoiser, 'step'):
params.denoiser.p._nrs_current_step = params.denoiser.step
else:
# Fallback: пытаемся определить по sigma
params.denoiser.p._nrs_current_step = getattr(params.denoiser.p, '_nrs_current_step', 0)
script_callbacks.on_cfg_denoiser(hook_cfg_denoiser_params)
# Бэкап оригинала
if not hasattr(sd_samplers_cfg_denoiser.CFGDenoiser, 'original_combine_denoised_nrs_backup'):
sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup = sd_samplers_cfg_denoiser.CFGDenoiser.combine_denoised
# Хук 2: Подменная функция с поддержкой step control
def hijacked_combine_denoised(self, x_out, conds_list, uncond, cond_scale):
# Проверка активности NRS
if not getattr(self, 'p', None) or not getattr(self.p, '_nrs_enabled', False):
return sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup(self, x_out, conds_list, uncond, cond_scale)
if not hasattr(self.p, '_nrs_current_sigma') or not hasattr(self.p, '_nrs_current_x_in'):
return sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup(self, x_out, conds_list, uncond, cond_scale)
try:
# Получаем базовые параметры
base_skew, base_stretch, base_squash = self.p._nrs_params
# Получаем настройки step control
step_control_enabled = getattr(self.p, '_nrs_step_control_enabled', False)
step_control_mode = getattr(self.p, '_nrs_step_control_mode', 'Global')
# Определяем текущий и общий шаги
current_step = getattr(self.p, '_nrs_current_step', 0)
total_steps = getattr(self.p, 'steps', 20)
# Если step control включен, вычисляем эффективные значения параметров
if step_control_enabled:
if step_control_mode == 'Global':
# Глобальный режим - одни настройки для всех
global_settings = getattr(self.p, '_nrs_global_step_settings', {})
start_step = global_settings.get('start_step', 0)
end_step = global_settings.get('end_step', total_steps)
start_frac = global_settings.get('start_frac', 0.0)
end_frac = global_settings.get('end_frac', 1.0)
step_mode = global_settings.get('step_mode', 'Absolute Steps')
# Проверяем, нужно ли применять эффекты
if not should_apply_at_step(current_step, total_steps, start_step, end_step,
start_frac, end_frac, step_mode):
# Вне диапазона - используем fallback
return sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup(
self, x_out, conds_list, uncond, cond_scale)
skew, stretch, squash = base_skew, base_stretch, base_squash
else:
# Индивидуальный режим - отдельные настройки для каждого параметра
individual_settings = getattr(self.p, '_nrs_individual_step_settings', {})
skew_settings = individual_settings.get('skew', {})
skew = get_param_value_at_step(
base_skew, current_step, total_steps,
skew_settings.get('start_step', 0),
skew_settings.get('end_step', total_steps),
skew_settings.get('start_frac', 0.0),
skew_settings.get('end_frac', 1.0),
skew_settings.get('step_mode', 'Absolute Steps'),
skew_settings.get('enabled', True)
)
stretch_settings = individual_settings.get('stretch', {})
stretch = get_param_value_at_step(
base_stretch, current_step, total_steps,
stretch_settings.get('start_step', 0),
stretch_settings.get('end_step', total_steps),
stretch_settings.get('start_frac', 0.0),
stretch_settings.get('end_frac', 1.0),
stretch_settings.get('step_mode', 'Absolute Steps'),
stretch_settings.get('enabled', True)
)
squash_settings = individual_settings.get('squash', {})
squash = get_param_value_at_step(
base_squash, current_step, total_steps,
squash_settings.get('start_step', 0),
squash_settings.get('end_step', total_steps),
squash_settings.get('start_frac', 0.0),
squash_settings.get('end_frac', 1.0),
squash_settings.get('step_mode', 'Absolute Steps'),
squash_settings.get('enabled', True)
)
else:
# Step control выключен - используем базовые значения
skew, stretch, squash = base_skew, base_stretch, base_squash
# Основная логика NRS
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
x_orig_all = self.p._nrs_current_x_in
x_orig_uncond = x_orig_all[-uncond.shape[0]:]
for i, conds in enumerate(conds_list):
for idx, (cond_index, weight) in enumerate(conds):
current_cond = x_out[cond_index]
# NRS только для основного промпта
if idx == 0:
current_x_orig = x_orig_uncond[i].unsqueeze(0)
c_in = current_cond.unsqueeze(0)
u_in = denoised_uncond[i].unsqueeze(0)
nrs_result = calc_nrs(
current_x_orig, c_in, u_in,
self.p._nrs_current_sigma,
skew, stretch, squash
)
if len(conds) == 1:
denoised[i] = nrs_result.squeeze(0)
else:
delta = nrs_result.squeeze(0) - denoised_uncond[i]
denoised[i] += delta * weight
else:
denoised[i] += (current_cond - denoised_uncond[i]) * (weight * cond_scale)
return denoised
except Exception as e:
print(f"!!! NRS Error (Fallback): {e}")
return sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup(self, x_out, conds_list, uncond, cond_scale)
# ==============================================================================
# ЧАСТЬ 3: Интерфейс (UI) с расширенными настройками
# ==============================================================================
class NRSScript(scripts.Script):
def title(self):
return "Negative Rejection Steering (Enhanced)"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion("Negative Rejection Steering", open=False):
with gr.Row():
enabled = gr.Checkbox(label="Включить NRS (Enable)", value=False)
# --- БЛОК СПРАВКИ ---
with gr.Accordion("❓ Как этим пользоваться (Инструкция)", open=False):
gr.Markdown("""
### Что это такое?
**NRS** — это замена стандартному CFG Scale. Вместо простого усиления промпта, он дает вам 3 рычага управления для более точного контроля.
### Параметры:
* **Skew (Сдвиг/Руление):** *Влияет на:* **Композицию и Геометрию**.
Отклоняет генерацию *прочь* от того, что указано в Negative prompt. Если картинка "ломается" или лезут лишние объекты — крутите это.
* **Stretch (Растяжение):** *Влияет на:* **Цвета и Текстуры**.
Усиливает элементы, которые совпадают с вашим Positive prompt. Работает как "газ" для ваших идей.
* **Squash (Сплющивание):** *Влияет на:* **Детализацию и Контраст**.
Ограничитель скорости. При `0.0` эффект максимальный (может "пережарить" картинку). Увеличение добавляет микро-деталей и смягчает цвета.
### 🚀 С чего начать (Советы от автора):
1. **Skew** ≈ Половина вашего обычного CFG (например, **3.0** - **4.0**).
2. **Stretch** ≈ Ваш обычный CFG (например, **6.0** - **7.0**).
3. **Squash** ≈ Оставьте **0.0** для начала.
### ⏱️ Step Control (Управление шагами):
Позволяет применять NRS только на определенных шагах генерации.
- **Global Mode**: Одинаковые настройки для всех трех параметров
- **Individual Mode**: Отдельные настройки для каждого параметра (Skew, Stretch, Squash)
""")
# --- ОСНОВНЫЕ ПАРАМЕТРЫ ---
gr.HTML("<div style='margin-bottom: 0.5em; opacity: 0.8; font-size: 0.9em; border-bottom: 1px solid #444;'>Основные настройки</div>")
with gr.Row():
skew = gr.Slider(
label="Skew (Композиция)",
minimum=-30.0, maximum=30.0, step=0.05, value=4.0,
info="Сила отклонения от Negative prompt. Аналог 'силы' CFG для структуры."
)
stretch = gr.Slider(
label="Stretch (Цвета/Текстура)",
minimum=-30.0, maximum=30.0, step=0.05, value=2.0,
info="Сила притяжения к Positive prompt. Усиливает цвета и стиль."
)
squash = gr.Slider(
label="Squash (Защита от пережарки)",
minimum=0.0, maximum=1.0, step=0.01, value=0.0,
info="0.0 = Макс. эффект. 1.0 = Ослабление (больше деталей, меньше контраста)."
)
# --- STEP CONTROL ---
with gr.Accordion("⏱️ Step Control (Управление шагами)", open=False):
with gr.Row():
step_control_enabled = gr.Checkbox(
label="Включить Step Control",
value=False,
info="Применять NRS только на определенных шагах"
)
step_control_mode = gr.Radio(
label="Режим управления",
choices=["Global", "Individual"],
value="Global",
info="Global: общие настройки | Individual: настройки для каждого параметра"
)
# ГЛОБАЛЬНЫЕ НАСТРОЙКИ
with gr.Group(visible=True) as global_group:
gr.HTML("<div style='margin: 0.5em 0; font-weight: bold;'>Глобальные настройки (для всех параметров)</div>")
global_step_mode = gr.Radio(
label="Режим шагов",
choices=["Absolute Steps", "Fraction of Steps"],
value="Absolute Steps",
info="Absolute: номера шагов | Fraction: доли от общего числа"
)
with gr.Row():
global_start_step = gr.Slider(
label="Start Step (абсолютный)",
minimum=0, maximum=150, step=1, value=0,
visible=True
)
global_end_step = gr.Slider(
label="End Step (абсолютный, 0=конец)",
minimum=0, maximum=150, step=1, value=0,
visible=True
)
with gr.Row():
global_start_frac = gr.Slider(
label="Start (доля от общего числа)",
minimum=0.0, maximum=1.0, step=0.01, value=0.0,
visible=False
)
global_end_frac = gr.Slider(
label="End (доля от общего числа)",
minimum=0.0, maximum=1.0, step=0.01, value=1.0,
visible=False
)
# ИНДИВИДУАЛЬНЫЕ НАСТРОЙКИ
with gr.Group(visible=False) as individual_group:
gr.HTML("<div style='margin: 0.5em 0; font-weight: bold;'>Индивидуальные настройки</div>")
# SKEW
with gr.Accordion("Skew (Композиция) - Step Settings", open=False):
skew_step_enabled = gr.Checkbox(label="Включить управление шагами для Skew", value=True)
skew_step_mode = gr.Radio(
label="Режим",
choices=["Absolute Steps", "Fraction of Steps"],
value="Absolute Steps"
)
with gr.Row():
skew_start_step = gr.Slider(label="Start Step", minimum=0, maximum=150, step=1, value=0, visible=True)
skew_end_step = gr.Slider(label="End Step (0=конец)", minimum=0, maximum=150, step=1, value=0, visible=True)
with gr.Row():
skew_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0, maximum=1.0, step=0.01, value=0.0, visible=False)
skew_end_frac = gr.Slider(label="End (fraction)", minimum=0.0, maximum=1.0, step=0.01, value=1.0, visible=False)
# STRETCH
with gr.Accordion("Stretch (Цвета/Текстура) - Step Settings", open=False):
stretch_step_enabled = gr.Checkbox(label="Включить управление шагами для Stretch", value=True)
stretch_step_mode = gr.Radio(
label="Режим",
choices=["Absolute Steps", "Fraction of Steps"],
value="Absolute Steps"
)
with gr.Row():
stretch_start_step = gr.Slider(label="Start Step", minimum=0, maximum=150, step=1, value=0, visible=True)
stretch_end_step = gr.Slider(label="End Step (0=конец)", minimum=0, maximum=150, step=1, value=0, visible=True)
with gr.Row():
stretch_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0, maximum=1.0, step=0.01, value=0.0, visible=False)
stretch_end_frac = gr.Slider(label="End (fraction)", minimum=0.0, maximum=1.0, step=0.01, value=1.0, visible=False)
# SQUASH
with gr.Accordion("Squash (Защита от пережарки) - Step Settings", open=False):
squash_step_enabled = gr.Checkbox(label="Включить управление шагами для Squash", value=True)
squash_step_mode = gr.Radio(
label="Режим",
choices=["Absolute Steps", "Fraction of Steps"],
value="Absolute Steps"
)
with gr.Row():
squash_start_step = gr.Slider(label="Start Step", minimum=0, maximum=150, step=1, value=0, visible=True)
squash_end_step = gr.Slider(label="End Step (0=конец)", minimum=0, maximum=150, step=1, value=0, visible=True)
with gr.Row():
squash_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0, maximum=1.0, step=0.01, value=0.0, visible=False)
squash_end_frac = gr.Slider(label="End (fraction)", minimum=0.0, maximum=1.0, step=0.01, value=1.0, visible=False)
# Переключение видимости групп при изменении режима
def update_mode_visibility(mode):
return {
global_group: gr.update(visible=(mode == "Global")),
individual_group: gr.update(visible=(mode == "Individual"))
}
step_control_mode.change(
fn=update_mode_visibility,
inputs=[step_control_mode],
outputs=[global_group, individual_group]
)
# Переключение между абсолютными и дробными шагами (GLOBAL)
def update_global_step_inputs(mode):
is_absolute = (mode == "Absolute Steps")
return {
global_start_step: gr.update(visible=is_absolute),
global_end_step: gr.update(visible=is_absolute),
global_start_frac: gr.update(visible=not is_absolute),
global_end_frac: gr.update(visible=not is_absolute)
}
global_step_mode.change(
fn=update_global_step_inputs,
inputs=[global_step_mode],
outputs=[global_start_step, global_end_step, global_start_frac, global_end_frac]
)
# Переключение для SKEW
def update_skew_step_inputs(mode):
is_absolute = (mode == "Absolute Steps")
return {
skew_start_step: gr.update(visible=is_absolute),
skew_end_step: gr.update(visible=is_absolute),
skew_start_frac: gr.update(visible=not is_absolute),
skew_end_frac: gr.update(visible=not is_absolute)
}
skew_step_mode.change(
fn=update_skew_step_inputs,
inputs=[skew_step_mode],
outputs=[skew_start_step, skew_end_step, skew_start_frac, skew_end_frac]
)
# Переключение для STRETCH
def update_stretch_step_inputs(mode):
is_absolute = (mode == "Absolute Steps")
return {
stretch_start_step: gr.update(visible=is_absolute),
stretch_end_step: gr.update(visible=is_absolute),
stretch_start_frac: gr.update(visible=not is_absolute),
stretch_end_frac: gr.update(visible=not is_absolute)
}
stretch_step_mode.change(
fn=update_stretch_step_inputs,
inputs=[stretch_step_mode],
outputs=[stretch_start_step, stretch_end_step, stretch_start_frac, stretch_end_frac]
)
# Переключение для SQUASH
def update_squash_step_inputs(mode):
is_absolute = (mode == "Absolute Steps")
return {
squash_start_step: gr.update(visible=is_absolute),
squash_end_step: gr.update(visible=is_absolute),
squash_start_frac: gr.update(visible=not is_absolute),
squash_end_frac: gr.update(visible=not is_absolute)
}
squash_step_mode.change(
fn=update_squash_step_inputs,
inputs=[squash_step_mode],
outputs=[squash_start_step, squash_end_step, squash_start_frac, squash_end_frac]
)
return [
enabled, skew, stretch, squash,
step_control_enabled, step_control_mode,
# Global
global_step_mode, global_start_step, global_end_step, global_start_frac, global_end_frac,
# Skew
skew_step_enabled, skew_step_mode, skew_start_step, skew_end_step, skew_start_frac, skew_end_frac,
# Stretch
stretch_step_enabled, stretch_step_mode, stretch_start_step, stretch_end_step, stretch_start_frac, stretch_end_frac,
# Squash
squash_step_enabled, squash_step_mode, squash_start_step, squash_end_step, squash_start_frac, squash_end_frac
]
def process(self, p, enabled, skew, stretch, squash,
step_control_enabled, step_control_mode,
global_step_mode, global_start_step, global_end_step, global_start_frac, global_end_frac,
skew_step_enabled, skew_step_mode, skew_start_step, skew_end_step, skew_start_frac, skew_end_frac,
stretch_step_enabled, stretch_step_mode, stretch_start_step, stretch_end_step, stretch_start_frac, stretch_end_frac,
squash_step_enabled, squash_step_mode, squash_start_step, squash_end_step, squash_start_frac, squash_end_frac):
p._nrs_enabled = enabled
p._nrs_params = (skew, stretch, squash)
p._nrs_step_control_enabled = step_control_enabled
p._nrs_step_control_mode = step_control_mode
# Сохраняем глобальные настройки
p._nrs_global_step_settings = {
'step_mode': global_step_mode,
'start_step': global_start_step,
'end_step': global_end_step,
'start_frac': global_start_frac,
'end_frac': global_end_frac
}
# Сохраняем индивидуальные настройки
p._nrs_individual_step_settings = {
'skew': {
'enabled': skew_step_enabled,
'step_mode': skew_step_mode,
'start_step': skew_start_step,
'end_step': skew_end_step,
'start_frac': skew_start_frac,
'end_frac': skew_end_frac
},
'stretch': {
'enabled': stretch_step_enabled,
'step_mode': stretch_step_mode,
'start_step': stretch_start_step,
'end_step': stretch_end_step,
'start_frac': stretch_start_frac,
'end_frac': stretch_end_frac
},
'squash': {
'enabled': squash_step_enabled,
'step_mode': squash_step_mode,
'start_step': squash_start_step,
'end_step': squash_end_step,
'start_frac': squash_start_frac,
'end_frac': squash_end_frac
}
}
if enabled:
sd_samplers_cfg_denoiser.CFGDenoiser.combine_denoised = hijacked_combine_denoised
# Инициализируем счетчик шагов
p._nrs_current_step = 0
def postprocess(self, p, processed, *args):
pass