kenji999's picture
Duplicate from Phr00t/LTX2-Rapid-Merges
1cc67f2
#!/usr/bin/env python3
# THIS IS FOR ADVANCED LORA INTO BASE MODEL MERGING
# Designed for use with the LTX2 model
#
# Make a text file with a list of LORAs you want to merge in this format for each line:
# <path to safetensors>,<strength>,<lerp>
#
# The "lerp" parameter means "how much should I overwrite tensors that compete with LORAs listed above?". A value of "0" just mixes them all together, while "1" hard applies the LORA delta
#
# You can also supply a separate audio and video strengths like this:
# <path to safetensors>,<video strength>,<lerp>,<audio strength>
#
# Use this script like this:
# python fancy-apply.py <base model safetensors> <lora list txt file> <merged output filename>
import argparse
import os
from typing import Dict, Tuple, List, Optional
from collections import defaultdict
import math
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
# ----------------- Tuning knobs ----------------- #
# If True, the normalized component uses:
# scale_norm = eff_strength / max(1.0, sum_eff_strengths_for_key)
NORMALIZE_OVERLAPS = True
# Per‑LoRA clipping threshold:
# If not None, each LoRA's delta is clipped so that:
# ||delta|| <= CLIP_RATIO * ||W||
CLIP_RATIO: Optional[float] = 1.0
# ----------------- Parsing LoRA list ----------------- #
def parse_lora_list(path: str) -> List[Tuple[str, float, float, float]]:
"""
Parse list_of_loras.txt with lines like:
filename.safetensors,0.7,0.0
filename2.safetensors,1.0,0.5,0.3
Returns list of tuples:
(path, video_strength, lerp_with_existing, audio_strength)
Where:
video_strength: base strength for video/shared weights
audio_strength: base strength for audio weights
(defaults to video_strength if omitted)
lerp_with_existing in [0, 1]:
0.0 -> fully normalized
1.0 -> fully direct
between -> blend between normalized and direct
"""
loras: List[Tuple[str, float, float, float]] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
parts = [p.strip() for p in line.split(",")]
if len(parts) < 3:
raise ValueError(f"Invalid LoRA line (need at least file,video_strength,lerp): {line}")
filename = parts[0]
video_strength = float(parts[1])
lerp = float(parts[2])
if len(parts) >= 4:
audio_strength = float(parts[3])
else:
audio_strength = video_strength
lerp = max(0.0, min(1.0, lerp))
loras.append((filename, video_strength, lerp, audio_strength))
return loras
# ----------------- Base loading ----------------- #
def load_base_with_metadata(path: str):
with safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
tensors = load_file(path, device="cpu")
return tensors, metadata
# ----------------- LoRA key grouping ----------------- #
def group_lora_pairs(lora_tensors: Dict[str, torch.Tensor]):
prefixes = {}
for k in lora_tensors.keys():
if k.endswith(".lora_A.weight"):
prefix = k[: -len(".lora_A.weight")]
prefixes.setdefault(prefix, {})["A"] = k
elif k.endswith(".lora_B.weight"):
prefix = k[: -len(".lora_B.weight")]
prefixes.setdefault(prefix, {})["B"] = k
elif k.endswith(".alpha"):
prefix = k[: -len(".alpha")]
prefixes.setdefault(prefix, {})["alpha"] = k
for prefix, keys in prefixes.items():
if "A" not in keys or "B" not in keys:
print(f"Warning: incomplete LoRA prefix {prefix}")
continue
yield prefix, keys["A"], keys["B"], keys.get("alpha")
def find_base_weight_key(base_tensors, lora_prefix):
candidates = [
f"{lora_prefix}.weight",
f"model.{lora_prefix}.weight",
lora_prefix,
f"model.{lora_prefix}",
]
for c in candidates:
if c in base_tensors:
return c
return None
# ----------------- Audio / video classification ----------------- #
def classify_prefix(prefix: str) -> str:
"""
Classify a LoRA prefix as 'audio', 'video', 'cross', or 'shared'.
"""
p = prefix.lower()
# Cross-modal first
if "audio_to_video" in p or "video_to_audio" in p:
return "cross"
# Audio-specific
if "audio_attn" in p or "audio_ff" in p or ".audio_" in p:
return "audio"
# Video-specific (heuristic)
if "video_attn" in p or "video_ff" in p or ".video_" in p:
return "video"
# Default: shared (treated as video-strength)
return "shared"
def effective_strength_for_prefix(
prefix: str,
video_strength: float,
audio_strength: float,
) -> float:
kind = classify_prefix(prefix)
if kind == "audio":
return audio_strength
elif kind == "video":
return video_strength
elif kind == "cross":
# Blend strengths for cross-modal
return math.sqrt(max(video_strength, 0.0) * max(audio_strength, 0.0))
else:
# shared
return video_strength
# ----------------- Pass 1: strength sums per key ----------------- #
def compute_strength_sums(
base_tensors,
lora_specs: List[Tuple[str, float, float, float]],
) -> Dict[str, float]:
"""
For each base weight key, compute the sum of effective strengths of all LoRAs
that touch it (using video/audio/cross classification).
"""
strength_sum: Dict[str, float] = defaultdict(float)
for lora_path, video_strength, lerp, audio_strength in lora_specs:
print(f"[Pass 1] Scanning {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
lora_tensors = load_file(lora_path, device="cpu")
for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
base_key = find_base_weight_key(base_tensors, prefix)
if base_key is None:
continue
eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
strength_sum[base_key] += eff_strength
del lora_tensors
print(f"[Pass 1] Keys with strength contributions: {len(strength_sum)}")
return strength_sum
# ----------------- Pass 2: streaming application ----------------- #
def apply_loras_streaming(
base_tensors,
lora_specs: List[Tuple[str, float, float, float]],
strength_sum: Dict[str, float],
clip_ratio: Optional[float] = CLIP_RATIO,
):
for lora_path, video_strength, lerp, audio_strength in lora_specs:
print(f"[Pass 2] Applying {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
lora_tensors = load_file(lora_path, device="cpu")
applied = 0
skipped = 0
for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
base_key = find_base_weight_key(base_tensors, prefix)
if base_key is None:
skipped += 1
continue
W = base_tensors[base_key]
A = lora_tensors[A_key].to(torch.float32)
B = lora_tensors[B_key].to(torch.float32)
delta = B @ A
if delta.shape != W.shape:
raise ValueError(
f"Shape mismatch for {prefix}: delta {delta.shape} vs base {W.shape}"
)
rank = A.shape[0] if A.dim() == 2 else A.numel()
# Effective strength for this prefix (audio/video/cross/shared)
eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
# Base strength + alpha scaling
if alpha_key is not None:
alpha = float(lora_tensors[alpha_key].to(torch.float32).item())
base_scale = eff_strength * alpha / max(rank, 1)
else:
base_scale = eff_strength
# Weighted normalization
if NORMALIZE_OVERLAPS:
total_strength = strength_sum.get(base_key, 0.0)
denom = max(1.0, total_strength)
scale_norm = base_scale / denom
else:
scale_norm = base_scale
# Direct (unnormalized) component
scale_direct = base_scale
# LERP between normalized and direct
scale = (1.0 - lerp) * scale_norm + lerp * scale_direct
delta_scaled = delta * scale
# Per‑LoRA clipping
if clip_ratio is not None:
Wf = W.to(torch.float32)
base_norm = Wf.norm().item()
delta_norm = delta_scaled.norm().item()
if delta_norm > clip_ratio * base_norm and delta_norm > 0:
delta_scaled *= (clip_ratio * base_norm) / delta_norm
# Apply update
W_new = W.to(torch.float32) + delta_scaled
base_tensors[base_key] = W_new.to(W.dtype)
applied += 1
print(f"[Pass 2] {lora_path}: applied {applied}, skipped {skipped}")
del lora_tensors
def apply_loras_to_base(base_tensors, lora_specs):
strength_sum = compute_strength_sums(base_tensors, lora_specs)
apply_loras_streaming(base_tensors, lora_specs, strength_sum)
# ----------------- FP8 conversion ----------------- #
def is_vae_key(key: str) -> bool:
return any(key.startswith(p) for p in [
"first_stage_model.",
"model.first_stage_model.",
"vae.",
"model.vae.",
])
def is_text_encoder_key(key: str) -> bool:
return any(key.startswith(p) for p in [
"text_encoder.",
"model.text_encoder.",
"cond_stage_model.",
"model.cond_stage_model.",
])
def is_unet_key(key: str) -> bool:
return any(key.startswith(p) for p in [
"model.diffusion_model.",
"diffusion_model.",
])
def convert_to_fp8_inplace(tensors: Dict[str, torch.Tensor]):
fp8_dtype = torch.float8_e4m3fn
converted = 0
skipped_vae = 0
skipped_other = 0
for k, v in list(tensors.items()):
if not torch.is_floating_point(v):
skipped_other += 1
continue
if is_vae_key(k):
skipped_vae += 1
continue
if is_unet_key(k) or is_text_encoder_key(k):
tensors[k] = v.to(fp8_dtype)
converted += 1
else:
skipped_other += 1
print(
f"FP8 conversion: converted={converted}, "
f"skipped_vae={skipped_vae}, skipped_other={skipped_other}"
)
# ----------------- Main CLI ----------------- #
def main():
parser = argparse.ArgumentParser(
description=(
"Apply LTX2-style LoRAs with separate video/audio strengths, "
"strength-weighted normalization, LERP blending, per‑LoRA clipping, "
"FP8 conversion, and metadata preservation (streaming, memory‑efficient)."
)
)
parser.add_argument("base", help="Base checkpoint (.safetensors)")
parser.add_argument("lora_list", help="Text file: path,video_strength,lerp[,audio_strength]")
parser.add_argument("output", help="Output FP8 checkpoint (.safetensors)")
args = parser.parse_args()
if not os.path.isfile(args.base):
raise FileNotFoundError(args.base)
lora_specs = parse_lora_list(args.lora_list)
if not lora_specs:
raise ValueError("No LoRAs specified.")
print(f"Loading base checkpoint: {args.base}")
base_tensors, metadata = load_base_with_metadata(args.base)
print(f"Base checkpoint has {len(base_tensors)} tensors.")
apply_loras_to_base(base_tensors, lora_specs)
print("Converting UNet + text encoder to FP8 (leaving VAE untouched)...")
convert_to_fp8_inplace(base_tensors)
print(f"Saving merged FP8 checkpoint to: {args.output}")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
save_file(base_tensors, args.output, metadata=metadata)
print("Done.")
if __name__ == "__main__":
main()