Upload folder using huggingface_hub
Browse files- inference.py +37 -35
- inference2.py +15 -9
- internals/data/dataAccessor.py +27 -10
- internals/pipelines/commons.py +60 -16
- internals/pipelines/controlnets.py +132 -130
- internals/pipelines/high_res.py +1 -1
- internals/pipelines/inpainter.py +48 -12
- internals/pipelines/remove_background.py +54 -9
- internals/pipelines/replace_background.py +17 -7
- internals/pipelines/twoStepPipeline.py +1 -1
- internals/util/cache.py +13 -3
- internals/util/commons.py +2 -2
- internals/util/config.py +5 -0
- internals/util/lora_style.py +5 -0
- internals/util/model_loader.py +3 -0
- pyproject.toml +1 -1
- requirements.txt +4 -0
inference.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
from typing import List, Optional
|
| 3 |
|
|
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
import internals.util.prompt as prompt_util
|
| 7 |
-
from internals.data.dataAccessor import update_db
|
| 8 |
from internals.data.task import Task, TaskType
|
| 9 |
from internals.pipelines.commons import Img2Img, Text2Img
|
| 10 |
from internals.pipelines.controlnets import ControlNet
|
|
@@ -18,11 +19,15 @@ from internals.pipelines.replace_background import ReplaceBackground
|
|
| 18 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 19 |
from internals.util.args import apply_style_args
|
| 20 |
from internals.util.avatar import Avatar
|
| 21 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 22 |
from internals.util.commons import download_image, upload_image, upload_images
|
| 23 |
-
from internals.util.config import (
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from internals.util.failure_hander import FailureHandler
|
| 27 |
from internals.util.lora_style import LoraStyle
|
| 28 |
from internals.util.model_loader import load_model_from_config
|
|
@@ -80,7 +85,7 @@ def canny(task: Task):
|
|
| 80 |
|
| 81 |
width, height = get_intermediate_dimension(task)
|
| 82 |
|
| 83 |
-
controlnet.
|
| 84 |
|
| 85 |
# pipe2 is used for canny and pose
|
| 86 |
lora_patcher = lora_style.get_patcher(
|
|
@@ -88,7 +93,7 @@ def canny(task: Task):
|
|
| 88 |
)
|
| 89 |
lora_patcher.patch()
|
| 90 |
|
| 91 |
-
images, has_nsfw = controlnet.
|
| 92 |
prompt=prompt,
|
| 93 |
imageUrl=task.get_imageUrl(),
|
| 94 |
seed=task.get_seed(),
|
|
@@ -132,12 +137,12 @@ def tile_upscale(task: Task):
|
|
| 132 |
|
| 133 |
prompt = get_patched_prompt_tile_upscale(task)
|
| 134 |
|
| 135 |
-
controlnet.
|
| 136 |
|
| 137 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 138 |
lora_patcher.patch()
|
| 139 |
|
| 140 |
-
images, has_nsfw = controlnet.
|
| 141 |
imageUrl=task.get_imageUrl(),
|
| 142 |
seed=task.get_seed(),
|
| 143 |
steps=task.get_steps(),
|
|
@@ -169,14 +174,14 @@ def scribble(task: Task):
|
|
| 169 |
|
| 170 |
width, height = get_intermediate_dimension(task)
|
| 171 |
|
| 172 |
-
controlnet.
|
| 173 |
|
| 174 |
lora_patcher = lora_style.get_patcher(
|
| 175 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
| 176 |
)
|
| 177 |
lora_patcher.patch()
|
| 178 |
|
| 179 |
-
images, has_nsfw = controlnet.
|
| 180 |
imageUrl=task.get_imageUrl(),
|
| 181 |
seed=task.get_seed(),
|
| 182 |
steps=task.get_steps(),
|
|
@@ -215,14 +220,14 @@ def linearart(task: Task):
|
|
| 215 |
|
| 216 |
width, height = get_intermediate_dimension(task)
|
| 217 |
|
| 218 |
-
controlnet.
|
| 219 |
|
| 220 |
lora_patcher = lora_style.get_patcher(
|
| 221 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
| 222 |
)
|
| 223 |
lora_patcher.patch()
|
| 224 |
|
| 225 |
-
images, has_nsfw = controlnet.
|
| 226 |
imageUrl=task.get_imageUrl(),
|
| 227 |
seed=task.get_seed(),
|
| 228 |
steps=task.get_steps(),
|
|
@@ -261,7 +266,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
| 261 |
|
| 262 |
width, height = get_intermediate_dimension(task)
|
| 263 |
|
| 264 |
-
controlnet.
|
| 265 |
|
| 266 |
# pipe2 is used for canny and pose
|
| 267 |
lora_patcher = lora_style.get_patcher(
|
|
@@ -291,7 +296,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
| 291 |
)
|
| 292 |
condition_image = ControlNet.linearart_condition_image(src_image)
|
| 293 |
|
| 294 |
-
images, has_nsfw = controlnet.
|
| 295 |
prompt=prompt,
|
| 296 |
image=poses,
|
| 297 |
condition_image=[condition_image] * num_return_sequences,
|
|
@@ -440,7 +445,7 @@ def inpaint(task: Task):
|
|
| 440 |
|
| 441 |
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
|
| 442 |
|
| 443 |
-
|
| 444 |
|
| 445 |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 446 |
|
|
@@ -469,12 +474,13 @@ def replace_bg(task: Task):
|
|
| 469 |
product_scale_width=task.get_image_scale(),
|
| 470 |
apply_high_res=task.get_high_res_fix(),
|
| 471 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
|
|
|
| 472 |
)
|
| 473 |
|
| 474 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
| 475 |
|
| 476 |
lora_patcher.cleanup()
|
| 477 |
-
|
| 478 |
|
| 479 |
return {
|
| 480 |
"modified_prompts": prompt,
|
|
@@ -484,38 +490,33 @@ def replace_bg(task: Task):
|
|
| 484 |
|
| 485 |
|
| 486 |
def load_model_by_task(task: Task):
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
if (
|
| 490 |
-
task.get_type()
|
| 491 |
-
in [
|
| 492 |
-
TaskType.TEXT_TO_IMAGE,
|
| 493 |
-
TaskType.IMAGE_TO_IMAGE,
|
| 494 |
-
TaskType.INPAINT,
|
| 495 |
-
]
|
| 496 |
-
and not text2img_pipe.is_loaded()
|
| 497 |
-
):
|
| 498 |
text2img_pipe.load(get_model_dir())
|
| 499 |
img2img_pipe.create(text2img_pipe)
|
| 500 |
-
inpainter.load()
|
| 501 |
high_res.load(img2img_pipe)
|
| 502 |
|
|
|
|
|
|
|
|
|
|
| 503 |
safety_checker.apply(text2img_pipe)
|
| 504 |
safety_checker.apply(img2img_pipe)
|
|
|
|
|
|
|
|
|
|
| 505 |
safety_checker.apply(inpainter)
|
| 506 |
elif task.get_type() == TaskType.REPLACE_BG:
|
| 507 |
replace_background.load(inpainter=inpainter, high_res=high_res)
|
| 508 |
else:
|
| 509 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
| 510 |
-
controlnet.
|
| 511 |
elif task.get_type() == TaskType.CANNY:
|
| 512 |
-
controlnet.
|
| 513 |
elif task.get_type() == TaskType.SCRIBBLE:
|
| 514 |
-
controlnet.
|
| 515 |
elif task.get_type() == TaskType.LINEARART:
|
| 516 |
-
controlnet.
|
| 517 |
elif task.get_type() == TaskType.POSE:
|
| 518 |
-
controlnet.
|
| 519 |
|
| 520 |
safety_checker.apply(controlnet)
|
| 521 |
|
|
@@ -589,7 +590,8 @@ def predict_fn(data, pipe):
|
|
| 589 |
else:
|
| 590 |
raise Exception("Invalid task type")
|
| 591 |
except Exception as e:
|
| 592 |
-
print(f"Error: {e}")
|
| 593 |
slack.error_alert(task, e)
|
| 594 |
controlnet.cleanup()
|
|
|
|
|
|
|
| 595 |
return None
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import List, Optional
|
| 3 |
|
| 4 |
+
import traceback
|
| 5 |
import torch
|
| 6 |
|
| 7 |
import internals.util.prompt as prompt_util
|
| 8 |
+
from internals.data.dataAccessor import update_db, update_db_source_failed
|
| 9 |
from internals.data.task import Task, TaskType
|
| 10 |
from internals.pipelines.commons import Img2Img, Text2Img
|
| 11 |
from internals.pipelines.controlnets import ControlNet
|
|
|
|
| 19 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 20 |
from internals.util.args import apply_style_args
|
| 21 |
from internals.util.avatar import Avatar
|
| 22 |
+
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
|
| 23 |
from internals.util.commons import download_image, upload_image, upload_images
|
| 24 |
+
from internals.util.config import (
|
| 25 |
+
get_model_dir,
|
| 26 |
+
num_return_sequences,
|
| 27 |
+
set_configs_from_task,
|
| 28 |
+
set_model_config,
|
| 29 |
+
set_root_dir,
|
| 30 |
+
)
|
| 31 |
from internals.util.failure_hander import FailureHandler
|
| 32 |
from internals.util.lora_style import LoraStyle
|
| 33 |
from internals.util.model_loader import load_model_from_config
|
|
|
|
| 85 |
|
| 86 |
width, height = get_intermediate_dimension(task)
|
| 87 |
|
| 88 |
+
controlnet.load_model("canny")
|
| 89 |
|
| 90 |
# pipe2 is used for canny and pose
|
| 91 |
lora_patcher = lora_style.get_patcher(
|
|
|
|
| 93 |
)
|
| 94 |
lora_patcher.patch()
|
| 95 |
|
| 96 |
+
images, has_nsfw = controlnet.process(
|
| 97 |
prompt=prompt,
|
| 98 |
imageUrl=task.get_imageUrl(),
|
| 99 |
seed=task.get_seed(),
|
|
|
|
| 137 |
|
| 138 |
prompt = get_patched_prompt_tile_upscale(task)
|
| 139 |
|
| 140 |
+
controlnet.load_model("tile_upscaler")
|
| 141 |
|
| 142 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 143 |
lora_patcher.patch()
|
| 144 |
|
| 145 |
+
images, has_nsfw = controlnet.process(
|
| 146 |
imageUrl=task.get_imageUrl(),
|
| 147 |
seed=task.get_seed(),
|
| 148 |
steps=task.get_steps(),
|
|
|
|
| 174 |
|
| 175 |
width, height = get_intermediate_dimension(task)
|
| 176 |
|
| 177 |
+
controlnet.load_model("scribble")
|
| 178 |
|
| 179 |
lora_patcher = lora_style.get_patcher(
|
| 180 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
| 181 |
)
|
| 182 |
lora_patcher.patch()
|
| 183 |
|
| 184 |
+
images, has_nsfw = controlnet.process(
|
| 185 |
imageUrl=task.get_imageUrl(),
|
| 186 |
seed=task.get_seed(),
|
| 187 |
steps=task.get_steps(),
|
|
|
|
| 220 |
|
| 221 |
width, height = get_intermediate_dimension(task)
|
| 222 |
|
| 223 |
+
controlnet.load_model("linearart")
|
| 224 |
|
| 225 |
lora_patcher = lora_style.get_patcher(
|
| 226 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
| 227 |
)
|
| 228 |
lora_patcher.patch()
|
| 229 |
|
| 230 |
+
images, has_nsfw = controlnet.process(
|
| 231 |
imageUrl=task.get_imageUrl(),
|
| 232 |
seed=task.get_seed(),
|
| 233 |
steps=task.get_steps(),
|
|
|
|
| 266 |
|
| 267 |
width, height = get_intermediate_dimension(task)
|
| 268 |
|
| 269 |
+
controlnet.load_model("pose")
|
| 270 |
|
| 271 |
# pipe2 is used for canny and pose
|
| 272 |
lora_patcher = lora_style.get_patcher(
|
|
|
|
| 296 |
)
|
| 297 |
condition_image = ControlNet.linearart_condition_image(src_image)
|
| 298 |
|
| 299 |
+
images, has_nsfw = controlnet.process(
|
| 300 |
prompt=prompt,
|
| 301 |
image=poses,
|
| 302 |
condition_image=[condition_image] * num_return_sequences,
|
|
|
|
| 445 |
|
| 446 |
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
|
| 447 |
|
| 448 |
+
clear_cuda_and_gc()
|
| 449 |
|
| 450 |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 451 |
|
|
|
|
| 474 |
product_scale_width=task.get_image_scale(),
|
| 475 |
apply_high_res=task.get_high_res_fix(),
|
| 476 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
| 477 |
+
model_type=task.get_modelType(),
|
| 478 |
)
|
| 479 |
|
| 480 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
| 481 |
|
| 482 |
lora_patcher.cleanup()
|
| 483 |
+
clear_cuda_and_gc()
|
| 484 |
|
| 485 |
return {
|
| 486 |
"modified_prompts": prompt,
|
|
|
|
| 490 |
|
| 491 |
|
| 492 |
def load_model_by_task(task: Task):
|
| 493 |
+
if not text2img_pipe.is_loaded():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
text2img_pipe.load(get_model_dir())
|
| 495 |
img2img_pipe.create(text2img_pipe)
|
|
|
|
| 496 |
high_res.load(img2img_pipe)
|
| 497 |
|
| 498 |
+
inpainter.init(text2img_pipe)
|
| 499 |
+
controlnet.init(text2img_pipe)
|
| 500 |
+
|
| 501 |
safety_checker.apply(text2img_pipe)
|
| 502 |
safety_checker.apply(img2img_pipe)
|
| 503 |
+
|
| 504 |
+
if task.get_type() == TaskType.INPAINT:
|
| 505 |
+
inpainter.load()
|
| 506 |
safety_checker.apply(inpainter)
|
| 507 |
elif task.get_type() == TaskType.REPLACE_BG:
|
| 508 |
replace_background.load(inpainter=inpainter, high_res=high_res)
|
| 509 |
else:
|
| 510 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
| 511 |
+
controlnet.load_model("tile_upscaler")
|
| 512 |
elif task.get_type() == TaskType.CANNY:
|
| 513 |
+
controlnet.load_model("canny")
|
| 514 |
elif task.get_type() == TaskType.SCRIBBLE:
|
| 515 |
+
controlnet.load_model("scribble")
|
| 516 |
elif task.get_type() == TaskType.LINEARART:
|
| 517 |
+
controlnet.load_model("linearart")
|
| 518 |
elif task.get_type() == TaskType.POSE:
|
| 519 |
+
controlnet.load_model("pose")
|
| 520 |
|
| 521 |
safety_checker.apply(controlnet)
|
| 522 |
|
|
|
|
| 590 |
else:
|
| 591 |
raise Exception("Invalid task type")
|
| 592 |
except Exception as e:
|
|
|
|
| 593 |
slack.error_alert(task, e)
|
| 594 |
controlnet.cleanup()
|
| 595 |
+
traceback.print_exc()
|
| 596 |
+
update_db_source_failed(task.get_sourceId(), task.get_userId())
|
| 597 |
return None
|
inference2.py
CHANGED
|
@@ -13,17 +13,19 @@ from internals.pipelines.img_to_text import Image2Text
|
|
| 13 |
from internals.pipelines.inpainter import InPainter
|
| 14 |
from internals.pipelines.object_remove import ObjectRemoval
|
| 15 |
from internals.pipelines.prompt_modifier import PromptModifier
|
| 16 |
-
from internals.pipelines.remove_background import
|
| 17 |
-
RemoveBackgroundV2)
|
| 18 |
from internals.pipelines.replace_background import ReplaceBackground
|
| 19 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 20 |
from internals.pipelines.upscaler import Upscaler
|
| 21 |
from internals.util.avatar import Avatar
|
| 22 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 23 |
-
from internals.util.commons import
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
from internals.util.failure_hander import FailureHandler
|
| 28 |
from internals.util.lora_style import LoraStyle
|
| 29 |
from internals.util.model_loader import load_model_from_config
|
|
@@ -65,7 +67,7 @@ def tile_upscale(task: Task):
|
|
| 65 |
|
| 66 |
prompt = get_patched_prompt_tile_upscale(task)
|
| 67 |
|
| 68 |
-
controlnet.
|
| 69 |
|
| 70 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 71 |
lora_patcher.patch()
|
|
@@ -98,7 +100,9 @@ def tile_upscale(task: Task):
|
|
| 98 |
@slack.auto_send_alert
|
| 99 |
def remove_bg(task: Task):
|
| 100 |
# remove_background = RemoveBackground()
|
| 101 |
-
output_image = remove_background_v2.remove(
|
|
|
|
|
|
|
| 102 |
|
| 103 |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
| 104 |
upload_image(output_image, output_key)
|
|
@@ -173,6 +177,7 @@ def replace_bg(task: Task):
|
|
| 173 |
extend_object=task.rbg_extend_object(),
|
| 174 |
product_scale_width=task.get_image_scale(),
|
| 175 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
|
@@ -231,6 +236,7 @@ def model_fn(model_dir):
|
|
| 231 |
upscaler.load()
|
| 232 |
inpainter.load()
|
| 233 |
high_res.load()
|
|
|
|
| 234 |
|
| 235 |
replace_background.load(
|
| 236 |
upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
|
|
@@ -242,7 +248,7 @@ def model_fn(model_dir):
|
|
| 242 |
|
| 243 |
def load_model_by_task(task: Task):
|
| 244 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
| 245 |
-
controlnet.
|
| 246 |
|
| 247 |
safety_checker.apply(controlnet)
|
| 248 |
|
|
|
|
| 13 |
from internals.pipelines.inpainter import InPainter
|
| 14 |
from internals.pipelines.object_remove import ObjectRemoval
|
| 15 |
from internals.pipelines.prompt_modifier import PromptModifier
|
| 16 |
+
from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
|
|
|
|
| 17 |
from internals.pipelines.replace_background import ReplaceBackground
|
| 18 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 19 |
from internals.pipelines.upscaler import Upscaler
|
| 20 |
from internals.util.avatar import Avatar
|
| 21 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 22 |
+
from internals.util.commons import construct_default_s3_url, upload_image, upload_images
|
| 23 |
+
from internals.util.config import (
|
| 24 |
+
num_return_sequences,
|
| 25 |
+
set_configs_from_task,
|
| 26 |
+
set_model_config,
|
| 27 |
+
set_root_dir,
|
| 28 |
+
)
|
| 29 |
from internals.util.failure_hander import FailureHandler
|
| 30 |
from internals.util.lora_style import LoraStyle
|
| 31 |
from internals.util.model_loader import load_model_from_config
|
|
|
|
| 67 |
|
| 68 |
prompt = get_patched_prompt_tile_upscale(task)
|
| 69 |
|
| 70 |
+
controlnet.load_model("tile_upscaler")
|
| 71 |
|
| 72 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 73 |
lora_patcher.patch()
|
|
|
|
| 100 |
@slack.auto_send_alert
|
| 101 |
def remove_bg(task: Task):
|
| 102 |
# remove_background = RemoveBackground()
|
| 103 |
+
output_image = remove_background_v2.remove(
|
| 104 |
+
task.get_imageUrl(), model_type=task.get_modelType()
|
| 105 |
+
)
|
| 106 |
|
| 107 |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
| 108 |
upload_image(output_image, output_key)
|
|
|
|
| 177 |
extend_object=task.rbg_extend_object(),
|
| 178 |
product_scale_width=task.get_image_scale(),
|
| 179 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
| 180 |
+
model_type=task.get_modelType(),
|
| 181 |
)
|
| 182 |
|
| 183 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
|
|
|
| 236 |
upscaler.load()
|
| 237 |
inpainter.load()
|
| 238 |
high_res.load()
|
| 239 |
+
controlnet.init(high_res)
|
| 240 |
|
| 241 |
replace_background.load(
|
| 242 |
upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
|
|
|
|
| 248 |
|
| 249 |
def load_model_by_task(task: Task):
|
| 250 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
| 251 |
+
controlnet.load_model("tile_upscaler")
|
| 252 |
|
| 253 |
safety_checker.apply(controlnet)
|
| 254 |
|
internals/data/dataAccessor.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import traceback
|
| 2 |
from typing import Dict, List, Optional
|
| 3 |
|
|
|
|
| 4 |
import requests
|
| 5 |
from pydash import includes
|
| 6 |
|
|
@@ -9,6 +10,14 @@ from internals.util.config import api_endpoint, api_headers
|
|
| 9 |
from internals.util.slack import Slack
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def updateSource(sourceId, userId, state):
|
| 13 |
print("update source is called")
|
| 14 |
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
|
|
@@ -21,7 +30,8 @@ def updateSource(sourceId, userId, state):
|
|
| 21 |
data = {"state": state}
|
| 22 |
|
| 23 |
try:
|
| 24 |
-
|
|
|
|
| 25 |
print("update source response", response)
|
| 26 |
except requests.exceptions.Timeout:
|
| 27 |
print("Request timed out while updating source")
|
|
@@ -47,7 +57,8 @@ def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
|
|
| 47 |
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
|
| 48 |
|
| 49 |
try:
|
| 50 |
-
|
|
|
|
| 51 |
# print("save generation response", response)
|
| 52 |
except requests.exceptions.Timeout:
|
| 53 |
print("Request timed out while saving image")
|
|
@@ -61,11 +72,12 @@ def getStyles() -> Optional[Dict]:
|
|
| 61 |
url = api_endpoint() + "/autodraft-crecoai/style"
|
| 62 |
print(url)
|
| 63 |
try:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
return response.json()
|
| 70 |
except requests.exceptions.Timeout:
|
| 71 |
print("Request timed out while fetching styles")
|
|
@@ -78,9 +90,10 @@ def getStyles() -> Optional[Dict]:
|
|
| 78 |
def getCharacters(model_id: str) -> Optional[List]:
|
| 79 |
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
|
| 80 |
try:
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
| 84 |
return response
|
| 85 |
except requests.exceptions.Timeout:
|
| 86 |
print("Request timed out while fetching characters")
|
|
@@ -89,6 +102,10 @@ def getCharacters(model_id: str) -> Optional[List]:
|
|
| 89 |
return None
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def update_db(func):
|
| 93 |
def caller(*args, **kwargs):
|
| 94 |
if type(args[0]) is not Task:
|
|
|
|
| 1 |
import traceback
|
| 2 |
from typing import Dict, List, Optional
|
| 3 |
|
| 4 |
+
from requests.adapters import Retry, HTTPAdapter
|
| 5 |
import requests
|
| 6 |
from pydash import includes
|
| 7 |
|
|
|
|
| 10 |
from internals.util.slack import Slack
|
| 11 |
|
| 12 |
|
| 13 |
+
class RetryRequest:
|
| 14 |
+
def __new__(cls):
|
| 15 |
+
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
|
| 16 |
+
session = requests.Session()
|
| 17 |
+
session.mount("https://", HTTPAdapter(max_retries=obj))
|
| 18 |
+
return session
|
| 19 |
+
|
| 20 |
+
|
| 21 |
def updateSource(sourceId, userId, state):
|
| 22 |
print("update source is called")
|
| 23 |
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
|
|
|
|
| 30 |
data = {"state": state}
|
| 31 |
|
| 32 |
try:
|
| 33 |
+
with RetryRequest() as session:
|
| 34 |
+
response = session.patch(url, headers=headers, json=data, timeout=10)
|
| 35 |
print("update source response", response)
|
| 36 |
except requests.exceptions.Timeout:
|
| 37 |
print("Request timed out while updating source")
|
|
|
|
| 57 |
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
|
| 58 |
|
| 59 |
try:
|
| 60 |
+
with RetryRequest() as session:
|
| 61 |
+
session.patch(url, headers=headers, json=data)
|
| 62 |
# print("save generation response", response)
|
| 63 |
except requests.exceptions.Timeout:
|
| 64 |
print("Request timed out while saving image")
|
|
|
|
| 72 |
url = api_endpoint() + "/autodraft-crecoai/style"
|
| 73 |
print(url)
|
| 74 |
try:
|
| 75 |
+
with RetryRequest() as session:
|
| 76 |
+
response = session.get(
|
| 77 |
+
url,
|
| 78 |
+
timeout=10,
|
| 79 |
+
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
|
| 80 |
+
)
|
| 81 |
return response.json()
|
| 82 |
except requests.exceptions.Timeout:
|
| 83 |
print("Request timed out while fetching styles")
|
|
|
|
| 90 |
def getCharacters(model_id: str) -> Optional[List]:
|
| 91 |
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
|
| 92 |
try:
|
| 93 |
+
with RetryRequest() as session:
|
| 94 |
+
response = session.get(url, timeout=10, headers=api_headers())
|
| 95 |
+
response = response.json()
|
| 96 |
+
response = response["data"]["characters"]
|
| 97 |
return response
|
| 98 |
except requests.exceptions.Timeout:
|
| 99 |
print("Request timed out while fetching characters")
|
|
|
|
| 102 |
return None
|
| 103 |
|
| 104 |
|
| 105 |
+
def update_db_source_failed(sourceId, userId):
|
| 106 |
+
updateSource(sourceId, userId, "FAILED")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def update_db(func):
|
| 110 |
def caller(*args, **kwargs):
|
| 111 |
if type(args[0]) is not Task:
|
internals/pipelines/commons.py
CHANGED
|
@@ -2,12 +2,16 @@ from dataclasses import dataclass
|
|
| 2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from internals.data.result import Result
|
| 8 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
| 9 |
from internals.util.commons import disable_safety_checker, download_image
|
| 10 |
-
from internals.util.config import get_hf_token, num_return_sequences
|
| 11 |
|
| 12 |
|
| 13 |
class AbstractPipeline:
|
|
@@ -27,9 +31,17 @@ class Text2Img(AbstractPipeline):
|
|
| 27 |
prompt_right: List[str] = None
|
| 28 |
|
| 29 |
def load(self, model_dir: str):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
self.__patch()
|
| 34 |
|
| 35 |
def is_loaded(self):
|
|
@@ -38,10 +50,16 @@ class Text2Img(AbstractPipeline):
|
|
| 38 |
return False
|
| 39 |
|
| 40 |
def create(self, pipeline: AbstractPipeline):
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
self.__patch()
|
| 43 |
|
| 44 |
def __patch(self):
|
|
|
|
|
|
|
|
|
|
| 45 |
self.pipe.enable_xformers_memory_efficient_attention()
|
| 46 |
|
| 47 |
@torch.inference_mode()
|
|
@@ -92,9 +110,19 @@ class Text2Img(AbstractPipeline):
|
|
| 92 |
# two step pipeline
|
| 93 |
modified_prompt = params.modified_prompt
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
height=height,
|
| 99 |
width=width,
|
| 100 |
num_inference_steps=num_inference_steps,
|
|
@@ -111,7 +139,7 @@ class Text2Img(AbstractPipeline):
|
|
| 111 |
callback=callback,
|
| 112 |
callback_steps=callback_steps,
|
| 113 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 114 |
-
|
| 115 |
)
|
| 116 |
|
| 117 |
return Result.from_result(result)
|
|
@@ -124,22 +152,38 @@ class Img2Img(AbstractPipeline):
|
|
| 124 |
if self.__loaded:
|
| 125 |
return
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
self.__patch()
|
| 131 |
|
| 132 |
self.__loaded = True
|
| 133 |
|
| 134 |
def create(self, pipeline: AbstractPipeline):
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
self.__patch()
|
| 139 |
|
| 140 |
self.__loaded = True
|
| 141 |
|
| 142 |
def __patch(self):
|
|
|
|
|
|
|
|
|
|
| 143 |
self.pipe.enable_xformers_memory_efficient_attention()
|
| 144 |
|
| 145 |
@torch.inference_mode()
|
|
|
|
| 2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from diffusers import (
|
| 6 |
+
StableDiffusionImg2ImgPipeline,
|
| 7 |
+
StableDiffusionXLPipeline,
|
| 8 |
+
StableDiffusionXLImg2ImgPipeline,
|
| 9 |
+
)
|
| 10 |
|
| 11 |
from internals.data.result import Result
|
| 12 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
| 13 |
from internals.util.commons import disable_safety_checker, download_image
|
| 14 |
+
from internals.util.config import get_hf_token, num_return_sequences, get_is_sdxl
|
| 15 |
|
| 16 |
|
| 17 |
class AbstractPipeline:
|
|
|
|
| 31 |
prompt_right: List[str] = None
|
| 32 |
|
| 33 |
def load(self, model_dir: str):
|
| 34 |
+
if get_is_sdxl():
|
| 35 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 36 |
+
model_dir,
|
| 37 |
+
torch_dtype=torch.float16,
|
| 38 |
+
use_auth_token=get_hf_token(),
|
| 39 |
+
use_safetensors=True,
|
| 40 |
+
).to("cuda")
|
| 41 |
+
else:
|
| 42 |
+
self.pipe = two_step_pipeline.from_pretrained(
|
| 43 |
+
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
|
| 44 |
+
).to("cuda")
|
| 45 |
self.__patch()
|
| 46 |
|
| 47 |
def is_loaded(self):
|
|
|
|
| 50 |
return False
|
| 51 |
|
| 52 |
def create(self, pipeline: AbstractPipeline):
|
| 53 |
+
if get_is_sdxl():
|
| 54 |
+
self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
|
| 55 |
+
else:
|
| 56 |
+
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
|
| 57 |
self.__patch()
|
| 58 |
|
| 59 |
def __patch(self):
|
| 60 |
+
if get_is_sdxl():
|
| 61 |
+
self.pipe.enable_vae_tiling()
|
| 62 |
+
self.pipe.enable_vae_slicing()
|
| 63 |
self.pipe.enable_xformers_memory_efficient_attention()
|
| 64 |
|
| 65 |
@torch.inference_mode()
|
|
|
|
| 110 |
# two step pipeline
|
| 111 |
modified_prompt = params.modified_prompt
|
| 112 |
|
| 113 |
+
if get_is_sdxl():
|
| 114 |
+
print("Warning: Two step pipeline is not supported on SDXL")
|
| 115 |
+
kwargs = {
|
| 116 |
+
"prompt": modified_prompt,
|
| 117 |
+
}
|
| 118 |
+
else:
|
| 119 |
+
kwargs = {
|
| 120 |
+
"prompt": prompt,
|
| 121 |
+
"modified_prompts": modified_prompt,
|
| 122 |
+
"iteration": iteration,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
result = self.pipe.__call__(
|
| 126 |
height=height,
|
| 127 |
width=width,
|
| 128 |
num_inference_steps=num_inference_steps,
|
|
|
|
| 139 |
callback=callback,
|
| 140 |
callback_steps=callback_steps,
|
| 141 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 142 |
+
**kwargs
|
| 143 |
)
|
| 144 |
|
| 145 |
return Result.from_result(result)
|
|
|
|
| 152 |
if self.__loaded:
|
| 153 |
return
|
| 154 |
|
| 155 |
+
if get_is_sdxl():
|
| 156 |
+
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
| 157 |
+
model_dir,
|
| 158 |
+
torch_dtype=torch.float16,
|
| 159 |
+
use_auth_token=get_hf_token(),
|
| 160 |
+
use_safetensors=True,
|
| 161 |
+
).to("cuda")
|
| 162 |
+
else:
|
| 163 |
+
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
| 164 |
+
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
|
| 165 |
+
).to("cuda")
|
| 166 |
self.__patch()
|
| 167 |
|
| 168 |
self.__loaded = True
|
| 169 |
|
| 170 |
def create(self, pipeline: AbstractPipeline):
|
| 171 |
+
if get_is_sdxl():
|
| 172 |
+
self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
|
| 173 |
+
"cuda"
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
|
| 177 |
+
"cuda"
|
| 178 |
+
)
|
| 179 |
self.__patch()
|
| 180 |
|
| 181 |
self.__loaded = True
|
| 182 |
|
| 183 |
def __patch(self):
|
| 184 |
+
if get_is_sdxl():
|
| 185 |
+
self.pipe.enable_vae_tiling()
|
| 186 |
+
self.pipe.enable_vae_slicing()
|
| 187 |
self.pipe.enable_xformers_memory_efficient_attention()
|
| 188 |
|
| 189 |
@torch.inference_mode()
|
internals/pipelines/controlnets.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
| 1 |
-
from typing import List, Union
|
| 2 |
|
| 3 |
import cv2
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
import torch
|
| 6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
| 7 |
-
from diffusers import (
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from PIL import Image
|
| 13 |
from torch.nn import Linear
|
| 14 |
from tqdm import gui
|
|
@@ -18,156 +24,127 @@ import internals.util.image as ImageUtil
|
|
| 18 |
from external.midas import apply_midas
|
| 19 |
from internals.data.result import Result
|
| 20 |
from internals.pipelines.commons import AbstractPipeline
|
| 21 |
-
from internals.pipelines.tileUpscalePipeline import
|
| 22 |
-
StableDiffusionControlNetImg2ImgPipeline
|
|
|
|
| 23 |
from internals.util.cache import clear_cuda_and_gc
|
| 24 |
from internals.util.commons import download_image
|
| 25 |
-
from internals.util.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
class ControlNet(AbstractPipeline):
|
| 29 |
__current_task_name = ""
|
| 30 |
__loaded = False
|
| 31 |
|
| 32 |
-
|
| 33 |
-
"Should not be called externally"
|
| 34 |
-
if self.__loaded:
|
| 35 |
-
return
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
controlnet=self.controlnet,
|
| 44 |
-
torch_dtype=torch.float16,
|
| 45 |
-
use_auth_token=get_hf_token(),
|
| 46 |
-
cache_dir=get_hf_cache_dir(),
|
| 47 |
-
).to("cuda")
|
| 48 |
-
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 49 |
-
pipe.enable_model_cpu_offload()
|
| 50 |
-
pipe.enable_xformers_memory_efficient_attention()
|
| 51 |
-
self.pipe = pipe
|
| 52 |
-
|
| 53 |
-
# controlnet pipeline for canny and pose
|
| 54 |
-
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
|
| 55 |
-
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
|
| 56 |
-
pipe2.enable_xformers_memory_efficient_attention()
|
| 57 |
-
self.pipe2 = pipe2
|
| 58 |
-
|
| 59 |
-
self.__loaded = True
|
| 60 |
-
|
| 61 |
-
def load_canny(self):
|
| 62 |
-
if self.__current_task_name == "canny":
|
| 63 |
return
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
torch_dtype=torch.float16,
|
| 67 |
cache_dir=get_hf_cache_dir(),
|
| 68 |
).to("cuda")
|
| 69 |
-
self.__current_task_name =
|
| 70 |
-
self.controlnet =
|
| 71 |
|
| 72 |
-
self.
|
| 73 |
|
| 74 |
if hasattr(self, "pipe"):
|
| 75 |
-
self.pipe.controlnet =
|
| 76 |
if hasattr(self, "pipe2"):
|
| 77 |
-
self.pipe2.controlnet =
|
| 78 |
clear_cuda_and_gc()
|
| 79 |
|
| 80 |
-
def
|
| 81 |
-
|
|
|
|
| 82 |
return
|
| 83 |
-
pose = ControlNetModel.from_pretrained(
|
| 84 |
-
"lllyasviel/control_v11p_sd15_openpose",
|
| 85 |
-
torch_dtype=torch.float16,
|
| 86 |
-
cache_dir=get_hf_cache_dir(),
|
| 87 |
-
).to("cuda")
|
| 88 |
-
# lineart = ControlNetModel.from_pretrained(
|
| 89 |
-
# "ControlNet-1-1-preview/control_v11p_sd15_lineart",
|
| 90 |
-
# torch_dtype=torch.float16,
|
| 91 |
-
# cache_dir=get_hf_cache_dir(),
|
| 92 |
-
# ).to("cuda")
|
| 93 |
-
self.__current_task_name = "pose"
|
| 94 |
-
self.controlnet = MultiControlNetModel([pose]).to("cuda")
|
| 95 |
-
|
| 96 |
-
self.load()
|
| 97 |
|
| 98 |
-
if hasattr(self, "
|
| 99 |
-
self.
|
| 100 |
-
if hasattr(self, "pipe2"):
|
| 101 |
-
self.pipe2.controlnet = self.controlnet
|
| 102 |
-
clear_cuda_and_gc()
|
| 103 |
-
|
| 104 |
-
def load_tile_upscaler(self):
|
| 105 |
-
if self.__current_task_name == "tile_upscaler":
|
| 106 |
-
return
|
| 107 |
-
tile_upscaler = ControlNetModel.from_pretrained(
|
| 108 |
-
"lllyasviel/control_v11f1e_sd15_tile",
|
| 109 |
-
torch_dtype=torch.float16,
|
| 110 |
-
cache_dir=get_hf_cache_dir(),
|
| 111 |
-
).to("cuda")
|
| 112 |
-
self.__current_task_name = "tile_upscaler"
|
| 113 |
-
self.controlnet = tile_upscaler
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
self.pipe.controlnet = tile_upscaler
|
| 119 |
-
if hasattr(self, "pipe2"):
|
| 120 |
-
self.pipe2.controlnet = tile_upscaler
|
| 121 |
-
clear_cuda_and_gc()
|
| 122 |
|
| 123 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if self.__current_task_name == "scribble":
|
| 125 |
-
return
|
| 126 |
-
scribble = ControlNetModel.from_pretrained(
|
| 127 |
-
"lllyasviel/control_v11p_sd15_scribble",
|
| 128 |
-
torch_dtype=torch.float16,
|
| 129 |
-
cache_dir=get_hf_cache_dir(),
|
| 130 |
-
).to("cuda")
|
| 131 |
-
self.__current_task_name = "scribble"
|
| 132 |
-
self.controlnet = scribble
|
| 133 |
-
|
| 134 |
-
self.load()
|
| 135 |
-
|
| 136 |
-
if hasattr(self, "pipe"):
|
| 137 |
-
self.pipe.controlnet = scribble
|
| 138 |
-
if hasattr(self, "pipe2"):
|
| 139 |
-
self.pipe2.controlnet = scribble
|
| 140 |
-
clear_cuda_and_gc()
|
| 141 |
-
|
| 142 |
-
def load_linearart(self):
|
| 143 |
if self.__current_task_name == "linearart":
|
| 144 |
-
return
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
cache_dir=get_hf_cache_dir(),
|
| 149 |
-
).to("cuda")
|
| 150 |
-
self.__current_task_name = "linearart"
|
| 151 |
-
self.controlnet = linearart
|
| 152 |
-
|
| 153 |
-
self.load()
|
| 154 |
-
|
| 155 |
-
if hasattr(self, "pipe"):
|
| 156 |
-
self.pipe.controlnet = linearart
|
| 157 |
-
if hasattr(self, "pipe2"):
|
| 158 |
-
self.pipe2.controlnet = linearart
|
| 159 |
-
clear_cuda_and_gc()
|
| 160 |
-
|
| 161 |
-
def cleanup(self):
|
| 162 |
-
if hasattr(self, "pipe"):
|
| 163 |
-
self.pipe.controlnet = None
|
| 164 |
-
if hasattr(self, "pipe2"):
|
| 165 |
-
self.pipe2.controlnet = None
|
| 166 |
-
self.controlnet = None
|
| 167 |
-
del self.controlnet
|
| 168 |
-
self.__current_task_name = ""
|
| 169 |
-
|
| 170 |
-
clear_cuda_and_gc()
|
| 171 |
|
| 172 |
@torch.inference_mode()
|
| 173 |
def process_canny(
|
|
@@ -228,7 +205,6 @@ class ControlNet(AbstractPipeline):
|
|
| 228 |
guidance_scale=guidance_scale,
|
| 229 |
height=height,
|
| 230 |
width=width,
|
| 231 |
-
controlnet_conditioning_scale=[1.0],
|
| 232 |
)
|
| 233 |
return Result.from_result(result)
|
| 234 |
|
|
@@ -333,6 +309,17 @@ class ControlNet(AbstractPipeline):
|
|
| 333 |
)
|
| 334 |
return Result.from_result(result)
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
def detect_pose(self, imageUrl: str) -> Image.Image:
|
| 337 |
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
| 338 |
image = download_image(imageUrl)
|
|
@@ -381,3 +368,18 @@ class ControlNet(AbstractPipeline):
|
|
| 381 |
W = int(round(W / 64.0)) * 64
|
| 382 |
img = input_image.resize((W, H), resample=Image.LANCZOS)
|
| 383 |
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Literal, Union
|
| 2 |
|
| 3 |
import cv2
|
| 4 |
import numpy as np
|
| 5 |
+
from pydash import has
|
| 6 |
import torch
|
| 7 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
| 8 |
+
from diffusers import (
|
| 9 |
+
ControlNetModel,
|
| 10 |
+
DiffusionPipeline,
|
| 11 |
+
StableDiffusionControlNetPipeline,
|
| 12 |
+
UniPCMultistepScheduler,
|
| 13 |
+
StableDiffusionXLControlNetPipeline,
|
| 14 |
+
)
|
| 15 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
|
| 16 |
+
MultiControlNetModel,
|
| 17 |
+
)
|
| 18 |
from PIL import Image
|
| 19 |
from torch.nn import Linear
|
| 20 |
from tqdm import gui
|
|
|
|
| 24 |
from external.midas import apply_midas
|
| 25 |
from internals.data.result import Result
|
| 26 |
from internals.pipelines.commons import AbstractPipeline
|
| 27 |
+
from internals.pipelines.tileUpscalePipeline import (
|
| 28 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
| 29 |
+
)
|
| 30 |
from internals.util.cache import clear_cuda_and_gc
|
| 31 |
from internals.util.commons import download_image
|
| 32 |
+
from internals.util.config import (
|
| 33 |
+
get_hf_cache_dir,
|
| 34 |
+
get_hf_token,
|
| 35 |
+
get_model_dir,
|
| 36 |
+
get_is_sdxl,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"]
|
| 41 |
|
| 42 |
|
| 43 |
class ControlNet(AbstractPipeline):
|
| 44 |
__current_task_name = ""
|
| 45 |
__loaded = False
|
| 46 |
|
| 47 |
+
__pipeline: AbstractPipeline
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
def init(self, pipeline: AbstractPipeline):
|
| 50 |
+
self.__pipeline = pipeline
|
| 51 |
|
| 52 |
+
def load_model(self, task_name: CONTROLNET_TYPES):
|
| 53 |
+
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
|
| 54 |
+
if self.__current_task_name == task_name:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return
|
| 56 |
+
model = config[task_name]
|
| 57 |
+
if not model:
|
| 58 |
+
raise Exception(f"ControlNet is not supported for {task_name}")
|
| 59 |
+
while model in list(config.keys()):
|
| 60 |
+
task_name = config[model] # pyright: ignore
|
| 61 |
+
model = config[task_name]
|
| 62 |
+
|
| 63 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 64 |
+
model,
|
| 65 |
torch_dtype=torch.float16,
|
| 66 |
cache_dir=get_hf_cache_dir(),
|
| 67 |
).to("cuda")
|
| 68 |
+
self.__current_task_name = task_name
|
| 69 |
+
self.controlnet = controlnet
|
| 70 |
|
| 71 |
+
self.__load()
|
| 72 |
|
| 73 |
if hasattr(self, "pipe"):
|
| 74 |
+
self.pipe.controlnet = controlnet
|
| 75 |
if hasattr(self, "pipe2"):
|
| 76 |
+
self.pipe2.controlnet = controlnet
|
| 77 |
clear_cuda_and_gc()
|
| 78 |
|
| 79 |
+
def __load(self):
|
| 80 |
+
"Should not be called externally"
|
| 81 |
+
if self.__loaded:
|
| 82 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
if not hasattr(self, "controlnet"):
|
| 85 |
+
self.load_model("pose")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# controlnet pipeline for tile upscaler
|
| 88 |
+
if get_is_sdxl():
|
| 89 |
+
print("Warning: Tile upscale is not supported on SDXL")
|
| 90 |
+
|
| 91 |
+
if self.__pipeline:
|
| 92 |
+
pipe = StableDiffusionXLControlNetPipeline(
|
| 93 |
+
controlnet=self.controlnet, **self.__pipeline.pipe.components
|
| 94 |
+
).to("cuda")
|
| 95 |
+
else:
|
| 96 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
| 97 |
+
get_model_dir(),
|
| 98 |
+
controlnet=self.controlnet,
|
| 99 |
+
torch_dtype=torch.float16,
|
| 100 |
+
use_auth_token=get_hf_token(),
|
| 101 |
+
cache_dir=get_hf_cache_dir(),
|
| 102 |
+
use_safetensors=True,
|
| 103 |
+
).to("cuda")
|
| 104 |
+
pipe.enable_vae_tiling()
|
| 105 |
+
pipe.enable_vae_slicing()
|
| 106 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 107 |
+
self.pipe2 = pipe
|
| 108 |
+
else:
|
| 109 |
+
if hasattr(self, "__pipeline"):
|
| 110 |
+
pipe = StableDiffusionControlNetImg2ImgPipeline(
|
| 111 |
+
controlnet=self.controlnet, **self.__pipeline.pipe.components
|
| 112 |
+
).to("cuda")
|
| 113 |
+
else:
|
| 114 |
+
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
| 115 |
+
get_model_dir(),
|
| 116 |
+
controlnet=self.controlnet,
|
| 117 |
+
torch_dtype=torch.float16,
|
| 118 |
+
use_auth_token=get_hf_token(),
|
| 119 |
+
cache_dir=get_hf_cache_dir(),
|
| 120 |
+
).to("cuda")
|
| 121 |
+
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 122 |
+
pipe.enable_model_cpu_offload()
|
| 123 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 124 |
+
self.pipe = pipe
|
| 125 |
+
|
| 126 |
+
# controlnet pipeline for canny and pose
|
| 127 |
+
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
|
| 128 |
+
pipe2.scheduler = UniPCMultistepScheduler.from_config(
|
| 129 |
+
pipe2.scheduler.config
|
| 130 |
+
)
|
| 131 |
+
pipe2.enable_xformers_memory_efficient_attention()
|
| 132 |
+
self.pipe2 = pipe2
|
| 133 |
|
| 134 |
+
self.__loaded = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
def process(self, **kwargs):
|
| 137 |
+
if self.__current_task_name == "pose":
|
| 138 |
+
return self.process_pose(**kwargs)
|
| 139 |
+
if self.__current_task_name == "canny":
|
| 140 |
+
return self.process_canny(**kwargs)
|
| 141 |
if self.__current_task_name == "scribble":
|
| 142 |
+
return self.process_scribble(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if self.__current_task_name == "linearart":
|
| 144 |
+
return self.process_linearart(**kwargs)
|
| 145 |
+
if self.__current_task_name == "tile_upscaler":
|
| 146 |
+
return self.process_tile_upscaler(**kwargs)
|
| 147 |
+
raise Exception("ControlNet is not loaded with any model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
@torch.inference_mode()
|
| 150 |
def process_canny(
|
|
|
|
| 205 |
guidance_scale=guidance_scale,
|
| 206 |
height=height,
|
| 207 |
width=width,
|
|
|
|
| 208 |
)
|
| 209 |
return Result.from_result(result)
|
| 210 |
|
|
|
|
| 309 |
)
|
| 310 |
return Result.from_result(result)
|
| 311 |
|
| 312 |
+
def cleanup(self):
|
| 313 |
+
if hasattr(self, "pipe") and hasattr(self.pipe, "controlnet"):
|
| 314 |
+
del self.pipe.controlnet
|
| 315 |
+
if hasattr(self, "pipe2") and hasattr(self.pipe2, "controlnet"):
|
| 316 |
+
del self.pipe2.controlnet
|
| 317 |
+
if hasattr(self, "controlnet"):
|
| 318 |
+
del self.controlnet
|
| 319 |
+
self.__current_task_name = ""
|
| 320 |
+
|
| 321 |
+
clear_cuda_and_gc()
|
| 322 |
+
|
| 323 |
def detect_pose(self, imageUrl: str) -> Image.Image:
|
| 324 |
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
| 325 |
image = download_image(imageUrl)
|
|
|
|
| 368 |
W = int(round(W / 64.0)) * 64
|
| 369 |
img = input_image.resize((W, H), resample=Image.LANCZOS)
|
| 370 |
return img
|
| 371 |
+
|
| 372 |
+
__model_normal = {
|
| 373 |
+
"pose": "lllyasviel/control_v11p_sd15_openpose",
|
| 374 |
+
"canny": "lllyasviel/control_v11p_sd15_canny",
|
| 375 |
+
"linearart": "lllyasviel/control_v11p_sd15_lineart",
|
| 376 |
+
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
| 377 |
+
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
|
| 378 |
+
}
|
| 379 |
+
__model_sdxl = {
|
| 380 |
+
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
|
| 381 |
+
"canny": "diffusers/controlnet-canny-sdxl-1.0",
|
| 382 |
+
"linearart": "canny",
|
| 383 |
+
"scribble": "canny",
|
| 384 |
+
"tile_upscaler": None,
|
| 385 |
+
}
|
internals/pipelines/high_res.py
CHANGED
|
@@ -42,7 +42,7 @@ class HighRes(AbstractPipeline):
|
|
| 42 |
|
| 43 |
@staticmethod
|
| 44 |
def get_intermediate_dimension(target_width: int, target_height: int):
|
| 45 |
-
def_size =
|
| 46 |
|
| 47 |
desired_pixel_count = def_size * def_size
|
| 48 |
actual_pixel_count = target_width * target_height
|
|
|
|
| 42 |
|
| 43 |
@staticmethod
|
| 44 |
def get_intermediate_dimension(target_width: int, target_height: int):
|
| 45 |
+
def_size = 1024
|
| 46 |
|
| 47 |
desired_pixel_count = def_size * def_size
|
| 48 |
actual_pixel_count = target_width * target_height
|
internals/pipelines/inpainter.py
CHANGED
|
@@ -1,38 +1,74 @@
|
|
| 1 |
from typing import List, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
-
from diffusers import StableDiffusionInpaintPipeline
|
| 5 |
|
| 6 |
from internals.pipelines.commons import AbstractPipeline
|
| 7 |
from internals.util.commons import disable_safety_checker, download_image
|
| 8 |
-
from internals.util.config import (
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class InPainter(AbstractPipeline):
|
| 13 |
__loaded = False
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
def load(self):
|
| 16 |
if self.__loaded:
|
| 17 |
return
|
| 18 |
|
| 19 |
-
self
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
disable_safety_checker(self.pipe)
|
| 27 |
|
|
|
|
|
|
|
| 28 |
self.__loaded = True
|
| 29 |
|
| 30 |
def create(self, pipeline: AbstractPipeline):
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
disable_safety_checker(self.pipe)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
@torch.inference_mode()
|
| 37 |
def process(
|
| 38 |
self,
|
|
|
|
| 1 |
from typing import List, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
|
| 5 |
|
| 6 |
from internals.pipelines.commons import AbstractPipeline
|
| 7 |
from internals.util.commons import disable_safety_checker, download_image
|
| 8 |
+
from internals.util.config import (
|
| 9 |
+
get_hf_cache_dir,
|
| 10 |
+
get_hf_token,
|
| 11 |
+
get_is_sdxl,
|
| 12 |
+
get_inpaint_model_path,
|
| 13 |
+
get_model_dir,
|
| 14 |
+
)
|
| 15 |
|
| 16 |
|
| 17 |
class InPainter(AbstractPipeline):
|
| 18 |
__loaded = False
|
| 19 |
|
| 20 |
+
def init(self, pipeline: AbstractPipeline):
|
| 21 |
+
self.__base = pipeline
|
| 22 |
+
|
| 23 |
def load(self):
|
| 24 |
if self.__loaded:
|
| 25 |
return
|
| 26 |
|
| 27 |
+
if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
|
| 28 |
+
self.create(self.__base)
|
| 29 |
+
self.__loaded = True
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
if get_is_sdxl():
|
| 33 |
+
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 34 |
+
get_inpaint_model_path(),
|
| 35 |
+
torch_dtype=torch.float16,
|
| 36 |
+
cache_dir=get_hf_cache_dir(),
|
| 37 |
+
use_auth_token=get_hf_token(),
|
| 38 |
+
).to("cuda")
|
| 39 |
+
else:
|
| 40 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 41 |
+
get_inpaint_model_path(),
|
| 42 |
+
torch_dtype=torch.float16,
|
| 43 |
+
cache_dir=get_hf_cache_dir(),
|
| 44 |
+
use_auth_token=get_hf_token(),
|
| 45 |
+
).to("cuda")
|
| 46 |
|
| 47 |
disable_safety_checker(self.pipe)
|
| 48 |
|
| 49 |
+
self.__patch()
|
| 50 |
+
|
| 51 |
self.__loaded = True
|
| 52 |
|
| 53 |
def create(self, pipeline: AbstractPipeline):
|
| 54 |
+
if get_is_sdxl():
|
| 55 |
+
self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
|
| 56 |
+
"cuda"
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
|
| 60 |
+
"cuda"
|
| 61 |
+
)
|
| 62 |
disable_safety_checker(self.pipe)
|
| 63 |
|
| 64 |
+
self.__patch()
|
| 65 |
+
|
| 66 |
+
def __patch(self):
|
| 67 |
+
if get_is_sdxl():
|
| 68 |
+
self.pipe.enable_vae_tiling()
|
| 69 |
+
self.pipe.enable_vae_slicing()
|
| 70 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 71 |
+
|
| 72 |
@torch.inference_mode()
|
| 73 |
def process(
|
| 74 |
self,
|
internals/pipelines/remove_background.py
CHANGED
|
@@ -1,15 +1,20 @@
|
|
| 1 |
import io
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Union
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from PIL import Image
|
| 8 |
from rembg import remove
|
|
|
|
| 9 |
|
| 10 |
import internals.util.image as ImageUtil
|
| 11 |
from carvekit.api.high import HiInterface
|
| 12 |
from internals.util.commons import download_image, read_url
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class RemoveBackground:
|
|
@@ -23,6 +28,11 @@ class RemoveBackground:
|
|
| 23 |
|
| 24 |
class RemoveBackgroundV2:
|
| 25 |
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
self.interface = HiInterface(
|
| 27 |
object_type="object", # Can be "object" or "hairs-like".
|
| 28 |
batch_size_seg=5,
|
|
@@ -36,16 +46,51 @@ class RemoveBackgroundV2:
|
|
| 36 |
fp16=False,
|
| 37 |
)
|
| 38 |
|
| 39 |
-
def remove(
|
| 40 |
-
|
|
|
|
| 41 |
if type(image) is str:
|
| 42 |
image = download_image(image)
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import io
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from PIL import Image
|
| 10 |
from rembg import remove
|
| 11 |
+
from internals.data.task import ModelType
|
| 12 |
|
| 13 |
import internals.util.image as ImageUtil
|
| 14 |
from carvekit.api.high import HiInterface
|
| 15 |
from internals.util.commons import download_image, read_url
|
| 16 |
+
import onnxruntime as rt
|
| 17 |
+
import huggingface_hub
|
| 18 |
|
| 19 |
|
| 20 |
class RemoveBackground:
|
|
|
|
| 28 |
|
| 29 |
class RemoveBackgroundV2:
|
| 30 |
def __init__(self):
|
| 31 |
+
model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
|
| 32 |
+
self.anime_rembg = rt.InferenceSession(
|
| 33 |
+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
self.interface = HiInterface(
|
| 37 |
object_type="object", # Can be "object" or "hairs-like".
|
| 38 |
batch_size_seg=5,
|
|
|
|
| 46 |
fp16=False,
|
| 47 |
)
|
| 48 |
|
| 49 |
+
def remove(
|
| 50 |
+
self, image: Union[str, Image.Image], model_type: ModelType = ModelType.REAL
|
| 51 |
+
) -> Image.Image:
|
| 52 |
if type(image) is str:
|
| 53 |
image = download_image(image)
|
| 54 |
|
| 55 |
+
if model_type == ModelType.ANIME or model_type == ModelType.COMIC:
|
| 56 |
+
print("Using Anime Background remover")
|
| 57 |
+
_, img = self.__rmbg_fn(np.array(image))
|
| 58 |
+
|
| 59 |
+
return Image.fromarray(img)
|
| 60 |
+
else:
|
| 61 |
+
print("Using Real Background remover")
|
| 62 |
+
img_path = Path.home() / ".cache" / "rm_bg.png"
|
| 63 |
+
|
| 64 |
+
w, h = image.size
|
| 65 |
+
if max(w, h) > 1536:
|
| 66 |
+
image = ImageUtil.resize_image(image, dimension=1024)
|
| 67 |
+
|
| 68 |
+
image.save(img_path)
|
| 69 |
+
images_without_background = self.interface([img_path])
|
| 70 |
+
out = images_without_background[0]
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
def __get_mask(self, img, s=1024):
|
| 74 |
+
img = (img / 255).astype(np.float32)
|
| 75 |
+
h, w = h0, w0 = img.shape[:-1]
|
| 76 |
+
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
| 77 |
+
ph, pw = s - h, s - w
|
| 78 |
+
img_input = np.zeros([s, s, 3], dtype=np.float32)
|
| 79 |
+
img_input[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = cv2.resize(
|
| 80 |
+
img, (w, h)
|
| 81 |
+
)
|
| 82 |
+
img_input = np.transpose(img_input, (2, 0, 1))
|
| 83 |
+
img_input = img_input[np.newaxis, :]
|
| 84 |
+
mask = self.anime_rembg.run(None, {"img": img_input})[0][0]
|
| 85 |
+
mask = np.transpose(mask, (1, 2, 0))
|
| 86 |
+
mask = mask[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
| 87 |
+
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
|
| 88 |
+
return mask
|
| 89 |
|
| 90 |
+
def __rmbg_fn(self, img):
|
| 91 |
+
mask = self.__get_mask(img)
|
| 92 |
+
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
|
| 93 |
+
mask = (mask * 255).astype(np.uint8)
|
| 94 |
+
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
|
| 95 |
+
mask = mask.repeat(3, axis=2)
|
| 96 |
+
return mask, img
|
internals/pipelines/replace_background.py
CHANGED
|
@@ -3,10 +3,14 @@ from typing import List, Optional, Union
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from cv2 import inpaint
|
| 6 |
-
from diffusers import (
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
from PIL import Image, ImageFilter, ImageOps
|
|
|
|
| 10 |
|
| 11 |
import internals.util.image as ImageUtil
|
| 12 |
from internals.data.result import Result
|
|
@@ -17,8 +21,12 @@ from internals.pipelines.inpainter import InPainter
|
|
| 17 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
| 18 |
from internals.pipelines.upscaler import Upscaler
|
| 19 |
from internals.util.commons import download_image
|
| 20 |
-
from internals.util.config import (
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class ReplaceBackground(AbstractPipeline):
|
|
@@ -52,7 +60,8 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 52 |
cache_dir=get_hf_cache_dir(),
|
| 53 |
use_auth_token=get_hf_token(),
|
| 54 |
)
|
| 55 |
-
pipe.
|
|
|
|
| 56 |
pipe.to("cuda")
|
| 57 |
|
| 58 |
self.pipe = pipe
|
|
@@ -87,6 +96,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 87 |
seed: int,
|
| 88 |
steps: int,
|
| 89 |
apply_high_res: bool = False,
|
|
|
|
| 90 |
):
|
| 91 |
# image = Image.open("original.png")
|
| 92 |
if type(image) is str:
|
|
@@ -98,7 +108,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 98 |
image = image.convert("RGB")
|
| 99 |
if max(image.size) > 1024:
|
| 100 |
image = ImageUtil.resize_image(image, dimension=1024)
|
| 101 |
-
image = self.remove_background.remove(image)
|
| 102 |
|
| 103 |
width = int(width)
|
| 104 |
height = int(height)
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from cv2 import inpaint
|
| 6 |
+
from diffusers import (
|
| 7 |
+
ControlNetModel,
|
| 8 |
+
StableDiffusionControlNetInpaintPipeline,
|
| 9 |
+
StableDiffusionInpaintPipeline,
|
| 10 |
+
UniPCMultistepScheduler,
|
| 11 |
+
)
|
| 12 |
from PIL import Image, ImageFilter, ImageOps
|
| 13 |
+
from internals.data.task import ModelType
|
| 14 |
|
| 15 |
import internals.util.image as ImageUtil
|
| 16 |
from internals.data.result import Result
|
|
|
|
| 21 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
| 22 |
from internals.pipelines.upscaler import Upscaler
|
| 23 |
from internals.util.commons import download_image
|
| 24 |
+
from internals.util.config import (
|
| 25 |
+
get_hf_cache_dir,
|
| 26 |
+
get_hf_token,
|
| 27 |
+
get_inpaint_model_path,
|
| 28 |
+
get_model_dir,
|
| 29 |
+
)
|
| 30 |
|
| 31 |
|
| 32 |
class ReplaceBackground(AbstractPipeline):
|
|
|
|
| 60 |
cache_dir=get_hf_cache_dir(),
|
| 61 |
use_auth_token=get_hf_token(),
|
| 62 |
)
|
| 63 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 64 |
+
pipe.enable_vae_slicing()
|
| 65 |
pipe.to("cuda")
|
| 66 |
|
| 67 |
self.pipe = pipe
|
|
|
|
| 96 |
seed: int,
|
| 97 |
steps: int,
|
| 98 |
apply_high_res: bool = False,
|
| 99 |
+
model_type: ModelType = ModelType.REAL,
|
| 100 |
):
|
| 101 |
# image = Image.open("original.png")
|
| 102 |
if type(image) is str:
|
|
|
|
| 108 |
image = image.convert("RGB")
|
| 109 |
if max(image.size) > 1024:
|
| 110 |
image = ImageUtil.resize_image(image, dimension=1024)
|
| 111 |
+
image = self.remove_background.remove(image, model_type=model_type)
|
| 112 |
|
| 113 |
width = int(width)
|
| 114 |
height = int(height)
|
internals/pipelines/twoStepPipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
| 12 |
|
| 13 |
class two_step_pipeline(StableDiffusionPipeline):
|
| 14 |
@torch.no_grad()
|
| 15 |
-
def
|
| 16 |
self,
|
| 17 |
prompt: Union[str, List[str]] = None,
|
| 18 |
modified_prompts: Union[str, List[str]] = None,
|
|
|
|
| 12 |
|
| 13 |
class two_step_pipeline(StableDiffusionPipeline):
|
| 14 |
@torch.no_grad()
|
| 15 |
+
def __call__(
|
| 16 |
self,
|
| 17 |
prompt: Union[str, List[str]] = None,
|
| 18 |
modified_prompts: Union[str, List[str]] = None,
|
internals/util/cache.py
CHANGED
|
@@ -1,15 +1,25 @@
|
|
| 1 |
import gc
|
| 2 |
-
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def clear_cuda_and_gc():
|
| 7 |
-
|
|
|
|
| 8 |
clear_gc()
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def clear_cuda():
|
| 12 |
-
torch.
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def clear_gc():
|
|
|
|
| 1 |
import gc
|
| 2 |
+
import os
|
| 3 |
+
import psutil
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
| 7 |
+
def print_memory_usage():
|
| 8 |
+
process = psutil.Process(os.getpid())
|
| 9 |
+
print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
def clear_cuda_and_gc():
|
| 13 |
+
print_memory_usage()
|
| 14 |
+
print("Clearing cuda and gc")
|
| 15 |
clear_gc()
|
| 16 |
+
clear_cuda()
|
| 17 |
+
print_memory_usage()
|
| 18 |
|
| 19 |
|
| 20 |
def clear_cuda():
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
torch.cuda.empty_cache()
|
| 23 |
|
| 24 |
|
| 25 |
def clear_gc():
|
internals/util/commons.py
CHANGED
|
@@ -150,9 +150,9 @@ def upload_image(image: Union[Image.Image, BytesIO], out_path):
|
|
| 150 |
return image_url
|
| 151 |
|
| 152 |
|
| 153 |
-
def download_image(url) -> Image.Image:
|
| 154 |
response = requests.get(url)
|
| 155 |
-
return Image.open(BytesIO(response.content)).convert(
|
| 156 |
|
| 157 |
|
| 158 |
def download_file(url, out_path: Path):
|
|
|
|
| 150 |
return image_url
|
| 151 |
|
| 152 |
|
| 153 |
+
def download_image(url, mode="RGB") -> Image.Image:
|
| 154 |
response = requests.get(url)
|
| 155 |
+
return Image.open(BytesIO(response.content)).convert(mode)
|
| 156 |
|
| 157 |
|
| 158 |
def download_file(url, out_path: Path):
|
internals/util/config.py
CHANGED
|
@@ -61,6 +61,11 @@ def get_inpaint_model_path():
|
|
| 61 |
return model_config.base_inpaint_model_path # pyright: ignore
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def get_root_dir():
|
| 65 |
global root_dir
|
| 66 |
return root_dir
|
|
|
|
| 61 |
return model_config.base_inpaint_model_path # pyright: ignore
|
| 62 |
|
| 63 |
|
| 64 |
+
def get_is_sdxl():
|
| 65 |
+
global model_config
|
| 66 |
+
return model_config.is_sdxl # pyright: ignore
|
| 67 |
+
|
| 68 |
+
|
| 69 |
def get_root_dir():
|
| 70 |
global root_dir
|
| 71 |
return root_dir
|
internals/util/lora_style.py
CHANGED
|
@@ -10,6 +10,7 @@ from lora_diffusion import patch_pipe, tune_lora_scale
|
|
| 10 |
from pydash import chain
|
| 11 |
|
| 12 |
from internals.data.dataAccessor import getStyles
|
|
|
|
| 13 |
from internals.util.commons import download_file
|
| 14 |
|
| 15 |
|
|
@@ -112,6 +113,10 @@ class LoraStyle:
|
|
| 112 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
| 113 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
| 114 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if key in self.__styles:
|
| 116 |
style = self.__styles[key]
|
| 117 |
if style["type"] == "diffuser":
|
|
|
|
| 10 |
from pydash import chain
|
| 11 |
|
| 12 |
from internals.data.dataAccessor import getStyles
|
| 13 |
+
from internals.util.config import get_is_sdxl
|
| 14 |
from internals.util.commons import download_file
|
| 15 |
|
| 16 |
|
|
|
|
| 113 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
| 114 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
| 115 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
| 116 |
+
if get_is_sdxl():
|
| 117 |
+
print("Warning: Lora is not supported on SDXL")
|
| 118 |
+
return self.EmptyLoraPatcher(pipe)
|
| 119 |
+
|
| 120 |
if key in self.__styles:
|
| 121 |
style = self.__styles[key]
|
| 122 |
if style["type"] == "diffuser":
|
internals/util/model_loader.py
CHANGED
|
@@ -14,6 +14,7 @@ from tqdm import tqdm
|
|
| 14 |
class ModelConfig:
|
| 15 |
base_model_path: str
|
| 16 |
base_inpaint_model_path: str
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def load_model_from_config(path):
|
|
@@ -23,9 +24,11 @@ def load_model_from_config(path):
|
|
| 23 |
config = json.loads(f.read())
|
| 24 |
model_path = config.get("model_path", path)
|
| 25 |
inpaint_model_path = config.get("inpaint_model_path", path)
|
|
|
|
| 26 |
|
| 27 |
m_config.base_model_path = model_path
|
| 28 |
m_config.base_inpaint_model_path = inpaint_model_path
|
|
|
|
| 29 |
|
| 30 |
#
|
| 31 |
# if config.get("model_type") == "huggingface":
|
|
|
|
| 14 |
class ModelConfig:
|
| 15 |
base_model_path: str
|
| 16 |
base_inpaint_model_path: str
|
| 17 |
+
is_sdxl: bool = False
|
| 18 |
|
| 19 |
|
| 20 |
def load_model_from_config(path):
|
|
|
|
| 24 |
config = json.loads(f.read())
|
| 25 |
model_path = config.get("model_path", path)
|
| 26 |
inpaint_model_path = config.get("inpaint_model_path", path)
|
| 27 |
+
is_sdxl = config.get("is_sdxl", False)
|
| 28 |
|
| 29 |
m_config.base_model_path = model_path
|
| 30 |
m_config.base_inpaint_model_path = inpaint_model_path
|
| 31 |
+
m_config.is_sdxl = is_sdxl
|
| 32 |
|
| 33 |
#
|
| 34 |
# if config.get("model_type") == "huggingface":
|
pyproject.toml
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
[tool.pyright]
|
| 2 |
-
venvPath = "
|
| 3 |
venv = "env"
|
| 4 |
exclude = ["env"]
|
|
|
|
| 1 |
[tool.pyright]
|
| 2 |
+
venvPath = "."
|
| 3 |
venv = "env"
|
| 4 |
exclude = ["env"]
|
requirements.txt
CHANGED
|
@@ -15,6 +15,7 @@ realesrgan==0.3.0
|
|
| 15 |
compel==1.0.4
|
| 16 |
scikit-image>=0.19.3
|
| 17 |
six==1.16.0
|
|
|
|
| 18 |
tifffile==2021.8.30
|
| 19 |
easydict==1.9.0
|
| 20 |
albumentations
|
|
@@ -32,10 +33,13 @@ xformers==0.0.21
|
|
| 32 |
scikit-image==0.19.3
|
| 33 |
omegaconf==2.3.0
|
| 34 |
webdataset==0.2.48
|
|
|
|
| 35 |
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
|
| 36 |
python-dateutil==2.8.2
|
| 37 |
PyYAML
|
| 38 |
invisible-watermark
|
| 39 |
torchvision==0.15.2
|
|
|
|
|
|
|
| 40 |
imgaug==0.4.0
|
| 41 |
tqdm==4.64.1
|
|
|
|
| 15 |
compel==1.0.4
|
| 16 |
scikit-image>=0.19.3
|
| 17 |
six==1.16.0
|
| 18 |
+
psutil
|
| 19 |
tifffile==2021.8.30
|
| 20 |
easydict==1.9.0
|
| 21 |
albumentations
|
|
|
|
| 33 |
scikit-image==0.19.3
|
| 34 |
omegaconf==2.3.0
|
| 35 |
webdataset==0.2.48
|
| 36 |
+
invisible-watermark
|
| 37 |
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
|
| 38 |
python-dateutil==2.8.2
|
| 39 |
PyYAML
|
| 40 |
invisible-watermark
|
| 41 |
torchvision==0.15.2
|
| 42 |
+
onnx
|
| 43 |
+
onnxruntime-gpu
|
| 44 |
imgaug==0.4.0
|
| 45 |
tqdm==4.64.1
|