primerz commited on
Commit
bddeb26
·
verified ·
1 Parent(s): 0b87c27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -50
app.py CHANGED
@@ -30,7 +30,7 @@ class RetroArtConverter:
30
  self.device = device
31
  self.dtype = dtype
32
 
33
- # Initialize face analysis for InstantID (optional)
34
  print("Loading face analysis model...")
35
  try:
36
  self.face_app = FaceAnalysis(
@@ -54,6 +54,22 @@ class RetroArtConverter:
54
  torch_dtype=self.dtype
55
  ).to(self.device)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Load custom VAE from HuggingFace Hub
58
  print("Loading custom VAE (pixelate) from HuggingFace Hub...")
59
  try:
@@ -83,6 +99,11 @@ class RetroArtConverter:
83
  device=self.device if self.device == "cuda" else -1
84
  )
85
 
 
 
 
 
 
86
  # Load SDXL checkpoint from HuggingFace Hub
87
  print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
88
  try:
@@ -93,7 +114,7 @@ class RetroArtConverter:
93
  )
94
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
95
  model_path,
96
- controlnet=self.controlnet_depth,
97
  vae=self.vae,
98
  torch_dtype=self.dtype,
99
  use_safetensors=True
@@ -104,7 +125,7 @@ class RetroArtConverter:
104
  print("Using default SDXL")
105
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
106
  "stabilityai/stable-diffusion-xl-base-1.0",
107
- controlnet=self.controlnet_depth,
108
  vae=self.vae,
109
  torch_dtype=self.dtype,
110
  use_safetensors=True
@@ -129,15 +150,9 @@ class RetroArtConverter:
129
  self.pipe.scheduler.config
130
  )
131
 
132
- # For ZeroGPU, don't use model_cpu_offload
133
- # self.pipe.enable_model_cpu_offload()
134
-
135
  self.pipe.enable_vae_slicing()
136
-
137
- # Enable attention slicing for memory efficiency
138
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
139
 
140
- # Try to enable xformers if available (only works on GPU)
141
  if self.device == "cuda":
142
  try:
143
  self.pipe.enable_xformers_memory_efficient_attention()
@@ -152,30 +167,44 @@ class RetroArtConverter:
152
  depth = self.depth_estimator(image)
153
  depth_image = depth['depth']
154
 
155
- # Convert to numpy array
156
  depth_array = np.array(depth_image)
157
-
158
- # Normalize to 0-255
159
  depth_normalized = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min()) * 255
160
  depth_normalized = depth_normalized.astype(np.uint8)
161
-
162
- # Convert to 3-channel image
163
  depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
164
 
165
  return Image.fromarray(depth_colored)
166
 
167
- def detect_faces(self, image):
168
- """Detect faces in the image using antelopev2"""
169
  if not self.face_detection_enabled or self.face_app is None:
170
- return []
171
 
172
  try:
173
  img_array = np.array(image)
174
  faces = self.face_app.get(img_array)
175
- return faces
 
 
 
 
 
 
176
  except Exception as e:
177
- print(f"Face detection error: {e}")
178
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def calculate_target_size(self, original_width, original_height, max_dimension=1024):
181
  """Calculate target size maintaining aspect ratio"""
@@ -188,7 +217,7 @@ class RetroArtConverter:
188
  new_height = min(original_height, max_dimension)
189
  new_width = int(new_height * aspect_ratio)
190
 
