Haiss123 commited on
Commit
70b0f71
·
verified ·
1 Parent(s): 36baa61

Upload sequential_moderation.py

Browse files
Files changed (1) hide show
  1. sequential_moderation.py +393 -0
sequential_moderation.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import os
4
+ import warnings
5
+ from typing import Dict
6
+ from dataclasses import dataclass
7
+
8
+ warnings.filterwarnings('ignore')
9
+
10
+ try:
11
+ from ultralytics import YOLO
12
+ from transformers import pipeline
13
+ from PIL import Image
14
+ except ImportError as e:
15
+ print(f"Missing dependency: {e}")
16
+
17
+
18
+ @dataclass
19
+ class DetectionResult:
20
+ """Simple detection result"""
21
+ nude_count: int = 0
22
+ gun_count: int = 0
23
+ knife_count: int = 0
24
+ fight_count: int = 0
25
+ is_safe: bool = True
26
+
27
+ def to_dict(self):
28
+ return {
29
+ 'nude': self.nude_count,
30
+ 'gun': self.gun_count,
31
+ 'knife': self.knife_count,
32
+ 'fight': self.fight_count,
33
+ 'is_safe': self.is_safe
34
+ }
35
+
36
+
37
+ class SmartSequentialModerator:
38
+ """
39
+ Smart Sequential Pipeline with balanced thresholds:
40
+ 1. NSFW Check with BALANCED threshold
41
+ 2. Only if NSFW is clean → Check Weapons/Fights
42
+ """
43
+
44
+ def __init__(self):
45
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
+
47
+ # Models
48
+ self.nsfw_classifier = None
49
+ self.weapon_model = None
50
+
51
+ # BALANCED Thresholds
52
+ self.nsfw_threshold = 0.75 # Balanced: not too high, not too low
53
+ self.nsfw_safe_threshold = 0.25 # If below this, definitely safe
54
+ self.gun_threshold = 0.7
55
+ self.knife_threshold = 0.65
56
+ self.fight_threshold = 0.25
57
+
58
+ print(f"🚀 Smart Sequential Moderator initialized on {self.device}")
59
+ print(f"📋 Pipeline: NSFW (0.75) → Weapons/Fights")
60
+
61
+ self._setup_models()
62
+
63
+ def _setup_models(self):
64
+ """Initialize models"""
65
+ try:
66
+ if torch.cuda.is_available():
67
+ torch.cuda.empty_cache()
68
+
69
+ # 1. NSFW Classifier (PRIORITY)
70
+ self._setup_nsfw()
71
+
72
+ # 2. Weapon/Fight Model
73
+ self._setup_weapons()
74
+
75
+ print("✅ All models ready!")
76
+
77
+ except Exception as e:
78
+ print(f"❌ Setup error: {e}")
79
+
80
+ def _setup_nsfw(self):
81
+ """Setup NSFW classifier"""
82
+ try:
83
+ print("🔞 Loading NSFW classifier...")
84
+
85
+ device_id = 0 if self.device == 'cuda' else -1
86
+
87
+ # Use the NSFW detection model
88
+ self.nsfw_classifier = pipeline(
89
+ "image-classification",
90
+ model="Falconsai/nsfw_image_detection",
91
+ device=device_id
92
+ )
93
+ print("✅ NSFW classifier loaded")
94
+
95
+ except Exception as e:
96
+ print(f"⚠️ NSFW failed: {e}")
97
+ self.nsfw_classifier = None
98
+
99
+ def _setup_weapons(self):
100
+ """Setup weapon/fight model"""
101
+ try:
102
+ print("🔫 Loading weapon/fight model...")
103
+
104
+ # Custom model path
105
+ custom_path = "models/best_ft4.pt"
106
+ if os.path.exists(custom_path):
107
+ self.weapon_model = YOLO(custom_path)
108
+ print(f"✅ Custom model loaded")
109
+
110
+ # Show available classes
111
+ if hasattr(self.weapon_model, 'names'):
112
+ classes = list(self.weapon_model.names.values())
113
+ print(f" Classes: {classes}")
114
+ else:
115
+ # Fallback
116
+ self.weapon_model = YOLO('yolo11n.pt')
117
+ print("✅ General model loaded")
118
+
119
+ except Exception as e:
120
+ print(f"⚠️ Weapon model failed: {e}")
121
+ self.weapon_model = None
122
+
123
+ def process_image(self, image) -> DetectionResult:
124
+ """
125
+ STRICT SEQUENTIAL:
126
+ 1. NSFW first (balanced threshold)
127
+ 2. If NSFW detected → STOP
128
+ 3. If clean → check weapons/fights
129
+ """
130
+
131
+ result = DetectionResult()
132
+
133
+ try:
134
+ # Load image
135
+ if isinstance(image, str):
136
+ image = cv2.imread(image)
137
+ if image is None:
138
+ return result
139
+
140
+ print(f"\n{'=' * 40}")
141
+ print(f"📸 Processing: {image.shape}")
142
+
143
+ # ========== STAGE 1: NSFW ==========
144
+ print("\n🔞 Stage 1: NSFW Check")
145
+
146
+ nsfw_score = self._check_nsfw(image)
147
+
148
+ if nsfw_score > self.nsfw_threshold:
149
+ print(f" 🚨 NSFW DETECTED: {nsfw_score:.3f}")
150
+ print(f" ⛔ STOPPING - Returning NSFW only")
151
+
152
+ result.nude_count = 1
153
+ result.is_safe = False
154
+ return result # STOP HERE
155
+
156
+ elif nsfw_score < self.nsfw_safe_threshold:
157
+ print(f" ✅ Definitely safe: {nsfw_score:.3f}")
158
+ else:
159
+ print(f" ⚠️ Borderline safe: {nsfw_score:.3f} - Continuing checks")
160
+
161
+ # ========== STAGE 2: WEAPONS/FIGHTS ==========
162
+ print("\n🔫 Stage 2: Weapons & Fights")
163
+
164
+ if self.weapon_model:
165
+ detections = self._detect_threats(image)
166
+ result.gun_count = detections['guns']
167
+ result.knife_count = detections['knives']
168
+ result.fight_count = detections['fights']
169
+
170
+ if detections['total'] > 0:
171
+ print(f" Found: G:{detections['guns']} K:{detections['knives']} F:{detections['fights']}")
172
+
173
+ # Final safety
174
+ total = result.nude_count + result.gun_count + result.knife_count + result.fight_count
175
+ result.is_safe = (total == 0)
176
+
177
+ print(
178
+ f"\n📊 Result: N:{result.nude_count} G:{result.gun_count} K:{result.knife_count} F:{result.fight_count} Safe:{result.is_safe}")
179
+ print(f"{'=' * 40}\n")
180
+
181
+ return result
182
+
183
+ except Exception as e:
184
+ print(f"❌ Error: {e}")
185
+ return result
186
+
187
+ def _check_nsfw(self, image) -> float:
188
+ """
189
+ Check NSFW with proper scoring
190
+ Returns confidence score (0-1)
191
+ """
192
+ try:
193
+ if not self.nsfw_classifier:
194
+ return 0.0
195
+
196
+ # Convert to RGB
197
+ rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
198
+ pil_image = Image.fromarray(rgb_image)
199
+
200
+ # Run classifier
201
+ results = self.nsfw_classifier(pil_image)
202
+
203
+ # Get NSFW score
204
+ nsfw_score = 0.0
205
+ for result in results:
206
+ label = result['label'].lower()
207
+ score = result['score']
208
+
209
+ # Check for NSFW label
210
+ if 'nsfw' in label or 'unsafe' in label or 'explicit' in label:
211
+ nsfw_score = max(nsfw_score, score)
212
+ print(f" {label}: {score:.3f}")
213
+
214
+ return nsfw_score
215
+
216
+ except Exception as e:
217
+ print(f" ⚠️ NSFW error: {e}")
218
+ return 0.0
219
+
220
+ def _detect_threats(self, image) -> Dict[str, int]:
221
+ """Detect weapons and fights"""
222
+ counts = {
223
+ 'guns': 0,
224
+ 'knives': 0,
225
+ 'fights': 0,
226
+ 'total': 0
227
+ }
228
+
229
+ try:
230
+ # Run detection with low base threshold
231
+ results = self.weapon_model(
232
+ image,
233
+ conf=0.4, # Low base threshold
234
+ device=self.device,
235
+ verbose=False
236
+ )
237
+
238
+ for result in results:
239
+ if result.boxes is None:
240
+ continue
241
+
242
+ for box in result.boxes:
243
+ class_id = int(box.cls[0])
244
+ confidence = float(box.conf[0])
245
+
246
+ if hasattr(result, 'names'):
247
+ class_name = result.names[class_id].lower()
248
+ else:
249
+ continue
250
+
251
+ # Check each category with proper threshold
252
+ if self._is_gun(class_name) and confidence > self.gun_threshold:
253
+ counts['guns'] += 1
254
+
255
+ elif self._is_knife(class_name) and confidence > self.knife_threshold:
256
+ counts['knives'] += 1
257
+
258
+ elif self._is_fight(class_name) and confidence > self.fight_threshold:
259
+ counts['fights'] += 1
260
+
261
+ counts['total'] = counts['guns'] + counts['knives'] + counts['fights']
262
+ return counts
263
+
264
+ except Exception as e:
265
+ print(f" ⚠️ Detection error: {e}")
266
+ return counts
267
+
268
+ def _is_gun(self, name: str) -> bool:
269
+ gun_words = ['gun', 'pistol', 'rifle', 'firearm', 'súng']
270
+ return any(w in name for w in gun_words)
271
+
272
+ def _is_knife(self, name: str) -> bool:
273
+ knife_words = ['knife', 'dao', 'blade', 'sword']
274
+ return any(w in name for w in knife_words)
275
+
276
+ def _is_fight(self, name: str) -> bool:
277
+ fight_words = ['fight', 'fighting', 'combat', 'violence']
278
+ return any(w in name for w in fight_words)
279
+
280
+ def process_video(self, video_path: str) -> Dict:
281
+ """
282
+ Process video with SMART frame skipping
283
+ Auto-adjusts based on video duration
284
+ """
285
+
286
+ total = DetectionResult()
287
+
288
+ try:
289
+ cap = cv2.VideoCapture(video_path)
290
+ if not cap.isOpened():
291
+ return total.to_dict()
292
+
293
+ # Get video info
294
+ fps = cap.get(cv2.CAP_PROP_FPS)
295
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
296
+ duration = total_frames / fps if fps > 0 else 0
297
+
298
+ # SMART frame skip based on duration
299
+ if duration <= 10: # Short video
300
+ frame_skip = 5 # Check every 5th frame
301
+ max_frames = 100
302
+ elif duration <= 30:
303
+ frame_skip = 10 # Check every 10th frame
304
+ max_frames = 150
305
+ elif duration <= 60:
306
+ frame_skip = 15
307
+ max_frames = 200
308
+ else: # Long video
309
+ frame_skip = 30
310
+ max_frames = 300
311
+
312
+ print(f"\n📹 Video: {duration:.1f}s, {total_frames} frames")
313
+ print(f" Auto settings: skip={frame_skip}, max={max_frames}")
314
+
315
+ frame_count = 0
316
+ processed = 0
317
+ nsfw_strikes = 0 # Count NSFW detections
318
+
319
+ while True:
320
+ ret, frame = cap.read()
321
+ if not ret:
322
+ break
323
+
324
+ frame_count += 1
325
+
326
+ # Skip frames
327
+ if frame_count % frame_skip != 0:
328
+ continue
329
+
330
+ # Max frame limit
331
+ if processed >= max_frames:
332
+ break
333
+
334
+ processed += 1
335
+
336
+ # Process frame
337
+ result = self.process_image(frame)
338
+
339
+ # Accumulate
340
+ total.nude_count += result.nude_count
341
+ total.gun_count += result.gun_count
342
+ total.knife_count += result.knife_count
343
+ total.fight_count += result.fight_count
344
+
345
+ # Early stop on multiple NSFW
346
+ if result.nude_count > 0:
347
+ nsfw_strikes += 1
348
+ if nsfw_strikes >= 3: # Stop after 3 NSFW frames
349
+ print(f"⛔ Early stop: {nsfw_strikes} NSFW frames")
350
+ break
351
+
352
+ # Progress
353
+ if processed % 50 == 0:
354
+ print(f" Processed {processed} frames...")
355
+
356
+ cap.release()
357
+
358
+ # Final safety
359
+ total_threats = total.nude_count + total.gun_count + total.knife_count + total.fight_count
360
+ total.is_safe = (total_threats == 0)
361
+
362
+ print(f"\n📊 Video complete: {processed} frames analyzed")
363
+ print(f" Total: N:{total.nude_count} G:{total.gun_count} K:{total.knife_count} F:{total.fight_count}")
364
+
365
+ return total.to_dict()
366
+
367
+ except Exception as e:
368
+ print(f"❌ Video error: {e}")
369
+ return total.to_dict()
370
+
371
+
372
+ def main():
373
+ """Test the moderator"""
374
+
375
+ moderator = SmartSequentialModerator()
376
+
377
+ print("\n" + "=" * 50)
378
+ print("🎯 SMART SEQUENTIAL MODERATOR")
379
+ print("=" * 50)
380
+ print("• Balanced NSFW threshold: 0.75")
381
+ print("• Auto frame skipping for videos")
382
+ print("• Simple output: counts + boolean")
383
+ print("=" * 50)
384
+
385
+ # Test
386
+ test_image = "test.jpg"
387
+ if os.path.exists(test_image):
388
+ result = moderator.process_image(test_image)
389
+ print(f"\nResult: {result.to_dict()}")
390
+
391
+
392
+ if __name__ == "__main__":
393
+ main()