pixagram-neo-backup / models.py
primerz's picture
Update models.py
612e78f verified
raw
history blame
7.61 kB
"""
Models.py - Following examplewithface.py EXACTLY
NO MultiControlNetModel wrapper!
NO fuse_lora with scale!
"""
import torch
import time
import os
from diffusers import ControlNetModel, AutoencoderKL, LCMScheduler
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from compel import Compel, ReturnedEmbeddingsType
from pipeline_stable_diffusion_xl_instantid_img2img import (
StableDiffusionXLInstantIDImg2ImgPipeline,
draw_kps
)
from config import (
device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)
def download_model_with_retry(repo_id, filename, max_retries=None):
if max_retries is None:
max_retries = DOWNLOAD_CONFIG['max_retries']
for attempt in range(max_retries):
try:
kwargs = {"repo_type": "model"}
if HUGGINGFACE_TOKEN:
kwargs["token"] = HUGGINGFACE_TOKEN
path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
return path
except Exception as e:
if attempt < max_retries - 1:
time.sleep(DOWNLOAD_CONFIG['retry_delay'])
else:
raise
return None
def load_face_analysis():
"""examplewithface.py line 113"""
print("Loading face analysis...")
try:
# Download antelopev2 model
snapshot_download(
repo_id="DIAMONIK7777/antelopev2",
local_dir="/data/models/antelopev2"
)
# examplewithface.py line 113 pattern
app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
print(" [OK] Face analysis loaded")
return app, True
except Exception as e:
print(f" [ERROR] Face analysis failed: {e}")
import traceback
traceback.print_exc()
return None, False
def load_depth_detector():
"""examplewithface.py line 151-155"""
print("Loading Zoe Depth...")
try:
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe.to(device) # examplewithface.py line 155
print(" [OK] Zoe Depth loaded")
return zoe, True
except Exception as e:
print(f" [WARNING] Zoe unavailable: {e}")
return None, False
def load_controlnets():
"""examplewithface.py lines 122-126"""
print("Loading ControlNets...")
# Load but don't move to device yet - pipe.to(device) will handle it
identitynet = ControlNetModel.from_pretrained(
"InstantX/InstantID",
subfolder="ControlNetModel",
torch_dtype=dtype
)
print(" [OK] InstantID ControlNet")
zoedepthnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-zoe-depth-sdxl-1.0",
torch_dtype=dtype
)
print(" [OK] Zoe Depth ControlNet")
return identitynet, zoedepthnet
def load_sdxl_pipeline(controlnets):
"""
examplewithface.py lines 128-145
CRITICAL: Pass controlnets as LIST - NO MultiControlNetModel!
"""
print("Loading pipeline...")
# Load VAE (line 128)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=dtype
)
print(" [OK] VAE loaded")
# Create pipeline (line 134) - controlnets as LIST!
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"frankjoshua/albedobaseXL_v21",
vae=vae,
controlnet=controlnets, # ← LIST [identitynet, zoedepthnet] - NO WRAPPER!
torch_dtype=dtype
)
print(" [OK] Pipeline created with direct controlnet list")
# LCM scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
print(" [OK] LCM scheduler")
# IP-Adapter (line 139)
ip_adapter_path = download_model_with_retry("InstantX/InstantID", "ip-adapter.bin")
pipe.load_ip_adapter_instantid(ip_adapter_path)
pipe.set_ip_adapter_scale(0.8)
print(" [OK] IP-Adapter loaded")
pipe = pipe.to(device)
print(" [OK] Pipeline ready (following examplewithface.py EXACTLY)")
return pipe, True
# Global LoRA state
lora_path_cached = None
def load_lora(pipe):
"""Load LoRA - store path for later use"""
print("Loading LoRA...")
global lora_path_cached
try:
lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
lora_path_cached = lora_path
print(f" [OK] LoRA path stored")
return True
except Exception as e:
print(f" [WARNING] LoRA failed: {e}")
return False
def fuse_lora_with_scale(pipe, lora_scale):
"""
Modern approach: Load LoRA and let cross_attention_kwargs apply scale
"""
global lora_path_cached
if lora_path_cached is None:
return False
try:
# Unload previous
try:
pipe.unload_lora_weights()
except:
pass
# Load LoRA
print(f" [LORA] Loading with scale {lora_scale}...")
pipe.load_lora_weights(lora_path_cached)
print(f" [OK] LoRA loaded (scale will be applied via cross_attention_kwargs)")
return True
except Exception as e:
print(f" [ERROR] LoRA failed: {e}")
return False
def setup_compel(pipe):
"""examplewithface.py line 145"""
print("Setting up Compel...")
try:
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
print(" [OK] Compel ready")
return compel, True
except Exception as e:
print(f" [WARNING] Compel unavailable: {e}")
return None, False
def setup_scheduler(pipe):
pass
def optimize_pipeline(pipe):
if device == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
print(" [OK] xformers enabled")
except:
pass
if hasattr(pipe, 'enable_vae_slicing'):
pipe.enable_vae_slicing()
if hasattr(pipe, 'enable_vae_tiling'):
pipe.enable_vae_tiling()
def load_caption_model():
print("Loading caption model...")
try:
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco", torch_dtype=dtype).to("cpu")
print(" [OK] GIT-Large")
return processor, model, True, 'git'
except:
try:
from transformers import BlipProcessor, BlipForConditionalGeneration
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=dtype).to("cpu")
print(" [OK] BLIP")
return processor, model, True, 'blip'
except:
return None, None, False, 'none'
def set_clip_skip(pipe):
if hasattr(pipe, 'text_encoder'):
print(f" [OK] CLIP skip {CLIP_SKIP}")
__all__ = ['draw_kps', 'fuse_lora_with_scale']
print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py")