import gradio as gr import torch from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler from diffusers import StableDiffusionImg2ImgPipeline import numpy as np from PIL import Image import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 logger.info(f"Using device: {device}, dtype: {torch_dtype}") # Function to create hair mask (simplified version) def create_hair_mask(image): # For a real app, you'd use a proper face parsing model like BiSeNet # This is a simplified placeholder that creates a basic top-of-head mask img_np = np.array(image) height, width = img_np.shape[:2] # Create a simple mask for the top portion of the image (where hair typically is) mask = np.zeros((height, width), dtype=np.uint8) mask[0:int(height * 0.4), int(width * 0.2):int(width * 0.8)] = 255 return Image.fromarray(mask) # Load models at startup to avoid reloading for each inference @torch.inference_mode() def load_models(): try: logger.info("Loading ControlNet model...") # Use a more reliable ControlNet model controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype ).to(device) logger.info("Loading Stable Diffusion pipeline...") # Use a smaller, faster model instead of the full SD model sd_pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None, # Disable safety checker for speed # Use low-memory variant with VAE variant="fp16" if device == "cuda" else None, use_safetensors=True ).to(device) # Set scheduler to a faster one from diffusers import DPMSolverMultistepScheduler sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) # Performance optimizations sd_pipe.enable_attention_slicing(slice_size=1) if device == "cuda": sd_pipe.enable_xformers_memory_efficient_attention() logger.info("Loading Cute Cartoon style model...") # Load the cute cartoon model instead of Ghibli style_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "AIGCDuckBoss/fluxlora_cute-cartoon", torch_dtype=torch_dtype, safety_checker=None, variant="fp16" if device == "cuda" else None, use_safetensors=True ).to(device) # Use the same faster scheduler for style_pipe style_pipe.scheduler = DPMSolverMultistepScheduler.from_config(style_pipe.scheduler.config) # Performance optimizations for style_pipe style_pipe.enable_attention_slicing(slice_size=1) if device == "cuda": style_pipe.enable_xformers_memory_efficient_attention() logger.info("All models loaded successfully!") return sd_pipe, style_pipe except Exception as e: logger.error(f"Error loading models: {str(e)}") # Fallback to a simpler model if the main ones fail try: logger.info("Attempting to load fallback models...") sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype, safety_checker=None ).to(device) # Use the same model for both pipelines in fallback mode return sd_pipe, sd_pipe except Exception as e2: logger.error(f"Fallback model loading failed: {str(e2)}") raise RuntimeError("Failed to load any models. Please check the logs for details.") # Function to enhance hair and apply Cute Cartoon style def enhance_and_stylize(input_image, sd_pipe, style_pipe, enhancement_strength=0.6, cartoon_strength=0.7): if input_image is None: return None try: # Preserve original size for better context original_size = input_image.size # Resize image for processing, but keep aspect ratio input_image = input_image.resize((384, 384), Image.LANCZOS) # Generate canny edges for ControlNet to preserve structure import cv2 img_np = np.array(input_image) canny_img = cv2.Canny(img_np, 100, 200) canny_img = canny_img[:, :, None] canny_img = np.concatenate([canny_img, canny_img, canny_img], axis=2) canny_image = Image.fromarray(canny_img) # Use a more specific prompt that includes "same person, same composition" hair_prompt = "portrait photo of the exact same person with slightly fuller hair, preserve facial features, same composition, same colors" negative_prompt = "different person, unrealistic, distorted face, bad anatomy, different composition" # First pass: Enhance hair using ControlNet with lower strength to preserve original logger.info("Generating enhanced image...") enhanced_image = sd_pipe( prompt=hair_prompt, negative_prompt=negative_prompt, image=canny_image, guidance_scale=5.5, # Lower guidance for more faithful reproduction num_inference_steps=10, # Use lower strength to preserve more of the original image controlnet_conditioning_scale=0.8 * enhancement_strength, ).images[0] # Second pass: Apply Cute Cartoon style but with lower strength to preserve content # Include more specific details in the prompt cartoon_prompt = f"portrait of the exact same person in cute cartoon style, preserve facial features, same composition, same colors, adorable, charming" logger.info("Applying Cute Cartoon style...") cartoon_image = style_pipe( prompt=cartoon_prompt, image=enhanced_image, # Lower strength preserves more of the original image strength=0.6 * cartoon_strength, guidance_scale=6.0, num_inference_steps=10, ).images[0] # Resize back to original dimensions cartoon_image = cartoon_image.resize(original_size, Image.LANCZOS) return cartoon_image except Exception as e: logger.error(f"Error in image processing: {str(e)}") # Return original image if processing fails return input_image # Load models at startup try: logger.info("Starting model loading...") sd_pipe, style_pipe = load_models() except Exception as e: logger.error(f"Failed to initialize models: {str(e)}") # We'll handle this in the process_image function # Create Gradio interface def process_image(input_image, hair_enhancement, cartoon_style): if input_image is None: return None, None try: # Check if models are loaded if 'sd_pipe' not in globals() or 'style_pipe' not in globals(): return input_image, gr.update(value="Failed to load models. Please check the logs.") # Process the image result = enhance_and_stylize( input_image, sd_pipe, style_pipe, enhancement_strength=hair_enhancement, cartoon_strength=cartoon_style ) # Return both original and processed images for comparison return input_image, result except Exception as e: logger.error(f"Error in process_image: {str(e)}") return input_image, input_image # Create the Gradio interface with gr.Blocks(title="Cute Cartoon Hair Enhancement") as demo: gr.Markdown("# Cute Cartoon-Style Hair Enhancement") gr.Markdown("Upload a selfie to enhance hair and apply a Cute Cartoon art style") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Selfie", type="pil") with gr.Row(): hair_enhancement = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Hair Enhancement Strength") cartoon_style = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Cartoon Style Strength") process_btn = gr.Button("Enhance & Stylize") with gr.Column(): output_original = gr.Image(label="Original Image") output_stylized = gr.Image(label="Cute Cartoon with Enhanced Hair") process_btn.click( fn=process_image, inputs=[input_image, hair_enhancement, cartoon_style], outputs=[output_original, output_stylized] ) gr.Markdown("### How it works") gr.Markdown("1. Identifies the hair region in your selfie") gr.Markdown("2. Enhances hair volume/fullness using AI") gr.Markdown("3. Applies Cute Cartoon art style to the entire image") gr.Markdown("4. Displays the before and after comparison") # Launch the app if __name__ == "__main__": demo.launch()