File size: 12,573 Bytes
f179fb3
bfd74f2
 
f179fb3
 
 
f6a0f97
f587361
bfd74f2
 
 
f587361
bfd74f2
 
f179fb3
dc93517
 
bfd74f2
176aa63
f179fb3
bfd74f2
 
8bbe1e4
f179fb3
 
d4170e9
f179fb3
 
 
f786a41
bfd74f2
f179fb3
 
 
 
 
bfd74f2
 
f786a41
 
 
 
bfd74f2
 
 
 
 
 
f786a41
bfd74f2
f179fb3
bfd74f2
 
f179fb3
bfd74f2
f179fb3
 
bfd74f2
f179fb3
bfd74f2
f179fb3
 
 
 
bfd74f2
 
5d3624b
bfd74f2
927c643
 
 
5d3624b
bfd74f2
927c643
 
bfd74f2
 
 
5d3624b
bfd74f2
09560c2
5d3624b
f786a41
f179fb3
bfd74f2
 
f179fb3
dc93517
927c643
 
bfd74f2
 
 
b9e8f75
bfd74f2
b9e8f75
f786a41
f179fb3
 
bfd74f2
 
 
f786a41
d4170e9
bfd74f2
 
b867149
bfd74f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f179fb3
 
 
5d3624b
bfd74f2
 
 
 
 
5d3624b
bfd74f2
1830448
bfd74f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bcd4d0
bfd74f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae5ba2
bfd74f2
 
 
 
 
 
 
 
 
 
 
 
 
 
dae5ba2
bfd74f2
 
 
 
 
 
 
 
 
 
f651fe9
 
f786a41
bfd74f2
 
f179fb3
f786a41
bfd74f2
 
f786a41
f179fb3
bfd74f2
f786a41
f179fb3
 
bfd74f2
f651fe9
bfd74f2
 
 
f651fe9
bfd74f2
f651fe9
bfd74f2
 
 
 
 
97a7af1
bfd74f2
 
 
f651fe9
bfd74f2
 
5d00e28
bfd74f2
 
 
5d00e28
beb501c
c1defd0
5d00e28
bfd74f2
5d00e28
 
 
f651fe9
 
176aa63
bfd74f2
 
176aa63
 
 
 
 
 
 
bfd74f2
f786a41
176aa63
bfd74f2
f786a41
f179fb3
 
 
bfd74f2
 
 
 
f179fb3
 
 
bfd74f2
 
f179fb3
 
 
 
bfd74f2
 
f179fb3
 
 
bfd74f2
 
 
 
8bbe1e4
bfd74f2
 
f179fb3
8bbe1e4
bfd74f2
 
 
 
 
 
 
 
 
 
 
 
 
a70cb97
8bbe1e4
bfd74f2
 
 
 
 
 
 
 
 
 
 
 
