Spaces:
Sleeping
Sleeping
| """ | |
| visualize.py | |
| ============ | |
| All diagnostic plots for the trajectory retrieval system. | |
| BEFORE RUNNING β set your paths at the top of this file: | |
| PKL_PATH : your processed trajectory database pkl | |
| VIDEO_DIR : folder containing your dashcam videos | |
| OUT_DIR : where to save the plots | |
| Run: | |
| python visualize.py | |
| All 10 plots saved as PNG in OUT_DIR. | |
| """ | |
| import os | |
| import sys | |
| import pickle | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import matplotlib.gridspec as gridspec | |
| import seaborn as sns | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| from scipy.spatial.distance import cosine as scipy_cosine | |
| from dtaidistance import dtw as dtw_lib | |
| from collections import defaultdict | |
| import cv2 | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from trajectory_extractor import ( | |
| extract_clip_frames, build_roi_mask, | |
| estimate_lambda, subtract_forward_motion, | |
| remove_outliers_iqr, FOE_X, FOE_Y, N_RESAMPLE, | |
| ST_PARAMS, LK_PARAMS | |
| ) | |
| from retrieval_engine import CANONICAL, score_clip | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # βΆβΆ SET YOUR PATHS HERE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PKL_PATH = "trajectory_db100.pkl" | |
| VIDEO_DIR = "/media/RTCIN15TBD/AllUsers/sjl3kor/trajectory_retrieval_system/data/videos" | |
| OUT_DIR = "plots" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Global dark style βββββββββββββββββββββββββββββββββββββββββββββββ | |
| plt.rcParams.update({ | |
| "figure.facecolor" : "#0f0f1a", | |
| "axes.facecolor" : "#1a1a2e", | |
| "axes.edgecolor" : "#444466", | |
| "axes.labelcolor" : "#ccccdd", | |
| "axes.titlecolor" : "#ffffff", | |
| "xtick.color" : "#aaaacc", | |
| "ytick.color" : "#aaaacc", | |
| "text.color" : "#ccccdd", | |
| "grid.color" : "#2a2a4a", | |
| "grid.linewidth" : 0.6, | |
| "legend.facecolor" : "#1a1a2e", | |
| "legend.edgecolor" : "#444466", | |
| "font.family" : "monospace", | |
| }) | |
| DIR_COLORS = {"left": "#00e676", "right": "#ff5252", "straight": "#90caf9"} | |
| DIR_MARKERS = {"left": "^", "right": "v", "straight": "s"} | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_db(path): | |
| with open(path, "rb") as f: | |
| db = pickle.load(f) | |
| print(f"Loaded {len(db)} clips") | |
| return db | |
| def save_fig(fig, name): | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| p = os.path.join(OUT_DIR, f"{name}.png") | |
| fig.savefig(p, dpi=130, bbox_inches="tight", facecolor=fig.get_facecolor()) | |
| plt.close(fig) | |
| print(f" Saved β {p}") | |
| def video_exists(entry): | |
| return os.path.exists(entry["video_path"]) | |
| def entries_with_video(db): | |
| return [e for e in db.values() if video_exists(e)] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 1 β Feature Detection Quality | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_01_feature_detection(db): | |
| print("\n[1/10] Feature detection quality...") | |
| all_ys, all_xs, feat_counts, clip_dirs = [], [], [], [] | |
| for entry in db.values(): | |
| if not video_exists(entry): | |
| continue | |
| frames = extract_clip_frames(entry["video_path"], | |
| entry["start_frame"], | |
| entry["start_frame"] + 2) | |
| if not frames: | |
| continue | |
| gray = frames[0] | |
| mask = build_roi_mask(gray.shape) | |
| pts = cv2.goodFeaturesToTrack(gray, mask=mask, **ST_PARAMS) | |
| if pts is not None: | |
| all_ys.extend(pts[:, 0, 1].tolist()) | |
| all_xs.extend(pts[:, 0, 0].tolist()) | |
| feat_counts.append(entry["n_features_avg"]) | |
| clip_dirs.append(entry["direction"]) | |
| fig, axes = plt.subplots(1, 3, figsize=(17, 5)) | |
| fig.suptitle("Plot 1 β Feature Detection Quality (Shi-Tomasi + ROI Mask)", | |
| fontsize=13, fontweight="bold", color="white", y=1.01) | |
| # A: vertical distribution | |
| ax = axes[0] | |
| if all_ys: | |
| ax.hist(all_ys, bins=54, orientation="horizontal", | |
| color="#00e5ff", alpha=0.85, edgecolor="none") | |
| ax.axhline(200, color="#ffeb3b", lw=2, ls="--", label="ROI top y=200") | |
| ax.axhline(880, color="#ff9800", lw=2, ls="--", label="ROI bot y=880") | |
| ax.axhspan(0, 200, alpha=0.12, color="#ffeb3b") | |
| ax.axhspan(880, 1080, alpha=0.12, color="#ff9800") | |
| ax.invert_yaxis() | |
| ax.set_xlabel("Feature count", fontsize=10) | |
| ax.set_ylabel("Image row (y pixel)", fontsize=10) | |
| ax.set_title("A β Vertical distribution\n(should cluster 200β880)", fontsize=10) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3) | |
| # B: horizontal distribution | |
| ax = axes[1] | |
| if all_xs: | |
| ax.hist(all_xs, bins=48, color="#e040fb", alpha=0.85, edgecolor="none") | |
| ax.axvline(960, color="#ffffff", lw=1.5, ls=":", label="center x=960") | |
| ax.set_xlabel("Image column (x pixel)", fontsize=10) | |
| ax.set_ylabel("Feature count", fontsize=10) | |
| ax.set_title("B β Horizontal distribution\n(should be roughly symmetric)", fontsize=10) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3) | |
| # C: feature count per clip | |
| ax = axes[2] | |
| if feat_counts: | |
| colors = [DIR_COLORS[d] for d in clip_dirs] | |
| ax.bar(range(len(feat_counts)), feat_counts, | |
| color=colors, alpha=0.85, width=0.7) | |
| ax.axhline(50, color="#ff5252", lw=1.5, ls="--", label="Min reliable = 50") | |
| patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()] | |
| ax.legend(handles=patches, fontsize=8) | |
| ax.set_xlabel("Clip index", fontsize=10) | |
| ax.set_ylabel("Avg tracked features", fontsize=10) | |
| ax.set_title("C β Avg features per clip\n(below 50 = unreliable)", fontsize=10) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| plt.tight_layout() | |
| save_fig(fig, "01_feature_detection_quality") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 2 β Optical Flow Quiver Field | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_02_flow_field(db): | |
| print("[2/10] Flow vector field...") | |
| ev = entries_with_video(db) | |
| if not ev: | |
| print(" Skipped β no video files accessible.") | |
| return | |
| straight_e = next((e for e in ev if e["direction"] == "straight"), None) | |
| turn_e = max(ev, key=lambda x: abs(x["turn_ratio"])) | |
| fig, axes = plt.subplots(1, 2, figsize=(16, 6)) | |
| fig.suptitle("Plot 2 β Optical Flow Vector Field\n" | |
| "Left: straight | Right: strongest turn", | |
| fontsize=12, fontweight="bold", color="white") | |
| def draw_quiver(ax, entry, title): | |
| frames = extract_clip_frames(entry["video_path"], | |
| entry["start_frame"], | |
| entry["start_frame"] + 3) | |
| if len(frames) < 2: | |
| return | |
| g0, g1 = frames[0], frames[1] | |
| mask = build_roi_mask(g0.shape) | |
| pts0 = cv2.goodFeaturesToTrack(g0, mask=mask, **ST_PARAMS) | |
| if pts0 is None: | |
| return | |
| pts1, status, _ = cv2.calcOpticalFlowPyrLK(g0, g1, pts0, None, **LK_PARAMS) | |
| good0 = pts0[status == 1].reshape(-1, 2) | |
| good1 = pts1[status == 1].reshape(-1, 2) | |
| dx = good1[:, 0] - good0[:, 0] | |
| dy = good1[:, 1] - good0[:, 1] | |
| sx, sy = 960 / 1920, 540 / 1080 | |
| xs = good0[:, 0] * sx | |
| ys = good0[:, 1] * sy | |
| mag = np.sqrt(dx**2 + dy**2) | |
| ax.set_facecolor("#0d1117") | |
| ax.quiver(xs, ys, dx * 12, -dy * 12, mag, | |
| cmap="plasma", alpha=0.85, | |
| scale=1, scale_units="xy", angles="xy", width=0.003) | |
| ax.scatter([FOE_X * sx], [FOE_Y * sy], | |
| c="red", s=150, zorder=6, label=f"FOE ({FOE_X},{FOE_Y})") | |
| ax.axhline(200 * sy, color="#ffeb3b", lw=1.2, ls="--", | |
| alpha=0.6, label="ROI boundary") | |
| ax.axhline(880 * sy, color="#ffeb3b", lw=1.2, ls="--", alpha=0.6) | |
| ax.set_xlim(0, 960); ax.set_ylim(540, 0) | |
| ax.set_xlabel("x (scaled)", fontsize=9) | |
| ax.set_ylabel("y (scaled)", fontsize=9) | |
| ax.set_title(f"{title}\n" | |
| f"dir={entry['direction']} " | |
| f"ratio={entry['turn_ratio']:+.3f}", fontsize=10) | |
| ax.legend(fontsize=8) | |
| if straight_e: | |
| draw_quiver(axes[0], straight_e, "Straight clip") | |
| else: | |
| axes[0].set_title("No straight clip found") | |
| draw_quiver(axes[1], turn_e, "Strongest turn clip") | |
| plt.tight_layout() | |
| save_fig(fig, "02_flow_vector_field") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 3 β FOE Stability + Lateral Signal per Frame | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_03_foe_stability(db): | |
| print("[3/10] FOE stability + lateral signal...") | |
| ev = entries_with_video(db) | |
| if not ev: | |
| print(" Skipped β no video files accessible.") | |
| return | |
| # Pick one clip per direction (prefer strongest) | |
| selected = {} | |
| for d in ["left", "right", "straight"]: | |
| candidates = [e for e in ev if e["direction"] == d] | |
| if candidates: | |
| selected[d] = max(candidates, key=lambda x: abs(x["turn_ratio"])) | |
| fig, axes = plt.subplots(len(selected), 2, | |
| figsize=(14, 4 * len(selected))) | |
| if len(selected) == 1: | |
| axes = [axes] | |
| fig.suptitle("Plot 3 β Lateral Signal Per Frame (one clip per direction)\n" | |
| "Consistent sign = clean turn detection", | |
| fontsize=12, fontweight="bold", color="white") | |
| for row, (direction, entry) in enumerate(selected.items()): | |
| col = DIR_COLORS[direction] | |
| lat_s = entry["lateral_signals"] | |
| t = np.arange(len(lat_s)) | |
| # Left sub-panel: cumulative trajectory | |
| ax_traj = axes[row][0] | |
| traj_raw = np.array(entry["trajectory_raw"]) | |
| t2 = np.linspace(0, 1, len(traj_raw)) | |
| ax_traj.plot(t2, traj_raw, color=col, lw=2) | |
| ax_traj.fill_between(t2, traj_raw, 0, alpha=0.2, color=col) | |
| ax_traj.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4) | |
| ax_traj.set_xlabel("Normalised time", fontsize=9) | |
| ax_traj.set_ylabel("Cumulative lateral (px)", fontsize=9) | |
| ax_traj.set_title(f"{direction.upper()} " | |
| f"turn_ratio={entry['turn_ratio']:+.3f}\n" | |
| f"Cumulative trajectory", | |
| fontsize=10, color=col) | |
| ax_traj.grid(True, alpha=0.3) | |
| # Right sub-panel: per-frame lateral bars | |
| ax_lat = axes[row][1] | |
| bar_colors = [DIR_COLORS["left"] if v >= 0 else DIR_COLORS["right"] | |
| for v in lat_s] | |
| ax_lat.bar(t, lat_s, color=bar_colors, alpha=0.8, width=0.8) | |
| ax_lat.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4) | |
| ax_lat.set_xlabel("Frame pair index", fontsize=9) | |
| ax_lat.set_ylabel("Lateral signal (px/frame)", fontsize=9) | |
| ax_lat.set_title("Per-frame lateral signal\n" | |
| "Green=rightward(left turn) Red=leftward(right turn)", | |
| fontsize=9) | |
| ax_lat.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| save_fig(fig, "03_foe_stability") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 4 β Lateral Signal Profiles (all clips overlaid) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_04_lateral_profiles(db): | |
| print("[4/10] Lateral signal profiles...") | |
| by_dir = defaultdict(list) | |
| for entry in db.values(): | |
| by_dir[entry["direction"]].append(entry["lateral_signals"]) | |
| fig, axes = plt.subplots(1, 3, figsize=(17, 5), sharey=False) | |
| fig.suptitle("Plot 4 β Per-Frame Lateral Signal (all clips overlaid per direction)\n" | |
| "Consistent direction = pipeline is working. Mixed = noise.", | |
| fontsize=12, fontweight="bold", color="white") | |
| for ax, direction in zip(axes, ["left", "straight", "right"]): | |
| signals = by_dir[direction] | |
| col = DIR_COLORS[direction] | |
| for sig in signals: | |
| t = np.linspace(0, 1, len(sig)) | |
| ax.plot(t, sig, color=col, lw=1, alpha=0.35) | |
| if signals: | |
| max_len = max(len(s) for s in signals) | |
| resampled = [np.interp(np.linspace(0, 1, max_len), | |
| np.linspace(0, 1, len(s)), s) | |
| for s in signals] | |
| arr = np.array(resampled) | |
| med = np.median(arr, axis=0) | |
| t_med = np.linspace(0, 1, max_len) | |
| ax.plot(t_med, med, color="white", lw=2.5, label="Median", zorder=5) | |
| ax.fill_between(t_med, | |
| np.percentile(arr, 25, axis=0), | |
| np.percentile(arr, 75, axis=0), | |
| alpha=0.2, color=col, label="IQR band") | |
| ax.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4) | |
| ax.set_title(f"{direction.upper()} ({len(signals)} clips)", | |
| fontsize=11, color=col) | |
| ax.set_xlabel("Normalised time (0=start, 1=end)", fontsize=9) | |
| ax.set_ylabel("Lateral signal (px/frame)", fontsize=9) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| save_fig(fig, "04_lateral_signal_profiles") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 5 β Trajectory Gallery | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_05_trajectory_gallery(db): | |
| print("[5/10] Trajectory gallery...") | |
| entries = list(db.values()) | |
| n = len(entries) | |
| ncols = 7 | |
| nrows = (n + ncols - 1) // ncols | |
| fig, axes = plt.subplots(nrows, ncols, | |
| figsize=(ncols * 2.2, nrows * 2.4)) | |
| fig.suptitle("Plot 5 β Trajectory Gallery (every clip)\n" | |
| "Each panel = one 3s clip's cumulative lateral trajectory", | |
| fontsize=12, fontweight="bold", color="white") | |
| axes_flat = axes.flatten() if nrows > 1 or ncols > 1 else [axes] | |
| for i, entry in enumerate(entries): | |
| ax = axes_flat[i] | |
| col = DIR_COLORS[entry["direction"]] | |
| traj = np.array(entry["trajectory_raw"]) | |
| t = np.linspace(0, 1, len(traj)) | |
| ax.plot(t, traj, color=col, lw=1.8) | |
| ax.fill_between(t, traj, 0, alpha=0.18, color=col) | |
| ax.axhline(0, color="#555577", lw=0.8) | |
| ax.set_title( | |
| f"{entry['clip_id'].split('__')[-1]}\n" | |
| f"{entry['direction']} {entry['turn_ratio']:+.2f}", | |
| fontsize=6.5, color=col | |
| ) | |
| ax.set_xticks([]); ax.set_yticks([]) | |
| for spine in ["bottom", "left"]: | |
| ax.spines[spine].set_color(col) | |
| for spine in ["top", "right"]: | |
| ax.spines[spine].set_color("#333355") | |
| for j in range(i + 1, len(axes_flat)): | |
| axes_flat[j].set_visible(False) | |
| patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()] | |
| fig.legend(handles=patches, loc="lower center", ncol=3, | |
| fontsize=10, bbox_to_anchor=(0.5, -0.01)) | |
| plt.tight_layout() | |
| save_fig(fig, "05_trajectory_gallery") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 6 β Turn Ratio Distribution | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_06_turn_ratio_distribution(db): | |
| print("[6/10] Turn ratio distribution...") | |
| ratios = [e["turn_ratio"] for e in db.values()] | |
| directions = [e["direction"] for e in db.values()] | |
| peaks = [e["peak_lateral"] for e in db.values()] | |
| fig, axes = plt.subplots(1, 3, figsize=(17, 5)) | |
| fig.suptitle("Plot 6 β Turn Ratio Distribution\n" | |
| "Three separate peaks = clean separation between directions", | |
| fontsize=12, fontweight="bold", color="white") | |
| # A: histogram | |
| ax = axes[0] | |
| for d, col in DIR_COLORS.items(): | |
| vals = [r for r, dr in zip(ratios, directions) if dr == d] | |
| if vals: | |
| ax.hist(vals, bins=20, color=col, alpha=0.7, | |
| label=f"{d} ({len(vals)})", edgecolor="none") | |
| ax.axvline( 0.60, color="#ffeb3b", lw=2, ls="--", label="+0.60 threshold") | |
| ax.axvline(-0.60, color="#ffeb3b", lw=2, ls="--", label="-0.60 threshold") | |
| ax.axvline(0.0, color="#ffffff", lw=1, ls=":", alpha=0.5) | |
| ax.set_xlabel("turn_ratio", fontsize=10) | |
| ax.set_ylabel("Count", fontsize=10) | |
| ax.set_title("A β Histogram\n(gap at Β±0.60?)", fontsize=10) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3) | |
| # B: strip plot | |
| ax = axes[1] | |
| rng = np.random.default_rng(42) | |
| jitter = rng.uniform(-0.2, 0.2, len(ratios)) | |
| for d, col in DIR_COLORS.items(): | |
| mask = [dr == d for dr in directions] | |
| r_d = [r for r, m in zip(ratios, mask) if m] | |
| j_d = [j for j, m in zip(jitter, mask) if m] | |
| ax.scatter(r_d, j_d, color=col, s=70, alpha=0.85, | |
| marker=DIR_MARKERS[d], label=d, zorder=3) | |
| ax.axvline( 0.60, color="#ffeb3b", lw=2, ls="--") | |
| ax.axvline(-0.60, color="#ffeb3b", lw=2, ls="--") | |
| ax.axvline(0.0, color="#ffffff", lw=1, ls=":", alpha=0.5) | |
| ax.set_xlabel("turn_ratio", fontsize=10) | |
| ax.set_yticks([]) | |
| ax.set_title("B β Strip plot\n(overlapping near threshold = borderline clips)", fontsize=10) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3, axis="x") | |
| # C: peak lateral vs turn ratio | |
| ax = axes[2] | |
| for d, col in DIR_COLORS.items(): | |
| r_d = [r for r, dr in zip(ratios, directions) if dr == d] | |
| p_d = [p for p, dr in zip(peaks, directions) if dr == d] | |
| ax.scatter(r_d, p_d, color=col, s=70, alpha=0.85, | |
| marker=DIR_MARKERS[d], label=d) | |
| ax.set_xlabel("turn_ratio", fontsize=10) | |
| ax.set_ylabel("peak_lateral (px)", fontsize=10) | |
| ax.set_title("C β Peak drift vs turn ratio\n(stronger turn = bigger peak)", fontsize=10) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| save_fig(fig, "06_turn_ratio_distribution") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 7 β Embedding Space (PCA + t-SNE) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_07_embedding_space(db): | |
| print("[7/10] Embedding space...") | |
| entries = list(db.values()) | |
| embeddings = np.array([e["embedding"] for e in entries]) | |
| directions = [e["direction"] for e in entries] | |
| labels = [e["clip_id"].split("__")[-1] for e in entries] | |
| fig = plt.figure(figsize=(18, 7)) | |
| fig.suptitle("Plot 7 β Embedding Space\n" | |
| "Separated clusters = embeddings discriminate directions well", | |
| fontsize=12, fontweight="bold", color="white") | |
| gs = gridspec.GridSpec(1, 3, figure=fig, wspace=0.35) | |
| # A: PCA | |
| ax_pca = fig.add_subplot(gs[0]) | |
| pca = PCA(n_components=2) | |
| pca_2d = pca.fit_transform(embeddings) | |
| for d, col in DIR_COLORS.items(): | |
| mask = np.array([dr == d for dr in directions]) | |
| if mask.any(): | |
| ax_pca.scatter(pca_2d[mask, 0], pca_2d[mask, 1], | |
| color=col, s=90, alpha=0.85, | |
| marker=DIR_MARKERS[d], label=d, zorder=3) | |
| for x, y, lbl in zip(pca_2d[:, 0], pca_2d[:, 1], labels): | |
| ax_pca.annotate(lbl, (x, y), fontsize=5.5, color="#aaaacc", | |
| alpha=0.7, xytext=(3, 3), textcoords="offset points") | |
| ax_pca.set_title( | |
| f"A β PCA\n({pca.explained_variance_ratio_.sum()*100:.1f}% variance)", fontsize=10) | |
| ax_pca.set_xlabel(f"PC1 {pca.explained_variance_ratio_[0]*100:.1f}%", fontsize=9) | |
| ax_pca.set_ylabel(f"PC2 {pca.explained_variance_ratio_[1]*100:.1f}%", fontsize=9) | |
| ax_pca.legend(fontsize=9); ax_pca.grid(True, alpha=0.3) | |
| # B: t-SNE | |
| ax_tsne = fig.add_subplot(gs[1]) | |
| n_clips = len(embeddings) | |
| if n_clips >= 6: | |
| perp = min(5, n_clips - 1) | |
| tsne = TSNE(n_components=2, perplexity=perp, | |
| random_state=42, max_iter=1000) | |
| ts_2d = tsne.fit_transform(embeddings) | |
| for d, col in DIR_COLORS.items(): | |
| mask = np.array([dr == d for dr in directions]) | |
| if mask.any(): | |
| ax_tsne.scatter(ts_2d[mask, 0], ts_2d[mask, 1], | |
| color=col, s=90, alpha=0.85, | |
| marker=DIR_MARKERS[d], label=d, zorder=3) | |
| for x, y, lbl in zip(ts_2d[:, 0], ts_2d[:, 1], labels): | |
| ax_tsne.annotate(lbl, (x, y), fontsize=5.5, color="#aaaacc", | |
| alpha=0.7, xytext=(3, 3), textcoords="offset points") | |
| ax_tsne.set_title("B β t-SNE (non-linear local structure)", fontsize=10) | |
| ax_tsne.legend(fontsize=9) | |
| else: | |
| ax_tsne.text(0.5, 0.5, f"Need β₯ 6 clips\n(have {n_clips})", | |
| ha="center", va="center", transform=ax_tsne.transAxes, | |
| color="#ff5252", fontsize=12) | |
| ax_tsne.set_title("B β t-SNE", fontsize=10) | |
| ax_tsne.grid(True, alpha=0.3) | |
| # C: explained variance | |
| ax_var = fig.add_subplot(gs[2]) | |
| n_comp = min(15, n_clips - 1) | |
| pca_f = PCA(n_components=n_comp) | |
| pca_f.fit(embeddings) | |
| evr = pca_f.explained_variance_ratio_ | |
| cum = np.cumsum(evr) | |
| comps = np.arange(1, len(evr) + 1) | |
| ax_var.bar(comps, evr * 100, color="#e040fb", alpha=0.8, label="Individual") | |
| ax_var.plot(comps, cum * 100, color="#ffeb3b", lw=2, | |
| marker="o", ms=4, label="Cumulative") | |
| ax_var.axhline(90, color="#ff5252", lw=1.5, ls="--", | |
| alpha=0.7, label="90% line") | |
| ax_var.set_xlabel("Principal component", fontsize=9) | |
| ax_var.set_ylabel("Explained variance (%)", fontsize=9) | |
| ax_var.set_title("C β PCA explained variance\n" | |
| "(few PCs to reach 90% = compact embedding)", fontsize=10) | |
| ax_var.legend(fontsize=8); ax_var.grid(True, alpha=0.3) | |
| save_fig(fig, "07_embedding_space") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 8 β Score Distribution per Query | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_08_score_distribution(db): | |
| print("[8/10] Score distribution per query...") | |
| entries = list(db.values()) | |
| fig, axes = plt.subplots(1, 3, figsize=(17, 5)) | |
| fig.suptitle("Plot 8 β Score Distribution for Each Button Query\n" | |
| "Sharp drop after top results = clear winner. Flat = poor discrimination.", | |
| fontsize=12, fontweight="bold", color="white") | |
| for ax, direction in zip(axes, ["left", "straight", "right"]): | |
| query_emb = CANONICAL[direction] | |
| col = DIR_COLORS[direction] | |
| rows = [] | |
| for e in entries: | |
| s = score_clip(query_emb, e["embedding"]) | |
| rows.append((s, e["direction"], e["clip_id"].split("__")[-1])) | |
| rows.sort(key=lambda x: x[0], reverse=True) | |
| scores = [r[0] for r in rows] | |
| dirs_out = [r[1] for r in rows] | |
| cids = [r[2] for r in rows] | |
| bar_cols = [DIR_COLORS[d] for d in dirs_out] | |
| x = np.arange(len(scores)) | |
| ax.bar(x, scores, color=bar_cols, alpha=0.85, width=0.7) | |
| ax.axhline(0.7, color="#ffeb3b", lw=1.5, ls="--", | |
| alpha=0.8, label="0.70 reference") | |
| ax.set_xlabel("Clip rank", fontsize=9) | |
| ax.set_ylabel("Combined score", fontsize=9) | |
| ax.set_title(f"Query: {direction.upper()}\n" | |
| f"(bar colour = actual clip direction)", fontsize=10, color=col) | |
| ax.set_ylim(0, 1.05) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3, axis="y") | |
| ax.set_xticks(x[::2]) | |
| ax.set_xticklabels(cids[::2], rotation=45, fontsize=6, ha="right") | |
| for rank in range(min(3, len(scores))): | |
| ax.text(x[rank], scores[rank] + 0.01, | |
| f"#{rank+1}", ha="center", fontsize=7, color="white") | |
| patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()] | |
| fig.legend(handles=patches, loc="lower center", ncol=3, | |
| fontsize=9, bbox_to_anchor=(0.5, -0.04)) | |
| plt.tight_layout() | |
| save_fig(fig, "08_score_distribution") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 9 β DTW vs Cosine Scatter | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_09_dtw_vs_cosine(db): | |
| print("[9/10] DTW vs cosine scatter...") | |
| entries = list(db.values()) | |
| fig, axes = plt.subplots(1, 3, figsize=(17, 5)) | |
| fig.suptitle("Plot 9 β DTW Similarity vs Cosine Similarity (per query)\n" | |
| "Top-right quadrant = high on both = best matches", | |
| fontsize=12, fontweight="bold", color="white") | |
| for ax, direction in zip(axes, ["left", "straight", "right"]): | |
| query_emb = CANONICAL[direction] | |
| query_traj = query_emb[:N_RESAMPLE].astype(np.double) | |
| cos_s_all, dtw_s_all, dirs_all = [], [], [] | |
| for e in entries: | |
| emb = e["embedding"] | |
| traj = emb[:N_RESAMPLE].astype(np.double) | |
| raw = 1 - scipy_cosine(query_emb.astype(np.float64), | |
| emb.astype(np.float64)) | |
| cs = (raw + 1) / 2 | |
| d = dtw_lib.distance_fast(query_traj, traj) | |
| ds = 1 / (1 + d) | |
| cos_s_all.append(cs) | |
| dtw_s_all.append(ds) | |
| dirs_all.append(e["direction"]) | |
| for d, col in DIR_COLORS.items(): | |
| mask = [dr == d for dr in dirs_all] | |
| cx = [c for c, m in zip(cos_s_all, mask) if m] | |
| dy = [dv for dv, m in zip(dtw_s_all, mask) if m] | |
| ax.scatter(cx, dy, color=col, s=80, alpha=0.85, | |
| marker=DIR_MARKERS[d], label=d, zorder=3) | |
| # Iso-score lines: score = 0.6*cos + 0.4*dtw β dtw = (score - 0.6*cos)/0.4 | |
| cs_grid = np.linspace(0, 1, 200) | |
| for target in [0.5, 0.7, 0.8]: | |
| dtw_iso = (target - 0.6 * cs_grid) / 0.4 | |
| valid = (dtw_iso >= 0) & (dtw_iso <= 1) | |
| ax.plot(cs_grid[valid], dtw_iso[valid], | |
| color="#555577", lw=1, ls=":", alpha=0.8) | |
| if valid.any(): | |
| xi, yi = cs_grid[valid][-1], dtw_iso[valid][-1] | |
| ax.text(xi, yi, f"score={target}", | |
| fontsize=6, color="#888899") | |
| ax.set_xlabel("Cosine similarity (Ξ±=0.6)", fontsize=9) | |
| ax.set_ylabel("DTW similarity (Ξ²=0.4)", fontsize=9) | |
| ax.set_title(f"Query: {direction.upper()}", fontsize=10, | |
| color=DIR_COLORS[direction]) | |
| ax.set_xlim(-0.05, 1.05) | |
| ax.set_ylim(-0.05, 1.05) | |
| ax.legend(fontsize=8); ax.grid(True, alpha=0.3) | |
| ax.text(0.84, 0.93, "BEST", fontsize=8, color="#00e676", | |
| ha="center", transform=ax.transAxes, alpha=0.7) | |
| ax.text(0.14, 0.07, "WORST", fontsize=8, color="#ff5252", | |
| ha="center", transform=ax.transAxes, alpha=0.7) | |
| plt.tight_layout() | |
| save_fig(fig, "09_dtw_vs_cosine") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT 10 β Confusion Matrix | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_10_confusion_matrix(db): | |
| print("[10/10] Confusion matrix...") | |
| entries = list(db.values()) | |
| dir_list = ["left", "straight", "right"] | |
| n = len(dir_list) | |
| top_k = 5 | |
| confusion = np.zeros((n, n), dtype=int) | |
| for qi, q_dir in enumerate(dir_list): | |
| query_emb = CANONICAL[q_dir] | |
| scored = sorted( | |
| [(score_clip(query_emb, e["embedding"]), e["direction"]) | |
| for e in entries], | |
| key=lambda x: x[0], reverse=True | |
| ) | |
| for _, ret_dir in scored[:top_k]: | |
| ri = dir_list.index(ret_dir) | |
| confusion[qi, ri] += 1 | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.suptitle("Plot 10 β Direction Confusion Matrix (top-5 retrieval)\n" | |
| "Diagonal = correct. Off-diagonal = wrong direction returned.", | |
| fontsize=12, fontweight="bold", color="white") | |
| # A: heatmap | |
| ax = axes[0] | |
| im = ax.imshow(confusion, cmap="YlOrRd", vmin=0, vmax=top_k) | |
| ax.set_xticks(range(n)); ax.set_yticks(range(n)) | |
| ax.set_xticklabels(dir_list, fontsize=11) | |
| ax.set_yticklabels(dir_list, fontsize=11) | |
| ax.set_xlabel("Retrieved direction (top-5)", fontsize=10) | |
| ax.set_ylabel("Query direction", fontsize=10) | |
| ax.set_title("A β Raw counts", fontsize=10) | |
| plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| for i in range(n): | |
| for j in range(n): | |
| ax.text(j, i, str(confusion[i, j]), | |
| ha="center", va="center", fontsize=16, fontweight="bold", | |
| color="black" if confusion[i, j] > top_k * 0.5 else "white") | |
| # B: precision bar chart | |
| ax = axes[1] | |
| precisions = [confusion[i, i] / max(confusion[i].sum(), 1) * 100 | |
| for i in range(n)] | |
| cols = [DIR_COLORS[d] for d in dir_list] | |
| bars = ax.bar(dir_list, precisions, color=cols, alpha=0.85, width=0.5) | |
| ax.axhline(100, color="#00e676", lw=1.5, ls="--", alpha=0.6, label="100%") | |
| ax.axhline(60, color="#ffeb3b", lw=1.5, ls="--", alpha=0.6, label="60% threshold") | |
| for bar, val in zip(bars, precisions): | |
| ax.text(bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 1.5, | |
| f"{val:.0f}%", ha="center", fontsize=13, | |
| fontweight="bold", color="white") | |
| ax.set_ylim(0, 115) | |
| ax.set_ylabel("Precision@5 (%)", fontsize=10) | |
| ax.set_title("B β Precision@5 per direction\n" | |
| "(what % of top-5 are the correct direction?)", fontsize=10) | |
| ax.legend(fontsize=9); ax.grid(True, alpha=0.3, axis="y") | |
| plt.tight_layout() | |
| save_fig(fig, "10_confusion_matrix") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print(" Trajectory Retrieval β Diagnostic Visualizations") | |
| print("=" * 60) | |
| if not os.path.exists(PKL_PATH): | |
| print(f"\nERROR: PKL not found at '{PKL_PATH}'") | |
| print("Run video_processor.py first to generate the database.") | |
| raise SystemExit(1) | |
| db = load_db(PKL_PATH) | |
| accessible = sum(1 for e in db.values() if os.path.exists(e["video_path"])) | |
| print(f"Clips total : {len(db)}") | |
| print(f"With accessible video : {accessible}") | |
| print(f"Plots output folder : {OUT_DIR}/\n") | |
| # Plots 1β3 need video files on disk | |
| plot_01_feature_detection(db) | |
| plot_02_flow_field(db) | |
| plot_03_foe_stability(db) | |
| # Plots 4β10 only need the PKL | |
| plot_04_lateral_profiles(db) | |
| plot_05_trajectory_gallery(db) | |
| plot_06_turn_ratio_distribution(db) | |
| plot_07_embedding_space(db) | |
| plot_08_score_distribution(db) | |
| plot_09_dtw_vs_cosine(db) | |
| plot_10_confusion_matrix(db) | |
| print("\n" + "=" * 60) | |
| print(f" All 10 plots saved to β {OUT_DIR}/") | |
| print("=" * 60) |