Spaces:
Sleeping
Sleeping
| """ | |
| realtime_engine.py β Real-Time Traffic Control Engine | |
| Converts uploaded videos into infinite "live streams" | |
| Processes frame-by-frame with YOLO + MARL decisions | |
| Ready to plug into Streamlit dashboard | |
| Features: | |
| - Video loop generators (infinite streams) | |
| - Parallel YOLO detection (3-5x faster) | |
| - Rolling average for smooth decisions | |
| - Stability constraints (min green time) | |
| - Real-time metrics | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import threading | |
| import time | |
| from collections import deque | |
| from queue import Queue | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Video Stream Generators (Infinite Loop) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class VideoStreamGenerator: | |
| """Generator that loops a video indefinitely (acts like live camera)""" | |
| def __init__(self, video_path: str, resize_width: int = 640, resize_height: int = 360): | |
| self.video_path = video_path | |
| self.resize_width = resize_width | |
| self.resize_height = resize_height | |
| self.frame_count = 0 | |
| self.cap = None | |
| self._init_capture() | |
| def _init_capture(self): | |
| self.cap = cv2.VideoCapture(self.video_path) | |
| if not self.cap.isOpened(): | |
| raise RuntimeError(f"Cannot open video: {self.video_path}") | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| """Get next frame (loops infinitely)""" | |
| while True: | |
| ret, frame = self.cap.read() | |
| if not ret: | |
| # Restart video (loop) | |
| self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) | |
| continue | |
| # Resize for faster processing | |
| frame = cv2.resize(frame, (self.resize_width, self.resize_height)) | |
| self.frame_count += 1 | |
| return frame, self.frame_count | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Fast YOLO Detection with Frame Skipping | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FastYOLODetector: | |
| """ | |
| YOLOv8 detector tuned for dense Indian traffic (high two-wheeler density). | |
| Key fixes vs original: | |
| - Motorcycle aspect ratio check now uses OR logic (wide OR large area OR tall-but-confident) | |
| - Bicycle accepts aspect ratios from 0.35 (covers head-on narrow bikes) | |
| - Height caps raised to 180px/160px (closer cameras show taller bounding boxes) | |
| - Added soft NMS via IoU deduplication to remove ghost detections | |
| - Class-specific confidence done AFTER YOLO inference at 0.05 (unchanged, correct) | |
| """ | |
| def __init__(self, model_path: str = "yolov8n.pt", skip_frames: int = 2, conf: float = 0.5): | |
| self.model_path = model_path | |
| self.skip_frames = skip_frames | |
| self.conf = conf | |
| self.frame_counter = 0 | |
| self.last_detections = None | |
| # YOLO COCO classes | |
| self.vehicle_classes = [2, 5, 7] # Car, Bus, Truck | |
| self.two_wheeler_classes = [1, 3] # Bicycle, Motorcycle | |
| self.person_class = 0 # MUST EXCLUDE | |
| # Per-class confidence thresholds (applied after YOLO runs at 0.05) | |
| self.conf_car = 0.25 | |
| self.conf_two_wheeler = 0.12 # Very lenient β dense traffic, small boxes | |
| self.conf_bus_truck = 0.30 | |
| # Minimum bounding box areas (pixelsΒ²) β small values intentional | |
| self.min_area_motorcycle = 5 | |
| self.min_area_bicycle = 4 | |
| self.min_area_car = 12 | |
| self.min_area_bus_truck = 25 | |
| # FIX #1: Raised height caps β original 120/100 was too strict for close cameras | |
| self.max_height_motorcycle = 180 | |
| self.max_height_bicycle = 160 | |
| # IoU threshold for deduplication | |
| self.nms_iou_threshold = 0.45 | |
| try: | |
| from ultralytics import YOLO | |
| self.model = YOLO(model_path, verbose=False) | |
| except Exception as e: | |
| raise RuntimeError(f"Cannot load YOLOv8: {e}") | |
| # ------------------------------------------------------------------ | |
| def _box_stats(self, box): | |
| """Return (area, aspect_ratio w/h, height) from a YOLO box.""" | |
| w, h = float(box.xywh[0][2]), float(box.xywh[0][3]) | |
| area = w * h | |
| ar = w / h if h > 0 else 1.0 | |
| return area, ar, h | |
| # ------------------------------------------------------------------ | |
| def _is_valid_motorcycle(self, ar: float, area: float, height: float, conf: float) -> bool: | |
| """ | |
| FIX #1 β original used AND logic that was too strict. | |
| A motorcycle is valid if ANY of these are true: | |
| a) Wide box (ar >= 0.6) β side view, typical | |
| b) Large enough area (area >= 80) β even a tall narrow box is probably real | |
| c) High confidence (conf >= 0.35) regardless of shape β model is sure | |
| d) NOT a person: people are very tall and narrow (ar < 0.4) AND small area | |
| Person rejection: only reject if BOTH ar < 0.4 AND area < 60 | |
| (people at a distance are small AND narrow; motorcycles are rarely both) | |
| """ | |
| # Hard reject: looks like a distant person silhouette | |
| if ar < 0.4 and area < 60: | |
| return False | |
| # Hard reject: implausibly tall (taller than 1.8Γ the max camera height cap) | |
| if height > self.max_height_motorcycle * 1.8: | |
| return False | |
| # Minimum area | |
| if area < self.min_area_motorcycle: | |
| return False | |
| # Accept if wide OR large OR confident | |
| if ar >= 0.6 or area >= 80 or conf >= 0.35: | |
| return True | |
| # Edge case: small, narrow, low-conf β skip | |
| return False | |
| # ------------------------------------------------------------------ | |
| def _is_valid_bicycle(self, ar: float, area: float, height: float, conf: float) -> bool: | |
| """ | |
| FIX #2 β original lower bound of 0.5 cut head-on bicycles (ar ~0.35). | |
| Accept if: | |
| - ar >= 0.35 (head-on narrow is OK) | |
| - area >= min_area_bicycle | |
| - NOT person silhouette: ar < 0.35 AND area < 50 AND conf < 0.3 | |
| """ | |
| if ar < 0.35 and area < 50 and conf < 0.3: | |
| return False # Looks like a very narrow distant person | |
| if height > self.max_height_bicycle * 1.8: | |
| return False | |
| if area < self.min_area_bicycle: | |
| return False | |
| return 0.35 <= ar <= 2.5 | |
| # ------------------------------------------------------------------ | |
| def _should_keep(self, cls_id: int, conf: float, area: float, | |
| ar: float, height: float) -> bool: | |
| """Master gate: exclude people, apply per-class rules.""" | |
| if cls_id == self.person_class: | |
| return False | |
| if cls_id not in self.vehicle_classes and cls_id not in self.two_wheeler_classes: | |
| return False | |
| if cls_id == 2: # Car | |
| return conf >= self.conf_car and area >= self.min_area_car | |
| if cls_id in [5, 7]: # Bus, Truck | |
| return conf >= self.conf_bus_truck and area >= self.min_area_bus_truck | |
| if cls_id == 3: # Motorcycle | |
| return conf >= self.conf_two_wheeler and self._is_valid_motorcycle(ar, area, height, conf) | |
| if cls_id == 1: # Bicycle | |
| return conf >= self.conf_two_wheeler and self._is_valid_bicycle(ar, area, height, conf) | |
| return False | |
| # ------------------------------------------------------------------ | |
| def _nms_deduplicate(self, detections: list) -> list: | |
| """ | |
| FIX #3 β Remove overlapping detections of the same class. | |
| Simple IoU-based NMS (already sorted by confidence descending). | |
| """ | |
| if not detections: | |
| return detections | |
| kept = [] | |
| used = [False] * len(detections) | |
| for i, d in enumerate(detections): | |
| if used[i]: | |
| continue | |
| kept.append(d) | |
| x1i, y1i = d['x'] - d['w'] / 2, d['y'] - d['h'] / 2 | |
| x2i, y2i = d['x'] + d['w'] / 2, d['y'] + d['h'] / 2 | |
| for j in range(i + 1, len(detections)): | |
| if used[j]: | |
| continue | |
| e = detections[j] | |
| # Only suppress same class | |
| if e['class'] != d['class']: | |
| continue | |
| x1j, y1j = e['x'] - e['w'] / 2, e['y'] - e['h'] / 2 | |
| x2j, y2j = e['x'] + e['w'] / 2, e['y'] + e['h'] / 2 | |
| ix1, iy1 = max(x1i, x1j), max(y1i, y1j) | |
| ix2, iy2 = min(x2i, x2j), min(y2i, y2j) | |
| inter = max(0, ix2 - ix1) * max(0, iy2 - iy1) | |
| union = (x2i - x1i) * (y2i - y1i) + (x2j - x1j) * (y2j - y1j) - inter | |
| if union > 0 and inter / union > self.nms_iou_threshold: | |
| used[j] = True | |
| return kept | |
| # ------------------------------------------------------------------ | |
| def detect(self, frame: np.ndarray, lane_region: tuple = None) -> int: | |
| """ | |
| Detect vehicles. Returns count (for compatibility with calling code). | |
| """ | |
| self.frame_counter += 1 | |
| if self.frame_counter % self.skip_frames != 0: | |
| return self.last_detections if self.last_detections is not None else 0 | |
| try: | |
| results = self.model(frame, verbose=False, conf=0.05) | |
| detections = [] | |
| for result in results: | |
| for box in result.boxes: | |
| cls_id = int(box.cls) | |
| conf = float(box.conf) | |
| area, ar, height = self._box_stats(box) | |
| if not self._should_keep(cls_id, conf, area, ar, height): | |
| continue | |
| x, y = int(box.xywh[0][0]), int(box.xywh[0][1]) | |
| w, h = float(box.xywh[0][2]), float(box.xywh[0][3]) | |
| if lane_region: | |
| x1r, y1r, x2r, y2r = lane_region | |
| if not (x1r <= x <= x2r and y1r <= y <= y2r): | |
| continue | |
| detections.append({ | |
| 'x': x, 'y': y, 'w': w, 'h': h, | |
| 'conf': min(conf, 1.0), | |
| 'class': cls_id, | |
| 'size': area, | |
| }) | |
| # Sort by confidence descending, then deduplicate | |
| detections.sort(key=lambda d: d['conf'], reverse=True) | |
| detections = self._nms_deduplicate(detections) | |
| # Return count | |
| self.last_detections = min(len(detections), 120) | |
| return self.last_detections | |
| except Exception: | |
| return self.last_detections if self.last_detections is not None else 0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Real-Time Queue Tracking with Rolling Average | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class QueueTracker: | |
| """Smooth queue estimates using rolling window""" | |
| def __init__(self, window_size: int = 10): | |
| self.window_size = window_size | |
| self.history = { | |
| 'N': deque(maxlen=window_size), | |
| 'S': deque(maxlen=window_size), | |
| 'E': deque(maxlen=window_size), | |
| 'W': deque(maxlen=window_size), | |
| } | |
| def update(self, raw_counts: dict) -> dict: | |
| """ | |
| Update with new counts and return smoothed estimates | |
| Args: | |
| raw_counts: {'N': count, 'S': count, ...} | |
| Returns: | |
| Smoothed counts (rolling average) | |
| """ | |
| for lane, count in raw_counts.items(): | |
| self.history[lane].append(count) | |
| # Return rolling average | |
| smoothed = {} | |
| for lane in self.history: | |
| if len(self.history[lane]) > 0: | |
| smoothed[lane] = int(np.mean(list(self.history[lane]))) | |
| else: | |
| smoothed[lane] = 0 | |
| return smoothed | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stability Controller (Min Green Time) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SignalStabilizer: | |
| """Prevents rapid signal flipping""" | |
| def __init__(self, min_green_time: float = 5.0): | |
| self.min_green_time = min_green_time | |
| self.current_phase = None | |
| self.phase_start_time = None | |
| def should_switch(self, new_phase: str) -> bool: | |
| """Check if enough time has passed to switch phases""" | |
| if self.current_phase is None: | |
| # First phase | |
| self.current_phase = new_phase | |
| self.phase_start_time = time.time() | |
| return True | |
| if new_phase == self.current_phase: | |
| # No switch needed | |
| return False | |
| # Check min time elapsed | |
| elapsed = time.time() - self.phase_start_time | |
| if elapsed >= self.min_green_time: | |
| self.current_phase = new_phase | |
| self.phase_start_time = time.time() | |
| return True | |
| # Keep current phase (too soon to switch) | |
| return False | |
| def get_current_phase(self) -> str: | |
| return self.current_phase if self.current_phase else "NS_GREEN" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Parallel Stream Reader (Threading) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ParallelStreamReader: | |
| """Read frames from 4 video streams in parallel""" | |
| def __init__(self, video_paths: dict): | |
| """ | |
| Args: | |
| video_paths: {'N': path, 'S': path, 'E': path, 'W': path} | |
| """ | |
| self.streams = { | |
| lane: VideoStreamGenerator(path) | |
| for lane, path in video_paths.items() | |
| } | |
| self.frames = {'N': None, 'S': None, 'E': None, 'W': None} | |
| self.frame_numbers = {'N': 0, 'S': 0, 'E': 0, 'W': 0} | |
| self.running = False | |
| self.threads = {} | |
| def start(self): | |
| """Start reading threads""" | |
| self.running = True | |
| for lane in self.streams: | |
| thread = threading.Thread(target=self._read_stream, args=(lane,), daemon=True) | |
| thread.start() | |
| self.threads[lane] = thread | |
| def _read_stream(self, lane: str): | |
| """Background thread reading frames from one lane""" | |
| for frame, frame_num in self.streams[lane]: | |
| if not self.running: | |
| break | |
| self.frames[lane] = frame | |
| self.frame_numbers[lane] = frame_num | |
| def get_frames(self) -> dict: | |
| """Get current frames from all lanes""" | |
| return self.frames.copy() | |
| def stop(self): | |
| """Stop reading threads""" | |
| self.running = False | |
| for thread in self.threads.values(): | |
| thread.join(timeout=1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Real-Time Decision Engine | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RealtimeDecisionEngine: | |
| """Main engine coordinating all components""" | |
| def __init__(self, video_paths: dict, skip_frames: int = 2, min_green_time: float = 5.0): | |
| """ | |
| Args: | |
| video_paths: {'N': path, 'S': path, 'E': path, 'W': path} | |
| skip_frames: Process every Nth frame (for speed) | |
| min_green_time: Minimum seconds to keep phase (for stability) | |
| """ | |
| self.stream_reader = ParallelStreamReader(video_paths) | |
| self.detector = FastYOLODetector(skip_frames=skip_frames) | |
| self.queue_tracker = QueueTracker(window_size=10) | |
| self.stabilizer = SignalStabilizer(min_green_time=min_green_time) | |
| # Lane regions (for detection) | |
| self.lane_regions = { | |
| 'N': (0, 0, 640, 180), | |
| 'S': (0, 180, 640, 360), | |
| 'E': None, # Full frame | |
| 'W': None, | |
| } | |
| # Import agents | |
| self.agents = self._load_agents() | |
| self.sim = None | |
| self._load_simulation() | |
| # Metrics | |
| self.metrics = { | |
| 'frame_count': 0, | |
| 'detection_time': 0, | |
| 'decision_time': 0, | |
| 'current_queues': {'N': 0, 'S': 0, 'E': 0, 'W': 0}, | |
| 'current_phase': 'NS_GREEN', | |
| 'agent_votes': {'N': 0, 'S': 0, 'E': 0, 'W': 0}, | |
| } | |
| def _load_agents(self): | |
| """Load trained PPO agents""" | |
| agents = {} | |
| try: | |
| from stable_baselines3 import PPO | |
| for lane in ['N', 'S', 'E', 'W']: | |
| try: | |
| agents[lane] = PPO.load(f"agent_{lane}.zip") | |
| except: | |
| pass | |
| except: | |
| pass | |
| return agents | |
| def _load_simulation(self): | |
| """Load simulation environment""" | |
| try: | |
| from traffic_env import IntersectionSimulator | |
| self.sim = IntersectionSimulator() | |
| self.sim.reset() | |
| except: | |
| pass | |
| def process_frame(self) -> dict: | |
| """ | |
| Process one frame from all lanes | |
| Returns: Decision info | |
| """ | |
| start_time = time.time() | |
| # 1. Get frames from all lanes | |
| raw_frames = self.stream_reader.get_frames() | |
| # 2. Detect vehicles (parallel in threads) | |
| detection_start = time.time() | |
| raw_counts = {} | |
| for lane, frame in raw_frames.items(): | |
| if frame is not None: | |
| region = self.lane_regions.get(lane) | |
| raw_counts[lane] = self.detector.detect(frame, region) | |
| else: | |
| raw_counts[lane] = 0 | |
| self.metrics['detection_time'] = time.time() - detection_start | |
| # 3. Smooth counts with rolling average | |
| smoothed_counts = self.queue_tracker.update(raw_counts) | |
| self.metrics['current_queues'] = smoothed_counts.copy() | |
| # 4. Get MARL agent decisions | |
| decision_start = time.time() | |
| agent_votes = self._get_agent_votes(smoothed_counts) | |
| self.metrics['agent_votes'] = agent_votes.copy() | |
| # 5. Get intersection manager decision | |
| phase = self._get_phase_decision(agent_votes, smoothed_counts) | |
| # 6. Check stability (min green time) | |
| if self.stabilizer.should_switch(phase): | |
| self.metrics['current_phase'] = phase | |
| else: | |
| phase = self.stabilizer.get_current_phase() | |
| self.metrics['decision_time'] = time.time() - decision_start | |
| self.metrics['frame_count'] += 1 | |
| return { | |
| 'phase': self.metrics['current_phase'], | |
| 'queues': smoothed_counts, | |
| 'votes': agent_votes, | |
| 'raw_counts': raw_counts, | |
| 'metrics': self.metrics.copy(), | |
| 'total_time_ms': (time.time() - start_time) * 1000, | |
| } | |
| def _get_agent_votes(self, queues: dict) -> dict: | |
| """Get votes from 4 MARL agents""" | |
| votes = {'N': 0, 'S': 0, 'E': 0, 'W': 0} | |
| if not self.agents or not self.sim: | |
| # Fallback: heuristic (request green if queue > 5) | |
| votes = {lane: 1 if queues[lane] > 5 else 0 for lane in queues} | |
| else: | |
| try: | |
| from traffic_env import AgentEnv | |
| self.sim.queues = queues | |
| for lane in ['N', 'S', 'E', 'W']: | |
| if lane in self.agents: | |
| env = AgentEnv(lane=lane, sim=self.sim) | |
| obs = env.get_obs() | |
| action, _ = self.agents[lane].predict(obs, deterministic=True) | |
| votes[lane] = int(action) | |
| else: | |
| votes[lane] = 1 if queues[lane] > 5 else 0 | |
| except: | |
| votes = {lane: 1 if queues[lane] > 5 else 0 for lane in queues} | |
| return votes | |
| def _get_phase_decision(self, votes: dict, queues: dict) -> str: | |
| """Get phase decision from intersection manager""" | |
| try: | |
| from traffic_env import intersection_manager | |
| phase = intersection_manager(votes, queues, False, None) | |
| return phase if phase else "NS_GREEN" | |
| except: | |
| # Fallback: max-pressure rule | |
| requesting = [lane for lane, vote in votes.items() if vote == 1] | |
| if not requesting: | |
| return "NS_GREEN" | |
| chosen = max(requesting, key=lambda l: queues[l]) | |
| if chosen in ['N', 'S']: | |
| return "NS_GREEN" | |
| else: | |
| return "EW_GREEN" | |
| def start(self): | |
| """Start the engine""" | |
| self.stream_reader.start() | |
| def stop(self): | |
| """Stop the engine""" | |
| self.stream_reader.stop() | |
| def get_metrics(self) -> dict: | |
| """Get current system metrics""" | |
| return self.metrics.copy() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Test / Standalone Usage | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import time | |
| print("π¦ Real-Time Traffic Control Engine") | |
| print("=" * 50) | |
| video_paths = { | |
| 'N': 'test_video_N.mp4', | |
| 'S': 'test_video_S.mp4', | |
| 'E': 'test_video_E.mp4', | |
| 'W': 'test_video_W.mp4', | |
| } | |
| # Initialize engine | |
| engine = RealtimeDecisionEngine(video_paths, skip_frames=2, min_green_time=3.0) | |
| engine.start() | |
| print("β Engine started. Processing frames...") | |
| print("=" * 50) | |
| # Run for 10 seconds | |
| start = time.time() | |
| frame_count = 0 | |
| while time.time() - start < 10: | |
| result = engine.process_frame() | |
| frame_count += 1 | |
| if frame_count % 10 == 0: # Print every 10 frames | |
| print(f"\nFrame {result['metrics']['frame_count']}") | |
| print(f" Queues: N={result['queues']['N']} S={result['queues']['S']} E={result['queues']['E']} W={result['queues']['W']}") | |
| print(f" Votes: N={result['votes']['N']} S={result['votes']['S']} E={result['votes']['E']} W={result['votes']['W']}") | |
| print(f" Phase: {result['phase']}") | |
| print(f" Time: {result['total_time_ms']:.1f}ms") | |
| time.sleep(0.05) # 20 FPS | |
| engine.stop() | |
| print("\nβ Engine stopped") | |