File size: 12,569 Bytes
1cc67f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
#!/usr/bin/env python3

# THIS IS FOR ADVANCED LORA INTO BASE MODEL MERGING
# Designed for use with the LTX2 model
#
# Make a text file with a list of LORAs you want to merge in this format for each line:
# <path to safetensors>,<strength>,<lerp>
#
# The "lerp" parameter means "how much should I overwrite tensors that compete with LORAs listed above?". A value of "0" just mixes them all together, while "1" hard applies the LORA delta
#
# You can also supply a separate audio and video strengths like this:
# <path to safetensors>,<video strength>,<lerp>,<audio strength>
#
# Use this script like this:
# python fancy-apply.py <base model safetensors> <lora list txt file> <merged output filename>

import argparse
import os
from typing import Dict, Tuple, List, Optional
from collections import defaultdict
import math

import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open


# ----------------- Tuning knobs ----------------- #

# If True, the normalized component uses:
#   scale_norm = eff_strength / max(1.0, sum_eff_strengths_for_key)
NORMALIZE_OVERLAPS = True

# Per‑LoRA clipping threshold:
# If not None, each LoRA's delta is clipped so that:
#   ||delta|| <= CLIP_RATIO * ||W||
CLIP_RATIO: Optional[float] = 1.0


# ----------------- Parsing LoRA list ----------------- #

def parse_lora_list(path: str) -> List[Tuple[str, float, float, float]]:
    """

    Parse list_of_loras.txt with lines like:

      filename.safetensors,0.7,0.0

      filename2.safetensors,1.0,0.5,0.3



    Returns list of tuples:

      (path, video_strength, lerp_with_existing, audio_strength)



    Where:

      video_strength: base strength for video/shared weights

      audio_strength: base strength for audio weights

                      (defaults to video_strength if omitted)

      lerp_with_existing in [0, 1]:

        0.0 -> fully normalized

        1.0 -> fully direct

        between -> blend between normalized and direct

    """
    loras: List[Tuple[str, float, float, float]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):
                continue

            parts = [p.strip() for p in line.split(",")]
            if len(parts) < 3:
                raise ValueError(f"Invalid LoRA line (need at least file,video_strength,lerp): {line}")

            filename = parts[0]
            video_strength = float(parts[1])
            lerp = float(parts[2])

            if len(parts) >= 4:
                audio_strength = float(parts[3])
            else:
                audio_strength = video_strength

            lerp = max(0.0, min(1.0, lerp))

            loras.append((filename, video_strength, lerp, audio_strength))

    return loras


# ----------------- Base loading ----------------- #

def load_base_with_metadata(path: str):
    with safe_open(path, framework="pt", device="cpu") as f:
        metadata = f.metadata() or {}
    tensors = load_file(path, device="cpu")
    return tensors, metadata


# ----------------- LoRA key grouping ----------------- #

def group_lora_pairs(lora_tensors: Dict[str, torch.Tensor]):
    prefixes = {}
    for k in lora_tensors.keys():
        if k.endswith(".lora_A.weight"):
            prefix = k[: -len(".lora_A.weight")]
            prefixes.setdefault(prefix, {})["A"] = k
        elif k.endswith(".lora_B.weight"):
            prefix = k[: -len(".lora_B.weight")]
            prefixes.setdefault(prefix, {})["B"] = k
        elif k.endswith(".alpha"):
            prefix = k[: -len(".alpha")]
            prefixes.setdefault(prefix, {})["alpha"] = k

    for prefix, keys in prefixes.items():
        if "A" not in keys or "B" not in keys:
            print(f"Warning: incomplete LoRA prefix {prefix}")
            continue
        yield prefix, keys["A"], keys["B"], keys.get("alpha")


def find_base_weight_key(base_tensors, lora_prefix):
    candidates = [
        f"{lora_prefix}.weight",
        f"model.{lora_prefix}.weight",
        lora_prefix,
        f"model.{lora_prefix}",
    ]
    for c in candidates:
        if c in base_tensors:
            return c
    return None


