Spaces:
Sleeping
Sleeping
File size: 2,109 Bytes
29c07d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import re
import requests
import base64
from PIL import Image
from io import BytesIO
from transformers import ImageClassificationPipeline
def get_normal_classifier(items: list[object])->object | None:
normal_classifier = next((item for item in items if item["label"] == "normal"), None)
return normal_classifier
def get_nsfw_classifier(items: list[object])->object | None:
nsfw_classifier = next((item for item in items if item["label"] == "nsfw"), None)
return nsfw_classifier
def classify_image_if_nsfw(classifier: ImageClassificationPipeline, image_url: str):
try:
# Check if it's a base64 data URL
if image_url.startswith('data:image'):
print("Processing base64 data URL")
# Extract the base64 data from the data URL
match = re.match(r'data:image/(?P<ext>\w+);base64,(?P<data>.*)', image_url)
if not match:
raise ValueError("Invalid base64 data URL format")
base64_data = match.group('data')
image_format = match.group('ext')
# Decode the base64 data
image_data = base64.b64decode(base64_data)
# Open the image from decoded data
img = Image.open(BytesIO(image_data))
else:
# It's a regular URL - download the image
print("Processing regular URL")
response = requests.get(image_url)
response.raise_for_status()
# Open and process the image
img = Image.open(BytesIO(response.content))
print("Image size:", img.size)
print("Image format:", img.format)
print("Image mode:", img.mode)
# Ensure image is in RGB mode (required by most models)
if img.mode != 'RGB':
img = img.convert('RGB')
# Classify the image
classifier_response = classifier(img)
print("Classifier Response:", classifier_response)
normal_classifier = classifier_response
return classifier_response
except Exception as e:
print(f"Error processing image: {e}")
raise |