File size: 35,373 Bytes
f179fb3
fe30f16
f179fb3
 
 
 
 
 
 
 
fe30f16
a70cb97
fe30f16
 
 
 
 
 
 
 
176aa63
fe30f16
 
f179fb3
 
 
fe30f16
f179fb3
 
 
 
 
 
 
 
69e6233
f179fb3
 
 
fe30f16
f179fb3
 
69e6233
 
 
f179fb3
fe30f16
f179fb3
 
 
 
 
fe30f16
f179fb3
 
 
 
 
fe30f16
f179fb3
 
fe30f16
f179fb3
 
fe30f16
f179fb3
fe30f16
f179fb3
 
 
fe30f16
f179fb3
 
 
fe30f16
f179fb3
 
 
 
fe30f16
f179fb3
 
 
176aa63
 
f179fb3
fe30f16
f179fb3
 
fe30f16
f179fb3
 
fe30f16
29a6101
f179fb3
29a6101
a70cb97
29a6101
a70cb97
29a6101
a70cb97
29a6101
 
 
a70cb97
fe30f16
f179fb3
 
fe30f16
f179fb3
fe30f16
 
 
f179fb3
fe30f16
 
f179fb3
 
fe30f16
f179fb3
 
fe30f16
f179fb3
fe30f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f179fb3
 
69e6233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe30f16
 
 
69e6233
f179fb3
 
fe30f16
f179fb3
 
 
 
 
fe30f16
 
 
 
f179fb3
 
fe30f16
f179fb3
fe30f16
f179fb3
 
 
fe30f16
 
f179fb3
fe30f16
 
f179fb3
 
fe30f16
f179fb3
 
 
fe30f16
f179fb3
 
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
 
 
fe30f16
 
 
 
f179fb3
 
 
 
 
fe30f16
f179fb3
 
fe30f16
 
f179fb3
 
fe30f16
 
f179fb3
 
 
 
 
fe30f16
f179fb3
fe30f16
 
f179fb3
fe30f16
 
 
f179fb3
 
 
 
 
fe30f16
 
 
f179fb3
fe30f16
f179fb3
 
fe30f16
f179fb3
 
 
 
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
 
 
 
 
 
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
f179fb3
fe30f16
 
 
 
 
 
 
f179fb3
fe30f16
f179fb3
 
fe30f16
 
f179fb3
fe30f16
 
 
 
 
 
 
 
 
 
 
f179fb3
fe30f16
 
 
 
 
f179fb3
 
 
 
a70cb97
f179fb3
 
 
a70cb97
f179fb3
a70cb97
 
 
 
 
 
 
f179fb3
 
 
 
a70cb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe30f16
a70cb97
fe30f16
f179fb3
fe30f16
f179fb3
 
 
 
 
fe30f16
 
f179fb3
 
 
 
 
 
 
 
 
 
 
fe30f16
f179fb3
fe30f16
f179fb3
 
 
fe30f16
f179fb3
fe30f16
f179fb3
 
 
 
 
 
fe30f16
f179fb3
 
a70cb97
f179fb3
a70cb97
f179fb3
fe30f16
 
 
 
 
f179fb3
 
fe30f16
69e6233
f179fb3
 
 
 
fe30f16
f179fb3
 
 
 
 
 
 
 
fe30f16
f179fb3
 
 
 
 
fe30f16
 
 
f179fb3
 
fe30f16
f179fb3
 
 
 
 
 
 
 
fe30f16
f179fb3
 
fe30f16
f179fb3
 
 
 
fe30f16
f179fb3
 
 
 
 
 
 
 
 
fe30f16
f179fb3
fe30f16
 
f179fb3
fe30f16
 
f179fb3
 
fe30f16
f179fb3
 
 
fe30f16
 
f179fb3
fe30f16
 
f179fb3
fe30f16
 
 
 
 
 
 
 
 
f179fb3
69e6233
 
f179fb3
69e6233
 
11f9790
69e6233
f179fb3
fe30f16
f179fb3
 
 
 
 
 
 
fe30f16
f179fb3
 
 
fe30f16
f179fb3
 
 
fe30f16
f179fb3
 
 
176aa63
 
 
 
69e6233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176aa63
69e6233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176aa63
69e6233
 
 
 
176aa63
69e6233
176aa63
 
69e6233
 
 
f179fb3
fe30f16
f179fb3
 
 
fe30f16
f179fb3
fe30f16
f179fb3
 
