pixagram-neo-backup / models.py
primerz's picture
Update models.py
70fc44d verified
raw
history blame
14.2 kB
"""
Models.py - Following examplewithface.py EXACTLY
NO MultiControlNetModel wrapper!
Using Kohya-style LoRA from lora.py (examplewithface.py lines 223-235)
"""
import torch
import time
import os
from diffusers import ControlNetModel, AutoencoderKL, LCMScheduler
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from compel import Compel, ReturnedEmbeddingsType
from pipeline_stable_diffusion_xl_instantid_img2img import (
StableDiffusionXLInstantIDImg2ImgPipeline,
draw_kps
)
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
def download_model_with_retry(repo_id, filename, max_retries=None):
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
for attempt in range(max_retries):
try:
kwargs = {"repo_type": "model"}
if HUGGINGFACE_TOKEN:
kwargs["token"] = HUGGINGFACE_TOKEN
path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
return path
except Exception as e:
if attempt < max_retries - 1:
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
raise
return None
def load_face_analysis():
"""examplewithface.py line 113"""
print("Loading face analysis...")
try:
# Download antelopev2 model
snapshot_download(
repo_id="DIAMONIK7777/antelopev2",
local_dir="/data/models/antelopev2"
)
# examplewithface.py line 113 pattern
app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
print(" [OK] Face analysis loaded")
return app, True
except Exception as e:
print(f" [ERROR] Face analysis failed: {e}")
import traceback
traceback.print_exc()
return None, False
def load_depth_detector():
"""examplewithface.py line 151-155"""
print("Loading Zoe Depth...")
try:
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe.to(device) # examplewithface.py line 155
print(" [OK] Zoe Depth loaded")
return zoe, True
except Exception as e:
print(f" [WARNING] Zoe unavailable: {e}")
return None, False
def load_controlnets():
"""examplewithface.py lines 122-126"""
print("Loading ControlNets...")
# Load but don't move to device yet - pipe.to(device) will handle it
identitynet = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
)
print(" [OK] InstantID ControlNet")
zoedepthnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=dtype
)
print(" [OK] Zoe Depth ControlNet")
return identitynet, zoedepthnet
def load_sdxl_pipeline(controlnets):
"""
examplewithface.py lines 128-145
CRITICAL: Pass controlnets as LIST - NO MultiControlNetModel!
"""
print("Loading pipeline...")
# Load VAE (line 128)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=dtype
)
print(" [OK] VAE loaded")
# Create pipeline (line 134) - controlnets as LIST!
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"frankjoshua/albedobaseXL_v21",
vae=vae,
controlnet=controlnets, # ← LIST [identitynet, zoedepthnet] - NO WRAPPER!
torch_dtype=dtype
)
print(" [OK] Pipeline created with direct controlnet list")
# LCM scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler")
# IP-Adapter (line 139)
ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
pipe.load_ip_adapter_instantid(ip_adapter_path)
pipe.set_ip_adapter_scale(0.8)
print(" [OK] IP-Adapter loaded")
pipe = pipe.to(device)
print(" [OK] Pipeline ready (following examplewithface.py EXACTLY)")
return pipe, True
# Global LoRA state
lora_path_cached = None
def load_lora(pipe):
"""Load LoRA - store path for later use"""
print("Loading LoRA...")
global lora_path_cached
try:
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
lora_path_cached = lora_path
print(f" [OK] LoRA path stored")
return True
except Exception as e:
print(f" [WARNING] LoRA failed: {e}")
return False
def fuse_lora_with_scale(pipe, lora_scale):
"""
Following examplewithface.py lines 223-235:
Use the Kohya-style LoRA loader from lora.py (NOT diffusers built-in)
"""
global lora_path_cached
if lora_path_cached is None:
return False
try:
# Import the local lora module (Kohya-style)
import lora
print(f" [LORA] Creating network from weights...")
# examplewithface.py lines 223-229
# Note: SDXL has two text encoders, pass both as a list
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
lora_model, weights_sd = lora.create_network_from_weights(
lora_scale, # multiplier
lora_path_cached, # file path
pipe.vae,
text_encoders, # Both SDXL text encoders
pipe.unet,
for_inference=True,
)
# examplewithface.py lines 231-233
print(f" [LORA] Merging to model with scale {lora_scale}...")
lora_model.merge_to(
text_encoders, pipe.unet, weights_sd, torch.float16, "cuda"
)
# Cleanup
del weights_sd
del lora_model
print(f" [OK] LoRA merged into model using Kohya loader")
return True
except Exception as e:
print(f" [ERROR] LoRA merge failed: {e}")
import traceback
traceback.print_exc()
return False
def setup_compel(pipe):
"""examplewithface.py line 145"""
print("Setting up Compel...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel ready")
return compel, True
except Exception as e:
print(f" [WARNING] Compel unavailable: {e}")
return None, False
def setup_scheduler(pipe):
pass
def optimize_pipeline(pipe):
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except:
pass
if hasattr(pipe, 'enable_vae_slicing'):
pipe.enable_vae_slicing()
if hasattr(pipe, 'enable_vae_tiling'):
pipe.enable_vae_tiling()
def load_caption_model():
print("Loading caption model...")
try:
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu")
print(" [OK] GIT-Large")
return processor, model, True, 'git'
except:
try:
from transformers import BlipProcessor, BlipForConditionalGeneration
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu")
print(" [OK] BLIP")
return processor, model, True, 'blip'
except:
return None, None, False, 'none'
def set_clip_skip(pipe):
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip {CLIP_SKIP}")
def load_image_encoder():
"""Load CLIP Image Encoder for IP-Adapter."""
print("Loading CLIP Image Encoder for IP-Adapter...")
try:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=dtype
).to(device)
print(" [OK] CLIP Image Encoder loaded successfully")
return image_encoder
except Exception as e:
print(f" [ERROR] Could not load image encoder: {e}")
return None
def setup_ip_adapter(pipe, image_encoder):
"""
Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
Based on the reference InstantID pipeline.
"""
if image_encoder is None:
return None, False
print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
try:
# Download InstantID weights
ip_adapter_path = download_model_with_retry(
"InstantX/InstantID",
"ip-adapter.bin"
)
# Load full state dict
state_dict = torch.load(ip_adapter_path, map_location="cpu")
# Extract image_proj and ip_adapter weights
image_proj_state_dict = {}
ip_adapter_state_dict = {}
for key, value in state_dict.items():
if key.startswith("image_proj."):
image_proj_state_dict[key.replace("image_proj.", "")] = value
elif key.startswith("ip_adapter."):
ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
# Create Resampler (image projection model) with CORRECT parameters from reference
print("Creating Resampler (Perceiver architecture)...")
image_proj_model = Resampler(
dim=1280, # Hidden dimension
depth=4, # IMPORTANT: 4 layers (not 8!)
dim_head=64, # Dimension per head
heads=20, # Number of heads
num_queries=16, # Number of output tokens
embedding_dim=512, # InsightFace embedding dim
output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048)
ff_mult=4 # Feedforward multiplier
)
image_proj_model.eval()
image_proj_model = image_proj_model.to(device, dtype=dtype)
# Load image_proj weights
if image_proj_state_dict:
try:
image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
print(" [OK] Resampler loaded with pretrained weights")
except Exception as e:
print(f" [WARNING] Could not load Resampler weights: {e}")
print(" Using randomly initialized Resampler")
else:
print(" [WARNING] No image_proj weights found, using random initialization")
# Setup IP-Adapter attention processors
print("Setting up IP-Adapter attention processors...")
attn_procs = {}
num_tokens = 16 # Match Resampler num_queries
for name in pipe.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = pipe.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipe.unet.config.block_out_channels[block_id]
else:
hidden_size = pipe.unet.config.block_out_channels[-1]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_tokens
).to(device, dtype=dtype)
# Set attention processors
pipe.unet.set_attn_processor(attn_procs)
# Load IP-Adapter weights into attention processors
if ip_adapter_state_dict:
try:
ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
print(" [OK] IP-Adapter attention weights loaded")
except Exception as e:
print(f" [WARNING] Could not load IP-Adapter weights: {e}")
else:
print(" [WARNING] No ip_adapter weights found")
# Store image encoder and projection model
pipe.image_encoder = image_encoder
print(" [OK] IP-Adapter fully loaded with InstantID architecture")
print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
print(f" - Face embeddings: 512D → 16x2048D")
return image_proj_model, True
except Exception as e:
print(f" [ERROR] Could not setup IP-Adapter: {e}")
import traceback
traceback.print_exc()
return None, False
__all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder', 'setup_ip_adapter']
print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py")