primerz commited on
Commit
6557cf9
·
verified ·
1 Parent(s): bbcd03c

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +313 -427
generator.py CHANGED
@@ -1,32 +1,31 @@
1
  """
2
- Generation logic for Pixagram AI Pixel Art Generator
3
- CORRECTED VERSION - Following examplewithface.py pattern
4
  """
5
  import torch
6
  import numpy as np
7
  import cv2
8
  from PIL import Image
9
- import gc
 
10
 
11
  from config import (
12
- device, dtype, TRIGGER_WORD,
13
- ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG
14
  )
15
  from utils import (
16
- sanitize_text, enhanced_color_match, color_match,
17
- get_demographic_description, calculate_optimal_size, safe_image_size
18
  )
19
- from models import (
20
- load_face_analysis, load_depth_detector, load_controlnets,
21
- load_sdxl_pipeline, load_lora, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
23
  )
24
- from pipeline_stable_diffusion_xl_instantid_img2img import draw_kps
25
- from memory_utils import MemoryManager, ModelOffloader
26
 
27
 
28
  class RetroArtConverter:
29
- """Main class for retro art generation with InstantID"""
30
 
31
  def __init__(self):
32
  self.device = device
@@ -35,29 +34,38 @@ class RetroArtConverter:
35
  'custom_checkpoint': False,
36
  'lora': False,
37
  'instantid': False,
38
- 'zoe_depth': False
 
39
  }
40
 
41
- # Initialize memory manager
42
- self.memory_manager = MemoryManager(device=device, dtype=dtype, verbose=True)
43
 
44
- # Load face analysis (like examplewithface.py line 113)
45
- self.face_app, face_detection_success = load_face_analysis()
46
- if not face_detection_success or self.face_app is None:
47
- raise RuntimeError("[ERROR] Face detection is required! Check InsightFace installation.")
48
-
49
- # Load depth detector (starts on CPU) - single assignment, no alias
50
  self.zoe_depth, zoe_success = load_depth_detector()
51
  self.models_loaded['zoe_depth'] = zoe_success
52
 
53
- # Load ControlNets AS LIST
54
- controlnet_instantid, controlnet_depth = load_controlnets()
55
- controlnets = [controlnet_instantid, controlnet_depth]
56
- self.models_loaded['instantid'] = True
 
 
 
 
 
 
 
57
 
58
- print("Initializing InstantID pipeline with Face + Depth ControlNets")
 
 
 
 
 
 
59
 
60
- # Load SDXL pipeline with InstantID (handles IP-Adapter internally)
61
  self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
62
  self.models_loaded['custom_checkpoint'] = checkpoint_success
63
 
@@ -65,28 +73,51 @@ class RetroArtConverter:
65
  lora_success = load_lora(self.pipe)
66
  self.models_loaded['lora'] = lora_success
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Setup Compel
69
  self.compel, self.use_compel = setup_compel(self.pipe)
70
 
71
- # Setup scheduler
72
  setup_scheduler(self.pipe)
73
 
74
- # Optimize
75
  optimize_pipeline(self.pipe)
76
 
77
- # Load caption model (starts on CPU)
78
  self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
79
 
 
 
 
 
 
 
 
 
 
80
  # Set CLIP skip
81
  set_clip_skip(self.pipe)
82
 
83
- # Print status
84
- self._print_status()
 
85
 
86
- # Initial memory cleanup
87
- self.memory_manager.cleanup_memory(aggressive=True)
88
 
89
- print(" [OK] RetroArtConverter initialized with optimized memory management!")
90
 
91
  def _print_status(self):
92
  """Print model loading status"""
@@ -94,20 +125,31 @@ class RetroArtConverter:
94
  for model, loaded in self.models_loaded.items():
95
  status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
96
  print(f"{model}: {status}")
