Spaces:
Runtime error
Runtime error
Update models.py
Browse files
models.py
CHANGED
|
@@ -177,9 +177,11 @@ def load_lora(pipe):
|
|
| 177 |
|
| 178 |
def fuse_lora_with_scale(pipe, lora_scale):
|
| 179 |
"""
|
| 180 |
-
Following examplewithface.py lines
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
"""
|
| 184 |
global lora_path_cached
|
| 185 |
|
|
@@ -187,38 +189,39 @@ def fuse_lora_with_scale(pipe, lora_scale):
|
|
| 187 |
return False
|
| 188 |
|
| 189 |
try:
|
| 190 |
-
#
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
|
| 195 |
-
#
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
lora_path_cached, # file path
|
| 201 |
-
pipe.vae,
|
| 202 |
-
pipe.text_encoder, # Single text encoder (example line 227)
|
| 203 |
-
pipe.unet,
|
| 204 |
-
for_inference=True,
|
| 205 |
-
)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
|
| 217 |
-
|
|
|
|
|
|
|
| 218 |
|
|
|
|
| 219 |
return True
|
|
|
|
| 220 |
except Exception as e:
|
| 221 |
-
print(f" [ERROR] LoRA
|
| 222 |
import traceback
|
| 223 |
traceback.print_exc()
|
| 224 |
return False
|
|
|
|
| 177 |
|
| 178 |
def fuse_lora_with_scale(pipe, lora_scale):
|
| 179 |
"""
|
| 180 |
+
Following examplewithface.py lines 266-267 EXACTLY:
|
| 181 |
+
pipe.load_lora_weights(loaded_state_dict)
|
| 182 |
+
pipe.fuse_lora(lora_scale)
|
| 183 |
+
|
| 184 |
+
Uses DIFFUSERS built-in LoRA (NOT Kohya lora.py!)
|
| 185 |
"""
|
| 186 |
global lora_path_cached
|
| 187 |
|
|
|
|
| 189 |
return False
|
| 190 |
|
| 191 |
try:
|
| 192 |
+
# Unfuse previous LoRA (example line 259)
|
| 193 |
+
try:
|
| 194 |
+
pipe.unfuse_lora()
|
| 195 |
+
except:
|
| 196 |
+
pass
|
| 197 |
|
| 198 |
+
# Unload previous LoRA (example line 260)
|
| 199 |
+
try:
|
| 200 |
+
pipe.unload_lora_weights()
|
| 201 |
+
except:
|
| 202 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
print(f" [LORA] Loading state dict from file...")
|
| 205 |
+
# Load state dict like example (lines 75-78)
|
| 206 |
+
if lora_path_cached.endswith('.safetensors'):
|
| 207 |
+
from safetensors.torch import load_file
|
| 208 |
+
state_dict = load_file(lora_path_cached)
|
| 209 |
+
else:
|
| 210 |
+
state_dict = torch.load(lora_path_cached, map_location="cpu")
|
| 211 |
|
| 212 |
+
print(f" [LORA] Loading weights into pipeline...")
|
| 213 |
+
# examplewithface.py line 266
|
| 214 |
+
pipe.load_lora_weights(state_dict)
|
| 215 |
|
| 216 |
+
# examplewithface.py line 267
|
| 217 |
+
print(f" [LORA] Fusing with scale {lora_scale}...")
|
| 218 |
+
pipe.fuse_lora(lora_scale)
|
| 219 |
|
| 220 |
+
print(f" [OK] LoRA fused into model (diffusers method)")
|
| 221 |
return True
|
| 222 |
+
|
| 223 |
except Exception as e:
|
| 224 |
+
print(f" [ERROR] LoRA fusion failed: {e}")
|
| 225 |
import traceback
|
| 226 |
traceback.print_exc()
|
| 227 |
return False
|