Upload 14 files
Browse files- BackgroundEngine.py +205 -3
- app.py +3 -1
- mask_generator.py +185 -2
- requirements.txt +4 -0
- style_transfer.py +708 -0
- ui_manager.py +557 -13
BackgroundEngine.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import Optional, Dict, Any, Callable
|
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings("ignore")
|
| 12 |
|
| 13 |
-
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
| 14 |
import open_clip
|
| 15 |
from mask_generator import MaskGenerator
|
| 16 |
from image_blender import ImageBlender
|
|
@@ -39,10 +39,12 @@ class BackgroundEngine:
|
|
| 39 |
self.clip_pretrained = "openai"
|
| 40 |
|
| 41 |
self.pipeline = None
|
|
|
|
| 42 |
self.clip_model = None
|
| 43 |
self.clip_preprocess = None
|
| 44 |
self.clip_tokenizer = None
|
| 45 |
self.is_initialized = False
|
|
|
|
| 46 |
|
| 47 |
self.max_image_size = 1024
|
| 48 |
self.default_steps = 25
|
|
@@ -336,13 +338,15 @@ class BackgroundEngine:
|
|
| 336 |
guidance_scale: float = 7.5,
|
| 337 |
progress_callback: Optional[Callable] = None,
|
| 338 |
enable_prompt_enhancement: bool = True,
|
| 339 |
-
feather_radius: int = 0
|
|
|
|
| 340 |
) -> Dict[str, Any]:
|
| 341 |
"""
|
| 342 |
Generate background and combine with foreground.
|
| 343 |
|
| 344 |
Args:
|
| 345 |
feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0)
|
|
|
|
| 346 |
|
| 347 |
Returns dict with: combined_image, generated_scene, original_image, mask, success
|
| 348 |
"""
|
|
@@ -391,7 +395,8 @@ class BackgroundEngine:
|
|
| 391 |
combination_mask = self.mask_generator.create_gradient_based_mask(
|
| 392 |
processed_original,
|
| 393 |
combination_mode,
|
| 394 |
-
focus_mode
|
|
|
|
| 395 |
)
|
| 396 |
|
| 397 |
if progress_callback:
|
|
@@ -430,3 +435,200 @@ class BackgroundEngine:
|
|
| 430 |
"success": False,
|
| 431 |
"error": str(e)
|
| 432 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings("ignore")
|
| 12 |
|
| 13 |
+
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline, DPMSolverMultistepScheduler
|
| 14 |
import open_clip
|
| 15 |
from mask_generator import MaskGenerator
|
| 16 |
from image_blender import ImageBlender
|
|
|
|
| 39 |
self.clip_pretrained = "openai"
|
| 40 |
|
| 41 |
self.pipeline = None
|
| 42 |
+
self.inpaint_pipeline = None
|
| 43 |
self.clip_model = None
|
| 44 |
self.clip_preprocess = None
|
| 45 |
self.clip_tokenizer = None
|
| 46 |
self.is_initialized = False
|
| 47 |
+
self.inpaint_initialized = False
|
| 48 |
|
| 49 |
self.max_image_size = 1024
|
| 50 |
self.default_steps = 25
|
|
|
|
| 338 |
guidance_scale: float = 7.5,
|
| 339 |
progress_callback: Optional[Callable] = None,
|
| 340 |
enable_prompt_enhancement: bool = True,
|
| 341 |
+
feather_radius: int = 0,
|
| 342 |
+
enhance_dark_edges: bool = False
|
| 343 |
) -> Dict[str, Any]:
|
| 344 |
"""
|
| 345 |
Generate background and combine with foreground.
|
| 346 |
|
| 347 |
Args:
|
| 348 |
feather_radius: Gaussian blur radius for mask edge softening (0-20, default 0)
|
| 349 |
+
enhance_dark_edges: Enhance mask edges for dark background images (default False)
|
| 350 |
|
| 351 |
Returns dict with: combined_image, generated_scene, original_image, mask, success
|
| 352 |
"""
|
|
|
|
| 395 |
combination_mask = self.mask_generator.create_gradient_based_mask(
|
| 396 |
processed_original,
|
| 397 |
combination_mode,
|
| 398 |
+
focus_mode,
|
| 399 |
+
enhance_dark_edges=enhance_dark_edges
|
| 400 |
)
|
| 401 |
|
| 402 |
if progress_callback:
|
|
|
|
| 435 |
"success": False,
|
| 436 |
"error": str(e)
|
| 437 |
}
|
| 438 |
+
|
| 439 |
+
def _load_inpaint_pipeline(self) -> bool:
|
| 440 |
+
"""Lazy load SDXL inpainting pipeline"""
|
| 441 |
+
if self.inpaint_initialized:
|
| 442 |
+
return True
|
| 443 |
+
|
| 444 |
+
try:
|
| 445 |
+
logger.info("Loading SDXL inpainting pipeline...")
|
| 446 |
+
actual_device = "cuda" if torch.cuda.is_available() else self.device
|
| 447 |
+
|
| 448 |
+
self.inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 449 |
+
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 450 |
+
torch_dtype=torch.float16 if actual_device == "cuda" else torch.float32,
|
| 451 |
+
variant="fp16" if actual_device == "cuda" else None,
|
| 452 |
+
use_safetensors=True
|
| 453 |
+
)
|
| 454 |
+
self.inpaint_pipeline.to(actual_device)
|
| 455 |
+
|
| 456 |
+
# Use fast scheduler
|
| 457 |
+
self.inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 458 |
+
self.inpaint_pipeline.scheduler.config
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Memory optimization
|
| 462 |
+
if actual_device == "cuda":
|
| 463 |
+
try:
|
| 464 |
+
self.inpaint_pipeline.enable_xformers_memory_efficient_attention()
|
| 465 |
+
except Exception:
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
self.inpaint_initialized = True
|
| 469 |
+
logger.info("β SDXL inpainting pipeline loaded")
|
| 470 |
+
return True
|
| 471 |
+
|
| 472 |
+
except Exception as e:
|
| 473 |
+
logger.error(f"Failed to load inpainting pipeline: {e}")
|
| 474 |
+
self.inpaint_initialized = False
|
| 475 |
+
return False
|
| 476 |
+
|
| 477 |
+
def inpaint_region(
|
| 478 |
+
self,
|
| 479 |
+
image: Image.Image,
|
| 480 |
+
mask: Image.Image,
|
| 481 |
+
prompt: str,
|
| 482 |
+
negative_prompt: str = "blurry, low quality, artifacts, seams",
|
| 483 |
+
num_inference_steps: int = 20,
|
| 484 |
+
guidance_scale: float = 7.5,
|
| 485 |
+
strength: float = 0.99
|
| 486 |
+
) -> Dict[str, Any]:
|
| 487 |
+
"""
|
| 488 |
+
Inpaint marked regions with background content.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
image: The combined image with artifacts to fix
|
| 492 |
+
mask: Binary mask where white = areas to inpaint
|
| 493 |
+
prompt: Background description for inpainting
|
| 494 |
+
negative_prompt: What to avoid
|
| 495 |
+
num_inference_steps: Denoising steps (20 is usually enough)
|
| 496 |
+
guidance_scale: How closely to follow prompt
|
| 497 |
+
strength: How much to change masked area (0.99 = almost complete replacement)
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
Dict with inpainted_image, success, error
|
| 501 |
+
"""
|
| 502 |
+
try:
|
| 503 |
+
# Load inpainting pipeline if not already loaded
|
| 504 |
+
if not self._load_inpaint_pipeline():
|
| 505 |
+
# Fallback to OpenCV inpainting
|
| 506 |
+
return self._opencv_inpaint_fallback(image, mask)
|
| 507 |
+
|
| 508 |
+
logger.info("Starting region inpainting...")
|
| 509 |
+
|
| 510 |
+
# Prepare images
|
| 511 |
+
image = self._prepare_image(image)
|
| 512 |
+
mask = mask.resize(image.size, Image.LANCZOS).convert('L')
|
| 513 |
+
|
| 514 |
+
# Ensure mask is properly binarized
|
| 515 |
+
mask_array = np.array(mask)
|
| 516 |
+
mask_array = (mask_array > 127).astype(np.uint8) * 255
|
| 517 |
+
mask = Image.fromarray(mask_array, mode='L')
|
| 518 |
+
|
| 519 |
+
# Dilate mask slightly for better blending
|
| 520 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 521 |
+
mask_dilated = cv2.dilate(mask_array, kernel, iterations=1)
|
| 522 |
+
mask = Image.fromarray(mask_dilated, mode='L')
|
| 523 |
+
|
| 524 |
+
actual_device = "cuda" if torch.cuda.is_available() else self.device
|
| 525 |
+
|
| 526 |
+
with torch.inference_mode():
|
| 527 |
+
result = self.inpaint_pipeline(
|
| 528 |
+
prompt=prompt,
|
| 529 |
+
negative_prompt=negative_prompt,
|
| 530 |
+
image=image,
|
| 531 |
+
mask_image=mask,
|
| 532 |
+
width=image.size[0],
|
| 533 |
+
height=image.size[1],
|
| 534 |
+
num_inference_steps=num_inference_steps,
|
| 535 |
+
guidance_scale=guidance_scale,
|
| 536 |
+
strength=strength,
|
| 537 |
+
generator=torch.Generator(device=actual_device).manual_seed(42)
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
inpainted = result.images[0]
|
| 541 |
+
|
| 542 |
+
# Blend edges for smoother transition
|
| 543 |
+
inpainted = self._blend_inpaint_edges(image, inpainted, mask)
|
| 544 |
+
|
| 545 |
+
self._memory_cleanup()
|
| 546 |
+
|
| 547 |
+
logger.info("β Region inpainting completed")
|
| 548 |
+
return {
|
| 549 |
+
"inpainted_image": inpainted,
|
| 550 |
+
"success": True
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
except Exception as e:
|
| 554 |
+
logger.error(f"Inpainting failed: {e}")
|
| 555 |
+
self._memory_cleanup()
|
| 556 |
+
return {
|
| 557 |
+
"success": False,
|
| 558 |
+
"error": str(e)
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
def _opencv_inpaint_fallback(
|
| 562 |
+
self,
|
| 563 |
+
image: Image.Image,
|
| 564 |
+
mask: Image.Image
|
| 565 |
+
) -> Dict[str, Any]:
|
| 566 |
+
"""Fallback to OpenCV inpainting for small areas or when SDXL unavailable"""
|
| 567 |
+
try:
|
| 568 |
+
logger.info("Using OpenCV inpainting fallback...")
|
| 569 |
+
|
| 570 |
+
img_array = np.array(image.convert('RGB'))
|
| 571 |
+
mask_array = np.array(mask.convert('L'))
|
| 572 |
+
|
| 573 |
+
# Binarize mask
|
| 574 |
+
mask_binary = (mask_array > 127).astype(np.uint8) * 255
|
| 575 |
+
|
| 576 |
+
# Use Telea algorithm for natural results
|
| 577 |
+
inpainted = cv2.inpaint(
|
| 578 |
+
img_array,
|
| 579 |
+
mask_binary,
|
| 580 |
+
inpaintRadius=5,
|
| 581 |
+
flags=cv2.INPAINT_TELEA
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
result = Image.fromarray(inpainted)
|
| 585 |
+
|
| 586 |
+
logger.info("β OpenCV inpainting completed")
|
| 587 |
+
return {
|
| 588 |
+
"inpainted_image": result,
|
| 589 |
+
"success": True
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
except Exception as e:
|
| 593 |
+
logger.error(f"OpenCV inpainting failed: {e}")
|
| 594 |
+
return {
|
| 595 |
+
"success": False,
|
| 596 |
+
"error": str(e)
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
def _blend_inpaint_edges(
|
| 600 |
+
self,
|
| 601 |
+
original: Image.Image,
|
| 602 |
+
inpainted: Image.Image,
|
| 603 |
+
mask: Image.Image,
|
| 604 |
+
feather_pixels: int = 8
|
| 605 |
+
) -> Image.Image:
|
| 606 |
+
"""Blend inpainted region edges for seamless transition"""
|
| 607 |
+
try:
|
| 608 |
+
orig_array = np.array(original).astype(np.float32)
|
| 609 |
+
inpaint_array = np.array(inpainted).astype(np.float32)
|
| 610 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 611 |
+
|
| 612 |
+
# Create feathered mask for smooth blending
|
| 613 |
+
if feather_pixels > 0:
|
| 614 |
+
kernel_size = feather_pixels * 2 + 1
|
| 615 |
+
mask_feathered = cv2.GaussianBlur(
|
| 616 |
+
mask_array,
|
| 617 |
+
(kernel_size, kernel_size),
|
| 618 |
+
feather_pixels / 2
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
mask_feathered = mask_array
|
| 622 |
+
|
| 623 |
+
# Expand mask to 3 channels
|
| 624 |
+
mask_3d = mask_feathered[:, :, np.newaxis]
|
| 625 |
+
|
| 626 |
+
# Blend: inpainted in masked area, original elsewhere
|
| 627 |
+
blended = inpaint_array * mask_3d + orig_array * (1 - mask_3d)
|
| 628 |
+
blended = np.clip(blended, 0, 255).astype(np.uint8)
|
| 629 |
+
|
| 630 |
+
return Image.fromarray(blended)
|
| 631 |
+
|
| 632 |
+
except Exception as e:
|
| 633 |
+
logger.warning(f"Edge blending failed: {e}, returning inpainted directly")
|
| 634 |
+
return inpainted
|
app.py
CHANGED
|
@@ -16,6 +16,7 @@ import sentencepiece
|
|
| 16 |
|
| 17 |
from FlowFacade import FlowFacade
|
| 18 |
from BackgroundEngine import BackgroundEngine
|
|
|
|
| 19 |
from ui_manager import UIManager
|
| 20 |
|
| 21 |
|
|
@@ -126,7 +127,8 @@ def main():
|
|
| 126 |
try:
|
| 127 |
facade = FlowFacade()
|
| 128 |
background_engine = BackgroundEngine()
|
| 129 |
-
|
|
|
|
| 130 |
interface = ui_manager.create_interface()
|
| 131 |
is_colab = 'google.colab' in sys.modules
|
| 132 |
|
|
|
|
| 16 |
|
| 17 |
from FlowFacade import FlowFacade
|
| 18 |
from BackgroundEngine import BackgroundEngine
|
| 19 |
+
from style_transfer import StyleTransferEngine
|
| 20 |
from ui_manager import UIManager
|
| 21 |
|
| 22 |
|
|
|
|
| 127 |
try:
|
| 128 |
facade = FlowFacade()
|
| 129 |
background_engine = BackgroundEngine()
|
| 130 |
+
style_engine = StyleTransferEngine()
|
| 131 |
+
ui_manager = UIManager(facade, background_engine, style_engine)
|
| 132 |
interface = ui_manager.create_interface()
|
| 133 |
is_colab = 'google.colab' in sys.modules
|
| 134 |
|
mask_generator.py
CHANGED
|
@@ -15,6 +15,13 @@ from rembg import remove, new_session
|
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
logger.setLevel(logging.INFO)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class MaskGenerator:
|
| 19 |
"""
|
| 20 |
Intelligent mask generation using deep learning models with traditional fallback.
|
|
@@ -92,6 +99,146 @@ class MaskGenerator:
|
|
| 92 |
gc.collect()
|
| 93 |
logger.info("π§Ή BiRefNet model unloaded")
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def apply_guided_filter(
|
| 96 |
self,
|
| 97 |
mask: np.ndarray,
|
|
@@ -481,13 +628,25 @@ class MaskGenerator:
|
|
| 481 |
logger.error(f"β Scene focus adjustment failed: {e}")
|
| 482 |
return mask
|
| 483 |
|
| 484 |
-
def create_gradient_based_mask(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
"""
|
| 486 |
Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
|
| 487 |
Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
"""
|
| 489 |
width, height = original_image.size
|
| 490 |
-
logger.info(f"π― Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}")
|
| 491 |
|
| 492 |
if mode == "center":
|
| 493 |
# Try using deep learning models for intelligent foreground extraction
|
|
@@ -495,9 +654,33 @@ class MaskGenerator:
|
|
| 495 |
dl_mask = self.try_deep_learning_mask(original_image)
|
| 496 |
if dl_mask is not None:
|
| 497 |
logger.info("β
Using deep learning generated mask")
|
|
|
|
| 498 |
# Apply focus mode adjustments to deep learning mask
|
| 499 |
if focus_mode == "scene":
|
| 500 |
dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
return dl_mask
|
| 502 |
|
| 503 |
# Fallback to traditional method
|
|
|
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
logger.setLevel(logging.INFO)
|
| 17 |
|
| 18 |
+
# Dark background detection thresholds
|
| 19 |
+
DARK_BG_LUMINANCE_THRESHOLD = 50 # Average luminance below this = dark background
|
| 20 |
+
DARK_BG_EDGE_SAMPLE_WIDTH = 20 # Pixels from edge to sample for background detection
|
| 21 |
+
DARK_BG_DILATION_PIXELS = 5 # Default dilation for dark backgrounds
|
| 22 |
+
DARK_BG_ENHANCED_DILATION = 8 # Enhanced dilation when user enables option
|
| 23 |
+
|
| 24 |
+
|
| 25 |
class MaskGenerator:
|
| 26 |
"""
|
| 27 |
Intelligent mask generation using deep learning models with traditional fallback.
|
|
|
|
| 99 |
gc.collect()
|
| 100 |
logger.info("π§Ή BiRefNet model unloaded")
|
| 101 |
|
| 102 |
+
def detect_dark_background(self, image: Image.Image, mask: Optional[np.ndarray] = None) -> Tuple[bool, float]:
|
| 103 |
+
"""
|
| 104 |
+
Detect if the image has a dark background.
|
| 105 |
+
|
| 106 |
+
Analyzes the edge regions of the image (where background is likely) to determine
|
| 107 |
+
if the background is predominantly dark, which can cause mask detection issues.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
image: Input PIL Image
|
| 111 |
+
mask: Optional existing mask to exclude foreground from analysis
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tuple of (is_dark_background: bool, avg_luminance: float)
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
img_array = np.array(image.convert('RGB'))
|
| 118 |
+
height, width = img_array.shape[:2]
|
| 119 |
+
|
| 120 |
+
# Convert to grayscale for luminance analysis
|
| 121 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 122 |
+
|
| 123 |
+
# Sample from edge regions (likely background)
|
| 124 |
+
edge_width = min(DARK_BG_EDGE_SAMPLE_WIDTH, width // 10, height // 10)
|
| 125 |
+
|
| 126 |
+
# Create edge sampling mask
|
| 127 |
+
edge_sample_mask = np.zeros((height, width), dtype=bool)
|
| 128 |
+
edge_sample_mask[:edge_width, :] = True # Top
|
| 129 |
+
edge_sample_mask[-edge_width:, :] = True # Bottom
|
| 130 |
+
edge_sample_mask[:, :edge_width] = True # Left
|
| 131 |
+
edge_sample_mask[:, -edge_width:] = True # Right
|
| 132 |
+
|
| 133 |
+
# Exclude foreground if mask is provided
|
| 134 |
+
if mask is not None:
|
| 135 |
+
foreground_mask = mask > 127
|
| 136 |
+
edge_sample_mask = edge_sample_mask & (~foreground_mask)
|
| 137 |
+
|
| 138 |
+
if not np.any(edge_sample_mask):
|
| 139 |
+
# Fallback: use corners only
|
| 140 |
+
corner_pixels = np.array([
|
| 141 |
+
gray[0, 0], gray[0, -1],
|
| 142 |
+
gray[-1, 0], gray[-1, -1]
|
| 143 |
+
])
|
| 144 |
+
avg_luminance = np.mean(corner_pixels)
|
| 145 |
+
else:
|
| 146 |
+
avg_luminance = np.mean(gray[edge_sample_mask])
|
| 147 |
+
|
| 148 |
+
is_dark = avg_luminance < DARK_BG_LUMINANCE_THRESHOLD
|
| 149 |
+
|
| 150 |
+
logger.info(f"π Background analysis - Avg luminance: {avg_luminance:.1f}, Dark: {is_dark}")
|
| 151 |
+
|
| 152 |
+
return is_dark, avg_luminance
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"β Dark background detection failed: {e}")
|
| 156 |
+
return False, 128.0 # Default: not dark
|
| 157 |
+
|
| 158 |
+
def enhance_mask_for_dark_background(
|
| 159 |
+
self,
|
| 160 |
+
mask: Image.Image,
|
| 161 |
+
original_image: Image.Image,
|
| 162 |
+
dilation_pixels: int = DARK_BG_DILATION_PIXELS,
|
| 163 |
+
enhance_gray_areas: bool = True
|
| 164 |
+
) -> Image.Image:
|
| 165 |
+
"""
|
| 166 |
+
Enhance mask for images with dark backgrounds.
|
| 167 |
+
|
| 168 |
+
Applies dilation and gray area enhancement to capture foreground elements
|
| 169 |
+
that may have been missed due to low contrast with dark backgrounds.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
mask: Input mask PIL Image (L mode)
|
| 173 |
+
original_image: Original image for reference
|
| 174 |
+
dilation_pixels: Number of pixels to dilate the mask
|
| 175 |
+
enhance_gray_areas: Whether to boost gray (uncertain) areas
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Enhanced mask PIL Image
|
| 179 |
+
"""
|
| 180 |
+
try:
|
| 181 |
+
mask_array = np.array(mask)
|
| 182 |
+
orig_array = np.array(original_image.convert('RGB'))
|
| 183 |
+
|
| 184 |
+
logger.info(f"π§ Enhancing mask for dark background (dilation: {dilation_pixels}px)")
|
| 185 |
+
|
| 186 |
+
# Step 1: Identify gray (uncertain) areas in the mask
|
| 187 |
+
if enhance_gray_areas:
|
| 188 |
+
gray_areas = (mask_array > 30) & (mask_array < 200)
|
| 189 |
+
|
| 190 |
+
if np.any(gray_areas):
|
| 191 |
+
# For gray areas, check if they're near high-confidence foreground
|
| 192 |
+
high_conf = mask_array >= 200
|
| 193 |
+
|
| 194 |
+
# Dilate high confidence area to find nearby gray pixels
|
| 195 |
+
kernel_check = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
|
| 196 |
+
high_conf_dilated = cv2.dilate(high_conf.astype(np.uint8), kernel_check, iterations=2)
|
| 197 |
+
|
| 198 |
+
# Gray pixels near high confidence foreground -> boost them
|
| 199 |
+
boost_candidates = gray_areas & (high_conf_dilated > 0)
|
| 200 |
+
|
| 201 |
+
# Boost gray areas near foreground
|
| 202 |
+
mask_array[boost_candidates] = np.clip(
|
| 203 |
+
mask_array[boost_candidates] * 1.5 + 50,
|
| 204 |
+
0, 255
|
| 205 |
+
).astype(np.uint8)
|
| 206 |
+
|
| 207 |
+
logger.info(f"π Boosted {np.sum(boost_candidates)} gray pixels near foreground")
|
| 208 |
+
|
| 209 |
+
# Step 2: Apply dilation to expand foreground coverage
|
| 210 |
+
if dilation_pixels > 0:
|
| 211 |
+
kernel = cv2.getStructuringElement(
|
| 212 |
+
cv2.MORPH_ELLIPSE,
|
| 213 |
+
(dilation_pixels * 2 + 1, dilation_pixels * 2 + 1)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Threshold to get foreground region for dilation
|
| 217 |
+
fg_binary = (mask_array > 50).astype(np.uint8) * 255
|
| 218 |
+
fg_dilated = cv2.dilate(fg_binary, kernel, iterations=1)
|
| 219 |
+
|
| 220 |
+
# Blend: keep original high values, expand into new areas
|
| 221 |
+
# New areas from dilation get moderate confidence
|
| 222 |
+
new_areas = (fg_dilated > 0) & (mask_array < 50)
|
| 223 |
+
mask_array[new_areas] = 180 # Moderate confidence for expanded areas
|
| 224 |
+
|
| 225 |
+
logger.info(f"π Dilated mask by {dilation_pixels}px, added {np.sum(new_areas)} pixels")
|
| 226 |
+
|
| 227 |
+
# Step 3: Smooth the transitions
|
| 228 |
+
mask_array = cv2.GaussianBlur(mask_array, (3, 3), 0.8)
|
| 229 |
+
|
| 230 |
+
# Step 4: Re-strengthen core foreground
|
| 231 |
+
core_fg = np.array(mask) >= 220
|
| 232 |
+
mask_array[core_fg] = 255
|
| 233 |
+
|
| 234 |
+
logger.info(f"β
Dark background enhancement complete - Final mean: {mask_array.mean():.1f}")
|
| 235 |
+
|
| 236 |
+
return Image.fromarray(mask_array, mode='L')
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.error(f"β Mask enhancement failed: {e}")
|
| 240 |
+
return mask
|
| 241 |
+
|
| 242 |
def apply_guided_filter(
|
| 243 |
self,
|
| 244 |
mask: np.ndarray,
|
|
|
|
| 628 |
logger.error(f"β Scene focus adjustment failed: {e}")
|
| 629 |
return mask
|
| 630 |
|
| 631 |
+
def create_gradient_based_mask(
|
| 632 |
+
self,
|
| 633 |
+
original_image: Image.Image,
|
| 634 |
+
mode: str = "center",
|
| 635 |
+
focus_mode: str = "person",
|
| 636 |
+
enhance_dark_edges: bool = False
|
| 637 |
+
) -> Image.Image:
|
| 638 |
"""
|
| 639 |
Intelligent foreground extraction: prioritize deep learning models, fallback to traditional methods
|
| 640 |
Focus mode: 'person' for tight crop around person, 'scene' for including nearby objects
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
original_image: Input PIL Image
|
| 644 |
+
mode: Composition mode (center, left_half, right_half, full)
|
| 645 |
+
focus_mode: 'person' for tight crop, 'scene' for including nearby objects
|
| 646 |
+
enhance_dark_edges: User toggle to enhance mask for dark backgrounds
|
| 647 |
"""
|
| 648 |
width, height = original_image.size
|
| 649 |
+
logger.info(f"π― Creating mask for {width}x{height} image, mode: {mode}, focus: {focus_mode}, enhance_dark: {enhance_dark_edges}")
|
| 650 |
|
| 651 |
if mode == "center":
|
| 652 |
# Try using deep learning models for intelligent foreground extraction
|
|
|
|
| 654 |
dl_mask = self.try_deep_learning_mask(original_image)
|
| 655 |
if dl_mask is not None:
|
| 656 |
logger.info("β
Using deep learning generated mask")
|
| 657 |
+
|
| 658 |
# Apply focus mode adjustments to deep learning mask
|
| 659 |
if focus_mode == "scene":
|
| 660 |
dl_mask = self._adjust_mask_for_scene_focus(dl_mask, original_image)
|
| 661 |
+
|
| 662 |
+
# === Dark background detection and enhancement ===
|
| 663 |
+
mask_array = np.array(dl_mask)
|
| 664 |
+
is_dark_bg, avg_luminance = self.detect_dark_background(original_image, mask_array)
|
| 665 |
+
|
| 666 |
+
if is_dark_bg or enhance_dark_edges:
|
| 667 |
+
# Determine dilation amount
|
| 668 |
+
if enhance_dark_edges:
|
| 669 |
+
# User explicitly enabled - use stronger dilation
|
| 670 |
+
dilation = DARK_BG_ENHANCED_DILATION
|
| 671 |
+
logger.info(f"π User enabled dark edge enhancement (dilation: {dilation}px)")
|
| 672 |
+
else:
|
| 673 |
+
# Auto-detected dark background - use moderate dilation
|
| 674 |
+
dilation = DARK_BG_DILATION_PIXELS
|
| 675 |
+
logger.info(f"π Auto-detected dark background (luminance: {avg_luminance:.1f}), applying enhancement")
|
| 676 |
+
|
| 677 |
+
dl_mask = self.enhance_mask_for_dark_background(
|
| 678 |
+
dl_mask,
|
| 679 |
+
original_image,
|
| 680 |
+
dilation_pixels=dilation,
|
| 681 |
+
enhance_gray_areas=True
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
return dl_mask
|
| 685 |
|
| 686 |
# Fallback to traditional method
|
requirements.txt
CHANGED
|
@@ -20,6 +20,10 @@ rembg[gpu]
|
|
| 20 |
scipy
|
| 21 |
opencv-contrib-python
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Core Dependencies
|
| 24 |
torch>=2.5.0
|
| 25 |
numpy
|
|
|
|
| 20 |
scipy
|
| 21 |
opencv-contrib-python
|
| 22 |
|
| 23 |
+
# 3D Cartoon Style Dependencies (SDXL + Pixar LoRA)
|
| 24 |
+
# Note: diffusers is already included above for I2V
|
| 25 |
+
# SDXL uses the same diffusers library
|
| 26 |
+
|
| 27 |
# Core Dependencies
|
| 28 |
torch>=2.5.0
|
| 29 |
numpy
|
style_transfer.py
ADDED
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
from typing import Tuple, Optional, Dict, Any
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import spaces
|
| 10 |
+
HAS_SPACES = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
HAS_SPACES = False
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Identity preservation keywords (added to all styles) - kept short for CLIP 77 token limit
|
| 16 |
+
IDENTITY_PRESERVE = "same person, same face, same ethnicity, same age"
|
| 17 |
+
IDENTITY_NEGATIVE = "different person, altered face, changed ethnicity, age change, distorted features"
|
| 18 |
+
|
| 19 |
+
# Enhanced face restore mode - concise weighted keywords
|
| 20 |
+
FACE_RESTORE_PRESERVE = "(same person:1.4), (preserve face:1.3), (same ethnicity:1.2), same pose, same lighting"
|
| 21 |
+
FACE_RESTORE_NEGATIVE = "(different person:1.4), (deformed face:1.3), wrong ethnicity, age change, western features"
|
| 22 |
+
|
| 23 |
+
# IP-Adapter settings for stronger identity preservation
|
| 24 |
+
# Using standard IP-Adapter (not face-specific) to avoid image encoder dependency
|
| 25 |
+
IP_ADAPTER_REPO = "h94/IP-Adapter"
|
| 26 |
+
IP_ADAPTER_SUBFOLDER = "sdxl_models"
|
| 27 |
+
IP_ADAPTER_WEIGHT = "ip-adapter_sdxl.bin" # Standard model, no extra encoder needed
|
| 28 |
+
IP_ADAPTER_SCALE_DEFAULT = 0.5 # Balance between identity and style
|
| 29 |
+
|
| 30 |
+
# Style-specific face_restore settings (some styles are more transformative)
|
| 31 |
+
FACE_RESTORE_STYLE_SETTINGS = {
|
| 32 |
+
"3d_cartoon": {"max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.4},
|
| 33 |
+
"anime": {"max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.4},
|
| 34 |
+
"illustrated_fantasy": {"max_strength": 0.42, "lora_scale_mult": 0.65, "ip_scale": 0.45},
|
| 35 |
+
"watercolor": {"max_strength": 0.40, "lora_scale_mult": 0.6, "ip_scale": 0.5},
|
| 36 |
+
"oil_painting": {"max_strength": 0.35, "lora_scale_mult": 0.5, "ip_scale": 0.6}, # Most transformative
|
| 37 |
+
"pixel_art": {"max_strength": 0.50, "lora_scale_mult": 0.8, "ip_scale": 0.3},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Style configurations
|
| 41 |
+
STYLE_CONFIGS = {
|
| 42 |
+
"3d_cartoon": {
|
| 43 |
+
"name": "3D Cartoon",
|
| 44 |
+
"emoji": "π¬",
|
| 45 |
+
"lora_repo": "imagepipeline/Samaritan-3d-Cartoon-SDXL",
|
| 46 |
+
"lora_weight": "Samaritan 3d Cartoon.safetensors",
|
| 47 |
+
"prompt": "3D cartoon style, smooth rounded features, soft ambient lighting, CGI quality, vibrant colors, cel-shaded, studio render",
|
| 48 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low quality, flat, sketch",
|
| 49 |
+
"lora_scale": 0.75,
|
| 50 |
+
"recommended_strength": 0.55,
|
| 51 |
+
},
|
| 52 |
+
"anime": {
|
| 53 |
+
"name": "Anime Illustration",
|
| 54 |
+
"emoji": "πΈ",
|
| 55 |
+
"lora_repo": None,
|
| 56 |
+
"lora_weight": None,
|
| 57 |
+
"prompt": "anime illustration, soft lighting, rich colors, delicate linework, smooth gradients, expressive eyes, cel shading, masterpiece",
|
| 58 |
+
"negative_prompt": "ugly, deformed, bad anatomy, bad hands, blurry, low quality",
|
| 59 |
+
"lora_scale": 0.0,
|
| 60 |
+
"recommended_strength": 0.50,
|
| 61 |
+
},
|
| 62 |
+
"illustrated_fantasy": {
|
| 63 |
+
"name": "Illustrated Fantasy",
|
| 64 |
+
"emoji": "π",
|
| 65 |
+
"lora_repo": "ntc-ai/SDXL-LoRA-slider.Studio-Ghibli-style",
|
| 66 |
+
"lora_weight": "Studio Ghibli style.safetensors",
|
| 67 |
+
"prompt": "Ghibli style illustration, hand-painted look, soft watercolor textures, dreamy atmosphere, pastel colors, golden hour lighting, storybook quality",
|
| 68 |
+
"negative_prompt": "ugly, dark, horror, scary, blurry, low quality, modern",
|
| 69 |
+
"lora_scale": 1.0,
|
| 70 |
+
"recommended_strength": 0.50,
|
| 71 |
+
},
|
| 72 |
+
"watercolor": {
|
| 73 |
+
"name": "Watercolor Art",
|
| 74 |
+
"emoji": "π",
|
| 75 |
+
"lora_repo": "ostris/watercolor_style_lora_sdxl",
|
| 76 |
+
"lora_weight": "watercolor_style_lora.safetensors",
|
| 77 |
+
"prompt": "watercolor painting, wet-on-wet technique, soft color bleeds, paper texture, transparent washes, feathered edges, hand-painted",
|
| 78 |
+
"negative_prompt": "sharp edges, solid flat colors, harsh lines, vector art, airbrushed",
|
| 79 |
+
"lora_scale": 1.0,
|
| 80 |
+
"recommended_strength": 0.50,
|
| 81 |
+
},
|
| 82 |
+
"oil_painting": {
|
| 83 |
+
"name": "Classic Oil Paint",
|
| 84 |
+
"emoji": "πΌοΈ",
|
| 85 |
+
"lora_repo": "EldritchAdam/ClassipeintXL",
|
| 86 |
+
"lora_weight": "ClassipeintXL.safetensors",
|
| 87 |
+
"prompt": "oil painting style, impasto technique, palette knife strokes, visible canvas texture, rich saturated pigments, masterful lighting, museum quality",
|
| 88 |
+
"negative_prompt": "flat, smooth, cartoon, anime, blurry, low quality, modern, airbrushed",
|
| 89 |
+
"lora_scale": 0.9,
|
| 90 |
+
"recommended_strength": 0.50,
|
| 91 |
+
},
|
| 92 |
+
"pixel_art": {
|
| 93 |
+
"name": "Pixel Art",
|
| 94 |
+
"emoji": "πΎ",
|
| 95 |
+
"lora_repo": "nerijs/pixel-art-xl",
|
| 96 |
+
"lora_weight": "pixel-art-xl.safetensors",
|
| 97 |
+
"prompt": "pixel art style, crisp blocky pixels, limited color palette, 16-bit aesthetic, retro game vibes, dithering effects, sprite art",
|
| 98 |
+
"negative_prompt": "smooth, blurry, anti-aliased, soft gradient, painterly",
|
| 99 |
+
"lora_scale": 0.9,
|
| 100 |
+
"recommended_strength": 0.60,
|
| 101 |
+
},
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# Style Blend Presets - combining multiple styles (prompts kept short for CLIP 77 token limit)
|
| 105 |
+
STYLE_BLENDS = {
|
| 106 |
+
"cartoon_anime": {
|
| 107 |
+
"name": "3D Anime Fusion",
|
| 108 |
+
"emoji": "οΏ½οΏ½",
|
| 109 |
+
"description": "70% 3D Cartoon + 30% Anime linework",
|
| 110 |
+
"primary_style": "3d_cartoon",
|
| 111 |
+
"secondary_style": "anime",
|
| 112 |
+
"primary_weight": 0.7,
|
| 113 |
+
"secondary_weight": 0.3,
|
| 114 |
+
"prompt": "3D cartoon with anime linework, smooth features, soft lighting, CGI quality, vibrant colors, cel-shaded",
|
| 115 |
+
"negative_prompt": "ugly, deformed, noisy, blurry, low quality",
|
| 116 |
+
"strength": 0.52,
|
| 117 |
+
},
|
| 118 |
+
"fantasy_watercolor": {
|
| 119 |
+
"name": "Dreamy Watercolor",
|
| 120 |
+
"emoji": "π",
|
| 121 |
+
"description": "60% Illustrated Fantasy + 40% Watercolor",
|
| 122 |
+
"primary_style": "illustrated_fantasy",
|
| 123 |
+
"secondary_style": "watercolor",
|
| 124 |
+
"primary_weight": 0.6,
|
| 125 |
+
"secondary_weight": 0.4,
|
| 126 |
+
"prompt": "Ghibli style with watercolor washes, soft color bleeds, storybook atmosphere, paper texture, warm golden lighting",
|
| 127 |
+
"negative_prompt": "dark, horror, harsh lines, solid colors",
|
| 128 |
+
"strength": 0.50,
|
| 129 |
+
},
|
| 130 |
+
"anime_fantasy": {
|
| 131 |
+
"name": "Anime Storybook",
|
| 132 |
+
"emoji": "π",
|
| 133 |
+
"description": "50% Anime + 50% Illustrated Fantasy",
|
| 134 |
+
"primary_style": "anime",
|
| 135 |
+
"secondary_style": "illustrated_fantasy",
|
| 136 |
+
"primary_weight": 0.5,
|
| 137 |
+
"secondary_weight": 0.5,
|
| 138 |
+
"prompt": "Ghibli anime illustration, hand-painted storybook, soft lighting, pastel colors, expressive eyes, warm glow",
|
| 139 |
+
"negative_prompt": "ugly, deformed, bad anatomy, dark, horror, blurry",
|
| 140 |
+
"strength": 0.48,
|
| 141 |
+
},
|
| 142 |
+
"oil_classical": {
|
| 143 |
+
"name": "Renaissance Portrait",
|
| 144 |
+
"emoji": "π",
|
| 145 |
+
"description": "Classical oil painting style",
|
| 146 |
+
"primary_style": "oil_painting",
|
| 147 |
+
"secondary_style": "oil_painting",
|
| 148 |
+
"primary_weight": 1.0,
|
| 149 |
+
"secondary_weight": 0.0,
|
| 150 |
+
"prompt": "classical oil portrait, impasto technique, palette knife strokes, chiaroscuro lighting, canvas texture, museum quality",
|
| 151 |
+
"negative_prompt": "flat, cartoon, anime, modern, minimalist, overexposed",
|
| 152 |
+
"strength": 0.50,
|
| 153 |
+
},
|
| 154 |
+
"pixel_retro": {
|
| 155 |
+
"name": "Retro Game Art",
|
| 156 |
+
"emoji": "πΉοΈ",
|
| 157 |
+
"description": "Pixel art with enhanced retro feel",
|
| 158 |
+
"primary_style": "pixel_art",
|
| 159 |
+
"secondary_style": "pixel_art",
|
| 160 |
+
"primary_weight": 1.0,
|
| 161 |
+
"secondary_weight": 0.0,
|
| 162 |
+
"prompt": "retro pixel art, crisp blocky pixels, limited palette, arcade aesthetic, dithering, 16-bit charm, sprite art",
|
| 163 |
+
"negative_prompt": "smooth, blurry, anti-aliased, modern, gradient",
|
| 164 |
+
"strength": 0.58,
|
| 165 |
+
},
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class StyleTransferEngine:
|
| 170 |
+
"""
|
| 171 |
+
Multi-style image transformation engine using SDXL + LoRAs.
|
| 172 |
+
Supports: 3D Cartoon, Anime, Watercolor, Oil Painting, Pixel Art styles.
|
| 173 |
+
With IP-Adapter support for identity preservation.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 177 |
+
|
| 178 |
+
def __init__(self):
|
| 179 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 180 |
+
self.pipe = None
|
| 181 |
+
self.current_lora = None
|
| 182 |
+
self.is_loaded = False
|
| 183 |
+
self.ip_adapter_loaded = False
|
| 184 |
+
|
| 185 |
+
def load_model(self) -> None:
|
| 186 |
+
"""Load SDXL base pipeline."""
|
| 187 |
+
if self.is_loaded:
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
print("β Loading SDXL base model...")
|
| 191 |
+
|
| 192 |
+
from diffusers import AutoPipelineForImage2Image
|
| 193 |
+
|
| 194 |
+
actual_device = "cuda" if torch.cuda.is_available() else self.device
|
| 195 |
+
|
| 196 |
+
self.pipe = AutoPipelineForImage2Image.from_pretrained(
|
| 197 |
+
self.BASE_MODEL,
|
| 198 |
+
torch_dtype=torch.float16 if actual_device == "cuda" else torch.float32,
|
| 199 |
+
variant="fp16" if actual_device == "cuda" else None,
|
| 200 |
+
use_safetensors=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.pipe.to(actual_device)
|
| 204 |
+
|
| 205 |
+
# Enable memory optimizations
|
| 206 |
+
if actual_device == "cuda":
|
| 207 |
+
try:
|
| 208 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 209 |
+
except Exception:
|
| 210 |
+
pass
|
| 211 |
+
|
| 212 |
+
self.is_loaded = True
|
| 213 |
+
self.device = actual_device
|
| 214 |
+
print(f"β SDXL base loaded ({actual_device})")
|
| 215 |
+
|
| 216 |
+
def _load_lora(self, style_key: str) -> None:
|
| 217 |
+
"""Load LoRA for the specified style."""
|
| 218 |
+
config = STYLE_CONFIGS.get(style_key)
|
| 219 |
+
if not config:
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
lora_repo = config.get("lora_repo")
|
| 223 |
+
|
| 224 |
+
# Skip if no LoRA needed or already loaded
|
| 225 |
+
if lora_repo is None:
|
| 226 |
+
if self.current_lora is not None:
|
| 227 |
+
print("β Unloading previous LoRA...")
|
| 228 |
+
self.pipe.unload_lora_weights()
|
| 229 |
+
self.current_lora = None
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
if self.current_lora == lora_repo:
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
# Unload previous LoRA if different
|
| 236 |
+
if self.current_lora is not None:
|
| 237 |
+
print(f"β Unloading previous LoRA: {self.current_lora}")
|
| 238 |
+
self.pipe.unload_lora_weights()
|
| 239 |
+
|
| 240 |
+
# Load new LoRA
|
| 241 |
+
print(f"β Loading LoRA: {config['name']}...")
|
| 242 |
+
try:
|
| 243 |
+
lora_weight = config.get("lora_weight")
|
| 244 |
+
if lora_weight:
|
| 245 |
+
self.pipe.load_lora_weights(lora_repo, weight_name=lora_weight)
|
| 246 |
+
else:
|
| 247 |
+
self.pipe.load_lora_weights(lora_repo)
|
| 248 |
+
|
| 249 |
+
self.current_lora = lora_repo
|
| 250 |
+
print(f"β LoRA loaded: {config['name']}")
|
| 251 |
+
except Exception as e:
|
| 252 |
+
print(f"β LoRA loading failed: {e}, continuing without LoRA")
|
| 253 |
+
self.current_lora = None
|
| 254 |
+
|
| 255 |
+
def _load_ip_adapter(self) -> bool:
|
| 256 |
+
"""Load IP-Adapter for identity preservation."""
|
| 257 |
+
if self.ip_adapter_loaded:
|
| 258 |
+
return True
|
| 259 |
+
|
| 260 |
+
if self.pipe is None:
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
print("β Loading IP-Adapter for face preservation...")
|
| 264 |
+
try:
|
| 265 |
+
self.pipe.load_ip_adapter(
|
| 266 |
+
IP_ADAPTER_REPO,
|
| 267 |
+
subfolder=IP_ADAPTER_SUBFOLDER,
|
| 268 |
+
weight_name=IP_ADAPTER_WEIGHT
|
| 269 |
+
)
|
| 270 |
+
self.ip_adapter_loaded = True
|
| 271 |
+
print("β IP-Adapter loaded")
|
| 272 |
+
return True
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f"β IP-Adapter loading failed: {e}")
|
| 275 |
+
self.ip_adapter_loaded = False
|
| 276 |
+
return False
|
| 277 |
+
|
| 278 |
+
def _unload_ip_adapter(self) -> None:
|
| 279 |
+
"""Unload IP-Adapter to free memory."""
|
| 280 |
+
if not self.ip_adapter_loaded or self.pipe is None:
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
self.pipe.unload_ip_adapter()
|
| 285 |
+
self.ip_adapter_loaded = False
|
| 286 |
+
print("β IP-Adapter unloaded")
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"β IP-Adapter unload failed: {e}")
|
| 289 |
+
|
| 290 |
+
def unload_model(self) -> None:
|
| 291 |
+
"""Unload model and free memory."""
|
| 292 |
+
if not self.is_loaded:
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
# Unload IP-Adapter first if loaded
|
| 296 |
+
if self.ip_adapter_loaded:
|
| 297 |
+
self._unload_ip_adapter()
|
| 298 |
+
|
| 299 |
+
if self.pipe is not None:
|
| 300 |
+
del self.pipe
|
| 301 |
+
self.pipe = None
|
| 302 |
+
|
| 303 |
+
self.current_lora = None
|
| 304 |
+
self.ip_adapter_loaded = False
|
| 305 |
+
|
| 306 |
+
gc.collect()
|
| 307 |
+
if torch.cuda.is_available():
|
| 308 |
+
torch.cuda.empty_cache()
|
| 309 |
+
|
| 310 |
+
self.is_loaded = False
|
| 311 |
+
print("β Model unloaded")
|
| 312 |
+
|
| 313 |
+
def _preprocess_image(self, image: Image.Image) -> Image.Image:
|
| 314 |
+
"""Preprocess image for SDXL - resize to appropriate dimensions."""
|
| 315 |
+
if image.mode != 'RGB':
|
| 316 |
+
image = image.convert('RGB')
|
| 317 |
+
|
| 318 |
+
# SDXL works best with 1024x1024, maintain aspect ratio
|
| 319 |
+
max_size = 1024
|
| 320 |
+
width, height = image.size
|
| 321 |
+
|
| 322 |
+
if width > height:
|
| 323 |
+
new_width = max_size
|
| 324 |
+
new_height = int(height * (max_size / width))
|
| 325 |
+
else:
|
| 326 |
+
new_height = max_size
|
| 327 |
+
new_width = int(width * (max_size / height))
|
| 328 |
+
|
| 329 |
+
# Round to nearest 8 (SDXL requirement)
|
| 330 |
+
new_width = (new_width // 8) * 8
|
| 331 |
+
new_height = (new_height // 8) * 8
|
| 332 |
+
|
| 333 |
+
# Ensure minimum size
|
| 334 |
+
new_width = max(new_width, 512)
|
| 335 |
+
new_height = max(new_height, 512)
|
| 336 |
+
|
| 337 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 338 |
+
return image
|
| 339 |
+
|
| 340 |
+
def generate_styled_image(
|
| 341 |
+
self,
|
| 342 |
+
image: Image.Image,
|
| 343 |
+
style_key: str = "3d_cartoon",
|
| 344 |
+
strength: float = 0.65,
|
| 345 |
+
guidance_scale: float = 7.5,
|
| 346 |
+
num_inference_steps: int = 30,
|
| 347 |
+
custom_prompt: str = "",
|
| 348 |
+
seed: int = -1,
|
| 349 |
+
face_restore: bool = False
|
| 350 |
+
) -> Tuple[Image.Image, int]:
|
| 351 |
+
"""
|
| 352 |
+
Convert image to the specified style.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
image: Input PIL Image
|
| 356 |
+
style_key: One of: 3d_cartoon, anime, illustrated_fantasy, watercolor, oil_painting, pixel_art
|
| 357 |
+
strength: How much to transform (0.0-1.0)
|
| 358 |
+
guidance_scale: How closely to follow the prompt
|
| 359 |
+
num_inference_steps: Number of denoising steps
|
| 360 |
+
custom_prompt: Additional prompt text
|
| 361 |
+
seed: Random seed (-1 for random)
|
| 362 |
+
face_restore: Enable enhanced face preservation mode
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Tuple of (Stylized PIL Image, seed used)
|
| 366 |
+
"""
|
| 367 |
+
if not self.is_loaded:
|
| 368 |
+
self.load_model()
|
| 369 |
+
|
| 370 |
+
# Get style config
|
| 371 |
+
config = STYLE_CONFIGS.get(style_key, STYLE_CONFIGS["3d_cartoon"])
|
| 372 |
+
|
| 373 |
+
# Load appropriate LoRA
|
| 374 |
+
self._load_lora(style_key)
|
| 375 |
+
|
| 376 |
+
# Preprocess
|
| 377 |
+
print("β Preprocessing image...")
|
| 378 |
+
processed_image = self._preprocess_image(image)
|
| 379 |
+
|
| 380 |
+
# Get style-specific face_restore settings
|
| 381 |
+
face_settings = FACE_RESTORE_STYLE_SETTINGS.get(style_key, {
|
| 382 |
+
"max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.5
|
| 383 |
+
})
|
| 384 |
+
|
| 385 |
+
# Build prompt based on face_restore mode
|
| 386 |
+
base_prompt = config["prompt"]
|
| 387 |
+
ip_adapter_image = None
|
| 388 |
+
ip_scale = 0.0
|
| 389 |
+
|
| 390 |
+
if face_restore:
|
| 391 |
+
# Enhanced face preservation mode with style-specific settings
|
| 392 |
+
preserve_prompt = FACE_RESTORE_PRESERVE
|
| 393 |
+
negative_base = FACE_RESTORE_NEGATIVE
|
| 394 |
+
|
| 395 |
+
# Apply style-specific strength cap
|
| 396 |
+
max_str = face_settings["max_strength"]
|
| 397 |
+
strength = min(strength, max_str)
|
| 398 |
+
print(f"β Face Restore enabled: strength capped at {strength} (style: {style_key})")
|
| 399 |
+
|
| 400 |
+
# Load IP-Adapter for stronger identity preservation
|
| 401 |
+
if self._load_ip_adapter():
|
| 402 |
+
ip_adapter_image = processed_image
|
| 403 |
+
ip_scale = face_settings["ip_scale"]
|
| 404 |
+
print(f"β IP-Adapter scale: {ip_scale}")
|
| 405 |
+
else:
|
| 406 |
+
preserve_prompt = IDENTITY_PRESERVE
|
| 407 |
+
negative_base = IDENTITY_NEGATIVE
|
| 408 |
+
# Unload IP-Adapter if not using face_restore (save memory)
|
| 409 |
+
if self.ip_adapter_loaded:
|
| 410 |
+
self._unload_ip_adapter()
|
| 411 |
+
|
| 412 |
+
if custom_prompt:
|
| 413 |
+
prompt = f"{preserve_prompt}, {base_prompt}, {custom_prompt}"
|
| 414 |
+
else:
|
| 415 |
+
prompt = f"{preserve_prompt}, {base_prompt}"
|
| 416 |
+
|
| 417 |
+
# Build negative prompt
|
| 418 |
+
negative_prompt = f"{negative_base}, {config['negative_prompt']}"
|
| 419 |
+
|
| 420 |
+
# Set LoRA scale (reduce for face restore mode with style-specific multiplier)
|
| 421 |
+
lora_scale = config.get("lora_scale", 1.0)
|
| 422 |
+
if face_restore:
|
| 423 |
+
lora_scale = lora_scale * face_settings["lora_scale_mult"]
|
| 424 |
+
|
| 425 |
+
# Handle seed
|
| 426 |
+
if seed == -1:
|
| 427 |
+
seed = torch.randint(0, 2147483647, (1,)).item()
|
| 428 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 429 |
+
|
| 430 |
+
# Generate
|
| 431 |
+
print(f"β Generating {config['name']} style (strength: {strength}, steps: {num_inference_steps}, seed: {seed})...")
|
| 432 |
+
|
| 433 |
+
# Build generation kwargs
|
| 434 |
+
gen_kwargs = {
|
| 435 |
+
"prompt": prompt,
|
| 436 |
+
"negative_prompt": negative_prompt,
|
| 437 |
+
"image": processed_image,
|
| 438 |
+
"strength": strength,
|
| 439 |
+
"guidance_scale": guidance_scale,
|
| 440 |
+
"num_inference_steps": num_inference_steps,
|
| 441 |
+
"generator": generator,
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
# Add cross_attention_kwargs only if LoRA is loaded
|
| 445 |
+
if self.current_lora is not None:
|
| 446 |
+
gen_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
|
| 447 |
+
|
| 448 |
+
# Add IP-Adapter settings for face restoration
|
| 449 |
+
if ip_adapter_image is not None and self.ip_adapter_loaded:
|
| 450 |
+
self.pipe.set_ip_adapter_scale(ip_scale)
|
| 451 |
+
gen_kwargs["ip_adapter_image"] = ip_adapter_image
|
| 452 |
+
|
| 453 |
+
result = self.pipe(**gen_kwargs).images[0]
|
| 454 |
+
|
| 455 |
+
print(f"β {config['name']} style generated (seed: {seed})")
|
| 456 |
+
|
| 457 |
+
# Cleanup
|
| 458 |
+
gc.collect()
|
| 459 |
+
if torch.cuda.is_available():
|
| 460 |
+
torch.cuda.empty_cache()
|
| 461 |
+
|
| 462 |
+
return result, seed
|
| 463 |
+
|
| 464 |
+
def generate_blended_style(
|
| 465 |
+
self,
|
| 466 |
+
image: Image.Image,
|
| 467 |
+
blend_key: str,
|
| 468 |
+
custom_prompt: str = "",
|
| 469 |
+
seed: int = -1,
|
| 470 |
+
face_restore: bool = False
|
| 471 |
+
) -> Tuple[Image.Image, int]:
|
| 472 |
+
"""
|
| 473 |
+
Generate image using a style blend preset.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
image: Input PIL Image
|
| 477 |
+
blend_key: Key from STYLE_BLENDS
|
| 478 |
+
custom_prompt: Additional prompt text
|
| 479 |
+
seed: Random seed (-1 for random)
|
| 480 |
+
face_restore: Enable enhanced face preservation mode
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
Tuple of (Stylized PIL Image, seed used)
|
| 484 |
+
"""
|
| 485 |
+
if not self.is_loaded:
|
| 486 |
+
self.load_model()
|
| 487 |
+
|
| 488 |
+
blend_config = STYLE_BLENDS.get(blend_key)
|
| 489 |
+
if not blend_config:
|
| 490 |
+
return self.generate_styled_image(image, "3d_cartoon", seed=seed, face_restore=face_restore)
|
| 491 |
+
|
| 492 |
+
# Get primary style for LoRA
|
| 493 |
+
primary_style = blend_config["primary_style"]
|
| 494 |
+
self._load_lora(primary_style)
|
| 495 |
+
|
| 496 |
+
# Preprocess
|
| 497 |
+
print("β Preprocessing image...")
|
| 498 |
+
processed_image = self._preprocess_image(image)
|
| 499 |
+
|
| 500 |
+
# Get style-specific face_restore settings (use primary style)
|
| 501 |
+
face_settings = FACE_RESTORE_STYLE_SETTINGS.get(primary_style, {
|
| 502 |
+
"max_strength": 0.45, "lora_scale_mult": 0.7, "ip_scale": 0.5
|
| 503 |
+
})
|
| 504 |
+
|
| 505 |
+
# Build prompt based on face_restore mode
|
| 506 |
+
base_prompt = blend_config["prompt"]
|
| 507 |
+
ip_adapter_image = None
|
| 508 |
+
ip_scale = 0.0
|
| 509 |
+
|
| 510 |
+
if face_restore:
|
| 511 |
+
preserve_prompt = FACE_RESTORE_PRESERVE
|
| 512 |
+
negative_base = FACE_RESTORE_NEGATIVE
|
| 513 |
+
|
| 514 |
+
# Apply style-specific strength cap
|
| 515 |
+
max_str = face_settings["max_strength"]
|
| 516 |
+
strength = min(blend_config["strength"], max_str)
|
| 517 |
+
print(f"β Face Restore enabled: strength capped at {strength} (blend: {blend_key})")
|
| 518 |
+
|
| 519 |
+
# Load IP-Adapter for stronger identity preservation
|
| 520 |
+
if self._load_ip_adapter():
|
| 521 |
+
ip_adapter_image = processed_image
|
| 522 |
+
ip_scale = face_settings["ip_scale"]
|
| 523 |
+
print(f"β IP-Adapter scale: {ip_scale}")
|
| 524 |
+
else:
|
| 525 |
+
preserve_prompt = IDENTITY_PRESERVE
|
| 526 |
+
negative_base = IDENTITY_NEGATIVE
|
| 527 |
+
strength = blend_config["strength"]
|
| 528 |
+
# Unload IP-Adapter if not using face_restore
|
| 529 |
+
if self.ip_adapter_loaded:
|
| 530 |
+
self._unload_ip_adapter()
|
| 531 |
+
|
| 532 |
+
if custom_prompt:
|
| 533 |
+
prompt = f"{preserve_prompt}, {base_prompt}, {custom_prompt}"
|
| 534 |
+
else:
|
| 535 |
+
prompt = f"{preserve_prompt}, {base_prompt}"
|
| 536 |
+
|
| 537 |
+
# Build negative prompt
|
| 538 |
+
negative_prompt = f"{negative_base}, {blend_config['negative_prompt']}"
|
| 539 |
+
|
| 540 |
+
# Get LoRA scale from primary style (reduce for face restore with style-specific multiplier)
|
| 541 |
+
primary_config = STYLE_CONFIGS.get(primary_style, {})
|
| 542 |
+
lora_scale = primary_config.get("lora_scale", 1.0) * blend_config["primary_weight"]
|
| 543 |
+
if face_restore:
|
| 544 |
+
lora_scale = lora_scale * face_settings["lora_scale_mult"]
|
| 545 |
+
|
| 546 |
+
# Handle seed
|
| 547 |
+
if seed == -1:
|
| 548 |
+
seed = torch.randint(0, 2147483647, (1,)).item()
|
| 549 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 550 |
+
|
| 551 |
+
# Generate
|
| 552 |
+
print(f"β Generating {blend_config['name']} blend (seed: {seed})...")
|
| 553 |
+
|
| 554 |
+
gen_kwargs = {
|
| 555 |
+
"prompt": prompt,
|
| 556 |
+
"negative_prompt": negative_prompt,
|
| 557 |
+
"image": processed_image,
|
| 558 |
+
"strength": strength,
|
| 559 |
+
"guidance_scale": 7.5,
|
| 560 |
+
"num_inference_steps": 30,
|
| 561 |
+
"generator": generator,
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
if self.current_lora is not None:
|
| 565 |
+
gen_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
|
| 566 |
+
|
| 567 |
+
# Add IP-Adapter settings for face restoration
|
| 568 |
+
if ip_adapter_image is not None and self.ip_adapter_loaded:
|
| 569 |
+
self.pipe.set_ip_adapter_scale(ip_scale)
|
| 570 |
+
gen_kwargs["ip_adapter_image"] = ip_adapter_image
|
| 571 |
+
|
| 572 |
+
result = self.pipe(**gen_kwargs).images[0]
|
| 573 |
+
|
| 574 |
+
print(f"β {blend_config['name']} blend generated (seed: {seed})")
|
| 575 |
+
|
| 576 |
+
# Cleanup
|
| 577 |
+
gc.collect()
|
| 578 |
+
if torch.cuda.is_available():
|
| 579 |
+
torch.cuda.empty_cache()
|
| 580 |
+
|
| 581 |
+
return result, seed
|
| 582 |
+
|
| 583 |
+
def generate_all_outputs(
|
| 584 |
+
self,
|
| 585 |
+
image: Image.Image,
|
| 586 |
+
style_key: str = "3d_cartoon",
|
| 587 |
+
strength: float = 0.65,
|
| 588 |
+
guidance_scale: float = 7.5,
|
| 589 |
+
num_inference_steps: int = 30,
|
| 590 |
+
custom_prompt: str = "",
|
| 591 |
+
seed: int = -1,
|
| 592 |
+
is_blend: bool = False,
|
| 593 |
+
face_restore: bool = False
|
| 594 |
+
) -> dict:
|
| 595 |
+
"""
|
| 596 |
+
Generate styled image output.
|
| 597 |
+
|
| 598 |
+
Returns dict with success status, stylized image, and seed used.
|
| 599 |
+
"""
|
| 600 |
+
result = {
|
| 601 |
+
"success": False,
|
| 602 |
+
"stylized_image": None,
|
| 603 |
+
"preview_image": None,
|
| 604 |
+
"style_name": "",
|
| 605 |
+
"seed_used": 0,
|
| 606 |
+
"error": None
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
try:
|
| 610 |
+
if is_blend:
|
| 611 |
+
# Use blend preset
|
| 612 |
+
blend_config = STYLE_BLENDS.get(style_key, {})
|
| 613 |
+
result["style_name"] = blend_config.get("name", "Unknown Blend")
|
| 614 |
+
|
| 615 |
+
stylized, seed_used = self.generate_blended_style(
|
| 616 |
+
image=image,
|
| 617 |
+
blend_key=style_key,
|
| 618 |
+
custom_prompt=custom_prompt,
|
| 619 |
+
seed=seed,
|
| 620 |
+
face_restore=face_restore
|
| 621 |
+
)
|
| 622 |
+
else:
|
| 623 |
+
# Use single style
|
| 624 |
+
config = STYLE_CONFIGS.get(style_key, STYLE_CONFIGS["3d_cartoon"])
|
| 625 |
+
result["style_name"] = config["name"]
|
| 626 |
+
|
| 627 |
+
stylized, seed_used = self.generate_styled_image(
|
| 628 |
+
image=image,
|
| 629 |
+
style_key=style_key,
|
| 630 |
+
strength=strength,
|
| 631 |
+
guidance_scale=guidance_scale,
|
| 632 |
+
num_inference_steps=num_inference_steps,
|
| 633 |
+
custom_prompt=custom_prompt,
|
| 634 |
+
seed=seed,
|
| 635 |
+
face_restore=face_restore
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
result["stylized_image"] = stylized
|
| 639 |
+
result["preview_image"] = stylized
|
| 640 |
+
result["seed_used"] = seed_used
|
| 641 |
+
result["success"] = True
|
| 642 |
+
print(f"β {result['style_name']} conversion completed (seed: {seed_used})")
|
| 643 |
+
|
| 644 |
+
except Exception as e:
|
| 645 |
+
result["error"] = str(e)
|
| 646 |
+
print(f"β Style conversion failed: {e}")
|
| 647 |
+
|
| 648 |
+
return result
|
| 649 |
+
|
| 650 |
+
@staticmethod
|
| 651 |
+
def get_available_styles() -> Dict[str, Dict[str, Any]]:
|
| 652 |
+
"""Return available style configurations."""
|
| 653 |
+
return {
|
| 654 |
+
key: {
|
| 655 |
+
"name": config["name"],
|
| 656 |
+
"emoji": config["emoji"],
|
| 657 |
+
}
|
| 658 |
+
for key, config in STYLE_CONFIGS.items()
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
@staticmethod
|
| 662 |
+
def get_style_choices() -> list:
|
| 663 |
+
"""Return style choices for UI dropdown."""
|
| 664 |
+
return [
|
| 665 |
+
f"{config['emoji']} {config['name']}"
|
| 666 |
+
for config in STYLE_CONFIGS.values()
|
| 667 |
+
]
|
| 668 |
+
|
| 669 |
+
@staticmethod
|
| 670 |
+
def get_style_key_from_choice(choice: str) -> str:
|
| 671 |
+
"""Convert UI choice back to style key."""
|
| 672 |
+
for key, config in STYLE_CONFIGS.items():
|
| 673 |
+
if config["name"] in choice:
|
| 674 |
+
return key
|
| 675 |
+
return "3d_cartoon"
|
| 676 |
+
|
| 677 |
+
@staticmethod
|
| 678 |
+
def get_blend_choices() -> list:
|
| 679 |
+
"""Return blend preset choices for UI dropdown."""
|
| 680 |
+
return [
|
| 681 |
+
f"{config['emoji']} {config['name']} - {config['description']}"
|
| 682 |
+
for config in STYLE_BLENDS.values()
|
| 683 |
+
]
|
| 684 |
+
|
| 685 |
+
@staticmethod
|
| 686 |
+
def get_blend_key_from_choice(choice: str) -> str:
|
| 687 |
+
"""Convert UI blend choice back to blend key."""
|
| 688 |
+
for key, config in STYLE_BLENDS.items():
|
| 689 |
+
if config["name"] in choice:
|
| 690 |
+
return key
|
| 691 |
+
return "cartoon_anime"
|
| 692 |
+
|
| 693 |
+
@staticmethod
|
| 694 |
+
def get_all_choices() -> dict:
|
| 695 |
+
"""Return both style and blend choices for UI."""
|
| 696 |
+
styles = [
|
| 697 |
+
f"{config['emoji']} {config['name']}"
|
| 698 |
+
for config in STYLE_CONFIGS.values()
|
| 699 |
+
]
|
| 700 |
+
blends = [
|
| 701 |
+
f"{config['emoji']} {config['name']}"
|
| 702 |
+
for config in STYLE_BLENDS.values()
|
| 703 |
+
]
|
| 704 |
+
return {
|
| 705 |
+
"styles": styles,
|
| 706 |
+
"blends": blends,
|
| 707 |
+
"all": styles + ["βββ Style Blends βββ"] + blends
|
| 708 |
+
}
|
ui_manager.py
CHANGED
|
@@ -6,6 +6,7 @@ import logging
|
|
| 6 |
|
| 7 |
from FlowFacade import FlowFacade
|
| 8 |
from BackgroundEngine import BackgroundEngine
|
|
|
|
| 9 |
from scene_templates import SceneTemplateManager
|
| 10 |
from css_style import DELTAFLOW_CSS
|
| 11 |
from prompt_examples import PROMPT_EXAMPLES
|
|
@@ -20,9 +21,10 @@ logger = logging.getLogger(__name__)
|
|
| 20 |
|
| 21 |
|
| 22 |
class UIManager:
|
| 23 |
-
def __init__(self, facade: FlowFacade, background_engine: BackgroundEngine):
|
| 24 |
self.facade = facade
|
| 25 |
self.background_engine = background_engine
|
|
|
|
| 26 |
self.template_manager = SceneTemplateManager()
|
| 27 |
|
| 28 |
def create_interface(self) -> gr.Blocks:
|
|
@@ -45,15 +47,19 @@ class UIManager:
|
|
| 45 |
|
| 46 |
# Main Tabs
|
| 47 |
with gr.Tabs() as main_tabs:
|
| 48 |
-
|
| 49 |
-
# Tab 1: Image to Video
|
| 50 |
with gr.Tab("π¬ Image to Video"):
|
| 51 |
self._create_i2v_tab()
|
| 52 |
-
|
| 53 |
-
# Tab 2: Background Generation
|
| 54 |
with gr.Tab("π¨ Background Generation"):
|
| 55 |
self._create_background_tab()
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Footer
|
| 58 |
gr.HTML("""
|
| 59 |
<div class="footer">
|
|
@@ -341,8 +347,21 @@ class UIManager:
|
|
| 341 |
gr.HTML("""
|
| 342 |
<div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
|
| 343 |
<strong>π‘ When to Adjust:</strong><br>
|
|
|
|
| 344 |
β’ <strong>Feather Radius:</strong> Use 5-10 for complex scenes with fine details (hair, fur, foliage). 0 = sharp edges for clean portraits.<br>
|
| 345 |
-
β’ <strong>Mask Preview:</strong> Check the "Mask Preview" tab after generation. White = kept, Black = replaced.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
</div>
|
| 347 |
""")
|
| 348 |
|
|
@@ -393,7 +412,7 @@ class UIManager:
|
|
| 393 |
|
| 394 |
gr.HTML("""
|
| 395 |
<div class="patience-banner">
|
| 396 |
-
<strong>β±οΈ First-time users:</strong> Initial model loading takes
|
| 397 |
Subsequent generations are much faster (~30s).
|
| 398 |
</div>
|
| 399 |
""")
|
|
@@ -443,6 +462,77 @@ class UIManager:
|
|
| 443 |
elem_classes=["secondary-button"]
|
| 444 |
)
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
# Event handlers for Background Generation tab
|
| 447 |
def apply_template(display_name: str, current_negative: str) -> Tuple[str, str, float]:
|
| 448 |
if not display_name:
|
|
@@ -474,7 +564,7 @@ class UIManager:
|
|
| 474 |
inputs=[
|
| 475 |
bg_image_input, bg_prompt_input, combination_mode,
|
| 476 |
focus_mode, bg_negative_prompt, bg_steps_slider, bg_guidance_slider,
|
| 477 |
-
feather_radius_slider
|
| 478 |
],
|
| 479 |
outputs=[
|
| 480 |
bg_combined_output, bg_generated_output,
|
|
@@ -495,6 +585,132 @@ class UIManager:
|
|
| 495 |
outputs=[bg_status_output]
|
| 496 |
)
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
def _generate_background_handler(
|
| 499 |
self,
|
| 500 |
image: Image.Image,
|
|
@@ -504,7 +720,8 @@ class UIManager:
|
|
| 504 |
negative_prompt: str,
|
| 505 |
steps: int,
|
| 506 |
guidance: float,
|
| 507 |
-
feather_radius: int
|
|
|
|
| 508 |
) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
|
| 509 |
"""Handler for background generation"""
|
| 510 |
if image is None:
|
|
@@ -522,7 +739,7 @@ class UIManager:
|
|
| 522 |
|
| 523 |
result = generate_fn(
|
| 524 |
image, prompt, combination_mode, focus_mode,
|
| 525 |
-
negative_prompt, steps, guidance, feather_radius
|
| 526 |
)
|
| 527 |
|
| 528 |
if result["success"]:
|
|
@@ -550,7 +767,8 @@ class UIManager:
|
|
| 550 |
negative_prompt: str,
|
| 551 |
steps: int,
|
| 552 |
guidance: float,
|
| 553 |
-
feather_radius: int
|
|
|
|
| 554 |
) -> Dict[str, Any]:
|
| 555 |
"""Core background generation with models"""
|
| 556 |
if not self.background_engine.is_initialized:
|
|
@@ -566,7 +784,333 @@ class UIManager:
|
|
| 566 |
num_inference_steps=int(steps),
|
| 567 |
guidance_scale=float(guidance),
|
| 568 |
enable_prompt_enhancement=True,
|
| 569 |
-
feather_radius=int(feather_radius)
|
|
|
|
| 570 |
)
|
| 571 |
|
| 572 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from FlowFacade import FlowFacade
|
| 8 |
from BackgroundEngine import BackgroundEngine
|
| 9 |
+
from style_transfer import StyleTransferEngine
|
| 10 |
from scene_templates import SceneTemplateManager
|
| 11 |
from css_style import DELTAFLOW_CSS
|
| 12 |
from prompt_examples import PROMPT_EXAMPLES
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class UIManager:
|
| 24 |
+
def __init__(self, facade: FlowFacade, background_engine: BackgroundEngine, style_engine: StyleTransferEngine):
|
| 25 |
self.facade = facade
|
| 26 |
self.background_engine = background_engine
|
| 27 |
+
self.style_engine = style_engine
|
| 28 |
self.template_manager = SceneTemplateManager()
|
| 29 |
|
| 30 |
def create_interface(self) -> gr.Blocks:
|
|
|
|
| 47 |
|
| 48 |
# Main Tabs
|
| 49 |
with gr.Tabs() as main_tabs:
|
| 50 |
+
|
| 51 |
+
# Tab 1: Image to Video
|
| 52 |
with gr.Tab("π¬ Image to Video"):
|
| 53 |
self._create_i2v_tab()
|
| 54 |
+
|
| 55 |
+
# Tab 2: Background Generation
|
| 56 |
with gr.Tab("π¨ Background Generation"):
|
| 57 |
self._create_background_tab()
|
| 58 |
|
| 59 |
+
# Tab 3: AI Style Transfer
|
| 60 |
+
with gr.Tab("β¨ Style Transfer"):
|
| 61 |
+
self._create_3d_tab()
|
| 62 |
+
|
| 63 |
# Footer
|
| 64 |
gr.HTML("""
|
| 65 |
<div class="footer">
|
|
|
|
| 347 |
gr.HTML("""
|
| 348 |
<div style="padding: 8px; background: #f0f4ff; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
|
| 349 |
<strong>π‘ When to Adjust:</strong><br>
|
| 350 |
+
β’ <strong>Enhance Dark Edges:</strong> Enable for images with dark/black backgrounds where foreground parts get lost.<br>
|
| 351 |
β’ <strong>Feather Radius:</strong> Use 5-10 for complex scenes with fine details (hair, fur, foliage). 0 = sharp edges for clean portraits.<br>
|
| 352 |
+
β’ <strong>Mask Preview:</strong> Check the "Mask Preview" tab after generation. White = kept, Black = replaced.
|
| 353 |
+
</div>
|
| 354 |
+
""")
|
| 355 |
+
|
| 356 |
+
enhance_dark_edges = gr.Checkbox(
|
| 357 |
+
label="π Enhance Dark Edges",
|
| 358 |
+
value=False,
|
| 359 |
+
info="Enable if dark foreground parts blend into dark backgrounds"
|
| 360 |
+
)
|
| 361 |
+
gr.HTML("""
|
| 362 |
+
<div style="padding: 6px 8px; background: #fff3cd; border-radius: 4px; font-size: 11px; margin-bottom: 12px;">
|
| 363 |
+
<strong>When to use:</strong> If mask preview shows gray areas where foreground should be white (e.g., dark hair/clothing on dark background).
|
| 364 |
+
Auto-detection is enabled by default, but this toggle forces stronger enhancement.
|
| 365 |
</div>
|
| 366 |
""")
|
| 367 |
|
|
|
|
| 412 |
|
| 413 |
gr.HTML("""
|
| 414 |
<div class="patience-banner">
|
| 415 |
+
<strong>β±οΈ First-time users:</strong> Initial model loading takes 30-60 seconds.
|
| 416 |
Subsequent generations are much faster (~30s).
|
| 417 |
</div>
|
| 418 |
""")
|
|
|
|
| 462 |
elem_classes=["secondary-button"]
|
| 463 |
)
|
| 464 |
|
| 465 |
+
# Touch Up Section for manual artifact removal
|
| 466 |
+
with gr.Accordion("ποΈ Touch Up (Remove Artifacts)", open=False) as touchup_accordion:
|
| 467 |
+
gr.HTML("""
|
| 468 |
+
<div style="padding: 10px; background: #e8f4fd; border-radius: 6px; margin-bottom: 12px; font-size: 13px;">
|
| 469 |
+
<strong>β¨ How to Use Touch Up:</strong><br>
|
| 470 |
+
1. After generating, if you see unwanted artifacts (gray edges, leftover objects)<br>
|
| 471 |
+
2. Click "Load Result for Touch Up" to load the image<br>
|
| 472 |
+
3. Use the brush to paint over areas you want to remove<br>
|
| 473 |
+
4. Click "Remove & Fill" to replace painted areas with background
|
| 474 |
+
</div>
|
| 475 |
+
""")
|
| 476 |
+
|
| 477 |
+
# State to store the current result and prompt
|
| 478 |
+
touchup_source_image = gr.State(value=None)
|
| 479 |
+
touchup_background_prompt = gr.State(value="")
|
| 480 |
+
|
| 481 |
+
load_touchup_btn = gr.Button(
|
| 482 |
+
"π₯ Load Result for Touch Up",
|
| 483 |
+
elem_classes=["secondary-button"]
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
touchup_editor = gr.ImageEditor(
|
| 487 |
+
label="Draw on areas to remove (use brush tool)",
|
| 488 |
+
type="pil",
|
| 489 |
+
height=400,
|
| 490 |
+
brush=gr.Brush(
|
| 491 |
+
colors=["#FF0000"],
|
| 492 |
+
default_color="#FF0000",
|
| 493 |
+
default_size=20
|
| 494 |
+
),
|
| 495 |
+
layers=False,
|
| 496 |
+
interactive=True,
|
| 497 |
+
visible=True
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
with gr.Row():
|
| 501 |
+
brush_size_slider = gr.Slider(
|
| 502 |
+
label="Brush Size",
|
| 503 |
+
minimum=5,
|
| 504 |
+
maximum=50,
|
| 505 |
+
value=20,
|
| 506 |
+
step=5,
|
| 507 |
+
scale=2
|
| 508 |
+
)
|
| 509 |
+
touchup_strength = gr.Slider(
|
| 510 |
+
label="Fill Strength",
|
| 511 |
+
minimum=0.8,
|
| 512 |
+
maximum=1.0,
|
| 513 |
+
value=0.99,
|
| 514 |
+
step=0.01,
|
| 515 |
+
scale=2,
|
| 516 |
+
info="Higher = more complete replacement"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
remove_fill_btn = gr.Button(
|
| 520 |
+
"π¨ Remove & Fill",
|
| 521 |
+
variant="primary",
|
| 522 |
+
elem_classes="primary-button"
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
touchup_result = gr.Image(
|
| 526 |
+
label="Touch Up Result",
|
| 527 |
+
elem_classes=["result-gallery"]
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
touchup_status = gr.Textbox(
|
| 531 |
+
label="Touch Up Status",
|
| 532 |
+
value="Load an image to start touch up.",
|
| 533 |
+
interactive=False
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
# Event handlers for Background Generation tab
|
| 537 |
def apply_template(display_name: str, current_negative: str) -> Tuple[str, str, float]:
|
| 538 |
if not display_name:
|
|
|
|
| 564 |
inputs=[
|
| 565 |
bg_image_input, bg_prompt_input, combination_mode,
|
| 566 |
focus_mode, bg_negative_prompt, bg_steps_slider, bg_guidance_slider,
|
| 567 |
+
feather_radius_slider, enhance_dark_edges
|
| 568 |
],
|
| 569 |
outputs=[
|
| 570 |
bg_combined_output, bg_generated_output,
|
|
|
|
| 585 |
outputs=[bg_status_output]
|
| 586 |
)
|
| 587 |
|
| 588 |
+
# Touch Up event handlers
|
| 589 |
+
def load_for_touchup(combined_image, prompt):
|
| 590 |
+
"""Load the generated result into touch up editor"""
|
| 591 |
+
if combined_image is None:
|
| 592 |
+
return None, None, "", "Please generate a background first!"
|
| 593 |
+
return combined_image, combined_image, prompt, "β Image loaded! Use brush to paint areas to remove."
|
| 594 |
+
|
| 595 |
+
load_touchup_btn.click(
|
| 596 |
+
fn=load_for_touchup,
|
| 597 |
+
inputs=[bg_combined_output, bg_prompt_input],
|
| 598 |
+
outputs=[touchup_editor, touchup_source_image, touchup_background_prompt, touchup_status]
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
remove_fill_btn.click(
|
| 602 |
+
fn=self._touchup_inpaint_handler,
|
| 603 |
+
inputs=[touchup_editor, touchup_background_prompt, touchup_strength],
|
| 604 |
+
outputs=[touchup_result, touchup_status]
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
def _touchup_inpaint_handler(
|
| 608 |
+
self,
|
| 609 |
+
editor_data: dict,
|
| 610 |
+
background_prompt: str,
|
| 611 |
+
strength: float
|
| 612 |
+
) -> Tuple[Optional[Image.Image], str]:
|
| 613 |
+
"""Handler for touch up inpainting"""
|
| 614 |
+
if editor_data is None:
|
| 615 |
+
return None, "Please load an image first!"
|
| 616 |
+
|
| 617 |
+
try:
|
| 618 |
+
# Extract image and mask from editor
|
| 619 |
+
# Gradio ImageEditor returns a dict with 'background', 'layers', 'composite'
|
| 620 |
+
if isinstance(editor_data, dict):
|
| 621 |
+
base_image = editor_data.get("background") or editor_data.get("composite")
|
| 622 |
+
layers = editor_data.get("layers", [])
|
| 623 |
+
|
| 624 |
+
if base_image is None:
|
| 625 |
+
return None, "No image found in editor!"
|
| 626 |
+
|
| 627 |
+
# Create mask from drawn layers (red brush strokes)
|
| 628 |
+
mask = self._extract_mask_from_editor(base_image, layers)
|
| 629 |
+
|
| 630 |
+
if mask is None or not self._has_painted_area(mask):
|
| 631 |
+
return None, "Please draw on areas you want to remove!"
|
| 632 |
+
|
| 633 |
+
else:
|
| 634 |
+
# Fallback for PIL Image
|
| 635 |
+
return None, "Invalid editor data format!"
|
| 636 |
+
|
| 637 |
+
# Apply ZeroGPU decorator if available
|
| 638 |
+
if SPACES_AVAILABLE:
|
| 639 |
+
inpaint_fn = spaces.GPU(duration=60)(self._touchup_inpaint_core)
|
| 640 |
+
else:
|
| 641 |
+
inpaint_fn = self._touchup_inpaint_core
|
| 642 |
+
|
| 643 |
+
result = inpaint_fn(base_image, mask, background_prompt, strength)
|
| 644 |
+
|
| 645 |
+
if result["success"]:
|
| 646 |
+
return result["inpainted_image"], "β Touch up completed!"
|
| 647 |
+
else:
|
| 648 |
+
return None, f"Error: {result.get('error', 'Unknown error')}"
|
| 649 |
+
|
| 650 |
+
except Exception as e:
|
| 651 |
+
logger.error(f"Touch up failed: {e}")
|
| 652 |
+
return None, f"Error: {str(e)}"
|
| 653 |
+
|
| 654 |
+
def _extract_mask_from_editor(self, base_image: Image.Image, layers: list) -> Optional[Image.Image]:
|
| 655 |
+
"""Extract painted mask from ImageEditor layers"""
|
| 656 |
+
import numpy as np
|
| 657 |
+
|
| 658 |
+
if not layers:
|
| 659 |
+
return None
|
| 660 |
+
|
| 661 |
+
# Create blank mask
|
| 662 |
+
width, height = base_image.size
|
| 663 |
+
mask_array = np.zeros((height, width), dtype=np.uint8)
|
| 664 |
+
|
| 665 |
+
for layer in layers:
|
| 666 |
+
if layer is None:
|
| 667 |
+
continue
|
| 668 |
+
|
| 669 |
+
# Convert layer to numpy array
|
| 670 |
+
if isinstance(layer, Image.Image):
|
| 671 |
+
layer_array = np.array(layer.convert('RGBA'))
|
| 672 |
+
else:
|
| 673 |
+
continue
|
| 674 |
+
|
| 675 |
+
# Find non-transparent pixels (painted areas)
|
| 676 |
+
# The alpha channel indicates where user drew
|
| 677 |
+
if layer_array.shape[2] >= 4:
|
| 678 |
+
alpha = layer_array[:, :, 3]
|
| 679 |
+
# Also check for red color (our brush color)
|
| 680 |
+
red = layer_array[:, :, 0]
|
| 681 |
+
# Painted areas have high alpha and red channel
|
| 682 |
+
painted = (alpha > 50) | (red > 100)
|
| 683 |
+
mask_array[painted] = 255
|
| 684 |
+
|
| 685 |
+
return Image.fromarray(mask_array, mode='L')
|
| 686 |
+
|
| 687 |
+
def _has_painted_area(self, mask: Image.Image) -> bool:
|
| 688 |
+
"""Check if mask has any painted area"""
|
| 689 |
+
import numpy as np
|
| 690 |
+
mask_array = np.array(mask)
|
| 691 |
+
return np.sum(mask_array > 127) > 100 # At least 100 white pixels
|
| 692 |
+
|
| 693 |
+
def _touchup_inpaint_core(
|
| 694 |
+
self,
|
| 695 |
+
image: Image.Image,
|
| 696 |
+
mask: Image.Image,
|
| 697 |
+
prompt: str,
|
| 698 |
+
strength: float
|
| 699 |
+
) -> dict:
|
| 700 |
+
"""Core inpainting function"""
|
| 701 |
+
# Use the background prompt to fill in the masked areas
|
| 702 |
+
inpaint_prompt = f"{prompt}, seamless, natural continuation, no artifacts" if prompt else "natural background, seamless continuation"
|
| 703 |
+
|
| 704 |
+
return self.background_engine.inpaint_region(
|
| 705 |
+
image=image,
|
| 706 |
+
mask=mask,
|
| 707 |
+
prompt=inpaint_prompt,
|
| 708 |
+
negative_prompt="blurry, artifacts, seams, inconsistent, unnatural",
|
| 709 |
+
num_inference_steps=20,
|
| 710 |
+
guidance_scale=7.5,
|
| 711 |
+
strength=float(strength)
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
def _generate_background_handler(
|
| 715 |
self,
|
| 716 |
image: Image.Image,
|
|
|
|
| 720 |
negative_prompt: str,
|
| 721 |
steps: int,
|
| 722 |
guidance: float,
|
| 723 |
+
feather_radius: int,
|
| 724 |
+
enhance_dark_edges: bool = False
|
| 725 |
) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
|
| 726 |
"""Handler for background generation"""
|
| 727 |
if image is None:
|
|
|
|
| 739 |
|
| 740 |
result = generate_fn(
|
| 741 |
image, prompt, combination_mode, focus_mode,
|
| 742 |
+
negative_prompt, steps, guidance, feather_radius, enhance_dark_edges
|
| 743 |
)
|
| 744 |
|
| 745 |
if result["success"]:
|
|
|
|
| 767 |
negative_prompt: str,
|
| 768 |
steps: int,
|
| 769 |
guidance: float,
|
| 770 |
+
feather_radius: int,
|
| 771 |
+
enhance_dark_edges: bool = False
|
| 772 |
) -> Dict[str, Any]:
|
| 773 |
"""Core background generation with models"""
|
| 774 |
if not self.background_engine.is_initialized:
|
|
|
|
| 784 |
num_inference_steps=int(steps),
|
| 785 |
guidance_scale=float(guidance),
|
| 786 |
enable_prompt_enhancement=True,
|
| 787 |
+
feather_radius=int(feather_radius),
|
| 788 |
+
enhance_dark_edges=enhance_dark_edges
|
| 789 |
)
|
| 790 |
|
| 791 |
+
return result
|
| 792 |
+
|
| 793 |
+
def _create_3d_tab(self):
|
| 794 |
+
"""Create Style Transfer tab - converts images to various artistic styles"""
|
| 795 |
+
with gr.Row():
|
| 796 |
+
# Left Panel: Input & Settings
|
| 797 |
+
with gr.Column(scale=1, elem_classes="feature-card"):
|
| 798 |
+
gr.Markdown("### π¨ AI Style Transfer")
|
| 799 |
+
|
| 800 |
+
# How It Works Guide
|
| 801 |
+
gr.HTML("""
|
| 802 |
+
<div class="quality-banner">
|
| 803 |
+
<strong>π Transform Your Photos</strong><br><br>
|
| 804 |
+
Convert your images into <strong>stunning artistic styles</strong>!<br><br>
|
| 805 |
+
<strong>π¨ Single Styles:</strong> Pure artistic transformations<br>
|
| 806 |
+
<strong>π Style Blends:</strong> Unique combinations for distinctive looks<br><br>
|
| 807 |
+
<strong>π‘ Tips:</strong><br>
|
| 808 |
+
β’ Use <strong>Seed</strong> to recreate the exact same result<br>
|
| 809 |
+
β’ Try different blends for unique artistic effects
|
| 810 |
+
</div>
|
| 811 |
+
""")
|
| 812 |
+
|
| 813 |
+
# Step 1: Upload
|
| 814 |
+
gr.Markdown("#### Step 1: Upload Image")
|
| 815 |
+
style3d_image_input = gr.Image(
|
| 816 |
+
label="Upload Your Image",
|
| 817 |
+
type="pil",
|
| 818 |
+
height=280
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
# Step 2: Choose Style
|
| 822 |
+
gr.Markdown("#### Step 2: Choose Style")
|
| 823 |
+
|
| 824 |
+
# Hidden state to track which mode is active (updated by tab selection)
|
| 825 |
+
is_blend_mode = gr.State(value=False)
|
| 826 |
+
|
| 827 |
+
with gr.Tabs() as style_tabs:
|
| 828 |
+
with gr.TabItem("π¨ Single Styles", id="single_tab") as single_tab:
|
| 829 |
+
style_dropdown = gr.Dropdown(
|
| 830 |
+
choices=self.style_engine.get_style_choices(),
|
| 831 |
+
value="π¬ 3D Cartoon",
|
| 832 |
+
label="Art Style",
|
| 833 |
+
info="Select a single artistic style"
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
style_strength = gr.Slider(
|
| 837 |
+
label="Style Strength",
|
| 838 |
+
minimum=0.3,
|
| 839 |
+
maximum=0.7,
|
| 840 |
+
value=0.50,
|
| 841 |
+
step=0.05,
|
| 842 |
+
info="Lower = keep more original | Higher = stronger style (0.45-0.55 recommended)"
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
with gr.TabItem("π Style Blends", id="blend_tab") as blend_tab:
|
| 846 |
+
blend_dropdown = gr.Dropdown(
|
| 847 |
+
choices=self.style_engine.get_blend_choices(),
|
| 848 |
+
value=self.style_engine.get_blend_choices()[0] if self.style_engine.get_blend_choices() else None,
|
| 849 |
+
label="Blend Preset",
|
| 850 |
+
info="Pre-configured style combinations"
|
| 851 |
+
)
|
| 852 |
+
gr.HTML("""
|
| 853 |
+
<div style="padding: 8px; background: #f0f4ff; border-radius: 6px; font-size: 12px; margin-top: 8px;">
|
| 854 |
+
<strong>Available Blends:</strong><br>
|
| 855 |
+
β’ π 3D Anime Fusion - 3D + Anime linework<br>
|
| 856 |
+
β’ π Dreamy Watercolor - Fantasy + Watercolor<br>
|
| 857 |
+
β’ π Anime Storybook - Anime + Fantasy<br>
|
| 858 |
+
β’ π Renaissance Portrait - Classical oil painting<br>
|
| 859 |
+
β’ πΉοΈ Retro Game Art - Enhanced pixel art
|
| 860 |
+
</div>
|
| 861 |
+
""")
|
| 862 |
+
|
| 863 |
+
# Face Restore option for identity preservation
|
| 864 |
+
face_restore = gr.Checkbox(
|
| 865 |
+
label="π‘οΈ Face Restore (Preserve Identity)",
|
| 866 |
+
value=False,
|
| 867 |
+
info="Enable to better preserve facial features and prevent identity changes"
|
| 868 |
+
)
|
| 869 |
+
gr.HTML("""
|
| 870 |
+
<div style="padding: 6px 8px; background: #fff3cd; border-radius: 4px; font-size: 11px; margin-top: 4px;">
|
| 871 |
+
<strong>π‘ When to use:</strong> Enable if the style changes the person's face, age, or ethnicity too much.
|
| 872 |
+
Auto-reduces strength to preserve original features.
|
| 873 |
+
</div>
|
| 874 |
+
""")
|
| 875 |
+
|
| 876 |
+
with gr.Accordion("βοΈ Advanced Settings", open=False):
|
| 877 |
+
guidance_scale = gr.Slider(
|
| 878 |
+
label="Guidance Scale",
|
| 879 |
+
minimum=5.0,
|
| 880 |
+
maximum=12.0,
|
| 881 |
+
value=7.5,
|
| 882 |
+
step=0.5,
|
| 883 |
+
info="How closely to follow the style"
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
num_steps = gr.Slider(
|
| 887 |
+
label="Quality Steps",
|
| 888 |
+
minimum=20,
|
| 889 |
+
maximum=50,
|
| 890 |
+
value=30,
|
| 891 |
+
step=5,
|
| 892 |
+
info="More steps = better quality but slower"
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
custom_prompt = gr.Textbox(
|
| 896 |
+
label="Additional Description (optional)",
|
| 897 |
+
placeholder="e.g., smiling, dramatic lighting, vibrant colors...",
|
| 898 |
+
lines=2
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
gr.Markdown("##### π² Seed Control")
|
| 902 |
+
randomize_seed = gr.Checkbox(
|
| 903 |
+
label="Randomize Seed",
|
| 904 |
+
value=True,
|
| 905 |
+
info="Uncheck to use manual seed for reproducible results"
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
seed_input = gr.Number(
|
| 909 |
+
label="Manual Seed",
|
| 910 |
+
value=42,
|
| 911 |
+
precision=0,
|
| 912 |
+
info="Use same seed to reproduce exact results"
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# Step 3: Generate
|
| 916 |
+
gr.Markdown("#### Step 3: Generate")
|
| 917 |
+
|
| 918 |
+
gr.HTML("""
|
| 919 |
+
<div class="patience-banner">
|
| 920 |
+
<strong>β±οΈ Generation Time:</strong> ~20-30 seconds.
|
| 921 |
+
First-time model loading may take 30-60 seconds.
|
| 922 |
+
</div>
|
| 923 |
+
""")
|
| 924 |
+
|
| 925 |
+
generate_style_btn = gr.Button(
|
| 926 |
+
"π¨ Transform Image",
|
| 927 |
+
variant="primary",
|
| 928 |
+
elem_classes="primary-button",
|
| 929 |
+
size="lg"
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# Right Panel: Output
|
| 933 |
+
with gr.Column(scale=1, elem_classes="feature-card"):
|
| 934 |
+
gr.Markdown("### π€ Results")
|
| 935 |
+
|
| 936 |
+
with gr.Tabs():
|
| 937 |
+
with gr.TabItem("Stylized Result"):
|
| 938 |
+
style3d_output = gr.Image(
|
| 939 |
+
label="Stylized Result",
|
| 940 |
+
elem_classes=["result-gallery"]
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
with gr.TabItem("Original"):
|
| 944 |
+
style3d_original = gr.Image(
|
| 945 |
+
label="Original Image",
|
| 946 |
+
elem_classes=["result-gallery"]
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
with gr.TabItem("Comparison"):
|
| 950 |
+
with gr.Row():
|
| 951 |
+
style3d_compare_original = gr.Image(
|
| 952 |
+
label="Before",
|
| 953 |
+
elem_classes=["result-gallery"]
|
| 954 |
+
)
|
| 955 |
+
style3d_compare_result = gr.Image(
|
| 956 |
+
label="After",
|
| 957 |
+
elem_classes=["result-gallery"]
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
with gr.Row():
|
| 961 |
+
style3d_status_output = gr.Textbox(
|
| 962 |
+
label="Status",
|
| 963 |
+
value="Ready! Upload an image and select a style to transform.",
|
| 964 |
+
interactive=False,
|
| 965 |
+
elem_classes=["status-panel"],
|
| 966 |
+
scale=3
|
| 967 |
+
)
|
| 968 |
+
seed_output = gr.Number(
|
| 969 |
+
label="Seed Used",
|
| 970 |
+
value=0,
|
| 971 |
+
interactive=False,
|
| 972 |
+
precision=0,
|
| 973 |
+
scale=1
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
with gr.Row():
|
| 977 |
+
clear_style_btn = gr.Button(
|
| 978 |
+
"Clear All",
|
| 979 |
+
elem_classes=["secondary-button"]
|
| 980 |
+
)
|
| 981 |
+
memory_style_btn = gr.Button(
|
| 982 |
+
"Clean Memory",
|
| 983 |
+
elem_classes=["secondary-button"]
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# Event handlers - detect mode from TAB selection (not just dropdown)
|
| 987 |
+
single_tab.select(
|
| 988 |
+
fn=lambda: False, # Single Styles tab clicked -> is_blend = False
|
| 989 |
+
inputs=[],
|
| 990 |
+
outputs=[is_blend_mode]
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
blend_tab.select(
|
| 994 |
+
fn=lambda: True, # Style Blends tab clicked -> is_blend = True
|
| 995 |
+
inputs=[],
|
| 996 |
+
outputs=[is_blend_mode]
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
generate_style_btn.click(
|
| 1000 |
+
fn=self._generate_3d_style_handler,
|
| 1001 |
+
inputs=[
|
| 1002 |
+
style3d_image_input, style_dropdown, blend_dropdown, is_blend_mode,
|
| 1003 |
+
style_strength, guidance_scale, num_steps, custom_prompt,
|
| 1004 |
+
randomize_seed, seed_input, face_restore
|
| 1005 |
+
],
|
| 1006 |
+
outputs=[
|
| 1007 |
+
style3d_output, style3d_original,
|
| 1008 |
+
style3d_compare_original, style3d_compare_result,
|
| 1009 |
+
style3d_status_output, seed_output
|
| 1010 |
+
]
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
clear_style_btn.click(
|
| 1014 |
+
fn=lambda: (None, None, None, None, "Ready! Upload an image and select a style to transform.", 0),
|
| 1015 |
+
outputs=[
|
| 1016 |
+
style3d_output, style3d_original,
|
| 1017 |
+
style3d_compare_original, style3d_compare_result,
|
| 1018 |
+
style3d_status_output, seed_output
|
| 1019 |
+
]
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
memory_style_btn.click(
|
| 1023 |
+
fn=self._cleanup_3d_memory,
|
| 1024 |
+
outputs=[style3d_status_output]
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
def _generate_3d_style_handler(
|
| 1028 |
+
self,
|
| 1029 |
+
image: Image.Image,
|
| 1030 |
+
style_choice: str,
|
| 1031 |
+
blend_choice: str,
|
| 1032 |
+
is_blend_mode: bool,
|
| 1033 |
+
strength: float,
|
| 1034 |
+
guidance_scale: float,
|
| 1035 |
+
num_steps: int,
|
| 1036 |
+
custom_prompt: str,
|
| 1037 |
+
randomize_seed: bool,
|
| 1038 |
+
manual_seed: int,
|
| 1039 |
+
face_restore: bool = False
|
| 1040 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str, int]:
|
| 1041 |
+
"""Handler for style transfer generation"""
|
| 1042 |
+
if image is None:
|
| 1043 |
+
return None, None, None, None, "Please upload an image first!", 0
|
| 1044 |
+
|
| 1045 |
+
try:
|
| 1046 |
+
# Determine style key based on mode (detected from last dropdown interaction)
|
| 1047 |
+
if is_blend_mode:
|
| 1048 |
+
style_key = self.style_engine.get_blend_key_from_choice(blend_choice)
|
| 1049 |
+
is_blend = True
|
| 1050 |
+
else:
|
| 1051 |
+
style_key = self.style_engine.get_style_key_from_choice(style_choice)
|
| 1052 |
+
is_blend = False
|
| 1053 |
+
|
| 1054 |
+
# Handle seed
|
| 1055 |
+
seed = -1 if randomize_seed else int(manual_seed)
|
| 1056 |
+
|
| 1057 |
+
if SPACES_AVAILABLE:
|
| 1058 |
+
generate_fn = spaces.GPU(duration=120)(self._3d_style_generate_core)
|
| 1059 |
+
else:
|
| 1060 |
+
generate_fn = self._3d_style_generate_core
|
| 1061 |
+
|
| 1062 |
+
result = generate_fn(
|
| 1063 |
+
image, style_key, is_blend, strength,
|
| 1064 |
+
guidance_scale, num_steps, custom_prompt, seed, face_restore
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
if result["success"]:
|
| 1068 |
+
stylized = result["stylized_image"]
|
| 1069 |
+
style_name = result.get("style_name", "Style")
|
| 1070 |
+
seed_used = result.get("seed_used", 0)
|
| 1071 |
+
return (
|
| 1072 |
+
stylized,
|
| 1073 |
+
image,
|
| 1074 |
+
image,
|
| 1075 |
+
stylized,
|
| 1076 |
+
f"β {style_name} completed! (seed: {seed_used})",
|
| 1077 |
+
seed_used
|
| 1078 |
+
)
|
| 1079 |
+
else:
|
| 1080 |
+
error_msg = result.get("error", "Unknown error")
|
| 1081 |
+
return None, None, None, None, f"Error: {error_msg}", 0
|
| 1082 |
+
|
| 1083 |
+
except Exception as e:
|
| 1084 |
+
logger.error(f"Style generation failed: {e}")
|
| 1085 |
+
return None, None, None, None, f"Error: {str(e)}", 0
|
| 1086 |
+
|
| 1087 |
+
def _3d_style_generate_core(
|
| 1088 |
+
self,
|
| 1089 |
+
image: Image.Image,
|
| 1090 |
+
style_key: str,
|
| 1091 |
+
is_blend: bool,
|
| 1092 |
+
strength: float,
|
| 1093 |
+
guidance_scale: float,
|
| 1094 |
+
num_steps: int,
|
| 1095 |
+
custom_prompt: str,
|
| 1096 |
+
seed: int,
|
| 1097 |
+
face_restore: bool = False
|
| 1098 |
+
) -> dict:
|
| 1099 |
+
"""Core style transfer generation"""
|
| 1100 |
+
return self.style_engine.generate_all_outputs(
|
| 1101 |
+
image=image,
|
| 1102 |
+
style_key=style_key,
|
| 1103 |
+
strength=float(strength),
|
| 1104 |
+
guidance_scale=float(guidance_scale),
|
| 1105 |
+
num_inference_steps=int(num_steps),
|
| 1106 |
+
custom_prompt=custom_prompt if custom_prompt else "",
|
| 1107 |
+
seed=seed,
|
| 1108 |
+
is_blend=is_blend,
|
| 1109 |
+
face_restore=face_restore
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
def _cleanup_3d_memory(self) -> str:
|
| 1113 |
+
"""Clean up 3D engine memory"""
|
| 1114 |
+
self.style_engine.unload_model()
|
| 1115 |
+
return "Memory cleaned!"
|
| 1116 |
+
|