191
- # Round to nearest multiple of 8 (required for diffusion models)
192
  new_width = (new_width // 8) * 8
193
  new_height = (new_height // 8) * 8
194
 
@@ -202,7 +231,9 @@ class RetroArtConverter:
202
  num_inference_steps=30,
203
  guidance_scale=7.5,
204
  controlnet_conditioning_scale=0.8,
205
- lora_scale=0.85
 
 
206
  ):
207
  """Main generation function"""
208
 
@@ -214,36 +245,70 @@ class RetroArtConverter:
214
 
215
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
216
 
217
- # Detect faces
218
- faces = self.detect_faces(resized_image)
219
- has_faces = len(faces) > 0
220
-
221
- if has_faces:
222
- print(f"Detected {len(faces)} face(s)")
223
- # Enhance prompt for face preservation
224
- prompt = f"portrait, detailed face, {prompt}"
225
-
226
  # Generate depth map
227
  print("Generating depth map...")
228
  depth_image = self.get_depth_map(resized_image)
229
  depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # Set LORA scale
232
- self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  # Generate image
235
  print("Generating retro art...")
236
- result = self.pipe(
237
- prompt=prompt,
238
- negative_prompt=negative_prompt,
239
- image=depth_image,
240
- num_inference_steps=num_inference_steps,
241
- guidance_scale=guidance_scale,
242
- controlnet_conditioning_scale=controlnet_conditioning_scale,
243
- width=target_width,
244
- height=target_height,
245
- generator=torch.Generator(device=self.device).manual_seed(42)
246
- )
247
 
248
  return result.images[0]
249
 
@@ -260,7 +325,9 @@ def process_image(
260
  steps,
261
  guidance_scale,
262
  controlnet_scale,
263
- lora_scale
 
 
264
  ):
265
  if image is None:
266
  return None
@@ -273,11 +340,15 @@ def process_image(
273
  num_inference_steps=int(steps),
274
  guidance_scale=guidance_scale,
275
  controlnet_conditioning_scale=controlnet_scale,
276
- lora_scale=lora_scale
 
 
277
  )
278
  return result
279
  except Exception as e:
280
  print(f"Error: {e}")
 
 
281
  raise gr.Error(f"Generation failed: {str(e)}")
282
 
283
  # Create Gradio interface
@@ -291,7 +362,7 @@ with gr.Blocks(title="RetroArt Converter") as demo:
291
  - Custom SDXL checkpoint (Horizon)
292
  - Pixelate VAE for authentic retro look
293
  - RetroArt LORA for style enhancement
294
- - Face preservation with InstantID
295
  - Depth-aware generation with ControlNet
296
  """)
297
 
@@ -343,6 +414,23 @@ with gr.Blocks(title="RetroArt Converter") as demo:
343
  step=0.05,
344
  label="RetroArt LORA Scale"
345
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  generate_btn = gr.Button("🎨 Generate Retro Art", variant="primary")
348
 
@@ -351,9 +439,9 @@ with gr.Blocks(title="RetroArt Converter") as demo:
351
 
352
  gr.Examples(
353
  examples=[
354
- ["example_portrait.jpg", "retro pixel art portrait, 16-bit game character", "blurry, modern", 30, 7.5, 0.8, 0.85],
355
  ],
356
- inputs=[input_image, prompt, negative_prompt, steps, guidance_scale, controlnet_scale, lora_scale],
357
  outputs=[output_image],
358
  fn=process_image,
359
  cache_examples=False
@@ -361,7 +449,7 @@ with gr.Blocks(title="RetroArt Converter") as demo:
361
 
362
  generate_btn.click(
363
  fn=process_image,
364
- inputs=[input_image, prompt, negative_prompt, steps, guidance_scale, controlnet_scale, lora_scale],
365
  outputs=[output_image]
366
  )
367
 
@@ -372,5 +460,5 @@ if __name__ == "__main__":
372
  server_name="0.0.0.0",
373
  server_port=7860,
374
  share=False,
375
- show_api=True # Enable API
376
  )
 
30
  self.device = device
31
  self.dtype = dtype
32
 
33
+ # Initialize face analysis for InstantID
34
  print("Loading face analysis model...")
35
  try:
36
  self.face_app = FaceAnalysis(
 
54
  torch_dtype=self.dtype
55
  ).to(self.device)
56
 
57
+ # Load InstantID ControlNet for identity preservation
58
+ print("Loading InstantID ControlNet...")
59
+ try:
60
+ self.controlnet_instantid = ControlNetModel.from_pretrained(
61
+ "InstantX/InstantID",
62
+ subfolder="ControlNetModel",
63
+ torch_dtype=self.dtype
64
+ ).to(self.device)
65
+ print("✓ InstantID ControlNet loaded successfully")
66
+ self.instantid_enabled = True
67
+ except Exception as e:
68
+ print(f"⚠️ InstantID ControlNet not available: {e}")
69
+ print("Running without InstantID (identity may not be preserved)")
70
+ self.controlnet_instantid = None
71
+ self.instantid_enabled = False
72
+
73
  # Load custom VAE from HuggingFace Hub
74
  print("Loading custom VAE (pixelate) from HuggingFace Hub...")
75
  try:
 
99
  device=self.device if self.device == "cuda" else -1
100
  )
101
 
102
+ # Determine which controlnets to use
103
+ controlnets = [self.controlnet_depth]
104
+ if self.instantid_enabled and self.controlnet_instantid is not None:
105
+ controlnets.append(self.controlnet_instantid)
106
+
107
  # Load SDXL checkpoint from HuggingFace Hub
108
  print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
109
  try:
 
114
  )
115
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
116
  model_path,
117
+ controlnet=controlnets if len(controlnets) > 1 else controlnets[0],
118
  vae=self.vae,
119
  torch_dtype=self.dtype,
120
  use_safetensors=True
 
125
  print("Using default SDXL")
126
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
127
  "stabilityai/stable-diffusion-xl-base-1.0",
128
+ controlnet=controlnets if len(controlnets) > 1 else controlnets[0],
129
  vae=self.vae,
130
  torch_dtype=self.dtype,
131
  use_safetensors=True
 
150
  self.pipe.scheduler.config
151
  )
152
 
 
 
 
153
  self.pipe.enable_vae_slicing()
 
 
154
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
155
 
 
156
  if self.device == "cuda":
157
  try:
158
  self.pipe.enable_xformers_memory_efficient_attention()
 
167
  depth = self.depth_estimator(image)
168
  depth_image = depth['depth']
169
 
 
170
  depth_array = np.array(depth_image)
 
 
171
  depth_normalized = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min()) * 255
172
  depth_normalized = depth_normalized.astype(np.uint8)
 
 
173
  depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
174
 
175
  return Image.fromarray(depth_colored)
176
 
177
+ def extract_face_embeddings(self, image):
178
+ """Extract face embeddings using InsightFace"""
179
  if not self.face_detection_enabled or self.face_app is None:
180
+ return None
181
 
182
  try:
183
  img_array = np.array(image)
184
  faces = self.face_app.get(img_array)
185
+
186
+ if len(faces) == 0:
187
+ return None
188
+
189
+ # Use the largest face
190
+ face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
191
+ return torch.from_numpy(face.normed_embedding).unsqueeze(0)
192
  except Exception as e:
193
+ print(f"Face embedding extraction error: {e}")
194
+ return None
195
+
196
+ def prepare_face_image(self, image, face_bbox):
197
+ """Prepare face image for InstantID ControlNet"""
198
+ x1, y1, x2, y2 = map(int, face_bbox)
199
+ # Add some padding
200
+ padding = 20
201
+ x1 = max(0, x1 - padding)
202
+ y1 = max(0, y1 - padding)
203
+ x2 = min(image.width, x2 + padding)
204
+ y2 = min(image.height, y2 + padding)
205
+
206
+ face_image = image.crop((x1, y1, x2, y2))
207
+ return face_image
208
 
209
  def calculate_target_size(self, original_width, original_height, max_dimension=1024):
210
  """Calculate target size maintaining aspect ratio"""
 
217
  new_height = min(original_height, max_dimension)
218
  new_width = int(new_height * aspect_ratio)
219
 
220
+ # Round to nearest multiple of 8
221
  new_width = (new_width // 8) * 8
222
  new_height = (new_height // 8) * 8
223
 
 
231
  num_inference_steps=30,
232
  guidance_scale=7.5,
233
  controlnet_conditioning_scale=0.8,
234
+ lora_scale=0.85,
235
+ identity_preservation=0.8, # NEW PARAMETER
236
+ image_scale=0.2 # NEW PARAMETER for InstantID strength
237
  ):
238
  """Main generation function"""
239
 
 
245
 
246
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
247
 
 
 
 
 
 
 
 
 
 
248
  # Generate depth map
249
  print("Generating depth map...")
250
  depth_image = self.get_depth_map(resized_image)
251
  depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
252
 
253
+ # Extract face embeddings if InstantID is enabled
254
+ face_embeddings = None
255
+ control_images = [depth_image]
256
+ conditioning_scales = [controlnet_conditioning_scale]
257
+
258
+ if self.instantid_enabled and self.controlnet_instantid is not None:
259
+ print("Extracting face embeddings...")
260
+ img_array = np.array(resized_image)
261
+ faces = self.face_app.get(img_array) if self.face_app is not None else []
262
+
263
+ if len(faces) > 0:
264
+ print(f"Detected {len(faces)} face(s), using for identity preservation")
265
+ # Get the largest face
266
+ face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
267
+ face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(self.device, dtype=self.dtype)
268
+
269
+ # Prepare face image for InstantID ControlNet
270
+ face_control_image = resized_image.resize((target_width, target_height), Image.LANCZOS)
271
+ control_images.append(face_control_image)
272
+ conditioning_scales.append(image_scale)
273
+
274
+ # Enhance prompt for face preservation
275
+ prompt = f"portrait, detailed face, facial features, {prompt}"
276
+
277
  # Set LORA scale
278
+ if hasattr(self.pipe, 'set_adapters'):
279
+ try:
280
+ self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
281
+ except:
282
+ print("Could not set LORA adapters, continuing without")
283
+
284
+ # Prepare pipeline kwargs
285
+ pipe_kwargs = {
286
+ "prompt": prompt,
287
+ "negative_prompt": negative_prompt,
288
+ "num_inference_steps": num_inference_steps,
289
+ "guidance_scale": guidance_scale,
290
+ "width": target_width,
291
+ "height": target_height,
292
+ "generator": torch.Generator(device=self.device).manual_seed(42)
293
+ }
294
+
295
+ # Add control images and scales
296
+ if len(control_images) > 1:
297
+ # Multiple ControlNets
298
+ pipe_kwargs["image"] = control_images
299
+ pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
300
+ else:
301
+ # Single ControlNet (depth only)
302
+ pipe_kwargs["image"] = depth_image
303
+ pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
304
+
305
+ # Add face embeddings if available (for InstantID IP-Adapter)
306
+ if face_embeddings is not None:
307
+ pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_image_embeds": [face_embeddings]}
308
 
309
  # Generate image
310
  print("Generating retro art...")
311
+ result = self.pipe(**pipe_kwargs)
 
 
 
 
 
 
 
 
 
 
312
 
313
  return result.images[0]
314
 
 
325
  steps,
326
  guidance_scale,
327
  controlnet_scale,
328
+ lora_scale,
329
+ identity_preservation, # NEW
330
+ image_scale # NEW
331
  ):
332
  if image is None:
333
  return None
 
340
  num_inference_steps=int(steps),
341
  guidance_scale=guidance_scale,
342
  controlnet_conditioning_scale=controlnet_scale,
343
+ lora_scale=lora_scale,
344
+ identity_preservation=identity_preservation, # NEW
345
+ image_scale=image_scale # NEW
346
  )
347
  return result
348
  except Exception as e:
349
  print(f"Error: {e}")
350
+ import traceback
351
+ traceback.print_exc()
352
  raise gr.Error(f"Generation failed: {str(e)}")
353
 
354
  # Create Gradio interface
 
362
  - Custom SDXL checkpoint (Horizon)
363
  - Pixelate VAE for authentic retro look
364
  - RetroArt LORA for style enhancement
365
+ - Face preservation with InstantID (if available)
366
  - Depth-aware generation with ControlNet
367
  """)
