| from typing import List, Tuple, Dict, Callable
|
|
|
|
|
| def preparse_loras_multipliers(loras_multipliers):
|
| if isinstance(loras_multipliers, list):
|
| return [multi.strip(" \r\n") if isinstance(multi, str) else multi for multi in loras_multipliers]
|
|
|
| loras_multipliers = loras_multipliers.strip(" \r\n")
|
| loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n")
|
| loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
|
| loras_multipliers = " ".join(loras_mult_choices_list)
|
| return loras_multipliers.replace("|"," ").strip().split(" ")
|
|
|
| def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ):
|
| def expand_one(slist, num_inference_steps):
|
| if not isinstance(slist, list): slist = [slist]
|
| new_slist= []
|
| if num_inference_steps <=0:
|
| return new_slist
|
| inc = len(slist) / num_inference_steps
|
| pos = 0
|
| for i in range(num_inference_steps):
|
| new_slist.append(slist[ int(pos)])
|
| pos += inc
|
| return new_slist
|
|
|
| phase1 = slists_dict["phase1"][mult_no]
|
| phase2 = slists_dict["phase2"][mult_no]
|
| phase3 = slists_dict["phase3"][mult_no]
|
| shared = slists_dict["shared"][mult_no]
|
| if shared:
|
| if isinstance(phase1, float): return phase1
|
| return expand_one(phase1, num_inference_steps)
|
| else:
|
| if isinstance(phase1, float) and isinstance(phase2, float) and isinstance(phase3, float) and phase1 == phase2 and phase2 == phase3: return phase1
|
| return expand_one(phase1, model_switch_step) + expand_one(phase2, model_switch_step2 - model_switch_step) + expand_one(phase3, num_inference_steps - model_switch_step2)
|
|
|
| def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, nb_phases = 2, model_switch_step = None, model_switch_step2 = None, model_switch_phase = 1):
|
| if "|" in loras_multipliers:
|
| pos = loras_multipliers.find("|")
|
| if "|" in loras_multipliers[pos+1:]: return "", "", "There can be only one '|' character in Loras Multipliers Sequence"
|
|
|
| if model_switch_step is None:
|
| model_switch_step = num_inference_steps
|
| if model_switch_step2 is None:
|
| model_switch_step2 = num_inference_steps
|
| def is_float(element: any) -> bool:
|
| if element is None:
|
| return False
|
| try:
|
| float(element)
|
| return True
|
| except ValueError:
|
| return False
|
| loras_list_mult_choices_nums = []
|
| slists_dict = { "model_switch_step": model_switch_step}
|
| slists_dict = { "model_switch_step2": model_switch_step2}
|
| slists_dict["phase1"] = phase1 = [1.] * nb_loras
|
| slists_dict["phase2"] = phase2 = [1.] * nb_loras
|
| slists_dict["phase3"] = phase3 = [1.] * nb_loras
|
| slists_dict["shared"] = shared = [False] * nb_loras
|
|
|
| if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0:
|
| list_mult_choices_list = preparse_loras_multipliers(loras_multipliers)[:nb_loras]
|
| for i, mult in enumerate(list_mult_choices_list):
|
| current_phase = phase1
|
| if isinstance(mult, str):
|
| mult = mult.strip()
|
| phase_mult = mult.split(";")
|
| shared_phases = len(phase_mult) <=1
|
| if not shared_phases and len(phase_mult) != nb_phases :
|
| if len(phase_mult) > nb_phases:
|
| return "", "", f"if the ';' syntax is used for one Lora multiplier, there should be at most {nb_phases} phases for this multiplier"
|
| phase_mult = (phase_mult[:1] + phase_mult) if model_switch_phase == 2 else (phase_mult + phase_mult[-1:])
|
| for phase_no, mult in enumerate(phase_mult):
|
| if phase_no == 1:
|
| current_phase = phase2
|
| elif phase_no == 2:
|
| current_phase = phase3
|
| if "," in mult:
|
| multlist = mult.split(",")
|
| slist = []
|
| for smult in multlist:
|
| if not is_float(smult):
|
| return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid in Phase {phase_no+1}"
|
| slist.append(float(smult))
|
| else:
|
| if not is_float(mult):
|
| return "", "", f"Lora Multiplier no {i+1} ({mult}) is invalid"
|
| slist = float(mult)
|
| if shared_phases:
|
| phase1[i] = phase2[i] = phase3[i] = slist
|
| shared[i] = True
|
| else:
|
| current_phase[i] = slist
|
| else:
|
| phase1[i] = phase2[i] = phase3[i] = float(mult)
|
| shared[i] = True
|
|
|
| if merge_slist is not None:
|
| slists_dict["phase1"] = phase1 = merge_slist["phase1"] + phase1
|
| slists_dict["phase2"] = phase2 = merge_slist["phase2"] + phase2
|
| slists_dict["phase3"] = phase3 = merge_slist["phase3"] + phase3
|
| slists_dict["shared"] = shared = merge_slist["shared"] + shared
|
|
|
| loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step, model_switch_step2 ) for i in range(len(phase1)) ]
|
| loras_list_mult_choices_nums = [ slist[0] if isinstance(slist, list) else slist for slist in loras_list_mult_choices_nums ]
|
|
|
| return loras_list_mult_choices_nums, slists_dict, ""
|
|
|
| def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None):
|
| from mmgp import offload
|
| sz = len(slists_dict["phase1"])
|
| slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ]
|
| nos = [str(l) for l in range(sz)]
|
| offload.activate_loras(trans, nos, slists )
|
|
|
|
|
|
|
| def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
|
| total_num_steps = len(timesteps)
|
| model_switch_step = model_switch_step2 = None
|
| for i, t in enumerate(timesteps):
|
| if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i
|
| if guide_phases >=3 and model_switch_step2 is None and t <= switch2_threshold: model_switch_step2 = i
|
| if model_switch_step is None: model_switch_step = total_num_steps
|
| if model_switch_step2 is None: model_switch_step2 = total_num_steps
|
| phases_description = ""
|
| if guide_phases > 1:
|
| phases_description = "Denoising Steps: "
|
| phases_description += f" Phase 1 = None" if model_switch_step == 0 else f" Phase 1 = 1:{ min(model_switch_step,total_num_steps) }"
|
| if model_switch_step < total_num_steps:
|
| phases_description += f", Phase 2 = None" if model_switch_step == model_switch_step2 else f", Phase 2 = {model_switch_step +1}:{ min(model_switch_step2,total_num_steps) }"
|
| if guide_phases > 2 and model_switch_step2 < total_num_steps:
|
| phases_description += f", Phase 3 = {model_switch_step2 +1}:{ total_num_steps}"
|
| return model_switch_step, model_switch_step2, phases_description
|
|
|
|
|
|
|
| from typing import List, Tuple, Dict, Callable
|
|
|
| _ALWD = set(":;,.0123456789")
|
|
|
|
|
|
|
| def _find_bar(s: str) -> int:
|
| com = False
|
| for i, ch in enumerate(s):
|
| if ch in ('\n', '\r'):
|
| com = False
|
| elif ch == '#':
|
| com = True
|
| elif ch == '|' and not com:
|
| return i
|
| return -1
|
|
|
| def _spans(text: str) -> List[Tuple[int, int]]:
|
| res, com, in_tok, st = [], False, False, 0
|
| for i, ch in enumerate(text):
|
| if ch in ('\n', '\r'):
|
| if in_tok: res.append((st, i)); in_tok = False
|
| com = False
|
| elif ch == '#':
|
| if in_tok: res.append((st, i)); in_tok = False
|
| com = True
|
| elif not com:
|
| if ch in _ALWD:
|
| if not in_tok: in_tok, st = True, i
|
| else:
|
| if in_tok: res.append((st, i)); in_tok = False
|
| if in_tok: res.append((st, len(text)))
|
| return res
|
|
|
| def _choose_sep(text: str, spans: List[Tuple[int, int]]) -> str:
|
| if len(spans) >= 2:
|
| a, b = spans[-2][1], spans[-1][0]
|
| return '\n' if ('\n' in text[a:b] or '\r' in text[a:b]) else ' '
|
| return '\n' if ('\n' in text or '\r' in text) else ' '
|
|
|
| def _ends_in_comment_line(text: str) -> bool:
|
| ln = text.rfind('\n')
|
| seg = text[ln + 1:] if ln != -1 else text
|
| return '#' in seg
|
|
|
| def _append_tokens(text: str, k: int, sep: str) -> str:
|
| if k <= 0: return text
|
| t = text
|
| if _ends_in_comment_line(t) and (not t.endswith('\n')): t += '\n'
|
| parts = []
|
| if t and not t[-1].isspace(): parts.append(sep)
|
| parts.append('1')
|
| for _ in range(k - 1):
|
| parts.append(sep); parts.append('1')
|
| return t + ''.join(parts)
|
|
|
| def _erase_span_and_one_sep(text: str, st: int, en: int) -> str:
|
| n = len(text)
|
| r = en
|
| while r < n and text[r] in (' ', '\t'): r += 1
|
| if r > en: return text[:st] + text[r:]
|
| l = st
|
| while l > 0 and text[l-1] in (' ', '\t'): l -= 1
|
| if l < st: return text[:l] + text[en:]
|
| return text[:st] + text[en:]
|
|
|
| def _trim_last_tokens(text: str, spans: List[Tuple[int, int]], drop: int) -> str:
|
| if drop <= 0: return text
|
| new_text = text
|
| for st, en in reversed(spans[-drop:]):
|
| new_text = _erase_span_and_one_sep(new_text, st, en)
|
| while new_text and new_text[-1] in (' ', '\t'):
|
| new_text = new_text[:-1]
|
| return new_text
|
|
|
| def _enforce_count(text: str, target: int) -> str:
|
| sp = _spans(text); cur = len(sp)
|
| if cur == target: return text
|
| if cur > target: return _trim_last_tokens(text, sp, cur - target)
|
| sep = _choose_sep(text, sp)
|
| return _append_tokens(text, target - cur, sep)
|
|
|
| def _strip_bars_outside_comments(s: str) -> str:
|
| com, out = False, []
|
| for ch in s:
|
| if ch in ('\n', '\r'): com = False; out.append(ch)
|
| elif ch == '#': com = True; out.append(ch)
|
| elif ch == '|' and not com: continue
|
| else: out.append(ch)
|
| return ''.join(out)
|
|
|
| def _replace_tokens(text: str, repl: Dict[int, str]) -> str:
|
| if not repl: return text
|
| sp = _spans(text)
|
| for idx in sorted(repl.keys(), reverse=True):
|
| if 0 <= idx < len(sp):
|
| st, en = sp[idx]
|
| text = text[:st] + repl[idx] + text[en:]
|
| return text
|
|
|
| def _drop_tokens_by_indices(text: str, idxs: List[int]) -> str:
|
| if not idxs: return text
|
| out = text
|
| for idx in sorted(set(idxs), reverse=True):
|
| sp = _spans(out)
|
| if 0 <= idx < len(sp):
|
| st, en = sp[idx]
|
| out = _erase_span_and_one_sep(out, st, en)
|
| return out
|
|
|
|
|
|
|
| def _default_path_key(p: str) -> str:
|
| s = p.strip().replace('\\', '/')
|
| while '//' in s: s = s.replace('//', '/')
|
| if len(s) > 1 and s.endswith('/'): s = s[:-1]
|
| return s
|
|
|
|
|
|
|
| def _select_new_side(
|
| loras_new: List[str],
|
| mult_new: str,
|
| mode: str,
|
| ) -> Tuple[List[str], str]:
|
| """
|
| Split mult_new on '|' (outside comments) and split loras_new accordingly.
|
| Return ONLY the side relevant to `mode`. Extras loras (if any) are appended to the selected side.
|
| """
|
| bi = _find_bar(mult_new)
|
| if bi == -1:
|
| return loras_new, _strip_bars_outside_comments(mult_new)
|
|
|
| left, right = mult_new[:bi], mult_new[bi + 1:]
|
| nL, nR = len(_spans(left)), len(_spans(right))
|
| L = len(loras_new)
|
|
|
|
|
| b_count = min(nL, L)
|
| rem = max(0, L - b_count)
|
| a_count = min(nR, rem)
|
| extras = max(0, L - (b_count + a_count))
|
|
|
| if mode == "merge before":
|
|
|
| l_sel = loras_new[:b_count] + (loras_new[b_count + a_count : b_count + a_count + extras] if extras else [])
|
| m_sel = left
|
| else:
|
|
|
| start_after = b_count
|
| l_sel = loras_new[start_after:start_after + a_count] + (loras_new[start_after + a_count : start_after + a_count + extras] if extras else [])
|
| m_sel = right
|
|
|
| return l_sel, _strip_bars_outside_comments(m_sel)
|
|
|
|
|
|
|
| def merge_loras_settings(
|
| loras_old: List[str],
|
| mult_old: str,
|
| loras_new: List[str],
|
| mult_new: str,
|
| mode: str = "merge before",
|
| path_key: Callable[[str], str] = _default_path_key,
|
| ) -> Tuple[List[str], str]:
|
| """
|
| Merge settings with full formatting/comment preservation and correct handling of `mult_new` with '|'.
|
| Dedup rule: when merging AFTER (resp. BEFORE), if a new lora already exists in preserved BEFORE (resp. AFTER),
|
| update that preserved multiplier and drop the duplicate from the replaced side.
|
| """
|
| assert mode in ("merge before", "merge after")
|
| mult_old= mult_old.strip()
|
| mult_new= mult_new.strip()
|
|
|
|
|
| bi_old = _find_bar(mult_old)
|
| before_old, after_old = (mult_old[:bi_old], mult_old[bi_old + 1:]) if bi_old != -1 else ("", mult_old)
|
| orig_had_bar = (bi_old != -1)
|
|
|
| sp_b_old, sp_a_old = _spans(before_old), _spans(after_old)
|
| n_b_old = len(sp_b_old)
|
| total_old = len(loras_old)
|
|
|
| if n_b_old <= total_old:
|
| keep_b = n_b_old
|
| keep_a = total_old - keep_b
|
| before_old_aligned = before_old
|
| after_old_aligned = _enforce_count(after_old, keep_a)
|
| else:
|
| keep_b = total_old
|
| keep_a = 0
|
| before_old_aligned = _enforce_count(before_old, keep_b)
|
| after_old_aligned = _enforce_count(after_old, 0)
|
|
|
|
|
| loras_new_sel, mult_new_sel = _select_new_side(loras_new, mult_new, mode)
|
| mult_new_aligned = _enforce_count(mult_new_sel, len(loras_new_sel))
|
| sp_new = _spans(mult_new_aligned)
|
| new_tokens = [mult_new_aligned[st:en] for st, en in sp_new]
|
|
|
| if mode == "merge after":
|
|
|
| preserved_loras = loras_old[:keep_b]
|
| preserved_text = before_old_aligned
|
| preserved_spans = _spans(preserved_text)
|
| pos_by_key: Dict[str, int] = {}
|
| for i, lp in enumerate(preserved_loras):
|
| k = path_key(lp)
|
| if k not in pos_by_key: pos_by_key[k] = i
|
|
|
| repl_map: Dict[int, str] = {}
|
| drop_idxs: List[int] = []
|
| for i, lp in enumerate(loras_new_sel):
|
| j = pos_by_key.get(path_key(lp))
|
| if j is not None and j < len(preserved_spans):
|
| repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1"
|
| drop_idxs.append(i)
|
|
|
| before_text = _replace_tokens(preserved_text, repl_map)
|
| after_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs)
|
| loras_keep = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)]
|
| loras_out = preserved_loras + loras_keep
|
|
|
| else:
|
|
|
| preserved_loras = loras_old[keep_b:]
|
| preserved_text = after_old_aligned
|
| preserved_spans = _spans(preserved_text)
|
| pos_by_key: Dict[str, int] = {}
|
| for i, lp in enumerate(preserved_loras):
|
| k = path_key(lp)
|
| if k not in pos_by_key: pos_by_key[k] = i
|
|
|
| repl_map: Dict[int, str] = {}
|
| drop_idxs: List[int] = []
|
| for i, lp in enumerate(loras_new_sel):
|
| j = pos_by_key.get(path_key(lp))
|
| if j is not None and j < len(preserved_spans):
|
| repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1"
|
| drop_idxs.append(i)
|
|
|
| after_text = _replace_tokens(preserved_text, repl_map)
|
| before_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs)
|
| loras_keep = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)]
|
| loras_out = loras_keep + preserved_loras
|
|
|
|
|
| has_before = len(_spans(before_text)) > 0
|
| has_after = len(_spans(after_text)) > 0
|
| if has_before and has_after:
|
| mult_out = f"{before_text}|{after_text}"
|
| elif has_before:
|
| mult_out = before_text + ('|' if (mode == 'merge before' or orig_had_bar) else '')
|
| else:
|
| mult_out = after_text
|
|
|
| return loras_out, mult_out
|
|
|
|
|
|
|
| def extract_loras_side(
|
| loras: List[str],
|
| mult: str,
|
| which: str = "before",
|
| ) -> Tuple[List[str], str]:
|
| assert which in ("before", "after")
|
| bi = _find_bar(mult)
|
| before_txt, after_txt = (mult[:bi], mult[bi + 1:]) if bi != -1 else ("", mult)
|
|
|
| sp_b = _spans(before_txt)
|
| n_b = len(sp_b)
|
| total = len(loras)
|
|
|
| if n_b <= total:
|
| keep_b = n_b
|
| keep_a = total - keep_b
|
| else:
|
| keep_b = total
|
| keep_a = 0
|
|
|
| if which == "before":
|
| return loras[:keep_b], _enforce_count(before_txt, keep_b)
|
| else:
|
| return loras[keep_b:keep_b + keep_a], _enforce_count(after_txt, keep_a)
|
|
|