Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,38 +21,7 @@ pipe = pipe.to(device)
|
|
| 21 |
MAX_SEED = np.iinfo(np.int32).max
|
| 22 |
MAX_IMAGE_SIZE = 1024
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# 1. FalconsAI safety classifier (post-generation)
|
| 27 |
-
from transformers import AutoProcessor, AutoModelForImageClassification
|
| 28 |
-
|
| 29 |
-
falcon_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
|
| 30 |
-
falcon_model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection").to(device)
|
| 31 |
-
|
| 32 |
-
def falcon_check(image):
|
| 33 |
-
inputs = falcon_processor(images=image, return_tensors="pt").to(device)
|
| 34 |
-
with torch.no_grad():
|
| 35 |
-
outputs = falcon_model(**inputs)
|
| 36 |
-
logits = outputs.logits.softmax(dim=-1)[0]
|
| 37 |
-
unsafe_prob = float(logits[1]) # class index 1 = NSFW
|
| 38 |
-
return unsafe_prob
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# 2. Secondary NSFW checker (example: CLIP-based)
|
| 42 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 43 |
-
from transformers import AutoFeatureExtractor
|
| 44 |
-
|
| 45 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
| 46 |
-
clip_safety = StableDiffusionSafetyChecker.from_pretrained(
|
| 47 |
-
"CompVis/stable-diffusion-safety-checker"
|
| 48 |
-
).to(device)
|
| 49 |
-
|
| 50 |
-
def clip_check(image):
|
| 51 |
-
# convert to PIL list format
|
| 52 |
-
images = [image]
|
| 53 |
-
safety_inputs = feature_extractor(images, return_tensors="pt").to(device)
|
| 54 |
-
checked_images, has_nsfw = clip_safety(images, safety_inputs.pixel_values)
|
| 55 |
-
return bool(has_nsfw[0])
|
| 56 |
# @spaces.GPU #[uncomment to use ZeroGPU]
|
| 57 |
def infer(
|
| 58 |
prompt,
|
|
@@ -70,7 +39,6 @@ def infer(
|
|
| 70 |
|
| 71 |
generator = torch.Generator().manual_seed(seed)
|
| 72 |
|
| 73 |
-
# generate image normally
|
| 74 |
image = pipe(
|
| 75 |
prompt=prompt,
|
| 76 |
negative_prompt=negative_prompt,
|
|
@@ -81,22 +49,6 @@ def infer(
|
|
| 81 |
generator=generator,
|
| 82 |
).images[0]
|
| 83 |
|
| 84 |
-
# ---------------- SAFETY CHECK 1: FalconsAI ---------------- #
|
| 85 |
-
unsafe_prob = falcon_check(image)
|
| 86 |
-
if unsafe_prob > 0.35: # adjust threshold if needed
|
| 87 |
-
return (
|
| 88 |
-
np.zeros((512, 512, 3), dtype=np.uint8), # blank output
|
| 89 |
-
f"⚠️ Blocked by FalconsAI safety model | score={unsafe_prob:.2f}",
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
# ---------------- SAFETY CHECK 2: CLIP Checker ---------------- #
|
| 93 |
-
if clip_check(image):
|
| 94 |
-
return (
|
| 95 |
-
np.zeros((512, 512, 3), dtype=np.uint8),
|
| 96 |
-
"⛔ Blocked by CLIP NSFW detector",
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
# no flags — return real image
|
| 100 |
return image, seed
|
| 101 |
|
| 102 |
|
|
@@ -135,8 +87,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 135 |
label="Negative prompt",
|
| 136 |
max_lines=1,
|
| 137 |
placeholder="Enter a negative prompt",
|
| 138 |
-
value = "(
|
| 139 |
-
visible=
|
| 140 |
)
|
| 141 |
|
| 142 |
seed = gr.Slider(
|
|
|
|
| 21 |
MAX_SEED = np.iinfo(np.int32).max
|
| 22 |
MAX_IMAGE_SIZE = 1024
|
| 23 |
|
| 24 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# @spaces.GPU #[uncomment to use ZeroGPU]
|
| 26 |
def infer(
|
| 27 |
prompt,
|
|
|
|
| 39 |
|
| 40 |
generator = torch.Generator().manual_seed(seed)
|
| 41 |
|
|
|
|
| 42 |
image = pipe(
|
| 43 |
prompt=prompt,
|
| 44 |
negative_prompt=negative_prompt,
|
|
|
|
| 49 |
generator=generator,
|
| 50 |
).images[0]
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
return image, seed
|
| 53 |
|
| 54 |
|
|
|
|
| 87 |
label="Negative prompt",
|
| 88 |
max_lines=1,
|
| 89 |
placeholder="Enter a negative prompt",
|
| 90 |
+
value = "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
| 91 |
+
visible=True,
|
| 92 |
)
|
| 93 |
|
| 94 |
seed = gr.Slider(
|