alex commited on
Commit
7357e70
·
1 Parent(s): 9cde9cc

LoRA offloading

Browse files
app.py CHANGED
@@ -30,9 +30,7 @@ from ltx_pipelines.utils.constants import (
30
  DEFAULT_FRAME_RATE,
31
  DEFAULT_LORA_STRENGTH,
32
  )
33
- from ltx_core.loader.single_gpu_model_builder import set_lora_enabled
34
-
35
-
36
 
37
  MAX_SEED = np.iinfo(np.int32).max
38
  # Import from public LTX-2 package
@@ -195,6 +193,15 @@ distilled_lora_path = get_hub_or_local_checkpoint(
195
  DEFAULT_DISTILLED_LORA_FILENAME,
196
  )
197
 
 
 
 
 
 
 
 
 
 
198
  dolly_in_lora_path = get_hub_or_local_checkpoint(
199
  "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In",
200
  "ltx-2-19b-lora-camera-control-dolly-in.safetensors",
@@ -203,6 +210,24 @@ dolly_out_lora_path = get_hub_or_local_checkpoint(
203
  "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out",
204
  "ltx-2-19b-lora-camera-control-dolly-out.safetensors",
205
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  # Load distilled LoRA as a regular LoRA
@@ -210,22 +235,30 @@ loras = [
210
  # --- fused / base behavior ---
211
  LoraPathStrengthAndSDOps(
212
  path=distilled_lora_path,
213
- strength=DEFAULT_LORA_STRENGTH,
214
  sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
215
  ),
216
- # # --- runtime-toggle camera controls ---#
217
  LoraPathStrengthAndSDOps(dolly_in_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
218
  LoraPathStrengthAndSDOps(dolly_out_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
 
 
 
 
219
  ]
220
 
221
  # Runtime-toggle LoRAs (exclude fused distilled at index 0)
222
  RUNTIME_LORA_CHOICES = [
223
  ("No LoRA", -1),
224
- ("Dolly In", 0),
225
- ("Dolly Out", 1),
 
 
 
 
 
226
  ]
227
 
228
-
229
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
230
  # Text encoding will be done by external space
231
  pipeline = DistilledPipeline(
@@ -240,15 +273,11 @@ pipeline = DistilledPipeline(
240
 
241
  pipeline._video_encoder = pipeline.model_ledger.video_encoder()
242
  pipeline._transformer = pipeline.model_ledger.transformer()
243
- # pipeline.device = torch.device("cuda")
244
- # pipeline.model_ledger.device = torch.device("cuda")
245
-
246
 
247
  print("=" * 80)
248
  print("Pipeline fully loaded and ready!")
249
  print("=" * 80)
250
 
251
-
252
  class RadioAnimated(gr.HTML):
253
  """
254
  Animated segmented radio (like iOS pill selector).
@@ -541,7 +570,7 @@ class CameraDropdown(gr.HTML):
541
  )
542
 
543
 
544
- def generate_video_example(input_image, prompt, duration, progress=gr.Progress(track_tqdm=True)):
545
 
546
  output_video, seed = generate_video(
547
  input_image,
@@ -552,7 +581,7 @@ def generate_video_example(input_image, prompt, duration, progress=gr.Progress(t
552
  True, # randomize_seed
553
  DEFAULT_1_STAGE_HEIGHT, # height
554
  DEFAULT_1_STAGE_WIDTH, # width
555
- "No LoRA",
556
  progress
557
  )
558
 
@@ -614,61 +643,55 @@ def generate_video(
614
  - GPU cache is cleared after generation to reduce VRAM pressure.
615
  - If an input image is provided, it is temporarily saved to disk for processing.
616
  """
 
 
 
 
 
617
  print(f'generating with duration:{duration} and LoRA:{camera_lora} in {width}x{height}')
618
- try:
619
- # Randomize seed if checkbox is enabled
620
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
621
 
622
- # Calculate num_frames from duration (using fixed 24 fps)
623
- frame_rate = 24.0
624
- num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration
625
 
626
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
627
- output_path = tmpfile.name
 
628
 
629
- # Handle image input
630
- images = []
631
- temp_image_path = None # Initialize to None
632
-
633
- images = []
634
- if input_image is not None:
635
- images = [(input_image, 0, 1.0)] # input_image is already a path
636
-
637
- # Prepare image for upload if it exists
638
- image_input = None
639
 
640
 
641
- embeddings, final_prompt, status = encode_prompt(
642
- prompt=prompt,
643
- enhance_prompt=enhance_prompt,
644
- input_image=input_image,
645
- seed=current_seed,
646
- negative_prompt="",
647
- )
648
-
649
- video_context = embeddings["video_context"].to("cuda", non_blocking=True)
650
- audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
651
- print("✓ Embeddings loaded successfully")
652
 
653
- # free prompt enhancer / encoder temps ASAP
654
- del embeddings, final_prompt, status
655
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
656
 
 
 
 
657
 
658
- # Map dropdown name -> adapter index
659
- name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
660
- selected_idx = name_to_idx.get(camera_lora, -1)
661
 
662
- # Disable all runtime adapters first (0..N-1)
663
- # N here is len(RUNTIME_LORA_CHOICES)-1 because "None" isn't an adapter
664
- for i in range(len(RUNTIME_LORA_CHOICES) - 1):
665
- set_lora_enabled(pipeline._transformer, i, False)
666
 
667
- # Enable selected one (if any)
668
- if selected_idx >= 0:
669
- set_lora_enabled(pipeline._transformer, selected_idx, True)
670
 
671
- # Run inference - progress automatically tracks tqdm from pipeline
 
672
  pipeline(
673
  prompt=prompt,
674
  output_path=str(output_path),
@@ -682,17 +705,12 @@ def generate_video(
682
  video_context=video_context,
683
  audio_context=audio_context,
684
  )
685
- del video_context, audio_context
686
- torch.cuda.empty_cache()
687
- print("successful generation")
688
 
689
- return str(output_path), current_seed
690
 
691
- except Exception as e:
692
- import traceback
693
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
694
- print(error_msg)
695
- return None, current_seed
696
 
697
 
698
  def apply_resolution(resolution: str):
@@ -1209,24 +1227,28 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
1209
  examples=[
1210
  [
1211
  "supergirl.png",
1212
- "A fuzzy puppet superhero character resembling a female puppet with blonde hair and a blue superhero suit stands inside an icy cave made of frozen walls and icicles, she looks panicked and frantic, rapidly turning her head left and right and scanning the cave while waving her arms and shouting angrily and desperately, mouthing the words “where the hell is my dog,” her movements exaggerated and puppet-like with high energy and urgency, suddenly a second puppet dog bursts into frame from the side, jumping up excitedly and tackling her affectionately while licking her face repeatedly, she freezes in surprise and then breaks into relief and laughter as the dog continues licking her, the scene feels chaotic, comedic, and emotional with expressive puppet reactions, cinematic lighting, smooth camera motion, shallow depth of field, and high-quality puppet-style animation"
 
1213
  ],
1214
  [
1215
  "highland.png",
1216
  "Realistic POV selfie-style video in a snowy, foggy field. Two shaggy Highland cows with long curved horns stand ahead. The camera is handheld and slightly shaky. The woman filming talks nervously and excitedly in a vlog tone: \"Oh my god guys… look how big those horns are… I’m kinda scared.\" The cow on the left walks toward the camera in a cute, bouncy, hopping way, curious and gentle. Snow crunches under its hooves, breath visible in the cold air. The horns look massive from the POV. As the cow gets very close, its wet nose with slight dripping fills part of the frame. She laughs nervously but reaches out and pets the cow. The cow makes deep, soft, interesting mooing and snorting sounds, calm and friendly. Ultra-realistic, natural lighting, immersive audio, documentary-style realism.",
 
1217
  ],
1218
  [
1219
  "wednesday.png",
1220
  "A cinematic close-up of Wednesday Addams frozen mid-dance on a dark, blue-lit ballroom floor as students move indistinctly behind her, their footsteps and muffled music reduced to a distant, underwater thrum; the audio foregrounds her steady breathing and the faint rustle of fabric as she slowly raises one arm, never breaking eye contact with the camera, then after a deliberately long silence she speaks in a flat, dry, perfectly controlled voice, “I don’t dance… I vibe code,” each word crisp and unemotional, followed by an abrupt cutoff of her voice as the background sound swells slightly, reinforcing the deadpan humor, with precise lip sync, minimal facial movement, stark gothic lighting, and cinematic realism.",
 
1221
  ],
1222
  [
1223
  "astronaut.png",
1224
  "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.",
 
1225
  ]
1226
 
1227
  ],
1228
  fn=generate_video_example,
1229
- inputs=[input_image, prompt_ui],
1230
  outputs = [output_video],
1231
  label="Example",
1232
  cache_examples=True,
 
30
  DEFAULT_FRAME_RATE,
31
  DEFAULT_LORA_STRENGTH,
32
  )
33
+ from ltx_core.loader.single_gpu_model_builder import enable_only_lora
 
 
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  # Import from public LTX-2 package
 
193
  DEFAULT_DISTILLED_LORA_FILENAME,
194
  )
195
 
196
+ distilled_lora_path = get_hub_or_local_checkpoint(
197
+ DEFAULT_REPO_ID,
198
+ DEFAULT_DISTILLED_LORA_FILENAME,
199
+ )
200
+
201
+ static_lora_path = get_hub_or_local_checkpoint(
202
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Static",
203
+ "ltx-2-19b-lora-camera-control-static.safetensors",
204
+ )
205
  dolly_in_lora_path = get_hub_or_local_checkpoint(
206
  "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In",
207
  "ltx-2-19b-lora-camera-control-dolly-in.safetensors",
 
210
  "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out",
211
  "ltx-2-19b-lora-camera-control-dolly-out.safetensors",
212
  )
213
+ dolly_left_lora_path = get_hub_or_local_checkpoint(
214
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left",
215
+ "ltx-2-19b-lora-camera-control-dolly-left.safetensors",
216
+ )
217
+ dolly_right_lora_path = get_hub_or_local_checkpoint(
218
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right",
219
+ "ltx-2-19b-lora-camera-control-dolly-right.safetensors",
220
+ )
221
+ jib_down_lora_path = get_hub_or_local_checkpoint(
222
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down",
223
+ "ltx-2-19b-lora-camera-control-jib-down.safetensors",
224
+ )
225
+ jib_up_lora_path = get_hub_or_local_checkpoint(
226
+ "Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up",
227
+ "ltx-2-19b-lora-camera-control-jib-up.safetensors",
228
+ )
229
+
230
+
231
 
232
 
233
  # Load distilled LoRA as a regular LoRA
 
235
  # --- fused / base behavior ---
236
  LoraPathStrengthAndSDOps(
237
  path=distilled_lora_path,
238
+ strength=0.6,
239
  sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
240
  ),
241
+ LoraPathStrengthAndSDOps(static_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
242
  LoraPathStrengthAndSDOps(dolly_in_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
243
  LoraPathStrengthAndSDOps(dolly_out_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
244
+ LoraPathStrengthAndSDOps(dolly_left_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
245
+ LoraPathStrengthAndSDOps(dolly_right_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
246
+ LoraPathStrengthAndSDOps(jib_down_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
247
+ LoraPathStrengthAndSDOps(jib_up_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
248
  ]
249
 
250
  # Runtime-toggle LoRAs (exclude fused distilled at index 0)
251
  RUNTIME_LORA_CHOICES = [
252
  ("No LoRA", -1),
253
+ ("Static", 0),
254
+ ("Dolly In", 1),
255
+ ("Dolly Out", 2),
256
+ ("Dolly Left", 3),
257
+ ("Dolly Right", 4),
258
+ ("Jib Down", 5),
259
+ ("Jib Up", 6),
260
  ]
261
 
 
262
  # Initialize pipeline WITHOUT text encoder (gemma_root=None)
263
  # Text encoding will be done by external space
264
  pipeline = DistilledPipeline(
 
273
 
274
  pipeline._video_encoder = pipeline.model_ledger.video_encoder()
275
  pipeline._transformer = pipeline.model_ledger.transformer()
 
 
 
276
 
277
  print("=" * 80)
278
  print("Pipeline fully loaded and ready!")
279
  print("=" * 80)
280
 
 
281
  class RadioAnimated(gr.HTML):
282
  """
283
  Animated segmented radio (like iOS pill selector).
 
570
  )
571
 
572
 
573
+ def generate_video_example(input_image, prompt, camera_lora, progress=gr.Progress(track_tqdm=True)):
574
 
575
  output_video, seed = generate_video(
576
  input_image,
 
581
  True, # randomize_seed
582
  DEFAULT_1_STAGE_HEIGHT, # height
583
  DEFAULT_1_STAGE_WIDTH, # width
584
+ camera_lora,
585
  progress
586
  )
587
 
 
643
  - GPU cache is cleared after generation to reduce VRAM pressure.
644
  - If an input image is provided, it is temporarily saved to disk for processing.
645
  """
646
+
647
+ if camera_lora != "No LoRA" and duration == 15:
648
+ gr.Info("15s not avaiable when a LoRA is activated, reducing to 10s for this generation")
649
+ duration = 10
650
+
651
  print(f'generating with duration:{duration} and LoRA:{camera_lora} in {width}x{height}')
 
 
 
652
 
653
+ # Randomize seed if checkbox is enabled
654
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
 
655
 
656
+ # Calculate num_frames from duration (using fixed 24 fps)
657
+ frame_rate = 24.0
658
+ num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration
659
 
660
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
661
+ output_path = tmpfile.name
 
 
 
 
 
 
 
 
662
 
663
 
664
+ images = []
665
+
666
+ if input_image is not None:
667
+ images = [(input_image, 0, 1.0)]
 
 
 
 
 
 
 
668
 
669
+ embeddings, final_prompt, status = encode_prompt(
670
+ prompt=prompt,
671
+ enhance_prompt=enhance_prompt,
672
+ input_image=input_image,
673
+ seed=current_seed,
674
+ negative_prompt="",
675
+ )
676
+
677
+ video_context = embeddings["video_context"].to("cuda", non_blocking=True)
678
+ audio_context = embeddings["audio_context"].to("cuda", non_blocking=True)
679
+ print("✓ Embeddings loaded successfully")
680
 
681
+ # free prompt enhancer / encoder temps ASAP
682
+ del embeddings, final_prompt, status
683
+ torch.cuda.empty_cache()
684
 
 
 
 
685
 
686
+ # Map dropdown name -> adapter index
687
+ name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
688
+ selected_idx = name_to_idx.get(camera_lora, -1)
 
689
 
690
+ enable_only_lora(pipeline._transformer, selected_idx)
691
+ torch.cuda.empty_cache()
 
692
 
693
+ # Run inference - progress automatically tracks tqdm from pipeline
694
+ with torch.inference_mode():
695
  pipeline(
696
  prompt=prompt,
697
  output_path=str(output_path),
 
705
  video_context=video_context,
706
  audio_context=audio_context,
707
  )
708
+ del video_context, audio_context
709
+ torch.cuda.empty_cache()
710
+ print("successful generation")
711
 
712
+ return str(output_path), current_seed
713
 
 
 
 
 
 
714
 
715
 
716
  def apply_resolution(resolution: str):
 
1227
  examples=[
1228
  [
1229
  "supergirl.png",
1230
+ "A fuzzy puppet superhero character resembling a female puppet with blonde hair and a blue superhero suit stands inside an icy cave made of frozen walls and icicles, she looks panicked and frantic, rapidly turning her head left and right and scanning the cave while waving her arms and shouting angrily and desperately, mouthing the words “where the hell is my dog,” her movements exaggerated and puppet-like with high energy and urgency, suddenly a second puppet dog bursts into frame from the side, jumping up excitedly and tackling her affectionately while licking her face repeatedly, she freezes in surprise and then breaks into relief and laughter as the dog continues licking her, the scene feels chaotic, comedic, and emotional with expressive puppet reactions, cinematic lighting, smooth camera motion, shallow depth of field, and high-quality puppet-style animation",
1231
+ "No LoRA",
1232
  ],
1233
  [
1234
  "highland.png",
1235
  "Realistic POV selfie-style video in a snowy, foggy field. Two shaggy Highland cows with long curved horns stand ahead. The camera is handheld and slightly shaky. The woman filming talks nervously and excitedly in a vlog tone: \"Oh my god guys… look how big those horns are… I’m kinda scared.\" The cow on the left walks toward the camera in a cute, bouncy, hopping way, curious and gentle. Snow crunches under its hooves, breath visible in the cold air. The horns look massive from the POV. As the cow gets very close, its wet nose with slight dripping fills part of the frame. She laughs nervously but reaches out and pets the cow. The cow makes deep, soft, interesting mooing and snorting sounds, calm and friendly. Ultra-realistic, natural lighting, immersive audio, documentary-style realism.",
1236
+ "No LoRA",
1237
  ],
1238
  [
1239
  "wednesday.png",
1240
  "A cinematic close-up of Wednesday Addams frozen mid-dance on a dark, blue-lit ballroom floor as students move indistinctly behind her, their footsteps and muffled music reduced to a distant, underwater thrum; the audio foregrounds her steady breathing and the faint rustle of fabric as she slowly raises one arm, never breaking eye contact with the camera, then after a deliberately long silence she speaks in a flat, dry, perfectly controlled voice, “I don’t dance… I vibe code,” each word crisp and unemotional, followed by an abrupt cutoff of her voice as the background sound swells slightly, reinforcing the deadpan humor, with precise lip sync, minimal facial movement, stark gothic lighting, and cinematic realism.",
1241
+ "Dolly Out",
1242
  ],
1243
  [
1244
  "astronaut.png",
1245
  "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.",
1246
+ "Static",
1247
  ]
1248
 
1249
  ],
1250
  fn=generate_video_example,
1251
+ inputs=[input_image, prompt_ui, camera_lora_ui],
1252
  outputs = [output_video],
1253
  label="Example",
1254
  cache_examples=True,
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py CHANGED
@@ -55,41 +55,120 @@ class MultiLoraLinear(nn.Module):
55
  def __init__(self, base: nn.Linear):
56
  super().__init__()
57
  self.base = base
58
- self.adapters: list[tuple[torch.Tensor, torch.Tensor, float]] = []
 
59
  self.enabled: list[bool] = []
60
 
61
- def add_adapter(self, A: torch.Tensor, B: torch.Tensor, scale: float, enabled: bool = True):
62
- # store as buffers for inference (keeps them off .parameters())
63
- idx = len(self.adapters)
64
- self.register_buffer(f"lora_A_{idx}", A, persistent=False)
65
- self.register_buffer(f"lora_B_{idx}", B, persistent=False)
66
- self.adapters.append((A, B, float(scale)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  self.enabled.append(bool(enabled))
68
 
69
- def set_enabled(self, idx: int, enabled: bool):
70
- if 0 <= idx < len(self.enabled):
71
- self.enabled[idx] = enabled
72
 
73
- def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  out = self.base(x)
75
- # add enabled adapters
76
  for i, on in enumerate(self.enabled):
77
  if not on:
78
  continue
79
- A = getattr(self, f"lora_A_{i}")
80
- B = getattr(self, f"lora_B_{i}")
81
- scale = self.adapters[i][2]
82
- out = out + ((x @ A.t()) @ B.t()) * scale
 
 
 
 
83
  return out
84
 
85
- def set_lora_enabled(model: nn.Module, adapter_idx: int, enabled: bool):
 
 
 
 
 
 
 
86
  for m in model.modules():
87
  if isinstance(m, MultiLoraLinear):
88
- m.set_enabled(adapter_idx, enabled)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def patch_only_affected_linears(
91
  model: nn.Module,
92
- lora_sd: dict,
93
  affected_modules: list[str],
94
  strength: float,
95
  adapter_idx: int,
@@ -98,7 +177,6 @@ def patch_only_affected_linears(
98
  for prefix in affected_modules:
99
  _, _, mod = get_submodule_and_parent(model, prefix)
100
 
101
- # unwrap / wrap
102
  if isinstance(mod, MultiLoraLinear):
103
  wrapped = mod
104
  else:
@@ -107,23 +185,30 @@ def patch_only_affected_linears(
107
  wrapped = MultiLoraLinear(mod)
108
  set_submodule(model, prefix, wrapped)
109
 
 
 
 
 
110
  key_a = f"{prefix}.lora_A.weight"
111
  key_b = f"{prefix}.lora_B.weight"
112
  if key_a not in lora_sd or key_b not in lora_sd:
 
113
  continue
114
 
115
- base_device = wrapped.base.weight.device
116
- base_dtype = wrapped.base.weight.dtype
117
-
118
- A = lora_sd[key_a].to(device=base_device, dtype=base_dtype)
119
- B = lora_sd[key_b].to(device=base_device, dtype=base_dtype)
120
 
121
- # parity with your current merge behavior:
122
- scale = strength
123
-
124
- # Ensure adapter list indices align across layers
125
- # If adapters are added sequentially per adapter_idx, this will line up.
126
- wrapped.add_adapter(A, B, scale=scale, enabled=default_enabled)
 
 
 
 
 
127
 
128
 
129
  @dataclass(frozen=True)
@@ -188,9 +273,28 @@ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType],
188
  meta_model.load_state_dict(sd, strict=False, assign=True)
189
  return self._return_model(meta_model, device)
190
 
191
- lora_state_dicts = [
192
- self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  ]
 
 
 
 
194
  lora_sd_and_strengths = [
195
  LoraStateDictWithStrength(sd, strength)
196
  for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
@@ -206,7 +310,7 @@ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType],
206
 
207
  _, affected_modules = apply_loras(
208
  model_sd=model_state_dict,
209
- lora_sd_and_strengths=lora_sd_and_strengths,
210
  dtype=dtype,
211
  destination_sd=None,
212
  return_affected=True,
 
55
  def __init__(self, base: nn.Linear):
56
  super().__init__()
57
  self.base = base
58
+ # Each entry: dict with CPU tensors always, and optional CUDA cache
59
+ self.adapters: list[dict] = []
60
  self.enabled: list[bool] = []
61
 
62
+ def add_adapter_cpu(self, A_cpu, B_cpu, scale=1.0, enabled=False, pin_memory=False):
63
+ if A_cpu is None or B_cpu is None:
64
+ self.adapters.append({
65
+ "A_cpu": None,
66
+ "B_cpu": None,
67
+ "scale": float(scale),
68
+ "A_gpu": None,
69
+ "B_gpu": None,
70
+ "gpu_dtype": None,
71
+ "gpu_device": None,
72
+ })
73
+ self.enabled.append(bool(enabled))
74
+ return
75
+
76
+ A_cpu = A_cpu.contiguous()
77
+ B_cpu = B_cpu.contiguous()
78
+ self.adapters.append({
79
+ "A_cpu": A_cpu,
80
+ "B_cpu": B_cpu,
81
+ "scale": float(scale),
82
+ "A_gpu": None,
83
+ "B_gpu": None,
84
+ "gpu_dtype": None,
85
+ "gpu_device": None,
86
+ })
87
  self.enabled.append(bool(enabled))
88
 
 
 
 
89
 
90
+
91
+ def _materialize_to_gpu(self, idx: int):
92
+ entry = self.adapters[idx]
93
+ if entry["A_cpu"] is None or entry["B_cpu"] is None:
94
+ return
95
+ """Move adapter idx to the base weight device/dtype if not already there."""
96
+ entry = self.adapters[idx]
97
+ dev = self.base.weight.device
98
+ dt = self.base.weight.dtype
99
+
100
+ if (
101
+ entry["A_gpu"] is not None
102
+ and entry["B_gpu"] is not None
103
+ and entry["gpu_device"] == dev
104
+ and entry["gpu_dtype"] == dt
105
+ ):
106
+ return # already good
107
+
108
+ A = entry["A_cpu"].to(device=dev, dtype=dt, non_blocking=True)
109
+ B = entry["B_cpu"].to(device=dev, dtype=dt, non_blocking=True)
110
+
111
+ entry["A_gpu"] = A
112
+ entry["B_gpu"] = B
113
+ entry["gpu_device"] = dev
114
+ entry["gpu_dtype"] = dt
115
+
116
+ def _evict_from_gpu(self, idx: int):
117
+ """Drop CUDA copies (free VRAM). CPU tensors remain."""
118
+ entry = self.adapters[idx]
119
+ entry["A_gpu"] = None
120
+ entry["B_gpu"] = None
121
+ entry["gpu_device"] = None
122
+ entry["gpu_dtype"] = None
123
+
124
+ def set_enabled(self, idx: int, enabled: bool, offload_when_disabled: bool = True):
125
+ if not (0 <= idx < len(self.enabled)):
126
+ return
127
+ self.enabled[idx] = enabled
128
+ if not enabled and offload_when_disabled:
129
+ self._evict_from_gpu(idx)
130
+
131
+ def forward(self, x):
132
  out = self.base(x)
 
133
  for i, on in enumerate(self.enabled):
134
  if not on:
135
  continue
136
+ entry = self.adapters[i]
137
+ if entry["A_cpu"] is None or entry["B_cpu"] is None:
138
+ continue
139
+ if entry["A_gpu"] is None or entry["B_gpu"] is None:
140
+ self._materialize_to_gpu(i)
141
+ A = entry["A_gpu"]
142
+ B = entry["B_gpu"]
143
+ out = out + ((x @ A.t()) @ B.t()) * entry["scale"]
144
  return out
145
 
146
+ def set_lora_enabled(model: nn.Module, adapter_idx: int, enabled: bool, offload_when_disabled: bool = True):
147
+ for m in model.modules():
148
+ if isinstance(m, MultiLoraLinear):
149
+ m.set_enabled(adapter_idx, enabled, offload_when_disabled=offload_when_disabled)
150
+
151
+ def enable_only_lora(model: nn.Module, adapter_idx: int | None):
152
+ # disable all
153
+ # (assumes all layers have same number of adapters; true if you patch consistently)
154
  for m in model.modules():
155
  if isinstance(m, MultiLoraLinear):
156
+ for i in range(len(m.enabled)):
157
+ m.set_enabled(i, False, offload_when_disabled=True)
158
+
159
+ torch.cuda.empty_cache()
160
+
161
+ # enable selected
162
+ if adapter_idx is not None and adapter_idx >= 0:
163
+ set_lora_enabled(model, adapter_idx, True)
164
+
165
+ torch.cuda.empty_cache()
166
+
167
+
168
 
169
  def patch_only_affected_linears(
170
  model: nn.Module,
171
+ lora_sd: dict, # can be CPU state dict
172
  affected_modules: list[str],
173
  strength: float,
174
  adapter_idx: int,
 
177
  for prefix in affected_modules:
178
  _, _, mod = get_submodule_and_parent(model, prefix)
179
 
 
180
  if isinstance(mod, MultiLoraLinear):
181
  wrapped = mod
182
  else:
 
185
  wrapped = MultiLoraLinear(mod)
186
  set_submodule(model, prefix, wrapped)
187
 
188
+ # ensure adapter slots exist up to adapter_idx
189
+ while len(wrapped.adapters) <= adapter_idx:
190
+ wrapped.add_adapter_cpu(None, None, scale=0.0, enabled=False)
191
+
192
  key_a = f"{prefix}.lora_A.weight"
193
  key_b = f"{prefix}.lora_B.weight"
194
  if key_a not in lora_sd or key_b not in lora_sd:
195
+ # leave the padded empty slot
196
  continue
197
 
198
+ A_cpu = lora_sd[key_a]
199
+ B_cpu = lora_sd[key_b]
 
 
 
200
 
201
+ # overwrite the placeholder slot
202
+ wrapped.adapters[adapter_idx] = {
203
+ "A_cpu": A_cpu.contiguous(),
204
+ "B_cpu": B_cpu.contiguous(),
205
+ "scale": float(strength),
206
+ "A_gpu": None,
207
+ "B_gpu": None,
208
+ "gpu_dtype": None,
209
+ "gpu_device": None,
210
+ }
211
+ wrapped.enabled[adapter_idx] = default_enabled
212
 
213
 
214
  @dataclass(frozen=True)
 
273
  meta_model.load_state_dict(sd, strict=False, assign=True)
274
  return self._return_model(meta_model, device)
275
 
276
+ # Load LoRA[0] (fused) on GPU (or CPU—GPU is fine since you fuse immediately)
277
+ lora0_sd = self.load_sd(
278
+ [self.loras[0].path],
279
+ sd_ops=self.loras[0].sd_ops,
280
+ registry=self.registry,
281
+ device=device,
282
+ )
283
+
284
+ # Load runtime LoRAs on CPU so they don't sit in VRAM
285
+ runtime_lora_sds = [
286
+ self.load_sd(
287
+ [lora.path],
288
+ sd_ops=lora.sd_ops,
289
+ registry=self.registry,
290
+ device=torch.device("cpu"),
291
+ )
292
+ for lora in self.loras[1:]
293
  ]
294
+
295
+ # Rebuild lists to match your later code expectations
296
+ lora_state_dicts = [lora0_sd, *runtime_lora_sds]
297
+
298
  lora_sd_and_strengths = [
299
  LoraStateDictWithStrength(sd, strength)
300
  for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
 
310
 
311
  _, affected_modules = apply_loras(
312
  model_sd=model_state_dict,
313
+ lora_sd_and_strengths=lora_sd_and_strengths[1:],
314
  dtype=dtype,
315
  destination_sd=None,
316
  return_affected=True,