primerz commited on
Commit
962b8c2
ยท
verified ยท
1 Parent(s): 28615cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -210
app.py CHANGED
@@ -2,17 +2,18 @@ import spaces # MUST be first, before any CUDA-related imports
2
  import gradio as gr
3
  import torch
4
  from diffusers import (
 
5
  StableDiffusionXLControlNetPipeline,
6
  ControlNetModel,
7
  AutoencoderKL,
8
- LCMScheduler
9
  )
10
  from diffusers.models.attention_processor import AttnProcessor2_0
11
  from insightface.app import FaceAnalysis
12
  from PIL import Image
13
  import numpy as np
14
  import cv2
15
- from transformers import pipeline as transformers_pipeline, CLIPImageProcessor
16
  from huggingface_hub import hf_hub_download
17
  import os
18
 
@@ -21,8 +22,12 @@ MODEL_REPO = "primerz/pixagram"
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  dtype = torch.float16 if device == "cuda" else torch.float32
23
 
 
 
 
24
  print(f"Using device: {device}")
25
  print(f"Loading models from: {MODEL_REPO}")
 
26
 
27
  class RetroArtConverter:
28
  def __init__(self):
@@ -30,7 +35,6 @@ class RetroArtConverter:
30
  self.dtype = dtype
31
  self.models_loaded = {
32
  'custom_checkpoint': False,
33
- 'custom_vae': False,
34
  'lora': False,
35
  'instantid': False
36
  }
@@ -58,7 +62,7 @@ class RetroArtConverter:
58
  torch_dtype=self.dtype
59
  ).to(self.device)
60
 
61
- # Load InstantID ControlNet
62
  print("Loading InstantID ControlNet...")
63
  try:
64
  self.controlnet_instantid = ControlNetModel.from_pretrained(
@@ -74,42 +78,6 @@ class RetroArtConverter:
74
  self.controlnet_instantid = None
75
  self.instantid_enabled = False
76
 
77
- # Load IP-Adapter for InstantID
78
- print("Loading IP-Adapter for InstantID...")
79
- try:
80
- from transformers import CLIPVisionModelWithProjection
81
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
82
- "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
83
- torch_dtype=self.dtype
84
- ).to(self.device)
85
- print("โœ“ IP-Adapter image encoder loaded")
86
- except Exception as e:
87
- print(f"โš ๏ธ IP-Adapter not available: {e}")
88
- self.image_encoder = None
89
-
90
- # Load custom VAE from HuggingFace Hub
91
- print("Loading custom VAE (pixelate) from HuggingFace Hub...")
92
- try:
93
- vae_path = hf_hub_download(
94
- repo_id=MODEL_REPO,
95
- filename="pixelate.safetensors",
96
- repo_type="model"
97
- )
98
- self.vae = AutoencoderKL.from_single_file(
99
- vae_path,
100
- torch_dtype=self.dtype
101
- ).to(self.device)
102
- print("โœ“ Custom VAE loaded successfully")
103
- self.models_loaded['custom_vae'] = True
104
- except Exception as e:
105
- print(f"โš ๏ธ Could not load custom VAE: {e}")
106
- print("Using high-quality SDXL VAE")
107
- self.vae = AutoencoderKL.from_pretrained(
108
- "madebyollin/sdxl-vae-fp16-fix",
109
- torch_dtype=self.dtype
110
- ).to(self.device)
111
- self.models_loaded['custom_vae'] = False
112
-
113
  # Load depth estimator
114
  print("Loading depth estimator...")
115
  self.depth_estimator = transformers_pipeline(
@@ -118,7 +86,7 @@ class RetroArtConverter:
118
  device=self.device if self.device == "cuda" else -1
119
  )
120
 
121
- # Determine controlnets configuration
122
  if self.instantid_enabled and self.controlnet_instantid is not None:
123
  controlnets = [self.controlnet_depth, self.controlnet_instantid]
124
  print(f"Initializing with multiple ControlNets: Depth + InstantID")
@@ -127,7 +95,8 @@ class RetroArtConverter:
127
  print(f"Initializing with single ControlNet: Depth only")
128
 
129
  # Load SDXL checkpoint from HuggingFace Hub
130
- print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
 
131
  try:
132
  model_path = hf_hub_download(
133
  repo_id=MODEL_REPO,
@@ -137,19 +106,17 @@ class RetroArtConverter:
137
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
138
  model_path,
139
  controlnet=controlnets,
140
- vae=self.vae,
141
  torch_dtype=self.dtype,
142
  use_safetensors=True
143
  ).to(self.device)
144
- print("โœ“ Custom checkpoint loaded successfully")
145
  self.models_loaded['custom_checkpoint'] = True
