Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- generator.py +8 -4
- models.py +45 -13
generator.py
CHANGED
|
@@ -550,10 +550,14 @@ class RetroArtConverter:
|
|
| 550 |
# Set LORA scale
|
| 551 |
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
|
| 552 |
try:
|
| 553 |
-
#
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
except Exception as e:
|
| 559 |
print(f"[WARNING] LORA set_adapters failed: {e}")
|
|
|
|
| 550 |
# Set LORA scale
|
| 551 |
if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
|
| 552 |
try:
|
| 553 |
+
# Try retroart first (if loaded with that name), then default_0
|
| 554 |
+
for adapter_name in ["retroart", "default_0"]:
|
| 555 |
+
try:
|
| 556 |
+
self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
|
| 557 |
+
print(f"[LORA] Set adapter '{adapter_name}' with scale: {lora_scale}")
|
| 558 |
+
break
|
| 559 |
+
except:
|
| 560 |
+
continue
|
| 561 |
|
| 562 |
except Exception as e:
|
| 563 |
print(f"[WARNING] LORA set_adapters failed: {e}")
|
models.py
CHANGED
|
@@ -164,12 +164,19 @@ def load_lora(pipe):
|
|
| 164 |
print("Loading LORA (retroart) from HuggingFace Hub...")
|
| 165 |
try:
|
| 166 |
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
|
| 167 |
-
|
| 168 |
-
|
|
|
|
| 169 |
return True
|
| 170 |
except Exception as e:
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def setup_ip_adapter(pipe, image_encoder):
|
|
@@ -191,15 +198,29 @@ def setup_ip_adapter(pipe, image_encoder):
|
|
| 191 |
# Load full state dict
|
| 192 |
state_dict = torch.load(ip_adapter_path, map_location="cpu")
|
| 193 |
|
| 194 |
-
#
|
|
|
|
|
|
|
|
|
|
| 195 |
image_proj_state_dict = {}
|
| 196 |
ip_adapter_state_dict = {}
|
| 197 |
|
| 198 |
for key, value in state_dict.items():
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Create Resampler (image projection model) with CORRECT parameters from reference
|
| 205 |
print("Creating Resampler (Perceiver architecture)...")
|
|
@@ -220,13 +241,24 @@ def setup_ip_adapter(pipe, image_encoder):
|
|
| 220 |
# Load image_proj weights
|
| 221 |
if image_proj_state_dict:
|
| 222 |
try:
|
|
|
|
| 223 |
image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
|
| 224 |
-
print(" [OK] Resampler loaded with pretrained weights")
|
| 225 |
except Exception as e:
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
else:
|
| 229 |
-
print(" [WARNING] No image_proj weights found
|
|
|
|
| 230 |
|
| 231 |
# Setup IP-Adapter attention processors
|
| 232 |
print("Setting up IP-Adapter attention processors...")
|
|
|
|
| 164 |
print("Loading LORA (retroart) from HuggingFace Hub...")
|
| 165 |
try:
|
| 166 |
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
|
| 167 |
+
# Load with explicit adapter name to avoid default_0
|
| 168 |
+
pipe.load_lora_weights(lora_path, adapter_name="retroart")
|
| 169 |
+
print(f" [OK] LORA loaded successfully as 'retroart' adapter")
|
| 170 |
return True
|
| 171 |
except Exception as e:
|
| 172 |
+
# Fallback to default loading
|
| 173 |
+
try:
|
| 174 |
+
pipe.load_lora_weights(lora_path)
|
| 175 |
+
print(f" [OK] LORA loaded successfully (default adapter)")
|
| 176 |
+
return True
|
| 177 |
+
except Exception as e2:
|
| 178 |
+
print(f" [WARNING] Could not load LORA: {e2}")
|
| 179 |
+
return False
|
| 180 |
|
| 181 |
|
| 182 |
def setup_ip_adapter(pipe, image_encoder):
|
|
|
|
| 198 |
# Load full state dict
|
| 199 |
state_dict = torch.load(ip_adapter_path, map_location="cpu")
|
| 200 |
|
| 201 |
+
# Debug: Print available keys
|
| 202 |
+
print(f"[DEBUG] State dict keys sample: {list(state_dict.keys())[:5]}")
|
| 203 |
+
|
| 204 |
+
# Extract image_proj and ip_adapter weights with flexible key matching
|
| 205 |
image_proj_state_dict = {}
|
| 206 |
ip_adapter_state_dict = {}
|
| 207 |
|
| 208 |
for key, value in state_dict.items():
|
| 209 |
+
# Handle different possible key formats
|
| 210 |
+
if "image_proj" in key:
|
| 211 |
+
# Remove any prefix before image_proj
|
| 212 |
+
clean_key = key.split("image_proj.")[-1] if "image_proj." in key else key
|
| 213 |
+
image_proj_state_dict[clean_key] = value
|
| 214 |
+
elif "ip_adapter" in key or "to_k_ip" in key or "to_v_ip" in key:
|
| 215 |
+
# IP adapter weights might not have prefix
|
| 216 |
+
if "ip_adapter." in key:
|
| 217 |
+
clean_key = key.replace("ip_adapter.", "")
|
| 218 |
+
else:
|
| 219 |
+
clean_key = key
|
| 220 |
+
ip_adapter_state_dict[clean_key] = value
|
| 221 |
+
|
| 222 |
+
print(f"[DEBUG] Found {len(image_proj_state_dict)} image_proj weights")
|
| 223 |
+
print(f"[DEBUG] Found {len(ip_adapter_state_dict)} ip_adapter weights")
|
| 224 |
|
| 225 |
# Create Resampler (image projection model) with CORRECT parameters from reference
|
| 226 |
print("Creating Resampler (Perceiver architecture)...")
|
|
|
|
| 241 |
# Load image_proj weights
|
| 242 |
if image_proj_state_dict:
|
| 243 |
try:
|
| 244 |
+
# Try strict loading first
|
| 245 |
image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
|
| 246 |
+
print(" [OK] Resampler loaded with pretrained weights (strict)")
|
| 247 |
except Exception as e:
|
| 248 |
+
# Try non-strict if strict fails
|
| 249 |
+
try:
|
| 250 |
+
missing, unexpected = image_proj_model.load_state_dict(image_proj_state_dict, strict=False)
|
| 251 |
+
print(f" [OK] Resampler loaded with pretrained weights (non-strict)")
|
| 252 |
+
if missing:
|
| 253 |
+
print(f" Missing keys: {missing[:5]}...") # Show first 5
|
| 254 |
+
if unexpected:
|
| 255 |
+
print(f" Unexpected keys: {unexpected[:5]}...") # Show first 5
|
| 256 |
+
except Exception as e2:
|
| 257 |
+
print(f" [WARNING] Could not load Resampler weights: {e2}")
|
| 258 |
+
print(" Using randomly initialized Resampler")
|
| 259 |
else:
|
| 260 |
+
print(" [WARNING] No image_proj weights found in state dict")
|
| 261 |
+
print(" Using randomly initialized Resampler")
|
| 262 |
|
| 263 |
# Setup IP-Adapter attention processors
|
| 264 |
print("Setting up IP-Adapter attention processors...")
|