import torch from PIL import Image import requests from io import BytesIO import clip from nudenet import NudeDetector import gradio as gr class ImageTextMatcher: def __init__(self, nsfw_threshold=0.5): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.clip_model, self.preprocess = clip.load("ViT-B/32", device=self.device) self.nsfw_detector = NudeDetector() self.nsfw_threshold = nsfw_threshold def load_image(self, image_path_or_url): try: if image_path_or_url.startswith(('http://', 'https://')): response = requests.get(image_path_or_url) img = Image.open(BytesIO(response.content)) else: img = Image.open(image_path_or_url) return img except Exception: return None def check_nsfw_content(self, image, image_path=None): if image is None: return {"is_nsfw": False, "error": "Failed to load image"} try: if image_path is None or image_path.startswith(('http://', 'https://')): temp_path = "temp_image.jpg" image.save(temp_path) image_path = temp_path detections = self.nsfw_detector.detect(image_path) if not detections: nsfw_score = 0.0 is_nsfw = False else: nsfw_scores = [d['score'] for d in detections] nsfw_score = sum(nsfw_scores) / len(nsfw_scores) is_nsfw = nsfw_score > self.nsfw_threshold return { "is_nsfw": is_nsfw, "nsfw_score": nsfw_score } except Exception as e: return {"is_nsfw": False, "error": str(e)} def check_image_text_match(self, image, description, threshold=0.25): if image is None: return {"match": False, "similarity_score": 0.0, "error": "Failed to load image"} image_input = self.preprocess(image).unsqueeze(0).to(self.device) text_inputs = clip.tokenize([description]).to(self.device) with torch.no_grad(): image_features = self.clip_model.encode_image(image_input) text_features = self.clip_model.encode_text(text_inputs) image_features = image_features / image_features.norm(dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True) similarity = torch.matmul(image_features, text_features.T).item() return { "match": similarity > threshold, "similarity_score": float(similarity) } def validate_item(self, image_url, description, match_threshold=0.25): image = self.load_image(image_url) if image is None: return {"valid": 0, "reason": "Failed to load image"} nsfw_result = self.check_nsfw_content(image, image_url) if nsfw_result.get("is_nsfw", False): return {"valid": 0, "reason": "NSFW content detected"} match_result = self.check_image_text_match(image, description, match_threshold) if not match_result.get("match", False): return {"valid": 0, "reason": "No match with description"} return {"valid": 1, "reason": None} # 🔁 Load once globally matcher = ImageTextMatcher() # 🔍 Function for Gradio API def run_validation(image_url, description): return matcher.validate_item(image_url, description) # 🖥️ Gradio interface iface = gr.Interface( fn=run_validation, inputs=[ gr.Textbox(label="Image URL"), gr.Textbox(label="Item Description") ], outputs="json", title="Image Validator API", description="Returns 1 if image is valid (matches description & is safe), else 0 and reason." ) # 🌐 Run if __name__ == "__main__": iface.launch()