|
|
import cv2 |
|
|
import torch |
|
|
import os |
|
|
import warnings |
|
|
from typing import Dict |
|
|
from dataclasses import dataclass |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
try: |
|
|
from ultralytics import YOLO |
|
|
from transformers import pipeline |
|
|
from PIL import Image |
|
|
except ImportError as e: |
|
|
print(f"Missing dependency: {e}") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DetectionResult: |
|
|
"""Simple detection result""" |
|
|
nude_count: int = 0 |
|
|
gun_count: int = 0 |
|
|
knife_count: int = 0 |
|
|
fight_count: int = 0 |
|
|
is_safe: bool = True |
|
|
|
|
|
def to_dict(self): |
|
|
return { |
|
|
'nude': self.nude_count, |
|
|
'gun': self.gun_count, |
|
|
'knife': self.knife_count, |
|
|
'fight': self.fight_count, |
|
|
'is_safe': self.is_safe |
|
|
} |
|
|
|
|
|
|
|
|
class SmartSequentialModerator: |
|
|
""" |
|
|
Smart Sequential Pipeline with balanced thresholds: |
|
|
1. NSFW Check with BALANCED threshold |
|
|
2. Only if NSFW is clean β Check Weapons/Fights |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
self.nsfw_classifier = None |
|
|
self.weapon_model = None |
|
|
|
|
|
|
|
|
self.nsfw_threshold = 0.75 |
|
|
self.nsfw_safe_threshold = 0.25 |
|
|
self.gun_threshold = 0.7 |
|
|
self.knife_threshold = 0.65 |
|
|
self.fight_threshold = 0.75 |
|
|
|
|
|
print(f"π Smart Sequential Moderator initialized on {self.device}") |
|
|
print(f"π Pipeline: NSFW (0.75) β Weapons/Fights") |
|
|
|
|
|
self._setup_models() |
|
|
|
|
|
def _setup_models(self): |
|
|
"""Initialize models""" |
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
self._setup_nsfw() |
|
|
|
|
|
|
|
|
self._setup_weapons() |
|
|
|
|
|
print("β
All models ready!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Setup error: {e}") |
|
|
|
|
|
def _setup_nsfw(self): |
|
|
"""Setup NSFW classifier""" |
|
|
try: |
|
|
print("π Loading NSFW classifier...") |
|
|
|
|
|
device_id = 0 if self.device == 'cuda' else -1 |
|
|
|
|
|
|
|
|
self.nsfw_classifier = pipeline( |
|
|
"image-classification", |
|
|
model="Falconsai/nsfw_image_detection", |
|
|
device=device_id |
|
|
) |
|
|
print("β
NSFW classifier loaded") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ NSFW failed: {e}") |
|
|
self.nsfw_classifier = None |
|
|
|
|
|
def _setup_weapons(self): |
|
|
"""Setup weapon/fight model""" |
|
|
try: |
|
|
print("π« Loading weapon/fight model...") |
|
|
|
|
|
|
|
|
custom_path = "models/best_ft4.pt" |
|
|
if os.path.exists(custom_path): |
|
|
self.weapon_model = YOLO(custom_path) |
|
|
print(f"β
Custom model loaded") |
|
|
|
|
|
|
|
|
if hasattr(self.weapon_model, 'names'): |
|
|
classes = list(self.weapon_model.names.values()) |
|
|
print(f" Classes: {classes}") |
|
|
else: |
|
|
|
|
|
self.weapon_model = YOLO('yolo11n.pt') |
|
|
print("β
General model loaded") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Weapon model failed: {e}") |
|
|
self.weapon_model = None |
|
|
|
|
|
def process_image(self, image) -> DetectionResult: |
|
|
""" |
|
|
STRICT SEQUENTIAL: |
|
|
1. NSFW first (balanced threshold) |
|
|
2. If NSFW detected β STOP |
|
|
3. If clean β check weapons/fights |
|
|
""" |
|
|
|
|
|
result = DetectionResult() |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(image, str): |
|
|
image = cv2.imread(image) |
|
|
if image is None: |
|
|
return result |
|
|
|
|
|
print(f"\n{'=' * 40}") |
|
|
print(f"πΈ Processing: {image.shape}") |
|
|
|
|
|
|
|
|
print("\nπ Stage 1: NSFW Check") |
|
|
|
|
|
nsfw_score = self._check_nsfw(image) |
|
|
|
|
|
if nsfw_score > self.nsfw_threshold: |
|
|
print(f" π¨ NSFW DETECTED: {nsfw_score:.3f}") |
|
|
print(f" β STOPPING - Returning NSFW only") |
|
|
|
|
|
result.nude_count = 1 |
|
|
result.is_safe = False |
|
|
return result |
|
|
|
|
|
elif nsfw_score < self.nsfw_safe_threshold: |
|
|
print(f" β
Definitely safe: {nsfw_score:.3f}") |
|
|
else: |
|
|
print(f" β οΈ Borderline safe: {nsfw_score:.3f} - Continuing checks") |
|
|
|
|
|
|
|
|
print("\nπ« Stage 2: Weapons & Fights") |
|
|
|
|
|
if self.weapon_model: |
|
|
detections = self._detect_threats(image) |
|
|
result.gun_count = detections['guns'] |
|
|
result.knife_count = detections['knives'] |
|
|
result.fight_count = detections['fights'] |
|
|
|
|
|
if detections['total'] > 0: |
|
|
print(f" Found: G:{detections['guns']} K:{detections['knives']} F:{detections['fights']}") |
|
|
|
|
|
|
|
|
total = result.nude_count + result.gun_count + result.knife_count + result.fight_count |
|
|
result.is_safe = (total == 0) |
|
|
|
|
|
print( |
|
|
f"\nπ Result: N:{result.nude_count} G:{result.gun_count} K:{result.knife_count} F:{result.fight_count} Safe:{result.is_safe}") |
|
|
print(f"{'=' * 40}\n") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error: {e}") |
|
|
return result |
|
|
|
|
|
def _check_nsfw(self, image) -> float: |
|
|
""" |
|
|
Check NSFW with proper scoring |
|
|
Returns confidence score (0-1) |
|
|
""" |
|
|
try: |
|
|
if not self.nsfw_classifier: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(rgb_image) |
|
|
|
|
|
|
|
|
results = self.nsfw_classifier(pil_image) |
|
|
|
|
|
|
|
|
nsfw_score = 0.0 |
|
|
for result in results: |
|
|
label = result['label'].lower() |
|
|
score = result['score'] |
|
|
|
|
|
|
|
|
if 'nsfw' in label or 'unsafe' in label or 'explicit' in label: |
|
|
nsfw_score = max(nsfw_score, score) |
|
|
print(f" {label}: {score:.3f}") |
|
|
|
|
|
return nsfw_score |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β οΈ NSFW error: {e}") |
|
|
return 0.0 |
|
|
|
|
|
def _detect_threats(self, image) -> Dict[str, int]: |
|
|
"""Detect weapons and fights""" |
|
|
counts = { |
|
|
'guns': 0, |
|
|
'knives': 0, |
|
|
'fights': 0, |
|
|
'total': 0 |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
results = self.weapon_model( |
|
|
image, |
|
|
conf=0.4, |
|
|
device=self.device, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
for result in results: |
|
|
if result.boxes is None: |
|
|
continue |
|
|
|
|
|
for box in result.boxes: |
|
|
class_id = int(box.cls[0]) |
|
|
confidence = float(box.conf[0]) |
|
|
|
|
|
if hasattr(result, 'names'): |
|
|
class_name = result.names[class_id].lower() |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
if self._is_gun(class_name) and confidence > self.gun_threshold: |
|
|
counts['guns'] += 1 |
|
|
|
|
|
elif self._is_knife(class_name) and confidence > self.knife_threshold: |
|
|
counts['knives'] += 1 |
|
|
|
|
|
elif self._is_fight(class_name) and confidence > self.fight_threshold: |
|
|
counts['fights'] += 1 |
|
|
|
|
|
counts['total'] = counts['guns'] + counts['knives'] + counts['fights'] |
|
|
return counts |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β οΈ Detection error: {e}") |
|
|
return counts |
|
|
|
|
|
def _is_gun(self, name: str) -> bool: |
|
|
gun_words = ['gun', 'pistol', 'rifle', 'firearm', 'sΓΊng'] |
|
|
return any(w in name for w in gun_words) |
|
|
|
|
|
def _is_knife(self, name: str) -> bool: |
|
|
knife_words = ['knife', 'dao', 'blade', 'sword'] |
|
|
return any(w in name for w in knife_words) |
|
|
|
|
|
def _is_fight(self, name: str) -> bool: |
|
|
fight_words = ['fight', 'fighting', 'combat', 'violence'] |
|
|
return any(w in name for w in fight_words) |
|
|
|
|
|
def process_video(self, video_path: str) -> Dict: |
|
|
""" |
|
|
Process video with SMART frame skipping |
|
|
Auto-adjusts based on video duration |
|
|
""" |
|
|
|
|
|
total = DetectionResult() |
|
|
|
|
|
try: |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return total.to_dict() |
|
|
|
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
duration = total_frames / fps if fps > 0 else 0 |
|
|
|
|
|
|
|
|
if duration <= 10: |
|
|
frame_skip = 5 |
|
|
max_frames = 100 |
|
|
elif duration <= 30: |
|
|
frame_skip = 10 |
|
|
max_frames = 150 |
|
|
elif duration <= 60: |
|
|
frame_skip = 15 |
|
|
max_frames = 200 |
|
|
else: |
|
|
frame_skip = 30 |
|
|
max_frames = 300 |
|
|
|
|
|
print(f"\nπΉ Video: {duration:.1f}s, {total_frames} frames") |
|
|
print(f" Auto settings: skip={frame_skip}, max={max_frames}") |
|
|
|
|
|
frame_count = 0 |
|
|
processed = 0 |
|
|
nsfw_strikes = 0 |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_count += 1 |
|
|
|
|
|
|
|
|
if frame_count % frame_skip != 0: |
|
|
continue |
|
|
|
|
|
|
|
|
if processed >= max_frames: |
|
|
break |
|
|
|
|
|
processed += 1 |
|
|
|
|
|
|
|
|
result = self.process_image(frame) |
|
|
|
|
|
|
|
|
total.nude_count += result.nude_count |
|
|
total.gun_count += result.gun_count |
|
|
total.knife_count += result.knife_count |
|
|
total.fight_count += result.fight_count |
|
|
|
|
|
|
|
|
if result.nude_count > 0: |
|
|
nsfw_strikes += 1 |
|
|
if nsfw_strikes >= 3: |
|
|
print(f"β Early stop: {nsfw_strikes} NSFW frames") |
|
|
break |
|
|
|
|
|
|
|
|
if processed % 50 == 0: |
|
|
print(f" Processed {processed} frames...") |
|
|
|
|
|
cap.release() |
|
|
|
|
|
|
|
|
total_threats = total.nude_count + total.gun_count + total.knife_count + total.fight_count |
|
|
total.is_safe = (total_threats == 0) |
|
|
|
|
|
print(f"\nπ Video complete: {processed} frames analyzed") |
|
|
print(f" Total: N:{total.nude_count} G:{total.gun_count} K:{total.knife_count} F:{total.fight_count}") |
|
|
|
|
|
return total.to_dict() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Video error: {e}") |
|
|
return total.to_dict() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Test the moderator""" |
|
|
|
|
|
moderator = SmartSequentialModerator() |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("π― SMART SEQUENTIAL MODERATOR") |
|
|
print("=" * 50) |
|
|
print("β’ Balanced NSFW threshold: 0.75") |
|
|
print("β’ Auto frame skipping for videos") |
|
|
print("β’ Simple output: counts + boolean") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
test_image = "test.jpg" |
|
|
if os.path.exists(test_image): |
|
|
result = moderator.process_image(test_image) |
|
|
print(f"\nResult: {result.to_dict()}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |