""" visualize.py ============ All diagnostic plots for the trajectory retrieval system. BEFORE RUNNING — set your paths at the top of this file: PKL_PATH : your processed trajectory database pkl VIDEO_DIR : folder containing your dashcam videos OUT_DIR : where to save the plots Run: python visualize.py All 10 plots saved as PNG in OUT_DIR. """ import os import sys import pickle import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import matplotlib.gridspec as gridspec import seaborn as sns from sklearn.decomposition import PCA from sklearn.manifold import TSNE from scipy.spatial.distance import cosine as scipy_cosine from dtaidistance import dtw as dtw_lib from collections import defaultdict import cv2 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from trajectory_extractor import ( extract_clip_frames, build_roi_mask, estimate_lambda, subtract_forward_motion, remove_outliers_iqr, FOE_X, FOE_Y, N_RESAMPLE, ST_PARAMS, LK_PARAMS ) from retrieval_engine import CANONICAL, score_clip # ═══════════════════════════════════════════════════════════════════ # ▶▶ SET YOUR PATHS HERE # ═══════════════════════════════════════════════════════════════════ PKL_PATH = "trajectory_db100.pkl" VIDEO_DIR = "/media/RTCIN15TBD/AllUsers/sjl3kor/trajectory_retrieval_system/data/videos" OUT_DIR = "plots" # ═══════════════════════════════════════════════════════════════════ # ── Global dark style ─────────────────────────────────────────────── plt.rcParams.update({ "figure.facecolor" : "#0f0f1a", "axes.facecolor" : "#1a1a2e", "axes.edgecolor" : "#444466", "axes.labelcolor" : "#ccccdd", "axes.titlecolor" : "#ffffff", "xtick.color" : "#aaaacc", "ytick.color" : "#aaaacc", "text.color" : "#ccccdd", "grid.color" : "#2a2a4a", "grid.linewidth" : 0.6, "legend.facecolor" : "#1a1a2e", "legend.edgecolor" : "#444466", "font.family" : "monospace", }) DIR_COLORS = {"left": "#00e676", "right": "#ff5252", "straight": "#90caf9"} DIR_MARKERS = {"left": "^", "right": "v", "straight": "s"} # ── Helpers ────────────────────────────────────────────────────────── def load_db(path): with open(path, "rb") as f: db = pickle.load(f) print(f"Loaded {len(db)} clips") return db def save_fig(fig, name): os.makedirs(OUT_DIR, exist_ok=True) p = os.path.join(OUT_DIR, f"{name}.png") fig.savefig(p, dpi=130, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) print(f" Saved → {p}") def video_exists(entry): return os.path.exists(entry["video_path"]) def entries_with_video(db): return [e for e in db.values() if video_exists(e)] # ═══════════════════════════════════════════════════════════════════ # PLOT 1 — Feature Detection Quality # ═══════════════════════════════════════════════════════════════════ def plot_01_feature_detection(db): print("\n[1/10] Feature detection quality...") all_ys, all_xs, feat_counts, clip_dirs = [], [], [], [] for entry in db.values(): if not video_exists(entry): continue frames = extract_clip_frames(entry["video_path"], entry["start_frame"], entry["start_frame"] + 2) if not frames: continue gray = frames[0] mask = build_roi_mask(gray.shape) pts = cv2.goodFeaturesToTrack(gray, mask=mask, **ST_PARAMS) if pts is not None: all_ys.extend(pts[:, 0, 1].tolist()) all_xs.extend(pts[:, 0, 0].tolist()) feat_counts.append(entry["n_features_avg"]) clip_dirs.append(entry["direction"]) fig, axes = plt.subplots(1, 3, figsize=(17, 5)) fig.suptitle("Plot 1 — Feature Detection Quality (Shi-Tomasi + ROI Mask)", fontsize=13, fontweight="bold", color="white", y=1.01) # A: vertical distribution ax = axes[0] if all_ys: ax.hist(all_ys, bins=54, orientation="horizontal", color="#00e5ff", alpha=0.85, edgecolor="none") ax.axhline(200, color="#ffeb3b", lw=2, ls="--", label="ROI top y=200") ax.axhline(880, color="#ff9800", lw=2, ls="--", label="ROI bot y=880") ax.axhspan(0, 200, alpha=0.12, color="#ffeb3b") ax.axhspan(880, 1080, alpha=0.12, color="#ff9800") ax.invert_yaxis() ax.set_xlabel("Feature count", fontsize=10) ax.set_ylabel("Image row (y pixel)", fontsize=10) ax.set_title("A — Vertical distribution\n(should cluster 200–880)", fontsize=10) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) # B: horizontal distribution ax = axes[1] if all_xs: ax.hist(all_xs, bins=48, color="#e040fb", alpha=0.85, edgecolor="none") ax.axvline(960, color="#ffffff", lw=1.5, ls=":", label="center x=960") ax.set_xlabel("Image column (x pixel)", fontsize=10) ax.set_ylabel("Feature count", fontsize=10) ax.set_title("B — Horizontal distribution\n(should be roughly symmetric)", fontsize=10) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) # C: feature count per clip ax = axes[2] if feat_counts: colors = [DIR_COLORS[d] for d in clip_dirs] ax.bar(range(len(feat_counts)), feat_counts, color=colors, alpha=0.85, width=0.7) ax.axhline(50, color="#ff5252", lw=1.5, ls="--", label="Min reliable = 50") patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()] ax.legend(handles=patches, fontsize=8) ax.set_xlabel("Clip index", fontsize=10) ax.set_ylabel("Avg tracked features", fontsize=10) ax.set_title("C — Avg features per clip\n(below 50 = unreliable)", fontsize=10) ax.grid(True, alpha=0.3, axis="y") plt.tight_layout() save_fig(fig, "01_feature_detection_quality") # ═══════════════════════════════════════════════════════════════════ # PLOT 2 — Optical Flow Quiver Field # ═══════════════════════════════════════════════════════════════════ def plot_02_flow_field(db): print("[2/10] Flow vector field...") ev = entries_with_video(db) if not ev: print(" Skipped — no video files accessible.") return straight_e = next((e for e in ev if e["direction"] == "straight"), None) turn_e = max(ev, key=lambda x: abs(x["turn_ratio"])) fig, axes = plt.subplots(1, 2, figsize=(16, 6)) fig.suptitle("Plot 2 — Optical Flow Vector Field\n" "Left: straight | Right: strongest turn", fontsize=12, fontweight="bold", color="white") def draw_quiver(ax, entry, title): frames = extract_clip_frames(entry["video_path"], entry["start_frame"], entry["start_frame"] + 3) if len(frames) < 2: return g0, g1 = frames[0], frames[1] mask = build_roi_mask(g0.shape) pts0 = cv2.goodFeaturesToTrack(g0, mask=mask, **ST_PARAMS) if pts0 is None: return pts1, status, _ = cv2.calcOpticalFlowPyrLK(g0, g1, pts0, None, **LK_PARAMS) good0 = pts0[status == 1].reshape(-1, 2) good1 = pts1[status == 1].reshape(-1, 2) dx = good1[:, 0] - good0[:, 0] dy = good1[:, 1] - good0[:, 1] sx, sy = 960 / 1920, 540 / 1080 xs = good0[:, 0] * sx ys = good0[:, 1] * sy mag = np.sqrt(dx**2 + dy**2) ax.set_facecolor("#0d1117") ax.quiver(xs, ys, dx * 12, -dy * 12, mag, cmap="plasma", alpha=0.85, scale=1, scale_units="xy", angles="xy", width=0.003) ax.scatter([FOE_X * sx], [FOE_Y * sy], c="red", s=150, zorder=6, label=f"FOE ({FOE_X},{FOE_Y})") ax.axhline(200 * sy, color="#ffeb3b", lw=1.2, ls="--", alpha=0.6, label="ROI boundary") ax.axhline(880 * sy, color="#ffeb3b", lw=1.2, ls="--", alpha=0.6) ax.set_xlim(0, 960); ax.set_ylim(540, 0) ax.set_xlabel("x (scaled)", fontsize=9) ax.set_ylabel("y (scaled)", fontsize=9) ax.set_title(f"{title}\n" f"dir={entry['direction']} " f"ratio={entry['turn_ratio']:+.3f}", fontsize=10) ax.legend(fontsize=8) if straight_e: draw_quiver(axes[0], straight_e, "Straight clip") else: axes[0].set_title("No straight clip found") draw_quiver(axes[1], turn_e, "Strongest turn clip") plt.tight_layout() save_fig(fig, "02_flow_vector_field") # ═══════════════════════════════════════════════════════════════════ # PLOT 3 — FOE Stability + Lateral Signal per Frame # ═══════════════════════════════════════════════════════════════════ def plot_03_foe_stability(db): print("[3/10] FOE stability + lateral signal...") ev = entries_with_video(db) if not ev: print(" Skipped — no video files accessible.") return # Pick one clip per direction (prefer strongest) selected = {} for d in ["left", "right", "straight"]: candidates = [e for e in ev if e["direction"] == d] if candidates: selected[d] = max(candidates, key=lambda x: abs(x["turn_ratio"])) fig, axes = plt.subplots(len(selected), 2, figsize=(14, 4 * len(selected))) if len(selected) == 1: axes = [axes] fig.suptitle("Plot 3 — Lateral Signal Per Frame (one clip per direction)\n" "Consistent sign = clean turn detection", fontsize=12, fontweight="bold", color="white") for row, (direction, entry) in enumerate(selected.items()): col = DIR_COLORS[direction] lat_s = entry["lateral_signals"] t = np.arange(len(lat_s)) # Left sub-panel: cumulative trajectory ax_traj = axes[row][0] traj_raw = np.array(entry["trajectory_raw"]) t2 = np.linspace(0, 1, len(traj_raw)) ax_traj.plot(t2, traj_raw, color=col, lw=2) ax_traj.fill_between(t2, traj_raw, 0, alpha=0.2, color=col) ax_traj.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4) ax_traj.set_xlabel("Normalised time", fontsize=9) ax_traj.set_ylabel("Cumulative lateral (px)", fontsize=9) ax_traj.set_title(f"{direction.upper()} " f"turn_ratio={entry['turn_ratio']:+.3f}\n" f"Cumulative trajectory", fontsize=10, color=col) ax_traj.grid(True, alpha=0.3) # Right sub-panel: per-frame lateral bars ax_lat = axes[row][1] bar_colors = [DIR_COLORS["left"] if v >= 0 else DIR_COLORS["right"] for v in lat_s] ax_lat.bar(t, lat_s, color=bar_colors, alpha=0.8, width=0.8) ax_lat.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4) ax_lat.set_xlabel("Frame pair index", fontsize=9) ax_lat.set_ylabel("Lateral signal (px/frame)", fontsize=9) ax_lat.set_title("Per-frame lateral signal\n" "Green=rightward(left turn) Red=leftward(right turn)", fontsize=9) ax_lat.grid(True, alpha=0.3) plt.tight_layout() save_fig(fig, "03_foe_stability") # ═══════════════════════════════════════════════════════════════════ # PLOT 4 — Lateral Signal Profiles (all clips overlaid) # ═══════════════════════════════════════════════════════════════════ def plot_04_lateral_profiles(db): print("[4/10] Lateral signal profiles...") by_dir = defaultdict(list) for entry in db.values(): by_dir[entry["direction"]].append(entry["lateral_signals"]) fig, axes = plt.subplots(1, 3, figsize=(17, 5), sharey=False) fig.suptitle("Plot 4 — Per-Frame Lateral Signal (all clips overlaid per direction)\n" "Consistent direction = pipeline is working. Mixed = noise.", fontsize=12, fontweight="bold", color="white") for ax, direction in zip(axes, ["left", "straight", "right"]): signals = by_dir[direction] col = DIR_COLORS[direction] for sig in signals: t = np.linspace(0, 1, len(sig)) ax.plot(t, sig, color=col, lw=1, alpha=0.35) if signals: max_len = max(len(s) for s in signals) resampled = [np.interp(np.linspace(0, 1, max_len), np.linspace(0, 1, len(s)), s) for s in signals] arr = np.array(resampled) med = np.median(arr, axis=0) t_med = np.linspace(0, 1, max_len) ax.plot(t_med, med, color="white", lw=2.5, label="Median", zorder=5) ax.fill_between(t_med, np.percentile(arr, 25, axis=0), np.percentile(arr, 75, axis=0), alpha=0.2, color=col, label="IQR band") ax.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4) ax.set_title(f"{direction.upper()} ({len(signals)} clips)", fontsize=11, color=col) ax.set_xlabel("Normalised time (0=start, 1=end)", fontsize=9) ax.set_ylabel("Lateral signal (px/frame)", fontsize=9) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) plt.tight_layout() save_fig(fig, "04_lateral_signal_profiles") # ═══════════════════════════════════════════════════════════════════ # PLOT 5 — Trajectory Gallery # ═══════════════════════════════════════════════════════════════════ def plot_05_trajectory_gallery(db): print("[5/10] Trajectory gallery...") entries = list(db.values()) n = len(entries) ncols = 7 nrows = (n + ncols - 1) // ncols fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2.2, nrows * 2.4)) fig.suptitle("Plot 5 — Trajectory Gallery (every clip)\n" "Each panel = one 3s clip's cumulative lateral trajectory", fontsize=12, fontweight="bold", color="white") axes_flat = axes.flatten() if nrows > 1 or ncols > 1 else [axes] for i, entry in enumerate(entries): ax = axes_flat[i] col = DIR_COLORS[entry["direction"]] traj = np.array(entry["trajectory_raw"]) t = np.linspace(0, 1, len(traj)) ax.plot(t, traj, color=col, lw=1.8) ax.fill_between(t, traj, 0, alpha=0.18, color=col) ax.axhline(0, color="#555577", lw=0.8) ax.set_title( f"{entry['clip_id'].split('__')[-1]}\n" f"{entry['direction']} {entry['turn_ratio']:+.2f}", fontsize=6.5, color=col ) ax.set_xticks([]); ax.set_yticks([]) for spine in ["bottom", "left"]: ax.spines[spine].set_color(col) for spine in ["top", "right"]: ax.spines[spine].set_color("#333355") for j in range(i + 1, len(axes_flat)): axes_flat[j].set_visible(False) patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()] fig.legend(handles=patches, loc="lower center", ncol=3, fontsize=10, bbox_to_anchor=(0.5, -0.01)) plt.tight_layout() save_fig(fig, "05_trajectory_gallery") # ═══════════════════════════════════════════════════════════════════ # PLOT 6 — Turn Ratio Distribution # ═══════════════════════════════════════════════════════════════════ def plot_06_turn_ratio_distribution(db): print("[6/10] Turn ratio distribution...") ratios = [e["turn_ratio"] for e in db.values()] directions = [e["direction"] for e in db.values()] peaks = [e["peak_lateral"] for e in db.values()] fig, axes = plt.subplots(1, 3, figsize=(17, 5)) fig.suptitle("Plot 6 — Turn Ratio Distribution\n" "Three separate peaks = clean separation between directions", fontsize=12, fontweight="bold", color="white") # A: histogram ax = axes[0] for d, col in DIR_COLORS.items(): vals = [r for r, dr in zip(ratios, directions) if dr == d] if vals: ax.hist(vals, bins=20, color=col, alpha=0.7, label=f"{d} ({len(vals)})", edgecolor="none") ax.axvline( 0.60, color="#ffeb3b", lw=2, ls="--", label="+0.60 threshold") ax.axvline(-0.60, color="#ffeb3b", lw=2, ls="--", label="-0.60 threshold") ax.axvline(0.0, color="#ffffff", lw=1, ls=":", alpha=0.5) ax.set_xlabel("turn_ratio", fontsize=10) ax.set_ylabel("Count", fontsize=10) ax.set_title("A — Histogram\n(gap at ±0.60?)", fontsize=10) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) # B: strip plot ax = axes[1] rng = np.random.default_rng(42) jitter = rng.uniform(-0.2, 0.2, len(ratios)) for d, col in DIR_COLORS.items(): mask = [dr == d for dr in directions] r_d = [r for r, m in zip(ratios, mask) if m] j_d = [j for j, m in zip(jitter, mask) if m] ax.scatter(r_d, j_d, color=col, s=70, alpha=0.85, marker=DIR_MARKERS[d], label=d, zorder=3) ax.axvline( 0.60, color="#ffeb3b", lw=2, ls="--") ax.axvline(-0.60, color="#ffeb3b", lw=2, ls="--") ax.axvline(0.0, color="#ffffff", lw=1, ls=":", alpha=0.5) ax.set_xlabel("turn_ratio", fontsize=10) ax.set_yticks([]) ax.set_title("B — Strip plot\n(overlapping near threshold = borderline clips)", fontsize=10) ax.legend(fontsize=8); ax.grid(True, alpha=0.3, axis="x") # C: peak lateral vs turn ratio ax = axes[2] for d, col in DIR_COLORS.items(): r_d = [r for r, dr in zip(ratios, directions) if dr == d] p_d = [p for p, dr in zip(peaks, directions) if dr == d] ax.scatter(r_d, p_d, color=col, s=70, alpha=0.85, marker=DIR_MARKERS[d], label=d) ax.set_xlabel("turn_ratio", fontsize=10) ax.set_ylabel("peak_lateral (px)", fontsize=10) ax.set_title("C — Peak drift vs turn ratio\n(stronger turn = bigger peak)", fontsize=10) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) plt.tight_layout() save_fig(fig, "06_turn_ratio_distribution") # ═══════════════════════════════════════════════════════════════════ # PLOT 7 — Embedding Space (PCA + t-SNE) # ═══════════════════════════════════════════════════════════════════ def plot_07_embedding_space(db): print("[7/10] Embedding space...") entries = list(db.values()) embeddings = np.array([e["embedding"] for e in entries]) directions = [e["direction"] for e in entries] labels = [e["clip_id"].split("__")[-1] for e in entries] fig = plt.figure(figsize=(18, 7)) fig.suptitle("Plot 7 — Embedding Space\n" "Separated clusters = embeddings discriminate directions well", fontsize=12, fontweight="bold", color="white") gs = gridspec.GridSpec(1, 3, figure=fig, wspace=0.35) # A: PCA ax_pca = fig.add_subplot(gs[0]) pca = PCA(n_components=2) pca_2d = pca.fit_transform(embeddings) for d, col in DIR_COLORS.items(): mask = np.array([dr == d for dr in directions]) if mask.any(): ax_pca.scatter(pca_2d[mask, 0], pca_2d[mask, 1], color=col, s=90, alpha=0.85, marker=DIR_MARKERS[d], label=d, zorder=3) for x, y, lbl in zip(pca_2d[:, 0], pca_2d[:, 1], labels): ax_pca.annotate(lbl, (x, y), fontsize=5.5, color="#aaaacc", alpha=0.7, xytext=(3, 3), textcoords="offset points") ax_pca.set_title( f"A — PCA\n({pca.explained_variance_ratio_.sum()*100:.1f}% variance)", fontsize=10) ax_pca.set_xlabel(f"PC1 {pca.explained_variance_ratio_[0]*100:.1f}%", fontsize=9) ax_pca.set_ylabel(f"PC2 {pca.explained_variance_ratio_[1]*100:.1f}%", fontsize=9) ax_pca.legend(fontsize=9); ax_pca.grid(True, alpha=0.3) # B: t-SNE ax_tsne = fig.add_subplot(gs[1]) n_clips = len(embeddings) if n_clips >= 6: perp = min(5, n_clips - 1) tsne = TSNE(n_components=2, perplexity=perp, random_state=42, max_iter=1000) ts_2d = tsne.fit_transform(embeddings) for d, col in DIR_COLORS.items(): mask = np.array([dr == d for dr in directions]) if mask.any(): ax_tsne.scatter(ts_2d[mask, 0], ts_2d[mask, 1], color=col, s=90, alpha=0.85, marker=DIR_MARKERS[d], label=d, zorder=3) for x, y, lbl in zip(ts_2d[:, 0], ts_2d[:, 1], labels): ax_tsne.annotate(lbl, (x, y), fontsize=5.5, color="#aaaacc", alpha=0.7, xytext=(3, 3), textcoords="offset points") ax_tsne.set_title("B — t-SNE (non-linear local structure)", fontsize=10) ax_tsne.legend(fontsize=9) else: ax_tsne.text(0.5, 0.5, f"Need ≥ 6 clips\n(have {n_clips})", ha="center", va="center", transform=ax_tsne.transAxes, color="#ff5252", fontsize=12) ax_tsne.set_title("B — t-SNE", fontsize=10) ax_tsne.grid(True, alpha=0.3) # C: explained variance ax_var = fig.add_subplot(gs[2]) n_comp = min(15, n_clips - 1) pca_f = PCA(n_components=n_comp) pca_f.fit(embeddings) evr = pca_f.explained_variance_ratio_ cum = np.cumsum(evr) comps = np.arange(1, len(evr) + 1) ax_var.bar(comps, evr * 100, color="#e040fb", alpha=0.8, label="Individual") ax_var.plot(comps, cum * 100, color="#ffeb3b", lw=2, marker="o", ms=4, label="Cumulative") ax_var.axhline(90, color="#ff5252", lw=1.5, ls="--", alpha=0.7, label="90% line") ax_var.set_xlabel("Principal component", fontsize=9) ax_var.set_ylabel("Explained variance (%)", fontsize=9) ax_var.set_title("C — PCA explained variance\n" "(few PCs to reach 90% = compact embedding)", fontsize=10) ax_var.legend(fontsize=8); ax_var.grid(True, alpha=0.3) save_fig(fig, "07_embedding_space") # ═══════════════════════════════════════════════════════════════════ # PLOT 8 — Score Distribution per Query # ═══════════════════════════════════════════════════════════════════ def plot_08_score_distribution(db): print("[8/10] Score distribution per query...") entries = list(db.values()) fig, axes = plt.subplots(1, 3, figsize=(17, 5)) fig.suptitle("Plot 8 — Score Distribution for Each Button Query\n" "Sharp drop after top results = clear winner. Flat = poor discrimination.", fontsize=12, fontweight="bold", color="white") for ax, direction in zip(axes, ["left", "straight", "right"]): query_emb = CANONICAL[direction] col = DIR_COLORS[direction] rows = [] for e in entries: s = score_clip(query_emb, e["embedding"]) rows.append((s, e["direction"], e["clip_id"].split("__")[-1])) rows.sort(key=lambda x: x[0], reverse=True) scores = [r[0] for r in rows] dirs_out = [r[1] for r in rows] cids = [r[2] for r in rows] bar_cols = [DIR_COLORS[d] for d in dirs_out] x = np.arange(len(scores)) ax.bar(x, scores, color=bar_cols, alpha=0.85, width=0.7) ax.axhline(0.7, color="#ffeb3b", lw=1.5, ls="--", alpha=0.8, label="0.70 reference") ax.set_xlabel("Clip rank", fontsize=9) ax.set_ylabel("Combined score", fontsize=9) ax.set_title(f"Query: {direction.upper()}\n" f"(bar colour = actual clip direction)", fontsize=10, color=col) ax.set_ylim(0, 1.05) ax.legend(fontsize=8); ax.grid(True, alpha=0.3, axis="y") ax.set_xticks(x[::2]) ax.set_xticklabels(cids[::2], rotation=45, fontsize=6, ha="right") for rank in range(min(3, len(scores))): ax.text(x[rank], scores[rank] + 0.01, f"#{rank+1}", ha="center", fontsize=7, color="white") patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()] fig.legend(handles=patches, loc="lower center", ncol=3, fontsize=9, bbox_to_anchor=(0.5, -0.04)) plt.tight_layout() save_fig(fig, "08_score_distribution") # ═══════════════════════════════════════════════════════════════════ # PLOT 9 — DTW vs Cosine Scatter # ═══════════════════════════════════════════════════════════════════ def plot_09_dtw_vs_cosine(db): print("[9/10] DTW vs cosine scatter...") entries = list(db.values()) fig, axes = plt.subplots(1, 3, figsize=(17, 5)) fig.suptitle("Plot 9 — DTW Similarity vs Cosine Similarity (per query)\n" "Top-right quadrant = high on both = best matches", fontsize=12, fontweight="bold", color="white") for ax, direction in zip(axes, ["left", "straight", "right"]): query_emb = CANONICAL[direction] query_traj = query_emb[:N_RESAMPLE].astype(np.double) cos_s_all, dtw_s_all, dirs_all = [], [], [] for e in entries: emb = e["embedding"] traj = emb[:N_RESAMPLE].astype(np.double) raw = 1 - scipy_cosine(query_emb.astype(np.float64), emb.astype(np.float64)) cs = (raw + 1) / 2 d = dtw_lib.distance_fast(query_traj, traj) ds = 1 / (1 + d) cos_s_all.append(cs) dtw_s_all.append(ds) dirs_all.append(e["direction"]) for d, col in DIR_COLORS.items(): mask = [dr == d for dr in dirs_all] cx = [c for c, m in zip(cos_s_all, mask) if m] dy = [dv for dv, m in zip(dtw_s_all, mask) if m] ax.scatter(cx, dy, color=col, s=80, alpha=0.85, marker=DIR_MARKERS[d], label=d, zorder=3) # Iso-score lines: score = 0.6*cos + 0.4*dtw → dtw = (score - 0.6*cos)/0.4 cs_grid = np.linspace(0, 1, 200) for target in [0.5, 0.7, 0.8]: dtw_iso = (target - 0.6 * cs_grid) / 0.4 valid = (dtw_iso >= 0) & (dtw_iso <= 1) ax.plot(cs_grid[valid], dtw_iso[valid], color="#555577", lw=1, ls=":", alpha=0.8) if valid.any(): xi, yi = cs_grid[valid][-1], dtw_iso[valid][-1] ax.text(xi, yi, f"score={target}", fontsize=6, color="#888899") ax.set_xlabel("Cosine similarity (α=0.6)", fontsize=9) ax.set_ylabel("DTW similarity (β=0.4)", fontsize=9) ax.set_title(f"Query: {direction.upper()}", fontsize=10, color=DIR_COLORS[direction]) ax.set_xlim(-0.05, 1.05) ax.set_ylim(-0.05, 1.05) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) ax.text(0.84, 0.93, "BEST", fontsize=8, color="#00e676", ha="center", transform=ax.transAxes, alpha=0.7) ax.text(0.14, 0.07, "WORST", fontsize=8, color="#ff5252", ha="center", transform=ax.transAxes, alpha=0.7) plt.tight_layout() save_fig(fig, "09_dtw_vs_cosine") # ═══════════════════════════════════════════════════════════════════ # PLOT 10 — Confusion Matrix # ═══════════════════════════════════════════════════════════════════ def plot_10_confusion_matrix(db): print("[10/10] Confusion matrix...") entries = list(db.values()) dir_list = ["left", "straight", "right"] n = len(dir_list) top_k = 5 confusion = np.zeros((n, n), dtype=int) for qi, q_dir in enumerate(dir_list): query_emb = CANONICAL[q_dir] scored = sorted( [(score_clip(query_emb, e["embedding"]), e["direction"]) for e in entries], key=lambda x: x[0], reverse=True ) for _, ret_dir in scored[:top_k]: ri = dir_list.index(ret_dir) confusion[qi, ri] += 1 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) fig.suptitle("Plot 10 — Direction Confusion Matrix (top-5 retrieval)\n" "Diagonal = correct. Off-diagonal = wrong direction returned.", fontsize=12, fontweight="bold", color="white") # A: heatmap ax = axes[0] im = ax.imshow(confusion, cmap="YlOrRd", vmin=0, vmax=top_k) ax.set_xticks(range(n)); ax.set_yticks(range(n)) ax.set_xticklabels(dir_list, fontsize=11) ax.set_yticklabels(dir_list, fontsize=11) ax.set_xlabel("Retrieved direction (top-5)", fontsize=10) ax.set_ylabel("Query direction", fontsize=10) ax.set_title("A — Raw counts", fontsize=10) plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) for i in range(n): for j in range(n): ax.text(j, i, str(confusion[i, j]), ha="center", va="center", fontsize=16, fontweight="bold", color="black" if confusion[i, j] > top_k * 0.5 else "white") # B: precision bar chart ax = axes[1] precisions = [confusion[i, i] / max(confusion[i].sum(), 1) * 100 for i in range(n)] cols = [DIR_COLORS[d] for d in dir_list] bars = ax.bar(dir_list, precisions, color=cols, alpha=0.85, width=0.5) ax.axhline(100, color="#00e676", lw=1.5, ls="--", alpha=0.6, label="100%") ax.axhline(60, color="#ffeb3b", lw=1.5, ls="--", alpha=0.6, label="60% threshold") for bar, val in zip(bars, precisions): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1.5, f"{val:.0f}%", ha="center", fontsize=13, fontweight="bold", color="white") ax.set_ylim(0, 115) ax.set_ylabel("Precision@5 (%)", fontsize=10) ax.set_title("B — Precision@5 per direction\n" "(what % of top-5 are the correct direction?)", fontsize=10) ax.legend(fontsize=9); ax.grid(True, alpha=0.3, axis="y") plt.tight_layout() save_fig(fig, "10_confusion_matrix") # ═══════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════ if __name__ == "__main__": print("=" * 60) print(" Trajectory Retrieval — Diagnostic Visualizations") print("=" * 60) if not os.path.exists(PKL_PATH): print(f"\nERROR: PKL not found at '{PKL_PATH}'") print("Run video_processor.py first to generate the database.") raise SystemExit(1) db = load_db(PKL_PATH) accessible = sum(1 for e in db.values() if os.path.exists(e["video_path"])) print(f"Clips total : {len(db)}") print(f"With accessible video : {accessible}") print(f"Plots output folder : {OUT_DIR}/\n") # Plots 1–3 need video files on disk plot_01_feature_detection(db) plot_02_flow_field(db) plot_03_foe_stability(db) # Plots 4–10 only need the PKL plot_04_lateral_profiles(db) plot_05_trajectory_gallery(db) plot_06_turn_ratio_distribution(db) plot_07_embedding_space(db) plot_08_score_distribution(db) plot_09_dtw_vs_cosine(db) plot_10_confusion_matrix(db) print("\n" + "=" * 60) print(f" All 10 plots saved to → {OUT_DIR}/") print("=" * 60)