pranit / realtime_engine.py
RushiMane2003's picture
Upload 41 files
99f938a verified
"""
realtime_engine.py β€” Real-Time Traffic Control Engine
Converts uploaded videos into infinite "live streams"
Processes frame-by-frame with YOLO + MARL decisions
Ready to plug into Streamlit dashboard
Features:
- Video loop generators (infinite streams)
- Parallel YOLO detection (3-5x faster)
- Rolling average for smooth decisions
- Stability constraints (min green time)
- Real-time metrics
"""
import cv2
import numpy as np
import threading
import time
from collections import deque
from queue import Queue
import warnings
warnings.filterwarnings('ignore')
# ─────────────────────────────────────────────────────────────────
# Video Stream Generators (Infinite Loop)
# ─────────────────────────────────────────────────────────────────
class VideoStreamGenerator:
"""Generator that loops a video indefinitely (acts like live camera)"""
def __init__(self, video_path: str, resize_width: int = 640, resize_height: int = 360):
self.video_path = video_path
self.resize_width = resize_width
self.resize_height = resize_height
self.frame_count = 0
self.cap = None
self._init_capture()
def _init_capture(self):
self.cap = cv2.VideoCapture(self.video_path)
if not self.cap.isOpened():
raise RuntimeError(f"Cannot open video: {self.video_path}")
def __iter__(self):
return self
def __next__(self):
"""Get next frame (loops infinitely)"""
while True:
ret, frame = self.cap.read()
if not ret:
# Restart video (loop)
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
# Resize for faster processing
frame = cv2.resize(frame, (self.resize_width, self.resize_height))
self.frame_count += 1
return frame, self.frame_count
# ─────────────────────────────────────────────────────────────────
# Fast YOLO Detection with Frame Skipping
# ─────────────────────────────────────────────────────────────────
class FastYOLODetector:
"""
YOLOv8 detector tuned for dense Indian traffic (high two-wheeler density).
Key fixes vs original:
- Motorcycle aspect ratio check now uses OR logic (wide OR large area OR tall-but-confident)
- Bicycle accepts aspect ratios from 0.35 (covers head-on narrow bikes)
- Height caps raised to 180px/160px (closer cameras show taller bounding boxes)
- Added soft NMS via IoU deduplication to remove ghost detections
- Class-specific confidence done AFTER YOLO inference at 0.05 (unchanged, correct)
"""
def __init__(self, model_path: str = "yolov8n.pt", skip_frames: int = 2, conf: float = 0.5):
self.model_path = model_path
self.skip_frames = skip_frames
self.conf = conf
self.frame_counter = 0
self.last_detections = None
# YOLO COCO classes
self.vehicle_classes = [2, 5, 7] # Car, Bus, Truck
self.two_wheeler_classes = [1, 3] # Bicycle, Motorcycle
self.person_class = 0 # MUST EXCLUDE
# Per-class confidence thresholds (applied after YOLO runs at 0.05)
self.conf_car = 0.25
self.conf_two_wheeler = 0.12 # Very lenient β€” dense traffic, small boxes
self.conf_bus_truck = 0.30
# Minimum bounding box areas (pixelsΒ²) β€” small values intentional
self.min_area_motorcycle = 5
self.min_area_bicycle = 4
self.min_area_car = 12
self.min_area_bus_truck = 25
# FIX #1: Raised height caps β€” original 120/100 was too strict for close cameras
self.max_height_motorcycle = 180
self.max_height_bicycle = 160
# IoU threshold for deduplication
self.nms_iou_threshold = 0.45
try:
from ultralytics import YOLO
self.model = YOLO(model_path, verbose=False)
except Exception as e:
raise RuntimeError(f"Cannot load YOLOv8: {e}")
# ------------------------------------------------------------------
def _box_stats(self, box):
"""Return (area, aspect_ratio w/h, height) from a YOLO box."""
w, h = float(box.xywh[0][2]), float(box.xywh[0][3])
area = w * h
ar = w / h if h > 0 else 1.0
return area, ar, h
# ------------------------------------------------------------------
def _is_valid_motorcycle(self, ar: float, area: float, height: float, conf: float) -> bool:
"""
FIX #1 β€” original used AND logic that was too strict.
A motorcycle is valid if ANY of these are true:
a) Wide box (ar >= 0.6) β€” side view, typical
b) Large enough area (area >= 80) β€” even a tall narrow box is probably real
c) High confidence (conf >= 0.35) regardless of shape β€” model is sure
d) NOT a person: people are very tall and narrow (ar < 0.4) AND small area
Person rejection: only reject if BOTH ar < 0.4 AND area < 60
(people at a distance are small AND narrow; motorcycles are rarely both)
"""
# Hard reject: looks like a distant person silhouette
if ar < 0.4 and area < 60:
return False
# Hard reject: implausibly tall (taller than 1.8Γ— the max camera height cap)
if height > self.max_height_motorcycle * 1.8:
return False
# Minimum area
if area < self.min_area_motorcycle:
return False
# Accept if wide OR large OR confident
if ar >= 0.6 or area >= 80 or conf >= 0.35:
return True
# Edge case: small, narrow, low-conf β€” skip
return False
# ------------------------------------------------------------------
def _is_valid_bicycle(self, ar: float, area: float, height: float, conf: float) -> bool:
"""
FIX #2 β€” original lower bound of 0.5 cut head-on bicycles (ar ~0.35).
Accept if:
- ar >= 0.35 (head-on narrow is OK)
- area >= min_area_bicycle
- NOT person silhouette: ar < 0.35 AND area < 50 AND conf < 0.3
"""
if ar < 0.35 and area < 50 and conf < 0.3:
return False # Looks like a very narrow distant person
if height > self.max_height_bicycle * 1.8:
return False
if area < self.min_area_bicycle:
return False
return 0.35 <= ar <= 2.5
# ------------------------------------------------------------------
def _should_keep(self, cls_id: int, conf: float, area: float,
ar: float, height: float) -> bool:
"""Master gate: exclude people, apply per-class rules."""
if cls_id == self.person_class:
return False
if cls_id not in self.vehicle_classes and cls_id not in self.two_wheeler_classes:
return False
if cls_id == 2: # Car
return conf >= self.conf_car and area >= self.min_area_car
if cls_id in [5, 7]: # Bus, Truck
return conf >= self.conf_bus_truck and area >= self.min_area_bus_truck
if cls_id == 3: # Motorcycle
return conf >= self.conf_two_wheeler and self._is_valid_motorcycle(ar, area, height, conf)
if cls_id == 1: # Bicycle
return conf >= self.conf_two_wheeler and self._is_valid_bicycle(ar, area, height, conf)
return False
# ------------------------------------------------------------------
def _nms_deduplicate(self, detections: list) -> list:
"""
FIX #3 β€” Remove overlapping detections of the same class.
Simple IoU-based NMS (already sorted by confidence descending).
"""
if not detections:
return detections
kept = []
used = [False] * len(detections)
for i, d in enumerate(detections):
if used[i]:
continue
kept.append(d)
x1i, y1i = d['x'] - d['w'] / 2, d['y'] - d['h'] / 2
x2i, y2i = d['x'] + d['w'] / 2, d['y'] + d['h'] / 2
for j in range(i + 1, len(detections)):
if used[j]:
continue
e = detections[j]
# Only suppress same class
if e['class'] != d['class']:
continue
x1j, y1j = e['x'] - e['w'] / 2, e['y'] - e['h'] / 2
x2j, y2j = e['x'] + e['w'] / 2, e['y'] + e['h'] / 2
ix1, iy1 = max(x1i, x1j), max(y1i, y1j)
ix2, iy2 = min(x2i, x2j), min(y2i, y2j)
inter = max(0, ix2 - ix1) * max(0, iy2 - iy1)
union = (x2i - x1i) * (y2i - y1i) + (x2j - x1j) * (y2j - y1j) - inter
if union > 0 and inter / union > self.nms_iou_threshold:
used[j] = True
return kept
# ------------------------------------------------------------------
def detect(self, frame: np.ndarray, lane_region: tuple = None) -> int:
"""
Detect vehicles. Returns count (for compatibility with calling code).
"""
self.frame_counter += 1
if self.frame_counter % self.skip_frames != 0:
return self.last_detections if self.last_detections is not None else 0
try:
results = self.model(frame, verbose=False, conf=0.05)
detections = []
for result in results:
for box in result.boxes:
cls_id = int(box.cls)
conf = float(box.conf)
area, ar, height = self._box_stats(box)
if not self._should_keep(cls_id, conf, area, ar, height):
continue
x, y = int(box.xywh[0][0]), int(box.xywh[0][1])
w, h = float(box.xywh[0][2]), float(box.xywh[0][3])
if lane_region:
x1r, y1r, x2r, y2r = lane_region
if not (x1r <= x <= x2r and y1r <= y <= y2r):
continue
detections.append({
'x': x, 'y': y, 'w': w, 'h': h,
'conf': min(conf, 1.0),
'class': cls_id,
'size': area,
})
# Sort by confidence descending, then deduplicate
detections.sort(key=lambda d: d['conf'], reverse=True)
detections = self._nms_deduplicate(detections)
# Return count
self.last_detections = min(len(detections), 120)
return self.last_detections
except Exception:
return self.last_detections if self.last_detections is not None else 0
# ─────────────────────────────────────────────────────────────────
# Real-Time Queue Tracking with Rolling Average
# ─────────────────────────────────────────────────────────────────
class QueueTracker:
"""Smooth queue estimates using rolling window"""
def __init__(self, window_size: int = 10):
self.window_size = window_size
self.history = {
'N': deque(maxlen=window_size),
'S': deque(maxlen=window_size),
'E': deque(maxlen=window_size),
'W': deque(maxlen=window_size),
}
def update(self, raw_counts: dict) -> dict:
"""
Update with new counts and return smoothed estimates
Args:
raw_counts: {'N': count, 'S': count, ...}
Returns:
Smoothed counts (rolling average)
"""
for lane, count in raw_counts.items():
self.history[lane].append(count)
# Return rolling average
smoothed = {}
for lane in self.history:
if len(self.history[lane]) > 0:
smoothed[lane] = int(np.mean(list(self.history[lane])))
else:
smoothed[lane] = 0
return smoothed
# ─────────────────────────────────────────────────────────────────
# Stability Controller (Min Green Time)
# ─────────────────────────────────────────────────────────────────
class SignalStabilizer:
"""Prevents rapid signal flipping"""
def __init__(self, min_green_time: float = 5.0):
self.min_green_time = min_green_time
self.current_phase = None
self.phase_start_time = None
def should_switch(self, new_phase: str) -> bool:
"""Check if enough time has passed to switch phases"""
if self.current_phase is None:
# First phase
self.current_phase = new_phase
self.phase_start_time = time.time()
return True
if new_phase == self.current_phase:
# No switch needed
return False
# Check min time elapsed
elapsed = time.time() - self.phase_start_time
if elapsed >= self.min_green_time:
self.current_phase = new_phase
self.phase_start_time = time.time()
return True
# Keep current phase (too soon to switch)
return False
def get_current_phase(self) -> str:
return self.current_phase if self.current_phase else "NS_GREEN"
# ─────────────────────────────────────────────────────────────────
# Parallel Stream Reader (Threading)
# ─────────────────────────────────────────────────────────────────
class ParallelStreamReader:
"""Read frames from 4 video streams in parallel"""
def __init__(self, video_paths: dict):
"""
Args:
video_paths: {'N': path, 'S': path, 'E': path, 'W': path}
"""
self.streams = {
lane: VideoStreamGenerator(path)
for lane, path in video_paths.items()
}
self.frames = {'N': None, 'S': None, 'E': None, 'W': None}
self.frame_numbers = {'N': 0, 'S': 0, 'E': 0, 'W': 0}
self.running = False
self.threads = {}
def start(self):
"""Start reading threads"""
self.running = True
for lane in self.streams:
thread = threading.Thread(target=self._read_stream, args=(lane,), daemon=True)
thread.start()
self.threads[lane] = thread
def _read_stream(self, lane: str):
"""Background thread reading frames from one lane"""
for frame, frame_num in self.streams[lane]:
if not self.running:
break
self.frames[lane] = frame
self.frame_numbers[lane] = frame_num
def get_frames(self) -> dict:
"""Get current frames from all lanes"""
return self.frames.copy()
def stop(self):
"""Stop reading threads"""
self.running = False
for thread in self.threads.values():
thread.join(timeout=1)
# ─────────────────────────────────────────────────────────────────
# Real-Time Decision Engine
# ─────────────────────────────────────────────────────────────────
class RealtimeDecisionEngine:
"""Main engine coordinating all components"""
def __init__(self, video_paths: dict, skip_frames: int = 2, min_green_time: float = 5.0):
"""
Args:
video_paths: {'N': path, 'S': path, 'E': path, 'W': path}
skip_frames: Process every Nth frame (for speed)
min_green_time: Minimum seconds to keep phase (for stability)
"""
self.stream_reader = ParallelStreamReader(video_paths)
self.detector = FastYOLODetector(skip_frames=skip_frames)
self.queue_tracker = QueueTracker(window_size=10)
self.stabilizer = SignalStabilizer(min_green_time=min_green_time)
# Lane regions (for detection)
self.lane_regions = {
'N': (0, 0, 640, 180),
'S': (0, 180, 640, 360),
'E': None, # Full frame
'W': None,
}
# Import agents
self.agents = self._load_agents()
self.sim = None
self._load_simulation()
# Metrics
self.metrics = {
'frame_count': 0,
'detection_time': 0,
'decision_time': 0,
'current_queues': {'N': 0, 'S': 0, 'E': 0, 'W': 0},
'current_phase': 'NS_GREEN',
'agent_votes': {'N': 0, 'S': 0, 'E': 0, 'W': 0},
}
def _load_agents(self):
"""Load trained PPO agents"""
agents = {}
try:
from stable_baselines3 import PPO
for lane in ['N', 'S', 'E', 'W']:
try:
agents[lane] = PPO.load(f"agent_{lane}.zip")
except:
pass
except:
pass
return agents
def _load_simulation(self):
"""Load simulation environment"""
try:
from traffic_env import IntersectionSimulator
self.sim = IntersectionSimulator()
self.sim.reset()
except:
pass
def process_frame(self) -> dict:
"""
Process one frame from all lanes
Returns: Decision info
"""
start_time = time.time()
# 1. Get frames from all lanes
raw_frames = self.stream_reader.get_frames()
# 2. Detect vehicles (parallel in threads)
detection_start = time.time()
raw_counts = {}
for lane, frame in raw_frames.items():
if frame is not None:
region = self.lane_regions.get(lane)
raw_counts[lane] = self.detector.detect(frame, region)
else:
raw_counts[lane] = 0
self.metrics['detection_time'] = time.time() - detection_start
# 3. Smooth counts with rolling average
smoothed_counts = self.queue_tracker.update(raw_counts)
self.metrics['current_queues'] = smoothed_counts.copy()
# 4. Get MARL agent decisions
decision_start = time.time()
agent_votes = self._get_agent_votes(smoothed_counts)
self.metrics['agent_votes'] = agent_votes.copy()
# 5. Get intersection manager decision
phase = self._get_phase_decision(agent_votes, smoothed_counts)
# 6. Check stability (min green time)
if self.stabilizer.should_switch(phase):
self.metrics['current_phase'] = phase
else:
phase = self.stabilizer.get_current_phase()
self.metrics['decision_time'] = time.time() - decision_start
self.metrics['frame_count'] += 1
return {
'phase': self.metrics['current_phase'],
'queues': smoothed_counts,
'votes': agent_votes,
'raw_counts': raw_counts,
'metrics': self.metrics.copy(),
'total_time_ms': (time.time() - start_time) * 1000,
}
def _get_agent_votes(self, queues: dict) -> dict:
"""Get votes from 4 MARL agents"""
votes = {'N': 0, 'S': 0, 'E': 0, 'W': 0}
if not self.agents or not self.sim:
# Fallback: heuristic (request green if queue > 5)
votes = {lane: 1 if queues[lane] > 5 else 0 for lane in queues}
else:
try:
from traffic_env import AgentEnv
self.sim.queues = queues
for lane in ['N', 'S', 'E', 'W']:
if lane in self.agents:
env = AgentEnv(lane=lane, sim=self.sim)
obs = env.get_obs()
action, _ = self.agents[lane].predict(obs, deterministic=True)
votes[lane] = int(action)
else:
votes[lane] = 1 if queues[lane] > 5 else 0
except:
votes = {lane: 1 if queues[lane] > 5 else 0 for lane in queues}
return votes
def _get_phase_decision(self, votes: dict, queues: dict) -> str:
"""Get phase decision from intersection manager"""
try:
from traffic_env import intersection_manager
phase = intersection_manager(votes, queues, False, None)
return phase if phase else "NS_GREEN"
except:
# Fallback: max-pressure rule
requesting = [lane for lane, vote in votes.items() if vote == 1]
if not requesting:
return "NS_GREEN"
chosen = max(requesting, key=lambda l: queues[l])
if chosen in ['N', 'S']:
return "NS_GREEN"
else:
return "EW_GREEN"
def start(self):
"""Start the engine"""
self.stream_reader.start()
def stop(self):
"""Stop the engine"""
self.stream_reader.stop()
def get_metrics(self) -> dict:
"""Get current system metrics"""
return self.metrics.copy()
# ─────────────────────────────────────────────────────────────────
# Test / Standalone Usage
# ─────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import time
print("🚦 Real-Time Traffic Control Engine")
print("=" * 50)
video_paths = {
'N': 'test_video_N.mp4',
'S': 'test_video_S.mp4',
'E': 'test_video_E.mp4',
'W': 'test_video_W.mp4',
}
# Initialize engine
engine = RealtimeDecisionEngine(video_paths, skip_frames=2, min_green_time=3.0)
engine.start()
print("βœ“ Engine started. Processing frames...")
print("=" * 50)
# Run for 10 seconds
start = time.time()
frame_count = 0
while time.time() - start < 10:
result = engine.process_frame()
frame_count += 1
if frame_count % 10 == 0: # Print every 10 frames
print(f"\nFrame {result['metrics']['frame_count']}")
print(f" Queues: N={result['queues']['N']} S={result['queues']['S']} E={result['queues']['E']} W={result['queues']['W']}")
print(f" Votes: N={result['votes']['N']} S={result['votes']['S']} E={result['votes']['E']} W={result['votes']['W']}")
print(f" Phase: {result['phase']}")
print(f" Time: {result['total_time_ms']:.1f}ms")
time.sleep(0.05) # 20 FPS
engine.stop()
print("\nβœ“ Engine stopped")