146
  except Exception as e:
147
  print(f"โš ๏ธ Could not load custom checkpoint: {e}")
148
- print("Using default SDXL")
149
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
150
  "stabilityai/stable-diffusion-xl-base-1.0",
151
  controlnet=controlnets,
152
- vae=self.vae,
153
  torch_dtype=self.dtype,
154
  use_safetensors=True
155
  ).to(self.device)
@@ -164,24 +131,23 @@ class RetroArtConverter:
164
  repo_type="model"
165
  )
166
  self.pipe.load_lora_weights(lora_path)
167
- print("โœ“ LORA loaded successfully")
 
168
  self.models_loaded['lora'] = True
169
  except Exception as e:
170
  print(f"โš ๏ธ Could not load LORA: {e}")
171
  self.models_loaded['lora'] = False
172
 
173
- # CRITICAL: Set LCM Scheduler for fast generation
174
  print("Setting up LCM scheduler...")
175
  self.pipe.scheduler = LCMScheduler.from_config(
176
  self.pipe.scheduler.config
177
  )
178
 
179
- # Disable VAE slicing for better quality
180
- # self.pipe.enable_vae_slicing()
181
-
182
- # Enable memory optimizations
183
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
184
 
 
185
  if self.device == "cuda":
186
  try:
187
  self.pipe.enable_xformers_memory_efficient_attention()
@@ -189,8 +155,14 @@ class RetroArtConverter:
189
  except Exception as e:
190
  print(f"โš ๏ธ xformers not available: {e}")
191
 
 
 
 
 
 
192
  # Track controlnet configuration
193
  self.using_multiple_controlnets = isinstance(controlnets, list)
 
194
 
195
  print("\n=== MODEL STATUS ===")
196
  for model, loaded in self.models_loaded.items():
@@ -198,7 +170,15 @@ class RetroArtConverter:
198
  print(f"{model}: {status}")
199
  print("===================\n")
200
 
201
- print("Model initialization complete!")
 
 
 
 
 
 
 
 
202
 
203
  def get_depth_map(self, image):
204
  """Generate depth map from input image"""
@@ -215,59 +195,73 @@ class RetroArtConverter:
215
  # Slight blur to reduce noise
216
  depth_normalized = cv2.GaussianBlur(depth_normalized, (3, 3), 0)
217
 
 
218
  depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
219
 
220
  return Image.fromarray(depth_colored)
221
 
222
- def calculate_target_size(self, original_width, original_height, preferred_resolution="896x1152"):
223
- """Calculate target size based on recommended SDXL resolutions"""
224
- # Recommended resolutions for this model
225
- resolutions = {
226
- "896x1152": (896, 1152), # Portrait
227
- "832x1216": (832, 1216), # Tall portrait
228
- "1152x896": (1152, 896), # Landscape
229
- "1216x832": (1216, 832), # Wide landscape
230
- "1024x1024": (1024, 1024) # Square
231
- }
232
-
233
  aspect_ratio = original_width / original_height
234
 
235
- # Choose resolution based on aspect ratio
236
- if aspect_ratio < 0.85: # Tall portrait
237
- target_width, target_height = resolutions["832x1216"]
238
- elif aspect_ratio < 1.15: # Portrait to square
239
- if aspect_ratio < 1.0:
240
- target_width, target_height = resolutions["896x1152"]
241
- else:
242
- target_width, target_height = resolutions["1024x1024"]
243
- elif aspect_ratio < 1.35: # Landscape
244
- target_width, target_height = resolutions["1152x896"]
245
- else: # Wide landscape
246
- target_width, target_height = resolutions["1216x832"]
247
-
248
- return target_width, target_height
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  def generate_retro_art(
251
  self,
252
  input_image,
253
- prompt="retro pixel art game, 16-bit style, vibrant colors",
254
- negative_prompt="blurry, low quality, modern, photorealistic, 3d render",
255
- num_inference_steps=12, # LCM default: 12 steps
256
- guidance_scale=1.5, # LCM default: 1-1.5
257
- controlnet_conditioning_scale=0.6,
258
- lora_scale=0.85,
259
- identity_scale=0.9, # Stronger identity preservation
260
- image_scale=0.5, # Stronger InstantID influence
261
- clip_skip=2 # SDXL clip skip
262
  ):
263
- """Main generation function with LCM optimization"""
 
 
 
264
 
265
- # Calculate target size
266
  original_width, original_height = input_image.size
267
- target_width, target_height = self.calculate_target_size(original_width, original_height)
268
 
269
  print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
 
270
 
 
271
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
272
 
273
  # Generate depth map
