primerz commited on
Commit
ff176e8
·
verified ·
1 Parent(s): c6fda39

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +334 -0
  2. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import (
4
+ StableDiffusionXLPipeline,
5
+ StableDiffusionXLControlNetPipeline,
6
+ ControlNetModel,
7
+ AutoencoderKL,
8
+ DPMSolverMultistepScheduler
9
+ )
10
+ from diffusers.models.attention_processor import AttnProcessor2_0
11
+ from insightface.app import FaceAnalysis
12
+ from PIL import Image
13
+ import numpy as np
14
+ import cv2
15
+ from transformers import pipeline as transformers_pipeline
16
+ import os
17
+
18
+ # Device configuration
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ dtype = torch.float16 if device == "cuda" else torch.float32
21
+
22
+ print(f"Using device: {device}")
23
+
24
+ class RetroArtConverter:
25
+ def __init__(self):
26
+ self.device = device
27
+ self.dtype = dtype
28
+
29
+ # Initialize face analysis for InstantID
30
+ print("Loading face analysis model...")
31
+ self.face_app = FaceAnalysis(
32
+ name='antelopev2',
33
+ root='./models/insightface',
34
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
35
+ )
36
+ self.face_app.prepare(ctx_id=0, det_size=(640, 640))
37
+
38
+ # Load ControlNet for depth
39
+ print("Loading ControlNet depth model...")
40
+ self.controlnet_depth = ControlNetModel.from_pretrained(
41
+ "diffusers/controlnet-zoe-depth-sdxl-1.0",
42
+ torch_dtype=self.dtype
43
+ ).to(self.device)
44
+
45
+ # Load custom VAE
46
+ print("Loading custom VAE (pixelate)...")
47
+ vae_path = "./models/vae/pixelate.safetensors"
48
+ if os.path.exists(vae_path):
49
+ self.vae = AutoencoderKL.from_single_file(
50
+ vae_path,
51
+ torch_dtype=self.dtype
52
+ ).to(self.device)
53
+ else:
54
+ print("Warning: Custom VAE not found, using default SDXL VAE")
55
+ self.vae = AutoencoderKL.from_pretrained(
56
+ "madebyollin/sdxl-vae-fp16-fix",
57
+ torch_dtype=self.dtype
58
+ ).to(self.device)
59
+
60
+ # Load depth estimator for preprocessing
61
+ print("Loading depth estimator...")
62
+ self.depth_estimator = transformers_pipeline(
63
+ 'depth-estimation',
64
+ model="Intel/dpt-hybrid-midas"
65
+ )
66
+
67
+ # Load SDXL base model with custom checkpoint
68
+ print("Loading SDXL model (horizon)...")
69
+ model_path = "./models/checkpoints/horizon.safetensors"
70
+
71
+ if os.path.exists(model_path):
72
+ self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
73
+ model_path,
74
+ controlnet=self.controlnet_depth,
75
+ vae=self.vae,
76
+ torch_dtype=self.dtype,
77
+ use_safetensors=True
78
+ ).to(self.device)
79
+ else:
80
+ print("Warning: Custom checkpoint not found, using default SDXL")
81
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
82
+ "stabilityai/stable-diffusion-xl-base-1.0",
83
+ controlnet=self.controlnet_depth,
84
+ vae=self.vae,
85
+ torch_dtype=self.dtype,
86
+ use_safetensors=True
87
+ ).to(self.device)
88
+
89
+ # Load custom LORA
90
+ print("Loading LORA (retroart)...")
91
+ lora_path = "./models/lora/retroart.safetensors"
92
+ if os.path.exists(lora_path):
93
+ self.pipe.load_lora_weights(lora_path)
94
+ print("LORA loaded successfully")
95
+ else:
96
+ print("Warning: Custom LORA not found")
97
+
98
+ # Optimize pipeline
99
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
100
+ self.pipe.scheduler.config
101
+ )
102
+ self.pipe.enable_model_cpu_offload()
103
+ self.pipe.enable_vae_slicing()
104
+
105
+ # Enable attention slicing for memory efficiency
106
+ self.pipe.unet.set_attn_processor(AttnProcessor2_0())
107
+
108
+ if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'):
109
+ try:
110
+ self.pipe.enable_xformers_memory_efficient_attention()
111
+ except Exception as e:
112
+ print(f"xformers not available: {e}")
113
+
114
+ print("Model initialization complete!")
115
+
116
+ def get_depth_map(self, image):
117
+ """Generate depth map from input image"""
118
+ depth = self.depth_estimator(image)
119
+ depth_image = depth['depth']
120
+
121
+ # Convert to numpy array
122
+ depth_array = np.array(depth_image)
123
+
124
+ # Normalize to 0-255
125
+ depth_normalized = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min()) * 255
126
+ depth_normalized = depth_normalized.astype(np.uint8)
127
+
128
+ # Convert to 3-channel image
129
+ depth_colored = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
130
+
131
+ return Image.fromarray(depth_colored)
132
+
133
+ def detect_faces(self, image):
134
+ """Detect faces in the image using antelopev2"""
135
+ img_array = np.array(image)
136
+ faces = self.face_app.get(img_array)
137
+ return faces
138
+
139
+ def calculate_target_size(self, original_width, original_height, max_dimension=1024):
140
+ """Calculate target size maintaining aspect ratio"""
141
+ aspect_ratio = original_width / original_height
142
+
143
+ if original_width > original_height:
144
+ new_width = min(original_width, max_dimension)
145
+ new_height = int(new_width / aspect_ratio)
146
+ else:
147
+ new_height = min(original_height, max_dimension)
148
+ new_width = int(new_height * aspect_ratio)
149
+
150
+ # Round to nearest multiple of 8 (required for diffusion models)
151
+ new_width = (new_width // 8) * 8
152
+ new_height = (new_height // 8) * 8
153
+
154
+ return new_width, new_height
155
+
156
+ def generate_retro_art(
157
+ self,
158
+ input_image,
159
+ prompt="retro pixel art game, 16-bit style, vibrant colors",
160
+ negative_prompt="blurry, low quality, modern, photorealistic, 3d render",
161
+ num_inference_steps=30,
162
+ guidance_scale=7.5,
163
+ controlnet_conditioning_scale=0.8,
164
+ lora_scale=0.85
165
+ ):
166
+ """Main generation function"""
167
+
168
+ # Resize image maintaining aspect ratio
169
+ original_width, original_height = input_image.size
170
+ target_width, target_height = self.calculate_target_size(original_width, original_height)
171
+
172
+ print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
173
+
174
+ resized_image = input_image.resize((target_width, target_height), Image.LANCZOS)
175
+
176
+ # Detect faces
177
+ faces = self.detect_faces(resized_image)
178
+ has_faces = len(faces) > 0
179
+
180
+ if has_faces:
181
+ print(f"Detected {len(faces)} face(s)")
182
+ # Enhance prompt for face preservation
183
+ prompt = f"portrait, detailed face, {prompt}"
184
+
185
+ # Generate depth map
186
+ print("Generating depth map...")
187
+ depth_image = self.get_depth_map(resized_image)
188
+ depth_image = depth_image.resize((target_width, target_height), Image.LANCZOS)
189
+
190
+ # Set LORA scale
191
+ self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
192
+
193
+ # Generate image
194
+ print("Generating retro art...")
195
+ result = self.pipe(
196
+ prompt=prompt,
197
+ negative_prompt=negative_prompt,
198
+ image=depth_image,
199
+ num_inference_steps=num_inference_steps,
200
+ guidance_scale=guidance_scale,
201
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
202
+ width=target_width,
203
+ height=target_height,
204
+ generator=torch.Generator(device=self.device).manual_seed(42)
205
+ )
206
+
207
+ return result.images[0]
208
+
209
+ # Initialize the converter
210
+ print("Initializing RetroArt Converter...")
211
+ converter = RetroArtConverter()
212
+
213
+ # Gradio interface
214
+ def process_image(
215
+ image,
216
+ prompt,
217
+ negative_prompt,
218
+ steps,
219
+ guidance_scale,
220
+ controlnet_scale,
221
+ lora_scale
222
+ ):
223
+ if image is None:
224
+ return None
225
+
226
+ try:
227
+ result = converter.generate_retro_art(
228
+ input_image=image,
229
+ prompt=prompt,
230
+ negative_prompt=negative_prompt,
231
+ num_inference_steps=int(steps),
232
+ guidance_scale=guidance_scale,
233
+ controlnet_conditioning_scale=controlnet_scale,
234
+ lora_scale=lora_scale
235
+ )
236
+ return result
237
+ except Exception as e:
238
+ print(f"Error: {e}")
239
+ raise gr.Error(f"Generation failed: {str(e)}")
240
+
241
+ # Create Gradio interface
242
+ with gr.Blocks(title="RetroArt Converter") as demo:
243
+ gr.Markdown("""
244
+ # 🎮 RetroArt Converter
245
+
246
+ Convert any image into retro game art style!
247
+
248
+ **Features:**
249
+ - Custom SDXL checkpoint (Horizon)
250
+ - Pixelate VAE for authentic retro look
251
+ - RetroArt LORA for style enhancement
252
+ - Face preservation with InstantID
253
+ - Depth-aware generation with ControlNet
254
+ """)
255
+
256
+ with gr.Row():
257
+ with gr.Column():
258
+ input_image = gr.Image(label="Input Image", type="pil")
259
+
260
+ prompt = gr.Textbox(
261
+ label="Prompt",
262
+ value="retro pixel art game, 16-bit style, vibrant colors, detailed",
263
+ lines=3
264
+ )
265
+
266
+ negative_prompt = gr.Textbox(
267
+ label="Negative Prompt",
268
+ value="blurry, low quality, modern, photorealistic, 3d render, ugly, distorted",
269
+ lines=2
270
+ )
271
+
272
+ with gr.Accordion("Advanced Settings", open=False):
273
+ steps = gr.Slider(
274
+ minimum=20,
275
+ maximum=50,
276
+ value=30,
277
+ step=1,
278
+ label="Inference Steps"
279
+ )
280
+
281
+ guidance_scale = gr.Slider(
282
+ minimum=1,
283
+ maximum=15,
284
+ value=7.5,
285
+ step=0.5,
286
+ label="Guidance Scale"
287
+ )
288
+
289
+ controlnet_scale = gr.Slider(
290
+ minimum=0,
291
+ maximum=2,
292
+ value=0.8,
293
+ step=0.1,
294
+ label="ControlNet Depth Scale"
295
+ )
296
+
297
+ lora_scale = gr.Slider(
298
+ minimum=0,
299
+ maximum=2,
300
+ value=0.85,
301
+ step=0.05,
302
+ label="RetroArt LORA Scale"
303
+ )
304
+
305
+ generate_btn = gr.Button("🎨 Generate Retro Art", variant="primary")
306
+
307
+ with gr.Column():
308
+ output_image = gr.Image(label="Retro Art Output")
309
+
310
+ gr.Examples(
311
+ examples=[
312
+ ["example_portrait.jpg", "retro pixel art portrait, 16-bit game character", "blurry, modern", 30, 7.5, 0.8, 0.85],
313
+ ],
314
+ inputs=[input_image, prompt, negative_prompt, steps, guidance_scale, controlnet_scale, lora_scale],
315
+ outputs=[output_image],
316
+ fn=process_image,
317
+ cache_examples=False
318
+ )
319
+
320
+ generate_btn.click(
321
+ fn=process_image,
322
+ inputs=[input_image, prompt, negative_prompt, steps, guidance_scale, controlnet_scale, lora_scale],
323
+ outputs=[output_image]
324
+ )
325
+
326
+ # Launch with API enabled
327
+ if __name__ == "__main__":
328
+ demo.queue(max_size=20)
329
+ demo.launch(
330
+ server_name="0.0.0.0",
331
+ server_port=7860,
332
+ share=False,
333
+ show_api=True # Enable API
334
+ )
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ diffusers==0.25.0
4
+ transformers==4.36.0
5
+ accelerate==0.25.0
6
+ gradio==4.12.0
7
+ pillow==10.1.0
8
+ numpy==1.24.3
9
+ opencv-python==4.8.1.78
10
+ safetensors==0.4.1
11
+ insightface==0.7.3
12
+ onnxruntime-gpu==1.16.3
13
+ onnx==1.15.0
14
+ scikit-image==0.22.0
15
+ scipy==1.11.4
16
+ omegaconf==2.3.0
17
+ einops==0.7.0
18
+ xformers==0.0.23
19
+ huggingface-hub==0.20.1