Upload folder using huggingface_hub
Browse files- handler.py +2 -23
- inference.py +7 -4
- inference2.py +4 -2
- internals/pipelines/controlnets.py +8 -13
- internals/pipelines/inpainter.py +12 -2
- internals/pipelines/replace_background.py +28 -50
- internals/util/config.py +14 -8
- internals/util/model_loader.py +187 -0
- requirements.txt +1 -1
handler.py
CHANGED
|
@@ -4,8 +4,8 @@ from pathlib import Path
|
|
| 4 |
from typing import Any, Dict, List
|
| 5 |
|
| 6 |
from inference import model_fn, predict_fn
|
| 7 |
-
from internals.util.config import set_hf_cache_dir
|
| 8 |
-
from internals.util.
|
| 9 |
|
| 10 |
|
| 11 |
class EndpointHandler:
|
|
@@ -13,27 +13,6 @@ class EndpointHandler:
|
|
| 13 |
set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
|
| 14 |
self.model_dir = path
|
| 15 |
|
| 16 |
-
if os.path.exists(path + "/inference.json"):
|
| 17 |
-
with open(path + "/inference.json", "r") as f:
|
| 18 |
-
config = json.loads(f.read())
|
| 19 |
-
if config.get("model_type") == "huggingface":
|
| 20 |
-
self.model_dir = config["model_path"]
|
| 21 |
-
if config.get("model_type") == "s3":
|
| 22 |
-
s3_config = config["model_path"]["s3"]
|
| 23 |
-
base_url = s3_config["base_url"]
|
| 24 |
-
|
| 25 |
-
urls = [base_url + item for item in s3_config["paths"]]
|
| 26 |
-
out_dir = Path.home() / ".cache" / "base_model"
|
| 27 |
-
if out_dir.exists():
|
| 28 |
-
print("Model already exist")
|
| 29 |
-
else:
|
| 30 |
-
print("Downloading model")
|
| 31 |
-
BaseModelDownloader(
|
| 32 |
-
urls, s3_config["paths"], out_dir
|
| 33 |
-
).download()
|
| 34 |
-
|
| 35 |
-
self.model_dir = str(out_dir)
|
| 36 |
-
|
| 37 |
return model_fn(self.model_dir)
|
| 38 |
|
| 39 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
|
|
|
| 4 |
from typing import Any, Dict, List
|
| 5 |
|
| 6 |
from inference import model_fn, predict_fn
|
| 7 |
+
from internals.util.config import set_hf_cache_dir, set_model_config
|
| 8 |
+
from internals.util.model_loader import load_model_from_config
|
| 9 |
|
| 10 |
|
| 11 |
class EndpointHandler:
|
|
|
|
| 13 |
set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
|
| 14 |
self.model_dir = path
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
return model_fn(self.model_dir)
|
| 17 |
|
| 18 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
inference.py
CHANGED
|
@@ -21,10 +21,11 @@ from internals.util.avatar import Avatar
|
|
| 21 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 22 |
from internals.util.commons import download_image, upload_image, upload_images
|
| 23 |
from internals.util.config import (get_model_dir, num_return_sequences,
|
| 24 |
-
set_configs_from_task,
|
| 25 |
set_root_dir)
|
| 26 |
from internals.util.failure_hander import FailureHandler
|
| 27 |
from internals.util.lora_style import LoraStyle
|
|
|
|
| 28 |
from internals.util.slack import Slack
|
| 29 |
|
| 30 |
torch.backends.cudnn.benchmark = True
|
|
@@ -496,13 +497,14 @@ def load_model_by_task(task: Task):
|
|
| 496 |
):
|
| 497 |
text2img_pipe.load(get_model_dir())
|
| 498 |
img2img_pipe.create(text2img_pipe)
|
| 499 |
-
inpainter.
|
| 500 |
high_res.load(img2img_pipe)
|
| 501 |
|
| 502 |
safety_checker.apply(text2img_pipe)
|
| 503 |
safety_checker.apply(img2img_pipe)
|
|
|
|
| 504 |
elif task.get_type() == TaskType.REPLACE_BG:
|
| 505 |
-
replace_background.load(
|
| 506 |
else:
|
| 507 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
| 508 |
controlnet.load_tile_upscaler()
|
|
@@ -521,7 +523,8 @@ def load_model_by_task(task: Task):
|
|
| 521 |
def model_fn(model_dir):
|
| 522 |
print("Logs: model loaded .... starts")
|
| 523 |
|
| 524 |
-
|
|
|
|
| 525 |
set_root_dir(__file__)
|
| 526 |
|
| 527 |
FailureHandler.register()
|
|
|
|
| 21 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 22 |
from internals.util.commons import download_image, upload_image, upload_images
|
| 23 |
from internals.util.config import (get_model_dir, num_return_sequences,
|
| 24 |
+
set_configs_from_task, set_model_config,
|
| 25 |
set_root_dir)
|
| 26 |
from internals.util.failure_hander import FailureHandler
|
| 27 |
from internals.util.lora_style import LoraStyle
|
| 28 |
+
from internals.util.model_loader import load_model_from_config
|
| 29 |
from internals.util.slack import Slack
|
| 30 |
|
| 31 |
torch.backends.cudnn.benchmark = True
|
|
|
|
| 497 |
):
|
| 498 |
text2img_pipe.load(get_model_dir())
|
| 499 |
img2img_pipe.create(text2img_pipe)
|
| 500 |
+
inpainter.load()
|
| 501 |
high_res.load(img2img_pipe)
|
| 502 |
|
| 503 |
safety_checker.apply(text2img_pipe)
|
| 504 |
safety_checker.apply(img2img_pipe)
|
| 505 |
+
safety_checker.apply(inpainter)
|
| 506 |
elif task.get_type() == TaskType.REPLACE_BG:
|
| 507 |
+
replace_background.load(inpainter=inpainter, high_res=high_res)
|
| 508 |
else:
|
| 509 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
| 510 |
controlnet.load_tile_upscaler()
|
|
|
|
| 523 |
def model_fn(model_dir):
|
| 524 |
print("Logs: model loaded .... starts")
|
| 525 |
|
| 526 |
+
config = load_model_from_config(model_dir)
|
| 527 |
+
set_model_config(config)
|
| 528 |
set_root_dir(__file__)
|
| 529 |
|
| 530 |
FailureHandler.register()
|
inference2.py
CHANGED
|
@@ -23,9 +23,10 @@ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
|
| 23 |
from internals.util.commons import (construct_default_s3_url, upload_image,
|
| 24 |
upload_images)
|
| 25 |
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
| 26 |
-
|
| 27 |
from internals.util.failure_hander import FailureHandler
|
| 28 |
from internals.util.lora_style import LoraStyle
|
|
|
|
| 29 |
from internals.util.slack import Slack
|
| 30 |
|
| 31 |
torch.backends.cudnn.benchmark = True
|
|
@@ -214,7 +215,8 @@ def upscale_image(task: Task):
|
|
| 214 |
def model_fn(model_dir):
|
| 215 |
print("Logs: model loaded .... starts")
|
| 216 |
|
| 217 |
-
|
|
|
|
| 218 |
set_root_dir(__file__)
|
| 219 |
|
| 220 |
FailureHandler.register()
|
|
|
|
| 23 |
from internals.util.commons import (construct_default_s3_url, upload_image,
|
| 24 |
upload_images)
|
| 25 |
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
| 26 |
+
set_model_config, set_root_dir)
|
| 27 |
from internals.util.failure_hander import FailureHandler
|
| 28 |
from internals.util.lora_style import LoraStyle
|
| 29 |
+
from internals.util.model_loader import load_model_from_config
|
| 30 |
from internals.util.slack import Slack
|
| 31 |
|
| 32 |
torch.backends.cudnn.benchmark = True
|
|
|
|
| 215 |
def model_fn(model_dir):
|
| 216 |
print("Logs: model loaded .... starts")
|
| 217 |
|
| 218 |
+
config = load_model_from_config(model_dir)
|
| 219 |
+
set_model_config(config)
|
| 220 |
set_root_dir(__file__)
|
| 221 |
|
| 222 |
FailureHandler.register()
|
internals/pipelines/controlnets.py
CHANGED
|
@@ -4,15 +4,11 @@ import cv2
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
| 7 |
-
from diffusers import (
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
)
|
| 13 |
-
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
|
| 14 |
-
MultiControlNetModel,
|
| 15 |
-
)
|
| 16 |
from PIL import Image
|
| 17 |
from torch.nn import Linear
|
| 18 |
from tqdm import gui
|
|
@@ -22,9 +18,8 @@ import internals.util.image as ImageUtil
|
|
| 22 |
from external.midas import apply_midas
|
| 23 |
from internals.data.result import Result
|
| 24 |
from internals.pipelines.commons import AbstractPipeline
|
| 25 |
-
from internals.pipelines.tileUpscalePipeline import
|
| 26 |
-
StableDiffusionControlNetImg2ImgPipeline
|
| 27 |
-
)
|
| 28 |
from internals.util.cache import clear_cuda_and_gc
|
| 29 |
from internals.util.commons import download_image
|
| 30 |
from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
|
|
@@ -86,7 +81,7 @@ class ControlNet(AbstractPipeline):
|
|
| 86 |
if self.__current_task_name == "pose":
|
| 87 |
return
|
| 88 |
pose = ControlNetModel.from_pretrained(
|
| 89 |
-
"lllyasviel/
|
| 90 |
torch_dtype=torch.float16,
|
| 91 |
cache_dir=get_hf_cache_dir(),
|
| 92 |
).to("cuda")
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
| 7 |
+
from diffusers import (ControlNetModel, DiffusionPipeline,
|
| 8 |
+
StableDiffusionControlNetPipeline,
|
| 9 |
+
UniPCMultistepScheduler)
|
| 10 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import \
|
| 11 |
+
MultiControlNetModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from PIL import Image
|
| 13 |
from torch.nn import Linear
|
| 14 |
from tqdm import gui
|
|
|
|
| 18 |
from external.midas import apply_midas
|
| 19 |
from internals.data.result import Result
|
| 20 |
from internals.pipelines.commons import AbstractPipeline
|
| 21 |
+
from internals.pipelines.tileUpscalePipeline import \
|
| 22 |
+
StableDiffusionControlNetImg2ImgPipeline
|
|
|
|
| 23 |
from internals.util.cache import clear_cuda_and_gc
|
| 24 |
from internals.util.commons import download_image
|
| 25 |
from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
|
|
|
|
| 81 |
if self.__current_task_name == "pose":
|
| 82 |
return
|
| 83 |
pose = ControlNetModel.from_pretrained(
|
| 84 |
+
"lllyasviel/control_v11p_sd15_openpose",
|
| 85 |
torch_dtype=torch.float16,
|
| 86 |
cache_dir=get_hf_cache_dir(),
|
| 87 |
).to("cuda")
|
internals/pipelines/inpainter.py
CHANGED
|
@@ -5,18 +5,28 @@ from diffusers import StableDiffusionInpaintPipeline
|
|
| 5 |
|
| 6 |
from internals.pipelines.commons import AbstractPipeline
|
| 7 |
from internals.util.commons import disable_safety_checker, download_image
|
| 8 |
-
from internals.util.config import get_hf_cache_dir
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class InPainter(AbstractPipeline):
|
|
|
|
|
|
|
| 12 |
def load(self):
|
|
|
|
|
|
|
|
|
|
| 13 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 14 |
-
|
| 15 |
torch_dtype=torch.float16,
|
| 16 |
cache_dir=get_hf_cache_dir(),
|
|
|
|
| 17 |
).to("cuda")
|
|
|
|
| 18 |
disable_safety_checker(self.pipe)
|
| 19 |
|
|
|
|
|
|
|
| 20 |
def create(self, pipeline: AbstractPipeline):
|
| 21 |
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
|
| 22 |
"cuda"
|
|
|
|
| 5 |
|
| 6 |
from internals.pipelines.commons import AbstractPipeline
|
| 7 |
from internals.util.commons import disable_safety_checker, download_image
|
| 8 |
+
from internals.util.config import (get_hf_cache_dir, get_hf_token,
|
| 9 |
+
get_inpaint_model_path)
|
| 10 |
|
| 11 |
|
| 12 |
class InPainter(AbstractPipeline):
|
| 13 |
+
__loaded = False
|
| 14 |
+
|
| 15 |
def load(self):
|
| 16 |
+
if self.__loaded:
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 20 |
+
get_inpaint_model_path(),
|
| 21 |
torch_dtype=torch.float16,
|
| 22 |
cache_dir=get_hf_cache_dir(),
|
| 23 |
+
use_auth_token=get_hf_token(),
|
| 24 |
).to("cuda")
|
| 25 |
+
|
| 26 |
disable_safety_checker(self.pipe)
|
| 27 |
|
| 28 |
+
self.__loaded = True
|
| 29 |
+
|
| 30 |
def create(self, pipeline: AbstractPipeline):
|
| 31 |
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
|
| 32 |
"cuda"
|
internals/pipelines/replace_background.py
CHANGED
|
@@ -2,6 +2,7 @@ from io import BytesIO
|
|
| 2 |
from typing import List, Optional, Union
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 5 |
from diffusers import (ControlNetModel,
|
| 6 |
StableDiffusionControlNetInpaintPipeline,
|
| 7 |
StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
|
|
@@ -12,10 +13,12 @@ from internals.data.result import Result
|
|
| 12 |
from internals.pipelines.commons import AbstractPipeline
|
| 13 |
from internals.pipelines.controlnets import ControlNet
|
| 14 |
from internals.pipelines.high_res import HighRes
|
|
|
|
| 15 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
| 16 |
from internals.pipelines.upscaler import Upscaler
|
| 17 |
from internals.util.commons import download_image
|
| 18 |
-
from internals.util.config import get_hf_cache_dir,
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class ReplaceBackground(AbstractPipeline):
|
|
@@ -25,7 +28,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 25 |
self,
|
| 26 |
upscaler: Optional[Upscaler] = None,
|
| 27 |
remove_background: Optional[RemoveBackgroundV2] = None,
|
| 28 |
-
|
| 29 |
high_res: Optional[HighRes] = None,
|
| 30 |
):
|
| 31 |
if self.__loaded:
|
|
@@ -35,18 +38,19 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 35 |
torch_dtype=torch.float16,
|
| 36 |
cache_dir=get_hf_cache_dir(),
|
| 37 |
).to("cuda")
|
| 38 |
-
if
|
| 39 |
-
|
| 40 |
pipe = StableDiffusionControlNetInpaintPipeline(
|
| 41 |
-
**
|
|
|
|
| 42 |
)
|
| 43 |
-
pipe.controlnet = controlnet_model
|
| 44 |
else:
|
| 45 |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
| 46 |
"runwayml/stable-diffusion-inpainting",
|
| 47 |
controlnet=controlnet_model,
|
| 48 |
torch_dtype=torch.float16,
|
| 49 |
cache_dir=get_hf_cache_dir(),
|
|
|
|
| 50 |
)
|
| 51 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 52 |
pipe.to("cuda")
|
|
@@ -104,14 +108,14 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 104 |
|
| 105 |
print(width, height, n_width, n_height)
|
| 106 |
|
|
|
|
| 107 |
if extend_object:
|
| 108 |
-
condition_image = ControlNet.linearart_condition_image(image)
|
| 109 |
-
|
| 110 |
-
)
|
| 111 |
condition_image = ImageUtil.padd_image(condition_image, width, height)
|
| 112 |
condition_image = condition_image.convert("RGB")
|
| 113 |
|
| 114 |
-
image =
|
| 115 |
image = ImageUtil.padd_image(image, width, height)
|
| 116 |
|
| 117 |
mask = image.copy()
|
|
@@ -130,46 +134,20 @@ class ReplaceBackground(AbstractPipeline):
|
|
| 130 |
condition_image = ControlNet.linearart_condition_image(image)
|
| 131 |
mask = mask.convert("RGB")
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
)
|
| 148 |
-
for i, _ in enumerate(result.images):
|
| 149 |
-
out_bytes = self.upscaler.upscale(
|
| 150 |
-
image=result.images[i],
|
| 151 |
-
width=w,
|
| 152 |
-
height=h,
|
| 153 |
-
face_enhance=False,
|
| 154 |
-
resize_dimension=max(width, height),
|
| 155 |
-
)
|
| 156 |
-
result.images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")
|
| 157 |
-
result = Result.from_result(result)
|
| 158 |
-
else:
|
| 159 |
-
result = self.pipe.__call__(
|
| 160 |
-
prompt=prompt,
|
| 161 |
-
negative_prompt=negative_prompt,
|
| 162 |
-
image=image,
|
| 163 |
-
mask_image=mask,
|
| 164 |
-
control_image=condition_image,
|
| 165 |
-
controlnet_conditioning_scale=conditioning_scale,
|
| 166 |
-
guidance_scale=9,
|
| 167 |
-
strength=1,
|
| 168 |
-
height=height,
|
| 169 |
-
num_inference_steps=steps,
|
| 170 |
-
width=width,
|
| 171 |
-
)
|
| 172 |
-
result = Result.from_result(result)
|
| 173 |
|
| 174 |
images, has_nsfw = result
|
| 175 |
|
|
|
|
| 2 |
from typing import List, Optional, Union
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from cv2 import inpaint
|
| 6 |
from diffusers import (ControlNetModel,
|
| 7 |
StableDiffusionControlNetInpaintPipeline,
|
| 8 |
StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
|
|
|
|
| 13 |
from internals.pipelines.commons import AbstractPipeline
|
| 14 |
from internals.pipelines.controlnets import ControlNet
|
| 15 |
from internals.pipelines.high_res import HighRes
|
| 16 |
+
from internals.pipelines.inpainter import InPainter
|
| 17 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
| 18 |
from internals.pipelines.upscaler import Upscaler
|
| 19 |
from internals.util.commons import download_image
|
| 20 |
+
from internals.util.config import (get_hf_cache_dir, get_hf_token,
|
| 21 |
+
get_inpaint_model_path, get_model_dir)
|
| 22 |
|
| 23 |
|
| 24 |
class ReplaceBackground(AbstractPipeline):
|
|
|
|
| 28 |
self,
|
| 29 |
upscaler: Optional[Upscaler] = None,
|
| 30 |
remove_background: Optional[RemoveBackgroundV2] = None,
|
| 31 |
+
inpainter: Optional[InPainter] = None,
|
| 32 |
high_res: Optional[HighRes] = None,
|
| 33 |
):
|
| 34 |
if self.__loaded:
|
|
|
|
| 38 |
torch_dtype=torch.float16,
|
| 39 |
cache_dir=get_hf_cache_dir(),
|
| 40 |
).to("cuda")
|
| 41 |
+
if inpainter:
|
| 42 |
+
inpainter.load()
|
| 43 |
pipe = StableDiffusionControlNetInpaintPipeline(
|
| 44 |
+
**inpainter.pipe.components,
|
| 45 |
+
controlnet=controlnet_model,
|
| 46 |
)
|
|
|
|
| 47 |
else:
|
| 48 |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
| 49 |
"runwayml/stable-diffusion-inpainting",
|
| 50 |
controlnet=controlnet_model,
|
| 51 |
torch_dtype=torch.float16,
|
| 52 |
cache_dir=get_hf_cache_dir(),
|
| 53 |
+
use_auth_token=get_hf_token(),
|
| 54 |
)
|
| 55 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 56 |
pipe.to("cuda")
|
|
|
|
| 108 |
|
| 109 |
print(width, height, n_width, n_height)
|
| 110 |
|
| 111 |
+
resolution = min(n_width, n_height)
|
| 112 |
if extend_object:
|
| 113 |
+
condition_image = ControlNet.linearart_condition_image(image)
|
| 114 |
+
condition_image = ImageUtil.resize_image(condition_image, resolution)
|
|
|
|
| 115 |
condition_image = ImageUtil.padd_image(condition_image, width, height)
|
| 116 |
condition_image = condition_image.convert("RGB")
|
| 117 |
|
| 118 |
+
image = ImageUtil.resize_image(image, resolution)
|
| 119 |
image = ImageUtil.padd_image(image, width, height)
|
| 120 |
|
| 121 |
mask = image.copy()
|
|
|
|
| 134 |
condition_image = ControlNet.linearart_condition_image(image)
|
| 135 |
mask = mask.convert("RGB")
|
| 136 |
|
| 137 |
+
result = self.pipe.__call__(
|
| 138 |
+
prompt=prompt,
|
| 139 |
+
negative_prompt=negative_prompt,
|
| 140 |
+
image=image,
|
| 141 |
+
mask_image=mask,
|
| 142 |
+
control_image=condition_image,
|
| 143 |
+
controlnet_conditioning_scale=conditioning_scale,
|
| 144 |
+
guidance_scale=9,
|
| 145 |
+
strength=1,
|
| 146 |
+
height=height,
|
| 147 |
+
num_inference_steps=steps,
|
| 148 |
+
width=width,
|
| 149 |
+
)
|
| 150 |
+
result = Result.from_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
images, has_nsfw = result
|
| 153 |
|
internals/util/config.py
CHANGED
|
@@ -3,13 +3,14 @@ from pathlib import Path
|
|
| 3 |
from typing import Union
|
| 4 |
|
| 5 |
from internals.data.task import Task
|
|
|
|
| 6 |
|
| 7 |
env = "prod"
|
| 8 |
nsfw_threshold = 0.0
|
| 9 |
nsfw_access = False
|
| 10 |
access_token = ""
|
| 11 |
root_dir = ""
|
| 12 |
-
|
| 13 |
hf_token = "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"
|
| 14 |
hf_cache_dir = "/tmp/hf_hub"
|
| 15 |
|
|
@@ -28,16 +29,16 @@ def get_hf_cache_dir():
|
|
| 28 |
return hf_cache_dir
|
| 29 |
|
| 30 |
|
| 31 |
-
def set_model_dir(dir: str):
|
| 32 |
-
global model_dir
|
| 33 |
-
model_dir = dir
|
| 34 |
-
|
| 35 |
-
|
| 36 |
def set_root_dir(main_file: str):
|
| 37 |
global root_dir
|
| 38 |
root_dir = os.path.dirname(os.path.abspath(main_file))
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def set_configs_from_task(task: Task):
|
| 42 |
global env, nsfw_threshold, nsfw_access, access_token
|
| 43 |
name = task.get_queue_name()
|
|
@@ -51,8 +52,13 @@ def set_configs_from_task(task: Task):
|
|
| 51 |
|
| 52 |
|
| 53 |
def get_model_dir():
|
| 54 |
-
global
|
| 55 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
def get_root_dir():
|
|
|
|
| 3 |
from typing import Union
|
| 4 |
|
| 5 |
from internals.data.task import Task
|
| 6 |
+
from internals.util.model_loader import ModelConfig
|
| 7 |
|
| 8 |
env = "prod"
|
| 9 |
nsfw_threshold = 0.0
|
| 10 |
nsfw_access = False
|
| 11 |
access_token = ""
|
| 12 |
root_dir = ""
|
| 13 |
+
model_config = None
|
| 14 |
hf_token = "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"
|
| 15 |
hf_cache_dir = "/tmp/hf_hub"
|
| 16 |
|
|
|
|
| 29 |
return hf_cache_dir
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def set_root_dir(main_file: str):
|
| 33 |
global root_dir
|
| 34 |
root_dir = os.path.dirname(os.path.abspath(main_file))
|
| 35 |
|
| 36 |
|
| 37 |
+
def set_model_config(config: ModelConfig):
|
| 38 |
+
global model_config
|
| 39 |
+
model_config = config
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def set_configs_from_task(task: Task):
|
| 43 |
global env, nsfw_threshold, nsfw_access, access_token
|
| 44 |
name = task.get_queue_name()
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def get_model_dir():
|
| 55 |
+
global model_config
|
| 56 |
+
return model_config.base_model_path # pyright: ignore
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_inpaint_model_path():
|
| 60 |
+
global model_config
|
| 61 |
+
return model_config.base_inpaint_model_path # pyright: ignore
|
| 62 |
|
| 63 |
|
| 64 |
def get_root_dir():
|
internals/util/model_loader.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from threading import Thread
|
| 7 |
+
from typing import Any, Dict, List, Optional
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ModelConfig:
|
| 15 |
+
base_model_path: str
|
| 16 |
+
base_inpaint_model_path: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_model_from_config(path):
|
| 20 |
+
m_config = ModelConfig(path, path)
|
| 21 |
+
if os.path.exists(path + "/inference.json"):
|
| 22 |
+
with open(path + "/inference.json", "r") as f:
|
| 23 |
+
config = json.loads(f.read())
|
| 24 |
+
model_path = config.get("model_path", path)
|
| 25 |
+
inpaint_model_path = config.get("inpaint_model_path", path)
|
| 26 |
+
|
| 27 |
+
m_config.base_model_path = model_path
|
| 28 |
+
m_config.base_inpaint_model_path = inpaint_model_path
|
| 29 |
+
|
| 30 |
+
#
|
| 31 |
+
# if config.get("model_type") == "huggingface":
|
| 32 |
+
# model_dir = config["model_path"]
|
| 33 |
+
# if config.get("model_type") == "s3":
|
| 34 |
+
# s3_config = config["model_path"]["s3"]
|
| 35 |
+
# base_url = s3_config["base_url"]
|
| 36 |
+
#
|
| 37 |
+
# urls = [base_url + item for item in s3_config["paths"]]
|
| 38 |
+
# out_dir = Path.home() / ".cache" / "base_model"
|
| 39 |
+
# if out_dir.exists():
|
| 40 |
+
# print("Model already exist")
|
| 41 |
+
# else:
|
| 42 |
+
# print("Downloading model")
|
| 43 |
+
# BaseModelDownloader(urls, s3_config["paths"], out_dir).download()
|
| 44 |
+
# model_dir = str(out_dir)
|
| 45 |
+
return m_config
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BaseModelDownloader:
|
| 49 |
+
"""
|
| 50 |
+
A utility for fast download of base model from S3 or any CDN served storage.
|
| 51 |
+
Works by downloading multiple files in parallel and dividing large files
|
| 52 |
+
into smaller chunks and combining them at the end.
|
| 53 |
+
|
| 54 |
+
Currently it uses multithreading (not multiprocessing) assuming GIL won't
|
| 55 |
+
interfere with network/disk IO.
|
| 56 |
+
|
| 57 |
+
Created by: KP
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, urls: List[str], url_paths: List[str], out_dir: Path):
|
| 61 |
+
self.urls = urls
|
| 62 |
+
self.url_paths = url_paths
|
| 63 |
+
shutil.rmtree(out_dir, ignore_errors=True)
|
| 64 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
self.out_dir = out_dir
|
| 66 |
+
|
| 67 |
+
def download(self):
|
| 68 |
+
threads = []
|
| 69 |
+
batch_urls = {}
|
| 70 |
+
|
| 71 |
+
for url, url_path in zip(self.urls, self.url_paths):
|
| 72 |
+
out_dir = self.out_dir / url_path
|
| 73 |
+
self.out_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
if url.endswith(".bin"):
|
| 75 |
+
if "unet/" in url_path:
|
| 76 |
+
thread = Thread(
|
| 77 |
+
target=self.__download_parallel, args=(url, out_dir, 6)
|
| 78 |
+
)
|
| 79 |
+
thread.start()
|
| 80 |
+
threads.append(thread)
|
| 81 |
+
else:
|
| 82 |
+
thread = Thread(
|
| 83 |
+
target=self.__download_files, args=([url], [out_dir])
|
| 84 |
+
)
|
| 85 |
+
thread.start()
|
| 86 |
+
threads.append(thread)
|
| 87 |
+
pass
|
| 88 |
+
else:
|
| 89 |
+
batch_urls[url] = out_dir
|
| 90 |
+
|
| 91 |
+
if batch_urls:
|
| 92 |
+
thread = Thread(
|
| 93 |
+
target=self.__download_files,
|
| 94 |
+
args=(list(batch_urls.keys()), list(batch_urls.values())),
|
| 95 |
+
)
|
| 96 |
+
thread.start()
|
| 97 |
+
threads.append(thread)
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
for thread in threads:
|
| 101 |
+
thread.join()
|
| 102 |
+
|
| 103 |
+
def __download_parallel(self, url, output_filename, num_parts=4):
|
| 104 |
+
response = requests.head(url)
|
| 105 |
+
total_size = int(response.headers.get("content-length", 0))
|
| 106 |
+
print("total_size", total_size)
|
| 107 |
+
|
| 108 |
+
chunk_size = total_size // num_parts
|
| 109 |
+
ranges = [
|
| 110 |
+
(i * chunk_size, (i + 1) * chunk_size - 1) for i in range(num_parts - 1)
|
| 111 |
+
]
|
| 112 |
+
ranges.append((ranges[-1][1] + 1, total_size))
|
| 113 |
+
|
| 114 |
+
print(ranges)
|
| 115 |
+
|
| 116 |
+
save_dir = Path.home() / ".cache" / "download_parts"
|
| 117 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 118 |
+
|
| 119 |
+
threads = []
|
| 120 |
+
for i, (start, end) in enumerate(ranges):
|
| 121 |
+
thread = Thread(
|
| 122 |
+
target=self.__download_part, args=(url, start, end, i, save_dir)
|
| 123 |
+
)
|
| 124 |
+
thread.start()
|
| 125 |
+
threads.append(thread)
|
| 126 |
+
|
| 127 |
+
for thread in threads:
|
| 128 |
+
thread.join()
|
| 129 |
+
|
| 130 |
+
self.__combine_parts(save_dir, output_filename, num_parts)
|
| 131 |
+
os.rmdir(save_dir)
|
| 132 |
+
|
| 133 |
+
def __combine_parts(self, save_dir, output_filename, num_parts):
|
| 134 |
+
part_files = [os.path.join(save_dir, f"part_{i}.tmp") for i in range(num_parts)]
|
| 135 |
+
|
| 136 |
+
output_filename.parent.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
with open(output_filename, "wb") as output_file:
|
| 138 |
+
for part_file in part_files:
|
| 139 |
+
print("combining: ", part_file)
|
| 140 |
+
with open(part_file, "rb") as part:
|
| 141 |
+
output_file.write(part.read())
|
| 142 |
+
|
| 143 |
+
out_file_size = output_file.tell()
|
| 144 |
+
print("out_file_size", out_file_size)
|
| 145 |
+
|
| 146 |
+
for part_file in part_files:
|
| 147 |
+
os.remove(part_file)
|
| 148 |
+
|
| 149 |
+
def __download_part(self, url, start_byte, end_byte, part_num, save_dir):
|
| 150 |
+
headers = {"Range": f"bytes={start_byte}-{end_byte}"}
|
| 151 |
+
response = requests.get(url, headers=headers, stream=True)
|
| 152 |
+
|
| 153 |
+
part_filename = os.path.join(save_dir, f"part_{part_num}.tmp")
|
| 154 |
+
print("Downloading part: ", url, part_filename, end_byte - start_byte)
|
| 155 |
+
|
| 156 |
+
with open(part_filename, "wb") as part_file, tqdm(
|
| 157 |
+
desc=str(part_filename),
|
| 158 |
+
total=end_byte - start_byte,
|
| 159 |
+
unit="B",
|
| 160 |
+
unit_scale=True,
|
| 161 |
+
unit_divisor=1024,
|
| 162 |
+
) as bar:
|
| 163 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 164 |
+
if chunk:
|
| 165 |
+
size = part_file.write(chunk)
|
| 166 |
+
bar.update(size)
|
| 167 |
+
|
| 168 |
+
return part_filename
|
| 169 |
+
|
| 170 |
+
def __download_files(self, urls, out_paths: List[Path]):
|
| 171 |
+
for url, out_path in zip(urls, out_paths):
|
| 172 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 173 |
+
with requests.get(url, stream=True) as r:
|
| 174 |
+
print("Downloading: ", url)
|
| 175 |
+
total_size = int(r.headers.get("content-length", 0))
|
| 176 |
+
chunk_size = 8192
|
| 177 |
+
r.raise_for_status()
|
| 178 |
+
with open(out_path, "wb") as f, tqdm(
|
| 179 |
+
desc=str(out_path),
|
| 180 |
+
total=total_size,
|
| 181 |
+
unit="B",
|
| 182 |
+
unit_scale=True,
|
| 183 |
+
unit_divisor=1024,
|
| 184 |
+
) as bar:
|
| 185 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 186 |
+
size = f.write(data)
|
| 187 |
+
bar.update(size)
|
requirements.txt
CHANGED
|
@@ -5,7 +5,7 @@ fastapi==0.87.0
|
|
| 5 |
Pillow==9.3.0
|
| 6 |
redis==4.3.4
|
| 7 |
requests==2.28.1
|
| 8 |
-
transformers
|
| 9 |
rembg==2.0.30
|
| 10 |
gfpgan==1.3.8
|
| 11 |
rembg==2.0.30
|
|
|
|
| 5 |
Pillow==9.3.0
|
| 6 |
redis==4.3.4
|
| 7 |
requests==2.28.1
|
| 8 |
+
transformers==4.34.1
|
| 9 |
rembg==2.0.30
|
| 10 |
gfpgan==1.3.8
|
| 11 |
rembg==2.0.30
|