File size: 16,643 Bytes
f179fb3
bde5828
 
f179fb3
 
 
 
 
bde5828
f179fb3
fe30f16
bde5828
 
fe30f16
 
bde5828
 
fe30f16
 
d432eb2
 
fe30f16
 
f179fb3
 
 
bde5828
f179fb3
 
 
 
 
 
 
 
d432eb2
f179fb3
 
bde5828
f179fb3
 
bde5828
69e6233
 
f179fb3
bde5828
d432eb2
 
 
f179fb3
d432eb2
f179fb3
d432eb2
f179fb3
 
 
fe30f16
f179fb3
 
 
176aa63
 
f179fb3
bde5828
f179fb3
 
bde5828
f179fb3
 
fe30f16
29a6101
f179fb3
fe30f16
f179fb3
 
bde5828
f179fb3
fe30f16
bde5828
f179fb3
 
fe30f16
f179fb3
 
fe30f16
f179fb3
bde5828
 
fe30f16
f179fb3
f30dc7a
bde5828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4236be3
bde5828
 
 
 
 
 
 
 
 
 
 
 
d432eb2
bde5828
d432eb2
 
 
 
 
bde5828
 
d432eb2
 
bde5828
 
 
d432eb2
 
 
 
f179fb3
bde5828
f179fb3
 
 
 
 
bde5828
 
f179fb3
d432eb2
 
f179fb3
 
d432eb2
 
f179fb3
 
 
 
bde5828
0775e46
bde5828
 
 
 
 
d432eb2
bde5828
d432eb2
bde5828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d432eb2
bde5828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22a70f
bde5828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d432eb2
bde5828
 
 
 
 
 
 
 
 
 
 
8f934d6
 
fe30f16
bde5828
 
 
d432eb2
bde5828
 
d432eb2
bde5828
 
 
 
 
d432eb2
bde5828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d432eb2
bde5828
 
 
 
 
d432eb2
bde5828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f179fb3
fe30f16
bde5828
 
 
 
d432eb2
bde5828
 
 
 
 
f179fb3
 
954ca3f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
"""
Generation logic for Pixagram AI Pixel Art Generator  
UPDATED VERSION with InstantID pipeline integration
"""
import torch
import numpy as np
import cv2
from PIL import Image
import gc

from config import (
    device, dtype, TRIGGER_WORD,
    ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG
)
from utils import (
    sanitize_text, enhanced_color_match, color_match,
    get_demographic_description, calculate_optimal_size, safe_image_size
)
from models import (
    load_face_analysis, load_depth_detector, load_controlnets,
    load_sdxl_pipeline, load_lora, setup_compel,
    setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
)