368
 
 
414
  step=0.05,
415
  label="RetroArt LORA Scale"
416
  )
417
+
418
+ # NEW PARAMETERS
419
+ identity_preservation = gr.Slider(
420
+ minimum=0,
421
+ maximum=1.5,
422
+ value=0.8,
423
+ step=0.1,
424
+ label="Identity Preservation (InstantID strength)"
425
+ )
426
+
427
+ image_scale = gr.Slider(
428
+ minimum=0,
429
+ maximum=1.0,
430
+ value=0.2,
431
+ step=0.05,
432
+ label="InstantID Image Scale"
433
+ )
434
 
435
  generate_btn = gr.Button("🎨 Generate Retro Art", variant="primary")
436
 
 
439
 
440
  gr.Examples(
441
  examples=[
442
+ ["example_portrait.jpg", "retro pixel art portrait, 16-bit game character", "blurry, modern", 30, 7.5, 0.8, 0.85, 0.8, 0.2],
443
  ],
444
+ inputs=[input_image, prompt, negative_prompt, steps, guidance_scale, controlnet_scale, lora_scale, identity_preservation, image_scale],
445
  outputs=[output_image],
446
  fn=process_image,
447
  cache_examples=False
 
449
 
450
  generate_btn.click(
451
  fn=process_image,
452
+ inputs=[input_image, prompt, negative_prompt, steps, guidance_scale, controlnet_scale, lora_scale, identity_preservation, image_scale],
453
  outputs=[output_image]
454
  )
455
 
 
460
  server_name="0.0.0.0",
461
  server_port=7860,
462
  share=False,
463
+ show_api=True
464
  )