primerz commited on
Commit
5bb4ff9
·
verified ·
1 Parent(s): ebdbde1

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +19 -36
model.py CHANGED
@@ -6,7 +6,8 @@ from config import Config
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
- TCDScheduler,
 
10
  )
11
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
12
 
@@ -109,27 +110,26 @@ class ModelHandler:
109
 
110
  self.pipeline.to(Config.DEVICE)
111
 
112
- # Enable xFormers
113
  try:
114
  self.pipeline.enable_xformers_memory_efficient_attention()
115
  print(" [OK] xFormers memory efficient attention enabled.")
116
  except Exception as e:
117
  print(f" [WARNING] Failed to enable xFormers: {e}")
118
-
119
- # 4. Set TCD Scheduler
120
- print("Configuring TCDScheduler...")
121
- # --- FIX: Set timestep_spacing="trailing" for proper distilled sampling ---
122
- self.pipeline.scheduler = TCDScheduler.from_config(
123
- self.pipeline.scheduler.config,
124
- use_karras_sigmas=True,
125
- timestep_spacing="trailing"
126
- )
127
- print(" [OK] TCDScheduler loaded (Karras + Trailing Spacing).")
128
-
129
- # 5. Load Adapters (IP-Adapter, TCD-LoRA & Style LoRA)
130
- print("Loading Adapters...")
131
 
132
- # 5a. IP-Adapter
133
  ip_adapter_filename = "ip-adapter.bin"
134
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
135
 
@@ -145,29 +145,12 @@ class ModelHandler:
145
  print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
146
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
147
 
148
- # 5b. Load TCD LoRA (Correct filename)
149
- print("Loading TCD-SDXL-LoRA...")
150
- tcd_lora_filename = "pytorch_lora_weights.safetensors"
151
- tcd_lora_path = os.path.join("./models", tcd_lora_filename)
152
-
153
- if not os.path.exists(tcd_lora_path):
154
- hf_hub_download(
155
- repo_id="h1t/TCD-SDXL-LoRA",
156
- filename=tcd_lora_filename,
157
- local_dir="./models",
158
- local_dir_use_symlinks=False
159
- )
160
- self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename)
161
- self.pipeline.fuse_lora(lora_scale=1.0)
162
- print(" [OK] TCD LoRA fused.")
163
-
164
- # 5c. Load Style LoRA
165
- print("Loading Style LoRA weights...")
166
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
167
 
168
- print(f"Fusing Style LoRA with scale {Config.LORA_STRENGTH}...")
169
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
170
- print(" [OK] Style LoRA fused.")
171
 
172
  # 6. Load Preprocessors
173
  print("Loading Preprocessors (LeReS, LineArtAnime)...")
 
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
+ LCMScheduler,
10
+ # AutoencoderKL # Removed as requested
11
  )
12
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
13
 
 
110
 
111
  self.pipeline.to(Config.DEVICE)
112
 
113
+ # --- NEW: Enable xFormers ---
114
  try:
115
  self.pipeline.enable_xformers_memory_efficient_attention()
116
  print(" [OK] xFormers memory efficient attention enabled.")
117
  except Exception as e:
118
  print(f" [WARNING] Failed to enable xFormers: {e}")
119
+ # --- END NEW ---
120
+
121
+ # 4. Set Scheduler
122
+ # --- MODIFIED: Disable clipping to prevent NaN artifacts ---
123
+ print("Configuring LCMScheduler...")
124
+ scheduler_config = self.pipeline.scheduler.config
125
+ scheduler_config['clip_sample'] = False # <-- THIS IS THE FIX
126
+ self.pipeline.scheduler = LCMScheduler.from_config(scheduler_config)
127
+ print(" [OK] LCMScheduler loaded (clip_sample=False).")
128
+ # --- END MODIFIED ---
129
+
130
+ # 5. Load Adapters (IP-Adapter & LoRA)
131
+ print("Loading Adapters (IP-Adapter & LoRA)...")
132
 
 
133
  ip_adapter_filename = "ip-adapter.bin"
134
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
135
 
 
145
  print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
146
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
147
 
148
+ print("Loading LoRA weights...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
150
 
151
+ print(f"Fusing LoRA with scale {Config.LORA_STRENGTH}...")
152
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
153
+ print(" [OK] LoRA fused.")
154
 
155
  # 6. Load Preprocessors
156
  print("Loading Preprocessors (LeReS, LineArtAnime)...")