Sree1234Job commited on
Commit
fc9b493
·
verified ·
1 Parent(s): 5f37345

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -224
app.py DELETED
@@ -1,224 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler
4
- from diffusers import StableDiffusionImg2ImgPipeline
5
- import numpy as np
6
- from PIL import Image
7
- import logging
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- # Check if CUDA is available
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
16
-
17
- logger.info(f"Using device: {device}, dtype: {torch_dtype}")
18
-
19
- # Function to create hair mask (simplified version)
20
- def create_hair_mask(image):
21
- # For a real app, you'd use a proper face parsing model like BiSeNet
22
- # This is a simplified placeholder that creates a basic top-of-head mask
23
- img_np = np.array(image)
24
- height, width = img_np.shape[:2]
25
-
26
- # Create a simple mask for the top portion of the image (where hair typically is)
27
- mask = np.zeros((height, width), dtype=np.uint8)
28
- mask[0:int(height * 0.4), int(width * 0.2):int(width * 0.8)] = 255
29
-
30
- return Image.fromarray(mask)
31
-
32
- # Load models at startup to avoid reloading for each inference
33
- @torch.inference_mode()
34
- def load_models():
35
- try:
36
- logger.info("Loading ControlNet model...")
37
- # Use a more reliable ControlNet model
38
- controlnet = ControlNetModel.from_pretrained(
39
- "lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype
40
- ).to(device)
41
-
42
- logger.info("Loading Stable Diffusion pipeline...")
43
- # Use a smaller, faster model instead of the full SD model
44
- sd_pipe = StableDiffusionControlNetPipeline.from_pretrained(
45
- "runwayml/stable-diffusion-v1-5",
46
- controlnet=controlnet,
47
- torch_dtype=torch_dtype,
48
- safety_checker=None, # Disable safety checker for speed
49
- # Use low-memory variant with VAE
50
- variant="fp16" if device == "cuda" else None,
51
- use_safetensors=True
52
- ).to(device)
53
-
54
- # Set scheduler to a faster one
55
- from diffusers import DPMSolverMultistepScheduler
56
- sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
57
-
58
- # Performance optimizations
59
- sd_pipe.enable_attention_slicing(slice_size=1)
60
- if device == "cuda":
61
- sd_pipe.enable_xformers_memory_efficient_attention()
62
-
63
- logger.info("Loading Ghibli style model...")
64
- # Load a smaller Ghibli style model
65
- style_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
66
- "nitrosocke/Ghibli-Diffusion",
67
- torch_dtype=torch_dtype,
68
- safety_checker=None,
69
- variant="fp16" if device == "cuda" else None,
70
- use_safetensors=True
71
- ).to(device)
72
-
73
- # Use the same faster scheduler for style_pipe
74
- style_pipe.scheduler = DPMSolverMultistepScheduler.from_config(style_pipe.scheduler.config)
75
-
76
- # Performance optimizations for style_pipe
77
- style_pipe.enable_attention_slicing(slice_size=1)
78
- if device == "cuda":
79
- style_pipe.enable_xformers_memory_efficient_attention()
80
-
81
- logger.info("All models loaded successfully!")
82
- return sd_pipe, style_pipe
83
-
84
- except Exception as e:
85
- logger.error(f"Error loading models: {str(e)}")
86
- # Fallback to a simpler model if the main ones fail
87
- try:
88
- logger.info("Attempting to load fallback models...")
89
- sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
90
- "CompVis/stable-diffusion-v1-4",
91
- torch_dtype=torch_dtype,
92
- safety_checker=None
93
- ).to(device)
94
-
95
- # Use the same model for both pipelines in fallback mode
96
- return sd_pipe, sd_pipe
97
- except Exception as e2:
98
- logger.error(f"Fallback model loading failed: {str(e2)}")
99
- raise RuntimeError("Failed to load any models. Please check the logs for details.")
100
-
101
- # Function to enhance hair and apply Ghibli style
102
- def enhance_and_stylize(input_image, sd_pipe, style_pipe, enhancement_strength=0.6, ghibli_strength=0.7):
103
- if input_image is None:
104
- return None
105
-
106
- try:
107
- # Resize image to even smaller dimensions for faster processing
108
- input_image = input_image.resize((256, 256))
109
-
110
- # Create hair mask
111
- hair_mask = create_hair_mask(input_image)
112
-
113
- # Convert mask to expected format
114
- mask_image = hair_mask.convert("L")
115
-
116
- # Generate canny edges for ControlNet
117
- import cv2
118
- img_np = np.array(input_image)
119
- canny_img = cv2.Canny(img_np, 100, 200)
120
- canny_img = canny_img[:, :, None]
121
- canny_img = np.concatenate([canny_img, canny_img, canny_img], axis=2)
122
- canny_image = Image.fromarray(canny_img)
123
-
124
- # Enhance hair - use even fewer steps for faster generation
125
- hair_prompt = "portrait photo of person with slightly fuller, naturally grown hair, same face, detailed"
126
- negative_prompt = "unrealistic, cartoon, distorted face, bad anatomy"
127
-
128
- # First pass: Enhance hair using ControlNet with fewer steps
129
- logger.info("Generating enhanced image...")
130
- enhanced_image = sd_pipe(
131
- prompt=hair_prompt,
132
- negative_prompt=negative_prompt,
133
- image=canny_image,
134
- guidance_scale=6.0 * enhancement_strength, # Reduced guidance scale
135
- num_inference_steps=8, # Reduced from 15 to 8
136
- ).images[0]
137
-
138
- # Second pass: Apply Ghibli style to the entire image with fewer steps
139
- ghibli_prompt = "portrait in Studio Ghibli style, soft watercolor, whimsical, warm lighting, detailed background"
140
-
141
- logger.info("Applying Ghibli style...")
142
- ghibli_image = style_pipe(
143
- prompt=ghibli_prompt,
144
- image=enhanced_image,
145
- strength=ghibli_strength,
146
- guidance_scale=6.5, # Reduced guidance scale
147
- num_inference_steps=8, # Reduced from 15 to 8
148
- ).images[0]
149
-
150
- # Resize back to a reasonable size for display
151
- ghibli_image = ghibli_image.resize((512, 512), Image.LANCZOS)
152
-
153
- return ghibli_image
154
-
155
- except Exception as e:
156
- logger.error(f"Error in image processing: {str(e)}")
157
- # Return original image if processing fails
158
- return input_image
159
-
160
- # Load models at startup
161
- try:
162
- logger.info("Starting model loading...")
163
- sd_pipe, style_pipe = load_models()
164
- except Exception as e:
165
- logger.error(f"Failed to initialize models: {str(e)}")
166
- # We'll handle this in the process_image function
167
-
168
- # Create Gradio interface
169
- def process_image(input_image, hair_enhancement, ghibli_style):
170
- if input_image is None:
171
- return None, None
172
-
173
- try:
174
- # Check if models are loaded
175
- if 'sd_pipe' not in globals() or 'style_pipe' not in globals():
176
- return input_image, gr.update(value="Failed to load models. Please check the logs.")
177
-
178
- # Process the image
179
- result = enhance_and_stylize(
180
- input_image,
181
- sd_pipe,
182
- style_pipe,
183
- enhancement_strength=hair_enhancement,
184
- ghibli_strength=ghibli_style
185
- )
186
-
187
- # Return both original and processed images for comparison
188
- return input_image, result
189
- except Exception as e:
190
- logger.error(f"Error in process_image: {str(e)}")
191
- return input_image, input_image
192
-
193
- # Create the Gradio interface
194
- with gr.Blocks(title="Ghibli Hair Enhancement") as demo:
195
- gr.Markdown("# Ghibli-Style Hair Enhancement")
196
- gr.Markdown("Upload a selfie to enhance hair and apply a Studio Ghibli art style")
197
-
198
- with gr.Row():
199
- with gr.Column():
200
- input_image = gr.Image(label="Upload Selfie", type="pil")
201
- with gr.Row():
202
- hair_enhancement = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Hair Enhancement Strength")
203
- ghibli_style = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Ghibli Style Strength")
204
- process_btn = gr.Button("Enhance & Stylize")
205
-
206
- with gr.Column():
207
- output_original = gr.Image(label="Original Image")
208
- output_stylized = gr.Image(label="Ghibli-Style with Enhanced Hair")
209
-
210
- process_btn.click(
211
- fn=process_image,
212
- inputs=[input_image, hair_enhancement, ghibli_style],
213
- outputs=[output_original, output_stylized]
214
- )
215
-
216
- gr.Markdown("### How it works")
217
- gr.Markdown("1. Identifies the hair region in your selfie")
218
- gr.Markdown("2. Enhances hair volume/fullness using AI")
219
- gr.Markdown("3. Applies Studio Ghibli art style to the entire image")
220
- gr.Markdown("4. Displays the before and after comparison")
221
-
222
- # Launch the app
223
- if __name__ == "__main__":
224
- demo.launch()