Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import uuid
|
|
| 7 |
import os
|
| 8 |
|
| 9 |
from diffusers import StableDiffusionXLPipeline, StableDiffusion3Pipeline
|
| 10 |
-
from transformers import
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
# Pre-Initialize
|
|
@@ -32,7 +32,8 @@ footer {
|
|
| 32 |
}
|
| 33 |
'''
|
| 34 |
|
| 35 |
-
repo_nsfw_classifier =
|
|
|
|
| 36 |
|
| 37 |
repo_default = StableDiffusionXLPipeline.from_pretrained("fluently/Fluently-XL-Final", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
|
| 38 |
repo_default.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="base")
|
|
@@ -99,7 +100,8 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
|
|
| 99 |
|
| 100 |
print(steps, guidance)
|
| 101 |
|
| 102 |
-
repo_nsfw_classifier.to(DEVICE)
|
|
|
|
| 103 |
repo.to(DEVICE)
|
| 104 |
|
| 105 |
parameters = {
|
|
@@ -119,7 +121,7 @@ def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATI
|
|
| 119 |
|
| 120 |
print(image_paths)
|
| 121 |
|
| 122 |
-
nsfw_prediction = repo_nsfw_classifier(Image.open(image_paths[0]))
|
| 123 |
|
| 124 |
print(nsfw_prediction)
|
| 125 |
|
|
|
|
| 7 |
import os
|
| 8 |
|
| 9 |
from diffusers import StableDiffusionXLPipeline, StableDiffusion3Pipeline
|
| 10 |
+
from transformers import AutoModelForImageClassification, ViTImageProcessor
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
# Pre-Initialize
|
|
|
|
| 32 |
}
|
| 33 |
'''
|
| 34 |
|
| 35 |
+
repo_nsfw_classifier = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
|
| 36 |
+
processor_nsfw_classifier = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")
|
| 37 |
|
| 38 |
repo_default = StableDiffusionXLPipeline.from_pretrained("fluently/Fluently-XL-Final", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
|
| 39 |
repo_default.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="base")
|
|
|
|
| 100 |
|
| 101 |
print(steps, guidance)
|
| 102 |
|
| 103 |
+
#repo_nsfw_classifier.to(DEVICE)
|
| 104 |
+
#processor_nsfw_classifier.to(DEVICE)
|
| 105 |
repo.to(DEVICE)
|
| 106 |
|
| 107 |
parameters = {
|
|
|
|
| 121 |
|
| 122 |
print(image_paths)
|
| 123 |
|
| 124 |
+
nsfw_prediction = repo_nsfw_classifier(**processor(images=Image.open(image_paths[0]), return_tensors="pt")).logits
|
| 125 |
|
| 126 |
print(nsfw_prediction)
|
| 127 |
|