# ----------------- Audio / video classification ----------------- #

def classify_prefix(prefix: str) -> str:
    """

    Classify a LoRA prefix as 'audio', 'video', 'cross', or 'shared'.

    """
    p = prefix.lower()

    # Cross-modal first
    if "audio_to_video" in p or "video_to_audio" in p:
        return "cross"

    # Audio-specific
    if "audio_attn" in p or "audio_ff" in p or ".audio_" in p:
        return "audio"

    # Video-specific (heuristic)
    if "video_attn" in p or "video_ff" in p or ".video_" in p:
        return "video"

    # Default: shared (treated as video-strength)
    return "shared"


def effective_strength_for_prefix(

    prefix: str,

    video_strength: float,

    audio_strength: float,

) -> float:
    kind = classify_prefix(prefix)
    if kind == "audio":
        return audio_strength
    elif kind == "video":
        return video_strength
    elif kind == "cross":
        # Blend strengths for cross-modal
        return math.sqrt(max(video_strength, 0.0) * max(audio_strength, 0.0))
    else:
        # shared
        return video_strength


# ----------------- Pass 1: strength sums per key ----------------- #

def compute_strength_sums(

    base_tensors,

    lora_specs: List[Tuple[str, float, float, float]],

) -> Dict[str, float]:
    """

    For each base weight key, compute the sum of effective strengths of all LoRAs

    that touch it (using video/audio/cross classification).

    """
    strength_sum: Dict[str, float] = defaultdict(float)

    for lora_path, video_strength, lerp, audio_strength in lora_specs:
        print(f"[Pass 1] Scanning {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
        lora_tensors = load_file(lora_path, device="cpu")

        for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
            base_key = find_base_weight_key(base_tensors, prefix)
            if base_key is None:
                continue

            eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)
            strength_sum[base_key] += eff_strength

        del lora_tensors

    print(f"[Pass 1] Keys with strength contributions: {len(strength_sum)}")
    return strength_sum


# ----------------- Pass 2: streaming application ----------------- #

def apply_loras_streaming(

    base_tensors,

    lora_specs: List[Tuple[str, float, float, float]],

    strength_sum: Dict[str, float],

    clip_ratio: Optional[float] = CLIP_RATIO,

):
    for lora_path, video_strength, lerp, audio_strength in lora_specs:
        print(f"[Pass 2] Applying {lora_path} (video={video_strength}, audio={audio_strength}, lerp={lerp})")
        lora_tensors = load_file(lora_path, device="cpu")

        applied = 0
        skipped = 0

        for prefix, A_key, B_key, alpha_key in group_lora_pairs(lora_tensors):
            base_key = find_base_weight_key(base_tensors, prefix)
            if base_key is None:
                skipped += 1
                continue

            W = base_tensors[base_key]

            A = lora_tensors[A_key].to(torch.float32)
            B = lora_tensors[B_key].to(torch.float32)
            delta = B @ A

            if delta.shape != W.shape:
                raise ValueError(
                    f"Shape mismatch for {prefix}: delta {delta.shape} vs base {W.shape}"
                )

            rank = A.shape[0] if A.dim() == 2 else A.numel()

            # Effective strength for this prefix (audio/video/cross/shared)
            eff_strength = effective_strength_for_prefix(prefix, video_strength, audio_strength)

            # Base strength + alpha scaling
            if alpha_key is not None:
                alpha = float(lora_tensors[alpha_key].to(torch.float32).item())
                base_scale = eff_strength * alpha / max(rank, 1)
            else:
                base_scale = eff_strength

            # Weighted normalization
            if NORMALIZE_OVERLAPS:
                total_strength = strength_sum.get(base_key, 0.0)
                denom = max(1.0, total_strength)
                scale_norm = base_scale / denom
            else:
                scale_norm = base_scale

            # Direct (unnormalized) component
            scale_direct = base_scale

            # LERP between normalized and direct
            scale = (1.0 - lerp) * scale_norm + lerp * scale_direct

            delta_scaled = delta * scale

            # Per‑LoRA clipping
            if clip_ratio is not None:
                Wf = W.to(torch.float32)
                base_norm = Wf.norm().item()
                delta_norm = delta_scaled.norm().item()

                if delta_norm > clip_ratio * base_norm and delta_norm > 0:
                    delta_scaled *= (clip_ratio * base_norm) / delta_norm

            # Apply update
            W_new = W.to(torch.float32) + delta_scaled
            base_tensors[base_key] = W_new.to(W.dtype)

            applied += 1

        print(f"[Pass 2] {lora_path}: applied {applied}, skipped {skipped}")
        del lora_tensors


