primerz commited on
Commit
f1e1174
·
verified ·
1 Parent(s): 83bb9ad

Upload 2 files

Browse files
Files changed (2) hide show
  1. generator.py +8 -4
  2. models.py +45 -13
generator.py CHANGED
@@ -550,10 +550,14 @@ class RetroArtConverter:
550
  # Set LORA scale
551
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
552
  try:
553
- # For SDXL with LORA, use set_adapters with proper names
554
- adapter_names = ["retroart"] # The adapter name from loading
555
- self.pipe.set_adapters(adapter_names, adapter_weights=[lora_scale])
556
- print(f"[LORA] Set adapter 'retroart' with scale: {lora_scale}")
 
 
 
 
557
 
558
  except Exception as e:
559
  print(f"[WARNING] LORA set_adapters failed: {e}")
 
550
  # Set LORA scale
551
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
552
  try:
553
+ # Try retroart first (if loaded with that name), then default_0
554
+ for adapter_name in ["retroart", "default_0"]:
555
+ try:
556
+ self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
557
+ print(f"[LORA] Set adapter '{adapter_name}' with scale: {lora_scale}")
558
+ break
559
+ except:
560
+ continue
561
 
562
  except Exception as e:
563
  print(f"[WARNING] LORA set_adapters failed: {e}")
models.py CHANGED
@@ -164,12 +164,19 @@ def load_lora(pipe):
164
  print("Loading LORA (retroart) from HuggingFace Hub...")
165
  try:
166
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
167
- pipe.load_lora_weights(lora_path)
168
- print(f" [OK] LORA loaded successfully")
 
169
  return True
170
  except Exception as e:
171
- print(f" [WARNING] Could not load LORA: {e}")
172
- return False
 
 
 
 
 
 
173
 
174
 
175
  def setup_ip_adapter(pipe, image_encoder):
@@ -191,15 +198,29 @@ def setup_ip_adapter(pipe, image_encoder):
191
  # Load full state dict
192
  state_dict = torch.load(ip_adapter_path, map_location="cpu")
193
 
194
- # Extract image_proj and ip_adapter weights
 
 
 
195
  image_proj_state_dict = {}
196
  ip_adapter_state_dict = {}
197
 
198
  for key, value in state_dict.items():
199
- if key.startswith("image_proj."):
200
- image_proj_state_dict[key.replace("image_proj.", "")] = value
201
- elif key.startswith("ip_adapter."):
202
- ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  # Create Resampler (image projection model) with CORRECT parameters from reference
205
  print("Creating Resampler (Perceiver architecture)...")
@@ -220,13 +241,24 @@ def setup_ip_adapter(pipe, image_encoder):
220
  # Load image_proj weights
221
  if image_proj_state_dict:
222
  try:
 
223
  image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
224
- print(" [OK] Resampler loaded with pretrained weights")
225
  except Exception as e:
226
- print(f" [WARNING] Could not load Resampler weights: {e}")
227
- print(" Using randomly initialized Resampler")
 
 
 
 
 
 
 
 
 
228
  else:
229
- print(" [WARNING] No image_proj weights found, using random initialization")
 
230
 
231
  # Setup IP-Adapter attention processors
232
  print("Setting up IP-Adapter attention processors...")
 
164
  print("Loading LORA (retroart) from HuggingFace Hub...")
165
  try:
166
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
167
+ # Load with explicit adapter name to avoid default_0
168
+ pipe.load_lora_weights(lora_path, adapter_name="retroart")
169
+ print(f" [OK] LORA loaded successfully as 'retroart' adapter")
170
  return True
171
  except Exception as e:
172
+ # Fallback to default loading
173
+ try:
174
+ pipe.load_lora_weights(lora_path)
175
+ print(f" [OK] LORA loaded successfully (default adapter)")
176
+ return True
177
+ except Exception as e2:
178
+ print(f" [WARNING] Could not load LORA: {e2}")
179
+ return False
180
 
181
 
182
  def setup_ip_adapter(pipe, image_encoder):
 
198
  # Load full state dict
199
  state_dict = torch.load(ip_adapter_path, map_location="cpu")
200
 
201
+ # Debug: Print available keys
202
+ print(f"[DEBUG] State dict keys sample: {list(state_dict.keys())[:5]}")
203
+
204
+ # Extract image_proj and ip_adapter weights with flexible key matching
205
  image_proj_state_dict = {}
206
  ip_adapter_state_dict = {}
207
 
208
  for key, value in state_dict.items():
209
+ # Handle different possible key formats
210
+ if "image_proj" in key:
211
+ # Remove any prefix before image_proj
212
+ clean_key = key.split("image_proj.")[-1] if "image_proj." in key else key
213
+ image_proj_state_dict[clean_key] = value
214
+ elif "ip_adapter" in key or "to_k_ip" in key or "to_v_ip" in key:
215
+ # IP adapter weights might not have prefix
216
+ if "ip_adapter." in key:
217
+ clean_key = key.replace("ip_adapter.", "")
218
+ else:
219
+ clean_key = key
220
+ ip_adapter_state_dict[clean_key] = value
221
+
222
+ print(f"[DEBUG] Found {len(image_proj_state_dict)} image_proj weights")
223
+ print(f"[DEBUG] Found {len(ip_adapter_state_dict)} ip_adapter weights")
224
 
225
  # Create Resampler (image projection model) with CORRECT parameters from reference
226
  print("Creating Resampler (Perceiver architecture)...")
 
241
  # Load image_proj weights
242
  if image_proj_state_dict:
243
  try:
244
+ # Try strict loading first
245
  image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
246
+ print(" [OK] Resampler loaded with pretrained weights (strict)")
247
  except Exception as e:
248
+ # Try non-strict if strict fails
249
+ try:
250
+ missing, unexpected = image_proj_model.load_state_dict(image_proj_state_dict, strict=False)
251
+ print(f" [OK] Resampler loaded with pretrained weights (non-strict)")
252
+ if missing:
253
+ print(f" Missing keys: {missing[:5]}...") # Show first 5
254
+ if unexpected:
255
+ print(f" Unexpected keys: {unexpected[:5]}...") # Show first 5
256
+ except Exception as e2:
257
+ print(f" [WARNING] Could not load Resampler weights: {e2}")
258
+ print(" Using randomly initialized Resampler")
259
  else:
260
+ print(" [WARNING] No image_proj weights found in state dict")
261
+ print(" Using randomly initialized Resampler")
262
 
263
  # Setup IP-Adapter attention processors
264
  print("Setting up IP-Adapter attention processors...")