@@ -275,81 +269,59 @@ class RetroArtConverter:
275
  depth_image = self.get_depth_map(resized_image)
276
  depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
277
 
278
- # IMPORTANT: Add LORA trigger word
279
- lora_trigger = "p1x3l4rt, pixel art"
280
- if lora_trigger not in prompt:
281
- prompt = f"{lora_trigger}, {prompt}"
282
- print(f"Added LORA trigger word: {lora_trigger}")
283
-
284
- # Check if using multiple controlnets
285
  using_multiple_controlnets = self.using_multiple_controlnets
286
-
287
- # Extract face embeddings for InstantID
288
  face_embeddings = None
289
  has_detected_faces = False
290
 
291
- if using_multiple_controlnets and self.face_app is not None:
292
- print("Extracting face embeddings...")
293
  img_array = np.array(resized_image)
294
- faces = self.face_app.get(img_array)
295
 
296
  if len(faces) > 0:
297
  has_detected_faces = True
298
  print(f"Detected {len(faces)} face(s)")
299
-
300
- # Get the largest face
301
  face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
302
-
303
- # Extract embedding
304
  face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(self.device, dtype=self.dtype)
305
-
306
- # Enhance prompt for better face preservation
307
- prompt = f"detailed face, portrait, facial features, {prompt}"
308
- print(f"Face detected, enhanced prompt for identity preservation")
309
 
310
  # Set LORA scale
311
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
312
  try:
313
  self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
 
314
  except Exception as e:
315
  print(f"Could not set LORA scale: {e}")
316
 
317
- # Enhanced negative prompt
318
- full_negative = f"{negative_prompt}, worst quality, normal quality, lowres, watermark, text"
319
-
320
- # Prepare pipeline kwargs
321
  pipe_kwargs = {
322
  "prompt": prompt,
323
- "negative_prompt": full_negative,
324
  "num_inference_steps": num_inference_steps,
325
  "guidance_scale": guidance_scale,
326
  "width": target_width,
327
  "height": target_height,
328
- "generator": torch.Generator(device=self.device).manual_seed(42),
329
- "clip_skip": clip_skip
330
  }
331
 
332
- # Configure control images based on setup
 
 
 
 
333
  if using_multiple_controlnets and has_detected_faces:
334
- print(f"Using Depth + InstantID (identity_scale={identity_scale}, image_scale={image_scale})")
335
-
336
- # For InstantID, use the original image
337
  control_images = [depth_image, resized_image]
338
  conditioning_scales = [controlnet_conditioning_scale, image_scale]
339
 
340
  pipe_kwargs["image"] = control_images
341
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
342
 
343
- # Add face embeddings with stronger influence
344
  if face_embeddings is not None:
345
- # Scale up the face embeddings for stronger identity
346
- scaled_embeddings = face_embeddings * identity_scale
347
- pipe_kwargs["cross_attention_kwargs"] = {
348
- "ip_adapter_image_embeds": [scaled_embeddings]
349
- }
350
 
351
  elif using_multiple_controlnets and not has_detected_faces:
352
- print("Multiple ControlNets but no faces detected")
353
  control_images = [depth_image, depth_image]
354
  conditioning_scales = [controlnet_conditioning_scale, 0.0]
355
 
@@ -362,16 +334,15 @@ class RetroArtConverter:
362
  pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
363
 
364
  # Generate
365
- print(f"Generating with LCM: {num_inference_steps} steps, CFG {guidance_scale}")
366
  result = self.pipe(**pipe_kwargs)
367
 
368
  return result.images[0]
369
 
370
  # Initialize converter
371
- print("Initializing RetroArt Converter with LCM...")
372
  converter = RetroArtConverter()
373
 
374
- # Gradio interface
375
  @spaces.GPU
