Sree1234Job commited on
Commit
a0f12de
·
verified ·
1 Parent(s): fc9b493

Create app,py

Browse files
Files changed (1) hide show
  1. app,py +224 -0
app,py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler
4
+ from diffusers import StableDiffusionImg2ImgPipeline
5
+ import numpy as np
6
+ from PIL import Image
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Check if CUDA is available
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
16
+
17
+ logger.info(f"Using device: {device}, dtype: {torch_dtype}")
18
+
19
+ # Function to create hair mask (simplified version)
20
+ def create_hair_mask(image):
21
+ # For a real app, you'd use a proper face parsing model like BiSeNet
22
+ # This is a simplified placeholder that creates a basic top-of-head mask
23
+ img_np = np.array(image)
24
+ height, width = img_np.shape[:2]
25
+
26
+ # Create a simple mask for the top portion of the image (where hair typically is)
27
+ mask = np.zeros((height, width), dtype=np.uint8)
28
+ mask[0:int(height * 0.4), int(width * 0.2):int(width * 0.8)] = 255
29
+
30
+ return Image.fromarray(mask)
31
+
32
+ # Load models at startup to avoid reloading for each inference
33
+ @torch.inference_mode()
34
+ def load_models():
35
+ try:
36
+ logger.info("Loading ControlNet model...")
37
+ # Use a more reliable ControlNet model
38
+ controlnet = ControlNetModel.from_pretrained(
39
+ "lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype
40
+ ).to(device)
41
+
42
+ logger.info("Loading Stable Diffusion pipeline...")
43
+ # Use a smaller, faster model instead of the full SD model
44
+ sd_pipe = StableDiffusionControlNetPipeline.from_pretrained(
45
+ "runwayml/stable-diffusion-v1-5",
46
+ controlnet=controlnet,
47
+ torch_dtype=torch_dtype,
48
+ safety_checker=None, # Disable safety checker for speed
49
+ # Use low-memory variant with VAE
50
+ variant="fp16" if device == "cuda" else None,
51
+ use_safetensors=True
52
+ ).to(device)
53
+
54
+ # Set scheduler to a faster one
55
+ from diffusers import DPMSolverMultistepScheduler
56
+ sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
57
+
58
+ # Performance optimizations
59
+ sd_pipe.enable_attention_slicing(slice_size=1)
60
+ if device == "cuda":
61
+ sd_pipe.enable_xformers_memory_efficient_attention()
62
+
63
+ logger.info("Loading Ghibli style model...")
64
+ # Load a smaller Ghibli style model
65
+ style_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
66
+ "nitrosocke/Ghibli-Diffusion",
67
+ torch_dtype=torch_dtype,
68
+ safety_checker=None,
69
+ variant="fp16" if device == "cuda" else None,
70
+ use_safetensors=True
71
+ ).to(device)
72
+
73
+ # Use the same faster scheduler for style_pipe
74
+ style_pipe.scheduler = DPMSolverMultistepScheduler.from_config(style_pipe.scheduler.config)
75
+
76
+ # Performance optimizations for style_pipe
77
+ style_pipe.enable_attention_slicing(slice_size=1)
78
+ if device == "cuda":
79
+ style_pipe.enable_xformers_memory_efficient_attention()
80
+
81
+ logger.info("All models loaded successfully!")
82
+ return sd_pipe, style_pipe
83
+
84
+ except Exception as e:
85
+ logger.error(f"Error loading models: {str(e)}")
86
+ # Fallback to a simpler model if the main ones fail
87
+ try:
88
+ logger.info("Attempting to load fallback models...")
89
+ sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
90
+ "CompVis/stable-diffusion-v1-4",
91
+ torch_dtype=torch_dtype,
92
+ safety_checker=None
93
+ ).to(device)
94
+
95
+ # Use the same model for both pipelines in fallback mode
96
+ return sd_pipe, sd_pipe
97
+ except Exception as e2:
98
+ logger.error(f"Fallback model loading failed: {str(e2)}")
99
+ raise RuntimeError("Failed to load any models. Please check the logs for details.")
100
+
101
+ # Function to enhance hair and apply Ghibli style
102
+ def enhance_and_stylize(input_image, sd_pipe, style_pipe, enhancement_strength=0.6, ghibli_strength=0.7):
103
+ if input_image is None:
104
+ return None
105
+
106
+ try:
107
+ # Resize image to even smaller dimensions for faster processing
108
+ input_image = input_image.resize((256, 256))
109
+
110
+ # Create hair mask
111
+ hair_mask = create_hair_mask(input_image)
112
+
113
+ # Convert mask to expected format
114
+ mask_image = hair_mask.convert("L")
115
+
116
+ # Generate canny edges for ControlNet
117
+ import cv2
118
+ img_np = np.array(input_image)
119
+ canny_img = cv2.Canny(img_np, 100, 200)
120
+ canny_img = canny_img[:, :, None]
121
+ canny_img = np.concatenate([canny_img, canny_img, canny_img], axis=2)
122
+ canny_image = Image.fromarray(canny_img)
123
+
124
+ # Enhance hair - use even fewer steps for faster generation
125
+ hair_prompt = "portrait photo of person with slightly fuller, naturally grown hair, same face, detailed"
126
+ negative_prompt = "unrealistic, cartoon, distorted face, bad anatomy"
127
+
128
+ # First pass: Enhance hair using ControlNet with fewer steps
129
+ logger.info("Generating enhanced image...")
130
+ enhanced_image = sd_pipe(
131
+ prompt=hair_prompt,
132
+ negative_prompt=negative_prompt,
133
+ image=canny_image,
134
+ guidance_scale=6.0 * enhancement_strength, # Reduced guidance scale
135
+ num_inference_steps=8, # Reduced from 15 to 8
136
+ ).images[0]
137
+
138
+ # Second pass: Apply Ghibli style to the entire image with fewer steps
139
+ ghibli_prompt = "portrait in Studio Ghibli style, soft watercolor, whimsical, warm lighting, detailed background"
140
+
141
+ logger.info("Applying Ghibli style...")
142
+ ghibli_image = style_pipe(
143
+ prompt=ghibli_prompt,
144
+ image=enhanced_image,
145
+ strength=ghibli_strength,
146
+ guidance_scale=6.5, # Reduced guidance scale
147
+ num_inference_steps=8, # Reduced from 15 to 8
148
+ ).images[0]
149
+
150
+ # Resize back to a reasonable size for display
151
+ ghibli_image = ghibli_image.resize((512, 512), Image.LANCZOS)
152
+
153
+ return ghibli_image
154
+
155
+ except Exception as e:
156
+ logger.error(f"Error in image processing: {str(e)}")
157
+ # Return original image if processing fails
158
+ return input_image
159
+
160
+ # Load models at startup
161
+ try:
162
+ logger.info("Starting model loading...")
163
+ sd_pipe, style_pipe = load_models()
164
+ except Exception as e:
165
+ logger.error(f"Failed to initialize models: {str(e)}")
166
+ # We'll handle this in the process_image function
167
+
168
+ # Create Gradio interface
169
+ def process_image(input_image, hair_enhancement, ghibli_style):
170
+ if input_image is None:
171
+ return None, None
172
+
173
+ try:
174
+ # Check if models are loaded
175
+ if 'sd_pipe' not in globals() or 'style_pipe' not in globals():
176
+ return input_image, gr.update(value="Failed to load models. Please check the logs.")
177
+
178
+ # Process the image
179
+ result = enhance_and_stylize(
180
+ input_image,
181
+ sd_pipe,
182
+ style_pipe,
183
+ enhancement_strength=hair_enhancement,
184
+ ghibli_strength=ghibli_style
185
+ )
186
+
187
+ # Return both original and processed images for comparison
188
+ return input_image, result
189
+ except Exception as e:
190
+ logger.error(f"Error in process_image: {str(e)}")
191
+ return input_image, input_image
192
+
193
+ # Create the Gradio interface
194
+ with gr.Blocks(title="Ghibli Hair Enhancement") as demo:
195
+ gr.Markdown("# Ghibli-Style Hair Enhancement")
196
+ gr.Markdown("Upload a selfie to enhance hair and apply a Studio Ghibli art style")
197
+
198
+ with gr.Row():
199
+ with gr.Column():
200
+ input_image = gr.Image(label="Upload Selfie", type="pil")
201
+ with gr.Row():
202
+ hair_enhancement = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Hair Enhancement Strength")
203
+ ghibli_style = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Ghibli Style Strength")
204
+ process_btn = gr.Button("Enhance & Stylize")
205
+
206
+ with gr.Column():
207
+ output_original = gr.Image(label="Original Image")
208
+ output_stylized = gr.Image(label="Ghibli-Style with Enhanced Hair")
209
+
210
+ process_btn.click(
211
+ fn=process_image,
212
+ inputs=[input_image, hair_enhancement, ghibli_style],
213
+ outputs=[output_original, output_stylized]
214
+ )
215
+
216
+ gr.Markdown("### How it works")
217
+ gr.Markdown("1. Identifies the hair region in your selfie")
218
+ gr.Markdown("2. Enhances hair volume/fullness using AI")
219
+ gr.Markdown("3. Applies Studio Ghibli art style to the entire image")
220
+ gr.Markdown("4. Displays the before and after comparison")
221
+
222
+ # Launch the app
223
+ if __name__ == "__main__":
224
+ demo.launch()