Spaces:
Sleeping
Sleeping
| import cv2 | |
| import tempfile | |
| import requests | |
| import os | |
| from PIL import Image | |
| from transformers import pipeline | |
| import torch | |
| # 🔥 SPEED BOOST SETTINGS | |
| torch.set_grad_enabled(False) | |
| torch.set_num_threads(2) | |
| # 🔥 Faster NSFW model | |
| classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") | |
| # ----------------------------- | |
| # Download with retry + headers (FIX CATBOX) | |
| # ----------------------------- | |
| def download_file(url): | |
| headers = { | |
| "User-Agent": "Mozilla/5.0", | |
| "Accept": "*/*", | |
| "Connection": "keep-alive", | |
| "Range": "bytes=0-" | |
| } | |
| for _ in range(3): # retry | |
| try: | |
| response = requests.get( | |
| url, | |
| headers=headers, | |
| stream=True, | |
| timeout=10 | |
| ) | |
| if response.status_code != 200: | |
| continue | |
| tmp = tempfile.NamedTemporaryFile(delete=False) | |
| for chunk in response.iter_content(1024 * 1024): | |
| if chunk: | |
| tmp.write(chunk) | |
| tmp.close() | |
| return tmp.name | |
| except requests.exceptions.RequestException: | |
| continue | |
| raise Exception("Failed to fetch file") | |
| # ----------------------------- | |
| # Video duration | |
| # ----------------------------- | |
| def get_video_duration(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| cap.release() | |
| return frames / fps if fps > 0 else 0 | |
| # ----------------------------- | |
| # Extract frame | |
| # ----------------------------- | |
| def extract_frame(video_path, second): | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_no = int(fps * second) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no) | |
| success, frame = cap.read() | |
| cap.release() | |
| if not success: | |
| return None | |
| tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) | |
| cv2.imwrite(tmp.name, frame) | |
| return tmp.name | |
| # ----------------------------- | |
| # FAST frame selection | |
| # ----------------------------- | |
| def get_frame_times(duration): | |
| if duration <= 3: | |
| return [1] | |
| elif duration <= 10: | |
| return [2] | |
| else: | |
| return [3, 8] # max 2 frames (FAST) | |
| # ----------------------------- | |
| # Image NSFW check (OPTIMIZED) | |
| # ----------------------------- | |
| def check_image_nsfw(image_path): | |
| image = Image.open(image_path).convert("RGB") | |
| result = classifier(image) | |
| for r in result: | |
| if r["label"] == "nsfw" and r["score"] > 0.5: | |
| return True | |
| return False | |
| # ----------------------------- | |
| # Video NSFW check | |
| # ----------------------------- | |
| def check_video_nsfw(video_path): | |
| size_mb = os.path.getsize(video_path) / (1024 * 1024) | |
| duration = get_video_duration(video_path) | |
| times = get_frame_times(duration, size_mb) | |
| for t in times: | |
| frame = extract_frame(video_path, t) | |
| if frame: | |
| if check_image_nsfw(frame): | |
| return True # 🚨 return immediately if ANY frame is NSFW | |
| return False |