pixagram-neo-backup / models.py
primerz's picture
Upload 11 files
f179fb3 verified
raw
history blame
13.6 kB
"""
Model loading and initialization for Pixagram AI Pixel Art Generator
Torch 2.1.1 optimized with Depth Anything V2
"""
import torch
import time
from diffusers import (
StableDiffusionXLControlNetImg2ImgPipeline,
ControlNetModel,
AutoencoderKL,
LCMScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPVisionModelWithProjection
from transformers import BlipProcessor, BlipForConditionalGeneration
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType
from ip_attention_processor_compatible import IPAttnProcessorCompatible as IPAttnProcessor2_0
from resampler_compatible import create_compatible_resampler
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
def download_model_with_retry(repo_id, filename, max_retries=None):
"""Download model with retry logic and proper token handling."""
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
for attempt in range(max_retries):
try:
print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
kwargs = {"repo_type": "model"}
if HUGGINGFACE_TOKEN:
kwargs["token"] = HUGGINGFACE_TOKEN
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
**kwargs
)
print(f" [OK] Downloaded: {filename}")
return path
except Exception as e:
print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
raise
return None
def load_face_analysis():
"""
Load face analysis with GPU/CPU fallback.
Critical fix: InsightFace often fails on GPU, CPU fallback essential.
"""
print("Loading face analysis model...")
# Try GPU first
try:
face_app = FaceAnalysis(
name=FACE_DETECTION_CONFIG['model_name'],
root='./models/insightface',
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
face_app.prepare(
ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
det_size=FACE_DETECTION_CONFIG['det_size']
)
print(" [OK] Face analysis loaded (GPU)")
return face_app, True
except Exception as e:
print(f" [WARNING] GPU face detection failed: {e}")
# Fallback to CPU
try:
print(" [INFO] Trying CPU fallback...")
face_app = FaceAnalysis(
name=FACE_DETECTION_CONFIG['model_name'],
root='./models/insightface',
providers=['CPUExecutionProvider']
)
face_app.prepare(
ctx_id=-1, # CPU context
det_size=FACE_DETECTION_CONFIG['det_size']
)
print(" [OK] Face analysis loaded (CPU fallback)")
return face_app, True
except Exception as e:
print(f" [ERROR] Face detection not available: {e}")
import traceback
traceback.print_exc()
return None, False
def load_depth_anything_v2():
"""
Load Depth Anything V2 - faster and better quality than Zoe.
3-5x faster, sharper details, Apache 2.0 license (Small model).
"""
print("Loading Depth Anything V2 (3-5x faster than Zoe)...")
try:
from transformers import pipeline
depth_pipe = pipeline(
task="depth-estimation",
model="depth-anything/Depth-Anything-V2-Small",
device=0 if device == "cuda" else -1
)
print(" [OK] Depth Anything V2 loaded (state-of-the-art quality)")
return depth_pipe, True
except Exception as e:
print(f" [WARNING] Depth Anything V2 not available: {e}")
return None, False
def load_depth_detector():
"""
Load depth detector with fallback chain:
1. Depth Anything V2 (fastest, best quality)
2. Zoe Depth (fallback)
3. Grayscale (emergency fallback)
"""
# Try Depth Anything V2 first
depth_anything, success = load_depth_anything_v2()
if success:
return depth_anything, True, "depth_anything_v2"
# Fallback to Zoe
print("Loading Zoe Depth detector (fallback)...")
try:
zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe_depth.to(device)
print(" [OK] Zoe Depth loaded")
return zoe_depth, True, "zoe"
except Exception as e:
print(f" [WARNING] Zoe Depth not available: {e}")
return None, False, "grayscale"
def load_controlnets():
"""Load ControlNet models."""
print("Loading ControlNet Zoe Depth model...")
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=dtype
).to(device)
print(" [OK] ControlNet Depth loaded")
print("Loading InstantID ControlNet...")
try:
controlnet_instantid = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
).to(device)
print(" [OK] InstantID ControlNet loaded")
return controlnet_depth, controlnet_instantid, True
except Exception as e:
print(f" [WARNING] InstantID ControlNet not available: {e}")
return controlnet_depth, None, False
def load_image_encoder():
"""Load CLIP Image Encoder for IP-Adapter."""
print("Loading CLIP Image Encoder...")
try:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=dtype
).to(device)
print(" [OK] CLIP Image Encoder loaded")
return image_encoder
except Exception as e:
print(f" [ERROR] Could not load image encoder: {e}")
return None
def load_sdxl_pipeline(controlnets):
"""Load SDXL checkpoint."""
print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
try:
model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
model_path,
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
print(" [OK] Custom checkpoint loaded")
return pipe, True
except Exception as e:
print(f" [WARNING] Could not load custom checkpoint: {e}")
print(" Using default SDXL base")
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets,
torch_dtype=dtype,
use_safetensors=True
).to(device)
return pipe, False
def load_lora(pipe):
"""Load LORA."""
print("Loading LORA (retroart) from HuggingFace Hub...")
try:
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
pipe.load_lora_weights(lora_path)
print(f" [OK] LORA loaded")
return True
except Exception as e:
print(f" [WARNING] Could not load LORA: {e}")
return False
def setup_ip_adapter(pipe, image_encoder):
"""Setup IP-Adapter with compatible architecture."""
if image_encoder is None:
return None, False
print("Setting up IP-Adapter...")
try:
ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
ip_adapter_state_dict = torch.load(ip_adapter_path, map_location="cpu")
image_proj_state_dict = {}
ip_state_dict = {}
for key, value in ip_adapter_state_dict.items():
if key.startswith("image_proj."):
image_proj_state_dict[key.replace("image_proj.", "")] = value
elif key.startswith("ip_adapter."):
ip_state_dict[key.replace("ip_adapter.", "")] = value
print("Creating Compatible Perceiver Resampler...")
# Create resampler with compatible architecture
image_proj_model = create_compatible_resampler(
num_queries=4,
embedding_dim=512,
output_dim=pipe.unet.config.cross_attention_dim,
device=device,
dtype=dtype
)
# Load pretrained weights
try:
if 'latents' in image_proj_state_dict:
image_proj_model.load_state_dict(image_proj_state_dict, strict=False)
print(" [OK] Resampler loaded with pretrained weights")
else:
print(" [INFO] Using randomly initialized Resampler")
except Exception as e:
print(f" [INFO] Resampler weights: {e}")
# Setup attention processors
attn_procs = {}
for name in pipe.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipe.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipe.unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=4
).to(device, dtype=dtype)
pipe.unet.set_attn_processor(attn_procs)
ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
ip_layers.load_state_dict(ip_state_dict, strict=False)
print(" [OK] IP-Adapter loaded with InstantID weights")
pipe.image_encoder = image_encoder
return image_proj_model, True
except Exception as e:
print(f" [ERROR] Could not load IP-Adapter: {e}")
import traceback
traceback.print_exc()
return None, False
def setup_compel(pipe):
"""Setup Compel."""
print("Setting up Compel...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel loaded")
return compel, True
except Exception as e:
print(f" [WARNING] Compel not available: {e}")
return None, False
def setup_scheduler(pipe):
"""Setup LCM scheduler."""
print("Setting up LCM scheduler...")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler configured")
def optimize_pipeline(pipe):
"""Apply torch 2.1.1 optimizations."""
# Enable attention optimizations
pipe.unet.set_attn_processor(AttnProcessor2_0())
# xformers
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except Exception as e:
print(f" [INFO] xformers not available: {e}")
# TORCH 2.1.1: Compile UNet for 50-100% speedup
if hasattr(torch, 'compile') and device == "cuda":
try:
print(" [TORCH 2.1] Compiling UNet (first run +30s, then 50-100% faster)...")
pipe.unet = torch.compile(
pipe.unet,
mode="reduce-overhead", # Faster for repeated inference
fullgraph=False # More stable with ControlNet
)
print(" [OK] UNet compiled")
except Exception as e:
print(f" [INFO] torch.compile not available: {e}")
def load_caption_model():
"""Load BLIP caption model."""
print("Loading BLIP model...")
try:
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base",
torch_dtype=dtype
).to(device)
print(" [OK] BLIP model loaded")
return caption_processor, caption_model, True
except Exception as e:
print(f" [WARNING] BLIP not available: {e}")
return None, None, False
def set_clip_skip(pipe):
"""Set CLIP skip."""
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip set to {CLIP_SKIP}")
print("[OK] Model loading functions ready (Torch 2.1.1 + Depth Anything V2)")