class RetroArtConverter:
    """Main class for retro art generation with InstantID"""
    
    def __init__(self):
        self.device = device
        self.dtype = dtype
        self.models_loaded = {
            'custom_checkpoint': False,
            'lora': False,
            'instantid': False,
            'zoe_depth': False
        }
        
        # Load face analysis
        self.face_app, self.face_detection_enabled = load_face_analysis()
        
        # Load depth detector
        self.zoe_depth, zoe_success = load_depth_detector()
        self.models_loaded['zoe_depth'] = zoe_success
        
        # Load ControlNets AS LIST
        controlnet_instantid, controlnet_depth = load_controlnets()
        controlnets = [controlnet_instantid, controlnet_depth]
        self.models_loaded['instantid'] = True
        
        print("Initializing InstantID pipeline with Face + Depth ControlNets")
        
        # Load SDXL pipeline with InstantID (handles IP-Adapter internally)
        self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
        self.models_loaded['custom_checkpoint'] = checkpoint_success
        
        # Load LORA
        lora_success = load_lora(self.pipe)
        self.models_loaded['lora'] = lora_success
        
        # Setup Compel
        self.compel, self.use_compel = setup_compel(self.pipe)
        
        # Setup scheduler
        setup_scheduler(self.pipe)
        
        # Optimize
        optimize_pipeline(self.pipe)
        
        # Load caption model
        self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
        
        # Set CLIP skip
        set_clip_skip(self.pipe)
        
        # Print status
        self._print_status()
        
        print("  [OK] RetroArtConverter initialized with InstantID!")
    
    def _print_status(self):
        """Print model loading status"""
        print("\n=== MODEL STATUS ===")
        for model, loaded in self.models_loaded.items():
            status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
            print(f"{model}: {status}")
        print("InstantID Pipeline: [OK] ACTIVE")
        print("IP-Adapter: [OK] Built into pipeline")
        print("===================\n")
    
    def get_depth_map(self, image):
        """Generate depth map using Zoe Depth"""
        if self.zoe_depth is not None:
            try:
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                
                # Use safe size helper to avoid numpy.int64 issues
                orig_width, orig_height = safe_image_size(image)
                
                # Use multiples of 64
                target_width = int((orig_width // 64) * 64)
                target_height = int((orig_height // 64) * 64)
                
                target_width = int(max(64, target_width))
                target_height = int(max(64, target_height))
                
                size_for_depth = (int(target_width), int(target_height))
                image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
                
                depth_array = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
                depth_image = Image.fromarray(depth_array)
                
                if depth_image.size != image.size:
                    depth_image = depth_image.resize(image.size, Image.LANCZOS)
                
                print(f"[DEPTH] Generated depth map: {depth_image.size}")
                return depth_image, depth_array
            except Exception as e:
                print(f"[DEPTH] Generation failed: {e}, using grayscale")
                return image.convert('L').convert('RGB'), None
        else:
            print("[DEPTH] Detector not available, using grayscale")
            return image.convert('L').convert('RGB'), None
    
    def add_trigger_word(self, prompt):
        """Add trigger word to prompt if not present"""
        if TRIGGER_WORD.lower() not in prompt.lower():
            if not prompt or not prompt.strip():
                return TRIGGER_WORD
            return f"{TRIGGER_WORD}, {prompt}"
        return prompt
    
    def detect_face_quality(self, face):
        """Detect face quality and adaptively adjust parameters"""
        try:
            bbox = face.bbox
            face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
            det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
            
            # Small face -> boost preservation
            if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
                return ADAPTIVE_PARAMS['small_face'].copy()
            
            # Low confidence -> boost preservation
            elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
                return ADAPTIVE_PARAMS['low_confidence'].copy()
            
            # Check for profile view
            elif hasattr(face, 'pose') and len(face.pose) > 1:
                try:
                    yaw = float(face.pose[1])
                    if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
                        return ADAPTIVE_PARAMS['profile_view'].copy()
                except (ValueError, TypeError, IndexError):
                    pass
            
            return None
        
        except Exception as e:
            print(f"[ADAPTIVE] Quality detection failed: {e}")
            return None
    
    def generate_caption(self, image):
        """Generate caption for image"""
        if not self.caption_enabled or self.caption_model is None:
            return None
        
        try:
            if self.caption_model_type == 'git':
                inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
                generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
                caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            elif self.caption_model_type == 'blip':
                inputs = self.caption_processor(image, return_tensors="pt").to(self.device)
                generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
                caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
            else:
                return None
            
            return sanitize_text(caption)
        except Exception as e:
            print(f"[CAPTION] Generation failed: {e}")
            return None
    
    def generate_retro_art(
        self,
        input_image,
        prompt=" ",
        negative_prompt=" ",
        num_inference_steps=12,
        guidance_scale=1.3,
        depth_control_scale=0.75,
        identity_control_scale=0.85,
        lora_scale=1.0,
        identity_preservation=1.2,
        strength=0.50,
        enable_color_matching=False,
        consistency_mode=True,
        seed=-1
    ):
        """Generate retro art with InstantID face preservation"""
        
        try:
            # Add trigger word
            prompt = self.add_trigger_word(prompt)
            prompt = sanitize_text(prompt)
            negative_prompt = sanitize_text(negative_prompt)
            
            print(f"[PROMPT] {prompt}")
            
            # Calculate optimal size
            orig_width, orig_height = safe_image_size(input_image)
            optimal_width, optimal_height = calculate_optimal_size(orig_width, orig_height)
            
            # Resize image
            resized_image = input_image.resize((optimal_width, optimal_height), Image.LANCZOS)
            print(f"[SIZE] Resized to {optimal_width}x{optimal_height}")
            
            # Generate depth map
            depth_image, depth_array = self.get_depth_map(resized_image)
            
            # Detect faces
            has_detected_faces = False
            face_kps_image = None
            face_embeddings = None
            face_bbox_original = None
            
            if self.face_detection_enabled and self.face_app is not None:
                try:
                    image_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
                    faces = self.face_app.get(image_array)
                    
                    if len(faces) > 0:
                        has_detected_faces = True
                        face = faces[0]
                        
                        # Get face embeddings (512D array)
                        face_embeddings = face.normed_embedding
                        
                        # Draw keypoints
                        from pipeline_stable_diffusion_xl_instantid_img2img import draw_kps
                        face_kps_image = draw_kps(resized_image, face.kps)
                        
                        # Get bbox for color matching
                        face_bbox_original = face.bbox
                        
                        # Adaptive parameter adjustment
                        adaptive_params = self.detect_face_quality(face)
                        if adaptive_params:
                            print(f"[ADAPTIVE] {adaptive_params['reason']}")
                            identity_preservation = adaptive_params.get('identity_preservation', identity_preservation)
                            identity_control_scale = adaptive_params.get('identity_control_scale', identity_control_scale)
                            guidance_scale = adaptive_params.get('guidance_scale', guidance_scale)
                            lora_scale = adaptive_params.get('lora_scale', lora_scale)
                        
                        print(f"[FACE] Detected face with {face.det_score:.2f} confidence")
                        print(f"[FACE] Embeddings shape: {face_embeddings.shape}")
                    else:
                        print("[FACE] No faces detected")
                
                except Exception as e:
                    print(f"[FACE] Detection failed: {e}")
            
            # Set LORA scale
            if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
                try:
                    self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
                    print(f"[LORA] Scale: {lora_scale}")
                except Exception as e:
                    print(f"[LORA] Could not set scale: {e}")
            
            # Prepare generation kwargs
            pipe_kwargs = {
                "image": resized_image,
                "strength": strength,
                "num_inference_steps": num_inference_steps,
                "guidance_scale": guidance_scale,
            }
            
            # Setup generator with seed
            if seed == -1:
                generator = torch.Generator(device=self.device)
                actual_seed = generator.seed()
                print(f"[SEED] Random: {actual_seed}")
            else:
                generator = torch.Generator(device=self.device).manual_seed(seed)
                actual_seed = seed
                print(f"[SEED] Fixed: {actual_seed}")
            
            pipe_kwargs["generator"] = generator
            
            # Use Compel for prompt encoding
            if self.use_compel and self.compel is not None:
                try:
                    conditioning = self.compel(prompt)
                    negative_conditioning = self.compel(negative_prompt)
                    
                    pipe_kwargs["prompt_embeds"] = conditioning[0]
                    pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
                    pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
                    pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
                    
                    print("[OK] Using Compel-encoded prompts")
                except Exception as e:
                    print(f"[COMPEL] Failed, using standard prompts: {e}")
                    pipe_kwargs["prompt"] = prompt
                    pipe_kwargs["negative_prompt"] = negative_prompt
            else:
                pipe_kwargs["prompt"] = prompt
                pipe_kwargs["negative_prompt"] = negative_prompt
            
            # Configure ControlNets + IP-Adapter (SIMPLIFIED!)
            if has_detected_faces and face_kps_image is not None:
                print("Using InstantID (keypoints + embeddings) + Depth ControlNets")
                
                # Control images: [face keypoints, depth map]
                pipe_kwargs["control_image"] = [face_kps_image, depth_image]
                
                # Conditioning scales: [identity, depth]
                pipe_kwargs["controlnet_conditioning_scale"] = [
                    identity_control_scale,
                    depth_control_scale
                ]
                
                # IP-Adapter face embeddings (SIMPLE - pipeline handles everything!)
                if face_embeddings is not None:
                    print(f"Adding face embeddings for IP-Adapter...")
                    
                    # Just pass the embeddings - pipeline does the rest!
                    pipe_kwargs["image_embeds"] = face_embeddings
                    
                    # Control IP-Adapter strength
                    pipe_kwargs["ip_adapter_scale"] = identity_preservation
                    
                    print(f"  - Face embeddings shape: {face_embeddings.shape}")
                    print(f"  - IP-Adapter scale: {identity_preservation}")
                    print(f"  [OK] Face embeddings configured")
                else:
                    print("  [WARNING] No face embeddings - using keypoints only")
            
            else:
                print("No faces detected - using Depth ControlNet only")
                
                # Use depth for both ControlNet slots (identity scale = 0)
                pipe_kwargs["control_image"] = [depth_image, depth_image]
                pipe_kwargs["controlnet_conditioning_scale"] = [0.0, depth_control_scale]
            
            # Generate
            print(f"Generating: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
            result = self.pipe(**pipe_kwargs)
            
            generated_image = result.images[0]
            
            # Post-processing: Color matching
            if enable_color_matching and has_detected_faces:
                print("Applying enhanced face-aware color matching...")
                try:
                    if face_bbox_original is not None:
                        generated_image = enhanced_color_match(
                            generated_image,
                            resized_image,
                            face_bbox=face_bbox_original
                        )
                        print("[OK] Enhanced color matching applied")
                    else:
                        generated_image = color_match(generated_image, resized_image, mode='mkl')
                        print("[OK] Standard color matching applied")
                except Exception as e:
                    print(f"[COLOR] Matching failed: {e}")
            elif enable_color_matching:
                print("Applying standard color matching...")
                try:
                    generated_image = color_match(generated_image, resized_image, mode='mkl')
                    print("[OK] Standard color matching applied")
                except Exception as e:
                    print(f"[COLOR] Matching failed: {e}")
            
            return generated_image
        
        finally:
            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()


print("[OK] Generator class ready with InstantID support")