97
- print("InstantID Pipeline: [OK] ACTIVE")
98
- print("IP-Adapter: [OK] Built into pipeline")
99
- print(f"Face Detection: [OK] {'READY' if self.face_app else 'UNAVAILABLE'}")
100
  print("===================\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def get_depth_map(self, image):
103
- """Generate depth map using Zoe Depth with optimized GPU usage"""
104
  if self.zoe_depth is not None:
105
  try:
106
  if image.mode != 'RGB':
107
  image = image.convert('RGB')
108
 
109
- # Use safe size helper to avoid numpy.int64 issues
110
- orig_width, orig_height = safe_image_size(image)
 
111
 
112
  # Use multiples of 64
113
  target_width = int((orig_width // 64) * 64)
@@ -119,412 +161,256 @@ class RetroArtConverter:
119
  size_for_depth = (int(target_width), int(target_height))
120
  image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
121
 
122
- # Move depth model to GPU temporarily
123
- try:
124
- if torch.cuda.is_available():
125
- self.zoe_depth = self.zoe_depth.to(self.device)
126
-
127
- # Generate depth map
128
- depth_output = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
129
-
130
- # Handle different output types
131
- if isinstance(depth_output, Image.Image):
132
- depth_image = depth_output
133
- elif isinstance(depth_output, np.ndarray):
134
- depth_image = Image.fromarray(depth_output.astype(np.uint8))
135
- elif isinstance(depth_output, torch.Tensor):
136
- depth_array = depth_output.cpu().numpy()
137
- if depth_array.ndim == 3 and depth_array.shape[0] == 3:
138
- depth_array = depth_array.transpose(1, 2, 0)
139
- depth_image = Image.fromarray((depth_array * 255).astype(np.uint8))
140
- else:
141
- print(f"[DEPTH] Unexpected output type: {type(depth_output)}")
142
- depth_image = image_for_depth.convert('L').convert('RGB')
143
-
144
- # Move back to CPU to free GPU memory
145
- if torch.cuda.is_available():
146
- self.zoe_depth = self.zoe_depth.to("cpu")
147
- torch.cuda.empty_cache()
148
-
149
- except Exception as inner_e:
150
- print(f"[DEPTH] GPU processing failed: {inner_e}, trying on CPU")
151
- self.zoe_depth = self.zoe_depth.to("cpu")
152
-
153
- depth_output = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
154
-
155
- if isinstance(depth_output, Image.Image):
156
- depth_image = depth_output
157
- elif isinstance(depth_output, np.ndarray):
158
- depth_image = Image.fromarray(depth_output.astype(np.uint8))
159
- else:
160
- depth_image = image_for_depth.convert('L').convert('RGB')
161
-
162
- # Ensure depth image is RGB
163
- if depth_image.mode != 'RGB':
164
- depth_image = depth_image.convert('RGB')
165
 
166
- if depth_image.size != image.size:
167
- depth_image = depth_image.resize(image.size, Image.LANCZOS)
168
-
169
- print(f"[DEPTH] Generated depth map: {depth_image.size}")
170
- return depth_image
171
 
 
172
  except Exception as e:
173
- print(f"[DEPTH] Generation failed: {e}, using grayscale fallback")
174
- fallback = image.convert('L').convert('RGB')
175
- return fallback
176
- else:
177
- print("[DEPTH] Detector not available, using grayscale")
178
- fallback = image.convert('L').convert('RGB')
179
- return fallback
180
-
181
- def add_trigger_word(self, prompt):
182
- """Add trigger word to prompt if not present"""
183
- if TRIGGER_WORD.lower() not in prompt.lower():
184
- if not prompt or not prompt.strip():
185
- return TRIGGER_WORD
186
- return f"{TRIGGER_WORD}, {prompt}"
187
- return prompt
188
-
189
- def detect_face_quality(self, face):
190
- """Detect face quality and adaptively adjust parameters"""
191
- try:
192
- bbox = face.bbox
193
- face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
194
- det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
195
-
196
- # Small face -> boost preservation
197
- if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
198
- return ADAPTIVE_PARAMS['small_face'].copy()
199
-
200
- # Low confidence -> boost preservation
201
- elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
202
- return ADAPTIVE_PARAMS['low_confidence'].copy()
203
-
204
- # Check for profile view
205
- elif hasattr(face, 'pose') and len(face.pose) > 1:
206
- try:
207
- yaw = float(face.pose[1])
208
- if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
209
- return ADAPTIVE_PARAMS['profile_view'].copy()
210
- except (ValueError, TypeError, IndexError):
211
- pass
212
-
213
- return None
214
-
215
- except Exception as e:
216
- print(f"[ADAPTIVE] Quality detection failed: {e}")
217
- return None
218
 
219
- def generate_caption(self, image):
220
- """Generate caption for image with optimized GPU usage"""
221
- if not self.caption_enabled or self.caption_model is None:
222
- return None
223
-
224
- try:
225
- # Move caption model to GPU temporarily
226
- original_device = "cpu"
227
- if hasattr(self.caption_model, 'device'):
228
- original_device = str(self.caption_model.device)
229
-
230
- try:
231
- # Move to GPU for processing
232
- if torch.cuda.is_available():
233
- self.caption_model = self.caption_model.to(self.device)
234
-
235
- if self.caption_model_type == 'git':
236
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
237
- generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
238
- caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
239
- elif self.caption_model_type == 'blip':
240
- inputs = self.caption_processor(image, return_tensors="pt").to(self.device)
241
- generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
242
- caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
243
- else:
244
- return None
245
-
246
- # Move back to CPU to free GPU memory
247
- if torch.cuda.is_available() and "cpu" in original_device:
248
- self.caption_model = self.caption_model.to("cpu")
249
- torch.cuda.empty_cache()
250
-
251
- except Exception as gpu_error:
252
- print(f"[CAPTION] GPU processing failed: {gpu_error}, trying on CPU")
253
- self.caption_model = self.caption_model.to("cpu")
254
-
255
- if self.caption_model_type == 'git':
256
- inputs = self.caption_processor(images=image, return_tensors="pt")
257
- generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
258
- caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
259
- elif self.caption_model_type == 'blip':
260
- inputs = self.caption_processor(image, return_tensors="pt")
261
- generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
262
- caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
263
- else:
264
- return None
265
-
266
- return sanitize_text(caption)
267
- except Exception as e:
268
- print(f"[CAPTION] Generation failed: {e}")
269
- return None
270
-
271
- def generate_retro_art(
272
  self,
273
- input_image,
274
- prompt=" ",
275
- negative_prompt=" ",
276
- num_inference_steps=12,
277
- guidance_scale=1.3,
278
- depth_control_scale=0.75,
279
- identity_control_scale=0.85,
280
  lora_scale=1.0,
281
- identity_preservation=1.2,
282
- strength=0.50,
283
- enable_color_matching=False,
284
- consistency_mode=True,
285
- seed=-1
 
286
  ):
287
- """Generate retro art with InstantID face preservation"""
 
 
 
288
 
289
- try:
290
- # Add trigger word
291
- prompt = self.add_trigger_word(prompt)
292
- prompt = sanitize_text(prompt)
293
- negative_prompt = sanitize_text(negative_prompt)
294
-
295
- print(f"[PROMPT] {prompt}")
296
-
297
- # Calculate optimal size
298
- orig_width, orig_height = safe_image_size(input_image)
299
- optimal_width, optimal_height = calculate_optimal_size(orig_width, orig_height)
300
-
301
- # Resize image
302
- resized_image = input_image.resize((optimal_width, optimal_height), Image.LANCZOS)
303
- print(f"[SIZE] Resized to {optimal_width}x{optimal_height}")
304
-
305
- # Generate depth map
306
- depth_image = self.get_depth_map(resized_image)
307
-
308
- # ═══════════════════════════════════════════════════════════
309
- # FACE DETECTION
310
- # ═══════════════════════════════════════════════════════════
311
- has_detected_faces = False
312
- face_kps_image = None
313
- face_embeddings = None
314
- face_bbox_original = None
315
-
316
- # FACE DETECTION (examplewithface.py line 321-327)
 
 
 
 
 
 
317
  try:
318
- image_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
319
- faces = self.face_app.get(image_array)
320
 
321
- if len(faces) == 0:
322
- raise ValueError("No faces detected in image")
323
-
324
- # Get largest face (examplewithface.py line 322)
325
- face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]
326
-
327
- # Get embeddings and keypoints
328
- face_embeddings = face['embedding']
329
- face_kps_image = draw_kps(resized_image, face['kps'])
330
- face_bbox_original = face.get('bbox', None)
331
-
332
- # Adaptive parameter adjustment
333
- adaptive_params = self.detect_face_quality(face)
334
- if adaptive_params:
335
- print(f"[ADAPTIVE] {adaptive_params['reason']}")
336
- identity_preservation = adaptive_params.get('identity_preservation', identity_preservation)
337
- identity_control_scale = adaptive_params.get('identity_control_scale', identity_control_scale)
338
- guidance_scale = adaptive_params.get('guidance_scale', guidance_scale)
339
- lora_scale = adaptive_params.get('lora_scale', lora_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
- print(f"[FACE] Detected face with {face.get('det_score', 1.0):.2f} confidence")
342
- print(f"[FACE] Embeddings shape: {face_embeddings.shape}")
343
- has_detected_faces = True
 
344
 
 
345
  except Exception as e:
346
- print(f"[FACE] Face detection failed: {str(e)[:100]}")
347
- raise ValueError(f"No face found in image. Only face images work. Error: {str(e)}")
348
-
349
- # Fuse LORA with scale (following working example approach)
350
- if self.models_loaded['lora']:
351
- try:
352
- from models import fuse_lora_with_scale
353
- fuse_lora_with_scale(self.pipe, lora_scale)
354
- print(f"[LORA] Fused with scale: {lora_scale}")
355
- except Exception as e:
356
- print(f"[LORA] Could not fuse: {e}")
357
-
358
- # ═══════════════════════════════════════════════════════════
359
- # PIPELINE CONFIGURATION
360
- # ═══════════════════════════════════════════════════════════
361
- pipe_kwargs = {
362
- "image": resized_image,
363
- "strength": strength,
364
- "num_inference_steps": num_inference_steps,
365
- "guidance_scale": guidance_scale,
366
- }
367
-
368
- # Setup generator with seed
369
- if seed == -1:
370
- generator = torch.Generator(device=self.device)
371
- actual_seed = generator.seed()
372
- print(f"[SEED] Random: {actual_seed}")
373
- else:
374
- generator = torch.Generator(device=self.device).manual_seed(seed)
375
- actual_seed = seed
376
- print(f"[SEED] Fixed: {actual_seed}")
377
-
378
- pipe_kwargs["generator"] = generator
379
-
380
- # Use Compel for prompt encoding
381
- if self.use_compel and self.compel is not None:
382
- try:
383
- conditioning = self.compel(prompt)
384
- negative_conditioning = self.compel(negative_prompt)
385
-
386
- pipe_kwargs["prompt_embeds"] = conditioning[0]
387
- pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
388
- pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
389
- pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
390
-
391
- print("[OK] Using Compel-encoded prompts")
392
- except Exception as e:
393
- print(f"[COMPEL] Failed, using standard prompts: {e}")
394
- pipe_kwargs["prompt"] = prompt
395
- pipe_kwargs["negative_prompt"] = negative_prompt
396
- else:
397
  pipe_kwargs["prompt"] = prompt
398
  pipe_kwargs["negative_prompt"] = negative_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- # ═══════════════════════════════════════════════════════════
401
- # CONTROLNET + IP-ADAPTER CONFIGURATION
402
- # ═══════════════════════════════════════════════════════════
403
-
404
- if has_detected_faces and face_kps_image is not None and face_embeddings is not None:
405
- print("═" * 60)
406
- print("MODE: InstantID (Face Keypoints + Depth + IP-Adapter)")
407
- print("═" * 60)
408
-
409
- # Set IP-Adapter scale
410
- self.pipe.set_ip_adapter_scale(identity_preservation)
411
- print(f" [IP-ADAPTER] Scale set to: {identity_preservation}")
412
-
413
- # Control images: [face keypoints, depth map]
414
- pipe_kwargs["control_image"] = [face_kps_image, depth_image]
415
-
416
- # ControlNet scales: [identity keypoints, depth]
417
- pipe_kwargs["controlnet_conditioning_scale"] = [
418
- identity_control_scale,
419
- depth_control_scale
420
- ]
421
-
422
- # Control guidance timing
423
- pipe_kwargs["control_guidance_start"] = [0.0, 0.0]
424
- pipe_kwargs["control_guidance_end"] = [1.0, 1.0]
425
-
426
- # Pass raw face embeddings - pipeline handles everything
427
- pipe_kwargs["image_embeds"] = face_embeddings
428
-
429
- print(f" [CONTROLNET] Identity scale: {identity_control_scale}")
430
- print(f" [CONTROLNET] Depth scale: {depth_control_scale}")
431
- print(f" [EMBEDDINGS] Shape: {face_embeddings.shape} (raw)")
432
- print(" [INFO] Pipeline will handle: Resampler → Concatenation → Attention")
433
- print("═" * 60)
434
-
435
- elif has_detected_faces and face_kps_image is not None:
436
- print("═" * 60)
437
- print("MODE: InstantID Keypoints Only (no embeddings)")
438
- print("═" * 60)
439
-
440
- # Disable IP-Adapter
441
- self.pipe.set_ip_adapter_scale(0.0)
442
- print(" [IP-ADAPTER] Disabled (no embeddings)")
443
-
444
- # Use keypoints + depth
445
- pipe_kwargs["control_image"] = [face_kps_image, depth_image]
446
- pipe_kwargs["controlnet_conditioning_scale"] = [
447
- identity_control_scale,
448
- depth_control_scale
449
- ]
450
- pipe_kwargs["control_guidance_start"] = [0.0, 0.0]
451
- pipe_kwargs["control_guidance_end"] = [1.0, 1.0]
452
-
453
- # Pass zero embeddings
454
- zero_embeddings = np.zeros(512, dtype=np.float32)
455
- pipe_kwargs["image_embeds"] = zero_embeddings
456
-
457
- print(" [INFO] Using keypoints for structure only (zero embeddings)")
458
- print("═" * 60)
459
-
460
- else:
461
- print("═" * 60)
462
- print("MODE: Depth Only (no face detection)")
463
- print("═" * 60)
464
-
465
- # Disable IP-Adapter
466
- self.pipe.set_ip_adapter_scale(0.0)
467
- print(" [IP-ADAPTER] Disabled (no face)")
468
-
469
- # Use depth only
470
- pipe_kwargs["control_image"] = [depth_image, depth_image]
471
- pipe_kwargs["controlnet_conditioning_scale"] = [0.0, depth_control_scale]
472
- pipe_kwargs["control_guidance_start"] = [0.0, 0.0]
473
- pipe_kwargs["control_guidance_end"] = [1.0, 1.0]
474
-
475
- # Pass zero embeddings
476
- zero_embeddings = np.zeros(512, dtype=np.float32)
477
- pipe_kwargs["image_embeds"] = zero_embeddings
478
-
479
- print(f" [CONTROLNET] Depth scale: {depth_control_scale}")
480
- print(" [INFO] Generating without face preservation (zero embeddings)")
481
- print("═" * 60)
482
 
483
- # ═══════════════════════════════════════════════════════════
484
- # GENERATION
485
- # ═��═════════════════════════════════════════════════════════
486
- print(f"\nGenerating: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
 
 
 
487
 
488
- result = self.pipe(**pipe_kwargs)
 
 
 
489
 
490
- generated_image = result.images[0]
 
 
 
491
 
492
- # ═══════════════════════════════════════════════════════════
493
- # POST-PROCESSING
494
- # ═══════════════════════════════════════════════════════════
495
- if enable_color_matching and has_detected_faces:
496
- print("Applying enhanced face-aware color matching...")
497
- try:
498
- if face_bbox_original is not None:
499
- generated_image = enhanced_color_match(
500
- generated_image,
501
- resized_image,
502
- face_bbox=face_bbox_original
503
- )
504
- print("[OK] Enhanced color matching applied")
505
- else:
506
- generated_image = color_match(generated_image, resized_image, mode='mkl')
507
- print("[OK] Standard color matching applied")
508
- except Exception as e:
509
- print(f"[COLOR] Matching failed: {e}")
510
- elif enable_color_matching:
511
- print("Applying standard color matching...")
512
- try:
 
 
 
 
 
 
 
513
  generated_image = color_match(generated_image, resized_image, mode='mkl')
514
- print("[OK] Standard color matching applied")
515
- except Exception as e:
516
- print(f"[COLOR] Matching failed: {e}")
517
-
518
- return generated_image
 
 
 
 
 
519
 
520
- finally:
521
- # Memory cleanup
522
- self.memory_manager.cleanup_memory(aggressive=True)
523
-
524
- # Final memory status
525
- if self.memory_manager.verbose:
526
- print("[MEMORY] Final status after generation:")
527
- self.memory_manager.print_memory_status()
528
 
529
 
530
- print("[OK] Generator class ready with cleaned code")
 
1
  """
2
+ Generation logic for Pixagram AI Pixel Art Generator
3
+ FIXED VERSION - Proper embedding integration following exampleapp.py pattern
4
  """
5
  import torch
6
  import numpy as np
7
  import cv2
8
  from PIL import Image
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
 
12
  from config import (
13
+ device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
14
+ ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER
15
  )
16
  from utils import (
17
+ sanitize_text, enhanced_color_match, color_match, create_face_mask,
18
+ draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
19
  )
20
+ from models_fixed import ( # Use the fixed version
21
+ load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
22
+ load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
23
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
24
  )
 
 
25
 
26
 
27
  class RetroArtConverter:
28
+ """Main class for retro art generation - FIXED VERSION"""
29
 
30
  def __init__(self):
31
  self.device = device
 
34
  'custom_checkpoint': False,
35
  'lora': False,
36
  'instantid': False,
37
+ 'zoe_depth': False,
38
+ 'ip_adapter': False
39
  }
40
 
41
+ # Initialize face analysis
42
+ self.face_app, self.face_detection_enabled = load_face_analysis()
43
 
44
+ # Load Zoe Depth detector
 
 
 
 
 
45
  self.zoe_depth, zoe_success = load_depth_detector()
46
  self.models_loaded['zoe_depth'] = zoe_success
47
 
48
+ # Load ControlNets
49
+ controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
50
+ self.controlnet_depth = controlnet_depth
51
+ self.instantid_enabled = instantid_success
52
+ self.models_loaded['instantid'] = instantid_success
53
+
54
+ # Load image encoder (still needed for some pipeline functions)
55
+ if self.instantid_enabled:
56
+ self.image_encoder = load_image_encoder()
57
+ else:
58
+ self.image_encoder = None
59
 
60
+ # Determine which controlnets to use
61
+ if self.instantid_enabled and self.controlnet_instantid is not None:
62
+ controlnets = [self.controlnet_instantid, controlnet_depth]
63
+ print(f"Initializing with multiple ControlNets: InstantID + Depth")
64
+ else:
65
+ controlnets = controlnet_depth
66
+ print(f"Initializing with single ControlNet: Depth only")
67
 
68
+ # CRITICAL FIX: Load SDXL pipeline with from_pretrained()
69
  self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
70
  self.models_loaded['custom_checkpoint'] = checkpoint_success
71
 
 
73
  lora_success = load_lora(self.pipe)
74
  self.models_loaded['lora'] = lora_success
75
 
76
+ # CRITICAL FIX: Setup IP-Adapter using pipeline's built-in method
77
+ if self.instantid_enabled and self.image_encoder is not None:
78
+ ip_adapter_success = setup_ip_adapter(self.pipe)
79
+ self.models_loaded['ip_adapter'] = ip_adapter_success
80
+
81
+ # The pipeline now has these attributes after load_ip_adapter_instantid:
82
+ # - self.pipe.image_proj_model (the Resampler)
83
+ # - self.pipe.ip_adapter_scale (current scale)
84
+ # We don't need to manually manage these anymore!
85
+ else:
86
+ print("[INFO] Face preservation: InstantID ControlNet keypoints only")
87
+ self.models_loaded['ip_adapter'] = False
88
+
89
  # Setup Compel
90
  self.compel, self.use_compel = setup_compel(self.pipe)
91
 
92
+ # Setup LCM scheduler
93
  setup_scheduler(self.pipe)
94
 
95
+ # Optimize pipeline
96
  optimize_pipeline(self.pipe)
97
 
98
+ # Load caption model
99
  self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
100
 
101
+ # Report caption model status
102
+ if self.caption_enabled and self.caption_model is not None:
103
+ if self.caption_model_type == "git":
104
+ print(" [OK] Using GIT for detailed captions")
105
+ elif self.caption_model_type == "blip":
106
+ print(" [OK] Using BLIP for standard captions")
107
+ else:
108
+ print(" [OK] Caption model loaded")
109
+
110
  # Set CLIP skip
111
  set_clip_skip(self.pipe)
112
 
113
+ # Track controlnet configuration
114
+ self.using_multiple_controlnets = isinstance(controlnets, list)
115
+ print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
116
 
117
+ # Print model status
118
+ self._print_status()
119
 
120
+ print(" [OK] Model initialization complete!")
121
 
122
  def _print_status(self):
123
  """Print model loading status"""
 
125
  for model, loaded in self.models_loaded.items():
126
  status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
127
  print(f"{model}: {status}")
 
 
 
128
  print("===================\n")
129
+
130
+ print("=== IP-ADAPTER STATUS ===")
131
+ if self.models_loaded.get('ip_adapter', False):
132
+ if hasattr(self.pipe, 'image_proj_model'):
133
+ print("[OK] IP-Adapter fully loaded via pipeline method")
134
+ print(" - Resampler: Available at pipe.image_proj_model")
135
+ print(" - Scale control: Available via pipe.set_ip_adapter_scale()")
136
+ print(" - Expected improvement: High face similarity")
137
+ else:
138
+ print("[WARNING] IP-Adapter loaded but Resampler not accessible")
139
+ else:
140
+ print("[INFO] IP-Adapter not active (using keypoints only)")
141
+ print("=========================\n")
142
 
143
  def get_depth_map(self, image):
144
+ """Generate depth map using Zoe Depth"""
145
  if self.zoe_depth is not None:
146
  try:
147
  if image.mode != 'RGB':
148
  image = image.convert('RGB')
149
 
150
+ orig_width, orig_height = image.size
151
+ orig_width = int(orig_width)
152
+ orig_height = int(orig_height)
153
 
154
  # Use multiples of 64
155
  target_width = int((orig_width // 64) * 64)
 
161
  size_for_depth = (int(target_width), int(target_height))
162
  image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
163
 
164
+ depth_map = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ if depth_map.size != image.size:
167
+ depth_map = depth_map.resize(image.size, Image.LANCZOS)
 
 
 
168
 
169
+ return depth_map
170
  except Exception as e:
171
+ print(f"Depth generation failed: {e}")
172
+ return None
173
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  self,
177
+ image,
178
+ prompt="a person",
179
+ negative_prompt="",
180
+ num_inference_steps=4,
181
+ guidance_scale=0.0,
182
+ strength=0.75,
 
183
  lora_scale=1.0,
184
+ identity_control_scale=0.8,
185
+ depth_control_scale=0.8,
186
+ identity_preservation=1.0,
187
+ enable_color_matching=True,
188
+ seed=-1,
189
+ **kwargs
190
  ):
191
+ """
192
+ Generate retro art with InstantID face preservation.
193
+ FIXED: Proper IP-Adapter integration following exampleapp.py pattern.
194
+ """
195
 
196
+ print(f"\n{'='*60}")
197
+ print(f"Starting generation with:")
198
+ print(f" Prompt: {prompt}")
199
+ print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}")
200
+ print(f" Identity scale: {identity_control_scale}, Depth scale: {depth_control_scale}")
201
+ print(f" Face preservation: {identity_preservation}")
202
+ print(f"{'='*60}\n")
203
+
204
+ # Prepare input image
205
+ if image.mode != 'RGB':
206
+ image = image.convert('RGB')
207
+
208
+ optimal_width, optimal_height = calculate_optimal_size(image.size)
209
+ resized_image = image.resize((optimal_width, optimal_height), Image.LANCZOS)
210
+
211
+ print(f"Image resized: {image.size} → {resized_image.size}")
212
+
213
+ # Generate depth map
214
+ print("Generating depth map...")
215
+ depth_image = self.get_depth_map(resized_image)
216
+
217
+ if depth_image is None:
218
+ raise RuntimeError("Could not generate depth map")
219
+
220
+ # Face detection and processing
221
+ has_detected_faces = False
222
+ face_kps_image = None
223
+ face_embeddings = None
224
+ face_crop = None
225
+ face_crop_enhanced = None
226
+ face_bbox_original = None
227
+
228
+ if self.face_app is not None:
229
+ print("Detecting faces...")
230
  try:
231
+ image_np = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
232
+ faces = self.face_app.get(image_np)
233
 
234
+ if len(faces) > 0:
235
+ has_detected_faces = True
236
+ face = faces[0]
237
+ print(f" [OK] Face detected (score: {face.det_score:.3f})")
238
+
239
+ # Get face keypoints image
240
+ face_kps_image = draw_kps(resized_image, face.kps)
241
+
242
+ # Get face embeddings (512D from InsightFace)
243
+ if hasattr(face, 'normed_embedding'):
244
+ face_embeddings = face.normed_embedding
245
+ print(f" Face embedding shape: {face_embeddings.shape}")
246
+ elif hasattr(face, 'embedding'):
247
+ face_embeddings = face.embedding / np.linalg.norm(face.embedding)
248
+ print(f" Face embedding shape: {face_embeddings.shape}")
249
+
250
+ # Store face bbox for color matching
251
+ if hasattr(face, 'bbox'):
252
+ face_bbox_original = face.bbox
253
+
254
+ # Get face crop for enhanced processing
255
+ bbox = face.bbox.astype(int)
256
+ face_crop = resized_image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
257
+ face_crop_enhanced = enhance_face_crop(face_crop)
258
+
259
+ # Debug info
260
+ if hasattr(face, 'age') and hasattr(face, 'gender'):
261
+ facial_attrs = {
262
+ 'age': face.age,
263
+ 'gender': face.gender,
264
+ 'quality': face.det_score
265
+ }
266
+ age = facial_attrs['age']
267
+ gender_code = facial_attrs['gender']
268
+ det_score = facial_attrs['quality']
269
+ gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
270
+ print(f" Face info: age={age if age else 'N/A'}, gender={gender_str}, quality={det_score:.3f}")
271
+ else:
272
+ print(" [INFO] No faces detected")
273
+ except Exception as e:
274
+ print(f" [WARNING] Face detection failed: {e}")
275
+
276
+ # CRITICAL FIX: Set IP-Adapter scale dynamically
277
+ # The pipeline's built-in method allows runtime adjustment
278
+ if self.models_loaded.get('ip_adapter', False) and has_detected_faces:
279
+ try:
280
+ # Scale based on identity_preservation parameter
281
+ adjusted_scale = 0.8 * identity_preservation
282
+ self.pipe.set_ip_adapter_scale(adjusted_scale)
283
+ print(f" IP-Adapter scale adjusted to: {adjusted_scale:.2f}")
284
+ except Exception as e:
285
+ print(f" [WARNING] Could not adjust IP-Adapter scale: {e}")
286
+
287
+ # Set LORA scale
288
+ if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
289
+ try:
290
+ self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
291
+ print(f" LORA scale: {lora_scale}")
292
+ except Exception as e:
293
+ print(f" [WARNING] Could not set LORA scale: {e}")
294
+
295
+ # Prepare generation kwargs
296
+ pipe_kwargs = {
297
+ "image": resized_image,
298
+ "strength": strength,
299
+ "num_inference_steps": num_inference_steps,
300
+ "guidance_scale": guidance_scale,
301
+ }
302
+
303
+ # Setup generator with seed control
304
+ if seed == -1:
305
+ generator = torch.Generator(device=self.device)
306
+ actual_seed = generator.seed()
307
+ print(f"[SEED] Using random seed: {actual_seed}")
308
+ else:
309
+ generator = torch.Generator(device=self.device).manual_seed(seed)
310
+ actual_seed = seed
311
+ print(f"[SEED] Using fixed seed: {actual_seed}")
312
+
313
+ pipe_kwargs["generator"] = generator
314
+
315
+ # Use Compel for prompt encoding if available
316
+ if self.use_compel and self.compel is not None:
317
+ try:
318
+ print("Encoding prompts with Compel...")
319
+ conditioning = self.compel(prompt)
320
+ negative_conditioning = self.compel(negative_prompt)
321
 
322
+ pipe_kwargs["prompt_embeds"] = conditioning[0]
323
+ pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
324
+ pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
325
+ pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
326
 
327
+ print(" [OK] Using Compel-encoded prompts")
328
  except Exception as e:
329
+ print(f" Compel encoding failed, using standard prompts: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  pipe_kwargs["prompt"] = prompt
331
  pipe_kwargs["negative_prompt"] = negative_prompt
332
+ else:
333
+ pipe_kwargs["prompt"] = prompt
334
+ pipe_kwargs["negative_prompt"] = negative_prompt
335
+
336
+ # Add CLIP skip
337
+ if hasattr(self.pipe, 'text_encoder'):
338
+ pipe_kwargs["clip_skip"] = 2
339
+
340
+ # Configure ControlNet inputs
341
+ using_multiple_controlnets = self.using_multiple_controlnets
342
+
343
+ if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
344
+ print("Using InstantID (keypoints + embeddings) + Depth ControlNets")
345
+ control_images = [face_kps_image, depth_image]
346
+ conditioning_scales = [identity_control_scale, depth_control_scale]
347
 
348
+ pipe_kwargs["control_image"] = control_images
349
+ pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ # CRITICAL FIX: The pipeline handles face embeddings automatically!
352
+ # When load_ip_adapter_instantid() was called, the pipeline was configured
353
+ # to automatically process face embeddings through the Resampler and
354
+ # integrate them with text embeddings during generation.
355
+ #
356
+ # We just need to provide the face image via control_image and the
357
+ # pipeline does the rest. No manual concatenation needed!
358
 
359
+ if face_embeddings is not None and self.models_loaded.get('ip_adapter', False):
360
+ print(" [OK] Face embeddings will be processed by pipeline")
361
+ print(" - Pipeline automatically handles Resampler projection")
362
+ print(" - Face features integrated via IP-Adapter attention")
363
 
364
+ elif using_multiple_controlnets and not has_detected_faces:
365
+ print("Multiple ControlNets available but no faces detected, using depth only")
366
+ control_images = [depth_image, depth_image]
367
+ conditioning_scales = [0.0, depth_control_scale]
368
 
369
+ pipe_kwargs["control_image"] = control_images
370
+ pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
371
+
372
+ else:
373
+ print("Using Depth ControlNet only")
374
+ pipe_kwargs["control_image"] = depth_image
375
+ pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
376
+
377
+ # Generate
378
+ print(f"\nGenerating with LCM:")
379
+ print(f" Steps: {num_inference_steps}, CFG: {guidance_scale}, Strength: {strength}")
380
+ print(f" ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}")
381
+
382
+ result = self.pipe(**pipe_kwargs)
383
+ generated_image = result.images[0]
384
+
385
+ # Post-processing
386
+ if enable_color_matching and has_detected_faces:
387
+ print("\nApplying enhanced face-aware color matching...")
388
+ try:
389
+ if face_bbox_original is not None:
390
+ generated_image = enhanced_color_match(
391
+ generated_image,
392
+ resized_image,
393
+ face_bbox=face_bbox_original
394
+ )
395
+ print(" [OK] Enhanced color matching applied (face-aware)")
396
+ else:
397
  generated_image = color_match(generated_image, resized_image, mode='mkl')
398
+ print(" [OK] Standard color matching applied")
399
+ except Exception as e:
400
+ print(f" [WARNING] Color matching failed: {e}")
401
+ elif enable_color_matching:
402
+ print("\nApplying standard color matching...")
403
+ try:
404
+ generated_image = color_match(generated_image, resized_image, mode='mkl')
405
+ print(" [OK] Standard color matching applied")
406
+ except Exception as e:
407
+ print(f" [WARNING] Color matching failed: {e}")
408
 
409
+ print(f"\n{'='*60}")
410
+ print("Generation complete!")
411
+ print(f"{'='*60}\n")
412
+
413
+ return generated_image
 
 
 
414
 
415
 
416
+ print("[OK] Generator class ready (FIXED VERSION)")