kfvideodt / detection_api.py
Haiss123's picture
Update detection_api.py
0161c5e verified
raw
history blame
27.6 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
import cv2
from gunicorn.app.base import BaseApplication
import numpy as np
from datetime import datetime
import aiofiles
import json
from pathlib import Path
import uuid
import traceback
from concurrent.futures import ThreadPoolExecutor
import logging
import uvicorn
from main import ContentModerator
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Weapon & NSFW Detection API",
description="API for detecting knives/dao, guns, fights and NSFW content in images and videos",
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class StandaloneApplication(BaseApplication):
def __init__(self, app, options=None):
self.application = app
self.options = options or {}
super().__init__()
def load_config(self):
for key, value in self.options.items():
self.cfg.set(key, value)
def load(self):
return self.application
# Configuration
class Config:
UPLOAD_DIR = Path("uploads")
RESULTS_DIR = Path("results")
PROCESSED_DIR = Path("processed")
MAX_IMAGE_SIZE = 50 * 1024 * 1024 # 50MB for images
MAX_VIDEO_SIZE = 500 * 1024 * 1024 # 500MB for videos
ALLOWED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'}
ALLOWED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv'}
VIDEO_FRAME_SKIP = 5 # Process every 5th frame for performance
CLEANUP_AFTER_HOURS = 24
ENABLE_ANNOTATED_OUTPUT = True
MAX_WORKERS = 4
config = Config()
# Create necessary directories
for directory in [config.UPLOAD_DIR, config.RESULTS_DIR, config.PROCESSED_DIR]:
directory.mkdir(exist_ok=True)
(directory / "images").mkdir(exist_ok=True)
(directory / "videos").mkdir(exist_ok=True)
# Global moderator instance (initialized on startup)
moderator: Optional[ContentModerator] = None
# Thread pool for background processing
executor = ThreadPoolExecutor(max_workers=config.MAX_WORKERS)
# ============== Response Models ==============
class BoundingBox(BaseModel):
x1: int = Field(..., description="Top-left x coordinate")
y1: int = Field(..., description="Top-left y coordinate")
x2: int = Field(..., description="Bottom-right x coordinate")
y2: int = Field(..., description="Bottom-right y coordinate")
class WeaponDetection(BaseModel):
type: str = Field(..., description="Detection type (weapon)")
class_name: str = Field(..., description="Weapon class (knife/dao/gun)")
weapon_type: str = Field(..., description="Weapon category (blade/firearm)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: BoundingBox
threat_level: str = Field(..., description="Threat level (low/medium/high/critical)")
detection_method: str = Field(..., description="Detection method used")
class NSFWDetection(BaseModel):
type: str = Field(..., description="Detection type (nsfw)")
class_name: str = Field(..., description="NSFW class")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: BoundingBox
method: str = Field(..., description="Detection method (classification/skin_detection/pose_analysis)")
skin_ratio: Optional[float] = Field(None, description="Skin exposure ratio if applicable")
class FightDetection(BaseModel):
type: str = Field(default="fight", description="Detection type")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: BoundingBox
persons_involved: int = Field(..., description="Number of persons detected in fight")
threat_level: str = Field(..., description="Threat level")
class ImageDetectionResponse(BaseModel):
success: bool
request_id: str
timestamp: str
image_info: Dict[str, Any]
detections: Dict[str, List[Union[WeaponDetection, NSFWDetection, FightDetection]]]
summary: Dict[str, Any]
risk_level: str
action_required: bool
annotated_image_url: Optional[str] = None
processing_time_ms: float
class VideoDetectionResponse(BaseModel):
success: bool
request_id: str
timestamp: str
video_info: Dict[str, Any]
total_frames_processed: int
frame_detections: List[Dict[str, Any]]
summary: Dict[str, Any]
risk_level: str
action_required: bool
processed_video_url: Optional[str] = None
processing_time_ms: float
class ErrorResponse(BaseModel):
success: bool = False
error: str
error_code: str
timestamp: str
request_id: Optional[str] = None
# ============== Utility Functions ==============
def generate_request_id() -> str:
"""Generate unique request ID"""
return f"req_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
def validate_file_extension(filename: str, allowed_extensions: set) -> bool:
"""Validate file extension"""
return Path(filename).suffix.lower() in allowed_extensions
def validate_file_size(file_size: int, max_size: int) -> bool:
"""Validate file size"""
return file_size <= max_size
async def save_upload_file(upload_file: UploadFile, destination: Path) -> Path:
"""Save uploaded file to destination"""
try:
async with aiofiles.open(destination, 'wb') as f:
content = await upload_file.read()
await f.write(content)
return destination
except Exception as e:
logger.error(f"Error saving file: {e}")
raise
def detect_fight_in_frame(image: np.ndarray, persons: List[Dict]) -> Optional[FightDetection]:
"""
Detect potential fight based on person proximity and poses
This is a simplified implementation - you may want to enhance this
"""
if len(persons) < 2:
return None
# Check for overlapping or very close person bounding boxes
for i in range(len(persons)):
for j in range(i + 1, len(persons)):
bbox1 = persons[i]['bbox']
bbox2 = persons[j]['bbox']
# Calculate center points
center1_x = (bbox1[0] + bbox1[2]) / 2
center1_y = (bbox1[1] + bbox1[3]) / 2
center2_x = (bbox2[0] + bbox2[2]) / 2
center2_y = (bbox2[1] + bbox2[3]) / 2
# Calculate distance between centers
distance = np.sqrt((center1_x - center2_x) ** 2 + (center1_y - center2_y) ** 2)
# Calculate average person width
avg_width = ((bbox1[2] - bbox1[0]) + (bbox2[2] - bbox2[0])) / 2
# If persons are very close (distance less than average width)
if distance < avg_width * 1.5:
# Create combined bounding box
min_x = min(bbox1[0], bbox2[0])
min_y = min(bbox1[1], bbox2[1])
max_x = max(bbox1[2], bbox2[2])
max_y = max(bbox1[3], bbox2[3])
return FightDetection(
type="fight",
confidence=0.7, # Simplified confidence
bbox=BoundingBox(x1=min_x, y1=min_y, x2=max_x, y2=max_y),
persons_involved=2,
threat_level="high"
)
return None
def process_detections(raw_detections: List[Dict]) -> Dict[str, List]:
"""Process and categorize raw detections"""
processed = {
'weapons': [],
'nsfw': [],
'fights': []
}
for det in raw_detections:
if det['type'] == 'weapon':
processed['weapons'].append(WeaponDetection(
type=det['type'],
class_name=det['class'],
weapon_type=det.get('weapon_type', 'unknown'),
confidence=det['confidence'],
bbox=BoundingBox(
x1=det['bbox'][0],
y1=det['bbox'][1],
x2=det['bbox'][2],
y2=det['bbox'][3]
),
threat_level=det.get('threat_level', 'medium'),
detection_method=det.get('detection_method', 'yolo')
))
elif det['type'] == 'nsfw':
processed['nsfw'].append(NSFWDetection(
type=det['type'],
class_name=det['class'],
confidence=det['confidence'],
bbox=BoundingBox(
x1=det['bbox'][0],
y1=det['bbox'][1],
x2=det['bbox'][2],
y2=det['bbox'][3]
),
method=det.get('method', 'classification'),
skin_ratio=det.get('skin_ratio')
))
elif det['type'] == 'fight':
processed['fights'].append(det)
return processed
# ============== API Endpoints ==============
@app.on_event("startup")
async def startup_event():
"""Initialize moderator on startup"""
global moderator
try:
logger.info("Initializing Content Moderator...")
# Custom configuration for API
custom_config = {
'weapon_detection': {
'enabled': True,
'confidence_threshold': 0.5,
'knife_confidence': 0.25,
'model_size': 'yolo11n',
'classes': ['knife', 'dao', 'gun', 'rifle', 'pistol', 'weapon', 'fight'],
'use_enhancement': True,
'multi_pass': True,
'boost_knife_detection': True
},
'nsfw_detection': {
'enabled': True,
'confidence_threshold': 0.7,
'skin_detection': True,
'pose_analysis': False, # Disabled for performance
'region_analysis': True
},
'performance': {
'image_size': 640,
'batch_size': 1,
'half_precision': True,
'use_flash_attention': False,
'cpu_optimization': False
},
'output': {
'save_detections': True,
'draw_boxes': True,
'log_results': True
}
}
moderator = ContentModerator(config=custom_config)
logger.info("✅ Content Moderator initialized successfully")
# Log model status
status = moderator.get_model_status()
logger.info(f"Model Status: {json.dumps(status, indent=2)}")
except Exception as e:
logger.error(f"Failed to initialize Content Moderator: {e}")
logger.error(traceback.format_exc())
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
executor.shutdown(wait=True)
logger.info("API shutdown complete")
@app.get("/", response_model=Dict[str, Any])
async def root():
"""API root endpoint with status information"""
if moderator:
status = moderator.get_model_status()
return {
"service": "Weapon & NSFW Detection API",
"version": "2.0.0",
"status": "operational",
"models": status,
"endpoints": {
"image_detection": "/detect_n_k_f_g/images",
"video_detection": "/detect_n_k_f_g/videos",
"documentation": "/docs"
}
}
else:
return {
"service": "Weapon & NSFW Detection API",
"version": "2.0.0",
"status": "initializing",
"message": "Models are being loaded..."
}
@app.post("/detect_n_k_f_g/images", response_model=ImageDetectionResponse)
async def detect_image(
file: UploadFile = File(..., description="Image file to analyze"),
enable_fight_detection: bool = Form(True, description="Enable fight detection"),
return_annotated: bool = Form(True, description="Return annotated image")
):
"""
Detect weapons (knife/dao/gun), fights, and NSFW content in images
Supports: JPG, JPEG, PNG, BMP, GIF, WEBP
Max size: 50MB
"""
request_id = generate_request_id()
start_time = datetime.now()
try:
# Validate file extension
if not validate_file_extension(file.filename, config.ALLOWED_IMAGE_EXTENSIONS):
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(config.ALLOWED_IMAGE_EXTENSIONS)}"
)
# Check file size
file_content = await file.read()
file_size = len(file_content)
if not validate_file_size(file_size, config.MAX_IMAGE_SIZE):
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: {config.MAX_IMAGE_SIZE / (1024 * 1024):.1f}MB"
)
# Save uploaded file
upload_path = config.UPLOAD_DIR / "images" / f"{request_id}_{file.filename}"
async with aiofiles.open(upload_path, 'wb') as f:
await f.write(file_content)
# Read image with OpenCV
nparr = np.frombuffer(file_content, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid or corrupted image file")
# Get image info
height, width, channels = image.shape
image_info = {
"filename": file.filename,
"width": width,
"height": height,
"channels": channels,
"size_bytes": file_size,
"size_mb": round(file_size / (1024 * 1024), 2)
}
# Process image with ContentModerator
logger.info(f"Processing image {request_id}")
result = moderator.process_image(image)
if not result:
raise HTTPException(status_code=500, detail="Detection processing failed")
# Detect persons for potential fight detection
persons = moderator.detect_persons(image)
# Check for fights if enabled
fight_detection = None
if enable_fight_detection and len(persons) >= 2:
fight_detection = detect_fight_in_frame(image, persons)
# Process detections
processed = process_detections(result['detections'])
# Add fight detection if found
if fight_detection:
processed['fights'].append(fight_detection)
# Save annotated image if requested
annotated_url = None
if return_annotated and config.ENABLE_ANNOTATED_OUTPUT:
if 'annotated_image' in result:
annotated_path = config.PROCESSED_DIR / "images" / f"{request_id}_annotated.jpg"
cv2.imwrite(str(annotated_path), result['annotated_image'])
annotated_url = f"/results/images/{request_id}_annotated.jpg"
else:
# Draw annotations manually if not provided
annotated_image = moderator.draw_detections(image.copy(), result['detections'])
annotated_path = config.PROCESSED_DIR / "images" / f"{request_id}_annotated.jpg"
cv2.imwrite(str(annotated_path), annotated_image)
annotated_url = f"/results/images/{request_id}_annotated.jpg"
# Calculate summary
total_weapons = len(processed['weapons'])
total_nsfw = len(processed['nsfw'])
total_fights = len(processed['fights'])
knife_count = sum(
1 for w in processed['weapons'] if 'knife' in w.class_name.lower() or 'dao' in w.class_name.lower())
gun_count = sum(1 for w in processed['weapons'] if
'gun' in w.class_name.lower() or 'pistol' in w.class_name.lower() or 'rifle' in w.class_name.lower())
summary = {
"total_detections": total_weapons + total_nsfw + total_fights,
"weapons": {
"total": total_weapons,
"knives": knife_count,
"guns": gun_count
},
"nsfw": total_nsfw,
"fights": total_fights,
"persons_detected": len(persons)
}
# Determine overall risk level
if total_weapons > 0 or total_fights > 0:
risk_level = "critical" if gun_count > 0 else "high"
elif total_nsfw > 0:
risk_level = "medium"
else:
risk_level = "safe"
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds() * 1000
return ImageDetectionResponse(
success=True,
request_id=request_id,
timestamp=datetime.now().isoformat(),
image_info=image_info,
detections=processed,
summary=summary,
risk_level=risk_level,
action_required=(summary["total_detections"] > 0),
annotated_image_url=annotated_url,
processing_time_ms=processing_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing image {request_id}: {e}")
logger.error(traceback.format_exc())
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
@app.post("/detect_n_k_f_g/videos", response_model=VideoDetectionResponse)
async def detect_video(
file: UploadFile = File(..., description="Video file to analyze"),
frame_skip: int = Form(5, ge=1, le=30, description="Process every Nth frame"),
max_frames: int = Form(1000, ge=10, le=5000, description="Maximum frames to process"),
enable_fight_detection: bool = Form(True, description="Enable fight detection")
):
"""
Detect weapons (knife/dao/gun), fights, and NSFW content in videos
Supports: MP4, AVI, MOV, MKV, WEBM, FLV, WMV
Max size: 500MB
Note: Videos are automatically deleted after processing to save disk space
"""
request_id = generate_request_id()
start_time = datetime.now()
upload_path = None
try:
# Validate file extension
if not validate_file_extension(file.filename, config.ALLOWED_VIDEO_EXTENSIONS):
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(config.ALLOWED_VIDEO_EXTENSIONS)}"
)
# Save uploaded video
upload_path = config.UPLOAD_DIR / "videos" / f"{request_id}_{file.filename}"
await save_upload_file(file, upload_path)
# Get file size
file_size = upload_path.stat().st_size
if not validate_file_size(file_size, config.MAX_VIDEO_SIZE):
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: {config.MAX_VIDEO_SIZE / (1024 * 1024):.1f}MB"
)
# Open video
cap = cv2.VideoCapture(str(upload_path))
if not cap.isOpened():
raise HTTPException(status_code=400, detail="Invalid or corrupted video file")
# Get video info
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
duration = total_frames / fps if fps > 0 else 0
video_info = {
"filename": file.filename,
"width": width,
"height": height,
"fps": fps,
"total_frames": total_frames,
"duration_seconds": round(duration, 2),
"size_bytes": file_size,
"size_mb": round(file_size / (1024 * 1024), 2)
}
# Process video frames
logger.info(f"Processing video {request_id}: {total_frames} frames, skip={frame_skip}")
frame_detections = []
frame_count = 0
processed_count = 0
# Aggregated statistics
all_weapons = []
all_nsfw = []
all_fights = []
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# Skip frames according to frame_skip parameter
if frame_count % frame_skip != 0:
continue
# Limit maximum frames processed
if processed_count >= max_frames:
logger.info(f"Reached max frames limit: {max_frames}")
break
processed_count += 1
# Process frame
result = moderator.process_image(frame)
if result and result['detections']:
# Get persons for fight detection
persons = moderator.detect_persons(frame)
# Check for fights
fight_detection = None
if enable_fight_detection and len(persons) >= 2:
fight_detection = detect_fight_in_frame(frame, persons)
# Process detections
processed = process_detections(result['detections'])
if fight_detection:
processed['fights'].append(fight_detection)
# Store frame detection info
if len(processed['weapons']) > 0 or len(processed['nsfw']) > 0 or len(processed['fights']) > 0:
frame_info = {
"frame_number": frame_count,
"timestamp_seconds": frame_count / fps if fps > 0 else 0,
"detections": {
"weapons": [w.dict() for w in processed['weapons']],
"nsfw": [n.dict() for n in processed['nsfw']],
"fights": [f.dict() for f in processed['fights']]
}
}
frame_detections.append(frame_info)
# Aggregate statistics
all_weapons.extend(processed['weapons'])
all_nsfw.extend(processed['nsfw'])
all_fights.extend(processed['fights'])
# Log progress every 100 frames
if processed_count % 100 == 0:
logger.info(f"Processed {processed_count} frames...")
# Release resources
cap.release()
# Calculate summary
knife_count = sum(1 for w in all_weapons if 'knife' in w.class_name.lower() or 'dao' in w.class_name.lower())
gun_count = sum(1 for w in all_weapons if 'gun' in w.class_name.lower() or 'pistol' in w.class_name.lower())
summary = {
"total_frames_analyzed": processed_count,
"frames_with_detections": len(frame_detections),
"total_detections": len(all_weapons) + len(all_nsfw) + len(all_fights),
"weapons": {
"total": len(all_weapons),
"knives": knife_count,
"guns": gun_count,
"unique_frames": len(set(f["frame_number"] for f in frame_detections if f["detections"]["weapons"]))
},
"nsfw": {
"total": len(all_nsfw),
"unique_frames": len(set(f["frame_number"] for f in frame_detections if f["detections"]["nsfw"]))
},
"fights": {
"total": len(all_fights),
"unique_frames": len(set(f["frame_number"] for f in frame_detections if f["detections"]["fights"]))
}
}
# Determine overall risk level
if gun_count > 0 or len(all_fights) > 5:
risk_level = "critical"
elif knife_count > 0 or len(all_fights) > 0:
risk_level = "high"
elif len(all_nsfw) > 0:
risk_level = "medium"
else:
risk_level = "safe"
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds() * 1000
return VideoDetectionResponse(
success=True,
request_id=request_id,
timestamp=datetime.now().isoformat(),
video_info=video_info,
total_frames_processed=processed_count,
frame_detections=frame_detections,
summary=summary,
risk_level=risk_level,
action_required=(summary["total_detections"] > 0),
processed_video_url=None, # Always None since we don't save processed videos
processing_time_ms=processing_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing video {request_id}: {e}")
logger.error(traceback.format_exc())
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
finally:
# Always cleanup uploaded video file after processing
if upload_path and upload_path.exists():
try:
upload_path.unlink()
logger.info(f"Cleaned up uploaded video: {upload_path}")
except Exception as cleanup_error:
logger.warning(f"Failed to cleanup uploaded video {upload_path}: {cleanup_error}")
@app.delete("/cleanup")
async def cleanup_old_files(hours: int = 24):
"""Clean up old files from upload and results directories (excluding videos from uploads as they are auto-deleted)"""
try:
from datetime import timedelta
cutoff_time = datetime.now() - timedelta(hours=hours)
deleted_count = 0
# Clean up images from all directories
for directory in [config.UPLOAD_DIR, config.RESULTS_DIR, config.PROCESSED_DIR]:
images_path = directory / "images"
if images_path.exists():
for file in images_path.iterdir():
if file.is_file():
file_time = datetime.fromtimestamp(file.stat().st_mtime)
if file_time < cutoff_time:
file.unlink()
deleted_count += 1
# Clean up any remaining uploaded videos (should be rare since they're auto-deleted)
upload_videos_path = config.UPLOAD_DIR / "videos"
if upload_videos_path.exists():
for file in upload_videos_path.iterdir():
if file.is_file():
file_time = datetime.fromtimestamp(file.stat().st_mtime)
if file_time < cutoff_time:
file.unlink()
deleted_count += 1
logger.info(f"Cleaned up old uploaded video: {file}")
# Note: No need to clean processed videos since we don't save them anymore
return {
"success": True,
"deleted_files": deleted_count,
"message": f"Deleted {deleted_count} files older than {hours} hours"
}
except Exception as e:
logger.error(f"Cleanup error: {e}")
return {
"success": False,
"error": str(e)
}
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 7860))
options = {
"bind": f"0.0.0.0:{port}",
"workers": 2,
}
StandaloneApplication(app, options).run()