primerz commited on
Commit
bde5828
·
verified ·
1 Parent(s): d4170e9

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +265 -316
generator.py CHANGED
@@ -1,22 +1,20 @@
1
  """
2
- Generation logic for Pixagram AI Pixel Art Generator
3
- UPDATED VERSION with simplified InstantID face preservation
4
  """
5
  import torch
6
  import numpy as np
7
  import cv2
8
  from PIL import Image
9
- import torch.nn.functional as F
10
- from torchvision import transforms
11
 
12
  from config import (
13
- device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
14
- ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER
15
  )
16
  from utils import (
17
- sanitize_text, enhanced_color_match, color_match, create_face_mask,
18
- draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop,
19
- safe_image_size, ensure_int
20
  )
21
  from models import (
22
  load_face_analysis, load_depth_detector, load_controlnets,
@@ -26,7 +24,7 @@ from models import (
26
 
27
 
28
  class RetroArtConverter:
29
- """Main class for retro art generation with InstantID face preservation"""
30
 
31
  def __init__(self):
32
  self.device = device
@@ -38,14 +36,14 @@ class RetroArtConverter:
38
  'zoe_depth': False
39
  }
40
 
41
- # Initialize face analysis
42
  self.face_app, self.face_detection_enabled = load_face_analysis()
43
 
44
- # Load Zoe Depth detector
45
  self.zoe_depth, zoe_success = load_depth_detector()
46
  self.models_loaded['zoe_depth'] = zoe_success
47
 
48
- # Load ControlNets - ALWAYS as list for InstantID pipeline
49
  controlnet_instantid, controlnet_depth = load_controlnets()
50
  controlnets = [controlnet_instantid, controlnet_depth]
51
  self.models_loaded['instantid'] = True
@@ -63,29 +61,22 @@ class RetroArtConverter:
63
  # Setup Compel
64
  self.compel, self.use_compel = setup_compel(self.pipe)
65
 
66
- # Setup LCM scheduler
67
  setup_scheduler(self.pipe)
68
 
69
- # Optimize pipeline
70
  optimize_pipeline(self.pipe)
71
 
72
  # Load caption model
73
  self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
74
 
75
- # Report caption model status
76
- if self.caption_enabled and self.caption_model is not None:
77
- if self.caption_model_type == "git":
78
- print(" [OK] Using GIT for detailed captions")
79
- elif self.caption_model_type == "blip":
80
- print(" [OK] Using BLIP for standard captions")
81
-
82
  # Set CLIP skip
83
  set_clip_skip(self.pipe)
84
 
85
- # Print model status
86
  self._print_status()
87
 
88
- print(" [OK] Model initialization complete with InstantID!")
89
 
90
  def _print_status(self):
91
  """Print model loading status"""
@@ -93,110 +84,110 @@ class RetroArtConverter:
93
  for model, loaded in self.models_loaded.items():
94
  status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
95
  print(f"{model}: {status}")
96
- print("InstantID Pipeline: [OK] Active with built-in IP-Adapter")
 
97
  print("===================\n")
98
 
99
  def get_depth_map(self, image):
100
- """
101
- Generate depth map using available depth detector.
102
- Supports: LeresDetector, ZoeDetector, or MidasDetector.
103
- """
104
- # --- FIX 1: Check for self.zoe_depth, not self.depth_detector ---
105
- if self.zoe_depth is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  try:
107
- if image.mode != 'RGB':
108
- image = image.convert('RGB')
109
-
110
- orig_width, orig_height = image.size
111
- orig_width = int(orig_width)
112
- orig_height = int(orig_height)
113
-
114
- target_width = int((orig_width // 64) * 64)
115
- target_height = int((orig_height // 64) * 64)
116
-
117
- target_width = int(max(64, target_width))
118
- target_height = int(max(64, target_height))
119
-
120
- size_for_depth = (int(target_width), int(target_height))
121
- image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
122
-
123
- if target_width != orig_width or target_height != orig_height:
124
- # --- FIX 2: Use "ZOE" instead of undefined self.depth_type ---
125
- print(f"[DEPTH] Resized for ZOEDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
126
-
127
- # Use torch.no_grad() and clear cache
128
- with torch.no_grad():
129
- # --- FIX 1: Use self.zoe_depth ---
130
- self.zoe_depth.to(self.device)
131
- depth_image = self.zoe_depth(image_for_depth)
132
- self.zoe_depth.to("cpu")
133
-
134
- # ADDED: Clear GPU cache after depth detection
135
- if torch.cuda.is_available():
136
- torch.cuda.empty_cache()
137
-
138
- depth_width, depth_height = depth_image.size
139
- if depth_width != orig_width or depth_height != orig_height:
140
- depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
141
-
142
- # --- FIX 2: Use "ZOE" instead of undefined self.depth_type ---
143
- print(f"[DEPTH] ZOE depth map generated: {orig_width}x{orig_height}")
144
- return depth_image
145
-
146
- except Exception as e:
147
- # --- FIX 2: Use "ZOE" instead of undefined self.depth_type ---
148
- print(f"[DEPTH] ZOEDetector failed ({e}), falling back to grayscale depth")
149
- # ADDED: Clear cache on error
150
- if torch.cuda.is_available():
151
- torch.cuda.empty_cache()
152
-
153
- gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
154
- depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
155
- return Image.fromarray(depth_colored)
156
- else:
157
- print("[DEPTH] No depth detector available, using grayscale fallback")
158
- gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
159
- depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
160
- return Image.fromarray(depth_colored)
161
-
162
-
163
  def generate_caption(self, image):
164
- """Generate caption for image using loaded caption model"""
165
  if not self.caption_enabled or self.caption_model is None:
166
  return None
167
 
168
  try:
169
  if self.caption_model_type == 'git':
170
- # GIT model
171
- pixel_values = self.caption_processor(images=image, return_tensors="pt").pixel_values
172
- pixel_values = pixel_values.to(device=self.device, dtype=self.dtype)
173
-
174
- generated_ids = self.caption_model.generate(
175
- pixel_values=pixel_values,
176
- max_length=CAPTION_CONFIG['max_length']
177
- )
178
  caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
179
-
180
  elif self.caption_model_type == 'blip':
181
- # BLIP model
182
- inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
183
- out = self.caption_model.generate(**inputs, max_new_tokens=CAPTION_CONFIG['max_length'])
184
- caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
185
-
186
  else:
187
  return None
188
 
189
  return sanitize_text(caption)
190
-
191
  except Exception as e:
192
- print(f"Caption generation failed: {e}")
193
  return None
194
 
195
  def generate_retro_art(
196
  self,
197
  input_image,
198
- prompt,
199
- negative_prompt="",
200
  num_inference_steps=12,
201
  guidance_scale=1.3,
202
  depth_control_scale=0.75,
@@ -208,232 +199,190 @@ class RetroArtConverter:
208
  consistency_mode=True,
209
  seed=-1
210
  ):
211
- """
212
- Generate retro art with InstantID face preservation.
213
-
214
- UPDATED: Simplified face embedding handling using InstantID pipeline.
215
- """
216
 
217
- # Validate and adjust parameters if consistency mode is enabled
218
- if consistency_mode:
219
- # Ensure guidance scale is in optimal range for LCM
220
- if guidance_scale < 1.0:
221
- guidance_scale = 1.0
222
- elif guidance_scale > 1.8:
223
- guidance_scale = 1.8
224
 
225
- # Ensure identity preservation and lora scale balance
226
- if identity_preservation > 1.5 and lora_scale > 1.2:
227
- lora_scale = min(lora_scale, 1.0)
228
 
229
- # Ensure strength is reasonable
230
- if strength < 0.3:
231
- strength = 0.3
232
- elif strength > 0.8:
233
- strength = 0.8
234
-
235
- # Calculate optimal size
236
- orig_width, orig_height = safe_image_size(input_image)
237
- optimal_width, optimal_height = calculate_optimal_size(orig_width, orig_height)
238
-
239
- # Resize image
240
- resized_image = input_image.resize((optimal_width, optimal_height), Image.LANCZOS)
241
-
242
- # Generate depth map
243
- print("Generating depth map...")
244
- # --- FIX 3: get_depth_map only returns one value ---
245
- depth_image = self.get_depth_map(resized_image)
246
-
247
- if depth_image is None:
248
- raise RuntimeError("Failed to generate depth map")
249
-
250
- # Detect faces
251
- print("Detecting faces...")
252
- has_detected_faces = False
253
- face_kps_image = None
254
- face_embeddings = None
255
- face_bbox_original = None
256
-
257
- if self.face_detection_enabled and self.face_app is not None:
258
- try:
259
- faces = self.face_app.get(np.array(resized_image))
260
-
261
- if len(faces) > 0:
262
- has_detected_faces = True
263
- face = faces[0]
264
-
265
- # Draw keypoints
266
- face_kps_image = draw_kps(resized_image, face.kps)
267
-
268
- # Get face embeddings (512D vector from InsightFace)
269
- face_embeddings = face.embedding
270
-
271
- # Get face bounding box for color matching
272
- face_bbox_original = face.bbox
273
-
274
- print(f" [OK] Face detected")
275
- # --- FIX 4: Clarify this is the numpy shape ---
276
- print(f" - Embedding shape (numpy): {face_embeddings.shape}")
277
- print(f" - Keypoints: {face.kps.shape}")
278
- print(f" - Bbox: {face_bbox_original}")
279
-
280
- # Check for adaptive parameter adjustment
281
- face_area = (face.bbox[2] - face.bbox[0]) * (face.bbox[3] - face.bbox[1])
282
- det_score = face.det_score if hasattr(face, 'det_score') else 1.0
283
-
284
- # Apply adaptive adjustments
285
- if face_area < ADAPTIVE_THRESHOLDS['small_face_size']:
286
- print(" [ADAPTIVE] Small face detected - boosting preservation")
287
- identity_preservation = max(identity_preservation, ADAPTIVE_PARAMS['small_face']['identity_preservation'])
288
- identity_control_scale = max(identity_control_scale, ADAPTIVE_PARAMS['small_face']['identity_control_scale'])
289
 
290
- elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
291
- print(" [ADAPTIVE] Low confidence - increasing identity weight")
292
- identity_preservation = max(identity_preservation, ADAPTIVE_PARAMS['low_confidence']['identity_preservation'])
293
- identity_control_scale = max(identity_control_scale, ADAPTIVE_PARAMS['low_confidence']['identity_control_scale'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- else:
296
- print(" No faces detected in image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- except Exception as e:
299
- print(f"Face detection error: {e}")
300
- has_detected_faces = False
301
-
302
- # Enhance prompt with trigger word
303
- if TRIGGER_WORD not in prompt.lower():
304
- prompt = f"{TRIGGER_WORD}, {prompt}"
305
-
306
- # Set LORA scale
307
- if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
308
- try:
309
- self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
310
- print(f"LORA scale: {lora_scale}")
311
- except Exception as e:
312
- print(f"Could not set LORA scale: {e}")
313
-
314
- # Prepare generation kwargs
315
- pipe_kwargs = {
316
- "image": resized_image,
317
- "strength": strength,
318
- "num_inference_steps": num_inference_steps,
319
- "guidance_scale": guidance_scale,
320
- }
321
-
322
- # Setup generator with seed control
323
- if seed == -1:
324
- generator = torch.Generator(device=self.device)
325
- actual_seed = generator.seed()
326
- print(f"[SEED] Using random seed: {actual_seed}")
327
- else:
328
- generator = torch.Generator(device=self.device).manual_seed(seed)
329
- actual_seed = seed
330
- print(f"[SEED] Using fixed seed: {actual_seed}")
331
-
332
- pipe_kwargs["generator"] = generator
333
-
334
- # Use Compel for prompt encoding if available
335
- if self.use_compel and self.compel is not None:
336
- try:
337
- print("Encoding prompts with Compel...")
338
- conditioning = self.compel(prompt)
339
- negative_conditioning = self.compel(negative_prompt)
340
-
341
- pipe_kwargs["prompt_embeds"] = conditioning[0]
342
- pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
343
- pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
344
- pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
345
-
346
- print("[OK] Using Compel-encoded prompts")
347
- except Exception as e:
348
- print(f"Compel encoding failed, using standard prompts: {e}")
349
  pipe_kwargs["prompt"] = prompt
350
  pipe_kwargs["negative_prompt"] = negative_prompt
351
- else:
352
- pipe_kwargs["prompt"] = prompt
353
- pipe_kwargs["negative_prompt"] = negative_prompt
354
-
355
- # Add CLIP skip
356
- if hasattr(self.pipe, 'text_encoder'):
357
- pipe_kwargs["clip_skip"] = 2
358
-
359
- # ========================================
360
- # SIMPLIFIED: Configure ControlNets + IP-Adapter
361
- # ========================================
362
- if has_detected_faces and face_kps_image is not None:
363
- print("Using InstantID (keypoints + embeddings) + Depth ControlNets")
364
 
365
- # Control images: [face keypoints, depth map]
366
- pipe_kwargs["control_image"] = [face_kps_image, depth_image]
367
-
368
- # Conditioning scales: [identity, depth]
369
- pipe_kwargs["controlnet_conditioning_scale"] = [
370
- identity_control_scale,
371
- depth_control_scale
372
- ]
373
-
374
- # CRITICAL: Pass face embeddings for IP-Adapter
375
- # The InstantID pipeline handles the Resampler internally!
376
- if face_embeddings is not None:
377
- print(f"Adding face embeddings for IP-Adapter...")
378
 
379
- # --- FIX 4: Convert numpy array to torch tensor, add batch dim, and move to device ---
380
- face_embeds_tensor = torch.tensor(face_embeddings, dtype=self.dtype, device=self.device).unsqueeze(0)
381
- pipe_kwargs["image_embeds"] = face_embeds_tensor
382
 
383
- # Control IP-Adapter strength
384
- boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
385
- pipe_kwargs["ip_adapter_scale"] = boosted_scale
 
 
386
 
387
- # --- FIX 4: Update log to show tensor shape ---
388
- print(f" - Face embeddings tensor shape: {face_embeds_tensor.shape}")
389
- print(f" - IP-Adapter scale: {boosted_scale:.2f}")
390
- print(f" [OK] Face embeddings configured")
 
 
 
 
 
 
 
 
 
 
 
 
391
  else:
392
- print(" [WARNING] No face embeddings - using keypoints only")
393
-
394
- else:
395
- print("No faces detected - using Depth ControlNet only")
 
396
 
397
- # Use depth for both ControlNet slots (identity scale = 0)
398
- pipe_kwargs["control_image"] = [depth_image, depth_image]
399
- pipe_kwargs["controlnet_conditioning_scale"] = [0.0, depth_control_scale]
400
-
401
- # Generate
402
- print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
403
- print(f"ControlNet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}")
404
- result = self.pipe(**pipe_kwargs)
405
-
406
- generated_image = result.images[0]
407
-
408
- # Post-processing: Color matching
409
- if enable_color_matching and has_detected_faces:
410
- print("Applying enhanced face-aware color matching...")
411
- try:
412
- if face_bbox_original is not None:
413
- generated_image = enhanced_color_match(
414
- generated_image,
415
- resized_image,
416
- face_bbox=face_bbox_original
417
- )
418
- print("[OK] Enhanced color matching applied (face-aware)")
419
- else:
 
 
420
  generated_image = color_match(generated_image, resized_image, mode='mkl')
421
  print("[OK] Standard color matching applied")
422
- except Exception as e:
423
- print(f"Color matching failed: {e}")
424
- elif enable_color_matching:
425
- print("Applying standard color matching...")
426
- try:
427
- generated_image = color_match(generated_image, resized_image, mode='mkl')
428
- print("[OK] Standard color matching applied")
429
- except Exception as e:
430
- print(f"Color matching failed: {e}")
431
-
432
- # Memory cleanup
433
- if torch.cuda.is_available():
434
- torch.cuda.empty_cache()
435
 
436
- return generated_image
 
 
 
 
437
 
438
 
439
- print("[OK] Generator class ready with InstantID pipeline")
 
1
  """
2
+ Generation logic for Pixagram AI Pixel Art Generator
3
+ UPDATED VERSION with InstantID pipeline integration
4
  """
5
  import torch
6
  import numpy as np
7
  import cv2
8
  from PIL import Image
9
+ import gc
 
10
 
11
  from config import (
12
+ device, dtype, TRIGGER_WORD,
13
+ ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG
14
  )
15
  from utils import (
16
+ sanitize_text, enhanced_color_match, color_match,
17
+ get_demographic_description, calculate_optimal_size, safe_image_size
 
18
  )
19
  from models import (
20
  load_face_analysis, load_depth_detector, load_controlnets,
 
24
 
25
 
26
  class RetroArtConverter:
27
+ """Main class for retro art generation with InstantID"""
28
 
29
  def __init__(self):
30
  self.device = device
 
36
  'zoe_depth': False
37
  }
38
 
39
+ # Load face analysis
40
  self.face_app, self.face_detection_enabled = load_face_analysis()
41
 
42
+ # Load depth detector
43
  self.zoe_depth, zoe_success = load_depth_detector()
44
  self.models_loaded['zoe_depth'] = zoe_success
45
 
46
+ # Load ControlNets AS LIST
47
  controlnet_instantid, controlnet_depth = load_controlnets()
48
  controlnets = [controlnet_instantid, controlnet_depth]
49
  self.models_loaded['instantid'] = True
 
61
  # Setup Compel
62
  self.compel, self.use_compel = setup_compel(self.pipe)
63
 
64
+ # Setup scheduler
65
  setup_scheduler(self.pipe)
66
 
67
+ # Optimize
68
  optimize_pipeline(self.pipe)
69
 
70
  # Load caption model
71
  self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
72
 
 
 
 
 
 
 
 
73
  # Set CLIP skip
74
  set_clip_skip(self.pipe)
75
 
76
+ # Print status
77
  self._print_status()
78
 
79
+ print(" [OK] RetroArtConverter initialized with InstantID!")
80
 
81
  def _print_status(self):
82
  """Print model loading status"""
 
84
  for model, loaded in self.models_loaded.items():
85
  status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
86
  print(f"{model}: {status}")
87
+ print("InstantID Pipeline: [OK] ACTIVE")
88
+ print("IP-Adapter: [OK] Built into pipeline")
89
  print("===================\n")
90
 
91
  def get_depth_map(self, image):
92
+ """Generate depth map using Zoe Depth"""
93
+ if self.zoe_depth is not None:
94
+ try:
95
+ if image.mode != 'RGB':
96
+ image = image.convert('RGB')
97
+
98
+ # Use safe size helper to avoid numpy.int64 issues
99
+ orig_width, orig_height = safe_image_size(image)
100
+
101
+ # Use multiples of 64
102
+ target_width = int((orig_width // 64) * 64)
103
+ target_height = int((orig_height // 64) * 64)
104
+
105
+ target_width = int(max(64, target_width))
106
+ target_height = int(max(64, target_height))
107
+
108
+ size_for_depth = (int(target_width), int(target_height))
109
+ image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
110
+
111
+ depth_array = self.zoe_depth(image_for_depth, detect_resolution=512, image_resolution=1024)
112
+ depth_image = Image.fromarray(depth_array)
113
+
114
+ if depth_image.size != image.size:
115
+ depth_image = depth_image.resize(image.size, Image.LANCZOS)
116
+
117
+ print(f"[DEPTH] Generated depth map: {depth_image.size}")
118
+ return depth_image, depth_array
119
+ except Exception as e:
120
+ print(f"[DEPTH] Generation failed: {e}, using grayscale")
121
+ return image.convert('L').convert('RGB'), None
122
+ else:
123
+ print("[DEPTH] Detector not available, using grayscale")
124
+ return image.convert('L').convert('RGB'), None
125
+
126
+ def add_trigger_word(self, prompt):
127
+ """Add trigger word to prompt if not present"""
128
+ if TRIGGER_WORD.lower() not in prompt.lower():
129
+ if not prompt or not prompt.strip():
130
+ return TRIGGER_WORD
131
+ return f"{TRIGGER_WORD}, {prompt}"
132
+ return prompt
133
+
134
+ def detect_face_quality(self, face):
135
+ """Detect face quality and adaptively adjust parameters"""
136
+ try:
137
+ bbox = face.bbox
138
+ face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
139
+ det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
140
+
141
+ # Small face -> boost preservation
142
+ if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
143
+ return ADAPTIVE_PARAMS['small_face'].copy()
144
+
145
+ # Low confidence -> boost preservation
146
+ elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
147
+ return ADAPTIVE_PARAMS['low_confidence'].copy()
148
+
149
+ # Check for profile view
150
+ elif hasattr(face, 'pose') and len(face.pose) > 1:
151
  try:
152
+ yaw = float(face.pose[1])
153
+ if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
154
+ return ADAPTIVE_PARAMS['profile_view'].copy()
155
+ except (ValueError, TypeError, IndexError):
156
+ pass
157
+
158
+ return None
159
+
160
+ except Exception as e:
161
+ print(f"[ADAPTIVE] Quality detection failed: {e}")
162
+ return None
163
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def generate_caption(self, image):
165
+ """Generate caption for image"""
166
  if not self.caption_enabled or self.caption_model is None:
167
  return None
168
 
169
  try:
170
  if self.caption_model_type == 'git':
171
+ inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device)
172
+ generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
 
 
 
 
 
 
173
  caption = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
174
  elif self.caption_model_type == 'blip':
175
+ inputs = self.caption_processor(image, return_tensors="pt").to(self.device)
176
+ generated_ids = self.caption_model.generate(**inputs, max_length=CAPTION_CONFIG['max_length'])
177
+ caption = self.caption_processor.decode(generated_ids[0], skip_special_tokens=True)
 
 
178
  else:
179
  return None
180
 
181
  return sanitize_text(caption)
 
182
  except Exception as e:
183
+ print(f"[CAPTION] Generation failed: {e}")
184
  return None
185
 
186
  def generate_retro_art(
187
  self,
188
  input_image,
189
+ prompt=" ",
190
+ negative_prompt=" ",
191
  num_inference_steps=12,
192
  guidance_scale=1.3,
193
  depth_control_scale=0.75,
 
199
  consistency_mode=True,
200
  seed=-1
201
  ):
202
+ """Generate retro art with InstantID face preservation"""
 
 
 
 
203
 
204
+ try:
205
+ # Add trigger word
206
+ prompt = self.add_trigger_word(prompt)
207
+ prompt = sanitize_text(prompt)
208
+ negative_prompt = sanitize_text(negative_prompt)
 
 
209
 
210
+ print(f"[PROMPT] {prompt}")
 
 
211
 
212
+ # Calculate optimal size
213
+ orig_width, orig_height = safe_image_size(input_image)
214
+ optimal_width, optimal_height = calculate_optimal_size(orig_width, orig_height)
215
+
216
+ # Resize image
217
+ resized_image = input_image.resize((optimal_width, optimal_height), Image.LANCZOS)
218
+ print(f"[SIZE] Resized to {optimal_width}x{optimal_height}")
219
+
220
+ # Generate depth map
221
+ depth_image, depth_array = self.get_depth_map(resized_image)
222
+
223
+ # Detect faces
224
+ has_detected_faces = False
225
+ face_kps_image = None
226
+ face_embeddings = None
227
+ face_bbox_original = None
228
+
229
+ if self.face_detection_enabled and self.face_app is not None:
230
+ try:
231
+ image_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
232
+ faces = self.face_app.get(image_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ if len(faces) > 0:
235
+ has_detected_faces = True
236
+ face = faces[0]
237
+
238
+ # Get face embeddings (512D array)
239
+ face_embeddings = face.normed_embedding
240
+
241
+ # Draw keypoints
242
+ from pipeline_stable_diffusion_xl_instantid_img2img import draw_kps
243
+ face_kps_image = draw_kps(resized_image, face.kps)
244
+
245
+ # Get bbox for color matching
246
+ face_bbox_original = face.bbox
247
+
248
+ # Adaptive parameter adjustment
249
+ adaptive_params = self.detect_face_quality(face)
250
+ if adaptive_params:
251
+ print(f"[ADAPTIVE] {adaptive_params['reason']}")
252
+ identity_preservation = adaptive_params.get('identity_preservation', identity_preservation)
253
+ identity_control_scale = adaptive_params.get('identity_control_scale', identity_control_scale)
254
+ guidance_scale = adaptive_params.get('guidance_scale', guidance_scale)
255
+ lora_scale = adaptive_params.get('lora_scale', lora_scale)
256
+
257
+ print(f"[FACE] Detected face with {face.det_score:.2f} confidence")
258
+ print(f"[FACE] Embeddings shape: {face_embeddings.shape}")
259
+ else:
260
+ print("[FACE] No faces detected")
261
 
262
+ except Exception as e:
263
+ print(f"[FACE] Detection failed: {e}")
264
+
265
+ # Set LORA scale
266
+ if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
267
+ try:
268
+ self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
269
+ print(f"[LORA] Scale: {lora_scale}")
270
+ except Exception as e:
271
+ print(f"[LORA] Could not set scale: {e}")
272
+
273
+ # Prepare generation kwargs
274
+ pipe_kwargs = {
275
+ "image": resized_image,
276
+ "strength": strength,
277
+ "num_inference_steps": num_inference_steps,
278
+ "guidance_scale": guidance_scale,
279
+ }
280
+
281
+ # Setup generator with seed
282
+ if seed == -1:
283
+ generator = torch.Generator(device=self.device)
284
+ actual_seed = generator.seed()
285
+ print(f"[SEED] Random: {actual_seed}")
286
+ else:
287
+ generator = torch.Generator(device=self.device).manual_seed(seed)
288
+ actual_seed = seed
289
+ print(f"[SEED] Fixed: {actual_seed}")
290
+
291
+ pipe_kwargs["generator"] = generator
292
+
293
+ # Use Compel for prompt encoding
294
+ if self.use_compel and self.compel is not None:
295
+ try:
296
+ conditioning = self.compel(prompt)
297
+ negative_conditioning = self.compel(negative_prompt)
298
 
299
+ pipe_kwargs["prompt_embeds"] = conditioning[0]
300
+ pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
301
+ pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
302
+ pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
303
+
304
+ print("[OK] Using Compel-encoded prompts")
305
+ except Exception as e:
306
+ print(f"[COMPEL] Failed, using standard prompts: {e}")
307
+ pipe_kwargs["prompt"] = prompt
308
+ pipe_kwargs["negative_prompt"] = negative_prompt
309
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  pipe_kwargs["prompt"] = prompt
311
  pipe_kwargs["negative_prompt"] = negative_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ # Configure ControlNets + IP-Adapter (SIMPLIFIED!)
314
+ if has_detected_faces and face_kps_image is not None:
315
+ print("Using InstantID (keypoints + embeddings) + Depth ControlNets")
 
 
 
 
 
 
 
 
 
 
316
 
317
+ # Control images: [face keypoints, depth map]
318
+ pipe_kwargs["control_image"] = [face_kps_image, depth_image]
 
319
 
320
+ # Conditioning scales: [identity, depth]
321
+ pipe_kwargs["controlnet_conditioning_scale"] = [
322
+ identity_control_scale,
323
+ depth_control_scale
324
+ ]
325
 
326
+ # IP-Adapter face embeddings (SIMPLE - pipeline handles everything!)
327
+ if face_embeddings is not None:
328
+ print(f"Adding face embeddings for IP-Adapter...")
329
+
330
+ # Just pass the embeddings - pipeline does the rest!
331
+ pipe_kwargs["image_embeds"] = face_embeddings
332
+
333
+ # Control IP-Adapter strength
334
+ pipe_kwargs["ip_adapter_scale"] = identity_preservation
335
+
336
+ print(f" - Face embeddings shape: {face_embeddings.shape}")
337
+ print(f" - IP-Adapter scale: {identity_preservation}")
338
+ print(f" [OK] Face embeddings configured")
339
+ else:
340
+ print(" [WARNING] No face embeddings - using keypoints only")
341
+
342
  else:
343
+ print("No faces detected - using Depth ControlNet only")
344
+
345
+ # Use depth for both ControlNet slots (identity scale = 0)
346
+ pipe_kwargs["control_image"] = [depth_image, depth_image]
347
+ pipe_kwargs["controlnet_conditioning_scale"] = [0.0, depth_control_scale]
348
 
349
+ # Generate
350
+ print(f"Generating: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
351
+ result = self.pipe(**pipe_kwargs)
352
+
353
+ generated_image = result.images[0]
354
+
355
+ # Post-processing: Color matching
356
+ if enable_color_matching and has_detected_faces:
357
+ print("Applying enhanced face-aware color matching...")
358
+ try:
359
+ if face_bbox_original is not None:
360
+ generated_image = enhanced_color_match(
361
+ generated_image,
362
+ resized_image,
363
+ face_bbox=face_bbox_original
364
+ )
365
+ print("[OK] Enhanced color matching applied")
366
+ else:
367
+ generated_image = color_match(generated_image, resized_image, mode='mkl')
368
+ print("[OK] Standard color matching applied")
369
+ except Exception as e:
370
+ print(f"[COLOR] Matching failed: {e}")
371
+ elif enable_color_matching:
372
+ print("Applying standard color matching...")
373
+ try:
374
  generated_image = color_match(generated_image, resized_image, mode='mkl')
375
  print("[OK] Standard color matching applied")
376
+ except Exception as e:
377
+ print(f"[COLOR] Matching failed: {e}")
378
+
379
+ return generated_image
 
 
 
 
 
 
 
 
 
380
 
381
+ finally:
382
+ # Memory cleanup
383
+ if torch.cuda.is_available():
384
+ torch.cuda.empty_cache()
385
+ gc.collect()
386
 
387
 
388
+ print("[OK] Generator class ready with InstantID support")