pixagram-neo-backup / models.py
primerz's picture
Update models.py
b9e8f75 verified
raw
history blame
19.8 kB
"""
Model loading and initialization for Pixagram AI Pixel Art Generator
FIXED VERSION with proper IP-Adapter and BLIP-2 support
"""
import torch
import time
import os
import shutil
from diffusers import (
StableDiffusionXLControlNetImg2ImgPipeline,
ControlNetModel,
AutoencoderKL,
LCMScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import (
CLIPVisionModelWithProjection, CLIPTokenizer,
CLIPTextModel, CLIPTextModelWithProjection
)
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector
from huggingface_hub import hf_hub_download, snapshot_download
# --- START FIX: Import Compel ---
from compel import Compel, ReturnedEmbeddingsType
# --- END FIX ---
# Use reference implementation's attention processor
from attention_processor import IPAttnProcessor2_0, AttnProcessor
from resampler import Resampler
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs):
"""Download model with retry logic and proper token handling."""
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
# Ensure token is passed if available
if HUGGINGFACE_TOKEN and "token" not in kwargs:
kwargs["token"] = HUGGINGFACE_TOKEN
for attempt in range(max_retries):
try:
print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
return hf_hub_download(
repo_id=repo_id,
filename=filename,
**kwargs
)
except Exception as e:
print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
raise
return None
def load_face_analysis():
"""
Load face analysis model with proper model downloading from HuggingFace.
Downloads from DIAMONIK7777/antelopev2 which has the correct model structure.
"""
print("Loading face analysis model...")
try:
antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
# --- FIX: Load InsightFace on CPU to save VRAM ---
face_app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
face_app.prepare(ctx_id=0, det_size=(640, 640))
print(" [OK] Face analysis loaded (on CPU)")
return face_app, True
except Exception as e:
print(f" [ERROR] Face detection not available: {e}")
import traceback
traceback.print_exc()
return None, False
def load_depth_detector():
"""
Load depth detector with fallback hierarchy: Leres → Zoe → Midas.
Returns (detector, detector_type, success).
"""
print("Loading depth detector with fallback hierarchy...")
# Try LeresDetector first (best quality)
try:
print(" Attempting LeresDetector (highest quality)...")
# --- FIX: Load on CPU ---
leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators")
# leres_depth.to(device) # Removed
print(" [OK] LeresDetector loaded successfully (on CPU)")
return leres_depth, 'leres', True
except Exception as e:
print(f" [INFO] LeresDetector not available: {e}")
# Fallback to ZoeDetector
try:
print(" Attempting ZoeDetector (fallback #1)...")
# --- FIX: Load on CPU ---
zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
# zoe_depth.to(device) # Removed
print(" [OK] ZoeDetector loaded successfully (on CPU)")
return zoe_depth, 'zoe', True
except Exception as e:
print(f" [INFO] ZoeDetector not available: {e}")
# Final fallback to MidasDetector
try:
print(" Attempting MidasDetector (fallback #2)...")
# --- FIX: Load on CPU ---
midas_depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
# midas_depth.to(device) # Removed
print(" [OK] MidasDetector loaded successfully (on CPU)")
return midas_depth, 'midas', True
except Exception as e:
print(f" [WARNING] MidasDetector not available: {e}")
print(" [ERROR] No depth detector available")
return None, None, False
# --- NEW FUNCTION ---
def load_openpose_detector():
"""Load OpenPose detector."""
print("Loading OpenPose detector...")
try:
# --- FIX: Load on CPU ---
openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# openpose.to(device) # Removed
print(" [OK] OpenPose loaded successfully (on CPU)")
return openpose, True
except Exception as e:
print(f" [WARNING] OpenPose not available: {e}")
return None, False
# --- END NEW FUNCTION ---
# --- NEW FUNCTION ---
def load_mediapipe_face_detector():
"""Load MediapipeFaceDetector for advanced face detection."""
print("Loading MediapipeFaceDetector...")
try:
face_detector = MediapipeFaceDetector()
print(" [OK] MediapipeFaceDetector loaded successfully")
return face_detector, True
except Exception as e:
print(f" [WARNING] MediapipeFaceDetector not available: {e}")
return None, False
# --- END NEW FUNCTION ---
def load_controlnets():
"""Load ControlNet models."""
print("Loading ControlNet Zoe Depth model...")
# --- FIX: Load core models on GPU ---
controlnet_depth = ControlNetModel.from_pretrained(
"xinsir/controlnet-depth-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet Depth loaded (on GPU)")
# --- NEW: Load OpenPose ControlNet ---
print("Loading ControlNet OpenPose model...")
try:
# --- FIX: Load core models on GPU ---
controlnet_openpose = ControlNetModel.from_pretrained(
"xinsir/controlnet-openpose-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet OpenPose loaded (on GPU)")
except Exception as e:
print(f" [WARNING] ControlNet OpenPose not available: {e}")
controlnet_openpose = None
# --- END NEW ---
print("Loading InstantID ControlNet...")
try:
# --- FIX: Load core models on GPU ---
controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
).to(device)
print(" [OK] InstantID ControlNet loaded successfully (on GPU)")
# Return all three models
return controlnet_depth, controlnet_instantid, controlnet_openpose, True
except Exception as e:
print(f" [WARNING] InstantID ControlNet not available: {e}")
# Return models, indicating InstantID failure
return controlnet_depth, None, controlnet_openpose, False
def load_image_encoder():
"""Load CLIP Image Encoder for IP-Adapter."""
print("Loading CLIP Image Encoder for IP-Adapter...")
try:
# --- FIX: Load core models on GPU ---
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=dtype
).to(device)
print(" [OK] CLIP Image Encoder loaded successfully (on GPU)")
return image_encoder
except Exception as e:
print(f" [ERROR] Could not load image encoder: {e}")
return None
def load_sdxl_pipeline(controlnets):
"""Load SDXL checkpoint from HuggingFace Hub."""
print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
# --- START FIX ---
# Load tokenizers and text encoders from the base model first
# This guarantees they exist, even if the single file doesn't have them
print(" Loading base tokenizers and text encoders...")
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
try:
tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(
BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype
).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype
).to(device)
print(" [OK] Base text/token models loaded")
except Exception as e:
print(f" [ERROR] Could not load base text models: {e}")
print(" Pipeline will likely fail. Check HF connection/model access.")
# Allow it to continue, but it will likely fail below
tokenizer = None
tokenizer_2 = None
text_encoder = None
text_encoder_2 = None
# --- END FIX ---
try:
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
# --- START FIX ---
# Pass the pre-loaded models to from_single_file
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True,
# Explicitly provide the models
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
).to(device) # This main pipe MUST be on device
# --- END FIX ---
print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
return pipe, True
except Exception as e:
print(f" [WARNING] Could not load custom checkpoint: {e}")
print(" Using default SDXL base model")
# The fallback logic is already correct
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device) # This main pipe MUST be on device
return pipe, False
def load_loras(pipe):
"""Load all LORAs from HuggingFace Hub."""
print("Loading all LORAs from HuggingFace Hub...")
loaded_loras = {}
lora_files = {
"retroart": MODEL_FILES.get("lora_retroart"),
"vga": MODEL_FILES.get("lora_vga"),
"lucasart": MODEL_FILES.get("lora_lucasart")
}
for adapter_name, filename in lora_files.items():
if not filename:
print(f" [INFO] No file specified for LORA '{adapter_name}', skipping.")
loaded_loras[adapter_name] = False
continue
try:
lora_path = download_model_with_retry(MODEL_REPO, filename, repo_type="model")
pipe.load_lora_weights(lora_path, adapter_name=adapter_name)
print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'")
loaded_loras[adapter_name] = True
except Exception as e:
print(f" [WARNING] Could not load LORA {filename}: {e}")
loaded_loras[adapter_name] = False
success = any(loaded_loras.values())
if not success:
print(" [WARNING] No LORAs were loaded successfully.")
return loaded_loras, success
def setup_ip_adapter(pipe, image_encoder):
"""
Setup IP-Adapter for InstantID face embeddings.
This is CRITICAL for face preservation.
"""
if image_encoder is None:
return None, False
print("Setting up IP-Adapter for InstantID face embeddings...")
try:
# Download InstantID weights
ip_adapter_path = download_model_with_retry(
"InstantX/InstantID",
"ip-adapter.bin",
repo_type="model"
)
# Load full state dict
state_dict = torch.load(ip_adapter_path, map_location="cpu")
# Extract image_proj and ip_adapter weights
image_proj_state_dict = {}
ip_adapter_state_dict = {}
for key, value in state_dict.items():
if key.startswith("image_proj."):
image_proj_state_dict[key.replace("image_proj.", "")] = value
elif key.startswith("ip_adapter."):
ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
# Create Resampler with CORRECT parameters
print("Creating Resampler (Perceiver architecture)...")
image_proj_model = Resampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=16,
embedding_dim=512, # CRITICAL: Must match InsightFace embedding size
output_dim=pipe.unet.config.cross_attention_dim,
ff_mult=4
)
image_proj_model.eval()
image_proj_model = image_proj_model.to(device, dtype=dtype)
# Load image_proj weights
if image_proj_state_dict:
try:
image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
print(" [OK] Resampler loaded with pretrained weights")
except Exception as e:
print(f" [WARNING] Could not load Resampler weights: {e}")
# Setup IP-Adapter attention processors
print("Setting up IP-Adapter attention processors...")
attn_procs = {}
num_tokens = 16
for name in pipe.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipe.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipe.unet.config.block_out_channels[block_id]
else:
hidden_size = pipe.unet.config.block_out_channels[-1]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_tokens
).to(device, dtype=dtype)
# Set attention processors
pipe.unet.set_attn_processor(attn_procs)
# Load IP-Adapter weights
if ip_adapter_state_dict:
try:
ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
print(" [OK] IP-Adapter attention weights loaded")
except Exception as e:
print(f" [WARNING] Could not load IP-Adapter weights: {e}")
# Store image encoder
pipe.image_encoder = image_encoder
print(" [OK] IP-Adapter fully loaded with InstantID architecture")
print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
print(f" - Face embeddings: 512D -> 16x{pipe.unet.config.cross_attention_dim}D")
return image_proj_model, True
except Exception as e:
print(f" [ERROR] Could not setup IP-Adapter: {e}")
import traceback
traceback.print_exc()
return None, False
# --- START FIX: Replace setup_cappella with setup_compel ---
def setup_compel(pipe):
"""Setup Compel for prompt encoding."""
print("Setting up Compel (prompt encoder)...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel loaded successfully.")
return compel, True
except Exception as e:
print(f" [WARNING] Compel not available: {e}")
import traceback
traceback.print_exc()
return None, False
# --- END FIX ---
def setup_scheduler(pipe):
"""Setup LCM scheduler."""
print("Setting up LCM scheduler...")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler configured")
def optimize_pipeline(pipe):
"""Apply optimizations to pipeline."""
# --- FIX: Removed enable_model_cpu_offload() ---
# Try to enable xformers
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except Exception as e:
print(f" [INFO] xformers not available: {e}")
def load_caption_model():
"""
Load caption model with proper error handling.
Tries multiple models in order of quality.
"""
print("Loading caption model...")
# Try GIT-Large first (good balance of quality and compatibility)
try:
from transformers import AutoProcessor, AutoModelForCausalLM
print(" Attempting GIT-Large (recommended)...")
caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
# --- FIX: Load on CPU ---
caption_model = AutoModelForCausalLM.from_pretrained(
"microsoft/git-large-coco",
torch_dtype=dtype
) # .to(device) removed
print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)")
return caption_processor, caption_model, True, 'git'
except Exception as e1:
print(f" [INFO] GIT-Large not available: {e1}")
# Try BLIP base as fallback
try:
from transformers import BlipProcessor, BlipForConditionalGeneration
print(" Attempting BLIP base (fallback)...")
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# --- FIX: Load on CPU ---
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
torch_dtype=dtype
) # .to(device) removed
print(" [OK] BLIP base model loaded (standard captions, on CPU)")
return caption_processor, caption_model, True, 'blip'
except Exception as e2:
print(f" [WARNING] Caption models not available: {e2}")
print(" Caption generation will be disabled")
return None, None, False, 'none'
def set_clip_skip(pipe):
"""Set CLIP skip value."""
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip set to {CLIP_SKIP}")
print("[OK] Model loading functions ready")