primerz commited on
Commit
c1defd0
·
verified ·
1 Parent(s): 5c4da2e

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +31 -28
models.py CHANGED
@@ -177,9 +177,11 @@ def load_lora(pipe):
177
 
178
  def fuse_lora_with_scale(pipe, lora_scale):
179
  """
180
- Following examplewithface.py lines 223-235:
181
- Use the Kohya-style LoRA loader from lora.py
182
- CRITICAL: Pass pipe.text_encoder (not a list) for the example to work
 
 
183
  """
184
  global lora_path_cached
185
 
@@ -187,38 +189,39 @@ def fuse_lora_with_scale(pipe, lora_scale):
187
  return False
188
 
189
  try:
190
- # Import the local lora module (Kohya-style)
191
- import lora
192
-
193
- print(f" [LORA] Creating network from weights...")
 
194
 
195
- # examplewithface.py lines 223-229
196
- # IMPORTANT: Example passes pipe.text_encoder (singular), not a list!
197
- # The LoRA network will handle SDXL's dual encoders internally
198
- lora_model, weights_sd = lora.create_network_from_weights(
199
- lora_scale, # multiplier
200
- lora_path_cached, # file path
201
- pipe.vae,
202
- pipe.text_encoder, # Single text encoder (example line 227)
203
- pipe.unet,
204
- for_inference=True,
205
- )
206
 
207
- # examplewithface.py lines 231-233
208
- print(f" [LORA] Merging to model with scale {lora_scale}...")
209
- lora_model.merge_to(
210
- pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
211
- )
 
 
212
 
213
- # Cleanup
214
- del weights_sd
215
- del lora_model
216
 
217
- print(f" [OK] LoRA merged into model using Kohya loader")
 
 
218
 
 
219
  return True
 
220
  except Exception as e:
221
- print(f" [ERROR] LoRA merge failed: {e}")
222
  import traceback
223
  traceback.print_exc()
224
  return False
 
177
 
178
  def fuse_lora_with_scale(pipe, lora_scale):
179
  """
180
+ Following examplewithface.py lines 266-267 EXACTLY:
181
+ pipe.load_lora_weights(loaded_state_dict)
182
+ pipe.fuse_lora(lora_scale)
183
+
184
+ Uses DIFFUSERS built-in LoRA (NOT Kohya lora.py!)
185
  """
186
  global lora_path_cached
187
 
 
189
  return False
190
 
191
  try:
192
+ # Unfuse previous LoRA (example line 259)
193
+ try:
194
+ pipe.unfuse_lora()
195
+ except:
196
+ pass
197
 
198
+ # Unload previous LoRA (example line 260)
199
+ try:
200
+ pipe.unload_lora_weights()
201
+ except:
202
+ pass
 
 
 
 
 
 
203
 
204
+ print(f" [LORA] Loading state dict from file...")
205
+ # Load state dict like example (lines 75-78)
206
+ if lora_path_cached.endswith('.safetensors'):
207
+ from safetensors.torch import load_file
208
+ state_dict = load_file(lora_path_cached)
209
+ else:
210
+ state_dict = torch.load(lora_path_cached, map_location="cpu")
211
 
212
+ print(f" [LORA] Loading weights into pipeline...")
213
+ # examplewithface.py line 266
214
+ pipe.load_lora_weights(state_dict)
215
 
216
+ # examplewithface.py line 267
217
+ print(f" [LORA] Fusing with scale {lora_scale}...")
218
+ pipe.fuse_lora(lora_scale)
219
 
220
+ print(f" [OK] LoRA fused into model (diffusers method)")
221
  return True
222
+
223
  except Exception as e:
224
+ print(f" [ERROR] LoRA fusion failed: {e}")
225
  import traceback
226
  traceback.print_exc()
227
  return False