kfvideodt / sequential_moderation.py
Haiss123's picture
Update sequential_moderation.py
602855a verified
import cv2
import torch
import os
import warnings
from typing import Dict
from dataclasses import dataclass
warnings.filterwarnings('ignore')
try:
from ultralytics import YOLO
from transformers import pipeline
from PIL import Image
except ImportError as e:
print(f"Missing dependency: {e}")
@dataclass
class DetectionResult:
"""Simple detection result"""
nude_count: int = 0
gun_count: int = 0
knife_count: int = 0
fight_count: int = 0
is_safe: bool = True
def to_dict(self):
return {
'nude': self.nude_count,
'gun': self.gun_count,
'knife': self.knife_count,
'fight': self.fight_count,
'is_safe': self.is_safe
}
class SmartSequentialModerator:
"""
Smart Sequential Pipeline with balanced thresholds:
1. NSFW Check with BALANCED threshold
2. Only if NSFW is clean β†’ Check Weapons/Fights
"""
def __init__(self):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Models
self.nsfw_classifier = None
self.weapon_model = None
# BALANCED Thresholds
self.nsfw_threshold = 0.75 # Balanced: not too high, not too low
self.nsfw_safe_threshold = 0.25 # If below this, definitely safe
self.gun_threshold = 0.7
self.knife_threshold = 0.65
self.fight_threshold = 0.75
print(f"πŸš€ Smart Sequential Moderator initialized on {self.device}")
print(f"πŸ“‹ Pipeline: NSFW (0.75) β†’ Weapons/Fights")
self._setup_models()
def _setup_models(self):
"""Initialize models"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 1. NSFW Classifier (PRIORITY)
self._setup_nsfw()
# 2. Weapon/Fight Model
self._setup_weapons()
print("βœ… All models ready!")
except Exception as e:
print(f"❌ Setup error: {e}")
def _setup_nsfw(self):
"""Setup NSFW classifier"""
try:
print("πŸ”ž Loading NSFW classifier...")
device_id = 0 if self.device == 'cuda' else -1
# Use the NSFW detection model
self.nsfw_classifier = pipeline(
"image-classification",
model="Falconsai/nsfw_image_detection",
device=device_id
)
print("βœ… NSFW classifier loaded")
except Exception as e:
print(f"⚠️ NSFW failed: {e}")
self.nsfw_classifier = None
def _setup_weapons(self):
"""Setup weapon/fight model"""
try:
print("πŸ”« Loading weapon/fight model...")
# Custom model path
custom_path = "models/best_ft4.pt"
if os.path.exists(custom_path):
self.weapon_model = YOLO(custom_path)
print(f"βœ… Custom model loaded")
# Show available classes
if hasattr(self.weapon_model, 'names'):
classes = list(self.weapon_model.names.values())
print(f" Classes: {classes}")
else:
# Fallback
self.weapon_model = YOLO('yolo11n.pt')
print("βœ… General model loaded")
except Exception as e:
print(f"⚠️ Weapon model failed: {e}")
self.weapon_model = None
def process_image(self, image) -> DetectionResult:
"""
STRICT SEQUENTIAL:
1. NSFW first (balanced threshold)
2. If NSFW detected β†’ STOP
3. If clean β†’ check weapons/fights
"""
result = DetectionResult()
try:
# Load image
if isinstance(image, str):
image = cv2.imread(image)
if image is None:
return result
print(f"\n{'=' * 40}")
print(f"πŸ“Έ Processing: {image.shape}")
# ========== STAGE 1: NSFW ==========
print("\nπŸ”ž Stage 1: NSFW Check")
nsfw_score = self._check_nsfw(image)
if nsfw_score > self.nsfw_threshold:
print(f" 🚨 NSFW DETECTED: {nsfw_score:.3f}")
print(f" β›” STOPPING - Returning NSFW only")
result.nude_count = 1
result.is_safe = False
return result # STOP HERE
elif nsfw_score < self.nsfw_safe_threshold:
print(f" βœ… Definitely safe: {nsfw_score:.3f}")
else:
print(f" ⚠️ Borderline safe: {nsfw_score:.3f} - Continuing checks")
# ========== STAGE 2: WEAPONS/FIGHTS ==========
print("\nπŸ”« Stage 2: Weapons & Fights")
if self.weapon_model:
detections = self._detect_threats(image)
result.gun_count = detections['guns']
result.knife_count = detections['knives']
result.fight_count = detections['fights']
if detections['total'] > 0:
print(f" Found: G:{detections['guns']} K:{detections['knives']} F:{detections['fights']}")
# Final safety
total = result.nude_count + result.gun_count + result.knife_count + result.fight_count
result.is_safe = (total == 0)
print(
f"\nπŸ“Š Result: N:{result.nude_count} G:{result.gun_count} K:{result.knife_count} F:{result.fight_count} Safe:{result.is_safe}")
print(f"{'=' * 40}\n")
return result
except Exception as e:
print(f"❌ Error: {e}")
return result
def _check_nsfw(self, image) -> float:
"""
Check NSFW with proper scoring
Returns confidence score (0-1)
"""
try:
if not self.nsfw_classifier:
return 0.0
# Convert to RGB
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_image)
# Run classifier
results = self.nsfw_classifier(pil_image)
# Get NSFW score
nsfw_score = 0.0
for result in results:
label = result['label'].lower()
score = result['score']
# Check for NSFW label
if 'nsfw' in label or 'unsafe' in label or 'explicit' in label:
nsfw_score = max(nsfw_score, score)
print(f" {label}: {score:.3f}")
return nsfw_score
except Exception as e:
print(f" ⚠️ NSFW error: {e}")
return 0.0
def _detect_threats(self, image) -> Dict[str, int]:
"""Detect weapons and fights"""
counts = {
'guns': 0,
'knives': 0,
'fights': 0,
'total': 0
}
try:
# Run detection with low base threshold
results = self.weapon_model(
image,
conf=0.4, # Low base threshold
device=self.device,
verbose=False
)
for result in results:
if result.boxes is None:
continue
for box in result.boxes:
class_id = int(box.cls[0])
confidence = float(box.conf[0])
if hasattr(result, 'names'):
class_name = result.names[class_id].lower()
else:
continue
# Check each category with proper threshold
if self._is_gun(class_name) and confidence > self.gun_threshold:
counts['guns'] += 1
elif self._is_knife(class_name) and confidence > self.knife_threshold:
counts['knives'] += 1
elif self._is_fight(class_name) and confidence > self.fight_threshold:
counts['fights'] += 1
counts['total'] = counts['guns'] + counts['knives'] + counts['fights']
return counts
except Exception as e:
print(f" ⚠️ Detection error: {e}")
return counts
def _is_gun(self, name: str) -> bool:
gun_words = ['gun', 'pistol', 'rifle', 'firearm', 'sΓΊng']
return any(w in name for w in gun_words)
def _is_knife(self, name: str) -> bool:
knife_words = ['knife', 'dao', 'blade', 'sword']
return any(w in name for w in knife_words)
def _is_fight(self, name: str) -> bool:
fight_words = ['fight', 'fighting', 'combat', 'violence']
return any(w in name for w in fight_words)
def process_video(self, video_path: str) -> Dict:
"""
Process video with SMART frame skipping
Auto-adjusts based on video duration
"""
total = DetectionResult()
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return total.to_dict()
# Get video info
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
# SMART frame skip based on duration
if duration <= 10: # Short video
frame_skip = 5 # Check every 5th frame
max_frames = 100
elif duration <= 30:
frame_skip = 10 # Check every 10th frame
max_frames = 150
elif duration <= 60:
frame_skip = 15
max_frames = 200
else: # Long video
frame_skip = 30
max_frames = 300
print(f"\nπŸ“Ή Video: {duration:.1f}s, {total_frames} frames")
print(f" Auto settings: skip={frame_skip}, max={max_frames}")
frame_count = 0
processed = 0
nsfw_strikes = 0 # Count NSFW detections
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# Skip frames
if frame_count % frame_skip != 0:
continue
# Max frame limit
if processed >= max_frames:
break
processed += 1
# Process frame
result = self.process_image(frame)
# Accumulate
total.nude_count += result.nude_count
total.gun_count += result.gun_count
total.knife_count += result.knife_count
total.fight_count += result.fight_count
# Early stop on multiple NSFW
if result.nude_count > 0:
nsfw_strikes += 1
if nsfw_strikes >= 3: # Stop after 3 NSFW frames
print(f"β›” Early stop: {nsfw_strikes} NSFW frames")
break
# Progress
if processed % 50 == 0:
print(f" Processed {processed} frames...")
cap.release()
# Final safety
total_threats = total.nude_count + total.gun_count + total.knife_count + total.fight_count
total.is_safe = (total_threats == 0)
print(f"\nπŸ“Š Video complete: {processed} frames analyzed")
print(f" Total: N:{total.nude_count} G:{total.gun_count} K:{total.knife_count} F:{total.fight_count}")
return total.to_dict()
except Exception as e:
print(f"❌ Video error: {e}")
return total.to_dict()
def main():
"""Test the moderator"""
moderator = SmartSequentialModerator()
print("\n" + "=" * 50)
print("🎯 SMART SEQUENTIAL MODERATOR")
print("=" * 50)
print("β€’ Balanced NSFW threshold: 0.75")
print("β€’ Auto frame skipping for videos")
print("β€’ Simple output: counts + boolean")
print("=" * 50)
# Test
test_image = "test.jpg"
if os.path.exists(test_image):
result = moderator.process_image(test_image)
print(f"\nResult: {result.to_dict()}")
if __name__ == "__main__":
main()