kfvideodt / detection_api.py
Haiss123's picture
Upload 5 files
9adeec5 verified
raw
history blame
29 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
import cv2
import numpy as np
from datetime import datetime
import aiofiles
import json
from pathlib import Path
import uuid
import traceback
from concurrent.futures import ThreadPoolExecutor
import logging
from main import ContentModerator
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Weapon & NSFW Detection API",
description="API for detecting knives/dao, guns, fights and NSFW content in images and videos",
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
class Config:
UPLOAD_DIR = Path("uploads")
RESULTS_DIR = Path("results")
PROCESSED_DIR = Path("processed")
MAX_IMAGE_SIZE = 50 * 1024 * 1024 # 50MB for images
MAX_VIDEO_SIZE = 500 * 1024 * 1024 # 500MB for videos
ALLOWED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'}
ALLOWED_VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv'}
VIDEO_FRAME_SKIP = 5 # Process every 5th frame for performance
CLEANUP_AFTER_HOURS = 24
ENABLE_ANNOTATED_OUTPUT = True
MAX_WORKERS = 4
config = Config()
# Create necessary directories
for directory in [config.UPLOAD_DIR, config.RESULTS_DIR, config.PROCESSED_DIR]:
directory.mkdir(exist_ok=True)
(directory / "images").mkdir(exist_ok=True)
(directory / "videos").mkdir(exist_ok=True)
# Global moderator instance (initialized on startup)
moderator: Optional[ContentModerator] = None
# Thread pool for background processing
executor = ThreadPoolExecutor(max_workers=config.MAX_WORKERS)
# ============== Response Models ==============
class BoundingBox(BaseModel):
x1: int = Field(..., description="Top-left x coordinate")
y1: int = Field(..., description="Top-left y coordinate")
x2: int = Field(..., description="Bottom-right x coordinate")
y2: int = Field(..., description="Bottom-right y coordinate")
class WeaponDetection(BaseModel):
type: str = Field(..., description="Detection type (weapon)")
class_name: str = Field(..., description="Weapon class (knife/dao/gun)")
weapon_type: str = Field(..., description="Weapon category (blade/firearm)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: BoundingBox
threat_level: str = Field(..., description="Threat level (low/medium/high/critical)")
detection_method: str = Field(..., description="Detection method used")
class NSFWDetection(BaseModel):
type: str = Field(..., description="Detection type (nsfw)")
class_name: str = Field(..., description="NSFW class")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: BoundingBox
method: str = Field(..., description="Detection method (classification/skin_detection/pose_analysis)")
skin_ratio: Optional[float] = Field(None, description="Skin exposure ratio if applicable")
class FightDetection(BaseModel):
type: str = Field(default="fight", description="Detection type")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: BoundingBox
persons_involved: int = Field(..., description="Number of persons detected in fight")
threat_level: str = Field(..., description="Threat level")
class ImageDetectionResponse(BaseModel):
success: bool
request_id: str
timestamp: str
image_info: Dict[str, Any]
detections: Dict[str, List[Union[WeaponDetection, NSFWDetection, FightDetection]]]
summary: Dict[str, Any]
risk_level: str
action_required: bool
annotated_image_url: Optional[str] = None
processing_time_ms: float
class VideoDetectionResponse(BaseModel):
success: bool
request_id: str
timestamp: str
video_info: Dict[str, Any]
total_frames_processed: int
frame_detections: List[Dict[str, Any]]
summary: Dict[str, Any]
risk_level: str
action_required: bool
processed_video_url: Optional[str] = None
processing_time_ms: float
class ErrorResponse(BaseModel):
success: bool = False
error: str
error_code: str
timestamp: str
request_id: Optional[str] = None
# ============== Utility Functions ==============
def generate_request_id() -> str:
"""Generate unique request ID"""
return f"req_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid.uuid4().hex[:8]}"
def validate_file_extension(filename: str, allowed_extensions: set) -> bool:
"""Validate file extension"""
return Path(filename).suffix.lower() in allowed_extensions
def validate_file_size(file_size: int, max_size: int) -> bool:
"""Validate file size"""
return file_size <= max_size
async def save_upload_file(upload_file: UploadFile, destination: Path) -> Path:
"""Save uploaded file to destination"""
try:
async with aiofiles.open(destination, 'wb') as f:
content = await upload_file.read()
await f.write(content)
return destination
except Exception as e:
logger.error(f"Error saving file: {e}")
raise
def detect_fight_in_frame(image: np.ndarray, persons: List[Dict]) -> Optional[FightDetection]:
"""
Detect potential fight based on person proximity and poses
This is a simplified implementation - you may want to enhance this
"""
if len(persons) < 2:
return None
# Check for overlapping or very close person bounding boxes
for i in range(len(persons)):
for j in range(i + 1, len(persons)):
bbox1 = persons[i]['bbox']
bbox2 = persons[j]['bbox']
# Calculate center points
center1_x = (bbox1[0] + bbox1[2]) / 2
center1_y = (bbox1[1] + bbox1[3]) / 2
center2_x = (bbox2[0] + bbox2[2]) / 2
center2_y = (bbox2[1] + bbox2[3]) / 2
# Calculate distance between centers
distance = np.sqrt((center1_x - center2_x) ** 2 + (center1_y - center2_y) ** 2)
# Calculate average person width
avg_width = ((bbox1[2] - bbox1[0]) + (bbox2[2] - bbox2[0])) / 2
# If persons are very close (distance less than average width)
if distance < avg_width * 1.5:
# Create combined bounding box
min_x = min(bbox1[0], bbox2[0])
min_y = min(bbox1[1], bbox2[1])
max_x = max(bbox1[2], bbox2[2])
max_y = max(bbox1[3], bbox2[3])
return FightDetection(
type="fight",
confidence=0.7, # Simplified confidence
bbox=BoundingBox(x1=min_x, y1=min_y, x2=max_x, y2=max_y),
persons_involved=2,
threat_level="high"
)
return None
def process_detections(raw_detections: List[Dict]) -> Dict[str, List]:
"""Process and categorize raw detections"""
processed = {
'weapons': [],
'nsfw': [],
'fights': []
}
for det in raw_detections:
if det['type'] == 'weapon':
processed['weapons'].append(WeaponDetection(
type=det['type'],
class_name=det['class'],
weapon_type=det.get('weapon_type', 'unknown'),
confidence=det['confidence'],
bbox=BoundingBox(
x1=det['bbox'][0],
y1=det['bbox'][1],
x2=det['bbox'][2],
y2=det['bbox'][3]
),
threat_level=det.get('threat_level', 'medium'),
detection_method=det.get('detection_method', 'yolo')
))
elif det['type'] == 'nsfw':
processed['nsfw'].append(NSFWDetection(
type=det['type'],
class_name=det['class'],
confidence=det['confidence'],
bbox=BoundingBox(
x1=det['bbox'][0],
y1=det['bbox'][1],
x2=det['bbox'][2],
y2=det['bbox'][3]
),
method=det.get('method', 'classification'),
skin_ratio=det.get('skin_ratio')
))
elif det['type'] == 'fight':
processed['fights'].append(det)
return processed
# ============== API Endpoints ==============
@app.on_event("startup")
async def startup_event():
"""Initialize moderator on startup"""
global moderator
try:
logger.info("Initializing Content Moderator...")
# Custom configuration for API
custom_config = {
'weapon_detection': {
'enabled': True,
'confidence_threshold': 0.5,
'knife_confidence': 0.25,
'model_size': 'yolo11n',
'classes': ['knife', 'dao', 'gun', 'rifle', 'pistol', 'weapon', 'fight'],
'use_enhancement': True,
'multi_pass': True,
'boost_knife_detection': True
},
'nsfw_detection': {
'enabled': True,
'confidence_threshold': 0.7,
'skin_detection': True,
'pose_analysis': False, # Disabled for performance
'region_analysis': True
},
'performance': {
'image_size': 640,
'batch_size': 1,
'half_precision': True,
'use_flash_attention': False,
'cpu_optimization': False
},
'output': {
'save_detections': True,
'draw_boxes': True,
'log_results': True
}
}
moderator = ContentModerator(config=custom_config)
logger.info("✅ Content Moderator initialized successfully")
# Log model status
status = moderator.get_model_status()
logger.info(f"Model Status: {json.dumps(status, indent=2)}")
except Exception as e:
logger.error(f"Failed to initialize Content Moderator: {e}")
logger.error(traceback.format_exc())
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
executor.shutdown(wait=True)
logger.info("API shutdown complete")
@app.get("/", response_model=Dict[str, Any])
async def root():
"""API root endpoint with status information"""
if moderator:
status = moderator.get_model_status()
return {
"service": "Weapon & NSFW Detection API",
"version": "2.0.0",
"status": "operational",
"models": status,
"endpoints": {
"image_detection": "/detect_n_k_f_g/images",
"video_detection": "/detect_n_k_f_g/videos",
"documentation": "/docs"
}
}
else:
return {
"service": "Weapon & NSFW Detection API",
"version": "2.0.0",
"status": "initializing",
"message": "Models are being loaded..."
}
@app.post("/detect_n_k_f_g/images", response_model=ImageDetectionResponse)
async def detect_image(
file: UploadFile = File(..., description="Image file to analyze"),
enable_fight_detection: bool = Form(True, description="Enable fight detection"),
return_annotated: bool = Form(True, description="Return annotated image")
):
"""
Detect weapons (knife/dao/gun), fights, and NSFW content in images
Supports: JPG, JPEG, PNG, BMP, GIF, WEBP
Max size: 50MB
"""
request_id = generate_request_id()
start_time = datetime.now()
try:
# Validate file extension
if not validate_file_extension(file.filename, config.ALLOWED_IMAGE_EXTENSIONS):
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(config.ALLOWED_IMAGE_EXTENSIONS)}"
)
# Check file size
file_content = await file.read()
file_size = len(file_content)
if not validate_file_size(file_size, config.MAX_IMAGE_SIZE):
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: {config.MAX_IMAGE_SIZE / (1024 * 1024):.1f}MB"
)
# Save uploaded file
upload_path = config.UPLOAD_DIR / "images" / f"{request_id}_{file.filename}"
async with aiofiles.open(upload_path, 'wb') as f:
await f.write(file_content)
# Read image with OpenCV
nparr = np.frombuffer(file_content, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid or corrupted image file")
# Get image info
height, width, channels = image.shape
image_info = {
"filename": file.filename,
"width": width,
"height": height,
"channels": channels,
"size_bytes": file_size,
"size_mb": round(file_size / (1024 * 1024), 2)
}
# Process image with ContentModerator
logger.info(f"Processing image {request_id}")
result = moderator.process_image(image)
if not result:
raise HTTPException(status_code=500, detail="Detection processing failed")
# Detect persons for potential fight detection
persons = moderator.detect_persons(image)
# Check for fights if enabled
fight_detection = None
if enable_fight_detection and len(persons) >= 2:
fight_detection = detect_fight_in_frame(image, persons)
# Process detections
processed = process_detections(result['detections'])
# Add fight detection if found
if fight_detection:
processed['fights'].append(fight_detection)
# Save annotated image if requested
annotated_url = None
if return_annotated and config.ENABLE_ANNOTATED_OUTPUT:
if 'annotated_image' in result:
annotated_path = config.PROCESSED_DIR / "images" / f"{request_id}_annotated.jpg"
cv2.imwrite(str(annotated_path), result['annotated_image'])
annotated_url = f"/results/images/{request_id}_annotated.jpg"
else:
# Draw annotations manually if not provided
annotated_image = moderator.draw_detections(image.copy(), result['detections'])
annotated_path = config.PROCESSED_DIR / "images" / f"{request_id}_annotated.jpg"
cv2.imwrite(str(annotated_path), annotated_image)
annotated_url = f"/results/images/{request_id}_annotated.jpg"
# Calculate summary
total_weapons = len(processed['weapons'])
total_nsfw = len(processed['nsfw'])
total_fights = len(processed['fights'])
knife_count = sum(
1 for w in processed['weapons'] if 'knife' in w.class_name.lower() or 'dao' in w.class_name.lower())
gun_count = sum(1 for w in processed['weapons'] if
'gun' in w.class_name.lower() or 'pistol' in w.class_name.lower() or 'rifle' in w.class_name.lower())
summary = {
"total_detections": total_weapons + total_nsfw + total_fights,
"weapons": {
"total": total_weapons,
"knives": knife_count,
"guns": gun_count
},
"nsfw": total_nsfw,
"fights": total_fights,
"persons_detected": len(persons)
}
# Determine overall risk level
if total_weapons > 0 or total_fights > 0:
risk_level = "critical" if gun_count > 0 else "high"
elif total_nsfw > 0:
risk_level = "medium"
else:
risk_level = "safe"
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds() * 1000
return ImageDetectionResponse(
success=True,
request_id=request_id,
timestamp=datetime.now().isoformat(),
image_info=image_info,
detections=processed,
summary=summary,
risk_level=risk_level,
action_required=(summary["total_detections"] > 0),
annotated_image_url=annotated_url,
processing_time_ms=processing_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing image {request_id}: {e}")
logger.error(traceback.format_exc())
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
@app.post("/detect_n_k_f_g/videos", response_model=VideoDetectionResponse)
async def detect_video(
file: UploadFile = File(..., description="Video file to analyze"),
frame_skip: int = Form(5, ge=1, le=30, description="Process every Nth frame"),
max_frames: int = Form(1000, ge=10, le=5000, description="Maximum frames to process"),
enable_fight_detection: bool = Form(True, description="Enable fight detection"),
save_processed: bool = Form(False, description="Save processed video with annotations")
):
"""
Detect weapons (knife/dao/gun), fights, and NSFW content in videos
Supports: MP4, AVI, MOV, MKV, WEBM, FLV, WMV
Max size: 500MB
"""
request_id = generate_request_id()
start_time = datetime.now()
try:
# Validate file extension
if not validate_file_extension(file.filename, config.ALLOWED_VIDEO_EXTENSIONS):
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(config.ALLOWED_VIDEO_EXTENSIONS)}"
)
# Save uploaded video
upload_path = config.UPLOAD_DIR / "videos" / f"{request_id}_{file.filename}"
await save_upload_file(file, upload_path)
# Get file size
file_size = upload_path.stat().st_size
if not validate_file_size(file_size, config.MAX_VIDEO_SIZE):
upload_path.unlink() # Delete the file
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: {config.MAX_VIDEO_SIZE / (1024 * 1024):.1f}MB"
)
# Open video
cap = cv2.VideoCapture(str(upload_path))
if not cap.isOpened():
raise HTTPException(status_code=400, detail="Invalid or corrupted video file")
# Get video info
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
duration = total_frames / fps if fps > 0 else 0
video_info = {
"filename": file.filename,
"width": width,
"height": height,
"fps": fps,
"total_frames": total_frames,
"duration_seconds": round(duration, 2),
"size_bytes": file_size,
"size_mb": round(file_size / (1024 * 1024), 2)
}
# Prepare output video if requested
out_writer = None
processed_video_path = None
if save_processed:
processed_video_path = config.PROCESSED_DIR / "videos" / f"{request_id}_processed.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_writer = cv2.VideoWriter(
str(processed_video_path),
fourcc,
fps,
(width, height)
)
# Process video frames
logger.info(f"Processing video {request_id}: {total_frames} frames, skip={frame_skip}")
frame_detections = []
frame_count = 0
processed_count = 0
# Aggregated statistics
all_weapons = []
all_nsfw = []
all_fights = []
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
# Skip frames according to frame_skip parameter
if frame_count % frame_skip != 0:
continue
# Limit maximum frames processed
if processed_count >= max_frames:
logger.info(f"Reached max frames limit: {max_frames}")
break
processed_count += 1
# Process frame
result = moderator.process_image(frame)
if result and result['detections']:
# Get persons for fight detection
persons = moderator.detect_persons(frame)
# Check for fights
fight_detection = None
if enable_fight_detection and len(persons) >= 2:
fight_detection = detect_fight_in_frame(frame, persons)
# Process detections
processed = process_detections(result['detections'])
if fight_detection:
processed['fights'].append(fight_detection)
# Store frame detection info
if len(processed['weapons']) > 0 or len(processed['nsfw']) > 0 or len(processed['fights']) > 0:
frame_info = {
"frame_number": frame_count,
"timestamp_seconds": frame_count / fps if fps > 0 else 0,
"detections": {
"weapons": [w.dict() for w in processed['weapons']],
"nsfw": [n.dict() for n in processed['nsfw']],
"fights": [f.dict() for f in processed['fights']]
}
}
frame_detections.append(frame_info)
# Aggregate statistics
all_weapons.extend(processed['weapons'])
all_nsfw.extend(processed['nsfw'])
all_fights.extend(processed['fights'])
# Write annotated frame if saving video
if out_writer and 'annotated_image' in result:
out_writer.write(result['annotated_image'])
elif out_writer:
# Write original frame if no detections
out_writer.write(frame)
# Log progress every 100 frames
if processed_count % 100 == 0:
logger.info(f"Processed {processed_count} frames...")
# Release resources
cap.release()
if out_writer:
out_writer.release()
# Calculate summary
knife_count = sum(1 for w in all_weapons if 'knife' in w.class_name.lower() or 'dao' in w.class_name.lower())
gun_count = sum(1 for w in all_weapons if 'gun' in w.class_name.lower() or 'pistol' in w.class_name.lower())
summary = {
"total_frames_analyzed": processed_count,
"frames_with_detections": len(frame_detections),
"total_detections": len(all_weapons) + len(all_nsfw) + len(all_fights),
"weapons": {
"total": len(all_weapons),
"knives": knife_count,
"guns": gun_count,
"unique_frames": len(set(f["frame_number"] for f in frame_detections if f["detections"]["weapons"]))
},
"nsfw": {
"total": len(all_nsfw),
"unique_frames": len(set(f["frame_number"] for f in frame_detections if f["detections"]["nsfw"]))
},
"fights": {
"total": len(all_fights),
"unique_frames": len(set(f["frame_number"] for f in frame_detections if f["detections"]["fights"]))
}
}
# Determine overall risk level
if gun_count > 0 or len(all_fights) > 5:
risk_level = "critical"
elif knife_count > 0 or len(all_fights) > 0:
risk_level = "high"
elif len(all_nsfw) > 0:
risk_level = "medium"
else:
risk_level = "safe"
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds() * 1000
# Prepare processed video URL if saved
processed_video_url = None
if save_processed and processed_video_path and processed_video_path.exists():
processed_video_url = f"/results/videos/{request_id}_processed.mp4"
return VideoDetectionResponse(
success=True,
request_id=request_id,
timestamp=datetime.now().isoformat(),
video_info=video_info,
total_frames_processed=processed_count,
frame_detections=frame_detections,
summary=summary,
risk_level=risk_level,
action_required=(summary["total_detections"] > 0),
processed_video_url=processed_video_url,
processing_time_ms=processing_time
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing video {request_id}: {e}")
logger.error(traceback.format_exc())
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(e)}"
)
finally:
# Cleanup uploaded file if needed
if upload_path.exists() and not save_processed:
try:
upload_path.unlink()
except:
pass
@app.get("/results/images/{filename}")
async def get_processed_image(filename: str):
"""Get processed/annotated image"""
file_path = config.PROCESSED_DIR / "images" / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path)
@app.get("/results/videos/{filename}")
async def get_processed_video(filename: str):
"""Get processed/annotated video"""
file_path = config.PROCESSED_DIR / "videos" / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if moderator:
status = moderator.get_model_status()
return {
"status": "healthy",
"models_loaded": True,
"model_details": status
}
else:
return {
"status": "initializing",
"models_loaded": False
}
@app.delete("/cleanup")
async def cleanup_old_files(hours: int = 24):
"""Clean up old files from upload and results directories"""
try:
from datetime import timedelta
cutoff_time = datetime.now() - timedelta(hours=hours)
deleted_count = 0
for directory in [config.UPLOAD_DIR, config.RESULTS_DIR, config.PROCESSED_DIR]:
for subdir in ["images", "videos"]:
path = directory / subdir
if path.exists():
for file in path.iterdir():
if file.is_file():
file_time = datetime.fromtimestamp(file.stat().st_mtime)
if file_time < cutoff_time:
file.unlink()
deleted_count += 1
return {
"success": True,
"deleted_files": deleted_count,
"message": f"Deleted {deleted_count} files older than {hours} hours"
}
except Exception as e:
logger.error(f"Cleanup error: {e}")
return {
"success": False,
"error": str(e)
}