376
  def process_image(
377
  image,
@@ -381,7 +352,7 @@ def process_image(
381
  guidance_scale,
382
  controlnet_scale,
383
  lora_scale,
384
- identity_scale,
385
  image_scale
386
  ):
387
  if image is None:
@@ -396,9 +367,8 @@ def process_image(
396
  guidance_scale=guidance_scale,
397
  controlnet_conditioning_scale=controlnet_scale,
398
  lora_scale=lora_scale,
399
- identity_scale=identity_scale,
400
- image_scale=image_scale,
401
- clip_skip=2
402
  )
403
  return result
404
  except Exception as e:
@@ -407,100 +377,103 @@ def process_image(
407
  traceback.print_exc()
408
  raise gr.Error(f"Generation failed: {str(e)}")
409
 
410
- # Create Gradio interface
411
  with gr.Blocks(title="RetroArt Converter - LCM", theme=gr.themes.Soft()) as demo:
412
  gr.Markdown("""
413
- # ๐ŸŽฎ RetroArt Converter - LCM Optimized
414
 
415
- Convert images to retro pixel art using **LCM (Latent Consistency Model)** for fast generation!
416
 
417
- **Key Features:**
418
- - โšก Fast generation (12 steps)
419
- - ๐ŸŽจ LORA trigger: "p1x3l4rt, pixel art" (auto-added)
420
- - ๐Ÿ‘ค Strong InstantID for face preservation
421
- - ๐ŸŽฏ Optimized SDXL resolutions (896x1152, 832x1216)
422
- - ๐Ÿ“ Clip Skip 2
423
  """)
424
 
425
  # Model status
426
  if converter.models_loaded:
427
- status_md = "**Model Status:**\n"
428
- status_md += f"- Custom Checkpoint: {'โœ“' if converter.models_loaded['custom_checkpoint'] else 'โœ— Fallback'}\n"
429
- status_md += f"- Custom VAE: {'โœ“' if converter.models_loaded['custom_vae'] else 'โœ— Fallback'}\n"
430
- status_md += f"- LORA: {'โœ“' if converter.models_loaded['lora'] else 'โœ— Fallback'}\n"
431
- status_md += f"- InstantID: {'โœ“' if converter.models_loaded['instantid'] else 'โœ— Disabled'}\n"
432
- gr.Markdown(status_md)
 
 
 
 
 
 
 
 
433
 
434
  with gr.Row():
435
  with gr.Column():
436
  input_image = gr.Image(label="Input Image", type="pil")
437
 
438
  prompt = gr.Textbox(
439
- label='Prompt (trigger "p1x3l4rt, pixel art" auto-added)',
440
- value="retro pixel art game, 16-bit style, vibrant colors, detailed",
441
- lines=2,
442
- info="Don't include trigger word - it's added automatically"
443
  )
444
 
445
  negative_prompt = gr.Textbox(
446
  label="Negative Prompt",
447
- value="blurry, low quality, modern, photorealistic, 3d render, ugly, distorted",
448
  lines=2
449
  )
450
 
451
- gr.Markdown("### โšก LCM Settings (Optimized)")
452
-
453
- with gr.Row():
454
  steps = gr.Slider(
455
  minimum=4,
456
  maximum=20,
457
  value=12,
458
  step=1,
459
- label="Steps (LCM recommended: 12)"
460
  )
461
 
462
  guidance_scale = gr.Slider(
463
- minimum=1.0,
464
  maximum=3.0,
465
- value=1.5,
466
  step=0.1,
467
- label="CFG Scale (LCM recommended: 1-1.5)"
468
  )
469
-
470
- with gr.Accordion("Advanced Settings", open=False):
471
  controlnet_scale = gr.Slider(
472
- minimum=0,
473
- maximum=1.5,
474
- value=0.6,
475
  step=0.05,
476
  label="ControlNet Depth Scale"
477
  )
478
 
479
  lora_scale = gr.Slider(
480
- minimum=0,
481
- maximum=2,
482
- value=0.85,
483
  step=0.05,
484
  label="RetroArt LORA Scale"
485
  )
486
 
487
- gr.Markdown("### ๐Ÿ‘ค InstantID Settings (Stronger)")
488
-
489
- with gr.Row():
490
- identity_scale = gr.Slider(
491
- minimum=0.5,
492
- maximum=2.0,
493
- value=0.9,
494
  step=0.1,
495
- label="Identity Strength (higher = more truthful)"
496
  )
497
 
498
  image_scale = gr.Slider(
499
  minimum=0,
500
- maximum=1.5,
501
- value=0.5,
502
  step=0.05,
503
- label="InstantID ControlNet Scale"
504
  )
505
 
506
  generate_btn = gr.Button("๐ŸŽจ Generate Retro Art", variant="primary", size="lg")
@@ -509,42 +482,29 @@ with gr.Blocks(title="RetroArt Converter - LCM", theme=gr.themes.Soft()) as demo
509
  output_image = gr.Image(label="Retro Art Output")
510
 
511
  gr.Markdown("""
512
- ### โšก LCM Quick Tips:
513
- - **12 steps** is optimal for LCM (faster than traditional 40-50)
514
- - **CFG 1-1.5** works best (not 7-8 like traditional)
515
- - LORA trigger **"p1x3l4rt, pixel art"** is auto-added
516
- - For stronger identity: increase **Identity Strength** to 1.2-1.5
517
- - Resolution auto-selected: 896x1152 (portrait) or 1152x896 (landscape)
 
 
 
 
 
 
518
 
519
- ### ๐Ÿ‘ค Face Preservation:
520
- - **Identity Strength 0.9-1.2**: Balanced retro + identity
521
- - **Identity Strength 1.3-2.0**: Maximum face accuracy
522
- - **Image Scale 0.5-0.8**: Strong InstantID influence
523
  """)
524
 
525
- gr.Examples(
526
- examples=[
527
- [
528
- "example_portrait.jpg",
529
- "retro pixel art portrait, 16-bit game character, detailed face",
530
- "blurry, modern, low quality",
531
- 12, 1.5, 0.6, 0.85, 0.9, 0.5
532
- ],
533
- ],
534
- inputs=[
535
- input_image, prompt, negative_prompt, steps, guidance_scale,
536
- controlnet_scale, lora_scale, identity_scale, image_scale
537
- ],
538
- outputs=[output_image],
539
- fn=process_image,
540
- cache_examples=False
541
- )
542
-
543
  generate_btn.click(
544
  fn=process_image,
545
  inputs=[
546
  input_image, prompt, negative_prompt, steps, guidance_scale,
547
- controlnet_scale, lora_scale, identity_scale, image_scale
548
  ],
549
  outputs=[output_image]
550
  )
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import (
5
+ StableDiffusionXLPipeline,
6
  StableDiffusionXLControlNetPipeline,
7
  ControlNetModel,
8
  AutoencoderKL,
9
+ LCMScheduler # CORRECT SCHEDULER FOR LCM
10
  )
11
  from diffusers.models.attention_processor import AttnProcessor2_0
12
  from insightface.app import FaceAnalysis
13
  from PIL import Image
14
  import numpy as np
15
  import cv2
16
+ from transformers import pipeline as transformers_pipeline
17
  from huggingface_hub import hf_hub_download
18
  import os
19
 
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  dtype = torch.float16 if device == "cuda" else torch.float32
24
 
25
+ # LORA trigger word
26
+ TRIGGER_WORD = "p1x3l4rt, pixel art"
27
+
28
  print(f"Using device: {device}")
29
  print(f"Loading models from: {MODEL_REPO}")
30
+ print(f"LORA Trigger Word: {TRIGGER_WORD}")
31
 
32
  class RetroArtConverter:
33
  def __init__(self):
 
35
  self.dtype = dtype
36
  self.models_loaded = {
37
  'custom_checkpoint': False,
 
38
  'lora': False,
39
  'instantid': False
40
  }
 
62
  torch_dtype=self.dtype
63
  ).to(self.device)
64
 
65
+ # Load InstantID ControlNet (optional)
66
  print("Loading InstantID ControlNet...")
67
  try:
68
  self.controlnet_instantid = ControlNetModel.from_pretrained(
 
78
  self.controlnet_instantid = None
79
  self.instantid_enabled = False
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Load depth estimator
82
  print("Loading depth estimator...")
83
  self.depth_estimator = transformers_pipeline(
 
86
  device=self.device if self.device == "cuda" else -1
87
  )
88
 
89
+ # Determine which controlnets to use
90
  if self.instantid_enabled and self.controlnet_instantid is not None:
91
  controlnets = [self.controlnet_depth, self.controlnet_instantid]
92
  print(f"Initializing with multiple ControlNets: Depth + InstantID")
 
95
  print(f"Initializing with single ControlNet: Depth only")
96
 
97
  # Load SDXL checkpoint from HuggingFace Hub
98
+ # NOTE: VAE is bundled in the checkpoint, don't load separately!
99
+ print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
100
  try:
101
  model_path = hf_hub_download(
102
  repo_id=MODEL_REPO,
 
106
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
107
  model_path,
108
  controlnet=controlnets,
 
109
  torch_dtype=self.dtype,
110
  use_safetensors=True
111
  ).to(self.device)
112
+ print("โœ“ Custom checkpoint loaded successfully (VAE bundled)")
113
  self.models_loaded['custom_checkpoint'] = True
114
  except Exception as e:
115
  print(f"โš ๏ธ Could not load custom checkpoint: {e}")
116
+ print("Using default SDXL base model")
117
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
118
  "stabilityai/stable-diffusion-xl-base-1.0",
119
  controlnet=controlnets,
 
120
  torch_dtype=self.dtype,
121
  use_safetensors=True
122
  ).to(self.device)
 
131
  repo_type="model"
132
  )
