""" Vehicle Counter & Speed Measurement App ======================================== Upload a traffic video, place a PRIMARY counting line (counts vehicles by class + direction, exactly like the original version), and a SECONDARY line that exists purely to record a second timestamp per vehicle so we can compute speed = distance / time-between-crossings. The secondary line never counts anything on its own โ€” it only complements the primary line with timing data. This keeps counting logic identical to what worked before. Two detection models are available: - General (COCO): fast, widely accurate, but has no "rickshaw"/"CNG" class. - Bangladesh-specific (BNVD): a third-party pretrained YOLOv8 model trained on Bangladeshi traffic, with native Rickshaw and CNG classes among others. Source: https://github.com/bipin-saha/BNVD (Saha et al., 2024, arXiv:2405.12150). Downloaded once and cached locally on first use. Uses YOLOv8 detection + ByteTrack tracking. Run with: streamlit run app.py """ import os import tempfile import time import urllib.request import cv2 import numpy as np import pandas as pd import streamlit as st from ultralytics import YOLO st.set_page_config(page_title="Vehicle Counter & Speed", layout="wide") # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models_cache") os.makedirs(MODELS_DIR, exist_ok=True) # Each model config defines its own class id -> name mapping, since the # Bangladesh-specific model uses entirely different classes/ids than COCO. MODEL_CONFIGS = { "General - Fastest (COCO)": { "weights": "yolov8n.pt", "source": "auto", # ultralytics downloads + caches this automatically "classes": {1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"}, }, "General - Balanced (COCO)": { "weights": "yolov8s.pt", "source": "auto", "classes": {1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"}, }, "General - Most accurate (COCO)": { "weights": "yolov8m.pt", "source": "auto", "classes": {1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"}, }, "Bangladesh-specific (includes Rickshaw & CNG)": { "weights": "bnvd_yolov8.pt", "source": "url", "url": "https://github.com/bipin-saha/BNVD/raw/main/Cheakpoints/YOLO%20V8/weights/yolov8_ods_new_100e_best.pt", "classes": { 0: "Bicycle", 1: "Bus", 2: "Bhotbhoti", 3: "Car", 4: "CNG", 5: "Easybike", 6: "Leguna", 7: "Motorbike", 8: "MPV", 9: "Pedestrian", 10: "Pickup", 11: "PowerTiller", 12: "Rickshaw", 13: "ShoppingVan", 14: "Truck", 15: "Van", 16: "Wheelbarrow", }, }, } # --------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------- @st.cache_resource(show_spinner=False) def load_model(model_label: str) -> YOLO: """Load (and cache) a YOLO model. Downloads weights on first use.""" config = MODEL_CONFIGS[model_label] if config["source"] == "url": local_path = os.path.join(MODELS_DIR, config["weights"]) if not os.path.exists(local_path): req = urllib.request.Request(config["url"], headers={"User-Agent": "Mozilla/5.0"}) with urllib.request.urlopen(req, timeout=180) as resp, open(local_path, "wb") as f: f.write(resp.read()) return YOLO(local_path) return YOLO(config["weights"]) def get_side(px, py, x1, y1, x2, y2): """ Return True/False depending on which side of the line (x1,y1)-(x2,y2) the point (px,py) is on. """ val = (x2 - x1) * (py - y1) - (y2 - y1) * (px - x1) return val > 0 def line_orientation_labels(x1, y1, x2, y2): """ Decide human-friendly direction names based on the orientation of a line, and figure out which "side" corresponds to which name. Returns: (ref_side, dir_to_ref, dir_from_ref) """ dx = x2 - x1 dy = y2 - y1 if abs(dx) >= abs(dy): ref_side = get_side(x1, y1 + 10, x1, y1, x2, y2) return ref_side, "Downward", "Upward" else: ref_side = get_side(x1 + 10, y1, x1, y1, x2, y2) return ref_side, "Rightward", "Leftward" def draw_direction_arrows(preview, x1, y1, x2, y2, ref_side, dir_to_ref, dir_from_ref): """Draw small perpendicular direction arrows at the midpoint of a line.""" mx, my = (x1 + x2) // 2, (y1 + y2) // 2 perp_dx, perp_dy = -(y2 - y1), (x2 - x1) norm = (perp_dx ** 2 + perp_dy ** 2) ** 0.5 if norm == 0: return ux, uy = perp_dx / norm, perp_dy / norm arrow_len = max(30, int(0.08 * max(preview.shape[:2]))) p1 = (int(mx + ux * arrow_len), int(my + uy * arrow_len)) p2 = (int(mx - ux * arrow_len), int(my - uy * arrow_len)) if get_side(p1[0], p1[1], x1, y1, x2, y2) == ref_side: to_ref_pt, from_ref_pt = p1, p2 else: to_ref_pt, from_ref_pt = p2, p1 cv2.arrowedLine(preview, (mx, my), to_ref_pt, (0, 200, 0), 2, tipLength=0.4) cv2.putText(preview, dir_to_ref, (to_ref_pt[0] + 5, to_ref_pt[1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 200, 0), 2) cv2.arrowedLine(preview, (mx, my), from_ref_pt, (255, 0, 255), 2, tipLength=0.4) cv2.putText(preview, dir_from_ref, (from_ref_pt[0] + 5, from_ref_pt[1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 255), 2) def draw_lines_preview(frame_bgr, px1, py1, px2, py2, sx1, sy1, sx2, sy2): """ Draw the PRIMARY line (red, with direction arrows โ€” this is what counts vehicles) and the SECONDARY line (orange, no arrows โ€” timing only) on a copy of the frame. """ preview = frame_bgr.copy() ref_side, dir_to_ref, dir_from_ref = line_orientation_labels(px1, py1, px2, py2) # Primary line cv2.line(preview, (px1, py1), (px2, py2), (0, 0, 255), 3) cv2.circle(preview, (px1, py1), 6, (0, 0, 255), -1) cv2.circle(preview, (px2, py2), 6, (255, 0, 0), -1) label_y = min(py1, py2) - 12 cv2.putText(preview, "PRIMARY - counts vehicles", (min(px1, px2), max(label_y, 15)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) draw_direction_arrows(preview, px1, py1, px2, py2, ref_side, dir_to_ref, dir_from_ref) # Secondary line (timing reference only, no arrows/counting) cv2.line(preview, (sx1, sy1), (sx2, sy2), (0, 165, 255), 3) cv2.circle(preview, (sx1, sy1), 6, (0, 165, 255), -1) cv2.circle(preview, (sx2, sy2), 6, (0, 165, 255), -1) label_y2 = min(sy1, sy2) - 12 cv2.putText(preview, "SECONDARY - timing only", (min(sx1, sx2), max(label_y2, 15)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 165, 255), 2) return preview, dir_to_ref, dir_from_ref # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- st.title("๐Ÿš— Vehicle Counter & Speed Measurement") st.write( "Upload a traffic video. The **primary line** counts vehicles by type " "and direction (exactly like a normal counting line). The **secondary " "line** is only used to record a second timestamp per vehicle, so the " "app can calculate speed from the time between the two crossings." ) # --- Sidebar settings ------------------------------------------------------- st.sidebar.header("Detection settings") model_label = st.sidebar.selectbox("Model", list(MODEL_CONFIGS.keys()), index=0) model_config = MODEL_CONFIGS[model_label] classes_dict = model_config["classes"] # id -> name, specific to this model confidence = st.sidebar.slider("Confidence threshold", 0.1, 0.9, 0.3, 0.05) default_classes = [name for name in classes_dict.values() if name != "Pedestrian"] selected_classes = st.sidebar.multiselect( "Vehicle types to count", options=list(classes_dict.values()), default=default_classes, key=f"vehicle_select_{model_label}", ) st.sidebar.markdown("---") if model_label.startswith("Bangladesh"): st.sidebar.caption( "โ„น๏ธ This model has native **Rickshaw** and **CNG** classes, trained on " "Bangladeshi traffic photos (mAP@0.5 โ‰ˆ 85% in the original benchmark). " "It's a third-party research model " "[source & paper](https://github.com/bipin-saha/BNVD). The first time " "you use it, it downloads a ~50MB file (one-time, then cached). It also " "runs somewhat slower per frame than the general 'Fastest' model." ) else: st.sidebar.caption( "โ„น๏ธ This general model doesn't have a dedicated **rickshaw** or **CNG** " "class. Auto-rickshaws/CNGs are usually detected as **car** or " "**motorcycle**, cycle-rickshaws as **bicycle**. Switch to the " "**'Bangladesh-specific'** model above for native rickshaw/CNG detection." ) # --- File upload ------------------------------------------------------------- uploaded = st.file_uploader("1. Upload a video", type=["mp4", "mov", "avi", "mkv"]) if uploaded is not None: if ( "video_name" not in st.session_state or st.session_state.video_name != uploaded.name ): suffix = os.path.splitext(uploaded.name)[1] or ".mp4" tfile = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) tfile.write(uploaded.read()) tfile.flush() tfile.close() st.session_state.video_path = tfile.name st.session_state.video_name = uploaded.name st.session_state.pop("result", None) video_path = st.session_state.video_path cap = cv2.VideoCapture(video_path) ret, first_frame = cap.read() fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() if fps is None or fps <= 0: fps = 25.0 if not ret or first_frame is None: st.error("Could not read this video file. Please try a different file.") st.stop() st.success( f"Loaded **{uploaded.name}** โ€” {width}x{height}px, " f"{total_frames} frames, ~{fps:.1f} fps" ) # --- Step 2: primary counting line -------------------------------------- st.subheader("2. Set the primary counting line") st.write( "Move the sliders so the **red line** crosses the road where you " "want vehicles counted. This line determines the class + direction " "counts, exactly as before." ) default_x1, default_y1 = int(width * 0.05), int(height * 0.45) default_x2, default_y2 = int(width * 0.95), int(height * 0.45) col_a, col_b = st.columns(2) with col_a: st.markdown("**Primary โ€” Point A**") x1 = st.slider("Primary A โ€” X", 0, max(width - 1, 1), min(default_x1, max(width - 1, 0)), key="px1") y1 = st.slider("Primary A โ€” Y", 0, max(height - 1, 1), min(default_y1, max(height - 1, 0)), key="py1") with col_b: st.markdown("**Primary โ€” Point B**") x2 = st.slider("Primary B โ€” X", 0, max(width - 1, 1), min(default_x2, max(width - 1, 0)), key="px2") y2 = st.slider("Primary B โ€” Y", 0, max(height - 1, 1), min(default_y2, max(height - 1, 0)), key="py2") # --- Step 3: secondary timing line -------------------------------------- st.subheader("3. Set the secondary timing line") st.write( "Move these sliders so the **orange line** sits a known distance " "away from the primary line, further along the direction of travel. " "It does not count vehicles โ€” it only records a second timestamp " "per vehicle, used together with the primary line to calculate speed." ) default_sx1, default_sy1 = int(width * 0.05), int(height * 0.65) default_sx2, default_sy2 = int(width * 0.95), int(height * 0.65) col_c, col_d = st.columns(2) with col_c: st.markdown("**Secondary โ€” Point A**") sx1 = st.slider("Secondary A โ€” X", 0, max(width - 1, 1), min(default_sx1, max(width - 1, 0)), key="sx1") sy1 = st.slider("Secondary A โ€” Y", 0, max(height - 1, 1), min(default_sy1, max(height - 1, 0)), key="sy1") with col_d: st.markdown("**Secondary โ€” Point B**") sx2 = st.slider("Secondary B โ€” X", 0, max(width - 1, 1), min(default_sx2, max(width - 1, 0)), key="sx2") sy2 = st.slider("Secondary B โ€” Y", 0, max(height - 1, 1), min(default_sy2, max(height - 1, 0)), key="sy2") preview_bgr, dir_to_ref, dir_from_ref = draw_lines_preview( first_frame, x1, y1, x2, y2, sx1, sy1, sx2, sy2 ) preview_rgb = cv2.cvtColor(preview_bgr, cv2.COLOR_BGR2RGB) st.image(preview_rgb, caption="Preview: red = primary (counts), orange = secondary (timing only)", width=1000) line_is_degenerate = (x1 == x2 and y1 == y2) if line_is_degenerate: st.warning("Primary line's Point A and Point B are the same โ€” please move one of them.") st.caption( f"Vehicles crossing the **primary** line toward the green arrow will be " f"counted as **\"{dir_to_ref}\"**, and toward the magenta arrow as **\"{dir_from_ref}\"**." ) # --- Step 4: distance calibration --------------------------------------- st.subheader("4. Enter the real-world distance between the lines") distance_meters = st.number_input( "Distance between primary and secondary line (meters)", min_value=0.1, max_value=1000.0, value=10.0, step=0.1 ) # --- Step 5: process ----------------------------------------------------- st.subheader("5. Process the video") if not selected_classes: st.warning("Select at least one vehicle type in the sidebar to count.") process_disabled = line_is_degenerate or not selected_classes if st.button("โ–ถ Start counting & measuring speed", type="primary", disabled=process_disabled): ref_side, dir_to_ref, dir_from_ref = line_orientation_labels(x1, y1, x2, y2) class_ids = [cid for cid, name in classes_dict.items() if name in selected_classes] already_cached = ( model_config["source"] != "url" or os.path.exists(os.path.join(MODELS_DIR, model_config["weights"])) ) spinner_msg = ( "Loading model..." if already_cached else "Downloading Bangladesh-specific model (~50MB, one-time only)..." ) with st.spinner(spinner_msg): model = load_model(model_label) out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) counts = {name: {dir_to_ref: 0, dir_from_ref: 0} for name in selected_classes} # track_id -> state dict track_state = {} progress_bar = st.progress(0.0) status_text = st.empty() frame_idx = 0 start_time = time.time() results_gen = model.track( source=video_path, stream=True, persist=True, classes=class_ids, conf=confidence, tracker="bytetrack.yaml", verbose=False, ) for result in results_gen: frame_img = result.orig_img.copy() cv2.line(frame_img, (x1, y1), (x2, y2), (0, 0, 255), 3) cv2.line(frame_img, (sx1, sy1), (sx2, sy2), (0, 165, 255), 3) boxes = result.boxes if boxes is not None and boxes.id is not None: ids = boxes.id.int().tolist() clss = boxes.cls.int().tolist() xyxy = boxes.xyxy.tolist() for tid, cls_id, box in zip(ids, clss, xyxy): cls_name = classes_dict.get(cls_id) if cls_name is None or cls_name not in selected_classes: continue bx1, by1, bx2, by2 = box cx, cy = (bx1 + bx2) / 2.0, (by1 + by2) / 2.0 cur_side_p = get_side(cx, cy, x1, y1, x2, y2) cur_side_s = get_side(cx, cy, sx1, sy1, sx2, sy2) state = track_state.get(tid) if state is None: track_state[tid] = { "side_p": cur_side_p, "side_s": cur_side_s, "counted": False, "class": cls_name, "direction": None, "primary_frame": None, "secondary_frame": None, "speed_kmh": None, } else: # --- Primary line: counts the vehicle (same as before) --- if not state["counted"] and state["side_p"] != cur_side_p: direction = dir_to_ref if cur_side_p == ref_side else dir_from_ref counts[cls_name][direction] += 1 state["counted"] = True state["direction"] = direction state["primary_frame"] = frame_idx # --- Secondary line: records a timestamp only, never counts --- if state["secondary_frame"] is None and state["side_s"] != cur_side_s: state["secondary_frame"] = frame_idx # --- Once both timestamps exist, compute speed once --- if ( state["primary_frame"] is not None and state["secondary_frame"] is not None and state["speed_kmh"] is None ): frame_diff = abs(state["secondary_frame"] - state["primary_frame"]) time_sec = frame_diff / fps if fps > 0 else 0 if time_sec > 0: speed_mps = distance_meters / time_sec state["speed_kmh"] = speed_mps * 3.6 state["side_p"] = cur_side_p state["side_s"] = cur_side_s # draw box + label label = f"{cls_name} #{tid}" if track_state[tid]["speed_kmh"] is not None: label += f" {track_state[tid]['speed_kmh']:.1f}km/h" cv2.rectangle(frame_img, (int(bx1), int(by1)), (int(bx2), int(by2)), (0, 255, 0), 2) cv2.putText( frame_img, label, (int(bx1), max(int(by1) - 8, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA, ) # overlay running counts y_off = 30 for cls_name, d in counts.items(): text = f"{cls_name}: {dir_to_ref}={d[dir_to_ref]} {dir_from_ref}={d[dir_from_ref]}" cv2.putText(frame_img, text, (10, y_off), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2, cv2.LINE_AA) y_off += 25 writer.write(frame_img) frame_idx += 1 if total_frames > 0: progress_bar.progress(min(frame_idx / total_frames, 1.0)) elapsed = time.time() - start_time status_text.text(f"Processing frame {frame_idx}/{total_frames or '?'} ยท {elapsed:.0f}s elapsed") writer.release() progress_bar.progress(1.0) status_text.text("โœ… Done!") # Build summary table rows = [] for cls_name, d in counts.items(): rows.append({ "Vehicle type": cls_name, dir_to_ref: d[dir_to_ref], dir_from_ref: d[dir_from_ref], "Total": d[dir_to_ref] + d[dir_from_ref], }) summary_df = pd.DataFrame(rows) if not summary_df.empty: total_row = { "Vehicle type": "TOTAL", dir_to_ref: summary_df[dir_to_ref].sum(), dir_from_ref: summary_df[dir_from_ref].sum(), "Total": summary_df["Total"].sum(), } summary_df = pd.concat([summary_df, pd.DataFrame([total_row])], ignore_index=True) # Build per-vehicle event log from every COUNTED vehicle # (counting is decided solely by the primary line; speed is filled # in if/when the secondary line was also crossed) events = [] for tid, state in track_state.items(): if not state["counted"]: continue time_sec = None if state["primary_frame"] is not None and state["secondary_frame"] is not None: time_sec = round(abs(state["secondary_frame"] - state["primary_frame"]) / fps, 2) events.append({ "track_id": tid, "class": state["class"], "direction": state["direction"], "primary_cross_frame": state["primary_frame"], "secondary_cross_frame": state["secondary_frame"], "time_between_lines_sec": time_sec, "speed_kmh": round(state["speed_kmh"], 2) if state["speed_kmh"] is not None else None, }) events_df = pd.DataFrame(events) st.session_state.result = { "out_path": out_path, "summary_df": summary_df, "events_df": events_df, } # --- Show results --------------------------------------------------------- if "result" in st.session_state: st.subheader("6. Results") result = st.session_state.result st.markdown("**Counts by vehicle type and direction**") st.dataframe(result["summary_df"], width="stretch", hide_index=True) csv_summary = result["summary_df"].to_csv(index=False).encode("utf-8") st.download_button( "โฌ‡ Download summary (CSV)", csv_summary, file_name="vehicle_counts_summary.csv", mime="text/csv" ) if not result["events_df"].empty: with st.expander("Show detailed crossing log with speeds (one row per counted vehicle)"): st.dataframe(result["events_df"], width="stretch", hide_index=True) st.caption( "speed_kmh is blank for vehicles that crossed the primary line " "but never reached the secondary line (e.g. they turned off, " "or the video ended before they got there)." ) csv_events = result["events_df"].to_csv(index=False).encode("utf-8") st.download_button( "โฌ‡ Download crossing log (CSV)", csv_events, file_name="vehicle_crossing_log.csv", mime="text/csv" ) st.markdown("**Annotated video**") st.caption( "If the preview below doesn't play in your browser, download the " "file and open it with VLC or another media player." ) try: with open(result["out_path"], "rb") as f: video_bytes = f.read() st.video(video_bytes) st.download_button( "โฌ‡ Download annotated video", video_bytes, file_name="vehicle_counter_output.mp4", mime="video/mp4" ) except Exception as e: st.error(f"Could not load output video: {e}") else: st.info("Upload a video to get started.")