Spaces:
Runtime error
Runtime error
Update models.py
Browse files
models.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Model loading and initialization for Pixagram AI Pixel Art Generator
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
import torch
|
| 6 |
import time
|
|
@@ -12,11 +12,13 @@ from diffusers import (
|
|
| 12 |
)
|
| 13 |
from insightface.app import FaceAnalysis
|
| 14 |
from controlnet_aux import ZoeDetector
|
| 15 |
-
from huggingface_hub import hf_hub_download
|
| 16 |
from compel import Compel, ReturnedEmbeddingsType
|
| 17 |
|
| 18 |
# Use InstantID pipeline
|
| 19 |
-
from pipeline_stable_diffusion_xl_instantid_img2img import
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from config import (
|
| 22 |
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
|
|
@@ -59,18 +61,79 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
|
|
| 59 |
|
| 60 |
|
| 61 |
def load_face_analysis():
|
| 62 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 63 |
print("Loading face analysis model...")
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def load_depth_detector():
|
| 76 |
"""Load Zoe Depth detector with optimized memory management."""
|
|
@@ -110,23 +173,44 @@ def load_controlnets():
|
|
| 110 |
|
| 111 |
|
| 112 |
def load_sdxl_pipeline(controlnets):
|
| 113 |
-
"""
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
try:
|
| 116 |
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
|
| 117 |
|
|
|
|
| 118 |
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
|
| 119 |
model_path,
|
| 120 |
controlnet=controlnets,
|
| 121 |
torch_dtype=dtype,
|
| 122 |
use_safetensors=True
|
| 123 |
).to(device)
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
return pipe, True
|
|
|
|
| 126 |
except Exception as e:
|
| 127 |
-
print(f" [
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 131 |
controlnet=controlnets,
|
| 132 |
torch_dtype=dtype,
|
|
@@ -134,6 +218,7 @@ def load_sdxl_pipeline(controlnets):
|
|
| 134 |
).to(device)
|
| 135 |
return pipe, False
|
| 136 |
|
|
|
|
| 137 |
def load_lora(pipe):
|
| 138 |
"""Load LORA from HuggingFace Hub."""
|
| 139 |
print("Loading LORA (retroart) from HuggingFace Hub...")
|
|
|
|
| 1 |
"""
|
| 2 |
Model loading and initialization for Pixagram AI Pixel Art Generator
|
| 3 |
+
CORRECTED VERSION with proper face analysis loading
|
| 4 |
"""
|
| 5 |
import torch
|
| 6 |
import time
|
|
|
|
| 12 |
)
|
| 13 |
from insightface.app import FaceAnalysis
|
| 14 |
from controlnet_aux import ZoeDetector
|
| 15 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 16 |
from compel import Compel, ReturnedEmbeddingsType
|
| 17 |
|
| 18 |
# Use InstantID pipeline
|
| 19 |
+
from pipeline_stable_diffusion_xl_instantid_img2img import (
|
| 20 |
+
StableDiffusionXLInstantIDImg2ImgPipeline
|
| 21 |
+
)
|
| 22 |
|
| 23 |
from config import (
|
| 24 |
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def load_face_analysis():
|
| 64 |
+
"""
|
| 65 |
+
Load face analysis model using the correct approach.
|
| 66 |
+
Downloads antelopev2 model and initializes FaceAnalysis.
|
| 67 |
+
"""
|
| 68 |
print("Loading face analysis model...")
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Download antelopev2 model using snapshot_download (like working example)
|
| 72 |
+
print(" Downloading antelopev2 model files...")
|
| 73 |
+
antelope_path = snapshot_download(
|
| 74 |
+
repo_id=FACE_DETECTION_CONFIG['download_repo'],
|
| 75 |
+
local_dir=FACE_DETECTION_CONFIG['local_dir']
|
| 76 |
+
)
|
| 77 |
+
print(f" [OK] Antelopev2 downloaded to: {antelope_path}")
|
| 78 |
+
|
| 79 |
+
# Initialize FaceAnalysis with the correct root path
|
| 80 |
+
# Use CPU provider for memory efficiency (can be changed in config)
|
| 81 |
+
providers = FACE_DETECTION_CONFIG.get('providers', ['CPUExecutionProvider'])
|
| 82 |
+
|
| 83 |
+
print(f" Initializing FaceAnalysis with providers: {providers}")
|
| 84 |
+
face_app = FaceAnalysis(
|
| 85 |
+
name=FACE_DETECTION_CONFIG['model_name'],
|
| 86 |
+
root=FACE_DETECTION_CONFIG['root'],
|
| 87 |
+
providers=providers
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Prepare the model
|
| 91 |
+
face_app.prepare(
|
| 92 |
+
ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
|
| 93 |
+
det_size=FACE_DETECTION_CONFIG['det_size']
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Test the model to ensure it works
|
| 97 |
+
import numpy as np
|
| 98 |
+
test_img = np.zeros((640, 640, 3), dtype=np.uint8)
|
| 99 |
+
_ = face_app.get(test_img)
|
| 100 |
+
|
| 101 |
+
print(f" [OK] Face analysis model loaded successfully")
|
| 102 |
+
print(f" [INFO] Using providers: {providers}")
|
| 103 |
+
return face_app, True
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f" [ERROR] Face analysis loading failed: {e}")
|
| 107 |
+
import traceback
|
| 108 |
+
traceback.print_exc()
|
| 109 |
+
|
| 110 |
+
# Try fallback with different providers
|
| 111 |
+
try:
|
| 112 |
+
print(" [INFO] Trying fallback with auto-detect providers...")
|
| 113 |
+
face_app = FaceAnalysis(
|
| 114 |
+
name=FACE_DETECTION_CONFIG['model_name'],
|
| 115 |
+
root=FACE_DETECTION_CONFIG['root']
|
| 116 |
+
)
|
| 117 |
+
face_app.prepare(
|
| 118 |
+
ctx_id=0,
|
| 119 |
+
det_size=FACE_DETECTION_CONFIG['det_size']
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Test
|
| 123 |
+
import numpy as np
|
| 124 |
+
test_img = np.zeros((640, 640, 3), dtype=np.uint8)
|
| 125 |
+
_ = face_app.get(test_img)
|
| 126 |
+
|
| 127 |
+
print(" [OK] Face analysis loaded with auto-detect providers")
|
| 128 |
+
return face_app, True
|
| 129 |
+
|
| 130 |
+
except Exception as e2:
|
| 131 |
+
print(f" [WARNING] Face detection not available: {e2}")
|
| 132 |
+
print(" [INFO] Generation will continue without face preservation")
|
| 133 |
+
print(" [TIP] Check that onnxruntime is properly installed:")
|
| 134 |
+
print(" pip install onnxruntime --break-system-packages")
|
| 135 |
+
return None, False
|
| 136 |
+
|
| 137 |
|
| 138 |
def load_depth_detector():
|
| 139 |
"""Load Zoe Depth detector with optimized memory management."""
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def load_sdxl_pipeline(controlnets):
|
| 176 |
+
"""
|
| 177 |
+
Load SDXL pipeline with InstantID support.
|
| 178 |
+
controlnets MUST be a list: [identitynet, depthnet]
|
| 179 |
+
"""
|
| 180 |
+
print("Loading SDXL checkpoint with InstantID pipeline...")
|
| 181 |
try:
|
| 182 |
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
|
| 183 |
|
| 184 |
+
# CRITICAL: Use InstantID-enabled pipeline (not standard ControlNet pipeline)
|
| 185 |
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
|
| 186 |
model_path,
|
| 187 |
controlnet=controlnets,
|
| 188 |
torch_dtype=dtype,
|
| 189 |
use_safetensors=True
|
| 190 |
).to(device)
|
| 191 |
+
|
| 192 |
+
# Load IP-Adapter weights for InstantID
|
| 193 |
+
print("Loading IP-Adapter for InstantID...")
|
| 194 |
+
ip_adapter_path = download_model_with_retry(
|
| 195 |
+
"InstantX/InstantID",
|
| 196 |
+
"ip-adapter.bin"
|
| 197 |
+
)
|
| 198 |
+
pipe.load_ip_adapter_instantid(ip_adapter_path)
|
| 199 |
+
# Don't set default scale - will be set dynamically based on face detection
|
| 200 |
+
print(" [OK] IP-Adapter loaded (scale will be set dynamically)")
|
| 201 |
+
|
| 202 |
+
print(" [OK] InstantID pipeline loaded successfully")
|
| 203 |
return pipe, True
|
| 204 |
+
|
| 205 |
except Exception as e:
|
| 206 |
+
print(f" [ERROR] Could not load InstantID pipeline: {e}")
|
| 207 |
+
import traceback
|
| 208 |
+
traceback.print_exc()
|
| 209 |
+
|
| 210 |
+
# Fallback to standard pipeline
|
| 211 |
+
print(" [WARNING] Falling back to standard SDXL pipeline (no InstantID)")
|
| 212 |
+
from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
|
| 213 |
+
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
|
| 214 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 215 |
controlnet=controlnets,
|
| 216 |
torch_dtype=dtype,
|
|
|
|
| 218 |
).to(device)
|
| 219 |
return pipe, False
|
| 220 |
|
| 221 |
+
|
| 222 |
def load_lora(pipe):
|
| 223 |
"""Load LORA from HuggingFace Hub."""
|
| 224 |
print("Loading LORA (retroart) from HuggingFace Hub...")
|