Auto-enlister / app.py
usef143's picture
Upload 2 files
0e402a4 verified
import torch
from PIL import Image
import requests
from io import BytesIO
import clip
from nudenet import NudeDetector
import gradio as gr
class ImageTextMatcher:
def __init__(self, nsfw_threshold=0.5):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.clip_model, self.preprocess = clip.load("ViT-B/32", device=self.device)
self.nsfw_detector = NudeDetector()
self.nsfw_threshold = nsfw_threshold
def load_image(self, image_path_or_url):
try:
if image_path_or_url.startswith(('http://', 'https://')):
response = requests.get(image_path_or_url)
img = Image.open(BytesIO(response.content))
else:
img = Image.open(image_path_or_url)
return img
except Exception:
return None
def check_nsfw_content(self, image, image_path=None):
if image is None:
return {"is_nsfw": False, "error": "Failed to load image"}
try:
if image_path is None or image_path.startswith(('http://', 'https://')):
temp_path = "temp_image.jpg"
image.save(temp_path)
image_path = temp_path
detections = self.nsfw_detector.detect(image_path)
if not detections:
nsfw_score = 0.0
is_nsfw = False
else:
nsfw_scores = [d['score'] for d in detections]
nsfw_score = sum(nsfw_scores) / len(nsfw_scores)
is_nsfw = nsfw_score > self.nsfw_threshold
return {
"is_nsfw": is_nsfw,
"nsfw_score": nsfw_score
}
except Exception as e:
return {"is_nsfw": False, "error": str(e)}
def check_image_text_match(self, image, description, threshold=0.25):
if image is None:
return {"match": False, "similarity_score": 0.0, "error": "Failed to load image"}
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
text_inputs = clip.tokenize([description]).to(self.device)
with torch.no_grad():
image_features = self.clip_model.encode_image(image_input)
text_features = self.clip_model.encode_text(text_inputs)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
similarity = torch.matmul(image_features, text_features.T).item()
return {
"match": similarity > threshold,
"similarity_score": float(similarity)
}
def validate_item(self, image_url, description, match_threshold=0.25):
image = self.load_image(image_url)
if image is None:
return {"valid": 0, "reason": "Failed to load image"}
nsfw_result = self.check_nsfw_content(image, image_url)
if nsfw_result.get("is_nsfw", False):
return {"valid": 0, "reason": "NSFW content detected"}
match_result = self.check_image_text_match(image, description, match_threshold)
if not match_result.get("match", False):
return {"valid": 0, "reason": "No match with description"}
return {"valid": 1, "reason": None}
# πŸ” Load once globally
matcher = ImageTextMatcher()
# πŸ” Function for Gradio API
def run_validation(image_url, description):
return matcher.validate_item(image_url, description)
# πŸ–₯️ Gradio interface
iface = gr.Interface(
fn=run_validation,
inputs=[
gr.Textbox(label="Image URL"),
gr.Textbox(label="Item Description")
],
outputs="json",
title="Image Validator API",
description="Returns 1 if image is valid (matches description & is safe), else 0 and reason."
)
# 🌐 Run
if __name__ == "__main__":
iface.launch()