133
  self.pipe.load_lora_weights(lora_path)
134
+ print(f"โœ“ LORA loaded successfully")
135
+ print(f" Trigger word: '{TRIGGER_WORD}'")
136
  self.models_loaded['lora'] = True
137
  except Exception as e:
138
  print(f"โš ๏ธ Could not load LORA: {e}")
139
  self.models_loaded['lora'] = False
140
 
141
+ # CRITICAL: Use LCM Scheduler for this model!
142
  print("Setting up LCM scheduler...")
143
  self.pipe.scheduler = LCMScheduler.from_config(
144
  self.pipe.scheduler.config
145
  )
146
 
147
+ # Enable attention optimizations
 
 
 
148
  self.pipe.unet.set_attn_processor(AttnProcessor2_0())
149
 
150
+ # Try to enable xformers
151
  if self.device == "cuda":
152
  try:
153
  self.pipe.enable_xformers_memory_efficient_attention()
 
155
  except Exception as e:
156
  print(f"โš ๏ธ xformers not available: {e}")
157
 
158
+ # Set CLIP skip to 2
159
+ if hasattr(self.pipe, 'text_encoder'):
160
+ self.clip_skip = 2
161
+ print(f"โœ“ CLIP skip set to {self.clip_skip}")
162
+
163
  # Track controlnet configuration
164
  self.using_multiple_controlnets = isinstance(controlnets, list)
165
+ print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
166
 
167
  print("\n=== MODEL STATUS ===")
168
  for model, loaded in self.models_loaded.items():
 
170
  print(f"{model}: {status}")
171
  print("===================\n")
172
 
173
+ print("โœ“ Model initialization complete!")
174
+ print("\n=== LCM CONFIGURATION ===")
175
+ print("Scheduler: LCM")
176
+ print("Recommended Steps: 12")
177
+ print("Recommended CFG: 1.0-1.5")
178
+ print("Recommended Resolution: 896x1152 or 832x1216")
179
+ print("CLIP Skip: 2")
180
+ print(f"LORA Trigger: '{TRIGGER_WORD}'")
181
+ print("=========================\n")
182
 
183
  def get_depth_map(self, image):
184
  """Generate depth map from input image"""
 
195
  # Slight blur to reduce noise
196
  depth_normalized = cv2.GaussianBlur(depth_normalized, (3, 3), 0)
197
 
198
+ # Convert to RGB
199
  depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
200
 
201
  return Image.fromarray(depth_colored)
202
 
203
+ def calculate_optimal_size(self, original_width, original_height):
204
+ """Calculate optimal size from recommended resolutions"""
 
 
 
 
 
 
 
 
 
205
  aspect_ratio = original_width / original_height
206
 