def apply_loras_to_base(base_tensors, lora_specs):
    strength_sum = compute_strength_sums(base_tensors, lora_specs)
    apply_loras_streaming(base_tensors, lora_specs, strength_sum)


# ----------------- FP8 conversion ----------------- #

def is_vae_key(key: str) -> bool:
    return any(key.startswith(p) for p in [
        "first_stage_model.",
        "model.first_stage_model.",
        "vae.",
        "model.vae.",
    ])


def is_text_encoder_key(key: str) -> bool:
    return any(key.startswith(p) for p in [
        "text_encoder.",
        "model.text_encoder.",
        "cond_stage_model.",
        "model.cond_stage_model.",
    ])


def is_unet_key(key: str) -> bool:
    return any(key.startswith(p) for p in [
        "model.diffusion_model.",
        "diffusion_model.",
    ])


def convert_to_fp8_inplace(tensors: Dict[str, torch.Tensor]):
    fp8_dtype = torch.float8_e4m3fn

    converted = 0
    skipped_vae = 0
    skipped_other = 0

    for k, v in list(tensors.items()):
        if not torch.is_floating_point(v):
            skipped_other += 1
            continue

        if is_vae_key(k):
            skipped_vae += 1
            continue

        if is_unet_key(k) or is_text_encoder_key(k):
            tensors[k] = v.to(fp8_dtype)
            converted += 1
        else:
            skipped_other += 1

    print(
        f"FP8 conversion: converted={converted}, "
        f"skipped_vae={skipped_vae}, skipped_other={skipped_other}"
    )


# ----------------- Main CLI ----------------- #

def main():
    parser = argparse.ArgumentParser(
        description=(
            "Apply LTX2-style LoRAs with separate video/audio strengths, "
            "strength-weighted normalization, LERP blending, per‑LoRA clipping, "
            "FP8 conversion, and metadata preservation (streaming, memory‑efficient)."
        )
    )
    parser.add_argument("base", help="Base checkpoint (.safetensors)")
    parser.add_argument("lora_list", help="Text file: path,video_strength,lerp[,audio_strength]")
    parser.add_argument("output", help="Output FP8 checkpoint (.safetensors)")

    args = parser.parse_args()

    if not os.path.isfile(args.base):
        raise FileNotFoundError(args.base)

    lora_specs = parse_lora_list(args.lora_list)
    if not lora_specs:
        raise ValueError("No LoRAs specified.")

    print(f"Loading base checkpoint: {args.base}")
    base_tensors, metadata = load_base_with_metadata(args.base)
    print(f"Base checkpoint has {len(base_tensors)} tensors.")

    apply_loras_to_base(base_tensors, lora_specs)

    print("Converting UNet + text encoder to FP8 (leaving VAE untouched)...")
    convert_to_fp8_inplace(base_tensors)

    print(f"Saving merged FP8 checkpoint to: {args.output}")
    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
    save_file(base_tensors, args.output, metadata=metadata)
    print("Done.")


if __name__ == "__main__":
    main()