dikdimon commited on
Commit
6b08e8c
·
verified ·
1 Parent(s): 8d19821

Update z-sd-webui-neutral-prompt-workYEAH3/lib_neutral_prompt/cfg_denoiser_hijack.py

Browse files
z-sd-webui-neutral-prompt-workYEAH3/lib_neutral_prompt/cfg_denoiser_hijack.py CHANGED
@@ -36,14 +36,11 @@ def combine_denoised_hijack(
36
  if not global_state.prompt_exprs and batch_cond_indices:
37
  global_state.prompt_exprs = convert_to_prompt_expr_from_multicond(batch_cond_indices, [])
38
 
39
- # Получить расписания для всех промптов
40
  prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
41
- [expr.accept(WebuiPromptVisitor()) for expr in global_state.prompt_exprs],
42
  shared.state.sampling_steps
43
  )
44
- active_prompts = [get_active_prompt(schedule, shared.state.sampling_step) for schedule in prompt_schedules]
45
-
46
- # Перепарсить активные промпты в PromptExpr
47
  global_state.prompt_exprs = [neutral_prompt_parser.parse_root(p) for p in active_prompts]
48
 
49
  denoised = get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
@@ -105,28 +102,35 @@ def gather_webui_conds(
105
  sliced_x_out = []
106
  sliced_cond_indices = []
107
 
108
- if isinstance(prompt, neutral_prompt_parser.LeafPrompt) and prompt.conciliation is None:
 
109
  child_x_out = args.x_out[args.cond_indices[index_in][0]]
110
  index_offset = index_out + len(sliced_x_out)
111
  sliced_x_out.append(child_x_out)
112
  sliced_cond_indices.append((index_offset, prompt.weight))
113
  return sliced_x_out, sliced_cond_indices
114
 
 
 
115
  if isinstance(prompt, neutral_prompt_parser.CompositePrompt):
116
  for child in prompt.children:
117
- if child.conciliation is None:
118
- if isinstance(child, neutral_prompt_parser.LeafPrompt):
119
- child_x_out = args.x_out[args.cond_indices[index_in][0]]
120
- else:
121
- child_x_out = child.accept(CondDeltaVisitor(), args, index_in)
122
- child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in)
123
- child_x_out += args.uncond
124
- index_offset = index_out + len(sliced_x_out)
125
- sliced_x_out.append(child_x_out)
126
- sliced_cond_indices.append((index_offset, child.weight))
 
 
127
  index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
 
128
  return sliced_x_out, sliced_cond_indices
129
 
 
130
  class CondDeltaVisitor:
131
  def visit_leaf_prompt(
132
  self,
 
36
  if not global_state.prompt_exprs and batch_cond_indices:
37
  global_state.prompt_exprs = convert_to_prompt_expr_from_multicond(batch_cond_indices, [])
38
 
 
39
  prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
40
+ getattr(global_state, "raw_prompts", [expr.accept(WebuiPromptVisitor()) for expr in global_state.prompt_exprs]),
41
  shared.state.sampling_steps
42
  )
43
+ active_prompts = [get_active_prompt(s, shared.state.sampling_step) for s in prompt_schedules]
 
 
44
  global_state.prompt_exprs = [neutral_prompt_parser.parse_root(p) for p in active_prompts]
45
 
46
  denoised = get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
 
102
  sliced_x_out = []
103
  sliced_cond_indices = []
104
 
105
+ # 1) Любой лист -> резервируем индекс (убрали проверку conciliation is None)
106
+ if isinstance(prompt, neutral_prompt_parser.LeafPrompt):
107
  child_x_out = args.x_out[args.cond_indices[index_in][0]]
108
  index_offset = index_out + len(sliced_x_out)
109
  sliced_x_out.append(child_x_out)
110
  sliced_cond_indices.append((index_offset, prompt.weight))
111
  return sliced_x_out, sliced_cond_indices
112
 
113
+ # 2) Для композита — для каждого ребёнка тоже резервируем индекс,
114
+ # а не только когда child.conciliation is None
115
  if isinstance(prompt, neutral_prompt_parser.CompositePrompt):
116
  for child in prompt.children:
117
+ if isinstance(child, neutral_prompt_parser.LeafPrompt):
118
+ child_x_out = args.x_out[args.cond_indices[index_in][0]]
119
+ else:
120
+ child_x_out = child.accept(CondDeltaVisitor(), args, index_in)
121
+ child_x_out += child.accept(AuxCondDeltaVisitor(), args, child_x_out, index_in)
122
+ child_x_out += args.uncond
123
+
124
+ index_offset = index_out + len(sliced_x_out)
125
+ sliced_x_out.append(child_x_out)
126
+ sliced_cond_indices.append((index_offset, child.weight))
127
+
128
+ # важно: шагать по плоскому размеру ребёнка
129
  index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
130
+
131
  return sliced_x_out, sliced_cond_indices
132
 
133
+
134
  class CondDeltaVisitor:
135
  def visit_leaf_prompt(
136
  self,