fe30f16
f179fb3
 
 
fe30f16
f179fb3
29a6101
fe30f16
f179fb3
29a6101
 
fe30f16
 
29a6101
 
 
 
f179fb3
69e6233
29a6101
fe30f16
29a6101
 
 
 
 
 
 
 
69e6233
176aa63
69e6233
176aa63
 
 
69e6233
176aa63
 
 
 
 
 
 
 
 
 
 
 
69e6233
 
176aa63
69e6233
f179fb3
fe30f16
29a6101
fe30f16
29a6101
f179fb3
fe30f16
 
 
 
 
 
 
f179fb3
 
fe30f16
 
 
29a6101
fe30f16
 
 
 
 
f179fb3
 
 
fe30f16
f179fb3
fe30f16
f179fb3
 
fe30f16
 
 
 
 
 
f179fb3
 
fe30f16
 
 
f179fb3
fe30f16
f179fb3
 
fe30f16
 
 
f179fb3
 
 
 
69e6233
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
"""
Generation logic for Pixagram AI Pixel Art Generator
"""
import torch
import numpy as np
import cv2
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms

from config import (
    device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
    ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER
)
from utils import (
    sanitize_text, enhanced_color_match, color_match, create_face_mask,
    draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
)
from models import (
    load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
    load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
    setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
)


