Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import clip | |
| from nudenet import NudeDetector | |
| import gradio as gr | |
| class ImageTextMatcher: | |
| def __init__(self, nsfw_threshold=0.5): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.clip_model, self.preprocess = clip.load("ViT-B/32", device=self.device) | |
| self.nsfw_detector = NudeDetector() | |
| self.nsfw_threshold = nsfw_threshold | |
| def load_image(self, image_path_or_url): | |
| try: | |
| if image_path_or_url.startswith(('http://', 'https://')): | |
| response = requests.get(image_path_or_url) | |
| img = Image.open(BytesIO(response.content)) | |
| else: | |
| img = Image.open(image_path_or_url) | |
| return img | |
| except Exception: | |
| return None | |
| def check_nsfw_content(self, image, image_path=None): | |
| if image is None: | |
| return {"is_nsfw": False, "error": "Failed to load image"} | |
| try: | |
| if image_path is None or image_path.startswith(('http://', 'https://')): | |
| temp_path = "temp_image.jpg" | |
| image.save(temp_path) | |
| image_path = temp_path | |
| detections = self.nsfw_detector.detect(image_path) | |
| if not detections: | |
| nsfw_score = 0.0 | |
| is_nsfw = False | |
| else: | |
| nsfw_scores = [d['score'] for d in detections] | |
| nsfw_score = sum(nsfw_scores) / len(nsfw_scores) | |
| is_nsfw = nsfw_score > self.nsfw_threshold | |
| return { | |
| "is_nsfw": is_nsfw, | |
| "nsfw_score": nsfw_score | |
| } | |
| except Exception as e: | |
| return {"is_nsfw": False, "error": str(e)} | |
| def check_image_text_match(self, image, description, threshold=0.25): | |
| if image is None: | |
| return {"match": False, "similarity_score": 0.0, "error": "Failed to load image"} | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| text_inputs = clip.tokenize([description]).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.clip_model.encode_image(image_input) | |
| text_features = self.clip_model.encode_text(text_inputs) | |
| image_features = image_features / image_features.norm(dim=1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
| similarity = torch.matmul(image_features, text_features.T).item() | |
| return { | |
| "match": similarity > threshold, | |
| "similarity_score": float(similarity) | |
| } | |
| def validate_item(self, image_url, description, match_threshold=0.25): | |
| image = self.load_image(image_url) | |
| if image is None: | |
| return {"valid": 0, "reason": "Failed to load image"} | |
| nsfw_result = self.check_nsfw_content(image, image_url) | |
| if nsfw_result.get("is_nsfw", False): | |
| return {"valid": 0, "reason": "NSFW content detected"} | |
| match_result = self.check_image_text_match(image, description, match_threshold) | |
| if not match_result.get("match", False): | |
| return {"valid": 0, "reason": "No match with description"} | |
| return {"valid": 1, "reason": None} | |
| # π Load once globally | |
| matcher = ImageTextMatcher() | |
| # π Function for Gradio API | |
| def run_validation(image_url, description): | |
| return matcher.validate_item(image_url, description) | |
| # π₯οΈ Gradio interface | |
| iface = gr.Interface( | |
| fn=run_validation, | |
| inputs=[ | |
| gr.Textbox(label="Image URL"), | |
| gr.Textbox(label="Item Description") | |
| ], | |
| outputs="json", | |
| title="Image Validator API", | |
| description="Returns 1 if image is valid (matches description & is safe), else 0 and reason." | |
| ) | |
| # π Run | |
| if __name__ == "__main__": | |
| iface.launch() | |