|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
import os
|
|
|
from typing import Dict, Tuple, List, Optional
|
|
|
from collections import defaultdict
|
|
|
import math
|
|
|
|
|
|
import torch
|
|
|
from safetensors.torch import load_file, save_file
|
|
|
from safetensors import safe_open
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NORMALIZE_OVERLAPS = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLIP_RATIO: Optional[float] = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_lora_list(path: str) -> List[Tuple[str, float, float, float]]:
|
|
|
"""
|
|
|
Parse list_of_loras.txt with lines like:
|
|
|
filename.safetensors,0.7,0.0
|
|
|
filename2.safetensors,1.0,0.5,0.3
|
|
|
|
|
|
Returns list of tuples:
|
|
|
(path, video_strength, lerp_with_existing, audio_strength)
|
|
|
|
|
|
Where:
|
|
|
video_strength: base strength for video/shared weights
|
|
|
audio_strength: base strength for audio weights
|
|
|
(defaults to video_strength if omitted)
|
|
|
lerp_with_existing in [0, 1]:
|
|
|
0.0 -> fully normalized
|
|
|
1.0 -> fully direct
|
|
|
between -> blend between normalized and direct
|
|
|
"""
|
|
|
loras: List[Tuple[str, float, float, float]] = []
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
for line in f:
|
|
|
line = line.strip()
|
|
|
if not line or line.startswith("#"):
|
|
|
continue
|
|
|
|
|
|
parts = [p.strip() for p in line.split(",")]
|
|
|
if len(parts) < 3:
|
|
|
raise ValueError(f"Invalid LoRA line (need at least file,video_strength,lerp): {line}")
|
|
|
|
|
|
filename = parts[0]
|
|
|
video_strength = float(parts[1])
|
|
|
lerp = float(parts[2])
|
|
|
|
|
|
if len(parts) >= 4:
|
|
|
audio_strength = float(parts[3])
|
|
|
else:
|
|
|
audio_strength = video_strength
|
|
|
|
|
|
lerp = max(0.0, min(1.0, lerp))
|
|
|
|
|
|
loras.append((filename, video_strength, lerp, audio_strength))
|
|
|
|
|
|
return loras
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_base_with_metadata(path: str):
|
|
|
with safe_open(path, framework="pt", device="cpu") as f:
|
|
|
metadata = f.metadata() or {}
|
|
|
tensors = load_file(path, device="cpu")
|
|
|
return tensors, metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def group_lora_pairs(lora_tensors: Dict[str, torch.Tensor]):
|
|
|
prefixes = {}
|
|
|
for k in lora_tensors.keys():
|
|
|
if k.endswith(".lora_A.weight"):
|
|
|
prefix = k[: -len(".lora_A.weight")]
|
|
|
prefixes.setdefault(prefix, {})["A"] = k
|
|
|
elif k.endswith(".lora_B.weight"):
|
|
|
prefix = k[: -len(".lora_B.weight")]
|
|
|
prefixes.setdefault(prefix, {})["B"] = k
|
|
|
elif k.endswith(".alpha"):
|
|
|
prefix = k[: -len(".alpha")]
|
|
|
prefixes.setdefault(prefix, {})["alpha"] = k
|
|
|
|
|
|
for prefix, keys in prefixes.items():
|
|
|
if "A" not in keys or "B" not in keys:
|
|
|
print(f"Warning: incomplete LoRA prefix {prefix}")
|
|
|
continue
|
|
|
yield prefix, keys["A"], keys["B"], keys.get("alpha")
|
|
|
|
|
|
|
|
|
def find_base_weight_key(base_tensors, lora_prefix):
|
|
|
candidates = [
|
|
|
f"{lora_prefix}.weight",
|
|
|
f"model.{lora_prefix}.weight",
|
|
|
lora_prefix,
|
|
|
f"model.{lora_prefix}",
|
|
|
]
|
|
|
for c in candidates:
|
|
|
if c in base_tensors:
|
|
|
return c
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_prefix(prefix: str) -> str:
|
|
|
"""
|
|
|
Classify a LoRA prefix as 'audio', 'video', 'cross', or 'shared'.
|
|
|
"""
|
|
|
p = prefix.lower()
|
|
|
|
|
|
|
|
|
if "audio_to_video" in p or "video_to_audio" in p:
|
|
|
return "cross"
|
|
|
|
|
|
|
|
|
if "audio_attn" in p or "audio_ff" in p or ".audio_" in p:
|
|
|
return "audio"
|
|
|
|
|
|
|
|
|
if "video_attn" in p or "video_ff" in p or ".video_" in p:
|
|
|
return "video"
|
|
|
|
|
|
|
|
|
return "shared"
|
|
|
|
|
|
|
|
|
def effective_strength_for_prefix(
|
|
|
prefix: str,
|
|
|
video_strength: float,
|
|
|
audio_strength: float,
|
|
|
) -> float:
|
|
|
kind = classify_prefix(prefix)
|
|
|
if kind == "audio":
|
|
|
return audio_strength
|
|
|
elif kind == "video":
|
|
|
return video_strength
|
|
|
elif kind == "cross":
|
|
|
|
|
|
return math.sqrt(max(video_strength, 0.0) * max(audio_strength, 0.0))
|
|
|
else:
|
|
|
|
|
|
return video_strength
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_strength_sums(
|
|
|
base_tensors,
|
|
|
lora_specs: List[Tuple[str, float, float, float]],
|
|
|
) -> Dict[str, float]:
|
|
|
"""
|
|
|
For each base weight key, compute the sum of effective strengths of all LoRAs
|
|
|
that touch it (using video/audio/cross classification).
|
|
|
"""
|
|
|
strength_sum: Dict[str, float] = defaultdict(float)
|
|
|
|
|
|
for lora_path, video_strength, lerp, audio_strength in lora_specs:
|
|
|
print(f"[Pass 1] Scanning {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
|
|
|
lora_tensors = load_file(lora_path, device="cpu")
|
|
|
|
|
|
for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
|
|
|
base_key = find_base_weight_key(base_tensors, prefix)
|
|
|
if base_key is None:
|
|
|
continue
|
|
|
|
|
|
eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
|
|
|
strength_sum[base_key] += eff_strength
|
|
|
|
|
|
del lora_tensors
|
|
|
|
|
|
print(f"[Pass 1] Keys with strength contributions: {len(strength_sum)}")
|
|
|
return strength_sum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_loras_streaming(
|
|
|
base_tensors,
|
|
|
lora_specs: List[Tuple[str, float, float, float]],
|
|
|
strength_sum: Dict[str, float],
|
|
|
clip_ratio: Optional[float] = CLIP_RATIO,
|
|
|
):
|
|
|
for lora_path, video_strength, lerp, audio_strength in lora_specs:
|
|
|
print(f"[Pass 2] Applying {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
|
|
|
lora_tensors = load_file(lora_path, device="cpu")
|
|
|
|
|
|
applied = 0
|
|
|
skipped = 0
|
|
|
|
|
|
for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
|
|
|
base_key = find_base_weight_key(base_tensors, prefix)
|
|
|
if base_key is None:
|
|
|
skipped += 1
|
|
|
continue
|
|
|
|
|
|
W = base_tensors[base_key]
|
|
|
|
|
|
A = lora_tensors[A_key].to(torch.float32)
|
|
|
B = lora_tensors[B_key].to(torch.float32)
|
|
|
delta = B @ A
|
|
|
|
|
|
if delta.shape != W.shape:
|
|
|
raise ValueError(
|
|
|
f"Shape mismatch for {prefix}: delta {delta.shape} vs base {W.shape}"
|
|
|
)
|
|
|
|
|
|
rank = A.shape[0] if A.dim() == 2 else A.numel()
|
|
|
|
|
|
|
|
|
eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
|
|
|
|
|
|
|
|
|
if alpha_key is not None:
|
|
|
alpha = float(lora_tensors[alpha_key].to(torch.float32).item())
|
|
|
base_scale = eff_strength * alpha / max(rank, 1)
|
|
|
else:
|
|
|
base_scale = eff_strength
|
|
|
|
|
|
|
|
|
if NORMALIZE_OVERLAPS:
|
|
|
total_strength = strength_sum.get(base_key, 0.0)
|
|
|
denom = max(1.0, total_strength)
|
|
|
scale_norm = base_scale / denom
|
|
|
else:
|
|
|
scale_norm = base_scale
|
|
|
|
|
|
|
|
|
scale_direct = base_scale
|
|
|
|
|
|
|
|
|
scale = (1.0 - lerp) * scale_norm + lerp * scale_direct
|
|
|
|
|
|
delta_scaled = delta * scale
|
|
|
|
|
|
|
|
|
if clip_ratio is not None:
|
|
|
Wf = W.to(torch.float32)
|
|
|
base_norm = Wf.norm().item()
|
|
|
delta_norm = delta_scaled.norm().item()
|
|
|
|
|
|
if delta_norm > clip_ratio * base_norm and delta_norm > 0:
|
|
|
delta_scaled *= (clip_ratio * base_norm) / delta_norm
|
|
|
|
|
|
|
|
|
W_new = W.to(torch.float32) + delta_scaled
|
|
|
base_tensors[base_key] = W_new.to(W.dtype)
|
|
|
|
|
|
applied += 1
|
|
|
|
|
|
print(f"[Pass 2] {lora_path}: applied {applied}, skipped {skipped}")
|
|
|
del lora_tensors
|
|
|
|
|
|
|
|
|
def apply_loras_to_base(base_tensors, lora_specs):
|
|
|
strength_sum = compute_strength_sums(base_tensors, lora_specs)
|
|
|
apply_loras_streaming(base_tensors, lora_specs, strength_sum)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_vae_key(key: str) -> bool:
|
|
|
return any(key.startswith(p) for p in [
|
|
|
"first_stage_model.",
|
|
|
"model.first_stage_model.",
|
|
|
"vae.",
|
|
|
"model.vae.",
|
|
|
])
|
|
|
|
|
|
|
|
|
def is_text_encoder_key(key: str) -> bool:
|
|
|
return any(key.startswith(p) for p in [
|
|
|
"text_encoder.",
|
|
|
"model.text_encoder.",
|
|
|
"cond_stage_model.",
|
|
|
"model.cond_stage_model.",
|
|
|
])
|
|
|
|
|
|
|
|
|
def is_unet_key(key: str) -> bool:
|
|
|
return any(key.startswith(p) for p in [
|
|
|
"model.diffusion_model.",
|
|
|
"diffusion_model.",
|
|
|
])
|
|
|
|
|
|
|
|
|
def convert_to_fp8_inplace(tensors: Dict[str, torch.Tensor]):
|
|
|
fp8_dtype = torch.float8_e4m3fn
|
|
|
|
|
|
converted = 0
|
|
|
skipped_vae = 0
|
|
|
skipped_other = 0
|
|
|
|
|
|
for k, v in list(tensors.items()):
|
|
|
if not torch.is_floating_point(v):
|
|
|
skipped_other += 1
|
|
|
continue
|
|
|
|
|
|
if is_vae_key(k):
|
|
|
skipped_vae += 1
|
|
|
continue
|
|
|
|
|
|
if is_unet_key(k) or is_text_encoder_key(k):
|
|
|
tensors[k] = v.to(fp8_dtype)
|
|
|
converted += 1
|
|
|
else:
|
|
|
skipped_other += 1
|
|
|
|
|
|
print(
|
|
|
f"FP8 conversion: converted={converted}, "
|
|
|
f"skipped_vae={skipped_vae}, skipped_other={skipped_other}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description=(
|
|
|
"Apply LTX2-style LoRAs with separate video/audio strengths, "
|
|
|
"strength-weighted normalization, LERP blending, per‑LoRA clipping, "
|
|
|
"FP8 conversion, and metadata preservation (streaming, memory‑efficient)."
|
|
|
)
|
|
|
)
|
|
|
parser.add_argument("base", help="Base checkpoint (.safetensors)")
|
|
|
parser.add_argument("lora_list", help="Text file: path,video_strength,lerp[,audio_strength]")
|
|
|
parser.add_argument("output", help="Output FP8 checkpoint (.safetensors)")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if not os.path.isfile(args.base):
|
|
|
raise FileNotFoundError(args.base)
|
|
|
|
|
|
lora_specs = parse_lora_list(args.lora_list)
|
|
|
if not lora_specs:
|
|
|
raise ValueError("No LoRAs specified.")
|
|
|
|
|
|
print(f"Loading base checkpoint: {args.base}")
|
|
|
base_tensors, metadata = load_base_with_metadata(args.base)
|
|
|
print(f"Base checkpoint has {len(base_tensors)} tensors.")
|
|
|
|
|
|
apply_loras_to_base(base_tensors, lora_specs)
|
|
|
|
|
|
print("Converting UNet + text encoder to FP8 (leaving VAE untouched)...")
|
|
|
convert_to_fp8_inplace(base_tensors)
|
|
|
|
|
|
print(f"Saving merged FP8 checkpoint to: {args.output}")
|
|
|
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
|
|
save_file(base_tensors, args.output, metadata=metadata)
|
|
|
print("Done.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|