""" retrieval_engine.py =================== Retrieval system: matches a query trajectory embedding against the database. Scoring formula: score = α * cosine_similarity(query, stored) + β * (1 / (1 + DTW_distance(query_50, stored_50))) α = 0.6 (shape direction) β = 0.4 (timing of turn) Canonical embeddings for button queries: straight → flat zero line left → monotonically increasing sigmoid right → monotonically decreasing sigmoid """ import pickle import numpy as np from typing import Dict, List, Tuple, Optional from scipy.spatial.distance import cosine from dtaidistance import dtw from trajectory_extractor import N_RESAMPLE, EMBEDDING_DIM # ───────────────────────────────────────────────────────────── # Scoring weights # ───────────────────────────────────────────────────────────── ALPHA = 0.4 # cosine similarity weight BETA = 0.6 # DTW similarity weight # ───────────────────────────────────────────────────────────── # Canonical Query Embeddings # ───────────────────────────────────────────────────────────── def _make_sigmoid(n: int, direction: str, sharpness: float = 6.0) -> np.ndarray: """ Create a smooth sigmoid-shaped trajectory of length n. direction='left' → rising (positive lateral drift) direction='right' → falling (negative lateral drift) """ x = np.linspace(-sharpness / 2, sharpness / 2, n) sig = 1.0 / (1.0 + np.exp(-x)) sig = (sig - sig.mean()) / (sig.std() + 1e-9) if direction == "right": sig = -sig return sig.astype(np.float32) def _make_canonical_embedding(direction: str) -> np.ndarray: """ Build a full 53-dim canonical embedding for a button query. Layout: [0:50] normalized trajectory curve [50] turn_ratio [51] peak_lateral [52] direction_sign """ dir_sign_map = {"left": 1.0, "straight": 0.0, "right": -1.0} if direction == "straight": traj_50 = np.zeros(N_RESAMPLE, dtype=np.float32) turn_ratio = 0.0 peak = 0.0 elif direction == "left": traj_50 = _make_sigmoid(N_RESAMPLE, "left") turn_ratio = 0.5 peak = 1.0 else: # right traj_50 = _make_sigmoid(N_RESAMPLE, "right") turn_ratio = -0.5 peak = 1.0 emb = np.concatenate([ traj_50, [turn_ratio, peak, dir_sign_map[direction]] ]).astype(np.float32) return emb # Pre-computed canonical embeddings CANONICAL = { "left" : _make_canonical_embedding("left"), "right" : _make_canonical_embedding("right"), "straight": _make_canonical_embedding("straight"), } # ───────────────────────────────────────────────────────────── # Similarity Functions # ───────────────────────────────────────────────────────────── def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: """ Cosine similarity in [-1, 1], rescaled to [0, 1]. 1.0 = identical direction, 0.0 = orthogonal, below 0.5 = opposite. """ a = a.astype(np.float64) b = b.astype(np.float64) # Avoid zero-vector edge case if np.linalg.norm(a) < 1e-9 or np.linalg.norm(b) < 1e-9: return 0.5 # neutral score raw = 1.0 - cosine(a, b) # raw ∈ [-1, 1] return float((raw + 1.0) / 2.0) # rescale to [0, 1] def dtw_sim(a_traj: np.ndarray, b_traj: np.ndarray) -> float: """ DTW distance between two 50-point trajectory curves, converted to a similarity score in [0, 1]. DTW handles clips where the turn happens at different times (e.g., early turn vs late turn both match a "left" query). """ a = a_traj.astype(np.double) b = b_traj.astype(np.double) dist = dtw.distance_fast(a, b) # Convert distance to similarity: sim = 1 / (1 + dist) return float(1.0 / (1.0 + dist)) # def score_clip(query_emb: np.ndarray, clip_emb: np.ndarray) -> float: # """ # Combined score for a single database clip against the query. # score = α * cosine_sim(full_emb) + β * dtw_sim(traj_only) # The full embedding cosine handles direction sign (left vs right). # The DTW on the 50-dim curve handles shape and timing. # """ # cos = cosine_sim(query_emb, clip_emb) # # DTW on trajectory part only (first 50 dims) # dtw_s = dtw_sim(query_emb[:N_RESAMPLE], clip_emb[:N_RESAMPLE]) # return ALPHA * cos + BETA * dtw_s def score_clip(query_emb, clip_emb, alpha=0.6, beta=0.4): cos = cosine_sim(query_emb, clip_emb) dtw_s = dtw_sim(query_emb[:N_RESAMPLE], clip_emb[:N_RESAMPLE]) return alpha * cos + beta * dtw_s # ───────────────────────────────────────────────────────────── # Database Loading # ───────────────────────────────────────────────────────────── def load_database(pkl_path: str) -> Dict: """Load trajectory database from pickle file.""" with open(pkl_path, "rb") as f: db = pickle.load(f) print(f"Loaded database: {len(db)} clips from {pkl_path}") return db # ───────────────────────────────────────────────────────────── # Main Retrieval Function # ───────────────────────────────────────────────────────────── def retrieve_top_k( #chaned this function query_embedding, database, top_k=5, direction_filter=None, alpha=0.6, beta=0.4, # query_embedding: np.ndarray, # database: Dict, # top_k: int = 5, # direction_filter: Optional[str] = None, ) -> List[Dict]: """ Score all database clips against query_embedding. Returns top_k results sorted by score descending. Args: query_embedding : 53-dim query vector database : loaded pkl dict top_k : number of results to return direction_filter : if set ('left'/'right'/'straight'), only score clips with that label (for button queries — faster + cleaner) """ results = [] for clip_id, entry in database.items(): # Optional pre-filter by direction label if direction_filter is not None: if entry["direction"] != direction_filter: continue clip_emb = entry["embedding"] s = score_clip(query_embedding, clip_emb, alpha, beta) results.append({ "clip_id" : clip_id, "score" : round(s, 4), "direction" : entry["direction"], "turn_ratio" : entry["turn_ratio"], "video_path" : entry["video_path"], "video_name" : entry["video_name"], "start_sec" : entry["start_sec"], "end_sec" : entry["end_sec"], "start_frame" : entry["start_frame"], "end_frame" : entry["end_frame"], "trajectory_raw": entry["trajectory_raw"], }) # Sort by score descending results.sort(key=lambda x: x["score"], reverse=True) return results[:top_k] # ───────────────────────────────────────────────────────────── # Button Query Shortcut # ───────────────────────────────────────────────────────────── def query_by_button( direction: str, database: Dict, top_k: int = 5, alpha=0.6, beta=0.4, ) -> List[Dict]: """ Retrieve top clips for a button click (Left / Right / Straight). Uses the canonical embedding for the direction AND pre-filters by direction label for speed. """ assert direction in ("left", "right", "straight"), \ f"direction must be left/right/straight, got: {direction}" query_emb = CANONICAL[direction] # Use direction_filter for button queries: only score matching clips return retrieve_top_k( query_embedding = query_emb, database = database, top_k = top_k, direction_filter = direction, alpha=alpha, beta=beta ) # ───────────────────────────────────────────────────────────── # Sketch Query # ───────────────────────────────────────────────────────────── def query_by_sketch( sketch_points: List[Tuple[float, float]], database: Dict, top_k: int = 5, alpha=0.6, beta=0.4, ) -> List[Dict]: """ Retrieve top clips matching a user-drawn sketch. sketch_points: list of (x, y) pixel coordinates from the canvas x goes left→right, y goes top→bottom We interpret horizontal deviation of the sketch as lateral signal: - sketch goes right → user wants a right-turning trajectory - sketch goes left → user wants a left-turning trajectory - sketch goes straight up → straight driving Steps: 1. Extract the x-coordinates of the sketch 2. Resample to 50 points 3. Interpret as a lateral trajectory 4. Normalize 5. Build full 53-dim embedding 6. Run retrieval """ if len(sketch_points) < 2: raise ValueError("Need at least 2 sketch points") pts = np.array(sketch_points, dtype=np.float32) x_coords = pts[:, 0] y_coords = pts[:, 1] # The sketch y-axis represents forward motion (top = far ahead) # The sketch x-axis represents lateral position # We convert x-deviation from start into lateral signal # Resample x to 50 points orig_idx = np.linspace(0, len(x_coords) - 1, len(x_coords)) target_idx = np.linspace(0, len(x_coords) - 1, N_RESAMPLE) x_resampled = np.interp(target_idx, orig_idx, x_coords) # Convert to lateral signal: deviation from starting x lateral_curve = x_resampled - x_resampled[0] # Note: in canvas coords, right = positive x = right turn # Our embedding: positive = left turn (flow goes right when turning left) # So we flip sign to match embedding convention # this line is responsible of the result (change - to + ) lateral_curve = lateral_curve # Normalize std = lateral_curve.std() if std > 1e-6: lateral_curve_norm = (lateral_curve - lateral_curve.mean()) / std else: lateral_curve_norm = np.zeros(N_RESAMPLE, dtype=np.float32) # Scalar features total_lateral = float(lateral_curve[-1] - lateral_curve[0]) turn_ratio = np.clip(total_lateral / 100.0, -1.0, 1.0) peak_lateral = float(np.max(np.abs(lateral_curve))) if abs(turn_ratio) < 0.08: dir_sign = 0.0 elif turn_ratio > 0: dir_sign = 1.0 else: dir_sign = -1.0 query_emb = np.concatenate([ lateral_curve_norm.astype(np.float32), [turn_ratio, peak_lateral / 100.0, dir_sign] ]).astype(np.float32) # return retrieve_top_k( # query_embedding = query_emb, # database = database, # top_k = top_k, # direction_filter= None, # sketch queries search all # ) # infer direction from sketch for pre-filtering if abs(turn_ratio) >= 0.60: inferred_dir = "left" if turn_ratio > 0 else "right" else: inferred_dir = None # don't filter if ambiguous return retrieve_top_k( query_embedding = query_emb, database = database, top_k = top_k, direction_filter= inferred_dir, alpha=alpha, beta=beta )