Spaces:
Runtime error
Runtime error
nsfw filter
Browse files- StableDiffuser.py +15 -1
StableDiffuser.py
CHANGED
|
@@ -5,10 +5,12 @@ from baukit import TraceDict
|
|
| 5 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
| 6 |
from PIL import Image
|
| 7 |
from tqdm.auto import tqdm
|
| 8 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
| 9 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 10 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
| 11 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
|
|
|
| 12 |
import util
|
| 13 |
|
| 14 |
|
|
@@ -53,6 +55,9 @@ class StableDiffuser(torch.nn.Module):
|
|
| 53 |
self.unet = UNet2DConditionModel.from_pretrained(
|
| 54 |
"CompVis/stable-diffusion-v1-4", subfolder="unet")
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
if scheduler == 'LMS':
|
| 57 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
| 58 |
elif scheduler == 'DDIM':
|
|
@@ -237,6 +242,15 @@ class StableDiffuser(torch.nn.Module):
|
|
| 237 |
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
|
| 238 |
images_steps = [self.to_image(latents) for latents in latents_steps]
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
images_steps = list(zip(*images_steps))
|
| 241 |
|
| 242 |
if trace_steps:
|
|
|
|
| 5 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
| 6 |
from PIL import Image
|
| 7 |
from tqdm.auto import tqdm
|
| 8 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
| 9 |
+
from diffusers.schedulers import EulerAncestralDiscreteScheduler
|
| 10 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
| 11 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
| 12 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
| 13 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 14 |
import util
|
| 15 |
|
| 16 |
|
|
|
|
| 55 |
self.unet = UNet2DConditionModel.from_pretrained(
|
| 56 |
"CompVis/stable-diffusion-v1-4", subfolder="unet")
|
| 57 |
|
| 58 |
+
self.feature_extractor = CLIPFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="feature_extractor")
|
| 59 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="safety_checker")
|
| 60 |
+
|
| 61 |
if scheduler == 'LMS':
|
| 62 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
| 63 |
elif scheduler == 'DDIM':
|
|
|
|
| 242 |
latents_steps = [self.decode(latents.to(self.unet.device)) for latents in latents_steps]
|
| 243 |
images_steps = [self.to_image(latents) for latents in latents_steps]
|
| 244 |
|
| 245 |
+
for i in range(len(images_steps)):
|
| 246 |
+
self.safety_checker = self.safety_checker.float()
|
| 247 |
+
safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
|
| 248 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 249 |
+
images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
|
| 253 |
+
|
| 254 |
images_steps = list(zip(*images_steps))
|
| 255 |
|
| 256 |
if trace_steps:
|