ogulcanaydogan's picture
feat: upgrade to advanced counting system with RT-DETR and proper line crossing
ba0c288
"""CCTV Customer Analytics - Advanced Object Counting System
This Space provides accurate object detection, tracking, and counting
across a user-defined line. Optimized for counting large numbers of
animals (sheep, cows) and vehicles in crowded scenes.
Key Features:
- RT-DETR and YOLOv8 model support
- Optimized ByteTrack for dense scenes
- Proper geometric line crossing detection
- Multi-class object support
"""
import gradio as gr
import spaces
import cv2
import numpy as np
import tempfile
import os
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import supervision as sv
from ultralytics import YOLO, RTDETR
# Detection modes with COCO class IDs
DETECTION_MODES = {
"All Objects (Street)": {
"class_ids": [0, 1, 2, 3, 5, 7, 17, 18, 19],
"labels": {0: "person", 1: "bicycle", 2: "car", 3: "motorcycle",
5: "bus", 7: "truck", 17: "horse", 18: "sheep", 19: "cow"},
},
"People Only": {
"class_ids": [0],
"labels": {0: "person"},
},
"Vehicles Only": {
"class_ids": [1, 2, 3, 5, 7],
"labels": {1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"},
},
"Animals (Sheep/Cow/Horse)": {
"class_ids": [17, 18, 19],
"labels": {17: "horse", 18: "sheep", 19: "cow"},
},
"Sheep Only": {
"class_ids": [18],
"labels": {18: "sheep"},
},
}
MODEL_CACHE: Dict[str, object] = {}
def get_model(model_name: str):
"""Load and cache detection model."""
if model_name not in MODEL_CACHE:
model_map = {
"YOLOv8n (Fast)": ("yolov8n.pt", "yolo"),
"YOLOv8s (Balanced)": ("yolov8s.pt", "yolo"),
"YOLOv8m (Accurate)": ("yolov8m.pt", "yolo"),
"YOLOv8x (Best YOLO)": ("yolov8x.pt", "yolo"),
"RT-DETR-L (Dense Scenes)": ("rtdetr-l.pt", "rtdetr"),
}
model_file, model_type = model_map.get(model_name, ("yolov8s.pt", "yolo"))
if model_type == "rtdetr":
MODEL_CACHE[model_name] = RTDETR(model_file)
else:
MODEL_CACHE[model_name] = YOLO(model_file)
return MODEL_CACHE[model_name]
def point_side(point: Tuple[float, float], line: Tuple[Tuple[float, float], Tuple[float, float]]) -> float:
"""Return the sign of a point relative to a line using cross product."""
(x1, y1), (x2, y2) = line
x, y = point
return (x - x1) * (y2 - y1) - (y - y1) * (x2 - x1)
def crossed_line(prev_point: Tuple[float, float], curr_point: Tuple[float, float],
line: Tuple[Tuple[float, float], Tuple[float, float]]) -> bool:
"""Check if movement from prev_point to curr_point crosses the line."""
prev_side = point_side(prev_point, line)
curr_side = point_side(curr_point, line)
return prev_side * curr_side < 0
def bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of bounding box."""
x1, y1, x2, y2 = bbox
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
def determine_outside_side(line: Tuple[Tuple[float, float], Tuple[float, float]],
frame_height: int) -> float:
"""Determine which side of the line is 'outside' based on line position."""
(x1, y1), (x2, y2) = line
mid_y = (y1 + y2) / 2.0
mid_x = (x1 + x2) / 2.0
# If line is in upper half, outside is above (y=0)
# If line is in lower half, outside is below (y=height)
if mid_y < frame_height / 2.0:
reference_point = (mid_x, 0.0)
else:
reference_point = (mid_x, float(frame_height))
return point_side(reference_point, line)
@spaces.GPU(duration=180)
def process_video(
video_path: str,
detection_model: str,
detection_mode: str,
confidence: float,
line_position: float,
track_buffer: int,
activation_threshold: float,
):
"""Process video with advanced tracking and counting."""
if video_path is None:
return None, "Please upload a video file."
# Get model and detection config
model = get_model(detection_model)
mode_config = DETECTION_MODES.get(detection_mode, DETECTION_MODES["All Objects (Street)"])
target_class_ids = set(mode_config["class_ids"])
class_labels = mode_config["labels"]
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "Failed to open video file."
fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Setup output video
output_path = tempfile.mktemp(suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# Initialize tracker with optimized parameters for dense scenes
tracker = sv.ByteTrack(
track_activation_threshold=activation_threshold,
lost_track_buffer=track_buffer,
minimum_matching_threshold=0.7,
frame_rate=fps,
)
# Setup counting line (absolute coordinates)
line_y = int(height * line_position)
line_start = (0, line_y)
line_end = (width, line_y)
abs_line = ((0.0, float(line_y)), (float(width), float(line_y)))
outside_side = determine_outside_side(abs_line, height)
# Annotators
box_annotator = sv.BoxAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_scale=0.4, text_thickness=1)
trace_annotator = sv.TraceAnnotator(thickness=1, trace_length=50)
# Tracking state
track_last_center: Dict[int, Tuple[float, float]] = {}
track_class: Dict[int, str] = {}
counted_tracks: set = set()
# Counters
total_in, total_out = 0, 0
class_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: {"in": 0, "out": 0})
frame_idx = 0
max_simultaneous = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Run detection
results = model.predict(frame, conf=confidence, verbose=False)[0]
# Filter detections by target classes
boxes = results.boxes
if boxes is not None and len(boxes) > 0:
mask = np.array([int(cls) in target_class_ids for cls in boxes.cls])
if mask.any():
filtered_boxes = boxes[mask]
detections = sv.Detections(
xyxy=filtered_boxes.xyxy.cpu().numpy(),
confidence=filtered_boxes.conf.cpu().numpy(),
class_id=filtered_boxes.cls.cpu().numpy().astype(int),
)
else:
detections = sv.Detections.empty()
else:
detections = sv.Detections.empty()
# Track objects
detections = tracker.update_with_detections(detections)
# Update max simultaneous count
if len(detections) > max_simultaneous:
max_simultaneous = len(detections)
# Check line crossings with proper geometry
if detections.tracker_id is not None:
for idx in range(len(detections)):
track_id = int(detections.tracker_id[idx])
x1, y1, x2, y2 = detections.xyxy[idx]
class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
class_name = class_labels.get(class_id, f"class_{class_id}")
current_center = bbox_center((int(x1), int(y1), int(x2), int(y2)))
track_class[track_id] = class_name
if track_id in track_last_center and track_id not in counted_tracks:
prev_center = track_last_center[track_id]
if crossed_line(prev_center, current_center, abs_line):
prev_side = point_side(prev_center, abs_line)
curr_side = point_side(current_center, abs_line)
# Determine direction based on which side is "outside"
if prev_side * outside_side >= 0 and curr_side * outside_side < 0:
total_in += 1
class_counts[class_name]["in"] += 1
elif prev_side * outside_side < 0 and curr_side * outside_side >= 0:
total_out += 1
class_counts[class_name]["out"] += 1
counted_tracks.add(track_id)
track_last_center[track_id] = current_center
# Annotate frame
annotated = frame.copy()
# Draw counting line
cv2.line(annotated, line_start, line_end, (0, 0, 255), 3)
cv2.putText(annotated, "COUNTING LINE", (10, line_y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# Draw traces, boxes, and labels
annotated = trace_annotator.annotate(annotated, detections)
annotated = box_annotator.annotate(annotated, detections)
labels = []
if detections.tracker_id is not None:
for idx in range(len(detections)):
class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
class_name = class_labels.get(class_id, f"class_{class_id}")
track_id = int(detections.tracker_id[idx])
labels.append(f"{class_name} #{track_id}")
annotated = label_annotator.annotate(annotated, detections, labels)
# Draw stats overlay
overlay_h = 80
cv2.rectangle(annotated, (5, 5), (300, overlay_h), (0, 0, 0), -1)
cv2.putText(annotated, f"IN: {total_in} | OUT: {total_out}", (15, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(annotated, f"Net: {total_in - total_out} | Now: {len(detections)}", (15, 55),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
cv2.putText(annotated, f"Frame: {frame_idx}/{total_frames}", (15, 75),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
out.write(annotated)
frame_idx += 1
cap.release()
out.release()
# Convert to H.264 for browser compatibility
final_path = tempfile.mktemp(suffix=".mp4")
os.system(f'ffmpeg -y -i {output_path} -c:v libx264 -preset fast -crf 23 {final_path} -loglevel quiet')
if os.path.exists(final_path) and os.path.getsize(final_path) > 0:
os.remove(output_path)
output_path = final_path
# Generate statistics report
unique_tracks = len(track_last_center)
stats = "## Counting Results\n\n"
stats += f"**Total Entered:** {total_in}\n"
stats += f"**Total Exited:** {total_out}\n"
stats += f"**Net Count:** {total_in - total_out}\n"
stats += f"**Unique Tracks:** {unique_tracks}\n"
stats += f"**Max Simultaneous:** {max_simultaneous}\n\n"
if class_counts:
stats += "### By Class\n"
for cls, counts in sorted(class_counts.items()):
net = counts['in'] - counts['out']
stats += f"- **{cls}**: IN={counts['in']}, OUT={counts['out']}, Net={net}\n"
stats += f"\n### Video Info\n"
stats += f"- Frames: {frame_idx}\n"
stats += f"- Resolution: {width}x{height}\n"
stats += f"- FPS: {fps}\n"
return output_path, stats
# Build Gradio interface
with gr.Blocks(analytics_enabled=False, title="CCTV Customer Analytics") as demo:
gr.Markdown("""
# CCTV Customer Analytics
Advanced object detection, tracking, and counting system.
Optimized for counting large numbers of animals and vehicles in crowded scenes.
**Tips for best results:**
- Use **RT-DETR** model for dense/crowded scenes (sheep flocks, traffic)
- Lower **confidence** (0.15-0.25) to detect more objects
- Increase **track buffer** (60-90) for objects that temporarily disappear
- Adjust **line position** to where objects cross most clearly
""")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="Upload Video")
model_dropdown = gr.Dropdown(
choices=[
"YOLOv8n (Fast)",
"YOLOv8s (Balanced)",
"YOLOv8m (Accurate)",
"YOLOv8x (Best YOLO)",
"RT-DETR-L (Dense Scenes)",
],
value="YOLOv8s (Balanced)",
label="Detection Model",
)
mode_dropdown = gr.Dropdown(
choices=list(DETECTION_MODES.keys()),
value="All Objects (Street)",
label="Detection Mode",
)
confidence_slider = gr.Slider(
0.05, 0.9, value=0.25, step=0.05,
label="Confidence Threshold",
info="Lower = more detections, higher = fewer false positives"
)
line_slider = gr.Slider(
0.1, 0.9, value=0.5, step=0.05,
label="Line Position",
info="Vertical position of counting line (0=top, 1=bottom)"
)
with gr.Accordion("Advanced Tracking Settings", open=False):
track_buffer = gr.Slider(
10, 120, value=45, step=5,
label="Track Buffer",
info="Frames to keep lost tracks (higher for crowded scenes)"
)
activation_threshold = gr.Slider(
0.1, 0.5, value=0.2, step=0.05,
label="Track Activation Threshold",
info="Lower = easier to start new tracks"
)
submit_btn = gr.Button("Process Video", variant="primary", size="lg")
with gr.Column(scale=1):
video_output = gr.Video(label="Processed Video")
stats_output = gr.Markdown(label="Statistics")
submit_btn.click(
fn=process_video,
inputs=[
video_input, model_dropdown, mode_dropdown,
confidence_slider, line_slider, track_buffer, activation_threshold
],
outputs=[video_output, stats_output],
api_name=False,
)
gr.Markdown("""
---
**Models:**
- **YOLOv8n/s/m/x**: General purpose, good for most scenarios
- **RT-DETR-L**: Transformer-based, better for dense/crowded scenes (recommended for sheep counting)
**Detection Modes:**
- **All Objects**: People + vehicles + animals
- **Animals**: Sheep, cows, horses
- **Sheep Only**: Optimized for sheep counting
""")
if __name__ == "__main__":
demo.launch()