|
|
from typing import List, Tuple, Dict, Callable |
|
|
|
|
|
|
|
|
def preparse_loras_multipliers(loras_multipliers): |
|
|
if isinstance(loras_multipliers, list): |
|
|
return [multi.strip(" \r\n") if isinstance(multi, str) else multi for multi in loras_multipliers] |
|
|
|
|
|
loras_multipliers = loras_multipliers.strip(" \r\n") |
|
|
loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") |
|
|
loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] |
|
|
loras_multipliers = " ".join(loras_mult_choices_list) |
|
|
return loras_multipliers.replace("|"," ").strip().split(" ") |
|
|
|
|
|
def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ): |
|
|
def expand_one(slist, num_inference_steps): |
|
|
if not isinstance(slist, list): slist = [slist] |
|
|
new_slist= [] |
|
|
if num_inference_steps <=0: |
|
|
return new_slist |
|
|
inc = len(slist) / num_inference_steps |
|
|
pos = 0 |
|
|
for i in range(num_inference_steps): |
|
|
new_slist.append(slist[ int(pos)]) |
|
|
pos += inc |
|
|
return new_slist |
|
|
|
|
|
phase1 = slists_dict["phase1"][mult_no] |
|
|
phase2 = slists_dict["phase2"][mult_no] |
|
|
phase3 = slists_dict["phase3"][mult_no] |
|
|
shared = slists_dict["shared"][mult_no] |
|
|
if shared: |
|
|
if isinstance(phase1, float): return phase1 |
|
|
return expand_one(phase1, num_inference_steps) |
|
|
else: |
|
|
if isinstance(phase1, float) and isinstance(phase2, float) and isinstance(phase3, float) and phase1 == phase2 and phase2 == phase3: return phase1 |
|
|
return expand_one(phase1, model_switch_step) + expand_one(phase2, model_switch_step2 - model_switch_step) + expand_one(phase3, num_inference_steps - model_switch_step2) |
|
|
|
|
|
def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, nb_phases = 2, model_switch_step = None, model_switch_step2 = None): |
|
|
if "|" in loras_multipliers: |
|
|
pos = loras_multipliers.find("|") |
|
|
if "|" in loras_multipliers[pos+1:]: return "", "", "There can be only one '|' character in Loras Multipliers Sequence" |
|
|
|
|
|
if model_switch_step is None: |
|
|
model_switch_step = num_inference_steps |
|
|
if model_switch_step2 is None: |
|
|
model_switch_step2 = num_inference_steps |
|
|
def is_float(element: any) -> bool: |
|
|
if element is None: |
|
|
return False |
|
|
try: |
|
|
float(element) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
loras_list_mult_choices_nums = [] |
|
|
slists_dict = { "model_switch_step": model_switch_step} |
|
|
slists_dict = { "model_switch_step2": model_switch_step2} |
|
|
slists_dict["phase1"] = phase1 = [1.] * nb_loras |
|
|
slists_dict["phase2"] = phase2 = [1.] * nb_loras |
|
|
slists_dict["phase3"] = phase3 = [1.] * nb_loras |
|
|
slists_dict["shared"] = shared = [False] * nb_loras |
|
|
|
|
|
if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0: |
|
|
list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras] |
|
|
for i, mult in enumerate(list_mult_choices_list): |
|
|
current_phase = phase1 |
|
|
if isinstance(mult, str): |
|
|
mult = mult.strip() |
|
|
phase_mult = mult.split(";") |
|
|
shared_phases = len(phase_mult) <=1 |
|
|
if not shared_phases and len(phase_mult) != nb_phases : |
|
|
return "", "", f"if the ';' syntax is used for one Lora multiplier, the multipliers for its {nb_phases} denoising phases should be specified for this multiplier" |
|
|
for phase_no, mult in enumerate(phase_mult): |
|
|
if phase_no == 1: |
|
|
current_phase = phase2 |
|
|
elif phase_no == 2: |
|
|
current_phase = phase3 |
|
|
if "," in mult: |
|
|
multlist = mult.split(",") |
|
|
slist = [] |
|
|
for smult in multlist: |
|
|
if not is_float(smult): |
|
|
return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid in Phase {phase_no+1}" |
|
|
slist.append(float(smult)) |
|
|
else: |
|
|
if not is_float(mult): |
|
|
return "", "", f"Lora Multiplier no {i+1} ({mult}) is invalid" |
|
|
slist = float(mult) |
|
|
if shared_phases: |
|
|
phase1[i] = phase2[i] = phase3[i] = slist |
|
|
shared[i] = True |
|
|
else: |
|
|
current_phase[i] = slist |
|
|
else: |
|
|
phase1[i] = phase2[i] = phase3[i] = float(mult) |
|
|
shared[i] = True |
|
|
|
|
|
if merge_slist is not None: |
|
|
slists_dict["phase1"] = phase1 = merge_slist["phase1"] + phase1 |
|
|
slists_dict["phase2"] = phase2 = merge_slist["phase2"] + phase2 |
|
|
slists_dict["phase3"] = phase3 = merge_slist["phase3"] + phase3 |
|
|
slists_dict["shared"] = shared = merge_slist["shared"] + shared |
|
|
|
|
|
loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step, model_switch_step2 ) for i in range(len(phase1)) ] |
|
|
loras_list_mult_choices_nums = [ slist[0] if isinstance(slist, list) else slist for slist in loras_list_mult_choices_nums ] |
|
|
|
|
|
return loras_list_mult_choices_nums, slists_dict, "" |
|
|
|
|
|
def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None): |
|
|
from mmgp import offload |
|
|
sz = len(slists_dict["phase1"]) |
|
|
slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ] |
|
|
nos = [str(l) for l in range(sz)] |
|
|
offload.activate_loras(trans, nos, slists ) |
|
|
|
|
|
|
|
|
|
|
|
def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): |
|
|
total_num_steps = len(timesteps) |
|
|
model_switch_step = model_switch_step2 = None |
|
|
for i, t in enumerate(timesteps): |
|
|
if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i |
|
|
if guide_phases >=3 and model_switch_step2 is None and t <= switch2_threshold: model_switch_step2 = i |
|
|
if model_switch_step is None: model_switch_step = total_num_steps |
|
|
if model_switch_step2 is None: model_switch_step2 = total_num_steps |
|
|
phases_description = "" |
|
|
if guide_phases > 1: |
|
|
phases_description = "Denoising Steps: " |
|
|
phases_description += f" Phase 1 = None" if model_switch_step == 0 else f" Phase 1 = 1:{ min(model_switch_step,total_num_steps) }" |
|
|
if model_switch_step < total_num_steps: |
|
|
phases_description += f", Phase 2 = None" if model_switch_step == model_switch_step2 else f", Phase 2 = {model_switch_step +1}:{ min(model_switch_step2,total_num_steps) }" |
|
|
if guide_phases > 2 and model_switch_step2 < total_num_steps: |
|
|
phases_description += f", Phase 3 = {model_switch_step2 +1}:{ total_num_steps}" |
|
|
return model_switch_step, model_switch_step2, phases_description |
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple, Dict, Callable |
|
|
|
|
|
_ALWD = set(":;,.0123456789") |
|
|
|
|
|
|
|
|
|
|
|
def _find_bar(s: str) -> int: |
|
|
com = False |
|
|
for i, ch in enumerate(s): |
|
|
if ch in ('\n', '\r'): |
|
|
com = False |
|
|
elif ch == '#': |
|
|
com = True |
|
|
elif ch == '|' and not com: |
|
|
return i |
|
|
return -1 |
|
|
|
|
|
def _spans(text: str) -> List[Tuple[int, int]]: |
|
|
res, com, in_tok, st = [], False, False, 0 |
|
|
for i, ch in enumerate(text): |
|
|
if ch in ('\n', '\r'): |
|
|
if in_tok: res.append((st, i)); in_tok = False |
|
|
com = False |
|
|
elif ch == '#': |
|
|
if in_tok: res.append((st, i)); in_tok = False |
|
|
com = True |
|
|
elif not com: |
|
|
if ch in _ALWD: |
|
|
if not in_tok: in_tok, st = True, i |
|
|
else: |
|
|
if in_tok: res.append((st, i)); in_tok = False |
|
|
if in_tok: res.append((st, len(text))) |
|
|
return res |
|
|
|
|
|
def _choose_sep(text: str, spans: List[Tuple[int, int]]) -> str: |
|
|
if len(spans) >= 2: |
|
|
a, b = spans[-2][1], spans[-1][0] |
|
|
return '\n' if ('\n' in text[a:b] or '\r' in text[a:b]) else ' ' |
|
|
return '\n' if ('\n' in text or '\r' in text) else ' ' |
|
|
|
|
|
def _ends_in_comment_line(text: str) -> bool: |
|
|
ln = text.rfind('\n') |
|
|
seg = text[ln + 1:] if ln != -1 else text |
|
|
return '#' in seg |
|
|
|
|
|
def _append_tokens(text: str, k: int, sep: str) -> str: |
|
|
if k <= 0: return text |
|
|
t = text |
|
|
if _ends_in_comment_line(t) and (not t.endswith('\n')): t += '\n' |
|
|
parts = [] |
|
|
if t and not t[-1].isspace(): parts.append(sep) |
|
|
parts.append('1') |
|
|
for _ in range(k - 1): |
|
|
parts.append(sep); parts.append('1') |
|
|
return t + ''.join(parts) |
|
|
|
|
|
def _erase_span_and_one_sep(text: str, st: int, en: int) -> str: |
|
|
n = len(text) |
|
|
r = en |
|
|
while r < n and text[r] in (' ', '\t'): r += 1 |
|
|
if r > en: return text[:st] + text[r:] |
|
|
l = st |
|
|
while l > 0 and text[l-1] in (' ', '\t'): l -= 1 |
|
|
if l < st: return text[:l] + text[en:] |
|
|
return text[:st] + text[en:] |
|
|
|
|
|
def _trim_last_tokens(text: str, spans: List[Tuple[int, int]], drop: int) -> str: |
|
|
if drop <= 0: return text |
|
|
new_text = text |
|
|
for st, en in reversed(spans[-drop:]): |
|
|
new_text = _erase_span_and_one_sep(new_text, st, en) |
|
|
while new_text and new_text[-1] in (' ', '\t'): |
|
|
new_text = new_text[:-1] |
|
|
return new_text |
|
|
|
|
|
def _enforce_count(text: str, target: int) -> str: |
|
|
sp = _spans(text); cur = len(sp) |
|
|
if cur == target: return text |
|
|
if cur > target: return _trim_last_tokens(text, sp, cur - target) |
|
|
sep = _choose_sep(text, sp) |
|
|
return _append_tokens(text, target - cur, sep) |
|
|
|
|
|
def _strip_bars_outside_comments(s: str) -> str: |
|
|
com, out = False, [] |
|
|
for ch in s: |
|
|
if ch in ('\n', '\r'): com = False; out.append(ch) |
|
|
elif ch == '#': com = True; out.append(ch) |
|
|
elif ch == '|' and not com: continue |
|
|
else: out.append(ch) |
|
|
return ''.join(out) |
|
|
|
|
|
def _replace_tokens(text: str, repl: Dict[int, str]) -> str: |
|
|
if not repl: return text |
|
|
sp = _spans(text) |
|
|
for idx in sorted(repl.keys(), reverse=True): |
|
|
if 0 <= idx < len(sp): |
|
|
st, en = sp[idx] |
|
|
text = text[:st] + repl[idx] + text[en:] |
|
|
return text |
|
|
|
|
|
def _drop_tokens_by_indices(text: str, idxs: List[int]) -> str: |
|
|
if not idxs: return text |
|
|
out = text |
|
|
for idx in sorted(set(idxs), reverse=True): |
|
|
sp = _spans(out) |
|
|
if 0 <= idx < len(sp): |
|
|
st, en = sp[idx] |
|
|
out = _erase_span_and_one_sep(out, st, en) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def _default_path_key(p: str) -> str: |
|
|
s = p.strip().replace('\\', '/') |
|
|
while '//' in s: s = s.replace('//', '/') |
|
|
if len(s) > 1 and s.endswith('/'): s = s[:-1] |
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
def _select_new_side( |
|
|
loras_new: List[str], |
|
|
mult_new: str, |
|
|
mode: str, |
|
|
) -> Tuple[List[str], str]: |
|
|
""" |
|
|
Split mult_new on '|' (outside comments) and split loras_new accordingly. |
|
|
Return ONLY the side relevant to `mode`. Extras loras (if any) are appended to the selected side. |
|
|
""" |
|
|
bi = _find_bar(mult_new) |
|
|
if bi == -1: |
|
|
return loras_new, _strip_bars_outside_comments(mult_new) |
|
|
|
|
|
left, right = mult_new[:bi], mult_new[bi + 1:] |
|
|
nL, nR = len(_spans(left)), len(_spans(right)) |
|
|
L = len(loras_new) |
|
|
|
|
|
|
|
|
b_count = min(nL, L) |
|
|
rem = max(0, L - b_count) |
|
|
a_count = min(nR, rem) |
|
|
extras = max(0, L - (b_count + a_count)) |
|
|
|
|
|
if mode == "merge before": |
|
|
|
|
|
l_sel = loras_new[:b_count] + (loras_new[b_count + a_count : b_count + a_count + extras] if extras else []) |
|
|
m_sel = left |
|
|
else: |
|
|
|
|
|
start_after = b_count |
|
|
l_sel = loras_new[start_after:start_after + a_count] + (loras_new[start_after + a_count : start_after + a_count + extras] if extras else []) |
|
|
m_sel = right |
|
|
|
|
|
return l_sel, _strip_bars_outside_comments(m_sel) |
|
|
|
|
|
|
|
|
|
|
|
def merge_loras_settings( |
|
|
loras_old: List[str], |
|
|
mult_old: str, |
|
|
loras_new: List[str], |
|
|
mult_new: str, |
|
|
mode: str = "merge before", |
|
|
path_key: Callable[[str], str] = _default_path_key, |
|
|
) -> Tuple[List[str], str]: |
|
|
""" |
|
|
Merge settings with full formatting/comment preservation and correct handling of `mult_new` with '|'. |
|
|
Dedup rule: when merging AFTER (resp. BEFORE), if a new lora already exists in preserved BEFORE (resp. AFTER), |
|
|
update that preserved multiplier and drop the duplicate from the replaced side. |
|
|
""" |
|
|
assert mode in ("merge before", "merge after") |
|
|
mult_old= mult_old.strip() |
|
|
mult_new= mult_new.strip() |
|
|
|
|
|
|
|
|
bi_old = _find_bar(mult_old) |
|
|
before_old, after_old = (mult_old[:bi_old], mult_old[bi_old + 1:]) if bi_old != -1 else ("", mult_old) |
|
|
orig_had_bar = (bi_old != -1) |
|
|
|
|
|
sp_b_old, sp_a_old = _spans(before_old), _spans(after_old) |
|
|
n_b_old = len(sp_b_old) |
|
|
total_old = len(loras_old) |
|
|
|
|
|
if n_b_old <= total_old: |
|
|
keep_b = n_b_old |
|
|
keep_a = total_old - keep_b |
|
|
before_old_aligned = before_old |
|
|
after_old_aligned = _enforce_count(after_old, keep_a) |
|
|
else: |
|
|
keep_b = total_old |
|
|
keep_a = 0 |
|
|
before_old_aligned = _enforce_count(before_old, keep_b) |
|
|
after_old_aligned = _enforce_count(after_old, 0) |
|
|
|
|
|
|
|
|
loras_new_sel, mult_new_sel = _select_new_side(loras_new, mult_new, mode) |
|
|
mult_new_aligned = _enforce_count(mult_new_sel, len(loras_new_sel)) |
|
|
sp_new = _spans(mult_new_aligned) |
|
|
new_tokens = [mult_new_aligned[st:en] for st, en in sp_new] |
|
|
|
|
|
if mode == "merge after": |
|
|
|
|
|
preserved_loras = loras_old[:keep_b] |
|
|
preserved_text = before_old_aligned |
|
|
preserved_spans = _spans(preserved_text) |
|
|
pos_by_key: Dict[str, int] = {} |
|
|
for i, lp in enumerate(preserved_loras): |
|
|
k = path_key(lp) |
|
|
if k not in pos_by_key: pos_by_key[k] = i |
|
|
|
|
|
repl_map: Dict[int, str] = {} |
|
|
drop_idxs: List[int] = [] |
|
|
for i, lp in enumerate(loras_new_sel): |
|
|
j = pos_by_key.get(path_key(lp)) |
|
|
if j is not None and j < len(preserved_spans): |
|
|
repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1" |
|
|
drop_idxs.append(i) |
|
|
|
|
|
before_text = _replace_tokens(preserved_text, repl_map) |
|
|
after_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs) |
|
|
loras_keep = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)] |
|
|
loras_out = preserved_loras + loras_keep |
|
|
|
|
|
else: |
|
|
|
|
|
preserved_loras = loras_old[keep_b:] |
|
|
preserved_text = after_old_aligned |
|
|
preserved_spans = _spans(preserved_text) |
|
|
pos_by_key: Dict[str, int] = {} |
|
|
for i, lp in enumerate(preserved_loras): |
|
|
k = path_key(lp) |
|
|
if k not in pos_by_key: pos_by_key[k] = i |
|
|
|
|
|
repl_map: Dict[int, str] = {} |
|
|
drop_idxs: List[int] = [] |
|
|
for i, lp in enumerate(loras_new_sel): |
|
|
j = pos_by_key.get(path_key(lp)) |
|
|
if j is not None and j < len(preserved_spans): |
|
|
repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1" |
|
|
drop_idxs.append(i) |
|
|
|
|
|
after_text = _replace_tokens(preserved_text, repl_map) |
|
|
before_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs) |
|
|
loras_keep = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)] |
|
|
loras_out = loras_keep + preserved_loras |
|
|
|
|
|
|
|
|
has_before = len(_spans(before_text)) > 0 |
|
|
has_after = len(_spans(after_text)) > 0 |
|
|
if has_before and has_after: |
|
|
mult_out = f"{before_text}|{after_text}" |
|
|
elif has_before: |
|
|
mult_out = before_text + ('|' if (mode == 'merge before' or orig_had_bar) else '') |
|
|
else: |
|
|
mult_out = after_text |
|
|
|
|
|
return loras_out, mult_out |
|
|
|
|
|
|
|
|
|
|
|
def extract_loras_side( |
|
|
loras: List[str], |
|
|
mult: str, |
|
|
which: str = "before", |
|
|
) -> Tuple[List[str], str]: |
|
|
assert which in ("before", "after") |
|
|
bi = _find_bar(mult) |
|
|
before_txt, after_txt = (mult[:bi], mult[bi + 1:]) if bi != -1 else ("", mult) |
|
|
|
|
|
sp_b = _spans(before_txt) |
|
|
n_b = len(sp_b) |
|
|
total = len(loras) |
|
|
|
|
|
if n_b <= total: |
|
|
keep_b = n_b |
|
|
keep_a = total - keep_b |
|
|
else: |
|
|
keep_b = total |
|
|
keep_a = 0 |
|
|
|
|
|
if which == "before": |
|
|
return loras[:keep_b], _enforce_count(before_txt, keep_b) |
|
|
else: |
|
|
return loras[keep_b:keep_b + keep_a], _enforce_count(after_txt, keep_a) |
|
|
|