Update app.py
Browse files
app.py
CHANGED
|
@@ -35,9 +35,9 @@ def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
|
|
| 35 |
|
| 36 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
| 37 |
sb_path,
|
| 38 |
-
controlnet=controlnet,
|
| 39 |
-
|
| 40 |
-
|
| 41 |
)
|
| 42 |
|
| 43 |
pipe.scheduler = scheduler
|
|
@@ -56,9 +56,9 @@ high_threshold = 200
|
|
| 56 |
|
| 57 |
pipe, params = load_sb_pipe(controlnet_version)
|
| 58 |
|
| 59 |
-
pipe.enable_xformers_memory_efficient_attention()
|
| 60 |
-
pipe.enable_model_cpu_offload()
|
| 61 |
-
pipe.enable_attention_slicing()
|
| 62 |
|
| 63 |
def pipe_inference(
|
| 64 |
image,
|
|
@@ -78,18 +78,20 @@ def pipe_inference(
|
|
| 78 |
resized_image = resize_image(image, resolution)
|
| 79 |
|
| 80 |
if not is_canny:
|
| 81 |
-
resized_image = preprocess_canny(resized_image)
|
| 82 |
|
| 83 |
rng = create_key(seed)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
|
| 87 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
|
| 88 |
processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
|
|
|
|
| 89 |
p_params = replicate(params)
|
| 90 |
prompt_ids = shard(prompt_ids)
|
| 91 |
negative_prompt_ids = shard(negative_prompt_ids)
|
| 92 |
processed_image = shard(processed_image)
|
|
|
|
| 93 |
output = pipe(
|
| 94 |
prompt_ids=prompt_ids,
|
| 95 |
image=processed_image,
|
|
@@ -122,15 +124,6 @@ def resize_image(image, resolution):
|
|
| 122 |
|
| 123 |
|
| 124 |
def preprocess_canny(image, resolution=128):
|
| 125 |
-
h, w = image.shape
|
| 126 |
-
ratio = w/h
|
| 127 |
-
if ratio > 1 :
|
| 128 |
-
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
|
| 129 |
-
elif ratio < 1 :
|
| 130 |
-
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
|
| 131 |
-
else:
|
| 132 |
-
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
|
| 133 |
-
|
| 134 |
processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
|
| 135 |
processed_image = processed_image[:, :, None]
|
| 136 |
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
|
|
@@ -139,6 +132,7 @@ def preprocess_canny(image, resolution=128):
|
|
| 139 |
processed_image = Image.fromarray(processed_image)
|
| 140 |
return resized_image, processed_image
|
| 141 |
|
|
|
|
| 142 |
def create_demo(process, max_images=12, default_num_images=4):
|
| 143 |
with gr.Blocks() as demo:
|
| 144 |
with gr.Row():
|
|
@@ -218,14 +212,12 @@ def create_demo(process, max_images=12, default_num_images=4):
|
|
| 218 |
inputs=inputs,
|
| 219 |
outputs=result,
|
| 220 |
api_name='canny')
|
| 221 |
-
return demo
|
| 222 |
|
| 223 |
|
| 224 |
if __name__ == '__main__':
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
demo = create_demo(model.process_canny)
|
| 229 |
demo.queue().launch()
|
| 230 |
-
|
| 231 |
-
|
|
|
|
| 35 |
|
| 36 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
| 37 |
sb_path,
|
| 38 |
+
controlnet=controlnet,
|
| 39 |
+
revision="flax",
|
| 40 |
+
dtype=jnp.bfloat16
|
| 41 |
)
|
| 42 |
|
| 43 |
pipe.scheduler = scheduler
|
|
|
|
| 56 |
|
| 57 |
pipe, params = load_sb_pipe(controlnet_version)
|
| 58 |
|
| 59 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
| 60 |
+
# pipe.enable_model_cpu_offload()
|
| 61 |
+
# pipe.enable_attention_slicing()
|
| 62 |
|
| 63 |
def pipe_inference(
|
| 64 |
image,
|
|
|
|
| 78 |
resized_image = resize_image(image, resolution)
|
| 79 |
|
| 80 |
if not is_canny:
|
| 81 |
+
resized_image = preprocess_canny(resized_image, resolution)
|
| 82 |
|
| 83 |
rng = create_key(seed)
|
| 84 |
+
rng = jax.random.split(rng, jax.device_count())
|
| 85 |
+
|
| 86 |
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
|
| 87 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
|
| 88 |
processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
|
| 89 |
+
|
| 90 |
p_params = replicate(params)
|
| 91 |
prompt_ids = shard(prompt_ids)
|
| 92 |
negative_prompt_ids = shard(negative_prompt_ids)
|
| 93 |
processed_image = shard(processed_image)
|
| 94 |
+
|
| 95 |
output = pipe(
|
| 96 |
prompt_ids=prompt_ids,
|
| 97 |
image=processed_image,
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def preprocess_canny(image, resolution=128):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
|
| 128 |
processed_image = processed_image[:, :, None]
|
| 129 |
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
|
|
|
|
| 132 |
processed_image = Image.fromarray(processed_image)
|
| 133 |
return resized_image, processed_image
|
| 134 |
|
| 135 |
+
|
| 136 |
def create_demo(process, max_images=12, default_num_images=4):
|
| 137 |
with gr.Blocks() as demo:
|
| 138 |
with gr.Row():
|
|
|
|
| 212 |
inputs=inputs,
|
| 213 |
outputs=result,
|
| 214 |
api_name='canny')
|
|
|
|
| 215 |
|
| 216 |
|
| 217 |
if __name__ == '__main__':
|
| 218 |
+
|
| 219 |
+
pipe_inference
|
| 220 |
+
demo = create_demo(pipe_inference)
|
|
|
|
| 221 |
demo.queue().launch()
|
| 222 |
+
# gr.Interface(create_demo).launch()
|
| 223 |
+
|