DawnC commited on
Commit
0e80bb9
Β·
verified Β·
1 Parent(s): dc7f2ba

Upload 14 files

Browse files
Files changed (6) hide show
  1. BackgroundEngine.py +205 -3
  2. app.py +3 -1
  3. mask_generator.py +185 -2
  4. requirements.txt +4 -0
  5. style_transfer.py +708 -0
  6. ui_manager.py +557 -13
BackgroundEngine.py CHANGED
@@ -10,7 +10,7 @@ from typing import Optional, Dict, Any, Callable
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
- from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
14
  import open_clip
15
  from mask_generator import MaskGenerator
16
  from image_blender import ImageBlender
@@ -39,10 +39,12 @@ class BackgroundEngine:
39
  self.clip_pretrained = "openai"
40
 
41
  self.pipeline = None
 
42
  self.clip_model = None
43
  self.clip_preprocess = None
44
  self.clip_tokenizer = None
45
  self.is_initialized = False
 
46
 
47
  self.max_image_size = 1024
48
  self.default_steps = 25
@@ -336,13 +338,15 @@ class BackgroundEngine:
336
  guidance_scale: float = 7.5,
337
  progress_callback: Optional[Callable] = None,
338
  enable_prompt_enhancement: bool = True,
339
- feather_radius: int = 0
 
340
  ) -> Dict[str, Any]:
341
  """
342
  Generate background and combine with foreground.
343
 
344
  Args:
345
  feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0)
 
346
 
347
  Returns dict with: combined_image, generated_scene, original_image, mask, success
348
  """
@@ -391,7 +395,8 @@ class BackgroundEngine:
391
  combination_mask = self.mask_generator.create_gradient_based_mask(
392
  processed_original,
393
  combination_mode,
394
- focus_mode
 
395
  )
396
 
397
  if progress_callback:
@@ -430,3 +435,200 @@ class BackgroundEngine:
430
  "success": False,
431
  "error": str(e)
432
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline, DPMSolverMultistepScheduler
14
  import open_clip
15
  from mask_generator import MaskGenerator
16
  from image_blender import ImageBlender
 
39
  self.clip_pretrained = "openai"
40
 
41
  self.pipeline = None
42
+ self.inpaint_pipeline = None
43
  self.clip_model = None
44
  self.clip_preprocess = None
45
  self.clip_tokenizer = None
46
  self.is_initialized = False
47
+ self.inpaint_initialized = False
48
 
49
  self.max_image_size = 1024
50
  self.default_steps = 25
 
338
  guidance_scale: float = 7.5,
339
  progress_callback: Optional[Callable] = None,
340
  enable_prompt_enhancement: bool = True,
341
+ feather_radius: int = 0,
342
+ enhance_dark_edges: bool = False
343
  ) -> Dict[str, Any]:
344
  """
345
  Generate background and combine with foreground.
346
 
347
  Args:
348
  feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0)
349
+ enhance_dark_edges: Enhance mask edges for dark background images (default False)
350
 
351
  Returns dict with: combined_image, generated_scene, original_image, mask, success
352
  """
 
395
  combination_mask = self.mask_generator.create_gradient_based_mask(
396
  processed_original,
397
  combination_mode,
398
+ focus_mode,
399
+ enhance_dark_edges=enhance_dark_edges
400
  )
401
 
402
  if progress_callback:
 
435
  "success": False,
436
  "error": str(e)
437
  }
438
+
439
+ def _load_inpaint_pipeline(self) -> bool:
440
+ """Lazy load SDXL inpainting pipeline"""
441
+ if self.inpaint_initialized:
442
+ return True
443
+
444
+ try:
445
+ logger.info("Loading SDXL inpainting pipeline...")
446
+ actual_device = "cuda" if torch.cuda.is_available() else self.device
447
+
448
+ self.inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
449
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
450
+ torch_dtype=torch.float16 if actual_device == "cuda" else torch.float32,
451
+ variant="fp16" if actual_device == "cuda" else None,
452
+ use_safetensors=True
453
+ )
454
+ self.inpaint_pipeline.to(actual_device)
455
+
456
+ # Use fast scheduler
457
+ self.inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
458
+ self.inpaint_pipeline.scheduler.config
459
+ )
460
+
461
+ # Memory optimization
462
+ if actual_device == "cuda":
463
+ try:
464
+ self.inpaint_pipeline.enable_xformers_memory_efficient_attention()
465
+ except Exception:
466
+ pass
467
+
468
+ self.inpaint_initialized = True
469
+ logger.info("βœ“ SDXL inpainting pipeline loaded")
470
+ return True
471
+
472
+ except Exception as e:
473
+ logger.error(f"Failed to load inpainting pipeline: {e}")
474
+ self.inpaint_initialized = False
475
+ return False
476
+
477
+ def inpaint_region(
478
+ self,
479
+ image: Image.Image,
480
+ mask: Image.Image,
481
+ prompt: str,
482
+ negative_prompt: str = "blurry, low quality, artifacts, seams",
483
+ num_inference_steps: int = 20,
484
+ guidance_scale: float = 7.5,
485
+ strength: float = 0.99
486
+ ) -> Dict[str, Any]:
487
+ """
488
+ Inpaint marked regions with background content.
489
+
490
+ Args:
491
+ image: The combined image with artifacts to fix
492
+ mask: Binary mask where white = areas to inpaint
493
+ prompt: Background description for inpainting
494
+ negative_prompt: What to avoid
495
+ num_inference_steps: Denoising steps (20 is usually enough)
496
+ guidance_scale: How closely to follow prompt
497
+ strength: How much to change masked area (0.99 = almost complete replacement)
498
+
499
+ Returns:
500
+ Dict with inpainted_image, success, error
501
+ """
502
+ try:
503
+ # Load inpainting pipeline if not already loaded
504
+ if not self._load_inpaint_pipeline():
505
+ # Fallback to OpenCV inpainting
506
+ return self._opencv_inpaint_fallback(image, mask)
507
+
508
+ logger.info("Starting region inpainting...")
509
+
510
+ # Prepare images
511
+ image = self._prepare_image(image)
512
+ mask = mask.resize(image.size, Image.LANCZOS).convert('L')
513
+
514
+ # Ensure mask is properly binarized
515
+ mask_array = np.array(mask)
516
+ mask_array = (mask_array > 127).astype(np.uint8) * 255
517
+ mask = Image.fromarray(mask_array, mode='L')
518
+
519
+ # Dilate mask slightly for better blending
520
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
521
+ mask_dilated = cv2.dilate(mask_array, kernel, iterations=1)
522
+ mask = Image.fromarray(mask_dilated, mode='L')
523
+
524
+ actual_device = "cuda" if torch.cuda.is_available() else self.device
525
+
526
+ with torch.inference_mode():
527
+ result = self.inpaint_pipeline(
528
+ prompt=prompt,
529
+ negative_prompt=negative_prompt,
530
+ image=image,
531
+ mask_image=mask,
532
+ width=image.size[0],
533
+ height=image.size[1],
534
+ num_inference_steps=num_inference_steps,
535
+ guidance_scale=guidance_scale,
536
+ strength=strength,
537
+ generator=torch.Generator(device=actual_device).manual_seed(42)
538
+ )
539
+
540
+ inpainted = result.images[0]
541
+
542
+ # Blend edges for smoother transition
543
+ inpainted = self._blend_inpaint_edges(image, inpainted, mask)
544
+
545
+ self._memory_cleanup()
546
+
547
+ logger.info("βœ“ Region inpainting completed")
548
+ return {
549
+ "inpainted_image": inpainted,
550
+ "success": True
551
+ }
552
+
553
+ except Exception as e:
554
+ logger.error(f"Inpainting failed: {e}")
555
+ self._memory_cleanup()
556
+ return {
557
+ "success": False,
558
+ "error": str(e)
559
+ }
560
+
561
+ def _opencv_inpaint_fallback(
562
+ self,
563
+ image: Image.Image,
564
+ mask: Image.Image
565
+ ) -> Dict[str, Any]:
566
+ """Fallback to OpenCV inpainting for small areas or when SDXL unavailable"""
567
+ try:
568
+ logger.info("Using OpenCV inpainting fallback...")
569
+
570
+ img_array = np.array(image.convert('RGB'))
571
+ mask_array = np.array(mask.convert('L'))
572
+
573
+ # Binarize mask
574
+ mask_binary = (mask_array > 127).astype(np.uint8) * 255
575
+
576
+ # Use Telea algorithm for natural results
577
+ inpainted = cv2.inpaint(
578
+ img_array,
579
+ mask_binary,
580
+ inpaintRadius=5,
581
+ flags=cv2.INPAINT_TELEA
582
+ )
583
+
584
+ result = Image.fromarray(inpainted)
585
+
586
+ logger.info("βœ“ OpenCV inpainting completed")
587
+ return {
588
+ "inpainted_image": result,
589
+ "success": True
590
+ }
591
+
592
+ except Exception as e:
593
+ logger.error(f"OpenCV inpainting failed: {e}")
594
+ return {
595
+ "success": False,
596
+ "error": str(e)
597
+ }
598
+
599
+ def _blend_inpaint_edges(
600
+ self,
601
+ original: Image.Image,
602
+ inpainted: Image.Image,
603
+ mask: Image.Image,
604
+ feather_pixels: int = 8
605
+ ) -> Image.Image:
606
+ """Blend inpainted region edges for seamless transition"""
607
+ try:
608
+ orig_array = np.array(original).astype(np.float32)
609
+ inpaint_array = np.array(inpainted).astype(np.float32)
610
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
611
+
612
+ # Create feathered mask for smooth blending
613
+ if feather_pixels > 0:
614
+ kernel_size = feather_pixels * 2 + 1
615
+ mask_feathered = cv2.GaussianBlur(
616
+ mask_array,
617
+ (kernel_size, kernel_size),
618
+ feather_pixels / 2
619
+ )
620
+ else:
621
+ mask_feathered = mask_array
622
+
623
+ # Expand mask to 3 channels
624
+ mask_3d = mask_feathered[:, :, np.newaxis]
625
+
626
+ # Blend: inpainted in masked area, original elsewhere
627
+ blended = inpaint_array * mask_3d + orig_array * (1 - mask_3d)
628
+ blended = np.clip(blended, 0, 255).astype(np.uint8)
629
+
630
+ return Image.fromarray(blended)
631
+
632
+ except Exception as e:
633
+ logger.warning(f"Edge blending failed: {e}, returning inpainted directly")
634
+ return inpainted
app.py CHANGED
@@ -16,6 +16,7 @@ import sentencepiece
16
 
17
  from FlowFacade import FlowFacade
18
  from BackgroundEngine import BackgroundEngine
 
19
  from ui_manager import UIManager
20
 
21
 
@@ -126,7 +127,8 @@ def main():
126
  try:
127
  facade = FlowFacade()
128
  background_engine = BackgroundEngine()
129
- ui_manager = UIManager(facade, background_engine)
 
130
  interface = ui_manager.create_interface()
131
  is_colab = 'google.colab' in sys.modules
132
 
 
16
 
17
  from FlowFacade import FlowFacade
18
  from BackgroundEngine import BackgroundEngine
19
+ from style_transfer import StyleTransferEngine
20
  from ui_manager import UIManager
21
 
22
 
 
127
  try:
128
  facade = FlowFacade()
129
  background_engine = BackgroundEngine()
130
+ style_engine = StyleTransferEngine()
131
+ ui_manager = UIManager(facade, background_engine, style_engine)
132
  interface = ui_manager.create_interface()
133
  is_colab = 'google.colab' in sys.modules
134
 
mask_generator.py CHANGED
@@ -15,6 +15,13 @@ from rembg import remove, new_session
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.INFO)
17
 
 
 
 
 
 
 
 
18
  class MaskGenerator:
19
  """
20
  Intelligent mask generation using deep learning models with traditional fallback.
@@ -92,6 +99,146 @@ class MaskGenerator:
92
  gc.collect()
93
  logger.info("🧹 BiRefNet model unloaded")
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def apply_guided_filter(
96
  self,
97
  mask: np.ndarray,
@@ -481,13 +628,25 @@ class MaskGenerator:
481
  logger.error(f"❌ Scene focus adjustment failed: {e}")
482
  return mask
483
 
484
- def create_gradient_based_mask(self, original_image: Image.Image, mode: str = "center", focus_mode: str = "person") -> Image.Image:
 
 
 
 
 
 
485
  """
486
  Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
487
  Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
 
 
 
 
 
 
488
  """
489
  width, height = original_image.size
490
- logger.info(f"🎯 Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}")
491
 
492
  if mode == "center":
493
  # Try using deep learning models for intelligent foreground extraction
@@ -495,9 +654,33 @@ class MaskGenerator:
495
  dl_mask = self.try_deep_learning_mask(original_image)
496
  if dl_mask is not None:
497
  logger.info("βœ… Using deep learning generated mask")
 
498
  # Apply focus mode adjustments to deep learning mask
499
  if focus_mode == "scene":
500
  dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  return dl_mask
502
 
503
  # Fallback to traditional method
 
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.INFO)
17
 
18
+ # Dark background detection thresholds
19
+ DARK_BG_LUMINANCE_THRESHOLD = 50 # Average luminance below this = dark background
20
+ DARK_BG_EDGE_SAMPLE_WIDTH = 20 # Pixels from edge to sample for background detection
21
+ DARK_BG_DILATION_PIXELS = 5 # Default dilation for dark backgrounds
22
+ DARK_BG_ENHANCED_DILATION = 8 # Enhanced dilation when user enables option
23
+
24
+
25
  class MaskGenerator:
26
  """
27
  Intelligent mask generation using deep learning models with traditional fallback.
 
99
  gc.collect()
100
  logger.info("🧹 BiRefNet model unloaded")
101
 
