Spaces:
Sleeping
Sleeping
| import re | |
| import requests | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from transformers import ImageClassificationPipeline | |
| def get_normal_classifier(items: list[object])->object | None: | |
| normal_classifier = next((item for item in items if item["label"] == "normal"), None) | |
| return normal_classifier | |
| def get_nsfw_classifier(items: list[object])->object | None: | |
| nsfw_classifier = next((item for item in items if item["label"] == "nsfw"), None) | |
| return nsfw_classifier | |
| def classify_image_if_nsfw(classifier: ImageClassificationPipeline, image_url: str): | |
| try: | |
| # Check if it's a base64 data URL | |
| if image_url.startswith('data:image'): | |
| print("Processing base64 data URL") | |
| # Extract the base64 data from the data URL | |
| match = re.match(r'data:image/(?P<ext>\w+);base64,(?P<data>.*)', image_url) | |
| if not match: | |
| raise ValueError("Invalid base64 data URL format") | |
| base64_data = match.group('data') | |
| image_format = match.group('ext') | |
| # Decode the base64 data | |
| image_data = base64.b64decode(base64_data) | |
| # Open the image from decoded data | |
| img = Image.open(BytesIO(image_data)) | |
| else: | |
| # It's a regular URL - download the image | |
| print("Processing regular URL") | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| # Open and process the image | |
| img = Image.open(BytesIO(response.content)) | |
| print("Image size:", img.size) | |
| print("Image format:", img.format) | |
| print("Image mode:", img.mode) | |
| # Ensure image is in RGB mode (required by most models) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Classify the image | |
| classifier_response = classifier(img) | |
| print("Classifier Response:", classifier_response) | |
| normal_classifier = classifier_response | |
| return classifier_response | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| raise |