Upload folder using huggingface_hub
Browse files- external/scripts/__init__.py +13 -0
- external/scripts/day_night_ip2p.py +63 -0
- inference.py +34 -0
- internals/data/dataAccessor.py +7 -3
- internals/data/task.py +5 -0
- internals/pipelines/controlnets.py +5 -1
- internals/pipelines/high_res.py +4 -1
- internals/util/__init__.py +7 -0
- internals/util/lora_style.py +1 -4
- internals/util/slack.py +6 -1
external/scripts/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from internals.util import getcwd
|
| 5 |
+
|
| 6 |
+
path = os.path.join(getcwd(), "external/scripts")
|
| 7 |
+
|
| 8 |
+
__scripts__ = []
|
| 9 |
+
for name in os.listdir(path):
|
| 10 |
+
name = name.split("/")[-1].replace(".py", "")
|
| 11 |
+
imp = importlib.import_module(f"external.scripts.{name}")
|
| 12 |
+
if hasattr(imp, "Script") and imp not in __scripts__:
|
| 13 |
+
__scripts__.append(imp)
|
external/scripts/day_night_ip2p.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline
|
| 3 |
+
|
| 4 |
+
import internals.util.image as ImageUtil
|
| 5 |
+
from internals.data.dataAccessor import update_db
|
| 6 |
+
from internals.data.task import Task
|
| 7 |
+
from internals.util.cache import clear_cuda_and_gc
|
| 8 |
+
from internals.util.commons import download_image, upload_images
|
| 9 |
+
from internals.util.config import get_hf_token
|
| 10 |
+
from internals.util.slack import Slack
|
| 11 |
+
|
| 12 |
+
slack = Slack()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Script:
|
| 16 |
+
def __init__(self, **kwargs):
|
| 17 |
+
self.__name__ = "day_night_ip2p"
|
| 18 |
+
|
| 19 |
+
@update_db
|
| 20 |
+
@slack.auto_send_alert
|
| 21 |
+
def __call__(self, task: Task, args: dict):
|
| 22 |
+
clear_cuda_and_gc()
|
| 23 |
+
|
| 24 |
+
model_id = args.get("model_id", None)
|
| 25 |
+
steps = args.get("steps", 50)
|
| 26 |
+
image_guidance_scale = args.get("image_guidance_scale", 1.5)
|
| 27 |
+
guidance_scale = args.get("guidance_scale", 7.5)
|
| 28 |
+
|
| 29 |
+
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
| 30 |
+
model_id,
|
| 31 |
+
use_auth_token=get_hf_token(),
|
| 32 |
+
torch_dtype=torch.float16,
|
| 33 |
+
safety_checker=None,
|
| 34 |
+
).to("cuda")
|
| 35 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 36 |
+
|
| 37 |
+
prompt = ["convert to night", "convert to evening", "convert to midnight"]
|
| 38 |
+
image = download_image(task.get_imageUrl())
|
| 39 |
+
image = ImageUtil.resize_image(image, 1024)
|
| 40 |
+
|
| 41 |
+
images = []
|
| 42 |
+
for p in prompt:
|
| 43 |
+
print("Generating: ", p)
|
| 44 |
+
image = pipe.__call__(
|
| 45 |
+
prompt=p,
|
| 46 |
+
num_inference_steps=steps,
|
| 47 |
+
image=image,
|
| 48 |
+
guidance_scale=guidance_scale,
|
| 49 |
+
num_images_per_prompt=1,
|
| 50 |
+
image_guidance_scale=image_guidance_scale,
|
| 51 |
+
).images[0]
|
| 52 |
+
images.append(image)
|
| 53 |
+
|
| 54 |
+
generated_image_urls = upload_images(
|
| 55 |
+
images, "_" + self.__name__, task.get_taskId()
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
pipe = None
|
| 59 |
+
del pipe
|
| 60 |
+
|
| 61 |
+
clear_cuda_and_gc()
|
| 62 |
+
|
| 63 |
+
return {"generated_image_urls": generated_image_urls}
|
inference.py
CHANGED
|
@@ -2,7 +2,9 @@ import os
|
|
| 2 |
import traceback
|
| 3 |
from typing import List, Optional
|
| 4 |
|
|
|
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
import internals.util.prompt as prompt_util
|
| 8 |
from internals.data.dataAccessor import update_db, update_db_source_failed
|
|
@@ -54,6 +56,8 @@ safety_checker = SafetyChecker()
|
|
| 54 |
slack = Slack()
|
| 55 |
avatar = Avatar()
|
| 56 |
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def get_patched_prompt(task: Task):
|
| 59 |
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
|
|
@@ -533,6 +537,32 @@ def replace_bg(task: Task):
|
|
| 533 |
}
|
| 534 |
|
| 535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
def load_model_by_task(task: Task):
|
| 537 |
if not text2img_pipe.is_loaded():
|
| 538 |
text2img_pipe.load(get_model_dir())
|
|
@@ -587,6 +617,8 @@ def predict_fn(data, pipe):
|
|
| 587 |
task = Task(data)
|
| 588 |
print("task is ", data)
|
| 589 |
|
|
|
|
|
|
|
| 590 |
FailureHandler.handle(task)
|
| 591 |
|
| 592 |
try:
|
|
@@ -629,6 +661,8 @@ def predict_fn(data, pipe):
|
|
| 629 |
return linearart(task)
|
| 630 |
elif task_type == TaskType.REPLACE_BG:
|
| 631 |
return replace_bg(task)
|
|
|
|
|
|
|
| 632 |
elif task_type == TaskType.SYSTEM_CMD:
|
| 633 |
os.system(task.get_prompt())
|
| 634 |
else:
|
|
|
|
| 2 |
import traceback
|
| 3 |
from typing import List, Optional
|
| 4 |
|
| 5 |
+
import pydash as _
|
| 6 |
import torch
|
| 7 |
+
from numpy import who
|
| 8 |
|
| 9 |
import internals.util.prompt as prompt_util
|
| 10 |
from internals.data.dataAccessor import update_db, update_db_source_failed
|
|
|
|
| 56 |
slack = Slack()
|
| 57 |
avatar = Avatar()
|
| 58 |
|
| 59 |
+
custom_scripts: List = []
|
| 60 |
+
|
| 61 |
|
| 62 |
def get_patched_prompt(task: Task):
|
| 63 |
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
|
|
|
|
| 537 |
}
|
| 538 |
|
| 539 |
|
| 540 |
+
def custom_action(task: Task):
|
| 541 |
+
from external.scripts import __scripts__
|
| 542 |
+
|
| 543 |
+
global custom_scripts
|
| 544 |
+
kwargs = {
|
| 545 |
+
"CONTROLNET": controlnet,
|
| 546 |
+
"LORASTYLE": lora_style,
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
torch.manual_seed(task.get_seed())
|
| 550 |
+
|
| 551 |
+
for script in __scripts__:
|
| 552 |
+
script = script.Script(**kwargs)
|
| 553 |
+
existing_script = _.find(
|
| 554 |
+
custom_scripts, lambda x: x.__name__ == script.__name__
|
| 555 |
+
)
|
| 556 |
+
if existing_script:
|
| 557 |
+
script = existing_script
|
| 558 |
+
else:
|
| 559 |
+
custom_scripts.append(script)
|
| 560 |
+
|
| 561 |
+
data = task.get_action_data()
|
| 562 |
+
if data["name"] == script.__name__:
|
| 563 |
+
return script(task, data)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
def load_model_by_task(task: Task):
|
| 567 |
if not text2img_pipe.is_loaded():
|
| 568 |
text2img_pipe.load(get_model_dir())
|
|
|
|
| 617 |
task = Task(data)
|
| 618 |
print("task is ", data)
|
| 619 |
|
| 620 |
+
clear_cuda_and_gc()
|
| 621 |
+
|
| 622 |
FailureHandler.handle(task)
|
| 623 |
|
| 624 |
try:
|
|
|
|
| 661 |
return linearart(task)
|
| 662 |
elif task_type == TaskType.REPLACE_BG:
|
| 663 |
return replace_bg(task)
|
| 664 |
+
elif task_type == TaskType.CUSTOM_ACTION:
|
| 665 |
+
return custom_action(task)
|
| 666 |
elif task_type == TaskType.SYSTEM_CMD:
|
| 667 |
os.system(task.get_prompt())
|
| 668 |
else:
|
internals/data/dataAccessor.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import traceback
|
| 2 |
from typing import Dict, List, Optional
|
| 3 |
|
| 4 |
-
from requests.adapters import Retry, HTTPAdapter
|
| 5 |
import requests
|
| 6 |
from pydash import includes
|
|
|
|
| 7 |
|
| 8 |
from internals.data.task import Task
|
| 9 |
from internals.util.config import api_endpoint, api_headers
|
|
@@ -104,9 +104,13 @@ def update_db_source_failed(sourceId, userId):
|
|
| 104 |
|
| 105 |
def update_db(func):
|
| 106 |
def caller(*args, **kwargs):
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
raise Exception("First argument must be a Task object")
|
| 109 |
-
task = args[0]
|
| 110 |
try:
|
| 111 |
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
|
| 112 |
rargs = func(*args, **kwargs)
|
|
|
|
| 1 |
import traceback
|
| 2 |
from typing import Dict, List, Optional
|
| 3 |
|
|
|
|
| 4 |
import requests
|
| 5 |
from pydash import includes
|
| 6 |
+
from requests.adapters import HTTPAdapter, Retry
|
| 7 |
|
| 8 |
from internals.data.task import Task
|
| 9 |
from internals.util.config import api_endpoint, api_headers
|
|
|
|
| 104 |
|
| 105 |
def update_db(func):
|
| 106 |
def caller(*args, **kwargs):
|
| 107 |
+
task = None
|
| 108 |
+
for arg in args:
|
| 109 |
+
if type(arg) is Task:
|
| 110 |
+
task = arg
|
| 111 |
+
break
|
| 112 |
+
if task is None:
|
| 113 |
raise Exception("First argument must be a Task object")
|
|
|
|
| 114 |
try:
|
| 115 |
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
|
| 116 |
rargs = func(*args, **kwargs)
|
internals/data/task.py
CHANGED
|
@@ -18,6 +18,7 @@ class TaskType(Enum):
|
|
| 18 |
SCRIBBLE = "SCRIBBLE"
|
| 19 |
LINEARART = "LINEARART"
|
| 20 |
REPLACE_BG = "REPLACE_BG"
|
|
|
|
| 21 |
SYSTEM_CMD = "SYSTEM_CMD"
|
| 22 |
|
| 23 |
|
|
@@ -148,6 +149,10 @@ class Task:
|
|
| 148 |
def get_base_dimension(self):
|
| 149 |
return self.__data.get("base_dimension", None)
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def get_raw(self) -> dict:
|
| 152 |
return self.__data.copy()
|
| 153 |
|
|
|
|
| 18 |
SCRIBBLE = "SCRIBBLE"
|
| 19 |
LINEARART = "LINEARART"
|
| 20 |
REPLACE_BG = "REPLACE_BG"
|
| 21 |
+
CUSTOM_ACTION = "CUSTOM_ACTION"
|
| 22 |
SYSTEM_CMD = "SYSTEM_CMD"
|
| 23 |
|
| 24 |
|
|
|
|
| 149 |
def get_base_dimension(self):
|
| 150 |
return self.__data.get("base_dimension", None)
|
| 151 |
|
| 152 |
+
def get_action_data(self) -> dict:
|
| 153 |
+
"If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
|
| 154 |
+
return self.__data.get("action_data", {})
|
| 155 |
+
|
| 156 |
def get_raw(self) -> dict:
|
| 157 |
return self.__data.copy()
|
| 158 |
|
internals/pipelines/controlnets.py
CHANGED
|
@@ -151,7 +151,6 @@ class ControlNet(AbstractPipeline):
|
|
| 151 |
|
| 152 |
self.__load_pipeline(model, pipeline_type)
|
| 153 |
|
| 154 |
-
self.network_model = model
|
| 155 |
self.__current_task_name = task_name
|
| 156 |
|
| 157 |
clear_cuda_and_gc()
|
|
@@ -247,6 +246,11 @@ class ControlNet(AbstractPipeline):
|
|
| 247 |
if hasattr(self, "pipe2"):
|
| 248 |
setattr(self.pipe2, "adapter", network_model)
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
clear_cuda_and_gc()
|
| 251 |
|
| 252 |
def process(self, **kwargs):
|
|
|
|
| 151 |
|
| 152 |
self.__load_pipeline(model, pipeline_type)
|
| 153 |
|
|
|
|
| 154 |
self.__current_task_name = task_name
|
| 155 |
|
| 156 |
clear_cuda_and_gc()
|
|
|
|
| 246 |
if hasattr(self, "pipe2"):
|
| 247 |
setattr(self.pipe2, "adapter", network_model)
|
| 248 |
|
| 249 |
+
if hasattr(self, "pipe"):
|
| 250 |
+
self.pipe = self.pipe.to("cuda")
|
| 251 |
+
if hasattr(self, "pipe2"):
|
| 252 |
+
self.pipe2 = self.pipe2.to("cuda")
|
| 253 |
+
|
| 254 |
clear_cuda_and_gc()
|
| 255 |
|
| 256 |
def process(self, **kwargs):
|
internals/pipelines/high_res.py
CHANGED
|
@@ -5,7 +5,8 @@ from PIL import Image
|
|
| 5 |
|
| 6 |
from internals.data.result import Result
|
| 7 |
from internals.pipelines.commons import AbstractPipeline, Img2Img
|
| 8 |
-
from internals.util.
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class HighRes(AbstractPipeline):
|
|
@@ -32,6 +33,8 @@ class HighRes(AbstractPipeline):
|
|
| 32 |
guidance_scale: int = 9,
|
| 33 |
**kwargs,
|
| 34 |
):
|
|
|
|
|
|
|
| 35 |
images = [image.resize((width, height)) for image in images]
|
| 36 |
kwargs = {
|
| 37 |
"prompt": prompt,
|
|
|
|
| 5 |
|
| 6 |
from internals.data.result import Result
|
| 7 |
from internals.pipelines.commons import AbstractPipeline, Img2Img
|
| 8 |
+
from internals.util.cache import clear_cuda_and_gc
|
| 9 |
+
from internals.util.config import get_base_dimension, get_model_dir
|
| 10 |
|
| 11 |
|
| 12 |
class HighRes(AbstractPipeline):
|
|
|
|
| 33 |
guidance_scale: int = 9,
|
| 34 |
**kwargs,
|
| 35 |
):
|
| 36 |
+
clear_cuda_and_gc()
|
| 37 |
+
|
| 38 |
images = [image.resize((width, height)) for image in images]
|
| 39 |
kwargs = {
|
| 40 |
"prompt": prompt,
|
internals/util/__init__.py
CHANGED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from internals.util.config import get_root_dir
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def getcwd():
|
| 7 |
+
return get_root_dir()
|
internals/util/lora_style.py
CHANGED
|
@@ -10,8 +10,8 @@ from lora_diffusion import patch_pipe, tune_lora_scale
|
|
| 10 |
from pydash import chain
|
| 11 |
|
| 12 |
from internals.data.dataAccessor import getStyles
|
| 13 |
-
from internals.util.config import get_is_sdxl
|
| 14 |
from internals.util.commons import download_file
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class LoraStyle:
|
|
@@ -113,9 +113,6 @@ class LoraStyle:
|
|
| 113 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
| 114 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
| 115 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
| 116 |
-
if get_is_sdxl():
|
| 117 |
-
print("Warning: Lora is not supported on SDXL")
|
| 118 |
-
return self.EmptyLoraPatcher(pipe)
|
| 119 |
|
| 120 |
if key in self.__styles:
|
| 121 |
style = self.__styles[key]
|
|
|
|
| 10 |
from pydash import chain
|
| 11 |
|
| 12 |
from internals.data.dataAccessor import getStyles
|
|
|
|
| 13 |
from internals.util.commons import download_file
|
| 14 |
+
from internals.util.config import get_is_sdxl
|
| 15 |
|
| 16 |
|
| 17 |
class LoraStyle:
|
|
|
|
| 113 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
| 114 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
| 115 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
if key in self.__styles:
|
| 118 |
style = self.__styles[key]
|
internals/util/slack.py
CHANGED
|
@@ -55,7 +55,12 @@ class Slack:
|
|
| 55 |
def auto_send_alert(self, func):
|
| 56 |
def inner(*args, **kwargs):
|
| 57 |
rargs = func(*args, **kwargs)
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
return rargs
|
| 60 |
|
| 61 |
return inner
|
|
|
|
| 55 |
def auto_send_alert(self, func):
|
| 56 |
def inner(*args, **kwargs):
|
| 57 |
rargs = func(*args, **kwargs)
|
| 58 |
+
task = Task({})
|
| 59 |
+
for arg in args:
|
| 60 |
+
if type(arg) is Task:
|
| 61 |
+
task = arg
|
| 62 |
+
break
|
| 63 |
+
self.send_alert(task, rargs)
|
| 64 |
return rargs
|
| 65 |
|
| 66 |
return inner
|