pixagram-backup / models.py
primerz's picture
Upload 12 files
fe30f16 verified
raw
history blame
13.7 kB
"""
Model loading and initialization for Pixagram AI Pixel Art Generator
"""
import torch
import time
from diffusers import (
StableDiffusionXLControlNetImg2ImgPipeline,
ControlNetModel,
AutoencoderKL,
LCMScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPVisionModelWithProjection
from transformers import BlipProcessor, BlipForConditionalGeneration
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType
from ip_attention_processor_compatible import IPAttnProcessorCompatible as IPAttnProcessor2_0
from resampler_compatible import create_compatible_resampler as create_enhanced_resampler
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
def download_model_with_retry(repo_id, filename, max_retries=None):
"""
Download model with retry logic and proper token handling.
Args:
repo_id: HuggingFace repository ID
filename: File to download
max_retries: Maximum number of retries (uses config default if None)
Returns:
Path to downloaded file
"""
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
for attempt in range(max_retries):
try:
print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
# Use token if available
kwargs = {"repo_type": "model"}
if HUGGINGFACE_TOKEN:
kwargs["token"] = HUGGINGFACE_TOKEN
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
**kwargs
)
print(f" [OK] Downloaded: {filename}")
return path
except Exception as e:
print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
raise
return None
def load_face_analysis():
"""
Load face analysis model with proper error handling.
Returns:
Tuple of (face_app, success_bool)
"""
print("Loading face analysis model...")
try:
face_app = FaceAnalysis(
name=FACE_DETECTION_CONFIG['model_name'],
root='./models/insightface',
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
face_app.prepare(
ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
det_size=FACE_DETECTION_CONFIG['det_size']
)
print(" [OK] Face analysis model loaded successfully")
return face_app, True
except Exception as e:
print(f" [WARNING] Face detection not available: {e}")
return None, False
def load_depth_detector():
"""
Load Zoe Depth detector.
Returns:
Tuple of (zoe_depth, success_bool)
"""
print("Loading Zoe Depth detector...")
try:
zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe_depth.to(device)
print(" [OK] Zoe Depth loaded successfully")
return zoe_depth, True
except Exception as e:
print(f" [WARNING] Zoe Depth not available: {e}")
return None, False
def load_controlnets():
"""
Load ControlNet models.
Returns:
Tuple of (controlnet_depth, controlnet_instantid, instantid_success)
"""
# Load ControlNet for depth
print("Loading ControlNet Zoe Depth model...")
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet Depth loaded")
# Load InstantID ControlNet
print("Loading InstantID ControlNet...")
try:
controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
).to(device)
print(" [OK] InstantID ControlNet loaded successfully")
return controlnet_depth, controlnet_instantid, True
except Exception as e:
print(f" [WARNING] InstantID ControlNet not available: {e}")
return controlnet_depth, None, False
def load_image_encoder():
"""
Load CLIP Image Encoder for IP-Adapter.
Returns:
Image encoder or None
"""
print("Loading CLIP Image Encoder for IP-Adapter...")
try:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=dtype
).to(device)
print(" [OK] CLIP Image Encoder loaded successfully")
return image_encoder
except Exception as e:
print(f" [ERROR] Could not load image encoder: {e}")
return None
def load_sdxl_pipeline(controlnets):
"""
Load SDXL checkpoint from HuggingFace Hub.
Args:
controlnets: ControlNet model(s) to use
Returns:
Tuple of (pipeline, checkpoint_loaded_bool)
"""
print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
try:
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
return pipe, True
except Exception as e:
print(f" [WARNING] Could not load custom checkpoint: {e}")
print(" Using default SDXL base model")
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
return pipe, False
def load_lora(pipe):
"""
Load LORA from HuggingFace Hub.
Args:
pipe: Pipeline to load LORA into
Returns:
Boolean indicating success
"""
print("Loading LORA (retroart) from HuggingFace Hub...")
try:
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
pipe.load_lora_weights(lora_path)
print(f" [OK] LORA loaded successfully")
return True
except Exception as e:
print(f" [WARNING] Could not load LORA: {e}")
return False
def setup_ip_adapter(pipe, image_encoder):
"""
Setup IP-Adapter for InstantID face embeddings.
Args:
pipe: Pipeline to setup IP-Adapter on
image_encoder: CLIP image encoder
Returns:
Tuple of (image_proj_model, success_bool)
"""
if image_encoder is None:
return None, False
print("Setting up IP-Adapter for InstantID face embeddings...")
try:
# Download InstantID IP-Adapter weights
ip_adapter_path = download_model_with_retry(
"InstantX/InstantID",
"ip-adapter.bin"
)
# Load IP-Adapter state dict
ip_adapter_state_dict = torch.load(ip_adapter_path, map_location="cpu")
# Separate image projection and IP-adapter weights
image_proj_state_dict = {}
ip_state_dict = {}
for key, value in ip_adapter_state_dict.items():
if key.startswith("image_proj."):
image_proj_state_dict[key.replace("image_proj.", "")] = value
elif key.startswith("ip_adapter."):
ip_state_dict[key.replace("ip_adapter.", "")] = value
print("Setting up Enhanced Perceiver Resampler for face embedding refinement...")
# Create enhanced resampler
image_proj_model = create_enhanced_resampler(
quality_mode='quality',
num_queries=4,
output_dim=pipe.unet.config.cross_attention_dim,
device=device,
dtype=dtype
)
# Try to load pretrained Resampler weights if available
try:
if 'latents' in image_proj_state_dict:
image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
print(" [OK] Resampler loaded with pretrained weights")
else:
print(" [INFO] No pretrained Resampler weights found")
print(" Using randomly initialized Resampler")
print(" Expected +8-10% face similarity improvement")
except Exception as e:
print(f" [INFO] Resampler initialization: {e}")
print(" Using randomly initialized Resampler")
# Set up IP-Adapter attention processors
attn_procs = {}
for name in pipe.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipe.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipe.unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=4
).to(device, dtype=dtype)
pipe.unet.set_attn_processor(attn_procs)
# Load IP-adapter weights into attention processors
ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
ip_layers.load_state_dict(ip_state_dict, strict=False)
print(" [OK] IP-Adapter attention processors loaded")
# Store the image encoder
pipe.image_encoder = image_encoder
print(" [OK] IP-Adapter fully loaded with InstantID weights")
return image_proj_model, True
except Exception as e:
print(f" [ERROR] Could not load IP-Adapter: {e}")
print(" InstantID will work with keypoints only (no face embeddings)")
import traceback
traceback.print_exc()
return None, False
def setup_compel(pipe):
"""
Setup Compel for better SDXL prompt handling.
Args:
pipe: Pipeline to setup Compel on
Returns:
Tuple of (compel, success_bool)
"""
print("Setting up Compel for enhanced prompt processing...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel loaded successfully")
return compel, True
except Exception as e:
print(f" [WARNING] Compel not available: {e}")
return None, False
def setup_scheduler(pipe):
"""
Setup LCM scheduler.
Args:
pipe: Pipeline to setup scheduler on
"""
print("Setting up LCM scheduler...")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler configured")
def optimize_pipeline(pipe):
"""
Apply optimizations to pipeline.
Args:
pipe: Pipeline to optimize
"""
# Enable attention optimizations
pipe.unet.set_attn_processor(AttnProcessor2_0())
# Try to enable xformers
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except Exception as e:
print(f" [INFO] xformers not available: {e}")
def load_caption_model():
"""
Load BLIP model for optional caption generation.
Returns:
Tuple of (processor, model, success_bool)
"""
print("Loading BLIP model for optional caption generation...")
try:
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
torch_dtype=dtype
).to(device)
print(" [OK] BLIP model loaded successfully")
return caption_processor, caption_model, True
except Exception as e:
print(f" [WARNING] BLIP model not available: {e}")
print(" Caption generation will be disabled")
return None, None, False
def set_clip_skip(pipe):
"""
Set CLIP skip value.
Args:
pipe: Pipeline to set CLIP skip on
"""
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip set to {CLIP_SKIP}")
print("[OK] Model loading functions ready")