8bbe1e4
f179fb3
 
 
bfd74f2
f179fb3
bfd74f2
ec60550
f651fe9
bfd74f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
Model loading and initialization for Pixagram AI Pixel Art Generator
HYBRID VERSION - Supports both local files and HuggingFace repos
"""
import torch
import time
import os
from diffusers import (
    ControlNetModel,
    AutoencoderKL,
    LCMScheduler
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPVisionModelWithProjection
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector
from controlnet_aux.processor import Processor
from huggingface_hub import hf_hub_download
from compel import Compel, ReturnedEmbeddingsType

# Import the custom pipeline that has load_ip_adapter_instantid method
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline

from config import (
    device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
    FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
)


def download_model_with_retry(repo_id, filename, max_retries=None):
    """Download model with retry logic and proper token handling."""
    if max_retries is None:
        max_retries = DOWNLOAD_CONFIG['max_retries']
    
    for attempt in range(max_retries):
        try:
            print(f"  Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
            
            kwargs = {"repo_type": "model"}
            if HUGGINGFACE_TOKEN:
                kwargs["token"] = HUGGINGFACE_TOKEN
            
            path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                **kwargs
            )
            print(f"  [OK] Downloaded: {filename}")
            return path
            
        except Exception as e:
            print(f"  [WARNING] Download attempt {attempt + 1} failed: {e}")
            
            if attempt < max_retries - 1:
                print(f"  Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
                time.sleep(DOWNLOAD_CONFIG['retry_delay'])
            else:
                print(f"  [ERROR] Failed to download {filename} after {max_retries} attempts")
                raise
    
    return None


def load_face_analysis():
    """Load face analysis model with proper error handling."""
    print("Loading face analysis model...")
    try:
        face_app = FaceAnalysis(
            name='antelopev2',
            root='/data',
            providers=['CPUExecutionProvider']
        )
        face_app.prepare(
            ctx_id=0, 
            det_size=(640, 640)
        )
        print("  [OK] Face analysis model loaded successfully")
        return face_app, True
    except Exception as e:
        print(f"  [WARNING] Face detection not available: {e}")
        return None, False


def load_depth_detector():
    """Load Zoe Depth detector."""
    print("Loading Zoe Depth detector...")
    try:
        zoe_depth = LeresDetector.from_pretrained(
            "lllyasviel/Annotators"
        )
        zoe_depth.to(device)
        print("  [OK] Zoe Depth loaded successfully")
        return zoe_depth, True
    except Exception as e:
        print(f"  [WARNING] Zoe Depth not available: {e}")
        return None, False


def load_controlnets():
    """Load ControlNet models."""
    print("Loading ControlNet Zoe Depth model...")
    controlnet_depth = ControlNetModel.from_pretrained(
        "diffusers/controlnet-zoe-depth-sdxl-1.0",
        torch_dtype=dtype
    ).to(device)
    print("  [OK] ControlNet Depth loaded")
    
    print("Loading InstantID ControlNet...")
    try:
        controlnet_instantid = ControlNetModel.from_pretrained(
            "InstantX/InstantID",
            subfolder="ControlNetModel",
            torch_dtype=dtype
        ).to(device)
        print("  [OK] InstantID ControlNet loaded successfully")
        return controlnet_depth, controlnet_instantid, True
    except Exception as e:
        print(f"  [WARNING] InstantID ControlNet not available: {e}")
        return controlnet_depth, None, False


def load_image_encoder():
    """Load CLIP Image Encoder for IP-Adapter."""
    print("Loading CLIP Image Encoder for IP-Adapter...")
    try:
        image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            "h94/IP-Adapter",
            subfolder="models/image_encoder",
            torch_dtype=dtype
        ).to(device)
        print("  [OK] CLIP Image Encoder loaded successfully")
        return image_encoder
    except Exception as e:
        print(f"  [ERROR] Could not load image encoder: {e}")
        return None


def load_sdxl_pipeline(controlnets):
    """
    Load SDXL checkpoint - HYBRID APPROACH.
    Tries in order:
    1. Local file via from_single_file (like examplemodels.py)
    2. HuggingFace repo via from_pretrained (like exampleapp.py)
    3. Fallback to known working checkpoint
    """
    print("Loading SDXL checkpoint (hybrid approach)...")
    
    # ATTEMPT 1: Try loading from local file using from_single_file
    # This is the examplemodels.py approach
    if MODEL_FILES.get('checkpoint'):
        try:
            print(f"  [Attempt 1] Loading from local file via from_single_file...")
            model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
            
            # Check if file exists and is a safetensors file
            if model_path and os.path.exists(model_path) and model_path.endswith('.safetensors'):
                pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
                    model_path,
                    controlnet=controlnets,
                    torch_dtype=dtype,
                    use_safetensors=True
                ).to(device)
                print(f"  [OK] Checkpoint loaded from local file: {model_path}")
                return pipe, True
            else:
                print(f"  [INFO] Local file not found or invalid, trying next method...")
        except Exception as e:
            print(f"  [WARNING] from_single_file failed: {e}")
            print(f"  [INFO] Trying from_pretrained approach...")
    
    # ATTEMPT 2: Try loading from HuggingFace repo using from_pretrained
    # This is the exampleapp.py approach
    try:
        print(f"  [Attempt 2] Loading from HuggingFace repo via from_pretrained...")
        pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
            MODEL_REPO,
            controlnet=controlnets,
            torch_dtype=dtype,
            use_safetensors=True
        ).to(device)
        print(f"  [OK] Checkpoint loaded from HuggingFace repo: {MODEL_REPO}")
        return pipe, True
    except Exception as e:
        print(f"  [WARNING] from_pretrained failed: {e}")
        print(f"  [INFO] Trying fallback checkpoint...")
    
    # ATTEMPT 3: Fallback to known working checkpoint
    try:
        print(f"  [Attempt 3] Loading fallback: frankjoshua/albedobaseXL_v21...")
        pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
            "frankjoshua/albedobaseXL_v21",
            controlnet=controlnets,
            torch_dtype=dtype,
            use_safetensors=True
        ).to(device)
        print("  [OK] Fallback checkpoint loaded successfully")
        return pipe, False
    except Exception as e:
        print(f"  [WARNING] Fallback also failed: {e}")
        print("  [INFO] Trying SDXL base model...")
    
    # ATTEMPT 4: Last resort - SDXL base
    print(f"  [Attempt 4] Loading base SDXL model...")
    pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        controlnet=controlnets,
        torch_dtype=dtype,
        use_safetensors=True
    ).to(device)
    print("  [OK] Base SDXL model loaded")
    return pipe, False


def load_lora(pipe):
    """Load LORA from HuggingFace Hub."""
    print("Loading LORA (retroart) from HuggingFace Hub...")
    try:
        lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
        pipe.load_lora_weights(lora_path, adapter_name="retroart")
        print(f"  [OK] LORA loaded successfully")
        return True
    except Exception as e:
        print(f"  [WARNING] Could not load LORA: {e}")
        return False


def setup_ip_adapter(pipe):
    """
    Setup IP-Adapter for InstantID - SIMPLIFIED VERSION.
    Uses pipeline's built-in method (like exampleapp.py lines 139-140).
    This is much simpler and more reliable than manual Resampler setup.
    """
    print("Setting up IP-Adapter for InstantID face embeddings...")
    try:
        # Download InstantID IP-Adapter weights
        face_adapter_path = download_model_with_retry(
            "InstantX/InstantID",
            "ip-adapter.bin"
        )
        
        # Use the pipeline's built-in method
        # This handles all the complex Resampler setup automatically
        pipe.load_ip_adapter_instantid(face_adapter_path)
        
        # Set initial scale (can be adjusted later during generation)
        pipe.set_ip_adapter_scale(0.8)
        
        print("  [OK] IP-Adapter loaded successfully with built-in method")
        print("  - Pipeline handles Resampler and attention processors automatically")
        print("  - Face embeddings will be properly integrated during generation")
        
        return True
        
    except Exception as e:
        print(f"  [ERROR] Could not setup IP-Adapter: {e}")
        import traceback
        traceback.print_exc()
        return False


def setup_compel(pipe):
    """Setup Compel for better SDXL prompt handling."""
    print("Setting up Compel for enhanced prompt processing...")
    try:
        compel = Compel(
            tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
            text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
            returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
            requires_pooled=[False, True]
        )
        print("  [OK] Compel loaded successfully")
        return compel, True
    except Exception as e:
        print(f"  [WARNING] Compel not available: {e}")
        return None, False


def setup_scheduler(pipe):
    """Setup LCM scheduler."""
    print("Setting up LCM scheduler...")
    pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
    print("  [OK] LCM scheduler configured")


def optimize_pipeline(pipe):
    """Apply optimizations to pipeline."""
    # Try to enable xformers
    if device == "cuda":
        try:
            pipe.enable_xformers_memory_efficient_attention()
            print("  [OK] xformers enabled")
        except Exception as e:
            print(f"  [INFO] xformers not available: {e}")


def load_caption_model():
    """
    Load caption model with proper error handling.
    Tries multiple models in order of quality.
    """
    print("Loading caption model...")
    
    # Try GIT-Large first (good balance of quality and compatibility)
    try:
        from transformers import AutoProcessor, AutoModelForCausalLM
        
        print("  Attempting GIT-Large (recommended)...")
        caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
        caption_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/git-large-coco",
            torch_dtype=dtype
        ).to(device)
        print("  [OK] GIT-Large model loaded (produces detailed captions)")
        return caption_processor, caption_model, True, 'git'
    except Exception as e1:
        print(f"  [INFO] GIT-Large not available: {e1}")
        
        # Try BLIP base as fallback
        try:
            from transformers import BlipProcessor, BlipForConditionalGeneration
            
            print("  Attempting BLIP base (fallback)...")
            caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
            caption_model = BlipForConditionalGeneration.from_pretrained(
                "Salesforce/blip-image-captioning-base",
                torch_dtype=dtype
            ).to(device)
            print("  [OK] BLIP base model loaded (standard captions)")
            return caption_processor, caption_model, True, 'blip'
        except Exception as e2:
            print(f"  [WARNING] Caption models not available: {e2}")
            print("  Caption generation will be disabled")
            return None, None, False, 'none'


def set_clip_skip(pipe):
    """Set CLIP skip value."""
    if hasattr(pipe, 'text_encoder'):
        print(f"  [OK] CLIP skip set to {CLIP_SKIP}")


print("[OK] Model loading functions ready (HYBRID VERSION)")