Upload folder using huggingface_hub
Browse files- inference.py +15 -17
- internals/pipelines/pose_detector.py +12 -10
- internals/util/commons.py +9 -1
inference.py
CHANGED
|
@@ -14,13 +14,11 @@ from internals.pipelines.prompt_modifier import PromptModifier
|
|
| 14 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 15 |
from internals.util.args import apply_style_args
|
| 16 |
from internals.util.avatar import Avatar
|
| 17 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda,
|
|
|
|
| 18 |
from internals.util.commons import pickPoses, upload_image, upload_images
|
| 19 |
-
from internals.util.config import (
|
| 20 |
-
|
| 21 |
-
set_configs_from_task,
|
| 22 |
-
set_root_dir,
|
| 23 |
-
)
|
| 24 |
from internals.util.failure_hander import FailureHandler
|
| 25 |
from internals.util.lora_style import LoraStyle
|
| 26 |
from internals.util.slack import Slack
|
|
@@ -295,17 +293,17 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
| 295 |
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
|
| 296 |
lora_patcher.patch()
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
|
| 310 |
images, has_nsfw = controlnet.process_pose(
|
| 311 |
prompt=prompt,
|
|
|
|
| 14 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 15 |
from internals.util.args import apply_style_args
|
| 16 |
from internals.util.avatar import Avatar
|
| 17 |
+
from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
|
| 18 |
+
clear_cuda_and_gc)
|
| 19 |
from internals.util.commons import pickPoses, upload_image, upload_images
|
| 20 |
+
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
| 21 |
+
set_root_dir)
|
|
|
|
|
|
|
|
|
|
| 22 |
from internals.util.failure_hander import FailureHandler
|
| 23 |
from internals.util.lora_style import LoraStyle
|
| 24 |
from internals.util.slack import Slack
|
|
|
|
| 293 |
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
|
| 294 |
lora_patcher.patch()
|
| 295 |
|
| 296 |
+
try:
|
| 297 |
+
infered_pose = pose_detector.transform(
|
| 298 |
+
image=task.get_imageUrl(),
|
| 299 |
+
client_coordinates=task.get_pose_coordinates(),
|
| 300 |
+
width=task.get_width(),
|
| 301 |
+
height=task.get_height(),
|
| 302 |
+
)
|
| 303 |
+
poses = [infered_pose] * num_return_sequences
|
| 304 |
+
except Exception as e:
|
| 305 |
+
print("Failed to detect pose, using Open Pose detector", e)
|
| 306 |
+
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
|
| 307 |
|
| 308 |
images, has_nsfw = controlnet.process_pose(
|
| 309 |
prompt=prompt,
|
internals/pipelines/pose_detector.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Optional, Union
|
|
| 4 |
from PIL import Image, ImageDraw
|
| 5 |
from torch import ge
|
| 6 |
|
| 7 |
-
from internals.util.commons import download_file, download_image
|
| 8 |
from internals.util.config import get_root_dir
|
| 9 |
from models.pose.body import Body
|
| 10 |
|
|
@@ -77,16 +77,18 @@ class PoseDetector:
|
|
| 77 |
image = Image.new("RGB", (width, height), "black")
|
| 78 |
draw = ImageDraw.Draw(image)
|
| 79 |
|
| 80 |
-
points = data["candidate"]
|
| 81 |
for pair in self.__pose_logical_map:
|
| 82 |
-
xy = points
|
| 83 |
-
x1y1 = points
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
draw.line(
|
| 86 |
-
(xy[0], xy[1], x1y1[0], x1y1[1]),
|
| 87 |
-
fill=pair[2],
|
| 88 |
-
width=4,
|
| 89 |
-
)
|
| 90 |
for i, point in enumerate(points):
|
| 91 |
x = point[0]
|
| 92 |
y = point[1]
|
|
@@ -99,7 +101,7 @@ class PoseDetector:
|
|
| 99 |
subset = []
|
| 100 |
|
| 101 |
if type(image) == str:
|
| 102 |
-
image = download_image(
|
| 103 |
|
| 104 |
image = image.resize((width, height))
|
| 105 |
|
|
|
|
| 4 |
from PIL import Image, ImageDraw
|
| 5 |
from torch import ge
|
| 6 |
|
| 7 |
+
from internals.util.commons import download_file, download_image, safe_index
|
| 8 |
from internals.util.config import get_root_dir
|
| 9 |
from models.pose.body import Body
|
| 10 |
|
|
|
|
| 77 |
image = Image.new("RGB", (width, height), "black")
|
| 78 |
draw = ImageDraw.Draw(image)
|
| 79 |
|
| 80 |
+
points: list = data["candidate"]
|
| 81 |
for pair in self.__pose_logical_map:
|
| 82 |
+
xy = safe_index(points, pair[0] - 1)
|
| 83 |
+
x1y1 = safe_index(points, pair[1] - 1)
|
| 84 |
+
|
| 85 |
+
if xy and x1y1:
|
| 86 |
+
draw.line(
|
| 87 |
+
(xy[0], xy[1], x1y1[0], x1y1[1]),
|
| 88 |
+
fill=pair[2],
|
| 89 |
+
width=4,
|
| 90 |
+
)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
for i, point in enumerate(points):
|
| 93 |
x = point[0]
|
| 94 |
y = point[1]
|
|
|
|
| 101 |
subset = []
|
| 102 |
|
| 103 |
if type(image) == str:
|
| 104 |
+
image = download_image(image)
|
| 105 |
|
| 106 |
image = image.resize((width, height))
|
| 107 |
|
internals/util/commons.py
CHANGED
|
@@ -5,7 +5,7 @@ import random
|
|
| 5 |
import re
|
| 6 |
from io import BytesIO
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import Union
|
| 9 |
|
| 10 |
import boto3
|
| 11 |
import requests
|
|
@@ -191,6 +191,14 @@ def construct_default_s3_url(key):
|
|
| 191 |
return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
|
| 192 |
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
def read_url(url: str):
|
| 195 |
with urllib.request.urlopen(url) as u:
|
| 196 |
return u.read()
|
|
|
|
| 5 |
import re
|
| 6 |
from io import BytesIO
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
|
| 10 |
import boto3
|
| 11 |
import requests
|
|
|
|
| 191 |
return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
|
| 192 |
|
| 193 |
|
| 194 |
+
def safe_index(array, index) -> Optional:
|
| 195 |
+
if index < 0:
|
| 196 |
+
return None
|
| 197 |
+
if index >= len(array):
|
| 198 |
+
return None
|
| 199 |
+
return array[index]
|
| 200 |
+
|
| 201 |
+
|
| 202 |
def read_url(url: str):
|
| 203 |
with urllib.request.urlopen(url) as u:
|
| 204 |
return u.read()
|