Upload folder using huggingface_hub
Browse files- inference.py +8 -8
- internals/pipelines/realtime_draw.py +10 -2
inference.py
CHANGED
|
@@ -727,9 +727,6 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
| 727 |
inpainter.init(text2img_pipe)
|
| 728 |
controlnet.init(text2img_pipe)
|
| 729 |
|
| 730 |
-
safety_checker.apply(text2img_pipe)
|
| 731 |
-
safety_checker.apply(img2img_pipe)
|
| 732 |
-
|
| 733 |
if task_type == TaskType.INPAINT:
|
| 734 |
inpainter.load()
|
| 735 |
safety_checker.apply(inpainter)
|
|
@@ -756,7 +753,11 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
| 756 |
elif task_type == TaskType.POSE:
|
| 757 |
controlnet.load_model("pose")
|
| 758 |
|
| 759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
|
| 761 |
|
| 762 |
def model_fn(model_dir):
|
|
@@ -798,6 +799,9 @@ def predict_fn(data, pipe):
|
|
| 798 |
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
|
| 799 |
)
|
| 800 |
|
|
|
|
|
|
|
|
|
|
| 801 |
# Realtime generation apis
|
| 802 |
if task_type == TaskType.RT_DRAW_SEG:
|
| 803 |
return rt_draw_seg(task)
|
|
@@ -814,10 +818,6 @@ def predict_fn(data, pipe):
|
|
| 814 |
avatar.fetch_from_network(task.get_model_id())
|
| 815 |
|
| 816 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
| 817 |
-
# character sheet
|
| 818 |
-
# if "character sheet" in task.get_prompt().lower():
|
| 819 |
-
# return pose(task, s3_outkey="", poses=pickPoses())
|
| 820 |
-
# else:
|
| 821 |
return text2img(task)
|
| 822 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 823 |
return img2img(task)
|
|
|
|
| 727 |
inpainter.init(text2img_pipe)
|
| 728 |
controlnet.init(text2img_pipe)
|
| 729 |
|
|
|
|
|
|
|
|
|
|
| 730 |
if task_type == TaskType.INPAINT:
|
| 731 |
inpainter.load()
|
| 732 |
safety_checker.apply(inpainter)
|
|
|
|
| 753 |
elif task_type == TaskType.POSE:
|
| 754 |
controlnet.load_model("pose")
|
| 755 |
|
| 756 |
+
|
| 757 |
+
def apply_safety_checkers():
|
| 758 |
+
safety_checker.apply(text2img_pipe)
|
| 759 |
+
safety_checker.apply(img2img_pipe)
|
| 760 |
+
safety_checker.apply(controlnet)
|
| 761 |
|
| 762 |
|
| 763 |
def model_fn(model_dir):
|
|
|
|
| 799 |
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
|
| 800 |
)
|
| 801 |
|
| 802 |
+
# Apply safety checkers
|
| 803 |
+
apply_safety_checkers()
|
| 804 |
+
|
| 805 |
# Realtime generation apis
|
| 806 |
if task_type == TaskType.RT_DRAW_SEG:
|
| 807 |
return rt_draw_seg(task)
|
|
|
|
| 818 |
avatar.fetch_from_network(task.get_model_id())
|
| 819 |
|
| 820 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
return text2img(task)
|
| 822 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 823 |
return img2img(task)
|
internals/pipelines/realtime_draw.py
CHANGED
|
@@ -72,10 +72,16 @@ class RealtimeDraw(AbstractPipeline):
|
|
| 72 |
torch.manual_seed(seed)
|
| 73 |
|
| 74 |
if not image:
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
if not image2:
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
image = ImageUtil.resize_image(image, 512)
|
| 81 |
|
|
@@ -91,6 +97,8 @@ class RealtimeDraw(AbstractPipeline):
|
|
| 91 |
negative_prompt=negative_prompt,
|
| 92 |
guidance_scale=10,
|
| 93 |
strength=0.9,
|
|
|
|
|
|
|
| 94 |
controlnet_conditioning_scale=[1.0, 0.8],
|
| 95 |
).images[0]
|
| 96 |
|
|
|
|
| 72 |
torch.manual_seed(seed)
|
| 73 |
|
| 74 |
if not image:
|
| 75 |
+
size = (512, 512)
|
| 76 |
+
if image2:
|
| 77 |
+
size = image2.size
|
| 78 |
+
image = Image.new("RGB", size, color=0)
|
| 79 |
|
| 80 |
if not image2:
|
| 81 |
+
size = (512, 512)
|
| 82 |
+
if image:
|
| 83 |
+
size = image.size
|
| 84 |
+
image2 = Image.new("RGB", size, color=0)
|
| 85 |
|
| 86 |
image = ImageUtil.resize_image(image, 512)
|
| 87 |
|
|
|
|
| 97 |
negative_prompt=negative_prompt,
|
| 98 |
guidance_scale=10,
|
| 99 |
strength=0.9,
|
| 100 |
+
width=image.size[0],
|
| 101 |
+
height=image.size[1],
|
| 102 |
controlnet_conditioning_scale=[1.0, 0.8],
|
| 103 |
).images[0]
|
| 104 |
|