LK_trajectory / visualize.py
sanskar753's picture
Upload folder using huggingface_hub
48f8a1e verified
"""
visualize.py
============
All diagnostic plots for the trajectory retrieval system.
BEFORE RUNNING β€” set your paths at the top of this file:
PKL_PATH : your processed trajectory database pkl
VIDEO_DIR : folder containing your dashcam videos
OUT_DIR : where to save the plots
Run:
python visualize.py
All 10 plots saved as PNG in OUT_DIR.
"""
import os
import sys
import pickle
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.spatial.distance import cosine as scipy_cosine
from dtaidistance import dtw as dtw_lib
from collections import defaultdict
import cv2
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from trajectory_extractor import (
extract_clip_frames, build_roi_mask,
estimate_lambda, subtract_forward_motion,
remove_outliers_iqr, FOE_X, FOE_Y, N_RESAMPLE,
ST_PARAMS, LK_PARAMS
)
from retrieval_engine import CANONICAL, score_clip
# ═══════════════════════════════════════════════════════════════════
# β–Άβ–Ά SET YOUR PATHS HERE
# ═══════════════════════════════════════════════════════════════════
PKL_PATH = "trajectory_db100.pkl"
VIDEO_DIR = "/media/RTCIN15TBD/AllUsers/sjl3kor/trajectory_retrieval_system/data/videos"
OUT_DIR = "plots"
# ═══════════════════════════════════════════════════════════════════
# ── Global dark style ───────────────────────────────────────────────
plt.rcParams.update({
"figure.facecolor" : "#0f0f1a",
"axes.facecolor" : "#1a1a2e",
"axes.edgecolor" : "#444466",
"axes.labelcolor" : "#ccccdd",
"axes.titlecolor" : "#ffffff",
"xtick.color" : "#aaaacc",
"ytick.color" : "#aaaacc",
"text.color" : "#ccccdd",
"grid.color" : "#2a2a4a",
"grid.linewidth" : 0.6,
"legend.facecolor" : "#1a1a2e",
"legend.edgecolor" : "#444466",
"font.family" : "monospace",
})
DIR_COLORS = {"left": "#00e676", "right": "#ff5252", "straight": "#90caf9"}
DIR_MARKERS = {"left": "^", "right": "v", "straight": "s"}
# ── Helpers ──────────────────────────────────────────────────────────
def load_db(path):
with open(path, "rb") as f:
db = pickle.load(f)
print(f"Loaded {len(db)} clips")
return db
def save_fig(fig, name):
os.makedirs(OUT_DIR, exist_ok=True)
p = os.path.join(OUT_DIR, f"{name}.png")
fig.savefig(p, dpi=130, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
print(f" Saved β†’ {p}")
def video_exists(entry):
return os.path.exists(entry["video_path"])
def entries_with_video(db):
return [e for e in db.values() if video_exists(e)]
# ═══════════════════════════════════════════════════════════════════
# PLOT 1 β€” Feature Detection Quality
# ═══════════════════════════════════════════════════════════════════
def plot_01_feature_detection(db):
print("\n[1/10] Feature detection quality...")
all_ys, all_xs, feat_counts, clip_dirs = [], [], [], []
for entry in db.values():
if not video_exists(entry):
continue
frames = extract_clip_frames(entry["video_path"],
entry["start_frame"],
entry["start_frame"] + 2)
if not frames:
continue
gray = frames[0]
mask = build_roi_mask(gray.shape)
pts = cv2.goodFeaturesToTrack(gray, mask=mask, **ST_PARAMS)
if pts is not None:
all_ys.extend(pts[:, 0, 1].tolist())
all_xs.extend(pts[:, 0, 0].tolist())
feat_counts.append(entry["n_features_avg"])
clip_dirs.append(entry["direction"])
fig, axes = plt.subplots(1, 3, figsize=(17, 5))
fig.suptitle("Plot 1 β€” Feature Detection Quality (Shi-Tomasi + ROI Mask)",
fontsize=13, fontweight="bold", color="white", y=1.01)
# A: vertical distribution
ax = axes[0]
if all_ys:
ax.hist(all_ys, bins=54, orientation="horizontal",
color="#00e5ff", alpha=0.85, edgecolor="none")
ax.axhline(200, color="#ffeb3b", lw=2, ls="--", label="ROI top y=200")
ax.axhline(880, color="#ff9800", lw=2, ls="--", label="ROI bot y=880")
ax.axhspan(0, 200, alpha=0.12, color="#ffeb3b")
ax.axhspan(880, 1080, alpha=0.12, color="#ff9800")
ax.invert_yaxis()
ax.set_xlabel("Feature count", fontsize=10)
ax.set_ylabel("Image row (y pixel)", fontsize=10)
ax.set_title("A β€” Vertical distribution\n(should cluster 200–880)", fontsize=10)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
# B: horizontal distribution
ax = axes[1]
if all_xs:
ax.hist(all_xs, bins=48, color="#e040fb", alpha=0.85, edgecolor="none")
ax.axvline(960, color="#ffffff", lw=1.5, ls=":", label="center x=960")
ax.set_xlabel("Image column (x pixel)", fontsize=10)
ax.set_ylabel("Feature count", fontsize=10)
ax.set_title("B β€” Horizontal distribution\n(should be roughly symmetric)", fontsize=10)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
# C: feature count per clip
ax = axes[2]
if feat_counts:
colors = [DIR_COLORS[d] for d in clip_dirs]
ax.bar(range(len(feat_counts)), feat_counts,
color=colors, alpha=0.85, width=0.7)
ax.axhline(50, color="#ff5252", lw=1.5, ls="--", label="Min reliable = 50")
patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()]
ax.legend(handles=patches, fontsize=8)
ax.set_xlabel("Clip index", fontsize=10)
ax.set_ylabel("Avg tracked features", fontsize=10)
ax.set_title("C β€” Avg features per clip\n(below 50 = unreliable)", fontsize=10)
ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
save_fig(fig, "01_feature_detection_quality")
# ═══════════════════════════════════════════════════════════════════
# PLOT 2 β€” Optical Flow Quiver Field
# ═══════════════════════════════════════════════════════════════════
def plot_02_flow_field(db):
print("[2/10] Flow vector field...")
ev = entries_with_video(db)
if not ev:
print(" Skipped β€” no video files accessible.")
return
straight_e = next((e for e in ev if e["direction"] == "straight"), None)
turn_e = max(ev, key=lambda x: abs(x["turn_ratio"]))
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle("Plot 2 β€” Optical Flow Vector Field\n"
"Left: straight | Right: strongest turn",
fontsize=12, fontweight="bold", color="white")
def draw_quiver(ax, entry, title):
frames = extract_clip_frames(entry["video_path"],
entry["start_frame"],
entry["start_frame"] + 3)
if len(frames) < 2:
return
g0, g1 = frames[0], frames[1]
mask = build_roi_mask(g0.shape)
pts0 = cv2.goodFeaturesToTrack(g0, mask=mask, **ST_PARAMS)
if pts0 is None:
return
pts1, status, _ = cv2.calcOpticalFlowPyrLK(g0, g1, pts0, None, **LK_PARAMS)
good0 = pts0[status == 1].reshape(-1, 2)
good1 = pts1[status == 1].reshape(-1, 2)
dx = good1[:, 0] - good0[:, 0]
dy = good1[:, 1] - good0[:, 1]
sx, sy = 960 / 1920, 540 / 1080
xs = good0[:, 0] * sx
ys = good0[:, 1] * sy
mag = np.sqrt(dx**2 + dy**2)
ax.set_facecolor("#0d1117")
ax.quiver(xs, ys, dx * 12, -dy * 12, mag,
cmap="plasma", alpha=0.85,
scale=1, scale_units="xy", angles="xy", width=0.003)
ax.scatter([FOE_X * sx], [FOE_Y * sy],
c="red", s=150, zorder=6, label=f"FOE ({FOE_X},{FOE_Y})")
ax.axhline(200 * sy, color="#ffeb3b", lw=1.2, ls="--",
alpha=0.6, label="ROI boundary")
ax.axhline(880 * sy, color="#ffeb3b", lw=1.2, ls="--", alpha=0.6)
ax.set_xlim(0, 960); ax.set_ylim(540, 0)
ax.set_xlabel("x (scaled)", fontsize=9)
ax.set_ylabel("y (scaled)", fontsize=9)
ax.set_title(f"{title}\n"
f"dir={entry['direction']} "
f"ratio={entry['turn_ratio']:+.3f}", fontsize=10)
ax.legend(fontsize=8)
if straight_e:
draw_quiver(axes[0], straight_e, "Straight clip")
else:
axes[0].set_title("No straight clip found")
draw_quiver(axes[1], turn_e, "Strongest turn clip")
plt.tight_layout()
save_fig(fig, "02_flow_vector_field")
# ═══════════════════════════════════════════════════════════════════
# PLOT 3 β€” FOE Stability + Lateral Signal per Frame
# ═══════════════════════════════════════════════════════════════════
def plot_03_foe_stability(db):
print("[3/10] FOE stability + lateral signal...")
ev = entries_with_video(db)
if not ev:
print(" Skipped β€” no video files accessible.")
return
# Pick one clip per direction (prefer strongest)
selected = {}
for d in ["left", "right", "straight"]:
candidates = [e for e in ev if e["direction"] == d]
if candidates:
selected[d] = max(candidates, key=lambda x: abs(x["turn_ratio"]))
fig, axes = plt.subplots(len(selected), 2,
figsize=(14, 4 * len(selected)))
if len(selected) == 1:
axes = [axes]
fig.suptitle("Plot 3 β€” Lateral Signal Per Frame (one clip per direction)\n"
"Consistent sign = clean turn detection",
fontsize=12, fontweight="bold", color="white")
for row, (direction, entry) in enumerate(selected.items()):
col = DIR_COLORS[direction]
lat_s = entry["lateral_signals"]
t = np.arange(len(lat_s))
# Left sub-panel: cumulative trajectory
ax_traj = axes[row][0]
traj_raw = np.array(entry["trajectory_raw"])
t2 = np.linspace(0, 1, len(traj_raw))
ax_traj.plot(t2, traj_raw, color=col, lw=2)
ax_traj.fill_between(t2, traj_raw, 0, alpha=0.2, color=col)
ax_traj.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4)
ax_traj.set_xlabel("Normalised time", fontsize=9)
ax_traj.set_ylabel("Cumulative lateral (px)", fontsize=9)
ax_traj.set_title(f"{direction.upper()} "
f"turn_ratio={entry['turn_ratio']:+.3f}\n"
f"Cumulative trajectory",
fontsize=10, color=col)
ax_traj.grid(True, alpha=0.3)
# Right sub-panel: per-frame lateral bars
ax_lat = axes[row][1]
bar_colors = [DIR_COLORS["left"] if v >= 0 else DIR_COLORS["right"]
for v in lat_s]
ax_lat.bar(t, lat_s, color=bar_colors, alpha=0.8, width=0.8)
ax_lat.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4)
ax_lat.set_xlabel("Frame pair index", fontsize=9)
ax_lat.set_ylabel("Lateral signal (px/frame)", fontsize=9)
ax_lat.set_title("Per-frame lateral signal\n"
"Green=rightward(left turn) Red=leftward(right turn)",
fontsize=9)
ax_lat.grid(True, alpha=0.3)
plt.tight_layout()
save_fig(fig, "03_foe_stability")
# ═══════════════════════════════════════════════════════════════════
# PLOT 4 β€” Lateral Signal Profiles (all clips overlaid)
# ═══════════════════════════════════════════════════════════════════
def plot_04_lateral_profiles(db):
print("[4/10] Lateral signal profiles...")
by_dir = defaultdict(list)
for entry in db.values():
by_dir[entry["direction"]].append(entry["lateral_signals"])
fig, axes = plt.subplots(1, 3, figsize=(17, 5), sharey=False)
fig.suptitle("Plot 4 β€” Per-Frame Lateral Signal (all clips overlaid per direction)\n"
"Consistent direction = pipeline is working. Mixed = noise.",
fontsize=12, fontweight="bold", color="white")
for ax, direction in zip(axes, ["left", "straight", "right"]):
signals = by_dir[direction]
col = DIR_COLORS[direction]
for sig in signals:
t = np.linspace(0, 1, len(sig))
ax.plot(t, sig, color=col, lw=1, alpha=0.35)
if signals:
max_len = max(len(s) for s in signals)
resampled = [np.interp(np.linspace(0, 1, max_len),
np.linspace(0, 1, len(s)), s)
for s in signals]
arr = np.array(resampled)
med = np.median(arr, axis=0)
t_med = np.linspace(0, 1, max_len)
ax.plot(t_med, med, color="white", lw=2.5, label="Median", zorder=5)
ax.fill_between(t_med,
np.percentile(arr, 25, axis=0),
np.percentile(arr, 75, axis=0),
alpha=0.2, color=col, label="IQR band")
ax.axhline(0, color="#ffffff", lw=1, ls="--", alpha=0.4)
ax.set_title(f"{direction.upper()} ({len(signals)} clips)",
fontsize=11, color=col)
ax.set_xlabel("Normalised time (0=start, 1=end)", fontsize=9)
ax.set_ylabel("Lateral signal (px/frame)", fontsize=9)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
plt.tight_layout()
save_fig(fig, "04_lateral_signal_profiles")
# ═══════════════════════════════════════════════════════════════════
# PLOT 5 β€” Trajectory Gallery
# ═══════════════════════════════════════════════════════════════════
def plot_05_trajectory_gallery(db):
print("[5/10] Trajectory gallery...")
entries = list(db.values())
n = len(entries)
ncols = 7
nrows = (n + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols,
figsize=(ncols * 2.2, nrows * 2.4))
fig.suptitle("Plot 5 β€” Trajectory Gallery (every clip)\n"
"Each panel = one 3s clip's cumulative lateral trajectory",
fontsize=12, fontweight="bold", color="white")
axes_flat = axes.flatten() if nrows > 1 or ncols > 1 else [axes]
for i, entry in enumerate(entries):
ax = axes_flat[i]
col = DIR_COLORS[entry["direction"]]
traj = np.array(entry["trajectory_raw"])
t = np.linspace(0, 1, len(traj))
ax.plot(t, traj, color=col, lw=1.8)
ax.fill_between(t, traj, 0, alpha=0.18, color=col)
ax.axhline(0, color="#555577", lw=0.8)
ax.set_title(
f"{entry['clip_id'].split('__')[-1]}\n"
f"{entry['direction']} {entry['turn_ratio']:+.2f}",
fontsize=6.5, color=col
)
ax.set_xticks([]); ax.set_yticks([])
for spine in ["bottom", "left"]:
ax.spines[spine].set_color(col)
for spine in ["top", "right"]:
ax.spines[spine].set_color("#333355")
for j in range(i + 1, len(axes_flat)):
axes_flat[j].set_visible(False)
patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()]
fig.legend(handles=patches, loc="lower center", ncol=3,
fontsize=10, bbox_to_anchor=(0.5, -0.01))
plt.tight_layout()
save_fig(fig, "05_trajectory_gallery")
# ═══════════════════════════════════════════════════════════════════
# PLOT 6 β€” Turn Ratio Distribution
# ═══════════════════════════════════════════════════════════════════
def plot_06_turn_ratio_distribution(db):
print("[6/10] Turn ratio distribution...")
ratios = [e["turn_ratio"] for e in db.values()]
directions = [e["direction"] for e in db.values()]
peaks = [e["peak_lateral"] for e in db.values()]
fig, axes = plt.subplots(1, 3, figsize=(17, 5))
fig.suptitle("Plot 6 β€” Turn Ratio Distribution\n"
"Three separate peaks = clean separation between directions",
fontsize=12, fontweight="bold", color="white")
# A: histogram
ax = axes[0]
for d, col in DIR_COLORS.items():
vals = [r for r, dr in zip(ratios, directions) if dr == d]
if vals:
ax.hist(vals, bins=20, color=col, alpha=0.7,
label=f"{d} ({len(vals)})", edgecolor="none")
ax.axvline( 0.60, color="#ffeb3b", lw=2, ls="--", label="+0.60 threshold")
ax.axvline(-0.60, color="#ffeb3b", lw=2, ls="--", label="-0.60 threshold")
ax.axvline(0.0, color="#ffffff", lw=1, ls=":", alpha=0.5)
ax.set_xlabel("turn_ratio", fontsize=10)
ax.set_ylabel("Count", fontsize=10)
ax.set_title("A β€” Histogram\n(gap at Β±0.60?)", fontsize=10)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
# B: strip plot
ax = axes[1]
rng = np.random.default_rng(42)
jitter = rng.uniform(-0.2, 0.2, len(ratios))
for d, col in DIR_COLORS.items():
mask = [dr == d for dr in directions]
r_d = [r for r, m in zip(ratios, mask) if m]
j_d = [j for j, m in zip(jitter, mask) if m]
ax.scatter(r_d, j_d, color=col, s=70, alpha=0.85,
marker=DIR_MARKERS[d], label=d, zorder=3)
ax.axvline( 0.60, color="#ffeb3b", lw=2, ls="--")
ax.axvline(-0.60, color="#ffeb3b", lw=2, ls="--")
ax.axvline(0.0, color="#ffffff", lw=1, ls=":", alpha=0.5)
ax.set_xlabel("turn_ratio", fontsize=10)
ax.set_yticks([])
ax.set_title("B β€” Strip plot\n(overlapping near threshold = borderline clips)", fontsize=10)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3, axis="x")
# C: peak lateral vs turn ratio
ax = axes[2]
for d, col in DIR_COLORS.items():
r_d = [r for r, dr in zip(ratios, directions) if dr == d]
p_d = [p for p, dr in zip(peaks, directions) if dr == d]
ax.scatter(r_d, p_d, color=col, s=70, alpha=0.85,
marker=DIR_MARKERS[d], label=d)
ax.set_xlabel("turn_ratio", fontsize=10)
ax.set_ylabel("peak_lateral (px)", fontsize=10)
ax.set_title("C β€” Peak drift vs turn ratio\n(stronger turn = bigger peak)", fontsize=10)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
plt.tight_layout()
save_fig(fig, "06_turn_ratio_distribution")
# ═══════════════════════════════════════════════════════════════════
# PLOT 7 β€” Embedding Space (PCA + t-SNE)
# ═══════════════════════════════════════════════════════════════════
def plot_07_embedding_space(db):
print("[7/10] Embedding space...")
entries = list(db.values())
embeddings = np.array([e["embedding"] for e in entries])
directions = [e["direction"] for e in entries]
labels = [e["clip_id"].split("__")[-1] for e in entries]
fig = plt.figure(figsize=(18, 7))
fig.suptitle("Plot 7 β€” Embedding Space\n"
"Separated clusters = embeddings discriminate directions well",
fontsize=12, fontweight="bold", color="white")
gs = gridspec.GridSpec(1, 3, figure=fig, wspace=0.35)
# A: PCA
ax_pca = fig.add_subplot(gs[0])
pca = PCA(n_components=2)
pca_2d = pca.fit_transform(embeddings)
for d, col in DIR_COLORS.items():
mask = np.array([dr == d for dr in directions])
if mask.any():
ax_pca.scatter(pca_2d[mask, 0], pca_2d[mask, 1],
color=col, s=90, alpha=0.85,
marker=DIR_MARKERS[d], label=d, zorder=3)
for x, y, lbl in zip(pca_2d[:, 0], pca_2d[:, 1], labels):
ax_pca.annotate(lbl, (x, y), fontsize=5.5, color="#aaaacc",
alpha=0.7, xytext=(3, 3), textcoords="offset points")
ax_pca.set_title(
f"A β€” PCA\n({pca.explained_variance_ratio_.sum()*100:.1f}% variance)", fontsize=10)
ax_pca.set_xlabel(f"PC1 {pca.explained_variance_ratio_[0]*100:.1f}%", fontsize=9)
ax_pca.set_ylabel(f"PC2 {pca.explained_variance_ratio_[1]*100:.1f}%", fontsize=9)
ax_pca.legend(fontsize=9); ax_pca.grid(True, alpha=0.3)
# B: t-SNE
ax_tsne = fig.add_subplot(gs[1])
n_clips = len(embeddings)
if n_clips >= 6:
perp = min(5, n_clips - 1)
tsne = TSNE(n_components=2, perplexity=perp,
random_state=42, max_iter=1000)
ts_2d = tsne.fit_transform(embeddings)
for d, col in DIR_COLORS.items():
mask = np.array([dr == d for dr in directions])
if mask.any():
ax_tsne.scatter(ts_2d[mask, 0], ts_2d[mask, 1],
color=col, s=90, alpha=0.85,
marker=DIR_MARKERS[d], label=d, zorder=3)
for x, y, lbl in zip(ts_2d[:, 0], ts_2d[:, 1], labels):
ax_tsne.annotate(lbl, (x, y), fontsize=5.5, color="#aaaacc",
alpha=0.7, xytext=(3, 3), textcoords="offset points")
ax_tsne.set_title("B β€” t-SNE (non-linear local structure)", fontsize=10)
ax_tsne.legend(fontsize=9)
else:
ax_tsne.text(0.5, 0.5, f"Need β‰₯ 6 clips\n(have {n_clips})",
ha="center", va="center", transform=ax_tsne.transAxes,
color="#ff5252", fontsize=12)
ax_tsne.set_title("B β€” t-SNE", fontsize=10)
ax_tsne.grid(True, alpha=0.3)
# C: explained variance
ax_var = fig.add_subplot(gs[2])
n_comp = min(15, n_clips - 1)
pca_f = PCA(n_components=n_comp)
pca_f.fit(embeddings)
evr = pca_f.explained_variance_ratio_
cum = np.cumsum(evr)
comps = np.arange(1, len(evr) + 1)
ax_var.bar(comps, evr * 100, color="#e040fb", alpha=0.8, label="Individual")
ax_var.plot(comps, cum * 100, color="#ffeb3b", lw=2,
marker="o", ms=4, label="Cumulative")
ax_var.axhline(90, color="#ff5252", lw=1.5, ls="--",
alpha=0.7, label="90% line")
ax_var.set_xlabel("Principal component", fontsize=9)
ax_var.set_ylabel("Explained variance (%)", fontsize=9)
ax_var.set_title("C β€” PCA explained variance\n"
"(few PCs to reach 90% = compact embedding)", fontsize=10)
ax_var.legend(fontsize=8); ax_var.grid(True, alpha=0.3)
save_fig(fig, "07_embedding_space")
# ═══════════════════════════════════════════════════════════════════
# PLOT 8 β€” Score Distribution per Query
# ═══════════════════════════════════════════════════════════════════
def plot_08_score_distribution(db):
print("[8/10] Score distribution per query...")
entries = list(db.values())
fig, axes = plt.subplots(1, 3, figsize=(17, 5))
fig.suptitle("Plot 8 β€” Score Distribution for Each Button Query\n"
"Sharp drop after top results = clear winner. Flat = poor discrimination.",
fontsize=12, fontweight="bold", color="white")
for ax, direction in zip(axes, ["left", "straight", "right"]):
query_emb = CANONICAL[direction]
col = DIR_COLORS[direction]
rows = []
for e in entries:
s = score_clip(query_emb, e["embedding"])
rows.append((s, e["direction"], e["clip_id"].split("__")[-1]))
rows.sort(key=lambda x: x[0], reverse=True)
scores = [r[0] for r in rows]
dirs_out = [r[1] for r in rows]
cids = [r[2] for r in rows]
bar_cols = [DIR_COLORS[d] for d in dirs_out]
x = np.arange(len(scores))
ax.bar(x, scores, color=bar_cols, alpha=0.85, width=0.7)
ax.axhline(0.7, color="#ffeb3b", lw=1.5, ls="--",
alpha=0.8, label="0.70 reference")
ax.set_xlabel("Clip rank", fontsize=9)
ax.set_ylabel("Combined score", fontsize=9)
ax.set_title(f"Query: {direction.upper()}\n"
f"(bar colour = actual clip direction)", fontsize=10, color=col)
ax.set_ylim(0, 1.05)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3, axis="y")
ax.set_xticks(x[::2])
ax.set_xticklabels(cids[::2], rotation=45, fontsize=6, ha="right")
for rank in range(min(3, len(scores))):
ax.text(x[rank], scores[rank] + 0.01,
f"#{rank+1}", ha="center", fontsize=7, color="white")
patches = [mpatches.Patch(color=v, label=k) for k, v in DIR_COLORS.items()]
fig.legend(handles=patches, loc="lower center", ncol=3,
fontsize=9, bbox_to_anchor=(0.5, -0.04))
plt.tight_layout()
save_fig(fig, "08_score_distribution")
# ═══════════════════════════════════════════════════════════════════
# PLOT 9 β€” DTW vs Cosine Scatter
# ═══════════════════════════════════════════════════════════════════
def plot_09_dtw_vs_cosine(db):
print("[9/10] DTW vs cosine scatter...")
entries = list(db.values())
fig, axes = plt.subplots(1, 3, figsize=(17, 5))
fig.suptitle("Plot 9 β€” DTW Similarity vs Cosine Similarity (per query)\n"
"Top-right quadrant = high on both = best matches",
fontsize=12, fontweight="bold", color="white")
for ax, direction in zip(axes, ["left", "straight", "right"]):
query_emb = CANONICAL[direction]
query_traj = query_emb[:N_RESAMPLE].astype(np.double)
cos_s_all, dtw_s_all, dirs_all = [], [], []
for e in entries:
emb = e["embedding"]
traj = emb[:N_RESAMPLE].astype(np.double)
raw = 1 - scipy_cosine(query_emb.astype(np.float64),
emb.astype(np.float64))
cs = (raw + 1) / 2
d = dtw_lib.distance_fast(query_traj, traj)
ds = 1 / (1 + d)
cos_s_all.append(cs)
dtw_s_all.append(ds)
dirs_all.append(e["direction"])
for d, col in DIR_COLORS.items():
mask = [dr == d for dr in dirs_all]
cx = [c for c, m in zip(cos_s_all, mask) if m]
dy = [dv for dv, m in zip(dtw_s_all, mask) if m]
ax.scatter(cx, dy, color=col, s=80, alpha=0.85,
marker=DIR_MARKERS[d], label=d, zorder=3)
# Iso-score lines: score = 0.6*cos + 0.4*dtw β†’ dtw = (score - 0.6*cos)/0.4
cs_grid = np.linspace(0, 1, 200)
for target in [0.5, 0.7, 0.8]:
dtw_iso = (target - 0.6 * cs_grid) / 0.4
valid = (dtw_iso >= 0) & (dtw_iso <= 1)
ax.plot(cs_grid[valid], dtw_iso[valid],
color="#555577", lw=1, ls=":", alpha=0.8)
if valid.any():
xi, yi = cs_grid[valid][-1], dtw_iso[valid][-1]
ax.text(xi, yi, f"score={target}",
fontsize=6, color="#888899")
ax.set_xlabel("Cosine similarity (Ξ±=0.6)", fontsize=9)
ax.set_ylabel("DTW similarity (Ξ²=0.4)", fontsize=9)
ax.set_title(f"Query: {direction.upper()}", fontsize=10,
color=DIR_COLORS[direction])
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.legend(fontsize=8); ax.grid(True, alpha=0.3)
ax.text(0.84, 0.93, "BEST", fontsize=8, color="#00e676",
ha="center", transform=ax.transAxes, alpha=0.7)
ax.text(0.14, 0.07, "WORST", fontsize=8, color="#ff5252",
ha="center", transform=ax.transAxes, alpha=0.7)
plt.tight_layout()
save_fig(fig, "09_dtw_vs_cosine")
# ═══════════════════════════════════════════════════════════════════
# PLOT 10 β€” Confusion Matrix
# ═══════════════════════════════════════════════════════════════════
def plot_10_confusion_matrix(db):
print("[10/10] Confusion matrix...")
entries = list(db.values())
dir_list = ["left", "straight", "right"]
n = len(dir_list)
top_k = 5
confusion = np.zeros((n, n), dtype=int)
for qi, q_dir in enumerate(dir_list):
query_emb = CANONICAL[q_dir]
scored = sorted(
[(score_clip(query_emb, e["embedding"]), e["direction"])
for e in entries],
key=lambda x: x[0], reverse=True
)
for _, ret_dir in scored[:top_k]:
ri = dir_list.index(ret_dir)
confusion[qi, ri] += 1
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle("Plot 10 β€” Direction Confusion Matrix (top-5 retrieval)\n"
"Diagonal = correct. Off-diagonal = wrong direction returned.",
fontsize=12, fontweight="bold", color="white")
# A: heatmap
ax = axes[0]
im = ax.imshow(confusion, cmap="YlOrRd", vmin=0, vmax=top_k)
ax.set_xticks(range(n)); ax.set_yticks(range(n))
ax.set_xticklabels(dir_list, fontsize=11)
ax.set_yticklabels(dir_list, fontsize=11)
ax.set_xlabel("Retrieved direction (top-5)", fontsize=10)
ax.set_ylabel("Query direction", fontsize=10)
ax.set_title("A β€” Raw counts", fontsize=10)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
for i in range(n):
for j in range(n):
ax.text(j, i, str(confusion[i, j]),
ha="center", va="center", fontsize=16, fontweight="bold",
color="black" if confusion[i, j] > top_k * 0.5 else "white")
# B: precision bar chart
ax = axes[1]
precisions = [confusion[i, i] / max(confusion[i].sum(), 1) * 100
for i in range(n)]
cols = [DIR_COLORS[d] for d in dir_list]
bars = ax.bar(dir_list, precisions, color=cols, alpha=0.85, width=0.5)
ax.axhline(100, color="#00e676", lw=1.5, ls="--", alpha=0.6, label="100%")
ax.axhline(60, color="#ffeb3b", lw=1.5, ls="--", alpha=0.6, label="60% threshold")
for bar, val in zip(bars, precisions):
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1.5,
f"{val:.0f}%", ha="center", fontsize=13,
fontweight="bold", color="white")
ax.set_ylim(0, 115)
ax.set_ylabel("Precision@5 (%)", fontsize=10)
ax.set_title("B β€” Precision@5 per direction\n"
"(what % of top-5 are the correct direction?)", fontsize=10)
ax.legend(fontsize=9); ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
save_fig(fig, "10_confusion_matrix")
# ═══════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════
if __name__ == "__main__":
print("=" * 60)
print(" Trajectory Retrieval β€” Diagnostic Visualizations")
print("=" * 60)
if not os.path.exists(PKL_PATH):
print(f"\nERROR: PKL not found at '{PKL_PATH}'")
print("Run video_processor.py first to generate the database.")
raise SystemExit(1)
db = load_db(PKL_PATH)
accessible = sum(1 for e in db.values() if os.path.exists(e["video_path"]))
print(f"Clips total : {len(db)}")
print(f"With accessible video : {accessible}")
print(f"Plots output folder : {OUT_DIR}/\n")
# Plots 1–3 need video files on disk
plot_01_feature_detection(db)
plot_02_flow_field(db)
plot_03_foe_stability(db)
# Plots 4–10 only need the PKL
plot_04_lateral_profiles(db)
plot_05_trajectory_gallery(db)
plot_06_turn_ratio_distribution(db)
plot_07_embedding_space(db)
plot_08_score_distribution(db)
plot_09_dtw_vs_cosine(db)
plot_10_confusion_matrix(db)
print("\n" + "=" * 60)
print(f" All 10 plots saved to β†’ {OUT_DIR}/")
print("=" * 60)