|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import argparse
|
| import os
|
| from typing import Dict, Tuple, List, Optional
|
| from collections import defaultdict
|
| import math
|
|
|
| import torch
|
| from safetensors.torch import load_file, save_file
|
| from safetensors import safe_open
|
|
|
|
|
|
|
|
|
|
|
|
|
| NORMALIZE_OVERLAPS = True
|
|
|
|
|
|
|
|
|
| CLIP_RATIO: Optional[float] = 1.0
|
|
|
|
|
|
|
|
|
| def parse_lora_list(path: str) -> List[Tuple[str, float, float, float]]:
|
| """
|
| Parse list_of_loras.txt with lines like:
|
| filename.safetensors,0.7,0.0
|
| filename2.safetensors,1.0,0.5,0.3
|
|
|
| Returns list of tuples:
|
| (path, video_strength, lerp_with_existing, audio_strength)
|
|
|
| Where:
|
| video_strength: base strength for video/shared weights
|
| audio_strength: base strength for audio weights
|
| (defaults to video_strength if omitted)
|
| lerp_with_existing in [0, 1]:
|
| 0.0 -> fully normalized
|
| 1.0 -> fully direct
|
| between -> blend between normalized and direct
|
| """
|
| loras: List[Tuple[str, float, float, float]] = []
|
| with open(path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line or line.startswith("#"):
|
| continue
|
|
|
| parts = [p.strip() for p in line.split(",")]
|
| if len(parts) < 3:
|
| raise ValueError(f"Invalid LoRA line (need at least file,video_strength,lerp): {line}")
|
|
|
| filename = parts[0]
|
| video_strength = float(parts[1])
|
| lerp = float(parts[2])
|
|
|
| if len(parts) >= 4:
|
| audio_strength = float(parts[3])
|
| else:
|
| audio_strength = video_strength
|
|
|
| lerp = max(0.0, min(1.0, lerp))
|
|
|
| loras.append((filename, video_strength, lerp, audio_strength))
|
|
|
| return loras
|
|
|
|
|
|
|
|
|
| def load_base_with_metadata(path: str):
|
| with safe_open(path, framework="pt", device="cpu") as f:
|
| metadata = f.metadata() or {}
|
| tensors = load_file(path, device="cpu")
|
| return tensors, metadata
|
|
|
|
|
|
|
|
|
| def group_lora_pairs(lora_tensors: Dict[str, torch.Tensor]):
|
| prefixes = {}
|
| for k in lora_tensors.keys():
|
| if k.endswith(".lora_A.weight"):
|
| prefix = k[: -len(".lora_A.weight")]
|
| prefixes.setdefault(prefix, {})["A"] = k
|
| elif k.endswith(".lora_B.weight"):
|
| prefix = k[: -len(".lora_B.weight")]
|
| prefixes.setdefault(prefix, {})["B"] = k
|
| elif k.endswith(".alpha"):
|
| prefix = k[: -len(".alpha")]
|
| prefixes.setdefault(prefix, {})["alpha"] = k
|
|
|
| for prefix, keys in prefixes.items():
|
| if "A" not in keys or "B" not in keys:
|
| print(f"Warning: incomplete LoRA prefix {prefix}")
|
| continue
|
| yield prefix, keys["A"], keys["B"], keys.get("alpha")
|
|
|
|
|
| def find_base_weight_key(base_tensors, lora_prefix):
|
| candidates = [
|
| f"{lora_prefix}.weight",
|
| f"model.{lora_prefix}.weight",
|
| lora_prefix,
|
| f"model.{lora_prefix}",
|
| ]
|
| for c in candidates:
|
| if c in base_tensors:
|
| return c
|
| return None
|
|
|
|
|
|
|
|
|
| def classify_prefix(prefix: str) -> str:
|
| """
|
| Classify a LoRA prefix as 'audio', 'video', 'cross', or 'shared'.
|
| """
|
| p = prefix.lower()
|
|
|
|
|
| if "audio_to_video" in p or "video_to_audio" in p:
|
| return "cross"
|
|
|
|
|
| if "audio_attn" in p or "audio_ff" in p or ".audio_" in p:
|
| return "audio"
|
|
|
|
|
| if "video_attn" in p or "video_ff" in p or ".video_" in p:
|
| return "video"
|
|
|
|
|
| return "shared"
|
|
|
|
|
| def effective_strength_for_prefix(
|
| prefix: str,
|
| video_strength: float,
|
| audio_strength: float,
|
| ) -> float:
|
| kind = classify_prefix(prefix)
|
| if kind == "audio":
|
| return audio_strength
|
| elif kind == "video":
|
| return video_strength
|
| elif kind == "cross":
|
|
|
| return math.sqrt(max(video_strength, 0.0) * max(audio_strength, 0.0))
|
| else:
|
|
|
| return video_strength
|
|
|
|
|
|
|
|
|
| def compute_strength_sums(
|
| base_tensors,
|
| lora_specs: List[Tuple[str, float, float, float]],
|
| ) -> Dict[str, float]:
|
| """
|
| For each base weight key, compute the sum of effective strengths of all LoRAs
|
| that touch it (using video/audio/cross classification).
|
| """
|
| strength_sum: Dict[str, float] = defaultdict(float)
|
|
|
| for lora_path, video_strength, lerp, audio_strength in lora_specs:
|
| print(f"[Pass 1] Scanning {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
|
| lora_tensors = load_file(lora_path, device="cpu")
|
|
|
| for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
|
| base_key = find_base_weight_key(base_tensors, prefix)
|
| if base_key is None:
|
| continue
|
|
|
| eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
|
| strength_sum[base_key] += eff_strength
|
|
|
| del lora_tensors
|
|
|
| print(f"[Pass 1] Keys with strength contributions: {len(strength_sum)}")
|
| return strength_sum
|
|
|
|
|
|
|
|
|
| def apply_loras_streaming(
|
| base_tensors,
|
| lora_specs: List[Tuple[str, float, float, float]],
|
| strength_sum: Dict[str, float],
|
| clip_ratio: Optional[float] = CLIP_RATIO,
|
| ):
|
| for lora_path, video_strength, lerp, audio_strength in lora_specs:
|
| print(f"[Pass 2] Applying {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
|
| lora_tensors = load_file(lora_path, device="cpu")
|
|
|
| applied = 0
|
| skipped = 0
|
|
|
| for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
|
| base_key = find_base_weight_key(base_tensors, prefix)
|
| if base_key is None:
|
| skipped += 1
|
| continue
|
|
|
| W = base_tensors[base_key]
|
|
|
| A = lora_tensors[A_key].to(torch.float32)
|
| B = lora_tensors[B_key].to(torch.float32)
|
| delta = B @ A
|
|
|
| if delta.shape != W.shape:
|
| raise ValueError(
|
| f"Shape mismatch for {prefix}: delta {delta.shape} vs base {W.shape}"
|
| )
|
|
|
| rank = A.shape[0] if A.dim() == 2 else A.numel()
|
|
|
|
|
| eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
|
|
|
|
|
| if alpha_key is not None:
|
| alpha = float(lora_tensors[alpha_key].to(torch.float32).item())
|
| base_scale = eff_strength * alpha / max(rank, 1)
|
| else:
|
| base_scale = eff_strength
|
|
|
|
|
| if NORMALIZE_OVERLAPS:
|
| total_strength = strength_sum.get(base_key, 0.0)
|
| denom = max(1.0, total_strength)
|
| scale_norm = base_scale / denom
|
| else:
|
| scale_norm = base_scale
|
|
|
|
|
| scale_direct = base_scale
|
|
|
|
|
| scale = (1.0 - lerp) * scale_norm + lerp * scale_direct
|
|
|
| delta_scaled = delta * scale
|
|
|
|
|
| if clip_ratio is not None:
|
| Wf = W.to(torch.float32)
|
| base_norm = Wf.norm().item()
|
| delta_norm = delta_scaled.norm().item()
|
|
|
| if delta_norm > clip_ratio * base_norm and delta_norm > 0:
|
| delta_scaled *= (clip_ratio * base_norm) / delta_norm
|
|
|
|
|
| W_new = W.to(torch.float32) + delta_scaled
|
| base_tensors[base_key] = W_new.to(W.dtype)
|
|
|
| applied += 1
|
|
|
| print(f"[Pass 2] {lora_path}: applied {applied}, skipped {skipped}")
|
| del lora_tensors
|
|
|
|
|
| def apply_loras_to_base(base_tensors, lora_specs):
|
| strength_sum = compute_strength_sums(base_tensors, lora_specs)
|
| apply_loras_streaming(base_tensors, lora_specs, strength_sum)
|
|
|
|
|
|
|
|
|
| def is_vae_key(key: str) -> bool:
|
| return any(key.startswith(p) for p in [
|
| "first_stage_model.",
|
| "model.first_stage_model.",
|
| "vae.",
|
| "model.vae.",
|
| ])
|
|
|
|
|
| def is_text_encoder_key(key: str) -> bool:
|
| return any(key.startswith(p) for p in [
|
| "text_encoder.",
|
| "model.text_encoder.",
|
| "cond_stage_model.",
|
| "model.cond_stage_model.",
|
| ])
|
|
|
|
|
| def is_unet_key(key: str) -> bool:
|
| return any(key.startswith(p) for p in [
|
| "model.diffusion_model.",
|
| "diffusion_model.",
|
| ])
|
|
|
|
|
| def convert_to_fp8_inplace(tensors: Dict[str, torch.Tensor]):
|
| fp8_dtype = torch.float8_e4m3fn
|
|
|
| converted = 0
|
| skipped_vae = 0
|
| skipped_other = 0
|
|
|
| for k, v in list(tensors.items()):
|
| if not torch.is_floating_point(v):
|
| skipped_other += 1
|
| continue
|
|
|
| if is_vae_key(k):
|
| skipped_vae += 1
|
| continue
|
|
|
| if is_unet_key(k) or is_text_encoder_key(k):
|
| tensors[k] = v.to(fp8_dtype)
|
| converted += 1
|
| else:
|
| skipped_other += 1
|
|
|
| print(
|
| f"FP8 conversion: converted={converted}, "
|
| f"skipped_vae={skipped_vae}, skipped_other={skipped_other}"
|
| )
|
|
|
|
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(
|
| description=(
|
| "Apply LTX2-style LoRAs with separate video/audio strengths, "
|
| "strength-weighted normalization, LERP blending, per‑LoRA clipping, "
|
| "FP8 conversion, and metadata preservation (streaming, memory‑efficient)."
|
| )
|
| )
|
| parser.add_argument("base", help="Base checkpoint (.safetensors)")
|
| parser.add_argument("lora_list", help="Text file: path,video_strength,lerp[,audio_strength]")
|
| parser.add_argument("output", help="Output FP8 checkpoint (.safetensors)")
|
|
|
| args = parser.parse_args()
|
|
|
| if not os.path.isfile(args.base):
|
| raise FileNotFoundError(args.base)
|
|
|
| lora_specs = parse_lora_list(args.lora_list)
|
| if not lora_specs:
|
| raise ValueError("No LoRAs specified.")
|
|
|
| print(f"Loading base checkpoint: {args.base}")
|
| base_tensors, metadata = load_base_with_metadata(args.base)
|
| print(f"Base checkpoint has {len(base_tensors)} tensors.")
|
|
|
| apply_loras_to_base(base_tensors, lora_specs)
|
|
|
| print("Converting UNet + text encoder to FP8 (leaving VAE untouched)...")
|
| convert_to_fp8_inplace(base_tensors)
|
|
|
| print(f"Saving merged FP8 checkpoint to: {args.output}")
|
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
|
| save_file(base_tensors, args.output, metadata=metadata)
|
| print("Done.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|