102
+ def detect_dark_background(self, image: Image.Image, mask: Optional[np.ndarray] = None) -> Tuple[bool, float]:
103
+ """
104
+ Detect if the image has a dark background.
105
+
106
+ Analyzes the edge regions of the image (where background is likely) to determine
107
+ if the background is predominantly dark, which can cause mask detection issues.
108
+
109
+ Args:
110
+ image: Input PIL Image
111
+ mask: Optional existing mask to exclude foreground from analysis
112
+
113
+ Returns:
114
+ Tuple of (is_dark_background: bool, avg_luminance: float)
115
+ """
116
+ try:
117
+ img_array = np.array(image.convert('RGB'))
118
+ height, width = img_array.shape[:2]
119
+
120
+ # Convert to grayscale for luminance analysis
121
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
122
+
123
+ # Sample from edge regions (likely background)
124
+ edge_width = min(DARK_BG_EDGE_SAMPLE_WIDTH, width // 10, height // 10)
125
+
126
+ # Create edge sampling mask
127
+ edge_sample_mask = np.zeros((height, width), dtype=bool)
128
+ edge_sample_mask[:edge_width, :] = True # Top
129
+ edge_sample_mask[-edge_width:, :] = True # Bottom
130
+ edge_sample_mask[:, :edge_width] = True # Left
131
+ edge_sample_mask[:, -edge_width:] = True # Right
132
+
133
+ # Exclude foreground if mask is provided
134
+ if mask is not None:
135
+ foreground_mask = mask > 127
136
+ edge_sample_mask = edge_sample_mask & (~foreground_mask)
137
+
138
+ if not np.any(edge_sample_mask):
139
+ # Fallback: use corners only
140
+ corner_pixels = np.array([
141
+ gray[0, 0], gray[0, -1],
142
+ gray[-1, 0], gray[-1, -1]
143
+ ])
144
+ avg_luminance = np.mean(corner_pixels)
145
+ else:
146
+ avg_luminance = np.mean(gray[edge_sample_mask])
147
+
148
+ is_dark = avg_luminance < DARK_BG_LUMINANCE_THRESHOLD
149
+
150
+ logger.info(f"πŸ” Background analysis - Avg luminance: {avg_luminance:.1f}, Dark: {is_dark}")
151
+
152
+ return is_dark, avg_luminance
153
+
154
+ except Exception as e:
155
+ logger.error(f"❌ Dark background detection failed: {e}")
156
+ return False, 128.0 # Default: not dark
157
+
158
+ def enhance_mask_for_dark_background(
159
+ self,
160
+ mask: Image.Image,
161
+ original_image: Image.Image,
162
+ dilation_pixels: int = DARK_BG_DILATION_PIXELS,
163
+ enhance_gray_areas: bool = True
164
+ ) -> Image.Image:
165
+ """
166
+ Enhance mask for images with dark backgrounds.
167
+
168
+ Applies dilation and gray area enhancement to capture foreground elements
169
+ that may have been missed due to low contrast with dark backgrounds.
170
+
171
+ Args:
172
+ mask: Input mask PIL Image (L mode)
173
+ original_image: Original image for reference
174
+ dilation_pixels: Number of pixels to dilate the mask
175
+ enhance_gray_areas: Whether to boost gray (uncertain) areas
176
+
177
+ Returns:
178
+ Enhanced mask PIL Image
179
+ """
180
+ try:
181
+ mask_array = np.array(mask)
182
+ orig_array = np.array(original_image.convert('RGB'))
183
+
184
+ logger.info(f"πŸ”§ Enhancing mask for dark background (dilation: {dilation_pixels}px)")
185
+
186
+ # Step 1: Identify gray (uncertain) areas in the mask
187
+ if enhance_gray_areas:
188
+ gray_areas = (mask_array > 30) & (mask_array < 200)
189
+
190
+ if np.any(gray_areas):
191
+ # For gray areas, check if they're near high-confidence foreground
192
+ high_conf = mask_array >= 200
193
+
194
+ # Dilate high confidence area to find nearby gray pixels
195
+ kernel_check = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
196
+ high_conf_dilated = cv2.dilate(high_conf.astype(np.uint8), kernel_check, iterations=2)
197
+
198
+ # Gray pixels near high confidence foreground -> boost them
199
+ boost_candidates = gray_areas & (high_conf_dilated > 0)
200
+
201
+ # Boost gray areas near foreground
202
+ mask_array[boost_candidates] = np.clip(
203
+ mask_array[boost_candidates] * 1.5 + 50,
204
+ 0, 255
205
+ ).astype(np.uint8)
206
+
207
+ logger.info(f"πŸ“ˆ Boosted {np.sum(boost_candidates)} gray pixels near foreground")
208
+
209
+ # Step 2: Apply dilation to expand foreground coverage
210
+ if dilation_pixels > 0:
211
+ kernel = cv2.getStructuringElement(
212
+ cv2.MORPH_ELLIPSE,
213
+ (dilation_pixels * 2 + 1, dilation_pixels * 2 + 1)
214
+ )
215
+
216
+ # Threshold to get foreground region for dilation
217
+ fg_binary = (mask_array > 50).astype(np.uint8) * 255
218
+ fg_dilated = cv2.dilate(fg_binary, kernel, iterations=1)
219
+
220
+ # Blend: keep original high values, expand into new areas
221
+ # New areas from dilation get moderate confidence
222
+ new_areas = (fg_dilated > 0) & (mask_array < 50)
223
+ mask_array[new_areas] = 180 # Moderate confidence for expanded areas
224
+
225
+ logger.info(f"πŸ“ Dilated mask by {dilation_pixels}px, added {np.sum(new_areas)} pixels")
226
+
227
+ # Step 3: Smooth the transitions
228
+ mask_array = cv2.GaussianBlur(mask_array, (3, 3), 0.8)
229
+
230
+ # Step 4: Re-strengthen core foreground
231
+ core_fg = np.array(mask) >= 220
232
+ mask_array[core_fg] = 255
233
+
234
+ logger.info(f"βœ… Dark background enhancement complete - Final mean: {mask_array.mean():.1f}")
235
+
236
+ return Image.fromarray(mask_array, mode='L')
237
+
238
+ except Exception as e:
239
+ logger.error(f"❌ Mask enhancement failed: {e}")
240
+ return mask
241
+
242
  def apply_guided_filter(
243
  self,
244
  mask: np.ndarray,
 
628
  logger.error(f"❌ Scene focus adjustment failed: {e}")
629
  return mask
630
 
631
+ def create_gradient_based_mask(
632
+ self,
633
+ original_image: Image.Image,
634
+ mode: str = "center",
635
+ focus_mode: str = "person",
636
+ enhance_dark_edges: bool = False
637
+ ) -> Image.Image:
638
  """
639
  Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
640
  Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
641
+
642
+ Args:
643
+ original_image: Input PIL Image
644
+ mode: Composition mode (center, left_half, right_half, full)
645
+ focus_mode: 'person' for tight crop, 'scene' for including nearby objects
646
+ enhance_dark_edges: User toggle to enhance mask for dark backgrounds
647
  """
648
  width, height = original_image.size
649
+ logger.info(f"🎯 Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}, enhance_dark: {enhance_dark_edges}")
650
 
651
  if mode == "center":
652
  # Try using deep learning models for intelligent foreground extraction
 
654
  dl_mask = self.try_deep_learning_mask(original_image)
655
  if dl_mask is not None:
656
  logger.info("βœ… Using deep learning generated mask")
657
+
658
  # Apply focus mode adjustments to deep learning mask
659
  if focus_mode == "scene":
660
  dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
661
+
662
+ # === Dark background detection and enhancement ===
663
+ mask_array = np.array(dl_mask)
664
+ is_dark_bg, avg_luminance = self.detect_dark_background(original_image, mask_array)
665
+
666
+ if is_dark_bg or enhance_dark_edges:
667
+ # Determine dilation amount
668
+ if enhance_dark_edges:
669
+ # User explicitly enabled - use stronger dilation
670
+ dilation = DARK_BG_ENHANCED_DILATION
671
+ logger.info(f"πŸŒ™ User enabled dark edge enhancement (dilation: {dilation}px)")
672
+ else:
673
+ # Auto-detected dark background - use moderate dilation
674
+ dilation = DARK_BG_DILATION_PIXELS
675
+ logger.info(f"πŸŒ™ Auto-detected dark background (luminance: {avg_luminance:.1f}), applying enhancement")
676
+
677
+ dl_mask = self.enhance_mask_for_dark_background(
678
+ dl_mask,
679
+ original_image,
680
+ dilation_pixels=dilation,
681
+ enhance_gray_areas=True
682
+ )
683
+
684
  return dl_mask
685
 
686
  # Fallback to traditional method
requirements.txt CHANGED
@@ -20,6 +20,10 @@ rembg[gpu]
20
  scipy
21
  opencv-contrib-python
22
 
 
 
 
 
23
  # Core Dependencies
24
  torch>=2.5.0
25
  numpy
 
20
  scipy
21
  opencv-contrib-python
22
 
23
+ # 3D Cartoon Style Dependencies (SDXL + Pixar LoRA)
24
+ # Note: diffusers is already included above for I2V
25
+ # SDXL uses the same diffusers library
26
+
27
  # Core Dependencies
28
  torch>=2.5.0
29
  numpy
style_transfer.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ from typing import Tuple, Optional, Dict, Any
4
+
5
+ from PIL import Image
6
+ import torch
7
+
8
+ try:
9
+ import spaces
10
+ HAS_SPACES = True
11
+ except ImportError:
12
+ HAS_SPACES = False
13
+
14
+
15
+ # Identity preservation keywords (added to all styles) - kept short for CLIP 77 token limit
16
+ IDENTITY_PRESERVE = "same person, same face, same ethnicity, same age"
17
+ IDENTITY_NEGATIVE = "different person, altered face, changed ethnicity, age change, distorted features"
18
+
19
+ # Enhanced face restore mode - concise weighted keywords
20
+ FACE_RESTORE_PRESERVE = "(same person:1.4), (preserve face:1.3), (same ethnicity:1.2), same pose, same lighting"
21
+ FACE_RESTORE_NEGATIVE = "(different person:1.4), (deformed face:1.3), wrong ethnicity, age change, western features"
22
+
23
+ # IP-Adapter settings for stronger identity preservation
24
+ # Using standard IP-Adapter (not face-specific) to avoid image encoder dependency
25
+ IP_ADAPTER_REPO = "h94/IP-Adapter"
26
+ IP_ADAPTER_SUBFOLDER = "sdxl_models"
27
+ IP_ADAPTER_WEIGHT = "ip-adapter_sdxl.bin" # Standard model, no extra encoder needed
28
+ IP_ADAPTER_SCALE_DEFAULT = 0.5 # Balance between identity and style
29
+
30
+ # Style-specific face_restore settings (some styles are more transformative)
31
+ FACE_RESTORE_STYLE_SETTINGS = {
32
+ "3d_cartoon": {"max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.4},
33
+ "anime": {"max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.4},
34
+ "illustrated_fantasy": {"max_strength": 0.42, "lora_scale_mult": 0.65, "ip_scale": 0.45},
35
+ "watercolor": {"max_strength": 0.40, "lora_scale_mult": 0.6, "ip_scale": 0.5},
36
+ "oil_painting": {"max_strength": 0.35, "lora_scale_mult": 0.5, "ip_scale": 0.6}, # Most transformative
37
+ "pixel_art": {"max_strength": 0.50, "lora_scale_mult": 0.8, "ip_scale": 0.3},
38
+ }
39
+
40
+ # Style configurations
41
+ STYLE_CONFIGS = {
42
+ "3d_cartoon": {
43
+ "name": "3D Cartoon",
44
+ "emoji": "🎬",
45
+ "lora_repo": "imagepipeline/Samaritan-3d-Cartoon-SDXL",
46
+ "lora_weight": "Samaritan 3d Cartoon.safetensors",
47
+ "prompt": "3D cartoon style, smooth rounded features, soft ambient lighting, CGI quality, vibrant colors, cel-shaded, studio render",
48
+ "negative_prompt": "ugly, deformed, noisy, blurry, low quality, flat, sketch",
49
+ "lora_scale": 0.75,
50
+ "recommended_strength": 0.55,
51
+ },
52
+ "anime": {
53
+ "name": "Anime Illustration",
54
+ "emoji": "🌸",
55
+ "lora_repo": None,
56
+ "lora_weight": None,
57
+ "prompt": "anime illustration, soft lighting, rich colors, delicate linework, smooth gradients, expressive eyes, cel shading, masterpiece",
58
+ "negative_prompt": "ugly, deformed, bad anatomy, bad hands, blurry, low quality",
59
+ "lora_scale": 0.0,
60
+ "recommended_strength": 0.50,
61
+ },
62
+ "illustrated_fantasy": {
63
+ "name": "Illustrated Fantasy",
64
+ "emoji": "πŸƒ",
65
+ "lora_repo": "ntc-ai/SDXL-LoRA-slider.Studio-Ghibli-style",
66
+ "lora_weight": "Studio Ghibli style.safetensors",
67
+ "prompt": "Ghibli style illustration, hand-painted look, soft watercolor textures, dreamy atmosphere, pastel colors, golden hour lighting, storybook quality",
68
+ "negative_prompt": "ugly, dark, horror, scary, blurry, low quality, modern",
69
+ "lora_scale": 1.0,
70
+ "recommended_strength": 0.50,
71
+ },
72
+ "watercolor": {
73
+ "name": "Watercolor Art",
74
+ "emoji": "🌊",
75
+ "lora_repo": "ostris/watercolor_style_lora_sdxl",
76
+ "lora_weight": "watercolor_style_lora.safetensors",
77
+ "prompt": "watercolor painting, wet-on-wet technique, soft color bleeds, paper texture, transparent washes, feathered edges, hand-painted",
78
+ "negative_prompt": "sharp edges, solid flat colors, harsh lines, vector art, airbrushed",
79
+ "lora_scale": 1.0,
80
+ "recommended_strength": 0.50,
81
+ },
82
+ "oil_painting": {
83
+ "name": "Classic Oil Paint",
84
+ "emoji": "πŸ–ΌοΈ",
85
+ "lora_repo": "EldritchAdam/ClassipeintXL",
86
+ "lora_weight": "ClassipeintXL.safetensors",
87
+ "prompt": "oil painting style, impasto technique, palette knife strokes, visible canvas texture, rich saturated pigments, masterful lighting, museum quality",
88
+ "negative_prompt": "flat, smooth, cartoon, anime, blurry, low quality, modern, airbrushed",
89
+ "lora_scale": 0.9,
90
+ "recommended_strength": 0.50,
91
+ },
92
+ "pixel_art": {
93
+ "name": "Pixel Art",
94
+ "emoji": "πŸ‘Ύ",
95
+ "lora_repo": "nerijs/pixel-art-xl",
96
+ "lora_weight": "pixel-art-xl.safetensors",
97
+ "prompt": "pixel art style, crisp blocky pixels, limited color palette, 16-bit aesthetic, retro game vibes, dithering effects, sprite art",
98
+ "negative_prompt": "smooth, blurry, anti-aliased, soft gradient, painterly",
99
+ "lora_scale": 0.9,
100
+ "recommended_strength": 0.60,
101
+ },
102
+ }
103
+
104
+ # Style Blend Presets - combining multiple styles (prompts kept short for CLIP 77 token limit)
105
+ STYLE_BLENDS = {
106
+ "cartoon_anime": {
107
+ "name": "3D Anime Fusion",
108
+ "emoji": "οΏ½οΏ½",
109
+ "description": "70% 3D Cartoon + 30% Anime linework",
110
+ "primary_style": "3d_cartoon",
111
+ "secondary_style": "anime",
112
+ "primary_weight": 0.7,
113
+ "secondary_weight": 0.3,
114
+ "prompt": "3D cartoon with anime linework, smooth features, soft lighting, CGI quality, vibrant colors, cel-shaded",
115
+ "negative_prompt": "ugly, deformed, noisy, blurry, low quality",
116
+ "strength": 0.52,
117
+ },
118
+ "fantasy_watercolor": {
119
+ "name": "Dreamy Watercolor",
120
+ "emoji": "🌈",
121
+ "description": "60% Illustrated Fantasy + 40% Watercolor",
122
+ "primary_style": "illustrated_fantasy",
123
+ "secondary_style": "watercolor",
124
+ "primary_weight": 0.6,
125
+ "secondary_weight": 0.4,
126
+ "prompt": "Ghibli style with watercolor washes, soft color bleeds, storybook atmosphere, paper texture, warm golden lighting",
127
+ "negative_prompt": "dark, horror, harsh lines, solid colors",
128
+ "strength": 0.50,
129
+ },
130
+ "anime_fantasy": {
131
+ "name": "Anime Storybook",
132
+ "emoji": "πŸ“–",
133
+ "description": "50% Anime + 50% Illustrated Fantasy",
134
+ "primary_style": "anime",
135
+ "secondary_style": "illustrated_fantasy",
136
+ "primary_weight": 0.5,
137
+ "secondary_weight": 0.5,
138
+ "prompt": "Ghibli anime illustration, hand-painted storybook, soft lighting, pastel colors, expressive eyes, warm glow",
139
+ "negative_prompt": "ugly, deformed, bad anatomy, dark, horror, blurry",
140
+ "strength": 0.48,
141
+ },
142
+ "oil_classical": {
143
+ "name": "Renaissance Portrait",
144
+ "emoji": "πŸ‘‘",
145
+ "description": "Classical oil painting style",
146
+ "primary_style": "oil_painting",
147
+ "secondary_style": "oil_painting",
148
+ "primary_weight": 1.0,
149
+ "secondary_weight": 0.0,
150
+ "prompt": "classical oil portrait, impasto technique, palette knife strokes, chiaroscuro lighting, canvas texture, museum quality",
151
+ "negative_prompt": "flat, cartoon, anime, modern, minimalist, overexposed",
152
+ "strength": 0.50,
153
+ },
154
+ "pixel_retro": {
155
+ "name": "Retro Game Art",
156
+ "emoji": "πŸ•ΉοΈ",
157
+ "description": "Pixel art with enhanced retro feel",
158
+ "primary_style": "pixel_art",
159
+ "secondary_style": "pixel_art",
160
+ "primary_weight": 1.0,
161
+ "secondary_weight": 0.0,
162
+ "prompt": "retro pixel art, crisp blocky pixels, limited palette, arcade aesthetic, dithering, 16-bit charm, sprite art",
163
+ "negative_prompt": "smooth, blurry, anti-aliased, modern, gradient",
164
+ "strength": 0.58,
165
+ },
166
+ }
167
+
168
+
169
+ class StyleTransferEngine:
170
+ """
171
+ Multi-style image transformation engine using SDXL + LoRAs.
172
+ Supports: 3D Cartoon, Anime, Watercolor, Oil Painting, Pixel Art styles.
173
+ With IP-Adapter support for identity preservation.
174
+ """
175
+
176
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
177
+
178
+ def __init__(self):
179
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
180
+ self.pipe = None
181
+ self.current_lora = None
182
+ self.is_loaded = False
183
+ self.ip_adapter_loaded = False
184
+
185
+ def load_model(self) -> None:
186
+ """Load SDXL base pipeline."""
187
+ if self.is_loaded:
188
+ return
189
+
190
+ print("β†’ Loading SDXL base model...")
191
+
192
+ from diffusers import AutoPipelineForImage2Image
193
+
194
+ actual_device = "cuda" if torch.cuda.is_available() else self.device
195
+
196
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
197
+ self.BASE_MODEL,
198
+ torch_dtype=torch.float16 if actual_device == "cuda" else torch.float32,
199
+ variant="fp16" if actual_device == "cuda" else None,
200
+ use_safetensors=True,
201
+ )
202
+
203
+ self.pipe.to(actual_device)
204
+
205
+ # Enable memory optimizations
206
+ if actual_device == "cuda":
207
+ try:
208
+ self.pipe.enable_xformers_memory_efficient_attention()
209
+ except Exception:
210
+ pass
211
+
212
+ self.is_loaded = True
213
+ self.device = actual_device
214
+ print(f"βœ“ SDXL base loaded ({actual_device})")
215
+
216
+ def _load_lora(self, style_key: str) -> None:
217
+ """Load LoRA for the specified style."""
218
+ config = STYLE_CONFIGS.get(style_key)
219
+ if not config:
220
+ return
221
+
222
+ lora_repo = config.get("lora_repo")
223
+
224
+ # Skip if no LoRA needed or already loaded
225
+ if lora_repo is None:
226
+ if self.current_lora is not None:
227
+ print("β†’ Unloading previous LoRA...")
228
+ self.pipe.unload_lora_weights()
229
+ self.current_lora = None
230
+ return
231
+
232
+ if self.current_lora == lora_repo:
233
+ return
234
+
235
+ # Unload previous LoRA if different
236
+ if self.current_lora is not None:
237
+ print(f"β†’ Unloading previous LoRA: {self.current_lora}")
238
+ self.pipe.unload_lora_weights()
239
+
240
+ # Load new LoRA
241
+ print(f"β†’ Loading LoRA: {config['name']}...")
242
+ try:
243
+ lora_weight = config.get("lora_weight")
244
+ if lora_weight:
245
+ self.pipe.load_lora_weights(lora_repo, weight_name=lora_weight)
246
+ else:
247
+ self.pipe.load_lora_weights(lora_repo)
248
+
249
+ self.current_lora = lora_repo
250
+ print(f"βœ“ LoRA loaded: {config['name']}")
251
+ except Exception as e:
252
+ print(f"⚠ LoRA loading failed: {e}, continuing without LoRA")
253
+ self.current_lora = None
254
+
255
+ def _load_ip_adapter(self) -> bool:
256
+ """Load IP-Adapter for identity preservation."""
257
+ if self.ip_adapter_loaded:
258
+ return True
259
+
260
+ if self.pipe is None:
261
+ return False
262
+
263
+ print("β†’ Loading IP-Adapter for face preservation...")
264
+ try:
265
+ self.pipe.load_ip_adapter(
266
+ IP_ADAPTER_REPO,
267
+ subfolder=IP_ADAPTER_SUBFOLDER,
268
+ weight_name=IP_ADAPTER_WEIGHT
269
+ )
270
+ self.ip_adapter_loaded = True
271
+ print("βœ“ IP-Adapter loaded")
272
+ return True
273
+ except Exception as e:
274
+ print(f"⚠ IP-Adapter loading failed: {e}")
275
+ self.ip_adapter_loaded = False
276
+ return False
277
+
278
+ def _unload_ip_adapter(self) -> None:
279
+ """Unload IP-Adapter to free memory."""
280
+ if not self.ip_adapter_loaded or self.pipe is None:
281
+ return
282
+
283
+ try:
284
+ self.pipe.unload_ip_adapter()
285
+ self.ip_adapter_loaded = False
286
+ print("βœ“ IP-Adapter unloaded")
287
+ except Exception as e:
288
+ print(f"⚠ IP-Adapter unload failed: {e}")
289
+
290
+ def unload_model(self) -> None:
291
+ """Unload model and free memory."""
292
+ if not self.is_loaded:
293
+ return
294
+
295
+ # Unload IP-Adapter first if loaded
296
+ if self.ip_adapter_loaded:
297
+ self._unload_ip_adapter()
298
+
299
+ if self.pipe is not None:
300
+ del self.pipe
301
+ self.pipe = None
302
+
303
+ self.current_lora = None
304
+ self.ip_adapter_loaded = False
305
+
306
+ gc.collect()
307
+ if torch.cuda.is_available():
308
+ torch.cuda.empty_cache()
309
+
310
+ self.is_loaded = False
311
+ print("βœ“ Model unloaded")
312
+
313
+ def _preprocess_image(self, image: Image.Image) -> Image.Image:
314
+ """Preprocess image for SDXL - resize to appropriate dimensions."""
315
+ if image.mode != 'RGB':
316
+ image = image.convert('RGB')
317
+
318
+ # SDXL works best with 1024x1024, maintain aspect ratio
319
+ max_size = 1024
320
+ width, height = image.size
321
+
322
+ if width > height:
323
+ new_width = max_size
324
+ new_height = int(height * (max_size / width))
325
+ else:
326
+ new_height = max_size
327
+ new_width = int(width * (max_size / height))
328
+
329
+ # Round to nearest 8 (SDXL requirement)
330
+ new_width = (new_width // 8) * 8
331
+ new_height = (new_height // 8) * 8
332
+
333
+ # Ensure minimum size
334
+ new_width = max(new_width, 512)
335
+ new_height = max(new_height, 512)
336
+
337
+ image = image.resize((new_width, new_height), Image.LANCZOS)
338
+ return image
339
+
340
+ def generate_styled_image(
341
+ self,
342
+ image: Image.Image,
343
+ style_key: str = "3d_cartoon",
344
+ strength: float = 0.65,
345
+ guidance_scale: float = 7.5,
346
+ num_inference_steps: int = 30,
347
+ custom_prompt: str = "",
348
+ seed: int = -1,
349
+ face_restore: bool = False
350
+ ) -> Tuple[Image.Image, int]:
351
+ """
352
+ Convert image to the specified style.
353
+
354
+ Args:
355
+ image: Input PIL Image
356
+ style_key: One of: 3d_cartoon, anime, illustrated_fantasy, watercolor, oil_painting, pixel_art
357
+ strength: How much to transform (0.0-1.0)
358
+ guidance_scale: How closely to follow the prompt
359
+ num_inference_steps: Number of denoising steps
360
+ custom_prompt: Additional prompt text
361
+ seed: Random seed (-1 for random)
362
+ face_restore: Enable enhanced face preservation mode
363
+
364
+ Returns:
365
+ Tuple of (Stylized PIL Image, seed used)
366
+ """
367
+ if not self.is_loaded:
368
+ self.load_model()
369
+
370
+ # Get style config
371
+ config = STYLE_CONFIGS.get(style_key, STYLE_CONFIGS["3d_cartoon"])
372
+
373
+ # Load appropriate LoRA
374
+ self._load_lora(style_key)
375
+
376
+ # Preprocess
377
+ print("β†’ Preprocessing image...")
378
+ processed_image = self._preprocess_image(image)
379
+
380
+ # Get style-specific face_restore settings
381
+ face_settings = FACE_RESTORE_STYLE_SETTINGS.get(style_key, {
382
+ "max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.5
383
+ })
384
+
385
+ # Build prompt based on face_restore mode
386
+ base_prompt = config["prompt"]
387
+ ip_adapter_image = None
388
+ ip_scale = 0.0
389
+
390
+ if face_restore:
391
+ # Enhanced face preservation mode with style-specific settings
392
+ preserve_prompt = FACE_RESTORE_PRESERVE
393
+ negative_base = FACE_RESTORE_NEGATIVE
394
+
395
+ # Apply style-specific strength cap
396
+ max_str = face_settings["max_strength"]
397
+ strength = min(strength, max_str)
398
+ print(f"β†’ Face Restore enabled: strength capped at {strength} (style: {style_key})")
399
+
400
+ # Load IP-Adapter for stronger identity preservation
401
+ if self._load_ip_adapter():
402
+ ip_adapter_image = processed_image
403
+ ip_scale = face_settings["ip_scale"]
404
+ print(f"β†’ IP-Adapter scale: {ip_scale}")
405
+ else:
406
+ preserve_prompt = IDENTITY_PRESERVE
407
+ negative_base = IDENTITY_NEGATIVE
408
+ # Unload IP-Adapter if not using face_restore (save memory)
409
+ if self.ip_adapter_loaded:
410
+ self._unload_ip_adapter()
411
+
412
+ if custom_prompt:
413
+ prompt = f"{preserve_prompt}, {base_prompt}, {custom_prompt}"
414
+ else:
415
+ prompt = f"{preserve_prompt}, {base_prompt}"
416
+
417
+ # Build negative prompt
418
+ negative_prompt = f"{negative_base}, {config['negative_prompt']}"
419
+
420
+ # Set LoRA scale (reduce for face restore mode with style-specific multiplier)
421
+ lora_scale = config.get("lora_scale", 1.0)
422
+ if face_restore:
423
+ lora_scale = lora_scale * face_settings["lora_scale_mult"]
424
+
425
+ # Handle seed
426
+ if seed == -1:
427
+ seed = torch.randint(0, 2147483647, (1,)).item()
428
+ generator = torch.Generator(device=self.device).manual_seed(seed)
429
+
430
+ # Generate
431
+ print(f"β†’ Generating {config['name']} style (strength: {strength}, steps: {num_inference_steps}, seed: {seed})...")
432
+
433
+ # Build generation kwargs
434
+ gen_kwargs = {
435
+ "prompt": prompt,
436
+ "negative_prompt": negative_prompt,
437
+ "image": processed_image,
438
+ "strength": strength,
439
+ "guidance_scale": guidance_scale,
440
+ "num_inference_steps": num_inference_steps,
441
+ "generator": generator,
442
+ }
443
+
444
+ # Add cross_attention_kwargs only if LoRA is loaded
445
+ if self.current_lora is not None:
446
+ gen_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
447
+
448
+ # Add IP-Adapter settings for face restoration
449
+ if ip_adapter_image is not None and self.ip_adapter_loaded:
450
+ self.pipe.set_ip_adapter_scale(ip_scale)
451
+ gen_kwargs["ip_adapter_image"] = ip_adapter_image
452
+
453
+ result = self.pipe(**gen_kwargs).images[0]
454
+
455
+ print(f"βœ“ {config['name']} style generated (seed: {seed})")
456
+
457
+ # Cleanup
458
+ gc.collect()
459
+ if torch.cuda.is_available():
460
+ torch.cuda.empty_cache()
461
+
462
+ return result, seed
463
+
464
+ def generate_blended_style(
465
+ self,
466
+ image: Image.Image,
467
+ blend_key: str,
468
+ custom_prompt: str = "",
469
+ seed: int = -1,
470
+ face_restore: bool = False
471
+ ) -> Tuple[Image.Image, int]:
472
+ """
473
+ Generate image using a style blend preset.
474
+
475
+ Args:
476
+ image: Input PIL Image
477
+ blend_key: Key from STYLE_BLENDS
478
+ custom_prompt: Additional prompt text
479
+ seed: Random seed (-1 for random)
480
+ face_restore: Enable enhanced face preservation mode
481
+
482
+ Returns:
483
+ Tuple of (Stylized PIL Image, seed used)
484
+ """
485
+ if not self.is_loaded:
486
+ self.load_model()
487
+
488
+ blend_config = STYLE_BLENDS.get(blend_key)
489
+ if not blend_config:
490
+ return self.generate_styled_image(image, "3d_cartoon", seed=seed, face_restore=face_restore)
491
+
492
+ # Get primary style for LoRA
493
+ primary_style = blend_config["primary_style"]
494
+ self._load_lora(primary_style)
495
+
496
+ # Preprocess
497
+ print("β†’ Preprocessing image...")
498
+ processed_image = self._preprocess_image(image)
499
+
500
+ # Get style-specific face_restore settings (use primary style)
501
+ face_settings = FACE_RESTORE_STYLE_SETTINGS.get(primary_style, {
502
+ "max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.5
503
+ })
504
+
505
+ # Build prompt based on face_restore mode
506
+ base_prompt = blend_config["prompt"]
507
+ ip_adapter_image = None
508
+ ip_scale = 0.0
509
+
510
+ if face_restore:
511
+ preserve_prompt = FACE_RESTORE_PRESERVE
512
+ negative_base = FACE_RESTORE_NEGATIVE
513
+
514
+ # Apply style-specific strength cap
515
+ max_str = face_settings["max_strength"]
516
+ strength = min(blend_config["strength"], max_str)
517
+ print(f"β†’ Face Restore enabled: strength capped at {strength} (blend: {blend_key})")
518
+
519
+ # Load IP-Adapter for stronger identity preservation
520
+ if self._load_ip_adapter():
521
+ ip_adapter_image = processed_image
522
+ ip_scale = face_settings["ip_scale"]
523
+ print(f"β†’ IP-Adapter scale: {ip_scale}")
524
+ else:
525
+ preserve_prompt = IDENTITY_PRESERVE
526
+ negative_base = IDENTITY_NEGATIVE
527
+ strength = blend_config["strength"]
528
+ # Unload IP-Adapter if not using face_restore
529
+ if self.ip_adapter_loaded:
530
+ self._unload_ip_adapter()
531
+
532
+ if custom_prompt:
533
+ prompt = f"{preserve_prompt}, {base_prompt}, {custom_prompt}"
534
+ else:
535
+ prompt = f"{preserve_prompt}, {base_prompt}"
536
+
537
+ # Build negative prompt
538
+ negative_prompt = f"{negative_base}, {blend_config['negative_prompt']}"
539
+
540
+ # Get LoRA scale from primary style (reduce for face restore with style-specific multiplier)
541
+ primary_config = STYLE_CONFIGS.get(primary_style, {})
542
+ lora_scale = primary_config.get("lora_scale", 1.0) * blend_config["primary_weight"]
543
+ if face_restore:
544
+ lora_scale = lora_scale * face_settings["lora_scale_mult"]
545
+
546
+ # Handle seed
547
+ if seed == -1:
548
+ seed = torch.randint(0, 2147483647, (1,)).item()
549
+ generator = torch.Generator(device=self.device).manual_seed(seed)
550
+
551
+ # Generate
552
+ print(f"β†’ Generating {blend_config['name']} blend (seed: {seed})...")
553
+
554
+ gen_kwargs = {
555
+ "prompt": prompt,
556
+ "negative_prompt": negative_prompt,
557
+ "image": processed_image,
558
+ "strength": strength,
559
+ "guidance_scale": 7.5,
560
+ "num_inference_steps": 30,
561
+ "generator": generator,
562
+ }
563
+
564
+ if self.current_lora is not None:
565
+ gen_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
566
+
567
+ # Add IP-Adapter settings for face restoration
568
+ if ip_adapter_image is not None and self.ip_adapter_loaded:
569
+ self.pipe.set_ip_adapter_scale(ip_scale)
570
+ gen_kwargs["ip_adapter_image"] = ip_adapter_image
571
+
572
+ result = self.pipe(**gen_kwargs).images[0]
573
+
574
+ print(f"βœ“ {blend_config['name']} blend generated (seed: {seed})")
575
+
576
+ # Cleanup
577
+ gc.collect()
578
+ if torch.cuda.is_available():
579
+ torch.cuda.empty_cache()
580
+
581
+ return result, seed
582
+
583
+ def generate_all_outputs(
584
+ self,
585
+ image: Image.Image,
586
+ style_key: str = "3d_cartoon",
587
+ strength: float = 0.65,
588
+ guidance_scale: float = 7.5,
589
+ num_inference_steps: int = 30,
590
+ custom_prompt: str = "",
591
+ seed: int = -1,
592
+ is_blend: bool = False,
593
+ face_restore: bool = False
594
+ ) -> dict:
595
+ """
596
+ Generate styled image output.
597
+
598
+ Returns dict with success status, stylized image, and seed used.
599
+ """
600
+ result = {
601
+ "success": False,
602
+ "stylized_image": None,
603
+ "preview_image": None,
604
+ "style_name": "",
605
+ "seed_used": 0,
606
+ "error": None
607
+ }
608
+
609
+ try:
610
+ if is_blend:
611
+ # Use blend preset
612
+ blend_config = STYLE_BLENDS.get(style_key, {})
613
+ result["style_name"] = blend_config.get("name", "Unknown Blend")
614
+
615
+ stylized, seed_used = self.generate_blended_style(
616
+ image=image,
617
+ blend_key=style_key,
618
+ custom_prompt=custom_prompt,
619
+ seed=seed,
620
+ face_restore=face_restore
621
+ )
622
+ else:
623
+ # Use single style
624
+ config = STYLE_CONFIGS.get(style_key, STYLE_CONFIGS["3d_cartoon"])
625
+ result["style_name"] = config["name"]
626
+
627
+ stylized, seed_used = self.generate_styled_image(
628
+ image=image,
629
+ style_key=style_key,
630
+ strength=strength,
631
+ guidance_scale=guidance_scale,
632
+ num_inference_steps=num_inference_steps,
633
+ custom_prompt=custom_prompt,
634
+ seed=seed,
635
+ face_restore=face_restore
636
+ )
637
+
638
+ result["stylized_image"] = stylized
639
+ result["preview_image"] = stylized
640
+ result["seed_used"] = seed_used
641
+ result["success"] = True
642
+ print(f"βœ“ {result['style_name']} conversion completed (seed: {seed_used})")
643
+
644
+ except Exception as e:
645
+ result["error"] = str(e)
646
+ print(f"βœ— Style conversion failed: {e}")
647
+
648
+ return result
649
+
650
+ @staticmethod
651
+ def get_available_styles() -> Dict[str, Dict[str, Any]]:
652
+ """Return available style configurations."""
653
+ return {
654
+ key: {
655
+ "name": config["name"],
656
+ "emoji": config["emoji"],
657
+ }
658
+ for key, config in STYLE_CONFIGS.items()
659
+ }
660
+
661
+ @staticmethod
662
+ def get_style_choices() -> list:
663
+ """Return style choices for UI dropdown."""
664
+ return [
665
+ f"{config['emoji']} {config['name']}"
666
+ for config in STYLE_CONFIGS.values()
667
+ ]
668
+
669
+ @staticmethod
670
+ def get_style_key_from_choice(choice: str) -> str:
671
+ """Convert UI choice back to style key."""
672
+ for key, config in STYLE_CONFIGS.items():
673
+ if config["name"] in choice:
674
+ return key
675
+ return "3d_cartoon"
676
+
677
+ @staticmethod
678
+ def get_blend_choices() -> list:
679
+ """Return blend preset choices for UI dropdown."""
680
+ return [
681
+ f"{config['emoji']} {config['name']} - {config['description']}"
682
+ for config in STYLE_BLENDS.values()
683
+ ]
684
+
685
+ @staticmethod
686
+ def get_blend_key_from_choice(choice: str) -> str:
687
+ """Convert UI blend choice back to blend key."""
688
+ for key, config in STYLE_BLENDS.items():
689
+ if config["name"] in choice:
690
+ return key
691
+ return "cartoon_anime"
692
+
693
+ @staticmethod
694
+ def get_all_choices() -> dict:
695
+ """Return both style and blend choices for UI."""
696
+ styles = [
697
+ f"{config['emoji']} {config['name']}"
698
+ for config in STYLE_CONFIGS.values()
699
+ ]
700
+ blends = [
701
+ f"{config['emoji']} {config['name']}"
702
+ for config in STYLE_BLENDS.values()
703
+ ]
704
+ return {
705
+ "styles": styles,
706
+ "blends": blends,
707
+ "all": styles + ["─── Style Blends ───"] + blends
708
+ }
ui_manager.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
 
7
  from FlowFacade import FlowFacade
8
  from BackgroundEngine import BackgroundEngine
 
9
  from scene_templates import SceneTemplateManager
10
  from css_style import DELTAFLOW_CSS
11
  from prompt_examples import PROMPT_EXAMPLES
@@ -20,9 +21,10 @@ logger = logging.getLogger(__name__)
20
 
21
 
22
  class UIManager:
23
- def __init__(self, facade: FlowFacade, background_engine: BackgroundEngine):
24
  self.facade = facade
25
  self.background_engine = background_engine
 
26
  self.template_manager = SceneTemplateManager()
27
 
28
  def create_interface(self) -> gr.Blocks:
@@ -45,15 +47,19 @@ class UIManager:
45
 
46
  # Main Tabs
47
  with gr.Tabs() as main_tabs:
48
-
49
- # Tab 1: Image to Video (Original Functionality)
50
  with gr.Tab("🎬 Image to Video"):
51
  self._create_i2v_tab()
52
-
53
- # Tab 2: Background Generation (New Feature)
54
  with gr.Tab("🎨 Background Generation"):
55
  self._create_background_tab()
56
 
 
 
 
 
57
  # Footer
58
  gr.HTML("""
59
  <div class="footer">
@@ -341,8 +347,21 @@ class UIManager:
341
  gr.HTML("""
342
  <div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
343
  <strong>πŸ’‘ When to Adjust:</strong><br>
 
344
  β€’ <strong>Feather Radius:</strong> Use 5-10 for complex scenes with fine details (hair, fur, foliage). 0 = sharp edges for clean portraits.<br>
345
- β€’ <strong>Mask Preview:</strong> Check the "Mask Preview" tab after generation. White = kept, Black = replaced. Helps diagnose edge issues.
 
 
 
 
 
 
 
 
 
 
 
 
346
  </div>
347
  """)
348
 
@@ -393,7 +412,7 @@ class UIManager:
393
 
394
  gr.HTML("""
395
  <div class="patience-banner">
396
- <strong>⏱️ First-time users:</strong> Initial model loading takes 1-2 minutes.
397
  Subsequent generations are much faster (~30s).
398
  </div>
399
  """)
@@ -443,6 +462,77 @@ class UIManager:
443
  elem_classes=["secondary-button"]
444
  )
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  # Event handlers for Background Generation tab
447
  def apply_template(display_name: str, current_negative: str) -> Tuple[str, str, float]:
448
  if not display_name:
@@ -474,7 +564,7 @@ class UIManager:
474
  inputs=[
475
  bg_image_input, bg_prompt_input, combination_mode,
476
  focus_mode, bg_negative_prompt, bg_steps_slider, bg_guidance_slider,
477
- feather_radius_slider
478
  ],
479
  outputs=[
480
  bg_combined_output, bg_generated_output,
@@ -495,6 +585,132 @@ class UIManager:
495
  outputs=[bg_status_output]
496
  )
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  def _generate_background_handler(
499
  self,
500
  image: Image.Image,
@@ -504,7 +720,8 @@ class UIManager:
504
  negative_prompt: str,
505
  steps: int,
506
  guidance: float,
507
- feather_radius: int
 
508
  ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
509
  """Handler for background generation"""
510
  if image is None:
@@ -522,7 +739,7 @@ class UIManager:
522
 
523
  result = generate_fn(
524
  image, prompt, combination_mode, focus_mode,
525
- negative_prompt, steps, guidance, feather_radius
526
  )
527
 
528
  if result["success"]:
@@ -550,7 +767,8 @@ class UIManager:
550
  negative_prompt: str,
551
  steps: int,
552
  guidance: float,
553
- feather_radius: int
 
554
  ) -> Dict[str, Any]:
555
  """Core background generation with models"""
556
  if not self.background_engine.is_initialized:
@@ -566,7 +784,333 @@ class UIManager:
566
  num_inference_steps=int(steps),
567
  guidance_scale=float(guidance),
568
  enable_prompt_enhancement=True,
569
- feather_radius=int(feather_radius)
 
570
  )
571
 
572
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from FlowFacade import FlowFacade
8
  from BackgroundEngine import BackgroundEngine
9
+ from style_transfer import StyleTransferEngine
10
  from scene_templates import SceneTemplateManager
11
  from css_style import DELTAFLOW_CSS
12
  from prompt_examples import PROMPT_EXAMPLES
 
21
 
22
 
23
  class UIManager:
24
+ def __init__(self, facade: FlowFacade, background_engine: BackgroundEngine, style_engine: StyleTransferEngine):
25
  self.facade = facade
26
  self.background_engine = background_engine
27
+ self.style_engine = style_engine
28
  self.template_manager = SceneTemplateManager()
29
 
30
  def create_interface(self) -> gr.Blocks:
 
47
 
48
  # Main Tabs
49
  with gr.Tabs() as main_tabs:
50
+
51
+ # Tab 1: Image to Video
52
  with gr.Tab("🎬 Image to Video"):
53
  self._create_i2v_tab()
54
+
55
+ # Tab 2: Background Generation
56
  with gr.Tab("🎨 Background Generation"):
57
  self._create_background_tab()
58
 
59
+ # Tab 3: AI Style Transfer
60
+ with gr.Tab("✨ Style Transfer"):
61
+ self._create_3d_tab()
62
+
63
  # Footer
64
  gr.HTML("""
65
  <div class="footer">
 
347
  gr.HTML("""
348
  <div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
349
  <strong>πŸ’‘ When to Adjust:</strong><br>
350
+ β€’ <strong>Enhance Dark Edges:</strong> Enable for images with dark/black backgrounds where foreground parts get lost.<br>
351
  β€’ <strong>Feather Radius:</strong> Use 5-10 for complex scenes with fine details (hair, fur, foliage). 0 = sharp edges for clean portraits.<br>
352
+ β€’ <strong>Mask Preview:</strong> Check the "Mask Preview" tab after generation. White = kept, Black = replaced.
353
+ </div>
354
+ """)
355
+
356
+ enhance_dark_edges = gr.Checkbox(
357
+ label="πŸŒ™ Enhance Dark Edges",
358
+ value=False,
359
+ info="Enable if dark foreground parts blend into dark backgrounds"
360
+ )
361
+ gr.HTML("""
362
+ <div style="padding: 6px 8px; background: #fff3cd; border-radius: 4px; font-size: 11px; margin-bottom: 12px;">
363
+ <strong>When to use:</strong> If mask preview shows gray areas where foreground should be white (e.g., dark hair/clothing on dark background).
364
+ Auto-detection is enabled by default, but this toggle forces stronger enhancement.
365
  </div>
366
  """)
367
 
 
412
 
413
  gr.HTML("""
414
  <div class="patience-banner">
415
+ <strong>⏱️ First-time users:</strong> Initial model loading takes 30-60 seconds.
416
  Subsequent generations are much faster (~30s).
417
  </div>
418
  """)
 
462
  elem_classes=["secondary-button"]
463
  )
464
 
465
+ # Touch Up Section for manual artifact removal
466
+ with gr.Accordion("πŸ–ŒοΈ Touch Up (Remove Artifacts)", open=False) as touchup_accordion:
467
+ gr.HTML("""
468
+ <div style="padding: 10px; background: #e8f4fd; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
469
+ <strong>✨ How to Use Touch Up:</strong><br>
470
+ 1. After generating, if you see unwanted artifacts (gray edges, leftover objects)<br>
471
+ 2. Click "Load Result for Touch Up" to load the image<br>
472
+ 3. Use the brush to paint over areas you want to remove<br>
473
+ 4. Click "Remove & Fill" to replace painted areas with background
474
+ </div>
475
+ """)
476
+
477
+ # State to store the current result and prompt
478
+ touchup_source_image = gr.State(value=None)
479
+ touchup_background_prompt = gr.State(value="")
480
+
481
+ load_touchup_btn = gr.Button(
482
+ "πŸ“₯ Load Result for Touch Up",
483
+ elem_classes=["secondary-button"]
484
+ )
485
+
486
+ touchup_editor = gr.ImageEditor(
487
+ label="Draw on areas to remove (use brush tool)",
488
+ type="pil",
489
+ height=400,
490
+ brush=gr.Brush(
491
+ colors=["#FF0000"],
492
+ default_color="#FF0000",
493
+ default_size=20
494
+ ),
495
+ layers=False,
496
+ interactive=True,
497
+ visible=True
498
+ )
499
+
500
+ with gr.Row():
501
+ brush_size_slider = gr.Slider(
502
+ label="Brush Size",
503
+ minimum=5,
504
+ maximum=50,
505
+ value=20,
506
+ step=5,
507
+ scale=2
508
+ )
509
+ touchup_strength = gr.Slider(
510
+ label="Fill Strength",
511
+ minimum=0.8,
512
+ maximum=1.0,
513
+ value=0.99,
514
+ step=0.01,
515
+ scale=2,
516
+ info="Higher = more complete replacement"
517
+ )
518
+
519
+ remove_fill_btn = gr.Button(
520
+ "🎨 Remove & Fill",
521
+ variant="primary",
522
+ elem_classes="primary-button"
523
+ )
524
+
525
+ touchup_result = gr.Image(
526
+ label="Touch Up Result",
527
+ elem_classes=["result-gallery"]
528
+ )
529
+
530
+ touchup_status = gr.Textbox(
531
+ label="Touch Up Status",
532
+ value="Load an image to start touch up.",
533
+ interactive=False
534
+ )
535
+
536
  # Event handlers for Background Generation tab
537
  def apply_template(display_name: str, current_negative: str) -> Tuple[str, str, float]:
538
  if not display_name:
 
564
  inputs=[
565
  bg_image_input, bg_prompt_input, combination_mode,
566
  focus_mode, bg_negative_prompt, bg_steps_slider, bg_guidance_slider,
567
+ feather_radius_slider, enhance_dark_edges
568
  ],
569
  outputs=[
570
  bg_combined_output, bg_generated_output,
 
585
  outputs=[bg_status_output]
586
  )
587
 
588
+ # Touch Up event handlers
589
+ def load_for_touchup(combined_image, prompt):
590
+ """Load the generated result into touch up editor"""
591
+ if combined_image is None:
592
+ return None, None, "", "Please generate a background first!"
593
+ return combined_image, combined_image, prompt, "βœ“ Image loaded! Use brush to paint areas to remove."
594
+
595
+ load_touchup_btn.click(
596
+ fn=load_for_touchup,
597
+ inputs=[bg_combined_output, bg_prompt_input],
598
+ outputs=[touchup_editor, touchup_source_image, touchup_background_prompt, touchup_status]
599
+ )
600
+
601
+ remove_fill_btn.click(
602
+ fn=self._touchup_inpaint_handler,
603
+ inputs=[touchup_editor, touchup_background_prompt, touchup_strength],
604
+ outputs=[touchup_result, touchup_status]
605
+ )
606
+
607
+ def _touchup_inpaint_handler(
608
+ self,
609
+ editor_data: dict,
610
+ background_prompt: str,
611
+ strength: float
612
+ ) -> Tuple[Optional[Image.Image], str]:
613
+ """Handler for touch up inpainting"""
614
+ if editor_data is None:
615
+ return None, "Please load an image first!"
616
+
617
+ try:
618
+ # Extract image and mask from editor
619
+ # Gradio ImageEditor returns a dict with 'background', 'layers', 'composite'
620
+ if isinstance(editor_data, dict):
621
+ base_image = editor_data.get("background") or editor_data.get("composite")
622
+ layers = editor_data.get("layers", [])
623
+
624
+ if base_image is None:
625
+ return None, "No image found in editor!"
626
+
627
+ # Create mask from drawn layers (red brush strokes)
628
+ mask = self._extract_mask_from_editor(base_image, layers)
629
+
630
+ if mask is None or not self._has_painted_area(mask):
631
+ return None, "Please draw on areas you want to remove!"
632
+
633
+ else:
634
+ # Fallback for PIL Image
635
+ return None, "Invalid editor data format!"
636
+
637
+ # Apply ZeroGPU decorator if available
638
+ if SPACES_AVAILABLE:
639
+ inpaint_fn = spaces.GPU(duration=60)(self._touchup_inpaint_core)
640
+ else:
641
+ inpaint_fn = self._touchup_inpaint_core
642
+
643
+ result = inpaint_fn(base_image, mask, background_prompt, strength)
644
+
645
+ if result["success"]:
646
+ return result["inpainted_image"], "βœ“ Touch up completed!"
647
+ else:
648
+ return None, f"Error: {result.get('error', 'Unknown error')}"
649
+
650
+ except Exception as e:
651
+ logger.error(f"Touch up failed: {e}")
652
+ return None, f"Error: {str(e)}"
653
+
654
+ def _extract_mask_from_editor(self, base_image: Image.Image, layers: list) -> Optional[Image.Image]:
655
+ """Extract painted mask from ImageEditor layers"""
656
+ import numpy as np
657
+
658
+ if not layers:
659
+ return None
660
+
661
+ # Create blank mask
662
+ width, height = base_image.size
663
+ mask_array = np.zeros((height, width), dtype=np.uint8)
664
+
665
+ for layer in layers:
666
+ if layer is None:
667
+ continue
668
+
669
+ # Convert layer to numpy array
670
+ if isinstance(layer, Image.Image):
671
+ layer_array = np.array(layer.convert('RGBA'))
672
+ else:
673
+ continue
674
+
675
+ # Find non-transparent pixels (painted areas)
676
+ # The alpha channel indicates where user drew
677
+ if layer_array.shape[2] >= 4:
678
+ alpha = layer_array[:, :, 3]
679
+ # Also check for red color (our brush color)
680
+ red = layer_array[:, :, 0]
681
+ # Painted areas have high alpha and red channel
682
+ painted = (alpha > 50) | (red > 100)
683
+ mask_array[painted] = 255
684
+
685
+ return Image.fromarray(mask_array, mode='L')
686
+
687
+ def _has_painted_area(self, mask: Image.Image) -> bool:
688
+ """Check if mask has any painted area"""
689
+ import numpy as np
690
+ mask_array = np.array(mask)
691
+ return np.sum(mask_array > 127) > 100 # At least 100 white pixels
692
+
693
+ def _touchup_inpaint_core(
694
+ self,
695
+ image: Image.Image,
696
+ mask: Image.Image,
697
+ prompt: str,
698
+ strength: float
699
+ ) -> dict:
700
+ """Core inpainting function"""
701
+ # Use the background prompt to fill in the masked areas
702
+ inpaint_prompt = f"{prompt}, seamless, natural continuation, no artifacts" if prompt else "natural background, seamless continuation"
703
+
704
+ return self.background_engine.inpaint_region(
705
+ image=image,
706
+ mask=mask,
707
+ prompt=inpaint_prompt,
708
+ negative_prompt="blurry, artifacts, seams, inconsistent, unnatural",
709
+ num_inference_steps=20,
710
+ guidance_scale=7.5,
711
+ strength=float(strength)
712
+ )
713
+
714
  def _generate_background_handler(
715
  self,
716
  image: Image.Image,
 
720
  negative_prompt: str,
721
  steps: int,
722
  guidance: float,
723
+ feather_radius: int,
724
+ enhance_dark_edges: bool = False
725
  ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
726
  """Handler for background generation"""
727
  if image is None:
 
739
 
740
  result = generate_fn(
741
  image, prompt, combination_mode, focus_mode,
742
+ negative_prompt, steps, guidance, feather_radius, enhance_dark_edges
743
  )
744
 
745
  if result["success"]:
 
767
  negative_prompt: str,
768
  steps: int,
769
  guidance: float,
770
+ feather_radius: int,
771
+ enhance_dark_edges: bool = False
772
  ) -> Dict[str, Any]:
773
  """Core background generation with models"""
774
  if not self.background_engine.is_initialized:
 
784
  num_inference_steps=int(steps),
785
  guidance_scale=float(guidance),
786
  enable_prompt_enhancement=True,
787
+ feather_radius=int(feather_radius),
788
+ enhance_dark_edges=enhance_dark_edges
789
  )
790
 
791
+ return result
792
+
793
+ def _create_3d_tab(self):
794
+ """Create Style Transfer tab - converts images to various artistic styles"""
795
+ with gr.Row():
796
+ # Left Panel: Input & Settings
797
+ with gr.Column(scale=1, elem_classes="feature-card"):
798
+ gr.Markdown("### 🎨 AI Style Transfer")
799
+
800
+ # How It Works Guide
801
+ gr.HTML("""
802
+ <div class="quality-banner">
803
+ <strong>πŸ“– Transform Your Photos</strong><br><br>
804
+ Convert your images into <strong>stunning artistic styles</strong>!<br><br>
805
+ <strong>🎨 Single Styles:</strong> Pure artistic transformations<br>
806
+ <strong>🎭 Style Blends:</strong> Unique combinations for distinctive looks<br><br>
807
+ <strong>πŸ’‘ Tips:</strong><br>
808
+ β€’ Use <strong>Seed</strong> to recreate the exact same result<br>
809
+ β€’ Try different blends for unique artistic effects
810
+ </div>
811
+ """)
812
+
813
+ # Step 1: Upload
814
+ gr.Markdown("#### Step 1: Upload Image")
815
+ style3d_image_input = gr.Image(
816
+ label="Upload Your Image",
817
+ type="pil",
818
+ height=280
819
+ )
820
+
821
+ # Step 2: Choose Style
822
+ gr.Markdown("#### Step 2: Choose Style")
823
+
824
+ # Hidden state to track which mode is active (updated by tab selection)
825
+ is_blend_mode = gr.State(value=False)
826
+
827
+ with gr.Tabs() as style_tabs:
828
+ with gr.TabItem("🎨 Single Styles", id="single_tab") as single_tab:
829
+ style_dropdown = gr.Dropdown(
830
+ choices=self.style_engine.get_style_choices(),
831
+ value="🎬 3D Cartoon",
832
+ label="Art Style",
833
+ info="Select a single artistic style"
834
+ )
835
+
836
+ style_strength = gr.Slider(
837
+ label="Style Strength",
838
+ minimum=0.3,
839
+ maximum=0.7,
840
+ value=0.50,
841
+ step=0.05,
842
+ info="Lower = keep more original | Higher = stronger style (0.45-0.55 recommended)"
843
+ )
844
+
845
+ with gr.TabItem("🎭 Style Blends", id="blend_tab") as blend_tab:
846
+ blend_dropdown = gr.Dropdown(
847
+ choices=self.style_engine.get_blend_choices(),
848
+ value=self.style_engine.get_blend_choices()[0] if self.style_engine.get_blend_choices() else None,
849
+ label="Blend Preset",
850
+ info="Pre-configured style combinations"
851
+ )
852
+ gr.HTML("""
853
+ <div style="padding: 8px; background: #f0f4ff; border-radius: 6px; font-size: 12px; margin-top: 8px;">
854
+ <strong>Available Blends:</strong><br>
855
+ β€’ 🎭 3D Anime Fusion - 3D + Anime linework<br>
856
+ β€’ 🌈 Dreamy Watercolor - Fantasy + Watercolor<br>
857
+ β€’ πŸ“– Anime Storybook - Anime + Fantasy<br>
858
+ β€’ πŸ‘‘ Renaissance Portrait - Classical oil painting<br>
859
+ β€’ πŸ•ΉοΈ Retro Game Art - Enhanced pixel art
860
+ </div>
861
+ """)
862
+
863
+ # Face Restore option for identity preservation
864
+ face_restore = gr.Checkbox(
865
+ label="πŸ›‘οΈ Face Restore (Preserve Identity)",
866
+ value=False,
867
+ info="Enable to better preserve facial features and prevent identity changes"
868
+ )
869
+ gr.HTML("""
870
+ <div style="padding: 6px 8px; background: #fff3cd; border-radius: 4px; font-size: 11px; margin-top: 4px;">
871
+ <strong>πŸ’‘ When to use:</strong> Enable if the style changes the person's face, age, or ethnicity too much.
872
+ Auto-reduces strength to preserve original features.
873
+ </div>
874
+ """)
875
+
876
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
877
+ guidance_scale = gr.Slider(
878
+ label="Guidance Scale",
879
+ minimum=5.0,
880
+ maximum=12.0,
881
+ value=7.5,
882
+ step=0.5,
883
+ info="How closely to follow the style"
884
+ )
885
+
886
+ num_steps = gr.Slider(
887
+ label="Quality Steps",
888
+ minimum=20,
889
+ maximum=50,
890
+ value=30,
891
+ step=5,
892
+ info="More steps = better quality but slower"
893
+ )
894
+
895
+ custom_prompt = gr.Textbox(
896
+ label="Additional Description (optional)",
897
+ placeholder="e.g., smiling, dramatic lighting, vibrant colors...",
898
+ lines=2
899
+ )
900
+
901
+ gr.Markdown("##### 🎲 Seed Control")
902
+ randomize_seed = gr.Checkbox(
903
+ label="Randomize Seed",
904
+ value=True,
905
+ info="Uncheck to use manual seed for reproducible results"
906
+ )
907
+
908
+ seed_input = gr.Number(
909
+ label="Manual Seed",
910
+ value=42,
911
+ precision=0,
912
+ info="Use same seed to reproduce exact results"
913
+ )
914
+
915
+ # Step 3: Generate
916
+ gr.Markdown("#### Step 3: Generate")
917
+
918
+ gr.HTML("""
919
+ <div class="patience-banner">
920
+ <strong>⏱️ Generation Time:</strong> ~20-30 seconds.
921
+ First-time model loading may take 30-60 seconds.
922
+ </div>
923
+ """)
924
+
925
+ generate_style_btn = gr.Button(
926
+ "🎨 Transform Image",
927
+ variant="primary",
928
+ elem_classes="primary-button",
929
+ size="lg"
930
+ )
931
+
932
+ # Right Panel: Output
933
+ with gr.Column(scale=1, elem_classes="feature-card"):
934
+ gr.Markdown("### πŸ“€ Results")
935
+
936
+ with gr.Tabs():
937
+ with gr.TabItem("Stylized Result"):
938
+ style3d_output = gr.Image(
939
+ label="Stylized Result",
940
+ elem_classes=["result-gallery"]
941
+ )
942
+
943
+ with gr.TabItem("Original"):
944
+ style3d_original = gr.Image(
945
+ label="Original Image",
946
+ elem_classes=["result-gallery"]
947
+ )
948
+
949
+ with gr.TabItem("Comparison"):
950
+ with gr.Row():
951
+ style3d_compare_original = gr.Image(
952
+ label="Before",
953
+ elem_classes=["result-gallery"]
954
+ )
955
+ style3d_compare_result = gr.Image(
956
+ label="After",
957
+ elem_classes=["result-gallery"]
958
+ )
959
+
960
+ with gr.Row():
961
+ style3d_status_output = gr.Textbox(
962
+ label="Status",
963
+ value="Ready! Upload an image and select a style to transform.",
964
+ interactive=False,
965
+ elem_classes=["status-panel"],
966
+ scale=3
967
+ )
968
+ seed_output = gr.Number(
969
+ label="Seed Used",
970
+ value=0,
971
+ interactive=False,
972
+ precision=0,
973
+ scale=1
974
+ )
975
+
976
+ with gr.Row():
977
+ clear_style_btn = gr.Button(
978
+ "Clear All",
979
+ elem_classes=["secondary-button"]
980
+ )
981
+ memory_style_btn = gr.Button(
982
+ "Clean Memory",
983
+ elem_classes=["secondary-button"]
984
+ )
985
+
986
+ # Event handlers - detect mode from TAB selection (not just dropdown)
987
+ single_tab.select(
988
+ fn=lambda: False, # Single Styles tab clicked -> is_blend = False
989
+ inputs=[],
990
+ outputs=[is_blend_mode]
991
+ )
992
+
993
+ blend_tab.select(
994
+ fn=lambda: True, # Style Blends tab clicked -> is_blend = True
995
+ inputs=[],
996
+ outputs=[is_blend_mode]
997
+ )
998
+
999
+ generate_style_btn.click(
1000
+ fn=self._generate_3d_style_handler,
1001
+ inputs=[
1002
+ style3d_image_input, style_dropdown, blend_dropdown, is_blend_mode,
1003
+ style_strength, guidance_scale, num_steps, custom_prompt,
1004
+ randomize_seed, seed_input, face_restore
1005
+ ],
1006
+ outputs=[
1007
+ style3d_output, style3d_original,
1008
+ style3d_compare_original, style3d_compare_result,
1009
+ style3d_status_output, seed_output
1010
+ ]
1011
+ )
1012
+
1013
+ clear_style_btn.click(
1014
+ fn=lambda: (None, None, None, None, "Ready! Upload an image and select a style to transform.", 0),
1015
+ outputs=[
1016
+ style3d_output, style3d_original,
1017
+ style3d_compare_original, style3d_compare_result,
1018
+ style3d_status_output, seed_output
1019
+ ]
1020
+ )
1021
+
1022
+ memory_style_btn.click(
1023
+ fn=self._cleanup_3d_memory,
1024
+ outputs=[style3d_status_output]
1025
+ )
1026
+
1027
+ def _generate_3d_style_handler(
1028
+ self,
1029
+ image: Image.Image,
1030
+ style_choice: str,
1031
+ blend_choice: str,
1032
+ is_blend_mode: bool,
1033
+ strength: float,
1034
+ guidance_scale: float,
1035
+ num_steps: int,
1036
+ custom_prompt: str,
1037
+ randomize_seed: bool,
1038
+ manual_seed: int,
1039
+ face_restore: bool = False
1040
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str, int]:
1041
+ """Handler for style transfer generation"""
1042
+ if image is None:
1043
+ return None, None, None, None, "Please upload an image first!", 0
1044
+
1045
+ try:
1046
+ # Determine style key based on mode (detected from last dropdown interaction)
1047
+ if is_blend_mode:
1048
+ style_key = self.style_engine.get_blend_key_from_choice(blend_choice)
1049
+ is_blend = True
1050
+ else:
1051
+ style_key = self.style_engine.get_style_key_from_choice(style_choice)
1052
+ is_blend = False
1053
+
1054
+ # Handle seed
1055
+ seed = -1 if randomize_seed else int(manual_seed)
1056
+
1057
+ if SPACES_AVAILABLE:
1058
+ generate_fn = spaces.GPU(duration=120)(self._3d_style_generate_core)
1059
+ else:
1060
+ generate_fn = self._3d_style_generate_core
1061
+
1062
+ result = generate_fn(
1063
+ image, style_key, is_blend, strength,
1064
+ guidance_scale, num_steps, custom_prompt, seed, face_restore
1065
+ )
1066
+
1067
+ if result["success"]:
1068
+ stylized = result["stylized_image"]
1069
+ style_name = result.get("style_name", "Style")
1070
+ seed_used = result.get("seed_used", 0)
1071
+ return (
1072
+ stylized,
1073
+ image,
1074
+ image,
1075
+ stylized,
1076
+ f"βœ“ {style_name} completed! (seed: {seed_used})",
1077
+ seed_used
1078
+ )
1079
+ else:
1080
+ error_msg = result.get("error", "Unknown error")
1081
+ return None, None, None, None, f"Error: {error_msg}", 0
1082
+
1083
+ except Exception as e:
1084
+ logger.error(f"Style generation failed: {e}")
1085
+ return None, None, None, None, f"Error: {str(e)}", 0
1086
+
1087
+ def _3d_style_generate_core(
1088
+ self,
1089
+ image: Image.Image,
1090
+ style_key: str,
1091
+ is_blend: bool,
1092
+ strength: float,
1093
+ guidance_scale: float,
1094
+ num_steps: int,
1095
+ custom_prompt: str,
1096
+ seed: int,
1097
+ face_restore: bool = False
1098
+ ) -> dict:
1099
+ """Core style transfer generation"""
1100
+ return self.style_engine.generate_all_outputs(
1101
+ image=image,
1102
+ style_key=style_key,
1103
+ strength=float(strength),
1104
+ guidance_scale=float(guidance_scale),
1105
+ num_inference_steps=int(num_steps),
1106
+ custom_prompt=custom_prompt if custom_prompt else "",
1107
+ seed=seed,
1108
+ is_blend=is_blend,
1109
+ face_restore=face_restore
1110
+ )
1111
+
1112
+ def _cleanup_3d_memory(self) -> str:
1113
+ """Clean up 3D engine memory"""
1114
+ self.style_engine.unload_model()
1115
+ return "Memory cleaned!"
1116
+