Update z-prompt-fusion-extension/scripts/promptlang.py
Browse files
z-prompt-fusion-extension/scripts/promptlang.py
CHANGED
|
@@ -134,7 +134,8 @@ def _adapt_flattened_schedules(result: Any, total_steps: int):
|
|
| 134 |
if isinstance(src, dict):
|
| 135 |
out = dict(dst)
|
| 136 |
for k, v in src.items():
|
| 137 |
-
|
|
|
|
| 138 |
return out
|
| 139 |
return dst + src * w
|
| 140 |
|
|
@@ -214,41 +215,54 @@ def _build_tensor_for_prompt(
|
|
| 214 |
]
|
| 215 |
wrapped_conds.append(wrapped_sched)
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
slerp_scale = global_state.get_slerp_scale()
|
| 221 |
-
slerp_epsilon = global_state.get_slerp_epsilon()
|
| 222 |
-
|
| 223 |
-
# Сформировать расписание, склеивая одинаковые соседние сегменты
|
| 224 |
-
schedules: List[prompt_parser.ScheduledPromptConditioning] = []
|
| 225 |
-
prev_wrapper = None
|
| 226 |
-
for step in range(total_steps):
|
| 227 |
-
params = interpolation_tensor.InterpolationParams(
|
| 228 |
-
t=step / max(1, total_steps - 1),
|
| 229 |
-
step=step,
|
| 230 |
-
total_steps=total_steps,
|
| 231 |
-
slerp_scale=slerp_scale,
|
| 232 |
-
slerp_epsilon=slerp_epsilon,
|
| 233 |
-
)
|
| 234 |
-
origin = global_state.get_origin_cond_at(step, is_hires=is_hires)
|
| 235 |
-
cond_wrapper = tensor.interpolate(params, origin, empty_cond.get())
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
schedules
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
)
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
-
|
| 248 |
-
if schedules:
|
| 249 |
-
schedules[-1].end_at_step = total_steps - 1
|
| 250 |
|
| 251 |
-
return schedules
|
| 252 |
|
| 253 |
|
| 254 |
# -----------------------------------------------------------------------------
|
|
|
|
| 134 |
if isinstance(src, dict):
|
| 135 |
out = dict(dst)
|
| 136 |
for k, v in src.items():
|
| 137 |
+
prev = out.get(k)
|
| 138 |
+
out[k] = (prev + v * w) if prev is not None else (v * w)
|
| 139 |
return out
|
| 140 |
return dst + src * w
|
| 141 |
|
|
|
|
| 215 |
]
|
| 216 |
wrapped_conds.append(wrapped_sched)
|
| 217 |
|
| 218 |
+
# 4) Собрать интерполяционный тензор и вычислить расписания по шагам
|
| 219 |
+
tensor = tensor_builder.build(wrapped_conds, empty_cond.get())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
slerp_scale = global_state.get_slerp_scale()
|
| 222 |
+
slerp_epsilon = global_state.get_slerp_epsilon()
|
| 223 |
+
|
| 224 |
+
# Вспомогалка: безопасно «продлить» последний сегмент до step
|
| 225 |
+
def _extend_last_segment_to(step_val: int):
|
| 226 |
+
last = schedules[-1]
|
| 227 |
+
if getattr(last, "end_at_step", None) != step_val:
|
| 228 |
+
schedules[-1] = prompt_parser.ScheduledPromptConditioning(
|
| 229 |
+
end_at_step=step_val,
|
| 230 |
+
cond=last.cond,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Сформировать расписание, склеивая одинаковые соседние сегменты
|
| 234 |
+
schedules: List[prompt_parser.ScheduledPromptConditioning] = []
|
| 235 |
+
prev_wrapper = None
|
| 236 |
+
for step in range(total_steps):
|
| 237 |
+
params = interpolation_tensor.InterpolationParams(
|
| 238 |
+
t=step / max(1, total_steps - 1),
|
| 239 |
+
step=step,
|
| 240 |
+
total_steps=total_steps,
|
| 241 |
+
slerp_scale=slerp_scale,
|
| 242 |
+
slerp_epsilon=slerp_epsilon,
|
| 243 |
)
|
| 244 |
+
origin = global_state.get_origin_cond_at(step, is_hires=is_hires)
|
| 245 |
+
cond_wrapper = tensor.interpolate(params, origin, empty_cond.get())
|
| 246 |
+
|
| 247 |
+
if prev_wrapper is not None and cond_wrapper == prev_wrapper:
|
| 248 |
+
# Было: schedules[-1].end_at_step = step (иммутабельно -> нельзя)
|
| 249 |
+
_extend_last_segment_to(step)
|
| 250 |
+
else:
|
| 251 |
+
raw = getattr(cond_wrapper, "original_cond", cond_wrapper)
|
| 252 |
+
schedules.append(
|
| 253 |
+
prompt_parser.ScheduledPromptConditioning(
|
| 254 |
+
end_at_step=step,
|
| 255 |
+
cond=raw,
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
prev_wrapper = cond_wrapper
|
| 259 |
+
|
| 260 |
+
# Финализировать последний сегмент на последний шаг
|
| 261 |
+
if schedules:
|
| 262 |
+
_extend_last_segment_to(total_steps - 1)
|
| 263 |
|
| 264 |
+
return schedules
|
|
|
|
|
|
|
| 265 |
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
# -----------------------------------------------------------------------------
|