Upload folder using huggingface_hub
Browse files- inference.py +11 -16
- inference2.py +9 -12
- internals/pipelines/safety_checker.py +17 -7
inference.py
CHANGED
|
@@ -14,18 +14,12 @@ from internals.pipelines.prompt_modifier import PromptModifier
|
|
| 14 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 15 |
from internals.util.args import apply_style_args
|
| 16 |
from internals.util.avatar import Avatar
|
| 17 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda,
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
)
|
| 24 |
-
from internals.util.config import (
|
| 25 |
-
num_return_sequences,
|
| 26 |
-
set_configs_from_task,
|
| 27 |
-
set_root_dir,
|
| 28 |
-
)
|
| 29 |
from internals.util.failure_hander import FailureHandler
|
| 30 |
from internals.util.lora_style import LoraStyle
|
| 31 |
from internals.util.slack import Slack
|
|
@@ -455,10 +449,6 @@ def model_fn(model_dir):
|
|
| 455 |
img2img_pipe.create(text2img_pipe)
|
| 456 |
inpainter.create(text2img_pipe)
|
| 457 |
|
| 458 |
-
safety_checker.apply(text2img_pipe)
|
| 459 |
-
safety_checker.apply(img2img_pipe)
|
| 460 |
-
safety_checker.apply(controlnet)
|
| 461 |
-
|
| 462 |
print("Logs: model loaded ....")
|
| 463 |
return
|
| 464 |
|
|
@@ -474,6 +464,11 @@ def predict_fn(data, pipe):
|
|
| 474 |
# Set set_environment
|
| 475 |
set_configs_from_task(task)
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
# Apply arguments
|
| 478 |
apply_style_args(data)
|
| 479 |
|
|
|
|
| 14 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 15 |
from internals.util.args import apply_style_args
|
| 16 |
from internals.util.avatar import Avatar
|
| 17 |
+
from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
|
| 18 |
+
clear_cuda_and_gc)
|
| 19 |
+
from internals.util.commons import (download_image, pickPoses, upload_image,
|
| 20 |
+
upload_images)
|
| 21 |
+
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
| 22 |
+
set_root_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from internals.util.failure_hander import FailureHandler
|
| 24 |
from internals.util.lora_style import LoraStyle
|
| 25 |
from internals.util.slack import Slack
|
|
|
|
| 449 |
img2img_pipe.create(text2img_pipe)
|
| 450 |
inpainter.create(text2img_pipe)
|
| 451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
print("Logs: model loaded ....")
|
| 453 |
return
|
| 454 |
|
|
|
|
| 464 |
# Set set_environment
|
| 465 |
set_configs_from_task(task)
|
| 466 |
|
| 467 |
+
# Apply safety checkers based on environment
|
| 468 |
+
safety_checker.apply(text2img_pipe)
|
| 469 |
+
safety_checker.apply(img2img_pipe)
|
| 470 |
+
safety_checker.apply(controlnet)
|
| 471 |
+
|
| 472 |
# Apply arguments
|
| 473 |
apply_style_args(data)
|
| 474 |
|
inference2.py
CHANGED
|
@@ -7,18 +7,17 @@ from internals.data.task import ModelType, Task, TaskType
|
|
| 7 |
from internals.pipelines.inpainter import InPainter
|
| 8 |
from internals.pipelines.object_remove import ObjectRemoval
|
| 9 |
from internals.pipelines.prompt_modifier import PromptModifier
|
| 10 |
-
from internals.pipelines.remove_background import RemoveBackground,
|
|
|
|
| 11 |
from internals.pipelines.replace_background import ReplaceBackground
|
| 12 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 13 |
from internals.pipelines.upscaler import Upscaler
|
| 14 |
from internals.util.avatar import Avatar
|
| 15 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 16 |
-
from internals.util.commons import construct_default_s3_url, upload_image,
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
set_root_dir,
|
| 21 |
-
)
|
| 22 |
from internals.util.failure_hander import FailureHandler
|
| 23 |
from internals.util.slack import Slack
|
| 24 |
|
|
@@ -173,8 +172,6 @@ def model_fn(model_dir):
|
|
| 173 |
|
| 174 |
replace_background.load(upscaler, remove_background_v2)
|
| 175 |
|
| 176 |
-
safety_checker.apply(inpainter)
|
| 177 |
-
|
| 178 |
print("Logs: model loaded ....")
|
| 179 |
return
|
| 180 |
|
|
@@ -186,13 +183,13 @@ def predict_fn(data, pipe):
|
|
| 186 |
|
| 187 |
FailureHandler.handle(task)
|
| 188 |
|
| 189 |
-
# Set set_environment
|
| 190 |
-
set_configs_from_task(task)
|
| 191 |
-
|
| 192 |
try:
|
| 193 |
# Set set_environment
|
| 194 |
set_configs_from_task(task)
|
| 195 |
|
|
|
|
|
|
|
|
|
|
| 196 |
# Fetch avatars
|
| 197 |
avatar.fetch_from_network(task.get_model_id())
|
| 198 |
|
|
|
|
| 7 |
from internals.pipelines.inpainter import InPainter
|
| 8 |
from internals.pipelines.object_remove import ObjectRemoval
|
| 9 |
from internals.pipelines.prompt_modifier import PromptModifier
|
| 10 |
+
from internals.pipelines.remove_background import (RemoveBackground,
|
| 11 |
+
RemoveBackgroundV2)
|
| 12 |
from internals.pipelines.replace_background import ReplaceBackground
|
| 13 |
from internals.pipelines.safety_checker import SafetyChecker
|
| 14 |
from internals.pipelines.upscaler import Upscaler
|
| 15 |
from internals.util.avatar import Avatar
|
| 16 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
| 17 |
+
from internals.util.commons import (construct_default_s3_url, upload_image,
|
| 18 |
+
upload_images)
|
| 19 |
+
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
| 20 |
+
set_root_dir)
|
|
|
|
|
|
|
| 21 |
from internals.util.failure_hander import FailureHandler
|
| 22 |
from internals.util.slack import Slack
|
| 23 |
|
|
|
|
| 172 |
|
| 173 |
replace_background.load(upscaler, remove_background_v2)
|
| 174 |
|
|
|
|
|
|
|
| 175 |
print("Logs: model loaded ....")
|
| 176 |
return
|
| 177 |
|
|
|
|
| 183 |
|
| 184 |
FailureHandler.handle(task)
|
| 185 |
|
|
|
|
|
|
|
|
|
|
| 186 |
try:
|
| 187 |
# Set set_environment
|
| 188 |
set_configs_from_task(task)
|
| 189 |
|
| 190 |
+
# Apply safety checker based on environment
|
| 191 |
+
safety_checker.apply(inpainter)
|
| 192 |
+
|
| 193 |
# Fetch avatars
|
| 194 |
avatar.fetch_from_network(task.get_model_id())
|
| 195 |
|
internals/pipelines/safety_checker.py
CHANGED
|
@@ -4,6 +4,7 @@ import cv2
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 7 |
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
| 8 |
|
| 9 |
from internals.pipelines.commons import AbstractPipeline
|
|
@@ -23,10 +24,17 @@ class SafetyChecker:
|
|
| 23 |
).to("cuda")
|
| 24 |
|
| 25 |
def apply(self, pipeline: AbstractPipeline):
|
|
|
|
| 26 |
if hasattr(pipeline, "pipe"):
|
| 27 |
-
pipeline.pipe.safety_checker =
|
| 28 |
if hasattr(pipeline, "pipe2"):
|
| 29 |
-
pipeline.pipe2.safety_checker =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
@@ -102,7 +110,7 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
| 102 |
result_img["concept_scores"][concept_idx] = round(
|
| 103 |
concept_cos - concept_threshold + adjustment, 3
|
| 104 |
)
|
| 105 |
-
if result_img["concept_scores"][concept_idx] >
|
| 106 |
result_img["bad_concepts"].append(concept_idx)
|
| 107 |
|
| 108 |
result.append(result_img)
|
|
@@ -115,11 +123,12 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
| 115 |
if any(has_nsfw_concepts) and not get_nsfw_access():
|
| 116 |
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
| 117 |
image = images[idx].cpu().numpy().astype(np.float32)
|
| 118 |
-
image =
|
|
|
|
| 119 |
image = torch.from_numpy(image)
|
| 120 |
images[idx] = image
|
| 121 |
else:
|
| 122 |
-
images[idx] =
|
| 123 |
|
| 124 |
if any(has_nsfw_concepts):
|
| 125 |
print("NSFW")
|
|
@@ -150,13 +159,14 @@ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
| 150 |
|
| 151 |
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
| 152 |
# concept_scores = concept_scores.round(decimals=3)
|
| 153 |
-
has_nsfw_concepts = torch.any(concept_scores >
|
| 154 |
|
|
|
|
| 155 |
# Blur images based on NSFW score
|
| 156 |
# -------------------------------
|
| 157 |
if not get_nsfw_access():
|
| 158 |
image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
|
| 159 |
-
image =
|
| 160 |
image = torch.from_numpy(image)
|
| 161 |
images[has_nsfw_concepts] = image
|
| 162 |
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
+
from scipy.ndimage.filters import gaussian_filter
|
| 8 |
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
| 9 |
|
| 10 |
from internals.pipelines.commons import AbstractPipeline
|
|
|
|
| 24 |
).to("cuda")
|
| 25 |
|
| 26 |
def apply(self, pipeline: AbstractPipeline):
|
| 27 |
+
model = self.model if not get_nsfw_access() else None
|
| 28 |
if hasattr(pipeline, "pipe"):
|
| 29 |
+
pipeline.pipe.safety_checker = model
|
| 30 |
if hasattr(pipeline, "pipe2"):
|
| 31 |
+
pipeline.pipe2.safety_checker = model
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cosine_distance(image_embeds, text_embeds):
|
| 35 |
+
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
| 36 |
+
normalized_text_embeds = nn.functional.normalize(text_embeds)
|
| 37 |
+
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
|
| 38 |
|
| 39 |
|
| 40 |
class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
|
|
|
| 110 |
result_img["concept_scores"][concept_idx] = round(
|
| 111 |
concept_cos - concept_threshold + adjustment, 3
|
| 112 |
)
|
| 113 |
+
if result_img["concept_scores"][concept_idx] > 0:
|
| 114 |
result_img["bad_concepts"].append(concept_idx)
|
| 115 |
|
| 116 |
result.append(result_img)
|
|
|
|
| 123 |
if any(has_nsfw_concepts) and not get_nsfw_access():
|
| 124 |
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
| 125 |
image = images[idx].cpu().numpy().astype(np.float32)
|
| 126 |
+
image = gaussian_filter(image, sigma=7)
|
| 127 |
+
# image = cv2.blur(image, (30, 30))
|
| 128 |
image = torch.from_numpy(image)
|
| 129 |
images[idx] = image
|
| 130 |
else:
|
| 131 |
+
images[idx] = gaussian_filter(images[idx], sigma=7)
|
| 132 |
|
| 133 |
if any(has_nsfw_concepts):
|
| 134 |
print("NSFW")
|
|
|
|
| 159 |
|
| 160 |
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
| 161 |
# concept_scores = concept_scores.round(decimals=3)
|
| 162 |
+
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
|
| 163 |
|
| 164 |
+
# images[has_nsfw_concepts] = 0.0 # black image
|
| 165 |
# Blur images based on NSFW score
|
| 166 |
# -------------------------------
|
| 167 |
if not get_nsfw_access():
|
| 168 |
image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
|
| 169 |
+
image = gaussian_filter(image, sigma=7)
|
| 170 |
image = torch.from_numpy(image)
|
| 171 |
images[has_nsfw_concepts] = image
|
| 172 |
|