class RetroArtConverter:
    """Main class for retro art generation"""
    
    def __init__(self):
        self.device = device
        self.dtype = dtype
        self.models_loaded = {
            'custom_checkpoint': False,
            'lora': False,
            'instantid': False,
            'zoe_depth': False,
            'ip_adapter': False
        }
        
        # Initialize face analysis
        self.face_app, self.face_detection_enabled = load_face_analysis()
        
        # Load Zoe Depth detector
        self.zoe_depth, zoe_success = load_depth_detector()
        self.models_loaded['zoe_depth'] = zoe_success
        
        # Load ControlNets
        controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
        self.controlnet_depth = controlnet_depth
        self.instantid_enabled = instantid_success
        self.models_loaded['instantid'] = instantid_success
        
        # Load image encoder
        if self.instantid_enabled:
            self.image_encoder = load_image_encoder()
        else:
            self.image_encoder = None
        
        # Determine which controlnets to use
        if self.instantid_enabled and self.controlnet_instantid is not None:
            controlnets = [self.controlnet_instantid, controlnet_depth]
            print(f"Initializing with multiple ControlNets: InstantID + Depth")
        else:
            controlnets = controlnet_depth
            print(f"Initializing with single ControlNet: Depth only")
        
        # Load SDXL pipeline
        self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
        self.models_loaded['custom_checkpoint'] = checkpoint_success
        
        # Load LORA
        lora_success = load_lora(self.pipe)
        self.models_loaded['lora'] = lora_success
        
        # Setup IP-Adapter
        if self.instantid_enabled and self.image_encoder is not None:
            self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
            self.models_loaded['ip_adapter'] = ip_adapter_success
        else:
            print("[INFO] Face preservation: InstantID ControlNet keypoints only")
            self.models_loaded['ip_adapter'] = False
            self.image_proj_model = None
        
        # Setup Compel
        self.compel, self.use_compel = setup_compel(self.pipe)
        
        # Setup LCM scheduler
        setup_scheduler(self.pipe)
        
        # Optimize pipeline
        optimize_pipeline(self.pipe)
        
        # Load caption model
        self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
        
        # Report caption model status
        if self.caption_enabled and self.caption_model is not None:
            if self.caption_model_type == "git":
                print("  [OK] Using GIT for detailed captions")
            elif self.caption_model_type == "blip":
                print("  [OK] Using BLIP for standard captions")
            else:
                print("  [OK] Caption model loaded")
        
        
        # Set CLIP skip
        set_clip_skip(self.pipe)
        
        # Track controlnet configuration
        self.using_multiple_controlnets = isinstance(controlnets, list)
        print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
        
        # Print model status
        self._print_status()
        
        print("  [OK] Model initialization complete!")
    
    def _print_status(self):
        """Print model loading status"""
        print("\n=== MODEL STATUS ===")
        for model, loaded in self.models_loaded.items():
            status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
            print(f"{model}: {status}")
        print("===================\n")
        
        print("=== UPGRADE VERIFICATION ===")
        try:
            from resampler_enhanced import EnhancedResampler
            from ip_attention_processor_enhanced import EnhancedIPAttnProcessor2_0
            
            resampler_check = isinstance(self.image_proj_model, EnhancedResampler) if hasattr(self, 'image_proj_model') and self.image_proj_model is not None else False
            custom_attn_check = any(isinstance(p, EnhancedIPAttnProcessor2_0) for p in self.pipe.unet.attn_processors.values()) if hasattr(self, 'pipe') else False
            
            print(f"Enhanced Perceiver Resampler: {'[OK] ACTIVE' if resampler_check else '[INFO] Not active'}")
            print(f"Enhanced IP-Adapter Attention: {'[OK] ACTIVE' if custom_attn_check else '[INFO] Not active'}")
            
            if resampler_check and custom_attn_check:
                print("[SUCCESS] Face preservation upgrade fully active")
                print("  Expected improvement: +10-15% face similarity")
            elif resampler_check or custom_attn_check:
                print("[PARTIAL] Some upgrades active")
            else:
                print("[INFO] Using standard components")
        except Exception as e:
            print(f"[INFO] Verification skipped: {e}")
        print("============================\n")
    
    def get_depth_map(self, image):
        """Generate depth map using Zoe Depth"""
        if self.zoe_depth is not None:
            try:
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                
                orig_width, orig_height = image.size
                orig_width = int(orig_width)
                orig_height = int(orig_height)
                
                # FIXED: Use multiples of 64 (not 32)
                target_width = int((orig_width // 64) * 64)
                target_height = int((orig_height // 64) * 64)
                
                target_width = int(max(64, target_width))
                target_height = int(max(64, target_height))
                
                if target_width != orig_width or target_height != orig_height:
                    image = image.resize((int(target_width), int(target_height)), Image.LANCZOS)
                    print(f"[DEPTH] Resized for ZoeDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
                
                # FIXED: Add torch.no_grad() wrapper
                with torch.no_grad():
                    depth_image = self.zoe_depth(image)
                
                depth_width, depth_height = depth_image.size
                if depth_width != orig_width or depth_height != orig_height:
                    depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
                
                print(f"[DEPTH] Zoe depth map generated: {orig_width}x{orig_height}")
                return depth_image
                
            except Exception as e:
                print(f"[DEPTH] ZoeDetector failed ({e}), falling back to grayscale depth")
                gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
                depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
                return Image.fromarray(depth_colored)
        else:
            gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
            depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
            return Image.fromarray(depth_colored)

    
    def add_trigger_word(self, prompt):
        """Add trigger word to prompt if not present"""
        if TRIGGER_WORD.lower() not in prompt.lower():
            return f"{TRIGGER_WORD}, {prompt}"
        return prompt
    
    def extract_multi_scale_face(self, face_crop, face):
        """
        Extract face features at multiple scales for better detail.
        +1-2% improvement in face preservation.
        """
        try:
            multi_scale_embeds = []
            
            for scale in MULTI_SCALE_FACTORS:
                # Resize
                w, h = face_crop.size
                scaled_size = (int(w * scale), int(h * scale))
                scaled_crop = face_crop.resize(scaled_size, Image.LANCZOS)
                
                # Pad/crop back to original
                scaled_crop = scaled_crop.resize((w, h), Image.LANCZOS)
                
                # Extract features
                scaled_array = cv2.cvtColor(np.array(scaled_crop), cv2.COLOR_RGB2BGR)
                scaled_faces = self.face_app.get(scaled_array)
                
                if len(scaled_faces) > 0:
                    multi_scale_embeds.append(scaled_faces[0].normed_embedding)
            
            # Average embeddings
            if len(multi_scale_embeds) > 0:
                averaged = np.mean(multi_scale_embeds, axis=0)
                # Renormalize
                averaged = averaged / np.linalg.norm(averaged)
                print(f"[MULTI-SCALE] Combined {len(multi_scale_embeds)} scales")
                return averaged
            
            return face.normed_embedding
        
        except Exception as e:
            print(f"[MULTI-SCALE] Failed: {e}, using single scale")
            return face.normed_embedding
    
    def detect_face_quality(self, face):
        """
        Detect face quality and adaptively adjust parameters.
        +2-3% consistency improvement.
        """
        try:
            bbox = face.bbox
            face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
            det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
            
            # Small face -> boost identity preservation
            if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
                return ADAPTIVE_PARAMS['small_face'].copy()
            
            # Low confidence -> boost preservation
            elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
                return ADAPTIVE_PARAMS['low_confidence'].copy()
            
            # Check for profile/side view (if pose available)
            elif hasattr(face, 'pose') and len(face.pose) > 1:
                try:
                    yaw = float(face.pose[1])
                    if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
                        return ADAPTIVE_PARAMS['profile_view'].copy()
                except (ValueError, TypeError, IndexError):
                    pass
            
            # Good quality face - use provided parameters
            return None
        
        except Exception as e:
            print(f"[ADAPTIVE] Quality detection failed: {e}")
            return None
    
    def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale, 
                                       identity_preservation, identity_control_scale,
                                       depth_control_scale, consistency_mode=True):
        """
        Enhanced parameter validation with stricter rules for consistency.
        """
        if consistency_mode:
            print("[CONSISTENCY] Applying strict parameter validation...")
            adjustments = []
            
            # Rule 1: Strong inverse relationship between identity and LORA
            if identity_preservation > 1.2:
                original_lora = lora_scale
                lora_scale = min(lora_scale, 1.0)
                if abs(lora_scale - original_lora) > 0.01:
                    adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (high identity)")
            
            # Rule 2: Strength-based profile activation
            if strength < 0.5:
                # Maximum preservation mode
                if identity_preservation < 1.3:
                    original_identity = identity_preservation
                    identity_preservation = 1.3
                    adjustments.append(f"Identity: {original_identity:.2f}->{identity_preservation:.2f} (max preservation)")
                if lora_scale > 0.9:
                    original_lora = lora_scale
                    lora_scale = 0.9
                    adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (max preservation)")
                if guidance_scale > 1.3:
                    original_cfg = guidance_scale
                    guidance_scale = 1.3
                    adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (max preservation)")
                    
            elif strength > 0.7:
                # Artistic transformation mode
                if identity_preservation > 1.0:
                    original_identity = identity_preservation
                    identity_preservation = 1.0
                    adjustments.append(f"Identity: {original_identity:.2f}->{identity_preservation:.2f} (artistic mode)")
                if lora_scale < 1.2:
                    original_lora = lora_scale
                    lora_scale = 1.2
                    adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (artistic mode)")
            
            # Rule 3: CFG-LORA relationship
            if guidance_scale > 1.4 and lora_scale > 1.2:
                original_lora = lora_scale
                lora_scale = 1.1
                adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (high CFG detected)")
            
            # Rule 4: LCM sweet spot enforcement
            original_cfg = guidance_scale
            guidance_scale = max(1.0, min(guidance_scale, 1.5))
            if abs(guidance_scale - original_cfg) > 0.01:
                adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
            
            # Rule 5: ControlNet balance
            total_control = identity_control_scale + depth_control_scale
            if total_control > 1.7:
                scale_factor = 1.7 / total_control
                original_id_ctrl = identity_control_scale
                original_depth_ctrl = depth_control_scale
                identity_control_scale *= scale_factor
                depth_control_scale *= scale_factor
                adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}")
            
            # Report adjustments
            if adjustments:
                print("  [OK] Applied adjustments:")
                for adj in adjustments:
                    print(f"    - {adj}")
            else:
                print("  [OK] Parameters already optimal")
        
        return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale
    
    def generate_caption(self, image, max_length=None, num_beams=None):
        """Generate a descriptive caption for the image (supports BLIP-2, GIT, BLIP)."""
        if not self.caption_enabled or self.caption_model is None:
            return None
        
        # Set defaults based on model type
        if max_length is None:
            if self.caption_model_type == "blip2":
                max_length = 50  # BLIP-2 can handle longer captions
            elif self.caption_model_type == "git":
                max_length = 40  # GIT also produces good long captions
            else:
                max_length = CAPTION_CONFIG['max_length']  # BLIP base (20)
        
        if num_beams is None:
            num_beams = CAPTION_CONFIG['num_beams']
        
        try:
            if self.caption_model_type == "blip2":
                # BLIP-2 specific processing
                inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
                
                with torch.no_grad():
                    output = self.caption_model.generate(
                        **inputs,
                        max_length=max_length,
                        num_beams=num_beams,
                        min_length=10,  # Encourage longer captions
                        length_penalty=1.0,
                        repetition_penalty=1.5,
                        early_stopping=True
                    )
                
                caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
                
            elif self.caption_model_type == "git":
                # GIT specific processing
                inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device, self.dtype)
                
                with torch.no_grad():
                    output = self.caption_model.generate(
                        pixel_values=inputs.pixel_values,
                        max_length=max_length,
                        num_beams=num_beams,
                        min_length=10,
                        length_penalty=1.0,
                        repetition_penalty=1.5,
                        early_stopping=True
                    )
                
                caption = self.caption_processor.batch_decode(output, skip_special_tokens=True)[0]
                
            else:
                # BLIP base processing
                inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
                
                with torch.no_grad():
                    output = self.caption_model.generate(
                        **inputs,
                        max_length=max_length,
                        num_beams=num_beams,
                        early_stopping=True
                    )
                
                caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
            
            return caption.strip()
        
        except Exception as e:
            print(f"Caption generation failed: {e}")
            return None
    
    def generate_retro_art(
        self,
        input_image,
        prompt="retro game character, vibrant colors, detailed",
        negative_prompt="blurry, low quality, ugly, distorted",
        num_inference_steps=12,
        guidance_scale=1.0,
        depth_control_scale=0.8,
        identity_control_scale=0.85,
        lora_scale=1.0,
        identity_preservation=0.8,
        strength=0.75,
        enable_color_matching=False,
        consistency_mode=True,
        seed=-1
    ):
        """Generate retro art with img2img pipeline and enhanced InstantID"""
        
        # Sanitize text inputs
        prompt = sanitize_text(prompt)
        negative_prompt = sanitize_text(negative_prompt)
        
        # Apply parameter validation
        if consistency_mode:
            print("\n[CONSISTENCY] Validating and adjusting parameters...")
            strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale = \
                self.validate_and_adjust_parameters(
                    strength, guidance_scale, lora_scale, identity_preservation, 
                    identity_control_scale, depth_control_scale, consistency_mode
                )
        
        # Add trigger word
        prompt = self.add_trigger_word(prompt)
        
        # Calculate optimal size with flexible aspect ratio support
        original_width, original_height = input_image.size
        target_width, target_height = calculate_optimal_size(original_width, original_height)
        
        print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
        print(f"Prompt: {prompt}")
        print(f"Img2Img Strength: {strength}")
        
        # Resize with high quality
        resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
        
        # Generate depth map
        print("Generating Zoe depth map...")
        depth_image = self.get_depth_map(resized_image)
        if depth_image.size != (target_width, target_height):
            depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
        
        # Handle face detection
        using_multiple_controlnets = self.using_multiple_controlnets
        face_kps_image = None
        face_embeddings = None
        face_crop_enhanced = None
        has_detected_faces = False
        face_bbox_original = None
        
        if using_multiple_controlnets and self.face_app is not None:
            print("Detecting faces and extracting keypoints...")
            img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
            faces = self.face_app.get(img_array)
            
            if len(faces) > 0:
                has_detected_faces = True
                print(f"Detected {len(faces)} face(s)")
                
                # Get largest face
                face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
                
                # ADAPTIVE PARAMETERS
                adaptive_params = self.detect_face_quality(face)
                if adaptive_params is not None:
                    print(f"[ADAPTIVE] {adaptive_params['reason']}")
                    identity_preservation = adaptive_params['identity_preservation']
                    identity_control_scale = adaptive_params['identity_control_scale']
                    guidance_scale = adaptive_params['guidance_scale']
                    lora_scale = adaptive_params['lora_scale']
                
                # Extract face embeddings
                face_embeddings_base = face.normed_embedding
                
                # Extract face crop
                bbox = face.bbox.astype(int)
                x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
                face_bbox_original = [x1, y1, x2, y2]
                
                # Add padding
                face_width = x2 - x1
                face_height = y2 - y1
                padding_x = int(face_width * 0.3)
                padding_y = int(face_height * 0.3)
                x1 = max(0, x1 - padding_x)
                y1 = max(0, y1 - padding_y)
                x2 = min(resized_image.width, x2 + padding_x)
                y2 = min(resized_image.height, y2 + padding_y)
                
                # Crop face region
                face_crop = resized_image.crop((x1, y1, x2, y2))
                
                # MULTI-SCALE PROCESSING
                face_embeddings = self.extract_multi_scale_face(face_crop, face)
                
                # Enhance face crop
                face_crop_enhanced = enhance_face_crop(face_crop)
                
                # Draw keypoints
                face_kps = face.kps
                face_kps_image = draw_kps(resized_image, face_kps)
                
                # ENHANCED: Extract comprehensive facial attributes
                from utils import get_facial_attributes, build_enhanced_prompt
                facial_attrs = get_facial_attributes(face)
                
                # Update prompt with detected attributes
                prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD)
                
                # Legacy output for compatibility
                age = facial_attrs['age']
                gender_code = facial_attrs['gender']
                det_score = facial_attrs['quality']
                
                gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
                print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
                print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
        
        # Set LORA scale
        if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
            try:
                self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
                print(f"LORA scale: {lora_scale}")
            except Exception as e:
                print(f"Could not set LORA scale: {e}")
        
        # Prepare generation kwargs
        pipe_kwargs = {
            "image": resized_image,
            "strength": strength,
            "num_inference_steps": num_inference_steps,
            "guidance_scale": guidance_scale,
        }
        
        # Setup generator with seed control
        if seed == -1:
            generator = torch.Generator(device=self.device)
            actual_seed = generator.seed()
            print(f"[SEED] Using random seed: {actual_seed}")
        else:
            generator = torch.Generator(device=self.device).manual_seed(seed)
            actual_seed = seed
            print(f"[SEED] Using fixed seed: {actual_seed}")
        
        pipe_kwargs["generator"] = generator
        
        if self.use_compel and self.compel is not None:
            try:
                print("Encoding prompts with Compel...")
                
                try:
                    # Tuple unpacking: (prompt_embeds, pooled_prompt_embeds)
                    conditioning = self.compel(prompt)
                    prompt_embeds, pooled_prompt_embeds = conditioning
                    
                    # Handle negative prompt conditionally
                    if negative_prompt and negative_prompt.strip():
                        negative_conditioning = self.compel(negative_prompt)
                        negative_prompt_embeds, negative_pooled_prompt_embeds = negative_conditioning
                    else:
                        # Use zeros for negative
                        negative_prompt_embeds = torch.zeros_like(prompt_embeds)
                        negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
                        
                except RuntimeError as e:
                    error_msg = str(e)
                    if ("size of tensor" in error_msg and "must match" in error_msg) or "dimension" in error_msg:
                        print(f"[COMPEL] Token length mismatch detected: {e}")
                        print(f"[COMPEL] Falling back to standard prompt encoding")
                        raise
                    else:
                        raise
                
                # Handle token length mismatch by padding/truncating to 77 tokens
                target_length = 77
                
                if prompt_embeds.shape[1] != target_length or negative_prompt_embeds.shape[1] != target_length:
                    print(f"[COMPEL] Adjusting token lengths: pos={prompt_embeds.shape[1]}, neg={negative_prompt_embeds.shape[1]} -> {target_length}")
                    
                    # Truncate or pad positive embeddings
                    if prompt_embeds.shape[1] > target_length:
                        prompt_embeds = prompt_embeds[:, :target_length, :]
                    elif prompt_embeds.shape[1] < target_length:
                        padding = torch.zeros(
                            prompt_embeds.shape[0], 
                            target_length - prompt_embeds.shape[1], 
                            prompt_embeds.shape[2],
                            dtype=prompt_embeds.dtype,
                            device=prompt_embeds.device
                        )
                        prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
                    
                    # Truncate or pad negative embeddings
                    if negative_prompt_embeds.shape[1] > target_length:
                        negative_prompt_embeds = negative_prompt_embeds[:, :target_length, :]
                    elif negative_prompt_embeds.shape[1] < target_length:
                        padding = torch.zeros(
                            negative_prompt_embeds.shape[0],
                            target_length - negative_prompt_embeds.shape[1],
                            negative_prompt_embeds.shape[2],
                            dtype=negative_prompt_embeds.dtype,
                            device=negative_prompt_embeds.device
                        )
                        negative_prompt_embeds = torch.cat([negative_prompt_embeds, padding], dim=1)
                
                pipe_kwargs["prompt_embeds"] = prompt_embeds
                pipe_kwargs["pooled_prompt_embeds"] = pooled_prompt_embeds
                pipe_kwargs["negative_prompt_embeds"] = negative_prompt_embeds
                pipe_kwargs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
                
                compel_success = True
                print("[OK] Using Compel-encoded prompts")
            except Exception as e:
                print(f"[COMPEL] Encoding failed: {e}")
                print(f"[COMPEL] Using standard prompt encoding instead")
                compel_success = False
        
        # Add CLIP skip
        if hasattr(self.pipe, 'text_encoder'):
            pipe_kwargs["clip_skip"] = 2
        
        # Configure ControlNet inputs
        if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
            print("Using InstantID (keypoints) + Depth ControlNets")
            control_images = [face_kps_image, depth_image]
            conditioning_scales = [identity_control_scale, depth_control_scale]
            
            pipe_kwargs["control_image"] = control_images
            pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
            
            # Add face embeddings for IP-Adapter if available
            if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None:
                print(f"Processing InstantID face embeddings with Resampler...")
                
                with torch.no_grad():
                    # Convert InsightFace embeddings to tensor
                    face_emb_tensor = torch.from_numpy(face_embeddings).to(
                        device=self.device, 
                        dtype=self.dtype
                    )
                    
                    # Reshape for Resampler: [1, 1, 512]
                    face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
                    
                    # Pass through Resampler: [1, 1, 512] → [1, 16, 2048]
                    face_proj_embeds = self.image_proj_model(face_emb_tensor)
                    
                    # Scale with identity preservation
                    boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
                    face_proj_embeds = face_proj_embeds * boosted_scale
                    
                    print(f"  - Face embedding: {face_emb_tensor.shape}")
                    print(f"  - Resampler output: {face_proj_embeds.shape}")
                    print(f"  - Scale: {boosted_scale:.2f}")
                    
                    # CRITICAL: Concatenate with text embeddings (not separate kwargs!)
                    if 'prompt_embeds' in pipe_kwargs:
                        # Compel encoded prompts
                        original_embeds = pipe_kwargs['prompt_embeds']
                        
                        # Handle CFG (classifier-free guidance)
                        if original_embeds.shape[0] > 1:  # Has negative + positive
                            # Duplicate for negative + positive
                            face_proj_embeds = torch.cat([
                                torch.zeros_like(face_proj_embeds),  # Negative
                                face_proj_embeds                      # Positive
                            ], dim=0)
                        
                        # Concatenate: [batch, text_tokens, 2048] + [batch, 16, 2048]
                        combined_embeds = torch.cat([original_embeds, face_proj_embeds], dim=1)
                        pipe_kwargs['prompt_embeds'] = combined_embeds
                        
                        print(f"  - Text embeds: {original_embeds.shape}")
                        print(f"  - Combined embeds: {combined_embeds.shape}")
                        print(f"  [OK] Face embeddings concatenated successfully!")
                        
                    else:
                        print(f"  [WARNING] Can't concatenate - no prompt_embeds (use Compel)")
            
            elif has_detected_faces and self.models_loaded.get('ip_adapter', False):
                # Face detected but embeddings unavailable
                print("  Face detected but embeddings unavailable, using keypoints only")
                # No need for dummy embeddings with concatenation approach
        
        elif using_multiple_controlnets and not has_detected_faces:
            print("Multiple ControlNets available but no faces detected, using depth only")
            control_images = [depth_image, depth_image]
            conditioning_scales = [0.0, depth_control_scale]
            
            pipe_kwargs["control_image"] = control_images
            pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
        
        else:
            print("Using Depth ControlNet only")
            pipe_kwargs["control_image"] = depth_image
            pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
        
        
        # Generate
        print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
        print(f"Controlnet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}")
        result = self.pipe(**pipe_kwargs)
        
        generated_image = result.images[0]
        
        # Post-processing
        if enable_color_matching and has_detected_faces:
            print("Applying enhanced face-aware color matching...")
            try:
                if face_bbox_original is not None:
                    generated_image = enhanced_color_match(
                        generated_image, 
                        resized_image, 
                        face_bbox=face_bbox_original
                    )
                    print("[OK] Enhanced color matching applied (face-aware)")
                else:
                    generated_image = color_match(generated_image, resized_image, mode='mkl')
                    print("[OK] Standard color matching applied")
            except Exception as e:
                print(f"Color matching failed: {e}")
        elif enable_color_matching:
            print("Applying standard color matching...")
            try:
                generated_image = color_match(generated_image, resized_image, mode='mkl')
                print("[OK] Standard color matching applied")
            except Exception as e:
                print(f"Color matching failed: {e}")
        
        return generated_image


print("[OK] Generator class ready")