Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
|
@@ -6,7 +6,8 @@ from config import Config
|
|
| 6 |
|
| 7 |
from diffusers import (
|
| 8 |
ControlNetModel,
|
| 9 |
-
|
|
|
|
| 10 |
)
|
| 11 |
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
|
| 12 |
|
|
@@ -109,27 +110,26 @@ class ModelHandler:
|
|
| 109 |
|
| 110 |
self.pipeline.to(Config.DEVICE)
|
| 111 |
|
| 112 |
-
# Enable xFormers
|
| 113 |
try:
|
| 114 |
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 115 |
print(" [OK] xFormers memory efficient attention enabled.")
|
| 116 |
except Exception as e:
|
| 117 |
print(f" [WARNING] Failed to enable xFormers: {e}")
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
# ---
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# 5. Load Adapters (IP-Adapter
|
| 130 |
-
print("Loading Adapters...")
|
| 131 |
|
| 132 |
-
# 5a. IP-Adapter
|
| 133 |
ip_adapter_filename = "ip-adapter.bin"
|
| 134 |
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
|
| 135 |
|
|
@@ -145,29 +145,12 @@ class ModelHandler:
|
|
| 145 |
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
|
| 146 |
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
|
| 147 |
|
| 148 |
-
|
| 149 |
-
print("Loading TCD-SDXL-LoRA...")
|
| 150 |
-
tcd_lora_filename = "pytorch_lora_weights.safetensors"
|
| 151 |
-
tcd_lora_path = os.path.join("./models", tcd_lora_filename)
|
| 152 |
-
|
| 153 |
-
if not os.path.exists(tcd_lora_path):
|
| 154 |
-
hf_hub_download(
|
| 155 |
-
repo_id="h1t/TCD-SDXL-LoRA",
|
| 156 |
-
filename=tcd_lora_filename,
|
| 157 |
-
local_dir="./models",
|
| 158 |
-
local_dir_use_symlinks=False
|
| 159 |
-
)
|
| 160 |
-
self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename)
|
| 161 |
-
self.pipeline.fuse_lora(lora_scale=1.0)
|
| 162 |
-
print(" [OK] TCD LoRA fused.")
|
| 163 |
-
|
| 164 |
-
# 5c. Load Style LoRA
|
| 165 |
-
print("Loading Style LoRA weights...")
|
| 166 |
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
|
| 167 |
|
| 168 |
-
print(f"Fusing
|
| 169 |
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
|
| 170 |
-
print(" [OK]
|
| 171 |
|
| 172 |
# 6. Load Preprocessors
|
| 173 |
print("Loading Preprocessors (LeReS, LineArtAnime)...")
|
|
|
|
| 6 |
|
| 7 |
from diffusers import (
|
| 8 |
ControlNetModel,
|
| 9 |
+
LCMScheduler,
|
| 10 |
+
# AutoencoderKL # Removed as requested
|
| 11 |
)
|
| 12 |
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
|
| 13 |
|
|
|
|
| 110 |
|
| 111 |
self.pipeline.to(Config.DEVICE)
|
| 112 |
|
| 113 |
+
# --- NEW: Enable xFormers ---
|
| 114 |
try:
|
| 115 |
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 116 |
print(" [OK] xFormers memory efficient attention enabled.")
|
| 117 |
except Exception as e:
|
| 118 |
print(f" [WARNING] Failed to enable xFormers: {e}")
|
| 119 |
+
# --- END NEW ---
|
| 120 |
+
|
| 121 |
+
# 4. Set Scheduler
|
| 122 |
+
# --- MODIFIED: Disable clipping to prevent NaN artifacts ---
|
| 123 |
+
print("Configuring LCMScheduler...")
|
| 124 |
+
scheduler_config = self.pipeline.scheduler.config
|
| 125 |
+
scheduler_config['clip_sample'] = False # <-- THIS IS THE FIX
|
| 126 |
+
self.pipeline.scheduler = LCMScheduler.from_config(scheduler_config)
|
| 127 |
+
print(" [OK] LCMScheduler loaded (clip_sample=False).")
|
| 128 |
+
# --- END MODIFIED ---
|
| 129 |
+
|
| 130 |
+
# 5. Load Adapters (IP-Adapter & LoRA)
|
| 131 |
+
print("Loading Adapters (IP-Adapter & LoRA)...")
|
| 132 |
|
|
|
|
| 133 |
ip_adapter_filename = "ip-adapter.bin"
|
| 134 |
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
|
| 135 |
|
|
|
|
| 145 |
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
|
| 146 |
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
|
| 147 |
|
| 148 |
+
print("Loading LoRA weights...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
|
| 150 |
|
| 151 |
+
print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
|
| 152 |
self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
|
| 153 |
+
print(" [OK] LoRA fused.")
|
| 154 |
|
| 155 |
# 6. Load Preprocessors
|
| 156 |
print("Loading Preprocessors (LeReS, LineArtAnime)...")
|