Upload 2 files
Browse files- lane_detection.py +282 -0
- streamlit_app.py +280 -0
lane_detection.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import time
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
import torch
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from typing import Dict, List, Tuple
|
| 9 |
+
|
| 10 |
+
# LABEL_MAP = {
|
| 11 |
+
# 0: "auto",
|
| 12 |
+
# 1: "bus",
|
| 13 |
+
# 2: "car",
|
| 14 |
+
# 3: "motorcycle",
|
| 15 |
+
# 4: "mini-bus",
|
| 16 |
+
# 5: "scooter",
|
| 17 |
+
# 6: "truck",
|
| 18 |
+
# }
|
| 19 |
+
LABEL_MAP = {0: "auto", 1: "bus", 2: "car", 3: "electric-rickshaw", 4: "large-sized-truck",5:'medium-sized-truck',6:'motorbike',7:'small-sized-truck'}
|
| 20 |
+
|
| 21 |
+
def draw_text_with_background(
|
| 22 |
+
image,
|
| 23 |
+
text,
|
| 24 |
+
position,
|
| 25 |
+
font=cv2.FONT_HERSHEY_SIMPLEX,
|
| 26 |
+
font_scale=1,
|
| 27 |
+
font_thickness=2,
|
| 28 |
+
text_color=(255, 255, 255),
|
| 29 |
+
bg_color=(0, 0, 0),
|
| 30 |
+
padding=5,
|
| 31 |
+
):
|
| 32 |
+
"""Draw `text` on `image` with a filled rectangle behind it."""
|
| 33 |
+
(text_width, text_height), baseline = cv2.getTextSize(
|
| 34 |
+
text, font, font_scale, font_thickness
|
| 35 |
+
)
|
| 36 |
+
x, y = position
|
| 37 |
+
rect_y1 = y - text_height - padding - baseline // 2
|
| 38 |
+
rect_y2 = y + padding - baseline // 2
|
| 39 |
+
|
| 40 |
+
cv2.rectangle(
|
| 41 |
+
image,
|
| 42 |
+
(x, rect_y1),
|
| 43 |
+
(x + text_width + 2 * padding, rect_y2),
|
| 44 |
+
bg_color,
|
| 45 |
+
-1,
|
| 46 |
+
)
|
| 47 |
+
cv2.putText(
|
| 48 |
+
image,
|
| 49 |
+
text,
|
| 50 |
+
(x + padding, y - baseline // 2),
|
| 51 |
+
font,
|
| 52 |
+
font_scale,
|
| 53 |
+
text_color,
|
| 54 |
+
font_thickness,
|
| 55 |
+
cv2.LINE_AA,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_color_for_class(cls_id: int):
|
| 60 |
+
"""Deterministic bright color for each class index."""
|
| 61 |
+
np.random.seed(cls_id + 37)
|
| 62 |
+
return tuple(np.random.randint(100, 256, size=3).tolist())
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _inside(pt: Tuple[int, int], poly: np.ndarray) -> bool:
|
| 66 |
+
"""Point-in-polygon test using OpenCV (nonβzero if inside)."""
|
| 67 |
+
return cv2.pointPolygonTest(poly, pt, False) >= 0
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class YOLOVideoDetector:
|
| 71 |
+
"""
|
| 72 |
+
Detect objects on a video and count them **per region**.
|
| 73 |
+
|
| 74 |
+
* `regions`: Dict[int, List[Tuple[int,int]]], mapping region id (0,1, β¦) to
|
| 75 |
+
4+ vertices in *pixel* coordinates (clockwise or anticlockwise).
|
| 76 |
+
* For each frame, counts are stored in a DataFrame column named
|
| 77 |
+
`<label>_<region>` (e.g. `car_0`, `bus_1`).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
model_path: str,
|
| 83 |
+
video_path: str,
|
| 84 |
+
output_path: str,
|
| 85 |
+
regions: Dict[int, List[Tuple[int, int]]],
|
| 86 |
+
classes=None,
|
| 87 |
+
conf: float = 0.35,
|
| 88 |
+
scale_factor: float = 1.5,
|
| 89 |
+
):
|
| 90 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 91 |
+
print(f"Using device: {self.device}")
|
| 92 |
+
|
| 93 |
+
self.model = YOLO(model_path)
|
| 94 |
+
self.video_path = video_path
|
| 95 |
+
self.output_path = output_path
|
| 96 |
+
self.conf = conf
|
| 97 |
+
self.classes = classes
|
| 98 |
+
self.scale = scale_factor
|
| 99 |
+
|
| 100 |
+
# ββββββββ NEW ββββββββ
|
| 101 |
+
self.regions = {
|
| 102 |
+
rid: np.array(pts, np.int32) for rid, pts in regions.items() if pts
|
| 103 |
+
}
|
| 104 |
+
if not self.regions:
|
| 105 |
+
raise ValueError("`regions` cannot be empty β provide at least one polygon.")
|
| 106 |
+
|
| 107 |
+
# Prepare DataFrame columns once
|
| 108 |
+
self.df_columns = [
|
| 109 |
+
"Frame Number",
|
| 110 |
+
*[
|
| 111 |
+
f"{LABEL_MAP[c]}_{rid}"
|
| 112 |
+
for rid in self.regions
|
| 113 |
+
for c in LABEL_MAP.keys()
|
| 114 |
+
],
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
def process_video(self) -> pd.DataFrame:
|
| 119 |
+
cap = cv2.VideoCapture(self.video_path)
|
| 120 |
+
if not cap.isOpened():
|
| 121 |
+
raise ValueError(f"Cannot open video: {self.video_path}")
|
| 122 |
+
|
| 123 |
+
ok, first_frame_original = cap.read()
|
| 124 |
+
if not ok:
|
| 125 |
+
cap.release()
|
| 126 |
+
raise ValueError(f"Cannot read first frame from: {self.video_path}")
|
| 127 |
+
|
| 128 |
+
h_orig, w_orig = first_frame_original.shape[:2]
|
| 129 |
+
prediction_counter_df = pd.DataFrame(columns=self.df_columns)
|
| 130 |
+
|
| 131 |
+
first_frame_processed = first_frame_original
|
| 132 |
+
frame_was_rotated = False
|
| 133 |
+
|
| 134 |
+
if w_orig < h_orig:
|
| 135 |
+
print(
|
| 136 |
+
f"Original frame (h,w): ({h_orig}, {w_orig}). Portrait β rotating 90Β° CW."
|
| 137 |
+
)
|
| 138 |
+
first_frame_processed = cv2.rotate(
|
| 139 |
+
first_frame_original, cv2.ROTATE_90_CLOCKWISE
|
| 140 |
+
)
|
| 141 |
+
frame_was_rotated = True
|
| 142 |
+
else:
|
| 143 |
+
print(f"Original frame (h,w): ({h_orig}, {w_orig}). Processing as landscape.")
|
| 144 |
+
|
| 145 |
+
# ----------------------------------------------------------------
|
| 146 |
+
base_h, base_w = first_frame_processed.shape[:2]
|
| 147 |
+
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
|
| 148 |
+
|
| 149 |
+
out_w, out_h = int(base_w * self.scale), int(base_h * self.scale)
|
| 150 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 151 |
+
out = cv2.VideoWriter(self.output_path, fourcc, fps, (out_w, out_h))
|
| 152 |
+
|
| 153 |
+
prev_t = time.time()
|
| 154 |
+
frame_count = 1
|
| 155 |
+
frame_up = cv2.resize(
|
| 156 |
+
first_frame_processed, (out_w, out_h), interpolation=cv2.INTER_LINEAR
|
| 157 |
+
)
|
| 158 |
+
prev_t = self._process_and_write_frame(
|
| 159 |
+
frame_up, out, prev_t, prediction_counter_df, frame_count
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
while True:
|
| 163 |
+
ok, frame_original_loop = cap.read()
|
| 164 |
+
if not ok:
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
if frame_count % (fps // 2 or 1) == 0: # frame skipping @ β2 fps
|
| 168 |
+
frame_processed_loop = (
|
| 169 |
+
cv2.rotate(frame_original_loop, cv2.ROTATE_90_CLOCKWISE)
|
| 170 |
+
if frame_was_rotated
|
| 171 |
+
else frame_original_loop
|
| 172 |
+
)
|
| 173 |
+
frame_up = cv2.resize(
|
| 174 |
+
frame_processed_loop, (out_w, out_h), interpolation=cv2.INTER_LINEAR
|
| 175 |
+
)
|
| 176 |
+
prev_t = self._process_and_write_frame(
|
| 177 |
+
frame_up, out, prev_t, prediction_counter_df, frame_count
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
frame_count += 1
|
| 181 |
+
|
| 182 |
+
cap.release()
|
| 183 |
+
out.release()
|
| 184 |
+
cv2.destroyAllWindows()
|
| 185 |
+
print(f"Processed {frame_count} frames. Finished β {self.output_path}")
|
| 186 |
+
return prediction_counter_df.fillna(0)
|
| 187 |
+
|
| 188 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 189 |
+
def _process_and_write_frame(
|
| 190 |
+
self,
|
| 191 |
+
frame_up: np.ndarray,
|
| 192 |
+
out_writer: cv2.VideoWriter,
|
| 193 |
+
prev_t: float,
|
| 194 |
+
prediction_counter_df: pd.DataFrame,
|
| 195 |
+
frame_count: int,
|
| 196 |
+
) -> float:
|
| 197 |
+
"""Run YOLO on one frame, count per region, annotate, write, return timestamp."""
|
| 198 |
+
# Draw polygons first (scaled!)
|
| 199 |
+
scale_x = frame_up.shape[1] / (frame_up.shape[1] / self.scale)
|
| 200 |
+
scale_y = frame_up.shape[0] / (frame_up.shape[0] / self.scale)
|
| 201 |
+
for rid, poly in self.regions.items():
|
| 202 |
+
poly_up = (poly * [self.scale, self.scale]).astype(np.int32)
|
| 203 |
+
cv2.polylines(frame_up, [poly_up], True, (255, 255, 0), 2)
|
| 204 |
+
draw_text_with_background(frame_up, f"R{rid}", tuple(poly_up[0]), font_scale=0.8)
|
| 205 |
+
|
| 206 |
+
results = self.model.predict(
|
| 207 |
+
frame_up,
|
| 208 |
+
conf=self.conf,
|
| 209 |
+
classes=self.classes,
|
| 210 |
+
verbose=False,
|
| 211 |
+
device=self.device,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# counts[region][cls_id] β int
|
| 215 |
+
counts: Dict[int, Dict[int, int]] = {
|
| 216 |
+
rid: defaultdict(int) for rid in self.regions
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
if results and len(results[0].boxes):
|
| 220 |
+
xyxy = results[0].boxes.xyxy.cpu().numpy()
|
| 221 |
+
scores = results[0].boxes.conf.cpu().numpy()
|
| 222 |
+
cls_ids = results[0].boxes.cls.int().cpu().tolist()
|
| 223 |
+
|
| 224 |
+
for (x1, y1, x2, y2), score, cls_id in zip(xyxy, scores, cls_ids):
|
| 225 |
+
color = get_color_for_class(cls_id)
|
| 226 |
+
cv2.rectangle(
|
| 227 |
+
frame_up, (int(x1), int(y1)), (int(x2), int(y2)), color, 2
|
| 228 |
+
)
|
| 229 |
+
label = LABEL_MAP.get(cls_id, f"Class {cls_id}")
|
| 230 |
+
draw_text_with_background(
|
| 231 |
+
frame_up,
|
| 232 |
+
f"{label}: {score:.2f}",
|
| 233 |
+
(int(x1), int(y1) - 10),
|
| 234 |
+
font_scale=0.6,
|
| 235 |
+
font_thickness=1,
|
| 236 |
+
bg_color=color,
|
| 237 |
+
padding=3,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Region assignment based on *centre* of the box
|
| 241 |
+
cx, cy = int((x1 + x2) / 2), int((y1 + y2) / 2)
|
| 242 |
+
for rid, poly in self.regions.items():
|
| 243 |
+
poly_up = (poly * [self.scale, self.scale]).astype(np.int32)
|
| 244 |
+
if _inside((cx, cy), poly_up):
|
| 245 |
+
counts[rid][cls_id] += 1
|
| 246 |
+
break # one region per detection
|
| 247 |
+
|
| 248 |
+
# βββ Overlay perβregion counts + update DataFrame βββ
|
| 249 |
+
df_idx = len(prediction_counter_df)
|
| 250 |
+
prediction_counter_df.at[df_idx, "Frame Number"] = frame_count
|
| 251 |
+
|
| 252 |
+
y_off = 30
|
| 253 |
+
for rid, cls_dict in counts.items():
|
| 254 |
+
for cls_id, cnt in cls_dict.items():
|
| 255 |
+
label = LABEL_MAP.get(cls_id, f"Class {cls_id}")
|
| 256 |
+
col_name = f"{label}_{rid}"
|
| 257 |
+
prediction_counter_df.at[df_idx, col_name] = cnt
|
| 258 |
+
draw_text_with_background(
|
| 259 |
+
frame_up,
|
| 260 |
+
f"{label}_{rid}: {cnt}",
|
| 261 |
+
(10, y_off),
|
| 262 |
+
font_scale=0.7,
|
| 263 |
+
font_thickness=2,
|
| 264 |
+
padding=6,
|
| 265 |
+
)
|
| 266 |
+
y_off += 25
|
| 267 |
+
|
| 268 |
+
# FPS overlay
|
| 269 |
+
now = time.time()
|
| 270 |
+
fps_live = 1.0 / (now - prev_t) if (now - prev_t) > 0 else 0.0
|
| 271 |
+
draw_text_with_background(
|
| 272 |
+
frame_up,
|
| 273 |
+
f"FPS: {fps_live:.1f}",
|
| 274 |
+
(10, frame_up.shape[0] - 20),
|
| 275 |
+
bg_color=(0, 0, 0),
|
| 276 |
+
text_color=(0, 255, 0),
|
| 277 |
+
font_scale=0.8,
|
| 278 |
+
padding=4,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
out_writer.write(frame_up)
|
| 282 |
+
return now
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from tempfile import NamedTemporaryFile
|
| 9 |
+
from typing import List, Tuple, Dict
|
| 10 |
+
|
| 11 |
+
import streamlit as st
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from streamlit_drawable_canvas import st_canvas
|
| 14 |
+
|
| 15 |
+
from lane_detection import YOLOVideoDetector, LABEL_MAP # Your detector + LABEL_MAP from detection.py
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# βββββββββββββββββββββββ Helper Functions βββββββββββββββββββββββ
|
| 19 |
+
|
| 20 |
+
def extract_four_points(js) -> List[Tuple[int,int]]:
|
| 21 |
+
"""
|
| 22 |
+
Return exactly four (x,y) clicks from streamlit_drawable_canvas JSON, or None.
|
| 23 |
+
"""
|
| 24 |
+
if not js or "objects" not in js:
|
| 25 |
+
return None
|
| 26 |
+
pts = []
|
| 27 |
+
for obj in js["objects"]:
|
| 28 |
+
if obj.get("type") in {"circle", "rect"}:
|
| 29 |
+
x = int(obj["left"] + obj.get("radius", 0))
|
| 30 |
+
y = int(obj["top"] + obj.get("radius", 0))
|
| 31 |
+
pts.append((x, y))
|
| 32 |
+
if len(pts) == 4:
|
| 33 |
+
return pts
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def draw_poly(img: np.ndarray, pts: List[Tuple[int,int]], color: Tuple[int,int,int]):
|
| 38 |
+
"""
|
| 39 |
+
Draw a closed polygon (4 points) on img in the specified color.
|
| 40 |
+
"""
|
| 41 |
+
cv2.polylines(img, [np.array(pts, np.int32)], True, color, 2, cv2.LINE_AA)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# βββββββββββββββββββββββββ Streamlit App βββββββββββββββββββββββββ
|
| 45 |
+
|
| 46 |
+
st.set_page_config(page_title="π¦ MultiβLane Congestion Demo", layout="wide")
|
| 47 |
+
st.title("π¦ MultiβLane Vehicle Congestion Demo")
|
| 48 |
+
|
| 49 |
+
# Initialize session state
|
| 50 |
+
if "num_lanes" not in st.session_state:
|
| 51 |
+
st.session_state.num_lanes = None
|
| 52 |
+
st.session_state.current_lane = 0
|
| 53 |
+
st.session_state.lanes = []
|
| 54 |
+
st.session_state.video_path = None
|
| 55 |
+
st.session_state.video_uploaded = False
|
| 56 |
+
|
| 57 |
+
# βββββββββββββββββββ Step 1: Choose Number of Lanes βββββββββββββββββββ
|
| 58 |
+
if st.session_state.num_lanes is None:
|
| 59 |
+
n = st.number_input(
|
| 60 |
+
"How many lanes would you like to define? (1β8)",
|
| 61 |
+
min_value=1,
|
| 62 |
+
max_value=8,
|
| 63 |
+
value=2
|
| 64 |
+
)
|
| 65 |
+
if st.button("β Set Number of Lanes"):
|
| 66 |
+
st.session_state.num_lanes = int(n)
|
| 67 |
+
st.session_state.lanes = [None] * st.session_state.num_lanes
|
| 68 |
+
st.stop()
|
| 69 |
+
|
| 70 |
+
# βββββββββββββββββββ Step 2: Upload a Video βββββββββββββββββββ
|
| 71 |
+
if not st.session_state.video_uploaded:
|
| 72 |
+
uploaded = st.file_uploader(
|
| 73 |
+
"Upload video (formats: mp4, avi, mov, mkv)",
|
| 74 |
+
type=["mp4","avi","mov","mkv"]
|
| 75 |
+
)
|
| 76 |
+
if uploaded:
|
| 77 |
+
tmpfile = NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded.name)[1])
|
| 78 |
+
tmpfile.write(uploaded.read())
|
| 79 |
+
tmpfile.flush()
|
| 80 |
+
st.session_state.video_path = tmpfile.name
|
| 81 |
+
st.session_state.video_uploaded = True
|
| 82 |
+
else:
|
| 83 |
+
st.stop()
|
| 84 |
+
|
| 85 |
+
# βββββββββββββββββββ Step 3: Grab First Frame & Scale βββββββββββββββββββ
|
| 86 |
+
cap = cv2.VideoCapture(st.session_state.video_path)
|
| 87 |
+
ret, first_frame = cap.read()
|
| 88 |
+
cap.release()
|
| 89 |
+
if not ret:
|
| 90 |
+
st.error("β Could not read the first frame of the video.")
|
| 91 |
+
st.stop()
|
| 92 |
+
|
| 93 |
+
frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 94 |
+
h_orig, w_orig = frame_rgb.shape[:2]
|
| 95 |
+
|
| 96 |
+
# If the frame is wider than 800 px, scale down
|
| 97 |
+
MAX_W = 800
|
| 98 |
+
if w_orig > MAX_W:
|
| 99 |
+
scale = MAX_W / w_orig
|
| 100 |
+
disp_w = MAX_W
|
| 101 |
+
disp_h = int(h_orig * scale)
|
| 102 |
+
frame_disp = cv2.resize(frame_rgb, (disp_w, disp_h), interpolation=cv2.INTER_AREA)
|
| 103 |
+
else:
|
| 104 |
+
scale = 1.0
|
| 105 |
+
disp_w, disp_h = w_orig, h_orig
|
| 106 |
+
frame_disp = frame_rgb.copy()
|
| 107 |
+
|
| 108 |
+
# βββββββββββββββββββ Step 4: Draw 4 Points Per Lane βββββββββββββββββββ
|
| 109 |
+
colors = [
|
| 110 |
+
(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0),
|
| 111 |
+
(255, 0, 255), (0, 255, 255), (128, 255, 0), (255, 128, 0)
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
if st.session_state.current_lane < st.session_state.num_lanes:
|
| 115 |
+
idx = st.session_state.current_lane
|
| 116 |
+
color = colors[idx % len(colors)]
|
| 117 |
+
st.subheader(f"2οΈβ£ Click exactly 4 points for Lane #{idx+1}")
|
| 118 |
+
st.caption("Draw 4 small circles on the image, then press **Confirm Lane**.")
|
| 119 |
+
|
| 120 |
+
canvas = st_canvas(
|
| 121 |
+
fill_color="rgba(0,0,0,0)",
|
| 122 |
+
stroke_width=2,
|
| 123 |
+
stroke_color=f"#{color[2]:02X}{color[1]:02X}{color[0]:02X}",
|
| 124 |
+
background_image=Image.fromarray(frame_disp),
|
| 125 |
+
drawing_mode="point",
|
| 126 |
+
key=f"lane_canvas_{idx}",
|
| 127 |
+
height=disp_h,
|
| 128 |
+
width=disp_w,
|
| 129 |
+
update_streamlit=True
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
pts_scaled = extract_four_points(canvas.json_data)
|
| 133 |
+
if pts_scaled:
|
| 134 |
+
preview = frame_disp.copy()
|
| 135 |
+
draw_poly(preview, pts_scaled, color)
|
| 136 |
+
st.image(preview, caption=f"Preview β Lane {idx+1}", use_column_width=True)
|
| 137 |
+
|
| 138 |
+
if st.button(f"Confirm Lane {idx+1}"):
|
| 139 |
+
if pts_scaled and len(pts_scaled) == 4:
|
| 140 |
+
# Convert scaled points back to original resolution
|
| 141 |
+
orig_pts = [(int(x/scale), int(y/scale)) for (x,y) in pts_scaled]
|
| 142 |
+
st.session_state.lanes[idx] = orig_pts
|
| 143 |
+
st.session_state.current_lane += 1
|
| 144 |
+
else:
|
| 145 |
+
st.warning("β Please click exactly 4 points.")
|
| 146 |
+
st.stop()
|
| 147 |
+
|
| 148 |
+
# βββββββββββββββββββ Step 5: Display All Lane Polygons βββββββββββββββββββ
|
| 149 |
+
st.subheader("β
All lanes defined:")
|
| 150 |
+
confirm_img = frame_rgb.copy()
|
| 151 |
+
for i, poly in enumerate(st.session_state.lanes):
|
| 152 |
+
c = colors[i % len(colors)]
|
| 153 |
+
draw_poly(confirm_img, poly, c)
|
| 154 |
+
cv2.putText(
|
| 155 |
+
confirm_img, f"L{i+1}", (poly[0][0], poly[0][1] - 10),
|
| 156 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1.0, c, 2, cv2.LINE_AA
|
| 157 |
+
)
|
| 158 |
+
st.image(confirm_img, caption="All lane regions overlaid", use_column_width=True)
|
| 159 |
+
|
| 160 |
+
# βββββββββββββββββββ Step 6: Input Thresholds & Run Congestion Analysis βββββββββββββββββββ
|
| 161 |
+
st.subheader("π§ Congestion Thresholds")
|
| 162 |
+
col1, col2 = st.columns(2)
|
| 163 |
+
with col1:
|
| 164 |
+
low_thresh = st.number_input(
|
| 165 |
+
"Green if PCE <",
|
| 166 |
+
min_value=0.0,
|
| 167 |
+
max_value=20.0,
|
| 168 |
+
value=3.5,
|
| 169 |
+
step=0.1,
|
| 170 |
+
format="%.1f",
|
| 171 |
+
help="Values below this will be colored green"
|
| 172 |
+
)
|
| 173 |
+
with col2:
|
| 174 |
+
high_thresh = st.number_input(
|
| 175 |
+
"Red if PCE >",
|
| 176 |
+
min_value=0.0,
|
| 177 |
+
max_value=20.0,
|
| 178 |
+
value=6.5,
|
| 179 |
+
step=0.1,
|
| 180 |
+
format="%.1f",
|
| 181 |
+
help="Values above this will be colored red"
|
| 182 |
+
)
|
| 183 |
+
st.caption("Values between green/red thresholds will be yellow.")
|
| 184 |
+
|
| 185 |
+
if st.button("π Run Congestion Analysis"):
|
| 186 |
+
out_tmp = NamedTemporaryFile(delete=False, suffix=".mp4").name
|
| 187 |
+
|
| 188 |
+
regions: Dict[int, List[Tuple[int,int]]] = {
|
| 189 |
+
i: st.session_state.lanes[i] for i in range(st.session_state.num_lanes)
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# Instantiate YOLOVideoDetector with 4 positional args
|
| 193 |
+
detector = YOLOVideoDetector(
|
| 194 |
+
"Weights/last.pt", # <-- replace with your actual .pt path
|
| 195 |
+
st.session_state.video_path,
|
| 196 |
+
out_tmp,
|
| 197 |
+
regions
|
| 198 |
+
)
|
| 199 |
+
# Assign optional attributes
|
| 200 |
+
detector.classes = list(LABEL_MAP.keys())
|
| 201 |
+
detector.conf = 0.35
|
| 202 |
+
detector.scale = 1.5
|
| 203 |
+
|
| 204 |
+
with st.spinner("Processing videoβthis may take a while..."):
|
| 205 |
+
df = detector.process_video()
|
| 206 |
+
|
| 207 |
+
st.success("β
Detection + annotation complete!")
|
| 208 |
+
|
| 209 |
+
# βββββββββββββββββ Compute Per-Lane PCE & Rolling Average βββββββββββββββββ
|
| 210 |
+
PCE = {
|
| 211 |
+
"auto": 0.8,
|
| 212 |
+
"bus": 4.0,
|
| 213 |
+
"car": 1.0,
|
| 214 |
+
"electric-rickshaw": 0.8,
|
| 215 |
+
"large-sized-truck": 4.5,
|
| 216 |
+
"medium-sized-truck": 3.5,
|
| 217 |
+
"motorbike": 0.5,
|
| 218 |
+
"small-sized-truck": 3.0,
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
for rid in regions.keys():
|
| 222 |
+
def lane_pce(row, rid=rid):
|
| 223 |
+
total = 0.0
|
| 224 |
+
for vt, factor in PCE.items():
|
| 225 |
+
coln = f"{vt}_{rid}"
|
| 226 |
+
cnt = row.get(coln, 0)
|
| 227 |
+
total += int(cnt) * factor
|
| 228 |
+
return total
|
| 229 |
+
|
| 230 |
+
df[f"PCE_lane{rid}"] = df.apply(lane_pce, axis=1)
|
| 231 |
+
df[f"PCE_lane{rid}_avg"] = df[f"PCE_lane{rid}"].rolling(window=5, min_periods=1).mean()
|
| 232 |
+
|
| 233 |
+
# βββββββββββββββββββ Lane-Wise Subplots with Smooth Line + Colored Markers βββββββββββββββββββ
|
| 234 |
+
num_lanes = len(regions)
|
| 235 |
+
fig, axes = plt.subplots(num_lanes, 1, figsize=(10, 3 * num_lanes), sharex=True)
|
| 236 |
+
|
| 237 |
+
if num_lanes == 1:
|
| 238 |
+
axes = [axes]
|
| 239 |
+
|
| 240 |
+
for rid, ax in zip(regions.keys(), axes):
|
| 241 |
+
x = df["Frame Number"].values
|
| 242 |
+
y = df[f"PCE_lane{rid}_avg"].values
|
| 243 |
+
|
| 244 |
+
# 1) Plot a smooth continuous line in dark gray
|
| 245 |
+
ax.plot(x, y, color="gray", linewidth=1.2)
|
| 246 |
+
|
| 247 |
+
# 2) Overlay colored markers at each frame
|
| 248 |
+
colors_list = []
|
| 249 |
+
for yi in y:
|
| 250 |
+
if yi < low_thresh:
|
| 251 |
+
colors_list.append("green")
|
| 252 |
+
elif yi > high_thresh:
|
| 253 |
+
colors_list.append("red")
|
| 254 |
+
else:
|
| 255 |
+
colors_list.append("yellow")
|
| 256 |
+
|
| 257 |
+
ax.scatter(x, y, c=colors_list, s=20, edgecolors="black", linewidths=0.3)
|
| 258 |
+
|
| 259 |
+
ax.set_title(f"Lane {rid} PCE (rolling average)")
|
| 260 |
+
ax.set_ylabel("PCE")
|
| 261 |
+
ax.grid(alpha=0.3)
|
| 262 |
+
|
| 263 |
+
axes[-1].set_xlabel("Frame Number")
|
| 264 |
+
plt.tight_layout()
|
| 265 |
+
|
| 266 |
+
st.subheader("π LaneβWise Congestion Plots")
|
| 267 |
+
st.pyplot(fig)
|
| 268 |
+
|
| 269 |
+
# βββββββββββββββββββ Display Annotated Video & CSV Download βββββββββββββββββββ
|
| 270 |
+
st.subheader("π¬ Annotated Output Video")
|
| 271 |
+
with open(out_tmp, "rb") as f:
|
| 272 |
+
st.video(f.read())
|
| 273 |
+
|
| 274 |
+
csv_bytes = df.to_csv(index=False).encode("utf-8")
|
| 275 |
+
st.download_button(
|
| 276 |
+
label="Download full counts + PCE CSV",
|
| 277 |
+
data=csv_bytes,
|
| 278 |
+
file_name="counts_and_pce.csv",
|
| 279 |
+
mime="text/csv"
|
| 280 |
+
)
|