dikdimon commited on
Commit
4b9cd96
·
verified ·
1 Parent(s): dac353a

Update z-prompt-fusion-extension/scripts/promptlang.py

Browse files
z-prompt-fusion-extension/scripts/promptlang.py CHANGED
@@ -134,7 +134,8 @@ def _adapt_flattened_schedules(result: Any, total_steps: int):
134
  if isinstance(src, dict):
135
  out = dict(dst)
136
  for k, v in src.items():
137
- out[k] = out.get(k, 0) + v * w
 
138
  return out
139
  return dst + src * w
140
 
@@ -214,41 +215,54 @@ def _build_tensor_for_prompt(
214
  ]
215
  wrapped_conds.append(wrapped_sched)
216
 
217
- # 4) Собрать интерполяционный тензор и вычислить расписания по шагам
218
- tensor = tensor_builder.build(wrapped_conds, empty_cond.get())
219
-
220
- slerp_scale = global_state.get_slerp_scale()
221
- slerp_epsilon = global_state.get_slerp_epsilon()
222
-
223
- # Сформировать расписание, склеивая одинаковые соседние сегменты
224
- schedules: List[prompt_parser.ScheduledPromptConditioning] = []
225
- prev_wrapper = None
226
- for step in range(total_steps):
227
- params = interpolation_tensor.InterpolationParams(
228
- t=step / max(1, total_steps - 1),
229
- step=step,
230
- total_steps=total_steps,
231
- slerp_scale=slerp_scale,
232
- slerp_epsilon=slerp_epsilon,
233
- )
234
- origin = global_state.get_origin_cond_at(step, is_hires=is_hires)
235
- cond_wrapper = tensor.interpolate(params, origin, empty_cond.get())
236
 
237
- # Склейка соседних, если равны (по wrapper.__eq__)
238
- if prev_wrapper is not None and cond_wrapper == prev_wrapper:
239
- schedules[-1].end_at_step = step
240
- else:
241
- raw = cond_wrapper.original_cond if hasattr(cond_wrapper, "original_cond") else cond_wrapper
242
- schedules.append(
243
- prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  )
245
- prev_wrapper = cond_wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- # Финализировать последний сегмент на последний шаг
248
- if schedules:
249
- schedules[-1].end_at_step = total_steps - 1
250
 
251
- return schedules
252
 
253
 
254
  # -----------------------------------------------------------------------------
 
134
  if isinstance(src, dict):
135
  out = dict(dst)
136
  for k, v in src.items():
137
+ prev = out.get(k)
138
+ out[k] = (prev + v * w) if prev is not None else (v * w)
139
  return out
140
  return dst + src * w
141
 
 
215
  ]
216
  wrapped_conds.append(wrapped_sched)
217
 
218
+ # 4) Собрать интерполяционный тензор и вычислить расписания по шагам
219
+ tensor = tensor_builder.build(wrapped_conds, empty_cond.get())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ slerp_scale = global_state.get_slerp_scale()
222
+ slerp_epsilon = global_state.get_slerp_epsilon()
223
+
224
+ # Вспомогалка: безопасно «продлить» последний сегмент до step
225
+ def _extend_last_segment_to(step_val: int):
226
+ last = schedules[-1]
227
+ if getattr(last, "end_at_step", None) != step_val:
228
+ schedules[-1] = prompt_parser.ScheduledPromptConditioning(
229
+ end_at_step=step_val,
230
+ cond=last.cond,
231
+ )
232
+
233
+ # Сформировать расписание, склеивая одинаковые соседние сегменты
234
+ schedules: List[prompt_parser.ScheduledPromptConditioning] = []
235
+ prev_wrapper = None
236
+ for step in range(total_steps):
237
+ params = interpolation_tensor.InterpolationParams(
238
+ t=step / max(1, total_steps - 1),
239
+ step=step,
240
+ total_steps=total_steps,
241
+ slerp_scale=slerp_scale,
242
+ slerp_epsilon=slerp_epsilon,
243
  )
244
+ origin = global_state.get_origin_cond_at(step, is_hires=is_hires)
245
+ cond_wrapper = tensor.interpolate(params, origin, empty_cond.get())
246
+
247
+ if prev_wrapper is not None and cond_wrapper == prev_wrapper:
248
+ # Было: schedules[-1].end_at_step = step (иммутабельно -> нельзя)
249
+ _extend_last_segment_to(step)
250
+ else:
251
+ raw = getattr(cond_wrapper, "original_cond", cond_wrapper)
252
+ schedules.append(
253
+ prompt_parser.ScheduledPromptConditioning(
254
+ end_at_step=step,
255
+ cond=raw,
256
+ )
257
+ )
258
+ prev_wrapper = cond_wrapper
259
+
260
+ # Финализировать последний сегмент на последний шаг
261
+ if schedules:
262
+ _extend_last_segment_to(total_steps - 1)
263
 
264
+ return schedules
 
 
265
 
 
266
 
267
 
268
  # -----------------------------------------------------------------------------