|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import supervision as sv
|
|
|
import torch
|
|
|
import requests
|
|
|
from PIL import Image
|
|
|
import os
|
|
|
import cv2
|
|
|
from tqdm import tqdm
|
|
|
import time
|
|
|
|
|
|
from rfdetr import RFDETRNano
|
|
|
|
|
|
THREAT_CLASSES = {
|
|
|
1: "Gun",
|
|
|
2: "Explosive",
|
|
|
3: "Grenade",
|
|
|
4: "Knife"
|
|
|
}
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.deterministic = False
|
|
|
else:
|
|
|
print("CUDA not available, using CPU")
|
|
|
|
|
|
|
|
|
INPUT_VIDEO = "test_video.mp4"
|
|
|
|
|
|
base, ext = os.path.splitext(INPUT_VIDEO)
|
|
|
OUTPUT_VIDEO = f"{base}_detr{ext}"
|
|
|
|
|
|
THRESHOLD = 0.5
|
|
|
BATCH_SIZE = 32
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
|
|
|
|
print(f"Using batch size: {BATCH_SIZE}")
|
|
|
|
|
|
|
|
|
weights_url = "https://huggingface.co/Subh775/Threat-Detection-RF-DETR/resolve/main/checkpoint_best_total.pth"
|
|
|
weights_filename = "checkpoint_best_total.pth"
|
|
|
|
|
|
if not os.path.exists(weights_filename):
|
|
|
print(f"Downloading weights from {weights_url}")
|
|
|
response = requests.get(weights_url, stream=True)
|
|
|
response.raise_for_status()
|
|
|
with open(weights_filename, 'wb') as f:
|
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
|
f.write(chunk)
|
|
|
print("Download complete.")
|
|
|
|
|
|
print("Loading model...")
|
|
|
model = RFDETRNano(resolution=640, pretrain_weights=weights_filename)
|
|
|
model.optimize_for_inference()
|
|
|
|
|
|
|
|
|
color = sv.ColorPalette.from_hex([
|
|
|
"#1E90FF", "#32CD32", "#FF0000", "#FF8C00"
|
|
|
])
|
|
|
|
|
|
bbox_annotator = sv.BoxAnnotator(color=color, thickness=3)
|
|
|
label_annotator = sv.LabelAnnotator(
|
|
|
color=color,
|
|
|
text_color=sv.Color.BLACK,
|
|
|
text_scale=1.0,
|
|
|
text_thickness=2,
|
|
|
smart_position=True
|
|
|
)
|
|
|
|
|
|
def process_frame_batch(frames):
|
|
|
"""Process a batch of frames for better GPU utilization"""
|
|
|
batch_results = []
|
|
|
|
|
|
|
|
|
pil_images = []
|
|
|
for frame in frames:
|
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
pil_image = Image.fromarray(rgb_frame)
|
|
|
pil_images.append(pil_image)
|
|
|
|
|
|
|
|
|
batch_detections = []
|
|
|
for pil_image in pil_images:
|
|
|
detections = model.predict(pil_image, threshold=THRESHOLD)
|
|
|
batch_detections.append(detections)
|
|
|
|
|
|
|
|
|
annotated_frames = []
|
|
|
for pil_image, detections in zip(pil_images, batch_detections):
|
|
|
|
|
|
labels = []
|
|
|
for class_id, confidence in zip(detections.class_id, detections.confidence):
|
|
|
class_name = THREAT_CLASSES.get(class_id, f"unknown_class_{class_id}")
|
|
|
labels.append(f"{class_name} {confidence:.2f}")
|
|
|
|
|
|
|
|
|
annotated_pil = pil_image.copy()
|
|
|
annotated_pil = bbox_annotator.annotate(annotated_pil, detections)
|
|
|
annotated_pil = label_annotator.annotate(annotated_pil, detections, labels)
|
|
|
|
|
|
|
|
|
annotated_frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR)
|
|
|
annotated_frames.append(annotated_frame)
|
|
|
|
|
|
return annotated_frames, batch_detections
|
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(INPUT_VIDEO)
|
|
|
if not cap.isOpened():
|
|
|
print(f"Error: Could not open video file {INPUT_VIDEO}")
|
|
|
exit()
|
|
|
|
|
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
print(f"Video: {width}x{height}, {fps} FPS, {total_frames} frames")
|
|
|
print(f"Processing in batches of {BATCH_SIZE} frames")
|
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
out = cv2.VideoWriter(OUTPUT_VIDEO, fourcc, fps, (width, height))
|
|
|
|
|
|
|
|
|
print("Processing video with batch inference...")
|
|
|
frame_buffer = []
|
|
|
total_detections = 0
|
|
|
processed_frames = 0
|
|
|
processing_times = []
|
|
|
|
|
|
with tqdm(total=total_frames, desc="Batch processing") as pbar:
|
|
|
while True:
|
|
|
ret, frame = cap.read()
|
|
|
if not ret:
|
|
|
|
|
|
if frame_buffer:
|
|
|
start_time = time.time()
|
|
|
annotated_frames, batch_detections = process_frame_batch(frame_buffer)
|
|
|
processing_time = time.time() - start_time
|
|
|
processing_times.append(processing_time)
|
|
|
|
|
|
|
|
|
for annotated_frame, detections in zip(annotated_frames, batch_detections):
|
|
|
out.write(annotated_frame)
|
|
|
total_detections += len(detections)
|
|
|
|
|
|
processed_frames += len(frame_buffer)
|
|
|
pbar.update(len(frame_buffer))
|
|
|
break
|
|
|
|
|
|
|
|
|
frame_buffer.append(frame)
|
|
|
|
|
|
|
|
|
if len(frame_buffer) >= BATCH_SIZE:
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
annotated_frames, batch_detections = process_frame_batch(frame_buffer)
|
|
|
|
|
|
processing_time = time.time() - start_time
|
|
|
processing_times.append(processing_time)
|
|
|
|
|
|
|
|
|
batch_threats = 0
|
|
|
for annotated_frame, detections in zip(annotated_frames, batch_detections):
|
|
|
out.write(annotated_frame)
|
|
|
batch_threats += len(detections)
|
|
|
total_detections += len(detections)
|
|
|
|
|
|
processed_frames += len(frame_buffer)
|
|
|
|
|
|
|
|
|
batch_fps = len(frame_buffer) / processing_time if processing_time > 0 else 0
|
|
|
pbar.set_postfix({
|
|
|
'Batch FPS': f"{batch_fps:.1f}",
|
|
|
'Threats': batch_threats,
|
|
|
'Total': total_detections
|
|
|
})
|
|
|
pbar.update(len(frame_buffer))
|
|
|
|
|
|
|
|
|
frame_buffer = []
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and processed_frames % (BATCH_SIZE * 10) == 0:
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
cap.release()
|
|
|
out.release()
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
total_time = sum(processing_times)
|
|
|
avg_fps = processed_frames / total_time if total_time > 0 else 0
|
|
|
speedup = avg_fps / fps if fps > 0 else 0
|
|
|
|
|
|
print(f"Output: {OUTPUT_VIDEO}")
|
|
|
print(f"Stats:")
|
|
|
print(f" • Processed: {processed_frames} frames")
|
|
|
print(f" • Detections: {total_detections}")
|
|
|
print(f" • Batch size: {BATCH_SIZE}")
|
|
|
print(f" • Average speed: {avg_fps:.1f} FPS")
|
|
|
print(f" • Speedup: {speedup:.1f}x real-time")
|
|
|
print(f" • Processing time: {total_time:.1f}s") |