Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler | |
| from diffusers import StableDiffusionImg2ImgPipeline | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Check if CUDA is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| logger.info(f"Using device: {device}, dtype: {torch_dtype}") | |
| # Function to create hair mask (simplified version) | |
| def create_hair_mask(image): | |
| # For a real app, you'd use a proper face parsing model like BiSeNet | |
| # This is a simplified placeholder that creates a basic top-of-head mask | |
| img_np = np.array(image) | |
| height, width = img_np.shape[:2] | |
| # Create a simple mask for the top portion of the image (where hair typically is) | |
| mask = np.zeros((height, width), dtype=np.uint8) | |
| mask[0:int(height * 0.4), int(width * 0.2):int(width * 0.8)] = 255 | |
| return Image.fromarray(mask) | |
| # Load models at startup to avoid reloading for each inference | |
| def load_models(): | |
| try: | |
| logger.info("Loading ControlNet model...") | |
| # Use a more reliable ControlNet model | |
| controlnet = ControlNetModel.from_pretrained( | |
| "lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype | |
| ).to(device) | |
| logger.info("Loading Stable Diffusion pipeline...") | |
| # Use a smaller, faster model instead of the full SD model | |
| sd_pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=controlnet, | |
| torch_dtype=torch_dtype, | |
| safety_checker=None, # Disable safety checker for speed | |
| # Use low-memory variant with VAE | |
| variant="fp16" if device == "cuda" else None, | |
| use_safetensors=True | |
| ).to(device) | |
| # Set scheduler to a faster one | |
| from diffusers import DPMSolverMultistepScheduler | |
| sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) | |
| # Performance optimizations | |
| sd_pipe.enable_attention_slicing(slice_size=1) | |
| if device == "cuda": | |
| sd_pipe.enable_xformers_memory_efficient_attention() | |
| logger.info("Loading Cute Cartoon style model...") | |
| # Load the cute cartoon model instead of Ghibli | |
| style_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| "AIGCDuckBoss/fluxlora_cute-cartoon", | |
| torch_dtype=torch_dtype, | |
| safety_checker=None, | |
| variant="fp16" if device == "cuda" else None, | |
| use_safetensors=True | |
| ).to(device) | |
| # Use the same faster scheduler for style_pipe | |
| style_pipe.scheduler = DPMSolverMultistepScheduler.from_config(style_pipe.scheduler.config) | |
| # Performance optimizations for style_pipe | |
| style_pipe.enable_attention_slicing(slice_size=1) | |
| if device == "cuda": | |
| style_pipe.enable_xformers_memory_efficient_attention() | |
| logger.info("All models loaded successfully!") | |
| return sd_pipe, style_pipe | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| # Fallback to a simpler model if the main ones fail | |
| try: | |
| logger.info("Attempting to load fallback models...") | |
| sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", | |
| torch_dtype=torch_dtype, | |
| safety_checker=None | |
| ).to(device) | |
| # Use the same model for both pipelines in fallback mode | |
| return sd_pipe, sd_pipe | |
| except Exception as e2: | |
| logger.error(f"Fallback model loading failed: {str(e2)}") | |
| raise RuntimeError("Failed to load any models. Please check the logs for details.") | |
| # Function to enhance hair and apply Cute Cartoon style | |
| def enhance_and_stylize(input_image, sd_pipe, style_pipe, enhancement_strength=0.6, cartoon_strength=0.7): | |
| if input_image is None: | |
| return None | |
| try: | |
| # Preserve original size for better context | |
| original_size = input_image.size | |
| # Resize image for processing, but keep aspect ratio | |
| input_image = input_image.resize((384, 384), Image.LANCZOS) | |
| # Generate canny edges for ControlNet to preserve structure | |
| import cv2 | |
| img_np = np.array(input_image) | |
| canny_img = cv2.Canny(img_np, 100, 200) | |
| canny_img = canny_img[:, :, None] | |
| canny_img = np.concatenate([canny_img, canny_img, canny_img], axis=2) | |
| canny_image = Image.fromarray(canny_img) | |
| # Use a more specific prompt that includes "same person, same composition" | |
| hair_prompt = "portrait photo of the exact same person with slightly fuller hair, preserve facial features, same composition, same colors" | |
| negative_prompt = "different person, unrealistic, distorted face, bad anatomy, different composition" | |
| # First pass: Enhance hair using ControlNet with lower strength to preserve original | |
| logger.info("Generating enhanced image...") | |
| enhanced_image = sd_pipe( | |
| prompt=hair_prompt, | |
| negative_prompt=negative_prompt, | |
| image=canny_image, | |
| guidance_scale=5.5, # Lower guidance for more faithful reproduction | |
| num_inference_steps=10, | |
| # Use lower strength to preserve more of the original image | |
| controlnet_conditioning_scale=0.8 * enhancement_strength, | |
| ).images[0] | |
| # Second pass: Apply Cute Cartoon style but with lower strength to preserve content | |
| # Include more specific details in the prompt | |
| cartoon_prompt = f"portrait of the exact same person in cute cartoon style, preserve facial features, same composition, same colors, adorable, charming" | |
| logger.info("Applying Cute Cartoon style...") | |
| cartoon_image = style_pipe( | |
| prompt=cartoon_prompt, | |
| image=enhanced_image, | |
| # Lower strength preserves more of the original image | |
| strength=0.6 * cartoon_strength, | |
| guidance_scale=6.0, | |
| num_inference_steps=10, | |
| ).images[0] | |
| # Resize back to original dimensions | |
| cartoon_image = cartoon_image.resize(original_size, Image.LANCZOS) | |
| return cartoon_image | |
| except Exception as e: | |
| logger.error(f"Error in image processing: {str(e)}") | |
| # Return original image if processing fails | |
| return input_image | |
| # Load models at startup | |
| try: | |
| logger.info("Starting model loading...") | |
| sd_pipe, style_pipe = load_models() | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {str(e)}") | |
| # We'll handle this in the process_image function | |
| # Create Gradio interface | |
| def process_image(input_image, hair_enhancement, cartoon_style): | |
| if input_image is None: | |
| return None, None | |
| try: | |
| # Check if models are loaded | |
| if 'sd_pipe' not in globals() or 'style_pipe' not in globals(): | |
| return input_image, gr.update(value="Failed to load models. Please check the logs.") | |
| # Process the image | |
| result = enhance_and_stylize( | |
| input_image, | |
| sd_pipe, | |
| style_pipe, | |
| enhancement_strength=hair_enhancement, | |
| cartoon_strength=cartoon_style | |
| ) | |
| # Return both original and processed images for comparison | |
| return input_image, result | |
| except Exception as e: | |
| logger.error(f"Error in process_image: {str(e)}") | |
| return input_image, input_image | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Cute Cartoon Hair Enhancement") as demo: | |
| gr.Markdown("# Cute Cartoon-Style Hair Enhancement") | |
| gr.Markdown("Upload a selfie to enhance hair and apply a Cute Cartoon art style") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload Selfie", type="pil") | |
| with gr.Row(): | |
| hair_enhancement = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Hair Enhancement Strength") | |
| cartoon_style = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Cartoon Style Strength") | |
| process_btn = gr.Button("Enhance & Stylize") | |
| with gr.Column(): | |
| output_original = gr.Image(label="Original Image") | |
| output_stylized = gr.Image(label="Cute Cartoon with Enhanced Hair") | |
| process_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, hair_enhancement, cartoon_style], | |
| outputs=[output_original, output_stylized] | |
| ) | |
| gr.Markdown("### How it works") | |
| gr.Markdown("1. Identifies the hair region in your selfie") | |
| gr.Markdown("2. Enhances hair volume/fullness using AI") | |
| gr.Markdown("3. Applies Cute Cartoon art style to the entire image") | |
| gr.Markdown("4. Displays the before and after comparison") | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |