LK_trajectory / retrieval_engine.py
sanskar753's picture
Upload folder using huggingface_hub
48f8a1e verified
"""
retrieval_engine.py
===================
Retrieval system: matches a query trajectory embedding against the database.
Scoring formula:
score = Ξ± * cosine_similarity(query, stored)
+ Ξ² * (1 / (1 + DTW_distance(query_50, stored_50)))
Ξ± = 0.6 (shape direction)
Ξ² = 0.4 (timing of turn)
Canonical embeddings for button queries:
straight β†’ flat zero line
left β†’ monotonically increasing sigmoid
right β†’ monotonically decreasing sigmoid
"""
import pickle
import numpy as np
from typing import Dict, List, Tuple, Optional
from scipy.spatial.distance import cosine
from dtaidistance import dtw
from trajectory_extractor import N_RESAMPLE, EMBEDDING_DIM
# ─────────────────────────────────────────────────────────────
# Scoring weights
# ─────────────────────────────────────────────────────────────
ALPHA = 0.4 # cosine similarity weight
BETA = 0.6 # DTW similarity weight
# ─────────────────────────────────────────────────────────────
# Canonical Query Embeddings
# ─────────────────────────────────────────────────────────────
def _make_sigmoid(n: int, direction: str, sharpness: float = 6.0) -> np.ndarray:
"""
Create a smooth sigmoid-shaped trajectory of length n.
direction='left' β†’ rising (positive lateral drift)
direction='right' β†’ falling (negative lateral drift)
"""
x = np.linspace(-sharpness / 2, sharpness / 2, n)
sig = 1.0 / (1.0 + np.exp(-x))
sig = (sig - sig.mean()) / (sig.std() + 1e-9)
if direction == "right":
sig = -sig
return sig.astype(np.float32)
def _make_canonical_embedding(direction: str) -> np.ndarray:
"""
Build a full 53-dim canonical embedding for a button query.
Layout:
[0:50] normalized trajectory curve
[50] turn_ratio
[51] peak_lateral
[52] direction_sign
"""
dir_sign_map = {"left": 1.0, "straight": 0.0, "right": -1.0}
if direction == "straight":
traj_50 = np.zeros(N_RESAMPLE, dtype=np.float32)
turn_ratio = 0.0
peak = 0.0
elif direction == "left":
traj_50 = _make_sigmoid(N_RESAMPLE, "left")
turn_ratio = 0.5
peak = 1.0
else: # right
traj_50 = _make_sigmoid(N_RESAMPLE, "right")
turn_ratio = -0.5
peak = 1.0
emb = np.concatenate([
traj_50,
[turn_ratio, peak, dir_sign_map[direction]]
]).astype(np.float32)
return emb
# Pre-computed canonical embeddings
CANONICAL = {
"left" : _make_canonical_embedding("left"),
"right" : _make_canonical_embedding("right"),
"straight": _make_canonical_embedding("straight"),
}
# ─────────────────────────────────────────────────────────────
# Similarity Functions
# ─────────────────────────────────────────────────────────────
def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
"""
Cosine similarity in [-1, 1], rescaled to [0, 1].
1.0 = identical direction, 0.0 = orthogonal, below 0.5 = opposite.
"""
a = a.astype(np.float64)
b = b.astype(np.float64)
# Avoid zero-vector edge case
if np.linalg.norm(a) < 1e-9 or np.linalg.norm(b) < 1e-9:
return 0.5 # neutral score
raw = 1.0 - cosine(a, b) # raw ∈ [-1, 1]
return float((raw + 1.0) / 2.0) # rescale to [0, 1]
def dtw_sim(a_traj: np.ndarray, b_traj: np.ndarray) -> float:
"""
DTW distance between two 50-point trajectory curves,
converted to a similarity score in [0, 1].
DTW handles clips where the turn happens at different times
(e.g., early turn vs late turn both match a "left" query).
"""
a = a_traj.astype(np.double)
b = b_traj.astype(np.double)
dist = dtw.distance_fast(a, b)
# Convert distance to similarity: sim = 1 / (1 + dist)
return float(1.0 / (1.0 + dist))
# def score_clip(query_emb: np.ndarray, clip_emb: np.ndarray) -> float:
# """
# Combined score for a single database clip against the query.
# score = Ξ± * cosine_sim(full_emb) + Ξ² * dtw_sim(traj_only)
# The full embedding cosine handles direction sign (left vs right).
# The DTW on the 50-dim curve handles shape and timing.
# """
# cos = cosine_sim(query_emb, clip_emb)
# # DTW on trajectory part only (first 50 dims)
# dtw_s = dtw_sim(query_emb[:N_RESAMPLE], clip_emb[:N_RESAMPLE])
# return ALPHA * cos + BETA * dtw_s
def score_clip(query_emb, clip_emb, alpha=0.6, beta=0.4):
cos = cosine_sim(query_emb, clip_emb)
dtw_s = dtw_sim(query_emb[:N_RESAMPLE], clip_emb[:N_RESAMPLE])
return alpha * cos + beta * dtw_s
# ─────────────────────────────────────────────────────────────
# Database Loading
# ─────────────────────────────────────────────────────────────
def load_database(pkl_path: str) -> Dict:
"""Load trajectory database from pickle file."""
with open(pkl_path, "rb") as f:
db = pickle.load(f)
print(f"Loaded database: {len(db)} clips from {pkl_path}")
return db
# ─────────────────────────────────────────────────────────────
# Main Retrieval Function
# ─────────────────────────────────────────────────────────────
def retrieve_top_k( #chaned this function
query_embedding,
database,
top_k=5,
direction_filter=None,
alpha=0.6,
beta=0.4,
# query_embedding: np.ndarray,
# database: Dict,
# top_k: int = 5,
# direction_filter: Optional[str] = None,
) -> List[Dict]:
"""
Score all database clips against query_embedding.
Returns top_k results sorted by score descending.
Args:
query_embedding : 53-dim query vector
database : loaded pkl dict
top_k : number of results to return
direction_filter : if set ('left'/'right'/'straight'),
only score clips with that label
(for button queries β€” faster + cleaner)
"""
results = []
for clip_id, entry in database.items():
# Optional pre-filter by direction label
if direction_filter is not None:
if entry["direction"] != direction_filter:
continue
clip_emb = entry["embedding"]
s = score_clip(query_embedding, clip_emb, alpha, beta)
results.append({
"clip_id" : clip_id,
"score" : round(s, 4),
"direction" : entry["direction"],
"turn_ratio" : entry["turn_ratio"],
"video_path" : entry["video_path"],
"video_name" : entry["video_name"],
"start_sec" : entry["start_sec"],
"end_sec" : entry["end_sec"],
"start_frame" : entry["start_frame"],
"end_frame" : entry["end_frame"],
"trajectory_raw": entry["trajectory_raw"],
})
# Sort by score descending
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]
# ─────────────────────────────────────────────────────────────
# Button Query Shortcut
# ─────────────────────────────────────────────────────────────
def query_by_button(
direction: str,
database: Dict,
top_k: int = 5,
alpha=0.6,
beta=0.4,
) -> List[Dict]:
"""
Retrieve top clips for a button click (Left / Right / Straight).
Uses the canonical embedding for the direction AND pre-filters
by direction label for speed.
"""
assert direction in ("left", "right", "straight"), \
f"direction must be left/right/straight, got: {direction}"
query_emb = CANONICAL[direction]
# Use direction_filter for button queries: only score matching clips
return retrieve_top_k(
query_embedding = query_emb,
database = database,
top_k = top_k,
direction_filter = direction,
alpha=alpha,
beta=beta
)
# ─────────────────────────────────────────────────────────────
# Sketch Query
# ─────────────────────────────────────────────────────────────
def query_by_sketch(
sketch_points: List[Tuple[float, float]],
database: Dict,
top_k: int = 5,
alpha=0.6,
beta=0.4,
) -> List[Dict]:
"""
Retrieve top clips matching a user-drawn sketch.
sketch_points: list of (x, y) pixel coordinates from the canvas
x goes left→right, y goes top→bottom
We interpret horizontal deviation of the sketch as lateral signal:
- sketch goes right β†’ user wants a right-turning trajectory
- sketch goes left β†’ user wants a left-turning trajectory
- sketch goes straight up β†’ straight driving
Steps:
1. Extract the x-coordinates of the sketch
2. Resample to 50 points
3. Interpret as a lateral trajectory
4. Normalize
5. Build full 53-dim embedding
6. Run retrieval
"""
if len(sketch_points) < 2:
raise ValueError("Need at least 2 sketch points")
pts = np.array(sketch_points, dtype=np.float32)
x_coords = pts[:, 0]
y_coords = pts[:, 1]
# The sketch y-axis represents forward motion (top = far ahead)
# The sketch x-axis represents lateral position
# We convert x-deviation from start into lateral signal
# Resample x to 50 points
orig_idx = np.linspace(0, len(x_coords) - 1, len(x_coords))
target_idx = np.linspace(0, len(x_coords) - 1, N_RESAMPLE)
x_resampled = np.interp(target_idx, orig_idx, x_coords)
# Convert to lateral signal: deviation from starting x
lateral_curve = x_resampled - x_resampled[0]
# Note: in canvas coords, right = positive x = right turn
# Our embedding: positive = left turn (flow goes right when turning left)
# So we flip sign to match embedding convention
# this line is responsible of the result (change - to + )
lateral_curve = lateral_curve
# Normalize
std = lateral_curve.std()
if std > 1e-6:
lateral_curve_norm = (lateral_curve - lateral_curve.mean()) / std
else:
lateral_curve_norm = np.zeros(N_RESAMPLE, dtype=np.float32)
# Scalar features
total_lateral = float(lateral_curve[-1] - lateral_curve[0])
turn_ratio = np.clip(total_lateral / 100.0, -1.0, 1.0)
peak_lateral = float(np.max(np.abs(lateral_curve)))
if abs(turn_ratio) < 0.08:
dir_sign = 0.0
elif turn_ratio > 0:
dir_sign = 1.0
else:
dir_sign = -1.0
query_emb = np.concatenate([
lateral_curve_norm.astype(np.float32),
[turn_ratio, peak_lateral / 100.0, dir_sign]
]).astype(np.float32)
# return retrieve_top_k(
# query_embedding = query_emb,
# database = database,
# top_k = top_k,
# direction_filter= None, # sketch queries search all
# )
# infer direction from sketch for pre-filtering
if abs(turn_ratio) >= 0.60:
inferred_dir = "left" if turn_ratio > 0 else "right"
else:
inferred_dir = None # don't filter if ambiguous
return retrieve_top_k(
query_embedding = query_emb,
database = database,
top_k = top_k,
direction_filter= inferred_dir,
alpha=alpha, beta=beta
)