pixagram-neo-backup / models.py
primerz's picture
Update models.py
dc93517 verified
"""
Model loading and initialization for Pixagram AI Pixel Art Generator
HYBRID VERSION - Supports both local files and HuggingFace repos
"""
import torch
import time
import os
from diffusers import (
ControlNetModel,
AutoencoderKL,
LCMScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPVisionModelWithProjection
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector
from controlnet_aux.processor import Processor
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType
# Import the custom pipeline that has load_ip_adapter_instantid method
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
def download_model_with_retry(repo_id, filename, max_retries=None):
"""Download model with retry logic and proper token handling."""
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
for attempt in range(max_retries):
try:
print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
kwargs = {"repo_type": "model"}
if HUGGINGFACE_TOKEN:
kwargs["token"] = HUGGINGFACE_TOKEN
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
**kwargs
)
print(f" [OK] Downloaded: {filename}")
return path
except Exception as e:
print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
raise
return None
def load_face_analysis():
"""Load face analysis model with proper error handling."""
print("Loading face analysis model...")
try:
face_app = FaceAnalysis(
name='antelopev2',
root='/data',
providers=['CPUExecutionProvider']
)
face_app.prepare(
ctx_id=0,
det_size=(640, 640)
)
print(" [OK] Face analysis model loaded successfully")
return face_app, True
except Exception as e:
print(f" [WARNING] Face detection not available: {e}")
return None, False
def load_depth_detector():
"""Load Zoe Depth detector."""
print("Loading Zoe Depth detector...")
try:
zoe_depth = LeresDetector.from_pretrained(
"lllyasviel/Annotators"
)
zoe_depth.to(device)
print(" [OK] Zoe Depth loaded successfully")
return zoe_depth, True
except Exception as e:
print(f" [WARNING] Zoe Depth not available: {e}")
return None, False
def load_controlnets():
"""Load ControlNet models."""
print("Loading ControlNet Zoe Depth model...")
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet Depth loaded")
print("Loading InstantID ControlNet...")
try:
controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
).to(device)
print(" [OK] InstantID ControlNet loaded successfully")
return controlnet_depth, controlnet_instantid, True
except Exception as e:
print(f" [WARNING] InstantID ControlNet not available: {e}")
return controlnet_depth, None, False
def load_image_encoder():
"""Load CLIP Image Encoder for IP-Adapter."""
print("Loading CLIP Image Encoder for IP-Adapter...")
try:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=dtype
).to(device)
print(" [OK] CLIP Image Encoder loaded successfully")
return image_encoder
except Exception as e:
print(f" [ERROR] Could not load image encoder: {e}")
return None
def load_sdxl_pipeline(controlnets):
"""
Load SDXL checkpoint - HYBRID APPROACH.
Tries in order:
1. Local file via from_single_file (like examplemodels.py)
2. HuggingFace repo via from_pretrained (like exampleapp.py)
3. Fallback to known working checkpoint
"""
print("Loading SDXL checkpoint (hybrid approach)...")
# ATTEMPT 1: Try loading from local file using from_single_file
# This is the examplemodels.py approach
if MODEL_FILES.get('checkpoint'):
try:
print(f" [Attempt 1] Loading from local file via from_single_file...")
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
# Check if file exists and is a safetensors file
if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'):
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
print(f" [OK] Checkpoint loaded from local file: {model_path}")
return pipe, True
else:
print(f" [INFO] Local file not found or invalid, trying next method...")
except Exception as e:
print(f" [WARNING] from_single_file failed: {e}")
print(f" [INFO] Trying from_pretrained approach...")
# ATTEMPT 2: Try loading from HuggingFace repo using from_pretrained
# This is the exampleapp.py approach
try:
print(f" [Attempt 2] Loading from HuggingFace repo via from_pretrained...")
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
MODEL_REPO,
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
print(f" [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}")
return pipe, True
except Exception as e:
print(f" [WARNING] from_pretrained failed: {e}")
print(f" [INFO] Trying fallback checkpoint...")
# ATTEMPT 3: Fallback to known working checkpoint
try:
print(f" [Attempt 3] Loading fallback: frankjoshua/albedobaseXL_v21...")
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"frankjoshua/albedobaseXL_v21",
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
print(" [OK] Fallback checkpoint loaded successfully")
return pipe, False
except Exception as e:
print(f" [WARNING] Fallback also failed: {e}")
print(" [INFO] Trying SDXL base model...")
# ATTEMPT 4: Last resort - SDXL base
print(f" [Attempt 4] Loading base SDXL model...")
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
print(" [OK] Base SDXL model loaded")
return pipe, False
def load_lora(pipe):
"""Load LORA from HuggingFace Hub."""
print("Loading LORA (retroart) from HuggingFace Hub...")
try:
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
pipe.load_lora_weights(lora_path, adapter_name="retroart")
print(f" [OK] LORA loaded successfully")
return True
except Exception as e:
print(f" [WARNING] Could not load LORA: {e}")
return False
def setup_ip_adapter(pipe):
"""
Setup IP-Adapter for InstantID - SIMPLIFIED VERSION.
Uses pipeline's built-in method (like exampleapp.py lines 139-140).
This is much simpler and more reliable than manual Resampler setup.
"""
print("Setting up IP-Adapter for InstantID face embeddings...")
try:
# Download InstantID IP-Adapter weights
face_adapter_path = download_model_with_retry(
"InstantX/InstantID",
"ip-adapter.bin"
)
# Use the pipeline's built-in method
# This handles all the complex Resampler setup automatically
pipe.load_ip_adapter_instantid(face_adapter_path)
# Set initial scale (can be adjusted later during generation)
pipe.set_ip_adapter_scale(0.8)
print(" [OK] IP-Adapter loaded successfully with built-in method")
print(" - Pipeline handles Resampler and attention processors automatically")
print(" - Face embeddings will be properly integrated during generation")
return True
except Exception as e:
print(f" [ERROR] Could not setup IP-Adapter: {e}")
import traceback
traceback.print_exc()
return False
def setup_compel(pipe):
"""Setup Compel for better SDXL prompt handling."""
print("Setting up Compel for enhanced prompt processing...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel loaded successfully")
return compel, True
except Exception as e:
print(f" [WARNING] Compel not available: {e}")
return None, False
def setup_scheduler(pipe):
"""Setup LCM scheduler."""
print("Setting up LCM scheduler...")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler configured")
def optimize_pipeline(pipe):
"""Apply optimizations to pipeline."""
# Try to enable xformers
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except Exception as e:
print(f" [INFO] xformers not available: {e}")
def load_caption_model():
"""
Load caption model with proper error handling.
Tries multiple models in order of quality.
"""
print("Loading caption model...")
# Try GIT-Large first (good balance of quality and compatibility)
try:
from transformers import AutoProcessor, AutoModelForCausalLM
print(" Attempting GIT-Large (recommended)...")
caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
caption_model = AutoModelForCausalLM.from_pretrained(
"microsoft/git-large-coco",
torch_dtype=dtype
).to(device)
print(" [OK] GIT-Large model loaded (produces detailed captions)")
return caption_processor, caption_model, True, 'git'
except Exception as e1:
print(f" [INFO] GIT-Large not available: {e1}")
# Try BLIP base as fallback
try:
from transformers import BlipProcessor, BlipForConditionalGeneration
print(" Attempting BLIP base (fallback)...")
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
torch_dtype=dtype
).to(device)
print(" [OK] BLIP base model loaded (standard captions)")
return caption_processor, caption_model, True, 'blip'
except Exception as e2:
print(f" [WARNING] Caption models not available: {e2}")
print(" Caption generation will be disabled")
return None, None, False, 'none'
def set_clip_skip(pipe):
"""Set CLIP skip value."""
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip set to {CLIP_SKIP}")
print("[OK] Model loading functions ready (HYBRID VERSION)")