207
+ # Recommended resolutions for this model
208
+ recommended_sizes = [
209
+ (896, 1152), # Portrait
210
+ (1152, 896), # Landscape
211
+ (832, 1216), # Tall portrait
212
+ (1216, 832), # Wide landscape
213
+ (1024, 1024) # Square
214
+ ]
215
+
216
+ # Find closest matching aspect ratio
217
+ best_match = None
218
+ best_diff = float('inf')
219
+
220
+ for width, height in recommended_sizes:
221
+ rec_aspect = width / height
222
+ diff = abs(rec_aspect - aspect_ratio)
223
+ if diff < best_diff:
224
+ best_diff = diff
225
+ best_match = (width, height)
226
+
227
+ # Ensure dimensions are multiples of 8
228
+ width, height = best_match
229
+ width = (width // 8) * 8
230
+ height = (height // 8) * 8
231
+
232
+ return width, height
233
+
234
+ def add_trigger_word(self, prompt):
235
+ """Add trigger word to prompt if not present"""
236
+ if TRIGGER_WORD.lower() not in prompt.lower():
237
+ return f"{TRIGGER_WORD}, {prompt}"
238
+ return prompt
239
 
240
  def generate_retro_art(
241
  self,
242
  input_image,
243
+ prompt="retro game character, vibrant colors, detailed",
244
+ negative_prompt="blurry, low quality, ugly, distorted",
245
+ num_inference_steps=12, # LCM recommended: 12 steps
246
+ guidance_scale=1.0, # LCM recommended: 1.0-1.5
247
+ controlnet_conditioning_scale=0.8,
248
+ lora_scale=1.0,
249
+ identity_preservation=0.8,
250
+ image_scale=0.2
 
251
  ):
252
+ """Generate retro art with correct LCM settings"""
253
+
254
+ # Add trigger word to prompt
255
+ prompt = self.add_trigger_word(prompt)
256
 
257
+ # Calculate optimal size
258
  original_width, original_height = input_image.size
259
+ target_width, target_height = self.calculate_optimal_size(original_width, original_height)
260
 
261
  print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
262
+ print(f"Prompt: {prompt}")
263
 
264
+ # Resize with high quality
265
  resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
266
 
267
  # Generate depth map
 
269
  depth_image = self.get_depth_map(resized_image)
270
  depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
271
 
272
+ # Handle face detection for InstantID
 
 
 
 
 
 
273
  using_multiple_controlnets = self.using_multiple_controlnets
 
 
274
  face_embeddings = None
275
  has_detected_faces = False
276
 
277
+ if using_multiple_controlnets:
278
+ print("Checking for faces...")
279
  img_array = np.array(resized_image)
280
+ faces = self.face_app.get(img_array) if self.face_app is not None else []
281
 
282
  if len(faces) > 0:
283
  has_detected_faces = True
284
  print(f"Detected {len(faces)} face(s)")
 
 
285
  face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
 
 
286
  face_embeddings = torch.from_numpy(face.normed_embedding).unsqueeze(0).to(self.device, dtype=self.dtype)
 
 
 
 
287
 
288
  # Set LORA scale
289
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
290
  try:
291
  self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
292
+ print(f"LORA scale: {lora_scale}")
293
  except Exception as e:
294
  print(f"Could not set LORA scale: {e}")
295
 
296
+ # Prepare generation kwargs
 
 
 
297
  pipe_kwargs = {
298
  "prompt": prompt,
299
+ "negative_prompt": negative_prompt,
300
  "num_inference_steps": num_inference_steps,
301
  "guidance_scale": guidance_scale,
302
  "width": target_width,
303
  "height": target_height,
304
+ "generator": torch.Generator(device=self.device).manual_seed(42)
 
305
  }
306
 
307
+ # Add CLIP skip
308
+ if hasattr(self.pipe, 'text_encoder'):
309
+ pipe_kwargs["clip_skip"] = 2
310
+
311
+ # Configure ControlNet inputs
312
  if using_multiple_controlnets and has_detected_faces:
313
+ print("Using Depth + InstantID ControlNets")
 
 
314
  control_images = [depth_image, resized_image]
315
  conditioning_scales = [controlnet_conditioning_scale, image_scale]
316
 
317
  pipe_kwargs["image"] = control_images
318
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
319
 
 
320
  if face_embeddings is not None:
321
+ pipe_kwargs["cross_attention_kwargs"] = {"ip_adapter_image_embeds": [face_embeddings]}
 
 
 
 
322
 
323
  elif using_multiple_controlnets and not has_detected_faces:
324
+ print("Multiple ControlNets available but no faces detected")
325
  control_images = [depth_image, depth_image]
326
  conditioning_scales = [controlnet_conditioning_scale, 0.0]
327
 
 
334
  pipe_kwargs["controlnet_conditioning_scale"] = controlnet_conditioning_scale
335
 
336
  # Generate
337
+ print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}")
338
  result = self.pipe(**pipe_kwargs)
339
 
340
  return result.images[0]
341
 
342
  # Initialize converter
343
+ print("Initializing RetroArt Converter...")
344
  converter = RetroArtConverter()
345
 
 
346
  @spaces.GPU
347
  def process_image(
348
  image,
 
352
  guidance_scale,
353
  controlnet_scale,
354
  lora_scale,
355
+ identity_preservation,
356
  image_scale
357
  ):
358
  if image is None:
 
367
  guidance_scale=guidance_scale,
368
  controlnet_conditioning_scale=controlnet_scale,
369
  lora_scale=lora_scale,
370
+ identity_preservation=identity_preservation,
371
+ image_scale=image_scale
 
372
  )
373
  return result
374
  except Exception as e:
 
377
  traceback.print_exc()
378
  raise gr.Error(f"Generation failed: {str(e)}")
379
 
380
+ # Gradio UI
381
  with gr.Blocks(title="RetroArt Converter - LCM", theme=gr.themes.Soft()) as demo:
382
  gr.Markdown("""
383
+ # ๐ŸŽฎ RetroArt Converter (LCM Optimized)
384
 
385
+ Convert images into retro pixel art style using LCM (Latent Consistency Model) for fast, high-quality generation!
386
 
387
+ **โœจ Features:**
388
+ - โšก Ultra-fast generation (12 steps!)
389
+ - ๐ŸŽจ Custom pixel art LORA with trigger word: `p1x3l4rt, pixel art`
390
+ - ๐Ÿ“ Optimized resolutions: 896x1152 / 832x1216
391
+ - ๐Ÿ–ผ๏ธ Bundled VAE for authentic retro look
392
+ - ๐ŸŽฏ CLIP Skip 2 for better style
393
  """)
394
 
395
  # Model status
396
  if converter.models_loaded:
