Spaces:
Sleeping
Sleeping
| """ | |
| retrieval_engine.py | |
| =================== | |
| Retrieval system: matches a query trajectory embedding against the database. | |
| Scoring formula: | |
| score = Ξ± * cosine_similarity(query, stored) | |
| + Ξ² * (1 / (1 + DTW_distance(query_50, stored_50))) | |
| Ξ± = 0.6 (shape direction) | |
| Ξ² = 0.4 (timing of turn) | |
| Canonical embeddings for button queries: | |
| straight β flat zero line | |
| left β monotonically increasing sigmoid | |
| right β monotonically decreasing sigmoid | |
| """ | |
| import pickle | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| from scipy.spatial.distance import cosine | |
| from dtaidistance import dtw | |
| from trajectory_extractor import N_RESAMPLE, EMBEDDING_DIM | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Scoring weights | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ALPHA = 0.4 # cosine similarity weight | |
| BETA = 0.6 # DTW similarity weight | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Canonical Query Embeddings | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_sigmoid(n: int, direction: str, sharpness: float = 6.0) -> np.ndarray: | |
| """ | |
| Create a smooth sigmoid-shaped trajectory of length n. | |
| direction='left' β rising (positive lateral drift) | |
| direction='right' β falling (negative lateral drift) | |
| """ | |
| x = np.linspace(-sharpness / 2, sharpness / 2, n) | |
| sig = 1.0 / (1.0 + np.exp(-x)) | |
| sig = (sig - sig.mean()) / (sig.std() + 1e-9) | |
| if direction == "right": | |
| sig = -sig | |
| return sig.astype(np.float32) | |
| def _make_canonical_embedding(direction: str) -> np.ndarray: | |
| """ | |
| Build a full 53-dim canonical embedding for a button query. | |
| Layout: | |
| [0:50] normalized trajectory curve | |
| [50] turn_ratio | |
| [51] peak_lateral | |
| [52] direction_sign | |
| """ | |
| dir_sign_map = {"left": 1.0, "straight": 0.0, "right": -1.0} | |
| if direction == "straight": | |
| traj_50 = np.zeros(N_RESAMPLE, dtype=np.float32) | |
| turn_ratio = 0.0 | |
| peak = 0.0 | |
| elif direction == "left": | |
| traj_50 = _make_sigmoid(N_RESAMPLE, "left") | |
| turn_ratio = 0.5 | |
| peak = 1.0 | |
| else: # right | |
| traj_50 = _make_sigmoid(N_RESAMPLE, "right") | |
| turn_ratio = -0.5 | |
| peak = 1.0 | |
| emb = np.concatenate([ | |
| traj_50, | |
| [turn_ratio, peak, dir_sign_map[direction]] | |
| ]).astype(np.float32) | |
| return emb | |
| # Pre-computed canonical embeddings | |
| CANONICAL = { | |
| "left" : _make_canonical_embedding("left"), | |
| "right" : _make_canonical_embedding("right"), | |
| "straight": _make_canonical_embedding("straight"), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Similarity Functions | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: | |
| """ | |
| Cosine similarity in [-1, 1], rescaled to [0, 1]. | |
| 1.0 = identical direction, 0.0 = orthogonal, below 0.5 = opposite. | |
| """ | |
| a = a.astype(np.float64) | |
| b = b.astype(np.float64) | |
| # Avoid zero-vector edge case | |
| if np.linalg.norm(a) < 1e-9 or np.linalg.norm(b) < 1e-9: | |
| return 0.5 # neutral score | |
| raw = 1.0 - cosine(a, b) # raw β [-1, 1] | |
| return float((raw + 1.0) / 2.0) # rescale to [0, 1] | |
| def dtw_sim(a_traj: np.ndarray, b_traj: np.ndarray) -> float: | |
| """ | |
| DTW distance between two 50-point trajectory curves, | |
| converted to a similarity score in [0, 1]. | |
| DTW handles clips where the turn happens at different times | |
| (e.g., early turn vs late turn both match a "left" query). | |
| """ | |
| a = a_traj.astype(np.double) | |
| b = b_traj.astype(np.double) | |
| dist = dtw.distance_fast(a, b) | |
| # Convert distance to similarity: sim = 1 / (1 + dist) | |
| return float(1.0 / (1.0 + dist)) | |
| # def score_clip(query_emb: np.ndarray, clip_emb: np.ndarray) -> float: | |
| # """ | |
| # Combined score for a single database clip against the query. | |
| # score = Ξ± * cosine_sim(full_emb) + Ξ² * dtw_sim(traj_only) | |
| # The full embedding cosine handles direction sign (left vs right). | |
| # The DTW on the 50-dim curve handles shape and timing. | |
| # """ | |
| # cos = cosine_sim(query_emb, clip_emb) | |
| # # DTW on trajectory part only (first 50 dims) | |
| # dtw_s = dtw_sim(query_emb[:N_RESAMPLE], clip_emb[:N_RESAMPLE]) | |
| # return ALPHA * cos + BETA * dtw_s | |
| def score_clip(query_emb, clip_emb, alpha=0.6, beta=0.4): | |
| cos = cosine_sim(query_emb, clip_emb) | |
| dtw_s = dtw_sim(query_emb[:N_RESAMPLE], clip_emb[:N_RESAMPLE]) | |
| return alpha * cos + beta * dtw_s | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Database Loading | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_database(pkl_path: str) -> Dict: | |
| """Load trajectory database from pickle file.""" | |
| with open(pkl_path, "rb") as f: | |
| db = pickle.load(f) | |
| print(f"Loaded database: {len(db)} clips from {pkl_path}") | |
| return db | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main Retrieval Function | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def retrieve_top_k( #chaned this function | |
| query_embedding, | |
| database, | |
| top_k=5, | |
| direction_filter=None, | |
| alpha=0.6, | |
| beta=0.4, | |
| # query_embedding: np.ndarray, | |
| # database: Dict, | |
| # top_k: int = 5, | |
| # direction_filter: Optional[str] = None, | |
| ) -> List[Dict]: | |
| """ | |
| Score all database clips against query_embedding. | |
| Returns top_k results sorted by score descending. | |
| Args: | |
| query_embedding : 53-dim query vector | |
| database : loaded pkl dict | |
| top_k : number of results to return | |
| direction_filter : if set ('left'/'right'/'straight'), | |
| only score clips with that label | |
| (for button queries β faster + cleaner) | |
| """ | |
| results = [] | |
| for clip_id, entry in database.items(): | |
| # Optional pre-filter by direction label | |
| if direction_filter is not None: | |
| if entry["direction"] != direction_filter: | |
| continue | |
| clip_emb = entry["embedding"] | |
| s = score_clip(query_embedding, clip_emb, alpha, beta) | |
| results.append({ | |
| "clip_id" : clip_id, | |
| "score" : round(s, 4), | |
| "direction" : entry["direction"], | |
| "turn_ratio" : entry["turn_ratio"], | |
| "video_path" : entry["video_path"], | |
| "video_name" : entry["video_name"], | |
| "start_sec" : entry["start_sec"], | |
| "end_sec" : entry["end_sec"], | |
| "start_frame" : entry["start_frame"], | |
| "end_frame" : entry["end_frame"], | |
| "trajectory_raw": entry["trajectory_raw"], | |
| }) | |
| # Sort by score descending | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| return results[:top_k] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Button Query Shortcut | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def query_by_button( | |
| direction: str, | |
| database: Dict, | |
| top_k: int = 5, | |
| alpha=0.6, | |
| beta=0.4, | |
| ) -> List[Dict]: | |
| """ | |
| Retrieve top clips for a button click (Left / Right / Straight). | |
| Uses the canonical embedding for the direction AND pre-filters | |
| by direction label for speed. | |
| """ | |
| assert direction in ("left", "right", "straight"), \ | |
| f"direction must be left/right/straight, got: {direction}" | |
| query_emb = CANONICAL[direction] | |
| # Use direction_filter for button queries: only score matching clips | |
| return retrieve_top_k( | |
| query_embedding = query_emb, | |
| database = database, | |
| top_k = top_k, | |
| direction_filter = direction, | |
| alpha=alpha, | |
| beta=beta | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Sketch Query | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def query_by_sketch( | |
| sketch_points: List[Tuple[float, float]], | |
| database: Dict, | |
| top_k: int = 5, | |
| alpha=0.6, | |
| beta=0.4, | |
| ) -> List[Dict]: | |
| """ | |
| Retrieve top clips matching a user-drawn sketch. | |
| sketch_points: list of (x, y) pixel coordinates from the canvas | |
| x goes leftβright, y goes topβbottom | |
| We interpret horizontal deviation of the sketch as lateral signal: | |
| - sketch goes right β user wants a right-turning trajectory | |
| - sketch goes left β user wants a left-turning trajectory | |
| - sketch goes straight up β straight driving | |
| Steps: | |
| 1. Extract the x-coordinates of the sketch | |
| 2. Resample to 50 points | |
| 3. Interpret as a lateral trajectory | |
| 4. Normalize | |
| 5. Build full 53-dim embedding | |
| 6. Run retrieval | |
| """ | |
| if len(sketch_points) < 2: | |
| raise ValueError("Need at least 2 sketch points") | |
| pts = np.array(sketch_points, dtype=np.float32) | |
| x_coords = pts[:, 0] | |
| y_coords = pts[:, 1] | |
| # The sketch y-axis represents forward motion (top = far ahead) | |
| # The sketch x-axis represents lateral position | |
| # We convert x-deviation from start into lateral signal | |
| # Resample x to 50 points | |
| orig_idx = np.linspace(0, len(x_coords) - 1, len(x_coords)) | |
| target_idx = np.linspace(0, len(x_coords) - 1, N_RESAMPLE) | |
| x_resampled = np.interp(target_idx, orig_idx, x_coords) | |
| # Convert to lateral signal: deviation from starting x | |
| lateral_curve = x_resampled - x_resampled[0] | |
| # Note: in canvas coords, right = positive x = right turn | |
| # Our embedding: positive = left turn (flow goes right when turning left) | |
| # So we flip sign to match embedding convention | |
| # this line is responsible of the result (change - to + ) | |
| lateral_curve = lateral_curve | |
| # Normalize | |
| std = lateral_curve.std() | |
| if std > 1e-6: | |
| lateral_curve_norm = (lateral_curve - lateral_curve.mean()) / std | |
| else: | |
| lateral_curve_norm = np.zeros(N_RESAMPLE, dtype=np.float32) | |
| # Scalar features | |
| total_lateral = float(lateral_curve[-1] - lateral_curve[0]) | |
| turn_ratio = np.clip(total_lateral / 100.0, -1.0, 1.0) | |
| peak_lateral = float(np.max(np.abs(lateral_curve))) | |
| if abs(turn_ratio) < 0.08: | |
| dir_sign = 0.0 | |
| elif turn_ratio > 0: | |
| dir_sign = 1.0 | |
| else: | |
| dir_sign = -1.0 | |
| query_emb = np.concatenate([ | |
| lateral_curve_norm.astype(np.float32), | |
| [turn_ratio, peak_lateral / 100.0, dir_sign] | |
| ]).astype(np.float32) | |
| # return retrieve_top_k( | |
| # query_embedding = query_emb, | |
| # database = database, | |
| # top_k = top_k, | |
| # direction_filter= None, # sketch queries search all | |
| # ) | |
| # infer direction from sketch for pre-filtering | |
| if abs(turn_ratio) >= 0.60: | |
| inferred_dir = "left" if turn_ratio > 0 else "right" | |
| else: | |
| inferred_dir = None # don't filter if ambiguous | |
| return retrieve_top_k( | |
| query_embedding = query_emb, | |
| database = database, | |
| top_k = top_k, | |
| direction_filter= inferred_dir, | |
| alpha=alpha, beta=beta | |
| ) | |