Spaces:
Paused
Paused
Upload app.py
Browse files
app.py
CHANGED
|
@@ -26,12 +26,13 @@ from transformers import CLIPImageProcessor
|
|
| 26 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 27 |
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 29 |
|
| 30 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 31 |
|
| 32 |
# Initialize both pipelines
|
| 33 |
-
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=
|
| 34 |
-
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=
|
| 35 |
|
| 36 |
# Initialize the safety checker conditionally
|
| 37 |
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
|
@@ -47,7 +48,7 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
| 47 |
vae=vae,
|
| 48 |
safety_checker=safety_checker,
|
| 49 |
feature_extractor=feature_extractor,
|
| 50 |
-
torch_dtype=
|
| 51 |
).to(device)
|
| 52 |
|
| 53 |
# Function to check NSFW images
|
|
|
|
| 26 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 27 |
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 30 |
|
| 31 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
| 32 |
|
| 33 |
# Initialize both pipelines
|
| 34 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
|
| 35 |
+
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=dtype)
|
| 36 |
|
| 37 |
# Initialize the safety checker conditionally
|
| 38 |
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
|
|
|
| 48 |
vae=vae,
|
| 49 |
safety_checker=safety_checker,
|
| 50 |
feature_extractor=feature_extractor,
|
| 51 |
+
torch_dtype=dtype,
|
| 52 |
).to(device)
|
| 53 |
|
| 54 |
# Function to check NSFW images
|