397
+ status_text = "**๐Ÿ“ฆ Loaded Models:**\n"
398
+ status_text += f"- Custom Checkpoint (Horizon): {'โœ“ Loaded' if converter.models_loaded['custom_checkpoint'] else 'โœ— Using SDXL base'}\n"
399
+ status_text += f"- LORA (RetroArt): {'โœ“ Loaded' if converter.models_loaded['lora'] else 'โœ— Disabled'}\n"
400
+ status_text += f"- InstantID: {'โœ“ Loaded' if converter.models_loaded['instantid'] else 'โœ— Disabled'}\n"
401
+ gr.Markdown(status_text)
402
+
403
+ gr.Markdown(f"""
404
+ **โš™๏ธ LCM Configuration:**
405
+ - Scheduler: LCM (Latent Consistency Model)
406
+ - Recommended Steps: **12** (fast!)
407
+ - Recommended CFG: **1.0-1.5** (lower than normal)
408
+ - CLIP Skip: **2**
409
+ - LORA Trigger: `{TRIGGER_WORD}` (auto-added)
410
+ """)
411
 
412
  with gr.Row():
413
  with gr.Column():
414
  input_image = gr.Image(label="Input Image", type="pil")
415
 
416
  prompt = gr.Textbox(
417
+ label="Prompt (trigger word auto-added)",
418
+ value="retro game character, vibrant colors, highly detailed",
419
+ lines=3,
420
+ info=f"'{TRIGGER_WORD}' will be automatically added"
421
  )
422
 
423
  negative_prompt = gr.Textbox(
424
  label="Negative Prompt",
425
+ value="blurry, low quality, ugly, distorted, deformed, bad anatomy",
426
  lines=2
427
  )
428
 
429
+ with gr.Accordion("โšก LCM Settings (Optimized)", open=True):
 
 
430
  steps = gr.Slider(
431
  minimum=4,
432
  maximum=20,
433
  value=12,
434
  step=1,
435
+ label="Inference Steps (LCM works great with just 12!)"
436
  )
437
 
438
  guidance_scale = gr.Slider(
439
+ minimum=0.5,
440
  maximum=3.0,
441
+ value=1.0,
442
  step=0.1,
443
+ label="Guidance Scale (CFG) - LCM uses 1.0-1.5"
444
  )
445
+
 
446
  controlnet_scale = gr.Slider(
447
+ minimum=0.3,
448
+ maximum=1.2,
449
+ value=0.8,
450
  step=0.05,
451
  label="ControlNet Depth Scale"
452
  )
453
 
454
  lora_scale = gr.Slider(
455
+ minimum=0.5,
456
+ maximum=1.5,
457
+ value=1.0,
458
  step=0.05,
459
  label="RetroArt LORA Scale"
460
  )
461
 
462
+ with gr.Accordion("๐ŸŽญ Identity Settings (for portraits)", open=False):
463
+ identity_preservation = gr.Slider(
464
+ minimum=0,
465
+ maximum=1.5,
466
+ value=0.8,
 
 
467
  step=0.1,
468
+ label="Identity Preservation"
469
  )
470
 
471
  image_scale = gr.Slider(
472
  minimum=0,
473
+ maximum=1.0,
474
+ value=0.2,
475
  step=0.05,
476
+ label="InstantID Image Scale"
477
  )
478
 
479
  generate_btn = gr.Button("๐ŸŽจ Generate Retro Art", variant="primary", size="lg")
 
482
  output_image = gr.Image(label="Retro Art Output")
483
 
484
  gr.Markdown("""
485
+ ### ๐Ÿ’ก Tips for Best Results:
486
+
487
+ **For LCM Models:**
488
+ - โœ… Use **12 steps** (already optimized!)
489
+ - โœ… Keep CFG at **1.0-1.5** (not 7.5!)
490
+ - โœ… LORA trigger word is **auto-added**
491
+ - โœ… Resolution auto-optimized to 896x1152 or 832x1216
492
+
493
+ **For Quality:**
494
+ - Use high-resolution input images
495
+ - Be specific in prompts: "16-bit game character" vs "character"
496
+ - Adjust ControlNet scale: lower = more creative, higher = more faithful
497
 
498
+ **For Style:**
499
+ - Increase LORA scale (1.0-1.5) for stronger pixel art effect
500
+ - Try prompts like: "SNES style", "16-bit RPG", "Game Boy advance style"
 
501
  """)
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  generate_btn.click(
504
  fn=process_image,
505
  inputs=[
506
  input_image, prompt, negative_prompt, steps, guidance_scale,
507
+ controlnet_scale, lora_scale, identity_preservation, image_scale
508
  ],
509
  outputs=[output_image]
510
  )