Spaces:
Running
Running
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| from transformers import CLIPFeatureExtractor | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from typing import Optional, Tuple, Union | |
| device = None | |
| torch_device = None | |
| torch_dtype = None | |
| safety_checker = None | |
| feature_extractor = None | |
| def load_model(): | |
| global device, torch_device, torch_dtype, safety_checker, feature_extractor | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch_device = device | |
| torch_dtype = torch.float16 | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
| "CompVis/stable-diffusion-safety-checker" | |
| ).to(device) | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
| "openai/clip-vit-base-patch32" | |
| ) | |
| def check(image): | |
| if not image: | |
| return None | |
| images = [image] | |
| safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) | |
| images_np = [np.array(img) for img in images] | |
| _, has_nsfw_concepts = safety_checker( | |
| images=images_np, | |
| clip_input=safety_checker_input.pixel_values.to(torch_device), | |
| ) | |
| return has_nsfw_concepts[0] | |